1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-18 11:40:18 +01:00

Avoid Memory Leak in Rollout worker

ISSUE: When we restore checkpoints, we create new nodes in the
Tensorflow graph. This happens when we assign new value (op node) to
RefVariable in GlobalVariableSaver. With every restore the size of TF
graph increases as new nodes are created and old unused nodes are not
removed from the graph. This causes the memory leak in
restore_checkpoint codepath.

FIX: We reset the Tensorflow graph and recreate the Global, Online and
Target networks on every restore. This ensures that the old unused nodes
in TF graph is dropped.
This commit is contained in:
Gourav Roy
2018-12-25 20:50:34 -08:00
parent 02f2db1264
commit c694766fad
3 changed files with 17 additions and 11 deletions

View File

@@ -953,6 +953,12 @@ class Agent(AgentInterface):
self.input_filter.restore_state_from_checkpoint(checkpoint_dir, checkpoint_prefix) self.input_filter.restore_state_from_checkpoint(checkpoint_dir, checkpoint_prefix)
self.pre_network_filter.restore_state_from_checkpoint(checkpoint_dir, checkpoint_prefix) self.pre_network_filter.restore_state_from_checkpoint(checkpoint_dir, checkpoint_prefix)
if self.ap.task_parameters.framework_type == Frameworks.tensorflow:
import tensorflow as tf
tf.reset_default_graph()
# Recreate all the networks of the agent
self.networks = self.create_networks()
# no output filters currently have an internal state to restore # no output filters currently have an internal state to restore
# self.output_filter.restore_state_from_checkpoint(checkpoint_dir) # self.output_filter.restore_state_from_checkpoint(checkpoint_dir)

View File

@@ -150,7 +150,7 @@ class GraphManager(object):
# create a session (it needs to be created after all the graph ops were created) # create a session (it needs to be created after all the graph ops were created)
self.sess = None self.sess = None
self.create_session(task_parameters=task_parameters) self.restore_checkpoint()
self._phase = self.phase = RunPhase.UNDEFINED self._phase = self.phase = RunPhase.UNDEFINED
@@ -261,8 +261,6 @@ class GraphManager(object):
self.checkpoint_saver = SaverCollection() self.checkpoint_saver = SaverCollection()
for level in self.level_managers: for level in self.level_managers:
self.checkpoint_saver.update(level.collect_savers()) self.checkpoint_saver.update(level.collect_savers())
# restore from checkpoint if given
self.restore_checkpoint()
def save_graph(self) -> None: def save_graph(self) -> None:
""" """
@@ -558,14 +556,20 @@ class GraphManager(object):
else: else:
checkpoint = get_checkpoint_state(self.task_parameters.checkpoint_restore_dir) checkpoint = get_checkpoint_state(self.task_parameters.checkpoint_restore_dir)
# As part of this restore, Agent recreates the global, target and online networks
[manager.restore_checkpoint(self.task_parameters.checkpoint_restore_dir) for manager in self.level_managers]
# Recreate the session to use the new TF Graphs
self.create_session(self.task_parameters)
if checkpoint is None: if checkpoint is None:
screen.warning("No checkpoint to restore in: {}".format(self.task_parameters.checkpoint_restore_dir)) screen.warning("No checkpoint to restore in: {}".format(self.task_parameters.checkpoint_restore_dir))
else: else:
screen.log_title("Loading checkpoint: {}".format(checkpoint.model_checkpoint_path)) screen.log_title("Loading checkpoint: {}".format(checkpoint.model_checkpoint_path))
if not hasattr(self.agent_params.memory, 'memory_backend_params') or self.agent_params.memory.memory_backend_params.run_type != str(RunType.ROLLOUT_WORKER):
self.checkpoint_saver.restore(self.sess, checkpoint.model_checkpoint_path) self.checkpoint_saver.restore(self.sess, checkpoint.model_checkpoint_path)
else:
[manager.restore_checkpoint(self.task_parameters.checkpoint_restore_dir) for manager in self.level_managers] # Create the session to use the new TF Graphs
self.create_session(self.task_parameters)
def _get_checkpoint_state_tf(self): def _get_checkpoint_state_tf(self):
import tensorflow as tf import tensorflow as tf

View File

@@ -56,10 +56,6 @@ def test_basic_rl_graph_manager_with_cartpole_dqn_and_repeated_checkpoint_restor
# graph_manager.save_checkpoint() # graph_manager.save_checkpoint()
# #
# graph_manager.task_parameters.checkpoint_restore_dir = "./experiments/test/checkpoint" # graph_manager.task_parameters.checkpoint_restore_dir = "./experiments/test/checkpoint"
# graph_manager.agent_params.memory.register_var('memory_backend_params',
# MemoryBackendParameters(store_type=None,
# orchestrator_type=None,
# run_type=str(RunType.ROLLOUT_WORKER)))
# while True: # while True:
# graph_manager.restore_checkpoint() # graph_manager.restore_checkpoint()
# gc.collect() # gc.collect()