diff --git a/rl_coach/agents/agent.py b/rl_coach/agents/agent.py index 059cae8..a6b5297 100644 --- a/rl_coach/agents/agent.py +++ b/rl_coach/agents/agent.py @@ -953,12 +953,6 @@ class Agent(AgentInterface): self.input_filter.restore_state_from_checkpoint(checkpoint_dir, checkpoint_prefix) self.pre_network_filter.restore_state_from_checkpoint(checkpoint_dir, checkpoint_prefix) - if self.ap.task_parameters.framework_type == Frameworks.tensorflow: - import tensorflow as tf - tf.reset_default_graph() - # Recreate all the networks of the agent - self.networks = self.create_networks() - # no output filters currently have an internal state to restore # self.output_filter.restore_state_from_checkpoint(checkpoint_dir) diff --git a/rl_coach/graph_managers/graph_manager.py b/rl_coach/graph_managers/graph_manager.py index 10f314a..b9013fd 100644 --- a/rl_coach/graph_managers/graph_manager.py +++ b/rl_coach/graph_managers/graph_manager.py @@ -150,7 +150,7 @@ class GraphManager(object): # create a session (it needs to be created after all the graph ops were created) self.sess = None - self.restore_checkpoint() + self.create_session(task_parameters=task_parameters) self._phase = self.phase = RunPhase.UNDEFINED @@ -261,6 +261,8 @@ class GraphManager(object): self.checkpoint_saver = SaverCollection() for level in self.level_managers: self.checkpoint_saver.update(level.collect_savers()) + # restore from checkpoint if given + self.restore_checkpoint() def save_graph(self) -> None: """ @@ -556,20 +558,14 @@ class GraphManager(object): else: checkpoint = get_checkpoint_state(self.task_parameters.checkpoint_restore_dir) - # As part of this restore, Agent recreates the global, target and online networks - [manager.restore_checkpoint(self.task_parameters.checkpoint_restore_dir) for manager in self.level_managers] - - # Recreate the session to use the new TF Graphs - self.create_session(self.task_parameters) - if checkpoint is None: screen.warning("No checkpoint to restore in: {}".format(self.task_parameters.checkpoint_restore_dir)) else: screen.log_title("Loading checkpoint: {}".format(checkpoint.model_checkpoint_path)) - self.checkpoint_saver.restore(self.sess, checkpoint.model_checkpoint_path) - else: - # Create the session to use the new TF Graphs - self.create_session(self.task_parameters) + if not hasattr(self.agent_params.memory, 'memory_backend_params') or self.agent_params.memory.memory_backend_params.run_type != str(RunType.ROLLOUT_WORKER): + self.checkpoint_saver.restore(self.sess, checkpoint.model_checkpoint_path) + + [manager.restore_checkpoint(self.task_parameters.checkpoint_restore_dir) for manager in self.level_managers] def _get_checkpoint_state_tf(self): import tensorflow as tf diff --git a/rl_coach/tests/graph_managers/test_basic_rl_graph_manager.py b/rl_coach/tests/graph_managers/test_basic_rl_graph_manager.py index e0d9585..fb5ba9d 100644 --- a/rl_coach/tests/graph_managers/test_basic_rl_graph_manager.py +++ b/rl_coach/tests/graph_managers/test_basic_rl_graph_manager.py @@ -56,6 +56,10 @@ def test_basic_rl_graph_manager_with_cartpole_dqn_and_repeated_checkpoint_restor # graph_manager.save_checkpoint() # # graph_manager.task_parameters.checkpoint_restore_dir = "./experiments/test/checkpoint" + # graph_manager.agent_params.memory.register_var('memory_backend_params', + # MemoryBackendParameters(store_type=None, + # orchestrator_type=None, + # run_type=str(RunType.ROLLOUT_WORKER))) # while True: # graph_manager.restore_checkpoint() # gc.collect()