mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 11:10:20 +01:00
Fix warning on import TF or MxNet, when only one of the frameworks is installed (#140)
This commit is contained in:
committed by
Gal Leibovich
parent
9e66bb653e
commit
155b78b995
@@ -21,17 +21,6 @@ from rl_coach.logger import failed_imports
|
||||
from rl_coach.saver import SaverCollection
|
||||
from rl_coach.spaces import SpacesDefinition
|
||||
from rl_coach.utils import force_list
|
||||
try:
|
||||
import tensorflow as tf
|
||||
from rl_coach.architectures.tensorflow_components.general_network import GeneralTensorFlowNetwork
|
||||
except ImportError:
|
||||
failed_imports.append("tensorflow")
|
||||
|
||||
try:
|
||||
import mxnet as mx
|
||||
from rl_coach.architectures.mxnet_components.general_network import GeneralMxnetNetwork
|
||||
except ImportError:
|
||||
failed_imports.append("mxnet")
|
||||
|
||||
|
||||
class NetworkWrapper(object):
|
||||
@@ -53,15 +42,19 @@ class NetworkWrapper(object):
|
||||
self.sess = None
|
||||
|
||||
if self.network_parameters.framework == Frameworks.tensorflow:
|
||||
if "tensorflow" not in failed_imports:
|
||||
general_network = GeneralTensorFlowNetwork.construct
|
||||
else:
|
||||
try:
|
||||
import tensorflow as tf
|
||||
except ImportError:
|
||||
raise Exception('Install tensorflow before using it as framework')
|
||||
from rl_coach.architectures.tensorflow_components.general_network import GeneralTensorFlowNetwork
|
||||
general_network = GeneralTensorFlowNetwork.construct
|
||||
elif self.network_parameters.framework == Frameworks.mxnet:
|
||||
if "mxnet" not in failed_imports:
|
||||
general_network = GeneralMxnetNetwork.construct
|
||||
else:
|
||||
try:
|
||||
import mxnet as mx
|
||||
except ImportError:
|
||||
raise Exception('Install mxnet before using it as framework')
|
||||
from rl_coach.architectures.mxnet_components.general_network import GeneralMxnetNetwork
|
||||
general_network = GeneralMxnetNetwork.construct
|
||||
else:
|
||||
raise Exception("{} Framework is not supported"
|
||||
.format(Frameworks().to_string(self.network_parameters.framework)))
|
||||
|
||||
Reference in New Issue
Block a user