1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 19:20:19 +01:00

Implement frame-work agnostic rollout and training workers (#137)

* Added checkpoint state file to coach checkpointing.

* Removed TF specific code from rollout_worker, training_worker, and s3_data_store
This commit is contained in:
Sina Afrooze
2018-11-23 18:05:44 -08:00
committed by Balaji Subramaniam
parent 4a6c404070
commit 5332013bd1
7 changed files with 350 additions and 117 deletions

298
rl_coach/checkpoint.py Normal file
View File

@@ -0,0 +1,298 @@
#
# Copyright (c) 2017 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
Module providing helper classes and functions for reading/writing checkpoint state
"""
import os
import re
from typing import List, Union, Tuple
class SingleCheckpoint(object):
"""
Helper class for storing checkpoint name and number
"""
def __init__(self, num: int, name: str):
"""
:param num: checkpoint number
:param name: checkpoint name (i.e. the prefix for all checkpoint files)
"""
self._num = num
self._name = name
@property
def num(self) -> int:
return self._num
@property
def name(self) -> str:
return self._name
def __str__(self):
return self._name
def __repr__(self):
return str(self)
def __eq__(self, other: 'SingleCheckpoint'):
if not isinstance(other, SingleCheckpoint):
return False
return self._name == other._name and self._num == other._num
def __ne__(self, other):
return not self.__eq__(other)
class CheckpointState(object):
"""
Helper class for checkpoint directory information. It replicates
the CheckpointState protobuf class in tensorflow with addition of
two new functions: last_checkpoint() and all_checkpoints()
"""
def __init__(self, checkpoints: List[SingleCheckpoint], checkpoint_dir: str):
"""
:param checkpoints: sorted list of checkpoints from oldest to newest. checkpoint[-1] is
considered to be the most recent checkpoint.
:param checkpoint_dir: checkpoint directory which is added to the paths
"""
self._checkpoints = checkpoints
self._checkpoin_dir = checkpoint_dir
@property
def all_checkpoints(self) -> List[SingleCheckpoint]:
"""
:return: list of all checkpoints
"""
return self._checkpoints
@property
def last_checkpoint(self) -> SingleCheckpoint:
"""
:return: the most recent checkpoint
"""
return self._checkpoints[-1]
@property
def all_model_checkpoint_paths(self) -> List[str]:
"""
TF compatible function call to get all checkpoints
:return: list of all available model checkpoint paths
"""
return [os.path.join(self._checkpoin_dir, c.name) for c in self._checkpoints]
@property
def model_checkpoint_path(self) -> str:
"""
TF compatible call to get most recent checkpoint
:return: path of the most recent model checkpoint
"""
return os.path.join(self._checkpoin_dir, self._checkpoints[-1].name)
def __str__(self):
out_str = 'model_checkpoint_path: {}\n'.format(self.model_checkpoint_path)
for c in self.all_model_checkpoint_paths:
out_str += 'all_model_checkpoint_paths: {}\n'.format(c)
return out_str
def __repr__(self):
return str(self._checkpoints)
class CheckpointStateFile(object):
"""
Helper class for reading from and writing to the checkpoint state file
"""
checkpoint_state_filename = '.coach_checkpoint'
def __init__(self, checkpoint_dir: str):
self._checkpoint_state_path = os.path.join(checkpoint_dir, self.checkpoint_state_filename)
def exists(self) -> bool:
"""
:return: True if checkpoint state file exists, false otherwise
"""
return os.path.exists(self._checkpoint_state_path)
def read(self) -> Union[None, SingleCheckpoint]:
"""
Read checkpoint state file and interpret its content
:return:
"""
if not self.exists():
return None
with open(self._checkpoint_state_path, 'r') as fd:
return CheckpointFilenameParser().parse(fd.read(256))
def write(self, data: SingleCheckpoint) -> None:
"""
Writes data to checkpoint state file
:param data: string data
"""
with open(self._checkpoint_state_path, 'w') as fd:
fd.write(data.name)
@property
def filename(self) -> str:
return self.checkpoint_state_filename
@property
def path(self) -> str:
return self._checkpoint_state_path
class CheckpointStateReader(object):
"""
Class for scanning checkpoint directory and updating the checkpoint state
"""
def __init__(self, checkpoint_dir: str, checkpoint_state_optional: bool=True):
"""
:param checkpoint_dir: path to checkpoint directory
:param checkpoint_state_optional: If True, checkpoint state file is optional and if not found,
directory is scanned to find the latest checkpoint. Default is True for backward compatibility
"""
self._checkpoint_dir = checkpoint_dir
self._checkpoint_state_file = CheckpointStateFile(self._checkpoint_dir)
self._checkpoint_state_optional = checkpoint_state_optional
def get_latest(self) -> SingleCheckpoint:
"""
Tries to read the checkpoint state file. If that fails, discovers latest by reading the entire directory.
:return: checkpoint object representing the latest checkpoint
"""
latest = self._checkpoint_state_file.read()
if latest is None and self._checkpoint_state_optional:
all_checkpoints = _filter_checkpoint_files(os.listdir(self._checkpoint_dir))
if len(all_checkpoints) > 0:
latest = all_checkpoints[-1]
return latest
def get_all(self) -> List[SingleCheckpoint]:
"""
Reads both the checkpoint state file as well as contents of the directory and merges them into one list.
:return: list of checkpoint objects
"""
# discover all checkpoint files in directory if requested or if a valid checkpoint-state file doesn't exist
all_checkpoints = _filter_checkpoint_files(os.listdir(self._checkpoint_dir))
last_checkpoint = self._checkpoint_state_file.read()
if last_checkpoint is not None:
# remove excess checkpoints: higher checkpoint number, but not recent (e.g. from a previous run)
all_checkpoints = all_checkpoints[: all_checkpoints.index(last_checkpoint) + 1]
elif not self._checkpoint_state_optional:
# if last_checkpoint is not discovered from the checkpoint-state file and it isn't optional, then
# all checkpoint files discovered must be partial or invalid, so don't return anything
all_checkpoints.clear()
return all_checkpoints
class CheckpointStateUpdater(object):
"""
Class for scanning checkpoint directory and updating the checkpoint state
"""
def __init__(self, checkpoint_dir: str, read_all: bool=False):
"""
:param checkpoint_dir: path to checkpoint directory
:param read_all: whether to scan the directory for existing checkpoints
"""
self._checkpoint_dir = checkpoint_dir
self._checkpoint_state_file = CheckpointStateFile(checkpoint_dir)
self._all_checkpoints = list()
# Read checkpoint state and initialize
state_reader = CheckpointStateReader(checkpoint_dir)
if read_all:
self._all_checkpoints = state_reader.get_all()
else:
latest = state_reader.get_latest()
if latest is not None:
self._all_checkpoints = [latest]
def update(self, checkpoint: SingleCheckpoint) -> None:
"""
Update the checkpoint state with the latest checkpoint.
:param checkpoint: SingleCheckpoint object containing name and number of checkpoint
"""
self._all_checkpoints.append(checkpoint)
# Simply write checkpoint_name to checkpoint-state file
self._checkpoint_state_file.write(checkpoint)
@property
def last_checkpoint(self) -> Union[None, SingleCheckpoint]:
if len(self._all_checkpoints) == 0:
return None
return self._all_checkpoints[-1]
@property
def all_checkpoints(self) -> List[SingleCheckpoint]:
return self._all_checkpoints
def get_checkpoint_state(self) -> Union[None, CheckpointState]:
"""
:return: The most recent checkpoint state
"""
if len(self._all_checkpoints) == 0:
return None
return CheckpointState(self._all_checkpoints, self._checkpoint_dir)
class CheckpointFilenameParser(object):
"""
Helper object for parsing filenames that are potentially checkpoints
"""
coach_checkpoint_filename_pattern = r'\A(([0-9]+)[^0-9])?.*?\.ckpt(-([0-9]+))?'
def __init__(self):
self._prog = re.compile(self.coach_checkpoint_filename_pattern)
def parse(self, filename: str) -> Union[None, SingleCheckpoint]:
"""
Tries to parse the filename using the checkpoint filename pattern. If successful,
it returns tuple of (checkpoint-number, checkpoint-name). Otherwise it returns None.
:param filename: filename to be parsed
:return: None or (checkpoint-number, checkpoint-name)
"""
m = self._prog.search(filename)
if m is not None and (m.group(2) is not None or m.group(4) is not None):
assert m.group(2) is None or m.group(4) is None # Only one group must be valid
checkpoint_num = int(m.group(2) if m.group(2) is not None else m.group(4))
return SingleCheckpoint(checkpoint_num, m.group(0))
return None
def _filter_checkpoint_files(filenames: List[str], sort_by_num: bool=True) -> List[SingleCheckpoint]:
"""
Given a list of potential file names, return the ones that match checkpoint pattern along with
the checkpoint number of each file name.
:param filenames: list of all filenames
:param sort_by_num: whether to sort the output result by checkpoint number
:return: list of (checkpoint-number, checkpoint-filename) tuples
"""
parser = CheckpointFilenameParser()
checkpoints = [ckp for ckp in [parser.parse(fn) for fn in filenames] if ckp is not None]
if sort_by_num:
checkpoints.sort(key=lambda x: x.num)
return checkpoints
def get_checkpoint_state(checkpoint_dir: str, all_checkpoints=False) ->Union[CheckpointState, None]:
"""
Scan checkpoint directory and find the list of checkpoint files.
:param checkpoint_dir: directory where checkpoints are saved
:param all_checkpoints: if True, scan the directory and return list of all checkpoints
as well as the most recent one
:return: a CheckpointState for checkpoint_dir containing a sorted list of checkpoints by checkpoint-number.
If no matching files are found, returns None.
"""
return CheckpointStateUpdater(checkpoint_dir, read_all=all_checkpoints).get_checkpoint_state()

View File

@@ -2,8 +2,7 @@ from rl_coach.data_stores.data_store import DataStore, DataStoreParameters
from minio import Minio from minio import Minio
from minio.error import ResponseError from minio.error import ResponseError
from configparser import ConfigParser, Error from configparser import ConfigParser, Error
from google.protobuf import text_format from rl_coach.checkpoint import CheckpointStateFile
from tensorflow.python.training.checkpoint_state_pb2 import CheckpointState
from rl_coach.data_stores.data_store import SyncFiles from rl_coach.data_stores.data_store import SyncFiles
import os import os
@@ -24,6 +23,7 @@ class S3DataStoreParameters(DataStoreParameters):
class S3DataStore(DataStore): class S3DataStore(DataStore):
def __init__(self, params: S3DataStoreParameters): def __init__(self, params: S3DataStoreParameters):
super(S3DataStore, self).__init__(params)
self.params = params self.params = params
access_key = None access_key = None
secret_key = None secret_key = None
@@ -51,14 +51,15 @@ class S3DataStore(DataStore):
def save_to_store(self): def save_to_store(self):
try: try:
# remove lock file if it exists
self.mc.remove_object(self.params.bucket_name, SyncFiles.LOCKFILE.value) self.mc.remove_object(self.params.bucket_name, SyncFiles.LOCKFILE.value)
# Acquire lock
self.mc.put_object(self.params.bucket_name, SyncFiles.LOCKFILE.value, io.BytesIO(b''), 0) self.mc.put_object(self.params.bucket_name, SyncFiles.LOCKFILE.value, io.BytesIO(b''), 0)
checkpoint_file = None checkpoint_file = None
for root, dirs, files in os.walk(self.params.checkpoint_dir): for root, dirs, files in os.walk(self.params.checkpoint_dir):
for filename in files: for filename in files:
if filename == 'checkpoint': if filename == CheckpointStateFile.checkpoint_state_filename:
checkpoint_file = (root, filename) checkpoint_file = (root, filename)
continue continue
abs_name = os.path.abspath(os.path.join(root, filename)) abs_name = os.path.abspath(os.path.join(root, filename))
@@ -69,6 +70,7 @@ class S3DataStore(DataStore):
rel_name = os.path.relpath(abs_name, self.params.checkpoint_dir) rel_name = os.path.relpath(abs_name, self.params.checkpoint_dir)
self.mc.fput_object(self.params.bucket_name, rel_name, abs_name) self.mc.fput_object(self.params.bucket_name, rel_name, abs_name)
# release lock
self.mc.remove_object(self.params.bucket_name, SyncFiles.LOCKFILE.value) self.mc.remove_object(self.params.bucket_name, SyncFiles.LOCKFILE.value)
except ResponseError as e: except ResponseError as e:
@@ -76,14 +78,16 @@ class S3DataStore(DataStore):
def load_from_store(self): def load_from_store(self):
try: try:
filename = os.path.abspath(os.path.join(self.params.checkpoint_dir, "checkpoint")) state_file = CheckpointStateFile(os.path.abspath(self.params.checkpoint_dir))
# wait until lock is removed
while True: while True:
objects = self.mc.list_objects_v2(self.params.bucket_name, SyncFiles.LOCKFILE.value) objects = self.mc.list_objects_v2(self.params.bucket_name, SyncFiles.LOCKFILE.value)
if next(objects, None) is None: if next(objects, None) is None:
try: try:
self.mc.fget_object(self.params.bucket_name, "checkpoint", filename) # fetch checkpoint state file from S3
self.mc.fget_object(self.params.bucket_name, state_file.filename, state_file.path)
except Exception as e: except Exception as e:
continue continue
break break
@@ -101,13 +105,9 @@ class S3DataStore(DataStore):
except Exception as e: except Exception as e:
pass pass
ckpt = CheckpointState() checkpoint_state = state_file.read()
if os.path.exists(filename): if checkpoint_state is not None:
contents = open(filename, 'r').read() objects = self.mc.list_objects_v2(self.params.bucket_name, prefix=checkpoint_state.name, recursive=True)
text_format.Merge(contents, ckpt)
rel_path = os.path.relpath(ckpt.model_checkpoint_path, self.params.checkpoint_dir)
objects = self.mc.list_objects_v2(self.params.bucket_name, prefix=rel_path, recursive=True)
for obj in objects: for obj in objects:
filename = os.path.abspath(os.path.join(self.params.checkpoint_dir, obj.object_name)) filename = os.path.abspath(os.path.join(self.params.checkpoint_dir, obj.object_name))
if not os.path.exists(filename): if not os.path.exists(filename):

View File

@@ -25,6 +25,7 @@ import contextlib
from rl_coach.base_parameters import iterable_to_items, TaskParameters, DistributedTaskParameters, Frameworks, \ from rl_coach.base_parameters import iterable_to_items, TaskParameters, DistributedTaskParameters, Frameworks, \
VisualizationParameters, \ VisualizationParameters, \
Parameters, PresetValidationParameters, RunType Parameters, PresetValidationParameters, RunType
from rl_coach.checkpoint import CheckpointStateUpdater, get_checkpoint_state, SingleCheckpoint
from rl_coach.core_types import TotalStepsCounter, RunPhase, PlayingStepsType, TrainingSteps, EnvironmentEpisodes, \ from rl_coach.core_types import TotalStepsCounter, RunPhase, PlayingStepsType, TrainingSteps, EnvironmentEpisodes, \
EnvironmentSteps, \ EnvironmentSteps, \
StepMethod, Transition StepMethod, Transition
@@ -32,7 +33,7 @@ from rl_coach.environments.environment import Environment
from rl_coach.level_manager import LevelManager from rl_coach.level_manager import LevelManager
from rl_coach.logger import screen, Logger from rl_coach.logger import screen, Logger
from rl_coach.saver import SaverCollection from rl_coach.saver import SaverCollection
from rl_coach.utils import get_checkpoint_state, set_cpu, start_shell_command_and_wait from rl_coach.utils import set_cpu, start_shell_command_and_wait
from rl_coach.data_stores.data_store_impl import get_data_store as data_store_creator from rl_coach.data_stores.data_store_impl import get_data_store as data_store_creator
from rl_coach.memories.backend.memory_impl import get_memory_backend from rl_coach.memories.backend.memory_impl import get_memory_backend
from rl_coach.data_stores.data_store import SyncFiles from rl_coach.data_stores.data_store import SyncFiles
@@ -115,6 +116,7 @@ class GraphManager(object):
self.checkpoint_id = 0 self.checkpoint_id = 0
self.checkpoint_saver = None self.checkpoint_saver = None
self.checkpoint_state_updater = None
self.graph_logger = Logger() self.graph_logger = Logger()
self.data_store = None self.data_store = None
@@ -495,7 +497,7 @@ class GraphManager(object):
self.act(EnvironmentEpisodes(1)) self.act(EnvironmentEpisodes(1))
self.sync() self.sync()
if self.should_stop(): if self.should_stop():
if self.task_parameters.checkpoint_save_dir: if self.task_parameters.checkpoint_save_dir and os.path.exists(self.task_parameters.checkpoint_save_dir):
open(os.path.join(self.task_parameters.checkpoint_save_dir, SyncFiles.FINISHED.value), 'w').close() open(os.path.join(self.task_parameters.checkpoint_save_dir, SyncFiles.FINISHED.value), 'w').close()
if hasattr(self, 'data_store_params'): if hasattr(self, 'data_store_params'):
data_store = self.get_data_store(self.data_store_params) data_store = self.get_data_store(self.data_store_params)
@@ -579,29 +581,35 @@ class GraphManager(object):
self.save_checkpoint() self.save_checkpoint()
def save_checkpoint(self): def save_checkpoint(self):
# create current session's checkpoint directory
if self.task_parameters.checkpoint_save_dir is None: if self.task_parameters.checkpoint_save_dir is None:
self.task_parameters.checkpoint_save_dir = os.path.join(self.task_parameters.experiment_path, 'checkpoint') self.task_parameters.checkpoint_save_dir = os.path.join(self.task_parameters.experiment_path, 'checkpoint')
filename = "{}_Step-{}.ckpt".format( if not os.path.exists(self.task_parameters.checkpoint_save_dir):
self.checkpoint_id, os.mkdir(self.task_parameters.checkpoint_save_dir) # Create directory structure
self.total_steps_counters[RunPhase.TRAIN][EnvironmentSteps])
if self.checkpoint_state_updater is None:
self.checkpoint_state_updater = CheckpointStateUpdater(self.task_parameters.checkpoint_save_dir)
checkpoint_name = "{}_Step-{}.ckpt".format(
self.checkpoint_id, self.total_steps_counters[RunPhase.TRAIN][EnvironmentSteps])
checkpoint_path = os.path.join(self.task_parameters.checkpoint_save_dir, checkpoint_name)
checkpoint_path = os.path.join(self.task_parameters.checkpoint_save_dir,
filename)
if not os.path.exists(os.path.dirname(checkpoint_path)):
os.mkdir(os.path.dirname(checkpoint_path)) # Create directory structure
if not isinstance(self.task_parameters, DistributedTaskParameters): if not isinstance(self.task_parameters, DistributedTaskParameters):
saved_checkpoint_path = self.checkpoint_saver.save(self.sess, checkpoint_path) saved_checkpoint_path = self.checkpoint_saver.save(self.sess, checkpoint_path)
else: else:
saved_checkpoint_path = checkpoint_path saved_checkpoint_path = checkpoint_path
# this is required in order for agents to save additional information like a DND for example # this is required in order for agents to save additional information like a DND for example
[manager.save_checkpoint(filename) for manager in self.level_managers] [manager.save_checkpoint(checkpoint_name) for manager in self.level_managers]
# the ONNX graph will be stored only if checkpoints are stored and the -onnx flag is used # the ONNX graph will be stored only if checkpoints are stored and the -onnx flag is used
if self.task_parameters.export_onnx_graph: if self.task_parameters.export_onnx_graph:
self.save_onnx_graph() self.save_onnx_graph()
# write the new checkpoint name to a file to signal this checkpoint has been fully saved
self.checkpoint_state_updater.update(SingleCheckpoint(self.checkpoint_id, checkpoint_name))
screen.log_dict( screen.log_dict(
OrderedDict([ OrderedDict([
("Saving in path", saved_checkpoint_path), ("Saving in path", saved_checkpoint_path),

View File

@@ -12,37 +12,26 @@ import os
import math import math
from rl_coach.base_parameters import TaskParameters, DistributedCoachSynchronizationType from rl_coach.base_parameters import TaskParameters, DistributedCoachSynchronizationType
from rl_coach.checkpoint import CheckpointStateFile, CheckpointStateReader
from rl_coach.core_types import EnvironmentSteps, RunPhase, EnvironmentEpisodes from rl_coach.core_types import EnvironmentSteps, RunPhase, EnvironmentEpisodes
from google.protobuf import text_format
from tensorflow.python.training.checkpoint_state_pb2 import CheckpointState
from rl_coach.data_stores.data_store import SyncFiles from rl_coach.data_stores.data_store import SyncFiles
def has_checkpoint(checkpoint_dir):
"""
True if a checkpoint is present in checkpoint_dir
"""
if os.path.isdir(checkpoint_dir):
if len(os.listdir(checkpoint_dir)) > 0:
return os.path.isfile(os.path.join(checkpoint_dir, "checkpoint"))
return False
def wait_for_checkpoint(checkpoint_dir, data_store=None, timeout=10): def wait_for_checkpoint(checkpoint_dir, data_store=None, timeout=10):
""" """
block until there is a checkpoint in checkpoint_dir block until there is a checkpoint in checkpoint_dir
""" """
chkpt_state_file = CheckpointStateFile(checkpoint_dir)
for i in range(timeout): for i in range(timeout):
if data_store: if data_store:
data_store.load_from_store() data_store.load_from_store()
if has_checkpoint(checkpoint_dir): if chkpt_state_file.read() is not None:
return return
time.sleep(10) time.sleep(10)
# one last time # one last time
if has_checkpoint(checkpoint_dir): if chkpt_state_file.read() is not None:
return return
raise ValueError(( raise ValueError((
@@ -54,21 +43,6 @@ def wait_for_checkpoint(checkpoint_dir, data_store=None, timeout=10):
)) ))
def data_store_ckpt_load(data_store):
while True:
data_store.load_from_store()
time.sleep(10)
def get_latest_checkpoint(checkpoint_dir):
if os.path.exists(os.path.join(checkpoint_dir, 'checkpoint')):
ckpt = CheckpointState()
contents = open(os.path.join(checkpoint_dir, 'checkpoint'), 'r').read()
text_format.Merge(contents, ckpt)
rel_path = os.path.relpath(ckpt.model_checkpoint_path, checkpoint_dir)
return int(rel_path.split('_Step')[0])
def should_stop(checkpoint_dir): def should_stop(checkpoint_dir):
return os.path.exists(os.path.join(checkpoint_dir, SyncFiles.FINISHED.value)) return os.path.exists(os.path.join(checkpoint_dir, SyncFiles.FINISHED.value))
@@ -83,6 +57,7 @@ def rollout_worker(graph_manager, data_store, num_workers, task_parameters):
graph_manager.create_graph(task_parameters) graph_manager.create_graph(task_parameters)
with graph_manager.phase_context(RunPhase.TRAIN): with graph_manager.phase_context(RunPhase.TRAIN):
chkpt_state_reader = CheckpointStateReader(checkpoint_dir, checkpoint_state_optional=False)
last_checkpoint = 0 last_checkpoint = 0
act_steps = math.ceil((graph_manager.agent_params.algorithm.num_consecutive_playing_steps.num_steps)/num_workers) act_steps = math.ceil((graph_manager.agent_params.algorithm.num_consecutive_playing_steps.num_steps)/num_workers)
@@ -97,20 +72,20 @@ def rollout_worker(graph_manager, data_store, num_workers, task_parameters):
elif type(graph_manager.agent_params.algorithm.num_consecutive_playing_steps) == EnvironmentEpisodes: elif type(graph_manager.agent_params.algorithm.num_consecutive_playing_steps) == EnvironmentEpisodes:
graph_manager.act(EnvironmentEpisodes(num_steps=act_steps)) graph_manager.act(EnvironmentEpisodes(num_steps=act_steps))
new_checkpoint = get_latest_checkpoint(checkpoint_dir) new_checkpoint = chkpt_state_reader.get_latest()
if graph_manager.agent_params.algorithm.distributed_coach_synchronization_type == DistributedCoachSynchronizationType.SYNC: if graph_manager.agent_params.algorithm.distributed_coach_synchronization_type == DistributedCoachSynchronizationType.SYNC:
while new_checkpoint < last_checkpoint + 1: while new_checkpoint is None or new_checkpoint.num < last_checkpoint + 1:
if should_stop(checkpoint_dir): if should_stop(checkpoint_dir):
break break
if data_store: if data_store:
data_store.load_from_store() data_store.load_from_store()
new_checkpoint = get_latest_checkpoint(checkpoint_dir) new_checkpoint = chkpt_state_reader.get_latest()
graph_manager.restore_checkpoint() graph_manager.restore_checkpoint()
if graph_manager.agent_params.algorithm.distributed_coach_synchronization_type == DistributedCoachSynchronizationType.ASYNC: if graph_manager.agent_params.algorithm.distributed_coach_synchronization_type == DistributedCoachSynchronizationType.ASYNC:
if new_checkpoint > last_checkpoint: if new_checkpoint is not None and new_checkpoint.num > last_checkpoint:
graph_manager.restore_checkpoint() graph_manager.restore_checkpoint()
last_checkpoint = new_checkpoint if new_checkpoint is not None:
last_checkpoint = new_checkpoint.num

View File

@@ -2,7 +2,7 @@ import os
import pytest import pytest
import tempfile import tempfile
from rl_coach import utils from rl_coach import checkpoint
@pytest.mark.unit_test @pytest.mark.unit_test
@@ -10,8 +10,15 @@ def test_get_checkpoint_state():
files = ['4.test.ckpt.ext', '2.test.ckpt.ext', '3.test.ckpt.ext', '1.test.ckpt.ext', 'prefix.10.test.ckpt.ext'] files = ['4.test.ckpt.ext', '2.test.ckpt.ext', '3.test.ckpt.ext', '1.test.ckpt.ext', 'prefix.10.test.ckpt.ext']
with tempfile.TemporaryDirectory() as temp_dir: with tempfile.TemporaryDirectory() as temp_dir:
[open(os.path.join(temp_dir, fn), 'a').close() for fn in files] [open(os.path.join(temp_dir, fn), 'a').close() for fn in files]
checkpoint_state = utils.get_checkpoint_state(temp_dir) checkpoint_state = checkpoint.get_checkpoint_state(temp_dir, all_checkpoints=True)
assert checkpoint_state.model_checkpoint_path == os.path.join(temp_dir, '4.test.ckpt') assert checkpoint_state.model_checkpoint_path == os.path.join(temp_dir, '4.test.ckpt')
assert checkpoint_state.all_model_checkpoint_paths == \ assert checkpoint_state.all_model_checkpoint_paths == \
[os.path.join(temp_dir, f[:-4]) for f in sorted(files[:-1])] [os.path.join(temp_dir, f[:-4]) for f in sorted(files[:-1])]
reader = checkpoint.CheckpointStateReader(temp_dir, checkpoint_state_optional=False)
assert reader.get_latest() is None
assert len(reader.get_all()) == 0
reader = checkpoint.CheckpointStateReader(temp_dir)
assert reader.get_latest().num == 4
assert [ckp.num for ckp in reader.get_all()] == [1, 2, 3, 4]

View File

@@ -30,7 +30,7 @@ def training_worker(graph_manager, task_parameters):
graph_manager.setup_memory_backend() graph_manager.setup_memory_backend()
while(steps < graph_manager.improve_steps.num_steps): while steps < graph_manager.improve_steps.num_steps:
graph_manager.phase = core_types.RunPhase.TRAIN graph_manager.phase = core_types.RunPhase.TRAIN
graph_manager.fetch_from_worker(graph_manager.agent_params.algorithm.num_consecutive_playing_steps) graph_manager.fetch_from_worker(graph_manager.agent_params.algorithm.num_consecutive_playing_steps)

View File

@@ -546,58 +546,3 @@ def start_shell_command_and_wait(command):
def indent_string(string): def indent_string(string):
return '\t' + string.replace('\n', '\n\t') return '\t' + string.replace('\n', '\n\t')
class CheckpointState(object):
"""
Helper class for checkpoint directory information. It replicates
the CheckpointState protobuf class in tensorflow.
"""
def __init__(self, checkpoints: List[str]):
self._checkpoints = checkpoints
@property
def all_model_checkpoint_paths(self):
return self._checkpoints
@property
def model_checkpoint_path(self):
return self._checkpoints[-1]
def __str__(self):
out_str = 'model_checkpoint_path: {}\n'.format(self.model_checkpoint_path)
for c in self._checkpoints:
out_str += 'all_model_checkpoint_paths: {}\n'.format(c)
return out_str
def __repr__(self):
return str(self._checkpoints)
COACH_CHECKPOINT_PATTERN = r'\A(([0-9]+)[^0-9])?.*?\.ckpt(-([0-9]+))?'
def get_checkpoint_state(checkpoint_dir: Union[str, List[str]], filename_pattern: str=COACH_CHECKPOINT_PATTERN) ->\
Union[CheckpointState, None]:
"""
Finds the latest checkpoint file. It uses the first group of filename_pattern (i.e. group(2) or group(4) to sort
the checkpoint names and find the latest checkpoint
:param checkpoint_dir: directory where checkpoints are saved or list of all files in a directory
:param filename_pattern: regex pattern for checkpoint filenames
:return: a CheckpointState for checkpoint_dir containing a sorted list of checkpoint names. If no matching
files are found, returns None.
"""
prog = re.compile(filename_pattern)
checkpoints = dict()
filenames = os.listdir(checkpoint_dir) if isinstance(checkpoint_dir, str) else checkpoint_dir
for name in filenames:
m = prog.search(name)
if m is not None and (m.group(2) is not None or m.group(4) is not None):
if m.group(2) is not None and m.group(4) is not None:
assert m.group(2) == m.group(4)
checkpoint_count = int(m.group(2) if m.group(2) is not None else m.group(4))
full_path = os.path.join(checkpoint_dir, m.group(0)) if isinstance(checkpoint_dir, str) else m.group(0)
checkpoints[checkpoint_count] = full_path
if len(checkpoints) == 0:
return None
return CheckpointState([checkpoints[k] for k in sorted(checkpoints.keys())])