mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 11:10:20 +01:00
committed by
Scott Leishman
parent
aa9f3cefaf
commit
deb0251367
@@ -575,6 +575,11 @@ class GraphManager(object):
|
||||
self.task_parameters.checkpoint_restore_path))
|
||||
model_checkpoint_path = checkpoint.model_checkpoint_path
|
||||
checkpoint_restore_dir = self.task_parameters.checkpoint_restore_path
|
||||
|
||||
# Set the last checkpoint ID - only in the case of the path being a dir
|
||||
chkpt_state_reader = CheckpointStateReader(self.task_parameters.checkpoint_restore_path,
|
||||
checkpoint_state_optional=False)
|
||||
self.checkpoint_id = chkpt_state_reader.get_latest().num + 1
|
||||
else:
|
||||
# a checkpoint file
|
||||
if self.task_parameters.framework_type == Frameworks.tensorflow:
|
||||
@@ -590,10 +595,6 @@ class GraphManager(object):
|
||||
|
||||
[manager.restore_checkpoint(checkpoint_restore_dir) for manager in self.level_managers]
|
||||
|
||||
# Set the last checkpoint ID
|
||||
chkpt_state_reader = CheckpointStateReader(self.task_parameters.checkpoint_restore_path, checkpoint_state_optional=False)
|
||||
self.checkpoint_id = chkpt_state_reader.get_latest().num + 1
|
||||
|
||||
def _get_checkpoint_state_tf(self, checkpoint_restore_dir):
|
||||
import tensorflow as tf
|
||||
return tf.train.get_checkpoint_state(checkpoint_restore_dir)
|
||||
|
||||
Reference in New Issue
Block a user