diff --git a/rl_coach/architectures/tensorflow_components/heads/dueling_q_head.py b/rl_coach/architectures/tensorflow_components/heads/dueling_q_head.py index ce7ad58..7090236 100644 --- a/rl_coach/architectures/tensorflow_components/heads/dueling_q_head.py +++ b/rl_coach/architectures/tensorflow_components/heads/dueling_q_head.py @@ -47,7 +47,7 @@ class DuelingQHead(QHead): with tf.variable_scope("action_advantage"): action_advantage = self.dense_layer(512)(input_layer, activation=self.activation_function, name='fc1') action_advantage = self.dense_layer(self.num_actions)(action_advantage, name='fc2') - action_advantage = action_advantage - tf.reduce_mean(action_advantage) + action_advantage = action_advantage - tf.reduce_mean(action_advantage, axis=1) # merge to state-action value function Q self.output = tf.add(state_value, action_advantage, name='output')