1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-18 03:30:19 +01:00

new feature - implementation of Quantile Regression DQN (https://arxiv.org/pdf/1710.10044v1.pdf)

API change - Distributional DQN renamed to Categorical DQN
This commit is contained in:
Itai Caspi
2017-11-01 15:09:07 +02:00
parent 1ad6262307
commit a8bce9828c
10 changed files with 157 additions and 17 deletions

View File

@@ -80,7 +80,8 @@ class GeneralTensorFlowNetwork(TensorFlowArchitecture):
OutputTypes.NAF: NAFHead,
OutputTypes.PPO: PPOHead,
OutputTypes.PPO_V : PPOVHead,
OutputTypes.DistributionalQ: DistributionalQHead
OutputTypes.CategoricalQ: CategoricalQHead,
OutputTypes.QuantileRegressionQ: QuantileRegressionQHead
}
return output_mapping[head_type](self.tp, head_idx, loss_weight, self.network_is_local)

View File

@@ -462,10 +462,10 @@ class PPOVHead(Head):
tf.losses.add_loss(self.loss)
class DistributionalQHead(Head):
class CategoricalQHead(Head):
def __init__(self, tuning_parameters, head_idx=0, loss_weight=1., is_local=True):
Head.__init__(self, tuning_parameters, head_idx, loss_weight, is_local)
self.name = 'distributional_dqn_head'
self.name = 'categorical_dqn_head'
self.num_actions = tuning_parameters.env_instance.action_space_size
self.num_atoms = tuning_parameters.agent.atoms
@@ -484,3 +484,47 @@ class DistributionalQHead(Head):
self.loss = tf.nn.softmax_cross_entropy_with_logits(labels=self.target, logits=values_distribution)
tf.losses.add_loss(self.loss)
class QuantileRegressionQHead(Head):
def __init__(self, tuning_parameters, head_idx=0, loss_weight=1., is_local=True):
Head.__init__(self, tuning_parameters, head_idx, loss_weight, is_local)
self.name = 'quantile_regression_dqn_head'
self.num_actions = tuning_parameters.env_instance.action_space_size
self.num_atoms = tuning_parameters.agent.atoms # we use atom / quantile interchangeably
self.huber_loss_interval = 1 # k
def _build_module(self, input_layer):
self.actions = tf.placeholder(tf.int32, [None, 2], name="actions")
self.quantile_midpoints = tf.placeholder(tf.float32, [None, self.num_atoms], name="quantile_midpoints")
self.input = [self.actions, self.quantile_midpoints]
# the output of the head is the N unordered quantile locations {theta_1, ..., theta_N}
quantiles_locations = tf.layers.dense(input_layer, self.num_actions * self.num_atoms)
quantiles_locations = tf.reshape(quantiles_locations, (tf.shape(quantiles_locations)[0], self.num_actions, self.num_atoms))
self.output = quantiles_locations
self.quantiles = tf.placeholder(tf.float32, shape=(None, self.num_atoms), name="quantiles")
self.target = self.quantiles
# only the quantiles of the taken action are taken into account
quantiles_for_used_actions = tf.gather_nd(quantiles_locations, self.actions)
# reorder the output quantiles and the target quantiles as a preparation step for calculating the loss
# the output quantiles vector and the quantile midpoints are tiled as rows of a NxN matrix (N = num quantiles)
# the target quantiles vector is tiled as column of a NxN matrix
theta_i = tf.tile(tf.expand_dims(quantiles_for_used_actions, -1), [1, 1, self.num_atoms])
T_theta_j = tf.tile(tf.expand_dims(self.target, -2), [1, self.num_atoms, 1])
tau_i = tf.tile(tf.expand_dims(self.quantile_midpoints, -1), [1, 1, self.num_atoms])
# Huber loss of T(theta_j) - theta_i
abs_error = tf.abs(T_theta_j - theta_i)
quadratic = tf.minimum(abs_error, self.huber_loss_interval)
huber_loss = self.huber_loss_interval * (abs_error - quadratic) + 0.5 * quadratic ** 2
# Quantile Huber loss
quantile_huber_loss = tf.abs(tau_i - tf.cast(abs_error < 0, dtype=tf.float32)) * huber_loss
# Quantile regression loss (the probability for each quantile is 1/num_quantiles)
quantile_regression_loss = tf.reduce_sum(quantile_huber_loss) / float(self.num_atoms)
self.loss = quantile_regression_loss
tf.losses.add_loss(self.loss)