From ebe574e4639bc32f689e75e83c3a0bb109305f21 Mon Sep 17 00:00:00 2001 From: Gal Leibovich Date: Thu, 30 Aug 2018 19:34:27 +0300 Subject: [PATCH] add missing hidden layer in rainbow_q_head --- .../tensorflow_components/heads/rainbow_q_head.py | 6 ++++-- rl_coach/presets/Atari_Rainbow.py | 2 +- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/rl_coach/architectures/tensorflow_components/heads/rainbow_q_head.py b/rl_coach/architectures/tensorflow_components/heads/rainbow_q_head.py index d24b723..4dde21a 100644 --- a/rl_coach/architectures/tensorflow_components/heads/rainbow_q_head.py +++ b/rl_coach/architectures/tensorflow_components/heads/rainbow_q_head.py @@ -43,12 +43,14 @@ class RainbowQHead(Head): def _build_module(self, input_layer): # state value tower - V with tf.variable_scope("state_value"): - state_value = self.dense_layer(self.num_atoms)(input_layer, name='fc1') + state_value = self.dense_layer(512)(input_layer, activation=self.activation_function, name='fc1') + state_value = self.dense_layer(self.num_atoms)(state_value, name='fc2') state_value = tf.expand_dims(state_value, axis=1) # action advantage tower - A with tf.variable_scope("action_advantage"): - action_advantage = self.dense_layer(self.num_actions * self.num_atoms)(input_layer, name='fc1') + action_advantage = self.dense_layer(512)(input_layer, activation=self.activation_function, name='fc1') + action_advantage = self.dense_layer(self.num_actions * self.num_atoms)(action_advantage, name='fc2') action_advantage = tf.reshape(action_advantage, (tf.shape(input_layer)[0], self.num_actions, self.num_atoms)) action_mean = tf.reduce_mean(action_advantage, axis=1, keepdims=True) diff --git a/rl_coach/presets/Atari_Rainbow.py b/rl_coach/presets/Atari_Rainbow.py index b6ccf67..3187fd7 100644 --- a/rl_coach/presets/Atari_Rainbow.py +++ b/rl_coach/presets/Atari_Rainbow.py @@ -15,7 +15,7 @@ schedule_params = ScheduleParameters() schedule_params.improve_steps = EnvironmentSteps(50000000) schedule_params.steps_between_evaluation_periods = EnvironmentSteps(250000) schedule_params.evaluation_steps = EnvironmentSteps(135000) -schedule_params.heatup_steps = EnvironmentSteps(500) +schedule_params.heatup_steps = EnvironmentSteps(50000) ######### # Agent #