mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 19:20:19 +01:00
Adding mxnet components to rl_coach/architectures (#60)
Adding mxnet components to rl_coach architectures. - Supports PPO and DQN - Tested with CartPole_PPO and CarPole_DQN - Normalizing filters don't work right now (see #49) and are disabled in CartPole_PPO preset - Checkpointing is disabled for MXNet
This commit is contained in:
280
rl_coach/architectures/mxnet_components/utils.py
Normal file
280
rl_coach/architectures/mxnet_components/utils.py
Normal file
@@ -0,0 +1,280 @@
|
||||
"""
|
||||
Module defining utility functions
|
||||
"""
|
||||
import inspect
|
||||
from typing import Any, Dict, Generator, Iterable, List, Tuple, Union
|
||||
from types import ModuleType
|
||||
|
||||
import mxnet as mx
|
||||
from mxnet import nd
|
||||
from mxnet.ndarray import NDArray
|
||||
import numpy as np
|
||||
|
||||
from rl_coach.core_types import GradientClippingMethod
|
||||
|
||||
nd_sym_type = Union[mx.nd.NDArray, mx.sym.Symbol]
|
||||
|
||||
|
||||
def to_mx_ndarray(data: Union[list, tuple, np.ndarray, NDArray, int, float]) ->\
|
||||
Union[List[NDArray], Tuple[NDArray], NDArray]:
|
||||
"""
|
||||
Convert data to mx.nd.NDArray. Data can be a list or tuple of np.ndarray, int, or float or
|
||||
it can be np.ndarray, int, or float
|
||||
:param data: input data to be converted
|
||||
:return: converted output data
|
||||
"""
|
||||
if isinstance(data, list):
|
||||
data = [to_mx_ndarray(d) for d in data]
|
||||
elif isinstance(data, tuple):
|
||||
data = tuple(to_mx_ndarray(d) for d in data)
|
||||
elif isinstance(data, np.ndarray):
|
||||
data = nd.array(data)
|
||||
elif isinstance(data, NDArray):
|
||||
pass
|
||||
elif isinstance(data, int) or isinstance(data, float):
|
||||
data = nd.array([data])
|
||||
else:
|
||||
raise TypeError('Unsupported data type: {}'.format(type(data)))
|
||||
return data
|
||||
|
||||
|
||||
def asnumpy_or_asscalar(data: Union[NDArray, list, tuple]) -> Union[np.ndarray, np.number, list, tuple]:
|
||||
"""
|
||||
Convert NDArray (or list or tuple of NDArray) to numpy. If shape is (1,), then convert to scalar instead.
|
||||
NOTE: This behavior is consistent with tensorflow
|
||||
:param data: NDArray or list or tuple of NDArray
|
||||
:return: data converted to numpy ndarray or to numpy scalar
|
||||
"""
|
||||
if isinstance(data, list):
|
||||
data = [asnumpy_or_asscalar(d) for d in data]
|
||||
elif isinstance(data, tuple):
|
||||
data = tuple(asnumpy_or_asscalar(d) for d in data)
|
||||
elif isinstance(data, NDArray):
|
||||
data = data.asscalar() if data.shape == (1,) else data.asnumpy()
|
||||
else:
|
||||
raise TypeError('Unsupported data type: {}'.format(type(data)))
|
||||
return data
|
||||
|
||||
|
||||
def global_norm(arrays: Union[Generator[NDArray, NDArray, NDArray], List[NDArray], Tuple[NDArray]]) -> NDArray:
|
||||
"""
|
||||
Calculate global norm on list or tuple of NDArrays using this formula:
|
||||
`global_norm = sqrt(sum([l2norm(p)**2 for p in parameters]))`
|
||||
|
||||
:param arrays: list or tuple of parameters to calculate global norm on
|
||||
:return: single-value NDArray
|
||||
"""
|
||||
def _norm(array):
|
||||
if array.stype == 'default':
|
||||
x = array.reshape((-1,))
|
||||
return nd.dot(x, x)
|
||||
return array.norm().square()
|
||||
|
||||
total_norm = nd.add_n(*[_norm(arr) for arr in arrays])
|
||||
total_norm = nd.sqrt(total_norm)
|
||||
return total_norm
|
||||
|
||||
|
||||
def split_outputs_per_head(outputs: Tuple[NDArray], heads: list) -> List[List[NDArray]]:
|
||||
"""
|
||||
Split outputs into outputs per head
|
||||
:param outputs: list of all outputs
|
||||
:param heads: list of all heads
|
||||
:return: list of outputs for each head
|
||||
"""
|
||||
head_outputs = []
|
||||
for h in heads:
|
||||
head_outputs.append(list(outputs[:h.num_outputs]))
|
||||
outputs = outputs[h.num_outputs:]
|
||||
assert len(outputs) == 0
|
||||
return head_outputs
|
||||
|
||||
|
||||
def split_targets_per_loss(targets: list, losses: list) -> List[list]:
|
||||
"""
|
||||
Splits targets into targets per loss
|
||||
:param targets: list of all targets (typically numpy ndarray)
|
||||
:param losses: list of all losses
|
||||
:return: list of targets for each loss
|
||||
"""
|
||||
loss_targets = list()
|
||||
for l in losses:
|
||||
loss_data_len = len(l.input_schema.targets)
|
||||
assert len(targets) >= loss_data_len, "Data length doesn't match schema"
|
||||
loss_targets.append(targets[:loss_data_len])
|
||||
targets = targets[loss_data_len:]
|
||||
assert len(targets) == 0
|
||||
return loss_targets
|
||||
|
||||
|
||||
def get_loss_agent_inputs(inputs: Dict[str, np.ndarray], head_type_idx: int, loss: Any) -> List[np.ndarray]:
|
||||
"""
|
||||
Collects all inputs with prefix 'output_<head_idx>_' and matches them against agent_inputs in loss input schema.
|
||||
:param inputs: list of all agent inputs
|
||||
:param head_type_idx: head-type index of the corresponding head
|
||||
:param loss: corresponding loss
|
||||
:return: list of agent inputs for this loss. This list matches the length in loss input schema.
|
||||
"""
|
||||
loss_inputs = list()
|
||||
for k in sorted(inputs.keys()):
|
||||
if k.startswith('output_{}_'.format(head_type_idx)):
|
||||
loss_inputs.append(inputs[k])
|
||||
# Enforce that number of inputs for head_type are the same as agent_inputs specified by loss input_schema
|
||||
assert len(loss_inputs) == len(loss.input_schema.agent_inputs), "agent_input length doesn't match schema"
|
||||
return loss_inputs
|
||||
|
||||
|
||||
def align_loss_args(
|
||||
head_outputs: List[NDArray],
|
||||
agent_inputs: List[np.ndarray],
|
||||
targets: List[np.ndarray],
|
||||
loss: Any) -> List[np.ndarray]:
|
||||
"""
|
||||
Creates a list of arguments from head_outputs, agent_inputs, and targets aligned with parameters of
|
||||
loss.loss_forward() based on their name in loss input_schema
|
||||
:param head_outputs: list of all head_outputs for this loss
|
||||
:param agent_inputs: list of all agent_inputs for this loss
|
||||
:param targets: list of all targets for this loss
|
||||
:param loss: corresponding loss
|
||||
:return: list of arguments in correct order to be passed to loss
|
||||
"""
|
||||
arg_list = list()
|
||||
schema = loss.input_schema
|
||||
assert len(schema.head_outputs) == len(head_outputs)
|
||||
assert len(schema.agent_inputs) == len(agent_inputs)
|
||||
assert len(schema.targets) == len(targets)
|
||||
|
||||
prev_found = True
|
||||
for arg_name in inspect.getfullargspec(loss.loss_forward).args[2:]: # First two args are self and F
|
||||
found = False
|
||||
for schema_list, data in [(schema.head_outputs, head_outputs),
|
||||
(schema.agent_inputs, agent_inputs),
|
||||
(schema.targets, targets)]:
|
||||
try:
|
||||
arg_list.append(data[schema_list.index(arg_name)])
|
||||
found = True
|
||||
break
|
||||
except ValueError:
|
||||
continue
|
||||
assert not found or prev_found, "missing arguments detected!"
|
||||
prev_found = found
|
||||
return arg_list
|
||||
|
||||
|
||||
def to_tuple(data: Union[tuple, list, Any]):
|
||||
"""
|
||||
If input is list, it is converted to tuple. If it's tuple, it is returned untouched. Otherwise
|
||||
returns a single-element tuple of the data.
|
||||
:return: tuple-ified data
|
||||
"""
|
||||
if isinstance(data, tuple):
|
||||
pass
|
||||
elif isinstance(data, list):
|
||||
data = tuple(data)
|
||||
else:
|
||||
data = (data,)
|
||||
return data
|
||||
|
||||
|
||||
def to_list(data: Union[tuple, list, Any]):
|
||||
"""
|
||||
If input is tuple, it is converted to list. If it's list, it is returned untouched. Otherwise
|
||||
returns a single-element list of the data.
|
||||
:return: list-ified data
|
||||
"""
|
||||
if isinstance(data, list):
|
||||
pass
|
||||
elif isinstance(data, tuple):
|
||||
data = list(data)
|
||||
else:
|
||||
data = [data]
|
||||
return data
|
||||
|
||||
|
||||
def loss_output_dict(output: List[NDArray], schema: List[str]) -> Dict[str, List[NDArray]]:
|
||||
"""
|
||||
Creates a dictionary for loss output based on the output schema. If two output values have the same
|
||||
type string in the schema they are concatenated in the same dicrionary item.
|
||||
:param output: list of output values
|
||||
:param schema: list of type-strings for output values
|
||||
:return: dictionary of keyword to list of NDArrays
|
||||
"""
|
||||
assert len(output) == len(schema)
|
||||
output_dict = dict()
|
||||
for name, val in zip(schema, output):
|
||||
if name in output_dict:
|
||||
output_dict[name].append(val)
|
||||
else:
|
||||
output_dict[name] = [val]
|
||||
return output_dict
|
||||
|
||||
|
||||
def clip_grad(
|
||||
grads: Union[Generator[NDArray, NDArray, NDArray], List[NDArray], Tuple[NDArray]],
|
||||
clip_method: GradientClippingMethod,
|
||||
clip_val: float,
|
||||
inplace=True) -> List[NDArray]:
|
||||
"""
|
||||
Clip gradient values inplace
|
||||
:param grads: gradients to be clipped
|
||||
:param clip_method: clipping method
|
||||
:param clip_val: clipping value. Interpreted differently depending on clipping method.
|
||||
:param inplace: modify grads if True, otherwise create NDArrays
|
||||
:return: clipped gradients
|
||||
"""
|
||||
output = list(grads) if inplace else list(nd.empty(g.shape) for g in grads)
|
||||
if clip_method == GradientClippingMethod.ClipByGlobalNorm:
|
||||
norm_unclipped_grads = global_norm(grads)
|
||||
scale = clip_val / (norm_unclipped_grads.asscalar() + 1e-8) # todo: use branching operators?
|
||||
if scale < 1.0:
|
||||
for g, o in zip(grads, output):
|
||||
nd.broadcast_mul(g, nd.array([scale]), out=o)
|
||||
elif clip_method == GradientClippingMethod.ClipByValue:
|
||||
for g, o in zip(grads, output):
|
||||
g.clip(-clip_val, clip_val, out=o)
|
||||
elif clip_method == GradientClippingMethod.ClipByNorm:
|
||||
for g, o in zip(grads, output):
|
||||
nd.broadcast_mul(g, nd.minimum(1.0, clip_val / (g.norm() + 1e-8)), out=o)
|
||||
else:
|
||||
raise KeyError('Unsupported gradient clipping method')
|
||||
return output
|
||||
|
||||
|
||||
def hybrid_clip(F: ModuleType, x: nd_sym_type, clip_lower: nd_sym_type, clip_upper: nd_sym_type) -> nd_sym_type:
|
||||
"""
|
||||
Apply clipping to input x between clip_lower and clip_upper.
|
||||
Added because F.clip doesn't support clipping bounds that are mx.nd.NDArray or mx.sym.Symbol.
|
||||
|
||||
:param F: backend api, either `mxnet.nd` or `mxnet.sym` (if block has been hybridized).
|
||||
:param x: input data
|
||||
:param clip_lower: lower bound used for clipping, should be of shape (1,)
|
||||
:param clip_upper: upper bound used for clipping, should be of shape (1,)
|
||||
:return: clipped data
|
||||
"""
|
||||
x_clip_lower = clip_lower.broadcast_like(x)
|
||||
x_clip_upper = clip_upper.broadcast_like(x)
|
||||
x_clipped = F.stack(x, x_clip_lower, axis=0).max(axis=0)
|
||||
x_clipped = F.stack(x_clipped, x_clip_upper, axis=0).min(axis=0)
|
||||
return x_clipped
|
||||
|
||||
|
||||
def get_mxnet_activation_name(activation_name: str):
|
||||
"""
|
||||
Convert coach activation name to mxnet specific activation name
|
||||
:param activation_name: name of the activation inc coach
|
||||
:return: name of the activation in mxnet
|
||||
"""
|
||||
activation_functions = {
|
||||
'relu': 'relu',
|
||||
'tanh': 'tanh',
|
||||
'sigmoid': 'sigmoid',
|
||||
# FIXME Add other activations
|
||||
# 'elu': tf.nn.elu,
|
||||
'selu': 'softrelu',
|
||||
# 'leaky_relu': tf.nn.leaky_relu,
|
||||
'none': None
|
||||
}
|
||||
assert activation_name in activation_functions, \
|
||||
"Activation function must be one of the following {}. instead it was: {}".format(
|
||||
activation_functions.keys(), activation_name)
|
||||
return activation_functions[activation_name]
|
||||
Reference in New Issue
Block a user