diff --git a/rl_coach/graph_managers/graph_manager.py b/rl_coach/graph_managers/graph_manager.py index dd98731..aef08cf 100644 --- a/rl_coach/graph_managers/graph_manager.py +++ b/rl_coach/graph_managers/graph_manager.py @@ -575,6 +575,11 @@ class GraphManager(object): self.task_parameters.checkpoint_restore_path)) model_checkpoint_path = checkpoint.model_checkpoint_path checkpoint_restore_dir = self.task_parameters.checkpoint_restore_path + + # Set the last checkpoint ID - only in the case of the path being a dir + chkpt_state_reader = CheckpointStateReader(self.task_parameters.checkpoint_restore_path, + checkpoint_state_optional=False) + self.checkpoint_id = chkpt_state_reader.get_latest().num + 1 else: # a checkpoint file if self.task_parameters.framework_type == Frameworks.tensorflow: @@ -590,10 +595,6 @@ class GraphManager(object): [manager.restore_checkpoint(checkpoint_restore_dir) for manager in self.level_managers] - # Set the last checkpoint ID - chkpt_state_reader = CheckpointStateReader(self.task_parameters.checkpoint_restore_path, checkpoint_state_optional=False) - self.checkpoint_id = chkpt_state_reader.get_latest().num + 1 - def _get_checkpoint_state_tf(self, checkpoint_restore_dir): import tensorflow as tf return tf.train.get_checkpoint_state(checkpoint_restore_dir)