diff --git a/rl_coach/coach.py b/rl_coach/coach.py index f84a758..56d2b2e 100644 --- a/rl_coach/coach.py +++ b/rl_coach/coach.py @@ -103,21 +103,22 @@ def handle_distributed_coach_tasks(graph_manager, args, task_parameters): data_store_params.checkpoint_dir = ckpt_inside_container graph_manager.data_store_params = data_store_params + data_store = None + if args.data_store_params: + data_store = get_data_store(data_store_params) + if args.distributed_coach_run_type == RunType.TRAINER: task_parameters.checkpoint_save_dir = ckpt_inside_container training_worker( graph_manager=graph_manager, task_parameters=task_parameters, + data_store=data_store, is_multi_node_test=args.is_multi_node_test ) if args.distributed_coach_run_type == RunType.ROLLOUT_WORKER: task_parameters.checkpoint_restore_path = ckpt_inside_container - data_store = None - if args.data_store_params: - data_store = get_data_store(data_store_params) - rollout_worker( graph_manager=graph_manager, data_store=data_store, @@ -169,7 +170,7 @@ def handle_distributed_coach_orchestrator(args): memory_backend_parameters=memory_backend_params, data_store_params=ds_params_instance) orchestrator = Kubernetes(orchestration_params) - if not orchestrator.setup(): + if not orchestrator.setup(args.checkpoint_restore_dir): print("Could not setup.") return 1 @@ -394,7 +395,9 @@ class CoachLauncher(object): # validate the checkpoints args if args.checkpoint_restore_dir is not None and not os.path.exists(args.checkpoint_restore_dir): - screen.error("The requested checkpoint folder to load from does not exist.") + # If distributed trainer, the checkpoint dir is not yet available so skipping the check in that case. + if not (args.distributed_coach and args.distributed_coach_run_type in [RunType.TRAINER, RunType.ROLLOUT_WORKER]): + screen.error("The requested checkpoint folder to load from does not exist.") # validate the checkpoints args if args.checkpoint_restore_file is not None and not glob(args.checkpoint_restore_file + '*'): diff --git a/rl_coach/data_stores/data_store.py b/rl_coach/data_stores/data_store.py index 3b5bef8..b4cb7f4 100644 --- a/rl_coach/data_stores/data_store.py +++ b/rl_coach/data_stores/data_store.py @@ -44,7 +44,11 @@ class DataStore(object): def load_from_store(self): pass + def setup_checkpoint_dir(self, crd=None): + pass + class SyncFiles(Enum): FINISHED = ".finished" LOCKFILE = ".lock" + TRAINER_READY = ".ready" diff --git a/rl_coach/data_stores/nfs_data_store.py b/rl_coach/data_stores/nfs_data_store.py index 5463eca..16fde54 100644 --- a/rl_coach/data_stores/nfs_data_store.py +++ b/rl_coach/data_stores/nfs_data_store.py @@ -284,3 +284,8 @@ class NFSDataStore(DataStore): return False return True + + def setup_checkpoint_dir(self, crd=None): + if crd: + # TODO: find a way to upload this to the deployed nfs store. + pass diff --git a/rl_coach/data_stores/s3_data_store.py b/rl_coach/data_stores/s3_data_store.py index 589ee5f..959422a 100644 --- a/rl_coach/data_stores/s3_data_store.py +++ b/rl_coach/data_stores/s3_data_store.py @@ -77,6 +77,9 @@ class S3DataStore(DataStore): return True def save_to_store(self): + self._save_to_store(self.params.checkpoint_dir) + + def _save_to_store(self, checkpoint_dir): """ save_to_store() uploads the policy checkpoint, gifs and videos to the S3 data store. It reads the checkpoint state files and uploads only the latest checkpoint files to S3. It is used by the trainer in Coach when used in the distributed mode. @@ -88,24 +91,32 @@ class S3DataStore(DataStore): # Acquire lock self.mc.put_object(self.params.bucket_name, SyncFiles.LOCKFILE.value, io.BytesIO(b''), 0) - state_file = CheckpointStateFile(os.path.abspath(self.params.checkpoint_dir)) + state_file = CheckpointStateFile(os.path.abspath(checkpoint_dir)) if state_file.exists(): ckpt_state = state_file.read() checkpoint_file = None - for root, dirs, files in os.walk(self.params.checkpoint_dir): + for root, dirs, files in os.walk(checkpoint_dir): for filename in files: if filename == CheckpointStateFile.checkpoint_state_filename: checkpoint_file = (root, filename) continue if filename.startswith(ckpt_state.name): abs_name = os.path.abspath(os.path.join(root, filename)) - rel_name = os.path.relpath(abs_name, self.params.checkpoint_dir) + rel_name = os.path.relpath(abs_name, checkpoint_dir) self.mc.fput_object(self.params.bucket_name, rel_name, abs_name) abs_name = os.path.abspath(os.path.join(checkpoint_file[0], checkpoint_file[1])) - rel_name = os.path.relpath(abs_name, self.params.checkpoint_dir) + rel_name = os.path.relpath(abs_name, checkpoint_dir) self.mc.fput_object(self.params.bucket_name, rel_name, abs_name) + # upload Finished if present + if os.path.exists(os.path.join(checkpoint_dir, SyncFiles.FINISHED.value)): + self.mc.put_object(self.params.bucket_name, SyncFiles.FINISHED.value, io.BytesIO(b''), 0) + + # upload Ready if present + if os.path.exists(os.path.join(checkpoint_dir, SyncFiles.TRAINER_READY.value)): + self.mc.put_object(self.params.bucket_name, SyncFiles.TRAINER_READY.value, io.BytesIO(b''), 0) + # release lock self.mc.remove_object(self.params.bucket_name, SyncFiles.LOCKFILE.value) @@ -121,6 +132,7 @@ class S3DataStore(DataStore): if self.params.expt_dir and os.path.exists(os.path.join(self.params.expt_dir, 'gifs')): for filename in os.listdir(os.path.join(self.params.expt_dir, 'gifs')): self.mc.fput_object(self.params.bucket_name, filename, os.path.join(self.params.expt_dir, 'gifs', filename)) + except ResponseError as e: print("Got exception: %s\n while saving to S3", e) @@ -157,6 +169,18 @@ class S3DataStore(DataStore): except Exception as e: pass + # Check if there's a ready file + objects = self.mc.list_objects_v2(self.params.bucket_name, SyncFiles.TRAINER_READY.value) + + if next(objects, None) is not None: + try: + self.mc.fget_object( + self.params.bucket_name, SyncFiles.TRAINER_READY.value, + os.path.abspath(os.path.join(self.params.checkpoint_dir, SyncFiles.TRAINER_READY.value)) + ) + except Exception as e: + pass + checkpoint_state = state_file.read() if checkpoint_state is not None: objects = self.mc.list_objects_v2(self.params.bucket_name, prefix=checkpoint_state.name, recursive=True) @@ -167,3 +191,7 @@ class S3DataStore(DataStore): except ResponseError as e: print("Got exception: %s\n while loading from S3", e) + + def setup_checkpoint_dir(self, crd=None): + if crd: + self._save_to_store(crd) diff --git a/rl_coach/graph_managers/graph_manager.py b/rl_coach/graph_managers/graph_manager.py index d8618d7..dd98731 100644 --- a/rl_coach/graph_managers/graph_manager.py +++ b/rl_coach/graph_managers/graph_manager.py @@ -37,6 +37,7 @@ from rl_coach.utils import set_cpu, start_shell_command_and_wait from rl_coach.data_stores.data_store_impl import get_data_store as data_store_creator from rl_coach.memories.backend.memory_impl import get_memory_backend from rl_coach.data_stores.data_store import SyncFiles +from rl_coach.checkpoint import CheckpointStateReader from rl_coach.core_types import TimeTypes @@ -589,6 +590,10 @@ class GraphManager(object): [manager.restore_checkpoint(checkpoint_restore_dir) for manager in self.level_managers] + # Set the last checkpoint ID + chkpt_state_reader = CheckpointStateReader(self.task_parameters.checkpoint_restore_path, checkpoint_state_optional=False) + self.checkpoint_id = chkpt_state_reader.get_latest().num + 1 + def _get_checkpoint_state_tf(self, checkpoint_restore_dir): import tensorflow as tf return tf.train.get_checkpoint_state(checkpoint_restore_dir) @@ -721,6 +726,13 @@ class GraphManager(object): return data_store_creator(param) + def signal_ready(self): + if self.task_parameters.checkpoint_save_dir and os.path.exists(self.task_parameters.checkpoint_save_dir): + open(os.path.join(self.task_parameters.checkpoint_save_dir, SyncFiles.TRAINER_READY.value), 'w').close() + if hasattr(self, 'data_store_params'): + data_store = self.get_data_store(self.data_store_params) + data_store.save_to_store() + def close(self) -> None: """ Clean up to close environments. diff --git a/rl_coach/orchestrators/kubernetes_orchestrator.py b/rl_coach/orchestrators/kubernetes_orchestrator.py index faf4be9..a94e15b 100644 --- a/rl_coach/orchestrators/kubernetes_orchestrator.py +++ b/rl_coach/orchestrators/kubernetes_orchestrator.py @@ -118,7 +118,7 @@ class Kubernetes(Deploy): self.s3_access_key = os.environ.get('ACCESS_KEY_ID') self.s3_secret_key = os.environ.get('SECRET_ACCESS_KEY') - def setup(self) -> bool: + def setup(self, crd=None) -> bool: """ Deploys the memory backend and data stores if required. """ @@ -128,6 +128,9 @@ class Kubernetes(Deploy): return False if self.params.data_store_params.store_type == "nfs": self.nfs_pvc = self.data_store.get_info() + + # Upload checkpoints in checkpoint_restore_dir (if provided) to the data store + self.data_store.setup_checkpoint_dir(crd) return True def deploy_trainer(self) -> bool: @@ -141,7 +144,6 @@ class Kubernetes(Deploy): trainer_params.command += ['--memory_backend_params', json.dumps(self.params.memory_backend_parameters.__dict__)] trainer_params.command += ['--data_store_params', json.dumps(self.params.data_store_params.__dict__)] - name = "{}-{}".format(trainer_params.run_type, uuid.uuid4()) if self.params.data_store_params.store_type == "nfs": diff --git a/rl_coach/rollout_worker.py b/rl_coach/rollout_worker.py index 1d681cb..6ecc38f 100644 --- a/rl_coach/rollout_worker.py +++ b/rl_coach/rollout_worker.py @@ -33,30 +33,50 @@ from rl_coach.core_types import EnvironmentSteps, RunPhase, EnvironmentEpisodes from rl_coach.data_stores.data_store import SyncFiles +def wait_for(wait_func, data_store=None, timeout=10): + """ + block until wait_func is true + """ + for i in range(timeout): + if data_store: + data_store.load_from_store() + + if wait_func(): + return + time.sleep(10) + + # one last time + if wait_func(): + return + + raise ValueError(( + 'Waited {timeout} seconds, but condition timed out' + ).format( + timeout=timeout, + )) + + def wait_for_checkpoint(checkpoint_dir, data_store=None, timeout=10): """ block until there is a checkpoint in checkpoint_dir """ chkpt_state_file = CheckpointStateFile(checkpoint_dir) - for i in range(timeout): - if data_store: - data_store.load_from_store() - if chkpt_state_file.read() is not None: - return - time.sleep(10) + def wait(): + return chkpt_state_file.read() is not None - # one last time - if chkpt_state_file.read() is not None: - return + wait_for(wait, data_store, timeout) - raise ValueError(( - 'Waited {timeout} seconds, but checkpoint never found in ' - '{checkpoint_dir}' - ).format( - timeout=timeout, - checkpoint_dir=checkpoint_dir, - )) + +def wait_for_trainer_ready(checkpoint_dir, data_store=None, timeout=10): + """ + Block until trainer is ready + """ + + def wait(): + return os.path.exists(os.path.join(checkpoint_dir, SyncFiles.TRAINER_READY.value)) + + wait_for(wait, data_store, timeout) def should_stop(checkpoint_dir): @@ -69,17 +89,18 @@ def rollout_worker(graph_manager, data_store, num_workers, task_parameters): """ checkpoint_dir = task_parameters.checkpoint_restore_path wait_for_checkpoint(checkpoint_dir, data_store) + wait_for_trainer_ready(checkpoint_dir, data_store) graph_manager.create_graph(task_parameters) with graph_manager.phase_context(RunPhase.TRAIN): chkpt_state_reader = CheckpointStateReader(checkpoint_dir, checkpoint_state_optional=False) - last_checkpoint = 0 + last_checkpoint = chkpt_state_reader.get_latest().num # this worker should play a fraction of the total playing steps per rollout act_steps = graph_manager.agent_params.algorithm.num_consecutive_playing_steps / num_workers - - for i in range(graph_manager.improve_steps / act_steps): + training_steps = (graph_manager.improve_steps / act_steps.num_steps).num_steps + for i in range(training_steps): if should_stop(checkpoint_dir): break diff --git a/rl_coach/training_worker.py b/rl_coach/training_worker.py index 0ada249..6c5c838 100644 --- a/rl_coach/training_worker.py +++ b/rl_coach/training_worker.py @@ -24,24 +24,30 @@ from rl_coach import core_types from rl_coach.logger import screen -def data_store_ckpt_save(data_store): - while True: - data_store.save_to_store() - time.sleep(10) +def data_store_ckpt_load(data_store): + if data_store: + data_store.load_from_store() -def training_worker(graph_manager, task_parameters, is_multi_node_test): +def training_worker(graph_manager, task_parameters, data_store, is_multi_node_test): """ restore a checkpoint then perform rollouts using the restored model :param graph_manager: An instance of the graph manager :param task_parameters: An instance of task parameters :param is_multi_node_test: If this is a multi node test insted of a normal run. """ - # initialize graph - graph_manager.create_graph(task_parameters) + # Load checkpoint if provided + if task_parameters.checkpoint_restore_path: + data_store_ckpt_load(data_store) + # initialize graph + graph_manager.create_graph(task_parameters) - # save randomly initialized graph - graph_manager.save_checkpoint() + else: + # initialize graph + graph_manager.create_graph(task_parameters) + + # save randomly initialized graph + graph_manager.save_checkpoint() # training loop steps = 0 @@ -50,6 +56,7 @@ def training_worker(graph_manager, task_parameters, is_multi_node_test): eval_offset = 1 graph_manager.setup_memory_backend() + graph_manager.signal_ready() while steps < graph_manager.improve_steps.num_steps: