1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-18 19:50:17 +01:00

fix dist. tf (#153)

This commit is contained in:
Gal Leibovich
2018-11-25 14:02:24 +02:00
committed by GitHub
parent 19a68812f6
commit 11170d5ba3

View File

@@ -15,6 +15,7 @@
#
import copy
from types import MethodType
from typing import Dict, List, Union
import numpy as np
@@ -73,14 +74,14 @@ class GeneralTensorFlowNetwork(TensorFlowArchitecture):
return construct_on_device()
@staticmethod
def _tf_device(device: Union[str, Device]) -> str:
def _tf_device(device: Union[str, MethodType, Device]) -> str:
"""
Convert device to tensorflow-specific device representation
:param device: either a specific string (used in distributed mode) which is returned without
any change or a Device type
:param device: either a specific string or method (used in distributed mode) which is returned without
any change or a Device type, which will be converted to a string
:return: tensorflow-specific string for device
"""
if isinstance(device, str):
if isinstance(device, str) or isinstance(device, MethodType):
return device
elif isinstance(device, Device):
if device.device_type == DeviceType.CPU: