1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 11:10:20 +01:00

bug fix following PR #191 (#313)

This commit is contained in:
Gal Leibovich
2019-05-12 23:42:45 +03:00
committed by Scott Leishman
parent aa9f3cefaf
commit deb0251367

View File

@@ -575,6 +575,11 @@ class GraphManager(object):
self.task_parameters.checkpoint_restore_path))
model_checkpoint_path = checkpoint.model_checkpoint_path
checkpoint_restore_dir = self.task_parameters.checkpoint_restore_path
# Set the last checkpoint ID - only in the case of the path being a dir
chkpt_state_reader = CheckpointStateReader(self.task_parameters.checkpoint_restore_path,
checkpoint_state_optional=False)
self.checkpoint_id = chkpt_state_reader.get_latest().num + 1
else:
# a checkpoint file
if self.task_parameters.framework_type == Frameworks.tensorflow:
@@ -590,10 +595,6 @@ class GraphManager(object):
[manager.restore_checkpoint(checkpoint_restore_dir) for manager in self.level_managers]
# Set the last checkpoint ID
chkpt_state_reader = CheckpointStateReader(self.task_parameters.checkpoint_restore_path, checkpoint_state_optional=False)
self.checkpoint_id = chkpt_state_reader.get_latest().num + 1
def _get_checkpoint_state_tf(self, checkpoint_restore_dir):
import tensorflow as tf
return tf.train.get_checkpoint_state(checkpoint_restore_dir)