mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 11:10:20 +01:00
Tf checkpointing using saver mechanism (#134)
This commit is contained in:
committed by
Gal Leibovich
parent
dd18959e53
commit
16cdd9a9c1
@@ -574,24 +574,30 @@ class CheckpointState(object):
|
||||
return str(self._checkpoints)
|
||||
|
||||
|
||||
COACH_CHECKPOINT_PATTERN = r'\A([0-9]+)[^0-9].*?\.ckpt'
|
||||
COACH_CHECKPOINT_PATTERN = r'\A(([0-9]+)[^0-9])?.*?\.ckpt(-([0-9]+))?'
|
||||
|
||||
|
||||
def get_checkpoint_state(checkpoint_dir: Union[str, List[str]], filename_pattern: str=COACH_CHECKPOINT_PATTERN) ->\
|
||||
CheckpointState:
|
||||
Union[CheckpointState, None]:
|
||||
"""
|
||||
Finds the latest checkpoint file. It uses the first group of filename_pattern (i.e. group(1)) to sort
|
||||
Finds the latest checkpoint file. It uses the first group of filename_pattern (i.e. group(2) or group(4) to sort
|
||||
the checkpoint names and find the latest checkpoint
|
||||
:param checkpoint_dir: directory where checkpoints are saved or list of all files in a directory
|
||||
:param filename_pattern: regex pattern for checkpoint filenames
|
||||
:return: a CheckpointState for checkpoint_dir containing a sorted list of checkpoint names
|
||||
:return: a CheckpointState for checkpoint_dir containing a sorted list of checkpoint names. If no matching
|
||||
files are found, returns None.
|
||||
"""
|
||||
prog = re.compile(filename_pattern)
|
||||
checkpoints = dict()
|
||||
filenames = os.listdir(checkpoint_dir) if isinstance(checkpoint_dir, str) else checkpoint_dir
|
||||
for name in filenames:
|
||||
m = prog.search(name)
|
||||
if m is not None and m.group(1) is not None:
|
||||
if m is not None and (m.group(2) is not None or m.group(4) is not None):
|
||||
if m.group(2) is not None and m.group(4) is not None:
|
||||
assert m.group(2) == m.group(4)
|
||||
checkpoint_count = int(m.group(2) if m.group(2) is not None else m.group(4))
|
||||
full_path = os.path.join(checkpoint_dir, m.group(0)) if isinstance(checkpoint_dir, str) else m.group(0)
|
||||
checkpoints[int(m.group(1))] = full_path
|
||||
checkpoints[checkpoint_count] = full_path
|
||||
if len(checkpoints) == 0:
|
||||
return None
|
||||
return CheckpointState([checkpoints[k] for k in sorted(checkpoints.keys())])
|
||||
|
||||
Reference in New Issue
Block a user