1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 19:20:19 +01:00

hacky way to resolve the checkpointing issue (#154)

This commit is contained in:
Gal Leibovich
2018-11-25 16:14:15 +02:00
committed by Gal Novik
parent 11170d5ba3
commit ab10852ad9

View File

@@ -541,28 +541,22 @@ class GraphManager(object):
if self.evaluate(self.evaluation_steps):
break
def _restore_checkpoint_tf(self, checkpoint_path: str):
import tensorflow as tf
variables = {}
reader = tf.contrib.framework.load_checkpoint(checkpoint_path)
for var_name, _ in reader.get_variable_to_shape_map().items():
# Load the variable
var = reader.get_tensor(var_name)
# Set the new name
new_name = var_name
new_name = new_name.replace('global/', 'online/')
variables[new_name] = var
for v in self.variables_to_restore:
self.sess.run(v.assign(variables[v.name.split(':')[0]]))
def restore_checkpoint(self):
self.verify_graph_was_created()
# TODO: find better way to load checkpoints that were saved with a global network into the online network
if self.task_parameters.checkpoint_restore_dir:
checkpoint = get_checkpoint_state(self.task_parameters.checkpoint_restore_dir)
if self.task_parameters.framework_type == Frameworks.tensorflow:
# TODO-fixme checkpointing
# MonitoredTrainingSession manages save/restore checkpoints autonomously. Doing so,
# it creates it own names for the saved checkpoints, which do not match the "{}_Step-{}.ckpt" filename
# pattern. The names used are maintained in a CheckpointState protobuf file named 'checkpoint'. Using
# Coach's '.coach_checkpoint' protobuf file, results in an error when trying to restore the model, as
# the checkpoint names defined do not match the actual checkpoint names.
checkpoint = self._get_checkpoint_state_tf()
else:
checkpoint = get_checkpoint_state(self.task_parameters.checkpoint_restore_dir)
if checkpoint is None:
screen.warning("No checkpoint to restore in: {}".format(self.task_parameters.checkpoint_restore_dir))
else:
@@ -571,6 +565,10 @@ class GraphManager(object):
[manager.restore_checkpoint(self.task_parameters.checkpoint_restore_dir) for manager in self.level_managers]
def _get_checkpoint_state_tf(self):
import tensorflow as tf
return tf.train.get_checkpoint_state(self.task_parameters.checkpoint_restore_dir)
def occasionally_save_checkpoint(self):
# only the chief process saves checkpoints
if self.task_parameters.checkpoint_save_secs \