1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-18 11:40:18 +01:00

Added code to fall back to CPU if GPU not available. (#150)

- Code will also prune GPU list if more than available GPUs is requested.
This commit is contained in:
Sina Afrooze
2018-11-24 22:32:26 -08:00
committed by Gal Leibovich
parent 7d25477942
commit 77fb561668

View File

@@ -26,6 +26,7 @@ from rl_coach.architectures.mxnet_components.heads.head import LOSS_OUT_TYPE_LOS
from rl_coach.architectures.mxnet_components import utils from rl_coach.architectures.mxnet_components import utils
from rl_coach.architectures.mxnet_components.savers import ParameterDictSaver, OnnxSaver from rl_coach.architectures.mxnet_components.savers import ParameterDictSaver, OnnxSaver
from rl_coach.base_parameters import AgentParameters from rl_coach.base_parameters import AgentParameters
from rl_coach.logger import screen
from rl_coach.saver import SaverCollection from rl_coach.saver import SaverCollection
from rl_coach.spaces import SpacesDefinition from rl_coach.spaces import SpacesDefinition
from rl_coach.utils import force_list, squeeze_list from rl_coach.utils import force_list, squeeze_list
@@ -58,7 +59,7 @@ class MxnetArchitecture(Architecture):
self.network_is_trainable = network_is_trainable self.network_is_trainable = network_is_trainable
self.is_training = False self.is_training = False
self.model = None # type: GeneralModel self.model = None # type: GeneralModel
self._devices = devices self._devices = self._sanitize_device_list(devices)
self.is_chief = self.ap.task_parameters.task_index == 0 self.is_chief = self.ap.task_parameters.task_index == 0
self.network_is_global = not self.network_is_local and global_network is None self.network_is_global = not self.network_is_local and global_network is None
@@ -76,6 +77,23 @@ class MxnetArchitecture(Architecture):
def __str__(self): def __str__(self):
return self.model.summary(*self._dummy_model_inputs()) return self.model.summary(*self._dummy_model_inputs())
@staticmethod
def _sanitize_device_list(devices: List[mx.Context]) -> List[mx.Context]:
"""
Returns intersection of devices with available devices. If no intersection, returns mx.cpu()
:param devices: list of requested devices
:return: list of devices that are actually available
"""
actual_device = [mx.cpu()] + [mx.gpu(i) for i in mx.test_utils.list_gpus()]
intersection = [dev for dev in devices if dev in actual_device]
if len(intersection) == 0:
intersection = [mx.cpu()]
screen.log('Requested devices {} not available. Default to CPU context.'.format(devices))
elif len(intersection) < len(devices):
screen.log('{} not available, using {}.'.format(
[dev for dev in devices if dev not in intersection], intersection))
return intersection
def _model_grads(self, index: int=0) ->\ def _model_grads(self, index: int=0) ->\
Union[Generator[NDArray, NDArray, Any], Generator[List[NDArray], List[NDArray], Any]]: Union[Generator[NDArray, NDArray, Any], Generator[List[NDArray], List[NDArray], Any]]:
""" """