1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 19:20:19 +01:00

fix ddpg head (#78)

This commit is contained in:
Itai Caspi
2018-11-09 18:17:04 +02:00
committed by Balaji Subramaniam
parent 3a0a1159e9
commit 3fd433ffab

View File

@@ -46,9 +46,12 @@ class DDPGActor(Head):
def _build_module(self, input_layer): def _build_module(self, input_layer):
# mean # mean
pre_activation_policy_values_mean = self.dense_layer(self.num_actions)(input_layer, name='fc_mean') pre_activation_policy_values_mean = self.dense_layer(self.num_actions)(input_layer, name='fc_mean')
policy_values_mean = batchnorm_activation_dropout(pre_activation_policy_values_mean, self.batchnorm, policy_values_mean = batchnorm_activation_dropout(input_layer=pre_activation_policy_values_mean,
self.activation_function, batchnorm=self.batchnorm,
False, 0, is_training=False, name="BatchnormActivationDropout_0")[-1] activation_function=self.activation_function,
dropout_rate=0,
is_training=False,
name="BatchnormActivationDropout_0")[-1]
self.policy_mean = tf.multiply(policy_values_mean, self.output_scale, name='output_mean') self.policy_mean = tf.multiply(policy_values_mean, self.output_scale, name='output_mean')
if self.is_local: if self.is_local: