mirror of
https://github.com/gryf/coach.git
synced 2025-12-18 11:40:18 +01:00
Added code to fall back to CPU if GPU not available. (#150)
- Code will also prune GPU list if more than available GPUs is requested.
This commit is contained in:
committed by
Gal Leibovich
parent
7d25477942
commit
77fb561668
@@ -26,6 +26,7 @@ from rl_coach.architectures.mxnet_components.heads.head import LOSS_OUT_TYPE_LOS
|
|||||||
from rl_coach.architectures.mxnet_components import utils
|
from rl_coach.architectures.mxnet_components import utils
|
||||||
from rl_coach.architectures.mxnet_components.savers import ParameterDictSaver, OnnxSaver
|
from rl_coach.architectures.mxnet_components.savers import ParameterDictSaver, OnnxSaver
|
||||||
from rl_coach.base_parameters import AgentParameters
|
from rl_coach.base_parameters import AgentParameters
|
||||||
|
from rl_coach.logger import screen
|
||||||
from rl_coach.saver import SaverCollection
|
from rl_coach.saver import SaverCollection
|
||||||
from rl_coach.spaces import SpacesDefinition
|
from rl_coach.spaces import SpacesDefinition
|
||||||
from rl_coach.utils import force_list, squeeze_list
|
from rl_coach.utils import force_list, squeeze_list
|
||||||
@@ -58,7 +59,7 @@ class MxnetArchitecture(Architecture):
|
|||||||
self.network_is_trainable = network_is_trainable
|
self.network_is_trainable = network_is_trainable
|
||||||
self.is_training = False
|
self.is_training = False
|
||||||
self.model = None # type: GeneralModel
|
self.model = None # type: GeneralModel
|
||||||
self._devices = devices
|
self._devices = self._sanitize_device_list(devices)
|
||||||
|
|
||||||
self.is_chief = self.ap.task_parameters.task_index == 0
|
self.is_chief = self.ap.task_parameters.task_index == 0
|
||||||
self.network_is_global = not self.network_is_local and global_network is None
|
self.network_is_global = not self.network_is_local and global_network is None
|
||||||
@@ -76,6 +77,23 @@ class MxnetArchitecture(Architecture):
|
|||||||
def __str__(self):
|
def __str__(self):
|
||||||
return self.model.summary(*self._dummy_model_inputs())
|
return self.model.summary(*self._dummy_model_inputs())
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _sanitize_device_list(devices: List[mx.Context]) -> List[mx.Context]:
|
||||||
|
"""
|
||||||
|
Returns intersection of devices with available devices. If no intersection, returns mx.cpu()
|
||||||
|
:param devices: list of requested devices
|
||||||
|
:return: list of devices that are actually available
|
||||||
|
"""
|
||||||
|
actual_device = [mx.cpu()] + [mx.gpu(i) for i in mx.test_utils.list_gpus()]
|
||||||
|
intersection = [dev for dev in devices if dev in actual_device]
|
||||||
|
if len(intersection) == 0:
|
||||||
|
intersection = [mx.cpu()]
|
||||||
|
screen.log('Requested devices {} not available. Default to CPU context.'.format(devices))
|
||||||
|
elif len(intersection) < len(devices):
|
||||||
|
screen.log('{} not available, using {}.'.format(
|
||||||
|
[dev for dev in devices if dev not in intersection], intersection))
|
||||||
|
return intersection
|
||||||
|
|
||||||
def _model_grads(self, index: int=0) ->\
|
def _model_grads(self, index: int=0) ->\
|
||||||
Union[Generator[NDArray, NDArray, Any], Generator[List[NDArray], List[NDArray], Any]]:
|
Union[Generator[NDArray, NDArray, Any], Generator[List[NDArray], List[NDArray], Any]]:
|
||||||
"""
|
"""
|
||||||
|
|||||||
Reference in New Issue
Block a user