From 801aed5e1000e9a741034163e3fdea2b0d245f57 Mon Sep 17 00:00:00 2001 From: gouravr Date: Sat, 15 Dec 2018 12:26:31 -0800 Subject: [PATCH] Changes to avoid memory leak in rollout worker Currently in rollout worker, we call restore_checkpoint repeatedly to load the latest model in memory. The restore checkpoint functions calls checkpoint_saver. Checkpoint saver uses GlobalVariablesSaver which does not release the references of the previous model variables. This leads to the situation where the memory keeps on growing before crashing the rollout worker. This change avoid using the checkpoint saver in the rollout worker as I believe it is not needed in this code path. Also added a test to easily reproduce the issue using CartPole example. We were also seeing this issue with the AWS DeepRacer implementation and the current implementation avoid the memory leak there as well. --- rl_coach/graph_managers/graph_manager.py | 3 ++- .../test_basic_rl_graph_manager.py | 26 ++++++++++++++++++- 2 files changed, 27 insertions(+), 2 deletions(-) diff --git a/rl_coach/graph_managers/graph_manager.py b/rl_coach/graph_managers/graph_manager.py index d13a59b..b9013fd 100644 --- a/rl_coach/graph_managers/graph_manager.py +++ b/rl_coach/graph_managers/graph_manager.py @@ -562,7 +562,8 @@ class GraphManager(object): screen.warning("No checkpoint to restore in: {}".format(self.task_parameters.checkpoint_restore_dir)) else: screen.log_title("Loading checkpoint: {}".format(checkpoint.model_checkpoint_path)) - self.checkpoint_saver.restore(self.sess, checkpoint.model_checkpoint_path) + if not hasattr(self.agent_params.memory, 'memory_backend_params') or self.agent_params.memory.memory_backend_params.run_type != str(RunType.ROLLOUT_WORKER): + self.checkpoint_saver.restore(self.sess, checkpoint.model_checkpoint_path) [manager.restore_checkpoint(self.task_parameters.checkpoint_restore_dir) for manager in self.level_managers] diff --git a/rl_coach/tests/graph_managers/test_basic_rl_graph_manager.py b/rl_coach/tests/graph_managers/test_basic_rl_graph_manager.py index 214ef31..a572fd9 100644 --- a/rl_coach/tests/graph_managers/test_basic_rl_graph_manager.py +++ b/rl_coach/tests/graph_managers/test_basic_rl_graph_manager.py @@ -1,8 +1,10 @@ import os import sys +import gc sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))) import tensorflow as tf -from rl_coach.base_parameters import TaskParameters, DistributedTaskParameters, Frameworks +from rl_coach.base_parameters import TaskParameters, DistributedTaskParameters, Frameworks, RunType +from rl_coach.memories.backend.memory import MemoryBackendParameters from rl_coach.utils import get_open_port from multiprocessing import Process from tensorflow import logging @@ -41,12 +43,34 @@ def test_basic_rl_graph_manager_with_cartpole_dqn(): experiment_path="./experiments/test")) # graph_manager.improve() +# Test for identifying memory leak in restore_checkpoint +@pytest.mark.unit_test +def test_basic_rl_graph_manager_with_cartpole_dqn_and_repeated_checkpoint_restore(): + tf.reset_default_graph() + from rl_coach.presets.CartPole_DQN import graph_manager + assert graph_manager + graph_manager.create_graph(task_parameters=TaskParameters(framework_type=Frameworks.tensorflow, + experiment_path="./experiments/test", + apply_stop_condition=True)) + graph_manager.improve() + graph_manager.save_checkpoint() + + graph_manager.task_parameters.checkpoint_restore_dir = "./experiments/test/checkpoint" + graph_manager.agent_params.memory.register_var('memory_backend_params', + MemoryBackendParameters(store_type=None, + orchestrator_type=None, + run_type=str(RunType.ROLLOUT_WORKER))) + while True: + graph_manager.restore_checkpoint() + gc.collect() + if __name__ == '__main__': pass # test_basic_rl_graph_manager_with_pong_a3c() # test_basic_rl_graph_manager_with_ant_a3c() # test_basic_rl_graph_manager_with_pong_nec() + # test_basic_rl_graph_manager_with_cartpole_dqn_and_repeated_checkpoint_restore() # test_basic_rl_graph_manager_with_cartpole_dqn() #test_basic_rl_graph_manager_multithreaded_with_pong_a3c() #test_basic_rl_graph_manager_with_doom_basic_dqn() \ No newline at end of file