diff --git a/rl_coach/agents/agent.py b/rl_coach/agents/agent.py index a6b5297..059cae8 100644 --- a/rl_coach/agents/agent.py +++ b/rl_coach/agents/agent.py @@ -953,6 +953,12 @@ class Agent(AgentInterface): self.input_filter.restore_state_from_checkpoint(checkpoint_dir, checkpoint_prefix) self.pre_network_filter.restore_state_from_checkpoint(checkpoint_dir, checkpoint_prefix) + if self.ap.task_parameters.framework_type == Frameworks.tensorflow: + import tensorflow as tf + tf.reset_default_graph() + # Recreate all the networks of the agent + self.networks = self.create_networks() + # no output filters currently have an internal state to restore # self.output_filter.restore_state_from_checkpoint(checkpoint_dir) diff --git a/rl_coach/graph_managers/graph_manager.py b/rl_coach/graph_managers/graph_manager.py index b9013fd..10f314a 100644 --- a/rl_coach/graph_managers/graph_manager.py +++ b/rl_coach/graph_managers/graph_manager.py @@ -150,7 +150,7 @@ class GraphManager(object): # create a session (it needs to be created after all the graph ops were created) self.sess = None - self.create_session(task_parameters=task_parameters) + self.restore_checkpoint() self._phase = self.phase = RunPhase.UNDEFINED @@ -261,8 +261,6 @@ class GraphManager(object): self.checkpoint_saver = SaverCollection() for level in self.level_managers: self.checkpoint_saver.update(level.collect_savers()) - # restore from checkpoint if given - self.restore_checkpoint() def save_graph(self) -> None: """ @@ -558,14 +556,20 @@ class GraphManager(object): else: checkpoint = get_checkpoint_state(self.task_parameters.checkpoint_restore_dir) + # As part of this restore, Agent recreates the global, target and online networks + [manager.restore_checkpoint(self.task_parameters.checkpoint_restore_dir) for manager in self.level_managers] + + # Recreate the session to use the new TF Graphs + self.create_session(self.task_parameters) + if checkpoint is None: screen.warning("No checkpoint to restore in: {}".format(self.task_parameters.checkpoint_restore_dir)) else: screen.log_title("Loading checkpoint: {}".format(checkpoint.model_checkpoint_path)) - if not hasattr(self.agent_params.memory, 'memory_backend_params') or self.agent_params.memory.memory_backend_params.run_type != str(RunType.ROLLOUT_WORKER): - self.checkpoint_saver.restore(self.sess, checkpoint.model_checkpoint_path) - - [manager.restore_checkpoint(self.task_parameters.checkpoint_restore_dir) for manager in self.level_managers] + self.checkpoint_saver.restore(self.sess, checkpoint.model_checkpoint_path) + else: + # Create the session to use the new TF Graphs + self.create_session(self.task_parameters) def _get_checkpoint_state_tf(self): import tensorflow as tf diff --git a/rl_coach/tests/graph_managers/test_basic_rl_graph_manager.py b/rl_coach/tests/graph_managers/test_basic_rl_graph_manager.py index fb5ba9d..e0d9585 100644 --- a/rl_coach/tests/graph_managers/test_basic_rl_graph_manager.py +++ b/rl_coach/tests/graph_managers/test_basic_rl_graph_manager.py @@ -56,10 +56,6 @@ def test_basic_rl_graph_manager_with_cartpole_dqn_and_repeated_checkpoint_restor # graph_manager.save_checkpoint() # # graph_manager.task_parameters.checkpoint_restore_dir = "./experiments/test/checkpoint" - # graph_manager.agent_params.memory.register_var('memory_backend_params', - # MemoryBackendParameters(store_type=None, - # orchestrator_type=None, - # run_type=str(RunType.ROLLOUT_WORKER))) # while True: # graph_manager.restore_checkpoint() # gc.collect()