mirror of
https://github.com/gryf/coach.git
synced 2025-12-18 19:50:17 +01:00
Revert "Avoid Memory Leak in Rollout worker"
This reverts commit c694766fad.
This commit is contained in:
@@ -953,12 +953,6 @@ class Agent(AgentInterface):
|
|||||||
self.input_filter.restore_state_from_checkpoint(checkpoint_dir, checkpoint_prefix)
|
self.input_filter.restore_state_from_checkpoint(checkpoint_dir, checkpoint_prefix)
|
||||||
self.pre_network_filter.restore_state_from_checkpoint(checkpoint_dir, checkpoint_prefix)
|
self.pre_network_filter.restore_state_from_checkpoint(checkpoint_dir, checkpoint_prefix)
|
||||||
|
|
||||||
if self.ap.task_parameters.framework_type == Frameworks.tensorflow:
|
|
||||||
import tensorflow as tf
|
|
||||||
tf.reset_default_graph()
|
|
||||||
# Recreate all the networks of the agent
|
|
||||||
self.networks = self.create_networks()
|
|
||||||
|
|
||||||
# no output filters currently have an internal state to restore
|
# no output filters currently have an internal state to restore
|
||||||
# self.output_filter.restore_state_from_checkpoint(checkpoint_dir)
|
# self.output_filter.restore_state_from_checkpoint(checkpoint_dir)
|
||||||
|
|
||||||
|
|||||||
@@ -150,7 +150,7 @@ class GraphManager(object):
|
|||||||
|
|
||||||
# create a session (it needs to be created after all the graph ops were created)
|
# create a session (it needs to be created after all the graph ops were created)
|
||||||
self.sess = None
|
self.sess = None
|
||||||
self.restore_checkpoint()
|
self.create_session(task_parameters=task_parameters)
|
||||||
|
|
||||||
self._phase = self.phase = RunPhase.UNDEFINED
|
self._phase = self.phase = RunPhase.UNDEFINED
|
||||||
|
|
||||||
@@ -261,6 +261,8 @@ class GraphManager(object):
|
|||||||
self.checkpoint_saver = SaverCollection()
|
self.checkpoint_saver = SaverCollection()
|
||||||
for level in self.level_managers:
|
for level in self.level_managers:
|
||||||
self.checkpoint_saver.update(level.collect_savers())
|
self.checkpoint_saver.update(level.collect_savers())
|
||||||
|
# restore from checkpoint if given
|
||||||
|
self.restore_checkpoint()
|
||||||
|
|
||||||
def save_graph(self) -> None:
|
def save_graph(self) -> None:
|
||||||
"""
|
"""
|
||||||
@@ -556,20 +558,14 @@ class GraphManager(object):
|
|||||||
else:
|
else:
|
||||||
checkpoint = get_checkpoint_state(self.task_parameters.checkpoint_restore_dir)
|
checkpoint = get_checkpoint_state(self.task_parameters.checkpoint_restore_dir)
|
||||||
|
|
||||||
# As part of this restore, Agent recreates the global, target and online networks
|
|
||||||
[manager.restore_checkpoint(self.task_parameters.checkpoint_restore_dir) for manager in self.level_managers]
|
|
||||||
|
|
||||||
# Recreate the session to use the new TF Graphs
|
|
||||||
self.create_session(self.task_parameters)
|
|
||||||
|
|
||||||
if checkpoint is None:
|
if checkpoint is None:
|
||||||
screen.warning("No checkpoint to restore in: {}".format(self.task_parameters.checkpoint_restore_dir))
|
screen.warning("No checkpoint to restore in: {}".format(self.task_parameters.checkpoint_restore_dir))
|
||||||
else:
|
else:
|
||||||
screen.log_title("Loading checkpoint: {}".format(checkpoint.model_checkpoint_path))
|
screen.log_title("Loading checkpoint: {}".format(checkpoint.model_checkpoint_path))
|
||||||
self.checkpoint_saver.restore(self.sess, checkpoint.model_checkpoint_path)
|
if not hasattr(self.agent_params.memory, 'memory_backend_params') or self.agent_params.memory.memory_backend_params.run_type != str(RunType.ROLLOUT_WORKER):
|
||||||
else:
|
self.checkpoint_saver.restore(self.sess, checkpoint.model_checkpoint_path)
|
||||||
# Create the session to use the new TF Graphs
|
|
||||||
self.create_session(self.task_parameters)
|
[manager.restore_checkpoint(self.task_parameters.checkpoint_restore_dir) for manager in self.level_managers]
|
||||||
|
|
||||||
def _get_checkpoint_state_tf(self):
|
def _get_checkpoint_state_tf(self):
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|||||||
@@ -56,6 +56,10 @@ def test_basic_rl_graph_manager_with_cartpole_dqn_and_repeated_checkpoint_restor
|
|||||||
# graph_manager.save_checkpoint()
|
# graph_manager.save_checkpoint()
|
||||||
#
|
#
|
||||||
# graph_manager.task_parameters.checkpoint_restore_dir = "./experiments/test/checkpoint"
|
# graph_manager.task_parameters.checkpoint_restore_dir = "./experiments/test/checkpoint"
|
||||||
|
# graph_manager.agent_params.memory.register_var('memory_backend_params',
|
||||||
|
# MemoryBackendParameters(store_type=None,
|
||||||
|
# orchestrator_type=None,
|
||||||
|
# run_type=str(RunType.ROLLOUT_WORKER)))
|
||||||
# while True:
|
# while True:
|
||||||
# graph_manager.restore_checkpoint()
|
# graph_manager.restore_checkpoint()
|
||||||
# gc.collect()
|
# gc.collect()
|
||||||
|
|||||||
Reference in New Issue
Block a user