From bbe7ac33384779f5f224090475fb451ef1f9f43a Mon Sep 17 00:00:00 2001 From: Gal Leibovich Date: Thu, 30 Aug 2018 15:11:51 +0300 Subject: [PATCH] Rainbow DQN agent (WIP - still missing dueling and n-step) + adding support for Prioritized ER for C51 --- rl_coach/agents/categorical_dqn_agent.py | 14 +- rl_coach/agents/rainbow_dqn_agent.py | 125 ++++++++++++++++++ .../heads/rainbow_dqn_head.py | 44 ++++++ rl_coach/presets/Atari_Rainbow.py | 46 +++++++ 4 files changed, 228 insertions(+), 1 deletion(-) create mode 100644 rl_coach/agents/rainbow_dqn_agent.py create mode 100644 rl_coach/architectures/tensorflow_components/heads/rainbow_dqn_head.py create mode 100644 rl_coach/presets/Atari_Rainbow.py diff --git a/rl_coach/agents/categorical_dqn_agent.py b/rl_coach/agents/categorical_dqn_agent.py index debce3d..5af83c1 100644 --- a/rl_coach/agents/categorical_dqn_agent.py +++ b/rl_coach/agents/categorical_dqn_agent.py @@ -25,6 +25,7 @@ from rl_coach.base_parameters import AgentParameters from rl_coach.core_types import StateType from rl_coach.exploration_policies.e_greedy import EGreedyParameters from rl_coach.memories.non_episodic.experience_replay import ExperienceReplayParameters +from rl_coach.memories.non_episodic.prioritized_experience_replay import PrioritizedExperienceReplay from rl_coach.schedules import LinearSchedule @@ -104,11 +105,22 @@ class CategoricalDQNAgent(ValueOptimizationAgent): l = (np.floor(bj)).astype(int) m[batches, l] = m[batches, l] + (distributed_q_st_plus_1[batches, target_actions, j] * (u - bj)) m[batches, u] = m[batches, u] + (distributed_q_st_plus_1[batches, target_actions, j] * (bj - l)) + # total_loss = cross entropy between actual result above and predicted result for the given action TD_targets[batches, batch.actions()] = m - result = self.networks['main'].train_and_sync_networks(batch.states(network_keys), TD_targets) + # update errors in prioritized replay buffer + importance_weights = batch.info('weight') if isinstance(self.memory, PrioritizedExperienceReplay) else None + + result = self.networks['main'].train_and_sync_networks(batch.states(network_keys), TD_targets, + importance_weights=importance_weights) + total_loss, losses, unclipped_grads = result[:3] + # TODO: fix this spaghetti code + if isinstance(self.memory, PrioritizedExperienceReplay): + errors = losses[0][np.arange(batch.size), batch.actions()] + self.memory.update_priorities(batch.info('idx'), errors) + return total_loss, losses, unclipped_grads diff --git a/rl_coach/agents/rainbow_dqn_agent.py b/rl_coach/agents/rainbow_dqn_agent.py new file mode 100644 index 0000000..8fef370 --- /dev/null +++ b/rl_coach/agents/rainbow_dqn_agent.py @@ -0,0 +1,125 @@ +# +# Copyright (c) 2017 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import Union + +import numpy as np + +from rl_coach.agents.categorical_dqn_agent import CategoricalDQNNetworkParameters, CategoricalDQNAlgorithmParameters, \ + CategoricalDQNAgent, CategoricalDQNAgentParameters +from rl_coach.agents.dqn_agent import DQNNetworkParameters, DQNAlgorithmParameters +from rl_coach.agents.value_optimization_agent import ValueOptimizationAgent +from rl_coach.architectures.tensorflow_components.heads.categorical_q_head import CategoricalQHeadParameters +from rl_coach.base_parameters import AgentParameters +from rl_coach.exploration_policies.parameter_noise import ParameterNoiseParameters +from rl_coach.memories.non_episodic.experience_replay import ExperienceReplayParameters +from rl_coach.memories.non_episodic.prioritized_experience_replay import PrioritizedExperienceReplayParameters, \ + PrioritizedExperienceReplay +from rl_coach.schedules import LinearSchedule + +from rl_coach.core_types import StateType +from rl_coach.exploration_policies.e_greedy import EGreedyParameters + + +class RainbowDQNNetworkParameters(CategoricalDQNNetworkParameters): + def __init__(self): + super().__init__() + + +class RainbowDQNAlgorithmParameters(CategoricalDQNAlgorithmParameters): + def __init__(self): + super().__init__() + + +class RainbowDQNExplorationParameters(ParameterNoiseParameters): + def __init__(self, agent_params): + super().__init__(agent_params) + + +class RainbowDQNAgentParameters(CategoricalDQNAgentParameters): + def __init__(self): + super().__init__() + self.algorithm = RainbowDQNAlgorithmParameters() + self.exploration = RainbowDQNExplorationParameters(self) + self.memory = PrioritizedExperienceReplayParameters() + self.network_wrappers = {"main": RainbowDQNNetworkParameters()} + + @property + def path(self): + return 'rl_coach.agents.rainbow_dqn_agent:RainbowDQNAgent' + + +# Rainbow Deep Q Network - https://arxiv.org/abs/1710.02298 +# Agent implementation is WIP. Currently has: +# 1. DQN +# 2. C51 +# 3. Prioritized ER +# 4. DDQN +# +# still missing: +# 1. N-Step +# 2. Dueling DQN +class RainbowDQNAgent(CategoricalDQNAgent): + def __init__(self, agent_parameters, parent: Union['LevelManager', 'CompositeAgent']=None): + super().__init__(agent_parameters, parent) + + def learn_from_batch(self, batch): + network_keys = self.ap.network_wrappers['main'].input_embedders_parameters.keys() + + ddqn_selected_actions = np.argmax(self.distribution_prediction_to_q_values( + self.networks['main'].online_network.predict(batch.next_states(network_keys))), axis=1) + + # for the action we actually took, the error is calculated by the atoms distribution + # for all other actions, the error is 0 + distributed_q_st_plus_1, TD_targets = self.networks['main'].parallel_prediction([ + (self.networks['main'].target_network, batch.next_states(network_keys)), + (self.networks['main'].online_network, batch.states(network_keys)) + ]) + + # only update the action that we have actually done in this transition (using the Double-DQN selected actions) + target_actions = ddqn_selected_actions + m = np.zeros((self.ap.network_wrappers['main'].batch_size, self.z_values.size)) + + batches = np.arange(self.ap.network_wrappers['main'].batch_size) + for j in range(self.z_values.size): + tzj = np.fmax(np.fmin(batch.rewards() + + (1.0 - batch.game_overs()) * self.ap.algorithm.discount * self.z_values[j], + self.z_values[self.z_values.size - 1]), + self.z_values[0]) + bj = (tzj - self.z_values[0])/(self.z_values[1] - self.z_values[0]) + u = (np.ceil(bj)).astype(int) + l = (np.floor(bj)).astype(int) + m[batches, l] = m[batches, l] + (distributed_q_st_plus_1[batches, target_actions, j] * (u - bj)) + m[batches, u] = m[batches, u] + (distributed_q_st_plus_1[batches, target_actions, j] * (bj - l)) + + # total_loss = cross entropy between actual result above and predicted result for the given action + TD_targets[batches, batch.actions()] = m + + # update errors in prioritized replay buffer + importance_weights = batch.info('weight') if isinstance(self.memory, PrioritizedExperienceReplay) else None + + result = self.networks['main'].train_and_sync_networks(batch.states(network_keys), TD_targets, + importance_weights=importance_weights) + + total_loss, losses, unclipped_grads = result[:3] + + # TODO: fix this spaghetti code + if isinstance(self.memory, PrioritizedExperienceReplay): + errors = losses[0][np.arange(batch.size), batch.actions()] + self.memory.update_priorities(batch.info('idx'), errors) + + return total_loss, losses, unclipped_grads + diff --git a/rl_coach/architectures/tensorflow_components/heads/rainbow_dqn_head.py b/rl_coach/architectures/tensorflow_components/heads/rainbow_dqn_head.py new file mode 100644 index 0000000..d559894 --- /dev/null +++ b/rl_coach/architectures/tensorflow_components/heads/rainbow_dqn_head.py @@ -0,0 +1,44 @@ +# +# Copyright (c) 2017 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import tensorflow as tf + +from rl_coach.architectures.tensorflow_components.architecture import Dense +from rl_coach.base_parameters import AgentParameters +from rl_coach.spaces import SpacesDefinition + +from rl_coach.architectures.tensorflow_components.heads.head import Head, HeadParameters +from rl_coach.core_types import QActionStateValue + + +class RainbowQHeadParameters(HeadParameters): + def __init__(self, activation_function: str ='relu', name: str='rainbow_q_head_params', dense_layer=Dense): + super().__init__(parameterized_class=RainbowQHead, activation_function=activation_function, name=name, + dense_layer=dense_layer) + + +class RainbowQHead(): + def __init__(self, agent_parameters: AgentParameters, spaces: SpacesDefinition, network_name: str, + head_idx: int = 0, loss_weight: float = 1., is_local: bool = True, activation_function: str ='relu', + dense_layer=Dense): + super().__init__(agent_parameters, spaces, network_name, head_idx, loss_weight, is_local, activation_function, + dense_layer=dense_layer) + self.name = 'rainbow_dqn_head' + self.num_actions = len(self.spaces.action.actions) + self.return_type = QActionStateValue + + def _build_module(self, input_layer): + pass \ No newline at end of file diff --git a/rl_coach/presets/Atari_Rainbow.py b/rl_coach/presets/Atari_Rainbow.py new file mode 100644 index 0000000..b6ccf67 --- /dev/null +++ b/rl_coach/presets/Atari_Rainbow.py @@ -0,0 +1,46 @@ +from rl_coach.agents.categorical_dqn_agent import CategoricalDQNAgentParameters +from rl_coach.agents.rainbow_dqn_agent import RainbowDQNAgentParameters +from rl_coach.base_parameters import VisualizationParameters, PresetValidationParameters +from rl_coach.core_types import EnvironmentSteps, RunPhase +from rl_coach.environments.environment import MaxDumpMethod, SelectedPhaseOnlyDumpMethod, SingleLevelSelection +from rl_coach.environments.gym_environment import Atari, atari_deterministic_v4 +from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager +from rl_coach.graph_managers.graph_manager import ScheduleParameters +from rl_coach.schedules import LinearSchedule + +#################### +# Graph Scheduling # +#################### +schedule_params = ScheduleParameters() +schedule_params.improve_steps = EnvironmentSteps(50000000) +schedule_params.steps_between_evaluation_periods = EnvironmentSteps(250000) +schedule_params.evaluation_steps = EnvironmentSteps(135000) +schedule_params.heatup_steps = EnvironmentSteps(500) + +######### +# Agent # +######### +agent_params = RainbowDQNAgentParameters() +agent_params.network_wrappers['main'].learning_rate = 0.00025 +agent_params.memory.beta = LinearSchedule(0.4, 1, 12500000) # 12.5M training iterations = 50M steps = 200M frames + + +############### +# Environment # +############### +env_params = Atari() +env_params.level = SingleLevelSelection(atari_deterministic_v4) + +vis_params = VisualizationParameters() +vis_params.video_dump_methods = [SelectedPhaseOnlyDumpMethod(RunPhase.TEST), MaxDumpMethod()] +vis_params.dump_mp4 = False + +######## +# Test # +######## +preset_validation_params = PresetValidationParameters() +preset_validation_params.trace_test_levels = ['breakout', 'pong', 'space_invaders'] + +graph_manager = BasicRLGraphManager(agent_params=agent_params, env_params=env_params, + schedule_params=schedule_params, vis_params=vis_params, + preset_validation_params=preset_validation_params)