From f9ee526536cc0ffbc42078c53565ffe35c51e1e3 Mon Sep 17 00:00:00 2001 From: Gal Leibovich Date: Sun, 16 Dec 2018 16:06:44 +0200 Subject: [PATCH] Fix for issue #128 - circular DQN import (#130) --- rl_coach/agents/dqn_agent.py | 1 + rl_coach/base_parameters.py | 3 +++ rl_coach/exploration_policies/parameter_noise.py | 4 ++-- .../filters/observation/observation_normalization_filter.py | 1 + 4 files changed, 7 insertions(+), 2 deletions(-) diff --git a/rl_coach/agents/dqn_agent.py b/rl_coach/agents/dqn_agent.py index a60aac2..b234e88 100644 --- a/rl_coach/agents/dqn_agent.py +++ b/rl_coach/agents/dqn_agent.py @@ -36,6 +36,7 @@ class DQNAlgorithmParameters(AlgorithmParameters): self.num_steps_between_copying_online_weights_to_target = EnvironmentSteps(10000) self.num_consecutive_playing_steps = EnvironmentSteps(4) self.discount = 0.99 + self.supports_parameter_noise = True class DQNNetworkParameters(NetworkParameters): diff --git a/rl_coach/base_parameters.py b/rl_coach/base_parameters.py index d3b5999..da368c3 100644 --- a/rl_coach/base_parameters.py +++ b/rl_coach/base_parameters.py @@ -211,6 +211,9 @@ class AlgorithmParameters(Parameters): # Should the workers wait for full episode self.act_for_full_episodes = False + # Support for parameter noise + self.supports_parameter_noise = False + class PresetValidationParameters(Parameters): def __init__(self, diff --git a/rl_coach/exploration_policies/parameter_noise.py b/rl_coach/exploration_policies/parameter_noise.py index 34381c4..7854329 100644 --- a/rl_coach/exploration_policies/parameter_noise.py +++ b/rl_coach/exploration_policies/parameter_noise.py @@ -18,7 +18,6 @@ from typing import List, Dict import numpy as np -from rl_coach.agents.dqn_agent import DQNAgentParameters from rl_coach.architectures.layers import NoisyNetDense from rl_coach.base_parameters import AgentParameters, NetworkParameters from rl_coach.spaces import ActionSpace, BoxActionSpace, DiscreteActionSpace @@ -30,7 +29,8 @@ from rl_coach.exploration_policies.exploration_policy import ExplorationPolicy, class ParameterNoiseParameters(ExplorationParameters): def __init__(self, agent_params: AgentParameters): super().__init__() - if not isinstance(agent_params, DQNAgentParameters): + + if not agent_params.algorithm.supports_parameter_noise: raise ValueError("Currently only DQN variants are supported for using an exploration type of " "ParameterNoise.") diff --git a/rl_coach/filters/observation/observation_normalization_filter.py b/rl_coach/filters/observation/observation_normalization_filter.py index db9e104..791b345 100644 --- a/rl_coach/filters/observation/observation_normalization_filter.py +++ b/rl_coach/filters/observation/observation_normalization_filter.py @@ -87,3 +87,4 @@ class ObservationNormalizationFilter(ObservationFilter): def restore_state_from_checkpoint(self, checkpoint_dir: str, checkpoint_prefix: str): self.running_observation_stats.restore_state_from_checkpoint(checkpoint_dir, checkpoint_prefix) + \ No newline at end of file