1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 19:20:19 +01:00

QR-DQN bug fix and imporvements (#30)

* bug fix - QR-DQN using error instead of abs-error in the quantile huber loss

* improvement - QR-DQN sorting the quantile only once instead of batch_size times

* new feature - adding the Breakout QRDQN preset (verified to achieve good results)
This commit is contained in:
Itai Caspi
2017-11-29 14:01:59 +02:00
committed by GitHub
parent 7bdba396d2
commit 11faf19649
3 changed files with 23 additions and 3 deletions

View File

@@ -522,12 +522,13 @@ class QuantileRegressionQHead(Head):
tau_i = tf.tile(tf.expand_dims(self.quantile_midpoints, -1), [1, 1, self.num_atoms])
# Huber loss of T(theta_j) - theta_i
abs_error = tf.abs(T_theta_j - theta_i)
error = T_theta_j - theta_i
abs_error = tf.abs(error)
quadratic = tf.minimum(abs_error, self.huber_loss_interval)
huber_loss = self.huber_loss_interval * (abs_error - quadratic) + 0.5 * quadratic ** 2
# Quantile Huber loss
quantile_huber_loss = tf.abs(tau_i - tf.cast(abs_error < 0, dtype=tf.float32)) * huber_loss
quantile_huber_loss = tf.abs(tau_i - tf.cast(error < 0, dtype=tf.float32)) * huber_loss
# Quantile regression loss (the probability for each quantile is 1/num_quantiles)
quantile_regression_loss = tf.reduce_sum(quantile_huber_loss) / float(self.num_atoms)