1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 19:20:19 +01:00

Tf checkpointing using saver mechanism (#134)

This commit is contained in:
Sina Afrooze
2018-11-22 04:08:10 -08:00
committed by Gal Leibovich
parent dd18959e53
commit 16cdd9a9c1
6 changed files with 110 additions and 50 deletions

View File

@@ -232,19 +232,11 @@ class GraphManager(object):
# set the session for all the modules
self.set_session(self.sess)
else:
self.variables_to_restore = tf.global_variables()
# self.variables_to_restore = [v for v in self.variables_to_restore if '/online' in v.name] TODO: is this necessary?
self.checkpoint_saver = tf.train.Saver(self.variables_to_restore)
# regular session
self.sess = tf.Session(config=config)
# set the session for all the modules
self.set_session(self.sess)
# restore from checkpoint if given
self.restore_checkpoint()
# the TF graph is static, and therefore is saved once - in the beginning of the experiment
if hasattr(self.task_parameters, 'checkpoint_save_dir') and self.task_parameters.checkpoint_save_dir:
self.save_graph()
@@ -254,11 +246,6 @@ class GraphManager(object):
Call set_session to initialize parameters and construct checkpoint_saver
"""
self.set_session(sess=None) # Initialize all modules
self.checkpoint_saver = SaverCollection()
for level in self.level_managers:
self.checkpoint_saver.update(level.collect_savers())
# restore from checkpoint if given
self.restore_checkpoint()
def create_session(self, task_parameters: TaskParameters):
if task_parameters.framework_type == Frameworks.tensorflow:
@@ -268,6 +255,13 @@ class GraphManager(object):
else:
raise ValueError('Invalid framework {}'.format(task_parameters.framework_type))
# Create parameter saver
self.checkpoint_saver = SaverCollection()
for level in self.level_managers:
self.checkpoint_saver.update(level.collect_savers())
# restore from checkpoint if given
self.restore_checkpoint()
def save_graph(self) -> None:
"""
Save the TF graph to a protobuf description file in the experiment directory
@@ -566,16 +560,12 @@ class GraphManager(object):
# TODO: find better way to load checkpoints that were saved with a global network into the online network
if self.task_parameters.checkpoint_restore_dir:
checkpoint = get_checkpoint_state(self.task_parameters.checkpoint_restore_dir)
screen.log_title("Loading checkpoint: {}".format(checkpoint.model_checkpoint_path))
if self.task_parameters.framework_type == Frameworks.tensorflow:
self._restore_checkpoint_tf(checkpoint.model_checkpoint_path)
elif self.task_parameters.framework_type == Frameworks.mxnet:
self.checkpoint_saver.restore(self.sess, checkpoint.model_checkpoint_path)
if checkpoint is None:
screen.warning("No checkpoint to restore in: {}".format(self.task_parameters.checkpoint_restore_dir))
else:
raise ValueError('Invalid framework {}'.format(self.task_parameters.framework_type))
screen.log_title("Loading checkpoint: {}".format(checkpoint.model_checkpoint_path))
self.checkpoint_saver.restore(self.sess, checkpoint.model_checkpoint_path)
[manager.restore_checkpoint(self.task_parameters.checkpoint_restore_dir) for manager in self.level_managers]
@@ -598,10 +588,7 @@ class GraphManager(object):
if not os.path.exists(os.path.dirname(checkpoint_path)):
os.mkdir(os.path.dirname(checkpoint_path)) # Create directory structure
if not isinstance(self.task_parameters, DistributedTaskParameters):
if self.checkpoint_saver is not None:
saved_checkpoint_path = self.checkpoint_saver.save(self.sess, checkpoint_path)
else:
saved_checkpoint_path = "<Not Saved>"
saved_checkpoint_path = self.checkpoint_saver.save(self.sess, checkpoint_path)
else:
saved_checkpoint_path = checkpoint_path