1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 19:20:19 +01:00

Adding mxnet components to rl_coach/architectures (#60)

Adding mxnet components to rl_coach architectures.

- Supports PPO and DQN
- Tested with CartPole_PPO and CarPole_DQN
- Normalizing filters don't work right now (see #49) and are disabled in CartPole_PPO preset
- Checkpointing is disabled for MXNet
This commit is contained in:
Sina Afrooze
2018-11-07 07:07:15 -08:00
committed by Itai Caspi
parent e7a91b4dc3
commit 5fadb9c18e
39 changed files with 3864 additions and 44 deletions

View File

@@ -5,12 +5,12 @@ COPY setup.py /root/src/.
COPY requirements.txt /root/src/.
COPY README.md /root/src/.
WORKDIR /root/src
RUN pip3 install -e .
RUN pip3 install -e .[all]
# everything above here should be cached most of the time
COPY . /root/src
WORKDIR /root/src
RUN pip3 install -e .
RUN pip3 install -e .[all]
RUN chmod 777 /root/src/docker/docker_entrypoint.sh
ENTRYPOINT ["/root/src/docker/docker_entrypoint.sh"]

View File

@@ -41,22 +41,41 @@ class Architecture(object):
self.optimizer = None
self.ap = agent_parameters
def predict(self, inputs: Dict[str, np.ndarray]) -> List[np.ndarray]:
def predict(self,
inputs: Dict[str, np.ndarray],
outputs: List[Any] = None,
squeeze_output: bool = True,
initial_feed_dict: Dict[Any, np.ndarray] = None) -> Tuple[np.ndarray, ...]:
"""
Given input observations, use the model to make predictions (e.g. action or value).
:param inputs: current state (i.e. observations, measurements, goals, etc.)
(e.g. `{'observation': numpy.ndarray}` of shape (batch_size, observation_space_size))
:param outputs: list of outputs to return. Return all outputs if unspecified. Type of the list elements
depends on the framework backend.
:param squeeze_output: call squeeze_list on output before returning if True
:param initial_feed_dict: a dictionary of extra inputs for forward pass.
:return: predictions of action or value of shape (batch_size, action_space_size) for action predictions)
"""
pass
@staticmethod
def parallel_predict(sess: Any,
network_input_tuples: List[Tuple['Architecture', Dict[str, np.ndarray]]]) -> \
Tuple[np.ndarray, ...]:
"""
:param sess: active session to use for prediction
:param network_input_tuples: tuple of network and corresponding input
:return: list or tuple of outputs from all networks
"""
pass
def train_on_batch(self,
inputs: Dict[str, np.ndarray],
targets: List[np.ndarray],
scaler: float=1.,
additional_fetches: list=None,
importance_weights: np.ndarray=None) -> tuple:
importance_weights: np.ndarray=None) -> Tuple[float, List[float], float, list]:
"""
Given a batch of inputs (e.g. states) and targets (e.g. discounted rewards), takes a training step: i.e. runs a
forward pass and backward pass of the network, accumulates the gradients and applies an optimization step to
@@ -118,8 +137,7 @@ class Architecture(object):
targets: List[np.ndarray],
additional_fetches: list=None,
importance_weights: np.ndarray=None,
no_accumulation: bool=False) ->\
Tuple[float, List[float], float, list]:
no_accumulation: bool=False) -> Tuple[float, List[float], float, list]:
"""
Given a batch of inputs (i.e. states) and targets (e.g. discounted rewards), computes and accumulates the
gradients for model parameters. Will run forward and backward pass to compute gradients, clip the gradient
@@ -142,30 +160,33 @@ class Architecture(object):
calculated gradients
:return: tuple of total_loss, losses, norm_unclipped_grads, fetched_tensors
total_loss (float): sum of all head losses
losses (list of float): list of all losses. The order is list of target losses followed by list of regularization losses.
The specifics of losses is dependant on the network parameters (number of heads, etc.)
losses (list of float): list of all losses. The order is list of target losses followed by list of
regularization losses. The specifics of losses is dependant on the network parameters
(number of heads, etc.)
norm_unclippsed_grads (float): global norm of all gradients before any gradient clipping is applied
fetched_tensors: all values for additional_fetches
"""
pass
def apply_and_reset_gradients(self, gradients: List[np.ndarray]) -> None:
def apply_and_reset_gradients(self, gradients: List[np.ndarray], scaler: float=1.) -> None:
"""
Applies the given gradients to the network weights and resets the gradient accumulations.
Has the same impact as calling `apply_gradients`, then `reset_accumulated_gradients`.
:param gradients: gradients for the parameter weights, taken from `accumulated_gradients` property
of an identical network (either self or another identical network)
:param scaler: A scaling factor that allows rescaling the gradients before applying them
"""
pass
def apply_gradients(self, gradients: List[np.ndarray]) -> None:
def apply_gradients(self, gradients: List[np.ndarray], scaler: float=1.) -> None:
"""
Applies the given gradients to the network weights.
Will be performed sync or async depending on `network_parameters.async_training`
:param gradients: gradients for the parameter weights, taken from `accumulated_gradients` property
of an identical network (either self or another identical network)
:param scaler: A scaling factor that allows rescaling the gradients before applying them
"""
pass

View File

@@ -0,0 +1,405 @@
#
# Copyright (c) 2017 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import copy
from typing import Any, Dict, Generator, List, Tuple, Union
import numpy as np
import mxnet as mx
from mxnet import autograd, gluon, nd
from mxnet.ndarray import NDArray
from rl_coach.architectures.architecture import Architecture
from rl_coach.architectures.mxnet_components.heads.head import LOSS_OUT_TYPE_LOSS, LOSS_OUT_TYPE_REGULARIZATION
from rl_coach.architectures.mxnet_components import utils
from rl_coach.base_parameters import AgentParameters, DistributedTaskParameters
from rl_coach.spaces import SpacesDefinition
from rl_coach.utils import force_list, squeeze_list
class MxnetArchitecture(Architecture):
def __init__(self, agent_parameters: AgentParameters, spaces: SpacesDefinition, name: str= "",
global_network=None, network_is_local: bool=True, network_is_trainable: bool=False):
"""
:param agent_parameters: the agent parameters
:param spaces: the spaces definition of the agent
:param name: the name of the network
:param global_network: the global network replica that is shared between all the workers
:param network_is_local: is the network global (shared between workers) or local (dedicated to the worker)
:param network_is_trainable: is the network trainable (we can apply gradients on it)
"""
super().__init__(agent_parameters, spaces, name)
self.middleware = None
self.network_is_local = network_is_local
self.global_network = global_network
if not self.network_parameters.tensorflow_support:
raise ValueError('TensorFlow is not supported for this agent')
self.losses = [] # type: List[HeadLoss]
self.shared_accumulated_gradients = []
self.curr_rnn_c_in = None
self.curr_rnn_h_in = None
self.gradients_wrt_inputs = []
self.train_writer = None
self.accumulated_gradients = None
self.network_is_trainable = network_is_trainable
self.is_training = False
self.model = None # type: GeneralModel
self.is_chief = self.ap.task_parameters.task_index == 0
self.network_is_global = not self.network_is_local and global_network is None
self.distributed_training = self.network_is_global or self.network_is_local and global_network is not None
self.optimizer_type = self.network_parameters.optimizer_type
if self.ap.task_parameters.seed is not None:
mx.random.seed(self.ap.task_parameters.seed)
# Call to child class to create the model
self.construct_model()
self.trainer = None # type: gluon.Trainer
def __str__(self):
return self.model.summary(*self._dummy_model_inputs())
@property
def _model_grads(self) -> Generator[NDArray, NDArray, Any]:
"""
Creates a copy of model gradients and returns them in a list, in the same order as collect_params()
:return: a generator for model gradient values
"""
return (p.list_grad()[0].copy() for p in self.model.collect_params().values() if p.grad_req != 'null')
def _dummy_model_inputs(self) -> Tuple[NDArray, ...]:
"""
Creates a tuple of input arrays with correct shapes that can be used for shape inference
of the model weights and for printing the summary
:return: tuple of inputs for model forward pass
"""
allowed_inputs = copy.copy(self.spaces.state.sub_spaces)
allowed_inputs["action"] = copy.copy(self.spaces.action)
allowed_inputs["goal"] = copy.copy(self.spaces.goal)
embedders = self.model.nets[0].input_embedders
inputs = tuple(nd.zeros((1,) + tuple(allowed_inputs[emb.embedder_name].shape.tolist())) for emb in embedders)
return inputs
def construct_model(self) -> None:
"""
Construct network model. Implemented by child class.
"""
raise NotImplementedError
def set_session(self, sess) -> None:
"""
Initializes the model parameters and creates the model trainer.
NOTEL Session for mxnet backend must be None.
:param sess: must be None
"""
assert sess is None
# FIXME Add GPU initialization
# FIXME Add initializer
self.model.collect_params().initialize(ctx=mx.cpu())
# Hybridize model and losses
self.model.hybridize()
for l in self.losses:
l.hybridize()
# Pass dummy data with correct shape to trigger shape inference and full parameter initialization
self.model(*self._dummy_model_inputs())
if self.network_is_trainable:
self.trainer = gluon.Trainer(
self.model.collect_params(), optimizer=self.optimizer, update_on_kvstore=False)
def reset_accumulated_gradients(self) -> None:
"""
Reset model gradients as well as accumulated gradients to zero. If accumulated gradients
have not been created yet, it constructs them on CPU.
"""
# Set model gradients to zero
for p in self.model.collect_params().values():
p.zero_grad()
# Set accumulated gradients to zero if already initialized, otherwise create a copy
if self.accumulated_gradients:
for a in self.accumulated_gradients:
a *= 0
else:
self.accumulated_gradients = [g.copy() for g in self._model_grads]
def accumulate_gradients(self,
inputs: Dict[str, np.ndarray],
targets: List[np.ndarray],
additional_fetches: List[Tuple[int, str]] = None,
importance_weights: np.ndarray = None,
no_accumulation: bool = False) -> Tuple[float, List[float], float, list]:
"""
Runs a forward & backward pass, clips gradients if needed and accumulates them into the accumulation
:param inputs: environment states (observation, etc.) as well extra inputs required by loss. Shape of ndarray
is (batch_size, observation_space_size) or (batch_size, observation_space_size, stack_size)
:param targets: targets required by loss (e.g. sum of discounted rewards)
:param additional_fetches: additional fetches to calculate and return. Each fetch is specified as (int, str)
tuple of head-type-index and fetch-name. The tuple is obtained from each head.
:param importance_weights: ndarray of shape (batch_size,) to multiply with batch loss.
:param no_accumulation: if True, set gradient values to the new gradients, otherwise sum with previously
calculated gradients
:return: tuple of total_loss, losses, norm_unclipped_grads, fetched_tensors
total_loss (float): sum of all head losses
losses (list of float): list of all losses. The order is list of target losses followed by list of
regularization losses. The specifics of losses is dependant on the network parameters
(number of heads, etc.)
norm_unclippsed_grads (float): global norm of all gradients before any gradient clipping is applied
fetched_tensors: all values for additional_fetches
"""
if self.accumulated_gradients is None:
self.reset_accumulated_gradients()
embedders = [emb.embedder_name for emb in self.model.nets[0].input_embedders]
nd_inputs = tuple(nd.array(inputs[emb]) for emb in embedders)
assert self.middleware.__class__.__name__ != 'LSTMMiddleware', "LSTM middleware not supported"
targets = force_list(targets)
with autograd.record():
out_per_head = utils.split_outputs_per_head(self.model(*nd_inputs), self.model.output_heads)
tgt_per_loss = utils.split_targets_per_loss(targets, self.losses)
losses = list()
regularizations = list()
additional_fetches = [(k, None) for k in additional_fetches]
for h, h_loss, h_out, l_tgt in zip(self.model.output_heads, self.losses, out_per_head, tgt_per_loss):
l_in = utils.get_loss_agent_inputs(inputs, head_type_idx=h.head_type_idx, loss=h_loss)
# Align arguments with loss.loss_forward and convert to NDArray
l_args = utils.to_mx_ndarray(utils.align_loss_args(h_out, l_in, l_tgt, h_loss))
# Calculate loss and all auxiliary outputs
loss_outputs = utils.loss_output_dict(utils.to_list(h_loss(*l_args)), h_loss.output_schema)
if LOSS_OUT_TYPE_LOSS in loss_outputs:
losses.extend(loss_outputs[LOSS_OUT_TYPE_LOSS])
if LOSS_OUT_TYPE_REGULARIZATION in loss_outputs:
regularizations.extend(loss_outputs[LOSS_OUT_TYPE_REGULARIZATION])
# Set additional fetches
for i, fetch in enumerate(additional_fetches):
head_type_idx, fetch_name = fetch[0] # fetch key is a tuple of (head_type_index, fetch_name)
if head_type_idx == h.head_type_idx:
assert fetch[1] is None # sanity check that fetch is None
additional_fetches[i] = (fetch[0], loss_outputs[fetch_name])
# Total loss is losses and regularization (NOTE: order is important)
total_loss_list = losses + regularizations
total_loss = nd.add_n(*total_loss_list)
# Calculate gradients
total_loss.backward()
assert self.optimizer_type != 'LBFGS', 'LBFGS not supported'
# allreduce gradients from all contexts
self.trainer.allreduce_grads()
# Calculate global norm of gradients
# FIXME global norm is returned even when not used for clipping! Is this necessary?
# FIXME global norm might be calculated twice if clipping method is global norm
norm_unclipped_grads = utils.global_norm(self._model_grads)
# Clip gradients
if self.network_parameters.clip_gradients:
utils.clip_grad(
self._model_grads,
clip_method=self.network_parameters.gradients_clipping_method,
clip_val=self.network_parameters.clip_gradients,
inplace=True)
# Update self.accumulated_gradients depending on no_accumulation flag
if no_accumulation:
for acc_grad, model_grad in zip(self.accumulated_gradients, self._model_grads):
acc_grad[:] = model_grad
else:
for acc_grad, model_grad in zip(self.accumulated_gradients, self._model_grads):
acc_grad += model_grad
# result of of additional fetches
fetched_tensors = [fetch[1] for fetch in additional_fetches]
# convert everything to numpy or scalar before returning
result = utils.asnumpy_or_asscalar((total_loss, total_loss_list, norm_unclipped_grads, fetched_tensors))
return result
def apply_and_reset_gradients(self, gradients: List[np.ndarray], scaler: float=1.) -> None:
"""
Applies the given gradients to the network weights and resets accumulated gradients to zero
:param gradients: The gradients to use for the update
:param scaler: A scaling factor that allows rescaling the gradients before applying them
"""
self.apply_gradients(gradients, scaler)
self.reset_accumulated_gradients()
def apply_gradients(self, gradients: List[np.ndarray], scaler: float=1.) -> None:
"""
Applies the given gradients to the network weights
:param gradients: The gradients to use for the update
:param scaler: A scaling factor that allows rescaling the gradients before applying them.
The gradients will be MULTIPLIED by this factor
"""
assert self.optimizer_type != 'LBFGS'
batch_size = 1
if self.distributed_training and not self.network_parameters.async_training:
# rescale the gradients so that they average out with the gradients from the other workers
if self.network_parameters.scale_down_gradients_by_number_of_workers_for_sync_training:
batch_size = self.ap.task_parameters.num_training_tasks
# set parameter gradients to gradients passed in
for param_grad, gradient in zip(self._model_grads, gradients):
param_grad[:] = gradient
# update gradients
self.trainer.update(batch_size=batch_size)
def _predict(self, inputs: Dict[str, np.ndarray]) -> Tuple[NDArray, ...]:
"""
Run a forward pass of the network using the given input
:param inputs: The input dictionary for the network. Key is name of the embedder.
:return: The network output
WARNING: must only call once per state since each call is assumed by LSTM to be a new time step.
"""
embedders = [emb.embedder_name for emb in self.model.nets[0].input_embedders]
nd_inputs = tuple(nd.array(inputs[emb]) for emb in embedders)
assert self.middleware.__class__.__name__ != 'LSTMMiddleware'
output = self.model(*nd_inputs)
return output
def predict(self,
inputs: Dict[str, np.ndarray],
outputs: List[str]=None,
squeeze_output: bool=True,
initial_feed_dict: Dict[str, np.ndarray]=None) -> Tuple[np.ndarray, ...]:
"""
Run a forward pass of the network using the given input
:param inputs: The input dictionary for the network. Key is name of the embedder.
:param outputs: list of outputs to return. Return all outputs if unspecified (currently not supported)
:param squeeze_output: call squeeze_list on output if True
:param initial_feed_dict: a dictionary of extra inputs for forward pass (currently not supported)
:return: The network output
WARNING: must only call once per state since each call is assumed by LSTM to be a new time step.
"""
assert initial_feed_dict is None, "initial_feed_dict must be None"
assert outputs is None, "outputs must be None"
output = self._predict(inputs)
output = tuple(o.asnumpy() for o in output)
if squeeze_output:
output = squeeze_list(output)
return output
@staticmethod
def parallel_predict(sess: Any,
network_input_tuples: List[Tuple['MxnetArchitecture', Dict[str, np.ndarray]]]) -> \
Tuple[np.ndarray, ...]:
"""
:param sess: active session to use for prediction (must be None for MXNet)
:param network_input_tuples: tuple of network and corresponding input
:return: tuple of outputs from all networks
"""
assert sess is None
output = list()
for net, inputs in network_input_tuples:
output += net._predict(inputs)
return tuple(o.asnumpy() for o in output)
def train_on_batch(self,
inputs: Dict[str, np.ndarray],
targets: List[np.ndarray],
scaler: float = 1.,
additional_fetches: list = None,
importance_weights: np.ndarray = None) -> Tuple[float, List[float], float, list]:
"""
Given a batch of inputs (e.g. states) and targets (e.g. discounted rewards), takes a training step: i.e. runs a
forward pass and backward pass of the network, accumulates the gradients and applies an optimization step to
update the weights.
:param inputs: environment states (observation, etc.) as well extra inputs required by loss. Shape of ndarray
is (batch_size, observation_space_size) or (batch_size, observation_space_size, stack_size)
:param targets: targets required by loss (e.g. sum of discounted rewards)
:param scaler: value to scale gradients by before optimizing network weights
:param additional_fetches: additional fetches to calculate and return. Each fetch is specified as (int, str)
tuple of head-type-index and fetch-name. The tuple is obtained from each head.
:param importance_weights: ndarray of shape (batch_size,) to multiply with batch loss.
:return: tuple of total_loss, losses, norm_unclipped_grads, fetched_tensors
total_loss (float): sum of all head losses
losses (list of float): list of all losses. The order is list of target losses followed by list
of regularization losses. The specifics of losses is dependant on the network parameters
(number of heads, etc.)
norm_unclippsed_grads (float): global norm of all gradients before any gradient clipping is applied
fetched_tensors: all values for additional_fetches
"""
loss = self.accumulate_gradients(inputs, targets, additional_fetches=additional_fetches,
importance_weights=importance_weights)
self.apply_and_reset_gradients(self.accumulated_gradients, scaler)
return loss
def get_weights(self) -> gluon.ParameterDict:
"""
:return: a ParameterDict containing all network weights
"""
return self.model.collect_params()
def set_weights(self, weights: gluon.ParameterDict, new_rate: float=1.0) -> None:
"""
Sets the network weights from the given ParameterDict
:param new_rate: ratio for adding new and old weight values: val=rate*weights + (1-rate)*old_weights
"""
old_weights = self.model.collect_params()
for name, p in weights.items():
name = name[len(weights.prefix):] # Strip prefix
old_p = old_weights[old_weights.prefix + name] # Add prefix
old_p.set_data(new_rate * p._reduce() + (1 - new_rate) * old_p._reduce())
def get_variable_value(self, variable: Union[gluon.Parameter, NDArray]) -> np.ndarray:
"""
Get the value of a variable
:param variable: the variable
:return: the value of the variable
"""
if isinstance(variable, gluon.Parameter):
variable = variable._reduce().asnumpy()
if isinstance(variable, NDArray):
return variable.asnumpy()
return variable
def set_variable_value(self, assign_op: callable, value: Any, placeholder=None) -> None:
"""
Updates value of a variable.
:param assign_op: a callable assign function for setting the variable
:param value: a value to set the variable to
:param placeholder: unused (placeholder in symbolic framework backends)
"""
assert callable(assign_op)
assign_op(value)
def set_is_training(self, state: bool) -> None:
"""
Set the phase of the network between training and testing
:param state: The current state (True = Training, False = Testing)
:return: None
"""
self.is_training = state
def reset_internal_memory(self) -> None:
"""
Reset any internal memory used by the network. For example, an LSTM internal state
:return: None
"""
assert self.middleware.__class__.__name__ != 'LSTMMiddleware', 'LSTM middleware not supported'

View File

@@ -0,0 +1,4 @@
from .image_embedder import ImageEmbedder
from .vector_embedder import VectorEmbedder
__all__ = ['ImageEmbedder', 'VectorEmbedder']

View File

@@ -0,0 +1,71 @@
from typing import Union
from types import ModuleType
import mxnet as mx
from mxnet.gluon import nn
from rl_coach.architectures.embedder_parameters import InputEmbedderParameters
from rl_coach.architectures.mxnet_components.layers import convert_layer
from rl_coach.base_parameters import EmbedderScheme
nd_sym_type = Union[mx.nd.NDArray, mx.sym.Symbol]
class InputEmbedder(nn.HybridBlock):
def __init__(self, params: InputEmbedderParameters):
"""
An input embedder is the first part of the network, which takes the input from the state and produces a vector
embedding by passing it through a neural network. The embedder will mostly be input type dependent, and there
can be multiple embedders in a single network.
:param params: parameters object containing input_clipping, input_rescaling, batchnorm, activation_function
and dropout properties.
"""
super(InputEmbedder, self).__init__()
self.embedder_name = params.name
self.input_clipping = params.input_clipping
self.scheme = params.scheme
with self.name_scope():
self.net = nn.HybridSequential()
if isinstance(self.scheme, EmbedderScheme):
blocks = self.schemes[self.scheme]
else:
# if scheme is specified directly, convert to MX layer if it's not a callable object
# NOTE: if layer object is callable, it must return a gluon block when invoked
blocks = [convert_layer(l) for l in self.scheme]
for block in blocks:
self.net.add(block())
if params.batchnorm:
self.net.add(nn.BatchNorm())
if params.activation_function:
self.net.add(nn.Activation(params.activation_function))
if params.dropout:
self.net.add(nn.Dropout(rate=params.dropout))
@property
def schemes(self) -> dict:
"""
Schemes are the pre-defined network architectures of various depths and complexities that can be used for the
InputEmbedder. Should be implemented in child classes, and are used to create Block when InputEmbedder is
initialised.
:return: dictionary of schemes, with key of type EmbedderScheme enum and value being list of mxnet.gluon.Block.
"""
raise NotImplementedError("Inheriting embedder must define schemes matching its allowed default "
"configurations.")
def hybrid_forward(self, F: ModuleType, x: nd_sym_type, *args, **kwargs) -> nd_sym_type:
"""
Used for forward pass through embedder network.
:param F: backend api, either `mxnet.nd` or `mxnet.sym` (if block has been hybridized).
:param x: environment state, where first dimension is batch_size, then dimensions are data type dependent.
:return: embedding of environment state, where shape is (batch_size, channels).
"""
# `input_rescaling` and `input_offset` set on inheriting embedder
x = x / self.input_rescaling
x = x - self.input_offset
if self.input_clipping is not None:
x.clip(a_min=self.input_clipping[0], a_max=self.input_clipping[1])
x = self.net(x)
return x.flatten()

View File

@@ -0,0 +1,76 @@
from typing import Union
from types import ModuleType
import mxnet as mx
from rl_coach.architectures.embedder_parameters import InputEmbedderParameters
from rl_coach.architectures.mxnet_components.embedders.embedder import InputEmbedder
from rl_coach.architectures.mxnet_components.layers import Conv2d
from rl_coach.base_parameters import EmbedderScheme
nd_sym_type = Union[mx.nd.NDArray, mx.sym.Symbol]
class ImageEmbedder(InputEmbedder):
def __init__(self, params: InputEmbedderParameters):
"""
An image embedder is an input embedder that takes an image input from the state and produces a vector
embedding by passing it through a neural network.
:param params: parameters object containing input_clipping, input_rescaling, batchnorm, activation_function
and dropout properties.
"""
super(ImageEmbedder, self).__init__(params)
self.input_rescaling = params.input_rescaling['image']
self.input_offset = params.input_offset['image']
@property
def schemes(self) -> dict:
"""
Schemes are the pre-defined network architectures of various depths and complexities that can be used. Are used
to create Block when ImageEmbedder is initialised.
:return: dictionary of schemes, with key of type EmbedderScheme enum and value being list of mxnet.gluon.Block.
"""
return {
EmbedderScheme.Empty:
[],
EmbedderScheme.Shallow:
[
Conv2d(num_filters=32, kernel_size=8, strides=4)
],
# Use for Atari DQN
EmbedderScheme.Medium:
[
Conv2d(num_filters=32, kernel_size=8, strides=4),
Conv2d(num_filters=64, kernel_size=4, strides=2),
Conv2d(num_filters=64, kernel_size=3, strides=1)
],
# Use for Carla
EmbedderScheme.Deep:
[
Conv2d(num_filters=32, kernel_size=5, strides=2),
Conv2d(num_filters=32, kernel_size=3, strides=1),
Conv2d(num_filters=64, kernel_size=3, strides=2),
Conv2d(num_filters=64, kernel_size=3, strides=1),
Conv2d(num_filters=128, kernel_size=3, strides=2),
Conv2d(num_filters=128, kernel_size=3, strides=1),
Conv2d(num_filters=256, kernel_size=3, strides=2),
Conv2d(num_filters=256, kernel_size=3, strides=1)
]
}
def hybrid_forward(self, F: ModuleType, x: nd_sym_type, *args, **kwargs) -> nd_sym_type:
"""
Used for forward pass through embedder network.
:param F: backend api, either `mxnet.nd` or `mxnet.sym` (if block has been hybridized).
:param x: image representing environment state, of shape (batch_size, in_channels, height, width).
:return: embedding of environment state, of shape (batch_size, channels).
"""
if len(x.shape) != 4 and self.scheme != EmbedderScheme.Empty:
raise ValueError("Image embedders expect the input size to have 4 dimensions. The given size is: {}"
.format(x.shape))
return super(ImageEmbedder, self).hybrid_forward(F, x, *args, **kwargs)

View File

@@ -0,0 +1,71 @@
from typing import Union
from types import ModuleType
import mxnet as mx
from mxnet import nd, sym
from rl_coach.architectures.embedder_parameters import InputEmbedderParameters
from rl_coach.architectures.mxnet_components.embedders.embedder import InputEmbedder
from rl_coach.architectures.mxnet_components.layers import Dense
from rl_coach.base_parameters import EmbedderScheme
nd_sym_type = Union[mx.nd.NDArray, mx.sym.Symbol]
class VectorEmbedder(InputEmbedder):
def __init__(self, params: InputEmbedderParameters):
"""
An vector embedder is an input embedder that takes an vector input from the state and produces a vector
embedding by passing it through a neural network.
:param params: parameters object containing input_clipping, input_rescaling, batchnorm, activation_function
and dropout properties.
"""
super(VectorEmbedder, self).__init__(params)
self.input_rescaling = params.input_rescaling['vector']
self.input_offset = params.input_offset['vector']
@property
def schemes(self):
"""
Schemes are the pre-defined network architectures of various depths and complexities that can be used. Are used
to create Block when VectorEmbedder is initialised.
:return: dictionary of schemes, with key of type EmbedderScheme enum and value being list of mxnet.gluon.Block.
"""
return {
EmbedderScheme.Empty:
[],
EmbedderScheme.Shallow:
[
Dense(units=128)
],
# Use for DQN
EmbedderScheme.Medium:
[
Dense(units=256)
],
# Use for Carla
EmbedderScheme.Deep:
[
Dense(units=128),
Dense(units=128),
Dense(units=128)
]
}
def hybrid_forward(self, F: ModuleType, x: nd_sym_type, *args, **kwargs) -> nd_sym_type:
"""
Used for forward pass through embedder network.
:param F: backend api, either `nd` or `sym` (if block has been hybridized).
:type F: nd or sym
:param x: vector representing environment state, of shape (batch_size, in_channels).
:return: embedding of environment state, of shape (batch_size, channels).
"""
if isinstance(x, nd.NDArray) and len(x.shape) != 2 and self.scheme != EmbedderScheme.Empty:
raise ValueError("Vector embedders expect the input size to have 2 dimensions. The given size is: {}"
.format(x.shape))
return super(VectorEmbedder, self).hybrid_forward(F, x, *args, **kwargs)

View File

@@ -0,0 +1,501 @@
#
# Copyright (c) 2017 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import copy
from itertools import chain
from typing import List, Tuple, Union
from types import ModuleType
import numpy as np
import mxnet as mx
from mxnet import nd, sym
from mxnet.gluon import HybridBlock
from mxnet.ndarray import NDArray
from mxnet.symbol import Symbol
from rl_coach.base_parameters import NetworkParameters
from rl_coach.architectures.embedder_parameters import InputEmbedderParameters
from rl_coach.architectures.head_parameters import HeadParameters, PPOHeadParameters
from rl_coach.architectures.head_parameters import PPOVHeadParameters, VHeadParameters, QHeadParameters
from rl_coach.architectures.middleware_parameters import MiddlewareParameters
from rl_coach.architectures.middleware_parameters import FCMiddlewareParameters, LSTMMiddlewareParameters
from rl_coach.architectures.mxnet_components.architecture import MxnetArchitecture
from rl_coach.architectures.mxnet_components.embedders import ImageEmbedder, VectorEmbedder
from rl_coach.architectures.mxnet_components.heads import Head, HeadLoss, PPOHead, PPOVHead, VHead, QHead
from rl_coach.architectures.mxnet_components.middlewares import FCMiddleware, LSTMMiddleware
from rl_coach.architectures.mxnet_components import utils
from rl_coach.base_parameters import AgentParameters, EmbeddingMergerType
from rl_coach.spaces import SpacesDefinition, PlanarMapsObservationSpace
class GeneralMxnetNetwork(MxnetArchitecture):
"""
A generalized version of all possible networks implemented using mxnet.
"""
def __init__(self,
agent_parameters: AgentParameters,
spaces: SpacesDefinition,
name: str,
global_network=None,
network_is_local: bool=True,
network_is_trainable: bool=False):
"""
:param agent_parameters: the agent parameters
:param spaces: the spaces definition of the agent
:param name: the name of the network
:param global_network: the global network replica that is shared between all the workers
:param network_is_local: is the network global (shared between workers) or local (dedicated to the worker)
:param network_is_trainable: is the network trainable (we can apply gradients on it)
"""
self.network_wrapper_name = name.split('/')[0]
self.network_parameters = agent_parameters.network_wrappers[self.network_wrapper_name]
if self.network_parameters.use_separate_networks_per_head:
self.num_heads_per_network = 1
self.num_networks = len(self.network_parameters.heads_parameters)
else:
self.num_heads_per_network = len(self.network_parameters.heads_parameters)
self.num_networks = 1
super().__init__(agent_parameters, spaces, name, global_network,
network_is_local, network_is_trainable)
def construct_model(self):
# validate the configuration
if len(self.network_parameters.input_embedders_parameters) == 0:
raise ValueError("At least one input type should be defined")
if len(self.network_parameters.heads_parameters) == 0:
raise ValueError("At least one output type should be defined")
if self.network_parameters.middleware_parameters is None:
raise ValueError("Exactly one middleware type should be defined")
self.model = GeneralModel(
num_networks=self.num_networks,
num_heads_per_network=self.num_heads_per_network,
network_is_local=self.network_is_local,
network_name=self.network_wrapper_name,
agent_parameters=self.ap,
network_parameters=self.network_parameters,
spaces=self.spaces)
self.losses = self.model.losses()
# Learning rate
lr_scheduler = None
if self.network_parameters.learning_rate_decay_rate != 0:
lr_scheduler = mx.lr_scheduler.FactorScheduler(
step=self.network_parameters.learning_rate_decay_steps,
factor=self.network_parameters.learning_rate_decay_rate)
# Optimizer
# FIXME Does this code for distributed training make sense?
if self.distributed_training and self.network_is_local and self.network_parameters.shared_optimizer:
# distributed training + is a local network + optimizer shared -> take the global optimizer
self.optimizer = self.global_network.optimizer
elif (self.distributed_training and self.network_is_local and not self.network_parameters.shared_optimizer)\
or self.network_parameters.shared_optimizer or not self.distributed_training:
if self.network_parameters.optimizer_type == 'Adam':
self.optimizer = mx.optimizer.Adam(
learning_rate=self.network_parameters.learning_rate,
beta1=self.network_parameters.adam_optimizer_beta1,
beta2=self.network_parameters.adam_optimizer_beta2,
epsilon=self.network_parameters.optimizer_epsilon,
lr_scheduler=lr_scheduler)
elif self.network_parameters.optimizer_type == 'RMSProp':
self.optimizer = mx.optimizer.RMSProp(
learning_rate=self.network_parameters.learning_rate,
gamma1=self.network_parameters.rms_prop_optimizer_decay,
epsilon=self.network_parameters.optimizer_epsilon,
lr_scheduler=lr_scheduler)
elif self.network_parameters.optimizer_type == 'LBFGS':
raise NotImplementedError('LBFGS optimizer not implemented')
else:
raise Exception("{} is not a valid optimizer type".format(self.network_parameters.optimizer_type))
@property
def output_heads(self):
return self.model.output_heads
def _get_activation(activation_function_string: str):
"""
Map the activation function from a string to the mxnet framework equivalent
:param activation_function_string: the type of the activation function
:return: mxnet activation function string
"""
return utils.get_mxnet_activation_name(activation_function_string)
def _sanitize_activation(params: Union[InputEmbedderParameters, MiddlewareParameters, HeadParameters]) ->\
Union[InputEmbedderParameters, MiddlewareParameters, HeadParameters]:
"""
Change activation function to the mxnet specific value
:param params: any parameter that has activation_function property
:return: copy of params with activation function correctly set
"""
params_copy = copy.copy(params)
params_copy.activation_function = _get_activation(params.activation_function)
return params_copy
def _get_input_embedder(spaces: SpacesDefinition,
input_name: str,
embedder_params: InputEmbedderParameters) -> ModuleType:
"""
Given an input embedder parameters class, creates the input embedder and returns it
:param input_name: the name of the input to the embedder (used for retrieving the shape). The input should
be a value within the state or the action.
:param embedder_params: the parameters of the class of the embedder
:return: the embedder instance
"""
allowed_inputs = copy.copy(spaces.state.sub_spaces)
allowed_inputs["action"] = copy.copy(spaces.action)
allowed_inputs["goal"] = copy.copy(spaces.goal)
if input_name not in allowed_inputs.keys():
raise ValueError("The key for the input embedder ({}) must match one of the following keys: {}"
.format(input_name, allowed_inputs.keys()))
type = "vector"
if isinstance(allowed_inputs[input_name], PlanarMapsObservationSpace):
type = "image"
def sanitize_params(params: InputEmbedderParameters):
params_copy = _sanitize_activation(params)
# params_copy.input_rescaling = params_copy.input_rescaling[type]
# params_copy.input_offset = params_copy.input_offset[type]
params_copy.name = input_name
return params_copy
embedder_params = sanitize_params(embedder_params)
if type == 'vector':
module = VectorEmbedder(embedder_params)
elif type == 'image':
module = ImageEmbedder(embedder_params)
else:
raise KeyError('Unsupported embedder type: {}'.format(type))
return module
def _get_middleware(middleware_params: MiddlewareParameters) -> ModuleType:
"""
Given a middleware type, creates the middleware and returns it
:param middleware_params: the paramaeters of the middleware class
:return: the middleware instance
"""
middleware_params = _sanitize_activation(middleware_params)
if isinstance(middleware_params, FCMiddlewareParameters):
module = FCMiddleware(middleware_params)
elif isinstance(middleware_params, LSTMMiddlewareParameters):
module = LSTMMiddleware(middleware_params)
else:
raise KeyError('Unsupported middleware type: {}'.format(type(middleware_params)))
return module
def _get_output_head(
head_params: HeadParameters,
head_idx: int,
head_type_index: int,
agent_params: AgentParameters,
spaces: SpacesDefinition,
network_name: str,
is_local: bool) -> Head:
"""
Given a head type, creates the head and returns it
:param head_params: the parameters of the head to create
:param head_idx: the head index
:param head_type_index: the head type index (same index if head_param.num_output_head_copies>0)
:param agent_params: agent parameters
:param spaces: state and action space definitions
:param network_name: name of the network
:param is_local:
:return: head block
"""
head_params = _sanitize_activation(head_params)
if isinstance(head_params, PPOHeadParameters):
module = PPOHead(
agent_parameters=agent_params,
spaces=spaces,
network_name=network_name,
head_type_idx=head_type_index,
loss_weight=head_params.loss_weight,
is_local=is_local,
activation_function=head_params.activation_function,
dense_layer=head_params.dense_layer)
elif isinstance(head_params, VHeadParameters):
module = VHead(
agent_parameters=agent_params,
spaces=spaces,
network_name=network_name,
head_type_idx=head_type_index,
loss_weight=head_params.loss_weight,
is_local=is_local,
activation_function=head_params.activation_function,
dense_layer=head_params.dense_layer)
elif isinstance(head_params, PPOVHeadParameters):
module = PPOVHead(
agent_parameters=agent_params,
spaces=spaces,
network_name=network_name,
head_type_idx=head_type_index,
loss_weight=head_params.loss_weight,
is_local=is_local,
activation_function=head_params.activation_function,
dense_layer=head_params.dense_layer)
elif isinstance(head_params, QHeadParameters):
module = QHead(
agent_parameters=agent_params,
spaces=spaces,
network_name=network_name,
head_type_idx=head_type_index,
loss_weight=head_params.loss_weight,
is_local=is_local,
activation_function=head_params.activation_function,
dense_layer=head_params.dense_layer)
else:
raise KeyError('Unsupported head type: {}'.format(type(head_params)))
return module
class ScaledGradHead(HybridBlock):
"""
Wrapper block for applying gradient scaling to input before feeding the head network
"""
def __init__(self,
head_index: int,
head_type_index: int,
network_name: str,
spaces: SpacesDefinition,
network_is_local: bool,
agent_params: AgentParameters,
head_params: HeadParameters) -> None:
"""
:param head_idx: the head index
:param head_type_index: the head type index (same index if head_param.num_output_head_copies>0)
:param network_name: name of the network
:param spaces: state and action space definitions
:param network_is_local: whether network is local
:param agent_params: agent parameters
:param head_params: head parameters
"""
super(ScaledGradHead, self).__init__()
head_params = _sanitize_activation(head_params)
with self.name_scope():
self.head = _get_output_head(
head_params=head_params,
head_idx=head_index,
head_type_index=head_type_index,
agent_params=agent_params,
spaces=spaces,
network_name=network_name,
is_local=network_is_local)
self.gradient_rescaler = self.params.get_constant(
name='gradient_rescaler',
value=np.array([float(head_params.rescale_gradient_from_head_by_factor)]))
# self.gradient_rescaler = self.params.get(
# name='gradient_rescaler',
# shape=(1,),
# init=mx.init.Constant(float(head_params.rescale_gradient_from_head_by_factor)))
def hybrid_forward(self,
F: ModuleType,
x: Union[NDArray, Symbol],
gradient_rescaler: Union[NDArray, Symbol]) -> Tuple[Union[NDArray, Symbol], ...]:
""" Overrides gluon.HybridBlock.hybrid_forward
:param nd or sym F: ndarray or symbol module
:param x: head input
:param gradient_rescaler: gradient rescaler for partial blocking of gradient
:return: head output
"""
grad_scaled_x = F.broadcast_mul((1 - gradient_rescaler), F.BlockGrad(x)) + F.broadcast_mul(gradient_rescaler, x)
out = self.head(grad_scaled_x)
return out
class SingleModel(HybridBlock):
"""
Block that connects a single embedder, with middleware and one to multiple heads
"""
def __init__(self,
network_is_local: bool,
network_name: str,
agent_parameters: AgentParameters,
in_emb_param_dict: {str: InputEmbedderParameters},
embedding_merger_type: EmbeddingMergerType,
middleware_param: MiddlewareParameters,
head_param_list: [HeadParameters],
head_type_idx_start: int,
spaces: SpacesDefinition,
*args, **kwargs):
"""
:param network_is_local: True if network is local
:param network_name: name of the network
:param agent_parameters: agent parameters
:param in_emb_param_dict: dictionary of embedder name to embedding parameters
:param embedding_merger_type: type of merging output of embedders: concatenate or sum
:param middleware_param: middleware parameters
:param head_param_list: list of head parameters, one per head type
:param head_type_idx_start: start index for head type index counting
:param spaces: state and action space definition
"""
super(SingleModel, self).__init__(*args, **kwargs)
self._embedding_merger_type = embedding_merger_type
self._input_embedders = list() # type: List[HybridBlock]
self._output_heads = list() # type: List[ScaledGradHead]
with self.name_scope():
for input_name in sorted(in_emb_param_dict):
input_type = in_emb_param_dict[input_name]
input_embedder = _get_input_embedder(spaces, input_name, input_type)
self.register_child(input_embedder)
self._input_embedders.append(input_embedder)
self.middleware = _get_middleware(middleware_param)
for i, head_param in enumerate(head_param_list):
for head_copy_idx in range(head_param.num_output_head_copies):
# create output head and add it to the output heads list
output_head = ScaledGradHead(
head_index=(head_type_idx_start + i) * head_param.num_output_head_copies + head_copy_idx,
head_type_index=head_type_idx_start + i,
network_name=network_name,
spaces=spaces,
network_is_local=network_is_local,
agent_params=agent_parameters,
head_params=head_param)
self.register_child(output_head)
self._output_heads.append(output_head)
def hybrid_forward(self, F, *inputs: Union[NDArray, Symbol]) -> Tuple[Union[NDArray, Symbol], ...]:
""" Overrides gluon.HybridBlock.hybrid_forward
:param nd or sym F: ndarray or symbol block
:param inputs: model inputs, one for each embedder
:return: head outputs in a tuple
"""
# Input Embeddings
state_embedding = list()
for input, embedder in zip(inputs, self._input_embedders):
state_embedding.append(embedder(input))
# Merger
if len(state_embedding) == 1:
state_embedding = state_embedding[0]
else:
if self._embedding_merger_type == EmbeddingMergerType.Concat:
state_embedding = F.concat(*state_embedding, dim=1, name='merger') # NC or NCHW layout
elif self._embedding_merger_type == EmbeddingMergerType.Sum:
state_embedding = F.add_n(*state_embedding, name='merger')
# Middleware
state_embedding = self.middleware(state_embedding)
# Head
outputs = tuple()
for head in self._output_heads:
outputs += (head(state_embedding),)
return outputs
@property
def input_embedders(self) -> List[HybridBlock]:
"""
:return: list of input embedders
"""
return self._input_embedders
@property
def output_heads(self) -> List[Head]:
"""
:return: list of output heads
"""
return [h.head for h in self._output_heads]
class GeneralModel(HybridBlock):
"""
Block that creates multiple single models
"""
def __init__(self,
num_networks: int,
num_heads_per_network: int,
network_is_local: bool,
network_name: str,
agent_parameters: AgentParameters,
network_parameters: NetworkParameters,
spaces: SpacesDefinition,
*args, **kwargs):
"""
:param num_networks: number of networks to create
:param num_heads_per_network: number of heads per network to create
:param network_is_local: True if network is local
:param network_name: name of the network
:param agent_parameters: agent parameters
:param network_parameters: network parameters
:param spaces: state and action space definitions
"""
super(GeneralModel, self).__init__(*args, **kwargs)
with self.name_scope():
self.nets = list()
for network_idx in range(num_networks):
head_type_idx_start = network_idx * num_heads_per_network
head_type_idx_end = head_type_idx_start + num_heads_per_network
net = SingleModel(
head_type_idx_start=head_type_idx_start,
network_name=network_name,
network_is_local=network_is_local,
agent_parameters=agent_parameters,
in_emb_param_dict=network_parameters.input_embedders_parameters,
embedding_merger_type=network_parameters.embedding_merger_type,
middleware_param=network_parameters.middleware_parameters,
head_param_list=network_parameters.heads_parameters[head_type_idx_start:head_type_idx_end],
spaces=spaces)
self.register_child(net)
self.nets.append(net)
def hybrid_forward(self, F, *inputs):
""" Overrides gluon.HybridBlock.hybrid_forward
:param nd or sym F: ndarray or symbol block
:param inputs: model inputs, one for each embedder. Passed to all networks.
:return: head outputs in a tuple
"""
outputs = tuple()
for net in self.nets:
out = net(*inputs)
outputs += out
return outputs
@property
def output_heads(self) -> List[Head]:
""" Return all heads in a single list
Note: There is a one-to-one mapping between output_heads and losses
:return: list of heads
"""
return list(chain.from_iterable(net.output_heads for net in self.nets))
def losses(self) -> List[HeadLoss]:
""" Construct loss blocks for network training
Note: There is a one-to-one mapping between output_heads and losses
:return: list of loss blocks
"""
return [h.loss() for net in self.nets for h in net.output_heads]

View File

@@ -0,0 +1,14 @@
from .head import Head, HeadLoss
from .q_head import QHead
from .ppo_head import PPOHead
from .ppo_v_head import PPOVHead
from .v_head import VHead
__all__ = [
'Head',
'HeadLoss',
'QHead',
'PPOHead',
'PPOVHead',
'VHead'
]

View File

@@ -0,0 +1,181 @@
from typing import Dict, List, Union, Tuple
from mxnet.gluon import nn, loss
from mxnet.ndarray import NDArray
from mxnet.symbol import Symbol
from rl_coach.base_parameters import AgentParameters
from rl_coach.spaces import SpacesDefinition
LOSS_OUT_TYPE_LOSS = 'loss'
LOSS_OUT_TYPE_REGULARIZATION = 'regularization'
class LossInputSchema(object):
"""
Helper class to contain schema for loss hybrid_forward input
"""
def __init__(self, head_outputs: List[str], agent_inputs: List[str], targets: List[str]):
"""
:param head_outputs: list of argument names in hybrid_forward that are outputs of the head.
The order and number MUST MATCH the output from the head.
:param agent_inputs: list of argument names in hybrid_forward that are inputs from the agent.
The order and number MUST MATCH `output_<head_type_idx>_<order>` for this head.
:param targets: list of argument names in hybrid_forward that are targets for the loss.
The order and number MUST MATCH targets passed from the agent.
"""
self._head_outputs = head_outputs
self._agent_inputs = agent_inputs
self._targets = targets
@property
def head_outputs(self):
return self._head_outputs
@property
def agent_inputs(self):
return self._agent_inputs
@property
def targets(self):
return self._targets
class HeadLoss(loss.Loss):
"""
ABC for loss functions of each head. Child class must implement input_schema() and loss_forward()
"""
def __init__(self, *args, **kwargs):
super(HeadLoss, self).__init__(*args, **kwargs)
self._output_schema = None # type: List[str]
@property
def input_schema(self) -> LossInputSchema:
"""
:return: schema for input of hybrid_forward. Read docstring for LossInputSchema for details.
"""
raise NotImplementedError
@property
def output_schema(self) -> List[str]:
"""
:return: schema for output of hybrid_forward. Must contain 'loss' and 'regularization' keys at least once.
The order and total number must match that of returned values from the loss. 'loss' and 'regularization'
are special keys. Any other string is treated as auxiliary outputs and must include match auxiliary
fetch names returned by the head.
"""
return self._output_schema
def forward(self, *args):
"""
Override forward() so that number of outputs can be checked against the schema
"""
outputs = super(HeadLoss, self).forward(*args)
if isinstance(outputs, tuple) or isinstance(outputs, list):
num_outputs = len(outputs)
else:
assert isinstance(outputs, NDArray) or isinstance(outputs, Symbol)
num_outputs = 1
assert num_outputs == len(self.output_schema), "Number of outputs don't match schema ({} != {})".format(
num_outputs, len(self.output_schema))
return outputs
def _loss_output(self, outputs: List[Tuple[Union[NDArray, Symbol], str]]):
"""
Must be called on the output from hybrid_forward().
Saves the returned output as the schema and returns output values in a list
:return: list of output values
"""
output_schema = [o[1] for o in outputs]
assert self._output_schema is None or self._output_schema == output_schema
self._output_schema = output_schema
return tuple(o[0] for o in outputs)
def hybrid_forward(self, F, x, *args, **kwargs):
"""
Passes the cal to loss_forward() and constructs output schema from its output by calling loss_output()
"""
return self._loss_output(self.loss_forward(F, x, *args, **kwargs))
def loss_forward(self, F, x, *args, **kwargs) -> List[Tuple[Union[NDArray, Symbol], str]]:
"""
Similar to hybrid_forward, but returns list of (NDArray, type_str)
"""
raise NotImplementedError
class Head(nn.HybridBlock):
def __init__(self, agent_parameters: AgentParameters, spaces: SpacesDefinition,
network_name: str, head_type_idx: int=0, loss_weight: float=1., is_local: bool=True,
activation_function: str='relu', dense_layer: None=None):
"""
A head is the final part of the network. It takes the embedding from the middleware embedder and passes it
through a neural network to produce the output of the network. There can be multiple heads in a network, and
each one has an assigned loss function. The heads are algorithm dependent.
:param agent_parameters: containing algorithm parameters such as clip_likelihood_ratio_using_epsilon
and beta_entropy.
:param spaces: containing action spaces used for defining size of network output.
:param network_name: name of head network. currently unused.
:param head_type_idx: index of head network. currently unused.
:param loss_weight: scalar used to adjust relative weight of loss (if using this loss with others).
:param is_local: flag to denote if network is local. currently unused.
:param activation_function: activation function to use between layers. currently unused.
:param dense_layer: type of dense layer to use in network. currently unused.
"""
super(Head, self).__init__()
self.head_type_idx = head_type_idx
self.network_name = network_name
self.loss_weight = loss_weight
self.is_local = is_local
self.ap = agent_parameters
self.spaces = spaces
self.return_type = None
self.activation_function = activation_function
self.dense_layer = dense_layer
self._num_outputs = None
def loss(self) -> HeadLoss:
"""
Returns loss block to be used for specific head implementation.
:return: loss block (can be called as function) for outputs returned by the head network.
"""
raise NotImplementedError()
@property
def num_outputs(self):
""" Returns number of outputs that forward() call will return
:return:
"""
assert self._num_outputs is not None, 'must call forward() once to configure number of outputs'
return self._num_outputs
def forward(self, *args):
"""
Override forward() so that number of outputs can be automatically set
"""
outputs = super(Head, self).forward(*args)
if isinstance(outputs, tuple):
num_outputs = len(outputs)
else:
assert isinstance(outputs, NDArray) or isinstance(outputs, Symbol)
num_outputs = 1
if self._num_outputs is None:
self._num_outputs = num_outputs
else:
assert self._num_outputs == num_outputs, 'Number of outputs cannot change ({} != {})'.format(
self._num_outputs, num_outputs)
assert self._num_outputs == len(self.loss().input_schema.head_outputs)
return outputs
def hybrid_forward(self, F, x, *args, **kwargs):
"""
Used for forward pass through head network.
:param F: backend api, either `mxnet.nd` or `mxnet.sym` (if block has been hybridized).
:param x: middleware state representation, of shape (batch_size, in_channels).
:return: final output of network, that will be used in loss calculations.
"""
raise NotImplementedError()

View File

@@ -0,0 +1,669 @@
from typing import List, Tuple, Union
from types import ModuleType
import math
import mxnet as mx
from mxnet.gluon import nn
from rl_coach.base_parameters import AgentParameters
from rl_coach.core_types import ActionProbabilities
from rl_coach.spaces import SpacesDefinition, BoxActionSpace, DiscreteActionSpace
from rl_coach.utils import eps
from rl_coach.architectures.mxnet_components.heads.head import Head, HeadLoss, LossInputSchema
from rl_coach.architectures.mxnet_components.heads.head import LOSS_OUT_TYPE_LOSS, LOSS_OUT_TYPE_REGULARIZATION
from rl_coach.architectures.mxnet_components.utils import hybrid_clip
LOSS_OUT_TYPE_KL = 'kl_divergence'
LOSS_OUT_TYPE_ENTROPY = 'entropy'
LOSS_OUT_TYPE_LIKELIHOOD_RATIO = 'likelihood_ratio'
LOSS_OUT_TYPE_CLIPPED_LIKELIHOOD_RATIO = 'clipped_likelihood_ratio'
nd_sym_type = Union[mx.nd.NDArray, mx.sym.Symbol]
class MultivariateNormalDist:
def __init__(self,
num_var: int,
mean: nd_sym_type,
sigma: nd_sym_type,
F: ModuleType=mx.nd) -> None:
"""
Distribution object for Multivariate Normal. Works with batches.
Optionally works with batches and time steps, but be consistent in usage: i.e. if using time_step,
mean, sigma and data for log_prob must all include a time_step dimension.
:param num_var: number of variables in distribution
:param mean: mean for each variable,
of shape (num_var) or
of shape (batch_size, num_var) or
of shape (batch_size, time_step, num_var).
:param sigma: covariance matrix,
of shape (num_var, num_var) or
of shape (batch_size, num_var, num_var) or
of shape (batch_size, time_step, num_var, num_var).
:param (mx.nd or mx.sym) F: backend api (mx.sym if block has been hybridized).
"""
self.num_var = num_var
self.mean = mean
self.sigma = sigma
self.F = F
def inverse_using_cholesky(self, matrix: nd_sym_type) -> nd_sym_type:
"""
Calculate inverses for a batch of matrices using Cholesky decomposition method.
:param matrix: matrix (or matrices) to invert,
of shape (num_var, num_var) or
of shape (batch_size, num_var, num_var) or
of shape (batch_size, time_step, num_var, num_var).
:return: inverted matrix (or matrices),
of shape (num_var, num_var) or
of shape (batch_size, num_var, num_var) or
of shape (batch_size, time_step, num_var, num_var).
"""
cholesky_factor = self.F.linalg.potrf(matrix)
return self.F.linalg.potri(cholesky_factor)
def log_det(self, matrix: nd_sym_type) -> nd_sym_type:
"""
Calculate log of the determinant for a batch of matrices using Cholesky decomposition method.
:param matrix: matrix (or matrices) to invert,
of shape (num_var, num_var) or
of shape (batch_size, num_var, num_var) or
of shape (batch_size, time_step, num_var, num_var).
:return: inverted matrix (or matrices),
of shape (num_var, num_var) or
of shape (batch_size, num_var, num_var) or
of shape (batch_size, time_step, num_var, num_var).
"""
cholesky_factor = self.F.linalg.potrf(matrix)
return 2 * self.F.linalg.sumlogdiag(cholesky_factor)
def log_prob(self, x: nd_sym_type) -> nd_sym_type:
"""
Calculate the log probability of data given the current distribution.
See http://www.notenoughthoughts.net/posts/normal-log-likelihood-gradient.html
and https://discuss.mxnet.io/t/multivariate-gaussian-log-density-operator/1169/7
:param x: input data,
of shape (num_var) or
of shape (batch_size, num_var) or
of shape (batch_size, time_step, num_var).
:return: log_probability,
of shape (1) or
of shape (batch_size) or
of shape (batch_size, time_step).
"""
a = (self.num_var / 2) * math.log(2 * math.pi)
log_det_sigma = self.log_det(self.sigma)
b = (1 / 2) * log_det_sigma
sigma_inv = self.inverse_using_cholesky(self.sigma)
# deviation from mean, and dev_t is equivalent to transpose on last two dims.
dev = (x - self.mean).expand_dims(-1)
dev_t = (x - self.mean).expand_dims(-2)
# since batch_dot only works with ndarrays with ndim of 3,
# and we could have ndarrays with ndim of 4,
# we flatten batch_size and time_step into single dim.
dev_flat = dev.reshape(shape=(-1, 0, 0), reverse=1)
sigma_inv_flat = sigma_inv.reshape(shape=(-1, 0, 0), reverse=1)
dev_t_flat = dev_t.reshape(shape=(-1, 0, 0), reverse=1)
c = (1 / 2) * self.F.batch_dot(self.F.batch_dot(dev_t_flat, sigma_inv_flat), dev_flat)
# and now reshape back to (batch_size, time_step) if required.
c = c.reshape_like(b)
log_likelihood = -a - b - c
return log_likelihood
def entropy(self) -> nd_sym_type:
"""
Calculate entropy of current distribution.
See http://www.nowozin.net/sebastian/blog/the-entropy-of-a-normal-distribution.html
:return: entropy,
of shape (1) or
of shape (batch_size) or
of shape (batch_size, time_step).
"""
# todo: check if differential entropy is correct
log_det_sigma = self.log_det(self.sigma)
return (self.num_var / 2) + ((self.num_var / 2) * math.log(2 * math.pi)) + ((1 / 2) * log_det_sigma)
def kl_div(self, alt_dist) -> nd_sym_type:
"""
Calculated KL-Divergence with another MultivariateNormalDist distribution
See https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence
Specifically https://wikimedia.org/api/rest_v1/media/math/render/svg/a3bf3b4917bd1fcb8be48d6d6139e2e387bdc7d3
:param alt_dist: alternative distribution used for kl divergence calculation
:type alt_dist: MultivariateNormalDist
:return: KL-Divergence, of shape (1,)
"""
sigma_a_inv = self.F.linalg.potri(self.F.linalg.potrf(self.sigma))
sigma_b_inv = self.F.linalg.potri(self.F.linalg.potrf(alt_dist.sigma))
term1a = mx.nd.batch_dot(sigma_b_inv, self.sigma)
# sum of diagonal for batch of matrices
term1 = (self.F.eye(self.num_var).broadcast_like(term1a) * term1a).sum(axis=-1).sum(axis=-1)
mean_diff = (alt_dist.mean - self.mean).expand_dims(-1)
mean_diff_t = (alt_dist.mean - self.mean).expand_dims(-2)
term2 = self.F.batch_dot(self.F.batch_dot(mean_diff_t, sigma_b_inv), mean_diff).reshape_like(term1)
term3 = (2 * self.F.linalg.sumlogdiag(self.F.linalg.potrf(alt_dist.sigma))) -\
(2 * self.F.linalg.sumlogdiag(self.F.linalg.potrf(self.sigma)))
return 0.5 * (term1 + term2 - self.num_var + term3)
class CategoricalDist:
def __init__(self, n_classes: int, probs: nd_sym_type, F: ModuleType=mx.nd) -> None:
"""
Distribution object for Categorical data.
Optionally works with batches and time steps, but be consistent in usage: i.e. if using time_step,
mean, sigma and data for log_prob must all include a time_step dimension.
:param n_classes: number of classes in distribution
:param probs: probabilities for each class,
of shape (n_classes),
of shape (batch_size, n_classes) or
of shape (batch_size, time_step, n_classes)
:param (mx.nd or mx.sym) F: backend api (mx.sym if block has been hybridized).
"""
self.n_classes = n_classes
self.probs = probs
self.F = F
def log_prob(self, actions: nd_sym_type) -> nd_sym_type:
"""
Calculate the log probability of data given the current distribution.
:param actions: actions, with int8 data type,
of shape (1) if probs was (n_classes),
of shape (batch_size) if probs was (batch_size, n_classes) and
of shape (batch_size, time_step) if probs was (batch_size, time_step, n_classes)
:return: log_probability,
of shape (1) if probs was (n_classes),
of shape (batch_size) if probs was (batch_size, n_classes) and
of shape (batch_size, time_step) if probs was (batch_size, time_step, n_classes)
"""
action_mask = actions.one_hot(depth=self.n_classes)
action_probs = (self.probs * action_mask).sum(axis=-1)
return action_probs.log()
def entropy(self) -> nd_sym_type:
"""
Calculate entropy of current distribution.
:return: entropy,
of shape (1) if probs was (n_classes),
of shape (batch_size) if probs was (batch_size, n_classes) and
of shape (batch_size, time_step) if probs was (batch_size, time_step, n_classes)
"""
# todo: look into numerical stability
return -(self.probs.log()*self.probs).sum(axis=-1)
def kl_div(self, alt_dist) -> nd_sym_type:
"""
Calculated KL-Divergence with another Categorical distribution
:param alt_dist: alternative distribution used for kl divergence calculation
:type alt_dist: CategoricalDist
:return: KL-Divergence
"""
logits_a = self.probs.clip(a_min=eps, a_max=1 - eps).log()
logits_b = alt_dist.probs.clip(a_min=eps, a_max=1 - eps).log()
t = self.probs * (logits_a - logits_b)
t = self.F.where(condition=(alt_dist.probs == 0), x=self.F.ones_like(alt_dist.probs) * math.inf, y=t)
t = self.F.where(condition=(self.probs == 0), x=self.F.zeros_like(self.probs), y=t)
return t.sum(axis=-1)
class DiscretePPOHead(nn.HybridBlock):
def __init__(self, num_actions: int) -> None:
"""
Head block for Discrete Proximal Policy Optimization, to calculate probabilities for each action given
middleware representation of the environment state.
:param num_actions: number of actions in action space.
"""
super(DiscretePPOHead, self).__init__()
with self.name_scope():
self.dense = nn.Dense(units=num_actions, flatten=False)
def hybrid_forward(self, F: ModuleType, x: nd_sym_type) -> nd_sym_type:
"""
Used for forward pass through head network.
:param (mx.nd or mx.sym) F: backend api (mx.sym if block has been hybridized).
:param x: middleware state representation,
of shape (batch_size, in_channels) or
of shape (batch_size, time_step, in_channels).
:return: batch of probabilities for each action,
of shape (batch_size, num_actions) or
of shape (batch_size, time_step, num_actions).
"""
policy_values = self.dense(x)
policy_probs = F.softmax(policy_values)
return policy_probs
class ContinuousPPOHead(nn.HybridBlock):
def __init__(self, num_actions: int) -> None:
"""
Head block for Continuous Proximal Policy Optimization, to calculate probabilities for each action given
middleware representation of the environment state.
:param num_actions: number of actions in action space.
"""
super(ContinuousPPOHead, self).__init__()
with self.name_scope():
# todo: change initialization strategy
self.dense = nn.Dense(units=num_actions, flatten=False)
# all samples (across batch, and time step) share the same covariance, which is learnt,
# but since we assume the action probability variables are independent,
# only the diagonal entries of the covariance matrix are specified.
self.log_std = self.params.get('log_std',
shape=num_actions,
init=mx.init.Zero(),
allow_deferred_init=True)
# todo: is_local?
def hybrid_forward(self, F: ModuleType, x: nd_sym_type, log_std: nd_sym_type) -> List[nd_sym_type]:
"""
Used for forward pass through head network.
:param (mx.nd or mx.sym) F: backend api (mx.sym if block has been hybridized).
:param x: middleware state representation,
of shape (batch_size, in_channels) or
of shape (batch_size, time_step, in_channels).
:return: batch of probabilities for each action,
of shape (batch_size, action_mean) or
of shape (batch_size, time_step, action_mean).
"""
policy_means = self.dense(x)
policy_std = log_std.exp()
return [policy_means, policy_std]
class ClippedPPOLossDiscrete(HeadLoss):
def __init__(self,
num_actions: int,
clip_likelihood_ratio_using_epsilon: float,
beta: float=0,
use_kl_regularization: bool=False,
initial_kl_coefficient: float=1,
kl_cutoff: float=0,
high_kl_penalty_coefficient: float=1,
weight: float=1,
batch_axis: int=0) -> None:
"""
Loss for discrete version of Clipped PPO.
:param num_actions: number of actions in action space.
:param clip_likelihood_ratio_using_epsilon: epsilon to use for likelihood ratio clipping.
:param beta: loss coefficient applied to entropy
:param use_kl_regularization: option to add kl divergence loss
:param initial_kl_coefficient: loss coefficient applied kl divergence loss (also see high_kl_penalty_coefficient).
:param kl_cutoff: threshold for using high_kl_penalty_coefficient
:param high_kl_penalty_coefficient: loss coefficient applied to kv divergence above kl_cutoff
:param weight: scalar used to adjust relative weight of loss (if using this loss with others).
:param batch_axis: axis used for mini-batch (default is 0) and excluded from loss aggregation.
"""
super(ClippedPPOLossDiscrete, self).__init__(weight=weight, batch_axis=batch_axis)
self.weight = weight
self.num_actions = num_actions
self.clip_likelihood_ratio_using_epsilon = clip_likelihood_ratio_using_epsilon
self.beta = beta
self.use_kl_regularization = use_kl_regularization
self.initial_kl_coefficient = initial_kl_coefficient if self.use_kl_regularization else 0.0
self.kl_coefficient = self.params.get('kl_coefficient',
shape=(1,),
init=mx.init.Constant([initial_kl_coefficient,]),
differentiable=False)
self.kl_cutoff = kl_cutoff
self.high_kl_penalty_coefficient = high_kl_penalty_coefficient
@property
def input_schema(self) -> LossInputSchema:
return LossInputSchema(
head_outputs=['new_policy_probs'],
agent_inputs=['actions', 'old_policy_probs', 'clip_param_rescaler'],
targets=['advantages']
)
def loss_forward(self,
F: ModuleType,
new_policy_probs: nd_sym_type,
actions: nd_sym_type,
old_policy_probs: nd_sym_type,
clip_param_rescaler: nd_sym_type,
advantages: nd_sym_type,
kl_coefficient: nd_sym_type) -> List[Tuple[nd_sym_type, str]]:
"""
Used for forward pass through loss computations.
Works with batches of data, and optionally time_steps, but be consistent in usage: i.e. if using time_step,
new_policy_probs, old_policy_probs, actions and advantages all must include a time_step dimension.
NOTE: order of input arguments MUST NOT CHANGE because it matches the order
parameters are passed in ppo_agent:train_network()
:param (mx.nd or mx.sym) F: backend api (mx.sym if block has been hybridized).
:param new_policy_probs: action probabilities predicted by DiscretePPOHead network,
of shape (batch_size, num_actions) or
of shape (batch_size, time_step, num_actions).
:param old_policy_probs: action probabilities for previous policy,
of shape (batch_size, num_actions) or
of shape (batch_size, time_step, num_actions).
:param actions: true actions taken during rollout,
of shape (batch_size) or
of shape (batch_size, time_step).
:param clip_param_rescaler: scales epsilon to use for likelihood ratio clipping.
:param advantages: change in state value after taking action (a.k.a advantage)
of shape (batch_size) or
of shape (batch_size, time_step).
:param kl_coefficient: loss coefficient applied kl divergence loss (also see high_kl_penalty_coefficient).
:return: loss, of shape (batch_size).
"""
old_policy_dist = CategoricalDist(self.num_actions, old_policy_probs, F=F)
action_probs_wrt_old_policy = old_policy_dist.log_prob(actions)
new_policy_dist = CategoricalDist(self.num_actions, new_policy_probs, F=F)
action_probs_wrt_new_policy = new_policy_dist.log_prob(actions)
entropy_loss = - self.beta * new_policy_dist.entropy().mean()
if self.use_kl_regularization:
kl_div = old_policy_dist.kl_div(new_policy_dist).mean()
weighted_kl_div = kl_coefficient * kl_div
high_kl_div = F.stack(F.zeros_like(kl_div), kl_div - self.kl_cutoff).max().square()
weighted_high_kl_div = self.high_kl_penalty_coefficient * high_kl_div
kl_div_loss = weighted_kl_div + weighted_high_kl_div
else:
kl_div_loss = F.zeros(shape=(1,))
# working with log probs, so minus first, then exponential (same as division)
likelihood_ratio = (action_probs_wrt_new_policy - action_probs_wrt_old_policy).exp()
if self.clip_likelihood_ratio_using_epsilon is not None:
# clipping of likelihood ratio
min_value = 1 - self.clip_likelihood_ratio_using_epsilon * clip_param_rescaler
max_value = 1 + self.clip_likelihood_ratio_using_epsilon * clip_param_rescaler
# can't use F.clip (with variable clipping bounds), hence custom implementation
clipped_likelihood_ratio = hybrid_clip(F, likelihood_ratio, clip_lower=min_value, clip_upper=max_value)
# lower bound of original, and clipped versions or each scaled advantage
# element-wise min between the two ndarrays
unclipped_scaled_advantages = likelihood_ratio * advantages
clipped_scaled_advantages = clipped_likelihood_ratio * advantages
scaled_advantages = F.stack(unclipped_scaled_advantages, clipped_scaled_advantages).min(axis=0)
else:
scaled_advantages = likelihood_ratio * advantages
clipped_likelihood_ratio = F.zeros_like(likelihood_ratio)
# for each batch, calculate expectation of scaled_advantages across time steps,
# but want code to work with data without time step too, so reshape to add timestep if doesn't exist.
scaled_advantages_w_time = scaled_advantages.reshape(shape=(0, -1))
expected_scaled_advantages = scaled_advantages_w_time.mean(axis=1)
# want to maximize expected_scaled_advantages, add minus so can minimize.
surrogate_loss = (-expected_scaled_advantages * self.weight).mean()
return [
(surrogate_loss, LOSS_OUT_TYPE_LOSS),
(entropy_loss + kl_div_loss, LOSS_OUT_TYPE_REGULARIZATION),
(kl_div_loss, LOSS_OUT_TYPE_KL),
(entropy_loss, LOSS_OUT_TYPE_ENTROPY),
(likelihood_ratio, LOSS_OUT_TYPE_LIKELIHOOD_RATIO),
(clipped_likelihood_ratio, LOSS_OUT_TYPE_CLIPPED_LIKELIHOOD_RATIO)
]
class ClippedPPOLossContinuous(HeadLoss):
def __init__(self,
num_actions: int,
clip_likelihood_ratio_using_epsilon: float,
beta: float=0,
use_kl_regularization: bool=False,
initial_kl_coefficient: float=1,
kl_cutoff: float=0,
high_kl_penalty_coefficient: float=1,
weight: float=1,
batch_axis: int=0):
"""
Loss for continuous version of Clipped PPO.
:param num_actions: number of actions in action space.
:param clip_likelihood_ratio_using_epsilon: epsilon to use for likelihood ratio clipping.
:param beta: loss coefficient applied to entropy
:param batch_axis: axis used for mini-batch (default is 0) and excluded from loss aggregation.
:param use_kl_regularization: option to add kl divergence loss
:param initial_kl_coefficient: initial loss coefficient applied kl divergence loss (also see high_kl_penalty_coefficient).
:param kl_cutoff: threshold for using high_kl_penalty_coefficient
:param high_kl_penalty_coefficient: loss coefficient applied to kv divergence above kl_cutoff
:param weight: scalar used to adjust relative weight of loss (if using this loss with others).
:param batch_axis: axis used for mini-batch (default is 0) and excluded from loss aggregation.
"""
super(ClippedPPOLossContinuous, self).__init__(weight=weight, batch_axis=batch_axis)
self.weight = weight
self.num_actions = num_actions
self.clip_likelihood_ratio_using_epsilon = clip_likelihood_ratio_using_epsilon
self.beta = beta
self.use_kl_regularization = use_kl_regularization
self.initial_kl_coefficient = initial_kl_coefficient if self.use_kl_regularization else 0.0
self.kl_coefficient = self.params.get('kl_coefficient',
shape=(1,),
init=mx.init.Constant([initial_kl_coefficient,]),
differentiable=False)
self.kl_cutoff = kl_cutoff
self.high_kl_penalty_coefficient = high_kl_penalty_coefficient
@property
def input_schema(self) -> LossInputSchema:
return LossInputSchema(
head_outputs=['new_policy_means','new_policy_stds'],
agent_inputs=['actions', 'old_policy_means', 'old_policy_stds', 'clip_param_rescaler'],
targets=['advantages']
)
def loss_forward(self,
F: ModuleType,
new_policy_means: nd_sym_type,
new_policy_stds: nd_sym_type,
actions: nd_sym_type,
old_policy_means: nd_sym_type,
old_policy_stds: nd_sym_type,
clip_param_rescaler: nd_sym_type,
advantages: nd_sym_type,
kl_coefficient: nd_sym_type) -> List[Tuple[nd_sym_type, str]]:
"""
Used for forward pass through loss computations.
Works with batches of data, and optionally time_steps, but be consistent in usage: i.e. if using time_step,
new_policy_means, old_policy_means, actions and advantages all must include a time_step dimension.
:param (mx.nd or mx.sym) F: backend api (mx.sym if block has been hybridized).
:param new_policy_means: action means predicted by MultivariateNormalDist network,
of shape (batch_size, num_actions) or
of shape (batch_size, time_step, num_actions).
:param new_policy_stds: action standard deviation returned by head,
of shape (batch_size, num_actions) or
of shape (batch_size, time_step, num_actions).
:param actions: true actions taken during rollout,
of shape (batch_size) or
of shape (batch_size, time_step).
:param old_policy_means: action means for previous policy,
of shape (batch_size, num_actions) or
of shape (batch_size, time_step, num_actions).
:param old_policy_stds: action standard deviation returned by head previously,
of shape (batch_size, num_actions) or
of shape (batch_size, time_step, num_actions).
:param clip_param_rescaler: scales epsilon to use for likelihood ratio clipping.
:param advantages: change in state value after taking action (a.k.a advantage)
of shape (batch_size) or
of shape (batch_size, time_step).
:param kl_coefficient: loss coefficient applied kl divergence loss (also see high_kl_penalty_coefficient).
:return: loss, of shape (batch_size).
"""
old_var = old_policy_stds ** 2
# sets diagonal in (batch size and time step) covariance matrices
old_covar = mx.nd.eye(N=self.num_actions) * (old_var + eps).broadcast_like(old_policy_means).expand_dims(-2)
old_policy_dist = MultivariateNormalDist(self.num_actions, old_policy_means, old_covar, F=F)
action_probs_wrt_old_policy = old_policy_dist.log_prob(actions)
new_var = new_policy_stds ** 2
# sets diagonal in (batch size and time step) covariance matrices
new_covar = mx.nd.eye(N=self.num_actions) * (new_var + eps).broadcast_like(new_policy_means).expand_dims(-2)
new_policy_dist = MultivariateNormalDist(self.num_actions, new_policy_means, new_covar, F=F)
action_probs_wrt_new_policy = new_policy_dist.log_prob(actions)
entropy_loss = - self.beta * new_policy_dist.entropy().mean()
if self.use_kl_regularization:
kl_div = old_policy_dist.kl_div(new_policy_dist).mean()
weighted_kl_div = kl_coefficient * kl_div
high_kl_div = F.stack(F.zeros_like(kl_div), kl_div - self.kl_cutoff).max().square()
weighted_high_kl_div = self.high_kl_penalty_coefficient * high_kl_div
kl_div_loss = weighted_kl_div + weighted_high_kl_div
else:
kl_div_loss = F.zeros(shape=(1,))
# working with log probs, so minus first, then exponential (same as division)
likelihood_ratio = (action_probs_wrt_new_policy - action_probs_wrt_old_policy).exp()
if self.clip_likelihood_ratio_using_epsilon is not None:
# clipping of likelihood ratio
min_value = 1 - self.clip_likelihood_ratio_using_epsilon * clip_param_rescaler
max_value = 1 + self.clip_likelihood_ratio_using_epsilon * clip_param_rescaler
# can't use F.clip (with variable clipping bounds), hence custom implementation
clipped_likelihood_ratio = hybrid_clip(F, likelihood_ratio, clip_lower=min_value, clip_upper=max_value)
# lower bound of original, and clipped versions or each scaled advantage
# element-wise min between the two ndarrays
unclipped_scaled_advantages = likelihood_ratio * advantages
clipped_scaled_advantages = clipped_likelihood_ratio * advantages
scaled_advantages = F.stack(unclipped_scaled_advantages, clipped_scaled_advantages).min(axis=0)
else:
scaled_advantages = likelihood_ratio * advantages
clipped_likelihood_ratio = F.zeros_like(likelihood_ratio)
# for each batch, calculate expectation of scaled_advantages across time steps,
# but want code to work with data without time step too, so reshape to add timestep if doesn't exist.
scaled_advantages_w_time = scaled_advantages.reshape(shape=(0, -1))
expected_scaled_advantages = scaled_advantages_w_time.mean(axis=1)
# want to maximize expected_scaled_advantages, add minus so can minimize.
surrogate_loss = (-expected_scaled_advantages * self.weight).mean()
return [
(surrogate_loss, LOSS_OUT_TYPE_LOSS),
(entropy_loss + kl_div_loss, LOSS_OUT_TYPE_REGULARIZATION),
(kl_div_loss, LOSS_OUT_TYPE_KL),
(entropy_loss, LOSS_OUT_TYPE_ENTROPY),
(likelihood_ratio, LOSS_OUT_TYPE_LIKELIHOOD_RATIO),
(clipped_likelihood_ratio, LOSS_OUT_TYPE_CLIPPED_LIKELIHOOD_RATIO)
]
class PPOHead(Head):
def __init__(self,
agent_parameters: AgentParameters,
spaces: SpacesDefinition,
network_name: str,
head_type_idx: int=0,
loss_weight: float=1.,
is_local: bool=True,
activation_function: str='tanh',
dense_layer: None=None) -> None:
"""
Head block for Proximal Policy Optimization, to calculate probabilities for each action given middleware
representation of the environment state.
:param agent_parameters: containing algorithm parameters such as clip_likelihood_ratio_using_epsilon
and beta_entropy.
:param spaces: containing action spaces used for defining size of network output.
:param network_name: name of head network. currently unused.
:param head_type_idx: index of head network. currently unused.
:param loss_weight: scalar used to adjust relative weight of loss (if using this loss with others).
:param is_local: flag to denote if network is local. currently unused.
:param activation_function: activation function to use between layers. currently unused.
:param dense_layer: type of dense layer to use in network. currently unused.
"""
super().__init__(agent_parameters, spaces, network_name, head_type_idx, loss_weight, is_local, activation_function,
dense_layer=dense_layer)
self.return_type = ActionProbabilities
self.clip_likelihood_ratio_using_epsilon = agent_parameters.algorithm.clip_likelihood_ratio_using_epsilon
self.beta = agent_parameters.algorithm.beta_entropy
self.use_kl_regularization = agent_parameters.algorithm.use_kl_regularization
if self.use_kl_regularization:
self.initial_kl_coefficient = agent_parameters.algorithm.initial_kl_coefficient
self.kl_cutoff = 2 * agent_parameters.algorithm.target_kl_divergence
self.high_kl_penalty_coefficient = agent_parameters.algorithm.high_kl_penalty_coefficient
else:
self.initial_kl_coefficient, self.kl_cutoff, self.high_kl_penalty_coefficient = (None, None, None)
self._loss = []
if isinstance(self.spaces.action, DiscreteActionSpace):
self.net = DiscretePPOHead(num_actions=len(self.spaces.action.actions))
elif isinstance(self.spaces.action, BoxActionSpace):
self.net = ContinuousPPOHead(num_actions=len(self.spaces.action.actions))
else:
raise ValueError("Only discrete or continuous action spaces are supported for PPO.")
def hybrid_forward(self,
F: ModuleType,
x: nd_sym_type) -> nd_sym_type:
"""
:param (mx.nd or mx.sym) F: backend api (mx.sym if block has been hybridized).
:param x: middleware embedding
:return: policy parameters/probabilities
"""
return self.net(x)
def loss(self) -> mx.gluon.loss.Loss:
"""
Specifies loss block to be used for this policy head.
:return: loss block (can be called as function) for action probabilities returned by this policy network.
"""
if isinstance(self.spaces.action, DiscreteActionSpace):
loss = ClippedPPOLossDiscrete(len(self.spaces.action.actions),
self.clip_likelihood_ratio_using_epsilon,
self.beta,
self.use_kl_regularization, self.initial_kl_coefficient,
self.kl_cutoff, self.high_kl_penalty_coefficient,
self.loss_weight)
elif isinstance(self.spaces.action, BoxActionSpace):
loss = ClippedPPOLossContinuous(len(self.spaces.action.actions),
self.clip_likelihood_ratio_using_epsilon,
self.beta,
self.use_kl_regularization, self.initial_kl_coefficient,
self.kl_cutoff, self.high_kl_penalty_coefficient,
self.loss_weight)
else:
raise ValueError("Only discrete or continuous action spaces are supported for PPO.")
loss.initialize()
# set a property so can assign_kl_coefficient in future,
# make a list, otherwise it would be added as a child of Head Block (due to type check)
self._loss = [loss]
return loss
@property
def kl_divergence(self):
return self.head_type_idx, LOSS_OUT_TYPE_KL
@property
def entropy(self):
return self.head_type_idx, LOSS_OUT_TYPE_ENTROPY
@property
def likelihood_ratio(self):
return self.head_type_idx, LOSS_OUT_TYPE_LIKELIHOOD_RATIO
@property
def clipped_likelihood_ratio(self):
return self.head_type_idx, LOSS_OUT_TYPE_CLIPPED_LIKELIHOOD_RATIO
def assign_kl_coefficient(self, kl_coefficient: float) -> None:
self._loss[0].kl_coefficient.set_data(mx.nd.array((kl_coefficient,)))

View File

@@ -0,0 +1,123 @@
from typing import List, Tuple, Union
from types import ModuleType
import mxnet as mx
from mxnet.gluon import nn
from rl_coach.architectures.mxnet_components.heads.head import Head, HeadLoss, LossInputSchema
from rl_coach.architectures.mxnet_components.heads.head import LOSS_OUT_TYPE_LOSS
from rl_coach.base_parameters import AgentParameters
from rl_coach.core_types import ActionProbabilities
from rl_coach.spaces import SpacesDefinition
nd_sym_type = Union[mx.nd.NDArray, mx.sym.Symbol]
class PPOVHeadLoss(HeadLoss):
def __init__(self, clip_likelihood_ratio_using_epsilon: float, weight: float=1, batch_axis: int=0) -> None:
"""
Loss for PPO Value network.
Schulman implemented this extension in OpenAI baselines for PPO2
See https://github.com/openai/baselines/blob/master/baselines/ppo2/ppo2.py#L72
:param clip_likelihood_ratio_using_epsilon: epsilon to use for likelihood ratio clipping.
:param weight: scalar used to adjust relative weight of loss (if using this loss with others).
:param batch_axis: axis used for mini-batch (default is 0) and excluded from loss aggregation.
"""
super(PPOVHeadLoss, self).__init__(weight=weight, batch_axis=batch_axis)
self.weight = weight
self.clip_likelihood_ratio_using_epsilon = clip_likelihood_ratio_using_epsilon
@property
def input_schema(self) -> LossInputSchema:
return LossInputSchema(
head_outputs=['new_policy_values'],
agent_inputs=['old_policy_values'],
targets=['target_values']
)
def loss_forward(self,
F: ModuleType,
new_policy_values: nd_sym_type,
old_policy_values: nd_sym_type,
target_values: nd_sym_type) -> List[Tuple[nd_sym_type, str]]:
"""
Used for forward pass through loss computations.
Calculates two losses (L2 and a clipped difference L2 loss) and takes the maximum of the two.
Works with batches of data, and optionally time_steps, but be consistent in usage: i.e. if using time_step,
new_policy_values, old_policy_values and target_values all must include a time_step dimension.
:param (mx.nd or mx.sym) F: backend api (mx.sym if block has been hybridized).
:param new_policy_values: values predicted by PPOVHead network,
of shape (batch_size) or
of shape (batch_size, time_step).
:param old_policy_values: values predicted by old value network,
of shape (batch_size) or
of shape (batch_size, time_step).
:param target_values: actual state values,
of shape (batch_size) or
of shape (batch_size, time_step).
:return: loss, of shape (batch_size).
"""
# L2 loss
value_loss_1 = (new_policy_values - target_values).square()
# Clipped difference L2 loss
diff = new_policy_values - old_policy_values
clipped_diff = diff.clip(a_min=-self.clip_likelihood_ratio_using_epsilon,
a_max=self.clip_likelihood_ratio_using_epsilon)
value_loss_2 = (old_policy_values + clipped_diff - target_values).square()
# Maximum of the two losses, element-wise maximum.
value_loss_max = mx.nd.stack(value_loss_1, value_loss_2).max(axis=0)
# Aggregate over temporal axis, adding if doesn't exist (hense reshape)
value_loss_max_w_time = value_loss_max.reshape(shape=(0, -1))
value_loss = value_loss_max_w_time.mean(axis=1)
# Weight the loss (and average over samples of batch)
value_loss_weighted = value_loss.mean() * self.weight
return [(value_loss_weighted, LOSS_OUT_TYPE_LOSS)]
class PPOVHead(Head):
def __init__(self,
agent_parameters: AgentParameters,
spaces: SpacesDefinition,
network_name: str,
head_type_idx: int=0,
loss_weight: float=1.,
is_local: bool = True,
activation_function: str='relu',
dense_layer: None=None) -> None:
"""
PPO Value Head for predicting state values.
:param agent_parameters: containing algorithm parameters, but currently unused.
:param spaces: containing action spaces, but currently unused.
:param network_name: name of head network. currently unused.
:param head_type_idx: index of head network. currently unused.
:param loss_weight: scalar used to adjust relative weight of loss (if using this loss with others).
:param is_local: flag to denote if network is local. currently unused.
:param activation_function: activation function to use between layers. currently unused.
:param dense_layer: type of dense layer to use in network. currently unused.
"""
super(PPOVHead, self).__init__(agent_parameters, spaces, network_name, head_type_idx, loss_weight, is_local,
activation_function, dense_layer=dense_layer)
self.clip_likelihood_ratio_using_epsilon = agent_parameters.algorithm.clip_likelihood_ratio_using_epsilon
self.return_type = ActionProbabilities
with self.name_scope():
self.dense = nn.Dense(units=1)
def hybrid_forward(self, F: ModuleType, x: nd_sym_type) -> nd_sym_type:
"""
Used for forward pass through value head network.
:param (mx.nd or mx.sym) F: backend api (mx.sym if block has been hybridized).
:param x: middleware state representation, of shape (batch_size, in_channels).
:return: final value output of network, of shape (batch_size).
"""
return self.dense(x).squeeze()
def loss(self) -> mx.gluon.loss.Loss:
"""
Specifies loss block to be used for specific value head implementation.
:return: loss block (can be called as function) for outputs returned by the value head network.
"""
return PPOVHeadLoss(self.clip_likelihood_ratio_using_epsilon, weight=self.loss_weight)

View File

@@ -0,0 +1,106 @@
from typing import Union, List, Tuple
from types import ModuleType
import mxnet as mx
from mxnet.gluon.loss import Loss, HuberLoss, L2Loss
from mxnet.gluon import nn
from rl_coach.architectures.mxnet_components.heads.head import Head, HeadLoss, LossInputSchema
from rl_coach.architectures.mxnet_components.heads.head import LOSS_OUT_TYPE_LOSS
from rl_coach.base_parameters import AgentParameters
from rl_coach.core_types import QActionStateValue
from rl_coach.spaces import SpacesDefinition, BoxActionSpace, DiscreteActionSpace
nd_sym_type = Union[mx.nd.NDArray, mx.sym.Symbol]
class QHeadLoss(HeadLoss):
def __init__(self, loss_type: Loss=L2Loss, weight: float=1, batch_axis: int=0) -> None:
"""
Loss for Q-Value Head.
:param loss_type: loss function with default of mean squared error (i.e. L2Loss).
:param weight: scalar used to adjust relative weight of loss (if using this loss with others).
:param batch_axis: axis used for mini-batch (default is 0) and excluded from loss aggregation.
"""
super(QHeadLoss, self).__init__(weight=weight, batch_axis=batch_axis)
with self.name_scope():
self.loss_fn = loss_type(weight=weight, batch_axis=batch_axis)
@property
def input_schema(self) -> LossInputSchema:
return LossInputSchema(
head_outputs=['pred'],
agent_inputs=[],
targets=['target']
)
def loss_forward(self,
F: ModuleType,
pred: nd_sym_type,
target: nd_sym_type) -> List[Tuple[nd_sym_type, str]]:
"""
Used for forward pass through loss computations.
:param F: backend api, either `mxnet.nd` or `mxnet.sym` (if block has been hybridized).
:param pred: state-action q-values predicted by QHead network, of shape (batch_size, num_actions).
:param target: actual state-action q-values, of shape (batch_size, num_actions).
:return: loss, of shape (batch_size).
"""
loss = self.loss_fn(pred, target).mean()
return [(loss, LOSS_OUT_TYPE_LOSS)]
class QHead(Head):
def __init__(self,
agent_parameters: AgentParameters,
spaces: SpacesDefinition,
network_name: str,
head_type_idx: int=0,
loss_weight: float=1.,
is_local: bool=True,
activation_function: str='relu',
dense_layer: None=None,
loss_type: Union[HuberLoss, L2Loss]=L2Loss) -> None:
"""
Q-Value Head for predicting state-action Q-Values.
:param agent_parameters: containing algorithm parameters, but currently unused.
:param spaces: containing action spaces used for defining size of network output.
:param network_name: name of head network. currently unused.
:param head_type_idx: index of head network. currently unused.
:param loss_weight: scalar used to adjust relative weight of loss (if using this loss with others).
:param is_local: flag to denote if network is local. currently unused.
:param activation_function: activation function to use between layers. currently unused.
:param dense_layer: type of dense layer to use in network. currently unused.
:param loss_type: loss function to use.
"""
super(QHead, self).__init__(agent_parameters, spaces, network_name, head_type_idx, loss_weight,
is_local, activation_function, dense_layer)
if isinstance(self.spaces.action, BoxActionSpace):
self.num_actions = 1
elif isinstance(self.spaces.action, DiscreteActionSpace):
self.num_actions = len(self.spaces.action.actions)
self.return_type = QActionStateValue
assert (loss_type == L2Loss) or (loss_type == HuberLoss), "Only expecting L2Loss or HuberLoss."
self.loss_type = loss_type
with self.name_scope():
self.dense = nn.Dense(units=self.num_actions)
def loss(self) -> Loss:
"""
Specifies loss block to be used for specific value head implementation.
:return: loss block (can be called as function) for outputs returned by the head network.
"""
return QHeadLoss(loss_type=self.loss_type, weight=self.loss_weight)
def hybrid_forward(self, F: ModuleType, x: nd_sym_type) -> nd_sym_type:
"""
Used for forward pass through Q-Value head network.
:param F: backend api, either `mxnet.nd` or `mxnet.sym` (if block has been hybridized).
:param x: middleware state representation, of shape (batch_size, in_channels).
:return: predicted state-action q-values, of shape (batch_size, num_actions).
"""
return self.dense(x)

View File

@@ -0,0 +1,100 @@
from typing import Union, List, Tuple
from types import ModuleType
import mxnet as mx
from mxnet.gluon.loss import Loss, HuberLoss, L2Loss
from mxnet.gluon import nn
from rl_coach.architectures.mxnet_components.heads.head import Head, HeadLoss, LossInputSchema
from rl_coach.architectures.mxnet_components.heads.head import LOSS_OUT_TYPE_LOSS
from rl_coach.base_parameters import AgentParameters
from rl_coach.core_types import VStateValue
from rl_coach.spaces import SpacesDefinition
nd_sym_type = Union[mx.nd.NDArray, mx.sym.Symbol]
class VHeadLoss(HeadLoss):
def __init__(self, loss_type: Loss=L2Loss, weight: float=1, batch_axis: int=0) -> None:
"""
Loss for Value Head.
:param loss_type: loss function with default of mean squared error (i.e. L2Loss).
:param weight: scalar used to adjust relative weight of loss (if using this loss with others).
:param batch_axis: axis used for mini-batch (default is 0) and excluded from loss aggregation.
"""
super(VHeadLoss, self).__init__(weight=weight, batch_axis=batch_axis)
with self.name_scope():
self.loss_fn = loss_type(weight=weight, batch_axis=batch_axis)
@property
def input_schema(self) -> LossInputSchema:
return LossInputSchema(
head_outputs=['pred'],
agent_inputs=[],
targets=['target']
)
def loss_forward(self,
F: ModuleType,
pred: nd_sym_type,
target: nd_sym_type) -> List[Tuple[nd_sym_type, str]]:
"""
Used for forward pass through loss computations.
:param F: backend api, either `mxnet.nd` or `mxnet.sym` (if block has been hybridized).
:param pred: state values predicted by VHead network, of shape (batch_size).
:param target: actual state values, of shape (batch_size).
:return: loss, of shape (batch_size).
"""
loss = self.loss_fn(pred, target).mean()
return [(loss, LOSS_OUT_TYPE_LOSS)]
class VHead(Head):
def __init__(self,
agent_parameters: AgentParameters,
spaces: SpacesDefinition,
network_name: str,
head_type_idx: int=0,
loss_weight: float=1.,
is_local: bool=True,
activation_function: str='relu',
dense_layer: None=None,
loss_type: Union[HuberLoss, L2Loss]=L2Loss):
"""
Value Head for predicting state values.
:param agent_parameters: containing algorithm parameters, but currently unused.
:param spaces: containing action spaces, but currently unused.
:param network_name: name of head network. currently unused.
:param head_type_idx: index of head network. currently unused.
:param loss_weight: scalar used to adjust relative weight of loss (if using this loss with others).
:param is_local: flag to denote if network is local. currently unused.
:param activation_function: activation function to use between layers. currently unused.
:param dense_layer: type of dense layer to use in network. currently unused.
:param loss_type: loss function with default of mean squared error (i.e. L2Loss), or alternatively HuberLoss.
"""
super(VHead, self).__init__(agent_parameters, spaces, network_name, head_type_idx, loss_weight,
is_local, activation_function, dense_layer)
assert (loss_type == L2Loss) or (loss_type == HuberLoss), "Only expecting L2Loss or HuberLoss."
self.loss_type = loss_type
self.return_type = VStateValue
with self.name_scope():
self.dense = nn.Dense(units=1)
def loss(self) -> Loss:
"""
Specifies loss block to be used for specific value head implementation.
:return: loss block (can be called as function) for outputs returned by the head network.
"""
return VHeadLoss(loss_type=self.loss_type, weight=self.loss_weight)
def hybrid_forward(self, F: ModuleType, x: nd_sym_type) -> nd_sym_type:
"""
Used for forward pass through value head network.
:param F: backend api, either `mxnet.nd` or `mxnet.sym` (if block has been hybridized).
:param x: middleware state representation, of shape (batch_size, in_channels).
:return: final output of value network, of shape (batch_size).
"""
return self.dense(x).squeeze()

View File

@@ -0,0 +1,99 @@
"""
Module implementing basic layers in mxnet
"""
from types import FunctionType
from mxnet.gluon import nn
from rl_coach.architectures import layers
from rl_coach.architectures.mxnet_components import utils
# define global dictionary for storing layer type to layer implementation mapping
mx_layer_dict = dict()
def reg_to_mx(layer_type) -> FunctionType:
""" function decorator that registers layer implementation
:return: decorated function
"""
def reg_impl_decorator(func):
assert layer_type not in mx_layer_dict
mx_layer_dict[layer_type] = func
return func
return reg_impl_decorator
def convert_layer(layer):
"""
If layer is callable, return layer, otherwise convert to MX type
:param layer: layer to be converted
:return: converted layer if not callable, otherwise layer itself
"""
if callable(layer):
return layer
return mx_layer_dict[type(layer)](layer)
class Conv2d(layers.Conv2d):
def __init__(self, num_filters: int, kernel_size: int, strides: int):
super(Conv2d, self).__init__(num_filters=num_filters, kernel_size=kernel_size, strides=strides)
def __call__(self) -> nn.Conv2D:
"""
returns a conv2d block
:return: conv2d block
"""
return nn.Conv2D(channels=self.num_filters, kernel_size=self.kernel_size, strides=self.strides)
@staticmethod
@reg_to_mx(layers.Conv2d)
def to_mx(base: layers.Conv2d):
return Conv2d(num_filters=base.num_filters, kernel_size=base.kernel_size, strides=base.strides)
class BatchnormActivationDropout(layers.BatchnormActivationDropout):
def __init__(self, batchnorm: bool=False, activation_function=None, dropout_rate: float=0):
super(BatchnormActivationDropout, self).__init__(
batchnorm=batchnorm, activation_function=activation_function, dropout_rate=dropout_rate)
def __call__(self):
"""
returns a list of mxnet batchnorm, activation and dropout layers
:return: batchnorm, activation and dropout layers
"""
block = nn.HybridSequential()
if self.batchnorm:
block.add(nn.BatchNorm())
if self.activation_function:
block.add(nn.Activation(activation=utils.get_mxnet_activation_name(self.activation_function)))
if self.dropout_rate:
block.add(nn.Dropout(self.dropout_rate))
return block
@staticmethod
@reg_to_mx(layers.BatchnormActivationDropout)
def to_mx(base: layers.BatchnormActivationDropout):
return BatchnormActivationDropout(
batchnorm=base.batchnorm,
activation_function=base.activation_function,
dropout_rate=base.dropout_rate)
class Dense(layers.Dense):
def __init__(self, units: int):
super(Dense, self).__init__(units=units)
def __call__(self):
"""
returns a mxnet dense layer
:return: dense layer
"""
# Set flatten to False for consistent behavior with tf.layers.dense
return nn.Dense(self.units, flatten=False)
@staticmethod
@reg_to_mx(layers.Dense)
def to_mx(base: layers.Dense):
return Dense(units=base.units)

View File

@@ -0,0 +1,4 @@
from .fc_middleware import FCMiddleware
from .lstm_middleware import LSTMMiddleware
__all__ = ["FCMiddleware", "LSTMMiddleware"]

View File

@@ -0,0 +1,52 @@
"""
Module that defines the fully-connected middleware class
"""
from rl_coach.architectures.mxnet_components.layers import Dense
from rl_coach.architectures.mxnet_components.middlewares.middleware import Middleware
from rl_coach.architectures.middleware_parameters import FCMiddlewareParameters
from rl_coach.base_parameters import MiddlewareScheme
class FCMiddleware(Middleware):
def __init__(self, params: FCMiddlewareParameters):
"""
FCMiddleware or Fully-Connected Middleware can be used in the middle part of the network. It takes the
embeddings from the input embedders, after they were aggregated in some method (for example, concatenation)
and passes it through a neural network which can be customizable but shared between the heads of the network.
:param params: parameters object containing batchnorm, activation_function and dropout properties.
"""
super(FCMiddleware, self).__init__(params)
@property
def schemes(self) -> dict:
"""
Schemes are the pre-defined network architectures of various depths and complexities that can be used for the
Middleware. Are used to create Block when FCMiddleware is initialised.
:return: dictionary of schemes, with key of type MiddlewareScheme enum and value being list of mxnet.gluon.Block.
"""
return {
MiddlewareScheme.Empty:
[],
# Use for PPO
MiddlewareScheme.Shallow:
[
Dense(units=64)
],
# Use for DQN
MiddlewareScheme.Medium:
[
Dense(units=512)
],
MiddlewareScheme.Deep:
[
Dense(units=128),
Dense(units=128),
Dense(units=128)
]
}

View File

@@ -0,0 +1,80 @@
"""
Module that defines the LSTM middleware class
"""
from typing import Union
from types import ModuleType
import mxnet as mx
from mxnet.gluon import rnn
from rl_coach.architectures.mxnet_components.layers import Dense
from rl_coach.architectures.mxnet_components.middlewares.middleware import Middleware
from rl_coach.architectures.middleware_parameters import LSTMMiddlewareParameters
from rl_coach.base_parameters import MiddlewareScheme
nd_sym_type = Union[mx.nd.NDArray, mx.sym.Symbol]
class LSTMMiddleware(Middleware):
def __init__(self, params: LSTMMiddlewareParameters):
"""
LSTMMiddleware or Long Short Term Memory Middleware can be used in the middle part of the network. It takes the
embeddings from the input embedders, after they were aggregated in some method (for example, concatenation)
and passes it through a neural network which can be customizable but shared between the heads of the network.
:param params: parameters object containing batchnorm, activation_function, dropout and
number_of_lstm_cells properties.
"""
super(LSTMMiddleware, self).__init__(params)
self.number_of_lstm_cells = params.number_of_lstm_cells
with self.name_scope():
self.lstm = rnn.LSTM(hidden_size=self.number_of_lstm_cells)
@property
def schemes(self) -> dict:
"""
Schemes are the pre-defined network architectures of various depths and complexities that can be used for the
Middleware. Are used to create Block when LSTMMiddleware is initialised, and are applied before the LSTM.
:return: dictionary of schemes, with key of type MiddlewareScheme enum and value being list of mxnet.gluon.Block.
"""
return {
MiddlewareScheme.Empty:
[],
# Use for PPO
MiddlewareScheme.Shallow:
[
Dense(units=64)
],
# Use for DQN
MiddlewareScheme.Medium:
[
Dense(units=512)
],
MiddlewareScheme.Deep:
[
Dense(units=128),
Dense(units=128),
Dense(units=128)
]
}
def hybrid_forward(self,
F: ModuleType,
x: nd_sym_type,
*args, **kwargs) -> nd_sym_type:
"""
Used for forward pass through LSTM middleware network.
Applies dense layers from selected scheme before passing result to LSTM layer.
:param F: backend api, either `mxnet.nd` or `mxnet.sym` (if block has been hybridized).
:param x: state embedding, of shape (batch_size, in_channels).
:return: state middleware embedding, where shape is (batch_size, channels).
"""
x_ntc = x.reshape(shape=(0, 0, -1))
emb_ntc = super(LSTMMiddleware, self).hybrid_forward(F, x_ntc, *args, **kwargs)
emb_tnc = emb_ntc.transpose(axes=(1, 0, 2))
return self.lstm(emb_tnc)

View File

@@ -0,0 +1,61 @@
from typing import Union
from types import ModuleType
import mxnet as mx
from mxnet.gluon import nn
from rl_coach.architectures.middleware_parameters import MiddlewareParameters
from rl_coach.architectures.mxnet_components.layers import convert_layer
from rl_coach.base_parameters import MiddlewareScheme
nd_sym_type = Union[mx.nd.NDArray, mx.sym.Symbol]
class Middleware(nn.HybridBlock):
def __init__(self, params: MiddlewareParameters):
"""
Middleware is the middle part of the network. It takes the embeddings from the input embedders,
after they were aggregated in some method (for example, concatenation) and passes it through a neural network
which can be customizable but shared between the heads of the network.
:param params: parameters object containing batchnorm, activation_function and dropout properties.
"""
super(Middleware, self).__init__()
self.scheme = params.scheme
with self.name_scope():
self.net = nn.HybridSequential()
if isinstance(self.scheme, MiddlewareScheme):
blocks = self.schemes[self.scheme]
else:
# if scheme is specified directly, convert to MX layer if it's not a callable object
# NOTE: if layer object is callable, it must return a gluon block when invoked
blocks = [convert_layer(l) for l in self.scheme]
for block in blocks:
self.net.add(block())
if params.batchnorm:
self.net.add(nn.BatchNorm())
if params.activation_function:
self.net.add(nn.Activation(params.activation_function))
if params.dropout:
self.net.add(nn.Dropout(rate=params.dropout))
@property
def schemes(self) -> dict:
"""
Schemes are the pre-defined network architectures of various depths and complexities that can be used for the
Middleware. Should be implemented in child classes, and are used to create Block when Middleware is initialised.
:return: dictionary of schemes, with key of type MiddlewareScheme enum and value being list of mxnet.gluon.Block.
"""
raise NotImplementedError("Inheriting embedder must define schemes matching its allowed default "
"configurations.")
def hybrid_forward(self, F: ModuleType, x: nd_sym_type, *args, **kwargs) -> nd_sym_type:
"""
Used for forward pass through middleware network.
:param F: backend api, either `mxnet.nd` or `mxnet.sym` (if block has been hybridized).
:param x: state embedding, of shape (batch_size, in_channels).
:return: state middleware embedding, where shape is (batch_size, channels).
"""
return self.net(x)

View File

@@ -0,0 +1,280 @@
"""
Module defining utility functions
"""
import inspect
from typing import Any, Dict, Generator, Iterable, List, Tuple, Union
from types import ModuleType
import mxnet as mx
from mxnet import nd
from mxnet.ndarray import NDArray
import numpy as np
from rl_coach.core_types import GradientClippingMethod
nd_sym_type = Union[mx.nd.NDArray, mx.sym.Symbol]
def to_mx_ndarray(data: Union[list, tuple, np.ndarray, NDArray, int, float]) ->\
Union[List[NDArray], Tuple[NDArray], NDArray]:
"""
Convert data to mx.nd.NDArray. Data can be a list or tuple of np.ndarray, int, or float or
it can be np.ndarray, int, or float
:param data: input data to be converted
:return: converted output data
"""
if isinstance(data, list):
data = [to_mx_ndarray(d) for d in data]
elif isinstance(data, tuple):
data = tuple(to_mx_ndarray(d) for d in data)
elif isinstance(data, np.ndarray):
data = nd.array(data)
elif isinstance(data, NDArray):
pass
elif isinstance(data, int) or isinstance(data, float):
data = nd.array([data])
else:
raise TypeError('Unsupported data type: {}'.format(type(data)))
return data
def asnumpy_or_asscalar(data: Union[NDArray, list, tuple]) -> Union[np.ndarray, np.number, list, tuple]:
"""
Convert NDArray (or list or tuple of NDArray) to numpy. If shape is (1,), then convert to scalar instead.
NOTE: This behavior is consistent with tensorflow
:param data: NDArray or list or tuple of NDArray
:return: data converted to numpy ndarray or to numpy scalar
"""
if isinstance(data, list):
data = [asnumpy_or_asscalar(d) for d in data]
elif isinstance(data, tuple):
data = tuple(asnumpy_or_asscalar(d) for d in data)
elif isinstance(data, NDArray):
data = data.asscalar() if data.shape == (1,) else data.asnumpy()
else:
raise TypeError('Unsupported data type: {}'.format(type(data)))
return data
def global_norm(arrays: Union[Generator[NDArray, NDArray, NDArray], List[NDArray], Tuple[NDArray]]) -> NDArray:
"""
Calculate global norm on list or tuple of NDArrays using this formula:
`global_norm = sqrt(sum([l2norm(p)**2 for p in parameters]))`
:param arrays: list or tuple of parameters to calculate global norm on
:return: single-value NDArray
"""
def _norm(array):
if array.stype == 'default':
x = array.reshape((-1,))
return nd.dot(x, x)
return array.norm().square()
total_norm = nd.add_n(*[_norm(arr) for arr in arrays])
total_norm = nd.sqrt(total_norm)
return total_norm
def split_outputs_per_head(outputs: Tuple[NDArray], heads: list) -> List[List[NDArray]]:
"""
Split outputs into outputs per head
:param outputs: list of all outputs
:param heads: list of all heads
:return: list of outputs for each head
"""
head_outputs = []
for h in heads:
head_outputs.append(list(outputs[:h.num_outputs]))
outputs = outputs[h.num_outputs:]
assert len(outputs) == 0
return head_outputs
def split_targets_per_loss(targets: list, losses: list) -> List[list]:
"""
Splits targets into targets per loss
:param targets: list of all targets (typically numpy ndarray)
:param losses: list of all losses
:return: list of targets for each loss
"""
loss_targets = list()
for l in losses:
loss_data_len = len(l.input_schema.targets)
assert len(targets) >= loss_data_len, "Data length doesn't match schema"
loss_targets.append(targets[:loss_data_len])
targets = targets[loss_data_len:]
assert len(targets) == 0
return loss_targets
def get_loss_agent_inputs(inputs: Dict[str, np.ndarray], head_type_idx: int, loss: Any) -> List[np.ndarray]:
"""
Collects all inputs with prefix 'output_<head_idx>_' and matches them against agent_inputs in loss input schema.
:param inputs: list of all agent inputs
:param head_type_idx: head-type index of the corresponding head
:param loss: corresponding loss
:return: list of agent inputs for this loss. This list matches the length in loss input schema.
"""
loss_inputs = list()
for k in sorted(inputs.keys()):
if k.startswith('output_{}_'.format(head_type_idx)):
loss_inputs.append(inputs[k])
# Enforce that number of inputs for head_type are the same as agent_inputs specified by loss input_schema
assert len(loss_inputs) == len(loss.input_schema.agent_inputs), "agent_input length doesn't match schema"
return loss_inputs
def align_loss_args(
head_outputs: List[NDArray],
agent_inputs: List[np.ndarray],
targets: List[np.ndarray],
loss: Any) -> List[np.ndarray]:
"""
Creates a list of arguments from head_outputs, agent_inputs, and targets aligned with parameters of
loss.loss_forward() based on their name in loss input_schema
:param head_outputs: list of all head_outputs for this loss
:param agent_inputs: list of all agent_inputs for this loss
:param targets: list of all targets for this loss
:param loss: corresponding loss
:return: list of arguments in correct order to be passed to loss
"""
arg_list = list()
schema = loss.input_schema
assert len(schema.head_outputs) == len(head_outputs)
assert len(schema.agent_inputs) == len(agent_inputs)
assert len(schema.targets) == len(targets)
prev_found = True
for arg_name in inspect.getfullargspec(loss.loss_forward).args[2:]: # First two args are self and F
found = False
for schema_list, data in [(schema.head_outputs, head_outputs),
(schema.agent_inputs, agent_inputs),
(schema.targets, targets)]:
try:
arg_list.append(data[schema_list.index(arg_name)])
found = True
break
except ValueError:
continue
assert not found or prev_found, "missing arguments detected!"
prev_found = found
return arg_list
def to_tuple(data: Union[tuple, list, Any]):
"""
If input is list, it is converted to tuple. If it's tuple, it is returned untouched. Otherwise
returns a single-element tuple of the data.
:return: tuple-ified data
"""
if isinstance(data, tuple):
pass
elif isinstance(data, list):
data = tuple(data)
else:
data = (data,)
return data
def to_list(data: Union[tuple, list, Any]):
"""
If input is tuple, it is converted to list. If it's list, it is returned untouched. Otherwise
returns a single-element list of the data.
:return: list-ified data
"""
if isinstance(data, list):
pass
elif isinstance(data, tuple):
data = list(data)
else:
data = [data]
return data
def loss_output_dict(output: List[NDArray], schema: List[str]) -> Dict[str, List[NDArray]]:
"""
Creates a dictionary for loss output based on the output schema. If two output values have the same
type string in the schema they are concatenated in the same dicrionary item.
:param output: list of output values
:param schema: list of type-strings for output values
:return: dictionary of keyword to list of NDArrays
"""
assert len(output) == len(schema)
output_dict = dict()
for name, val in zip(schema, output):
if name in output_dict:
output_dict[name].append(val)
else:
output_dict[name] = [val]
return output_dict
def clip_grad(
grads: Union[Generator[NDArray, NDArray, NDArray], List[NDArray], Tuple[NDArray]],
clip_method: GradientClippingMethod,
clip_val: float,
inplace=True) -> List[NDArray]:
"""
Clip gradient values inplace
:param grads: gradients to be clipped
:param clip_method: clipping method
:param clip_val: clipping value. Interpreted differently depending on clipping method.
:param inplace: modify grads if True, otherwise create NDArrays
:return: clipped gradients
"""
output = list(grads) if inplace else list(nd.empty(g.shape) for g in grads)
if clip_method == GradientClippingMethod.ClipByGlobalNorm:
norm_unclipped_grads = global_norm(grads)
scale = clip_val / (norm_unclipped_grads.asscalar() + 1e-8) # todo: use branching operators?
if scale < 1.0:
for g, o in zip(grads, output):
nd.broadcast_mul(g, nd.array([scale]), out=o)
elif clip_method == GradientClippingMethod.ClipByValue:
for g, o in zip(grads, output):
g.clip(-clip_val, clip_val, out=o)
elif clip_method == GradientClippingMethod.ClipByNorm:
for g, o in zip(grads, output):
nd.broadcast_mul(g, nd.minimum(1.0, clip_val / (g.norm() + 1e-8)), out=o)
else:
raise KeyError('Unsupported gradient clipping method')
return output
def hybrid_clip(F: ModuleType, x: nd_sym_type, clip_lower: nd_sym_type, clip_upper: nd_sym_type) -> nd_sym_type:
"""
Apply clipping to input x between clip_lower and clip_upper.
Added because F.clip doesn't support clipping bounds that are mx.nd.NDArray or mx.sym.Symbol.
:param F: backend api, either `mxnet.nd` or `mxnet.sym` (if block has been hybridized).
:param x: input data
:param clip_lower: lower bound used for clipping, should be of shape (1,)
:param clip_upper: upper bound used for clipping, should be of shape (1,)
:return: clipped data
"""
x_clip_lower = clip_lower.broadcast_like(x)
x_clip_upper = clip_upper.broadcast_like(x)
x_clipped = F.stack(x, x_clip_lower, axis=0).max(axis=0)
x_clipped = F.stack(x_clipped, x_clip_upper, axis=0).min(axis=0)
return x_clipped
def get_mxnet_activation_name(activation_name: str):
"""
Convert coach activation name to mxnet specific activation name
:param activation_name: name of the activation inc coach
:return: name of the activation in mxnet
"""
activation_functions = {
'relu': 'relu',
'tanh': 'tanh',
'sigmoid': 'sigmoid',
# FIXME Add other activations
# 'elu': tf.nn.elu,
'selu': 'softrelu',
# 'leaky_relu': tf.nn.leaky_relu,
'none': None
}
assert activation_name in activation_functions, \
"Activation function must be one of the following {}. instead it was: {}".format(
activation_functions.keys(), activation_name)
return activation_functions[activation_name]

View File

@@ -183,16 +183,7 @@ class NetworkWrapper(object):
target_network or global_network) and the second element is the inputs
:return: the outputs of all the networks in the same order as the inputs were given
"""
feed_dict = {}
fetches = []
for idx, (network, input) in enumerate(network_input_tuples):
feed_dict.update(network.create_feed_dict(input))
fetches += network.outputs
outputs = self.sess.run(fetches, feed_dict)
return outputs
return type(self.online_network).parallel_predict(self.sess, network_input_tuples)
def get_local_variables(self):
"""

View File

@@ -13,8 +13,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import time
import os
import time
from typing import Any, List, Tuple, Dict
import numpy as np
import tensorflow as tf
@@ -544,6 +545,26 @@ class TensorFlowArchitecture(Architecture):
output = squeeze_list(output)
return output
@staticmethod
def parallel_predict(sess: Any,
network_input_tuples: List[Tuple['TensorFlowArchitecture', Dict[str, np.ndarray]]]) ->\
List[np.ndarray]:
"""
:param sess: active session to use for prediction
:param network_input_tuples: tuple of network and corresponding input
:return: list of outputs from all networks
"""
feed_dict = {}
fetches = []
for network, input in network_input_tuples:
feed_dict.update(network.create_feed_dict(input))
fetches += network.outputs
outputs = sess.run(fetches, feed_dict)
return outputs
def train_on_batch(self, inputs, targets, scaler=1., additional_fetches=None, importance_weights=None):
"""
Given a batch of examples and targets, runs a forward pass & backward pass and then applies the gradients

View File

@@ -22,7 +22,7 @@ from distutils.dir_util import copy_tree, remove_tree
from typing import List, Tuple
import contextlib
from rl_coach.base_parameters import iterable_to_items, TaskParameters, DistributedTaskParameters, \
from rl_coach.base_parameters import iterable_to_items, TaskParameters, DistributedTaskParameters, Frameworks, \
VisualizationParameters, \
Parameters, PresetValidationParameters
from rl_coach.core_types import TotalStepsCounter, RunPhase, PlayingStepsType, TrainingSteps, EnvironmentEpisodes, \
@@ -161,7 +161,8 @@ class GraphManager(object):
"""
raise NotImplementedError("")
def create_worker_or_parameters_server(self, task_parameters: DistributedTaskParameters):
@staticmethod
def _create_worker_or_parameters_server_tf(task_parameters: DistributedTaskParameters):
import tensorflow as tf
config = tf.ConfigProto()
config.allow_soft_placement = True # allow placing ops on cpu if they are not fit for gpu
@@ -170,7 +171,8 @@ class GraphManager(object):
config.intra_op_parallelism_threads = 1
config.inter_op_parallelism_threads = 1
from rl_coach.architectures.tensorflow_components.distributed_tf_utils import create_and_start_parameters_server, \
from rl_coach.architectures.tensorflow_components.distributed_tf_utils import \
create_and_start_parameters_server, \
create_cluster_spec, create_worker_server_and_device
# create cluster spec
@@ -190,7 +192,16 @@ class GraphManager(object):
raise ValueError("The job type should be either ps or worker and not {}"
.format(task_parameters.job_type))
def create_session(self, task_parameters: DistributedTaskParameters):
@staticmethod
def create_worker_or_parameters_server(task_parameters: DistributedTaskParameters):
if task_parameters.framework_type == Frameworks.tensorflow:
GraphManager._create_worker_or_parameters_server_tf(task_parameters)
elif task_parameters.framework_type == Frameworks.mxnet:
raise NotImplementedError('Distributed training not implemented for MXNet')
else:
raise ValueError('Invalid framework {}'.format(task_parameters.framework_type))
def _create_session_tf(self, task_parameters: TaskParameters):
import tensorflow as tf
config = tf.ConfigProto()
config.allow_soft_placement = True # allow placing ops on cpu if they are not fit for gpu
@@ -235,6 +246,15 @@ class GraphManager(object):
if hasattr(self.task_parameters, 'checkpoint_save_dir') and self.task_parameters.checkpoint_save_dir:
self.save_graph()
def create_session(self, task_parameters: TaskParameters):
if task_parameters.framework_type == Frameworks.tensorflow:
self._create_session_tf(task_parameters)
elif task_parameters.framework_type == Frameworks.mxnet:
self.set_session(sess=None) # Initialize all modules
# TODO add checkpoint loading
else:
raise ValueError('Invalid framework {}'.format(task_parameters.framework_type))
def save_graph(self) -> None:
"""
Save the TF graph to a protobuf description file in the experiment directory
@@ -490,27 +510,35 @@ class GraphManager(object):
self.train_and_act(self.steps_between_evaluation_periods)
self.evaluate(self.evaluation_steps)
def _restore_checkpoint_tf(self, checkpoint_dir: str):
import tensorflow as tf
checkpoint = tf.train.get_checkpoint_state(checkpoint_dir)
screen.log_title("Loading checkpoint: {}".format(checkpoint.model_checkpoint_path))
variables = {}
for var_name, _ in tf.contrib.framework.list_variables(checkpoint_dir):
# Load the variable
var = tf.contrib.framework.load_variable(checkpoint_dir, var_name)
# Set the new name
new_name = var_name
new_name = new_name.replace('global/', 'online/')
variables[new_name] = var
for v in self.variables_to_restore:
self.sess.run(v.assign(variables[v.name.split(':')[0]]))
def restore_checkpoint(self):
self.verify_graph_was_created()
# TODO: find better way to load checkpoints that were saved with a global network into the online network
if hasattr(self.task_parameters, 'checkpoint_restore_dir') and self.task_parameters.checkpoint_restore_dir:
import tensorflow as tf
checkpoint_dir = self.task_parameters.checkpoint_restore_dir
checkpoint = tf.train.get_checkpoint_state(checkpoint_dir)
screen.log_title("Loading checkpoint: {}".format(checkpoint.model_checkpoint_path))
variables = {}
for var_name, _ in tf.contrib.framework.list_variables(checkpoint_dir):
# Load the variable
var = tf.contrib.framework.load_variable(checkpoint_dir, var_name)
# Set the new name
new_name = var_name
new_name = new_name.replace('global/', 'online/')
variables[new_name] = var
for v in self.variables_to_restore:
self.sess.run(v.assign(variables[v.name.split(':')[0]]))
if self.task_parameters.framework_type == Frameworks.tensorflow:
self._restore_checkpoint_tf(self.task_parameters.checkpoint_restore_dir)
elif self.task_parameters.framework_type == Frameworks.mxnet:
# TODO implement checkpoint restore
pass
else:
raise ValueError('Invalid framework {}'.format(self.task_parameters.framework_type))
def occasionally_save_checkpoint(self):
# only the chief process saves checkpoints
@@ -529,7 +557,10 @@ class GraphManager(object):
self.checkpoint_id,
self.total_steps_counters[RunPhase.TRAIN][EnvironmentSteps]))
if not isinstance(self.task_parameters, DistributedTaskParameters):
saved_checkpoint_path = self.checkpoint_saver.save(self.sess, checkpoint_path)
if self.checkpoint_saver is not None:
saved_checkpoint_path = self.checkpoint_saver.save(self.sess, checkpoint_path)
else:
saved_checkpoint_path = "<Not Saved>"
else:
saved_checkpoint_path = checkpoint_path

View File

@@ -49,8 +49,8 @@ agent_params.algorithm.distributed_coach_synchronization_type = DistributedCoach
agent_params.exploration = EGreedyParameters()
agent_params.exploration.epsilon_schedule = LinearSchedule(1.0, 0.01, 10000)
agent_params.pre_network_filter.add_observation_filter('observation', 'normalize_observation',
ObservationNormalizationFilter(name='normalize_observation'))
# agent_params.pre_network_filter.add_observation_filter('observation', 'normalize_observation',
# ObservationNormalizationFilter(name='normalize_observation'))
###############
# Environment #

View File

@@ -0,0 +1,21 @@
import mxnet as mx
import os
import pytest
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
from rl_coach.base_parameters import EmbedderScheme
from rl_coach.architectures.embedder_parameters import InputEmbedderParameters
from rl_coach.architectures.mxnet_components.embedders.image_embedder import ImageEmbedder
@pytest.mark.unit_test
def test_image_embedder():
params = InputEmbedderParameters(scheme=EmbedderScheme.Medium)
emb = ImageEmbedder(params=params)
emb.initialize()
input_data = mx.nd.random.uniform(low=0, high=1, shape=(10, 3, 244, 244))
output = emb(input_data)
assert len(output.shape) == 2 # since last block was flatten
assert output.shape[0] == 10 # since batch_size is 10

