diff --git a/rl_coach/graph_managers/graph_manager.py b/rl_coach/graph_managers/graph_manager.py index d13a59b..b9013fd 100644 --- a/rl_coach/graph_managers/graph_manager.py +++ b/rl_coach/graph_managers/graph_manager.py @@ -562,7 +562,8 @@ class GraphManager(object): screen.warning("No checkpoint to restore in: {}".format(self.task_parameters.checkpoint_restore_dir)) else: screen.log_title("Loading checkpoint: {}".format(checkpoint.model_checkpoint_path)) - self.checkpoint_saver.restore(self.sess, checkpoint.model_checkpoint_path) + if not hasattr(self.agent_params.memory, 'memory_backend_params') or self.agent_params.memory.memory_backend_params.run_type != str(RunType.ROLLOUT_WORKER): + self.checkpoint_saver.restore(self.sess, checkpoint.model_checkpoint_path) [manager.restore_checkpoint(self.task_parameters.checkpoint_restore_dir) for manager in self.level_managers] diff --git a/rl_coach/tests/graph_managers/test_basic_rl_graph_manager.py b/rl_coach/tests/graph_managers/test_basic_rl_graph_manager.py index 214ef31..a572fd9 100644 --- a/rl_coach/tests/graph_managers/test_basic_rl_graph_manager.py +++ b/rl_coach/tests/graph_managers/test_basic_rl_graph_manager.py @@ -1,8 +1,10 @@ import os import sys +import gc sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))) import tensorflow as tf -from rl_coach.base_parameters import TaskParameters, DistributedTaskParameters, Frameworks +from rl_coach.base_parameters import TaskParameters, DistributedTaskParameters, Frameworks, RunType +from rl_coach.memories.backend.memory import MemoryBackendParameters from rl_coach.utils import get_open_port from multiprocessing import Process from tensorflow import logging @@ -41,12 +43,34 @@ def test_basic_rl_graph_manager_with_cartpole_dqn(): experiment_path="./experiments/test")) # graph_manager.improve() +# Test for identifying memory leak in restore_checkpoint +@pytest.mark.unit_test +def test_basic_rl_graph_manager_with_cartpole_dqn_and_repeated_checkpoint_restore(): + tf.reset_default_graph() + from rl_coach.presets.CartPole_DQN import graph_manager + assert graph_manager + graph_manager.create_graph(task_parameters=TaskParameters(framework_type=Frameworks.tensorflow, + experiment_path="./experiments/test", + apply_stop_condition=True)) + graph_manager.improve() + graph_manager.save_checkpoint() + + graph_manager.task_parameters.checkpoint_restore_dir = "./experiments/test/checkpoint" + graph_manager.agent_params.memory.register_var('memory_backend_params', + MemoryBackendParameters(store_type=None, + orchestrator_type=None, + run_type=str(RunType.ROLLOUT_WORKER))) + while True: + graph_manager.restore_checkpoint() + gc.collect() + if __name__ == '__main__': pass # test_basic_rl_graph_manager_with_pong_a3c() # test_basic_rl_graph_manager_with_ant_a3c() # test_basic_rl_graph_manager_with_pong_nec() + # test_basic_rl_graph_manager_with_cartpole_dqn_and_repeated_checkpoint_restore() # test_basic_rl_graph_manager_with_cartpole_dqn() #test_basic_rl_graph_manager_multithreaded_with_pong_a3c() #test_basic_rl_graph_manager_with_doom_basic_dqn() \ No newline at end of file