From 95b4fc68883d71d122d082a3a4a0073a7daf1c24 Mon Sep 17 00:00:00 2001 From: Sina Afrooze Date: Tue, 30 Oct 2018 15:29:34 -0700 Subject: [PATCH] Added ability to switch between tensorflow and mxnet using -f commandline argument. (#48) NOTE: tensorflow framework works fine if mxnet is not installed in env, but mxnet will not work if tensorflow is not installed because of the code in network_wrapper. --- rl_coach/architectures/network_wrapper.py | 19 ++++++++++++++++--- rl_coach/base_parameters.py | 5 +++-- rl_coach/coach.py | 14 ++++++++++++-- .../test_agent_external_communication.py | 4 ++-- .../test_basic_rl_graph_manager.py | 8 ++++---- tutorials/1. Implementing an Algorithm.ipynb | 6 +++--- tutorials/2. Adding an Environment.ipynb | 10 ++++++---- ...Implementing a Hierarchical RL Graph.ipynb | 2 +- 8 files changed, 47 insertions(+), 21 deletions(-) diff --git a/rl_coach/architectures/network_wrapper.py b/rl_coach/architectures/network_wrapper.py index 5c868d5..042c93c 100644 --- a/rl_coach/architectures/network_wrapper.py +++ b/rl_coach/architectures/network_wrapper.py @@ -19,12 +19,17 @@ from typing import List, Tuple from rl_coach.base_parameters import Frameworks, AgentParameters from rl_coach.logger import failed_imports from rl_coach.spaces import SpacesDefinition - try: import tensorflow as tf from rl_coach.architectures.tensorflow_components.general_network import GeneralTensorFlowNetwork except ImportError: - failed_imports.append("TensorFlow") + failed_imports.append("tensorflow") + +try: + import mxnet as mx + from rl_coach.architectures.mxnet_components.general_network import GeneralMxnetNetwork +except ImportError: + failed_imports.append("mxnet") class NetworkWrapper(object): @@ -42,7 +47,15 @@ class NetworkWrapper(object): self.sess = None if self.network_parameters.framework == Frameworks.tensorflow: - general_network = GeneralTensorFlowNetwork + if "tensorflow" not in failed_imports: + general_network = GeneralTensorFlowNetwork + else: + raise Exception('Install tensorflow before using it as framework') + elif self.network_parameters.framework == Frameworks.mxnet: + if "mxnet" not in failed_imports: + general_network = GeneralMxnetNetwork + else: + raise Exception('Install mxnet before using it as framework') else: raise Exception("{} Framework is not supported" .format(Frameworks().to_string(self.network_parameters.framework))) diff --git a/rl_coach/base_parameters.py b/rl_coach/base_parameters.py index 028b649..a48a22c 100644 --- a/rl_coach/base_parameters.py +++ b/rl_coach/base_parameters.py @@ -29,6 +29,7 @@ from rl_coach.filters.filter import NoInputFilter class Frameworks(Enum): tensorflow = "TensorFlow" + mxnet = "MXNet" class EmbedderScheme(Enum): @@ -415,7 +416,7 @@ class AgentParameters(Parameters): class TaskParameters(Parameters): - def __init__(self, framework_type: str='tensorflow', evaluate_only: bool=False, use_cpu: bool=False, + def __init__(self, framework_type: Frameworks=Frameworks.tensorflow, evaluate_only: bool=False, use_cpu: bool=False, experiment_path='/tmp', seed=None, checkpoint_save_secs=None): """ :param framework_type: deep learning framework type. currently only tensorflow is supported @@ -435,7 +436,7 @@ class TaskParameters(Parameters): class DistributedTaskParameters(TaskParameters): - def __init__(self, framework_type: str, parameters_server_hosts: str, worker_hosts: str, job_type: str, + def __init__(self, framework_type: Frameworks, parameters_server_hosts: str, worker_hosts: str, job_type: str, task_index: int, evaluate_only: bool=False, num_tasks: int=None, num_training_tasks: int=None, use_cpu: bool=False, experiment_path=None, dnd=None, shared_memory_scratchpad=None, seed=None): diff --git a/rl_coach/coach.py b/rl_coach/coach.py index e479974..cab0492 100644 --- a/rl_coach/coach.py +++ b/rl_coach/coach.py @@ -61,6 +61,16 @@ def get_graph_manager_from_args(args: argparse.Namespace) -> 'GraphManager': schedule_params = HumanPlayScheduleParameters() graph_manager = BasicRLGraphManager(HumanAgentParameters(), env_params, schedule_params, VisualizationParameters()) + # Set framework + # Note: Some graph managers (e.g. HAC preset) create multiple agents and the attribute is called agents_params + if hasattr(graph_manager, 'agent_params'): + for network_parameters in graph_manager.agent_params.network_wrappers.values(): + network_parameters.framework = args.framework + elif hasattr(graph_manager, 'agents_params'): + for ap in graph_manager.agents_params: + for network_parameters in ap.network_wrappers.values(): + network_parameters.framework = args.framework + if args.level: if isinstance(graph_manager.env_params.level, SingleLevelSelection): graph_manager.env_params.level.select(args.level) @@ -344,7 +354,7 @@ def main(): # Single-threaded runs if args.num_workers == 1: # Start the training or evaluation - task_parameters = TaskParameters(framework_type="tensorflow", # TODO: tensorflow shouldn't be hardcoded + task_parameters = TaskParameters(framework_type=args.framework, evaluate_only=args.evaluate, experiment_path=args.experiment_path, seed=args.seed, @@ -373,7 +383,7 @@ def main(): def start_distributed_task(job_type, task_index, evaluation_worker=False, shared_memory_scratchpad=shared_memory_scratchpad): - task_parameters = DistributedTaskParameters(framework_type="tensorflow", # TODO: tensorflow should'nt be hardcoded + task_parameters = DistributedTaskParameters(framework_type=args.framework, parameters_server_hosts=ps_hosts, worker_hosts=worker_hosts, job_type=job_type, diff --git a/rl_coach/tests/agents/test_agent_external_communication.py b/rl_coach/tests/agents/test_agent_external_communication.py index 0bef271..77f0a89 100644 --- a/rl_coach/tests/agents/test_agent_external_communication.py +++ b/rl_coach/tests/agents/test_agent_external_communication.py @@ -1,7 +1,7 @@ import os import sys -from rl_coach.base_parameters import TaskParameters +from rl_coach.base_parameters import TaskParameters, Frameworks sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))) import tensorflow as tf @@ -16,7 +16,7 @@ def test_get_QActionStateValue_predictions(): from rl_coach.presets.CartPole_DQN import graph_manager as cartpole_dqn_graph_manager assert cartpole_dqn_graph_manager cartpole_dqn_graph_manager.create_graph(task_parameters= - TaskParameters(framework_type="tensorflow", + TaskParameters(framework_type=Frameworks.tensorflow, experiment_path="./experiments/test")) cartpole_dqn_graph_manager.improve_steps.num_steps = 1 cartpole_dqn_graph_manager.steps_between_evaluation_periods.num_steps = 5 diff --git a/rl_coach/tests/graph_managers/test_basic_rl_graph_manager.py b/rl_coach/tests/graph_managers/test_basic_rl_graph_manager.py index e489373..214ef31 100644 --- a/rl_coach/tests/graph_managers/test_basic_rl_graph_manager.py +++ b/rl_coach/tests/graph_managers/test_basic_rl_graph_manager.py @@ -2,7 +2,7 @@ import os import sys sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))) import tensorflow as tf -from rl_coach.base_parameters import TaskParameters, DistributedTaskParameters +from rl_coach.base_parameters import TaskParameters, DistributedTaskParameters, Frameworks from rl_coach.utils import get_open_port from multiprocessing import Process from tensorflow import logging @@ -16,7 +16,7 @@ def test_basic_rl_graph_manager_with_pong_a3c(): from rl_coach.presets.Atari_A3C import graph_manager assert graph_manager graph_manager.env_params.level = "PongDeterministic-v4" - graph_manager.create_graph(task_parameters=TaskParameters(framework_type="tensorflow", + graph_manager.create_graph(task_parameters=TaskParameters(framework_type=Frameworks.tensorflow, experiment_path="./experiments/test")) # graph_manager.improve() @@ -27,7 +27,7 @@ def test_basic_rl_graph_manager_with_pong_nec(): from rl_coach.presets.Atari_NEC import graph_manager assert graph_manager graph_manager.env_params.level = "PongDeterministic-v4" - graph_manager.create_graph(task_parameters=TaskParameters(framework_type="tensorflow", + graph_manager.create_graph(task_parameters=TaskParameters(framework_type=Frameworks.tensorflow, experiment_path="./experiments/test")) # graph_manager.improve() @@ -37,7 +37,7 @@ def test_basic_rl_graph_manager_with_cartpole_dqn(): tf.reset_default_graph() from rl_coach.presets.CartPole_DQN import graph_manager assert graph_manager - graph_manager.create_graph(task_parameters=TaskParameters(framework_type="tensorflow", + graph_manager.create_graph(task_parameters=TaskParameters(framework_type=Frameworks.tensorflow, experiment_path="./experiments/test")) # graph_manager.improve() diff --git a/tutorials/1. Implementing an Algorithm.ipynb b/tutorials/1. Implementing an Algorithm.ipynb index a309964..459a958 100644 --- a/tutorials/1. Implementing an Algorithm.ipynb +++ b/tutorials/1. Implementing an Algorithm.ipynb @@ -363,9 +363,9 @@ "if not os.path.exists(log_path):\n", " os.makedirs(log_path)\n", " \n", - "task_parameters = TaskParameters(framework_type=\"tensorflow\", \n", - " evaluate_only=False,\n", - " experiment_path=log_path)\n", + "task_parameters = TaskParameters(framework_type=Frameworks.tensorflow, \n", + " evaluate_only=False,\n", + " experiment_path=log_path)\n", "\n", "task_parameters.__dict__['checkpoint_save_secs'] = None\n", "\n", diff --git a/tutorials/2. Adding an Environment.ipynb b/tutorials/2. Adding an Environment.ipynb index 71869d1..9f0fa71 100644 --- a/tutorials/2. Adding an Environment.ipynb +++ b/tutorials/2. Adding an Environment.ipynb @@ -1,8 +1,10 @@ { "cells": [ { - "cell_type": "markdown", + "cell_type": "code", + "execution_count": null, "metadata": {}, + "outputs": [], "source": [ "In this tutorial we'll add the DeepMind Control Suite environment to Coach, and create a preset that trains the DDPG agent on the new environment." ] @@ -341,9 +343,9 @@ "if not os.path.exists(log_path):\n", " os.makedirs(log_path)\n", " \n", - "task_parameters = TaskParameters(framework_type=\"tensorflow\", \n", - " evaluate_only=False,\n", - " experiment_path=log_path)\n", + "task_parameters = TaskParameters(framework_type=Frameworks.tensorflow, \n", + " evaluate_only=False,\n", + " experiment_path=log_path)\n", "\n", "task_parameters.__dict__['checkpoint_save_secs'] = None\n", "\n", diff --git a/tutorials/3. Implementing a Hierarchical RL Graph.ipynb b/tutorials/3. Implementing a Hierarchical RL Graph.ipynb index f200c2c..191d267 100644 --- a/tutorials/3. Implementing a Hierarchical RL Graph.ipynb +++ b/tutorials/3. Implementing a Hierarchical RL Graph.ipynb @@ -368,7 +368,7 @@ "if not os.path.exists(log_path):\n", " os.makedirs(log_path)\n", " \n", - "task_parameters = TaskParameters(framework_type=\"tensorflow\", \n", + "task_parameters = TaskParameters(framework_type=Frameworks.tensorflow, \n", " evaluate_only=False,\n", " experiment_path=log_path)\n", "\n",