1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-18 11:40:18 +01:00

Add RedisDataStore (#295)

* GraphManager.set_session also sets self.sess

* make sure that GraphManager.fetch_from_worker uses training phase

* remove unnecessary phase setting in training worker

* reorganize rollout worker

* provide default name to GlobalVariableSaver.__init__ since it isn't really used anyway

* allow dividing TrainingSteps and EnvironmentSteps

* add timestamps to the log

* added redis data store

* conflict merge fix
This commit is contained in:
Zach Dwiel
2019-08-28 14:15:58 -04:00
committed by shadiendrawis
parent 34e1c04f29
commit 7b0fccb041
18 changed files with 528 additions and 120 deletions

View File

@@ -232,18 +232,14 @@ class GraphManager(object):
else:
checkpoint_dir = task_parameters.checkpoint_save_dir
self.sess = create_monitored_session(target=task_parameters.worker_target,
task_index=task_parameters.task_index,
checkpoint_dir=checkpoint_dir,
checkpoint_save_secs=task_parameters.checkpoint_save_secs,
config=config)
# set the session for all the modules
self.set_session(self.sess)
self.set_session(create_monitored_session(target=task_parameters.worker_target,
task_index=task_parameters.task_index,
checkpoint_dir=checkpoint_dir,
checkpoint_save_secs=task_parameters.checkpoint_save_secs,
config=config))
else:
# regular session
self.sess = tf.Session(config=config)
# set the session for all the modules
self.set_session(self.sess)
self.set_session(tf.Session(config=config))
# the TF graph is static, and therefore is saved once - in the beginning of the experiment
if hasattr(self.task_parameters, 'checkpoint_save_dir') and self.task_parameters.checkpoint_save_dir:
@@ -366,6 +362,8 @@ class GraphManager(object):
Set the deep learning framework session for all the modules in the graph
:return: None
"""
self.sess = sess
[manager.set_session(sess) for manager in self.level_managers]
def heatup(self, steps: PlayingStepsType) -> None:
@@ -710,8 +708,9 @@ class GraphManager(object):
def fetch_from_worker(self, num_consecutive_playing_steps=None):
if hasattr(self, 'memory_backend'):
for transition in self.memory_backend.fetch(num_consecutive_playing_steps):
self.emulate_act_on_trainer(EnvironmentSteps(1), transition)
with self.phase_context(RunPhase.TRAIN):
for transition in self.memory_backend.fetch(num_consecutive_playing_steps):
self.emulate_act_on_trainer(EnvironmentSteps(1), transition)
def setup_memory_backend(self) -> None:
if hasattr(self.agent_params.memory, 'memory_backend_params'):