1
0
mirror of https://github.com/gryf/coach.git synced 2026-04-26 10:41:28 +02:00

Added Custom Initialisation for MXNet Heads (#86)

* Added NormalizedRSSInitializer, using same method as TensorFlow backend, but changed name since ‘columns’ have different meaning in dense layer weight matrix in MXNet.

* Added unit test for NormalizedRSSInitializer.
This commit is contained in:
Thom Lane
2018-11-16 08:15:43 -08:00
committed by Scott Leishman
parent 101c55d37d
commit 81bac050d7
5 changed files with 59 additions and 8 deletions
@@ -3,7 +3,8 @@ from types import ModuleType
import mxnet as mx
from mxnet.gluon import nn
from rl_coach.architectures.mxnet_components.heads.head import Head, HeadLoss, LossInputSchema
from rl_coach.architectures.mxnet_components.heads.head import Head, HeadLoss, LossInputSchema,\
NormalizedRSSInitializer
from rl_coach.architectures.mxnet_components.heads.head import LOSS_OUT_TYPE_LOSS
from rl_coach.base_parameters import AgentParameters
from rl_coach.core_types import ActionProbabilities
@@ -102,7 +103,7 @@ class PPOVHead(Head):
self.clip_likelihood_ratio_using_epsilon = agent_parameters.algorithm.clip_likelihood_ratio_using_epsilon
self.return_type = ActionProbabilities
with self.name_scope():
self.dense = nn.Dense(units=1)
self.dense = nn.Dense(units=1, weight_initializer=NormalizedRSSInitializer(1.0))
def hybrid_forward(self, F: ModuleType, x: nd_sym_type) -> nd_sym_type:
"""