diff --git a/rl_coach/coach.py b/rl_coach/coach.py index aa18a14..1e6bfba 100644 --- a/rl_coach/coach.py +++ b/rl_coach/coach.py @@ -92,6 +92,39 @@ def get_graph_manager_from_args(args: argparse.Namespace) -> 'GraphManager': return graph_manager +def display_all_presets_and_exit(): + # list available presets + if args.list: + screen.log_title("Available Presets:") + for preset in sorted(list_all_presets()): + print(preset) + sys.exit(0) + +def expand_preset(preset): + if preset.lower() in [p.lower() for p in list_all_presets()]: + preset = "{}.py:graph_manager".format(os.path.join(get_base_dir(), 'presets', preset)) + else: + preset = "{}".format(preset) + # if a graph manager variable was not specified, try the default of :graph_manager + if len(preset.split(":")) == 1: + preset += ":graph_manager" + + # verify that the preset exists + preset_path = preset.split(":")[0] + if not os.path.exists(preset_path): + screen.error("The given preset ({}) cannot be found.".format(preset)) + + # verify that the preset can be instantiated + try: + short_dynamic_import(preset, ignore_module_case=True) + except TypeError as e: + traceback.print_exc() + screen.error('Internal Error: ' + str(e) + "\n\nThe given preset ({}) cannot be instantiated." + .format(preset)) + + return preset + + def parse_arguments(parser: argparse.ArgumentParser) -> argparse.Namespace: """ Parse the arguments that the user entered @@ -103,38 +136,15 @@ def parse_arguments(parser: argparse.ArgumentParser) -> argparse.Namespace: # if no arg is given if len(sys.argv) == 1: parser.print_help() - exit(0) + sys.exit(0) # list available presets - preset_names = list_all_presets() if args.list: - screen.log_title("Available Presets:") - for preset in sorted(preset_names): - print(preset) - sys.exit(0) + display_all_presets_and_exit() # replace a short preset name with the full path if args.preset is not None: - if args.preset.lower() in [p.lower() for p in preset_names]: - args.preset = "{}.py:graph_manager".format(os.path.join(get_base_dir(), 'presets', args.preset)) - else: - args.preset = "{}".format(args.preset) - # if a graph manager variable was not specified, try the default of :graph_manager - if len(args.preset.split(":")) == 1: - args.preset += ":graph_manager" - - # verify that the preset exists - preset_path = args.preset.split(":")[0] - if not os.path.exists(preset_path): - screen.error("The given preset ({}) cannot be found.".format(args.preset)) - - # verify that the preset can be instantiated - try: - short_dynamic_import(args.preset, ignore_module_case=True) - except TypeError as e: - traceback.print_exc() - screen.error('Internal Error: ' + str(e) + "\n\nThe given preset ({}) cannot be instantiated." - .format(args.preset)) + args.preset = expand_preset(args.preset) # validate the checkpoints args if args.checkpoint_restore_dir is not None and not os.path.exists(args.checkpoint_restore_dir): @@ -179,6 +189,9 @@ def add_items_to_dict(target_dict, source_dict): def open_dashboard(experiment_path): + """ + open X11 based dashboard in a new process (nonblocking) + """ dashboard_path = 'python {}/dashboard.py'.format(get_base_dir()) cmd = "{} --experiment_dir {}".format(dashboard_path, experiment_path) screen.log_title("Opening dashboard - experiment path: {}".format(experiment_path))