1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-18 11:40:18 +01:00

Integrate coach.py params with distributed Coach. (#42)

* Integrate coach.py params with distributed Coach.
* Minor improvements
- Use enums instead of constants.
- Reduce code duplication.
- Ask experiment name with timeout.
This commit is contained in:
Balaji Subramaniam
2018-11-05 09:33:30 -08:00
committed by GitHub
parent 95b4fc6888
commit 7e7006305a
13 changed files with 263 additions and 285 deletions

View File

@@ -17,6 +17,7 @@ import sys
sys.path.append('.')
import copy
from configparser import ConfigParser, Error
from rl_coach.core_types import EnvironmentSteps
import os
from rl_coach import logger
@@ -26,6 +27,7 @@ import argparse
import atexit
import time
import sys
import json
from rl_coach.base_parameters import Frameworks, VisualizationParameters, TaskParameters, DistributedTaskParameters
from multiprocessing import Process
from multiprocessing.managers import BaseManager
@@ -35,6 +37,14 @@ from rl_coach.utils import list_all_presets, short_dynamic_import, get_open_port
from rl_coach.agents.human_agent import HumanAgentParameters
from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager
from rl_coach.environments.environment import SingleLevelSelection
from rl_coach.orchestrators.kubernetes_orchestrator import KubernetesParameters, Kubernetes, RunType, RunTypeParameters
from rl_coach.memories.backend.redis import RedisPubSubMemoryBackendParameters
from rl_coach.memories.backend.memory_impl import construct_memory_params
from rl_coach.data_stores.data_store import DataStoreParameters
from rl_coach.data_stores.s3_data_store import S3DataStoreParameters
from rl_coach.data_stores.data_store_impl import get_data_store, construct_data_store_params
from rl_coach.training_worker import training_worker
from rl_coach.rollout_worker import rollout_worker, wait_for_checkpoint
if len(set(failed_imports)) > 0:
@@ -108,6 +118,7 @@ def display_all_presets_and_exit():
print(preset)
sys.exit(0)
def expand_preset(preset):
if preset.lower() in [p.lower() for p in list_all_presets()]:
preset = "{}.py:graph_manager".format(os.path.join(get_base_dir(), 'presets', preset))
@@ -150,6 +161,49 @@ def parse_arguments(parser: argparse.ArgumentParser) -> argparse.Namespace:
if args.list:
display_all_presets_and_exit()
# Read args from config file for distributed Coach.
if args.distributed_coach and args.distributed_coach_run_type == RunType.ORCHESTRATOR:
coach_config = ConfigParser({
'image': '',
'memory_backend': 'redispubsub',
'data_store': 's3',
's3_end_point': 's3.amazonaws.com',
's3_bucket_name': '',
's3_creds_file': ''
})
try:
coach_config.read(args.distributed_coach_config_path)
args.image = coach_config.get('coach', 'image')
args.memory_backend = coach_config.get('coach', 'memory_backend')
args.data_store = coach_config.get('coach', 'data_store')
args.s3_end_point = coach_config.get('coach', 's3_end_point')
args.s3_bucket_name = coach_config.get('coach', 's3_bucket_name')
args.s3_creds_file = coach_config.get('coach', 's3_creds_file')
except Error as e:
screen.error("Error when reading distributed Coach config file: {}".format(e))
if args.image == '':
screen.error("Image cannot be empty.")
data_store_choices = ['s3']
if args.data_store not in data_store_choices:
screen.warning("{} data store is unsupported.".format(args.data_store))
screen.error("Supported data stores are {}.".format(data_store_choices))
memory_backend_choices = ['redispubsub']
if args.memory_backend not in memory_backend_choices:
screen.warning("{} memory backend is not supported.".format(args.memory_backend))
screen.error("Supported memory backends are {}.".format(memory_backend_choices))
if args.s3_bucket_name == '':
screen.error("S3 bucket name cannot be empty.")
if args.s3_creds_file == '':
args.s3_creds_file = None
if args.play and args.distributed_coach:
screen.error("Playing is not supported in distributed Coach.")
# replace a short preset name with the full path
if args.preset is not None:
args.preset = expand_preset(args.preset)
@@ -217,6 +271,94 @@ def start_graph(graph_manager: 'GraphManager', task_parameters: 'TaskParameters'
graph_manager.improve()
def handle_distributed_coach_tasks(graph_manager, args):
ckpt_inside_container = "/checkpoint"
memory_backend_params = None
if args.memory_backend_params:
memory_backend_params = json.loads(args.memory_backend_params)
memory_backend_params['run_type'] = str(args.distributed_coach_run_type)
graph_manager.agent_params.memory.register_var('memory_backend_params', construct_memory_params(memory_backend_params))
data_store_params = None
if args.data_store_params:
data_store_params = construct_data_store_params(json.loads(args.data_store_params))
data_store_params.checkpoint_dir = ckpt_inside_container
graph_manager.data_store_params = data_store_params
if args.distributed_coach_run_type == RunType.TRAINER:
training_worker(
graph_manager=graph_manager,
checkpoint_dir=ckpt_inside_container
)
if args.distributed_coach_run_type == RunType.ROLLOUT_WORKER:
data_store = None
if args.data_store_params:
data_store = get_data_store(data_store_params)
wait_for_checkpoint(checkpoint_dir=ckpt_inside_container, data_store=data_store)
rollout_worker(
graph_manager=graph_manager,
checkpoint_dir=ckpt_inside_container,
data_store=data_store,
num_workers=args.num_workers
)
def handle_distributed_coach_orchestrator(graph_manager, args):
ckpt_inside_container = "/checkpoint"
rollout_command = ['python3', 'rl_coach/coach.py', '--distributed_coach_run_type', str(RunType.ROLLOUT_WORKER)] + sys.argv[1:]
trainer_command = ['python3', 'rl_coach/coach.py', '--distributed_coach_run_type', str(RunType.TRAINER)] + sys.argv[1:]
if '--experiment_name' not in rollout_command:
rollout_command = rollout_command + ['--experiment_name', args.experiment_name]
if '--experiment_name' not in trainer_command:
trainer_command = trainer_command + ['--experiment_name', args.experiment_name]
memory_backend_params = None
if args.memory_backend == "redispubsub":
memory_backend_params = RedisPubSubMemoryBackendParameters()
ds_params_instance = None
if args.data_store == "s3":
ds_params = DataStoreParameters("s3", "", "")
ds_params_instance = S3DataStoreParameters(ds_params=ds_params, end_point=args.s3_end_point, bucket_name=args.s3_bucket_name,
creds_file=args.s3_creds_file, checkpoint_dir=ckpt_inside_container)
worker_run_type_params = RunTypeParameters(args.image, rollout_command, run_type=str(RunType.ROLLOUT_WORKER), num_replicas=args.num_workers)
trainer_run_type_params = RunTypeParameters(args.image, trainer_command, run_type=str(RunType.TRAINER))
orchestration_params = KubernetesParameters([worker_run_type_params, trainer_run_type_params],
kubeconfig='~/.kube/config',
memory_backend_parameters=memory_backend_params,
data_store_params=ds_params_instance)
orchestrator = Kubernetes(orchestration_params)
if not orchestrator.setup():
print("Could not setup.")
return
if orchestrator.deploy_trainer():
print("Successfully deployed trainer.")
else:
print("Could not deploy trainer.")
return
if orchestrator.deploy_worker():
print("Successfully deployed rollout worker(s).")
else:
print("Could not deploy rollout worker(s).")
return
try:
orchestrator.trainer_logs()
except KeyboardInterrupt:
pass
orchestrator.undeploy()
def main():
parser = argparse.ArgumentParser()
parser.add_argument('-p', '--preset',
@@ -329,11 +471,35 @@ def main():
help="(int) A seed to use for running the experiment",
default=None,
type=int)
parser.add_argument('-dc', '--distributed_coach',
help="(flag) Use distributed Coach.",
action='store_true')
parser.add_argument('-dcp', '--distributed_coach_config_path',
help="(string) Path to config file when using distributed rollout workers."
"Only distributed Coach parameters should be provided through this config file."
"Rest of the parameters are provided using Coach command line options."
"Used only with --distributed_coach flag."
"Ignored if --distributed_coach flag is not used.",
type=str)
parser.add_argument('--memory_backend_params',
help=argparse.SUPPRESS,
type=str)
parser.add_argument('--data_store_params',
help=argparse.SUPPRESS,
type=str)
parser.add_argument('--distributed_coach_run_type',
help=argparse.SUPPRESS,
type=RunType,
default=RunType.ORCHESTRATOR,
choices=list(RunType))
args = parse_arguments(parser)
graph_manager = get_graph_manager_from_args(args)
if args.distributed_coach and not graph_manager.agent_params.algorithm.distributed_coach_synchronization_type:
screen.error("{} preset is not supported using distributed Coach.".format(args.preset))
# Intel optimized TF seems to run significantly faster when limiting to a single OMP thread.
# This will not affect GPU runs.
os.environ["OMP_NUM_THREADS"] = "1"
@@ -343,7 +509,7 @@ def main():
os.environ['TF_CPP_MIN_LOG_LEVEL'] = str(args.tf_verbosity)
# turn off the summary at the end of the run if necessary
if not args.no_summary:
if not args.no_summary and not args.distributed_coach:
atexit.register(logger.summarize_experiment)
screen.change_terminal_title(args.experiment_name)
@@ -351,6 +517,14 @@ def main():
if args.open_dashboard:
open_dashboard(args.experiment_path)
if args.distributed_coach and args.distributed_coach_run_type != RunType.ORCHESTRATOR:
handle_distributed_coach_tasks(graph_manager, args)
return
if args.distributed_coach and args.distributed_coach_run_type == RunType.ORCHESTRATOR:
handle_distributed_coach_orchestrator(graph_manager, args)
return
# Single-threaded runs
if args.num_workers == 1:
# Start the training or evaluation