1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-23 14:42:31 +01:00

Revert "Avoid Memory Leak in Rollout worker"

This reverts commit c694766fad.
This commit is contained in:
Gourav Roy
2019-01-02 22:35:06 -08:00
parent 2461892c9e
commit 6dd7ae2343
3 changed files with 11 additions and 17 deletions

View File

@@ -150,7 +150,7 @@ class GraphManager(object):
# create a session (it needs to be created after all the graph ops were created)
self.sess = None
self.restore_checkpoint()
self.create_session(task_parameters=task_parameters)
self._phase = self.phase = RunPhase.UNDEFINED
@@ -261,6 +261,8 @@ class GraphManager(object):
self.checkpoint_saver = SaverCollection()
for level in self.level_managers:
self.checkpoint_saver.update(level.collect_savers())
# restore from checkpoint if given
self.restore_checkpoint()
def save_graph(self) -> None:
"""
@@ -556,20 +558,14 @@ class GraphManager(object):
else:
checkpoint = get_checkpoint_state(self.task_parameters.checkpoint_restore_dir)
# As part of this restore, Agent recreates the global, target and online networks
[manager.restore_checkpoint(self.task_parameters.checkpoint_restore_dir) for manager in self.level_managers]
# Recreate the session to use the new TF Graphs
self.create_session(self.task_parameters)
if checkpoint is None:
screen.warning("No checkpoint to restore in: {}".format(self.task_parameters.checkpoint_restore_dir))
else:
screen.log_title("Loading checkpoint: {}".format(checkpoint.model_checkpoint_path))
self.checkpoint_saver.restore(self.sess, checkpoint.model_checkpoint_path)
else:
# Create the session to use the new TF Graphs
self.create_session(self.task_parameters)
if not hasattr(self.agent_params.memory, 'memory_backend_params') or self.agent_params.memory.memory_backend_params.run_type != str(RunType.ROLLOUT_WORKER):
self.checkpoint_saver.restore(self.sess, checkpoint.model_checkpoint_path)
[manager.restore_checkpoint(self.task_parameters.checkpoint_restore_dir) for manager in self.level_managers]
def _get_checkpoint_state_tf(self):
import tensorflow as tf