mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 11:10:20 +01:00
Added ability to switch between tensorflow and mxnet using -f commandline argument. (#48)
NOTE: tensorflow framework works fine if mxnet is not installed in env, but mxnet will not work if tensorflow is not installed because of the code in network_wrapper.
This commit is contained in:
committed by
Scott Leishman
parent
2046358ab0
commit
95b4fc6888
@@ -19,12 +19,17 @@ from typing import List, Tuple
|
||||
from rl_coach.base_parameters import Frameworks, AgentParameters
|
||||
from rl_coach.logger import failed_imports
|
||||
from rl_coach.spaces import SpacesDefinition
|
||||
|
||||
try:
|
||||
import tensorflow as tf
|
||||
from rl_coach.architectures.tensorflow_components.general_network import GeneralTensorFlowNetwork
|
||||
except ImportError:
|
||||
failed_imports.append("TensorFlow")
|
||||
failed_imports.append("tensorflow")
|
||||
|
||||
try:
|
||||
import mxnet as mx
|
||||
from rl_coach.architectures.mxnet_components.general_network import GeneralMxnetNetwork
|
||||
except ImportError:
|
||||
failed_imports.append("mxnet")
|
||||
|
||||
|
||||
class NetworkWrapper(object):
|
||||
@@ -42,7 +47,15 @@ class NetworkWrapper(object):
|
||||
self.sess = None
|
||||
|
||||
if self.network_parameters.framework == Frameworks.tensorflow:
|
||||
if "tensorflow" not in failed_imports:
|
||||
general_network = GeneralTensorFlowNetwork
|
||||
else:
|
||||
raise Exception('Install tensorflow before using it as framework')
|
||||
elif self.network_parameters.framework == Frameworks.mxnet:
|
||||
if "mxnet" not in failed_imports:
|
||||
general_network = GeneralMxnetNetwork
|
||||
else:
|
||||
raise Exception('Install mxnet before using it as framework')
|
||||
else:
|
||||
raise Exception("{} Framework is not supported"
|
||||
.format(Frameworks().to_string(self.network_parameters.framework)))
|
||||
|
||||
@@ -29,6 +29,7 @@ from rl_coach.filters.filter import NoInputFilter
|
||||
|
||||
class Frameworks(Enum):
|
||||
tensorflow = "TensorFlow"
|
||||
mxnet = "MXNet"
|
||||
|
||||
|
||||
class EmbedderScheme(Enum):
|
||||
@@ -415,7 +416,7 @@ class AgentParameters(Parameters):
|
||||
|
||||
|
||||
class TaskParameters(Parameters):
|
||||
def __init__(self, framework_type: str='tensorflow', evaluate_only: bool=False, use_cpu: bool=False,
|
||||
def __init__(self, framework_type: Frameworks=Frameworks.tensorflow, evaluate_only: bool=False, use_cpu: bool=False,
|
||||
experiment_path='/tmp', seed=None, checkpoint_save_secs=None):
|
||||
"""
|
||||
:param framework_type: deep learning framework type. currently only tensorflow is supported
|
||||
@@ -435,7 +436,7 @@ class TaskParameters(Parameters):
|
||||
|
||||
|
||||
class DistributedTaskParameters(TaskParameters):
|
||||
def __init__(self, framework_type: str, parameters_server_hosts: str, worker_hosts: str, job_type: str,
|
||||
def __init__(self, framework_type: Frameworks, parameters_server_hosts: str, worker_hosts: str, job_type: str,
|
||||
task_index: int, evaluate_only: bool=False, num_tasks: int=None,
|
||||
num_training_tasks: int=None, use_cpu: bool=False, experiment_path=None, dnd=None,
|
||||
shared_memory_scratchpad=None, seed=None):
|
||||
|
||||
@@ -61,6 +61,16 @@ def get_graph_manager_from_args(args: argparse.Namespace) -> 'GraphManager':
|
||||
schedule_params = HumanPlayScheduleParameters()
|
||||
graph_manager = BasicRLGraphManager(HumanAgentParameters(), env_params, schedule_params, VisualizationParameters())
|
||||
|
||||
# Set framework
|
||||
# Note: Some graph managers (e.g. HAC preset) create multiple agents and the attribute is called agents_params
|
||||
if hasattr(graph_manager, 'agent_params'):
|
||||
for network_parameters in graph_manager.agent_params.network_wrappers.values():
|
||||
network_parameters.framework = args.framework
|
||||
elif hasattr(graph_manager, 'agents_params'):
|
||||
for ap in graph_manager.agents_params:
|
||||
for network_parameters in ap.network_wrappers.values():
|
||||
network_parameters.framework = args.framework
|
||||
|
||||
if args.level:
|
||||
if isinstance(graph_manager.env_params.level, SingleLevelSelection):
|
||||
graph_manager.env_params.level.select(args.level)
|
||||
@@ -344,7 +354,7 @@ def main():
|
||||
# Single-threaded runs
|
||||
if args.num_workers == 1:
|
||||
# Start the training or evaluation
|
||||
task_parameters = TaskParameters(framework_type="tensorflow", # TODO: tensorflow shouldn't be hardcoded
|
||||
task_parameters = TaskParameters(framework_type=args.framework,
|
||||
evaluate_only=args.evaluate,
|
||||
experiment_path=args.experiment_path,
|
||||
seed=args.seed,
|
||||
@@ -373,7 +383,7 @@ def main():
|
||||
|
||||
def start_distributed_task(job_type, task_index, evaluation_worker=False,
|
||||
shared_memory_scratchpad=shared_memory_scratchpad):
|
||||
task_parameters = DistributedTaskParameters(framework_type="tensorflow", # TODO: tensorflow should'nt be hardcoded
|
||||
task_parameters = DistributedTaskParameters(framework_type=args.framework,
|
||||
parameters_server_hosts=ps_hosts,
|
||||
worker_hosts=worker_hosts,
|
||||
job_type=job_type,
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import os
|
||||
import sys
|
||||
|
||||
from rl_coach.base_parameters import TaskParameters
|
||||
from rl_coach.base_parameters import TaskParameters, Frameworks
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
|
||||
import tensorflow as tf
|
||||
@@ -16,7 +16,7 @@ def test_get_QActionStateValue_predictions():
|
||||
from rl_coach.presets.CartPole_DQN import graph_manager as cartpole_dqn_graph_manager
|
||||
assert cartpole_dqn_graph_manager
|
||||
cartpole_dqn_graph_manager.create_graph(task_parameters=
|
||||
TaskParameters(framework_type="tensorflow",
|
||||
TaskParameters(framework_type=Frameworks.tensorflow,
|
||||
experiment_path="./experiments/test"))
|
||||
cartpole_dqn_graph_manager.improve_steps.num_steps = 1
|
||||
cartpole_dqn_graph_manager.steps_between_evaluation_periods.num_steps = 5
|
||||
|
||||
@@ -2,7 +2,7 @@ import os
|
||||
import sys
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
|
||||
import tensorflow as tf
|
||||
from rl_coach.base_parameters import TaskParameters, DistributedTaskParameters
|
||||
from rl_coach.base_parameters import TaskParameters, DistributedTaskParameters, Frameworks
|
||||
from rl_coach.utils import get_open_port
|
||||
from multiprocessing import Process
|
||||
from tensorflow import logging
|
||||
@@ -16,7 +16,7 @@ def test_basic_rl_graph_manager_with_pong_a3c():
|
||||
from rl_coach.presets.Atari_A3C import graph_manager
|
||||
assert graph_manager
|
||||
graph_manager.env_params.level = "PongDeterministic-v4"
|
||||
graph_manager.create_graph(task_parameters=TaskParameters(framework_type="tensorflow",
|
||||
graph_manager.create_graph(task_parameters=TaskParameters(framework_type=Frameworks.tensorflow,
|
||||
experiment_path="./experiments/test"))
|
||||
# graph_manager.improve()
|
||||
|
||||
@@ -27,7 +27,7 @@ def test_basic_rl_graph_manager_with_pong_nec():
|
||||
from rl_coach.presets.Atari_NEC import graph_manager
|
||||
assert graph_manager
|
||||
graph_manager.env_params.level = "PongDeterministic-v4"
|
||||
graph_manager.create_graph(task_parameters=TaskParameters(framework_type="tensorflow",
|
||||
graph_manager.create_graph(task_parameters=TaskParameters(framework_type=Frameworks.tensorflow,
|
||||
experiment_path="./experiments/test"))
|
||||
# graph_manager.improve()
|
||||
|
||||
@@ -37,7 +37,7 @@ def test_basic_rl_graph_manager_with_cartpole_dqn():
|
||||
tf.reset_default_graph()
|
||||
from rl_coach.presets.CartPole_DQN import graph_manager
|
||||
assert graph_manager
|
||||
graph_manager.create_graph(task_parameters=TaskParameters(framework_type="tensorflow",
|
||||
graph_manager.create_graph(task_parameters=TaskParameters(framework_type=Frameworks.tensorflow,
|
||||
experiment_path="./experiments/test"))
|
||||
# graph_manager.improve()
|
||||
|
||||
|
||||
@@ -363,7 +363,7 @@
|
||||
"if not os.path.exists(log_path):\n",
|
||||
" os.makedirs(log_path)\n",
|
||||
" \n",
|
||||
"task_parameters = TaskParameters(framework_type=\"tensorflow\", \n",
|
||||
"task_parameters = TaskParameters(framework_type=Frameworks.tensorflow, \n",
|
||||
" evaluate_only=False,\n",
|
||||
" experiment_path=log_path)\n",
|
||||
"\n",
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"In this tutorial we'll add the DeepMind Control Suite environment to Coach, and create a preset that trains the DDPG agent on the new environment."
|
||||
]
|
||||
@@ -341,7 +343,7 @@
|
||||
"if not os.path.exists(log_path):\n",
|
||||
" os.makedirs(log_path)\n",
|
||||
" \n",
|
||||
"task_parameters = TaskParameters(framework_type=\"tensorflow\", \n",
|
||||
"task_parameters = TaskParameters(framework_type=Frameworks.tensorflow, \n",
|
||||
" evaluate_only=False,\n",
|
||||
" experiment_path=log_path)\n",
|
||||
"\n",
|
||||
|
||||
@@ -368,7 +368,7 @@
|
||||
"if not os.path.exists(log_path):\n",
|
||||
" os.makedirs(log_path)\n",
|
||||
" \n",
|
||||
"task_parameters = TaskParameters(framework_type=\"tensorflow\", \n",
|
||||
"task_parameters = TaskParameters(framework_type=Frameworks.tensorflow, \n",
|
||||
" evaluate_only=False,\n",
|
||||
" experiment_path=log_path)\n",
|
||||
"\n",
|
||||
|
||||
Reference in New Issue
Block a user