mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 19:20:19 +01:00
Moved tf.variable_scope and tf.device calls to framework-specific architecture (#136)
This commit is contained in:
committed by
Gal Leibovich
parent
559969d3dd
commit
87a7848b0a
@@ -15,24 +15,26 @@ from rl_coach.core_types import GradientClippingMethod
|
||||
nd_sym_type = Union[mx.nd.NDArray, mx.sym.Symbol]
|
||||
|
||||
|
||||
def to_mx_ndarray(data: Union[list, tuple, np.ndarray, NDArray, int, float]) ->\
|
||||
def to_mx_ndarray(data: Union[list, tuple, np.ndarray, NDArray, int, float], ctx: mx.Context=None) ->\
|
||||
Union[List[NDArray], Tuple[NDArray], NDArray]:
|
||||
"""
|
||||
Convert data to mx.nd.NDArray. Data can be a list or tuple of np.ndarray, int, or float or
|
||||
it can be np.ndarray, int, or float
|
||||
:param data: input data to be converted
|
||||
:param ctx: context of the data (CPU, GPU0, GPU1, etc.)
|
||||
:return: converted output data
|
||||
"""
|
||||
if isinstance(data, list):
|
||||
data = [to_mx_ndarray(d) for d in data]
|
||||
data = [to_mx_ndarray(d, ctx=ctx) for d in data]
|
||||
elif isinstance(data, tuple):
|
||||
data = tuple(to_mx_ndarray(d) for d in data)
|
||||
data = tuple(to_mx_ndarray(d, ctx=ctx) for d in data)
|
||||
elif isinstance(data, np.ndarray):
|
||||
data = nd.array(data)
|
||||
data = nd.array(data, ctx=ctx)
|
||||
elif isinstance(data, NDArray):
|
||||
assert data.context == ctx
|
||||
pass
|
||||
elif isinstance(data, int) or isinstance(data, float):
|
||||
data = nd.array([data])
|
||||
data = nd.array([data], ctx=ctx)
|
||||
else:
|
||||
raise TypeError('Unsupported data type: {}'.format(type(data)))
|
||||
return data
|
||||
|
||||
Reference in New Issue
Block a user