diff --git a/rl_coach/architectures/tensorflow_components/general_network.py b/rl_coach/architectures/tensorflow_components/general_network.py index cecc157..ca72bd7 100644 --- a/rl_coach/architectures/tensorflow_components/general_network.py +++ b/rl_coach/architectures/tensorflow_components/general_network.py @@ -15,6 +15,7 @@ # import copy +from types import MethodType from typing import Dict, List, Union import numpy as np @@ -73,14 +74,14 @@ class GeneralTensorFlowNetwork(TensorFlowArchitecture): return construct_on_device() @staticmethod - def _tf_device(device: Union[str, Device]) -> str: + def _tf_device(device: Union[str, MethodType, Device]) -> str: """ Convert device to tensorflow-specific device representation - :param device: either a specific string (used in distributed mode) which is returned without - any change or a Device type + :param device: either a specific string or method (used in distributed mode) which is returned without + any change or a Device type, which will be converted to a string :return: tensorflow-specific string for device """ - if isinstance(device, str): + if isinstance(device, str) or isinstance(device, MethodType): return device elif isinstance(device, Device): if device.device_type == DeviceType.CPU: