mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 19:20:19 +01:00
Tf checkpointing using saver mechanism (#134)
This commit is contained in:
committed by
Gal Leibovich
parent
dd18959e53
commit
16cdd9a9c1
@@ -232,19 +232,11 @@ class GraphManager(object):
|
||||
# set the session for all the modules
|
||||
self.set_session(self.sess)
|
||||
else:
|
||||
self.variables_to_restore = tf.global_variables()
|
||||
# self.variables_to_restore = [v for v in self.variables_to_restore if '/online' in v.name] TODO: is this necessary?
|
||||
self.checkpoint_saver = tf.train.Saver(self.variables_to_restore)
|
||||
|
||||
# regular session
|
||||
self.sess = tf.Session(config=config)
|
||||
|
||||
# set the session for all the modules
|
||||
self.set_session(self.sess)
|
||||
|
||||
# restore from checkpoint if given
|
||||
self.restore_checkpoint()
|
||||
|
||||
# the TF graph is static, and therefore is saved once - in the beginning of the experiment
|
||||
if hasattr(self.task_parameters, 'checkpoint_save_dir') and self.task_parameters.checkpoint_save_dir:
|
||||
self.save_graph()
|
||||
@@ -254,11 +246,6 @@ class GraphManager(object):
|
||||
Call set_session to initialize parameters and construct checkpoint_saver
|
||||
"""
|
||||
self.set_session(sess=None) # Initialize all modules
|
||||
self.checkpoint_saver = SaverCollection()
|
||||
for level in self.level_managers:
|
||||
self.checkpoint_saver.update(level.collect_savers())
|
||||
# restore from checkpoint if given
|
||||
self.restore_checkpoint()
|
||||
|
||||
def create_session(self, task_parameters: TaskParameters):
|
||||
if task_parameters.framework_type == Frameworks.tensorflow:
|
||||
@@ -268,6 +255,13 @@ class GraphManager(object):
|
||||
else:
|
||||
raise ValueError('Invalid framework {}'.format(task_parameters.framework_type))
|
||||
|
||||
# Create parameter saver
|
||||
self.checkpoint_saver = SaverCollection()
|
||||
for level in self.level_managers:
|
||||
self.checkpoint_saver.update(level.collect_savers())
|
||||
# restore from checkpoint if given
|
||||
self.restore_checkpoint()
|
||||
|
||||
def save_graph(self) -> None:
|
||||
"""
|
||||
Save the TF graph to a protobuf description file in the experiment directory
|
||||
@@ -566,16 +560,12 @@ class GraphManager(object):
|
||||
|
||||
# TODO: find better way to load checkpoints that were saved with a global network into the online network
|
||||
if self.task_parameters.checkpoint_restore_dir:
|
||||
|
||||
checkpoint = get_checkpoint_state(self.task_parameters.checkpoint_restore_dir)
|
||||
screen.log_title("Loading checkpoint: {}".format(checkpoint.model_checkpoint_path))
|
||||
|
||||
if self.task_parameters.framework_type == Frameworks.tensorflow:
|
||||
self._restore_checkpoint_tf(checkpoint.model_checkpoint_path)
|
||||
elif self.task_parameters.framework_type == Frameworks.mxnet:
|
||||
self.checkpoint_saver.restore(self.sess, checkpoint.model_checkpoint_path)
|
||||
if checkpoint is None:
|
||||
screen.warning("No checkpoint to restore in: {}".format(self.task_parameters.checkpoint_restore_dir))
|
||||
else:
|
||||
raise ValueError('Invalid framework {}'.format(self.task_parameters.framework_type))
|
||||
screen.log_title("Loading checkpoint: {}".format(checkpoint.model_checkpoint_path))
|
||||
self.checkpoint_saver.restore(self.sess, checkpoint.model_checkpoint_path)
|
||||
|
||||
[manager.restore_checkpoint(self.task_parameters.checkpoint_restore_dir) for manager in self.level_managers]
|
||||
|
||||
@@ -598,10 +588,7 @@ class GraphManager(object):
|
||||
if not os.path.exists(os.path.dirname(checkpoint_path)):
|
||||
os.mkdir(os.path.dirname(checkpoint_path)) # Create directory structure
|
||||
if not isinstance(self.task_parameters, DistributedTaskParameters):
|
||||
if self.checkpoint_saver is not None:
|
||||
saved_checkpoint_path = self.checkpoint_saver.save(self.sess, checkpoint_path)
|
||||
else:
|
||||
saved_checkpoint_path = "<Not Saved>"
|
||||
saved_checkpoint_path = self.checkpoint_saver.save(self.sess, checkpoint_path)
|
||||
else:
|
||||
saved_checkpoint_path = checkpoint_path
|
||||
|
||||
|
||||
Reference in New Issue
Block a user