From 619ea0944e1f652a426045038ee2977996733758 Mon Sep 17 00:00:00 2001 From: Gourav Roy Date: Wed, 2 Jan 2019 23:06:44 -0800 Subject: [PATCH] Avoid Memory Leak in Rollout worker ISSUE: When we restore checkpoints, we create new nodes in the Tensorflow graph. This happens when we assign new value (op node) to RefVariable in GlobalVariableSaver. With every restore the size of TF graph increases as new nodes are created and old unused nodes are not removed from the graph. This causes the memory leak in restore_checkpoint codepath. FIX: We use TF placeholder to update the variables which avoids the memory leak. --- .../tensorflow_components/savers.py | 15 +++++++++++-- .../test_basic_rl_graph_manager.py | 21 +++++++++++++++++++ 2 files changed, 34 insertions(+), 2 deletions(-) diff --git a/rl_coach/architectures/tensorflow_components/savers.py b/rl_coach/architectures/tensorflow_components/savers.py index 38a36ee..0f7ddbf 100644 --- a/rl_coach/architectures/tensorflow_components/savers.py +++ b/rl_coach/architectures/tensorflow_components/savers.py @@ -32,6 +32,14 @@ class GlobalVariableSaver(Saver): # target network is never saved or restored directly from checkpoint, so we are removing all its variables from the list # the target network would be synched back from the online network in graph_manager.improve(...), at the beginning of the run flow. self._variables = [v for v in self._variables if '/target' not in v.name] + + # Using a placeholder to update the variable during restore to avoid memory leak. + # Ref: https://github.com/tensorflow/tensorflow/issues/4151 + self._variable_placeholders = [tf.placeholder(v.dtype, shape=v.get_shape()) for v in self._variables] + self._variable_update_ops = [] + for i in range(len(self._variables)): + self._variable_update_ops.append(self._variables[i].assign(self._variable_placeholders[i])) + self._saver = tf.train.Saver(self._variables) @property @@ -66,8 +74,11 @@ class GlobalVariableSaver(Saver): # TODO: Can this be more generic so that `global/` and `online/` are not hardcoded here? new_name = var_name.replace('global/', 'online/') variables[new_name] = reader.get_tensor(var_name) - # Assign all variables - sess.run([v.assign(variables[v.name.split(':')[0]]) for v in self._variables]) + + # Assign all variables using placeholder + for i in range(len(self._variables)): + variable_name = self._variables[i].name.split(':')[0] + sess.run(self._variable_update_ops[i], {self._variable_placeholders[i]: variables[variable_name]}) def merge(self, other: 'Saver'): """ diff --git a/rl_coach/tests/graph_managers/test_basic_rl_graph_manager.py b/rl_coach/tests/graph_managers/test_basic_rl_graph_manager.py index 214ef31..fbdc098 100644 --- a/rl_coach/tests/graph_managers/test_basic_rl_graph_manager.py +++ b/rl_coach/tests/graph_managers/test_basic_rl_graph_manager.py @@ -1,8 +1,10 @@ +import gc import os import sys sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))) import tensorflow as tf from rl_coach.base_parameters import TaskParameters, DistributedTaskParameters, Frameworks +from rl_coach.core_types import EnvironmentSteps from rl_coach.utils import get_open_port from multiprocessing import Process from tensorflow import logging @@ -41,6 +43,24 @@ def test_basic_rl_graph_manager_with_cartpole_dqn(): experiment_path="./experiments/test")) # graph_manager.improve() +# Test for identifying memory leak in restore_checkpoint +@pytest.mark.unit_test +def test_basic_rl_graph_manager_with_cartpole_dqn_and_repeated_checkpoint_restore(): + tf.reset_default_graph() + from rl_coach.presets.CartPole_DQN import graph_manager + assert graph_manager + graph_manager.create_graph(task_parameters=TaskParameters(framework_type=Frameworks.tensorflow, + experiment_path="./experiments/test", + apply_stop_condition=True)) + # graph_manager.improve() + # graph_manager.evaluate(EnvironmentSteps(1000)) + # graph_manager.save_checkpoint() + # + # graph_manager.task_parameters.checkpoint_restore_dir = "./experiments/test/checkpoint" + # while True: + # graph_manager.restore_checkpoint() + # graph_manager.evaluate(EnvironmentSteps(1000)) + # gc.collect() if __name__ == '__main__': pass @@ -48,5 +68,6 @@ if __name__ == '__main__': # test_basic_rl_graph_manager_with_ant_a3c() # test_basic_rl_graph_manager_with_pong_nec() # test_basic_rl_graph_manager_with_cartpole_dqn() + # test_basic_rl_graph_manager_with_cartpole_dqn_and_repeated_checkpoint_restore() #test_basic_rl_graph_manager_multithreaded_with_pong_a3c() #test_basic_rl_graph_manager_with_doom_basic_dqn() \ No newline at end of file