View File

@@ -0,0 +1,22 @@
import mxnet as mx
import os
import pytest
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
from rl_coach.architectures.embedder_parameters import InputEmbedderParameters
from rl_coach.architectures.mxnet_components.embedders.vector_embedder import VectorEmbedder
from rl_coach.base_parameters import EmbedderScheme
@pytest.mark.unit_test
def test_vector_embedder():
params = InputEmbedderParameters(scheme=EmbedderScheme.Medium)
emb = VectorEmbedder(params=params)
emb.initialize()
input_data = mx.nd.random.uniform(low=0, high=255, shape=(10, 100))
output = emb(input_data)
assert len(output.shape) == 2 # since last block was flatten
assert output.shape[0] == 10 # since batch_size is 10
assert output.shape[1] == 256 # since last dense layer has 256 units

View File

@@ -0,0 +1,406 @@
import mxnet as mx
import numpy as np
import os
import pytest
from scipy import stats as sp_stats
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
from rl_coach.architectures.head_parameters import PPOHeadParameters
from rl_coach.architectures.mxnet_components.heads.ppo_head import CategoricalDist, MultivariateNormalDist,\
DiscretePPOHead, ClippedPPOLossDiscrete, ClippedPPOLossContinuous, PPOHead
from rl_coach.agents.clipped_ppo_agent import ClippedPPOAlgorithmParameters, ClippedPPOAgentParameters
from rl_coach.spaces import SpacesDefinition, DiscreteActionSpace
@pytest.mark.unit_test
def test_multivariate_normal_dist_shape():
num_var = 2
means = mx.nd.array((0, 1))
covar = mx.nd.array(((1, 0),(0, 0.5)))
data = mx.nd.array((0.5, 0.8))
policy_dist = MultivariateNormalDist(num_var, means, covar)
log_probs = policy_dist.log_prob(data)
assert log_probs.ndim == 1
assert log_probs.shape[0] == 1
@pytest.mark.unit_test
def test_multivariate_normal_dist_batch_shape():
num_var = 2
batch_size = 3
means = mx.nd.random.uniform(shape=(batch_size, num_var))
# create batch of covariance matrices only defined on diagonal
std = mx.nd.array((1, 0.5)).broadcast_like(means).expand_dims(-2)
covar = mx.nd.eye(N=num_var) * std
data = mx.nd.random.uniform(shape=(batch_size, num_var))
policy_dist = MultivariateNormalDist(num_var, means, covar)
log_probs = policy_dist.log_prob(data)
assert log_probs.ndim == 1
assert log_probs.shape[0] == batch_size
@pytest.mark.unit_test
def test_multivariate_normal_dist_batch_time_shape():
num_var = 2
batch_size = 3
time_steps = 4
means = mx.nd.random.uniform(shape=(batch_size, time_steps, num_var))
# create batch (per time step) of covariance matrices only defined on diagonal
std = mx.nd.array((1, 0.5)).broadcast_like(means).expand_dims(-2)
covar = mx.nd.eye(N=num_var) * std
data = mx.nd.random.uniform(shape=(batch_size, time_steps, num_var))
policy_dist = MultivariateNormalDist(num_var, means, covar)
log_probs = policy_dist.log_prob(data)
assert log_probs.ndim == 2
assert log_probs.shape[0] == batch_size
assert log_probs.shape[1] == time_steps
@pytest.mark.unit_test
def test_multivariate_normal_dist_kl_div():
n_classes = 2
dist_a = MultivariateNormalDist(num_var=n_classes,
mean = mx.nd.array([0.2, 0.8]).expand_dims(0),
sigma = mx.nd.array([[1, 0.5], [0.5, 0.5]]).expand_dims(0))
dist_b = MultivariateNormalDist(num_var=n_classes,
mean = mx.nd.array([0.3, 0.7]).expand_dims(0),
sigma = mx.nd.array([[1, 0.2], [0.2, 0.5]]).expand_dims(0))
actual = dist_a.kl_div(dist_b).asnumpy()
np.testing.assert_almost_equal(actual=actual, desired=0.195100128)
@pytest.mark.unit_test
def test_multivariate_normal_dist_kl_div_batch():
n_classes = 2
dist_a = MultivariateNormalDist(num_var=n_classes,
mean = mx.nd.array([[0.2, 0.8],
[0.2, 0.8]]),
sigma = mx.nd.array([[[1, 0.5], [0.5, 0.5]],
[[1, 0.5], [0.5, 0.5]]]))
dist_b = MultivariateNormalDist(num_var=n_classes,
mean = mx.nd.array([[0.3, 0.7],
[0.3, 0.7]]),
sigma = mx.nd.array([[[1, 0.2], [0.2, 0.5]],
[[1, 0.2], [0.2, 0.5]]]))
actual = dist_a.kl_div(dist_b).asnumpy()
np.testing.assert_almost_equal(actual=actual, desired=[0.195100128, 0.195100128])
@pytest.mark.unit_test
def test_categorical_dist_shape():
num_actions = 2
# actions taken, of shape (batch_size, time_steps)
actions = mx.nd.array((1,))
# action probabilities, of shape (batch_size, time_steps, num_actions)
policy_probs = mx.nd.array((0.8, 0.2))
policy_dist = CategoricalDist(num_actions, policy_probs)
action_probs = policy_dist.log_prob(actions)
assert action_probs.ndim == 1
assert action_probs.shape[0] == 1
@pytest.mark.unit_test
def test_categorical_dist_batch_shape():
batch_size = 3
num_actions = 2
# actions taken, of shape (batch_size, time_steps)
actions = mx.nd.array((0, 1, 0))
# action probabilities, of shape (batch_size, time_steps, num_actions)
policy_probs = mx.nd.array(((0.8, 0.2), (0.5, 0.5), (0.5, 0.5)))
policy_dist = CategoricalDist(num_actions, policy_probs)
action_probs = policy_dist.log_prob(actions)
assert action_probs.ndim == 1
assert action_probs.shape[0] == batch_size
@pytest.mark.unit_test
def test_categorical_dist_batch_time_shape():
batch_size = 3
time_steps = 4
num_actions = 2
# actions taken, of shape (batch_size, time_steps)
actions = mx.nd.array(((0, 1, 0, 0),
(1, 1, 0, 0),
(0, 0, 0, 0)))
# action probabilities, of shape (batch_size, time_steps, num_actions)
policy_probs = mx.nd.array((((0.8, 0.2), (0.5, 0.5), (0.5, 0.5), (0.5, 0.5)),
((0.5, 0.5), (0.5, 0.5), (0.5, 0.5), (0.5, 0.5)),
((0.5, 0.5), (0.5, 0.5), (0.5, 0.5), (0.5, 0.5))))
policy_dist = CategoricalDist(num_actions, policy_probs)
action_probs = policy_dist.log_prob(actions)
assert action_probs.ndim == 2
assert action_probs.shape[0] == batch_size
assert action_probs.shape[1] == time_steps
@pytest.mark.unit_test
def test_categorical_dist_batch():
n_classes = 2
probs = mx.nd.array(((0.8, 0.2),
(0.7, 0.3),
(0.5, 0.5)))
dist = CategoricalDist(n_classes, probs)
# check log_prob
actions = mx.nd.array((0, 1, 0))
manual_log_prob = np.array((-0.22314353, -1.20397282, -0.69314718))
np.testing.assert_almost_equal(actual=dist.log_prob(actions).asnumpy(), desired=manual_log_prob)
# check entropy
sp_entropy = np.array([sp_stats.entropy(pk=(0.8, 0.2)),
sp_stats.entropy(pk=(0.7, 0.3)),
sp_stats.entropy(pk=(0.5, 0.5))])
np.testing.assert_almost_equal(actual=dist.entropy().asnumpy(), desired=sp_entropy)
@pytest.mark.unit_test
def test_categorical_dist_kl_div():
n_classes = 3
dist_a = CategoricalDist(n_classes=n_classes, probs=mx.nd.array([0.4, 0.2, 0.4]))
dist_b = CategoricalDist(n_classes=n_classes, probs=mx.nd.array([0.3, 0.4, 0.3]))
dist_c = CategoricalDist(n_classes=n_classes, probs=mx.nd.array([0.2, 0.6, 0.2]))
dist_d = CategoricalDist(n_classes=n_classes, probs=mx.nd.array([0.0, 1.0, 0.0]))
np.testing.assert_almost_equal(actual=dist_a.kl_div(dist_b).asnumpy(), desired=0.09151624)
np.testing.assert_almost_equal(actual=dist_a.kl_div(dist_c).asnumpy(), desired=0.33479536)
np.testing.assert_almost_equal(actual=dist_c.kl_div(dist_a).asnumpy(), desired=0.38190854)
np.testing.assert_almost_equal(actual=dist_a.kl_div(dist_d).asnumpy(), desired=np.nan)
np.testing.assert_almost_equal(actual=dist_d.kl_div(dist_a).asnumpy(), desired=1.60943782)
@pytest.mark.unit_test
def test_categorical_dist_kl_div_batch():
n_classes = 3
dist_a = CategoricalDist(n_classes=n_classes, probs=mx.nd.array([[0.4, 0.2, 0.4],
[0.4, 0.2, 0.4],
[0.4, 0.2, 0.4]]))
dist_b = CategoricalDist(n_classes=n_classes, probs=mx.nd.array([[0.3, 0.4, 0.3],
[0.3, 0.4, 0.3],
[0.3, 0.4, 0.3]]))
actual = dist_a.kl_div(dist_b).asnumpy()
np.testing.assert_almost_equal(actual=actual, desired=[0.09151624, 0.09151624, 0.09151624])
@pytest.mark.unit_test
def test_clipped_ppo_loss_continuous_batch():
# check lower loss for policy with better probabilities:
# i.e. higher probability on high advantage actions, low probability on low advantage actions.
loss_fn = ClippedPPOLossContinuous(num_actions=2,
clip_likelihood_ratio_using_epsilon=0.2)
loss_fn.initialize()
# actual actions taken, of shape (batch_size)
actions = mx.nd.array(((0.5, -0.5), (0.2, 0.3), (0.4, 2.0)))
# advantages from taking action, of shape (batch_size)
advantages = mx.nd.array((2, -2, 1))
# action probabilities, of shape (batch_size, num_actions)
old_policy_means = mx.nd.array(((1, 0), (0, 0), (-1, 0)))
new_policy_means_worse = mx.nd.array(((2, 0), (0, 0), (-1, 0)))
new_policy_means_better = mx.nd.array(((0.5, 0), (0, 0), (-1, 0)))
policy_stds = mx.nd.array(((1, 1), (1, 1), (1, 1)))
clip_param_rescaler = mx.nd.array((1,))
loss_worse = loss_fn(new_policy_means_worse, policy_stds,
actions, old_policy_means, policy_stds,
clip_param_rescaler, advantages)
loss_better = loss_fn(new_policy_means_better, policy_stds,
actions, old_policy_means, policy_stds,
clip_param_rescaler, advantages)
assert len(loss_worse) == 6 # (LOSS, REGULARIZATION, KL, ENTROPY, LIKELIHOOD_RATIO, CLIPPED_LIKELIHOOD_RATIO)
loss_worse_val = loss_worse[0]
assert loss_worse_val.ndim == 1
assert loss_worse_val.shape[0] == 1
assert len(loss_better) == 6 # (LOSS, REGULARIZATION, KL, ENTROPY, LIKELIHOOD_RATIO, CLIPPED_LIKELIHOOD_RATIO)
loss_better_val = loss_better[0]
assert loss_better_val.ndim == 1
assert loss_better_val.shape[0] == 1
assert loss_worse_val > loss_better_val
@pytest.mark.unit_test
def test_clipped_ppo_loss_discrete_batch():
# check lower loss for policy with better probabilities:
# i.e. higher probability on high advantage actions, low probability on low advantage actions.
loss_fn = ClippedPPOLossDiscrete(num_actions=2,
clip_likelihood_ratio_using_epsilon=None,
use_kl_regularization=True,
initial_kl_coefficient=1)
loss_fn.initialize()
# actual actions taken, of shape (batch_size)
actions = mx.nd.array((0, 1, 0))
# advantages from taking action, of shape (batch_size)
advantages = mx.nd.array((-2, 2, 1))
# action probabilities, of shape (batch_size, num_actions)
old_policy_probs = mx.nd.array(((0.7, 0.3), (0.2, 0.8), (0.4, 0.6)))
new_policy_probs_worse = mx.nd.array(((0.9, 0.1), (0.2, 0.8), (0.4, 0.6)))
new_policy_probs_better = mx.nd.array(((0.5, 0.5), (0.2, 0.8), (0.4, 0.6)))
clip_param_rescaler = mx.nd.array((1,))
loss_worse = loss_fn(new_policy_probs_worse, actions, old_policy_probs, clip_param_rescaler, advantages)
loss_better = loss_fn(new_policy_probs_better, actions, old_policy_probs, clip_param_rescaler, advantages)
assert len(loss_worse) == 6 # (LOSS, REGULARIZATION, KL, ENTROPY, LIKELIHOOD_RATIO, CLIPPED_LIKELIHOOD_RATIO)
lw_loss, lw_reg, lw_kl, lw_ent, lw_lr, lw_clip_lr = loss_worse
assert lw_loss.ndim == 1
assert lw_loss.shape[0] == 1
assert len(loss_better) == 6 # (LOSS, REGULARIZATION, KL, ENTROPY, LIKELIHOOD_RATIO, CLIPPED_LIKELIHOOD_RATIO)
lb_loss, lb_reg, lb_kl, lb_ent, lb_lr, lb_clip_lr = loss_better
assert lb_loss.ndim == 1
assert lb_loss.shape[0] == 1
assert lw_loss > lb_loss
assert lw_kl > lb_kl
@pytest.mark.unit_test
def test_clipped_ppo_loss_discrete_batch_kl_div():
# check lower loss for policy with better probabilities:
# i.e. higher probability on high advantage actions, low probability on low advantage actions.
loss_fn = ClippedPPOLossDiscrete(num_actions=2,
clip_likelihood_ratio_using_epsilon=None,
use_kl_regularization=True,
initial_kl_coefficient=0.5)
loss_fn.initialize()
# actual actions taken, of shape (batch_size)
actions = mx.nd.array((0, 1, 0))
# advantages from taking action, of shape (batch_size)
advantages = mx.nd.array((-2, 2, 1))
# action probabilities, of shape (batch_size, num_actions)
old_policy_probs = mx.nd.array(((0.7, 0.3), (0.2, 0.8), (0.4, 0.6)))
new_policy_probs_worse = mx.nd.array(((0.9, 0.1), (0.2, 0.8), (0.4, 0.6)))
new_policy_probs_better = mx.nd.array(((0.5, 0.5), (0.2, 0.8), (0.4, 0.6)))
clip_param_rescaler = mx.nd.array((1,))
loss_worse = loss_fn(new_policy_probs_worse, actions, old_policy_probs, clip_param_rescaler, advantages)
loss_better = loss_fn(new_policy_probs_better, actions, old_policy_probs, clip_param_rescaler, advantages)
assert len(loss_worse) == 6 # (LOSS, REGULARIZATION, KL, ENTROPY, LIKELIHOOD_RATIO, CLIPPED_LIKELIHOOD_RATIO)
lw_loss, lw_reg, lw_kl, lw_ent, lw_lr, lw_clip_lr = loss_worse
assert lw_kl.ndim == 1
assert lw_kl.shape[0] == 1
assert len(loss_better) == 6 # (LOSS, REGULARIZATION, KL, ENTROPY, LIKELIHOOD_RATIO, CLIPPED_LIKELIHOOD_RATIO)
lb_loss, lb_reg, lb_kl, lb_ent, lb_lr, lb_clip_lr = loss_better
assert lb_kl.ndim == 1
assert lb_kl.shape[0] == 1
assert lw_kl > lb_kl
assert lw_reg > lb_reg
@pytest.mark.unit_test
def test_clipped_ppo_loss_discrete_batch_time():
batch_size = 3
time_steps = 4
num_actions = 2
# actions taken, of shape (batch_size, time_steps)
actions = mx.nd.array(((0, 1, 0, 0),
(1, 1, 0, 0),
(0, 0, 0, 0)))
# advantages from taking action, of shape (batch_size, time_steps)
advantages = mx.nd.array(((-2, 2, 1, 0),
(-1, 1, 0, 1),
(-1, 0, 1, 0)))
# action probabilities, of shape (batch_size, num_actions)
old_policy_probs = mx.nd.array((((0.8, 0.2), (0.5, 0.5), (0.5, 0.5), (0.5, 0.5)),
((0.5, 0.5), (0.5, 0.5), (0.5, 0.5), (0.5, 0.5)),
((0.5, 0.5), (0.5, 0.5), (0.5, 0.5), (0.5, 0.5))))
new_policy_probs_worse = mx.nd.array((((0.9, 0.1), (0.5, 0.5), (0.5, 0.5), (0.5, 0.5)),
((0.5, 0.5), (0.5, 0.5), (0.5, 0.5), (0.5, 0.5)),
((0.5, 0.5), (0.5, 0.5), (0.5, 0.5), (0.5, 0.5))))
new_policy_probs_better = mx.nd.array((((0.2, 0.8), (0.5, 0.5), (0.5, 0.5), (0.5, 0.5)),
((0.5, 0.5), (0.5, 0.5), (0.5, 0.5), (0.5, 0.5)),
((0.5, 0.5), (0.5, 0.5), (0.5, 0.5), (0.5, 0.5))))
# check lower loss for policy with better probabilities:
# i.e. higher probability on high advantage actions, low probability on low advantage actions.
loss_fn = ClippedPPOLossDiscrete(num_actions=num_actions,
clip_likelihood_ratio_using_epsilon=0.2)
loss_fn.initialize()
clip_param_rescaler = mx.nd.array((1,))
loss_worse = loss_fn(new_policy_probs_worse, actions, old_policy_probs, clip_param_rescaler, advantages)
loss_better = loss_fn(new_policy_probs_better, actions, old_policy_probs, clip_param_rescaler, advantages)
assert len(loss_worse) == 6 # (LOSS, REGULARIZATION, KL, ENTROPY, LIKELIHOOD_RATIO, CLIPPED_LIKELIHOOD_RATIO)
loss_worse_val = loss_worse[0]
assert loss_worse_val.ndim == 1
assert loss_worse_val.shape[0] == 1
assert len(loss_better) == 6 # (LOSS, REGULARIZATION, KL, ENTROPY, LIKELIHOOD_RATIO, CLIPPED_LIKELIHOOD_RATIO)
loss_better_val = loss_better[0]
assert loss_better_val.ndim == 1
assert loss_better_val.shape[0] == 1
assert loss_worse_val > loss_better_val
@pytest.mark.unit_test
def test_clipped_ppo_loss_discrete_weight():
actions = mx.nd.array((0, 1, 0))
advantages = mx.nd.array((-2, 2, 1))
old_policy_probs = mx.nd.array(((0.7, 0.3), (0.2, 0.8), (0.4, 0.6)))
new_policy_probs = mx.nd.array(((0.9, 0.1), (0.2, 0.8), (0.4, 0.6)))
clip_param_rescaler = mx.nd.array((1,))
loss_fn = ClippedPPOLossDiscrete(num_actions=2,
clip_likelihood_ratio_using_epsilon=0.2)
loss_fn.initialize()
loss = loss_fn(new_policy_probs, actions, old_policy_probs, clip_param_rescaler, advantages)
loss_fn_weighted = ClippedPPOLossDiscrete(num_actions=2,
clip_likelihood_ratio_using_epsilon=0.2,
weight=0.5)
loss_fn_weighted.initialize()
loss_weighted = loss_fn_weighted(new_policy_probs, actions, old_policy_probs, clip_param_rescaler, advantages)
assert loss[0] == loss_weighted[0] * 2
@pytest.mark.unit_test
def test_clipped_ppo_loss_discrete_hybridize():
loss_fn = ClippedPPOLossDiscrete(num_actions=2,
clip_likelihood_ratio_using_epsilon=0.2)
loss_fn.initialize()
loss_fn.hybridize()
actions = mx.nd.array((0, 1, 0))
advantages = mx.nd.array((-2, 2, 1))
old_policy_probs = mx.nd.array(((0.7, 0.3), (0.2, 0.8), (0.4, 0.6)))
new_policy_probs = mx.nd.array(((0.9, 0.1), (0.2, 0.8), (0.4, 0.6)))
clip_param_rescaler = mx.nd.array((1,))
loss = loss_fn(new_policy_probs, actions, old_policy_probs, clip_param_rescaler, advantages)
assert loss[0] == mx.nd.array((-0.142857153,))
@pytest.mark.unit_test
def test_discrete_ppo_head():
head = DiscretePPOHead(num_actions=2)
head.initialize()
middleware_data = mx.nd.random.uniform(shape=(10, 100))
probs = head(middleware_data)
assert probs.ndim == 2 # (batch_size, num_actions)
assert probs.shape[0] == 10 # since batch_size is 10
assert probs.shape[1] == 2 # since num_actions is 2
@pytest.mark.unit_test
def test_ppo_head():
agent_parameters = ClippedPPOAgentParameters()
num_actions = 5
action_space = DiscreteActionSpace(num_actions=num_actions)
spaces = SpacesDefinition(state=None, goal=None, action=action_space, reward=None)
head = PPOHead(agent_parameters=agent_parameters,
spaces=spaces,
network_name="test_ppo_head")
head.initialize()
batch_size = 15
middleware_data = mx.nd.random.uniform(shape=(batch_size, 100))
probs = head(middleware_data)
assert probs.ndim == 2 # (batch_size, num_actions)
assert probs.shape[0] == batch_size
assert probs.shape[1] == num_actions

