1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-18 11:40:18 +01:00
Files
coach/rl_coach/tests/architectures/mxnet_components/heads/test_head.py
Thom Lane 81bac050d7 Added Custom Initialisation for MXNet Heads (#86)
* Added NormalizedRSSInitializer, using same method as TensorFlow backend, but changed name since ‘columns’ have different meaning in dense layer weight matrix in MXNet.

* Added unit test for NormalizedRSSInitializer.
2018-11-16 08:15:43 -08:00

26 lines
745 B
Python

import mxnet as mx
import numpy as np
import os
import pytest
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
from rl_coach.architectures.mxnet_components.heads.head import NormalizedRSSInitializer
@pytest.mark.unit_test
def test_normalized_rss_initializer():
target_rss = 0.5
units = 10
dense = mx.gluon.nn.Dense(units=units, weight_initializer=NormalizedRSSInitializer(target_rss))
dense.initialize()
input_data = mx.random.uniform(shape=(25, 5))
output_data = dense(input_data)
weights = dense.weight.data()
assert weights.shape == (10, 5)
rss = weights.square().sum(axis=1).sqrt()
np.testing.assert_almost_equal(rss.asnumpy(), np.tile(target_rss, units))