1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 19:20:19 +01:00
Files
coach/rl_coach/architectures/mxnet_components/layers.py
2018-11-27 22:43:40 +02:00

117 lines
3.7 KiB
Python

#
# Copyright (c) 2017 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
Module implementing basic layers in mxnet
"""
from types import FunctionType
from mxnet.gluon import nn
from rl_coach.architectures import layers
from rl_coach.architectures.mxnet_components import utils
# define global dictionary for storing layer type to layer implementation mapping
mx_layer_dict = dict()
def reg_to_mx(layer_type) -> FunctionType:
""" function decorator that registers layer implementation
:return: decorated function
"""
def reg_impl_decorator(func):
assert layer_type not in mx_layer_dict
mx_layer_dict[layer_type] = func
return func
return reg_impl_decorator
def convert_layer(layer):
"""
If layer is callable, return layer, otherwise convert to MX type
:param layer: layer to be converted
:return: converted layer if not callable, otherwise layer itself
"""
if callable(layer):
return layer
return mx_layer_dict[type(layer)](layer)
class Conv2d(layers.Conv2d):
def __init__(self, num_filters: int, kernel_size: int, strides: int):
super(Conv2d, self).__init__(num_filters=num_filters, kernel_size=kernel_size, strides=strides)
def __call__(self) -> nn.Conv2D:
"""
returns a conv2d block
:return: conv2d block
"""
return nn.Conv2D(channels=self.num_filters, kernel_size=self.kernel_size, strides=self.strides)
@staticmethod
@reg_to_mx(layers.Conv2d)
def to_mx(base: layers.Conv2d):
return Conv2d(num_filters=base.num_filters, kernel_size=base.kernel_size, strides=base.strides)
class BatchnormActivationDropout(layers.BatchnormActivationDropout):
def __init__(self, batchnorm: bool=False, activation_function=None, dropout_rate: float=0):
super(BatchnormActivationDropout, self).__init__(
batchnorm=batchnorm, activation_function=activation_function, dropout_rate=dropout_rate)
def __call__(self):
"""
returns a list of mxnet batchnorm, activation and dropout layers
:return: batchnorm, activation and dropout layers
"""
block = nn.HybridSequential()
if self.batchnorm:
block.add(nn.BatchNorm())
if self.activation_function:
block.add(nn.Activation(activation=utils.get_mxnet_activation_name(self.activation_function)))
if self.dropout_rate:
block.add(nn.Dropout(self.dropout_rate))
return block
@staticmethod
@reg_to_mx(layers.BatchnormActivationDropout)
def to_mx(base: layers.BatchnormActivationDropout):
return BatchnormActivationDropout(
batchnorm=base.batchnorm,
activation_function=base.activation_function,
dropout_rate=base.dropout_rate)
class Dense(layers.Dense):
def __init__(self, units: int):
super(Dense, self).__init__(units=units)
def __call__(self):
"""
returns a mxnet dense layer
:return: dense layer
"""
# Set flatten to False for consistent behavior with tf.layers.dense
return nn.Dense(self.units, flatten=False)
@staticmethod
@reg_to_mx(layers.Dense)
def to_mx(base: layers.Dense):
return Dense(units=base.units)