mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 19:20:19 +01:00
Implement frame-work agnostic rollout and training workers (#137)
* Added checkpoint state file to coach checkpointing. * Removed TF specific code from rollout_worker, training_worker, and s3_data_store
This commit is contained in:
committed by
Balaji Subramaniam
parent
4a6c404070
commit
5332013bd1
298
rl_coach/checkpoint.py
Normal file
298
rl_coach/checkpoint.py
Normal file
@@ -0,0 +1,298 @@
|
||||
#
|
||||
# Copyright (c) 2017 Intel Corporation
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
"""
|
||||
Module providing helper classes and functions for reading/writing checkpoint state
|
||||
"""
|
||||
|
||||
import os
|
||||
import re
|
||||
from typing import List, Union, Tuple
|
||||
|
||||
|
||||
class SingleCheckpoint(object):
|
||||
"""
|
||||
Helper class for storing checkpoint name and number
|
||||
"""
|
||||
def __init__(self, num: int, name: str):
|
||||
"""
|
||||
:param num: checkpoint number
|
||||
:param name: checkpoint name (i.e. the prefix for all checkpoint files)
|
||||
"""
|
||||
self._num = num
|
||||
self._name = name
|
||||
|
||||
@property
|
||||
def num(self) -> int:
|
||||
return self._num
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self._name
|
||||
|
||||
def __str__(self):
|
||||
return self._name
|
||||
|
||||
def __repr__(self):
|
||||
return str(self)
|
||||
|
||||
def __eq__(self, other: 'SingleCheckpoint'):
|
||||
if not isinstance(other, SingleCheckpoint):
|
||||
return False
|
||||
return self._name == other._name and self._num == other._num
|
||||
|
||||
def __ne__(self, other):
|
||||
return not self.__eq__(other)
|
||||
|
||||
|
||||
class CheckpointState(object):
|
||||
"""
|
||||
Helper class for checkpoint directory information. It replicates
|
||||
the CheckpointState protobuf class in tensorflow with addition of
|
||||
two new functions: last_checkpoint() and all_checkpoints()
|
||||
"""
|
||||
def __init__(self, checkpoints: List[SingleCheckpoint], checkpoint_dir: str):
|
||||
"""
|
||||
:param checkpoints: sorted list of checkpoints from oldest to newest. checkpoint[-1] is
|
||||
considered to be the most recent checkpoint.
|
||||
:param checkpoint_dir: checkpoint directory which is added to the paths
|
||||
"""
|
||||
self._checkpoints = checkpoints
|
||||
self._checkpoin_dir = checkpoint_dir
|
||||
|
||||
@property
|
||||
def all_checkpoints(self) -> List[SingleCheckpoint]:
|
||||
"""
|
||||
:return: list of all checkpoints
|
||||
"""
|
||||
return self._checkpoints
|
||||
|
||||
@property
|
||||
def last_checkpoint(self) -> SingleCheckpoint:
|
||||
"""
|
||||
:return: the most recent checkpoint
|
||||
"""
|
||||
return self._checkpoints[-1]
|
||||
|
||||
@property
|
||||
def all_model_checkpoint_paths(self) -> List[str]:
|
||||
"""
|
||||
TF compatible function call to get all checkpoints
|
||||
:return: list of all available model checkpoint paths
|
||||
"""
|
||||
return [os.path.join(self._checkpoin_dir, c.name) for c in self._checkpoints]
|
||||
|
||||
@property
|
||||
def model_checkpoint_path(self) -> str:
|
||||
"""
|
||||
TF compatible call to get most recent checkpoint
|
||||
:return: path of the most recent model checkpoint
|
||||
"""
|
||||
return os.path.join(self._checkpoin_dir, self._checkpoints[-1].name)
|
||||
|
||||
def __str__(self):
|
||||
out_str = 'model_checkpoint_path: {}\n'.format(self.model_checkpoint_path)
|
||||
for c in self.all_model_checkpoint_paths:
|
||||
out_str += 'all_model_checkpoint_paths: {}\n'.format(c)
|
||||
return out_str
|
||||
|
||||
def __repr__(self):
|
||||
return str(self._checkpoints)
|
||||
|
||||
|
||||
class CheckpointStateFile(object):
|
||||
"""
|
||||
Helper class for reading from and writing to the checkpoint state file
|
||||
"""
|
||||
checkpoint_state_filename = '.coach_checkpoint'
|
||||
|
||||
def __init__(self, checkpoint_dir: str):
|
||||
self._checkpoint_state_path = os.path.join(checkpoint_dir, self.checkpoint_state_filename)
|
||||
|
||||
def exists(self) -> bool:
|
||||
"""
|
||||
:return: True if checkpoint state file exists, false otherwise
|
||||
"""
|
||||
return os.path.exists(self._checkpoint_state_path)
|
||||
|
||||
def read(self) -> Union[None, SingleCheckpoint]:
|
||||
"""
|
||||
Read checkpoint state file and interpret its content
|
||||
:return:
|
||||
"""
|
||||
if not self.exists():
|
||||
return None
|
||||
with open(self._checkpoint_state_path, 'r') as fd:
|
||||
return CheckpointFilenameParser().parse(fd.read(256))
|
||||
|
||||
def write(self, data: SingleCheckpoint) -> None:
|
||||
"""
|
||||
Writes data to checkpoint state file
|
||||
:param data: string data
|
||||
"""
|
||||
with open(self._checkpoint_state_path, 'w') as fd:
|
||||
fd.write(data.name)
|
||||
|
||||
@property
|
||||
def filename(self) -> str:
|
||||
return self.checkpoint_state_filename
|
||||
|
||||
@property
|
||||
def path(self) -> str:
|
||||
return self._checkpoint_state_path
|
||||
|
||||
|
||||
class CheckpointStateReader(object):
|
||||
"""
|
||||
Class for scanning checkpoint directory and updating the checkpoint state
|
||||
"""
|
||||
def __init__(self, checkpoint_dir: str, checkpoint_state_optional: bool=True):
|
||||
"""
|
||||
:param checkpoint_dir: path to checkpoint directory
|
||||
:param checkpoint_state_optional: If True, checkpoint state file is optional and if not found,
|
||||
directory is scanned to find the latest checkpoint. Default is True for backward compatibility
|
||||
"""
|
||||
self._checkpoint_dir = checkpoint_dir
|
||||
self._checkpoint_state_file = CheckpointStateFile(self._checkpoint_dir)
|
||||
self._checkpoint_state_optional = checkpoint_state_optional
|
||||
|
||||
def get_latest(self) -> SingleCheckpoint:
|
||||
"""
|
||||
Tries to read the checkpoint state file. If that fails, discovers latest by reading the entire directory.
|
||||
:return: checkpoint object representing the latest checkpoint
|
||||
"""
|
||||
latest = self._checkpoint_state_file.read()
|
||||
if latest is None and self._checkpoint_state_optional:
|
||||
all_checkpoints = _filter_checkpoint_files(os.listdir(self._checkpoint_dir))
|
||||
if len(all_checkpoints) > 0:
|
||||
latest = all_checkpoints[-1]
|
||||
return latest
|
||||
|
||||
def get_all(self) -> List[SingleCheckpoint]:
|
||||
"""
|
||||
Reads both the checkpoint state file as well as contents of the directory and merges them into one list.
|
||||
:return: list of checkpoint objects
|
||||
"""
|
||||
# discover all checkpoint files in directory if requested or if a valid checkpoint-state file doesn't exist
|
||||
all_checkpoints = _filter_checkpoint_files(os.listdir(self._checkpoint_dir))
|
||||
last_checkpoint = self._checkpoint_state_file.read()
|
||||
if last_checkpoint is not None:
|
||||
# remove excess checkpoints: higher checkpoint number, but not recent (e.g. from a previous run)
|
||||
all_checkpoints = all_checkpoints[: all_checkpoints.index(last_checkpoint) + 1]
|
||||
elif not self._checkpoint_state_optional:
|
||||
# if last_checkpoint is not discovered from the checkpoint-state file and it isn't optional, then
|
||||
# all checkpoint files discovered must be partial or invalid, so don't return anything
|
||||
all_checkpoints.clear()
|
||||
return all_checkpoints
|
||||
|
||||
|
||||
class CheckpointStateUpdater(object):
|
||||
"""
|
||||
Class for scanning checkpoint directory and updating the checkpoint state
|
||||
"""
|
||||
def __init__(self, checkpoint_dir: str, read_all: bool=False):
|
||||
"""
|
||||
:param checkpoint_dir: path to checkpoint directory
|
||||
:param read_all: whether to scan the directory for existing checkpoints
|
||||
"""
|
||||
self._checkpoint_dir = checkpoint_dir
|
||||
self._checkpoint_state_file = CheckpointStateFile(checkpoint_dir)
|
||||
self._all_checkpoints = list()
|
||||
# Read checkpoint state and initialize
|
||||
state_reader = CheckpointStateReader(checkpoint_dir)
|
||||
if read_all:
|
||||
self._all_checkpoints = state_reader.get_all()
|
||||
else:
|
||||
latest = state_reader.get_latest()
|
||||
if latest is not None:
|
||||
self._all_checkpoints = [latest]
|
||||
|
||||
def update(self, checkpoint: SingleCheckpoint) -> None:
|
||||
"""
|
||||
Update the checkpoint state with the latest checkpoint.
|
||||
:param checkpoint: SingleCheckpoint object containing name and number of checkpoint
|
||||
"""
|
||||
self._all_checkpoints.append(checkpoint)
|
||||
# Simply write checkpoint_name to checkpoint-state file
|
||||
self._checkpoint_state_file.write(checkpoint)
|
||||
|
||||
@property
|
||||
def last_checkpoint(self) -> Union[None, SingleCheckpoint]:
|
||||
if len(self._all_checkpoints) == 0:
|
||||
return None
|
||||
return self._all_checkpoints[-1]
|
||||
|
||||
@property
|
||||
def all_checkpoints(self) -> List[SingleCheckpoint]:
|
||||
return self._all_checkpoints
|
||||
|
||||
def get_checkpoint_state(self) -> Union[None, CheckpointState]:
|
||||
"""
|
||||
:return: The most recent checkpoint state
|
||||
"""
|
||||
if len(self._all_checkpoints) == 0:
|
||||
return None
|
||||
return CheckpointState(self._all_checkpoints, self._checkpoint_dir)
|
||||
|
||||
|
||||
class CheckpointFilenameParser(object):
|
||||
"""
|
||||
Helper object for parsing filenames that are potentially checkpoints
|
||||
"""
|
||||
coach_checkpoint_filename_pattern = r'\A(([0-9]+)[^0-9])?.*?\.ckpt(-([0-9]+))?'
|
||||
|
||||
def __init__(self):
|
||||
self._prog = re.compile(self.coach_checkpoint_filename_pattern)
|
||||
|
||||
def parse(self, filename: str) -> Union[None, SingleCheckpoint]:
|
||||
"""
|
||||
Tries to parse the filename using the checkpoint filename pattern. If successful,
|
||||
it returns tuple of (checkpoint-number, checkpoint-name). Otherwise it returns None.
|
||||
:param filename: filename to be parsed
|
||||
:return: None or (checkpoint-number, checkpoint-name)
|
||||
"""
|
||||
m = self._prog.search(filename)
|
||||
if m is not None and (m.group(2) is not None or m.group(4) is not None):
|
||||
assert m.group(2) is None or m.group(4) is None # Only one group must be valid
|
||||
checkpoint_num = int(m.group(2) if m.group(2) is not None else m.group(4))
|
||||
return SingleCheckpoint(checkpoint_num, m.group(0))
|
||||
return None
|
||||
|
||||
|
||||
def _filter_checkpoint_files(filenames: List[str], sort_by_num: bool=True) -> List[SingleCheckpoint]:
|
||||
"""
|
||||
Given a list of potential file names, return the ones that match checkpoint pattern along with
|
||||
the checkpoint number of each file name.
|
||||
:param filenames: list of all filenames
|
||||
:param sort_by_num: whether to sort the output result by checkpoint number
|
||||
:return: list of (checkpoint-number, checkpoint-filename) tuples
|
||||
"""
|
||||
parser = CheckpointFilenameParser()
|
||||
checkpoints = [ckp for ckp in [parser.parse(fn) for fn in filenames] if ckp is not None]
|
||||
if sort_by_num:
|
||||
checkpoints.sort(key=lambda x: x.num)
|
||||
return checkpoints
|
||||
|
||||
|
||||
def get_checkpoint_state(checkpoint_dir: str, all_checkpoints=False) ->Union[CheckpointState, None]:
|
||||
"""
|
||||
Scan checkpoint directory and find the list of checkpoint files.
|
||||
:param checkpoint_dir: directory where checkpoints are saved
|
||||
:param all_checkpoints: if True, scan the directory and return list of all checkpoints
|
||||
as well as the most recent one
|
||||
:return: a CheckpointState for checkpoint_dir containing a sorted list of checkpoints by checkpoint-number.
|
||||
If no matching files are found, returns None.
|
||||
"""
|
||||
return CheckpointStateUpdater(checkpoint_dir, read_all=all_checkpoints).get_checkpoint_state()
|
||||
@@ -2,8 +2,7 @@ from rl_coach.data_stores.data_store import DataStore, DataStoreParameters
|
||||
from minio import Minio
|
||||
from minio.error import ResponseError
|
||||
from configparser import ConfigParser, Error
|
||||
from google.protobuf import text_format
|
||||
from tensorflow.python.training.checkpoint_state_pb2 import CheckpointState
|
||||
from rl_coach.checkpoint import CheckpointStateFile
|
||||
from rl_coach.data_stores.data_store import SyncFiles
|
||||
|
||||
import os
|
||||
@@ -24,6 +23,7 @@ class S3DataStoreParameters(DataStoreParameters):
|
||||
|
||||
class S3DataStore(DataStore):
|
||||
def __init__(self, params: S3DataStoreParameters):
|
||||
super(S3DataStore, self).__init__(params)
|
||||
self.params = params
|
||||
access_key = None
|
||||
secret_key = None
|
||||
@@ -51,14 +51,15 @@ class S3DataStore(DataStore):
|
||||
|
||||
def save_to_store(self):
|
||||
try:
|
||||
# remove lock file if it exists
|
||||
self.mc.remove_object(self.params.bucket_name, SyncFiles.LOCKFILE.value)
|
||||
|
||||
# Acquire lock
|
||||
self.mc.put_object(self.params.bucket_name, SyncFiles.LOCKFILE.value, io.BytesIO(b''), 0)
|
||||
|
||||
checkpoint_file = None
|
||||
for root, dirs, files in os.walk(self.params.checkpoint_dir):
|
||||
for filename in files:
|
||||
if filename == 'checkpoint':
|
||||
if filename == CheckpointStateFile.checkpoint_state_filename:
|
||||
checkpoint_file = (root, filename)
|
||||
continue
|
||||
abs_name = os.path.abspath(os.path.join(root, filename))
|
||||
@@ -69,6 +70,7 @@ class S3DataStore(DataStore):
|
||||
rel_name = os.path.relpath(abs_name, self.params.checkpoint_dir)
|
||||
self.mc.fput_object(self.params.bucket_name, rel_name, abs_name)
|
||||
|
||||
# release lock
|
||||
self.mc.remove_object(self.params.bucket_name, SyncFiles.LOCKFILE.value)
|
||||
|
||||
except ResponseError as e:
|
||||
@@ -76,14 +78,16 @@ class S3DataStore(DataStore):
|
||||
|
||||
def load_from_store(self):
|
||||
try:
|
||||
filename = os.path.abspath(os.path.join(self.params.checkpoint_dir, "checkpoint"))
|
||||
state_file = CheckpointStateFile(os.path.abspath(self.params.checkpoint_dir))
|
||||
|
||||
# wait until lock is removed
|
||||
while True:
|
||||
objects = self.mc.list_objects_v2(self.params.bucket_name, SyncFiles.LOCKFILE.value)
|
||||
|
||||
if next(objects, None) is None:
|
||||
try:
|
||||
self.mc.fget_object(self.params.bucket_name, "checkpoint", filename)
|
||||
# fetch checkpoint state file from S3
|
||||
self.mc.fget_object(self.params.bucket_name, state_file.filename, state_file.path)
|
||||
except Exception as e:
|
||||
continue
|
||||
break
|
||||
@@ -101,13 +105,9 @@ class S3DataStore(DataStore):
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
ckpt = CheckpointState()
|
||||
if os.path.exists(filename):
|
||||
contents = open(filename, 'r').read()
|
||||
text_format.Merge(contents, ckpt)
|
||||
rel_path = os.path.relpath(ckpt.model_checkpoint_path, self.params.checkpoint_dir)
|
||||
|
||||
objects = self.mc.list_objects_v2(self.params.bucket_name, prefix=rel_path, recursive=True)
|
||||
checkpoint_state = state_file.read()
|
||||
if checkpoint_state is not None:
|
||||
objects = self.mc.list_objects_v2(self.params.bucket_name, prefix=checkpoint_state.name, recursive=True)
|
||||
for obj in objects:
|
||||
filename = os.path.abspath(os.path.join(self.params.checkpoint_dir, obj.object_name))
|
||||
if not os.path.exists(filename):
|
||||
|
||||
@@ -25,6 +25,7 @@ import contextlib
|
||||
from rl_coach.base_parameters import iterable_to_items, TaskParameters, DistributedTaskParameters, Frameworks, \
|
||||
VisualizationParameters, \
|
||||
Parameters, PresetValidationParameters, RunType
|
||||
from rl_coach.checkpoint import CheckpointStateUpdater, get_checkpoint_state, SingleCheckpoint
|
||||
from rl_coach.core_types import TotalStepsCounter, RunPhase, PlayingStepsType, TrainingSteps, EnvironmentEpisodes, \
|
||||
EnvironmentSteps, \
|
||||
StepMethod, Transition
|
||||
@@ -32,7 +33,7 @@ from rl_coach.environments.environment import Environment
|
||||
from rl_coach.level_manager import LevelManager
|
||||
from rl_coach.logger import screen, Logger
|
||||
from rl_coach.saver import SaverCollection
|
||||
from rl_coach.utils import get_checkpoint_state, set_cpu, start_shell_command_and_wait
|
||||
from rl_coach.utils import set_cpu, start_shell_command_and_wait
|
||||
from rl_coach.data_stores.data_store_impl import get_data_store as data_store_creator
|
||||
from rl_coach.memories.backend.memory_impl import get_memory_backend
|
||||
from rl_coach.data_stores.data_store import SyncFiles
|
||||
@@ -115,6 +116,7 @@ class GraphManager(object):
|
||||
self.checkpoint_id = 0
|
||||
|
||||
self.checkpoint_saver = None
|
||||
self.checkpoint_state_updater = None
|
||||
self.graph_logger = Logger()
|
||||
self.data_store = None
|
||||
|
||||
@@ -495,7 +497,7 @@ class GraphManager(object):
|
||||
self.act(EnvironmentEpisodes(1))
|
||||
self.sync()
|
||||
if self.should_stop():
|
||||
if self.task_parameters.checkpoint_save_dir:
|
||||
if self.task_parameters.checkpoint_save_dir and os.path.exists(self.task_parameters.checkpoint_save_dir):
|
||||
open(os.path.join(self.task_parameters.checkpoint_save_dir, SyncFiles.FINISHED.value), 'w').close()
|
||||
if hasattr(self, 'data_store_params'):
|
||||
data_store = self.get_data_store(self.data_store_params)
|
||||
@@ -579,29 +581,35 @@ class GraphManager(object):
|
||||
self.save_checkpoint()
|
||||
|
||||
def save_checkpoint(self):
|
||||
# create current session's checkpoint directory
|
||||
if self.task_parameters.checkpoint_save_dir is None:
|
||||
self.task_parameters.checkpoint_save_dir = os.path.join(self.task_parameters.experiment_path, 'checkpoint')
|
||||
|
||||
filename = "{}_Step-{}.ckpt".format(
|
||||
self.checkpoint_id,
|
||||
self.total_steps_counters[RunPhase.TRAIN][EnvironmentSteps])
|
||||
if not os.path.exists(self.task_parameters.checkpoint_save_dir):
|
||||
os.mkdir(self.task_parameters.checkpoint_save_dir) # Create directory structure
|
||||
|
||||
if self.checkpoint_state_updater is None:
|
||||
self.checkpoint_state_updater = CheckpointStateUpdater(self.task_parameters.checkpoint_save_dir)
|
||||
|
||||
checkpoint_name = "{}_Step-{}.ckpt".format(
|
||||
self.checkpoint_id, self.total_steps_counters[RunPhase.TRAIN][EnvironmentSteps])
|
||||
checkpoint_path = os.path.join(self.task_parameters.checkpoint_save_dir, checkpoint_name)
|
||||
|
||||
checkpoint_path = os.path.join(self.task_parameters.checkpoint_save_dir,
|
||||
filename)
|
||||
if not os.path.exists(os.path.dirname(checkpoint_path)):
|
||||
os.mkdir(os.path.dirname(checkpoint_path)) # Create directory structure
|
||||
if not isinstance(self.task_parameters, DistributedTaskParameters):
|
||||
saved_checkpoint_path = self.checkpoint_saver.save(self.sess, checkpoint_path)
|
||||
else:
|
||||
saved_checkpoint_path = checkpoint_path
|
||||
|
||||
# this is required in order for agents to save additional information like a DND for example
|
||||
[manager.save_checkpoint(filename) for manager in self.level_managers]
|
||||
[manager.save_checkpoint(checkpoint_name) for manager in self.level_managers]
|
||||
|
||||
# the ONNX graph will be stored only if checkpoints are stored and the -onnx flag is used
|
||||
if self.task_parameters.export_onnx_graph:
|
||||
self.save_onnx_graph()
|
||||
|
||||
# write the new checkpoint name to a file to signal this checkpoint has been fully saved
|
||||
self.checkpoint_state_updater.update(SingleCheckpoint(self.checkpoint_id, checkpoint_name))
|
||||
|
||||
screen.log_dict(
|
||||
OrderedDict([
|
||||
("Saving in path", saved_checkpoint_path),
|
||||
|
||||
@@ -12,37 +12,26 @@ import os
|
||||
import math
|
||||
|
||||
from rl_coach.base_parameters import TaskParameters, DistributedCoachSynchronizationType
|
||||
from rl_coach.checkpoint import CheckpointStateFile, CheckpointStateReader
|
||||
from rl_coach.core_types import EnvironmentSteps, RunPhase, EnvironmentEpisodes
|
||||
from google.protobuf import text_format
|
||||
from tensorflow.python.training.checkpoint_state_pb2 import CheckpointState
|
||||
from rl_coach.data_stores.data_store import SyncFiles
|
||||
|
||||
|
||||
def has_checkpoint(checkpoint_dir):
|
||||
"""
|
||||
True if a checkpoint is present in checkpoint_dir
|
||||
"""
|
||||
if os.path.isdir(checkpoint_dir):
|
||||
if len(os.listdir(checkpoint_dir)) > 0:
|
||||
return os.path.isfile(os.path.join(checkpoint_dir, "checkpoint"))
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def wait_for_checkpoint(checkpoint_dir, data_store=None, timeout=10):
|
||||
"""
|
||||
block until there is a checkpoint in checkpoint_dir
|
||||
"""
|
||||
chkpt_state_file = CheckpointStateFile(checkpoint_dir)
|
||||
for i in range(timeout):
|
||||
if data_store:
|
||||
data_store.load_from_store()
|
||||
|
||||
if has_checkpoint(checkpoint_dir):
|
||||
if chkpt_state_file.read() is not None:
|
||||
return
|
||||
time.sleep(10)
|
||||
|
||||
# one last time
|
||||
if has_checkpoint(checkpoint_dir):
|
||||
if chkpt_state_file.read() is not None:
|
||||
return
|
||||
|
||||
raise ValueError((
|
||||
@@ -54,21 +43,6 @@ def wait_for_checkpoint(checkpoint_dir, data_store=None, timeout=10):
|
||||
))
|
||||
|
||||
|
||||
def data_store_ckpt_load(data_store):
|
||||
while True:
|
||||
data_store.load_from_store()
|
||||
time.sleep(10)
|
||||
|
||||
|
||||
def get_latest_checkpoint(checkpoint_dir):
|
||||
if os.path.exists(os.path.join(checkpoint_dir, 'checkpoint')):
|
||||
ckpt = CheckpointState()
|
||||
contents = open(os.path.join(checkpoint_dir, 'checkpoint'), 'r').read()
|
||||
text_format.Merge(contents, ckpt)
|
||||
rel_path = os.path.relpath(ckpt.model_checkpoint_path, checkpoint_dir)
|
||||
return int(rel_path.split('_Step')[0])
|
||||
|
||||
|
||||
def should_stop(checkpoint_dir):
|
||||
return os.path.exists(os.path.join(checkpoint_dir, SyncFiles.FINISHED.value))
|
||||
|
||||
@@ -83,6 +57,7 @@ def rollout_worker(graph_manager, data_store, num_workers, task_parameters):
|
||||
graph_manager.create_graph(task_parameters)
|
||||
with graph_manager.phase_context(RunPhase.TRAIN):
|
||||
|
||||
chkpt_state_reader = CheckpointStateReader(checkpoint_dir, checkpoint_state_optional=False)
|
||||
last_checkpoint = 0
|
||||
|
||||
act_steps = math.ceil((graph_manager.agent_params.algorithm.num_consecutive_playing_steps.num_steps)/num_workers)
|
||||
@@ -97,20 +72,20 @@ def rollout_worker(graph_manager, data_store, num_workers, task_parameters):
|
||||
elif type(graph_manager.agent_params.algorithm.num_consecutive_playing_steps) == EnvironmentEpisodes:
|
||||
graph_manager.act(EnvironmentEpisodes(num_steps=act_steps))
|
||||
|
||||
new_checkpoint = get_latest_checkpoint(checkpoint_dir)
|
||||
|
||||
new_checkpoint = chkpt_state_reader.get_latest()
|
||||
if graph_manager.agent_params.algorithm.distributed_coach_synchronization_type == DistributedCoachSynchronizationType.SYNC:
|
||||
while new_checkpoint < last_checkpoint + 1:
|
||||
while new_checkpoint is None or new_checkpoint.num < last_checkpoint + 1:
|
||||
if should_stop(checkpoint_dir):
|
||||
break
|
||||
if data_store:
|
||||
data_store.load_from_store()
|
||||
new_checkpoint = get_latest_checkpoint(checkpoint_dir)
|
||||
new_checkpoint = chkpt_state_reader.get_latest()
|
||||
|
||||
graph_manager.restore_checkpoint()
|
||||
|
||||
if graph_manager.agent_params.algorithm.distributed_coach_synchronization_type == DistributedCoachSynchronizationType.ASYNC:
|
||||
if new_checkpoint > last_checkpoint:
|
||||
if new_checkpoint is not None and new_checkpoint.num > last_checkpoint:
|
||||
graph_manager.restore_checkpoint()
|
||||
|
||||
last_checkpoint = new_checkpoint
|
||||
if new_checkpoint is not None:
|
||||
last_checkpoint = new_checkpoint.num
|
||||
|
||||
@@ -2,7 +2,7 @@ import os
|
||||
import pytest
|
||||
import tempfile
|
||||
|
||||
from rl_coach import utils
|
||||
from rl_coach import checkpoint
|
||||
|
||||
|
||||
@pytest.mark.unit_test
|
||||
@@ -10,8 +10,15 @@ def test_get_checkpoint_state():
|
||||
files = ['4.test.ckpt.ext', '2.test.ckpt.ext', '3.test.ckpt.ext', '1.test.ckpt.ext', 'prefix.10.test.ckpt.ext']
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
[open(os.path.join(temp_dir, fn), 'a').close() for fn in files]
|
||||
checkpoint_state = utils.get_checkpoint_state(temp_dir)
|
||||
checkpoint_state = checkpoint.get_checkpoint_state(temp_dir, all_checkpoints=True)
|
||||
assert checkpoint_state.model_checkpoint_path == os.path.join(temp_dir, '4.test.ckpt')
|
||||
assert checkpoint_state.all_model_checkpoint_paths == \
|
||||
[os.path.join(temp_dir, f[:-4]) for f in sorted(files[:-1])]
|
||||
|
||||
reader = checkpoint.CheckpointStateReader(temp_dir, checkpoint_state_optional=False)
|
||||
assert reader.get_latest() is None
|
||||
assert len(reader.get_all()) == 0
|
||||
|
||||
reader = checkpoint.CheckpointStateReader(temp_dir)
|
||||
assert reader.get_latest().num == 4
|
||||
assert [ckp.num for ckp in reader.get_all()] == [1, 2, 3, 4]
|
||||
@@ -30,7 +30,7 @@ def training_worker(graph_manager, task_parameters):
|
||||
|
||||
graph_manager.setup_memory_backend()
|
||||
|
||||
while(steps < graph_manager.improve_steps.num_steps):
|
||||
while steps < graph_manager.improve_steps.num_steps:
|
||||
|
||||
graph_manager.phase = core_types.RunPhase.TRAIN
|
||||
graph_manager.fetch_from_worker(graph_manager.agent_params.algorithm.num_consecutive_playing_steps)
|
||||
|
||||
@@ -546,58 +546,3 @@ def start_shell_command_and_wait(command):
|
||||
|
||||
def indent_string(string):
|
||||
return '\t' + string.replace('\n', '\n\t')
|
||||
|
||||
|
||||
class CheckpointState(object):
|
||||
"""
|
||||
Helper class for checkpoint directory information. It replicates
|
||||
the CheckpointState protobuf class in tensorflow.
|
||||
"""
|
||||
def __init__(self, checkpoints: List[str]):
|
||||
self._checkpoints = checkpoints
|
||||
|
||||
@property
|
||||
def all_model_checkpoint_paths(self):
|
||||
return self._checkpoints
|
||||
|
||||
@property
|
||||
def model_checkpoint_path(self):
|
||||
return self._checkpoints[-1]
|
||||
|
||||
def __str__(self):
|
||||
out_str = 'model_checkpoint_path: {}\n'.format(self.model_checkpoint_path)
|
||||
for c in self._checkpoints:
|
||||
out_str += 'all_model_checkpoint_paths: {}\n'.format(c)
|
||||
return out_str
|
||||
|
||||
def __repr__(self):
|
||||
return str(self._checkpoints)
|
||||
|
||||
|
||||
COACH_CHECKPOINT_PATTERN = r'\A(([0-9]+)[^0-9])?.*?\.ckpt(-([0-9]+))?'
|
||||
|
||||
|
||||
def get_checkpoint_state(checkpoint_dir: Union[str, List[str]], filename_pattern: str=COACH_CHECKPOINT_PATTERN) ->\
|
||||
Union[CheckpointState, None]:
|
||||
"""
|
||||
Finds the latest checkpoint file. It uses the first group of filename_pattern (i.e. group(2) or group(4) to sort
|
||||
the checkpoint names and find the latest checkpoint
|
||||
:param checkpoint_dir: directory where checkpoints are saved or list of all files in a directory
|
||||
:param filename_pattern: regex pattern for checkpoint filenames
|
||||
:return: a CheckpointState for checkpoint_dir containing a sorted list of checkpoint names. If no matching
|
||||
files are found, returns None.
|
||||
"""
|
||||
prog = re.compile(filename_pattern)
|
||||
checkpoints = dict()
|
||||
filenames = os.listdir(checkpoint_dir) if isinstance(checkpoint_dir, str) else checkpoint_dir
|
||||
for name in filenames:
|
||||
m = prog.search(name)
|
||||
if m is not None and (m.group(2) is not None or m.group(4) is not None):
|
||||
if m.group(2) is not None and m.group(4) is not None:
|
||||
assert m.group(2) == m.group(4)
|
||||
checkpoint_count = int(m.group(2) if m.group(2) is not None else m.group(4))
|
||||
full_path = os.path.join(checkpoint_dir, m.group(0)) if isinstance(checkpoint_dir, str) else m.group(0)
|
||||
checkpoints[checkpoint_count] = full_path
|
||||
if len(checkpoints) == 0:
|
||||
return None
|
||||
return CheckpointState([checkpoints[k] for k in sorted(checkpoints.keys())])
|
||||
|
||||
Reference in New Issue
Block a user