diff --git a/rl_coach/checkpoint.py b/rl_coach/checkpoint.py new file mode 100644 index 0000000..8c2deb2 --- /dev/null +++ b/rl_coach/checkpoint.py @@ -0,0 +1,298 @@ +# +# Copyright (c) 2017 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +Module providing helper classes and functions for reading/writing checkpoint state +""" + +import os +import re +from typing import List, Union, Tuple + + +class SingleCheckpoint(object): + """ + Helper class for storing checkpoint name and number + """ + def __init__(self, num: int, name: str): + """ + :param num: checkpoint number + :param name: checkpoint name (i.e. the prefix for all checkpoint files) + """ + self._num = num + self._name = name + + @property + def num(self) -> int: + return self._num + + @property + def name(self) -> str: + return self._name + + def __str__(self): + return self._name + + def __repr__(self): + return str(self) + + def __eq__(self, other: 'SingleCheckpoint'): + if not isinstance(other, SingleCheckpoint): + return False + return self._name == other._name and self._num == other._num + + def __ne__(self, other): + return not self.__eq__(other) + + +class CheckpointState(object): + """ + Helper class for checkpoint directory information. It replicates + the CheckpointState protobuf class in tensorflow with addition of + two new functions: last_checkpoint() and all_checkpoints() + """ + def __init__(self, checkpoints: List[SingleCheckpoint], checkpoint_dir: str): + """ + :param checkpoints: sorted list of checkpoints from oldest to newest. checkpoint[-1] is + considered to be the most recent checkpoint. + :param checkpoint_dir: checkpoint directory which is added to the paths + """ + self._checkpoints = checkpoints + self._checkpoin_dir = checkpoint_dir + + @property + def all_checkpoints(self) -> List[SingleCheckpoint]: + """ + :return: list of all checkpoints + """ + return self._checkpoints + + @property + def last_checkpoint(self) -> SingleCheckpoint: + """ + :return: the most recent checkpoint + """ + return self._checkpoints[-1] + + @property + def all_model_checkpoint_paths(self) -> List[str]: + """ + TF compatible function call to get all checkpoints + :return: list of all available model checkpoint paths + """ + return [os.path.join(self._checkpoin_dir, c.name) for c in self._checkpoints] + + @property + def model_checkpoint_path(self) -> str: + """ + TF compatible call to get most recent checkpoint + :return: path of the most recent model checkpoint + """ + return os.path.join(self._checkpoin_dir, self._checkpoints[-1].name) + + def __str__(self): + out_str = 'model_checkpoint_path: {}\n'.format(self.model_checkpoint_path) + for c in self.all_model_checkpoint_paths: + out_str += 'all_model_checkpoint_paths: {}\n'.format(c) + return out_str + + def __repr__(self): + return str(self._checkpoints) + + +class CheckpointStateFile(object): + """ + Helper class for reading from and writing to the checkpoint state file + """ + checkpoint_state_filename = '.coach_checkpoint' + + def __init__(self, checkpoint_dir: str): + self._checkpoint_state_path = os.path.join(checkpoint_dir, self.checkpoint_state_filename) + + def exists(self) -> bool: + """ + :return: True if checkpoint state file exists, false otherwise + """ + return os.path.exists(self._checkpoint_state_path) + + def read(self) -> Union[None, SingleCheckpoint]: + """ + Read checkpoint state file and interpret its content + :return: + """ + if not self.exists(): + return None + with open(self._checkpoint_state_path, 'r') as fd: + return CheckpointFilenameParser().parse(fd.read(256)) + + def write(self, data: SingleCheckpoint) -> None: + """ + Writes data to checkpoint state file + :param data: string data + """ + with open(self._checkpoint_state_path, 'w') as fd: + fd.write(data.name) + + @property + def filename(self) -> str: + return self.checkpoint_state_filename + + @property + def path(self) -> str: + return self._checkpoint_state_path + + +class CheckpointStateReader(object): + """ + Class for scanning checkpoint directory and updating the checkpoint state + """ + def __init__(self, checkpoint_dir: str, checkpoint_state_optional: bool=True): + """ + :param checkpoint_dir: path to checkpoint directory + :param checkpoint_state_optional: If True, checkpoint state file is optional and if not found, + directory is scanned to find the latest checkpoint. Default is True for backward compatibility + """ + self._checkpoint_dir = checkpoint_dir + self._checkpoint_state_file = CheckpointStateFile(self._checkpoint_dir) + self._checkpoint_state_optional = checkpoint_state_optional + + def get_latest(self) -> SingleCheckpoint: + """ + Tries to read the checkpoint state file. If that fails, discovers latest by reading the entire directory. + :return: checkpoint object representing the latest checkpoint + """ + latest = self._checkpoint_state_file.read() + if latest is None and self._checkpoint_state_optional: + all_checkpoints = _filter_checkpoint_files(os.listdir(self._checkpoint_dir)) + if len(all_checkpoints) > 0: + latest = all_checkpoints[-1] + return latest + + def get_all(self) -> List[SingleCheckpoint]: + """ + Reads both the checkpoint state file as well as contents of the directory and merges them into one list. + :return: list of checkpoint objects + """ + # discover all checkpoint files in directory if requested or if a valid checkpoint-state file doesn't exist + all_checkpoints = _filter_checkpoint_files(os.listdir(self._checkpoint_dir)) + last_checkpoint = self._checkpoint_state_file.read() + if last_checkpoint is not None: + # remove excess checkpoints: higher checkpoint number, but not recent (e.g. from a previous run) + all_checkpoints = all_checkpoints[: all_checkpoints.index(last_checkpoint) + 1] + elif not self._checkpoint_state_optional: + # if last_checkpoint is not discovered from the checkpoint-state file and it isn't optional, then + # all checkpoint files discovered must be partial or invalid, so don't return anything + all_checkpoints.clear() + return all_checkpoints + + +class CheckpointStateUpdater(object): + """ + Class for scanning checkpoint directory and updating the checkpoint state + """ + def __init__(self, checkpoint_dir: str, read_all: bool=False): + """ + :param checkpoint_dir: path to checkpoint directory + :param read_all: whether to scan the directory for existing checkpoints + """ + self._checkpoint_dir = checkpoint_dir + self._checkpoint_state_file = CheckpointStateFile(checkpoint_dir) + self._all_checkpoints = list() + # Read checkpoint state and initialize + state_reader = CheckpointStateReader(checkpoint_dir) + if read_all: + self._all_checkpoints = state_reader.get_all() + else: + latest = state_reader.get_latest() + if latest is not None: + self._all_checkpoints = [latest] + + def update(self, checkpoint: SingleCheckpoint) -> None: + """ + Update the checkpoint state with the latest checkpoint. + :param checkpoint: SingleCheckpoint object containing name and number of checkpoint + """ + self._all_checkpoints.append(checkpoint) + # Simply write checkpoint_name to checkpoint-state file + self._checkpoint_state_file.write(checkpoint) + + @property + def last_checkpoint(self) -> Union[None, SingleCheckpoint]: + if len(self._all_checkpoints) == 0: + return None + return self._all_checkpoints[-1] + + @property + def all_checkpoints(self) -> List[SingleCheckpoint]: + return self._all_checkpoints + + def get_checkpoint_state(self) -> Union[None, CheckpointState]: + """ + :return: The most recent checkpoint state + """ + if len(self._all_checkpoints) == 0: + return None + return CheckpointState(self._all_checkpoints, self._checkpoint_dir) + + +class CheckpointFilenameParser(object): + """ + Helper object for parsing filenames that are potentially checkpoints + """ + coach_checkpoint_filename_pattern = r'\A(([0-9]+)[^0-9])?.*?\.ckpt(-([0-9]+))?' + + def __init__(self): + self._prog = re.compile(self.coach_checkpoint_filename_pattern) + + def parse(self, filename: str) -> Union[None, SingleCheckpoint]: + """ + Tries to parse the filename using the checkpoint filename pattern. If successful, + it returns tuple of (checkpoint-number, checkpoint-name). Otherwise it returns None. + :param filename: filename to be parsed + :return: None or (checkpoint-number, checkpoint-name) + """ + m = self._prog.search(filename) + if m is not None and (m.group(2) is not None or m.group(4) is not None): + assert m.group(2) is None or m.group(4) is None # Only one group must be valid + checkpoint_num = int(m.group(2) if m.group(2) is not None else m.group(4)) + return SingleCheckpoint(checkpoint_num, m.group(0)) + return None + + +def _filter_checkpoint_files(filenames: List[str], sort_by_num: bool=True) -> List[SingleCheckpoint]: + """ + Given a list of potential file names, return the ones that match checkpoint pattern along with + the checkpoint number of each file name. + :param filenames: list of all filenames + :param sort_by_num: whether to sort the output result by checkpoint number + :return: list of (checkpoint-number, checkpoint-filename) tuples + """ + parser = CheckpointFilenameParser() + checkpoints = [ckp for ckp in [parser.parse(fn) for fn in filenames] if ckp is not None] + if sort_by_num: + checkpoints.sort(key=lambda x: x.num) + return checkpoints + + +def get_checkpoint_state(checkpoint_dir: str, all_checkpoints=False) ->Union[CheckpointState, None]: + """ + Scan checkpoint directory and find the list of checkpoint files. + :param checkpoint_dir: directory where checkpoints are saved + :param all_checkpoints: if True, scan the directory and return list of all checkpoints + as well as the most recent one + :return: a CheckpointState for checkpoint_dir containing a sorted list of checkpoints by checkpoint-number. + If no matching files are found, returns None. + """ + return CheckpointStateUpdater(checkpoint_dir, read_all=all_checkpoints).get_checkpoint_state() diff --git a/rl_coach/data_stores/s3_data_store.py b/rl_coach/data_stores/s3_data_store.py index 8e83adc..7a643a1 100644 --- a/rl_coach/data_stores/s3_data_store.py +++ b/rl_coach/data_stores/s3_data_store.py @@ -2,8 +2,7 @@ from rl_coach.data_stores.data_store import DataStore, DataStoreParameters from minio import Minio from minio.error import ResponseError from configparser import ConfigParser, Error -from google.protobuf import text_format -from tensorflow.python.training.checkpoint_state_pb2 import CheckpointState +from rl_coach.checkpoint import CheckpointStateFile from rl_coach.data_stores.data_store import SyncFiles import os @@ -24,6 +23,7 @@ class S3DataStoreParameters(DataStoreParameters): class S3DataStore(DataStore): def __init__(self, params: S3DataStoreParameters): + super(S3DataStore, self).__init__(params) self.params = params access_key = None secret_key = None @@ -51,14 +51,15 @@ class S3DataStore(DataStore): def save_to_store(self): try: + # remove lock file if it exists self.mc.remove_object(self.params.bucket_name, SyncFiles.LOCKFILE.value) - + # Acquire lock self.mc.put_object(self.params.bucket_name, SyncFiles.LOCKFILE.value, io.BytesIO(b''), 0) checkpoint_file = None for root, dirs, files in os.walk(self.params.checkpoint_dir): for filename in files: - if filename == 'checkpoint': + if filename == CheckpointStateFile.checkpoint_state_filename: checkpoint_file = (root, filename) continue abs_name = os.path.abspath(os.path.join(root, filename)) @@ -69,6 +70,7 @@ class S3DataStore(DataStore): rel_name = os.path.relpath(abs_name, self.params.checkpoint_dir) self.mc.fput_object(self.params.bucket_name, rel_name, abs_name) + # release lock self.mc.remove_object(self.params.bucket_name, SyncFiles.LOCKFILE.value) except ResponseError as e: @@ -76,14 +78,16 @@ class S3DataStore(DataStore): def load_from_store(self): try: - filename = os.path.abspath(os.path.join(self.params.checkpoint_dir, "checkpoint")) + state_file = CheckpointStateFile(os.path.abspath(self.params.checkpoint_dir)) + # wait until lock is removed while True: objects = self.mc.list_objects_v2(self.params.bucket_name, SyncFiles.LOCKFILE.value) if next(objects, None) is None: try: - self.mc.fget_object(self.params.bucket_name, "checkpoint", filename) + # fetch checkpoint state file from S3 + self.mc.fget_object(self.params.bucket_name, state_file.filename, state_file.path) except Exception as e: continue break @@ -101,13 +105,9 @@ class S3DataStore(DataStore): except Exception as e: pass - ckpt = CheckpointState() - if os.path.exists(filename): - contents = open(filename, 'r').read() - text_format.Merge(contents, ckpt) - rel_path = os.path.relpath(ckpt.model_checkpoint_path, self.params.checkpoint_dir) - - objects = self.mc.list_objects_v2(self.params.bucket_name, prefix=rel_path, recursive=True) + checkpoint_state = state_file.read() + if checkpoint_state is not None: + objects = self.mc.list_objects_v2(self.params.bucket_name, prefix=checkpoint_state.name, recursive=True) for obj in objects: filename = os.path.abspath(os.path.join(self.params.checkpoint_dir, obj.object_name)) if not os.path.exists(filename): diff --git a/rl_coach/graph_managers/graph_manager.py b/rl_coach/graph_managers/graph_manager.py index 518c388..44a5df8 100644 --- a/rl_coach/graph_managers/graph_manager.py +++ b/rl_coach/graph_managers/graph_manager.py @@ -25,6 +25,7 @@ import contextlib from rl_coach.base_parameters import iterable_to_items, TaskParameters, DistributedTaskParameters, Frameworks, \ VisualizationParameters, \ Parameters, PresetValidationParameters, RunType +from rl_coach.checkpoint import CheckpointStateUpdater, get_checkpoint_state, SingleCheckpoint from rl_coach.core_types import TotalStepsCounter, RunPhase, PlayingStepsType, TrainingSteps, EnvironmentEpisodes, \ EnvironmentSteps, \ StepMethod, Transition @@ -32,7 +33,7 @@ from rl_coach.environments.environment import Environment from rl_coach.level_manager import LevelManager from rl_coach.logger import screen, Logger from rl_coach.saver import SaverCollection -from rl_coach.utils import get_checkpoint_state, set_cpu, start_shell_command_and_wait +from rl_coach.utils import set_cpu, start_shell_command_and_wait from rl_coach.data_stores.data_store_impl import get_data_store as data_store_creator from rl_coach.memories.backend.memory_impl import get_memory_backend from rl_coach.data_stores.data_store import SyncFiles @@ -115,6 +116,7 @@ class GraphManager(object): self.checkpoint_id = 0 self.checkpoint_saver = None + self.checkpoint_state_updater = None self.graph_logger = Logger() self.data_store = None @@ -495,7 +497,7 @@ class GraphManager(object): self.act(EnvironmentEpisodes(1)) self.sync() if self.should_stop(): - if self.task_parameters.checkpoint_save_dir: + if self.task_parameters.checkpoint_save_dir and os.path.exists(self.task_parameters.checkpoint_save_dir): open(os.path.join(self.task_parameters.checkpoint_save_dir, SyncFiles.FINISHED.value), 'w').close() if hasattr(self, 'data_store_params'): data_store = self.get_data_store(self.data_store_params) @@ -579,29 +581,35 @@ class GraphManager(object): self.save_checkpoint() def save_checkpoint(self): + # create current session's checkpoint directory if self.task_parameters.checkpoint_save_dir is None: self.task_parameters.checkpoint_save_dir = os.path.join(self.task_parameters.experiment_path, 'checkpoint') - filename = "{}_Step-{}.ckpt".format( - self.checkpoint_id, - self.total_steps_counters[RunPhase.TRAIN][EnvironmentSteps]) + if not os.path.exists(self.task_parameters.checkpoint_save_dir): + os.mkdir(self.task_parameters.checkpoint_save_dir) # Create directory structure + + if self.checkpoint_state_updater is None: + self.checkpoint_state_updater = CheckpointStateUpdater(self.task_parameters.checkpoint_save_dir) + + checkpoint_name = "{}_Step-{}.ckpt".format( + self.checkpoint_id, self.total_steps_counters[RunPhase.TRAIN][EnvironmentSteps]) + checkpoint_path = os.path.join(self.task_parameters.checkpoint_save_dir, checkpoint_name) - checkpoint_path = os.path.join(self.task_parameters.checkpoint_save_dir, - filename) - if not os.path.exists(os.path.dirname(checkpoint_path)): - os.mkdir(os.path.dirname(checkpoint_path)) # Create directory structure if not isinstance(self.task_parameters, DistributedTaskParameters): saved_checkpoint_path = self.checkpoint_saver.save(self.sess, checkpoint_path) else: saved_checkpoint_path = checkpoint_path # this is required in order for agents to save additional information like a DND for example - [manager.save_checkpoint(filename) for manager in self.level_managers] + [manager.save_checkpoint(checkpoint_name) for manager in self.level_managers] # the ONNX graph will be stored only if checkpoints are stored and the -onnx flag is used if self.task_parameters.export_onnx_graph: self.save_onnx_graph() + # write the new checkpoint name to a file to signal this checkpoint has been fully saved + self.checkpoint_state_updater.update(SingleCheckpoint(self.checkpoint_id, checkpoint_name)) + screen.log_dict( OrderedDict([ ("Saving in path", saved_checkpoint_path), diff --git a/rl_coach/rollout_worker.py b/rl_coach/rollout_worker.py index 82045ed..549b0a1 100644 --- a/rl_coach/rollout_worker.py +++ b/rl_coach/rollout_worker.py @@ -12,37 +12,26 @@ import os import math from rl_coach.base_parameters import TaskParameters, DistributedCoachSynchronizationType +from rl_coach.checkpoint import CheckpointStateFile, CheckpointStateReader from rl_coach.core_types import EnvironmentSteps, RunPhase, EnvironmentEpisodes -from google.protobuf import text_format -from tensorflow.python.training.checkpoint_state_pb2 import CheckpointState from rl_coach.data_stores.data_store import SyncFiles -def has_checkpoint(checkpoint_dir): - """ - True if a checkpoint is present in checkpoint_dir - """ - if os.path.isdir(checkpoint_dir): - if len(os.listdir(checkpoint_dir)) > 0: - return os.path.isfile(os.path.join(checkpoint_dir, "checkpoint")) - - return False - - def wait_for_checkpoint(checkpoint_dir, data_store=None, timeout=10): """ block until there is a checkpoint in checkpoint_dir """ + chkpt_state_file = CheckpointStateFile(checkpoint_dir) for i in range(timeout): if data_store: data_store.load_from_store() - if has_checkpoint(checkpoint_dir): + if chkpt_state_file.read() is not None: return time.sleep(10) # one last time - if has_checkpoint(checkpoint_dir): + if chkpt_state_file.read() is not None: return raise ValueError(( @@ -54,21 +43,6 @@ def wait_for_checkpoint(checkpoint_dir, data_store=None, timeout=10): )) -def data_store_ckpt_load(data_store): - while True: - data_store.load_from_store() - time.sleep(10) - - -def get_latest_checkpoint(checkpoint_dir): - if os.path.exists(os.path.join(checkpoint_dir, 'checkpoint')): - ckpt = CheckpointState() - contents = open(os.path.join(checkpoint_dir, 'checkpoint'), 'r').read() - text_format.Merge(contents, ckpt) - rel_path = os.path.relpath(ckpt.model_checkpoint_path, checkpoint_dir) - return int(rel_path.split('_Step')[0]) - - def should_stop(checkpoint_dir): return os.path.exists(os.path.join(checkpoint_dir, SyncFiles.FINISHED.value)) @@ -83,6 +57,7 @@ def rollout_worker(graph_manager, data_store, num_workers, task_parameters): graph_manager.create_graph(task_parameters) with graph_manager.phase_context(RunPhase.TRAIN): + chkpt_state_reader = CheckpointStateReader(checkpoint_dir, checkpoint_state_optional=False) last_checkpoint = 0 act_steps = math.ceil((graph_manager.agent_params.algorithm.num_consecutive_playing_steps.num_steps)/num_workers) @@ -97,20 +72,20 @@ def rollout_worker(graph_manager, data_store, num_workers, task_parameters): elif type(graph_manager.agent_params.algorithm.num_consecutive_playing_steps) == EnvironmentEpisodes: graph_manager.act(EnvironmentEpisodes(num_steps=act_steps)) - new_checkpoint = get_latest_checkpoint(checkpoint_dir) - + new_checkpoint = chkpt_state_reader.get_latest() if graph_manager.agent_params.algorithm.distributed_coach_synchronization_type == DistributedCoachSynchronizationType.SYNC: - while new_checkpoint < last_checkpoint + 1: + while new_checkpoint is None or new_checkpoint.num < last_checkpoint + 1: if should_stop(checkpoint_dir): break if data_store: data_store.load_from_store() - new_checkpoint = get_latest_checkpoint(checkpoint_dir) + new_checkpoint = chkpt_state_reader.get_latest() graph_manager.restore_checkpoint() if graph_manager.agent_params.algorithm.distributed_coach_synchronization_type == DistributedCoachSynchronizationType.ASYNC: - if new_checkpoint > last_checkpoint: + if new_checkpoint is not None and new_checkpoint.num > last_checkpoint: graph_manager.restore_checkpoint() - last_checkpoint = new_checkpoint + if new_checkpoint is not None: + last_checkpoint = new_checkpoint.num diff --git a/rl_coach/tests/test_utils.py b/rl_coach/tests/test_checkpoint.py similarity index 54% rename from rl_coach/tests/test_utils.py rename to rl_coach/tests/test_checkpoint.py index face47b..5fb8991 100644 --- a/rl_coach/tests/test_utils.py +++ b/rl_coach/tests/test_checkpoint.py @@ -2,7 +2,7 @@ import os import pytest import tempfile -from rl_coach import utils +from rl_coach import checkpoint @pytest.mark.unit_test @@ -10,8 +10,15 @@ def test_get_checkpoint_state(): files = ['4.test.ckpt.ext', '2.test.ckpt.ext', '3.test.ckpt.ext', '1.test.ckpt.ext', 'prefix.10.test.ckpt.ext'] with tempfile.TemporaryDirectory() as temp_dir: [open(os.path.join(temp_dir, fn), 'a').close() for fn in files] - checkpoint_state = utils.get_checkpoint_state(temp_dir) + checkpoint_state = checkpoint.get_checkpoint_state(temp_dir, all_checkpoints=True) assert checkpoint_state.model_checkpoint_path == os.path.join(temp_dir, '4.test.ckpt') assert checkpoint_state.all_model_checkpoint_paths == \ [os.path.join(temp_dir, f[:-4]) for f in sorted(files[:-1])] + reader = checkpoint.CheckpointStateReader(temp_dir, checkpoint_state_optional=False) + assert reader.get_latest() is None + assert len(reader.get_all()) == 0 + + reader = checkpoint.CheckpointStateReader(temp_dir) + assert reader.get_latest().num == 4 + assert [ckp.num for ckp in reader.get_all()] == [1, 2, 3, 4] diff --git a/rl_coach/training_worker.py b/rl_coach/training_worker.py index c43306a..17d6d1e 100644 --- a/rl_coach/training_worker.py +++ b/rl_coach/training_worker.py @@ -30,7 +30,7 @@ def training_worker(graph_manager, task_parameters): graph_manager.setup_memory_backend() - while(steps < graph_manager.improve_steps.num_steps): + while steps < graph_manager.improve_steps.num_steps: graph_manager.phase = core_types.RunPhase.TRAIN graph_manager.fetch_from_worker(graph_manager.agent_params.algorithm.num_consecutive_playing_steps) diff --git a/rl_coach/utils.py b/rl_coach/utils.py index 7b5b95d..a8ace7a 100644 --- a/rl_coach/utils.py +++ b/rl_coach/utils.py @@ -546,58 +546,3 @@ def start_shell_command_and_wait(command): def indent_string(string): return '\t' + string.replace('\n', '\n\t') - - -class CheckpointState(object): - """ - Helper class for checkpoint directory information. It replicates - the CheckpointState protobuf class in tensorflow. - """ - def __init__(self, checkpoints: List[str]): - self._checkpoints = checkpoints - - @property - def all_model_checkpoint_paths(self): - return self._checkpoints - - @property - def model_checkpoint_path(self): - return self._checkpoints[-1] - - def __str__(self): - out_str = 'model_checkpoint_path: {}\n'.format(self.model_checkpoint_path) - for c in self._checkpoints: - out_str += 'all_model_checkpoint_paths: {}\n'.format(c) - return out_str - - def __repr__(self): - return str(self._checkpoints) - - -COACH_CHECKPOINT_PATTERN = r'\A(([0-9]+)[^0-9])?.*?\.ckpt(-([0-9]+))?' - - -def get_checkpoint_state(checkpoint_dir: Union[str, List[str]], filename_pattern: str=COACH_CHECKPOINT_PATTERN) ->\ - Union[CheckpointState, None]: - """ - Finds the latest checkpoint file. It uses the first group of filename_pattern (i.e. group(2) or group(4) to sort - the checkpoint names and find the latest checkpoint - :param checkpoint_dir: directory where checkpoints are saved or list of all files in a directory - :param filename_pattern: regex pattern for checkpoint filenames - :return: a CheckpointState for checkpoint_dir containing a sorted list of checkpoint names. If no matching - files are found, returns None. - """ - prog = re.compile(filename_pattern) - checkpoints = dict() - filenames = os.listdir(checkpoint_dir) if isinstance(checkpoint_dir, str) else checkpoint_dir - for name in filenames: - m = prog.search(name) - if m is not None and (m.group(2) is not None or m.group(4) is not None): - if m.group(2) is not None and m.group(4) is not None: - assert m.group(2) == m.group(4) - checkpoint_count = int(m.group(2) if m.group(2) is not None else m.group(4)) - full_path = os.path.join(checkpoint_dir, m.group(0)) if isinstance(checkpoint_dir, str) else m.group(0) - checkpoints[checkpoint_count] = full_path - if len(checkpoints) == 0: - return None - return CheckpointState([checkpoints[k] for k in sorted(checkpoints.keys())])