mirror of
https://github.com/gryf/coach.git
synced 2025-12-18 03:30:19 +01:00
Add RedisDataStore (#295)
* GraphManager.set_session also sets self.sess * make sure that GraphManager.fetch_from_worker uses training phase * remove unnecessary phase setting in training worker * reorganize rollout worker * provide default name to GlobalVariableSaver.__init__ since it isn't really used anyway * allow dividing TrainingSteps and EnvironmentSteps * add timestamps to the log * added redis data store * conflict merge fix
This commit is contained in:
committed by
shadiendrawis
parent
34e1c04f29
commit
7b0fccb041
@@ -232,18 +232,14 @@ class GraphManager(object):
|
||||
else:
|
||||
checkpoint_dir = task_parameters.checkpoint_save_dir
|
||||
|
||||
self.sess = create_monitored_session(target=task_parameters.worker_target,
|
||||
task_index=task_parameters.task_index,
|
||||
checkpoint_dir=checkpoint_dir,
|
||||
checkpoint_save_secs=task_parameters.checkpoint_save_secs,
|
||||
config=config)
|
||||
# set the session for all the modules
|
||||
self.set_session(self.sess)
|
||||
self.set_session(create_monitored_session(target=task_parameters.worker_target,
|
||||
task_index=task_parameters.task_index,
|
||||
checkpoint_dir=checkpoint_dir,
|
||||
checkpoint_save_secs=task_parameters.checkpoint_save_secs,
|
||||
config=config))
|
||||
else:
|
||||
# regular session
|
||||
self.sess = tf.Session(config=config)
|
||||
# set the session for all the modules
|
||||
self.set_session(self.sess)
|
||||
self.set_session(tf.Session(config=config))
|
||||
|
||||
# the TF graph is static, and therefore is saved once - in the beginning of the experiment
|
||||
if hasattr(self.task_parameters, 'checkpoint_save_dir') and self.task_parameters.checkpoint_save_dir:
|
||||
@@ -366,6 +362,8 @@ class GraphManager(object):
|
||||
Set the deep learning framework session for all the modules in the graph
|
||||
:return: None
|
||||
"""
|
||||
self.sess = sess
|
||||
|
||||
[manager.set_session(sess) for manager in self.level_managers]
|
||||
|
||||
def heatup(self, steps: PlayingStepsType) -> None:
|
||||
@@ -710,8 +708,9 @@ class GraphManager(object):
|
||||
|
||||
def fetch_from_worker(self, num_consecutive_playing_steps=None):
|
||||
if hasattr(self, 'memory_backend'):
|
||||
for transition in self.memory_backend.fetch(num_consecutive_playing_steps):
|
||||
self.emulate_act_on_trainer(EnvironmentSteps(1), transition)
|
||||
with self.phase_context(RunPhase.TRAIN):
|
||||
for transition in self.memory_backend.fetch(num_consecutive_playing_steps):
|
||||
self.emulate_act_on_trainer(EnvironmentSteps(1), transition)
|
||||
|
||||
def setup_memory_backend(self) -> None:
|
||||
if hasattr(self.agent_params.memory, 'memory_backend_params'):
|
||||
|
||||
Reference in New Issue
Block a user