mirror of
https://github.com/gryf/coach.git
synced 2025-12-18 11:40:18 +01:00
hacky way to resolve the checkpointing issue (#154)
This commit is contained in:
@@ -541,28 +541,22 @@ class GraphManager(object):
|
|||||||
if self.evaluate(self.evaluation_steps):
|
if self.evaluate(self.evaluation_steps):
|
||||||
break
|
break
|
||||||
|
|
||||||
def _restore_checkpoint_tf(self, checkpoint_path: str):
|
|
||||||
import tensorflow as tf
|
|
||||||
variables = {}
|
|
||||||
reader = tf.contrib.framework.load_checkpoint(checkpoint_path)
|
|
||||||
for var_name, _ in reader.get_variable_to_shape_map().items():
|
|
||||||
# Load the variable
|
|
||||||
var = reader.get_tensor(var_name)
|
|
||||||
|
|
||||||
# Set the new name
|
|
||||||
new_name = var_name
|
|
||||||
new_name = new_name.replace('global/', 'online/')
|
|
||||||
variables[new_name] = var
|
|
||||||
|
|
||||||
for v in self.variables_to_restore:
|
|
||||||
self.sess.run(v.assign(variables[v.name.split(':')[0]]))
|
|
||||||
|
|
||||||
def restore_checkpoint(self):
|
def restore_checkpoint(self):
|
||||||
self.verify_graph_was_created()
|
self.verify_graph_was_created()
|
||||||
|
|
||||||
# TODO: find better way to load checkpoints that were saved with a global network into the online network
|
# TODO: find better way to load checkpoints that were saved with a global network into the online network
|
||||||
if self.task_parameters.checkpoint_restore_dir:
|
if self.task_parameters.checkpoint_restore_dir:
|
||||||
|
if self.task_parameters.framework_type == Frameworks.tensorflow:
|
||||||
|
# TODO-fixme checkpointing
|
||||||
|
# MonitoredTrainingSession manages save/restore checkpoints autonomously. Doing so,
|
||||||
|
# it creates it own names for the saved checkpoints, which do not match the "{}_Step-{}.ckpt" filename
|
||||||
|
# pattern. The names used are maintained in a CheckpointState protobuf file named 'checkpoint'. Using
|
||||||
|
# Coach's '.coach_checkpoint' protobuf file, results in an error when trying to restore the model, as
|
||||||
|
# the checkpoint names defined do not match the actual checkpoint names.
|
||||||
|
checkpoint = self._get_checkpoint_state_tf()
|
||||||
|
else:
|
||||||
checkpoint = get_checkpoint_state(self.task_parameters.checkpoint_restore_dir)
|
checkpoint = get_checkpoint_state(self.task_parameters.checkpoint_restore_dir)
|
||||||
|
|
||||||
if checkpoint is None:
|
if checkpoint is None:
|
||||||
screen.warning("No checkpoint to restore in: {}".format(self.task_parameters.checkpoint_restore_dir))
|
screen.warning("No checkpoint to restore in: {}".format(self.task_parameters.checkpoint_restore_dir))
|
||||||
else:
|
else:
|
||||||
@@ -571,6 +565,10 @@ class GraphManager(object):
|
|||||||
|
|
||||||
[manager.restore_checkpoint(self.task_parameters.checkpoint_restore_dir) for manager in self.level_managers]
|
[manager.restore_checkpoint(self.task_parameters.checkpoint_restore_dir) for manager in self.level_managers]
|
||||||
|
|
||||||
|
def _get_checkpoint_state_tf(self):
|
||||||
|
import tensorflow as tf
|
||||||
|
return tf.train.get_checkpoint_state(self.task_parameters.checkpoint_restore_dir)
|
||||||
|
|
||||||
def occasionally_save_checkpoint(self):
|
def occasionally_save_checkpoint(self):
|
||||||
# only the chief process saves checkpoints
|
# only the chief process saves checkpoints
|
||||||
if self.task_parameters.checkpoint_save_secs \
|
if self.task_parameters.checkpoint_save_secs \
|
||||||
|
|||||||
Reference in New Issue
Block a user