mirror of
https://github.com/gryf/coach.git
synced 2025-12-18 11:40:18 +01:00
Re-enable NFS data store. (#101)
This commit is contained in:
committed by
GitHub
parent
a0f25034c3
commit
dea1826658
@@ -42,6 +42,7 @@ from rl_coach.memories.backend.redis import RedisPubSubMemoryBackendParameters
|
||||
from rl_coach.memories.backend.memory_impl import construct_memory_params
|
||||
from rl_coach.data_stores.data_store import DataStoreParameters
|
||||
from rl_coach.data_stores.s3_data_store import S3DataStoreParameters
|
||||
from rl_coach.data_stores.nfs_data_store import NFSDataStoreParameters
|
||||
from rl_coach.data_stores.data_store_impl import get_data_store, construct_data_store_params
|
||||
from rl_coach.training_worker import training_worker
|
||||
from rl_coach.rollout_worker import rollout_worker, wait_for_checkpoint
|
||||
@@ -137,6 +138,9 @@ def handle_distributed_coach_orchestrator(graph_manager, args):
|
||||
ds_params = DataStoreParameters("s3", "", "")
|
||||
ds_params_instance = S3DataStoreParameters(ds_params=ds_params, end_point=args.s3_end_point, bucket_name=args.s3_bucket_name,
|
||||
creds_file=args.s3_creds_file, checkpoint_dir=ckpt_inside_container)
|
||||
elif args.data_store == "nfs":
|
||||
ds_params = DataStoreParameters("nfs", "kubernetes", "")
|
||||
ds_params_instance = NFSDataStoreParameters(ds_params)
|
||||
|
||||
worker_run_type_params = RunTypeParameters(args.image, rollout_command, run_type=str(RunType.ROLLOUT_WORKER), num_replicas=args.num_workers)
|
||||
trainer_run_type_params = RunTypeParameters(args.image, trainer_command, run_type=str(RunType.TRAINER))
|
||||
@@ -328,16 +332,17 @@ class CoachLauncher(object):
|
||||
args.image = coach_config.get('coach', 'image')
|
||||
args.memory_backend = coach_config.get('coach', 'memory_backend')
|
||||
args.data_store = coach_config.get('coach', 'data_store')
|
||||
args.s3_end_point = coach_config.get('coach', 's3_end_point')
|
||||
args.s3_bucket_name = coach_config.get('coach', 's3_bucket_name')
|
||||
args.s3_creds_file = coach_config.get('coach', 's3_creds_file')
|
||||
if args.data_store == 's3':
|
||||
args.s3_end_point = coach_config.get('coach', 's3_end_point')
|
||||
args.s3_bucket_name = coach_config.get('coach', 's3_bucket_name')
|
||||
args.s3_creds_file = coach_config.get('coach', 's3_creds_file')
|
||||
except Error as e:
|
||||
screen.error("Error when reading distributed Coach config file: {}".format(e))
|
||||
|
||||
if args.image == '':
|
||||
screen.error("Image cannot be empty.")
|
||||
|
||||
data_store_choices = ['s3']
|
||||
data_store_choices = ['s3', 'nfs']
|
||||
if args.data_store not in data_store_choices:
|
||||
screen.warning("{} data store is unsupported.".format(args.data_store))
|
||||
screen.error("Supported data stores are {}.".format(data_store_choices))
|
||||
@@ -347,11 +352,11 @@ class CoachLauncher(object):
|
||||
screen.warning("{} memory backend is not supported.".format(args.memory_backend))
|
||||
screen.error("Supported memory backends are {}.".format(memory_backend_choices))
|
||||
|
||||
if args.s3_bucket_name == '':
|
||||
screen.error("S3 bucket name cannot be empty.")
|
||||
|
||||
if args.s3_creds_file == '':
|
||||
args.s3_creds_file = None
|
||||
if args.data_store == 's3':
|
||||
if args.s3_bucket_name == '':
|
||||
screen.error("S3 bucket name cannot be empty.")
|
||||
if args.s3_creds_file == '':
|
||||
args.s3_creds_file = None
|
||||
|
||||
if args.play and args.distributed_coach:
|
||||
screen.error("Playing is not supported in distributed Coach.")
|
||||
|
||||
@@ -135,9 +135,9 @@ class NFSDataStore(DataStore):
|
||||
)
|
||||
|
||||
try:
|
||||
k8s_core_v1_api_client.create_namespaced_service(self.params.namespace, service)
|
||||
svc_response = k8s_core_v1_api_client.create_namespaced_service(self.params.namespace, service)
|
||||
self.params.svc_name = svc_name
|
||||
self.params.server = 'nfs-service.{}.svc.cluster.local'.format(self.params.namespace)
|
||||
self.params.server = svc_response.spec.cluster_ip
|
||||
except k8sclient.rest.ApiException as e:
|
||||
print("Got exception: %s\n while creating a service for nfs-server", e)
|
||||
return False
|
||||
|
||||
@@ -68,6 +68,8 @@ def get_latest_checkpoint(checkpoint_dir):
|
||||
rel_path = os.path.relpath(ckpt.model_checkpoint_path, checkpoint_dir)
|
||||
return int(rel_path.split('_Step')[0])
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
def should_stop(checkpoint_dir):
|
||||
return os.path.exists(os.path.join(checkpoint_dir, SyncFiles.FINISHED.value))
|
||||
|
||||
Reference in New Issue
Block a user