mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 19:20:19 +01:00
committed by
Scott Leishman
parent
aa9f3cefaf
commit
deb0251367
@@ -575,6 +575,11 @@ class GraphManager(object):
|
|||||||
self.task_parameters.checkpoint_restore_path))
|
self.task_parameters.checkpoint_restore_path))
|
||||||
model_checkpoint_path = checkpoint.model_checkpoint_path
|
model_checkpoint_path = checkpoint.model_checkpoint_path
|
||||||
checkpoint_restore_dir = self.task_parameters.checkpoint_restore_path
|
checkpoint_restore_dir = self.task_parameters.checkpoint_restore_path
|
||||||
|
|
||||||
|
# Set the last checkpoint ID - only in the case of the path being a dir
|
||||||
|
chkpt_state_reader = CheckpointStateReader(self.task_parameters.checkpoint_restore_path,
|
||||||
|
checkpoint_state_optional=False)
|
||||||
|
self.checkpoint_id = chkpt_state_reader.get_latest().num + 1
|
||||||
else:
|
else:
|
||||||
# a checkpoint file
|
# a checkpoint file
|
||||||
if self.task_parameters.framework_type == Frameworks.tensorflow:
|
if self.task_parameters.framework_type == Frameworks.tensorflow:
|
||||||
@@ -590,10 +595,6 @@ class GraphManager(object):
|
|||||||
|
|
||||||
[manager.restore_checkpoint(checkpoint_restore_dir) for manager in self.level_managers]
|
[manager.restore_checkpoint(checkpoint_restore_dir) for manager in self.level_managers]
|
||||||
|
|
||||||
# Set the last checkpoint ID
|
|
||||||
chkpt_state_reader = CheckpointStateReader(self.task_parameters.checkpoint_restore_path, checkpoint_state_optional=False)
|
|
||||||
self.checkpoint_id = chkpt_state_reader.get_latest().num + 1
|
|
||||||
|
|
||||||
def _get_checkpoint_state_tf(self, checkpoint_restore_dir):
|
def _get_checkpoint_state_tf(self, checkpoint_restore_dir):
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
return tf.train.get_checkpoint_state(checkpoint_restore_dir)
|
return tf.train.get_checkpoint_state(checkpoint_restore_dir)
|
||||||
|
|||||||
Reference in New Issue
Block a user