View File

@@ -0,0 +1,90 @@
import mxnet as mx
import os
import pytest
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
from rl_coach.architectures.mxnet_components.heads.ppo_v_head import PPOVHead, PPOVHeadLoss
from rl_coach.agents.clipped_ppo_agent import ClippedPPOAlgorithmParameters, ClippedPPOAgentParameters
from rl_coach.spaces import SpacesDefinition, DiscreteActionSpace
@pytest.mark.unit_test
def test_ppo_v_head_loss_batch():
loss_fn = PPOVHeadLoss(clip_likelihood_ratio_using_epsilon=0.1)
total_return = mx.nd.array((5, -3, 0))
old_policy_values = mx.nd.array((3, -1, -1))
new_policy_values_worse = mx.nd.array((2, 0, -1))
new_policy_values_better = mx.nd.array((4, -2, -1))
loss_worse = loss_fn(new_policy_values_worse, old_policy_values, total_return)
loss_better = loss_fn(new_policy_values_better, old_policy_values, total_return)
assert len(loss_worse) == 1 # (LOSS)
loss_worse_val = loss_worse[0]
assert loss_worse_val.ndim == 1
assert loss_worse_val.shape[0] == 1
assert len(loss_better) == 1 # (LOSS)
loss_better_val = loss_better[0]
assert loss_better_val.ndim == 1
assert loss_better_val.shape[0] == 1
assert loss_worse_val > loss_better_val
@pytest.mark.unit_test
def test_ppo_v_head_loss_batch_time():
loss_fn = PPOVHeadLoss(clip_likelihood_ratio_using_epsilon=0.1)
total_return = mx.nd.array(((3, 1, 1, 0),
(1, 0, 0, 1),
(3, 0, 1, 0)))
old_policy_values = mx.nd.array(((2, 1, 1, 0),
(1, 0, 0, 1),
(0, 0, 1, 0)))
new_policy_values_worse = mx.nd.array(((2, 1, 1, 0),
(1, 0, 0, 1),
(2, 0, 1, 0)))
new_policy_values_better = mx.nd.array(((3, 1, 1, 0),
(1, 0, 0, 1),
(2, 0, 1, 0)))
loss_worse = loss_fn(new_policy_values_worse, old_policy_values, total_return)
loss_better = loss_fn(new_policy_values_better, old_policy_values, total_return)
assert len(loss_worse) == 1 # (LOSS)
loss_worse_val = loss_worse[0]
assert loss_worse_val.ndim == 1
assert loss_worse_val.shape[0] == 1
assert len(loss_better) == 1 # (LOSS)
loss_better_val = loss_better[0]
assert loss_better_val.ndim == 1
assert loss_better_val.shape[0] == 1
assert loss_worse_val > loss_better_val
@pytest.mark.unit_test
def test_ppo_v_head_loss_weight():
total_return = mx.nd.array((5, -3, 0))
old_policy_values = mx.nd.array((3, -1, -1))
new_policy_values = mx.nd.array((4, -2, -1))
loss_fn = PPOVHeadLoss(clip_likelihood_ratio_using_epsilon=0.2, weight=1)
loss = loss_fn(new_policy_values, old_policy_values, total_return)
loss_fn_weighted = PPOVHeadLoss(clip_likelihood_ratio_using_epsilon=0.2, weight=0.5)
loss_weighted = loss_fn_weighted(new_policy_values, old_policy_values, total_return)
assert loss[0].sum() == loss_weighted[0].sum() * 2
@pytest.mark.unit_test
def test_ppo_v_head():
agent_parameters = ClippedPPOAgentParameters()
action_space = DiscreteActionSpace(num_actions=5)
spaces = SpacesDefinition(state=None, goal=None, action=action_space, reward=None)
value_net = PPOVHead(agent_parameters=agent_parameters,
spaces=spaces,
network_name="test_ppo_v_head")
value_net.initialize()
batch_size = 15
middleware_data = mx.nd.random.uniform(shape=(batch_size, 100))
values = value_net(middleware_data)
assert values.ndim == 1 # (batch_size)
assert values.shape[0] == batch_size

