1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-18 11:40:18 +01:00

add missing hidden layer in rainbow_q_head

This commit is contained in:
Gal Leibovich
2018-08-30 19:34:27 +03:00
parent ea294de7fd
commit ebe574e463
2 changed files with 5 additions and 3 deletions

View File

@@ -43,12 +43,14 @@ class RainbowQHead(Head):
def _build_module(self, input_layer):
# state value tower - V
with tf.variable_scope("state_value"):
state_value = self.dense_layer(self.num_atoms)(input_layer, name='fc1')
state_value = self.dense_layer(512)(input_layer, activation=self.activation_function, name='fc1')
state_value = self.dense_layer(self.num_atoms)(state_value, name='fc2')
state_value = tf.expand_dims(state_value, axis=1)
# action advantage tower - A
with tf.variable_scope("action_advantage"):
action_advantage = self.dense_layer(self.num_actions * self.num_atoms)(input_layer, name='fc1')
action_advantage = self.dense_layer(512)(input_layer, activation=self.activation_function, name='fc1')
action_advantage = self.dense_layer(self.num_actions * self.num_atoms)(action_advantage, name='fc2')
action_advantage = tf.reshape(action_advantage, (tf.shape(input_layer)[0], self.num_actions,
self.num_atoms))
action_mean = tf.reduce_mean(action_advantage, axis=1, keepdims=True)

View File

@@ -15,7 +15,7 @@ schedule_params = ScheduleParameters()
schedule_params.improve_steps = EnvironmentSteps(50000000)
schedule_params.steps_between_evaluation_periods = EnvironmentSteps(250000)
schedule_params.evaluation_steps = EnvironmentSteps(135000)
schedule_params.heatup_steps = EnvironmentSteps(500)
schedule_params.heatup_steps = EnvironmentSteps(50000)
#########
# Agent #