1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 11:10:20 +01:00

Fix warning on import TF or MxNet, when only one of the frameworks is installed (#140)

This commit is contained in:
Sina Afrooze
2018-12-05 01:52:24 -08:00
committed by Gal Leibovich
parent 9e66bb653e
commit 155b78b995

View File

@@ -21,17 +21,6 @@ from rl_coach.logger import failed_imports
from rl_coach.saver import SaverCollection
from rl_coach.spaces import SpacesDefinition
from rl_coach.utils import force_list
try:
import tensorflow as tf
from rl_coach.architectures.tensorflow_components.general_network import GeneralTensorFlowNetwork
except ImportError:
failed_imports.append("tensorflow")
try:
import mxnet as mx
from rl_coach.architectures.mxnet_components.general_network import GeneralMxnetNetwork
except ImportError:
failed_imports.append("mxnet")
class NetworkWrapper(object):
@@ -53,15 +42,19 @@ class NetworkWrapper(object):
self.sess = None
if self.network_parameters.framework == Frameworks.tensorflow:
if "tensorflow" not in failed_imports:
general_network = GeneralTensorFlowNetwork.construct
else:
try:
import tensorflow as tf
except ImportError:
raise Exception('Install tensorflow before using it as framework')
from rl_coach.architectures.tensorflow_components.general_network import GeneralTensorFlowNetwork
general_network = GeneralTensorFlowNetwork.construct
elif self.network_parameters.framework == Frameworks.mxnet:
if "mxnet" not in failed_imports:
general_network = GeneralMxnetNetwork.construct
else:
try:
import mxnet as mx
except ImportError:
raise Exception('Install mxnet before using it as framework')
from rl_coach.architectures.mxnet_components.general_network import GeneralMxnetNetwork
general_network = GeneralMxnetNetwork.construct
else:
raise Exception("{} Framework is not supported"
.format(Frameworks().to_string(self.network_parameters.framework)))