1
0
mirror of https://github.com/gryf/coach.git synced 2026-02-15 13:35:55 +01:00

Added Custom Initialisation for MXNet Heads (#86)

* Added NormalizedRSSInitializer, using same method as TensorFlow backend, but changed name since ‘columns’ have different meaning in dense layer weight matrix in MXNet.

* Added unit test for NormalizedRSSInitializer.
This commit is contained in:
Thom Lane
2018-11-16 08:15:43 -08:00
committed by Scott Leishman
parent 101c55d37d
commit 81bac050d7
5 changed files with 59 additions and 8 deletions

View File

@@ -0,0 +1,25 @@
import mxnet as mx
import numpy as np
import os
import pytest
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
from rl_coach.architectures.mxnet_components.heads.head import NormalizedRSSInitializer
@pytest.mark.unit_test
def test_normalized_rss_initializer():
target_rss = 0.5
units = 10
dense = mx.gluon.nn.Dense(units=units, weight_initializer=NormalizedRSSInitializer(target_rss))
dense.initialize()
input_data = mx.random.uniform(shape=(25, 5))
output_data = dense(input_data)
weights = dense.weight.data()
assert weights.shape == (10, 5)
rss = weights.square().sum(axis=1).sqrt()
np.testing.assert_almost_equal(rss.asnumpy(), np.tile(target_rss, units))