mirror of
https://github.com/gryf/coach.git
synced 2025-12-18 03:30:19 +01:00
Adding mxnet components to rl_coach architectures. - Supports PPO and DQN - Tested with CartPole_PPO and CarPole_DQN - Normalizing filters don't work right now (see #49) and are disabled in CartPole_PPO preset - Checkpointing is disabled for MXNet
100 lines
3.1 KiB
Python
100 lines
3.1 KiB
Python
"""
|
|
Module implementing basic layers in mxnet
|
|
"""
|
|
|
|
from types import FunctionType
|
|
|
|
from mxnet.gluon import nn
|
|
|
|
from rl_coach.architectures import layers
|
|
from rl_coach.architectures.mxnet_components import utils
|
|
|
|
|
|
# define global dictionary for storing layer type to layer implementation mapping
|
|
mx_layer_dict = dict()
|
|
|
|
|
|
def reg_to_mx(layer_type) -> FunctionType:
|
|
""" function decorator that registers layer implementation
|
|
:return: decorated function
|
|
"""
|
|
def reg_impl_decorator(func):
|
|
assert layer_type not in mx_layer_dict
|
|
mx_layer_dict[layer_type] = func
|
|
return func
|
|
return reg_impl_decorator
|
|
|
|
|
|
def convert_layer(layer):
|
|
"""
|
|
If layer is callable, return layer, otherwise convert to MX type
|
|
:param layer: layer to be converted
|
|
:return: converted layer if not callable, otherwise layer itself
|
|
"""
|
|
if callable(layer):
|
|
return layer
|
|
return mx_layer_dict[type(layer)](layer)
|
|
|
|
|
|
class Conv2d(layers.Conv2d):
|
|
def __init__(self, num_filters: int, kernel_size: int, strides: int):
|
|
super(Conv2d, self).__init__(num_filters=num_filters, kernel_size=kernel_size, strides=strides)
|
|
|
|
def __call__(self) -> nn.Conv2D:
|
|
"""
|
|
returns a conv2d block
|
|
:return: conv2d block
|
|
"""
|
|
return nn.Conv2D(channels=self.num_filters, kernel_size=self.kernel_size, strides=self.strides)
|
|
|
|
@staticmethod
|
|
@reg_to_mx(layers.Conv2d)
|
|
def to_mx(base: layers.Conv2d):
|
|
return Conv2d(num_filters=base.num_filters, kernel_size=base.kernel_size, strides=base.strides)
|
|
|
|
|
|
class BatchnormActivationDropout(layers.BatchnormActivationDropout):
|
|
def __init__(self, batchnorm: bool=False, activation_function=None, dropout_rate: float=0):
|
|
super(BatchnormActivationDropout, self).__init__(
|
|
batchnorm=batchnorm, activation_function=activation_function, dropout_rate=dropout_rate)
|
|
|
|
def __call__(self):
|
|
"""
|
|
returns a list of mxnet batchnorm, activation and dropout layers
|
|
:return: batchnorm, activation and dropout layers
|
|
"""
|
|
block = nn.HybridSequential()
|
|
if self.batchnorm:
|
|
block.add(nn.BatchNorm())
|
|
if self.activation_function:
|
|
block.add(nn.Activation(activation=utils.get_mxnet_activation_name(self.activation_function)))
|
|
if self.dropout_rate:
|
|
block.add(nn.Dropout(self.dropout_rate))
|
|
return block
|
|
|
|
@staticmethod
|
|
@reg_to_mx(layers.BatchnormActivationDropout)
|
|
def to_mx(base: layers.BatchnormActivationDropout):
|
|
return BatchnormActivationDropout(
|
|
batchnorm=base.batchnorm,
|
|
activation_function=base.activation_function,
|
|
dropout_rate=base.dropout_rate)
|
|
|
|
|
|
class Dense(layers.Dense):
|
|
def __init__(self, units: int):
|
|
super(Dense, self).__init__(units=units)
|
|
|
|
def __call__(self):
|
|
"""
|
|
returns a mxnet dense layer
|
|
:return: dense layer
|
|
"""
|
|
# Set flatten to False for consistent behavior with tf.layers.dense
|
|
return nn.Dense(self.units, flatten=False)
|
|
|
|
@staticmethod
|
|
@reg_to_mx(layers.Dense)
|
|
def to_mx(base: layers.Dense):
|
|
return Dense(units=base.units)
|