1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 11:10:20 +01:00

Added ability to switch between tensorflow and mxnet using -f commandline argument. (#48)

NOTE: tensorflow framework works fine if mxnet is not installed in env, but mxnet will not work if tensorflow is not installed because of the code in network_wrapper.
This commit is contained in:
Sina Afrooze
2018-10-30 15:29:34 -07:00
committed by Scott Leishman
parent 2046358ab0
commit 95b4fc6888
8 changed files with 47 additions and 21 deletions

View File

@@ -19,12 +19,17 @@ from typing import List, Tuple
from rl_coach.base_parameters import Frameworks, AgentParameters
from rl_coach.logger import failed_imports
from rl_coach.spaces import SpacesDefinition
try:
import tensorflow as tf
from rl_coach.architectures.tensorflow_components.general_network import GeneralTensorFlowNetwork
except ImportError:
failed_imports.append("TensorFlow")
failed_imports.append("tensorflow")
try:
import mxnet as mx
from rl_coach.architectures.mxnet_components.general_network import GeneralMxnetNetwork
except ImportError:
failed_imports.append("mxnet")
class NetworkWrapper(object):
@@ -42,7 +47,15 @@ class NetworkWrapper(object):
self.sess = None
if self.network_parameters.framework == Frameworks.tensorflow:
if "tensorflow" not in failed_imports:
general_network = GeneralTensorFlowNetwork
else:
raise Exception('Install tensorflow before using it as framework')
elif self.network_parameters.framework == Frameworks.mxnet:
if "mxnet" not in failed_imports:
general_network = GeneralMxnetNetwork
else:
raise Exception('Install mxnet before using it as framework')
else:
raise Exception("{} Framework is not supported"
.format(Frameworks().to_string(self.network_parameters.framework)))

View File

@@ -29,6 +29,7 @@ from rl_coach.filters.filter import NoInputFilter
class Frameworks(Enum):
tensorflow = "TensorFlow"
mxnet = "MXNet"
class EmbedderScheme(Enum):
@@ -415,7 +416,7 @@ class AgentParameters(Parameters):
class TaskParameters(Parameters):
def __init__(self, framework_type: str='tensorflow', evaluate_only: bool=False, use_cpu: bool=False,
def __init__(self, framework_type: Frameworks=Frameworks.tensorflow, evaluate_only: bool=False, use_cpu: bool=False,
experiment_path='/tmp', seed=None, checkpoint_save_secs=None):
"""
:param framework_type: deep learning framework type. currently only tensorflow is supported
@@ -435,7 +436,7 @@ class TaskParameters(Parameters):
class DistributedTaskParameters(TaskParameters):
def __init__(self, framework_type: str, parameters_server_hosts: str, worker_hosts: str, job_type: str,
def __init__(self, framework_type: Frameworks, parameters_server_hosts: str, worker_hosts: str, job_type: str,
task_index: int, evaluate_only: bool=False, num_tasks: int=None,
num_training_tasks: int=None, use_cpu: bool=False, experiment_path=None, dnd=None,
shared_memory_scratchpad=None, seed=None):

View File

@@ -61,6 +61,16 @@ def get_graph_manager_from_args(args: argparse.Namespace) -> 'GraphManager':
schedule_params = HumanPlayScheduleParameters()
graph_manager = BasicRLGraphManager(HumanAgentParameters(), env_params, schedule_params, VisualizationParameters())
# Set framework
# Note: Some graph managers (e.g. HAC preset) create multiple agents and the attribute is called agents_params
if hasattr(graph_manager, 'agent_params'):
for network_parameters in graph_manager.agent_params.network_wrappers.values():
network_parameters.framework = args.framework
elif hasattr(graph_manager, 'agents_params'):
for ap in graph_manager.agents_params:
for network_parameters in ap.network_wrappers.values():
network_parameters.framework = args.framework
if args.level:
if isinstance(graph_manager.env_params.level, SingleLevelSelection):
graph_manager.env_params.level.select(args.level)
@@ -344,7 +354,7 @@ def main():
# Single-threaded runs
if args.num_workers == 1:
# Start the training or evaluation
task_parameters = TaskParameters(framework_type="tensorflow", # TODO: tensorflow shouldn't be hardcoded
task_parameters = TaskParameters(framework_type=args.framework,
evaluate_only=args.evaluate,
experiment_path=args.experiment_path,
seed=args.seed,
@@ -373,7 +383,7 @@ def main():
def start_distributed_task(job_type, task_index, evaluation_worker=False,
shared_memory_scratchpad=shared_memory_scratchpad):
task_parameters = DistributedTaskParameters(framework_type="tensorflow", # TODO: tensorflow should'nt be hardcoded
task_parameters = DistributedTaskParameters(framework_type=args.framework,
parameters_server_hosts=ps_hosts,
worker_hosts=worker_hosts,
job_type=job_type,

View File

@@ -1,7 +1,7 @@
import os
import sys
from rl_coach.base_parameters import TaskParameters
from rl_coach.base_parameters import TaskParameters, Frameworks
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
import tensorflow as tf
@@ -16,7 +16,7 @@ def test_get_QActionStateValue_predictions():
from rl_coach.presets.CartPole_DQN import graph_manager as cartpole_dqn_graph_manager
assert cartpole_dqn_graph_manager
cartpole_dqn_graph_manager.create_graph(task_parameters=
TaskParameters(framework_type="tensorflow",
TaskParameters(framework_type=Frameworks.tensorflow,
experiment_path="./experiments/test"))
cartpole_dqn_graph_manager.improve_steps.num_steps = 1
cartpole_dqn_graph_manager.steps_between_evaluation_periods.num_steps = 5

View File

@@ -2,7 +2,7 @@ import os
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
import tensorflow as tf
from rl_coach.base_parameters import TaskParameters, DistributedTaskParameters
from rl_coach.base_parameters import TaskParameters, DistributedTaskParameters, Frameworks
from rl_coach.utils import get_open_port
from multiprocessing import Process
from tensorflow import logging
@@ -16,7 +16,7 @@ def test_basic_rl_graph_manager_with_pong_a3c():
from rl_coach.presets.Atari_A3C import graph_manager
assert graph_manager
graph_manager.env_params.level = "PongDeterministic-v4"
graph_manager.create_graph(task_parameters=TaskParameters(framework_type="tensorflow",
graph_manager.create_graph(task_parameters=TaskParameters(framework_type=Frameworks.tensorflow,
experiment_path="./experiments/test"))
# graph_manager.improve()
@@ -27,7 +27,7 @@ def test_basic_rl_graph_manager_with_pong_nec():
from rl_coach.presets.Atari_NEC import graph_manager
assert graph_manager
graph_manager.env_params.level = "PongDeterministic-v4"
graph_manager.create_graph(task_parameters=TaskParameters(framework_type="tensorflow",
graph_manager.create_graph(task_parameters=TaskParameters(framework_type=Frameworks.tensorflow,
experiment_path="./experiments/test"))
# graph_manager.improve()
@@ -37,7 +37,7 @@ def test_basic_rl_graph_manager_with_cartpole_dqn():
tf.reset_default_graph()
from rl_coach.presets.CartPole_DQN import graph_manager
assert graph_manager
graph_manager.create_graph(task_parameters=TaskParameters(framework_type="tensorflow",
graph_manager.create_graph(task_parameters=TaskParameters(framework_type=Frameworks.tensorflow,
experiment_path="./experiments/test"))
# graph_manager.improve()

View File

@@ -363,7 +363,7 @@
"if not os.path.exists(log_path):\n",
" os.makedirs(log_path)\n",
" \n",
"task_parameters = TaskParameters(framework_type=\"tensorflow\", \n",
"task_parameters = TaskParameters(framework_type=Frameworks.tensorflow, \n",
" evaluate_only=False,\n",
" experiment_path=log_path)\n",
"\n",

View File

@@ -1,8 +1,10 @@
{
"cells": [
{
"cell_type": "markdown",
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"In this tutorial we'll add the DeepMind Control Suite environment to Coach, and create a preset that trains the DDPG agent on the new environment."
]
@@ -341,7 +343,7 @@
"if not os.path.exists(log_path):\n",
" os.makedirs(log_path)\n",
" \n",
"task_parameters = TaskParameters(framework_type=\"tensorflow\", \n",
"task_parameters = TaskParameters(framework_type=Frameworks.tensorflow, \n",
" evaluate_only=False,\n",
" experiment_path=log_path)\n",
"\n",

View File

@@ -368,7 +368,7 @@
"if not os.path.exists(log_path):\n",
" os.makedirs(log_path)\n",
" \n",
"task_parameters = TaskParameters(framework_type=\"tensorflow\", \n",
"task_parameters = TaskParameters(framework_type=Frameworks.tensorflow, \n",
" evaluate_only=False,\n",
" experiment_path=log_path)\n",
"\n",