mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 19:20:19 +01:00
Add RedisDataStore (#295)
* GraphManager.set_session also sets self.sess * make sure that GraphManager.fetch_from_worker uses training phase * remove unnecessary phase setting in training worker * reorganize rollout worker * provide default name to GlobalVariableSaver.__init__ since it isn't really used anyway * allow dividing TrainingSteps and EnvironmentSteps * add timestamps to the log * added redis data store * conflict merge fix
This commit is contained in:
committed by
shadiendrawis
parent
34e1c04f29
commit
7b0fccb041
@@ -1,4 +1,4 @@
|
||||
#
|
||||
#
|
||||
# Copyright (c) 2017 Intel Corporation
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@@ -15,12 +15,7 @@
|
||||
#
|
||||
|
||||
|
||||
"""
|
||||
"""
|
||||
import time
|
||||
|
||||
from rl_coach.base_parameters import TaskParameters, DistributedCoachSynchronizationType
|
||||
from rl_coach import core_types
|
||||
from rl_coach.base_parameters import DistributedCoachSynchronizationType
|
||||
from rl_coach.logger import screen
|
||||
|
||||
|
||||
@@ -32,22 +27,26 @@ def data_store_ckpt_load(data_store):
|
||||
def training_worker(graph_manager, task_parameters, data_store, is_multi_node_test):
|
||||
"""
|
||||
restore a checkpoint then perform rollouts using the restored model
|
||||
|
||||
:param graph_manager: An instance of the graph manager
|
||||
:param data_store: An instance of DataStore which can be used to communicate policies to roll out workers
|
||||
:param task_parameters: An instance of task parameters
|
||||
:param is_multi_node_test: If this is a multi node test insted of a normal run.
|
||||
"""
|
||||
# Load checkpoint if provided
|
||||
if task_parameters.checkpoint_restore_path:
|
||||
data_store_ckpt_load(data_store)
|
||||
|
||||
# initialize graph
|
||||
graph_manager.create_graph(task_parameters)
|
||||
|
||||
|
||||
else:
|
||||
# initialize graph
|
||||
graph_manager.create_graph(task_parameters)
|
||||
|
||||
# save randomly initialized graph
|
||||
graph_manager.save_checkpoint()
|
||||
data_store.save_policy(graph_manager)
|
||||
|
||||
|
||||
# training loop
|
||||
steps = 0
|
||||
@@ -60,21 +59,17 @@ def training_worker(graph_manager, task_parameters, data_store, is_multi_node_te
|
||||
|
||||
while steps < graph_manager.improve_steps.num_steps:
|
||||
|
||||
graph_manager.phase = core_types.RunPhase.TRAIN
|
||||
if is_multi_node_test and graph_manager.get_current_episodes_count() > graph_manager.preset_validation_params.max_episodes_to_achieve_reward:
|
||||
# Test failed as it has not reached the required success rate
|
||||
graph_manager.flush_finished()
|
||||
screen.error("Could not reach required success by {} episodes.".format(graph_manager.preset_validation_params.max_episodes_to_achieve_reward), crash=True)
|
||||
|
||||
graph_manager.fetch_from_worker(graph_manager.agent_params.algorithm.num_consecutive_playing_steps)
|
||||
graph_manager.phase = core_types.RunPhase.UNDEFINED
|
||||
|
||||
if graph_manager.should_train():
|
||||
steps += 1
|
||||
|
||||
graph_manager.phase = core_types.RunPhase.TRAIN
|
||||
graph_manager.train()
|
||||
graph_manager.phase = core_types.RunPhase.UNDEFINED
|
||||
|
||||
if steps * graph_manager.agent_params.algorithm.num_consecutive_playing_steps.num_steps > graph_manager.steps_between_evaluation_periods.num_steps * eval_offset:
|
||||
eval_offset += 1
|
||||
@@ -82,6 +77,10 @@ def training_worker(graph_manager, task_parameters, data_store, is_multi_node_te
|
||||
break
|
||||
|
||||
if graph_manager.agent_params.algorithm.distributed_coach_synchronization_type == DistributedCoachSynchronizationType.SYNC:
|
||||
graph_manager.save_checkpoint()
|
||||
data_store.save_policy(graph_manager)
|
||||
else:
|
||||
graph_manager.occasionally_save_checkpoint()
|
||||
# NOTE: this implementation conflated occasionally saving checkpoints for later use
|
||||
# in production with checkpoints saved for communication to rollout workers.
|
||||
# TODO: this should be implemented with a new parameter: distributed_coach_synchronization_frequency or similar
|
||||
# graph_manager.occasionally_save_checkpoint()
|
||||
raise NotImplementedError()
|
||||
|
||||
Reference in New Issue
Block a user