mirror of
https://github.com/gryf/coach.git
synced 2026-01-06 13:54:21 +01:00
Added Custom Initialisation for MXNet Heads (#86)
* Added NormalizedRSSInitializer, using same method as TensorFlow backend, but changed name since ‘columns’ have different meaning in dense layer weight matrix in MXNet. * Added unit test for NormalizedRSSInitializer.
This commit is contained in:
committed by
Scott Leishman
parent
101c55d37d
commit
81bac050d7
@@ -4,7 +4,8 @@ from types import ModuleType
|
||||
import mxnet as mx
|
||||
from mxnet.gluon.loss import Loss, HuberLoss, L2Loss
|
||||
from mxnet.gluon import nn
|
||||
from rl_coach.architectures.mxnet_components.heads.head import Head, HeadLoss, LossInputSchema
|
||||
from rl_coach.architectures.mxnet_components.heads.head import Head, HeadLoss, LossInputSchema,\
|
||||
NormalizedRSSInitializer
|
||||
from rl_coach.architectures.mxnet_components.heads.head import LOSS_OUT_TYPE_LOSS
|
||||
from rl_coach.base_parameters import AgentParameters
|
||||
from rl_coach.core_types import VStateValue
|
||||
@@ -79,7 +80,7 @@ class VHead(Head):
|
||||
self.loss_type = loss_type
|
||||
self.return_type = VStateValue
|
||||
with self.name_scope():
|
||||
self.dense = nn.Dense(units=1)
|
||||
self.dense = nn.Dense(units=1, weight_initializer=NormalizedRSSInitializer(1.0))
|
||||
|
||||
def loss(self) -> Loss:
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user