1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-18 19:50:17 +01:00

Adding mxnet components to rl_coach/architectures (#60)

Adding mxnet components to rl_coach architectures.

- Supports PPO and DQN
- Tested with CartPole_PPO and CarPole_DQN
- Normalizing filters don't work right now (see #49) and are disabled in CartPole_PPO preset
- Checkpointing is disabled for MXNet
This commit is contained in:
Sina Afrooze
2018-11-07 07:07:15 -08:00
committed by Itai Caspi
parent e7a91b4dc3
commit 5fadb9c18e
39 changed files with 3864 additions and 44 deletions

View File

@@ -0,0 +1,99 @@
"""
Module implementing basic layers in mxnet
"""
from types import FunctionType
from mxnet.gluon import nn
from rl_coach.architectures import layers
from rl_coach.architectures.mxnet_components import utils
# define global dictionary for storing layer type to layer implementation mapping
mx_layer_dict = dict()
def reg_to_mx(layer_type) -> FunctionType:
""" function decorator that registers layer implementation
:return: decorated function
"""
def reg_impl_decorator(func):
assert layer_type not in mx_layer_dict
mx_layer_dict[layer_type] = func
return func
return reg_impl_decorator
def convert_layer(layer):
"""
If layer is callable, return layer, otherwise convert to MX type
:param layer: layer to be converted
:return: converted layer if not callable, otherwise layer itself
"""
if callable(layer):
return layer
return mx_layer_dict[type(layer)](layer)
class Conv2d(layers.Conv2d):
def __init__(self, num_filters: int, kernel_size: int, strides: int):
super(Conv2d, self).__init__(num_filters=num_filters, kernel_size=kernel_size, strides=strides)
def __call__(self) -> nn.Conv2D:
"""
returns a conv2d block
:return: conv2d block
"""
return nn.Conv2D(channels=self.num_filters, kernel_size=self.kernel_size, strides=self.strides)
@staticmethod
@reg_to_mx(layers.Conv2d)
def to_mx(base: layers.Conv2d):
return Conv2d(num_filters=base.num_filters, kernel_size=base.kernel_size, strides=base.strides)
class BatchnormActivationDropout(layers.BatchnormActivationDropout):
def __init__(self, batchnorm: bool=False, activation_function=None, dropout_rate: float=0):
super(BatchnormActivationDropout, self).__init__(
batchnorm=batchnorm, activation_function=activation_function, dropout_rate=dropout_rate)
def __call__(self):
"""
returns a list of mxnet batchnorm, activation and dropout layers
:return: batchnorm, activation and dropout layers
"""
block = nn.HybridSequential()
if self.batchnorm:
block.add(nn.BatchNorm())
if self.activation_function:
block.add(nn.Activation(activation=utils.get_mxnet_activation_name(self.activation_function)))
if self.dropout_rate:
block.add(nn.Dropout(self.dropout_rate))
return block
@staticmethod
@reg_to_mx(layers.BatchnormActivationDropout)
def to_mx(base: layers.BatchnormActivationDropout):
return BatchnormActivationDropout(
batchnorm=base.batchnorm,
activation_function=base.activation_function,
dropout_rate=base.dropout_rate)
class Dense(layers.Dense):
def __init__(self, units: int):
super(Dense, self).__init__(units=units)
def __call__(self):
"""
returns a mxnet dense layer
:return: dense layer
"""
# Set flatten to False for consistent behavior with tf.layers.dense
return nn.Dense(self.units, flatten=False)
@staticmethod
@reg_to_mx(layers.Dense)
def to_mx(base: layers.Dense):
return Dense(units=base.units)