mirror of
https://github.com/gryf/coach.git
synced 2025-12-18 11:40:18 +01:00
bug fix in dueling network + revert to TF 1.6 for CPU due to requirements compatibility issues
This commit is contained in:
@@ -39,15 +39,15 @@ class DuelingQHead(QHead):
|
||||
def _build_module(self, input_layer):
|
||||
# state value tower - V
|
||||
with tf.variable_scope("state_value"):
|
||||
state_value = self.dense_layer(512)(input_layer, activation=self.activation_function, name='fc1')
|
||||
state_value = self.dense_layer(1)(state_value, name='fc2')
|
||||
# state_value = tf.expand_dims(state_value, axis=-1)
|
||||
self.state_value = self.dense_layer(512)(input_layer, activation=self.activation_function, name='fc1')
|
||||
self.state_value = self.dense_layer(1)(self.state_value, name='fc2')
|
||||
|
||||
# action advantage tower - A
|
||||
with tf.variable_scope("action_advantage"):
|
||||
action_advantage = self.dense_layer(512)(input_layer, activation=self.activation_function, name='fc1')
|
||||
action_advantage = self.dense_layer(self.num_actions)(action_advantage, name='fc2')
|
||||
action_advantage = action_advantage - tf.reduce_mean(action_advantage, axis=1)
|
||||
self.action_advantage = self.dense_layer(512)(input_layer, activation=self.activation_function, name='fc1')
|
||||
self.action_advantage = self.dense_layer(self.num_actions)(self.action_advantage, name='fc2')
|
||||
self.action_mean = tf.reduce_mean(self.action_advantage, axis=1, keep_dims=True)
|
||||
self.action_advantage = self.action_advantage - self.action_mean
|
||||
|
||||
# merge to state-action value function Q
|
||||
self.output = tf.add(state_value, action_advantage, name='output')
|
||||
self.output = tf.add(self.state_value, self.action_advantage, name='output')
|
||||
|
||||
Reference in New Issue
Block a user