1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 19:20:19 +01:00

Implement frame-work agnostic rollout and training workers (#137)

* Added checkpoint state file to coach checkpointing.

* Removed TF specific code from rollout_worker, training_worker, and s3_data_store
This commit is contained in:
Sina Afrooze
2018-11-23 18:05:44 -08:00
committed by Balaji Subramaniam
parent 4a6c404070
commit 5332013bd1
7 changed files with 350 additions and 117 deletions

View File

@@ -546,58 +546,3 @@ def start_shell_command_and_wait(command):
def indent_string(string):
return '\t' + string.replace('\n', '\n\t')
class CheckpointState(object):
"""
Helper class for checkpoint directory information. It replicates
the CheckpointState protobuf class in tensorflow.
"""
def __init__(self, checkpoints: List[str]):
self._checkpoints = checkpoints
@property
def all_model_checkpoint_paths(self):
return self._checkpoints
@property
def model_checkpoint_path(self):
return self._checkpoints[-1]
def __str__(self):
out_str = 'model_checkpoint_path: {}\n'.format(self.model_checkpoint_path)
for c in self._checkpoints:
out_str += 'all_model_checkpoint_paths: {}\n'.format(c)
return out_str
def __repr__(self):
return str(self._checkpoints)
COACH_CHECKPOINT_PATTERN = r'\A(([0-9]+)[^0-9])?.*?\.ckpt(-([0-9]+))?'
def get_checkpoint_state(checkpoint_dir: Union[str, List[str]], filename_pattern: str=COACH_CHECKPOINT_PATTERN) ->\
Union[CheckpointState, None]:
"""
Finds the latest checkpoint file. It uses the first group of filename_pattern (i.e. group(2) or group(4) to sort
the checkpoint names and find the latest checkpoint
:param checkpoint_dir: directory where checkpoints are saved or list of all files in a directory
:param filename_pattern: regex pattern for checkpoint filenames
:return: a CheckpointState for checkpoint_dir containing a sorted list of checkpoint names. If no matching
files are found, returns None.
"""
prog = re.compile(filename_pattern)
checkpoints = dict()
filenames = os.listdir(checkpoint_dir) if isinstance(checkpoint_dir, str) else checkpoint_dir
for name in filenames:
m = prog.search(name)
if m is not None and (m.group(2) is not None or m.group(4) is not None):
if m.group(2) is not None and m.group(4) is not None:
assert m.group(2) == m.group(4)
checkpoint_count = int(m.group(2) if m.group(2) is not None else m.group(4))
full_path = os.path.join(checkpoint_dir, m.group(0)) if isinstance(checkpoint_dir, str) else m.group(0)
checkpoints[checkpoint_count] = full_path
if len(checkpoints) == 0:
return None
return CheckpointState([checkpoints[k] for k in sorted(checkpoints.keys())])