1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 19:20:19 +01:00
Files
coach/rl_coach/checkpoint.py
Sina Afrooze 5332013bd1 Implement frame-work agnostic rollout and training workers (#137)
* Added checkpoint state file to coach checkpointing.

* Removed TF specific code from rollout_worker, training_worker, and s3_data_store
2018-11-23 18:05:44 -08:00

299 lines
11 KiB
Python

#
# Copyright (c) 2017 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
Module providing helper classes and functions for reading/writing checkpoint state
"""
import os
import re
from typing import List, Union, Tuple
class SingleCheckpoint(object):
"""
Helper class for storing checkpoint name and number
"""
def __init__(self, num: int, name: str):
"""
:param num: checkpoint number
:param name: checkpoint name (i.e. the prefix for all checkpoint files)
"""
self._num = num
self._name = name
@property
def num(self) -> int:
return self._num
@property
def name(self) -> str:
return self._name
def __str__(self):
return self._name
def __repr__(self):
return str(self)
def __eq__(self, other: 'SingleCheckpoint'):
if not isinstance(other, SingleCheckpoint):
return False
return self._name == other._name and self._num == other._num
def __ne__(self, other):
return not self.__eq__(other)
class CheckpointState(object):
"""
Helper class for checkpoint directory information. It replicates
the CheckpointState protobuf class in tensorflow with addition of
two new functions: last_checkpoint() and all_checkpoints()
"""
def __init__(self, checkpoints: List[SingleCheckpoint], checkpoint_dir: str):
"""
:param checkpoints: sorted list of checkpoints from oldest to newest. checkpoint[-1] is
considered to be the most recent checkpoint.
:param checkpoint_dir: checkpoint directory which is added to the paths
"""
self._checkpoints = checkpoints
self._checkpoin_dir = checkpoint_dir
@property
def all_checkpoints(self) -> List[SingleCheckpoint]:
"""
:return: list of all checkpoints
"""
return self._checkpoints
@property
def last_checkpoint(self) -> SingleCheckpoint:
"""
:return: the most recent checkpoint
"""
return self._checkpoints[-1]
@property
def all_model_checkpoint_paths(self) -> List[str]:
"""
TF compatible function call to get all checkpoints
:return: list of all available model checkpoint paths
"""
return [os.path.join(self._checkpoin_dir, c.name) for c in self._checkpoints]
@property
def model_checkpoint_path(self) -> str:
"""
TF compatible call to get most recent checkpoint
:return: path of the most recent model checkpoint
"""
return os.path.join(self._checkpoin_dir, self._checkpoints[-1].name)
def __str__(self):
out_str = 'model_checkpoint_path: {}\n'.format(self.model_checkpoint_path)
for c in self.all_model_checkpoint_paths:
out_str += 'all_model_checkpoint_paths: {}\n'.format(c)
return out_str
def __repr__(self):
return str(self._checkpoints)
class CheckpointStateFile(object):
"""
Helper class for reading from and writing to the checkpoint state file
"""
checkpoint_state_filename = '.coach_checkpoint'
def __init__(self, checkpoint_dir: str):
self._checkpoint_state_path = os.path.join(checkpoint_dir, self.checkpoint_state_filename)
def exists(self) -> bool:
"""
:return: True if checkpoint state file exists, false otherwise
"""
return os.path.exists(self._checkpoint_state_path)
def read(self) -> Union[None, SingleCheckpoint]:
"""
Read checkpoint state file and interpret its content
:return:
"""
if not self.exists():
return None
with open(self._checkpoint_state_path, 'r') as fd:
return CheckpointFilenameParser().parse(fd.read(256))
def write(self, data: SingleCheckpoint) -> None:
"""
Writes data to checkpoint state file
:param data: string data
"""
with open(self._checkpoint_state_path, 'w') as fd:
fd.write(data.name)
@property
def filename(self) -> str:
return self.checkpoint_state_filename
@property
def path(self) -> str:
return self._checkpoint_state_path
class CheckpointStateReader(object):
"""
Class for scanning checkpoint directory and updating the checkpoint state
"""
def __init__(self, checkpoint_dir: str, checkpoint_state_optional: bool=True):
"""
:param checkpoint_dir: path to checkpoint directory
:param checkpoint_state_optional: If True, checkpoint state file is optional and if not found,
directory is scanned to find the latest checkpoint. Default is True for backward compatibility
"""
self._checkpoint_dir = checkpoint_dir
self._checkpoint_state_file = CheckpointStateFile(self._checkpoint_dir)
self._checkpoint_state_optional = checkpoint_state_optional
def get_latest(self) -> SingleCheckpoint:
"""
Tries to read the checkpoint state file. If that fails, discovers latest by reading the entire directory.
:return: checkpoint object representing the latest checkpoint
"""
latest = self._checkpoint_state_file.read()
if latest is None and self._checkpoint_state_optional:
all_checkpoints = _filter_checkpoint_files(os.listdir(self._checkpoint_dir))
if len(all_checkpoints) > 0:
latest = all_checkpoints[-1]
return latest
def get_all(self) -> List[SingleCheckpoint]:
"""
Reads both the checkpoint state file as well as contents of the directory and merges them into one list.
:return: list of checkpoint objects
"""
# discover all checkpoint files in directory if requested or if a valid checkpoint-state file doesn't exist
all_checkpoints = _filter_checkpoint_files(os.listdir(self._checkpoint_dir))
last_checkpoint = self._checkpoint_state_file.read()
if last_checkpoint is not None:
# remove excess checkpoints: higher checkpoint number, but not recent (e.g. from a previous run)
all_checkpoints = all_checkpoints[: all_checkpoints.index(last_checkpoint) + 1]
elif not self._checkpoint_state_optional:
# if last_checkpoint is not discovered from the checkpoint-state file and it isn't optional, then
# all checkpoint files discovered must be partial or invalid, so don't return anything
all_checkpoints.clear()
return all_checkpoints
class CheckpointStateUpdater(object):
"""
Class for scanning checkpoint directory and updating the checkpoint state
"""
def __init__(self, checkpoint_dir: str, read_all: bool=False):
"""
:param checkpoint_dir: path to checkpoint directory
:param read_all: whether to scan the directory for existing checkpoints
"""
self._checkpoint_dir = checkpoint_dir
self._checkpoint_state_file = CheckpointStateFile(checkpoint_dir)
self._all_checkpoints = list()
# Read checkpoint state and initialize
state_reader = CheckpointStateReader(checkpoint_dir)
if read_all:
self._all_checkpoints = state_reader.get_all()
else:
latest = state_reader.get_latest()
if latest is not None:
self._all_checkpoints = [latest]
def update(self, checkpoint: SingleCheckpoint) -> None:
"""
Update the checkpoint state with the latest checkpoint.
:param checkpoint: SingleCheckpoint object containing name and number of checkpoint
"""
self._all_checkpoints.append(checkpoint)
# Simply write checkpoint_name to checkpoint-state file
self._checkpoint_state_file.write(checkpoint)
@property
def last_checkpoint(self) -> Union[None, SingleCheckpoint]:
if len(self._all_checkpoints) == 0:
return None
return self._all_checkpoints[-1]
@property
def all_checkpoints(self) -> List[SingleCheckpoint]:
return self._all_checkpoints
def get_checkpoint_state(self) -> Union[None, CheckpointState]:
"""
:return: The most recent checkpoint state
"""
if len(self._all_checkpoints) == 0:
return None
return CheckpointState(self._all_checkpoints, self._checkpoint_dir)
class CheckpointFilenameParser(object):
"""
Helper object for parsing filenames that are potentially checkpoints
"""
coach_checkpoint_filename_pattern = r'\A(([0-9]+)[^0-9])?.*?\.ckpt(-([0-9]+))?'
def __init__(self):
self._prog = re.compile(self.coach_checkpoint_filename_pattern)
def parse(self, filename: str) -> Union[None, SingleCheckpoint]:
"""
Tries to parse the filename using the checkpoint filename pattern. If successful,
it returns tuple of (checkpoint-number, checkpoint-name). Otherwise it returns None.
:param filename: filename to be parsed
:return: None or (checkpoint-number, checkpoint-name)
"""
m = self._prog.search(filename)
if m is not None and (m.group(2) is not None or m.group(4) is not None):
assert m.group(2) is None or m.group(4) is None # Only one group must be valid
checkpoint_num = int(m.group(2) if m.group(2) is not None else m.group(4))
return SingleCheckpoint(checkpoint_num, m.group(0))
return None
def _filter_checkpoint_files(filenames: List[str], sort_by_num: bool=True) -> List[SingleCheckpoint]:
"""
Given a list of potential file names, return the ones that match checkpoint pattern along with
the checkpoint number of each file name.
:param filenames: list of all filenames
:param sort_by_num: whether to sort the output result by checkpoint number
:return: list of (checkpoint-number, checkpoint-filename) tuples
"""
parser = CheckpointFilenameParser()
checkpoints = [ckp for ckp in [parser.parse(fn) for fn in filenames] if ckp is not None]
if sort_by_num:
checkpoints.sort(key=lambda x: x.num)
return checkpoints
def get_checkpoint_state(checkpoint_dir: str, all_checkpoints=False) ->Union[CheckpointState, None]:
"""
Scan checkpoint directory and find the list of checkpoint files.
:param checkpoint_dir: directory where checkpoints are saved
:param all_checkpoints: if True, scan the directory and return list of all checkpoints
as well as the most recent one
:return: a CheckpointState for checkpoint_dir containing a sorted list of checkpoints by checkpoint-number.
If no matching files are found, returns None.
"""
return CheckpointStateUpdater(checkpoint_dir, read_all=all_checkpoints).get_checkpoint_state()