1
0
mirror of https://github.com/gryf/coach.git synced 2026-02-15 05:25:55 +01:00

Sync experiment dir, videos, gifs to S3. (#147)

This commit is contained in:
Balaji Subramaniam
2018-11-23 20:52:12 -08:00
committed by Ajay Deshpande
parent 5332013bd1
commit 13d2679af4
3 changed files with 21 additions and 4 deletions

View File

@@ -18,7 +18,10 @@ def construct_data_store_params(json: dict):
if json['store_type'] == 'nfs':
ds_params_instance = NFSDataStoreParameters(ds_params)
elif json['store_type'] == 's3':
ds_params_instance = S3DataStoreParameters(ds_params=ds_params, end_point=json['end_point'],
bucket_name=json['bucket_name'], checkpoint_dir=json['checkpoint_dir'])
ds_params_instance = S3DataStoreParameters(ds_params=ds_params,
end_point=json['end_point'],
bucket_name=json['bucket_name'],
checkpoint_dir=json['checkpoint_dir'],
expt_dir=json['expt_dir'])
return ds_params_instance

View File

@@ -12,13 +12,14 @@ import io
class S3DataStoreParameters(DataStoreParameters):
def __init__(self, ds_params, creds_file: str = None, end_point: str = None, bucket_name: str = None,
checkpoint_dir: str = None):
checkpoint_dir: str = None, expt_dir: str = None):
super().__init__(ds_params.store_type, ds_params.orchestrator_type, ds_params.orchestrator_params)
self.creds_file = creds_file
self.end_point = end_point
self.bucket_name = bucket_name
self.checkpoint_dir = checkpoint_dir
self.expt_dir = expt_dir
class S3DataStore(DataStore):
@@ -73,6 +74,18 @@ class S3DataStore(DataStore):
# release lock
self.mc.remove_object(self.params.bucket_name, SyncFiles.LOCKFILE.value)
if self.params.expt_dir and os.path.exists(self.params.expt_dir):
for filename in os.listdir(self.params.expt_dir):
if filename.endswith((".csv", ".json")):
self.mc.fput_object(self.params.bucket_name, filename, os.path.join(self.params.expt_dir, filename))
if self.params.expt_dir and os.path.exists(os.path.join(self.params.expt_dir, 'videos')):
for filename in os.listdir(os.path.join(self.params.expt_dir, 'videos')):
self.mc.fput_object(self.params.bucket_name, filename, os.path.join(self.params.expt_dir, 'videos', filename))
if self.params.expt_dir and os.path.exists(os.path.join(self.params.expt_dir, 'gifs')):
for filename in os.listdir(os.path.join(self.params.expt_dir, 'gifs')):
self.mc.fput_object(self.params.bucket_name, filename, os.path.join(self.params.expt_dir, 'gifs', filename))
except ResponseError as e:
print("Got exception: %s\n while saving to S3", e)