From 155b78b99530f17640f39a57b0717f764bf80f54 Mon Sep 17 00:00:00 2001 From: Sina Afrooze Date: Wed, 5 Dec 2018 01:52:24 -0800 Subject: [PATCH] Fix warning on import TF or MxNet, when only one of the frameworks is installed (#140) --- rl_coach/architectures/network_wrapper.py | 27 +++++++++-------------- 1 file changed, 10 insertions(+), 17 deletions(-) diff --git a/rl_coach/architectures/network_wrapper.py b/rl_coach/architectures/network_wrapper.py index 644a151..26f2920 100644 --- a/rl_coach/architectures/network_wrapper.py +++ b/rl_coach/architectures/network_wrapper.py @@ -21,17 +21,6 @@ from rl_coach.logger import failed_imports from rl_coach.saver import SaverCollection from rl_coach.spaces import SpacesDefinition from rl_coach.utils import force_list -try: - import tensorflow as tf - from rl_coach.architectures.tensorflow_components.general_network import GeneralTensorFlowNetwork -except ImportError: - failed_imports.append("tensorflow") - -try: - import mxnet as mx - from rl_coach.architectures.mxnet_components.general_network import GeneralMxnetNetwork -except ImportError: - failed_imports.append("mxnet") class NetworkWrapper(object): @@ -53,15 +42,19 @@ class NetworkWrapper(object): self.sess = None if self.network_parameters.framework == Frameworks.tensorflow: - if "tensorflow" not in failed_imports: - general_network = GeneralTensorFlowNetwork.construct - else: + try: + import tensorflow as tf + except ImportError: raise Exception('Install tensorflow before using it as framework') + from rl_coach.architectures.tensorflow_components.general_network import GeneralTensorFlowNetwork + general_network = GeneralTensorFlowNetwork.construct elif self.network_parameters.framework == Frameworks.mxnet: - if "mxnet" not in failed_imports: - general_network = GeneralMxnetNetwork.construct - else: + try: + import mxnet as mx + except ImportError: raise Exception('Install mxnet before using it as framework') + from rl_coach.architectures.mxnet_components.general_network import GeneralMxnetNetwork + general_network = GeneralMxnetNetwork.construct else: raise Exception("{} Framework is not supported" .format(Frameworks().to_string(self.network_parameters.framework)))