View File

@@ -0,0 +1,60 @@
import mxnet as mx
import os
import pytest
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
from rl_coach.architectures.mxnet_components.heads.q_head import QHead, QHeadLoss
from rl_coach.agents.clipped_ppo_agent import ClippedPPOAgentParameters
from rl_coach.spaces import SpacesDefinition, DiscreteActionSpace
@pytest.mark.unit_test
def test_q_head_loss():
loss_fn = QHeadLoss()
# example with batch_size of 3, and num_actions of 2
target_q_values = mx.nd.array(((3, 5), (-1, -2), (0, 2)))
pred_q_values_worse = mx.nd.array(((6, 5), (-1, -2), (0, 2)))
pred_q_values_better = mx.nd.array(((4, 5), (-2, -2), (1, 2)))
loss_worse = loss_fn(pred_q_values_worse, target_q_values)
loss_better = loss_fn(pred_q_values_better, target_q_values)
assert len(loss_worse) == 1 # (LOSS)
loss_worse_val = loss_worse[0]
assert loss_worse_val.ndim == 1
assert loss_worse_val.shape[0] == 1
assert len(loss_better) == 1 # (LOSS)
loss_better_val = loss_better[0]
assert loss_better_val.ndim == 1
assert loss_better_val.shape[0] == 1
assert loss_worse_val > loss_better_val
@pytest.mark.unit_test
def test_v_head_loss_weight():
target_q_values = mx.nd.array(((3, 5), (-1, -2), (0, 2)))
pred_q_values = mx.nd.array(((4, 5), (-2, -2), (1, 2)))
loss_fn = QHeadLoss()
loss = loss_fn(pred_q_values, target_q_values)
loss_fn_weighted = QHeadLoss(weight=0.5)
loss_weighted = loss_fn_weighted(pred_q_values, target_q_values)
assert loss[0] == loss_weighted[0]*2
@pytest.mark.unit_test
def test_ppo_v_head():
agent_parameters = ClippedPPOAgentParameters()
num_actions = 5
action_space = DiscreteActionSpace(num_actions=num_actions)
spaces = SpacesDefinition(state=None, goal=None, action=action_space, reward=None)
value_net = QHead(agent_parameters=agent_parameters,
spaces=spaces,
network_name="test_q_head")
value_net.initialize()
batch_size = 15
middleware_data = mx.nd.random.uniform(shape=(batch_size, 100))
values = value_net(middleware_data)
assert values.ndim == 2 # (batch_size, num_actions)
assert values.shape[0] == batch_size
assert values.shape[1] == num_actions

