From deb0251367d46598e8c5382331f79b6af442e885 Mon Sep 17 00:00:00 2001 From: Gal Leibovich Date: Sun, 12 May 2019 23:42:45 +0300 Subject: [PATCH] bug fix following PR #191 (#313) --- rl_coach/graph_managers/graph_manager.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/rl_coach/graph_managers/graph_manager.py b/rl_coach/graph_managers/graph_manager.py index dd98731..aef08cf 100644 --- a/rl_coach/graph_managers/graph_manager.py +++ b/rl_coach/graph_managers/graph_manager.py @@ -575,6 +575,11 @@ class GraphManager(object): self.task_parameters.checkpoint_restore_path)) model_checkpoint_path = checkpoint.model_checkpoint_path checkpoint_restore_dir = self.task_parameters.checkpoint_restore_path + + # Set the last checkpoint ID - only in the case of the path being a dir + chkpt_state_reader = CheckpointStateReader(self.task_parameters.checkpoint_restore_path, + checkpoint_state_optional=False) + self.checkpoint_id = chkpt_state_reader.get_latest().num + 1 else: # a checkpoint file if self.task_parameters.framework_type == Frameworks.tensorflow: @@ -590,10 +595,6 @@ class GraphManager(object): [manager.restore_checkpoint(checkpoint_restore_dir) for manager in self.level_managers] - # Set the last checkpoint ID - chkpt_state_reader = CheckpointStateReader(self.task_parameters.checkpoint_restore_path, checkpoint_state_optional=False) - self.checkpoint_id = chkpt_state_reader.get_latest().num + 1 - def _get_checkpoint_state_tf(self, checkpoint_restore_dir): import tensorflow as tf return tf.train.get_checkpoint_state(checkpoint_restore_dir)