1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-18 03:30:19 +01:00

Added ability to switch between tensorflow and mxnet using -f commandline argument. (#48)

NOTE: tensorflow framework works fine if mxnet is not installed in env, but mxnet will not work if tensorflow is not installed because of the code in network_wrapper.
This commit is contained in:
Sina Afrooze
2018-10-30 15:29:34 -07:00
committed by Scott Leishman
parent 2046358ab0
commit 95b4fc6888
8 changed files with 47 additions and 21 deletions

View File

@@ -19,12 +19,17 @@ from typing import List, Tuple
from rl_coach.base_parameters import Frameworks, AgentParameters
from rl_coach.logger import failed_imports
from rl_coach.spaces import SpacesDefinition
try:
import tensorflow as tf
from rl_coach.architectures.tensorflow_components.general_network import GeneralTensorFlowNetwork
except ImportError:
failed_imports.append("TensorFlow")
failed_imports.append("tensorflow")
try:
import mxnet as mx
from rl_coach.architectures.mxnet_components.general_network import GeneralMxnetNetwork
except ImportError:
failed_imports.append("mxnet")
class NetworkWrapper(object):
@@ -42,7 +47,15 @@ class NetworkWrapper(object):
self.sess = None
if self.network_parameters.framework == Frameworks.tensorflow:
general_network = GeneralTensorFlowNetwork
if "tensorflow" not in failed_imports:
general_network = GeneralTensorFlowNetwork
else:
raise Exception('Install tensorflow before using it as framework')
elif self.network_parameters.framework == Frameworks.mxnet:
if "mxnet" not in failed_imports:
general_network = GeneralMxnetNetwork
else:
raise Exception('Install mxnet before using it as framework')
else:
raise Exception("{} Framework is not supported"
.format(Frameworks().to_string(self.network_parameters.framework)))