diff --git a/rl_coach/agents/rainbow_dqn_agent.py b/rl_coach/agents/rainbow_dqn_agent.py index b45415e..e39024a 100644 --- a/rl_coach/agents/rainbow_dqn_agent.py +++ b/rl_coach/agents/rainbow_dqn_agent.py @@ -22,7 +22,8 @@ from rl_coach.agents.categorical_dqn_agent import CategoricalDQNAlgorithmParamet CategoricalDQNAgent, CategoricalDQNAgentParameters from rl_coach.agents.dqn_agent import DQNNetworkParameters from rl_coach.architectures.tensorflow_components.heads.rainbow_q_head import RainbowQHeadParameters - +from rl_coach.architectures.tensorflow_components.middlewares.fc_middleware import FCMiddlewareParameters +from rl_coach.base_parameters import MiddlewareScheme from rl_coach.exploration_policies.parameter_noise import ParameterNoiseParameters from rl_coach.memories.non_episodic.prioritized_experience_replay import PrioritizedExperienceReplayParameters, \ PrioritizedExperienceReplay @@ -32,6 +33,7 @@ class RainbowDQNNetworkParameters(DQNNetworkParameters): def __init__(self): super().__init__() self.heads_parameters = [RainbowQHeadParameters()] + self.middleware_parameters = FCMiddlewareParameters(scheme=MiddlewareScheme.Empty) class RainbowDQNAlgorithmParameters(CategoricalDQNAlgorithmParameters): @@ -44,6 +46,11 @@ class RainbowDQNExplorationParameters(ParameterNoiseParameters): super().__init__(agent_params) +class RainbowDQNMemoryParameters(PrioritizedExperienceReplayParameters): + def __init__(self): + super().__init__() + + class RainbowDQNAgentParameters(CategoricalDQNAgentParameters): def __init__(self): super().__init__() @@ -58,8 +65,8 @@ class RainbowDQNAgentParameters(CategoricalDQNAgentParameters): # Rainbow Deep Q Network - https://arxiv.org/abs/1710.02298 -# Agent implementation is WIP. Currently has: -# 1. DQN +# Agent implementation is WIP. Currently is composed of: +# 1. NoisyNets # 2. C51 # 3. Prioritized ER # 4. DDQN diff --git a/rl_coach/presets/Atari_Rainbow.py b/rl_coach/presets/Atari_Rainbow.py index 3187fd7..8399cf6 100644 --- a/rl_coach/presets/Atari_Rainbow.py +++ b/rl_coach/presets/Atari_Rainbow.py @@ -1,4 +1,3 @@ -from rl_coach.agents.categorical_dqn_agent import CategoricalDQNAgentParameters from rl_coach.agents.rainbow_dqn_agent import RainbowDQNAgentParameters from rl_coach.base_parameters import VisualizationParameters, PresetValidationParameters from rl_coach.core_types import EnvironmentSteps, RunPhase @@ -13,17 +12,20 @@ from rl_coach.schedules import LinearSchedule #################### schedule_params = ScheduleParameters() schedule_params.improve_steps = EnvironmentSteps(50000000) -schedule_params.steps_between_evaluation_periods = EnvironmentSteps(250000) -schedule_params.evaluation_steps = EnvironmentSteps(135000) -schedule_params.heatup_steps = EnvironmentSteps(50000) +schedule_params.steps_between_evaluation_periods = EnvironmentSteps(1000000) +schedule_params.evaluation_steps = EnvironmentSteps(125000) +schedule_params.heatup_steps = EnvironmentSteps(20000) ######### # Agent # ######### agent_params = RainbowDQNAgentParameters() -agent_params.network_wrappers['main'].learning_rate = 0.00025 -agent_params.memory.beta = LinearSchedule(0.4, 1, 12500000) # 12.5M training iterations = 50M steps = 200M frames +agent_params.network_wrappers['main'].learning_rate = 0.0000625 +agent_params.network_wrappers['main'].optimizer_epsilon = 1.5e-4 +agent_params.algorithm.num_steps_between_copying_online_weights_to_target = EnvironmentSteps(32000 // 4) # 32k frames +agent_params.memory.beta = LinearSchedule(0.4, 1, 12500000) # 12.5M training iterations = 50M steps = 200M frames +agent_params.memory.alpha = 0.5 ############### # Environment #