diff --git a/rl_coach/coach.py b/rl_coach/coach.py index 9f61baf..6c4e6e8 100644 --- a/rl_coach/coach.py +++ b/rl_coach/coach.py @@ -96,6 +96,7 @@ def handle_distributed_coach_tasks(graph_manager, args, task_parameters): data_store_params = None if args.data_store_params: data_store_params = construct_data_store_params(json.loads(args.data_store_params)) + data_store_params.expt_dir = args.experiment_path data_store_params.checkpoint_dir = ckpt_inside_container graph_manager.data_store_params = data_store_params @@ -151,7 +152,7 @@ def handle_distributed_coach_orchestrator(args): if args.data_store == "s3": ds_params = DataStoreParameters("s3", "", "") ds_params_instance = S3DataStoreParameters(ds_params=ds_params, end_point=args.s3_end_point, bucket_name=args.s3_bucket_name, - creds_file=args.s3_creds_file, checkpoint_dir=ckpt_inside_container) + creds_file=args.s3_creds_file, checkpoint_dir=ckpt_inside_container, expt_dir=args.experiment_path) elif args.data_store == "nfs": ds_params = DataStoreParameters("nfs", "kubernetes", "") ds_params_instance = NFSDataStoreParameters(ds_params) diff --git a/rl_coach/data_stores/data_store_impl.py b/rl_coach/data_stores/data_store_impl.py index d98dfcd..eb50d21 100644 --- a/rl_coach/data_stores/data_store_impl.py +++ b/rl_coach/data_stores/data_store_impl.py @@ -18,7 +18,10 @@ def construct_data_store_params(json: dict): if json['store_type'] == 'nfs': ds_params_instance = NFSDataStoreParameters(ds_params) elif json['store_type'] == 's3': - ds_params_instance = S3DataStoreParameters(ds_params=ds_params, end_point=json['end_point'], - bucket_name=json['bucket_name'], checkpoint_dir=json['checkpoint_dir']) + ds_params_instance = S3DataStoreParameters(ds_params=ds_params, + end_point=json['end_point'], + bucket_name=json['bucket_name'], + checkpoint_dir=json['checkpoint_dir'], + expt_dir=json['expt_dir']) return ds_params_instance diff --git a/rl_coach/data_stores/s3_data_store.py b/rl_coach/data_stores/s3_data_store.py index 7a643a1..01eaf55 100644 --- a/rl_coach/data_stores/s3_data_store.py +++ b/rl_coach/data_stores/s3_data_store.py @@ -12,13 +12,14 @@ import io class S3DataStoreParameters(DataStoreParameters): def __init__(self, ds_params, creds_file: str = None, end_point: str = None, bucket_name: str = None, - checkpoint_dir: str = None): + checkpoint_dir: str = None, expt_dir: str = None): super().__init__(ds_params.store_type, ds_params.orchestrator_type, ds_params.orchestrator_params) self.creds_file = creds_file self.end_point = end_point self.bucket_name = bucket_name self.checkpoint_dir = checkpoint_dir + self.expt_dir = expt_dir class S3DataStore(DataStore): @@ -73,6 +74,18 @@ class S3DataStore(DataStore): # release lock self.mc.remove_object(self.params.bucket_name, SyncFiles.LOCKFILE.value) + if self.params.expt_dir and os.path.exists(self.params.expt_dir): + for filename in os.listdir(self.params.expt_dir): + if filename.endswith((".csv", ".json")): + self.mc.fput_object(self.params.bucket_name, filename, os.path.join(self.params.expt_dir, filename)) + + if self.params.expt_dir and os.path.exists(os.path.join(self.params.expt_dir, 'videos')): + for filename in os.listdir(os.path.join(self.params.expt_dir, 'videos')): + self.mc.fput_object(self.params.bucket_name, filename, os.path.join(self.params.expt_dir, 'videos', filename)) + + if self.params.expt_dir and os.path.exists(os.path.join(self.params.expt_dir, 'gifs')): + for filename in os.listdir(os.path.join(self.params.expt_dir, 'gifs')): + self.mc.fput_object(self.params.bucket_name, filename, os.path.join(self.params.expt_dir, 'gifs', filename)) except ResponseError as e: print("Got exception: %s\n while saving to S3", e)