mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 19:20:19 +01:00
* Added checkpoint state file to coach checkpointing. * Removed TF specific code from rollout_worker, training_worker, and s3_data_store
299 lines
11 KiB
Python
299 lines
11 KiB
Python
#
|
|
# Copyright (c) 2017 Intel Corporation
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
#
|
|
"""
|
|
Module providing helper classes and functions for reading/writing checkpoint state
|
|
"""
|
|
|
|
import os
|
|
import re
|
|
from typing import List, Union, Tuple
|
|
|
|
|
|
class SingleCheckpoint(object):
|
|
"""
|
|
Helper class for storing checkpoint name and number
|
|
"""
|
|
def __init__(self, num: int, name: str):
|
|
"""
|
|
:param num: checkpoint number
|
|
:param name: checkpoint name (i.e. the prefix for all checkpoint files)
|
|
"""
|
|
self._num = num
|
|
self._name = name
|
|
|
|
@property
|
|
def num(self) -> int:
|
|
return self._num
|
|
|
|
@property
|
|
def name(self) -> str:
|
|
return self._name
|
|
|
|
def __str__(self):
|
|
return self._name
|
|
|
|
def __repr__(self):
|
|
return str(self)
|
|
|
|
def __eq__(self, other: 'SingleCheckpoint'):
|
|
if not isinstance(other, SingleCheckpoint):
|
|
return False
|
|
return self._name == other._name and self._num == other._num
|
|
|
|
def __ne__(self, other):
|
|
return not self.__eq__(other)
|
|
|
|
|
|
class CheckpointState(object):
|
|
"""
|
|
Helper class for checkpoint directory information. It replicates
|
|
the CheckpointState protobuf class in tensorflow with addition of
|
|
two new functions: last_checkpoint() and all_checkpoints()
|
|
"""
|
|
def __init__(self, checkpoints: List[SingleCheckpoint], checkpoint_dir: str):
|
|
"""
|
|
:param checkpoints: sorted list of checkpoints from oldest to newest. checkpoint[-1] is
|
|
considered to be the most recent checkpoint.
|
|
:param checkpoint_dir: checkpoint directory which is added to the paths
|
|
"""
|
|
self._checkpoints = checkpoints
|
|
self._checkpoin_dir = checkpoint_dir
|
|
|
|
@property
|
|
def all_checkpoints(self) -> List[SingleCheckpoint]:
|
|
"""
|
|
:return: list of all checkpoints
|
|
"""
|
|
return self._checkpoints
|
|
|
|
@property
|
|
def last_checkpoint(self) -> SingleCheckpoint:
|
|
"""
|
|
:return: the most recent checkpoint
|
|
"""
|
|
return self._checkpoints[-1]
|
|
|
|
@property
|
|
def all_model_checkpoint_paths(self) -> List[str]:
|
|
"""
|
|
TF compatible function call to get all checkpoints
|
|
:return: list of all available model checkpoint paths
|
|
"""
|
|
return [os.path.join(self._checkpoin_dir, c.name) for c in self._checkpoints]
|
|
|
|
@property
|
|
def model_checkpoint_path(self) -> str:
|
|
"""
|
|
TF compatible call to get most recent checkpoint
|
|
:return: path of the most recent model checkpoint
|
|
"""
|
|
return os.path.join(self._checkpoin_dir, self._checkpoints[-1].name)
|
|
|
|
def __str__(self):
|
|
out_str = 'model_checkpoint_path: {}\n'.format(self.model_checkpoint_path)
|
|
for c in self.all_model_checkpoint_paths:
|
|
out_str += 'all_model_checkpoint_paths: {}\n'.format(c)
|
|
return out_str
|
|
|
|
def __repr__(self):
|
|
return str(self._checkpoints)
|
|
|
|
|
|
class CheckpointStateFile(object):
|
|
"""
|
|
Helper class for reading from and writing to the checkpoint state file
|
|
"""
|
|
checkpoint_state_filename = '.coach_checkpoint'
|
|
|
|
def __init__(self, checkpoint_dir: str):
|
|
self._checkpoint_state_path = os.path.join(checkpoint_dir, self.checkpoint_state_filename)
|
|
|
|
def exists(self) -> bool:
|
|
"""
|
|
:return: True if checkpoint state file exists, false otherwise
|
|
"""
|
|
return os.path.exists(self._checkpoint_state_path)
|
|
|
|
def read(self) -> Union[None, SingleCheckpoint]:
|
|
"""
|
|
Read checkpoint state file and interpret its content
|
|
:return:
|
|
"""
|
|
if not self.exists():
|
|
return None
|
|
with open(self._checkpoint_state_path, 'r') as fd:
|
|
return CheckpointFilenameParser().parse(fd.read(256))
|
|
|
|
def write(self, data: SingleCheckpoint) -> None:
|
|
"""
|
|
Writes data to checkpoint state file
|
|
:param data: string data
|
|
"""
|
|
with open(self._checkpoint_state_path, 'w') as fd:
|
|
fd.write(data.name)
|
|
|
|
@property
|
|
def filename(self) -> str:
|
|
return self.checkpoint_state_filename
|
|
|
|
@property
|
|
def path(self) -> str:
|
|
return self._checkpoint_state_path
|
|
|
|
|
|
class CheckpointStateReader(object):
|
|
"""
|
|
Class for scanning checkpoint directory and updating the checkpoint state
|
|
"""
|
|
def __init__(self, checkpoint_dir: str, checkpoint_state_optional: bool=True):
|
|
"""
|
|
:param checkpoint_dir: path to checkpoint directory
|
|
:param checkpoint_state_optional: If True, checkpoint state file is optional and if not found,
|
|
directory is scanned to find the latest checkpoint. Default is True for backward compatibility
|
|
"""
|
|
self._checkpoint_dir = checkpoint_dir
|
|
self._checkpoint_state_file = CheckpointStateFile(self._checkpoint_dir)
|
|
self._checkpoint_state_optional = checkpoint_state_optional
|
|
|
|
def get_latest(self) -> SingleCheckpoint:
|
|
"""
|
|
Tries to read the checkpoint state file. If that fails, discovers latest by reading the entire directory.
|
|
:return: checkpoint object representing the latest checkpoint
|
|
"""
|
|
latest = self._checkpoint_state_file.read()
|
|
if latest is None and self._checkpoint_state_optional:
|
|
all_checkpoints = _filter_checkpoint_files(os.listdir(self._checkpoint_dir))
|
|
if len(all_checkpoints) > 0:
|
|
latest = all_checkpoints[-1]
|
|
return latest
|
|
|
|
def get_all(self) -> List[SingleCheckpoint]:
|
|
"""
|
|
Reads both the checkpoint state file as well as contents of the directory and merges them into one list.
|
|
:return: list of checkpoint objects
|
|
"""
|
|
# discover all checkpoint files in directory if requested or if a valid checkpoint-state file doesn't exist
|
|
all_checkpoints = _filter_checkpoint_files(os.listdir(self._checkpoint_dir))
|
|
last_checkpoint = self._checkpoint_state_file.read()
|
|
if last_checkpoint is not None:
|
|
# remove excess checkpoints: higher checkpoint number, but not recent (e.g. from a previous run)
|
|
all_checkpoints = all_checkpoints[: all_checkpoints.index(last_checkpoint) + 1]
|
|
elif not self._checkpoint_state_optional:
|
|
# if last_checkpoint is not discovered from the checkpoint-state file and it isn't optional, then
|
|
# all checkpoint files discovered must be partial or invalid, so don't return anything
|
|
all_checkpoints.clear()
|
|
return all_checkpoints
|
|
|
|
|
|
class CheckpointStateUpdater(object):
|
|
"""
|
|
Class for scanning checkpoint directory and updating the checkpoint state
|
|
"""
|
|
def __init__(self, checkpoint_dir: str, read_all: bool=False):
|
|
"""
|
|
:param checkpoint_dir: path to checkpoint directory
|
|
:param read_all: whether to scan the directory for existing checkpoints
|
|
"""
|
|
self._checkpoint_dir = checkpoint_dir
|
|
self._checkpoint_state_file = CheckpointStateFile(checkpoint_dir)
|
|
self._all_checkpoints = list()
|
|
# Read checkpoint state and initialize
|
|
state_reader = CheckpointStateReader(checkpoint_dir)
|
|
if read_all:
|
|
self._all_checkpoints = state_reader.get_all()
|
|
else:
|
|
latest = state_reader.get_latest()
|
|
if latest is not None:
|
|
self._all_checkpoints = [latest]
|
|
|
|
def update(self, checkpoint: SingleCheckpoint) -> None:
|
|
"""
|
|
Update the checkpoint state with the latest checkpoint.
|
|
:param checkpoint: SingleCheckpoint object containing name and number of checkpoint
|
|
"""
|
|
self._all_checkpoints.append(checkpoint)
|
|
# Simply write checkpoint_name to checkpoint-state file
|
|
self._checkpoint_state_file.write(checkpoint)
|
|
|
|
@property
|
|
def last_checkpoint(self) -> Union[None, SingleCheckpoint]:
|
|
if len(self._all_checkpoints) == 0:
|
|
return None
|
|
return self._all_checkpoints[-1]
|
|
|
|
@property
|
|
def all_checkpoints(self) -> List[SingleCheckpoint]:
|
|
return self._all_checkpoints
|
|
|
|
def get_checkpoint_state(self) -> Union[None, CheckpointState]:
|
|
"""
|
|
:return: The most recent checkpoint state
|
|
"""
|
|
if len(self._all_checkpoints) == 0:
|
|
return None
|
|
return CheckpointState(self._all_checkpoints, self._checkpoint_dir)
|
|
|
|
|
|
class CheckpointFilenameParser(object):
|
|
"""
|
|
Helper object for parsing filenames that are potentially checkpoints
|
|
"""
|
|
coach_checkpoint_filename_pattern = r'\A(([0-9]+)[^0-9])?.*?\.ckpt(-([0-9]+))?'
|
|
|
|
def __init__(self):
|
|
self._prog = re.compile(self.coach_checkpoint_filename_pattern)
|
|
|
|
def parse(self, filename: str) -> Union[None, SingleCheckpoint]:
|
|
"""
|
|
Tries to parse the filename using the checkpoint filename pattern. If successful,
|
|
it returns tuple of (checkpoint-number, checkpoint-name). Otherwise it returns None.
|
|
:param filename: filename to be parsed
|
|
:return: None or (checkpoint-number, checkpoint-name)
|
|
"""
|
|
m = self._prog.search(filename)
|
|
if m is not None and (m.group(2) is not None or m.group(4) is not None):
|
|
assert m.group(2) is None or m.group(4) is None # Only one group must be valid
|
|
checkpoint_num = int(m.group(2) if m.group(2) is not None else m.group(4))
|
|
return SingleCheckpoint(checkpoint_num, m.group(0))
|
|
return None
|
|
|
|
|
|
def _filter_checkpoint_files(filenames: List[str], sort_by_num: bool=True) -> List[SingleCheckpoint]:
|
|
"""
|
|
Given a list of potential file names, return the ones that match checkpoint pattern along with
|
|
the checkpoint number of each file name.
|
|
:param filenames: list of all filenames
|
|
:param sort_by_num: whether to sort the output result by checkpoint number
|
|
:return: list of (checkpoint-number, checkpoint-filename) tuples
|
|
"""
|
|
parser = CheckpointFilenameParser()
|
|
checkpoints = [ckp for ckp in [parser.parse(fn) for fn in filenames] if ckp is not None]
|
|
if sort_by_num:
|
|
checkpoints.sort(key=lambda x: x.num)
|
|
return checkpoints
|
|
|
|
|
|
def get_checkpoint_state(checkpoint_dir: str, all_checkpoints=False) ->Union[CheckpointState, None]:
|
|
"""
|
|
Scan checkpoint directory and find the list of checkpoint files.
|
|
:param checkpoint_dir: directory where checkpoints are saved
|
|
:param all_checkpoints: if True, scan the directory and return list of all checkpoints
|
|
as well as the most recent one
|
|
:return: a CheckpointState for checkpoint_dir containing a sorted list of checkpoints by checkpoint-number.
|
|
If no matching files are found, returns None.
|
|
"""
|
|
return CheckpointStateUpdater(checkpoint_dir, read_all=all_checkpoints).get_checkpoint_state()
|