View File

@@ -0,0 +1,57 @@
import mxnet as mx
import os
import pytest
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
from rl_coach.architectures.mxnet_components.heads.v_head import VHead, VHeadLoss
from rl_coach.agents.clipped_ppo_agent import ClippedPPOAlgorithmParameters, ClippedPPOAgentParameters
from rl_coach.spaces import SpacesDefinition, DiscreteActionSpace
@pytest.mark.unit_test
def test_v_head_loss():
loss_fn = VHeadLoss()
target_values = mx.nd.array((3, -1, 0))
pred_values_worse = mx.nd.array((0, 0, 1))
pred_values_better = mx.nd.array((2, -1, 0))
loss_worse = loss_fn(pred_values_worse, target_values)
loss_better = loss_fn(pred_values_better, target_values)
assert len(loss_worse) == 1 # (LOSS)
loss_worse_val = loss_worse[0]
assert loss_worse_val.ndim == 1
assert loss_worse_val.shape[0] == 1
assert len(loss_better) == 1 # (LOSS)
loss_better_val = loss_better[0]
assert loss_better_val.ndim == 1
assert loss_better_val.shape[0] == 1
assert loss_worse_val > loss_better_val
@pytest.mark.unit_test
def test_v_head_loss_weight():
target_values = mx.nd.array((3, -1, 0))
pred_values = mx.nd.array((0, 0, 1))
loss_fn = VHeadLoss()
loss = loss_fn(pred_values, target_values)
loss_fn_weighted = VHeadLoss(weight=0.5)
loss_weighted = loss_fn_weighted(pred_values, target_values)
assert loss[0] == loss_weighted[0]*2
@pytest.mark.unit_test
def test_ppo_v_head():
agent_parameters = ClippedPPOAgentParameters()
action_space = DiscreteActionSpace(num_actions=5)
spaces = SpacesDefinition(state=None, goal=None, action=action_space, reward=None)
value_net = VHead(agent_parameters=agent_parameters,
spaces=spaces,
network_name="test_v_head")
value_net.initialize()
batch_size = 15
middleware_data = mx.nd.random.uniform(shape=(batch_size, 100))
values = value_net(middleware_data)
assert values.ndim == 1 # (batch_size)
assert values.shape[0] == batch_size

