mirror of
https://github.com/gryf/coach.git
synced 2026-02-15 13:35:55 +01:00
Added Custom Initialisation for MXNet Heads (#86)
* Added NormalizedRSSInitializer, using same method as TensorFlow backend, but changed name since ‘columns’ have different meaning in dense layer weight matrix in MXNet. * Added unit test for NormalizedRSSInitializer.
This commit is contained in:
committed by
Scott Leishman
parent
101c55d37d
commit
81bac050d7
@@ -0,0 +1,25 @@
|
||||
import mxnet as mx
|
||||
import numpy as np
|
||||
import os
|
||||
import pytest
|
||||
import sys
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
|
||||
|
||||
|
||||
from rl_coach.architectures.mxnet_components.heads.head import NormalizedRSSInitializer
|
||||
|
||||
|
||||
@pytest.mark.unit_test
|
||||
def test_normalized_rss_initializer():
|
||||
target_rss = 0.5
|
||||
units = 10
|
||||
dense = mx.gluon.nn.Dense(units=units, weight_initializer=NormalizedRSSInitializer(target_rss))
|
||||
dense.initialize()
|
||||
|
||||
input_data = mx.random.uniform(shape=(25, 5))
|
||||
output_data = dense(input_data)
|
||||
|
||||
weights = dense.weight.data()
|
||||
assert weights.shape == (10, 5)
|
||||
rss = weights.square().sum(axis=1).sqrt()
|
||||
np.testing.assert_almost_equal(rss.asnumpy(), np.tile(target_rss, units))
|
||||
Reference in New Issue
Block a user