mirror of
https://github.com/gryf/coach.git
synced 2025-12-18 19:50:17 +01:00
Adding mxnet components to rl_coach/architectures (#60)
Adding mxnet components to rl_coach architectures. - Supports PPO and DQN - Tested with CartPole_PPO and CarPole_DQN - Normalizing filters don't work right now (see #49) and are disabled in CartPole_PPO preset - Checkpointing is disabled for MXNet
This commit is contained in:
99
rl_coach/architectures/mxnet_components/layers.py
Normal file
99
rl_coach/architectures/mxnet_components/layers.py
Normal file
@@ -0,0 +1,99 @@
|
||||
"""
|
||||
Module implementing basic layers in mxnet
|
||||
"""
|
||||
|
||||
from types import FunctionType
|
||||
|
||||
from mxnet.gluon import nn
|
||||
|
||||
from rl_coach.architectures import layers
|
||||
from rl_coach.architectures.mxnet_components import utils
|
||||
|
||||
|
||||
# define global dictionary for storing layer type to layer implementation mapping
|
||||
mx_layer_dict = dict()
|
||||
|
||||
|
||||
def reg_to_mx(layer_type) -> FunctionType:
|
||||
""" function decorator that registers layer implementation
|
||||
:return: decorated function
|
||||
"""
|
||||
def reg_impl_decorator(func):
|
||||
assert layer_type not in mx_layer_dict
|
||||
mx_layer_dict[layer_type] = func
|
||||
return func
|
||||
return reg_impl_decorator
|
||||
|
||||
|
||||
def convert_layer(layer):
|
||||
"""
|
||||
If layer is callable, return layer, otherwise convert to MX type
|
||||
:param layer: layer to be converted
|
||||
:return: converted layer if not callable, otherwise layer itself
|
||||
"""
|
||||
if callable(layer):
|
||||
return layer
|
||||
return mx_layer_dict[type(layer)](layer)
|
||||
|
||||
|
||||
class Conv2d(layers.Conv2d):
|
||||
def __init__(self, num_filters: int, kernel_size: int, strides: int):
|
||||
super(Conv2d, self).__init__(num_filters=num_filters, kernel_size=kernel_size, strides=strides)
|
||||
|
||||
def __call__(self) -> nn.Conv2D:
|
||||
"""
|
||||
returns a conv2d block
|
||||
:return: conv2d block
|
||||
"""
|
||||
return nn.Conv2D(channels=self.num_filters, kernel_size=self.kernel_size, strides=self.strides)
|
||||
|
||||
@staticmethod
|
||||
@reg_to_mx(layers.Conv2d)
|
||||
def to_mx(base: layers.Conv2d):
|
||||
return Conv2d(num_filters=base.num_filters, kernel_size=base.kernel_size, strides=base.strides)
|
||||
|
||||
|
||||
class BatchnormActivationDropout(layers.BatchnormActivationDropout):
|
||||
def __init__(self, batchnorm: bool=False, activation_function=None, dropout_rate: float=0):
|
||||
super(BatchnormActivationDropout, self).__init__(
|
||||
batchnorm=batchnorm, activation_function=activation_function, dropout_rate=dropout_rate)
|
||||
|
||||
def __call__(self):
|
||||
"""
|
||||
returns a list of mxnet batchnorm, activation and dropout layers
|
||||
:return: batchnorm, activation and dropout layers
|
||||
"""
|
||||
block = nn.HybridSequential()
|
||||
if self.batchnorm:
|
||||
block.add(nn.BatchNorm())
|
||||
if self.activation_function:
|
||||
block.add(nn.Activation(activation=utils.get_mxnet_activation_name(self.activation_function)))
|
||||
if self.dropout_rate:
|
||||
block.add(nn.Dropout(self.dropout_rate))
|
||||
return block
|
||||
|
||||
@staticmethod
|
||||
@reg_to_mx(layers.BatchnormActivationDropout)
|
||||
def to_mx(base: layers.BatchnormActivationDropout):
|
||||
return BatchnormActivationDropout(
|
||||
batchnorm=base.batchnorm,
|
||||
activation_function=base.activation_function,
|
||||
dropout_rate=base.dropout_rate)
|
||||
|
||||
|
||||
class Dense(layers.Dense):
|
||||
def __init__(self, units: int):
|
||||
super(Dense, self).__init__(units=units)
|
||||
|
||||
def __call__(self):
|
||||
"""
|
||||
returns a mxnet dense layer
|
||||
:return: dense layer
|
||||
"""
|
||||
# Set flatten to False for consistent behavior with tf.layers.dense
|
||||
return nn.Dense(self.units, flatten=False)
|
||||
|
||||
@staticmethod
|
||||
@reg_to_mx(layers.Dense)
|
||||
def to_mx(base: layers.Dense):
|
||||
return Dense(units=base.units)
|
||||
Reference in New Issue
Block a user