mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 19:20:19 +01:00
Add RedisDataStore (#295)
* GraphManager.set_session also sets self.sess * make sure that GraphManager.fetch_from_worker uses training phase * remove unnecessary phase setting in training worker * reorganize rollout worker * provide default name to GlobalVariableSaver.__init__ since it isn't really used anyway * allow dividing TrainingSteps and EnvironmentSteps * add timestamps to the log * added redis data store * conflict merge fix
This commit is contained in:
committed by
shadiendrawis
parent
34e1c04f29
commit
7b0fccb041
@@ -45,6 +45,7 @@ from rl_coach.memories.backend.memory_impl import construct_memory_params
|
||||
from rl_coach.data_stores.data_store import DataStoreParameters
|
||||
from rl_coach.data_stores.s3_data_store import S3DataStoreParameters
|
||||
from rl_coach.data_stores.nfs_data_store import NFSDataStoreParameters
|
||||
from rl_coach.data_stores.redis_data_store import RedisDataStoreParameters
|
||||
from rl_coach.data_stores.data_store_impl import get_data_store, construct_data_store_params
|
||||
from rl_coach.training_worker import training_worker
|
||||
from rl_coach.rollout_worker import rollout_worker
|
||||
@@ -97,29 +98,25 @@ def handle_distributed_coach_tasks(graph_manager, args, task_parameters):
|
||||
memory_backend_params['run_type'] = str(args.distributed_coach_run_type)
|
||||
graph_manager.agent_params.memory.register_var('memory_backend_params', construct_memory_params(memory_backend_params))
|
||||
|
||||
data_store = None
|
||||
data_store_params = None
|
||||
if args.data_store_params:
|
||||
data_store_params = construct_data_store_params(json.loads(args.data_store_params))
|
||||
data_store_params.expt_dir = args.experiment_path
|
||||
data_store_params.checkpoint_dir = ckpt_inside_container
|
||||
graph_manager.data_store_params = data_store_params
|
||||
|
||||
data_store = None
|
||||
if args.data_store_params:
|
||||
data_store = get_data_store(data_store_params)
|
||||
|
||||
if args.distributed_coach_run_type == RunType.TRAINER:
|
||||
task_parameters.checkpoint_save_dir = ckpt_inside_container
|
||||
training_worker(
|
||||
graph_manager=graph_manager,
|
||||
task_parameters=task_parameters,
|
||||
data_store=data_store,
|
||||
task_parameters=task_parameters,
|
||||
is_multi_node_test=args.is_multi_node_test
|
||||
)
|
||||
|
||||
if args.distributed_coach_run_type == RunType.ROLLOUT_WORKER:
|
||||
task_parameters.checkpoint_restore_path = ckpt_inside_container
|
||||
|
||||
rollout_worker(
|
||||
graph_manager=graph_manager,
|
||||
data_store=data_store,
|
||||
@@ -162,6 +159,11 @@ def handle_distributed_coach_orchestrator(args):
|
||||
elif args.data_store == "nfs":
|
||||
ds_params = DataStoreParameters("nfs", "kubernetes", "")
|
||||
ds_params_instance = NFSDataStoreParameters(ds_params)
|
||||
elif args.data_store == "redis":
|
||||
ds_params = DataStoreParameters("redis", "kubernetes", "")
|
||||
ds_params_instance = RedisDataStoreParameters(ds_params)
|
||||
else:
|
||||
raise ValueError("data_store {} found. Expected 's3' or 'nfs'".format(args.data_store))
|
||||
|
||||
worker_run_type_params = RunTypeParameters(args.image, rollout_command, run_type=str(RunType.ROLLOUT_WORKER), num_replicas=args.num_workers)
|
||||
trainer_run_type_params = RunTypeParameters(args.image, trainer_command, run_type=str(RunType.TRAINER))
|
||||
@@ -375,7 +377,7 @@ class CoachLauncher(object):
|
||||
if args.image == '':
|
||||
screen.error("Image cannot be empty.")
|
||||
|
||||
data_store_choices = ['s3', 'nfs']
|
||||
data_store_choices = ['s3', 'nfs', 'redis']
|
||||
if args.data_store not in data_store_choices:
|
||||
screen.warning("{} data store is unsupported.".format(args.data_store))
|
||||
screen.error("Supported data stores are {}.".format(data_store_choices))
|
||||
|
||||
Reference in New Issue
Block a user