From 81bac050d7efc6d37dd28f3a287e6c21190af5cc Mon Sep 17 00:00:00 2001 From: Thom Lane Date: Fri, 16 Nov 2018 08:15:43 -0800 Subject: [PATCH] Added Custom Initialisation for MXNet Heads (#86) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Added NormalizedRSSInitializer, using same method as TensorFlow backend, but changed name since ‘columns’ have different meaning in dense layer weight matrix in MXNet. * Added unit test for NormalizedRSSInitializer. --- .../mxnet_components/heads/head.py | 22 ++++++++++++++++ .../mxnet_components/heads/ppo_head.py | 10 +++++--- .../mxnet_components/heads/ppo_v_head.py | 5 ++-- .../mxnet_components/heads/v_head.py | 5 ++-- .../mxnet_components/heads/test_head.py | 25 +++++++++++++++++++ 5 files changed, 59 insertions(+), 8 deletions(-) create mode 100644 rl_coach/tests/architectures/mxnet_components/heads/test_head.py diff --git a/rl_coach/architectures/mxnet_components/heads/head.py b/rl_coach/architectures/mxnet_components/heads/head.py index 4a83152..d6757ba 100644 --- a/rl_coach/architectures/mxnet_components/heads/head.py +++ b/rl_coach/architectures/mxnet_components/heads/head.py @@ -1,5 +1,7 @@ from typing import Dict, List, Union, Tuple +import mxnet as mx +from mxnet.initializer import Initializer, register from mxnet.gluon import nn, loss from mxnet.ndarray import NDArray from mxnet.symbol import Symbol @@ -11,6 +13,26 @@ LOSS_OUT_TYPE_LOSS = 'loss' LOSS_OUT_TYPE_REGULARIZATION = 'regularization' +@register +class NormalizedRSSInitializer(Initializer): + """ + Standardizes Root Sum of Squares along the input channel dimension. + Used for Dense layer weight matrices only (ie. do not use on Convolution kernels). + MXNet Dense layer weight matrix is of shape (out_ch, in_ch), so standardize across axis 1. + Root Sum of Squares set to `rss`, which is 1.0 by default. + Called `normalized_columns_initializer` in TensorFlow backend (but we work with rows instead of columns for MXNet). + """ + def __init__(self, rss=1.0): + super(NormalizedRSSInitializer, self).__init__(rss=rss) + self.rss = float(rss) + + def _init_weight(self, name, arr): + mx.nd.random.normal(0, 1, out=arr) + sample_rss = arr.square().sum(axis=1).sqrt() + scalers = self.rss / sample_rss + arr *= scalers.expand_dims(1) + + class LossInputSchema(object): """ Helper class to contain schema for loss hybrid_forward input diff --git a/rl_coach/architectures/mxnet_components/heads/ppo_head.py b/rl_coach/architectures/mxnet_components/heads/ppo_head.py index a6e65e3..269aec6 100644 --- a/rl_coach/architectures/mxnet_components/heads/ppo_head.py +++ b/rl_coach/architectures/mxnet_components/heads/ppo_head.py @@ -8,7 +8,8 @@ from rl_coach.base_parameters import AgentParameters from rl_coach.core_types import ActionProbabilities from rl_coach.spaces import SpacesDefinition, BoxActionSpace, DiscreteActionSpace from rl_coach.utils import eps -from rl_coach.architectures.mxnet_components.heads.head import Head, HeadLoss, LossInputSchema +from rl_coach.architectures.mxnet_components.heads.head import Head, HeadLoss, LossInputSchema,\ + NormalizedRSSInitializer from rl_coach.architectures.mxnet_components.heads.head import LOSS_OUT_TYPE_LOSS, LOSS_OUT_TYPE_REGULARIZATION from rl_coach.architectures.mxnet_components.utils import hybrid_clip @@ -229,7 +230,8 @@ class DiscretePPOHead(nn.HybridBlock): """ super(DiscretePPOHead, self).__init__() with self.name_scope(): - self.dense = nn.Dense(units=num_actions, flatten=False) + self.dense = nn.Dense(units=num_actions, flatten=False, + weight_initializer=NormalizedRSSInitializer(0.01)) def hybrid_forward(self, F: ModuleType, x: nd_sym_type) -> nd_sym_type: """ @@ -258,8 +260,8 @@ class ContinuousPPOHead(nn.HybridBlock): """ super(ContinuousPPOHead, self).__init__() with self.name_scope(): - # todo: change initialization strategy - self.dense = nn.Dense(units=num_actions, flatten=False) + self.dense = nn.Dense(units=num_actions, flatten=False, + weight_initializer=NormalizedRSSInitializer(0.01)) # all samples (across batch, and time step) share the same covariance, which is learnt, # but since we assume the action probability variables are independent, # only the diagonal entries of the covariance matrix are specified. diff --git a/rl_coach/architectures/mxnet_components/heads/ppo_v_head.py b/rl_coach/architectures/mxnet_components/heads/ppo_v_head.py index 5512b90..7b675e4 100644 --- a/rl_coach/architectures/mxnet_components/heads/ppo_v_head.py +++ b/rl_coach/architectures/mxnet_components/heads/ppo_v_head.py @@ -3,7 +3,8 @@ from types import ModuleType import mxnet as mx from mxnet.gluon import nn -from rl_coach.architectures.mxnet_components.heads.head import Head, HeadLoss, LossInputSchema +from rl_coach.architectures.mxnet_components.heads.head import Head, HeadLoss, LossInputSchema,\ + NormalizedRSSInitializer from rl_coach.architectures.mxnet_components.heads.head import LOSS_OUT_TYPE_LOSS from rl_coach.base_parameters import AgentParameters from rl_coach.core_types import ActionProbabilities @@ -102,7 +103,7 @@ class PPOVHead(Head): self.clip_likelihood_ratio_using_epsilon = agent_parameters.algorithm.clip_likelihood_ratio_using_epsilon self.return_type = ActionProbabilities with self.name_scope(): - self.dense = nn.Dense(units=1) + self.dense = nn.Dense(units=1, weight_initializer=NormalizedRSSInitializer(1.0)) def hybrid_forward(self, F: ModuleType, x: nd_sym_type) -> nd_sym_type: """ diff --git a/rl_coach/architectures/mxnet_components/heads/v_head.py b/rl_coach/architectures/mxnet_components/heads/v_head.py index a04cafd..cfa765e 100644 --- a/rl_coach/architectures/mxnet_components/heads/v_head.py +++ b/rl_coach/architectures/mxnet_components/heads/v_head.py @@ -4,7 +4,8 @@ from types import ModuleType import mxnet as mx from mxnet.gluon.loss import Loss, HuberLoss, L2Loss from mxnet.gluon import nn -from rl_coach.architectures.mxnet_components.heads.head import Head, HeadLoss, LossInputSchema +from rl_coach.architectures.mxnet_components.heads.head import Head, HeadLoss, LossInputSchema,\ + NormalizedRSSInitializer from rl_coach.architectures.mxnet_components.heads.head import LOSS_OUT_TYPE_LOSS from rl_coach.base_parameters import AgentParameters from rl_coach.core_types import VStateValue @@ -79,7 +80,7 @@ class VHead(Head): self.loss_type = loss_type self.return_type = VStateValue with self.name_scope(): - self.dense = nn.Dense(units=1) + self.dense = nn.Dense(units=1, weight_initializer=NormalizedRSSInitializer(1.0)) def loss(self) -> Loss: """ diff --git a/rl_coach/tests/architectures/mxnet_components/heads/test_head.py b/rl_coach/tests/architectures/mxnet_components/heads/test_head.py new file mode 100644 index 0000000..615d070 --- /dev/null +++ b/rl_coach/tests/architectures/mxnet_components/heads/test_head.py @@ -0,0 +1,25 @@ +import mxnet as mx +import numpy as np +import os +import pytest +import sys +sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))) + + +from rl_coach.architectures.mxnet_components.heads.head import NormalizedRSSInitializer + + +@pytest.mark.unit_test +def test_normalized_rss_initializer(): + target_rss = 0.5 + units = 10 + dense = mx.gluon.nn.Dense(units=units, weight_initializer=NormalizedRSSInitializer(target_rss)) + dense.initialize() + + input_data = mx.random.uniform(shape=(25, 5)) + output_data = dense(input_data) + + weights = dense.weight.data() + assert weights.shape == (10, 5) + rss = weights.square().sum(axis=1).sqrt() + np.testing.assert_almost_equal(rss.asnumpy(), np.tile(target_rss, units))