diff --git a/rl_coach/exploration_policies/additive_noise.py b/rl_coach/exploration_policies/additive_noise.py index 682021c..5f89889 100644 --- a/rl_coach/exploration_policies/additive_noise.py +++ b/rl_coach/exploration_policies/additive_noise.py @@ -88,7 +88,7 @@ class AdditiveNoise(ExplorationPolicy): action_values_mean = action_values.squeeze() # step the noise schedule - if self.phase == RunPhase.TRAIN: + if self.phase is not RunPhase.TEST: self.noise_percentage_schedule.step() # the second element of the list is assumed to be the standard deviation if isinstance(action_values, list) and len(action_values) > 1: diff --git a/rl_coach/exploration_policies/truncated_normal.py b/rl_coach/exploration_policies/truncated_normal.py index bfd0ba1..396f348 100644 --- a/rl_coach/exploration_policies/truncated_normal.py +++ b/rl_coach/exploration_policies/truncated_normal.py @@ -92,7 +92,7 @@ class TruncatedNormal(ExplorationPolicy): action_values_mean = action_values.squeeze() # step the noise schedule - if self.phase == RunPhase.TRAIN: + if self.phase is not RunPhase.TEST: self.noise_percentage_schedule.step() # the second element of the list is assumed to be the standard deviation if isinstance(action_values, list) and len(action_values) > 1: