mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 19:20:19 +01:00
fix ddpg head (#78)
This commit is contained in:
committed by
Balaji Subramaniam
parent
3a0a1159e9
commit
3fd433ffab
@@ -46,9 +46,12 @@ class DDPGActor(Head):
|
|||||||
def _build_module(self, input_layer):
|
def _build_module(self, input_layer):
|
||||||
# mean
|
# mean
|
||||||
pre_activation_policy_values_mean = self.dense_layer(self.num_actions)(input_layer, name='fc_mean')
|
pre_activation_policy_values_mean = self.dense_layer(self.num_actions)(input_layer, name='fc_mean')
|
||||||
policy_values_mean = batchnorm_activation_dropout(pre_activation_policy_values_mean, self.batchnorm,
|
policy_values_mean = batchnorm_activation_dropout(input_layer=pre_activation_policy_values_mean,
|
||||||
self.activation_function,
|
batchnorm=self.batchnorm,
|
||||||
False, 0, is_training=False, name="BatchnormActivationDropout_0")[-1]
|
activation_function=self.activation_function,
|
||||||
|
dropout_rate=0,
|
||||||
|
is_training=False,
|
||||||
|
name="BatchnormActivationDropout_0")[-1]
|
||||||
self.policy_mean = tf.multiply(policy_values_mean, self.output_scale, name='output_mean')
|
self.policy_mean = tf.multiply(policy_values_mean, self.output_scale, name='output_mean')
|
||||||
|
|
||||||
if self.is_local:
|
if self.is_local:
|
||||||
|
|||||||
Reference in New Issue
Block a user