mirror of
https://github.com/gryf/coach.git
synced 2025-12-18 11:40:18 +01:00
fix dist. tf (#153)
This commit is contained in:
@@ -15,6 +15,7 @@
|
||||
#
|
||||
|
||||
import copy
|
||||
from types import MethodType
|
||||
from typing import Dict, List, Union
|
||||
|
||||
import numpy as np
|
||||
@@ -73,14 +74,14 @@ class GeneralTensorFlowNetwork(TensorFlowArchitecture):
|
||||
return construct_on_device()
|
||||
|
||||
@staticmethod
|
||||
def _tf_device(device: Union[str, Device]) -> str:
|
||||
def _tf_device(device: Union[str, MethodType, Device]) -> str:
|
||||
"""
|
||||
Convert device to tensorflow-specific device representation
|
||||
:param device: either a specific string (used in distributed mode) which is returned without
|
||||
any change or a Device type
|
||||
:param device: either a specific string or method (used in distributed mode) which is returned without
|
||||
any change or a Device type, which will be converted to a string
|
||||
:return: tensorflow-specific string for device
|
||||
"""
|
||||
if isinstance(device, str):
|
||||
if isinstance(device, str) or isinstance(device, MethodType):
|
||||
return device
|
||||
elif isinstance(device, Device):
|
||||
if device.device_type == DeviceType.CPU:
|
||||
|
||||
Reference in New Issue
Block a user