mirror of
https://github.com/gryf/coach.git
synced 2025-12-18 11:40:18 +01:00
Integrate coach.py params with distributed Coach. (#42)
* Integrate coach.py params with distributed Coach. * Minor improvements - Use enums instead of constants. - Reduce code duplication. - Ask experiment name with timeout.
This commit is contained in:
committed by
GitHub
parent
95b4fc6888
commit
7e7006305a
@@ -17,6 +17,7 @@ import sys
|
||||
sys.path.append('.')
|
||||
|
||||
import copy
|
||||
from configparser import ConfigParser, Error
|
||||
from rl_coach.core_types import EnvironmentSteps
|
||||
import os
|
||||
from rl_coach import logger
|
||||
@@ -26,6 +27,7 @@ import argparse
|
||||
import atexit
|
||||
import time
|
||||
import sys
|
||||
import json
|
||||
from rl_coach.base_parameters import Frameworks, VisualizationParameters, TaskParameters, DistributedTaskParameters
|
||||
from multiprocessing import Process
|
||||
from multiprocessing.managers import BaseManager
|
||||
@@ -35,6 +37,14 @@ from rl_coach.utils import list_all_presets, short_dynamic_import, get_open_port
|
||||
from rl_coach.agents.human_agent import HumanAgentParameters
|
||||
from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager
|
||||
from rl_coach.environments.environment import SingleLevelSelection
|
||||
from rl_coach.orchestrators.kubernetes_orchestrator import KubernetesParameters, Kubernetes, RunType, RunTypeParameters
|
||||
from rl_coach.memories.backend.redis import RedisPubSubMemoryBackendParameters
|
||||
from rl_coach.memories.backend.memory_impl import construct_memory_params
|
||||
from rl_coach.data_stores.data_store import DataStoreParameters
|
||||
from rl_coach.data_stores.s3_data_store import S3DataStoreParameters
|
||||
from rl_coach.data_stores.data_store_impl import get_data_store, construct_data_store_params
|
||||
from rl_coach.training_worker import training_worker
|
||||
from rl_coach.rollout_worker import rollout_worker, wait_for_checkpoint
|
||||
|
||||
|
||||
if len(set(failed_imports)) > 0:
|
||||
@@ -108,6 +118,7 @@ def display_all_presets_and_exit():
|
||||
print(preset)
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
def expand_preset(preset):
|
||||
if preset.lower() in [p.lower() for p in list_all_presets()]:
|
||||
preset = "{}.py:graph_manager".format(os.path.join(get_base_dir(), 'presets', preset))
|
||||
@@ -150,6 +161,49 @@ def parse_arguments(parser: argparse.ArgumentParser) -> argparse.Namespace:
|
||||
if args.list:
|
||||
display_all_presets_and_exit()
|
||||
|
||||
# Read args from config file for distributed Coach.
|
||||
if args.distributed_coach and args.distributed_coach_run_type == RunType.ORCHESTRATOR:
|
||||
coach_config = ConfigParser({
|
||||
'image': '',
|
||||
'memory_backend': 'redispubsub',
|
||||
'data_store': 's3',
|
||||
's3_end_point': 's3.amazonaws.com',
|
||||
's3_bucket_name': '',
|
||||
's3_creds_file': ''
|
||||
})
|
||||
try:
|
||||
coach_config.read(args.distributed_coach_config_path)
|
||||
args.image = coach_config.get('coach', 'image')
|
||||
args.memory_backend = coach_config.get('coach', 'memory_backend')
|
||||
args.data_store = coach_config.get('coach', 'data_store')
|
||||
args.s3_end_point = coach_config.get('coach', 's3_end_point')
|
||||
args.s3_bucket_name = coach_config.get('coach', 's3_bucket_name')
|
||||
args.s3_creds_file = coach_config.get('coach', 's3_creds_file')
|
||||
except Error as e:
|
||||
screen.error("Error when reading distributed Coach config file: {}".format(e))
|
||||
|
||||
if args.image == '':
|
||||
screen.error("Image cannot be empty.")
|
||||
|
||||
data_store_choices = ['s3']
|
||||
if args.data_store not in data_store_choices:
|
||||
screen.warning("{} data store is unsupported.".format(args.data_store))
|
||||
screen.error("Supported data stores are {}.".format(data_store_choices))
|
||||
|
||||
memory_backend_choices = ['redispubsub']
|
||||
if args.memory_backend not in memory_backend_choices:
|
||||
screen.warning("{} memory backend is not supported.".format(args.memory_backend))
|
||||
screen.error("Supported memory backends are {}.".format(memory_backend_choices))
|
||||
|
||||
if args.s3_bucket_name == '':
|
||||
screen.error("S3 bucket name cannot be empty.")
|
||||
|
||||
if args.s3_creds_file == '':
|
||||
args.s3_creds_file = None
|
||||
|
||||
if args.play and args.distributed_coach:
|
||||
screen.error("Playing is not supported in distributed Coach.")
|
||||
|
||||
# replace a short preset name with the full path
|
||||
if args.preset is not None:
|
||||
args.preset = expand_preset(args.preset)
|
||||
@@ -217,6 +271,94 @@ def start_graph(graph_manager: 'GraphManager', task_parameters: 'TaskParameters'
|
||||
graph_manager.improve()
|
||||
|
||||
|
||||
def handle_distributed_coach_tasks(graph_manager, args):
|
||||
ckpt_inside_container = "/checkpoint"
|
||||
|
||||
memory_backend_params = None
|
||||
if args.memory_backend_params:
|
||||
memory_backend_params = json.loads(args.memory_backend_params)
|
||||
memory_backend_params['run_type'] = str(args.distributed_coach_run_type)
|
||||
graph_manager.agent_params.memory.register_var('memory_backend_params', construct_memory_params(memory_backend_params))
|
||||
|
||||
data_store_params = None
|
||||
if args.data_store_params:
|
||||
data_store_params = construct_data_store_params(json.loads(args.data_store_params))
|
||||
data_store_params.checkpoint_dir = ckpt_inside_container
|
||||
graph_manager.data_store_params = data_store_params
|
||||
|
||||
if args.distributed_coach_run_type == RunType.TRAINER:
|
||||
training_worker(
|
||||
graph_manager=graph_manager,
|
||||
checkpoint_dir=ckpt_inside_container
|
||||
)
|
||||
|
||||
if args.distributed_coach_run_type == RunType.ROLLOUT_WORKER:
|
||||
data_store = None
|
||||
if args.data_store_params:
|
||||
data_store = get_data_store(data_store_params)
|
||||
wait_for_checkpoint(checkpoint_dir=ckpt_inside_container, data_store=data_store)
|
||||
|
||||
rollout_worker(
|
||||
graph_manager=graph_manager,
|
||||
checkpoint_dir=ckpt_inside_container,
|
||||
data_store=data_store,
|
||||
num_workers=args.num_workers
|
||||
)
|
||||
|
||||
|
||||
def handle_distributed_coach_orchestrator(graph_manager, args):
|
||||
ckpt_inside_container = "/checkpoint"
|
||||
rollout_command = ['python3', 'rl_coach/coach.py', '--distributed_coach_run_type', str(RunType.ROLLOUT_WORKER)] + sys.argv[1:]
|
||||
trainer_command = ['python3', 'rl_coach/coach.py', '--distributed_coach_run_type', str(RunType.TRAINER)] + sys.argv[1:]
|
||||
|
||||
if '--experiment_name' not in rollout_command:
|
||||
rollout_command = rollout_command + ['--experiment_name', args.experiment_name]
|
||||
|
||||
if '--experiment_name' not in trainer_command:
|
||||
trainer_command = trainer_command + ['--experiment_name', args.experiment_name]
|
||||
|
||||
memory_backend_params = None
|
||||
if args.memory_backend == "redispubsub":
|
||||
memory_backend_params = RedisPubSubMemoryBackendParameters()
|
||||
|
||||
ds_params_instance = None
|
||||
if args.data_store == "s3":
|
||||
ds_params = DataStoreParameters("s3", "", "")
|
||||
ds_params_instance = S3DataStoreParameters(ds_params=ds_params, end_point=args.s3_end_point, bucket_name=args.s3_bucket_name,
|
||||
creds_file=args.s3_creds_file, checkpoint_dir=ckpt_inside_container)
|
||||
|
||||
worker_run_type_params = RunTypeParameters(args.image, rollout_command, run_type=str(RunType.ROLLOUT_WORKER), num_replicas=args.num_workers)
|
||||
trainer_run_type_params = RunTypeParameters(args.image, trainer_command, run_type=str(RunType.TRAINER))
|
||||
|
||||
orchestration_params = KubernetesParameters([worker_run_type_params, trainer_run_type_params],
|
||||
kubeconfig='~/.kube/config',
|
||||
memory_backend_parameters=memory_backend_params,
|
||||
data_store_params=ds_params_instance)
|
||||
orchestrator = Kubernetes(orchestration_params)
|
||||
if not orchestrator.setup():
|
||||
print("Could not setup.")
|
||||
return
|
||||
|
||||
if orchestrator.deploy_trainer():
|
||||
print("Successfully deployed trainer.")
|
||||
else:
|
||||
print("Could not deploy trainer.")
|
||||
return
|
||||
|
||||
if orchestrator.deploy_worker():
|
||||
print("Successfully deployed rollout worker(s).")
|
||||
else:
|
||||
print("Could not deploy rollout worker(s).")
|
||||
return
|
||||
|
||||
try:
|
||||
orchestrator.trainer_logs()
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
|
||||
orchestrator.undeploy()
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-p', '--preset',
|
||||
@@ -329,11 +471,35 @@ def main():
|
||||
help="(int) A seed to use for running the experiment",
|
||||
default=None,
|
||||
type=int)
|
||||
parser.add_argument('-dc', '--distributed_coach',
|
||||
help="(flag) Use distributed Coach.",
|
||||
action='store_true')
|
||||
parser.add_argument('-dcp', '--distributed_coach_config_path',
|
||||
help="(string) Path to config file when using distributed rollout workers."
|
||||
"Only distributed Coach parameters should be provided through this config file."
|
||||
"Rest of the parameters are provided using Coach command line options."
|
||||
"Used only with --distributed_coach flag."
|
||||
"Ignored if --distributed_coach flag is not used.",
|
||||
type=str)
|
||||
parser.add_argument('--memory_backend_params',
|
||||
help=argparse.SUPPRESS,
|
||||
type=str)
|
||||
parser.add_argument('--data_store_params',
|
||||
help=argparse.SUPPRESS,
|
||||
type=str)
|
||||
parser.add_argument('--distributed_coach_run_type',
|
||||
help=argparse.SUPPRESS,
|
||||
type=RunType,
|
||||
default=RunType.ORCHESTRATOR,
|
||||
choices=list(RunType))
|
||||
|
||||
args = parse_arguments(parser)
|
||||
|
||||
graph_manager = get_graph_manager_from_args(args)
|
||||
|
||||
if args.distributed_coach and not graph_manager.agent_params.algorithm.distributed_coach_synchronization_type:
|
||||
screen.error("{} preset is not supported using distributed Coach.".format(args.preset))
|
||||
|
||||
# Intel optimized TF seems to run significantly faster when limiting to a single OMP thread.
|
||||
# This will not affect GPU runs.
|
||||
os.environ["OMP_NUM_THREADS"] = "1"
|
||||
@@ -343,7 +509,7 @@ def main():
|
||||
os.environ['TF_CPP_MIN_LOG_LEVEL'] = str(args.tf_verbosity)
|
||||
|
||||
# turn off the summary at the end of the run if necessary
|
||||
if not args.no_summary:
|
||||
if not args.no_summary and not args.distributed_coach:
|
||||
atexit.register(logger.summarize_experiment)
|
||||
screen.change_terminal_title(args.experiment_name)
|
||||
|
||||
@@ -351,6 +517,14 @@ def main():
|
||||
if args.open_dashboard:
|
||||
open_dashboard(args.experiment_path)
|
||||
|
||||
if args.distributed_coach and args.distributed_coach_run_type != RunType.ORCHESTRATOR:
|
||||
handle_distributed_coach_tasks(graph_manager, args)
|
||||
return
|
||||
|
||||
if args.distributed_coach and args.distributed_coach_run_type == RunType.ORCHESTRATOR:
|
||||
handle_distributed_coach_orchestrator(graph_manager, args)
|
||||
return
|
||||
|
||||
# Single-threaded runs
|
||||
if args.num_workers == 1:
|
||||
# Start the training or evaluation
|
||||
|
||||
Reference in New Issue
Block a user