mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 19:20:19 +01:00
Added Custom Initialisation for MXNet Heads (#86)
* Added NormalizedRSSInitializer, using same method as TensorFlow backend, but changed name since ‘columns’ have different meaning in dense layer weight matrix in MXNet. * Added unit test for NormalizedRSSInitializer.
This commit is contained in:
committed by
Scott Leishman
parent
101c55d37d
commit
81bac050d7
@@ -1,5 +1,7 @@
|
||||
from typing import Dict, List, Union, Tuple
|
||||
|
||||
import mxnet as mx
|
||||
from mxnet.initializer import Initializer, register
|
||||
from mxnet.gluon import nn, loss
|
||||
from mxnet.ndarray import NDArray
|
||||
from mxnet.symbol import Symbol
|
||||
@@ -11,6 +13,26 @@ LOSS_OUT_TYPE_LOSS = 'loss'
|
||||
LOSS_OUT_TYPE_REGULARIZATION = 'regularization'
|
||||
|
||||
|
||||
@register
|
||||
class NormalizedRSSInitializer(Initializer):
|
||||
"""
|
||||
Standardizes Root Sum of Squares along the input channel dimension.
|
||||
Used for Dense layer weight matrices only (ie. do not use on Convolution kernels).
|
||||
MXNet Dense layer weight matrix is of shape (out_ch, in_ch), so standardize across axis 1.
|
||||
Root Sum of Squares set to `rss`, which is 1.0 by default.
|
||||
Called `normalized_columns_initializer` in TensorFlow backend (but we work with rows instead of columns for MXNet).
|
||||
"""
|
||||
def __init__(self, rss=1.0):
|
||||
super(NormalizedRSSInitializer, self).__init__(rss=rss)
|
||||
self.rss = float(rss)
|
||||
|
||||
def _init_weight(self, name, arr):
|
||||
mx.nd.random.normal(0, 1, out=arr)
|
||||
sample_rss = arr.square().sum(axis=1).sqrt()
|
||||
scalers = self.rss / sample_rss
|
||||
arr *= scalers.expand_dims(1)
|
||||
|
||||
|
||||
class LossInputSchema(object):
|
||||
"""
|
||||
Helper class to contain schema for loss hybrid_forward input
|
||||
|
||||
@@ -8,7 +8,8 @@ from rl_coach.base_parameters import AgentParameters
|
||||
from rl_coach.core_types import ActionProbabilities
|
||||
from rl_coach.spaces import SpacesDefinition, BoxActionSpace, DiscreteActionSpace
|
||||
from rl_coach.utils import eps
|
||||
from rl_coach.architectures.mxnet_components.heads.head import Head, HeadLoss, LossInputSchema
|
||||
from rl_coach.architectures.mxnet_components.heads.head import Head, HeadLoss, LossInputSchema,\
|
||||
NormalizedRSSInitializer
|
||||
from rl_coach.architectures.mxnet_components.heads.head import LOSS_OUT_TYPE_LOSS, LOSS_OUT_TYPE_REGULARIZATION
|
||||
from rl_coach.architectures.mxnet_components.utils import hybrid_clip
|
||||
|
||||
@@ -229,7 +230,8 @@ class DiscretePPOHead(nn.HybridBlock):
|
||||
"""
|
||||
super(DiscretePPOHead, self).__init__()
|
||||
with self.name_scope():
|
||||
self.dense = nn.Dense(units=num_actions, flatten=False)
|
||||
self.dense = nn.Dense(units=num_actions, flatten=False,
|
||||
weight_initializer=NormalizedRSSInitializer(0.01))
|
||||
|
||||
def hybrid_forward(self, F: ModuleType, x: nd_sym_type) -> nd_sym_type:
|
||||
"""
|
||||
@@ -258,8 +260,8 @@ class ContinuousPPOHead(nn.HybridBlock):
|
||||
"""
|
||||
super(ContinuousPPOHead, self).__init__()
|
||||
with self.name_scope():
|
||||
# todo: change initialization strategy
|
||||
self.dense = nn.Dense(units=num_actions, flatten=False)
|
||||
self.dense = nn.Dense(units=num_actions, flatten=False,
|
||||
weight_initializer=NormalizedRSSInitializer(0.01))
|
||||
# all samples (across batch, and time step) share the same covariance, which is learnt,
|
||||
# but since we assume the action probability variables are independent,
|
||||
# only the diagonal entries of the covariance matrix are specified.
|
||||
|
||||
@@ -3,7 +3,8 @@ from types import ModuleType
|
||||
|
||||
import mxnet as mx
|
||||
from mxnet.gluon import nn
|
||||
from rl_coach.architectures.mxnet_components.heads.head import Head, HeadLoss, LossInputSchema
|
||||
from rl_coach.architectures.mxnet_components.heads.head import Head, HeadLoss, LossInputSchema,\
|
||||
NormalizedRSSInitializer
|
||||
from rl_coach.architectures.mxnet_components.heads.head import LOSS_OUT_TYPE_LOSS
|
||||
from rl_coach.base_parameters import AgentParameters
|
||||
from rl_coach.core_types import ActionProbabilities
|
||||
@@ -102,7 +103,7 @@ class PPOVHead(Head):
|
||||
self.clip_likelihood_ratio_using_epsilon = agent_parameters.algorithm.clip_likelihood_ratio_using_epsilon
|
||||
self.return_type = ActionProbabilities
|
||||
with self.name_scope():
|
||||
self.dense = nn.Dense(units=1)
|
||||
self.dense = nn.Dense(units=1, weight_initializer=NormalizedRSSInitializer(1.0))
|
||||
|
||||
def hybrid_forward(self, F: ModuleType, x: nd_sym_type) -> nd_sym_type:
|
||||
"""
|
||||
|
||||
@@ -4,7 +4,8 @@ from types import ModuleType
|
||||
import mxnet as mx
|
||||
from mxnet.gluon.loss import Loss, HuberLoss, L2Loss
|
||||
from mxnet.gluon import nn
|
||||
from rl_coach.architectures.mxnet_components.heads.head import Head, HeadLoss, LossInputSchema
|
||||
from rl_coach.architectures.mxnet_components.heads.head import Head, HeadLoss, LossInputSchema,\
|
||||
NormalizedRSSInitializer
|
||||
from rl_coach.architectures.mxnet_components.heads.head import LOSS_OUT_TYPE_LOSS
|
||||
from rl_coach.base_parameters import AgentParameters
|
||||
from rl_coach.core_types import VStateValue
|
||||
@@ -79,7 +80,7 @@ class VHead(Head):
|
||||
self.loss_type = loss_type
|
||||
self.return_type = VStateValue
|
||||
with self.name_scope():
|
||||
self.dense = nn.Dense(units=1)
|
||||
self.dense = nn.Dense(units=1, weight_initializer=NormalizedRSSInitializer(1.0))
|
||||
|
||||
def loss(self) -> Loss:
|
||||
"""
|
||||
|
||||
@@ -0,0 +1,25 @@
|
||||
import mxnet as mx
|
||||
import numpy as np
|
||||
import os
|
||||
import pytest
|
||||
import sys
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
|
||||
|
||||
|
||||
from rl_coach.architectures.mxnet_components.heads.head import NormalizedRSSInitializer
|
||||
|
||||
|
||||
@pytest.mark.unit_test
|
||||
def test_normalized_rss_initializer():
|
||||
target_rss = 0.5
|
||||
units = 10
|
||||
dense = mx.gluon.nn.Dense(units=units, weight_initializer=NormalizedRSSInitializer(target_rss))
|
||||
dense.initialize()
|
||||
|
||||
input_data = mx.random.uniform(shape=(25, 5))
|
||||
output_data = dense(input_data)
|
||||
|
||||
weights = dense.weight.data()
|
||||
assert weights.shape == (10, 5)
|
||||
rss = weights.square().sum(axis=1).sqrt()
|
||||
np.testing.assert_almost_equal(rss.asnumpy(), np.tile(target_rss, units))
|
||||
Reference in New Issue
Block a user