1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-18 03:30:19 +01:00

Refactor launcher to be object-oriented (#63)

* Import of annoy library uses failed_import mechanism.
This commit is contained in:
Leo Dirac
2018-11-10 12:10:19 -08:00
committed by Gal Leibovich
parent 3fd433ffab
commit 2804a7c24f
2 changed files with 417 additions and 355 deletions

View File

@@ -51,200 +51,6 @@ if len(set(failed_imports)) > 0:
screen.warning("Warning: failed to import the following packages - {}".format(', '.join(set(failed_imports)))) screen.warning("Warning: failed to import the following packages - {}".format(', '.join(set(failed_imports))))
def get_graph_manager_from_args(args: argparse.Namespace) -> 'GraphManager':
"""
Return the graph manager according to the command line arguments given by the user
:param args: the arguments given by the user
:return: the updated graph manager
"""
graph_manager = None
# if a preset was given we will load the graph manager for the preset
if args.preset is not None:
graph_manager = short_dynamic_import(args.preset, ignore_module_case=True)
# for human play we need to create a custom graph manager
if args.play:
env_params = short_dynamic_import(args.environment_type, ignore_module_case=True)()
env_params.human_control = True
schedule_params = HumanPlayScheduleParameters()
graph_manager = BasicRLGraphManager(HumanAgentParameters(), env_params, schedule_params, VisualizationParameters())
# Set framework
# Note: Some graph managers (e.g. HAC preset) create multiple agents and the attribute is called agents_params
if hasattr(graph_manager, 'agent_params'):
for network_parameters in graph_manager.agent_params.network_wrappers.values():
network_parameters.framework = args.framework
elif hasattr(graph_manager, 'agents_params'):
for ap in graph_manager.agents_params:
for network_parameters in ap.network_wrappers.values():
network_parameters.framework = args.framework
if args.level:
if isinstance(graph_manager.env_params.level, SingleLevelSelection):
graph_manager.env_params.level.select(args.level)
else:
graph_manager.env_params.level = args.level
# set the seed for the environment
if args.seed is not None:
graph_manager.env_params.seed = args.seed
# visualization
graph_manager.visualization_parameters.dump_gifs = graph_manager.visualization_parameters.dump_gifs or args.dump_gifs
graph_manager.visualization_parameters.dump_mp4 = graph_manager.visualization_parameters.dump_mp4 or args.dump_mp4
graph_manager.visualization_parameters.render = args.render
graph_manager.visualization_parameters.tensorboard = args.tensorboard
graph_manager.visualization_parameters.print_networks_summary = args.print_networks_summary
# update the custom parameters
if args.custom_parameter is not None:
unstripped_key_value_pairs = [pair.split('=') for pair in args.custom_parameter.split(';')]
stripped_key_value_pairs = [tuple([pair[0].strip(), pair[1].strip()]) for pair in
unstripped_key_value_pairs if len(pair) == 2]
# load custom parameters into run_dict
for key, value in stripped_key_value_pairs:
exec("graph_manager.{}={}".format(key, value))
return graph_manager
def display_all_presets_and_exit():
# list available presets
screen.log_title("Available Presets:")
for preset in sorted(list_all_presets()):
print(preset)
sys.exit(0)
def expand_preset(preset):
if preset.lower() in [p.lower() for p in list_all_presets()]:
preset = "{}.py:graph_manager".format(os.path.join(get_base_dir(), 'presets', preset))
else:
preset = "{}".format(preset)
# if a graph manager variable was not specified, try the default of :graph_manager
if len(preset.split(":")) == 1:
preset += ":graph_manager"
# verify that the preset exists
preset_path = preset.split(":")[0]
if not os.path.exists(preset_path):
screen.error("The given preset ({}) cannot be found.".format(preset))
# verify that the preset can be instantiated
try:
short_dynamic_import(preset, ignore_module_case=True)
except TypeError as e:
traceback.print_exc()
screen.error('Internal Error: ' + str(e) + "\n\nThe given preset ({}) cannot be instantiated."
.format(preset))
return preset
def parse_arguments(parser: argparse.ArgumentParser) -> argparse.Namespace:
"""
Parse the arguments that the user entered
:param parser: the argparse command line parser
:return: the parsed arguments
"""
args = parser.parse_args()
# if no arg is given
if len(sys.argv) == 1:
parser.print_help()
sys.exit(0)
# list available presets
if args.list:
display_all_presets_and_exit()
# Read args from config file for distributed Coach.
if args.distributed_coach and args.distributed_coach_run_type == RunType.ORCHESTRATOR:
coach_config = ConfigParser({
'image': '',
'memory_backend': 'redispubsub',
'data_store': 's3',
's3_end_point': 's3.amazonaws.com',
's3_bucket_name': '',
's3_creds_file': ''
})
try:
coach_config.read(args.distributed_coach_config_path)
args.image = coach_config.get('coach', 'image')
args.memory_backend = coach_config.get('coach', 'memory_backend')
args.data_store = coach_config.get('coach', 'data_store')
args.s3_end_point = coach_config.get('coach', 's3_end_point')
args.s3_bucket_name = coach_config.get('coach', 's3_bucket_name')
args.s3_creds_file = coach_config.get('coach', 's3_creds_file')
except Error as e:
screen.error("Error when reading distributed Coach config file: {}".format(e))
if args.image == '':
screen.error("Image cannot be empty.")
data_store_choices = ['s3']
if args.data_store not in data_store_choices:
screen.warning("{} data store is unsupported.".format(args.data_store))
screen.error("Supported data stores are {}.".format(data_store_choices))
memory_backend_choices = ['redispubsub']
if args.memory_backend not in memory_backend_choices:
screen.warning("{} memory backend is not supported.".format(args.memory_backend))
screen.error("Supported memory backends are {}.".format(memory_backend_choices))
if args.s3_bucket_name == '':
screen.error("S3 bucket name cannot be empty.")
if args.s3_creds_file == '':
args.s3_creds_file = None
if args.play and args.distributed_coach:
screen.error("Playing is not supported in distributed Coach.")
# replace a short preset name with the full path
if args.preset is not None:
args.preset = expand_preset(args.preset)
# validate the checkpoints args
if args.checkpoint_restore_dir is not None and not os.path.exists(args.checkpoint_restore_dir):
screen.error("The requested checkpoint folder to load from does not exist.")
# no preset was given. check if the user requested to play some environment on its own
if args.preset is None and args.play and not args.environment_type:
screen.error('When no preset is given for Coach to run, and the user requests human control over '
'the environment, the user is expected to input the desired environment_type and level.'
'\nAt least one of these parameters was not given.')
elif args.preset and args.play:
screen.error("Both the --preset and the --play flags were set. These flags can not be used together. "
"For human control, please use the --play flag together with the environment type flag (-et)")
elif args.preset is None and not args.play:
screen.error("Please choose a preset using the -p flag or use the --play flag together with choosing an "
"environment type (-et) in order to play the game.")
# get experiment name and path
args.experiment_name = logger.get_experiment_name(args.experiment_name)
args.experiment_path = logger.get_experiment_path(args.experiment_name)
if args.play and args.num_workers > 1:
screen.warning("Playing the game as a human is only available with a single worker. "
"The number of workers will be reduced to 1")
args.num_workers = 1
args.framework = Frameworks[args.framework.lower()]
# checkpoints
args.checkpoint_save_dir = os.path.join(args.experiment_path, 'checkpoint') if args.checkpoint_save_secs is not None else None
if args.export_onnx_graph and not args.checkpoint_save_secs:
screen.warning("Exporting ONNX graphs requires setting the --checkpoint_save_secs flag. "
"The --export_onnx_graph will have no effect.")
return args
def add_items_to_dict(target_dict, source_dict): def add_items_to_dict(target_dict, source_dict):
updated_task_parameters = copy.copy(source_dict) updated_task_parameters = copy.copy(source_dict)
updated_task_parameters.update(target_dict) updated_task_parameters.update(target_dict)
@@ -263,6 +69,10 @@ def open_dashboard(experiment_path):
def start_graph(graph_manager: 'GraphManager', task_parameters: 'TaskParameters'): def start_graph(graph_manager: 'GraphManager', task_parameters: 'TaskParameters'):
"""
Runs the graph_manager using the configured task_parameters.
This stand-alone method is a convenience for multiprocessing.
"""
graph_manager.create_graph(task_parameters) graph_manager.create_graph(task_parameters)
# let the adventure begin # let the adventure begin
@@ -360,7 +170,243 @@ def handle_distributed_coach_orchestrator(graph_manager, args):
orchestrator.undeploy() orchestrator.undeploy()
def main(): class CoachLauncher(object):
"""
This class is responsible for gathering all user-specified configuration options, parsing them,
instantiating a GraphManager and then starting that GraphManager with either improve() or evaluate().
This class is also responsible for launching multiple processes.
It is structured so that it can be sub-classed to provide alternate mechanisms to configure and launch
Coach jobs.
The key entry-point for this class is the .launch() method which is expected to be called from __main__
and handle absolutely everything for a job.
"""
def launch(self):
"""
Main entry point for the class, and the standard way to run coach from the command line.
Parses command-line arguments through argparse, instantiates a GraphManager and then runs it.
"""
parser = self.get_argument_parser()
args = self.get_config_args(parser)
graph_manager = self.get_graph_manager_from_args(args)
self.run_graph_manager(graph_manager, args)
def get_graph_manager_from_args(self, args: argparse.Namespace) -> 'GraphManager':
"""
Return the graph manager according to the command line arguments given by the user.
:param args: the arguments given by the user
:return: the graph manager, not bound to task_parameters yet.
"""
graph_manager = None
# if a preset was given we will load the graph manager for the preset
if args.preset is not None:
graph_manager = short_dynamic_import(args.preset, ignore_module_case=True)
# for human play we need to create a custom graph manager
if args.play:
env_params = short_dynamic_import(args.environment_type, ignore_module_case=True)()
env_params.human_control = True
schedule_params = HumanPlayScheduleParameters()
graph_manager = BasicRLGraphManager(HumanAgentParameters(), env_params, schedule_params, VisualizationParameters())
# Set framework
# Note: Some graph managers (e.g. HAC preset) create multiple agents and the attribute is called agents_params
if hasattr(graph_manager, 'agent_params'):
for network_parameters in graph_manager.agent_params.network_wrappers.values():
network_parameters.framework = args.framework
elif hasattr(graph_manager, 'agents_params'):
for ap in graph_manager.agents_params:
for network_parameters in ap.network_wrappers.values():
network_parameters.framework = args.framework
if args.level:
if isinstance(graph_manager.env_params.level, SingleLevelSelection):
graph_manager.env_params.level.select(args.level)
else:
graph_manager.env_params.level = args.level
# set the seed for the environment
if args.seed is not None:
graph_manager.env_params.seed = args.seed
# visualization
graph_manager.visualization_parameters.dump_gifs = graph_manager.visualization_parameters.dump_gifs or args.dump_gifs
graph_manager.visualization_parameters.dump_mp4 = graph_manager.visualization_parameters.dump_mp4 or args.dump_mp4
graph_manager.visualization_parameters.render = args.render
graph_manager.visualization_parameters.tensorboard = args.tensorboard
graph_manager.visualization_parameters.print_networks_summary = args.print_networks_summary
# update the custom parameters
if args.custom_parameter is not None:
unstripped_key_value_pairs = [pair.split('=') for pair in args.custom_parameter.split(';')]
stripped_key_value_pairs = [tuple([pair[0].strip(), pair[1].strip()]) for pair in
unstripped_key_value_pairs if len(pair) == 2]
# load custom parameters into run_dict
for key, value in stripped_key_value_pairs:
exec("graph_manager.{}={}".format(key, value))
return graph_manager
def display_all_presets_and_exit(self):
# list available presets
screen.log_title("Available Presets:")
for preset in sorted(list_all_presets()):
print(preset)
sys.exit(0)
def expand_preset(self, preset):
"""
Replace a short preset name with the full python path, and verify that it can be imported.
"""
if preset.lower() in [p.lower() for p in list_all_presets()]:
preset = "{}.py:graph_manager".format(os.path.join(get_base_dir(), 'presets', preset))
else:
preset = "{}".format(preset)
# if a graph manager variable was not specified, try the default of :graph_manager
if len(preset.split(":")) == 1:
preset += ":graph_manager"
# verify that the preset exists
preset_path = preset.split(":")[0]
if not os.path.exists(preset_path):
screen.error("The given preset ({}) cannot be found.".format(preset))
# verify that the preset can be instantiated
try:
short_dynamic_import(preset, ignore_module_case=True)
except TypeError as e:
traceback.print_exc()
screen.error('Internal Error: ' + str(e) + "\n\nThe given preset ({}) cannot be instantiated."
.format(preset))
return preset
def get_config_args(self, parser: argparse.ArgumentParser) -> argparse.Namespace:
"""
Returns a Namespace object with all the user-specified configuration options needed to launch.
This implementation uses argparse to take arguments from the CLI, but this can be over-ridden by
another method that gets its configuration from elsewhere. An equivalent method however must
return an identically structured Namespace object, which conforms to the structure defined by
get_argument_parser.
This method parses the arguments that the user entered, does some basic validation, and
modification of user-specified values in short form to be more explicit.
:param parser: a parser object which implicitly defines the format of the Namespace that
is expected to be returned.
:return: the parsed arguments as a Namespace
"""
args = parser.parse_args()
if args.nocolor:
screen.set_use_colors(False)
# if no arg is given
if len(sys.argv) == 1:
parser.print_help()
sys.exit(0)
# list available presets
if args.list:
self.display_all_presets_and_exit()
# Read args from config file for distributed Coach.
if args.distributed_coach and args.distributed_coach_run_type == RunType.ORCHESTRATOR:
coach_config = ConfigParser({
'image': '',
'memory_backend': 'redispubsub',
'data_store': 's3',
's3_end_point': 's3.amazonaws.com',
's3_bucket_name': '',
's3_creds_file': ''
})
try:
coach_config.read(args.distributed_coach_config_path)
args.image = coach_config.get('coach', 'image')
args.memory_backend = coach_config.get('coach', 'memory_backend')
args.data_store = coach_config.get('coach', 'data_store')
args.s3_end_point = coach_config.get('coach', 's3_end_point')
args.s3_bucket_name = coach_config.get('coach', 's3_bucket_name')
args.s3_creds_file = coach_config.get('coach', 's3_creds_file')
except Error as e:
screen.error("Error when reading distributed Coach config file: {}".format(e))
if args.image == '':
screen.error("Image cannot be empty.")
data_store_choices = ['s3']
if args.data_store not in data_store_choices:
screen.warning("{} data store is unsupported.".format(args.data_store))
screen.error("Supported data stores are {}.".format(data_store_choices))
memory_backend_choices = ['redispubsub']
if args.memory_backend not in memory_backend_choices:
screen.warning("{} memory backend is not supported.".format(args.memory_backend))
screen.error("Supported memory backends are {}.".format(memory_backend_choices))
if args.s3_bucket_name == '':
screen.error("S3 bucket name cannot be empty.")
if args.s3_creds_file == '':
args.s3_creds_file = None
if args.play and args.distributed_coach:
screen.error("Playing is not supported in distributed Coach.")
# replace a short preset name with the full path
if args.preset is not None:
args.preset = self.expand_preset(args.preset)
# validate the checkpoints args
if args.checkpoint_restore_dir is not None and not os.path.exists(args.checkpoint_restore_dir):
screen.error("The requested checkpoint folder to load from does not exist.")
# no preset was given. check if the user requested to play some environment on its own
if args.preset is None and args.play and not args.environment_type:
screen.error('When no preset is given for Coach to run, and the user requests human control over '
'the environment, the user is expected to input the desired environment_type and level.'
'\nAt least one of these parameters was not given.')
elif args.preset and args.play:
screen.error("Both the --preset and the --play flags were set. These flags can not be used together. "
"For human control, please use the --play flag together with the environment type flag (-et)")
elif args.preset is None and not args.play:
screen.error("Please choose a preset using the -p flag or use the --play flag together with choosing an "
"environment type (-et) in order to play the game.")
# get experiment name and path
args.experiment_name = logger.get_experiment_name(args.experiment_name)
args.experiment_path = logger.get_experiment_path(args.experiment_name)
if args.play and args.num_workers > 1:
screen.warning("Playing the game as a human is only available with a single worker. "
"The number of workers will be reduced to 1")
args.num_workers = 1
args.framework = Frameworks[args.framework.lower()]
# checkpoints
args.checkpoint_save_dir = os.path.join(args.experiment_path, 'checkpoint') if args.checkpoint_save_secs is not None else None
if args.export_onnx_graph and not args.checkpoint_save_secs:
screen.warning("Exporting ONNX graphs requires setting the --checkpoint_save_secs flag. "
"The --export_onnx_graph will have no effect.")
return args
def get_argument_parser(self) -> argparse.ArgumentParser:
"""
This returns an ArgumentParser object which defines the set of options that customers are expected to supply in order
to launch a coach job.
"""
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('-p', '--preset', parser.add_argument('-p', '--preset',
help="(string) Name of a preset to run (class name from the 'presets' directory.)", help="(string) Name of a preset to run (class name from the 'presets' directory.)",
@@ -429,6 +475,14 @@ def main():
help="(string) Choose an environment type class to override on top of the selected preset.", help="(string) Choose an environment type class to override on top of the selected preset.",
default=None, default=None,
type=str) type=str)
parser.add_argument('-ept', '--exploration_policy_type',
help="(string) Choose an exploration policy type class to override on top of the selected "
"preset."
"If no preset is defined, a preset can be set from the command-line by combining settings "
"which are set by using --agent_type, --experiment_type, --environemnt_type"
,
default=None,
type=str)
parser.add_argument('-lvl', '--level', parser.add_argument('-lvl', '--level',
help="(string) Choose the level that will be played in the environment that was selected." help="(string) Choose the level that will be played in the environment that was selected."
"This value will override the level parameter in the environment class." "This value will override the level parameter in the environment class."
@@ -489,15 +543,12 @@ def main():
default=RunType.ORCHESTRATOR, default=RunType.ORCHESTRATOR,
choices=list(RunType)) choices=list(RunType))
args = parse_arguments(parser) return parser
if args.nocolor:
screen.set_use_colors(False)
graph_manager = get_graph_manager_from_args(args)
def run_graph_manager(self, graph_manager: 'GraphManager', args: argparse.Namespace):
if args.distributed_coach and not graph_manager.agent_params.algorithm.distributed_coach_synchronization_type: if args.distributed_coach and not graph_manager.agent_params.algorithm.distributed_coach_synchronization_type:
screen.error("{} preset is not supported using distributed Coach.".format(args.preset)) screen.error("{} algorithm is not supported using distributed Coach.".format(graph_manager.agent_params.algorithm))
# Intel optimized TF seems to run significantly faster when limiting to a single OMP thread. # Intel optimized TF seems to run significantly faster when limiting to a single OMP thread.
# This will not affect GPU runs. # This will not affect GPU runs.
@@ -526,6 +577,12 @@ def main():
# Single-threaded runs # Single-threaded runs
if args.num_workers == 1: if args.num_workers == 1:
self.start_single_threaded(graph_manager, args)
else:
self.start_multi_threaded(graph_manager, args)
def start_single_threaded(self, graph_manager: 'GraphManager', args: argparse.Namespace):
# Start the training or evaluation # Start the training or evaluation
task_parameters = TaskParameters( task_parameters = TaskParameters(
framework_type=args.framework, framework_type=args.framework,
@@ -541,8 +598,8 @@ def main():
start_graph(graph_manager=graph_manager, task_parameters=task_parameters) start_graph(graph_manager=graph_manager, task_parameters=task_parameters)
# Multi-threaded runs
else: def start_multi_threaded(self, graph_manager: 'GraphManager', args: argparse.Namespace):
total_tasks = args.num_workers total_tasks = args.num_workers
if args.evaluation_worker: if args.evaluation_worker:
total_tasks += 1 total_tasks += 1
@@ -578,7 +635,6 @@ def main():
checkpoint_save_dir=args.checkpoint_save_dir, checkpoint_save_dir=args.checkpoint_save_dir,
export_onnx_graph=args.export_onnx_graph export_onnx_graph=args.export_onnx_graph
) )
# we assume that only the evaluation workers are rendering # we assume that only the evaluation workers are rendering
graph_manager.visualization_parameters.render = args.render and evaluation_worker graph_manager.visualization_parameters.render = args.render and evaluation_worker
p = Process(target=start_graph, args=(graph_manager, task_parameters)) p = Process(target=start_graph, args=(graph_manager, task_parameters))
@@ -608,4 +664,5 @@ def main():
if __name__ == "__main__": if __name__ == "__main__":
main() launcher = CoachLauncher()
launcher.launch()

View File

@@ -18,7 +18,12 @@ import os
import pickle import pickle
import numpy as np import numpy as np
from annoy import AnnoyIndex try:
import annoy
from annoy import AnnoyIndex
except ImportError:
from rl_coach.logger import failed_imports
failed_imports.append("annoy")
class AnnoyDictionary(object): class AnnoyDictionary(object):