1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-18 03:30:19 +01:00
Files
coach/rl_coach/architectures/mxnet_components/embedders/embedder.py
Sina Afrooze 5fadb9c18e Adding mxnet components to rl_coach/architectures (#60)
Adding mxnet components to rl_coach architectures.

- Supports PPO and DQN
- Tested with CartPole_PPO and CarPole_DQN
- Normalizing filters don't work right now (see #49) and are disabled in CartPole_PPO preset
- Checkpointing is disabled for MXNet
2018-11-07 17:07:15 +02:00

72 lines
3.3 KiB
Python

from typing import Union
from types import ModuleType
import mxnet as mx
from mxnet.gluon import nn
from rl_coach.architectures.embedder_parameters import InputEmbedderParameters
from rl_coach.architectures.mxnet_components.layers import convert_layer
from rl_coach.base_parameters import EmbedderScheme
nd_sym_type = Union[mx.nd.NDArray, mx.sym.Symbol]
class InputEmbedder(nn.HybridBlock):
def __init__(self, params: InputEmbedderParameters):
"""
An input embedder is the first part of the network, which takes the input from the state and produces a vector
embedding by passing it through a neural network. The embedder will mostly be input type dependent, and there
can be multiple embedders in a single network.
:param params: parameters object containing input_clipping, input_rescaling, batchnorm, activation_function
and dropout properties.
"""
super(InputEmbedder, self).__init__()
self.embedder_name = params.name
self.input_clipping = params.input_clipping
self.scheme = params.scheme
with self.name_scope():
self.net = nn.HybridSequential()
if isinstance(self.scheme, EmbedderScheme):
blocks = self.schemes[self.scheme]
else:
# if scheme is specified directly, convert to MX layer if it's not a callable object
# NOTE: if layer object is callable, it must return a gluon block when invoked
blocks = [convert_layer(l) for l in self.scheme]
for block in blocks:
self.net.add(block())
if params.batchnorm:
self.net.add(nn.BatchNorm())
if params.activation_function:
self.net.add(nn.Activation(params.activation_function))
if params.dropout:
self.net.add(nn.Dropout(rate=params.dropout))
@property
def schemes(self) -> dict:
"""
Schemes are the pre-defined network architectures of various depths and complexities that can be used for the
InputEmbedder. Should be implemented in child classes, and are used to create Block when InputEmbedder is
initialised.
:return: dictionary of schemes, with key of type EmbedderScheme enum and value being list of mxnet.gluon.Block.
"""
raise NotImplementedError("Inheriting embedder must define schemes matching its allowed default "
"configurations.")
def hybrid_forward(self, F: ModuleType, x: nd_sym_type, *args, **kwargs) -> nd_sym_type:
"""
Used for forward pass through embedder network.
:param F: backend api, either `mxnet.nd` or `mxnet.sym` (if block has been hybridized).
:param x: environment state, where first dimension is batch_size, then dimensions are data type dependent.
:return: embedding of environment state, where shape is (batch_size, channels).
"""
# `input_rescaling` and `input_offset` set on inheriting embedder
x = x / self.input_rescaling
x = x - self.input_offset
if self.input_clipping is not None:
x.clip(a_min=self.input_clipping[0], a_max=self.input_clipping[1])
x = self.net(x)
return x.flatten()