1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-18 03:30:19 +01:00

Tf checkpointing using saver mechanism (#134)

This commit is contained in:
Sina Afrooze
2018-11-22 04:08:10 -08:00
committed by Gal Leibovich
parent dd18959e53
commit 16cdd9a9c1
6 changed files with 110 additions and 50 deletions

View File

@@ -88,7 +88,7 @@ class SaverCollection(object):
"""
paths = list()
for saver in self:
paths.extend(saver.save(sess, "{}.{}".format(save_path, saver.path)))
paths.extend(saver.save(sess, self._full_path(save_path, saver)))
return paths
def restore(self, sess: Any, restore_path: str) -> None:
@@ -98,8 +98,7 @@ class SaverCollection(object):
:param restore_path: path for restoring checkpoint using savers.
"""
for saver in self:
restore_path = "{}.{}".format(restore_path, saver.path)
saver.restore(sess, restore_path)
saver.restore(sess, self._full_path(restore_path, saver))
def __iter__(self):
"""
@@ -108,5 +107,16 @@ class SaverCollection(object):
"""
return (v for v in self._saver_dict.values())
@staticmethod
def _full_path(path_prefix: str, saver: Saver) -> str:
"""
Concatenates path of the saver to parent prefix to create full save path
:param path_prefix: prefix of the path
:param saver: saver object to get unique path extension from
:return: full path
"""
if saver.path == "":
return path_prefix
return "{}.{}".format(path_prefix, saver.path)