From c694766faddd72b2b3966430f9f56ace27693606 Mon Sep 17 00:00:00 2001 From: Gourav Roy Date: Tue, 25 Dec 2018 20:50:34 -0800 Subject: [PATCH] Avoid Memory Leak in Rollout worker ISSUE: When we restore checkpoints, we create new nodes in the Tensorflow graph. This happens when we assign new value (op node) to RefVariable in GlobalVariableSaver. With every restore the size of TF graph increases as new nodes are created and old unused nodes are not removed from the graph. This causes the memory leak in restore_checkpoint codepath. FIX: We reset the Tensorflow graph and recreate the Global, Online and Target networks on every restore. This ensures that the old unused nodes in TF graph is dropped. --- rl_coach/agents/agent.py | 6 ++++++ rl_coach/graph_managers/graph_manager.py | 18 +++++++++++------- .../test_basic_rl_graph_manager.py | 4 ---- 3 files changed, 17 insertions(+), 11 deletions(-) diff --git a/rl_coach/agents/agent.py b/rl_coach/agents/agent.py index a6b5297..059cae8 100644 --- a/rl_coach/agents/agent.py +++ b/rl_coach/agents/agent.py @@ -953,6 +953,12 @@ class Agent(AgentInterface): self.input_filter.restore_state_from_checkpoint(checkpoint_dir, checkpoint_prefix) self.pre_network_filter.restore_state_from_checkpoint(checkpoint_dir, checkpoint_prefix) + if self.ap.task_parameters.framework_type == Frameworks.tensorflow: + import tensorflow as tf + tf.reset_default_graph() + # Recreate all the networks of the agent + self.networks = self.create_networks() + # no output filters currently have an internal state to restore # self.output_filter.restore_state_from_checkpoint(checkpoint_dir) diff --git a/rl_coach/graph_managers/graph_manager.py b/rl_coach/graph_managers/graph_manager.py index b9013fd..10f314a 100644 --- a/rl_coach/graph_managers/graph_manager.py +++ b/rl_coach/graph_managers/graph_manager.py @@ -150,7 +150,7 @@ class GraphManager(object): # create a session (it needs to be created after all the graph ops were created) self.sess = None - self.create_session(task_parameters=task_parameters) + self.restore_checkpoint() self._phase = self.phase = RunPhase.UNDEFINED @@ -261,8 +261,6 @@ class GraphManager(object): self.checkpoint_saver = SaverCollection() for level in self.level_managers: self.checkpoint_saver.update(level.collect_savers()) - # restore from checkpoint if given - self.restore_checkpoint() def save_graph(self) -> None: """ @@ -558,14 +556,20 @@ class GraphManager(object): else: checkpoint = get_checkpoint_state(self.task_parameters.checkpoint_restore_dir) + # As part of this restore, Agent recreates the global, target and online networks + [manager.restore_checkpoint(self.task_parameters.checkpoint_restore_dir) for manager in self.level_managers] + + # Recreate the session to use the new TF Graphs + self.create_session(self.task_parameters) + if checkpoint is None: screen.warning("No checkpoint to restore in: {}".format(self.task_parameters.checkpoint_restore_dir)) else: screen.log_title("Loading checkpoint: {}".format(checkpoint.model_checkpoint_path)) - if not hasattr(self.agent_params.memory, 'memory_backend_params') or self.agent_params.memory.memory_backend_params.run_type != str(RunType.ROLLOUT_WORKER): - self.checkpoint_saver.restore(self.sess, checkpoint.model_checkpoint_path) - - [manager.restore_checkpoint(self.task_parameters.checkpoint_restore_dir) for manager in self.level_managers] + self.checkpoint_saver.restore(self.sess, checkpoint.model_checkpoint_path) + else: + # Create the session to use the new TF Graphs + self.create_session(self.task_parameters) def _get_checkpoint_state_tf(self): import tensorflow as tf diff --git a/rl_coach/tests/graph_managers/test_basic_rl_graph_manager.py b/rl_coach/tests/graph_managers/test_basic_rl_graph_manager.py index fb5ba9d..e0d9585 100644 --- a/rl_coach/tests/graph_managers/test_basic_rl_graph_manager.py +++ b/rl_coach/tests/graph_managers/test_basic_rl_graph_manager.py @@ -56,10 +56,6 @@ def test_basic_rl_graph_manager_with_cartpole_dqn_and_repeated_checkpoint_restor # graph_manager.save_checkpoint() # # graph_manager.task_parameters.checkpoint_restore_dir = "./experiments/test/checkpoint" - # graph_manager.agent_params.memory.register_var('memory_backend_params', - # MemoryBackendParameters(store_type=None, - # orchestrator_type=None, - # run_type=str(RunType.ROLLOUT_WORKER))) # while True: # graph_manager.restore_checkpoint() # gc.collect()