mirror of
https://github.com/gryf/coach.git
synced 2026-03-17 23:33:37 +01:00
Re-enable NFS data store. (#101)
This commit is contained in:
committed by
GitHub
parent
a0f25034c3
commit
dea1826658
@@ -42,6 +42,7 @@ from rl_coach.memories.backend.redis import RedisPubSubMemoryBackendParameters
|
||||
from rl_coach.memories.backend.memory_impl import construct_memory_params
|
||||
from rl_coach.data_stores.data_store import DataStoreParameters
|
||||
from rl_coach.data_stores.s3_data_store import S3DataStoreParameters
|
||||
from rl_coach.data_stores.nfs_data_store import NFSDataStoreParameters
|
||||
from rl_coach.data_stores.data_store_impl import get_data_store, construct_data_store_params
|
||||
from rl_coach.training_worker import training_worker
|
||||
from rl_coach.rollout_worker import rollout_worker, wait_for_checkpoint
|
||||
@@ -137,6 +138,9 @@ def handle_distributed_coach_orchestrator(graph_manager, args):
|
||||
ds_params = DataStoreParameters("s3", "", "")
|
||||
ds_params_instance = S3DataStoreParameters(ds_params=ds_params, end_point=args.s3_end_point, bucket_name=args.s3_bucket_name,
|
||||
creds_file=args.s3_creds_file, checkpoint_dir=ckpt_inside_container)
|
||||
elif args.data_store == "nfs":
|
||||
ds_params = DataStoreParameters("nfs", "kubernetes", "")
|
||||
ds_params_instance = NFSDataStoreParameters(ds_params)
|
||||
|
||||
worker_run_type_params = RunTypeParameters(args.image, rollout_command, run_type=str(RunType.ROLLOUT_WORKER), num_replicas=args.num_workers)
|
||||
trainer_run_type_params = RunTypeParameters(args.image, trainer_command, run_type=str(RunType.TRAINER))
|
||||
@@ -328,16 +332,17 @@ class CoachLauncher(object):
|
||||
args.image = coach_config.get('coach', 'image')
|
||||
args.memory_backend = coach_config.get('coach', 'memory_backend')
|
||||
args.data_store = coach_config.get('coach', 'data_store')
|
||||
args.s3_end_point = coach_config.get('coach', 's3_end_point')
|
||||
args.s3_bucket_name = coach_config.get('coach', 's3_bucket_name')
|
||||
args.s3_creds_file = coach_config.get('coach', 's3_creds_file')
|
||||
if args.data_store == 's3':
|
||||
args.s3_end_point = coach_config.get('coach', 's3_end_point')
|
||||
args.s3_bucket_name = coach_config.get('coach', 's3_bucket_name')
|
||||
args.s3_creds_file = coach_config.get('coach', 's3_creds_file')
|
||||
except Error as e:
|
||||
screen.error("Error when reading distributed Coach config file: {}".format(e))
|
||||
|
||||
if args.image == '':
|
||||
screen.error("Image cannot be empty.")
|
||||
|
||||
data_store_choices = ['s3']
|
||||
data_store_choices = ['s3', 'nfs']
|
||||
if args.data_store not in data_store_choices:
|
||||
screen.warning("{} data store is unsupported.".format(args.data_store))
|
||||
screen.error("Supported data stores are {}.".format(data_store_choices))
|
||||
@@ -347,11 +352,11 @@ class CoachLauncher(object):
|
||||
screen.warning("{} memory backend is not supported.".format(args.memory_backend))
|
||||
screen.error("Supported memory backends are {}.".format(memory_backend_choices))
|
||||
|
||||
if args.s3_bucket_name == '':
|
||||
screen.error("S3 bucket name cannot be empty.")
|
||||
|
||||
if args.s3_creds_file == '':
|
||||
args.s3_creds_file = None
|
||||
if args.data_store == 's3':
|
||||
if args.s3_bucket_name == '':
|
||||
screen.error("S3 bucket name cannot be empty.")
|
||||
if args.s3_creds_file == '':
|
||||
args.s3_creds_file = None
|
||||
|
||||
if args.play and args.distributed_coach:
|
||||
screen.error("Playing is not supported in distributed Coach.")
|
||||
|
||||
Reference in New Issue
Block a user