1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 19:20:19 +01:00

Added ability to switch between tensorflow and mxnet using -f commandline argument. (#48)

NOTE: tensorflow framework works fine if mxnet is not installed in env, but mxnet will not work if tensorflow is not installed because of the code in network_wrapper.
This commit is contained in:
Sina Afrooze
2018-10-30 15:29:34 -07:00
committed by Scott Leishman
parent 2046358ab0
commit 95b4fc6888
8 changed files with 47 additions and 21 deletions

View File

@@ -61,6 +61,16 @@ def get_graph_manager_from_args(args: argparse.Namespace) -> 'GraphManager':
schedule_params = HumanPlayScheduleParameters()
graph_manager = BasicRLGraphManager(HumanAgentParameters(), env_params, schedule_params, VisualizationParameters())
# Set framework
# Note: Some graph managers (e.g. HAC preset) create multiple agents and the attribute is called agents_params
if hasattr(graph_manager, 'agent_params'):
for network_parameters in graph_manager.agent_params.network_wrappers.values():
network_parameters.framework = args.framework
elif hasattr(graph_manager, 'agents_params'):
for ap in graph_manager.agents_params:
for network_parameters in ap.network_wrappers.values():
network_parameters.framework = args.framework
if args.level:
if isinstance(graph_manager.env_params.level, SingleLevelSelection):
graph_manager.env_params.level.select(args.level)
@@ -344,7 +354,7 @@ def main():
# Single-threaded runs
if args.num_workers == 1:
# Start the training or evaluation
task_parameters = TaskParameters(framework_type="tensorflow", # TODO: tensorflow shouldn't be hardcoded
task_parameters = TaskParameters(framework_type=args.framework,
evaluate_only=args.evaluate,
experiment_path=args.experiment_path,
seed=args.seed,
@@ -373,7 +383,7 @@ def main():
def start_distributed_task(job_type, task_index, evaluation_worker=False,
shared_memory_scratchpad=shared_memory_scratchpad):
task_parameters = DistributedTaskParameters(framework_type="tensorflow", # TODO: tensorflow should'nt be hardcoded
task_parameters = DistributedTaskParameters(framework_type=args.framework,
parameters_server_hosts=ps_hosts,
worker_hosts=worker_hosts,
job_type=job_type,