1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 19:20:19 +01:00

Implement frame-work agnostic rollout and training workers (#137)

* Added checkpoint state file to coach checkpointing.

* Removed TF specific code from rollout_worker, training_worker, and s3_data_store
This commit is contained in:
Sina Afrooze
2018-11-23 18:05:44 -08:00
committed by Balaji Subramaniam
parent 4a6c404070
commit 5332013bd1
7 changed files with 350 additions and 117 deletions

298
rl_coach/checkpoint.py Normal file
View File

@@ -0,0 +1,298 @@
#
# Copyright (c) 2017 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
Module providing helper classes and functions for reading/writing checkpoint state
"""
import os
import re
from typing import List, Union, Tuple
class SingleCheckpoint(object):
"""
Helper class for storing checkpoint name and number
"""
def __init__(self, num: int, name: str):
"""
:param num: checkpoint number
:param name: checkpoint name (i.e. the prefix for all checkpoint files)
"""
self._num = num
self._name = name
@property
def num(self) -> int:
return self._num
@property
def name(self) -> str:
return self._name
def __str__(self):
return self._name
def __repr__(self):
return str(self)
def __eq__(self, other: 'SingleCheckpoint'):
if not isinstance(other, SingleCheckpoint):
return False
return self._name == other._name and self._num == other._num
def __ne__(self, other):
return not self.__eq__(other)
class CheckpointState(object):
"""
Helper class for checkpoint directory information. It replicates
the CheckpointState protobuf class in tensorflow with addition of
two new functions: last_checkpoint() and all_checkpoints()
"""
def __init__(self, checkpoints: List[SingleCheckpoint], checkpoint_dir: str):
"""
:param checkpoints: sorted list of checkpoints from oldest to newest. checkpoint[-1] is
considered to be the most recent checkpoint.
:param checkpoint_dir: checkpoint directory which is added to the paths
"""
self._checkpoints = checkpoints
self._checkpoin_dir = checkpoint_dir
@property
def all_checkpoints(self) -> List[SingleCheckpoint]:
"""
:return: list of all checkpoints
"""
return self._checkpoints
@property
def last_checkpoint(self) -> SingleCheckpoint:
"""
:return: the most recent checkpoint
"""
return self._checkpoints[-1]
@property
def all_model_checkpoint_paths(self) -> List[str]:
"""
TF compatible function call to get all checkpoints
:return: list of all available model checkpoint paths
"""
return [os.path.join(self._checkpoin_dir, c.name) for c in self._checkpoints]
@property
def model_checkpoint_path(self) -> str:
"""
TF compatible call to get most recent checkpoint
:return: path of the most recent model checkpoint
"""
return os.path.join(self._checkpoin_dir, self._checkpoints[-1].name)
def __str__(self):
out_str = 'model_checkpoint_path: {}\n'.format(self.model_checkpoint_path)
for c in self.all_model_checkpoint_paths:
out_str += 'all_model_checkpoint_paths: {}\n'.format(c)
return out_str
def __repr__(self):
return str(self._checkpoints)
class CheckpointStateFile(object):
"""
Helper class for reading from and writing to the checkpoint state file
"""
checkpoint_state_filename = '.coach_checkpoint'
def __init__(self, checkpoint_dir: str):
self._checkpoint_state_path = os.path.join(checkpoint_dir, self.checkpoint_state_filename)
def exists(self) -> bool:
"""
:return: True if checkpoint state file exists, false otherwise
"""
return os.path.exists(self._checkpoint_state_path)
def read(self) -> Union[None, SingleCheckpoint]:
"""
Read checkpoint state file and interpret its content
:return:
"""
if not self.exists():
return None
with open(self._checkpoint_state_path, 'r') as fd:
return CheckpointFilenameParser().parse(fd.read(256))
def write(self, data: SingleCheckpoint) -> None:
"""
Writes data to checkpoint state file
:param data: string data
"""
with open(self._checkpoint_state_path, 'w') as fd:
fd.write(data.name)
@property
def filename(self) -> str:
return self.checkpoint_state_filename
@property
def path(self) -> str:
return self._checkpoint_state_path
class CheckpointStateReader(object):
"""
Class for scanning checkpoint directory and updating the checkpoint state
"""
def __init__(self, checkpoint_dir: str, checkpoint_state_optional: bool=True):
"""
:param checkpoint_dir: path to checkpoint directory
:param checkpoint_state_optional: If True, checkpoint state file is optional and if not found,
directory is scanned to find the latest checkpoint. Default is True for backward compatibility
"""
self._checkpoint_dir = checkpoint_dir
self._checkpoint_state_file = CheckpointStateFile(self._checkpoint_dir)
self._checkpoint_state_optional = checkpoint_state_optional
def get_latest(self) -> SingleCheckpoint:
"""
Tries to read the checkpoint state file. If that fails, discovers latest by reading the entire directory.
:return: checkpoint object representing the latest checkpoint
"""
latest = self._checkpoint_state_file.read()
if latest is None and self._checkpoint_state_optional:
all_checkpoints = _filter_checkpoint_files(os.listdir(self._checkpoint_dir))
if len(all_checkpoints) > 0:
latest = all_checkpoints[-1]
return latest
def get_all(self) -> List[SingleCheckpoint]:
"""
Reads both the checkpoint state file as well as contents of the directory and merges them into one list.
:return: list of checkpoint objects
"""
# discover all checkpoint files in directory if requested or if a valid checkpoint-state file doesn't exist
all_checkpoints = _filter_checkpoint_files(os.listdir(self._checkpoint_dir))
last_checkpoint = self._checkpoint_state_file.read()
if last_checkpoint is not None:
# remove excess checkpoints: higher checkpoint number, but not recent (e.g. from a previous run)
all_checkpoints = all_checkpoints[: all_checkpoints.index(last_checkpoint) + 1]
elif not self._checkpoint_state_optional:
# if last_checkpoint is not discovered from the checkpoint-state file and it isn't optional, then
# all checkpoint files discovered must be partial or invalid, so don't return anything
all_checkpoints.clear()
return all_checkpoints
class CheckpointStateUpdater(object):
"""
Class for scanning checkpoint directory and updating the checkpoint state
"""
def __init__(self, checkpoint_dir: str, read_all: bool=False):
"""
:param checkpoint_dir: path to checkpoint directory
:param read_all: whether to scan the directory for existing checkpoints
"""
self._checkpoint_dir = checkpoint_dir
self._checkpoint_state_file = CheckpointStateFile(checkpoint_dir)
self._all_checkpoints = list()
# Read checkpoint state and initialize
state_reader = CheckpointStateReader(checkpoint_dir)
if read_all:
self._all_checkpoints = state_reader.get_all()
else:
latest = state_reader.get_latest()
if latest is not None:
self._all_checkpoints = [latest]
def update(self, checkpoint: SingleCheckpoint) -> None:
"""
Update the checkpoint state with the latest checkpoint.
:param checkpoint: SingleCheckpoint object containing name and number of checkpoint
"""
self._all_checkpoints.append(checkpoint)
# Simply write checkpoint_name to checkpoint-state file
self._checkpoint_state_file.write(checkpoint)
@property
def last_checkpoint(self) -> Union[None, SingleCheckpoint]:
if len(self._all_checkpoints) == 0:
return None
return self._all_checkpoints[-1]
@property
def all_checkpoints(self) -> List[SingleCheckpoint]:
return self._all_checkpoints
def get_checkpoint_state(self) -> Union[None, CheckpointState]:
"""
:return: The most recent checkpoint state
"""
if len(self._all_checkpoints) == 0:
return None
return CheckpointState(self._all_checkpoints, self._checkpoint_dir)
class CheckpointFilenameParser(object):
"""
Helper object for parsing filenames that are potentially checkpoints
"""
coach_checkpoint_filename_pattern = r'\A(([0-9]+)[^0-9])?.*?\.ckpt(-([0-9]+))?'
def __init__(self):
self._prog = re.compile(self.coach_checkpoint_filename_pattern)
def parse(self, filename: str) -> Union[None, SingleCheckpoint]:
"""
Tries to parse the filename using the checkpoint filename pattern. If successful,
it returns tuple of (checkpoint-number, checkpoint-name). Otherwise it returns None.
:param filename: filename to be parsed
:return: None or (checkpoint-number, checkpoint-name)
"""
m = self._prog.search(filename)
if m is not None and (m.group(2) is not None or m.group(4) is not None):
assert m.group(2) is None or m.group(4) is None # Only one group must be valid
checkpoint_num = int(m.group(2) if m.group(2) is not None else m.group(4))
return SingleCheckpoint(checkpoint_num, m.group(0))
return None
def _filter_checkpoint_files(filenames: List[str], sort_by_num: bool=True) -> List[SingleCheckpoint]:
"""
Given a list of potential file names, return the ones that match checkpoint pattern along with
the checkpoint number of each file name.
:param filenames: list of all filenames
:param sort_by_num: whether to sort the output result by checkpoint number
:return: list of (checkpoint-number, checkpoint-filename) tuples
"""
parser = CheckpointFilenameParser()
checkpoints = [ckp for ckp in [parser.parse(fn) for fn in filenames] if ckp is not None]
if sort_by_num:
checkpoints.sort(key=lambda x: x.num)
return checkpoints
def get_checkpoint_state(checkpoint_dir: str, all_checkpoints=False) ->Union[CheckpointState, None]:
"""
Scan checkpoint directory and find the list of checkpoint files.
:param checkpoint_dir: directory where checkpoints are saved
:param all_checkpoints: if True, scan the directory and return list of all checkpoints
as well as the most recent one
:return: a CheckpointState for checkpoint_dir containing a sorted list of checkpoints by checkpoint-number.
If no matching files are found, returns None.
"""
return CheckpointStateUpdater(checkpoint_dir, read_all=all_checkpoints).get_checkpoint_state()

View File

@@ -2,8 +2,7 @@ from rl_coach.data_stores.data_store import DataStore, DataStoreParameters
from minio import Minio
from minio.error import ResponseError
from configparser import ConfigParser, Error
from google.protobuf import text_format
from tensorflow.python.training.checkpoint_state_pb2 import CheckpointState
from rl_coach.checkpoint import CheckpointStateFile
from rl_coach.data_stores.data_store import SyncFiles
import os
@@ -24,6 +23,7 @@ class S3DataStoreParameters(DataStoreParameters):
class S3DataStore(DataStore):
def __init__(self, params: S3DataStoreParameters):
super(S3DataStore, self).__init__(params)
self.params = params
access_key = None
secret_key = None
@@ -51,14 +51,15 @@ class S3DataStore(DataStore):
def save_to_store(self):
try:
# remove lock file if it exists
self.mc.remove_object(self.params.bucket_name, SyncFiles.LOCKFILE.value)
# Acquire lock
self.mc.put_object(self.params.bucket_name, SyncFiles.LOCKFILE.value, io.BytesIO(b''), 0)
checkpoint_file = None
for root, dirs, files in os.walk(self.params.checkpoint_dir):
for filename in files:
if filename == 'checkpoint':
if filename == CheckpointStateFile.checkpoint_state_filename:
checkpoint_file = (root, filename)
continue
abs_name = os.path.abspath(os.path.join(root, filename))
@@ -69,6 +70,7 @@ class S3DataStore(DataStore):
rel_name = os.path.relpath(abs_name, self.params.checkpoint_dir)
self.mc.fput_object(self.params.bucket_name, rel_name, abs_name)
# release lock
self.mc.remove_object(self.params.bucket_name, SyncFiles.LOCKFILE.value)
except ResponseError as e:
@@ -76,14 +78,16 @@ class S3DataStore(DataStore):
def load_from_store(self):
try:
filename = os.path.abspath(os.path.join(self.params.checkpoint_dir, "checkpoint"))
state_file = CheckpointStateFile(os.path.abspath(self.params.checkpoint_dir))
# wait until lock is removed
while True:
objects = self.mc.list_objects_v2(self.params.bucket_name, SyncFiles.LOCKFILE.value)
if next(objects, None) is None:
try:
self.mc.fget_object(self.params.bucket_name, "checkpoint", filename)
# fetch checkpoint state file from S3
self.mc.fget_object(self.params.bucket_name, state_file.filename, state_file.path)
except Exception as e:
continue
break
@@ -101,13 +105,9 @@ class S3DataStore(DataStore):
except Exception as e:
pass
ckpt = CheckpointState()
if os.path.exists(filename):
contents = open(filename, 'r').read()
text_format.Merge(contents, ckpt)
rel_path = os.path.relpath(ckpt.model_checkpoint_path, self.params.checkpoint_dir)
objects = self.mc.list_objects_v2(self.params.bucket_name, prefix=rel_path, recursive=True)
checkpoint_state = state_file.read()
if checkpoint_state is not None:
objects = self.mc.list_objects_v2(self.params.bucket_name, prefix=checkpoint_state.name, recursive=True)
for obj in objects:
filename = os.path.abspath(os.path.join(self.params.checkpoint_dir, obj.object_name))
if not os.path.exists(filename):

View File

@@ -25,6 +25,7 @@ import contextlib
from rl_coach.base_parameters import iterable_to_items, TaskParameters, DistributedTaskParameters, Frameworks, \
VisualizationParameters, \
Parameters, PresetValidationParameters, RunType
from rl_coach.checkpoint import CheckpointStateUpdater, get_checkpoint_state, SingleCheckpoint
from rl_coach.core_types import TotalStepsCounter, RunPhase, PlayingStepsType, TrainingSteps, EnvironmentEpisodes, \
EnvironmentSteps, \
StepMethod, Transition
@@ -32,7 +33,7 @@ from rl_coach.environments.environment import Environment
from rl_coach.level_manager import LevelManager
from rl_coach.logger import screen, Logger
from rl_coach.saver import SaverCollection
from rl_coach.utils import get_checkpoint_state, set_cpu, start_shell_command_and_wait
from rl_coach.utils import set_cpu, start_shell_command_and_wait
from rl_coach.data_stores.data_store_impl import get_data_store as data_store_creator
from rl_coach.memories.backend.memory_impl import get_memory_backend
from rl_coach.data_stores.data_store import SyncFiles
@@ -115,6 +116,7 @@ class GraphManager(object):
self.checkpoint_id = 0
self.checkpoint_saver = None
self.checkpoint_state_updater = None
self.graph_logger = Logger()
self.data_store = None
@@ -495,7 +497,7 @@ class GraphManager(object):
self.act(EnvironmentEpisodes(1))
self.sync()
if self.should_stop():
if self.task_parameters.checkpoint_save_dir:
if self.task_parameters.checkpoint_save_dir and os.path.exists(self.task_parameters.checkpoint_save_dir):
open(os.path.join(self.task_parameters.checkpoint_save_dir, SyncFiles.FINISHED.value), 'w').close()
if hasattr(self, 'data_store_params'):
data_store = self.get_data_store(self.data_store_params)
@@ -579,29 +581,35 @@ class GraphManager(object):
self.save_checkpoint()
def save_checkpoint(self):
# create current session's checkpoint directory
if self.task_parameters.checkpoint_save_dir is None:
self.task_parameters.checkpoint_save_dir = os.path.join(self.task_parameters.experiment_path, 'checkpoint')
filename = "{}_Step-{}.ckpt".format(
self.checkpoint_id,
self.total_steps_counters[RunPhase.TRAIN][EnvironmentSteps])
if not os.path.exists(self.task_parameters.checkpoint_save_dir):
os.mkdir(self.task_parameters.checkpoint_save_dir) # Create directory structure
if self.checkpoint_state_updater is None:
self.checkpoint_state_updater = CheckpointStateUpdater(self.task_parameters.checkpoint_save_dir)
checkpoint_name = "{}_Step-{}.ckpt".format(
self.checkpoint_id, self.total_steps_counters[RunPhase.TRAIN][EnvironmentSteps])
checkpoint_path = os.path.join(self.task_parameters.checkpoint_save_dir, checkpoint_name)
checkpoint_path = os.path.join(self.task_parameters.checkpoint_save_dir,
filename)
if not os.path.exists(os.path.dirname(checkpoint_path)):
os.mkdir(os.path.dirname(checkpoint_path)) # Create directory structure
if not isinstance(self.task_parameters, DistributedTaskParameters):
saved_checkpoint_path = self.checkpoint_saver.save(self.sess, checkpoint_path)
else:
saved_checkpoint_path = checkpoint_path
# this is required in order for agents to save additional information like a DND for example
[manager.save_checkpoint(filename) for manager in self.level_managers]
[manager.save_checkpoint(checkpoint_name) for manager in self.level_managers]
# the ONNX graph will be stored only if checkpoints are stored and the -onnx flag is used
if self.task_parameters.export_onnx_graph:
self.save_onnx_graph()
# write the new checkpoint name to a file to signal this checkpoint has been fully saved
self.checkpoint_state_updater.update(SingleCheckpoint(self.checkpoint_id, checkpoint_name))
screen.log_dict(
OrderedDict([
("Saving in path", saved_checkpoint_path),

View File

@@ -12,37 +12,26 @@ import os
import math
from rl_coach.base_parameters import TaskParameters, DistributedCoachSynchronizationType
from rl_coach.checkpoint import CheckpointStateFile, CheckpointStateReader
from rl_coach.core_types import EnvironmentSteps, RunPhase, EnvironmentEpisodes
from google.protobuf import text_format
from tensorflow.python.training.checkpoint_state_pb2 import CheckpointState
from rl_coach.data_stores.data_store import SyncFiles
def has_checkpoint(checkpoint_dir):
"""
True if a checkpoint is present in checkpoint_dir
"""
if os.path.isdir(checkpoint_dir):
if len(os.listdir(checkpoint_dir)) > 0:
return os.path.isfile(os.path.join(checkpoint_dir, "checkpoint"))
return False
def wait_for_checkpoint(checkpoint_dir, data_store=None, timeout=10):
"""
block until there is a checkpoint in checkpoint_dir
"""
chkpt_state_file = CheckpointStateFile(checkpoint_dir)
for i in range(timeout):
if data_store:
data_store.load_from_store()
if has_checkpoint(checkpoint_dir):
if chkpt_state_file.read() is not None:
return
time.sleep(10)
# one last time
if has_checkpoint(checkpoint_dir):
if chkpt_state_file.read() is not None:
return
raise ValueError((
@@ -54,21 +43,6 @@ def wait_for_checkpoint(checkpoint_dir, data_store=None, timeout=10):
))
def data_store_ckpt_load(data_store):
while True:
data_store.load_from_store()
time.sleep(10)
def get_latest_checkpoint(checkpoint_dir):
if os.path.exists(os.path.join(checkpoint_dir, 'checkpoint')):
ckpt = CheckpointState()
contents = open(os.path.join(checkpoint_dir, 'checkpoint'), 'r').read()
text_format.Merge(contents, ckpt)
rel_path = os.path.relpath(ckpt.model_checkpoint_path, checkpoint_dir)
return int(rel_path.split('_Step')[0])
def should_stop(checkpoint_dir):
return os.path.exists(os.path.join(checkpoint_dir, SyncFiles.FINISHED.value))
@@ -83,6 +57,7 @@ def rollout_worker(graph_manager, data_store, num_workers, task_parameters):
graph_manager.create_graph(task_parameters)
with graph_manager.phase_context(RunPhase.TRAIN):
chkpt_state_reader = CheckpointStateReader(checkpoint_dir, checkpoint_state_optional=False)
last_checkpoint = 0
act_steps = math.ceil((graph_manager.agent_params.algorithm.num_consecutive_playing_steps.num_steps)/num_workers)
@@ -97,20 +72,20 @@ def rollout_worker(graph_manager, data_store, num_workers, task_parameters):
elif type(graph_manager.agent_params.algorithm.num_consecutive_playing_steps) == EnvironmentEpisodes:
graph_manager.act(EnvironmentEpisodes(num_steps=act_steps))
new_checkpoint = get_latest_checkpoint(checkpoint_dir)
new_checkpoint = chkpt_state_reader.get_latest()
if graph_manager.agent_params.algorithm.distributed_coach_synchronization_type == DistributedCoachSynchronizationType.SYNC:
while new_checkpoint < last_checkpoint + 1:
while new_checkpoint is None or new_checkpoint.num < last_checkpoint + 1:
if should_stop(checkpoint_dir):
break
if data_store:
data_store.load_from_store()
new_checkpoint = get_latest_checkpoint(checkpoint_dir)
new_checkpoint = chkpt_state_reader.get_latest()
graph_manager.restore_checkpoint()
if graph_manager.agent_params.algorithm.distributed_coach_synchronization_type == DistributedCoachSynchronizationType.ASYNC:
if new_checkpoint > last_checkpoint:
if new_checkpoint is not None and new_checkpoint.num > last_checkpoint:
graph_manager.restore_checkpoint()
last_checkpoint = new_checkpoint
if new_checkpoint is not None:
last_checkpoint = new_checkpoint.num

View File

@@ -2,7 +2,7 @@ import os
import pytest
import tempfile
from rl_coach import utils
from rl_coach import checkpoint
@pytest.mark.unit_test
@@ -10,8 +10,15 @@ def test_get_checkpoint_state():
files = ['4.test.ckpt.ext', '2.test.ckpt.ext', '3.test.ckpt.ext', '1.test.ckpt.ext', 'prefix.10.test.ckpt.ext']
with tempfile.TemporaryDirectory() as temp_dir:
[open(os.path.join(temp_dir, fn), 'a').close() for fn in files]
checkpoint_state = utils.get_checkpoint_state(temp_dir)
checkpoint_state = checkpoint.get_checkpoint_state(temp_dir, all_checkpoints=True)
assert checkpoint_state.model_checkpoint_path == os.path.join(temp_dir, '4.test.ckpt')
assert checkpoint_state.all_model_checkpoint_paths == \
[os.path.join(temp_dir, f[:-4]) for f in sorted(files[:-1])]
reader = checkpoint.CheckpointStateReader(temp_dir, checkpoint_state_optional=False)
assert reader.get_latest() is None
assert len(reader.get_all()) == 0
reader = checkpoint.CheckpointStateReader(temp_dir)
assert reader.get_latest().num == 4
assert [ckp.num for ckp in reader.get_all()] == [1, 2, 3, 4]

View File

@@ -30,7 +30,7 @@ def training_worker(graph_manager, task_parameters):
graph_manager.setup_memory_backend()
while(steps < graph_manager.improve_steps.num_steps):
while steps < graph_manager.improve_steps.num_steps:
graph_manager.phase = core_types.RunPhase.TRAIN
graph_manager.fetch_from_worker(graph_manager.agent_params.algorithm.num_consecutive_playing_steps)

View File

@@ -546,58 +546,3 @@ def start_shell_command_and_wait(command):
def indent_string(string):
return '\t' + string.replace('\n', '\n\t')
class CheckpointState(object):
"""
Helper class for checkpoint directory information. It replicates
the CheckpointState protobuf class in tensorflow.
"""
def __init__(self, checkpoints: List[str]):
self._checkpoints = checkpoints
@property
def all_model_checkpoint_paths(self):
return self._checkpoints
@property
def model_checkpoint_path(self):
return self._checkpoints[-1]
def __str__(self):
out_str = 'model_checkpoint_path: {}\n'.format(self.model_checkpoint_path)
for c in self._checkpoints:
out_str += 'all_model_checkpoint_paths: {}\n'.format(c)
return out_str
def __repr__(self):
return str(self._checkpoints)
COACH_CHECKPOINT_PATTERN = r'\A(([0-9]+)[^0-9])?.*?\.ckpt(-([0-9]+))?'
def get_checkpoint_state(checkpoint_dir: Union[str, List[str]], filename_pattern: str=COACH_CHECKPOINT_PATTERN) ->\
Union[CheckpointState, None]:
"""
Finds the latest checkpoint file. It uses the first group of filename_pattern (i.e. group(2) or group(4) to sort
the checkpoint names and find the latest checkpoint
:param checkpoint_dir: directory where checkpoints are saved or list of all files in a directory
:param filename_pattern: regex pattern for checkpoint filenames
:return: a CheckpointState for checkpoint_dir containing a sorted list of checkpoint names. If no matching
files are found, returns None.
"""
prog = re.compile(filename_pattern)
checkpoints = dict()
filenames = os.listdir(checkpoint_dir) if isinstance(checkpoint_dir, str) else checkpoint_dir
for name in filenames:
m = prog.search(name)
if m is not None and (m.group(2) is not None or m.group(4) is not None):
if m.group(2) is not None and m.group(4) is not None:
assert m.group(2) == m.group(4)
checkpoint_count = int(m.group(2) if m.group(2) is not None else m.group(4))
full_path = os.path.join(checkpoint_dir, m.group(0)) if isinstance(checkpoint_dir, str) else m.group(0)
checkpoints[checkpoint_count] = full_path
if len(checkpoints) == 0:
return None
return CheckpointState([checkpoints[k] for k in sorted(checkpoints.keys())])