1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 11:10:20 +01:00

Uploading checkpoint if crd provided (#191)

* Uploading checkpoint if crd provided
* Changing the calculation of total steps because of a recent change in core_types

Fixes #195
This commit is contained in:
Ajay Deshpande
2019-04-26 12:27:33 -07:00
committed by Scott Leishman
parent b3db9ce77d
commit 33dc29ee99
8 changed files with 122 additions and 40 deletions

View File

@@ -103,21 +103,22 @@ def handle_distributed_coach_tasks(graph_manager, args, task_parameters):
data_store_params.checkpoint_dir = ckpt_inside_container
graph_manager.data_store_params = data_store_params
data_store = None
if args.data_store_params:
data_store = get_data_store(data_store_params)
if args.distributed_coach_run_type == RunType.TRAINER:
task_parameters.checkpoint_save_dir = ckpt_inside_container
training_worker(
graph_manager=graph_manager,
task_parameters=task_parameters,
data_store=data_store,
is_multi_node_test=args.is_multi_node_test
)
if args.distributed_coach_run_type == RunType.ROLLOUT_WORKER:
task_parameters.checkpoint_restore_path = ckpt_inside_container
data_store = None
if args.data_store_params:
data_store = get_data_store(data_store_params)
rollout_worker(
graph_manager=graph_manager,
data_store=data_store,
@@ -169,7 +170,7 @@ def handle_distributed_coach_orchestrator(args):
memory_backend_parameters=memory_backend_params,
data_store_params=ds_params_instance)
orchestrator = Kubernetes(orchestration_params)
if not orchestrator.setup():
if not orchestrator.setup(args.checkpoint_restore_dir):
print("Could not setup.")
return 1
@@ -394,7 +395,9 @@ class CoachLauncher(object):
# validate the checkpoints args
if args.checkpoint_restore_dir is not None and not os.path.exists(args.checkpoint_restore_dir):
screen.error("The requested checkpoint folder to load from does not exist.")
# If distributed trainer, the checkpoint dir is not yet available so skipping the check in that case.
if not (args.distributed_coach and args.distributed_coach_run_type in [RunType.TRAINER, RunType.ROLLOUT_WORKER]):
screen.error("The requested checkpoint folder to load from does not exist.")
# validate the checkpoints args
if args.checkpoint_restore_file is not None and not glob(args.checkpoint_restore_file + '*'):

View File

@@ -44,7 +44,11 @@ class DataStore(object):
def load_from_store(self):
pass
def setup_checkpoint_dir(self, crd=None):
pass
class SyncFiles(Enum):
FINISHED = ".finished"
LOCKFILE = ".lock"
TRAINER_READY = ".ready"

View File

@@ -284,3 +284,8 @@ class NFSDataStore(DataStore):
return False
return True
def setup_checkpoint_dir(self, crd=None):
if crd:
# TODO: find a way to upload this to the deployed nfs store.
pass

View File

@@ -77,6 +77,9 @@ class S3DataStore(DataStore):
return True
def save_to_store(self):
self._save_to_store(self.params.checkpoint_dir)
def _save_to_store(self, checkpoint_dir):
"""
save_to_store() uploads the policy checkpoint, gifs and videos to the S3 data store. It reads the checkpoint state files and
uploads only the latest checkpoint files to S3. It is used by the trainer in Coach when used in the distributed mode.
@@ -88,24 +91,32 @@ class S3DataStore(DataStore):
# Acquire lock
self.mc.put_object(self.params.bucket_name, SyncFiles.LOCKFILE.value, io.BytesIO(b''), 0)
state_file = CheckpointStateFile(os.path.abspath(self.params.checkpoint_dir))
state_file = CheckpointStateFile(os.path.abspath(checkpoint_dir))
if state_file.exists():
ckpt_state = state_file.read()
checkpoint_file = None
for root, dirs, files in os.walk(self.params.checkpoint_dir):
for root, dirs, files in os.walk(checkpoint_dir):
for filename in files:
if filename == CheckpointStateFile.checkpoint_state_filename:
checkpoint_file = (root, filename)
continue
if filename.startswith(ckpt_state.name):
abs_name = os.path.abspath(os.path.join(root, filename))
rel_name = os.path.relpath(abs_name, self.params.checkpoint_dir)
rel_name = os.path.relpath(abs_name, checkpoint_dir)
self.mc.fput_object(self.params.bucket_name, rel_name, abs_name)
abs_name = os.path.abspath(os.path.join(checkpoint_file[0], checkpoint_file[1]))
rel_name = os.path.relpath(abs_name, self.params.checkpoint_dir)
rel_name = os.path.relpath(abs_name, checkpoint_dir)
self.mc.fput_object(self.params.bucket_name, rel_name, abs_name)
# upload Finished if present
if os.path.exists(os.path.join(checkpoint_dir, SyncFiles.FINISHED.value)):
self.mc.put_object(self.params.bucket_name, SyncFiles.FINISHED.value, io.BytesIO(b''), 0)
# upload Ready if present
if os.path.exists(os.path.join(checkpoint_dir, SyncFiles.TRAINER_READY.value)):
self.mc.put_object(self.params.bucket_name, SyncFiles.TRAINER_READY.value, io.BytesIO(b''), 0)
# release lock
self.mc.remove_object(self.params.bucket_name, SyncFiles.LOCKFILE.value)
@@ -121,6 +132,7 @@ class S3DataStore(DataStore):
if self.params.expt_dir and os.path.exists(os.path.join(self.params.expt_dir, 'gifs')):
for filename in os.listdir(os.path.join(self.params.expt_dir, 'gifs')):
self.mc.fput_object(self.params.bucket_name, filename, os.path.join(self.params.expt_dir, 'gifs', filename))
except ResponseError as e:
print("Got exception: %s\n while saving to S3", e)
@@ -157,6 +169,18 @@ class S3DataStore(DataStore):
except Exception as e:
pass
# Check if there's a ready file
objects = self.mc.list_objects_v2(self.params.bucket_name, SyncFiles.TRAINER_READY.value)
if next(objects, None) is not None:
try:
self.mc.fget_object(
self.params.bucket_name, SyncFiles.TRAINER_READY.value,
os.path.abspath(os.path.join(self.params.checkpoint_dir, SyncFiles.TRAINER_READY.value))
)
except Exception as e:
pass
checkpoint_state = state_file.read()
if checkpoint_state is not None:
objects = self.mc.list_objects_v2(self.params.bucket_name, prefix=checkpoint_state.name, recursive=True)
@@ -167,3 +191,7 @@ class S3DataStore(DataStore):
except ResponseError as e:
print("Got exception: %s\n while loading from S3", e)
def setup_checkpoint_dir(self, crd=None):
if crd:
self._save_to_store(crd)

View File

@@ -37,6 +37,7 @@ from rl_coach.utils import set_cpu, start_shell_command_and_wait
from rl_coach.data_stores.data_store_impl import get_data_store as data_store_creator
from rl_coach.memories.backend.memory_impl import get_memory_backend
from rl_coach.data_stores.data_store import SyncFiles
from rl_coach.checkpoint import CheckpointStateReader
from rl_coach.core_types import TimeTypes
@@ -589,6 +590,10 @@ class GraphManager(object):
[manager.restore_checkpoint(checkpoint_restore_dir) for manager in self.level_managers]
# Set the last checkpoint ID
chkpt_state_reader = CheckpointStateReader(self.task_parameters.checkpoint_restore_path, checkpoint_state_optional=False)
self.checkpoint_id = chkpt_state_reader.get_latest().num + 1
def _get_checkpoint_state_tf(self, checkpoint_restore_dir):
import tensorflow as tf
return tf.train.get_checkpoint_state(checkpoint_restore_dir)
@@ -721,6 +726,13 @@ class GraphManager(object):
return data_store_creator(param)
def signal_ready(self):
if self.task_parameters.checkpoint_save_dir and os.path.exists(self.task_parameters.checkpoint_save_dir):
open(os.path.join(self.task_parameters.checkpoint_save_dir, SyncFiles.TRAINER_READY.value), 'w').close()
if hasattr(self, 'data_store_params'):
data_store = self.get_data_store(self.data_store_params)
data_store.save_to_store()
def close(self) -> None:
"""
Clean up to close environments.

View File

@@ -118,7 +118,7 @@ class Kubernetes(Deploy):
self.s3_access_key = os.environ.get('ACCESS_KEY_ID')
self.s3_secret_key = os.environ.get('SECRET_ACCESS_KEY')
def setup(self) -> bool:
def setup(self, crd=None) -> bool:
"""
Deploys the memory backend and data stores if required.
"""
@@ -128,6 +128,9 @@ class Kubernetes(Deploy):
return False
if self.params.data_store_params.store_type == "nfs":
self.nfs_pvc = self.data_store.get_info()
# Upload checkpoints in checkpoint_restore_dir (if provided) to the data store
self.data_store.setup_checkpoint_dir(crd)
return True
def deploy_trainer(self) -> bool:
@@ -141,7 +144,6 @@ class Kubernetes(Deploy):
trainer_params.command += ['--memory_backend_params', json.dumps(self.params.memory_backend_parameters.__dict__)]
trainer_params.command += ['--data_store_params', json.dumps(self.params.data_store_params.__dict__)]
name = "{}-{}".format(trainer_params.run_type, uuid.uuid4())
if self.params.data_store_params.store_type == "nfs":

View File

@@ -33,30 +33,50 @@ from rl_coach.core_types import EnvironmentSteps, RunPhase, EnvironmentEpisodes
from rl_coach.data_stores.data_store import SyncFiles
def wait_for(wait_func, data_store=None, timeout=10):
"""
block until wait_func is true
"""
for i in range(timeout):
if data_store:
data_store.load_from_store()
if wait_func():
return
time.sleep(10)
# one last time
if wait_func():
return
raise ValueError((
'Waited {timeout} seconds, but condition timed out'
).format(
timeout=timeout,
))
def wait_for_checkpoint(checkpoint_dir, data_store=None, timeout=10):
"""
block until there is a checkpoint in checkpoint_dir
"""
chkpt_state_file = CheckpointStateFile(checkpoint_dir)
for i in range(timeout):
if data_store:
data_store.load_from_store()
if chkpt_state_file.read() is not None:
return
time.sleep(10)
def wait():
return chkpt_state_file.read() is not None
# one last time
if chkpt_state_file.read() is not None:
return
wait_for(wait, data_store, timeout)
raise ValueError((
'Waited {timeout} seconds, but checkpoint never found in '
'{checkpoint_dir}'
).format(
timeout=timeout,
checkpoint_dir=checkpoint_dir,
))
def wait_for_trainer_ready(checkpoint_dir, data_store=None, timeout=10):
"""
Block until trainer is ready
"""
def wait():
return os.path.exists(os.path.join(checkpoint_dir, SyncFiles.TRAINER_READY.value))
wait_for(wait, data_store, timeout)
def should_stop(checkpoint_dir):
@@ -69,17 +89,18 @@ def rollout_worker(graph_manager, data_store, num_workers, task_parameters):
"""
checkpoint_dir = task_parameters.checkpoint_restore_path
wait_for_checkpoint(checkpoint_dir, data_store)
wait_for_trainer_ready(checkpoint_dir, data_store)
graph_manager.create_graph(task_parameters)
with graph_manager.phase_context(RunPhase.TRAIN):
chkpt_state_reader = CheckpointStateReader(checkpoint_dir, checkpoint_state_optional=False)
last_checkpoint = 0
last_checkpoint = chkpt_state_reader.get_latest().num
# this worker should play a fraction of the total playing steps per rollout
act_steps = graph_manager.agent_params.algorithm.num_consecutive_playing_steps / num_workers
for i in range(graph_manager.improve_steps / act_steps):
training_steps = (graph_manager.improve_steps / act_steps.num_steps).num_steps
for i in range(training_steps):
if should_stop(checkpoint_dir):
break

View File

@@ -24,24 +24,30 @@ from rl_coach import core_types
from rl_coach.logger import screen
def data_store_ckpt_save(data_store):
while True:
data_store.save_to_store()
time.sleep(10)
def data_store_ckpt_load(data_store):
if data_store:
data_store.load_from_store()
def training_worker(graph_manager, task_parameters, is_multi_node_test):
def training_worker(graph_manager, task_parameters, data_store, is_multi_node_test):
"""
restore a checkpoint then perform rollouts using the restored model
:param graph_manager: An instance of the graph manager
:param task_parameters: An instance of task parameters
:param is_multi_node_test: If this is a multi node test insted of a normal run.
"""
# initialize graph
graph_manager.create_graph(task_parameters)
# Load checkpoint if provided
if task_parameters.checkpoint_restore_path:
data_store_ckpt_load(data_store)
# initialize graph
graph_manager.create_graph(task_parameters)
# save randomly initialized graph
graph_manager.save_checkpoint()
else:
# initialize graph
graph_manager.create_graph(task_parameters)
# save randomly initialized graph
graph_manager.save_checkpoint()
# training loop
steps = 0
@@ -50,6 +56,7 @@ def training_worker(graph_manager, task_parameters, is_multi_node_test):
eval_offset = 1
graph_manager.setup_memory_backend()
graph_manager.signal_ready()
while steps < graph_manager.improve_steps.num_steps: