diff --git a/rl_coach/coach.py b/rl_coach/coach.py index 8db1cd0..8404a38 100644 --- a/rl_coach/coach.py +++ b/rl_coach/coach.py @@ -42,6 +42,7 @@ from rl_coach.memories.backend.redis import RedisPubSubMemoryBackendParameters from rl_coach.memories.backend.memory_impl import construct_memory_params from rl_coach.data_stores.data_store import DataStoreParameters from rl_coach.data_stores.s3_data_store import S3DataStoreParameters +from rl_coach.data_stores.nfs_data_store import NFSDataStoreParameters from rl_coach.data_stores.data_store_impl import get_data_store, construct_data_store_params from rl_coach.training_worker import training_worker from rl_coach.rollout_worker import rollout_worker, wait_for_checkpoint @@ -137,6 +138,9 @@ def handle_distributed_coach_orchestrator(graph_manager, args): ds_params = DataStoreParameters("s3", "", "") ds_params_instance = S3DataStoreParameters(ds_params=ds_params, end_point=args.s3_end_point, bucket_name=args.s3_bucket_name, creds_file=args.s3_creds_file, checkpoint_dir=ckpt_inside_container) + elif args.data_store == "nfs": + ds_params = DataStoreParameters("nfs", "kubernetes", "") + ds_params_instance = NFSDataStoreParameters(ds_params) worker_run_type_params = RunTypeParameters(args.image, rollout_command, run_type=str(RunType.ROLLOUT_WORKER), num_replicas=args.num_workers) trainer_run_type_params = RunTypeParameters(args.image, trainer_command, run_type=str(RunType.TRAINER)) @@ -328,16 +332,17 @@ class CoachLauncher(object): args.image = coach_config.get('coach', 'image') args.memory_backend = coach_config.get('coach', 'memory_backend') args.data_store = coach_config.get('coach', 'data_store') - args.s3_end_point = coach_config.get('coach', 's3_end_point') - args.s3_bucket_name = coach_config.get('coach', 's3_bucket_name') - args.s3_creds_file = coach_config.get('coach', 's3_creds_file') + if args.data_store == 's3': + args.s3_end_point = coach_config.get('coach', 's3_end_point') + args.s3_bucket_name = coach_config.get('coach', 's3_bucket_name') + args.s3_creds_file = coach_config.get('coach', 's3_creds_file') except Error as e: screen.error("Error when reading distributed Coach config file: {}".format(e)) if args.image == '': screen.error("Image cannot be empty.") - data_store_choices = ['s3'] + data_store_choices = ['s3', 'nfs'] if args.data_store not in data_store_choices: screen.warning("{} data store is unsupported.".format(args.data_store)) screen.error("Supported data stores are {}.".format(data_store_choices)) @@ -347,11 +352,11 @@ class CoachLauncher(object): screen.warning("{} memory backend is not supported.".format(args.memory_backend)) screen.error("Supported memory backends are {}.".format(memory_backend_choices)) - if args.s3_bucket_name == '': - screen.error("S3 bucket name cannot be empty.") - - if args.s3_creds_file == '': - args.s3_creds_file = None + if args.data_store == 's3': + if args.s3_bucket_name == '': + screen.error("S3 bucket name cannot be empty.") + if args.s3_creds_file == '': + args.s3_creds_file = None if args.play and args.distributed_coach: screen.error("Playing is not supported in distributed Coach.") diff --git a/rl_coach/data_stores/nfs_data_store.py b/rl_coach/data_stores/nfs_data_store.py index ba2e057..6b93ce4 100644 --- a/rl_coach/data_stores/nfs_data_store.py +++ b/rl_coach/data_stores/nfs_data_store.py @@ -135,9 +135,9 @@ class NFSDataStore(DataStore): ) try: - k8s_core_v1_api_client.create_namespaced_service(self.params.namespace, service) + svc_response = k8s_core_v1_api_client.create_namespaced_service(self.params.namespace, service) self.params.svc_name = svc_name - self.params.server = 'nfs-service.{}.svc.cluster.local'.format(self.params.namespace) + self.params.server = svc_response.spec.cluster_ip except k8sclient.rest.ApiException as e: print("Got exception: %s\n while creating a service for nfs-server", e) return False diff --git a/rl_coach/rollout_worker.py b/rl_coach/rollout_worker.py index 76d55ea..758d583 100644 --- a/rl_coach/rollout_worker.py +++ b/rl_coach/rollout_worker.py @@ -68,6 +68,8 @@ def get_latest_checkpoint(checkpoint_dir): rel_path = os.path.relpath(ckpt.model_checkpoint_path, checkpoint_dir) return int(rel_path.split('_Step')[0]) + return 0 + def should_stop(checkpoint_dir): return os.path.exists(os.path.join(checkpoint_dir, SyncFiles.FINISHED.value))