mirror of
https://github.com/gryf/coach.git
synced 2026-04-11 07:33:37 +02:00
Added Custom Initialisation for MXNet Heads (#86)
* Added NormalizedRSSInitializer, using same method as TensorFlow backend, but changed name since ‘columns’ have different meaning in dense layer weight matrix in MXNet. * Added unit test for NormalizedRSSInitializer.
This commit is contained in:
committed by
Scott Leishman
parent
101c55d37d
commit
81bac050d7
@@ -8,7 +8,8 @@ from rl_coach.base_parameters import AgentParameters
|
||||
from rl_coach.core_types import ActionProbabilities
|
||||
from rl_coach.spaces import SpacesDefinition, BoxActionSpace, DiscreteActionSpace
|
||||
from rl_coach.utils import eps
|
||||
from rl_coach.architectures.mxnet_components.heads.head import Head, HeadLoss, LossInputSchema
|
||||
from rl_coach.architectures.mxnet_components.heads.head import Head, HeadLoss, LossInputSchema,\
|
||||
NormalizedRSSInitializer
|
||||
from rl_coach.architectures.mxnet_components.heads.head import LOSS_OUT_TYPE_LOSS, LOSS_OUT_TYPE_REGULARIZATION
|
||||
from rl_coach.architectures.mxnet_components.utils import hybrid_clip
|
||||
|
||||
@@ -229,7 +230,8 @@ class DiscretePPOHead(nn.HybridBlock):
|
||||
"""
|
||||
super(DiscretePPOHead, self).__init__()
|
||||
with self.name_scope():
|
||||
self.dense = nn.Dense(units=num_actions, flatten=False)
|
||||
self.dense = nn.Dense(units=num_actions, flatten=False,
|
||||
weight_initializer=NormalizedRSSInitializer(0.01))
|
||||
|
||||
def hybrid_forward(self, F: ModuleType, x: nd_sym_type) -> nd_sym_type:
|
||||
"""
|
||||
@@ -258,8 +260,8 @@ class ContinuousPPOHead(nn.HybridBlock):
|
||||
"""
|
||||
super(ContinuousPPOHead, self).__init__()
|
||||
with self.name_scope():
|
||||
# todo: change initialization strategy
|
||||
self.dense = nn.Dense(units=num_actions, flatten=False)
|
||||
self.dense = nn.Dense(units=num_actions, flatten=False,
|
||||
weight_initializer=NormalizedRSSInitializer(0.01))
|
||||
# all samples (across batch, and time step) share the same covariance, which is learnt,
|
||||
# but since we assume the action probability variables are independent,
|
||||
# only the diagonal entries of the covariance matrix are specified.
|
||||
|
||||
Reference in New Issue
Block a user