mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 19:20:19 +01:00
Added ability to switch between tensorflow and mxnet using -f commandline argument. (#48)
NOTE: tensorflow framework works fine if mxnet is not installed in env, but mxnet will not work if tensorflow is not installed because of the code in network_wrapper.
This commit is contained in:
committed by
Scott Leishman
parent
2046358ab0
commit
95b4fc6888
@@ -61,6 +61,16 @@ def get_graph_manager_from_args(args: argparse.Namespace) -> 'GraphManager':
|
||||
schedule_params = HumanPlayScheduleParameters()
|
||||
graph_manager = BasicRLGraphManager(HumanAgentParameters(), env_params, schedule_params, VisualizationParameters())
|
||||
|
||||
# Set framework
|
||||
# Note: Some graph managers (e.g. HAC preset) create multiple agents and the attribute is called agents_params
|
||||
if hasattr(graph_manager, 'agent_params'):
|
||||
for network_parameters in graph_manager.agent_params.network_wrappers.values():
|
||||
network_parameters.framework = args.framework
|
||||
elif hasattr(graph_manager, 'agents_params'):
|
||||
for ap in graph_manager.agents_params:
|
||||
for network_parameters in ap.network_wrappers.values():
|
||||
network_parameters.framework = args.framework
|
||||
|
||||
if args.level:
|
||||
if isinstance(graph_manager.env_params.level, SingleLevelSelection):
|
||||
graph_manager.env_params.level.select(args.level)
|
||||
@@ -344,7 +354,7 @@ def main():
|
||||
# Single-threaded runs
|
||||
if args.num_workers == 1:
|
||||
# Start the training or evaluation
|
||||
task_parameters = TaskParameters(framework_type="tensorflow", # TODO: tensorflow shouldn't be hardcoded
|
||||
task_parameters = TaskParameters(framework_type=args.framework,
|
||||
evaluate_only=args.evaluate,
|
||||
experiment_path=args.experiment_path,
|
||||
seed=args.seed,
|
||||
@@ -373,7 +383,7 @@ def main():
|
||||
|
||||
def start_distributed_task(job_type, task_index, evaluation_worker=False,
|
||||
shared_memory_scratchpad=shared_memory_scratchpad):
|
||||
task_parameters = DistributedTaskParameters(framework_type="tensorflow", # TODO: tensorflow should'nt be hardcoded
|
||||
task_parameters = DistributedTaskParameters(framework_type=args.framework,
|
||||
parameters_server_hosts=ps_hosts,
|
||||
worker_hosts=worker_hosts,
|
||||
job_type=job_type,
|
||||
|
||||
Reference in New Issue
Block a user