diff --git a/dist-coach-config.template b/dist-coach-config.template new file mode 100644 index 0000000..1566345 --- /dev/null +++ b/dist-coach-config.template @@ -0,0 +1,7 @@ +[coach] +image = +memory_backend = redispubsub +data_store = s3 +s3_end_point = s3.amazonaws.com +s3_bucket_name = +s3_creds_file = diff --git a/rl_coach/base_parameters.py b/rl_coach/base_parameters.py index a48a22c..e23226e 100644 --- a/rl_coach/base_parameters.py +++ b/rl_coach/base_parameters.py @@ -53,6 +53,18 @@ class EmbeddingMergerType(Enum): #Multiply = 3 +# DistributedCoachSynchronizationType provides the synchronization type for distributed Coach. +# The default value is None, which means the algorithm or preset cannot be used with distributed Coach. +class DistributedCoachSynchronizationType(Enum): + # In SYNC mode, the trainer waits for all the experiences to be gathered from distributed rollout workers before + # training a new policy and the rollout workers wait for a new policy before gathering experiences. + SYNC = "sync" + + # In ASYNC mode, the trainer doesn't wait for any set of experiences to be gathered from distributed rollout workers + # and the rollout workers continously gather experiences loading new policies, whenever they become available. + ASYNC = "async" + + def iterable_to_items(obj): if isinstance(obj, dict) or isinstance(obj, OrderedDict) or isinstance(obj, types.MappingProxyType): items = obj.items() @@ -154,6 +166,9 @@ class AlgorithmParameters(Parameters): # intrinsic reward self.scale_external_reward_by_intrinsic_reward_value = False + # Distributed Coach params + self.distributed_coach_synchronization_type = None + class PresetValidationParameters(Parameters): def __init__(self): diff --git a/rl_coach/coach.py b/rl_coach/coach.py index cab0492..eecfc70 100644 --- a/rl_coach/coach.py +++ b/rl_coach/coach.py @@ -17,6 +17,7 @@ import sys sys.path.append('.') import copy +from configparser import ConfigParser, Error from rl_coach.core_types import EnvironmentSteps import os from rl_coach import logger @@ -26,6 +27,7 @@ import argparse import atexit import time import sys +import json from rl_coach.base_parameters import Frameworks, VisualizationParameters, TaskParameters, DistributedTaskParameters from multiprocessing import Process from multiprocessing.managers import BaseManager @@ -35,6 +37,14 @@ from rl_coach.utils import list_all_presets, short_dynamic_import, get_open_port from rl_coach.agents.human_agent import HumanAgentParameters from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager from rl_coach.environments.environment import SingleLevelSelection +from rl_coach.orchestrators.kubernetes_orchestrator import KubernetesParameters, Kubernetes, RunType, RunTypeParameters +from rl_coach.memories.backend.redis import RedisPubSubMemoryBackendParameters +from rl_coach.memories.backend.memory_impl import construct_memory_params +from rl_coach.data_stores.data_store import DataStoreParameters +from rl_coach.data_stores.s3_data_store import S3DataStoreParameters +from rl_coach.data_stores.data_store_impl import get_data_store, construct_data_store_params +from rl_coach.training_worker import training_worker +from rl_coach.rollout_worker import rollout_worker, wait_for_checkpoint if len(set(failed_imports)) > 0: @@ -108,6 +118,7 @@ def display_all_presets_and_exit(): print(preset) sys.exit(0) + def expand_preset(preset): if preset.lower() in [p.lower() for p in list_all_presets()]: preset = "{}.py:graph_manager".format(os.path.join(get_base_dir(), 'presets', preset)) @@ -150,6 +161,49 @@ def parse_arguments(parser: argparse.ArgumentParser) -> argparse.Namespace: if args.list: display_all_presets_and_exit() + # Read args from config file for distributed Coach. + if args.distributed_coach and args.distributed_coach_run_type == RunType.ORCHESTRATOR: + coach_config = ConfigParser({ + 'image': '', + 'memory_backend': 'redispubsub', + 'data_store': 's3', + 's3_end_point': 's3.amazonaws.com', + 's3_bucket_name': '', + 's3_creds_file': '' + }) + try: + coach_config.read(args.distributed_coach_config_path) + args.image = coach_config.get('coach', 'image') + args.memory_backend = coach_config.get('coach', 'memory_backend') + args.data_store = coach_config.get('coach', 'data_store') + args.s3_end_point = coach_config.get('coach', 's3_end_point') + args.s3_bucket_name = coach_config.get('coach', 's3_bucket_name') + args.s3_creds_file = coach_config.get('coach', 's3_creds_file') + except Error as e: + screen.error("Error when reading distributed Coach config file: {}".format(e)) + + if args.image == '': + screen.error("Image cannot be empty.") + + data_store_choices = ['s3'] + if args.data_store not in data_store_choices: + screen.warning("{} data store is unsupported.".format(args.data_store)) + screen.error("Supported data stores are {}.".format(data_store_choices)) + + memory_backend_choices = ['redispubsub'] + if args.memory_backend not in memory_backend_choices: + screen.warning("{} memory backend is not supported.".format(args.memory_backend)) + screen.error("Supported memory backends are {}.".format(memory_backend_choices)) + + if args.s3_bucket_name == '': + screen.error("S3 bucket name cannot be empty.") + + if args.s3_creds_file == '': + args.s3_creds_file = None + + if args.play and args.distributed_coach: + screen.error("Playing is not supported in distributed Coach.") + # replace a short preset name with the full path if args.preset is not None: args.preset = expand_preset(args.preset) @@ -217,6 +271,94 @@ def start_graph(graph_manager: 'GraphManager', task_parameters: 'TaskParameters' graph_manager.improve() +def handle_distributed_coach_tasks(graph_manager, args): + ckpt_inside_container = "/checkpoint" + + memory_backend_params = None + if args.memory_backend_params: + memory_backend_params = json.loads(args.memory_backend_params) + memory_backend_params['run_type'] = str(args.distributed_coach_run_type) + graph_manager.agent_params.memory.register_var('memory_backend_params', construct_memory_params(memory_backend_params)) + + data_store_params = None + if args.data_store_params: + data_store_params = construct_data_store_params(json.loads(args.data_store_params)) + data_store_params.checkpoint_dir = ckpt_inside_container + graph_manager.data_store_params = data_store_params + + if args.distributed_coach_run_type == RunType.TRAINER: + training_worker( + graph_manager=graph_manager, + checkpoint_dir=ckpt_inside_container + ) + + if args.distributed_coach_run_type == RunType.ROLLOUT_WORKER: + data_store = None + if args.data_store_params: + data_store = get_data_store(data_store_params) + wait_for_checkpoint(checkpoint_dir=ckpt_inside_container, data_store=data_store) + + rollout_worker( + graph_manager=graph_manager, + checkpoint_dir=ckpt_inside_container, + data_store=data_store, + num_workers=args.num_workers + ) + + +def handle_distributed_coach_orchestrator(graph_manager, args): + ckpt_inside_container = "/checkpoint" + rollout_command = ['python3', 'rl_coach/coach.py', '--distributed_coach_run_type', str(RunType.ROLLOUT_WORKER)] + sys.argv[1:] + trainer_command = ['python3', 'rl_coach/coach.py', '--distributed_coach_run_type', str(RunType.TRAINER)] + sys.argv[1:] + + if '--experiment_name' not in rollout_command: + rollout_command = rollout_command + ['--experiment_name', args.experiment_name] + + if '--experiment_name' not in trainer_command: + trainer_command = trainer_command + ['--experiment_name', args.experiment_name] + + memory_backend_params = None + if args.memory_backend == "redispubsub": + memory_backend_params = RedisPubSubMemoryBackendParameters() + + ds_params_instance = None + if args.data_store == "s3": + ds_params = DataStoreParameters("s3", "", "") + ds_params_instance = S3DataStoreParameters(ds_params=ds_params, end_point=args.s3_end_point, bucket_name=args.s3_bucket_name, + creds_file=args.s3_creds_file, checkpoint_dir=ckpt_inside_container) + + worker_run_type_params = RunTypeParameters(args.image, rollout_command, run_type=str(RunType.ROLLOUT_WORKER), num_replicas=args.num_workers) + trainer_run_type_params = RunTypeParameters(args.image, trainer_command, run_type=str(RunType.TRAINER)) + + orchestration_params = KubernetesParameters([worker_run_type_params, trainer_run_type_params], + kubeconfig='~/.kube/config', + memory_backend_parameters=memory_backend_params, + data_store_params=ds_params_instance) + orchestrator = Kubernetes(orchestration_params) + if not orchestrator.setup(): + print("Could not setup.") + return + + if orchestrator.deploy_trainer(): + print("Successfully deployed trainer.") + else: + print("Could not deploy trainer.") + return + + if orchestrator.deploy_worker(): + print("Successfully deployed rollout worker(s).") + else: + print("Could not deploy rollout worker(s).") + return + + try: + orchestrator.trainer_logs() + except KeyboardInterrupt: + pass + + orchestrator.undeploy() + + def main(): parser = argparse.ArgumentParser() parser.add_argument('-p', '--preset', @@ -329,11 +471,35 @@ def main(): help="(int) A seed to use for running the experiment", default=None, type=int) + parser.add_argument('-dc', '--distributed_coach', + help="(flag) Use distributed Coach.", + action='store_true') + parser.add_argument('-dcp', '--distributed_coach_config_path', + help="(string) Path to config file when using distributed rollout workers." + "Only distributed Coach parameters should be provided through this config file." + "Rest of the parameters are provided using Coach command line options." + "Used only with --distributed_coach flag." + "Ignored if --distributed_coach flag is not used.", + type=str) + parser.add_argument('--memory_backend_params', + help=argparse.SUPPRESS, + type=str) + parser.add_argument('--data_store_params', + help=argparse.SUPPRESS, + type=str) + parser.add_argument('--distributed_coach_run_type', + help=argparse.SUPPRESS, + type=RunType, + default=RunType.ORCHESTRATOR, + choices=list(RunType)) args = parse_arguments(parser) graph_manager = get_graph_manager_from_args(args) + if args.distributed_coach and not graph_manager.agent_params.algorithm.distributed_coach_synchronization_type: + screen.error("{} preset is not supported using distributed Coach.".format(args.preset)) + # Intel optimized TF seems to run significantly faster when limiting to a single OMP thread. # This will not affect GPU runs. os.environ["OMP_NUM_THREADS"] = "1" @@ -343,7 +509,7 @@ def main(): os.environ['TF_CPP_MIN_LOG_LEVEL'] = str(args.tf_verbosity) # turn off the summary at the end of the run if necessary - if not args.no_summary: + if not args.no_summary and not args.distributed_coach: atexit.register(logger.summarize_experiment) screen.change_terminal_title(args.experiment_name) @@ -351,6 +517,14 @@ def main(): if args.open_dashboard: open_dashboard(args.experiment_path) + if args.distributed_coach and args.distributed_coach_run_type != RunType.ORCHESTRATOR: + handle_distributed_coach_tasks(graph_manager, args) + return + + if args.distributed_coach and args.distributed_coach_run_type == RunType.ORCHESTRATOR: + handle_distributed_coach_orchestrator(graph_manager, args) + return + # Single-threaded runs if args.num_workers == 1: # Start the training or evaluation diff --git a/rl_coach/graph_managers/graph_manager.py b/rl_coach/graph_managers/graph_manager.py index 4b5ae9f..b373c1c 100644 --- a/rl_coach/graph_managers/graph_manager.py +++ b/rl_coach/graph_managers/graph_manager.py @@ -33,6 +33,7 @@ from rl_coach.level_manager import LevelManager from rl_coach.logger import screen, Logger from rl_coach.utils import set_cpu, start_shell_command_and_wait from rl_coach.data_stores.data_store_impl import get_data_store +from rl_coach.orchestrators.kubernetes_orchestrator import RunType class ScheduleParameters(Parameters): @@ -361,9 +362,10 @@ class GraphManager(object): self.verify_graph_was_created() if hasattr(self, 'data_store_params') and hasattr(self.agent_params.memory, 'memory_backend_params'): - if self.agent_params.memory.memory_backend_params.run_type == "worker": + if self.agent_params.memory.memory_backend_params.run_type == str(RunType.ROLLOUT_WORKER): data_store = get_data_store(self.data_store_params) data_store.load_from_store() + # perform several steps of playing count_end = self.current_step_counter + steps while self.current_step_counter < count_end: diff --git a/rl_coach/logger.py b/rl_coach/logger.py index 02f1a9c..0a42296 100644 --- a/rl_coach/logger.py +++ b/rl_coach/logger.py @@ -18,7 +18,9 @@ import datetime import os import re import shutil +import signal import time +import uuid from subprocess import Popen, PIPE from typing import Union @@ -90,6 +92,23 @@ class ScreenLogger(object): def ask_input(self, title): return input("{}{}{}".format(Colors.BG_CYAN, title, Colors.END)) + def ask_input_with_timeout(self, title, timeout, msg_if_timeout='Timeout expired.'): + class TimeoutExpired(Exception): + pass + + def timeout_alarm_handler(signum, frame): + raise TimeoutExpired + + signal.signal(signal.SIGALRM, timeout_alarm_handler) + signal.alarm(timeout) + + try: + return input("{}{}{}".format(Colors.BG_CYAN, title, Colors.END)) + except TimeoutExpired: + self.warning(msg_if_timeout) + finally: + signal.alarm(0) + def ask_yes_no(self, title: str, default: Union[None, bool] = None): """ Ask the user for a yes / no question and return True if the answer is yes and False otherwise. @@ -333,10 +352,14 @@ def get_experiment_name(initial_experiment_name=''): match = None while match is None: if initial_experiment_name == '': - experiment_name = screen.ask_input("Please enter an experiment name: ") + msg_if_timeout = "Timeout waiting for experiement name." + experiment_name = screen.ask_input_with_timeout("Please enter an experiment name: ", 60, msg_if_timeout) else: experiment_name = initial_experiment_name + if not experiment_name: + experiment_name = '' + experiment_name = experiment_name.replace(" ", "_") match = re.match("^$|^[\w -/]{1,1000}$", experiment_name) diff --git a/rl_coach/orchestrators/kubernetes_orchestrator.py b/rl_coach/orchestrators/kubernetes_orchestrator.py index 4a50f3e..d2afb4d 100644 --- a/rl_coach/orchestrators/kubernetes_orchestrator.py +++ b/rl_coach/orchestrators/kubernetes_orchestrator.py @@ -2,6 +2,7 @@ import os import uuid import json import time +from enum import Enum from typing import List from configparser import ConfigParser, Error from rl_coach.orchestrators.deploy import Deploy, DeployParameters @@ -12,10 +13,19 @@ from rl_coach.data_stores.data_store import DataStoreParameters from rl_coach.data_stores.data_store_impl import get_data_store +class RunType(Enum): + ORCHESTRATOR = "orchestrator" + TRAINER = "trainer" + ROLLOUT_WORKER = "rollout-worker" + + def __str__(self): + return self.value + + class RunTypeParameters(): def __init__(self, image: str, command: list(), arguments: list() = None, - run_type: str = "trainer", checkpoint_dir: str = "/checkpoint", + run_type: str = str(RunType.TRAINER), checkpoint_dir: str = "/checkpoint", num_replicas: int = 1, orchestration_params: dict=None): self.image = image self.command = command @@ -97,12 +107,12 @@ class Kubernetes(Deploy): def deploy_trainer(self) -> bool: - trainer_params = self.params.run_type_params.get('trainer', None) + trainer_params = self.params.run_type_params.get(str(RunType.TRAINER), None) if not trainer_params: return False - trainer_params.command += ['--memory-backend-params', json.dumps(self.params.memory_backend_parameters.__dict__)] - trainer_params.command += ['--data-store-params', json.dumps(self.params.data_store_params.__dict__)] + trainer_params.command += ['--memory_backend_params', json.dumps(self.params.memory_backend_parameters.__dict__)] + trainer_params.command += ['--data_store_params', json.dumps(self.params.data_store_params.__dict__)] name = "{}-{}".format(trainer_params.run_type, uuid.uuid4()) @@ -175,13 +185,13 @@ class Kubernetes(Deploy): def deploy_worker(self): - worker_params = self.params.run_type_params.get('worker', None) + worker_params = self.params.run_type_params.get(str(RunType.ROLLOUT_WORKER), None) if not worker_params: return False - worker_params.command += ['--memory-backend-params', json.dumps(self.params.memory_backend_parameters.__dict__)] - worker_params.command += ['--data-store-params', json.dumps(self.params.data_store_params.__dict__)] - worker_params.command += ['--num-workers', '{}'.format(worker_params.num_replicas)] + worker_params.command += ['--memory_backend_params', json.dumps(self.params.memory_backend_parameters.__dict__)] + worker_params.command += ['--data_store_params', json.dumps(self.params.data_store_params.__dict__)] + worker_params.command += ['--num_workers', '{}'.format(worker_params.num_replicas)] name = "{}-{}".format(worker_params.run_type, uuid.uuid4()) @@ -255,7 +265,7 @@ class Kubernetes(Deploy): pass def trainer_logs(self): - trainer_params = self.params.run_type_params.get('trainer', None) + trainer_params = self.params.run_type_params.get(str(RunType.TRAINER), None) if not trainer_params: return @@ -313,7 +323,7 @@ class Kubernetes(Deploy): return def undeploy(self): - trainer_params = self.params.run_type_params.get('trainer', None) + trainer_params = self.params.run_type_params.get(str(RunType.TRAINER), None) api_client = k8sclient.AppsV1Api() delete_options = k8sclient.V1DeleteOptions() if trainer_params: @@ -321,7 +331,7 @@ class Kubernetes(Deploy): api_client.delete_namespaced_deployment(trainer_params.orchestration_params['deployment_name'], self.params.namespace, delete_options) except k8sclient.rest.ApiException as e: print("Got exception: %s\n while deleting trainer", e) - worker_params = self.params.run_type_params.get('worker', None) + worker_params = self.params.run_type_params.get(str(RunType.ROLLOUT_WORKER), None) if worker_params: try: api_client.delete_namespaced_deployment(worker_params.orchestration_params['deployment_name'], self.params.namespace, delete_options) diff --git a/rl_coach/orchestrators/start_training.py b/rl_coach/orchestrators/start_training.py deleted file mode 100644 index 100a870..0000000 --- a/rl_coach/orchestrators/start_training.py +++ /dev/null @@ -1,120 +0,0 @@ -import argparse - -from rl_coach.orchestrators.kubernetes_orchestrator import KubernetesParameters, Kubernetes, RunTypeParameters -from rl_coach.memories.backend.redis import RedisPubSubMemoryBackendParameters -from rl_coach.data_stores.data_store import DataStoreParameters -from rl_coach.data_stores.s3_data_store import S3DataStoreParameters -from rl_coach.data_stores.nfs_data_store import NFSDataStoreParameters - - -def main(preset: str, image: str='ajaysudh/testing:coach', num_workers: int=1, nfs_server: str=None, nfs_path: str=None, - memory_backend: str=None, data_store: str=None, s3_end_point: str=None, s3_bucket_name: str=None, - s3_creds_file: str=None, policy_type: str="OFF"): - rollout_command = ['python3', 'rl_coach/rollout_worker.py', '-p', preset, '--policy-type', policy_type] - training_command = ['python3', 'rl_coach/training_worker.py', '-p', preset, '--policy-type', policy_type] - - memory_backend_params = None - if memory_backend == "redispubsub": - memory_backend_params = RedisPubSubMemoryBackendParameters() - - ds_params_instance = None - if data_store == "s3": - ds_params = DataStoreParameters("s3", "", "") - ds_params_instance = S3DataStoreParameters(ds_params=ds_params, end_point=s3_end_point, bucket_name=s3_bucket_name, - creds_file=s3_creds_file, checkpoint_dir="/checkpoint") - elif data_store == "nfs": - ds_params = DataStoreParameters("nfs", "kubernetes", {"namespace": "default"}) - ds_params_instance = NFSDataStoreParameters(ds_params) - - worker_run_type_params = RunTypeParameters(image, rollout_command, run_type="worker", num_replicas=num_workers) - trainer_run_type_params = RunTypeParameters(image, training_command, run_type="trainer") - - orchestration_params = KubernetesParameters([worker_run_type_params, trainer_run_type_params], - kubeconfig='~/.kube/config', nfs_server=nfs_server, nfs_path=nfs_path, - memory_backend_parameters=memory_backend_params, - data_store_params=ds_params_instance) - orchestrator = Kubernetes(orchestration_params) - if not orchestrator.setup(): - print("Could not setup") - return - - if orchestrator.deploy_trainer(): - print("Successfully deployed") - else: - print("Could not deploy") - return - - if orchestrator.deploy_worker(): - print("Successfully deployed") - else: - print("Could not deploy") - return - - try: - orchestrator.trainer_logs() - except KeyboardInterrupt: - pass - orchestrator.undeploy() - - -if __name__ == '__main__': - parser = argparse.ArgumentParser() - parser.add_argument('--image', - help="(string) Name of a docker image.", - type=str, - required=True) - parser.add_argument('-p', '--preset', - help="(string) Name of a preset to run (class name from the 'presets' directory).", - type=str, - required=True) - parser.add_argument('--memory-backend', - help="(string) Memory backend to use.", - type=str, - choices=['redispubsub'], - default="redispubsub") - parser.add_argument('--data-store', - help="(string) Data store to use.", - type=str, - choices=['s3', 'nfs'], - default="s3") - parser.add_argument('--nfs-server', - help="(string) Addresss of the nfs server.", - type=str, - required=False) - parser.add_argument('--nfs-path', - help="(string) Exported path for the nfs server.", - type=str, - required=False) - parser.add_argument('--s3-end-point', - help="(string) S3 endpoint to use when S3 data store is used.", - type=str, - required=False) - parser.add_argument('--s3-bucket-name', - help="(string) S3 bucket name to use when S3 data store is used.", - type=str, - required=False) - parser.add_argument('--s3-creds-file', - help="(string) S3 credentials file to use when S3 data store is used.", - type=str, - required=False) - parser.add_argument('--num-workers', - help="(string) Number of rollout workers.", - type=int, - required=False, - default=1) - parser.add_argument('--policy-type', - help="(string) The type of policy: OFF/ON.", - type=str, - choices=['ON', 'OFF'], - default='OFF') - - # parser.add_argument('--checkpoint_dir', - # help='(string) Path to a folder containing a checkpoint to write the model to.', - # type=str, - # default='/checkpoint') - args = parser.parse_args() - - main(preset=args.preset, image=args.image, nfs_server=args.nfs_server, nfs_path=args.nfs_path, - memory_backend=args.memory_backend, data_store=args.data_store, s3_end_point=args.s3_end_point, - s3_bucket_name=args.s3_bucket_name, s3_creds_file=args.s3_creds_file, num_workers=args.num_workers, - policy_type=args.policy_type) diff --git a/rl_coach/orchestrators/test.py b/rl_coach/orchestrators/test.py deleted file mode 100644 index 3142b6a..0000000 --- a/rl_coach/orchestrators/test.py +++ /dev/null @@ -1,18 +0,0 @@ -from rl_coach.orchestrators.kubernetes_orchestrator import KubernetesParameters, Kubernetes - -# image = 'gcr.io/constant-cubist-173123/coach:latest' -image = 'ajaysudh/testing:coach' -command = ['python3', 'rl_coach/rollout_worker.py', '-p', 'CartPole_DQN_distributed'] -# command = ['sleep', '10h'] - -params = KubernetesParameters(image, command, kubeconfig='~/.kube/config', redis_ip='redis-service.ajay.svc', redis_port=6379, num_workers=1) -# params = KubernetesParameters(image, command, kubeconfig='~/.kube/config') - -obj = Kubernetes(params) -if not obj.setup(): - print("Could not setup") - -if obj.deploy(): - print("Successfully deployed") -else: - print("Could not deploy") diff --git a/rl_coach/presets/CartPole_DQN.py b/rl_coach/presets/CartPole_DQN.py index 02a38d7..ba4472f 100644 --- a/rl_coach/presets/CartPole_DQN.py +++ b/rl_coach/presets/CartPole_DQN.py @@ -1,5 +1,5 @@ from rl_coach.agents.dqn_agent import DQNAgentParameters -from rl_coach.base_parameters import VisualizationParameters, PresetValidationParameters +from rl_coach.base_parameters import VisualizationParameters, PresetValidationParameters, DistributedCoachSynchronizationType from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps from rl_coach.environments.gym_environment import GymVectorEnvironment from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager diff --git a/rl_coach/presets/CartPole_PPO.py b/rl_coach/presets/CartPole_PPO.py index eb2ac09..40e06d1 100644 --- a/rl_coach/presets/CartPole_PPO.py +++ b/rl_coach/presets/CartPole_PPO.py @@ -1,6 +1,6 @@ from rl_coach.agents.clipped_ppo_agent import ClippedPPOAgentParameters from rl_coach.architectures.tensorflow_components.layers import Dense -from rl_coach.base_parameters import VisualizationParameters, PresetValidationParameters +from rl_coach.base_parameters import VisualizationParameters, PresetValidationParameters, DistributedCoachSynchronizationType from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps, RunPhase from rl_coach.environments.gym_environment import GymVectorEnvironment, mujoco_v2 from rl_coach.exploration_policies.additive_noise import AdditiveNoiseParameters @@ -44,6 +44,9 @@ agent_params.algorithm.optimization_epochs = 10 agent_params.algorithm.estimate_state_value_using_gae = True agent_params.algorithm.num_steps_between_copying_online_weights_to_target = EnvironmentSteps(2048) +# Distributed Coach synchronization type. +agent_params.algorithm.distributed_coach_synchronization_type = DistributedCoachSynchronizationType.SYNC + agent_params.exploration = EGreedyParameters() agent_params.exploration.epsilon_schedule = LinearSchedule(1.0, 0.01, 10000) agent_params.pre_network_filter.add_observation_filter('observation', 'normalize_observation', diff --git a/rl_coach/rollout_worker.py b/rl_coach/rollout_worker.py index a3ae386..d039a9e 100644 --- a/rl_coach/rollout_worker.py +++ b/rl_coach/rollout_worker.py @@ -7,28 +7,16 @@ this rollout worker: - exits """ -import argparse import time import os -import json import math -from threading import Thread - -from rl_coach.base_parameters import TaskParameters -from rl_coach.coach import expand_preset +from rl_coach.base_parameters import TaskParameters, DistributedCoachSynchronizationType from rl_coach.core_types import EnvironmentSteps, RunPhase -from rl_coach.utils import short_dynamic_import -from rl_coach.memories.backend.memory_impl import construct_memory_params -from rl_coach.data_stores.data_store_impl import get_data_store, construct_data_store_params from google.protobuf import text_format from tensorflow.python.training.checkpoint_state_pb2 import CheckpointState -# Q: specify alternative distributed memory, or should this go in the preset? -# A: preset must define distributed memory to be used. we aren't going to take -# a non-distributed preset and automatically distribute it. - def has_checkpoint(checkpoint_dir): """ True if a checkpoint is present in checkpoint_dir @@ -39,6 +27,7 @@ def has_checkpoint(checkpoint_dir): return False + def wait_for_checkpoint(checkpoint_dir, data_store=None, timeout=10): """ block until there is a checkpoint in checkpoint_dir @@ -79,7 +68,7 @@ def get_latest_checkpoint(checkpoint_dir): return int(rel_path.split('_Step')[0]) -def rollout_worker(graph_manager, checkpoint_dir, data_store, num_workers, policy_type): +def rollout_worker(graph_manager, checkpoint_dir, data_store, num_workers): """ wait for first checkpoint then perform rollouts using the model """ @@ -102,7 +91,7 @@ def rollout_worker(graph_manager, checkpoint_dir, data_store, num_workers, polic new_checkpoint = get_latest_checkpoint(checkpoint_dir) - if policy_type == 'ON': + if graph_manager.agent_params.algorithm.distributed_coach_synchronization_type == DistributedCoachSynchronizationType.SYNC: while new_checkpoint < last_checkpoint + 1: if data_store: data_store.load_from_store() @@ -110,64 +99,8 @@ def rollout_worker(graph_manager, checkpoint_dir, data_store, num_workers, polic graph_manager.restore_checkpoint() - if policy_type == "OFF": - + if graph_manager.agent_params.algorithm.distributed_coach_synchronization_type == DistributedCoachSynchronizationType.ASYNC: if new_checkpoint > last_checkpoint: graph_manager.restore_checkpoint() last_checkpoint = new_checkpoint - - -def main(): - parser = argparse.ArgumentParser() - parser.add_argument('-p', '--preset', - help="(string) Name of a preset to run (class name from the 'presets' directory.)", - type=str, - required=True) - parser.add_argument('--checkpoint-dir', - help='(string) Path to a folder containing a checkpoint to restore the model from.', - type=str, - default='/checkpoint') - parser.add_argument('--memory-backend-params', - help="(string) JSON string of the memory backend params", - type=str) - parser.add_argument('--data-store-params', - help="(string) JSON string of the data store params", - type=str) - parser.add_argument('--num-workers', - help="(int) The number of workers started in this pool", - type=int, - default=1) - parser.add_argument('--policy-type', - help="(string) The type of policy: OFF/ON", - type=str, - default='OFF') - - args = parser.parse_args() - graph_manager = short_dynamic_import(expand_preset(args.preset), ignore_module_case=True) - - data_store = None - if args.memory_backend_params: - args.memory_backend_params = json.loads(args.memory_backend_params) - args.memory_backend_params['run_type'] = 'worker' - graph_manager.agent_params.memory.register_var('memory_backend_params', construct_memory_params(args.memory_backend_params)) - - if args.data_store_params: - data_store_params = construct_data_store_params(json.loads(args.data_store_params)) - data_store_params.checkpoint_dir = args.checkpoint_dir - graph_manager.data_store_params = data_store_params - data_store = get_data_store(data_store_params) - wait_for_checkpoint(checkpoint_dir=args.checkpoint_dir, data_store=data_store) - # thread = Thread(target = data_store_ckpt_load, args = [data_store]) - # thread.start() - - rollout_worker( - graph_manager=graph_manager, - checkpoint_dir=args.checkpoint_dir, - data_store=data_store, - num_workers=args.num_workers, - policy_type=args.policy_type - ) - -if __name__ == '__main__': - main() diff --git a/rl_coach/training_worker.py b/rl_coach/training_worker.py index e107ad8..3d2f9e5 100644 --- a/rl_coach/training_worker.py +++ b/rl_coach/training_worker.py @@ -1,24 +1,18 @@ """ """ -import argparse import time -import json -from threading import Thread - -from rl_coach.base_parameters import TaskParameters -from rl_coach.coach import expand_preset +from rl_coach.base_parameters import TaskParameters, DistributedCoachSynchronizationType from rl_coach import core_types -from rl_coach.utils import short_dynamic_import -from rl_coach.memories.backend.memory_impl import construct_memory_params -from rl_coach.data_stores.data_store_impl import get_data_store, construct_data_store_params + def data_store_ckpt_save(data_store): while True: data_store.save_to_store() time.sleep(10) -def training_worker(graph_manager, checkpoint_dir, policy_type): + +def training_worker(graph_manager, checkpoint_dir): """ restore a checkpoint then perform rollouts using the restored model """ @@ -49,55 +43,7 @@ def training_worker(graph_manager, checkpoint_dir, policy_type): graph_manager.evaluate(graph_manager.evaluation_steps) eval_offset += 1 - if policy_type == 'ON': + if graph_manager.agent_params.algorithm.distributed_coach_synchronization_type == DistributedCoachSynchronizationType.SYNC: graph_manager.save_checkpoint() else: graph_manager.occasionally_save_checkpoint() - - -def main(): - parser = argparse.ArgumentParser() - parser.add_argument('-p', '--preset', - help="(string) Name of a preset to run (class name from the 'presets' directory.)", - type=str, - required=True) - parser.add_argument('--checkpoint-dir', - help='(string) Path to a folder containing a checkpoint to write the model to.', - type=str, - default='/checkpoint') - parser.add_argument('--memory-backend-params', - help="(string) JSON string of the memory backend params", - type=str) - parser.add_argument('--data-store-params', - help="(string) JSON string of the data store params", - type=str) - parser.add_argument('--policy-type', - help="(string) The type of policy: OFF/ON", - type=str, - default='OFF') - args = parser.parse_args() - - graph_manager = short_dynamic_import(expand_preset(args.preset), ignore_module_case=True) - - if args.memory_backend_params: - args.memory_backend_params = json.loads(args.memory_backend_params) - args.memory_backend_params['run_type'] = 'trainer' - graph_manager.agent_params.memory.register_var('memory_backend_params', construct_memory_params(args.memory_backend_params)) - - if args.data_store_params: - data_store_params = construct_data_store_params(json.loads(args.data_store_params)) - data_store_params.checkpoint_dir = args.checkpoint_dir - graph_manager.data_store_params = data_store_params - # data_store = get_data_store(data_store_params) - # thread = Thread(target = data_store_ckpt_save, args = [data_store]) - # thread.start() - - training_worker( - graph_manager=graph_manager, - checkpoint_dir=args.checkpoint_dir, - policy_type=args.policy_type - ) - - -if __name__ == '__main__': - main() diff --git a/rl_coach/utils.py b/rl_coach/utils.py index b542f1e..76729b3 100644 --- a/rl_coach/utils.py +++ b/rl_coach/utils.py @@ -23,6 +23,7 @@ import signal import sys import threading import time +import traceback from multiprocessing import Manager from subprocess import Popen from typing import List, Tuple @@ -30,6 +31,8 @@ from typing import List, Tuple import atexit import numpy as np +from rl_coach.logger import screen + killed_processes = [] eps = np.finfo(np.float32).eps