mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 19:20:19 +01:00
new feature - implementation of Quantile Regression DQN (https://arxiv.org/pdf/1710.10044v1.pdf)
API change - Distributional DQN renamed to Categorical DQN
This commit is contained in:
@@ -195,7 +195,8 @@ python3 coach.py -p Hopper_A3C -n 16
|
|||||||
* [Dueling Q Network](https://arxiv.org/abs/1511.06581)
|
* [Dueling Q Network](https://arxiv.org/abs/1511.06581)
|
||||||
* [Mixed Monte Carlo (MMC)](https://arxiv.org/abs/1703.01310)
|
* [Mixed Monte Carlo (MMC)](https://arxiv.org/abs/1703.01310)
|
||||||
* [Persistent Advantage Learning (PAL)](https://arxiv.org/abs/1512.04860)
|
* [Persistent Advantage Learning (PAL)](https://arxiv.org/abs/1512.04860)
|
||||||
* [Distributional Deep Q Network ](https://arxiv.org/abs/1707.06887)
|
* [Categorical Deep Q Network (C51)](https://arxiv.org/abs/1707.06887)
|
||||||
|
* [Quantile Regression Deep Q Network (QR-DQN)](https://arxiv.org/pdf/1710.10044v1.pdf)
|
||||||
* [Bootstrapped Deep Q Network](https://arxiv.org/abs/1602.04621)
|
* [Bootstrapped Deep Q Network](https://arxiv.org/abs/1602.04621)
|
||||||
* [N-Step Q Learning](https://arxiv.org/abs/1602.01783) | **Distributed**
|
* [N-Step Q Learning](https://arxiv.org/abs/1602.01783) | **Distributed**
|
||||||
* [Neural Episodic Control (NEC) ](https://arxiv.org/abs/1703.01988)
|
* [Neural Episodic Control (NEC) ](https://arxiv.org/abs/1703.01988)
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ from agents.ddpg_agent import *
|
|||||||
from agents.ddqn_agent import *
|
from agents.ddqn_agent import *
|
||||||
from agents.dfp_agent import *
|
from agents.dfp_agent import *
|
||||||
from agents.dqn_agent import *
|
from agents.dqn_agent import *
|
||||||
from agents.distributional_dqn_agent import *
|
from agents.categorical_dqn_agent import *
|
||||||
from agents.mmc_agent import *
|
from agents.mmc_agent import *
|
||||||
from agents.n_step_q_agent import *
|
from agents.n_step_q_agent import *
|
||||||
from agents.naf_agent import *
|
from agents.naf_agent import *
|
||||||
@@ -32,3 +32,4 @@ from agents.policy_gradients_agent import *
|
|||||||
from agents.policy_optimization_agent import *
|
from agents.policy_optimization_agent import *
|
||||||
from agents.ppo_agent import *
|
from agents.ppo_agent import *
|
||||||
from agents.value_optimization_agent import *
|
from agents.value_optimization_agent import *
|
||||||
|
from agents.qr_dqn_agent import *
|
||||||
|
|||||||
@@ -17,8 +17,8 @@
|
|||||||
from agents.value_optimization_agent import *
|
from agents.value_optimization_agent import *
|
||||||
|
|
||||||
|
|
||||||
# Distributional Deep Q Network - https://arxiv.org/pdf/1707.06887.pdf
|
# Categorical Deep Q Network - https://arxiv.org/pdf/1707.06887.pdf
|
||||||
class DistributionalDQNAgent(ValueOptimizationAgent):
|
class CategoricalDQNAgent(ValueOptimizationAgent):
|
||||||
def __init__(self, env, tuning_parameters, replicated_device=None, thread_id=0):
|
def __init__(self, env, tuning_parameters, replicated_device=None, thread_id=0):
|
||||||
ValueOptimizationAgent.__init__(self, env, tuning_parameters, replicated_device, thread_id)
|
ValueOptimizationAgent.__init__(self, env, tuning_parameters, replicated_device, thread_id)
|
||||||
self.z_values = np.linspace(self.tp.agent.v_min, self.tp.agent.v_max, self.tp.agent.atoms)
|
self.z_values = np.linspace(self.tp.agent.v_min, self.tp.agent.v_max, self.tp.agent.atoms)
|
||||||
62
agents/qr_dqn_agent.py
Normal file
62
agents/qr_dqn_agent.py
Normal file
@@ -0,0 +1,62 @@
|
|||||||
|
#
|
||||||
|
# Copyright (c) 2017 Intel Corporation
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
|
||||||
|
from agents.value_optimization_agent import *
|
||||||
|
|
||||||
|
|
||||||
|
# Quantile Regression Deep Q Network - https://arxiv.org/pdf/1710.10044v1.pdf
|
||||||
|
class QuantileRegressionDQNAgent(ValueOptimizationAgent):
|
||||||
|
def __init__(self, env, tuning_parameters, replicated_device=None, thread_id=0):
|
||||||
|
ValueOptimizationAgent.__init__(self, env, tuning_parameters, replicated_device, thread_id)
|
||||||
|
self.quantile_probabilities = np.ones(self.tp.agent.atoms) / float(self.tp.agent.atoms)
|
||||||
|
|
||||||
|
# prediction's format is (batch,actions,atoms)
|
||||||
|
def get_q_values(self, quantile_values):
|
||||||
|
return np.dot(quantile_values, self.quantile_probabilities)
|
||||||
|
|
||||||
|
def learn_from_batch(self, batch):
|
||||||
|
current_states, next_states, actions, rewards, game_overs, _ = self.extract_batch(batch)
|
||||||
|
|
||||||
|
# get the quantiles of the next states and current states
|
||||||
|
next_state_quantiles = self.main_network.target_network.predict(next_states)
|
||||||
|
current_quantiles = self.main_network.online_network.predict(current_states)
|
||||||
|
|
||||||
|
# get the optimal actions to take for the next states
|
||||||
|
target_actions = np.argmax(self.get_q_values(next_state_quantiles), axis=1)
|
||||||
|
|
||||||
|
# calculate the Bellman update
|
||||||
|
batch_idx = list(range(self.tp.batch_size))
|
||||||
|
rewards = np.expand_dims(rewards, -1)
|
||||||
|
game_overs = np.expand_dims(game_overs, -1)
|
||||||
|
TD_targets = rewards + (1.0 - game_overs) * self.tp.agent.discount \
|
||||||
|
* next_state_quantiles[batch_idx, target_actions]
|
||||||
|
|
||||||
|
# get the locations of the selected actions within the batch for indexing purposes
|
||||||
|
actions_locations = [[b, a] for b, a in zip(batch_idx, actions)]
|
||||||
|
|
||||||
|
# calculate the cumulative quantile probabilities and reorder them to fit the sorted quantiles order
|
||||||
|
cumulative_probabilities = np.array(range(self.tp.agent.atoms+1))/float(self.tp.agent.atoms) # tau_i
|
||||||
|
quantile_midpoints = 0.5*(cumulative_probabilities[1:] + cumulative_probabilities[:-1]) # tau^hat_i
|
||||||
|
quantile_midpoints = np.tile(quantile_midpoints, (self.tp.batch_size, 1))
|
||||||
|
for idx in range(self.tp.batch_size):
|
||||||
|
quantile_midpoints[idx, :] = quantile_midpoints[idx, np.argsort(current_quantiles[batch_idx, actions])[idx]]
|
||||||
|
|
||||||
|
# train
|
||||||
|
result = self.main_network.train_and_sync_networks([current_states, actions_locations, quantile_midpoints], TD_targets)
|
||||||
|
total_loss = result[0]
|
||||||
|
|
||||||
|
return total_loss
|
||||||
|
|
||||||
@@ -80,7 +80,8 @@ class GeneralTensorFlowNetwork(TensorFlowArchitecture):
|
|||||||
OutputTypes.NAF: NAFHead,
|
OutputTypes.NAF: NAFHead,
|
||||||
OutputTypes.PPO: PPOHead,
|
OutputTypes.PPO: PPOHead,
|
||||||
OutputTypes.PPO_V : PPOVHead,
|
OutputTypes.PPO_V : PPOVHead,
|
||||||
OutputTypes.DistributionalQ: DistributionalQHead
|
OutputTypes.CategoricalQ: CategoricalQHead,
|
||||||
|
OutputTypes.QuantileRegressionQ: QuantileRegressionQHead
|
||||||
}
|
}
|
||||||
return output_mapping[head_type](self.tp, head_idx, loss_weight, self.network_is_local)
|
return output_mapping[head_type](self.tp, head_idx, loss_weight, self.network_is_local)
|
||||||
|
|
||||||
|
|||||||
@@ -462,10 +462,10 @@ class PPOVHead(Head):
|
|||||||
tf.losses.add_loss(self.loss)
|
tf.losses.add_loss(self.loss)
|
||||||
|
|
||||||
|
|
||||||
class DistributionalQHead(Head):
|
class CategoricalQHead(Head):
|
||||||
def __init__(self, tuning_parameters, head_idx=0, loss_weight=1., is_local=True):
|
def __init__(self, tuning_parameters, head_idx=0, loss_weight=1., is_local=True):
|
||||||
Head.__init__(self, tuning_parameters, head_idx, loss_weight, is_local)
|
Head.__init__(self, tuning_parameters, head_idx, loss_weight, is_local)
|
||||||
self.name = 'distributional_dqn_head'
|
self.name = 'categorical_dqn_head'
|
||||||
self.num_actions = tuning_parameters.env_instance.action_space_size
|
self.num_actions = tuning_parameters.env_instance.action_space_size
|
||||||
self.num_atoms = tuning_parameters.agent.atoms
|
self.num_atoms = tuning_parameters.agent.atoms
|
||||||
|
|
||||||
@@ -484,3 +484,47 @@ class DistributionalQHead(Head):
|
|||||||
self.loss = tf.nn.softmax_cross_entropy_with_logits(labels=self.target, logits=values_distribution)
|
self.loss = tf.nn.softmax_cross_entropy_with_logits(labels=self.target, logits=values_distribution)
|
||||||
tf.losses.add_loss(self.loss)
|
tf.losses.add_loss(self.loss)
|
||||||
|
|
||||||
|
|
||||||
|
class QuantileRegressionQHead(Head):
|
||||||
|
def __init__(self, tuning_parameters, head_idx=0, loss_weight=1., is_local=True):
|
||||||
|
Head.__init__(self, tuning_parameters, head_idx, loss_weight, is_local)
|
||||||
|
self.name = 'quantile_regression_dqn_head'
|
||||||
|
self.num_actions = tuning_parameters.env_instance.action_space_size
|
||||||
|
self.num_atoms = tuning_parameters.agent.atoms # we use atom / quantile interchangeably
|
||||||
|
self.huber_loss_interval = 1 # k
|
||||||
|
|
||||||
|
def _build_module(self, input_layer):
|
||||||
|
self.actions = tf.placeholder(tf.int32, [None, 2], name="actions")
|
||||||
|
self.quantile_midpoints = tf.placeholder(tf.float32, [None, self.num_atoms], name="quantile_midpoints")
|
||||||
|
self.input = [self.actions, self.quantile_midpoints]
|
||||||
|
|
||||||
|
# the output of the head is the N unordered quantile locations {theta_1, ..., theta_N}
|
||||||
|
quantiles_locations = tf.layers.dense(input_layer, self.num_actions * self.num_atoms)
|
||||||
|
quantiles_locations = tf.reshape(quantiles_locations, (tf.shape(quantiles_locations)[0], self.num_actions, self.num_atoms))
|
||||||
|
self.output = quantiles_locations
|
||||||
|
|
||||||
|
self.quantiles = tf.placeholder(tf.float32, shape=(None, self.num_atoms), name="quantiles")
|
||||||
|
self.target = self.quantiles
|
||||||
|
|
||||||
|
# only the quantiles of the taken action are taken into account
|
||||||
|
quantiles_for_used_actions = tf.gather_nd(quantiles_locations, self.actions)
|
||||||
|
|
||||||
|
# reorder the output quantiles and the target quantiles as a preparation step for calculating the loss
|
||||||
|
# the output quantiles vector and the quantile midpoints are tiled as rows of a NxN matrix (N = num quantiles)
|
||||||
|
# the target quantiles vector is tiled as column of a NxN matrix
|
||||||
|
theta_i = tf.tile(tf.expand_dims(quantiles_for_used_actions, -1), [1, 1, self.num_atoms])
|
||||||
|
T_theta_j = tf.tile(tf.expand_dims(self.target, -2), [1, self.num_atoms, 1])
|
||||||
|
tau_i = tf.tile(tf.expand_dims(self.quantile_midpoints, -1), [1, 1, self.num_atoms])
|
||||||
|
|
||||||
|
# Huber loss of T(theta_j) - theta_i
|
||||||
|
abs_error = tf.abs(T_theta_j - theta_i)
|
||||||
|
quadratic = tf.minimum(abs_error, self.huber_loss_interval)
|
||||||
|
huber_loss = self.huber_loss_interval * (abs_error - quadratic) + 0.5 * quadratic ** 2
|
||||||
|
|
||||||
|
# Quantile Huber loss
|
||||||
|
quantile_huber_loss = tf.abs(tau_i - tf.cast(abs_error < 0, dtype=tf.float32)) * huber_loss
|
||||||
|
|
||||||
|
# Quantile regression loss (the probability for each quantile is 1/num_quantiles)
|
||||||
|
quantile_regression_loss = tf.reduce_sum(quantile_huber_loss) / float(self.num_atoms)
|
||||||
|
self.loss = quantile_regression_loss
|
||||||
|
tf.losses.add_loss(self.loss)
|
||||||
|
|||||||
@@ -42,7 +42,8 @@ class OutputTypes(object):
|
|||||||
NAF = 7
|
NAF = 7
|
||||||
PPO = 8
|
PPO = 8
|
||||||
PPO_V = 9
|
PPO_V = 9
|
||||||
DistributionalQ = 10
|
CategoricalQ = 10
|
||||||
|
QuantileRegressionQ = 11
|
||||||
|
|
||||||
|
|
||||||
class MiddlewareTypes(object):
|
class MiddlewareTypes(object):
|
||||||
@@ -307,14 +308,20 @@ class BootstrappedDQN(DQN):
|
|||||||
num_output_head_copies = 10
|
num_output_head_copies = 10
|
||||||
|
|
||||||
|
|
||||||
class DistributionalDQN(DQN):
|
class CategoricalDQN(DQN):
|
||||||
type = 'DistributionalDQNAgent'
|
type = 'CategoricalDQNAgent'
|
||||||
output_types = [OutputTypes.DistributionalQ]
|
output_types = [OutputTypes.CategoricalQ]
|
||||||
v_min = -10.0
|
v_min = -10.0
|
||||||
v_max = 10.0
|
v_max = 10.0
|
||||||
atoms = 51
|
atoms = 51
|
||||||
|
|
||||||
|
|
||||||
|
class QuantileRegressionDQN(DQN):
|
||||||
|
type = 'QuantileRegressionDQNAgent'
|
||||||
|
output_types = [OutputTypes.QuantileRegressionQ]
|
||||||
|
atoms = 51
|
||||||
|
|
||||||
|
|
||||||
class NEC(AgentParameters):
|
class NEC(AgentParameters):
|
||||||
type = 'NECAgent'
|
type = 'NECAgent'
|
||||||
optimizer_type = 'RMSProp'
|
optimizer_type = 'RMSProp'
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
# Distributional DQN
|
# Categorical DQN
|
||||||
|
|
||||||
**Actions space:** Discrete
|
**Actions space:** Discrete
|
||||||
|
|
||||||
@@ -15,7 +15,7 @@ pages:
|
|||||||
- 'DQN' : algorithms/value_optimization/dqn.md
|
- 'DQN' : algorithms/value_optimization/dqn.md
|
||||||
- 'Double DQN' : algorithms/value_optimization/double_dqn.md
|
- 'Double DQN' : algorithms/value_optimization/double_dqn.md
|
||||||
- 'Dueling DQN' : algorithms/value_optimization/dueling_dqn.md
|
- 'Dueling DQN' : algorithms/value_optimization/dueling_dqn.md
|
||||||
- 'Distributional DQN' : algorithms/value_optimization/distributional_dqn.md
|
- 'Categorical DQN' : algorithms/value_optimization/categorical_dqn.md
|
||||||
- 'Mixed Monte Carlo' : algorithms/value_optimization/mmc.md
|
- 'Mixed Monte Carlo' : algorithms/value_optimization/mmc.md
|
||||||
- 'Persistent Advantage Learning' : algorithms/value_optimization/pal.md
|
- 'Persistent Advantage Learning' : algorithms/value_optimization/pal.md
|
||||||
- 'Neural Episodic Control' : algorithms/value_optimization/nec.md
|
- 'Neural Episodic Control' : algorithms/value_optimization/nec.md
|
||||||
|
|||||||
32
presets.py
32
presets.py
@@ -70,6 +70,18 @@ class Doom_Basic_DQN(Preset):
|
|||||||
self.num_heatup_steps = 1000
|
self.num_heatup_steps = 1000
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class Doom_Basic_QRDQN(Preset):
|
||||||
|
def __init__(self):
|
||||||
|
Preset.__init__(self, QuantileRegressionDQN, Doom, ExplorationParameters)
|
||||||
|
self.env.level = 'basic'
|
||||||
|
self.agent.num_steps_between_copying_online_weights_to_target = 1000
|
||||||
|
self.learning_rate = 0.00025
|
||||||
|
self.agent.num_episodes_in_experience_replay = 200
|
||||||
|
self.num_heatup_steps = 1000
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class Doom_Basic_OneStepQ(Preset):
|
class Doom_Basic_OneStepQ(Preset):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
Preset.__init__(self, NStepQ, Doom, ExplorationParameters)
|
Preset.__init__(self, NStepQ, Doom, ExplorationParameters)
|
||||||
@@ -340,9 +352,9 @@ class CartPole_DQN(Preset):
|
|||||||
self.test_min_return_threshold = 150
|
self.test_min_return_threshold = 150
|
||||||
|
|
||||||
|
|
||||||
class CartPole_DistributionalDQN(Preset):
|
class CartPole_C51(Preset):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
Preset.__init__(self, DistributionalDQN, GymVectorObservation, ExplorationParameters)
|
Preset.__init__(self, CategoricalDQN, GymVectorObservation, ExplorationParameters)
|
||||||
self.env.level = 'CartPole-v0'
|
self.env.level = 'CartPole-v0'
|
||||||
self.agent.num_steps_between_copying_online_weights_to_target = 100
|
self.agent.num_steps_between_copying_online_weights_to_target = 100
|
||||||
self.learning_rate = 0.00025
|
self.learning_rate = 0.00025
|
||||||
@@ -356,6 +368,18 @@ class CartPole_DistributionalDQN(Preset):
|
|||||||
self.test_min_return_threshold = 150
|
self.test_min_return_threshold = 150
|
||||||
|
|
||||||
|
|
||||||
|
class CartPole_QRDQN(Preset):
|
||||||
|
def __init__(self):
|
||||||
|
Preset.__init__(self, QuantileRegressionDQN, GymVectorObservation, ExplorationParameters)
|
||||||
|
self.env.level = 'CartPole-v0'
|
||||||
|
self.agent.num_steps_between_copying_online_weights_to_target = 100
|
||||||
|
self.learning_rate = 0.00025
|
||||||
|
self.agent.num_episodes_in_experience_replay = 200
|
||||||
|
self.num_heatup_steps = 1000
|
||||||
|
self.exploration.epsilon_decay_steps = 3000
|
||||||
|
self.agent.discount = 1.0
|
||||||
|
|
||||||
|
|
||||||
# The below preset matches the hyper-parameters setting as in the original DQN paper.
|
# The below preset matches the hyper-parameters setting as in the original DQN paper.
|
||||||
# This a very resource intensive preset, and might easily blow up your RAM (> 100GB of usage).
|
# This a very resource intensive preset, and might easily blow up your RAM (> 100GB of usage).
|
||||||
# Try reducing the number of transitions in the experience replay (50e3 might be a reasonable number to start with),
|
# Try reducing the number of transitions in the experience replay (50e3 might be a reasonable number to start with),
|
||||||
@@ -377,9 +401,9 @@ class Breakout_DQN(Preset):
|
|||||||
self.evaluate_every_x_episodes = 100
|
self.evaluate_every_x_episodes = 100
|
||||||
|
|
||||||
|
|
||||||
class Breakout_DistributionalDQN(Preset):
|
class Breakout_C51(Preset):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
Preset.__init__(self, DistributionalDQN, Atari, ExplorationParameters)
|
Preset.__init__(self, CategoricalDQN, Atari, ExplorationParameters)
|
||||||
self.env.level = 'BreakoutDeterministic-v4'
|
self.env.level = 'BreakoutDeterministic-v4'
|
||||||
self.agent.num_steps_between_copying_online_weights_to_target = 10000
|
self.agent.num_steps_between_copying_online_weights_to_target = 10000
|
||||||
self.learning_rate = 0.00025
|
self.learning_rate = 0.00025
|
||||||
|
|||||||
Reference in New Issue
Block a user