mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 19:20:19 +01:00
hacky way to resolve the checkpointing issue (#154)
This commit is contained in:
@@ -541,28 +541,22 @@ class GraphManager(object):
|
||||
if self.evaluate(self.evaluation_steps):
|
||||
break
|
||||
|
||||
def _restore_checkpoint_tf(self, checkpoint_path: str):
|
||||
import tensorflow as tf
|
||||
variables = {}
|
||||
reader = tf.contrib.framework.load_checkpoint(checkpoint_path)
|
||||
for var_name, _ in reader.get_variable_to_shape_map().items():
|
||||
# Load the variable
|
||||
var = reader.get_tensor(var_name)
|
||||
|
||||
# Set the new name
|
||||
new_name = var_name
|
||||
new_name = new_name.replace('global/', 'online/')
|
||||
variables[new_name] = var
|
||||
|
||||
for v in self.variables_to_restore:
|
||||
self.sess.run(v.assign(variables[v.name.split(':')[0]]))
|
||||
|
||||
def restore_checkpoint(self):
|
||||
self.verify_graph_was_created()
|
||||
|
||||
# TODO: find better way to load checkpoints that were saved with a global network into the online network
|
||||
if self.task_parameters.checkpoint_restore_dir:
|
||||
checkpoint = get_checkpoint_state(self.task_parameters.checkpoint_restore_dir)
|
||||
if self.task_parameters.framework_type == Frameworks.tensorflow:
|
||||
# TODO-fixme checkpointing
|
||||
# MonitoredTrainingSession manages save/restore checkpoints autonomously. Doing so,
|
||||
# it creates it own names for the saved checkpoints, which do not match the "{}_Step-{}.ckpt" filename
|
||||
# pattern. The names used are maintained in a CheckpointState protobuf file named 'checkpoint'. Using
|
||||
# Coach's '.coach_checkpoint' protobuf file, results in an error when trying to restore the model, as
|
||||
# the checkpoint names defined do not match the actual checkpoint names.
|
||||
checkpoint = self._get_checkpoint_state_tf()
|
||||
else:
|
||||
checkpoint = get_checkpoint_state(self.task_parameters.checkpoint_restore_dir)
|
||||
|
||||
if checkpoint is None:
|
||||
screen.warning("No checkpoint to restore in: {}".format(self.task_parameters.checkpoint_restore_dir))
|
||||
else:
|
||||
@@ -571,6 +565,10 @@ class GraphManager(object):
|
||||
|
||||
[manager.restore_checkpoint(self.task_parameters.checkpoint_restore_dir) for manager in self.level_managers]
|
||||
|
||||
def _get_checkpoint_state_tf(self):
|
||||
import tensorflow as tf
|
||||
return tf.train.get_checkpoint_state(self.task_parameters.checkpoint_restore_dir)
|
||||
|
||||
def occasionally_save_checkpoint(self):
|
||||
# only the chief process saves checkpoints
|
||||
if self.task_parameters.checkpoint_save_secs \
|
||||
|
||||
Reference in New Issue
Block a user