mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 19:20:19 +01:00
Implement frame-work agnostic rollout and training workers (#137)
* Added checkpoint state file to coach checkpointing. * Removed TF specific code from rollout_worker, training_worker, and s3_data_store
This commit is contained in:
committed by
Balaji Subramaniam
parent
4a6c404070
commit
5332013bd1
298
rl_coach/checkpoint.py
Normal file
298
rl_coach/checkpoint.py
Normal file
@@ -0,0 +1,298 @@
|
||||
#
|
||||
# Copyright (c) 2017 Intel Corporation
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
"""
|
||||
Module providing helper classes and functions for reading/writing checkpoint state
|
||||
"""
|
||||
|
||||
import os
|
||||
import re
|
||||
from typing import List, Union, Tuple
|
||||
|
||||
|
||||
class SingleCheckpoint(object):
|
||||
"""
|
||||
Helper class for storing checkpoint name and number
|
||||
"""
|
||||
def __init__(self, num: int, name: str):
|
||||
"""
|
||||
:param num: checkpoint number
|
||||
:param name: checkpoint name (i.e. the prefix for all checkpoint files)
|
||||
"""
|
||||
self._num = num
|
||||
self._name = name
|
||||
|
||||
@property
|
||||
def num(self) -> int:
|
||||
return self._num
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self._name
|
||||
|
||||
def __str__(self):
|
||||
return self._name
|
||||
|
||||
def __repr__(self):
|
||||
return str(self)
|
||||
|
||||
def __eq__(self, other: 'SingleCheckpoint'):
|
||||
if not isinstance(other, SingleCheckpoint):
|
||||
return False
|
||||
return self._name == other._name and self._num == other._num
|
||||
|
||||
def __ne__(self, other):
|
||||
return not self.__eq__(other)
|
||||
|
||||
|
||||
class CheckpointState(object):
|
||||
"""
|
||||
Helper class for checkpoint directory information. It replicates
|
||||
the CheckpointState protobuf class in tensorflow with addition of
|
||||
two new functions: last_checkpoint() and all_checkpoints()
|
||||
"""
|
||||
def __init__(self, checkpoints: List[SingleCheckpoint], checkpoint_dir: str):
|
||||
"""
|
||||
:param checkpoints: sorted list of checkpoints from oldest to newest. checkpoint[-1] is
|
||||
considered to be the most recent checkpoint.
|
||||
:param checkpoint_dir: checkpoint directory which is added to the paths
|
||||
"""
|
||||
self._checkpoints = checkpoints
|
||||
self._checkpoin_dir = checkpoint_dir
|
||||
|
||||
@property
|
||||
def all_checkpoints(self) -> List[SingleCheckpoint]:
|
||||
"""
|
||||
:return: list of all checkpoints
|
||||
"""
|
||||
return self._checkpoints
|
||||
|
||||
@property
|
||||
def last_checkpoint(self) -> SingleCheckpoint:
|
||||
"""
|
||||
:return: the most recent checkpoint
|
||||
"""
|
||||
return self._checkpoints[-1]
|
||||
|
||||
@property
|
||||
def all_model_checkpoint_paths(self) -> List[str]:
|
||||
"""
|
||||
TF compatible function call to get all checkpoints
|
||||
:return: list of all available model checkpoint paths
|
||||
"""
|
||||
return [os.path.join(self._checkpoin_dir, c.name) for c in self._checkpoints]
|
||||
|
||||
@property
|
||||
def model_checkpoint_path(self) -> str:
|
||||
"""
|
||||
TF compatible call to get most recent checkpoint
|
||||
:return: path of the most recent model checkpoint
|
||||
"""
|
||||
return os.path.join(self._checkpoin_dir, self._checkpoints[-1].name)
|
||||
|
||||
def __str__(self):
|
||||
out_str = 'model_checkpoint_path: {}\n'.format(self.model_checkpoint_path)
|
||||
for c in self.all_model_checkpoint_paths:
|
||||
out_str += 'all_model_checkpoint_paths: {}\n'.format(c)
|
||||
return out_str
|
||||
|
||||
def __repr__(self):
|
||||
return str(self._checkpoints)
|
||||
|
||||
|
||||
class CheckpointStateFile(object):
|
||||
"""
|
||||
Helper class for reading from and writing to the checkpoint state file
|
||||
"""
|
||||
checkpoint_state_filename = '.coach_checkpoint'
|
||||
|
||||
def __init__(self, checkpoint_dir: str):
|
||||
self._checkpoint_state_path = os.path.join(checkpoint_dir, self.checkpoint_state_filename)
|
||||
|
||||
def exists(self) -> bool:
|
||||
"""
|
||||
:return: True if checkpoint state file exists, false otherwise
|
||||
"""
|
||||
return os.path.exists(self._checkpoint_state_path)
|
||||
|
||||
def read(self) -> Union[None, SingleCheckpoint]:
|
||||
"""
|
||||
Read checkpoint state file and interpret its content
|
||||
:return:
|
||||
"""
|
||||
if not self.exists():
|
||||
return None
|
||||
with open(self._checkpoint_state_path, 'r') as fd:
|
||||
return CheckpointFilenameParser().parse(fd.read(256))
|
||||
|
||||
def write(self, data: SingleCheckpoint) -> None:
|
||||
"""
|
||||
Writes data to checkpoint state file
|
||||
:param data: string data
|
||||
"""
|
||||
with open(self._checkpoint_state_path, 'w') as fd:
|
||||
fd.write(data.name)
|
||||
|
||||
@property
|
||||
def filename(self) -> str:
|
||||
return self.checkpoint_state_filename
|
||||
|
||||
@property
|
||||
def path(self) -> str:
|
||||
return self._checkpoint_state_path
|
||||
|
||||
|
||||
class CheckpointStateReader(object):
|
||||
"""
|
||||
Class for scanning checkpoint directory and updating the checkpoint state
|
||||
"""
|
||||
def __init__(self, checkpoint_dir: str, checkpoint_state_optional: bool=True):
|
||||
"""
|
||||
:param checkpoint_dir: path to checkpoint directory
|
||||
:param checkpoint_state_optional: If True, checkpoint state file is optional and if not found,
|
||||
directory is scanned to find the latest checkpoint. Default is True for backward compatibility
|
||||
"""
|
||||
self._checkpoint_dir = checkpoint_dir
|
||||
self._checkpoint_state_file = CheckpointStateFile(self._checkpoint_dir)
|
||||
self._checkpoint_state_optional = checkpoint_state_optional
|
||||
|
||||
def get_latest(self) -> SingleCheckpoint:
|
||||
"""
|
||||
Tries to read the checkpoint state file. If that fails, discovers latest by reading the entire directory.
|
||||
:return: checkpoint object representing the latest checkpoint
|
||||
"""
|
||||
latest = self._checkpoint_state_file.read()
|
||||
if latest is None and self._checkpoint_state_optional:
|
||||
all_checkpoints = _filter_checkpoint_files(os.listdir(self._checkpoint_dir))
|
||||
if len(all_checkpoints) > 0:
|
||||
latest = all_checkpoints[-1]
|
||||
return latest
|
||||
|
||||
def get_all(self) -> List[SingleCheckpoint]:
|
||||
"""
|
||||
Reads both the checkpoint state file as well as contents of the directory and merges them into one list.
|
||||
:return: list of checkpoint objects
|
||||
"""
|
||||
# discover all checkpoint files in directory if requested or if a valid checkpoint-state file doesn't exist
|
||||
all_checkpoints = _filter_checkpoint_files(os.listdir(self._checkpoint_dir))
|
||||
last_checkpoint = self._checkpoint_state_file.read()
|
||||
if last_checkpoint is not None:
|
||||
# remove excess checkpoints: higher checkpoint number, but not recent (e.g. from a previous run)
|
||||
all_checkpoints = all_checkpoints[: all_checkpoints.index(last_checkpoint) + 1]
|
||||
elif not self._checkpoint_state_optional:
|
||||
# if last_checkpoint is not discovered from the checkpoint-state file and it isn't optional, then
|
||||
# all checkpoint files discovered must be partial or invalid, so don't return anything
|
||||
all_checkpoints.clear()
|
||||
return all_checkpoints
|
||||
|
||||
|
||||
class CheckpointStateUpdater(object):
|
||||
"""
|
||||
Class for scanning checkpoint directory and updating the checkpoint state
|
||||
"""
|
||||
def __init__(self, checkpoint_dir: str, read_all: bool=False):
|
||||
"""
|
||||
:param checkpoint_dir: path to checkpoint directory
|
||||
:param read_all: whether to scan the directory for existing checkpoints
|
||||
"""
|
||||
self._checkpoint_dir = checkpoint_dir
|
||||
self._checkpoint_state_file = CheckpointStateFile(checkpoint_dir)
|
||||
self._all_checkpoints = list()
|
||||
# Read checkpoint state and initialize
|
||||
state_reader = CheckpointStateReader(checkpoint_dir)
|
||||
if read_all:
|
||||
self._all_checkpoints = state_reader.get_all()
|
||||
else:
|
||||
latest = state_reader.get_latest()
|
||||
if latest is not None:
|
||||
self._all_checkpoints = [latest]
|
||||
|
||||
def update(self, checkpoint: SingleCheckpoint) -> None:
|
||||
"""
|
||||
Update the checkpoint state with the latest checkpoint.
|
||||
:param checkpoint: SingleCheckpoint object containing name and number of checkpoint
|
||||
"""
|
||||
self._all_checkpoints.append(checkpoint)
|
||||
# Simply write checkpoint_name to checkpoint-state file
|
||||
self._checkpoint_state_file.write(checkpoint)
|
||||
|
||||
@property
|
||||
def last_checkpoint(self) -> Union[None, SingleCheckpoint]:
|
||||
if len(self._all_checkpoints) == 0:
|
||||
return None
|
||||
return self._all_checkpoints[-1]
|
||||
|
||||
@property
|
||||
def all_checkpoints(self) -> List[SingleCheckpoint]:
|
||||
return self._all_checkpoints
|
||||
|
||||
def get_checkpoint_state(self) -> Union[None, CheckpointState]:
|
||||
"""
|
||||
:return: The most recent checkpoint state
|
||||
"""
|
||||
if len(self._all_checkpoints) == 0:
|
||||
return None
|
||||
return CheckpointState(self._all_checkpoints, self._checkpoint_dir)
|
||||
|
||||
|
||||
class CheckpointFilenameParser(object):
|
||||
"""
|
||||
Helper object for parsing filenames that are potentially checkpoints
|
||||
"""
|
||||
coach_checkpoint_filename_pattern = r'\A(([0-9]+)[^0-9])?.*?\.ckpt(-([0-9]+))?'
|
||||
|
||||
def __init__(self):
|
||||
self._prog = re.compile(self.coach_checkpoint_filename_pattern)
|
||||
|
||||
def parse(self, filename: str) -> Union[None, SingleCheckpoint]:
|
||||
"""
|
||||
Tries to parse the filename using the checkpoint filename pattern. If successful,
|
||||
it returns tuple of (checkpoint-number, checkpoint-name). Otherwise it returns None.
|
||||
:param filename: filename to be parsed
|
||||
:return: None or (checkpoint-number, checkpoint-name)
|
||||
"""
|
||||
m = self._prog.search(filename)
|
||||
if m is not None and (m.group(2) is not None or m.group(4) is not None):
|
||||
assert m.group(2) is None or m.group(4) is None # Only one group must be valid
|
||||
checkpoint_num = int(m.group(2) if m.group(2) is not None else m.group(4))
|
||||
return SingleCheckpoint(checkpoint_num, m.group(0))
|
||||
return None
|
||||
|
||||
|
||||
def _filter_checkpoint_files(filenames: List[str], sort_by_num: bool=True) -> List[SingleCheckpoint]:
|
||||
"""
|
||||
Given a list of potential file names, return the ones that match checkpoint pattern along with
|
||||
the checkpoint number of each file name.
|
||||
:param filenames: list of all filenames
|
||||
:param sort_by_num: whether to sort the output result by checkpoint number
|
||||
:return: list of (checkpoint-number, checkpoint-filename) tuples
|
||||
"""
|
||||
parser = CheckpointFilenameParser()
|
||||
checkpoints = [ckp for ckp in [parser.parse(fn) for fn in filenames] if ckp is not None]
|
||||
if sort_by_num:
|
||||
checkpoints.sort(key=lambda x: x.num)
|
||||
return checkpoints
|
||||
|
||||
|
||||
def get_checkpoint_state(checkpoint_dir: str, all_checkpoints=False) ->Union[CheckpointState, None]:
|
||||
"""
|
||||
Scan checkpoint directory and find the list of checkpoint files.
|
||||
:param checkpoint_dir: directory where checkpoints are saved
|
||||
:param all_checkpoints: if True, scan the directory and return list of all checkpoints
|
||||
as well as the most recent one
|
||||
:return: a CheckpointState for checkpoint_dir containing a sorted list of checkpoints by checkpoint-number.
|
||||
If no matching files are found, returns None.
|
||||
"""
|
||||
return CheckpointStateUpdater(checkpoint_dir, read_all=all_checkpoints).get_checkpoint_state()
|
||||
Reference in New Issue
Block a user