From b4bc8a476ccbb1b595dd93efa6475ce11a4a8a32 Mon Sep 17 00:00:00 2001 From: Neta Zmora <31280975+nzmora@users.noreply.github.com> Date: Mon, 17 Dec 2018 10:08:54 +0200 Subject: [PATCH] Bug fix: when enabling 'heatup_using_network_decisions', we should add the configured noise (#162) During heatup we may want to add agent-generated-noise (i.e. not "simple" random noise). This is enabled by setting 'heatup_using_network_decisions' to True. For example: agent_params = DDPGAgentParameters() agent_params.algorithm.heatup_using_network_decisions = True The fix ensures that the correct noise is added not just while in the TRAINING phase, but also during the HEATUP phase. No one has enabled 'heatup_using_network_decisions' yet, which explains why this problem arose only now (in my configuration I do enable 'heatup_using_network_decisions'). --- rl_coach/exploration_policies/additive_noise.py | 2 +- rl_coach/exploration_policies/truncated_normal.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/rl_coach/exploration_policies/additive_noise.py b/rl_coach/exploration_policies/additive_noise.py index 682021c..5f89889 100644 --- a/rl_coach/exploration_policies/additive_noise.py +++ b/rl_coach/exploration_policies/additive_noise.py @@ -88,7 +88,7 @@ class AdditiveNoise(ExplorationPolicy): action_values_mean = action_values.squeeze() # step the noise schedule - if self.phase == RunPhase.TRAIN: + if self.phase is not RunPhase.TEST: self.noise_percentage_schedule.step() # the second element of the list is assumed to be the standard deviation if isinstance(action_values, list) and len(action_values) > 1: diff --git a/rl_coach/exploration_policies/truncated_normal.py b/rl_coach/exploration_policies/truncated_normal.py index bfd0ba1..396f348 100644 --- a/rl_coach/exploration_policies/truncated_normal.py +++ b/rl_coach/exploration_policies/truncated_normal.py @@ -92,7 +92,7 @@ class TruncatedNormal(ExplorationPolicy): action_values_mean = action_values.squeeze() # step the noise schedule - if self.phase == RunPhase.TRAIN: + if self.phase is not RunPhase.TEST: self.noise_percentage_schedule.step() # the second element of the list is assumed to be the standard deviation if isinstance(action_values, list) and len(action_values) > 1: