mirror of
https://github.com/gryf/coach.git
synced 2025-12-18 03:30:19 +01:00
Tf checkpointing using saver mechanism (#134)
This commit is contained in:
committed by
Gal Leibovich
parent
dd18959e53
commit
16cdd9a9c1
@@ -88,7 +88,7 @@ class SaverCollection(object):
|
||||
"""
|
||||
paths = list()
|
||||
for saver in self:
|
||||
paths.extend(saver.save(sess, "{}.{}".format(save_path, saver.path)))
|
||||
paths.extend(saver.save(sess, self._full_path(save_path, saver)))
|
||||
return paths
|
||||
|
||||
def restore(self, sess: Any, restore_path: str) -> None:
|
||||
@@ -98,8 +98,7 @@ class SaverCollection(object):
|
||||
:param restore_path: path for restoring checkpoint using savers.
|
||||
"""
|
||||
for saver in self:
|
||||
restore_path = "{}.{}".format(restore_path, saver.path)
|
||||
saver.restore(sess, restore_path)
|
||||
saver.restore(sess, self._full_path(restore_path, saver))
|
||||
|
||||
def __iter__(self):
|
||||
"""
|
||||
@@ -108,5 +107,16 @@ class SaverCollection(object):
|
||||
"""
|
||||
return (v for v in self._saver_dict.values())
|
||||
|
||||
@staticmethod
|
||||
def _full_path(path_prefix: str, saver: Saver) -> str:
|
||||
"""
|
||||
Concatenates path of the saver to parent prefix to create full save path
|
||||
:param path_prefix: prefix of the path
|
||||
:param saver: saver object to get unique path extension from
|
||||
:return: full path
|
||||
"""
|
||||
if saver.path == "":
|
||||
return path_prefix
|
||||
return "{}.{}".format(path_prefix, saver.path)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user