1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 19:20:19 +01:00

Moved tf.variable_scope and tf.device calls to framework-specific architecture (#136)

This commit is contained in:
Sina Afrooze
2018-11-22 12:52:22 -08:00
committed by Gal Leibovich
parent 559969d3dd
commit 87a7848b0a
11 changed files with 219 additions and 91 deletions

View File

@@ -15,24 +15,26 @@ from rl_coach.core_types import GradientClippingMethod
nd_sym_type = Union[mx.nd.NDArray, mx.sym.Symbol]
def to_mx_ndarray(data: Union[list, tuple, np.ndarray, NDArray, int, float]) ->\
def to_mx_ndarray(data: Union[list, tuple, np.ndarray, NDArray, int, float], ctx: mx.Context=None) ->\
Union[List[NDArray], Tuple[NDArray], NDArray]:
"""
Convert data to mx.nd.NDArray. Data can be a list or tuple of np.ndarray, int, or float or
it can be np.ndarray, int, or float
:param data: input data to be converted
:param ctx: context of the data (CPU, GPU0, GPU1, etc.)
:return: converted output data
"""
if isinstance(data, list):
data = [to_mx_ndarray(d) for d in data]
data = [to_mx_ndarray(d, ctx=ctx) for d in data]
elif isinstance(data, tuple):
data = tuple(to_mx_ndarray(d) for d in data)
data = tuple(to_mx_ndarray(d, ctx=ctx) for d in data)
elif isinstance(data, np.ndarray):
data = nd.array(data)
data = nd.array(data, ctx=ctx)
elif isinstance(data, NDArray):
assert data.context == ctx
pass
elif isinstance(data, int) or isinstance(data, float):
data = nd.array([data])
data = nd.array([data], ctx=ctx)
else:
raise TypeError('Unsupported data type: {}'.format(type(data)))
return data