View File

@@ -0,0 +1,22 @@
import mxnet as mx
import os
import pytest
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
from rl_coach.base_parameters import MiddlewareScheme
from rl_coach.architectures.middleware_parameters import FCMiddlewareParameters
from rl_coach.architectures.mxnet_components.middlewares.fc_middleware import FCMiddleware
@pytest.mark.unit_test
def test_fc_middleware():
params = FCMiddlewareParameters(scheme=MiddlewareScheme.Medium)
mid = FCMiddleware(params=params)
mid.initialize()
embedded_data = mx.nd.random.uniform(low=0, high=1, shape=(10, 100))
output = mid(embedded_data)
assert output.ndim == 2 # since last block was flatten
assert output.shape[0] == 10 # since batch_size is 10
assert output.shape[1] == 512 # since last layer of middleware (middle scheme) had 512 units

View File

@@ -0,0 +1,25 @@
import mxnet as mx
import os
import pytest
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
from rl_coach.base_parameters import MiddlewareScheme
from rl_coach.architectures.middleware_parameters import LSTMMiddlewareParameters
from rl_coach.architectures.mxnet_components.middlewares.lstm_middleware import LSTMMiddleware
@pytest.mark.unit_test
def test_lstm_middleware():
params = LSTMMiddlewareParameters(number_of_lstm_cells=25, scheme=MiddlewareScheme.Medium)
mid = LSTMMiddleware(params=params)
mid.initialize()
# NTC
embedded_data = mx.nd.random.uniform(low=0, high=1, shape=(10, 15, 20))
# NTC -> TNC
output = mid(embedded_data)
assert output.ndim == 3 # since last block was flatten
assert output.shape[0] == 15 # since t is 15
assert output.shape[1] == 10 # since batch_size is 10
assert output.shape[2] == 25 # since number_of_lstm_cells is 25

View File

@@ -0,0 +1,144 @@
import pytest
import mxnet as mx
from mxnet import nd
import numpy as np
from rl_coach.architectures.mxnet_components.utils import *
@pytest.mark.unit_test
def test_to_mx_ndarray():
# scalar
assert to_mx_ndarray(1.2) == nd.array([1.2])
# list of one scalar
assert to_mx_ndarray([1.2]) == [nd.array([1.2])]
# list of multiple scalars
assert to_mx_ndarray([1.2, 3.4]) == [nd.array([1.2]), nd.array([3.4])]
# list of lists of scalars
assert to_mx_ndarray([[1.2], [3.4]]) == [[nd.array([1.2])], [nd.array([3.4])]]
# numpy
assert np.array_equal(to_mx_ndarray(np.array([[1.2], [3.4]])).asnumpy(), nd.array([[1.2], [3.4]]).asnumpy())
# tuple
assert to_mx_ndarray(((1.2,), (3.4,))) == ((nd.array([1.2]),), (nd.array([3.4]),))
@pytest.mark.unit_test
def test_asnumpy_or_asscalar():
# scalar float32
assert asnumpy_or_asscalar(nd.array([1.2])) == np.float32(1.2)
# scalar int32
assert asnumpy_or_asscalar(nd.array([2], dtype=np.int32)) == np.int32(2)
# list of one scalar
assert asnumpy_or_asscalar([nd.array([1.2])]) == [np.float32(1.2)]
# list of multiple scalars
assert asnumpy_or_asscalar([nd.array([1.2]), nd.array([3.4])]) == [np.float32([1.2]), np.float32([3.4])]
# list of lists of scalars
assert asnumpy_or_asscalar([[nd.array([1.2])], [nd.array([3.4])]]) == [[np.float32([1.2])], [np.float32([3.4])]]
# tensor
assert np.array_equal(asnumpy_or_asscalar(nd.array([[1.2], [3.4]])), np.array([[1.2], [3.4]], dtype=np.float32))
# tuple
assert (asnumpy_or_asscalar(((nd.array([1.2]),), (nd.array([3.4]),))) ==
((np.array([1.2], dtype=np.float32),), (np.array([3.4], dtype=np.float32),)))
@pytest.mark.unit_test
def test_global_norm():
data = list()
for i in range(1, 6):
data.append(np.ones((i * 10, i * 10)) * i)
gnorm = np.asscalar(np.sqrt(sum([np.sum(np.square(d)) for d in data])))
assert np.isclose(gnorm, global_norm([nd.array(d) for d in data]).asscalar())
@pytest.mark.unit_test
def test_split_outputs_per_head():
class TestHead:
def __init__(self, num_outputs):
self.num_outputs = num_outputs
assert split_outputs_per_head((1, 2, 3, 4), [TestHead(2), TestHead(1), TestHead(1)]) == [[1, 2], [3], [4]]
class DummySchema:
def __init__(self, num_head_outputs, num_agent_inputs, num_targets):
self.head_outputs = ['head_output_{}'.format(i) for i in range(num_head_outputs)]
self.agent_inputs = ['agent_input_{}'.format(i) for i in range(num_agent_inputs)]
self.targets = ['target_{}'.format(i) for i in range(num_targets)]
class DummyLoss:
def __init__(self, num_head_outputs, num_agent_inputs, num_targets):
self.input_schema = DummySchema(num_head_outputs, num_agent_inputs, num_targets)
@pytest.mark.unit_test
def test_split_targets_per_loss():
assert split_targets_per_loss([1, 2, 3, 4],
[DummyLoss(10, 100, 2), DummyLoss(20, 200, 1), DummyLoss(30, 300, 1)]) == \
[[1, 2], [3], [4]]
@pytest.mark.unit_test
def test_get_loss_agent_inputs():
input_dict = {'output_0_0': [1, 2], 'output_0_1': [3, 4], 'output_1_0': [5]}
assert get_loss_agent_inputs(input_dict, 0, DummyLoss(10, 2, 100)) == [[1, 2], [3, 4]]
assert get_loss_agent_inputs(input_dict, 1, DummyLoss(20, 1, 200)) == [[5]]
@pytest.mark.unit_test
def test_align_loss_args():
class TestLossFwd(DummyLoss):
def __init__(self, num_targets, num_agent_inputs, num_head_outputs):
super(TestLossFwd, self).__init__(num_targets, num_agent_inputs, num_head_outputs)
def loss_forward(self, F, head_output_2, head_output_1, agent_input_2, target_0, agent_input_1, param1, param2):
pass
assert align_loss_args([1, 2, 3], [4, 5, 6, 7], [8, 9], TestLossFwd(3, 4, 2)) == [3, 2, 6, 8, 5]
@pytest.mark.unit_test
def test_to_tuple():
assert to_tuple(123) == (123,)
assert to_tuple((1, 2, 3)) == (1, 2, 3)
assert to_tuple([1, 2, 3]) == (1, 2, 3)
@pytest.mark.unit_test
def test_to_list():
assert to_list(123) == [123]
assert to_list((1, 2, 3)) == [1, 2, 3]
assert to_list([1, 2, 3]) == [1, 2, 3]
@pytest.mark.unit_test
def test_loss_output_dict():
assert loss_output_dict([1, 2, 3], ['loss', 'loss', 'reg']) == {'loss': [1, 2], 'reg': [3]}
@pytest.mark.unit_test
def test_clip_grad():
a = np.array([1, 2, -3])
b = np.array([4, 5, -6])
clip = 2
gscale = np.minimum(1.0, clip / np.sqrt(np.sum(np.square(a)) + np.sum(np.square(b))))
for lhs, rhs in zip(clip_grad([nd.array(a), nd.array(b)], GradientClippingMethod.ClipByGlobalNorm, clip_val=clip),
[a, b]):
assert np.allclose(lhs.asnumpy(), rhs * gscale)
for lhs, rhs in zip(clip_grad([nd.array(a), nd.array(b)], GradientClippingMethod.ClipByValue, clip_val=clip),
[a, b]):
assert np.allclose(lhs.asnumpy(), np.clip(rhs, -clip, clip))
for lhs, rhs in zip(clip_grad([nd.array(a), nd.array(b)], GradientClippingMethod.ClipByNorm, clip_val=clip),
[a, b]):
scale = np.minimum(1.0, clip / np.sqrt(np.sum(np.square(rhs))))
assert np.allclose(lhs.asnumpy(), rhs * scale)
@pytest.mark.unit_test
def test_hybrid_clip():
x = mx.nd.array((0.5, 1.5, 2.5))
a = mx.nd.array((1,))
b = mx.nd.array((2,))
clipped = hybrid_clip(F=mx.nd, x=x, clip_lower=a, clip_upper=b)
assert (np.isclose(a= clipped.asnumpy(), b=(1, 1.5, 2))).all()

View File

@@ -68,6 +68,17 @@ if not using_GPU:
else:
install_requires.append('tensorflow-gpu==1.9.0')
# Framework-specific dependencies.
extras = {
'mxnet': ['mxnet-cu90mkl>=1.3.0']
}
all_deps = []
for group_name in extras:
all_deps += extras[group_name]
extras['all'] = all_deps
setup(
name='rl-coach',
version='0.10.0',
@@ -78,6 +89,7 @@ setup(
packages=find_packages(),
python_requires=">=3.5.*",
install_requires=install_requires,
extras_require=extras,
package_data={'rl_coach': ['dashboard_components/*.css',
'environments/doom/*.cfg',
'environments/doom/*.wad',