1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 19:20:19 +01:00

Added Custom Initialisation for MXNet Heads (#86)

* Added NormalizedRSSInitializer, using same method as TensorFlow backend, but changed name since ‘columns’ have different meaning in dense layer weight matrix in MXNet.

* Added unit test for NormalizedRSSInitializer.
This commit is contained in:
Thom Lane
2018-11-16 08:15:43 -08:00
committed by Scott Leishman
parent 101c55d37d
commit 81bac050d7
5 changed files with 59 additions and 8 deletions

View File

@@ -1,5 +1,7 @@
from typing import Dict, List, Union, Tuple
import mxnet as mx
from mxnet.initializer import Initializer, register
from mxnet.gluon import nn, loss
from mxnet.ndarray import NDArray
from mxnet.symbol import Symbol
@@ -11,6 +13,26 @@ LOSS_OUT_TYPE_LOSS = 'loss'
LOSS_OUT_TYPE_REGULARIZATION = 'regularization'
@register
class NormalizedRSSInitializer(Initializer):
"""
Standardizes Root Sum of Squares along the input channel dimension.
Used for Dense layer weight matrices only (ie. do not use on Convolution kernels).
MXNet Dense layer weight matrix is of shape (out_ch, in_ch), so standardize across axis 1.
Root Sum of Squares set to `rss`, which is 1.0 by default.
Called `normalized_columns_initializer` in TensorFlow backend (but we work with rows instead of columns for MXNet).
"""
def __init__(self, rss=1.0):
super(NormalizedRSSInitializer, self).__init__(rss=rss)
self.rss = float(rss)
def _init_weight(self, name, arr):
mx.nd.random.normal(0, 1, out=arr)
sample_rss = arr.square().sum(axis=1).sqrt()
scalers = self.rss / sample_rss
arr *= scalers.expand_dims(1)
class LossInputSchema(object):
"""
Helper class to contain schema for loss hybrid_forward input

View File

@@ -8,7 +8,8 @@ from rl_coach.base_parameters import AgentParameters
from rl_coach.core_types import ActionProbabilities
from rl_coach.spaces import SpacesDefinition, BoxActionSpace, DiscreteActionSpace
from rl_coach.utils import eps
from rl_coach.architectures.mxnet_components.heads.head import Head, HeadLoss, LossInputSchema
from rl_coach.architectures.mxnet_components.heads.head import Head, HeadLoss, LossInputSchema,\
NormalizedRSSInitializer
from rl_coach.architectures.mxnet_components.heads.head import LOSS_OUT_TYPE_LOSS, LOSS_OUT_TYPE_REGULARIZATION
from rl_coach.architectures.mxnet_components.utils import hybrid_clip
@@ -229,7 +230,8 @@ class DiscretePPOHead(nn.HybridBlock):
"""
super(DiscretePPOHead, self).__init__()
with self.name_scope():
self.dense = nn.Dense(units=num_actions, flatten=False)
self.dense = nn.Dense(units=num_actions, flatten=False,
weight_initializer=NormalizedRSSInitializer(0.01))
def hybrid_forward(self, F: ModuleType, x: nd_sym_type) -> nd_sym_type:
"""
@@ -258,8 +260,8 @@ class ContinuousPPOHead(nn.HybridBlock):
"""
super(ContinuousPPOHead, self).__init__()
with self.name_scope():
# todo: change initialization strategy
self.dense = nn.Dense(units=num_actions, flatten=False)
self.dense = nn.Dense(units=num_actions, flatten=False,
weight_initializer=NormalizedRSSInitializer(0.01))
# all samples (across batch, and time step) share the same covariance, which is learnt,
# but since we assume the action probability variables are independent,
# only the diagonal entries of the covariance matrix are specified.

View File

@@ -3,7 +3,8 @@ from types import ModuleType
import mxnet as mx
from mxnet.gluon import nn
from rl_coach.architectures.mxnet_components.heads.head import Head, HeadLoss, LossInputSchema
from rl_coach.architectures.mxnet_components.heads.head import Head, HeadLoss, LossInputSchema,\
NormalizedRSSInitializer
from rl_coach.architectures.mxnet_components.heads.head import LOSS_OUT_TYPE_LOSS
from rl_coach.base_parameters import AgentParameters
from rl_coach.core_types import ActionProbabilities
@@ -102,7 +103,7 @@ class PPOVHead(Head):
self.clip_likelihood_ratio_using_epsilon = agent_parameters.algorithm.clip_likelihood_ratio_using_epsilon
self.return_type = ActionProbabilities
with self.name_scope():
self.dense = nn.Dense(units=1)
self.dense = nn.Dense(units=1, weight_initializer=NormalizedRSSInitializer(1.0))
def hybrid_forward(self, F: ModuleType, x: nd_sym_type) -> nd_sym_type:
"""

View File

@@ -4,7 +4,8 @@ from types import ModuleType
import mxnet as mx
from mxnet.gluon.loss import Loss, HuberLoss, L2Loss
from mxnet.gluon import nn
from rl_coach.architectures.mxnet_components.heads.head import Head, HeadLoss, LossInputSchema
from rl_coach.architectures.mxnet_components.heads.head import Head, HeadLoss, LossInputSchema,\
NormalizedRSSInitializer
from rl_coach.architectures.mxnet_components.heads.head import LOSS_OUT_TYPE_LOSS
from rl_coach.base_parameters import AgentParameters
from rl_coach.core_types import VStateValue
@@ -79,7 +80,7 @@ class VHead(Head):
self.loss_type = loss_type
self.return_type = VStateValue
with self.name_scope():
self.dense = nn.Dense(units=1)
self.dense = nn.Dense(units=1, weight_initializer=NormalizedRSSInitializer(1.0))
def loss(self) -> Loss:
"""

View File

@@ -0,0 +1,25 @@
import mxnet as mx
import numpy as np
import os
import pytest
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
from rl_coach.architectures.mxnet_components.heads.head import NormalizedRSSInitializer
@pytest.mark.unit_test
def test_normalized_rss_initializer():
target_rss = 0.5
units = 10
dense = mx.gluon.nn.Dense(units=units, weight_initializer=NormalizedRSSInitializer(target_rss))
dense.initialize()
input_data = mx.random.uniform(shape=(25, 5))
output_data = dense(input_data)
weights = dense.weight.data()
assert weights.shape == (10, 5)
rss = weights.square().sum(axis=1).sqrt()
np.testing.assert_almost_equal(rss.asnumpy(), np.tile(target_rss, units))