1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 19:20:19 +01:00

Add RedisDataStore (#295)

* GraphManager.set_session also sets self.sess

* make sure that GraphManager.fetch_from_worker uses training phase

* remove unnecessary phase setting in training worker

* reorganize rollout worker

* provide default name to GlobalVariableSaver.__init__ since it isn't really used anyway

* allow dividing TrainingSteps and EnvironmentSteps

* add timestamps to the log

* added redis data store

* conflict merge fix
This commit is contained in:
Zach Dwiel
2019-08-28 14:15:58 -04:00
committed by shadiendrawis
parent 34e1c04f29
commit 7b0fccb041
18 changed files with 528 additions and 120 deletions

View File

@@ -1,4 +1,4 @@
#
#
# Copyright (c) 2017 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -15,12 +15,7 @@
#
"""
"""
import time
from rl_coach.base_parameters import TaskParameters, DistributedCoachSynchronizationType
from rl_coach import core_types
from rl_coach.base_parameters import DistributedCoachSynchronizationType
from rl_coach.logger import screen
@@ -32,22 +27,26 @@ def data_store_ckpt_load(data_store):
def training_worker(graph_manager, task_parameters, data_store, is_multi_node_test):
"""
restore a checkpoint then perform rollouts using the restored model
:param graph_manager: An instance of the graph manager
:param data_store: An instance of DataStore which can be used to communicate policies to roll out workers
:param task_parameters: An instance of task parameters
:param is_multi_node_test: If this is a multi node test insted of a normal run.
"""
# Load checkpoint if provided
if task_parameters.checkpoint_restore_path:
data_store_ckpt_load(data_store)
# initialize graph
graph_manager.create_graph(task_parameters)
else:
# initialize graph
graph_manager.create_graph(task_parameters)
# save randomly initialized graph
graph_manager.save_checkpoint()
data_store.save_policy(graph_manager)
# training loop
steps = 0
@@ -60,21 +59,17 @@ def training_worker(graph_manager, task_parameters, data_store, is_multi_node_te
while steps < graph_manager.improve_steps.num_steps:
graph_manager.phase = core_types.RunPhase.TRAIN
if is_multi_node_test and graph_manager.get_current_episodes_count() > graph_manager.preset_validation_params.max_episodes_to_achieve_reward:
# Test failed as it has not reached the required success rate
graph_manager.flush_finished()
screen.error("Could not reach required success by {} episodes.".format(graph_manager.preset_validation_params.max_episodes_to_achieve_reward), crash=True)
graph_manager.fetch_from_worker(graph_manager.agent_params.algorithm.num_consecutive_playing_steps)
graph_manager.phase = core_types.RunPhase.UNDEFINED
if graph_manager.should_train():
steps += 1
graph_manager.phase = core_types.RunPhase.TRAIN
graph_manager.train()
graph_manager.phase = core_types.RunPhase.UNDEFINED
if steps * graph_manager.agent_params.algorithm.num_consecutive_playing_steps.num_steps > graph_manager.steps_between_evaluation_periods.num_steps * eval_offset:
eval_offset += 1
@@ -82,6 +77,10 @@ def training_worker(graph_manager, task_parameters, data_store, is_multi_node_te
break
if graph_manager.agent_params.algorithm.distributed_coach_synchronization_type == DistributedCoachSynchronizationType.SYNC:
graph_manager.save_checkpoint()
data_store.save_policy(graph_manager)
else:
graph_manager.occasionally_save_checkpoint()
# NOTE: this implementation conflated occasionally saving checkpoints for later use
# in production with checkpoints saved for communication to rollout workers.
# TODO: this should be implemented with a new parameter: distributed_coach_synchronization_frequency or similar
# graph_manager.occasionally_save_checkpoint()
raise NotImplementedError()