diff --git a/rl_coach/graph_managers/graph_manager.py b/rl_coach/graph_managers/graph_manager.py index b9013fd..d13a59b 100644 --- a/rl_coach/graph_managers/graph_manager.py +++ b/rl_coach/graph_managers/graph_manager.py @@ -562,8 +562,7 @@ class GraphManager(object): screen.warning("No checkpoint to restore in: {}".format(self.task_parameters.checkpoint_restore_dir)) else: screen.log_title("Loading checkpoint: {}".format(checkpoint.model_checkpoint_path)) - if not hasattr(self.agent_params.memory, 'memory_backend_params') or self.agent_params.memory.memory_backend_params.run_type != str(RunType.ROLLOUT_WORKER): - self.checkpoint_saver.restore(self.sess, checkpoint.model_checkpoint_path) + self.checkpoint_saver.restore(self.sess, checkpoint.model_checkpoint_path) [manager.restore_checkpoint(self.task_parameters.checkpoint_restore_dir) for manager in self.level_managers] diff --git a/rl_coach/tests/graph_managers/test_basic_rl_graph_manager.py b/rl_coach/tests/graph_managers/test_basic_rl_graph_manager.py index a572fd9..214ef31 100644 --- a/rl_coach/tests/graph_managers/test_basic_rl_graph_manager.py +++ b/rl_coach/tests/graph_managers/test_basic_rl_graph_manager.py @@ -1,10 +1,8 @@ import os import sys -import gc sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))) import tensorflow as tf -from rl_coach.base_parameters import TaskParameters, DistributedTaskParameters, Frameworks, RunType -from rl_coach.memories.backend.memory import MemoryBackendParameters +from rl_coach.base_parameters import TaskParameters, DistributedTaskParameters, Frameworks from rl_coach.utils import get_open_port from multiprocessing import Process from tensorflow import logging @@ -43,34 +41,12 @@ def test_basic_rl_graph_manager_with_cartpole_dqn(): experiment_path="./experiments/test")) # graph_manager.improve() -# Test for identifying memory leak in restore_checkpoint -@pytest.mark.unit_test -def test_basic_rl_graph_manager_with_cartpole_dqn_and_repeated_checkpoint_restore(): - tf.reset_default_graph() - from rl_coach.presets.CartPole_DQN import graph_manager - assert graph_manager - graph_manager.create_graph(task_parameters=TaskParameters(framework_type=Frameworks.tensorflow, - experiment_path="./experiments/test", - apply_stop_condition=True)) - graph_manager.improve() - graph_manager.save_checkpoint() - - graph_manager.task_parameters.checkpoint_restore_dir = "./experiments/test/checkpoint" - graph_manager.agent_params.memory.register_var('memory_backend_params', - MemoryBackendParameters(store_type=None, - orchestrator_type=None, - run_type=str(RunType.ROLLOUT_WORKER))) - while True: - graph_manager.restore_checkpoint() - gc.collect() - if __name__ == '__main__': pass # test_basic_rl_graph_manager_with_pong_a3c() # test_basic_rl_graph_manager_with_ant_a3c() # test_basic_rl_graph_manager_with_pong_nec() - # test_basic_rl_graph_manager_with_cartpole_dqn_and_repeated_checkpoint_restore() # test_basic_rl_graph_manager_with_cartpole_dqn() #test_basic_rl_graph_manager_multithreaded_with_pong_a3c() #test_basic_rl_graph_manager_with_doom_basic_dqn() \ No newline at end of file