1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-18 11:40:18 +01:00

bug fix in dueling network + revert to TF 1.6 for CPU due to requirements compatibility issues

This commit is contained in:
itaicaspi-intel
2018-09-02 13:38:16 +03:00
parent 3a399d1361
commit 2c62a40466
2 changed files with 9 additions and 9 deletions

View File

@@ -39,15 +39,15 @@ class DuelingQHead(QHead):
def _build_module(self, input_layer):
# state value tower - V
with tf.variable_scope("state_value"):
state_value = self.dense_layer(512)(input_layer, activation=self.activation_function, name='fc1')
state_value = self.dense_layer(1)(state_value, name='fc2')
# state_value = tf.expand_dims(state_value, axis=-1)
self.state_value = self.dense_layer(512)(input_layer, activation=self.activation_function, name='fc1')
self.state_value = self.dense_layer(1)(self.state_value, name='fc2')
# action advantage tower - A
with tf.variable_scope("action_advantage"):
action_advantage = self.dense_layer(512)(input_layer, activation=self.activation_function, name='fc1')
action_advantage = self.dense_layer(self.num_actions)(action_advantage, name='fc2')
action_advantage = action_advantage - tf.reduce_mean(action_advantage, axis=1)
self.action_advantage = self.dense_layer(512)(input_layer, activation=self.activation_function, name='fc1')
self.action_advantage = self.dense_layer(self.num_actions)(self.action_advantage, name='fc2')
self.action_mean = tf.reduce_mean(self.action_advantage, axis=1, keep_dims=True)
self.action_advantage = self.action_advantage - self.action_mean
# merge to state-action value function Q
self.output = tf.add(state_value, action_advantage, name='output')
self.output = tf.add(self.state_value, self.action_advantage, name='output')