From a8bce9828c358babf0e63aac0594ce180d047144 Mon Sep 17 00:00:00 2001 From: Itai Caspi Date: Wed, 1 Nov 2017 15:09:07 +0200 Subject: [PATCH] new feature - implementation of Quantile Regression DQN (https://arxiv.org/pdf/1710.10044v1.pdf) API change - Distributional DQN renamed to Categorical DQN --- README.md | 3 +- agents/__init__.py | 3 +- ..._dqn_agent.py => categorical_dqn_agent.py} | 4 +- agents/qr_dqn_agent.py | 62 +++++++++++++++++++ .../tensorflow_components/general_network.py | 3 +- architectures/tensorflow_components/heads.py | 48 +++++++++++++- configurations.py | 15 +++-- ...stributional_dqn.md => categorical_dqn.md} | 2 +- docs/mkdocs.yml | 2 +- presets.py | 32 ++++++++-- 10 files changed, 157 insertions(+), 17 deletions(-) rename agents/{distributional_dqn_agent.py => categorical_dqn_agent.py} (95%) create mode 100644 agents/qr_dqn_agent.py rename docs/docs/algorithms/value_optimization/{distributional_dqn.md => categorical_dqn.md} (98%) diff --git a/README.md b/README.md index 866ff5c..ea00fa5 100644 --- a/README.md +++ b/README.md @@ -195,7 +195,8 @@ python3 coach.py -p Hopper_A3C -n 16 * [Dueling Q Network](https://arxiv.org/abs/1511.06581) * [Mixed Monte Carlo (MMC)](https://arxiv.org/abs/1703.01310) * [Persistent Advantage Learning (PAL)](https://arxiv.org/abs/1512.04860) -* [Distributional Deep Q Network ](https://arxiv.org/abs/1707.06887) +* [Categorical Deep Q Network (C51)](https://arxiv.org/abs/1707.06887) +* [Quantile Regression Deep Q Network (QR-DQN)](https://arxiv.org/pdf/1710.10044v1.pdf) * [Bootstrapped Deep Q Network](https://arxiv.org/abs/1602.04621) * [N-Step Q Learning](https://arxiv.org/abs/1602.01783) | **Distributed** * [Neural Episodic Control (NEC) ](https://arxiv.org/abs/1703.01988) diff --git a/agents/__init__.py b/agents/__init__.py index c8d342a..b1ae8d3 100644 --- a/agents/__init__.py +++ b/agents/__init__.py @@ -22,7 +22,7 @@ from agents.ddpg_agent import * from agents.ddqn_agent import * from agents.dfp_agent import * from agents.dqn_agent import * -from agents.distributional_dqn_agent import * +from agents.categorical_dqn_agent import * from agents.mmc_agent import * from agents.n_step_q_agent import * from agents.naf_agent import * @@ -32,3 +32,4 @@ from agents.policy_gradients_agent import * from agents.policy_optimization_agent import * from agents.ppo_agent import * from agents.value_optimization_agent import * +from agents.qr_dqn_agent import * diff --git a/agents/distributional_dqn_agent.py b/agents/categorical_dqn_agent.py similarity index 95% rename from agents/distributional_dqn_agent.py rename to agents/categorical_dqn_agent.py index d7c0088..dec8ba2 100644 --- a/agents/distributional_dqn_agent.py +++ b/agents/categorical_dqn_agent.py @@ -17,8 +17,8 @@ from agents.value_optimization_agent import * -# Distributional Deep Q Network - https://arxiv.org/pdf/1707.06887.pdf -class DistributionalDQNAgent(ValueOptimizationAgent): +# Categorical Deep Q Network - https://arxiv.org/pdf/1707.06887.pdf +class CategoricalDQNAgent(ValueOptimizationAgent): def __init__(self, env, tuning_parameters, replicated_device=None, thread_id=0): ValueOptimizationAgent.__init__(self, env, tuning_parameters, replicated_device, thread_id) self.z_values = np.linspace(self.tp.agent.v_min, self.tp.agent.v_max, self.tp.agent.atoms) diff --git a/agents/qr_dqn_agent.py b/agents/qr_dqn_agent.py new file mode 100644 index 0000000..f23e383 --- /dev/null +++ b/agents/qr_dqn_agent.py @@ -0,0 +1,62 @@ +# +# Copyright (c) 2017 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from agents.value_optimization_agent import * + + +# Quantile Regression Deep Q Network - https://arxiv.org/pdf/1710.10044v1.pdf +class QuantileRegressionDQNAgent(ValueOptimizationAgent): + def __init__(self, env, tuning_parameters, replicated_device=None, thread_id=0): + ValueOptimizationAgent.__init__(self, env, tuning_parameters, replicated_device, thread_id) + self.quantile_probabilities = np.ones(self.tp.agent.atoms) / float(self.tp.agent.atoms) + + # prediction's format is (batch,actions,atoms) + def get_q_values(self, quantile_values): + return np.dot(quantile_values, self.quantile_probabilities) + + def learn_from_batch(self, batch): + current_states, next_states, actions, rewards, game_overs, _ = self.extract_batch(batch) + + # get the quantiles of the next states and current states + next_state_quantiles = self.main_network.target_network.predict(next_states) + current_quantiles = self.main_network.online_network.predict(current_states) + + # get the optimal actions to take for the next states + target_actions = np.argmax(self.get_q_values(next_state_quantiles), axis=1) + + # calculate the Bellman update + batch_idx = list(range(self.tp.batch_size)) + rewards = np.expand_dims(rewards, -1) + game_overs = np.expand_dims(game_overs, -1) + TD_targets = rewards + (1.0 - game_overs) * self.tp.agent.discount \ + * next_state_quantiles[batch_idx, target_actions] + + # get the locations of the selected actions within the batch for indexing purposes + actions_locations = [[b, a] for b, a in zip(batch_idx, actions)] + + # calculate the cumulative quantile probabilities and reorder them to fit the sorted quantiles order + cumulative_probabilities = np.array(range(self.tp.agent.atoms+1))/float(self.tp.agent.atoms) # tau_i + quantile_midpoints = 0.5*(cumulative_probabilities[1:] + cumulative_probabilities[:-1]) # tau^hat_i + quantile_midpoints = np.tile(quantile_midpoints, (self.tp.batch_size, 1)) + for idx in range(self.tp.batch_size): + quantile_midpoints[idx, :] = quantile_midpoints[idx, np.argsort(current_quantiles[batch_idx, actions])[idx]] + + # train + result = self.main_network.train_and_sync_networks([current_states, actions_locations, quantile_midpoints], TD_targets) + total_loss = result[0] + + return total_loss + diff --git a/architectures/tensorflow_components/general_network.py b/architectures/tensorflow_components/general_network.py index a3ff5f1..9b20082 100644 --- a/architectures/tensorflow_components/general_network.py +++ b/architectures/tensorflow_components/general_network.py @@ -80,7 +80,8 @@ class GeneralTensorFlowNetwork(TensorFlowArchitecture): OutputTypes.NAF: NAFHead, OutputTypes.PPO: PPOHead, OutputTypes.PPO_V : PPOVHead, - OutputTypes.DistributionalQ: DistributionalQHead + OutputTypes.CategoricalQ: CategoricalQHead, + OutputTypes.QuantileRegressionQ: QuantileRegressionQHead } return output_mapping[head_type](self.tp, head_idx, loss_weight, self.network_is_local) diff --git a/architectures/tensorflow_components/heads.py b/architectures/tensorflow_components/heads.py index ab2bc2c..1ebc10c 100644 --- a/architectures/tensorflow_components/heads.py +++ b/architectures/tensorflow_components/heads.py @@ -462,10 +462,10 @@ class PPOVHead(Head): tf.losses.add_loss(self.loss) -class DistributionalQHead(Head): +class CategoricalQHead(Head): def __init__(self, tuning_parameters, head_idx=0, loss_weight=1., is_local=True): Head.__init__(self, tuning_parameters, head_idx, loss_weight, is_local) - self.name = 'distributional_dqn_head' + self.name = 'categorical_dqn_head' self.num_actions = tuning_parameters.env_instance.action_space_size self.num_atoms = tuning_parameters.agent.atoms @@ -484,3 +484,47 @@ class DistributionalQHead(Head): self.loss = tf.nn.softmax_cross_entropy_with_logits(labels=self.target, logits=values_distribution) tf.losses.add_loss(self.loss) + +class QuantileRegressionQHead(Head): + def __init__(self, tuning_parameters, head_idx=0, loss_weight=1., is_local=True): + Head.__init__(self, tuning_parameters, head_idx, loss_weight, is_local) + self.name = 'quantile_regression_dqn_head' + self.num_actions = tuning_parameters.env_instance.action_space_size + self.num_atoms = tuning_parameters.agent.atoms # we use atom / quantile interchangeably + self.huber_loss_interval = 1 # k + + def _build_module(self, input_layer): + self.actions = tf.placeholder(tf.int32, [None, 2], name="actions") + self.quantile_midpoints = tf.placeholder(tf.float32, [None, self.num_atoms], name="quantile_midpoints") + self.input = [self.actions, self.quantile_midpoints] + + # the output of the head is the N unordered quantile locations {theta_1, ..., theta_N} + quantiles_locations = tf.layers.dense(input_layer, self.num_actions * self.num_atoms) + quantiles_locations = tf.reshape(quantiles_locations, (tf.shape(quantiles_locations)[0], self.num_actions, self.num_atoms)) + self.output = quantiles_locations + + self.quantiles = tf.placeholder(tf.float32, shape=(None, self.num_atoms), name="quantiles") + self.target = self.quantiles + + # only the quantiles of the taken action are taken into account + quantiles_for_used_actions = tf.gather_nd(quantiles_locations, self.actions) + + # reorder the output quantiles and the target quantiles as a preparation step for calculating the loss + # the output quantiles vector and the quantile midpoints are tiled as rows of a NxN matrix (N = num quantiles) + # the target quantiles vector is tiled as column of a NxN matrix + theta_i = tf.tile(tf.expand_dims(quantiles_for_used_actions, -1), [1, 1, self.num_atoms]) + T_theta_j = tf.tile(tf.expand_dims(self.target, -2), [1, self.num_atoms, 1]) + tau_i = tf.tile(tf.expand_dims(self.quantile_midpoints, -1), [1, 1, self.num_atoms]) + + # Huber loss of T(theta_j) - theta_i + abs_error = tf.abs(T_theta_j - theta_i) + quadratic = tf.minimum(abs_error, self.huber_loss_interval) + huber_loss = self.huber_loss_interval * (abs_error - quadratic) + 0.5 * quadratic ** 2 + + # Quantile Huber loss + quantile_huber_loss = tf.abs(tau_i - tf.cast(abs_error < 0, dtype=tf.float32)) * huber_loss + + # Quantile regression loss (the probability for each quantile is 1/num_quantiles) + quantile_regression_loss = tf.reduce_sum(quantile_huber_loss) / float(self.num_atoms) + self.loss = quantile_regression_loss + tf.losses.add_loss(self.loss) diff --git a/configurations.py b/configurations.py index 976cf17..b7f9953 100644 --- a/configurations.py +++ b/configurations.py @@ -42,7 +42,8 @@ class OutputTypes(object): NAF = 7 PPO = 8 PPO_V = 9 - DistributionalQ = 10 + CategoricalQ = 10 + QuantileRegressionQ = 11 class MiddlewareTypes(object): @@ -307,14 +308,20 @@ class BootstrappedDQN(DQN): num_output_head_copies = 10 -class DistributionalDQN(DQN): - type = 'DistributionalDQNAgent' - output_types = [OutputTypes.DistributionalQ] +class CategoricalDQN(DQN): + type = 'CategoricalDQNAgent' + output_types = [OutputTypes.CategoricalQ] v_min = -10.0 v_max = 10.0 atoms = 51 +class QuantileRegressionDQN(DQN): + type = 'QuantileRegressionDQNAgent' + output_types = [OutputTypes.QuantileRegressionQ] + atoms = 51 + + class NEC(AgentParameters): type = 'NECAgent' optimizer_type = 'RMSProp' diff --git a/docs/docs/algorithms/value_optimization/distributional_dqn.md b/docs/docs/algorithms/value_optimization/categorical_dqn.md similarity index 98% rename from docs/docs/algorithms/value_optimization/distributional_dqn.md rename to docs/docs/algorithms/value_optimization/categorical_dqn.md index 009a518..e4b2983 100644 --- a/docs/docs/algorithms/value_optimization/distributional_dqn.md +++ b/docs/docs/algorithms/value_optimization/categorical_dqn.md @@ -1,4 +1,4 @@ -# Distributional DQN +# Categorical DQN **Actions space:** Discrete diff --git a/docs/mkdocs.yml b/docs/mkdocs.yml index b775463..48fca7d 100644 --- a/docs/mkdocs.yml +++ b/docs/mkdocs.yml @@ -15,7 +15,7 @@ pages: - 'DQN' : algorithms/value_optimization/dqn.md - 'Double DQN' : algorithms/value_optimization/double_dqn.md - 'Dueling DQN' : algorithms/value_optimization/dueling_dqn.md - - 'Distributional DQN' : algorithms/value_optimization/distributional_dqn.md + - 'Categorical DQN' : algorithms/value_optimization/categorical_dqn.md - 'Mixed Monte Carlo' : algorithms/value_optimization/mmc.md - 'Persistent Advantage Learning' : algorithms/value_optimization/pal.md - 'Neural Episodic Control' : algorithms/value_optimization/nec.md diff --git a/presets.py b/presets.py index 51bba41..fad9a0e 100644 --- a/presets.py +++ b/presets.py @@ -70,6 +70,18 @@ class Doom_Basic_DQN(Preset): self.num_heatup_steps = 1000 + +class Doom_Basic_QRDQN(Preset): + def __init__(self): + Preset.__init__(self, QuantileRegressionDQN, Doom, ExplorationParameters) + self.env.level = 'basic' + self.agent.num_steps_between_copying_online_weights_to_target = 1000 + self.learning_rate = 0.00025 + self.agent.num_episodes_in_experience_replay = 200 + self.num_heatup_steps = 1000 + + + class Doom_Basic_OneStepQ(Preset): def __init__(self): Preset.__init__(self, NStepQ, Doom, ExplorationParameters) @@ -340,9 +352,9 @@ class CartPole_DQN(Preset): self.test_min_return_threshold = 150 -class CartPole_DistributionalDQN(Preset): +class CartPole_C51(Preset): def __init__(self): - Preset.__init__(self, DistributionalDQN, GymVectorObservation, ExplorationParameters) + Preset.__init__(self, CategoricalDQN, GymVectorObservation, ExplorationParameters) self.env.level = 'CartPole-v0' self.agent.num_steps_between_copying_online_weights_to_target = 100 self.learning_rate = 0.00025 @@ -356,6 +368,18 @@ class CartPole_DistributionalDQN(Preset): self.test_min_return_threshold = 150 +class CartPole_QRDQN(Preset): + def __init__(self): + Preset.__init__(self, QuantileRegressionDQN, GymVectorObservation, ExplorationParameters) + self.env.level = 'CartPole-v0' + self.agent.num_steps_between_copying_online_weights_to_target = 100 + self.learning_rate = 0.00025 + self.agent.num_episodes_in_experience_replay = 200 + self.num_heatup_steps = 1000 + self.exploration.epsilon_decay_steps = 3000 + self.agent.discount = 1.0 + + # The below preset matches the hyper-parameters setting as in the original DQN paper. # This a very resource intensive preset, and might easily blow up your RAM (> 100GB of usage). # Try reducing the number of transitions in the experience replay (50e3 might be a reasonable number to start with), @@ -377,9 +401,9 @@ class Breakout_DQN(Preset): self.evaluate_every_x_episodes = 100 -class Breakout_DistributionalDQN(Preset): +class Breakout_C51(Preset): def __init__(self): - Preset.__init__(self, DistributionalDQN, Atari, ExplorationParameters) + Preset.__init__(self, CategoricalDQN, Atari, ExplorationParameters) self.env.level = 'BreakoutDeterministic-v4' self.agent.num_steps_between_copying_online_weights_to_target = 10000 self.learning_rate = 0.00025