mirror of
https://github.com/gryf/coach.git
synced 2025-12-18 03:30:19 +01:00
Added ability to switch between tensorflow and mxnet using -f commandline argument. (#48)
NOTE: tensorflow framework works fine if mxnet is not installed in env, but mxnet will not work if tensorflow is not installed because of the code in network_wrapper.
This commit is contained in:
committed by
Scott Leishman
parent
2046358ab0
commit
95b4fc6888
@@ -19,12 +19,17 @@ from typing import List, Tuple
|
||||
from rl_coach.base_parameters import Frameworks, AgentParameters
|
||||
from rl_coach.logger import failed_imports
|
||||
from rl_coach.spaces import SpacesDefinition
|
||||
|
||||
try:
|
||||
import tensorflow as tf
|
||||
from rl_coach.architectures.tensorflow_components.general_network import GeneralTensorFlowNetwork
|
||||
except ImportError:
|
||||
failed_imports.append("TensorFlow")
|
||||
failed_imports.append("tensorflow")
|
||||
|
||||
try:
|
||||
import mxnet as mx
|
||||
from rl_coach.architectures.mxnet_components.general_network import GeneralMxnetNetwork
|
||||
except ImportError:
|
||||
failed_imports.append("mxnet")
|
||||
|
||||
|
||||
class NetworkWrapper(object):
|
||||
@@ -42,7 +47,15 @@ class NetworkWrapper(object):
|
||||
self.sess = None
|
||||
|
||||
if self.network_parameters.framework == Frameworks.tensorflow:
|
||||
general_network = GeneralTensorFlowNetwork
|
||||
if "tensorflow" not in failed_imports:
|
||||
general_network = GeneralTensorFlowNetwork
|
||||
else:
|
||||
raise Exception('Install tensorflow before using it as framework')
|
||||
elif self.network_parameters.framework == Frameworks.mxnet:
|
||||
if "mxnet" not in failed_imports:
|
||||
general_network = GeneralMxnetNetwork
|
||||
else:
|
||||
raise Exception('Install mxnet before using it as framework')
|
||||
else:
|
||||
raise Exception("{} Framework is not supported"
|
||||
.format(Frameworks().to_string(self.network_parameters.framework)))
|
||||
|
||||
Reference in New Issue
Block a user