diff --git a/rl_coach/data_stores/__init__.py b/rl_coach/data_stores/__init__.py new file mode 100644 index 0000000..cf26739 --- /dev/null +++ b/rl_coach/data_stores/__init__.py @@ -0,0 +1,15 @@ +# +# Copyright (c) 2017 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/rl_coach/data_stores/data_store_impl.py b/rl_coach/data_stores/data_store_impl.py index 9727fc1..d98dfcd 100644 --- a/rl_coach/data_stores/data_store_impl.py +++ b/rl_coach/data_stores/data_store_impl.py @@ -1,5 +1,6 @@ from rl_coach.data_stores.nfs_data_store import NFSDataStore, NFSDataStoreParameters from rl_coach.data_stores.s3_data_store import S3DataStore, S3DataStoreParameters +from rl_coach.data_stores.data_store import DataStoreParameters def get_data_store(params): @@ -10,3 +11,14 @@ def get_data_store(params): data_store = S3DataStore(params) return data_store + +def construct_data_store_params(json: dict): + ds_params_instance = None + ds_params = DataStoreParameters(json['store_type'], json['orchestrator_type'], json['orchestrator_params']) + if json['store_type'] == 'nfs': + ds_params_instance = NFSDataStoreParameters(ds_params) + elif json['store_type'] == 's3': + ds_params_instance = S3DataStoreParameters(ds_params=ds_params, end_point=json['end_point'], + bucket_name=json['bucket_name'], checkpoint_dir=json['checkpoint_dir']) + + return ds_params_instance diff --git a/rl_coach/data_stores/s3_data_store.py b/rl_coach/data_stores/s3_data_store.py index 1b623e2..5ceb00a 100644 --- a/rl_coach/data_stores/s3_data_store.py +++ b/rl_coach/data_stores/s3_data_store.py @@ -46,6 +46,7 @@ class S3DataStore(DataStore): def save_to_store(self): try: + print("saving to s3") for root, dirs, files in os.walk(self.params.checkpoint_dir): for filename in files: abs_name = os.path.abspath(os.path.join(root, filename)) @@ -56,6 +57,7 @@ class S3DataStore(DataStore): def load_from_store(self): try: + print("loading from s3") objects = self.mc.list_objects_v2(self.params.bucket_name, recursive=True) for obj in objects: filename = os.path.abspath(os.path.join(self.params.checkpoint_dir, obj.object_name)) diff --git a/rl_coach/graph_managers/graph_manager.py b/rl_coach/graph_managers/graph_manager.py index 230c90b..5ee24ed 100644 --- a/rl_coach/graph_managers/graph_manager.py +++ b/rl_coach/graph_managers/graph_manager.py @@ -30,7 +30,12 @@ from rl_coach.core_types import TotalStepsCounter, RunPhase, PlayingStepsType, T from rl_coach.environments.environment import Environment from rl_coach.level_manager import LevelManager from rl_coach.logger import screen, Logger +<<<<<<< HEAD from rl_coach.utils import set_cpu, start_shell_command_and_wait +======= +from rl_coach.utils import set_cpu +from rl_coach.data_stores.data_store_impl import get_data_store +>>>>>>> Make distributed coach work end-to-end. class ScheduleParameters(Parameters): @@ -367,6 +372,11 @@ class GraphManager(object): """ self.verify_graph_was_created() + if hasattr(self, 'data_store_params') and hasattr(self.agent_params.memory, 'memory_backend_params'): + if self.agent_params.memory.memory_backend_params.run_type == "worker": + data_store = get_data_store(self.data_store_params) + data_store.load_from_store() + # perform several steps of playing result = None @@ -522,6 +532,11 @@ class GraphManager(object): self.checkpoint_id += 1 self.last_checkpoint_saving_time = time.time() + if hasattr(self, 'data_store_params'): + data_store = get_data_store(self.data_store_params) + data_store.save_to_store() + + def improve(self): """ The main loop of the run. diff --git a/rl_coach/orchestrators/kubernetes_orchestrator.py b/rl_coach/orchestrators/kubernetes_orchestrator.py index b22b681..d690623 100644 --- a/rl_coach/orchestrators/kubernetes_orchestrator.py +++ b/rl_coach/orchestrators/kubernetes_orchestrator.py @@ -4,9 +4,11 @@ import json import time from typing import List from rl_coach.orchestrators.deploy import Deploy, DeployParameters -from kubernetes import client, config +from kubernetes import client as k8sclient, config as k8sconfig from rl_coach.memories.backend.memory import MemoryBackendParameters from rl_coach.memories.backend.memory_impl import get_memory_backend +from rl_coach.data_stores.data_store import DataStoreParameters +from rl_coach.data_stores.data_store_impl import get_data_store class RunTypeParameters(): @@ -29,8 +31,9 @@ class RunTypeParameters(): class KubernetesParameters(DeployParameters): - def __init__(self, run_type_params: List[RunTypeParameters], kubeconfig: str = None, namespace: str = "", nfs_server: str = None, - nfs_path: str = None, checkpoint_dir: str = '/checkpoint', memory_backend_parameters: MemoryBackendParameters = None): + def __init__(self, run_type_params: List[RunTypeParameters], kubeconfig: str = None, namespace: str = None, + nfs_server: str = None, nfs_path: str = None, checkpoint_dir: str = '/checkpoint', + memory_backend_parameters: MemoryBackendParameters = None, data_store_params: DataStoreParameters = None): self.run_type_params = {} for run_type_param in run_type_params: @@ -41,195 +44,204 @@ class KubernetesParameters(DeployParameters): self.nfs_path = nfs_path self.checkpoint_dir = checkpoint_dir self.memory_backend_parameters = memory_backend_parameters + self.data_store_params = data_store_params class Kubernetes(Deploy): - def __init__(self, deploy_parameters: KubernetesParameters): - super().__init__(deploy_parameters) - self.deploy_parameters = deploy_parameters - if self.deploy_parameters.kubeconfig: - config.load_kube_config() + def __init__(self, params: KubernetesParameters): + super().__init__(params) + self.params = params + if self.params.kubeconfig: + k8sconfig.load_kube_config() else: - config.load_incluster_config() + k8sconfig.load_incluster_config() + + if not self.params.namespace: + _, current_context = k8sconfig.list_kube_config_contexts() + self.params.namespace = current_context['context']['namespace'] - if not self.deploy_parameters.namespace: - _, current_context = config.list_kube_config_contexts() - self.deploy_parameters.namespace = current_context['context']['namespace'] self.nfs_pvc_name = 'nfs-checkpoint-pvc' if os.environ.get('http_proxy'): - client.Configuration._default.proxy = os.environ.get('http_proxy') + k8sclient.Configuration._default.proxy = os.environ.get('http_proxy') - self.deploy_parameters.memory_backend_parameters.orchestrator_params = {'namespace': self.deploy_parameters.namespace} - self.memory_backend = get_memory_backend(self.deploy_parameters.memory_backend_parameters) + self.params.memory_backend_parameters.orchestrator_params = {'namespace': self.params.namespace} + self.memory_backend = get_memory_backend(self.params.memory_backend_parameters) + + self.params.data_store_params.orchestrator_params = {'namespace': self.params.namespace} + self.data_store = get_data_store(self.params.data_store_params) + + if self.params.data_store_params.store_type == "s3": + self.s3_access_key = None + self.s3_secret_key = None + if self.params.data_store_params.creds_file: + s3config = ConfigParser() + s3config.read(self.params.data_store_params.creds_file) + try: + self.s3_access_key = s3config.get('default', 'aws_access_key_id') + self.s3_secret_key = s3config.get('default', 'aws_secret_access_key') + except Error as e: + print("Error when reading S3 credentials file: %s", e) + else: + self.s3_access_key = os.environ.get('ACCESS_KEY_ID') + self.s3_secret_key = os.environ.get('SECRET_ACCESS_KEY') def setup(self) -> bool: self.memory_backend.deploy() - if not self.create_nfs_resources(): - return False - return True - - def create_nfs_resources(self): - persistent_volume = client.V1PersistentVolume( - api_version="v1", - kind="PersistentVolume", - metadata=client.V1ObjectMeta( - name='nfs-checkpoint-pv', - labels={'app': 'nfs-checkpoint-pv'} - ), - spec=client.V1PersistentVolumeSpec( - access_modes=["ReadWriteMany"], - nfs=client.V1NFSVolumeSource( - path=self.deploy_parameters.nfs_path, - server=self.deploy_parameters.nfs_server - ), - capacity={'storage': '10Gi'}, - storage_class_name="" - ) - ) - api_client = client.CoreV1Api() - try: - api_client.create_persistent_volume(persistent_volume) - except client.rest.ApiException as e: - print("Got exception: %s\n while creating the NFS PV", e) - return False - - persistent_volume_claim = client.V1PersistentVolumeClaim( - api_version="v1", - kind="PersistentVolumeClaim", - metadata=client.V1ObjectMeta( - name="nfs-checkpoint-pvc" - ), - spec=client.V1PersistentVolumeClaimSpec( - access_modes=["ReadWriteMany"], - resources=client.V1ResourceRequirements( - requests={'storage': '10Gi'} - ), - selector=client.V1LabelSelector( - match_labels={'app': 'nfs-checkpoint-pv'} - ), - storage_class_name="" - ) - ) - - try: - api_client.create_namespaced_persistent_volume_claim(self.deploy_parameters.namespace, persistent_volume_claim) - except client.rest.ApiException as e: - print("Got exception: %s\n while creating the NFS PVC", e) + if not self.data_store.deploy(): return False return True def deploy_trainer(self) -> bool: - trainer_params = self.deploy_parameters.run_type_params.get('trainer', None) + trainer_params = self.params.run_type_params.get('trainer', None) if not trainer_params: return False - trainer_params.command += ['--memory_backend_params', json.dumps(self.deploy_parameters.memory_backend_parameters.__dict__)] + trainer_params.command += ['--memory_backend_params', json.dumps(self.params.memory_backend_parameters.__dict__)] + trainer_params.command += ['--data_store_params', json.dumps(self.params.data_store_params.__dict__)] + name = "{}-{}".format(trainer_params.run_type, uuid.uuid4()) - container = client.V1Container( - name=name, - image=trainer_params.image, - command=trainer_params.command, - args=trainer_params.arguments, - image_pull_policy='Always', - volume_mounts=[client.V1VolumeMount( - name='nfs-pvc', - mount_path=trainer_params.checkpoint_dir - )] - ) - template = client.V1PodTemplateSpec( - metadata=client.V1ObjectMeta(labels={'app': name}), - spec=client.V1PodSpec( - containers=[container], - volumes=[client.V1Volume( - name="nfs-pvc", - persistent_volume_claim=client.V1PersistentVolumeClaimVolumeSource( - claim_name=self.nfs_pvc_name - ) + if self.params.data_store_params.store_type == "nfs": + container = k8sclient.V1Container( + name=name, + image=trainer_params.image, + command=trainer_params.command, + args=trainer_params.arguments, + image_pull_policy='Always', + volume_mounts=[k8sclient.V1VolumeMount( + name='nfs-pvc', + mount_path=trainer_params.checkpoint_dir )] - ), - ) - deployment_spec = client.V1DeploymentSpec( + ) + template = k8sclient.V1PodTemplateSpec( + metadata=k8sclient.V1ObjectMeta(labels={'app': name}), + spec=k8sclient.V1PodSpec( + containers=[container], + volumes=[k8sclient.V1Volume( + name="nfs-pvc", + persistent_volume_claim=k8sclient.V1PersistentVolumeClaimVolumeSource( + claim_name=self.nfs_pvc_name + ) + )] + ), + ) + else: + container = k8sclient.V1Container( + name=name, + image=trainer_params.image, + command=trainer_params.command, + args=trainer_params.arguments, + image_pull_policy='Always', + env=[k8sclient.V1EnvVar("ACCESS_KEY_ID", self.s3_access_key), + k8sclient.V1EnvVar("SECRET_ACCESS_KEY", self.s3_secret_key)] + ) + template = k8sclient.V1PodTemplateSpec( + metadata=k8sclient.V1ObjectMeta(labels={'app': name}), + spec=k8sclient.V1PodSpec( + containers=[container] + ), + ) + + deployment_spec = k8sclient.V1DeploymentSpec( replicas=trainer_params.num_replicas, template=template, - selector=client.V1LabelSelector( + selector=k8sclient.V1LabelSelector( match_labels={'app': name} ) ) - deployment = client.V1Deployment( + deployment = k8sclient.V1Deployment( api_version='apps/v1', kind='Deployment', - metadata=client.V1ObjectMeta(name=name), + metadata=k8sclient.V1ObjectMeta(name=name), spec=deployment_spec ) - api_client = client.AppsV1Api() + api_client = k8sclient.AppsV1Api() try: - api_client.create_namespaced_deployment(self.deploy_parameters.namespace, deployment) + api_client.create_namespaced_deployment(self.params.namespace, deployment) trainer_params.orchestration_params['deployment_name'] = name return True - except client.rest.ApiException as e: + except k8sclient.rest.ApiException as e: print("Got exception: %s\n while creating deployment", e) return False def deploy_worker(self): - worker_params = self.deploy_parameters.run_type_params.get('worker', None) + worker_params = self.params.run_type_params.get('worker', None) if not worker_params: return False - worker_params.command += ['--memory_backend_params', json.dumps(self.deploy_parameters.memory_backend_parameters.__dict__)] + worker_params.command += ['--memory_backend_params', json.dumps(self.params.memory_backend_parameters.__dict__)] + worker_params.command += ['--data_store_params', json.dumps(self.params.data_store_params.__dict__)] + name = "{}-{}".format(worker_params.run_type, uuid.uuid4()) - container = client.V1Container( - name=name, - image=worker_params.image, - command=worker_params.command, - args=worker_params.arguments, - image_pull_policy='Always', - volume_mounts=[client.V1VolumeMount( - name='nfs-pvc', - mount_path=worker_params.checkpoint_dir - )] - ) - template = client.V1PodTemplateSpec( - metadata=client.V1ObjectMeta(labels={'app': name}), - spec=client.V1PodSpec( - containers=[container], - volumes=[client.V1Volume( - name="nfs-pvc", - persistent_volume_claim=client.V1PersistentVolumeClaimVolumeSource( - claim_name=self.nfs_pvc_name - ) - )], - ), - ) + if self.params.data_store_params.store_type == "nfs": + container = k8sclient.V1Container( + name=name, + image=worker_params.image, + command=worker_params.command, + args=worker_params.arguments, + image_pull_policy='Always', + volume_mounts=[k8sclient.V1VolumeMount( + name='nfs-pvc', + mount_path=worker_params.checkpoint_dir + )] + ) + template = k8sclient.V1PodTemplateSpec( + metadata=k8sclient.V1ObjectMeta(labels={'app': name}), + spec=k8sclient.V1PodSpec( + containers=[container], + volumes=[k8sclient.V1Volume( + name="nfs-pvc", + persistent_volume_claim=k8sclient.V1PersistentVolumeClaimVolumeSource( + claim_name=self.nfs_pvc_name + ) + )], + ), + ) + else: + container = k8sclient.V1Container( + name=name, + image=worker_params.image, + command=worker_params.command, + args=worker_params.arguments, + image_pull_policy='Always', + env=[k8sclient.V1EnvVar("ACCESS_KEY_ID", self.s3_access_key), + k8sclient.V1EnvVar("SECRET_ACCESS_KEY", self.s3_secret_key)] + ) + template = k8sclient.V1PodTemplateSpec( + metadata=k8sclient.V1ObjectMeta(labels={'app': name}), + spec=k8sclient.V1PodSpec( + containers=[container] + ) + ) - deployment_spec = client.V1DeploymentSpec( + deployment_spec = k8sclient.V1DeploymentSpec( replicas=worker_params.num_replicas, template=template, - selector=client.V1LabelSelector( + selector=k8sclient.V1LabelSelector( match_labels={'app': name} ) ) - deployment = client.V1Deployment( + deployment = k8sclient.V1Deployment( api_version='apps/v1', kind="Deployment", - metadata=client.V1ObjectMeta(name=name), + metadata=k8sclient.V1ObjectMeta(name=name), spec=deployment_spec ) - api_client = client.AppsV1Api() + api_client = k8sclient.AppsV1Api() try: - api_client.create_namespaced_deployment(self.deploy_parameters.namespace, deployment) + api_client.create_namespaced_deployment(self.params.namespace, deployment) worker_params.orchestration_params['deployment_name'] = name return True - except client.rest.ApiException as e: + except k8sclient.rest.ApiException as e: print("Got exception: %s\n while creating deployment", e) return False @@ -237,19 +249,19 @@ class Kubernetes(Deploy): pass def trainer_logs(self): - trainer_params = self.deploy_parameters.run_type_params.get('trainer', None) + trainer_params = self.params.run_type_params.get('trainer', None) if not trainer_params: return - api_client = client.CoreV1Api() + api_client = k8sclient.CoreV1Api() pod = None try: - pods = api_client.list_namespaced_pod(self.deploy_parameters.namespace, label_selector='app={}'.format( + pods = api_client.list_namespaced_pod(self.params.namespace, label_selector='app={}'.format( trainer_params.orchestration_params['deployment_name'] )) pod = pods.items[0] - except client.rest.ApiException as e: + except k8sclient.rest.ApiException as e: print("Got exception: %s\n while reading pods", e) return @@ -264,17 +276,17 @@ class Kubernetes(Deploy): # Try to tail the pod logs try: print(corev1_api.read_namespaced_pod_log( - pod_name, self.deploy_parameters.namespace, follow=True + pod_name, self.params.namespace, follow=True ), flush=True) - except client.rest.ApiException as e: + except k8sclient.rest.ApiException as e: pass # This part will get executed if the pod is one of the following phases: not ready, failed or terminated. # Check if the pod has errored out, else just try again. # Get the pod try: - pod = corev1_api.read_namespaced_pod(pod_name, self.deploy_parameters.namespace) - except client.rest.ApiException as e: + pod = corev1_api.read_namespaced_pod(pod_name, self.params.namespace) + except k8sclient.rest.ApiException as e: continue if not hasattr(pod, 'status') or not pod.status: @@ -293,18 +305,19 @@ class Kubernetes(Deploy): return def undeploy(self): - trainer_params = self.deploy_parameters.run_type_params.get('trainer', None) - api_client = client.AppsV1Api() - delete_options = client.V1DeleteOptions() + trainer_params = self.params.run_type_params.get('trainer', None) + api_client = k8sclient.AppsV1Api() + delete_options = k8sclient.V1DeleteOptions() if trainer_params: try: - api_client.delete_namespaced_deployment(trainer_params.orchestration_params['deployment_name'], self.deploy_parameters.namespace, delete_options) - except client.rest.ApiException as e: + api_client.delete_namespaced_deployment(trainer_params.orchestration_params['deployment_name'], self.params.namespace, delete_options) + except k8sclient.rest.ApiException as e: print("Got exception: %s\n while deleting trainer", e) - worker_params = self.deploy_parameters.run_type_params.get('worker', None) + worker_params = self.params.run_type_params.get('worker', None) if worker_params: try: - api_client.delete_namespaced_deployment(worker_params.orchestration_params['deployment_name'], self.deploy_parameters.namespace, delete_options) - except client.rest.ApiException as e: + api_client.delete_namespaced_deployment(worker_params.orchestration_params['deployment_name'], self.params.namespace, delete_options) + except k8sclient.rest.ApiException as e: print("Got exception: %s\n while deleting workers", e) self.memory_backend.undeploy() + self.data_store.undeploy() diff --git a/rl_coach/orchestrators/start_training.py b/rl_coach/orchestrators/start_training.py index 547bb5c..05af288 100644 --- a/rl_coach/orchestrators/start_training.py +++ b/rl_coach/orchestrators/start_training.py @@ -2,19 +2,36 @@ import argparse from rl_coach.orchestrators.kubernetes_orchestrator import KubernetesParameters, Kubernetes, RunTypeParameters from rl_coach.memories.backend.redis import RedisPubSubMemoryBackendParameters +from rl_coach.data_stores.data_store import DataStoreParameters +from rl_coach.data_stores.s3_data_store import S3DataStoreParameters +from rl_coach.data_stores.nfs_data_store import NFSDataStoreParameters -def main(preset: str, image: str='ajaysudh/testing:coach', num_workers: int=1, nfs_server: str="", nfs_path: str="", memory_backend: str=""): +def main(preset: str, image: str='ajaysudh/testing:coach', num_workers: int=1, nfs_server: str=None, nfs_path: str=None, + memory_backend: str=None, data_store: str=None, s3_end_point: str=None, s3_bucket_name: str=None): rollout_command = ['python3', 'rl_coach/rollout_worker.py', '-p', preset] training_command = ['python3', 'rl_coach/training_worker.py', '-p', preset] - memory_backend_params = RedisPubSubMemoryBackendParameters() + memory_backend_params = None + if memory_backend == "redispubsub": + memory_backend_params = RedisPubSubMemoryBackendParameters() + + ds_params_instance = None + if data_store == "s3": + ds_params = DataStoreParameters("s3", "", "") + ds_params_instance = S3DataStoreParameters(ds_params=ds_params, end_point=s3_end_point, bucket_name=s3_bucket_name, + checkpoint_dir="/checkpoint") + elif data_store == "nfs": + ds_params = DataStoreParameters("nfs", "kubernetes", {"namespace": "default"}) + ds_params_instance = NFSDataStoreParameters(ds_params) worker_run_type_params = RunTypeParameters(image, rollout_command, run_type="worker") trainer_run_type_params = RunTypeParameters(image, training_command, run_type="trainer") - orchestration_params = KubernetesParameters([worker_run_type_params, trainer_run_type_params], kubeconfig='~/.kube/config', nfs_server=nfs_server, - nfs_path=nfs_path, memory_backend_parameters=memory_backend_params) + orchestration_params = KubernetesParameters([worker_run_type_params, trainer_run_type_params], + kubeconfig='~/.kube/config', nfs_server=nfs_server, nfs_path=nfs_path, + memory_backend_parameters=memory_backend_params, + data_store_params=ds_params_instance) orchestrator = Kubernetes(orchestration_params) if not orchestrator.setup(): print("Could not setup") @@ -36,7 +53,7 @@ def main(preset: str, image: str='ajaysudh/testing:coach', num_workers: int=1, n orchestrator.trainer_logs() except KeyboardInterrupt: pass - orchestrator.undeploy() + # orchestrator.undeploy() if __name__ == '__main__': @@ -46,21 +63,33 @@ if __name__ == '__main__': type=str, required=True) parser.add_argument('-p', '--preset', - help="(string) Name of a preset to run (class name from the 'presets' directory.)", + help="(string) Name of a preset to run (class name from the 'presets' directory).", type=str, required=True) + parser.add_argument('--memory-backend', + help="(string) Memory backend to use.", + type=str, + default="redispubsub") + parser.add_argument('-ds', '--data-store', + help="(string) Data store to use.", + type=str, + default="s3") parser.add_argument('-ns', '--nfs-server', - help="(string) Addresss of the nfs server.)", + help="(string) Addresss of the nfs server.", type=str, required=True) parser.add_argument('-np', '--nfs-path', - help="(string) Exported path for the nfs server", + help="(string) Exported path for the nfs server.", type=str, required=True) - parser.add_argument('--memory_backend', - help="(string) Memory backend to use", + parser.add_argument('--s3-end-point', + help="(string) S3 endpoint to use when S3 data store is used.", type=str, - default="redispubsub") + required=True) + parser.add_argument('--s3-bucket-name', + help="(string) S3 bucket name to use when S3 data store is used.", + type=str, + required=True) # parser.add_argument('--checkpoint_dir', # help='(string) Path to a folder containing a checkpoint to write the model to.', @@ -68,4 +97,6 @@ if __name__ == '__main__': # default='/checkpoint') args = parser.parse_args() - main(preset=args.preset, image=args.image, nfs_server=args.nfs_server, nfs_path=args.nfs_path, memory_backend=args.memory_backend) + main(preset=args.preset, image=args.image, nfs_server=args.nfs_server, nfs_path=args.nfs_path, + memory_backend=args.memory_backend, data_store=args.data_store, s3_end_point=args.s3_end_point, + s3_bucket_name=args.s3_bucket_name) diff --git a/rl_coach/rollout_worker.py b/rl_coach/rollout_worker.py index 7d8a371..b2f98d8 100644 --- a/rl_coach/rollout_worker.py +++ b/rl_coach/rollout_worker.py @@ -12,11 +12,14 @@ import time import os import json +from threading import Thread + from rl_coach.base_parameters import TaskParameters from rl_coach.coach import expand_preset from rl_coach.core_types import EnvironmentEpisodes, RunPhase from rl_coach.utils import short_dynamic_import from rl_coach.memories.backend.memory_impl import construct_memory_params +from rl_coach.data_stores.data_store_impl import get_data_store, construct_data_store_params # Q: specify alternative distributed memory, or should this go in the preset? @@ -27,17 +30,23 @@ def has_checkpoint(checkpoint_dir): """ True if a checkpoint is present in checkpoint_dir """ - return len(os.listdir(checkpoint_dir)) > 0 + if os.path.isdir(checkpoint_dir): + if len(os.listdir(checkpoint_dir)) > 0: + return os.path.isfile(os.path.join(checkpoint_dir, "checkpoint")) + return False -def wait_for_checkpoint(checkpoint_dir, timeout=10): +def wait_for_checkpoint(checkpoint_dir, data_store=None, timeout=10): """ block until there is a checkpoint in checkpoint_dir """ for i in range(timeout): + if data_store: + data_store.load_from_store() + if has_checkpoint(checkpoint_dir): return - time.sleep(1) + time.sleep(10) # one last time if has_checkpoint(checkpoint_dir): @@ -52,20 +61,26 @@ def wait_for_checkpoint(checkpoint_dir, timeout=10): )) +def data_store_ckpt_load(data_store): + while True: + data_store.load_from_store() + time.sleep(10) + def rollout_worker(graph_manager, checkpoint_dir): """ - restore a checkpoint then perform rollouts using the restored model + wait for first checkpoint then perform rollouts using the model """ wait_for_checkpoint(checkpoint_dir) task_parameters = TaskParameters() task_parameters.__dict__['checkpoint_restore_dir'] = checkpoint_dir + time.sleep(30) graph_manager.create_graph(task_parameters) graph_manager.phase = RunPhase.TRAIN for i in range(10000000): - graph_manager.act(EnvironmentEpisodes(num_steps=10)) graph_manager.restore_checkpoint() + graph_manager.act(EnvironmentEpisodes(num_steps=10)) graph_manager.phase = RunPhase.UNDEFINED @@ -91,6 +106,9 @@ def main(): parser.add_argument('--memory_backend_params', help="(string) JSON string of the memory backend params", type=str) + parser.add_argument('--data_store_params', + help="(string) JSON string of the data store params", + type=str) args = parser.parse_args() @@ -98,9 +116,20 @@ def main(): if args.memory_backend_params: args.memory_backend_params = json.loads(args.memory_backend_params) - if 'run_type' not in args.memory_backend_params: - args.memory_backend_params['run_type'] = 'worker' + print(args.memory_backend_params) + args.memory_backend_params['run_type'] = 'worker' + print(construct_memory_params(args.memory_backend_params)) graph_manager.agent_params.memory.register_var('memory_backend_params', construct_memory_params(args.memory_backend_params)) + + if args.data_store_params: + data_store_params = construct_data_store_params(json.loads(args.data_store_params)) + data_store_params.checkpoint_dir = args.checkpoint_dir + graph_manager.data_store_params = data_store_params + data_store = get_data_store(data_store_params) + wait_for_checkpoint(checkpoint_dir=args.checkpoint_dir, data_store=data_store) + # thread = Thread(target = data_store_ckpt_load, args = [data_store]) + # thread.start() + rollout_worker( graph_manager=graph_manager, checkpoint_dir=args.checkpoint_dir, diff --git a/rl_coach/training_worker.py b/rl_coach/training_worker.py index 85f2052..4da27ed 100644 --- a/rl_coach/training_worker.py +++ b/rl_coach/training_worker.py @@ -4,15 +4,19 @@ import argparse import time import json +from threading import Thread + from rl_coach.base_parameters import TaskParameters from rl_coach.coach import expand_preset from rl_coach import core_types from rl_coach.utils import short_dynamic_import from rl_coach.memories.backend.memory_impl import construct_memory_params +from rl_coach.data_stores.data_store_impl import get_data_store, construct_data_store_params -# Q: specify alternative distributed memory, or should this go in the preset? -# A: preset must define distributed memory to be used. we aren't going to take a non-distributed preset and automatically distribute it. - +def data_store_ckpt_save(data_store): + while True: + data_store.save_to_store() + time.sleep(10) def training_worker(graph_manager, checkpoint_dir): """ @@ -58,16 +62,26 @@ def main(): parser.add_argument('--memory_backend_params', help="(string) JSON string of the memory backend params", type=str) + parser.add_argument('--data_store_params', + help="(string) JSON string of the data store params", + type=str) args = parser.parse_args() graph_manager = short_dynamic_import(expand_preset(args.preset), ignore_module_case=True) if args.memory_backend_params: args.memory_backend_params = json.loads(args.memory_backend_params) - if 'run_type' not in args.memory_backend_params: - args.memory_backend_params['run_type'] = 'trainer' + args.memory_backend_params['run_type'] = 'trainer' graph_manager.agent_params.memory.register_var('memory_backend_params', construct_memory_params(args.memory_backend_params)) + if args.data_store_params: + data_store_params = construct_data_store_params(json.loads(args.data_store_params)) + data_store_params.checkpoint_dir = args.checkpoint_dir + graph_manager.data_store_params = data_store_params + # data_store = get_data_store(data_store_params) + # thread = Thread(target = data_store_ckpt_save, args = [data_store]) + # thread.start() + training_worker( graph_manager=graph_manager, checkpoint_dir=args.checkpoint_dir,