1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 19:20:19 +01:00

Add RedisDataStore (#295)

* GraphManager.set_session also sets self.sess

* make sure that GraphManager.fetch_from_worker uses training phase

* remove unnecessary phase setting in training worker

* reorganize rollout worker

* provide default name to GlobalVariableSaver.__init__ since it isn't really used anyway

* allow dividing TrainingSteps and EnvironmentSteps

* add timestamps to the log

* added redis data store

* conflict merge fix
This commit is contained in:
Zach Dwiel
2019-08-28 14:15:58 -04:00
committed by shadiendrawis
parent 34e1c04f29
commit 7b0fccb041
18 changed files with 528 additions and 120 deletions

View File

@@ -45,6 +45,7 @@ from rl_coach.memories.backend.memory_impl import construct_memory_params
from rl_coach.data_stores.data_store import DataStoreParameters
from rl_coach.data_stores.s3_data_store import S3DataStoreParameters
from rl_coach.data_stores.nfs_data_store import NFSDataStoreParameters
from rl_coach.data_stores.redis_data_store import RedisDataStoreParameters
from rl_coach.data_stores.data_store_impl import get_data_store, construct_data_store_params
from rl_coach.training_worker import training_worker
from rl_coach.rollout_worker import rollout_worker
@@ -97,29 +98,25 @@ def handle_distributed_coach_tasks(graph_manager, args, task_parameters):
memory_backend_params['run_type'] = str(args.distributed_coach_run_type)
graph_manager.agent_params.memory.register_var('memory_backend_params', construct_memory_params(memory_backend_params))
data_store = None
data_store_params = None
if args.data_store_params:
data_store_params = construct_data_store_params(json.loads(args.data_store_params))
data_store_params.expt_dir = args.experiment_path
data_store_params.checkpoint_dir = ckpt_inside_container
graph_manager.data_store_params = data_store_params
data_store = None
if args.data_store_params:
data_store = get_data_store(data_store_params)
if args.distributed_coach_run_type == RunType.TRAINER:
task_parameters.checkpoint_save_dir = ckpt_inside_container
training_worker(
graph_manager=graph_manager,
task_parameters=task_parameters,
data_store=data_store,
task_parameters=task_parameters,
is_multi_node_test=args.is_multi_node_test
)
if args.distributed_coach_run_type == RunType.ROLLOUT_WORKER:
task_parameters.checkpoint_restore_path = ckpt_inside_container
rollout_worker(
graph_manager=graph_manager,
data_store=data_store,
@@ -162,6 +159,11 @@ def handle_distributed_coach_orchestrator(args):
elif args.data_store == "nfs":
ds_params = DataStoreParameters("nfs", "kubernetes", "")
ds_params_instance = NFSDataStoreParameters(ds_params)
elif args.data_store == "redis":
ds_params = DataStoreParameters("redis", "kubernetes", "")
ds_params_instance = RedisDataStoreParameters(ds_params)
else:
raise ValueError("data_store {} found. Expected 's3' or 'nfs'".format(args.data_store))
worker_run_type_params = RunTypeParameters(args.image, rollout_command, run_type=str(RunType.ROLLOUT_WORKER), num_replicas=args.num_workers)
trainer_run_type_params = RunTypeParameters(args.image, trainer_command, run_type=str(RunType.TRAINER))
@@ -375,7 +377,7 @@ class CoachLauncher(object):
if args.image == '':
screen.error("Image cannot be empty.")
data_store_choices = ['s3', 'nfs']
data_store_choices = ['s3', 'nfs', 'redis']
if args.data_store not in data_store_choices:
screen.warning("{} data store is unsupported.".format(args.data_store))
screen.error("Supported data stores are {}.".format(data_store_choices))