mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 19:20:19 +01:00
QR-DQN bug fix and imporvements (#30)
* bug fix - QR-DQN using error instead of abs-error in the quantile huber loss * improvement - QR-DQN sorting the quantile only once instead of batch_size times * new feature - adding the Breakout QRDQN preset (verified to achieve good results)
This commit is contained in:
@@ -51,8 +51,9 @@ class QuantileRegressionDQNAgent(ValueOptimizationAgent):
|
||||
cumulative_probabilities = np.array(range(self.tp.agent.atoms+1))/float(self.tp.agent.atoms) # tau_i
|
||||
quantile_midpoints = 0.5*(cumulative_probabilities[1:] + cumulative_probabilities[:-1]) # tau^hat_i
|
||||
quantile_midpoints = np.tile(quantile_midpoints, (self.tp.batch_size, 1))
|
||||
sorted_quantiles = np.argsort(current_quantiles[batch_idx, actions])
|
||||
for idx in range(self.tp.batch_size):
|
||||
quantile_midpoints[idx, :] = quantile_midpoints[idx, np.argsort(current_quantiles[batch_idx, actions])[idx]]
|
||||
quantile_midpoints[idx, :] = quantile_midpoints[idx, sorted_quantiles[idx]]
|
||||
|
||||
# train
|
||||
result = self.main_network.train_and_sync_networks([current_states, actions_locations, quantile_midpoints], TD_targets)
|
||||
|
||||
@@ -522,12 +522,13 @@ class QuantileRegressionQHead(Head):
|
||||
tau_i = tf.tile(tf.expand_dims(self.quantile_midpoints, -1), [1, 1, self.num_atoms])
|
||||
|
||||
# Huber loss of T(theta_j) - theta_i
|
||||
abs_error = tf.abs(T_theta_j - theta_i)
|
||||
error = T_theta_j - theta_i
|
||||
abs_error = tf.abs(error)
|
||||
quadratic = tf.minimum(abs_error, self.huber_loss_interval)
|
||||
huber_loss = self.huber_loss_interval * (abs_error - quadratic) + 0.5 * quadratic ** 2
|
||||
|
||||
# Quantile Huber loss
|
||||
quantile_huber_loss = tf.abs(tau_i - tf.cast(abs_error < 0, dtype=tf.float32)) * huber_loss
|
||||
quantile_huber_loss = tf.abs(tau_i - tf.cast(error < 0, dtype=tf.float32)) * huber_loss
|
||||
|
||||
# Quantile regression loss (the probability for each quantile is 1/num_quantiles)
|
||||
quantile_regression_loss = tf.reduce_sum(quantile_huber_loss) / float(self.num_atoms)
|
||||
|
||||
18
presets.py
18
presets.py
@@ -422,6 +422,24 @@ class Breakout_C51(Preset):
|
||||
self.evaluate_every_x_episodes = 5000000
|
||||
|
||||
|
||||
|
||||
class Breakout_QRDQN(Preset):
|
||||
def __init__(self):
|
||||
Preset.__init__(self, QuantileRegressionDQN, Atari, ExplorationParameters)
|
||||
self.env.level = 'BreakoutDeterministic-v4'
|
||||
self.agent.num_steps_between_copying_online_weights_to_target = 10000
|
||||
self.learning_rate = 0.00025
|
||||
self.agent.num_transitions_in_experience_replay = 1000000
|
||||
self.exploration.initial_epsilon = 1.0
|
||||
self.exploration.final_epsilon = 0.01
|
||||
self.exploration.epsilon_decay_steps = 1000000
|
||||
self.exploration.evaluation_policy = 'EGreedy'
|
||||
self.exploration.evaluation_epsilon = 0.001
|
||||
self.num_heatup_steps = 50000
|
||||
self.evaluation_episodes = 1
|
||||
self.evaluate_every_x_episodes = 50
|
||||
|
||||
|
||||
class Atari_DQN_TestBench(Preset):
|
||||
def __init__(self):
|
||||
Preset.__init__(self, DQN, Atari, ExplorationParameters)
|
||||
|
||||
Reference in New Issue
Block a user