1
0
mirror of https://github.com/gryf/coach.git synced 2026-04-09 06:33:33 +02:00

Added Custom Initialisation for MXNet Heads (#86)

* Added NormalizedRSSInitializer, using same method as TensorFlow backend, but changed name since ‘columns’ have different meaning in dense layer weight matrix in MXNet.

* Added unit test for NormalizedRSSInitializer.
This commit is contained in:
Thom Lane
2018-11-16 08:15:43 -08:00
committed by Scott Leishman
parent 101c55d37d
commit 81bac050d7
5 changed files with 59 additions and 8 deletions

View File

@@ -1,5 +1,7 @@
from typing import Dict, List, Union, Tuple
import mxnet as mx
from mxnet.initializer import Initializer, register
from mxnet.gluon import nn, loss
from mxnet.ndarray import NDArray
from mxnet.symbol import Symbol
@@ -11,6 +13,26 @@ LOSS_OUT_TYPE_LOSS = 'loss'
LOSS_OUT_TYPE_REGULARIZATION = 'regularization'
@register
class NormalizedRSSInitializer(Initializer):
"""
Standardizes Root Sum of Squares along the input channel dimension.
Used for Dense layer weight matrices only (ie. do not use on Convolution kernels).
MXNet Dense layer weight matrix is of shape (out_ch, in_ch), so standardize across axis 1.
Root Sum of Squares set to `rss`, which is 1.0 by default.
Called `normalized_columns_initializer` in TensorFlow backend (but we work with rows instead of columns for MXNet).
"""
def __init__(self, rss=1.0):
super(NormalizedRSSInitializer, self).__init__(rss=rss)
self.rss = float(rss)
def _init_weight(self, name, arr):
mx.nd.random.normal(0, 1, out=arr)
sample_rss = arr.square().sum(axis=1).sqrt()
scalers = self.rss / sample_rss
arr *= scalers.expand_dims(1)
class LossInputSchema(object):
"""
Helper class to contain schema for loss hybrid_forward input