mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 19:20:19 +01:00
Refactor launcher to be object-oriented (#63)
* Import of annoy library uses failed_import mechanism.
This commit is contained in:
@@ -51,200 +51,6 @@ if len(set(failed_imports)) > 0:
|
|||||||
screen.warning("Warning: failed to import the following packages - {}".format(', '.join(set(failed_imports))))
|
screen.warning("Warning: failed to import the following packages - {}".format(', '.join(set(failed_imports))))
|
||||||
|
|
||||||
|
|
||||||
def get_graph_manager_from_args(args: argparse.Namespace) -> 'GraphManager':
|
|
||||||
"""
|
|
||||||
Return the graph manager according to the command line arguments given by the user
|
|
||||||
:param args: the arguments given by the user
|
|
||||||
:return: the updated graph manager
|
|
||||||
"""
|
|
||||||
|
|
||||||
graph_manager = None
|
|
||||||
|
|
||||||
# if a preset was given we will load the graph manager for the preset
|
|
||||||
if args.preset is not None:
|
|
||||||
graph_manager = short_dynamic_import(args.preset, ignore_module_case=True)
|
|
||||||
|
|
||||||
# for human play we need to create a custom graph manager
|
|
||||||
if args.play:
|
|
||||||
env_params = short_dynamic_import(args.environment_type, ignore_module_case=True)()
|
|
||||||
env_params.human_control = True
|
|
||||||
schedule_params = HumanPlayScheduleParameters()
|
|
||||||
graph_manager = BasicRLGraphManager(HumanAgentParameters(), env_params, schedule_params, VisualizationParameters())
|
|
||||||
|
|
||||||
# Set framework
|
|
||||||
# Note: Some graph managers (e.g. HAC preset) create multiple agents and the attribute is called agents_params
|
|
||||||
if hasattr(graph_manager, 'agent_params'):
|
|
||||||
for network_parameters in graph_manager.agent_params.network_wrappers.values():
|
|
||||||
network_parameters.framework = args.framework
|
|
||||||
elif hasattr(graph_manager, 'agents_params'):
|
|
||||||
for ap in graph_manager.agents_params:
|
|
||||||
for network_parameters in ap.network_wrappers.values():
|
|
||||||
network_parameters.framework = args.framework
|
|
||||||
|
|
||||||
if args.level:
|
|
||||||
if isinstance(graph_manager.env_params.level, SingleLevelSelection):
|
|
||||||
graph_manager.env_params.level.select(args.level)
|
|
||||||
else:
|
|
||||||
graph_manager.env_params.level = args.level
|
|
||||||
|
|
||||||
# set the seed for the environment
|
|
||||||
if args.seed is not None:
|
|
||||||
graph_manager.env_params.seed = args.seed
|
|
||||||
|
|
||||||
# visualization
|
|
||||||
graph_manager.visualization_parameters.dump_gifs = graph_manager.visualization_parameters.dump_gifs or args.dump_gifs
|
|
||||||
graph_manager.visualization_parameters.dump_mp4 = graph_manager.visualization_parameters.dump_mp4 or args.dump_mp4
|
|
||||||
graph_manager.visualization_parameters.render = args.render
|
|
||||||
graph_manager.visualization_parameters.tensorboard = args.tensorboard
|
|
||||||
graph_manager.visualization_parameters.print_networks_summary = args.print_networks_summary
|
|
||||||
|
|
||||||
# update the custom parameters
|
|
||||||
if args.custom_parameter is not None:
|
|
||||||
unstripped_key_value_pairs = [pair.split('=') for pair in args.custom_parameter.split(';')]
|
|
||||||
stripped_key_value_pairs = [tuple([pair[0].strip(), pair[1].strip()]) for pair in
|
|
||||||
unstripped_key_value_pairs if len(pair) == 2]
|
|
||||||
|
|
||||||
# load custom parameters into run_dict
|
|
||||||
for key, value in stripped_key_value_pairs:
|
|
||||||
exec("graph_manager.{}={}".format(key, value))
|
|
||||||
|
|
||||||
return graph_manager
|
|
||||||
|
|
||||||
|
|
||||||
def display_all_presets_and_exit():
|
|
||||||
# list available presets
|
|
||||||
screen.log_title("Available Presets:")
|
|
||||||
for preset in sorted(list_all_presets()):
|
|
||||||
print(preset)
|
|
||||||
sys.exit(0)
|
|
||||||
|
|
||||||
|
|
||||||
def expand_preset(preset):
|
|
||||||
if preset.lower() in [p.lower() for p in list_all_presets()]:
|
|
||||||
preset = "{}.py:graph_manager".format(os.path.join(get_base_dir(), 'presets', preset))
|
|
||||||
else:
|
|
||||||
preset = "{}".format(preset)
|
|
||||||
# if a graph manager variable was not specified, try the default of :graph_manager
|
|
||||||
if len(preset.split(":")) == 1:
|
|
||||||
preset += ":graph_manager"
|
|
||||||
|
|
||||||
# verify that the preset exists
|
|
||||||
preset_path = preset.split(":")[0]
|
|
||||||
if not os.path.exists(preset_path):
|
|
||||||
screen.error("The given preset ({}) cannot be found.".format(preset))
|
|
||||||
|
|
||||||
# verify that the preset can be instantiated
|
|
||||||
try:
|
|
||||||
short_dynamic_import(preset, ignore_module_case=True)
|
|
||||||
except TypeError as e:
|
|
||||||
traceback.print_exc()
|
|
||||||
screen.error('Internal Error: ' + str(e) + "\n\nThe given preset ({}) cannot be instantiated."
|
|
||||||
.format(preset))
|
|
||||||
|
|
||||||
return preset
|
|
||||||
|
|
||||||
|
|
||||||
def parse_arguments(parser: argparse.ArgumentParser) -> argparse.Namespace:
|
|
||||||
"""
|
|
||||||
Parse the arguments that the user entered
|
|
||||||
:param parser: the argparse command line parser
|
|
||||||
:return: the parsed arguments
|
|
||||||
"""
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
# if no arg is given
|
|
||||||
if len(sys.argv) == 1:
|
|
||||||
parser.print_help()
|
|
||||||
sys.exit(0)
|
|
||||||
|
|
||||||
# list available presets
|
|
||||||
if args.list:
|
|
||||||
display_all_presets_and_exit()
|
|
||||||
|
|
||||||
# Read args from config file for distributed Coach.
|
|
||||||
if args.distributed_coach and args.distributed_coach_run_type == RunType.ORCHESTRATOR:
|
|
||||||
coach_config = ConfigParser({
|
|
||||||
'image': '',
|
|
||||||
'memory_backend': 'redispubsub',
|
|
||||||
'data_store': 's3',
|
|
||||||
's3_end_point': 's3.amazonaws.com',
|
|
||||||
's3_bucket_name': '',
|
|
||||||
's3_creds_file': ''
|
|
||||||
})
|
|
||||||
try:
|
|
||||||
coach_config.read(args.distributed_coach_config_path)
|
|
||||||
args.image = coach_config.get('coach', 'image')
|
|
||||||
args.memory_backend = coach_config.get('coach', 'memory_backend')
|
|
||||||
args.data_store = coach_config.get('coach', 'data_store')
|
|
||||||
args.s3_end_point = coach_config.get('coach', 's3_end_point')
|
|
||||||
args.s3_bucket_name = coach_config.get('coach', 's3_bucket_name')
|
|
||||||
args.s3_creds_file = coach_config.get('coach', 's3_creds_file')
|
|
||||||
except Error as e:
|
|
||||||
screen.error("Error when reading distributed Coach config file: {}".format(e))
|
|
||||||
|
|
||||||
if args.image == '':
|
|
||||||
screen.error("Image cannot be empty.")
|
|
||||||
|
|
||||||
data_store_choices = ['s3']
|
|
||||||
if args.data_store not in data_store_choices:
|
|
||||||
screen.warning("{} data store is unsupported.".format(args.data_store))
|
|
||||||
screen.error("Supported data stores are {}.".format(data_store_choices))
|
|
||||||
|
|
||||||
memory_backend_choices = ['redispubsub']
|
|
||||||
if args.memory_backend not in memory_backend_choices:
|
|
||||||
screen.warning("{} memory backend is not supported.".format(args.memory_backend))
|
|
||||||
screen.error("Supported memory backends are {}.".format(memory_backend_choices))
|
|
||||||
|
|
||||||
if args.s3_bucket_name == '':
|
|
||||||
screen.error("S3 bucket name cannot be empty.")
|
|
||||||
|
|
||||||
if args.s3_creds_file == '':
|
|
||||||
args.s3_creds_file = None
|
|
||||||
|
|
||||||
if args.play and args.distributed_coach:
|
|
||||||
screen.error("Playing is not supported in distributed Coach.")
|
|
||||||
|
|
||||||
# replace a short preset name with the full path
|
|
||||||
if args.preset is not None:
|
|
||||||
args.preset = expand_preset(args.preset)
|
|
||||||
|
|
||||||
# validate the checkpoints args
|
|
||||||
if args.checkpoint_restore_dir is not None and not os.path.exists(args.checkpoint_restore_dir):
|
|
||||||
screen.error("The requested checkpoint folder to load from does not exist.")
|
|
||||||
|
|
||||||
# no preset was given. check if the user requested to play some environment on its own
|
|
||||||
if args.preset is None and args.play and not args.environment_type:
|
|
||||||
screen.error('When no preset is given for Coach to run, and the user requests human control over '
|
|
||||||
'the environment, the user is expected to input the desired environment_type and level.'
|
|
||||||
'\nAt least one of these parameters was not given.')
|
|
||||||
elif args.preset and args.play:
|
|
||||||
screen.error("Both the --preset and the --play flags were set. These flags can not be used together. "
|
|
||||||
"For human control, please use the --play flag together with the environment type flag (-et)")
|
|
||||||
elif args.preset is None and not args.play:
|
|
||||||
screen.error("Please choose a preset using the -p flag or use the --play flag together with choosing an "
|
|
||||||
"environment type (-et) in order to play the game.")
|
|
||||||
|
|
||||||
# get experiment name and path
|
|
||||||
args.experiment_name = logger.get_experiment_name(args.experiment_name)
|
|
||||||
args.experiment_path = logger.get_experiment_path(args.experiment_name)
|
|
||||||
|
|
||||||
if args.play and args.num_workers > 1:
|
|
||||||
screen.warning("Playing the game as a human is only available with a single worker. "
|
|
||||||
"The number of workers will be reduced to 1")
|
|
||||||
args.num_workers = 1
|
|
||||||
|
|
||||||
args.framework = Frameworks[args.framework.lower()]
|
|
||||||
|
|
||||||
# checkpoints
|
|
||||||
args.checkpoint_save_dir = os.path.join(args.experiment_path, 'checkpoint') if args.checkpoint_save_secs is not None else None
|
|
||||||
|
|
||||||
if args.export_onnx_graph and not args.checkpoint_save_secs:
|
|
||||||
screen.warning("Exporting ONNX graphs requires setting the --checkpoint_save_secs flag. "
|
|
||||||
"The --export_onnx_graph will have no effect.")
|
|
||||||
|
|
||||||
return args
|
|
||||||
|
|
||||||
|
|
||||||
def add_items_to_dict(target_dict, source_dict):
|
def add_items_to_dict(target_dict, source_dict):
|
||||||
updated_task_parameters = copy.copy(source_dict)
|
updated_task_parameters = copy.copy(source_dict)
|
||||||
updated_task_parameters.update(target_dict)
|
updated_task_parameters.update(target_dict)
|
||||||
@@ -263,6 +69,10 @@ def open_dashboard(experiment_path):
|
|||||||
|
|
||||||
|
|
||||||
def start_graph(graph_manager: 'GraphManager', task_parameters: 'TaskParameters'):
|
def start_graph(graph_manager: 'GraphManager', task_parameters: 'TaskParameters'):
|
||||||
|
"""
|
||||||
|
Runs the graph_manager using the configured task_parameters.
|
||||||
|
This stand-alone method is a convenience for multiprocessing.
|
||||||
|
"""
|
||||||
graph_manager.create_graph(task_parameters)
|
graph_manager.create_graph(task_parameters)
|
||||||
|
|
||||||
# let the adventure begin
|
# let the adventure begin
|
||||||
@@ -360,172 +170,419 @@ def handle_distributed_coach_orchestrator(graph_manager, args):
|
|||||||
orchestrator.undeploy()
|
orchestrator.undeploy()
|
||||||
|
|
||||||
|
|
||||||
def main():
|
class CoachLauncher(object):
|
||||||
parser = argparse.ArgumentParser()
|
"""
|
||||||
parser.add_argument('-p', '--preset',
|
This class is responsible for gathering all user-specified configuration options, parsing them,
|
||||||
help="(string) Name of a preset to run (class name from the 'presets' directory.)",
|
instantiating a GraphManager and then starting that GraphManager with either improve() or evaluate().
|
||||||
default=None,
|
This class is also responsible for launching multiple processes.
|
||||||
type=str)
|
It is structured so that it can be sub-classed to provide alternate mechanisms to configure and launch
|
||||||
parser.add_argument('-l', '--list',
|
Coach jobs.
|
||||||
help="(flag) List all available presets",
|
|
||||||
action='store_true')
|
|
||||||
parser.add_argument('-e', '--experiment_name',
|
|
||||||
help="(string) Experiment name to be used to store the results.",
|
|
||||||
default='',
|
|
||||||
type=str)
|
|
||||||
parser.add_argument('-r', '--render',
|
|
||||||
help="(flag) Render environment",
|
|
||||||
action='store_true')
|
|
||||||
parser.add_argument('-f', '--framework',
|
|
||||||
help="(string) Neural network framework. Available values: tensorflow",
|
|
||||||
default='tensorflow',
|
|
||||||
type=str)
|
|
||||||
parser.add_argument('-n', '--num_workers',
|
|
||||||
help="(int) Number of workers for multi-process based agents, e.g. A3C",
|
|
||||||
default=1,
|
|
||||||
type=int)
|
|
||||||
parser.add_argument('-c', '--use_cpu',
|
|
||||||
help="(flag) Use only the cpu for training. If a GPU is not available, this flag will have no "
|
|
||||||
"effect and the CPU will be used either way.",
|
|
||||||
action='store_true')
|
|
||||||
parser.add_argument('-ew', '--evaluation_worker',
|
|
||||||
help="(int) If multiple workers are used, add an evaluation worker as well which will "
|
|
||||||
"evaluate asynchronously and independently during the training. NOTE: this worker will "
|
|
||||||
"ignore the evaluation settings in the preset's ScheduleParams.",
|
|
||||||
action='store_true')
|
|
||||||
parser.add_argument('--play',
|
|
||||||
help="(flag) Play as a human by controlling the game with the keyboard. "
|
|
||||||
"This option will save a replay buffer with the game play.",
|
|
||||||
action='store_true')
|
|
||||||
parser.add_argument('--evaluate',
|
|
||||||
help="(flag) Run evaluation only. This is a convenient way to disable "
|
|
||||||
"training in order to evaluate an existing checkpoint.",
|
|
||||||
action='store_true')
|
|
||||||
parser.add_argument('-v', '--verbosity',
|
|
||||||
help="(flag) Sets the verbosity level of Coach print outs. Can be either low or high.",
|
|
||||||
default="low",
|
|
||||||
type=str)
|
|
||||||
parser.add_argument('-tfv', '--tf_verbosity',
|
|
||||||
help="(flag) TensorFlow verbosity level",
|
|
||||||
default=3,
|
|
||||||
type=int)
|
|
||||||
parser.add_argument('--nocolor',
|
|
||||||
help="(flag) Turn off color-codes in screen logging. Ascii text only",
|
|
||||||
action='store_true')
|
|
||||||
parser.add_argument('-s', '--checkpoint_save_secs',
|
|
||||||
help="(int) Time in seconds between saving checkpoints of the model.",
|
|
||||||
default=None,
|
|
||||||
type=int)
|
|
||||||
parser.add_argument('-crd', '--checkpoint_restore_dir',
|
|
||||||
help='(string) Path to a folder containing a checkpoint to restore the model from.',
|
|
||||||
type=str)
|
|
||||||
parser.add_argument('-dg', '--dump_gifs',
|
|
||||||
help="(flag) Enable the gif saving functionality.",
|
|
||||||
action='store_true')
|
|
||||||
parser.add_argument('-dm', '--dump_mp4',
|
|
||||||
help="(flag) Enable the mp4 saving functionality.",
|
|
||||||
action='store_true')
|
|
||||||
parser.add_argument('-et', '--environment_type',
|
|
||||||
help="(string) Choose an environment type class to override on top of the selected preset.",
|
|
||||||
default=None,
|
|
||||||
type=str)
|
|
||||||
parser.add_argument('-lvl', '--level',
|
|
||||||
help="(string) Choose the level that will be played in the environment that was selected."
|
|
||||||
"This value will override the level parameter in the environment class."
|
|
||||||
,
|
|
||||||
default=None,
|
|
||||||
type=str)
|
|
||||||
parser.add_argument('-cp', '--custom_parameter',
|
|
||||||
help="(string) Semicolon separated parameters used to override specific parameters on top of"
|
|
||||||
" the selected preset (or on top of the command-line assembled one). "
|
|
||||||
"Whenever a parameter value is a string, it should be inputted as '\\\"string\\\"'. "
|
|
||||||
"For ex.: "
|
|
||||||
"\"visualization.render=False; num_training_iterations=500; optimizer='rmsprop'\"",
|
|
||||||
default=None,
|
|
||||||
type=str)
|
|
||||||
parser.add_argument('--print_networks_summary',
|
|
||||||
help="(flag) Print network summary to stdout",
|
|
||||||
action='store_true')
|
|
||||||
parser.add_argument('-tb', '--tensorboard',
|
|
||||||
help="(flag) When using the TensorFlow backend, enable TensorBoard log dumps. ",
|
|
||||||
action='store_true')
|
|
||||||
parser.add_argument('-ns', '--no_summary',
|
|
||||||
help="(flag) Prevent Coach from printing a summary and asking questions at the end of runs",
|
|
||||||
action='store_true')
|
|
||||||
parser.add_argument('-d', '--open_dashboard',
|
|
||||||
help="(flag) Open dashboard with the experiment when the run starts",
|
|
||||||
action='store_true')
|
|
||||||
parser.add_argument('--seed',
|
|
||||||
help="(int) A seed to use for running the experiment",
|
|
||||||
default=None,
|
|
||||||
type=int)
|
|
||||||
parser.add_argument('-onnx', '--export_onnx_graph',
|
|
||||||
help="(flag) Export the ONNX graph to the experiment directory. "
|
|
||||||
"This will have effect only if the --checkpoint_save_secs flag is used in order to store "
|
|
||||||
"checkpoints, since the weights checkpoint are needed for the ONNX graph. "
|
|
||||||
"Keep in mind that this can cause major overhead on the experiment. "
|
|
||||||
"Exporting ONNX graphs requires manually installing the tf2onnx package "
|
|
||||||
"(https://github.com/onnx/tensorflow-onnx).",
|
|
||||||
action='store_true')
|
|
||||||
parser.add_argument('-dc', '--distributed_coach',
|
|
||||||
help="(flag) Use distributed Coach.",
|
|
||||||
action='store_true')
|
|
||||||
parser.add_argument('-dcp', '--distributed_coach_config_path',
|
|
||||||
help="(string) Path to config file when using distributed rollout workers."
|
|
||||||
"Only distributed Coach parameters should be provided through this config file."
|
|
||||||
"Rest of the parameters are provided using Coach command line options."
|
|
||||||
"Used only with --distributed_coach flag."
|
|
||||||
"Ignored if --distributed_coach flag is not used.",
|
|
||||||
type=str)
|
|
||||||
parser.add_argument('--memory_backend_params',
|
|
||||||
help=argparse.SUPPRESS,
|
|
||||||
type=str)
|
|
||||||
parser.add_argument('--data_store_params',
|
|
||||||
help=argparse.SUPPRESS,
|
|
||||||
type=str)
|
|
||||||
parser.add_argument('--distributed_coach_run_type',
|
|
||||||
help=argparse.SUPPRESS,
|
|
||||||
type=RunType,
|
|
||||||
default=RunType.ORCHESTRATOR,
|
|
||||||
choices=list(RunType))
|
|
||||||
|
|
||||||
args = parse_arguments(parser)
|
The key entry-point for this class is the .launch() method which is expected to be called from __main__
|
||||||
|
and handle absolutely everything for a job.
|
||||||
|
"""
|
||||||
|
|
||||||
if args.nocolor:
|
def launch(self):
|
||||||
screen.set_use_colors(False)
|
"""
|
||||||
|
Main entry point for the class, and the standard way to run coach from the command line.
|
||||||
|
Parses command-line arguments through argparse, instantiates a GraphManager and then runs it.
|
||||||
|
"""
|
||||||
|
parser = self.get_argument_parser()
|
||||||
|
args = self.get_config_args(parser)
|
||||||
|
graph_manager = self.get_graph_manager_from_args(args)
|
||||||
|
self.run_graph_manager(graph_manager, args)
|
||||||
|
|
||||||
graph_manager = get_graph_manager_from_args(args)
|
|
||||||
|
def get_graph_manager_from_args(self, args: argparse.Namespace) -> 'GraphManager':
|
||||||
|
"""
|
||||||
|
Return the graph manager according to the command line arguments given by the user.
|
||||||
|
:param args: the arguments given by the user
|
||||||
|
:return: the graph manager, not bound to task_parameters yet.
|
||||||
|
"""
|
||||||
|
graph_manager = None
|
||||||
|
|
||||||
if args.distributed_coach and not graph_manager.agent_params.algorithm.distributed_coach_synchronization_type:
|
# if a preset was given we will load the graph manager for the preset
|
||||||
screen.error("{} preset is not supported using distributed Coach.".format(args.preset))
|
if args.preset is not None:
|
||||||
|
graph_manager = short_dynamic_import(args.preset, ignore_module_case=True)
|
||||||
|
|
||||||
# Intel optimized TF seems to run significantly faster when limiting to a single OMP thread.
|
# for human play we need to create a custom graph manager
|
||||||
# This will not affect GPU runs.
|
if args.play:
|
||||||
os.environ["OMP_NUM_THREADS"] = "1"
|
env_params = short_dynamic_import(args.environment_type, ignore_module_case=True)()
|
||||||
|
env_params.human_control = True
|
||||||
|
schedule_params = HumanPlayScheduleParameters()
|
||||||
|
graph_manager = BasicRLGraphManager(HumanAgentParameters(), env_params, schedule_params, VisualizationParameters())
|
||||||
|
|
||||||
# turn TF debug prints off
|
# Set framework
|
||||||
if args.framework == Frameworks.tensorflow:
|
# Note: Some graph managers (e.g. HAC preset) create multiple agents and the attribute is called agents_params
|
||||||
os.environ['TF_CPP_MIN_LOG_LEVEL'] = str(args.tf_verbosity)
|
if hasattr(graph_manager, 'agent_params'):
|
||||||
|
for network_parameters in graph_manager.agent_params.network_wrappers.values():
|
||||||
|
network_parameters.framework = args.framework
|
||||||
|
elif hasattr(graph_manager, 'agents_params'):
|
||||||
|
for ap in graph_manager.agents_params:
|
||||||
|
for network_parameters in ap.network_wrappers.values():
|
||||||
|
network_parameters.framework = args.framework
|
||||||
|
|
||||||
# turn off the summary at the end of the run if necessary
|
if args.level:
|
||||||
if not args.no_summary and not args.distributed_coach:
|
if isinstance(graph_manager.env_params.level, SingleLevelSelection):
|
||||||
atexit.register(logger.summarize_experiment)
|
graph_manager.env_params.level.select(args.level)
|
||||||
screen.change_terminal_title(args.experiment_name)
|
else:
|
||||||
|
graph_manager.env_params.level = args.level
|
||||||
|
|
||||||
# open dashboard
|
# set the seed for the environment
|
||||||
if args.open_dashboard:
|
if args.seed is not None:
|
||||||
open_dashboard(args.experiment_path)
|
graph_manager.env_params.seed = args.seed
|
||||||
|
|
||||||
if args.distributed_coach and args.distributed_coach_run_type != RunType.ORCHESTRATOR:
|
# visualization
|
||||||
handle_distributed_coach_tasks(graph_manager, args)
|
graph_manager.visualization_parameters.dump_gifs = graph_manager.visualization_parameters.dump_gifs or args.dump_gifs
|
||||||
return
|
graph_manager.visualization_parameters.dump_mp4 = graph_manager.visualization_parameters.dump_mp4 or args.dump_mp4
|
||||||
|
graph_manager.visualization_parameters.render = args.render
|
||||||
|
graph_manager.visualization_parameters.tensorboard = args.tensorboard
|
||||||
|
graph_manager.visualization_parameters.print_networks_summary = args.print_networks_summary
|
||||||
|
|
||||||
if args.distributed_coach and args.distributed_coach_run_type == RunType.ORCHESTRATOR:
|
# update the custom parameters
|
||||||
handle_distributed_coach_orchestrator(graph_manager, args)
|
if args.custom_parameter is not None:
|
||||||
return
|
unstripped_key_value_pairs = [pair.split('=') for pair in args.custom_parameter.split(';')]
|
||||||
|
stripped_key_value_pairs = [tuple([pair[0].strip(), pair[1].strip()]) for pair in
|
||||||
|
unstripped_key_value_pairs if len(pair) == 2]
|
||||||
|
|
||||||
# Single-threaded runs
|
# load custom parameters into run_dict
|
||||||
if args.num_workers == 1:
|
for key, value in stripped_key_value_pairs:
|
||||||
|
exec("graph_manager.{}={}".format(key, value))
|
||||||
|
|
||||||
|
return graph_manager
|
||||||
|
|
||||||
|
|
||||||
|
def display_all_presets_and_exit(self):
|
||||||
|
# list available presets
|
||||||
|
screen.log_title("Available Presets:")
|
||||||
|
for preset in sorted(list_all_presets()):
|
||||||
|
print(preset)
|
||||||
|
sys.exit(0)
|
||||||
|
|
||||||
|
|
||||||
|
def expand_preset(self, preset):
|
||||||
|
"""
|
||||||
|
Replace a short preset name with the full python path, and verify that it can be imported.
|
||||||
|
"""
|
||||||
|
if preset.lower() in [p.lower() for p in list_all_presets()]:
|
||||||
|
preset = "{}.py:graph_manager".format(os.path.join(get_base_dir(), 'presets', preset))
|
||||||
|
else:
|
||||||
|
preset = "{}".format(preset)
|
||||||
|
# if a graph manager variable was not specified, try the default of :graph_manager
|
||||||
|
if len(preset.split(":")) == 1:
|
||||||
|
preset += ":graph_manager"
|
||||||
|
|
||||||
|
# verify that the preset exists
|
||||||
|
preset_path = preset.split(":")[0]
|
||||||
|
if not os.path.exists(preset_path):
|
||||||
|
screen.error("The given preset ({}) cannot be found.".format(preset))
|
||||||
|
|
||||||
|
# verify that the preset can be instantiated
|
||||||
|
try:
|
||||||
|
short_dynamic_import(preset, ignore_module_case=True)
|
||||||
|
except TypeError as e:
|
||||||
|
traceback.print_exc()
|
||||||
|
screen.error('Internal Error: ' + str(e) + "\n\nThe given preset ({}) cannot be instantiated."
|
||||||
|
.format(preset))
|
||||||
|
|
||||||
|
return preset
|
||||||
|
|
||||||
|
|
||||||
|
def get_config_args(self, parser: argparse.ArgumentParser) -> argparse.Namespace:
|
||||||
|
"""
|
||||||
|
Returns a Namespace object with all the user-specified configuration options needed to launch.
|
||||||
|
This implementation uses argparse to take arguments from the CLI, but this can be over-ridden by
|
||||||
|
another method that gets its configuration from elsewhere. An equivalent method however must
|
||||||
|
return an identically structured Namespace object, which conforms to the structure defined by
|
||||||
|
get_argument_parser.
|
||||||
|
|
||||||
|
This method parses the arguments that the user entered, does some basic validation, and
|
||||||
|
modification of user-specified values in short form to be more explicit.
|
||||||
|
|
||||||
|
:param parser: a parser object which implicitly defines the format of the Namespace that
|
||||||
|
is expected to be returned.
|
||||||
|
:return: the parsed arguments as a Namespace
|
||||||
|
"""
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if args.nocolor:
|
||||||
|
screen.set_use_colors(False)
|
||||||
|
|
||||||
|
# if no arg is given
|
||||||
|
if len(sys.argv) == 1:
|
||||||
|
parser.print_help()
|
||||||
|
sys.exit(0)
|
||||||
|
|
||||||
|
# list available presets
|
||||||
|
if args.list:
|
||||||
|
self.display_all_presets_and_exit()
|
||||||
|
|
||||||
|
|
||||||
|
# Read args from config file for distributed Coach.
|
||||||
|
if args.distributed_coach and args.distributed_coach_run_type == RunType.ORCHESTRATOR:
|
||||||
|
coach_config = ConfigParser({
|
||||||
|
'image': '',
|
||||||
|
'memory_backend': 'redispubsub',
|
||||||
|
'data_store': 's3',
|
||||||
|
's3_end_point': 's3.amazonaws.com',
|
||||||
|
's3_bucket_name': '',
|
||||||
|
's3_creds_file': ''
|
||||||
|
})
|
||||||
|
try:
|
||||||
|
coach_config.read(args.distributed_coach_config_path)
|
||||||
|
args.image = coach_config.get('coach', 'image')
|
||||||
|
args.memory_backend = coach_config.get('coach', 'memory_backend')
|
||||||
|
args.data_store = coach_config.get('coach', 'data_store')
|
||||||
|
args.s3_end_point = coach_config.get('coach', 's3_end_point')
|
||||||
|
args.s3_bucket_name = coach_config.get('coach', 's3_bucket_name')
|
||||||
|
args.s3_creds_file = coach_config.get('coach', 's3_creds_file')
|
||||||
|
except Error as e:
|
||||||
|
screen.error("Error when reading distributed Coach config file: {}".format(e))
|
||||||
|
|
||||||
|
if args.image == '':
|
||||||
|
screen.error("Image cannot be empty.")
|
||||||
|
|
||||||
|
data_store_choices = ['s3']
|
||||||
|
if args.data_store not in data_store_choices:
|
||||||
|
screen.warning("{} data store is unsupported.".format(args.data_store))
|
||||||
|
screen.error("Supported data stores are {}.".format(data_store_choices))
|
||||||
|
|
||||||
|
memory_backend_choices = ['redispubsub']
|
||||||
|
if args.memory_backend not in memory_backend_choices:
|
||||||
|
screen.warning("{} memory backend is not supported.".format(args.memory_backend))
|
||||||
|
screen.error("Supported memory backends are {}.".format(memory_backend_choices))
|
||||||
|
|
||||||
|
if args.s3_bucket_name == '':
|
||||||
|
screen.error("S3 bucket name cannot be empty.")
|
||||||
|
|
||||||
|
if args.s3_creds_file == '':
|
||||||
|
args.s3_creds_file = None
|
||||||
|
|
||||||
|
if args.play and args.distributed_coach:
|
||||||
|
screen.error("Playing is not supported in distributed Coach.")
|
||||||
|
|
||||||
|
# replace a short preset name with the full path
|
||||||
|
if args.preset is not None:
|
||||||
|
args.preset = self.expand_preset(args.preset)
|
||||||
|
|
||||||
|
# validate the checkpoints args
|
||||||
|
if args.checkpoint_restore_dir is not None and not os.path.exists(args.checkpoint_restore_dir):
|
||||||
|
screen.error("The requested checkpoint folder to load from does not exist.")
|
||||||
|
|
||||||
|
# no preset was given. check if the user requested to play some environment on its own
|
||||||
|
if args.preset is None and args.play and not args.environment_type:
|
||||||
|
screen.error('When no preset is given for Coach to run, and the user requests human control over '
|
||||||
|
'the environment, the user is expected to input the desired environment_type and level.'
|
||||||
|
'\nAt least one of these parameters was not given.')
|
||||||
|
elif args.preset and args.play:
|
||||||
|
screen.error("Both the --preset and the --play flags were set. These flags can not be used together. "
|
||||||
|
"For human control, please use the --play flag together with the environment type flag (-et)")
|
||||||
|
elif args.preset is None and not args.play:
|
||||||
|
screen.error("Please choose a preset using the -p flag or use the --play flag together with choosing an "
|
||||||
|
"environment type (-et) in order to play the game.")
|
||||||
|
|
||||||
|
# get experiment name and path
|
||||||
|
args.experiment_name = logger.get_experiment_name(args.experiment_name)
|
||||||
|
args.experiment_path = logger.get_experiment_path(args.experiment_name)
|
||||||
|
|
||||||
|
if args.play and args.num_workers > 1:
|
||||||
|
screen.warning("Playing the game as a human is only available with a single worker. "
|
||||||
|
"The number of workers will be reduced to 1")
|
||||||
|
args.num_workers = 1
|
||||||
|
|
||||||
|
args.framework = Frameworks[args.framework.lower()]
|
||||||
|
|
||||||
|
# checkpoints
|
||||||
|
args.checkpoint_save_dir = os.path.join(args.experiment_path, 'checkpoint') if args.checkpoint_save_secs is not None else None
|
||||||
|
|
||||||
|
if args.export_onnx_graph and not args.checkpoint_save_secs:
|
||||||
|
screen.warning("Exporting ONNX graphs requires setting the --checkpoint_save_secs flag. "
|
||||||
|
"The --export_onnx_graph will have no effect.")
|
||||||
|
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
def get_argument_parser(self) -> argparse.ArgumentParser:
|
||||||
|
"""
|
||||||
|
This returns an ArgumentParser object which defines the set of options that customers are expected to supply in order
|
||||||
|
to launch a coach job.
|
||||||
|
"""
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('-p', '--preset',
|
||||||
|
help="(string) Name of a preset to run (class name from the 'presets' directory.)",
|
||||||
|
default=None,
|
||||||
|
type=str)
|
||||||
|
parser.add_argument('-l', '--list',
|
||||||
|
help="(flag) List all available presets",
|
||||||
|
action='store_true')
|
||||||
|
parser.add_argument('-e', '--experiment_name',
|
||||||
|
help="(string) Experiment name to be used to store the results.",
|
||||||
|
default='',
|
||||||
|
type=str)
|
||||||
|
parser.add_argument('-r', '--render',
|
||||||
|
help="(flag) Render environment",
|
||||||
|
action='store_true')
|
||||||
|
parser.add_argument('-f', '--framework',
|
||||||
|
help="(string) Neural network framework. Available values: tensorflow",
|
||||||
|
default='tensorflow',
|
||||||
|
type=str)
|
||||||
|
parser.add_argument('-n', '--num_workers',
|
||||||
|
help="(int) Number of workers for multi-process based agents, e.g. A3C",
|
||||||
|
default=1,
|
||||||
|
type=int)
|
||||||
|
parser.add_argument('-c', '--use_cpu',
|
||||||
|
help="(flag) Use only the cpu for training. If a GPU is not available, this flag will have no "
|
||||||
|
"effect and the CPU will be used either way.",
|
||||||
|
action='store_true')
|
||||||
|
parser.add_argument('-ew', '--evaluation_worker',
|
||||||
|
help="(int) If multiple workers are used, add an evaluation worker as well which will "
|
||||||
|
"evaluate asynchronously and independently during the training. NOTE: this worker will "
|
||||||
|
"ignore the evaluation settings in the preset's ScheduleParams.",
|
||||||
|
action='store_true')
|
||||||
|
parser.add_argument('--play',
|
||||||
|
help="(flag) Play as a human by controlling the game with the keyboard. "
|
||||||
|
"This option will save a replay buffer with the game play.",
|
||||||
|
action='store_true')
|
||||||
|
parser.add_argument('--evaluate',
|
||||||
|
help="(flag) Run evaluation only. This is a convenient way to disable "
|
||||||
|
"training in order to evaluate an existing checkpoint.",
|
||||||
|
action='store_true')
|
||||||
|
parser.add_argument('-v', '--verbosity',
|
||||||
|
help="(flag) Sets the verbosity level of Coach print outs. Can be either low or high.",
|
||||||
|
default="low",
|
||||||
|
type=str)
|
||||||
|
parser.add_argument('-tfv', '--tf_verbosity',
|
||||||
|
help="(flag) TensorFlow verbosity level",
|
||||||
|
default=3,
|
||||||
|
type=int)
|
||||||
|
parser.add_argument('--nocolor',
|
||||||
|
help="(flag) Turn off color-codes in screen logging. Ascii text only",
|
||||||
|
action='store_true')
|
||||||
|
parser.add_argument('-s', '--checkpoint_save_secs',
|
||||||
|
help="(int) Time in seconds between saving checkpoints of the model.",
|
||||||
|
default=None,
|
||||||
|
type=int)
|
||||||
|
parser.add_argument('-crd', '--checkpoint_restore_dir',
|
||||||
|
help='(string) Path to a folder containing a checkpoint to restore the model from.',
|
||||||
|
type=str)
|
||||||
|
parser.add_argument('-dg', '--dump_gifs',
|
||||||
|
help="(flag) Enable the gif saving functionality.",
|
||||||
|
action='store_true')
|
||||||
|
parser.add_argument('-dm', '--dump_mp4',
|
||||||
|
help="(flag) Enable the mp4 saving functionality.",
|
||||||
|
action='store_true')
|
||||||
|
parser.add_argument('-et', '--environment_type',
|
||||||
|
help="(string) Choose an environment type class to override on top of the selected preset.",
|
||||||
|
default=None,
|
||||||
|
type=str)
|
||||||
|
parser.add_argument('-ept', '--exploration_policy_type',
|
||||||
|
help="(string) Choose an exploration policy type class to override on top of the selected "
|
||||||
|
"preset."
|
||||||
|
"If no preset is defined, a preset can be set from the command-line by combining settings "
|
||||||
|
"which are set by using --agent_type, --experiment_type, --environemnt_type"
|
||||||
|
,
|
||||||
|
default=None,
|
||||||
|
type=str)
|
||||||
|
parser.add_argument('-lvl', '--level',
|
||||||
|
help="(string) Choose the level that will be played in the environment that was selected."
|
||||||
|
"This value will override the level parameter in the environment class."
|
||||||
|
,
|
||||||
|
default=None,
|
||||||
|
type=str)
|
||||||
|
parser.add_argument('-cp', '--custom_parameter',
|
||||||
|
help="(string) Semicolon separated parameters used to override specific parameters on top of"
|
||||||
|
" the selected preset (or on top of the command-line assembled one). "
|
||||||
|
"Whenever a parameter value is a string, it should be inputted as '\\\"string\\\"'. "
|
||||||
|
"For ex.: "
|
||||||
|
"\"visualization.render=False; num_training_iterations=500; optimizer='rmsprop'\"",
|
||||||
|
default=None,
|
||||||
|
type=str)
|
||||||
|
parser.add_argument('--print_networks_summary',
|
||||||
|
help="(flag) Print network summary to stdout",
|
||||||
|
action='store_true')
|
||||||
|
parser.add_argument('-tb', '--tensorboard',
|
||||||
|
help="(flag) When using the TensorFlow backend, enable TensorBoard log dumps. ",
|
||||||
|
action='store_true')
|
||||||
|
parser.add_argument('-ns', '--no_summary',
|
||||||
|
help="(flag) Prevent Coach from printing a summary and asking questions at the end of runs",
|
||||||
|
action='store_true')
|
||||||
|
parser.add_argument('-d', '--open_dashboard',
|
||||||
|
help="(flag) Open dashboard with the experiment when the run starts",
|
||||||
|
action='store_true')
|
||||||
|
parser.add_argument('--seed',
|
||||||
|
help="(int) A seed to use for running the experiment",
|
||||||
|
default=None,
|
||||||
|
type=int)
|
||||||
|
parser.add_argument('-onnx', '--export_onnx_graph',
|
||||||
|
help="(flag) Export the ONNX graph to the experiment directory. "
|
||||||
|
"This will have effect only if the --checkpoint_save_secs flag is used in order to store "
|
||||||
|
"checkpoints, since the weights checkpoint are needed for the ONNX graph. "
|
||||||
|
"Keep in mind that this can cause major overhead on the experiment. "
|
||||||
|
"Exporting ONNX graphs requires manually installing the tf2onnx package "
|
||||||
|
"(https://github.com/onnx/tensorflow-onnx).",
|
||||||
|
action='store_true')
|
||||||
|
parser.add_argument('-dc', '--distributed_coach',
|
||||||
|
help="(flag) Use distributed Coach.",
|
||||||
|
action='store_true')
|
||||||
|
parser.add_argument('-dcp', '--distributed_coach_config_path',
|
||||||
|
help="(string) Path to config file when using distributed rollout workers."
|
||||||
|
"Only distributed Coach parameters should be provided through this config file."
|
||||||
|
"Rest of the parameters are provided using Coach command line options."
|
||||||
|
"Used only with --distributed_coach flag."
|
||||||
|
"Ignored if --distributed_coach flag is not used.",
|
||||||
|
type=str)
|
||||||
|
parser.add_argument('--memory_backend_params',
|
||||||
|
help=argparse.SUPPRESS,
|
||||||
|
type=str)
|
||||||
|
parser.add_argument('--data_store_params',
|
||||||
|
help=argparse.SUPPRESS,
|
||||||
|
type=str)
|
||||||
|
parser.add_argument('--distributed_coach_run_type',
|
||||||
|
help=argparse.SUPPRESS,
|
||||||
|
type=RunType,
|
||||||
|
default=RunType.ORCHESTRATOR,
|
||||||
|
choices=list(RunType))
|
||||||
|
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def run_graph_manager(self, graph_manager: 'GraphManager', args: argparse.Namespace):
|
||||||
|
if args.distributed_coach and not graph_manager.agent_params.algorithm.distributed_coach_synchronization_type:
|
||||||
|
screen.error("{} algorithm is not supported using distributed Coach.".format(graph_manager.agent_params.algorithm))
|
||||||
|
|
||||||
|
# Intel optimized TF seems to run significantly faster when limiting to a single OMP thread.
|
||||||
|
# This will not affect GPU runs.
|
||||||
|
os.environ["OMP_NUM_THREADS"] = "1"
|
||||||
|
|
||||||
|
# turn TF debug prints off
|
||||||
|
if args.framework == Frameworks.tensorflow:
|
||||||
|
os.environ['TF_CPP_MIN_LOG_LEVEL'] = str(args.tf_verbosity)
|
||||||
|
|
||||||
|
# turn off the summary at the end of the run if necessary
|
||||||
|
if not args.no_summary and not args.distributed_coach:
|
||||||
|
atexit.register(logger.summarize_experiment)
|
||||||
|
screen.change_terminal_title(args.experiment_name)
|
||||||
|
|
||||||
|
# open dashboard
|
||||||
|
if args.open_dashboard:
|
||||||
|
open_dashboard(args.experiment_path)
|
||||||
|
|
||||||
|
if args.distributed_coach and args.distributed_coach_run_type != RunType.ORCHESTRATOR:
|
||||||
|
handle_distributed_coach_tasks(graph_manager, args)
|
||||||
|
return
|
||||||
|
|
||||||
|
if args.distributed_coach and args.distributed_coach_run_type == RunType.ORCHESTRATOR:
|
||||||
|
handle_distributed_coach_orchestrator(graph_manager, args)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Single-threaded runs
|
||||||
|
if args.num_workers == 1:
|
||||||
|
self.start_single_threaded(graph_manager, args)
|
||||||
|
else:
|
||||||
|
self.start_multi_threaded(graph_manager, args)
|
||||||
|
|
||||||
|
|
||||||
|
def start_single_threaded(self, graph_manager: 'GraphManager', args: argparse.Namespace):
|
||||||
# Start the training or evaluation
|
# Start the training or evaluation
|
||||||
task_parameters = TaskParameters(
|
task_parameters = TaskParameters(
|
||||||
framework_type=args.framework,
|
framework_type=args.framework,
|
||||||
@@ -541,8 +598,8 @@ def main():
|
|||||||
|
|
||||||
start_graph(graph_manager=graph_manager, task_parameters=task_parameters)
|
start_graph(graph_manager=graph_manager, task_parameters=task_parameters)
|
||||||
|
|
||||||
# Multi-threaded runs
|
|
||||||
else:
|
def start_multi_threaded(self, graph_manager: 'GraphManager', args: argparse.Namespace):
|
||||||
total_tasks = args.num_workers
|
total_tasks = args.num_workers
|
||||||
if args.evaluation_worker:
|
if args.evaluation_worker:
|
||||||
total_tasks += 1
|
total_tasks += 1
|
||||||
@@ -578,7 +635,6 @@ def main():
|
|||||||
checkpoint_save_dir=args.checkpoint_save_dir,
|
checkpoint_save_dir=args.checkpoint_save_dir,
|
||||||
export_onnx_graph=args.export_onnx_graph
|
export_onnx_graph=args.export_onnx_graph
|
||||||
)
|
)
|
||||||
|
|
||||||
# we assume that only the evaluation workers are rendering
|
# we assume that only the evaluation workers are rendering
|
||||||
graph_manager.visualization_parameters.render = args.render and evaluation_worker
|
graph_manager.visualization_parameters.render = args.render and evaluation_worker
|
||||||
p = Process(target=start_graph, args=(graph_manager, task_parameters))
|
p = Process(target=start_graph, args=(graph_manager, task_parameters))
|
||||||
@@ -608,4 +664,5 @@ def main():
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
launcher = CoachLauncher()
|
||||||
|
launcher.launch()
|
||||||
|
|||||||
@@ -18,7 +18,12 @@ import os
|
|||||||
import pickle
|
import pickle
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from annoy import AnnoyIndex
|
try:
|
||||||
|
import annoy
|
||||||
|
from annoy import AnnoyIndex
|
||||||
|
except ImportError:
|
||||||
|
from rl_coach.logger import failed_imports
|
||||||
|
failed_imports.append("annoy")
|
||||||
|
|
||||||
|
|
||||||
class AnnoyDictionary(object):
|
class AnnoyDictionary(object):
|
||||||
@@ -283,4 +288,4 @@ def load_dnd(model_dir):
|
|||||||
|
|
||||||
DND.dicts[a].index.build(50)
|
DND.dicts[a].index.build(50)
|
||||||
|
|
||||||
return DND
|
return DND
|
||||||
|
|||||||
Reference in New Issue
Block a user