mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 11:10:20 +01:00
Sync experiment dir, videos, gifs to S3. (#147)
This commit is contained in:
committed by
Ajay Deshpande
parent
5332013bd1
commit
13d2679af4
@@ -96,6 +96,7 @@ def handle_distributed_coach_tasks(graph_manager, args, task_parameters):
|
|||||||
data_store_params = None
|
data_store_params = None
|
||||||
if args.data_store_params:
|
if args.data_store_params:
|
||||||
data_store_params = construct_data_store_params(json.loads(args.data_store_params))
|
data_store_params = construct_data_store_params(json.loads(args.data_store_params))
|
||||||
|
data_store_params.expt_dir = args.experiment_path
|
||||||
data_store_params.checkpoint_dir = ckpt_inside_container
|
data_store_params.checkpoint_dir = ckpt_inside_container
|
||||||
graph_manager.data_store_params = data_store_params
|
graph_manager.data_store_params = data_store_params
|
||||||
|
|
||||||
@@ -151,7 +152,7 @@ def handle_distributed_coach_orchestrator(args):
|
|||||||
if args.data_store == "s3":
|
if args.data_store == "s3":
|
||||||
ds_params = DataStoreParameters("s3", "", "")
|
ds_params = DataStoreParameters("s3", "", "")
|
||||||
ds_params_instance = S3DataStoreParameters(ds_params=ds_params, end_point=args.s3_end_point, bucket_name=args.s3_bucket_name,
|
ds_params_instance = S3DataStoreParameters(ds_params=ds_params, end_point=args.s3_end_point, bucket_name=args.s3_bucket_name,
|
||||||
creds_file=args.s3_creds_file, checkpoint_dir=ckpt_inside_container)
|
creds_file=args.s3_creds_file, checkpoint_dir=ckpt_inside_container, expt_dir=args.experiment_path)
|
||||||
elif args.data_store == "nfs":
|
elif args.data_store == "nfs":
|
||||||
ds_params = DataStoreParameters("nfs", "kubernetes", "")
|
ds_params = DataStoreParameters("nfs", "kubernetes", "")
|
||||||
ds_params_instance = NFSDataStoreParameters(ds_params)
|
ds_params_instance = NFSDataStoreParameters(ds_params)
|
||||||
|
|||||||
@@ -18,7 +18,10 @@ def construct_data_store_params(json: dict):
|
|||||||
if json['store_type'] == 'nfs':
|
if json['store_type'] == 'nfs':
|
||||||
ds_params_instance = NFSDataStoreParameters(ds_params)
|
ds_params_instance = NFSDataStoreParameters(ds_params)
|
||||||
elif json['store_type'] == 's3':
|
elif json['store_type'] == 's3':
|
||||||
ds_params_instance = S3DataStoreParameters(ds_params=ds_params, end_point=json['end_point'],
|
ds_params_instance = S3DataStoreParameters(ds_params=ds_params,
|
||||||
bucket_name=json['bucket_name'], checkpoint_dir=json['checkpoint_dir'])
|
end_point=json['end_point'],
|
||||||
|
bucket_name=json['bucket_name'],
|
||||||
|
checkpoint_dir=json['checkpoint_dir'],
|
||||||
|
expt_dir=json['expt_dir'])
|
||||||
|
|
||||||
return ds_params_instance
|
return ds_params_instance
|
||||||
|
|||||||
@@ -12,13 +12,14 @@ import io
|
|||||||
|
|
||||||
class S3DataStoreParameters(DataStoreParameters):
|
class S3DataStoreParameters(DataStoreParameters):
|
||||||
def __init__(self, ds_params, creds_file: str = None, end_point: str = None, bucket_name: str = None,
|
def __init__(self, ds_params, creds_file: str = None, end_point: str = None, bucket_name: str = None,
|
||||||
checkpoint_dir: str = None):
|
checkpoint_dir: str = None, expt_dir: str = None):
|
||||||
|
|
||||||
super().__init__(ds_params.store_type, ds_params.orchestrator_type, ds_params.orchestrator_params)
|
super().__init__(ds_params.store_type, ds_params.orchestrator_type, ds_params.orchestrator_params)
|
||||||
self.creds_file = creds_file
|
self.creds_file = creds_file
|
||||||
self.end_point = end_point
|
self.end_point = end_point
|
||||||
self.bucket_name = bucket_name
|
self.bucket_name = bucket_name
|
||||||
self.checkpoint_dir = checkpoint_dir
|
self.checkpoint_dir = checkpoint_dir
|
||||||
|
self.expt_dir = expt_dir
|
||||||
|
|
||||||
|
|
||||||
class S3DataStore(DataStore):
|
class S3DataStore(DataStore):
|
||||||
@@ -73,6 +74,18 @@ class S3DataStore(DataStore):
|
|||||||
# release lock
|
# release lock
|
||||||
self.mc.remove_object(self.params.bucket_name, SyncFiles.LOCKFILE.value)
|
self.mc.remove_object(self.params.bucket_name, SyncFiles.LOCKFILE.value)
|
||||||
|
|
||||||
|
if self.params.expt_dir and os.path.exists(self.params.expt_dir):
|
||||||
|
for filename in os.listdir(self.params.expt_dir):
|
||||||
|
if filename.endswith((".csv", ".json")):
|
||||||
|
self.mc.fput_object(self.params.bucket_name, filename, os.path.join(self.params.expt_dir, filename))
|
||||||
|
|
||||||
|
if self.params.expt_dir and os.path.exists(os.path.join(self.params.expt_dir, 'videos')):
|
||||||
|
for filename in os.listdir(os.path.join(self.params.expt_dir, 'videos')):
|
||||||
|
self.mc.fput_object(self.params.bucket_name, filename, os.path.join(self.params.expt_dir, 'videos', filename))
|
||||||
|
|
||||||
|
if self.params.expt_dir and os.path.exists(os.path.join(self.params.expt_dir, 'gifs')):
|
||||||
|
for filename in os.listdir(os.path.join(self.params.expt_dir, 'gifs')):
|
||||||
|
self.mc.fput_object(self.params.bucket_name, filename, os.path.join(self.params.expt_dir, 'gifs', filename))
|
||||||
except ResponseError as e:
|
except ResponseError as e:
|
||||||
print("Got exception: %s\n while saving to S3", e)
|
print("Got exception: %s\n while saving to S3", e)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user