diff --git a/rl_coach/architectures/mxnet_components/architecture.py b/rl_coach/architectures/mxnet_components/architecture.py index 0f665d3..6c64a31 100644 --- a/rl_coach/architectures/mxnet_components/architecture.py +++ b/rl_coach/architectures/mxnet_components/architecture.py @@ -26,6 +26,7 @@ from rl_coach.architectures.mxnet_components.heads.head import LOSS_OUT_TYPE_LOS from rl_coach.architectures.mxnet_components import utils from rl_coach.architectures.mxnet_components.savers import ParameterDictSaver, OnnxSaver from rl_coach.base_parameters import AgentParameters +from rl_coach.logger import screen from rl_coach.saver import SaverCollection from rl_coach.spaces import SpacesDefinition from rl_coach.utils import force_list, squeeze_list @@ -58,7 +59,7 @@ class MxnetArchitecture(Architecture): self.network_is_trainable = network_is_trainable self.is_training = False self.model = None # type: GeneralModel - self._devices = devices + self._devices = self._sanitize_device_list(devices) self.is_chief = self.ap.task_parameters.task_index == 0 self.network_is_global = not self.network_is_local and global_network is None @@ -76,6 +77,23 @@ class MxnetArchitecture(Architecture): def __str__(self): return self.model.summary(*self._dummy_model_inputs()) + @staticmethod + def _sanitize_device_list(devices: List[mx.Context]) -> List[mx.Context]: + """ + Returns intersection of devices with available devices. If no intersection, returns mx.cpu() + :param devices: list of requested devices + :return: list of devices that are actually available + """ + actual_device = [mx.cpu()] + [mx.gpu(i) for i in mx.test_utils.list_gpus()] + intersection = [dev for dev in devices if dev in actual_device] + if len(intersection) == 0: + intersection = [mx.cpu()] + screen.log('Requested devices {} not available. Default to CPU context.'.format(devices)) + elif len(intersection) < len(devices): + screen.log('{} not available, using {}.'.format( + [dev for dev in devices if dev not in intersection], intersection)) + return intersection + def _model_grads(self, index: int=0) ->\ Union[Generator[NDArray, NDArray, Any], Generator[List[NDArray], List[NDArray], Any]]: """