mirror of
https://github.com/gryf/coach.git
synced 2025-12-18 03:30:19 +01:00
Fix warning on import TF or MxNet, when only one of the frameworks is installed (#140)
This commit is contained in:
committed by
Gal Leibovich
parent
9e66bb653e
commit
155b78b995
@@ -21,17 +21,6 @@ from rl_coach.logger import failed_imports
|
|||||||
from rl_coach.saver import SaverCollection
|
from rl_coach.saver import SaverCollection
|
||||||
from rl_coach.spaces import SpacesDefinition
|
from rl_coach.spaces import SpacesDefinition
|
||||||
from rl_coach.utils import force_list
|
from rl_coach.utils import force_list
|
||||||
try:
|
|
||||||
import tensorflow as tf
|
|
||||||
from rl_coach.architectures.tensorflow_components.general_network import GeneralTensorFlowNetwork
|
|
||||||
except ImportError:
|
|
||||||
failed_imports.append("tensorflow")
|
|
||||||
|
|
||||||
try:
|
|
||||||
import mxnet as mx
|
|
||||||
from rl_coach.architectures.mxnet_components.general_network import GeneralMxnetNetwork
|
|
||||||
except ImportError:
|
|
||||||
failed_imports.append("mxnet")
|
|
||||||
|
|
||||||
|
|
||||||
class NetworkWrapper(object):
|
class NetworkWrapper(object):
|
||||||
@@ -53,15 +42,19 @@ class NetworkWrapper(object):
|
|||||||
self.sess = None
|
self.sess = None
|
||||||
|
|
||||||
if self.network_parameters.framework == Frameworks.tensorflow:
|
if self.network_parameters.framework == Frameworks.tensorflow:
|
||||||
if "tensorflow" not in failed_imports:
|
try:
|
||||||
general_network = GeneralTensorFlowNetwork.construct
|
import tensorflow as tf
|
||||||
else:
|
except ImportError:
|
||||||
raise Exception('Install tensorflow before using it as framework')
|
raise Exception('Install tensorflow before using it as framework')
|
||||||
|
from rl_coach.architectures.tensorflow_components.general_network import GeneralTensorFlowNetwork
|
||||||
|
general_network = GeneralTensorFlowNetwork.construct
|
||||||
elif self.network_parameters.framework == Frameworks.mxnet:
|
elif self.network_parameters.framework == Frameworks.mxnet:
|
||||||
if "mxnet" not in failed_imports:
|
try:
|
||||||
general_network = GeneralMxnetNetwork.construct
|
import mxnet as mx
|
||||||
else:
|
except ImportError:
|
||||||
raise Exception('Install mxnet before using it as framework')
|
raise Exception('Install mxnet before using it as framework')
|
||||||
|
from rl_coach.architectures.mxnet_components.general_network import GeneralMxnetNetwork
|
||||||
|
general_network = GeneralMxnetNetwork.construct
|
||||||
else:
|
else:
|
||||||
raise Exception("{} Framework is not supported"
|
raise Exception("{} Framework is not supported"
|
||||||
.format(Frameworks().to_string(self.network_parameters.framework)))
|
.format(Frameworks().to_string(self.network_parameters.framework)))
|
||||||
|
|||||||
Reference in New Issue
Block a user