diff --git a/rl_coach/tests/graph_managers/test_basic_rl_graph_manager.py b/rl_coach/tests/graph_managers/test_basic_rl_graph_manager.py index fb5ba9d..a572fd9 100644 --- a/rl_coach/tests/graph_managers/test_basic_rl_graph_manager.py +++ b/rl_coach/tests/graph_managers/test_basic_rl_graph_manager.py @@ -52,17 +52,17 @@ def test_basic_rl_graph_manager_with_cartpole_dqn_and_repeated_checkpoint_restor graph_manager.create_graph(task_parameters=TaskParameters(framework_type=Frameworks.tensorflow, experiment_path="./experiments/test", apply_stop_condition=True)) - # graph_manager.improve() - # graph_manager.save_checkpoint() - # - # graph_manager.task_parameters.checkpoint_restore_dir = "./experiments/test/checkpoint" - # graph_manager.agent_params.memory.register_var('memory_backend_params', - # MemoryBackendParameters(store_type=None, - # orchestrator_type=None, - # run_type=str(RunType.ROLLOUT_WORKER))) - # while True: - # graph_manager.restore_checkpoint() - # gc.collect() + graph_manager.improve() + graph_manager.save_checkpoint() + + graph_manager.task_parameters.checkpoint_restore_dir = "./experiments/test/checkpoint" + graph_manager.agent_params.memory.register_var('memory_backend_params', + MemoryBackendParameters(store_type=None, + orchestrator_type=None, + run_type=str(RunType.ROLLOUT_WORKER))) + while True: + graph_manager.restore_checkpoint() + gc.collect() if __name__ == '__main__':