mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 19:20:19 +01:00
QR-DQN bug fix and imporvements (#30)
* bug fix - QR-DQN using error instead of abs-error in the quantile huber loss * improvement - QR-DQN sorting the quantile only once instead of batch_size times * new feature - adding the Breakout QRDQN preset (verified to achieve good results)
This commit is contained in:
@@ -51,8 +51,9 @@ class QuantileRegressionDQNAgent(ValueOptimizationAgent):
|
||||
cumulative_probabilities = np.array(range(self.tp.agent.atoms+1))/float(self.tp.agent.atoms) # tau_i
|
||||
quantile_midpoints = 0.5*(cumulative_probabilities[1:] + cumulative_probabilities[:-1]) # tau^hat_i
|
||||
quantile_midpoints = np.tile(quantile_midpoints, (self.tp.batch_size, 1))
|
||||
sorted_quantiles = np.argsort(current_quantiles[batch_idx, actions])
|
||||
for idx in range(self.tp.batch_size):
|
||||
quantile_midpoints[idx, :] = quantile_midpoints[idx, np.argsort(current_quantiles[batch_idx, actions])[idx]]
|
||||
quantile_midpoints[idx, :] = quantile_midpoints[idx, sorted_quantiles[idx]]
|
||||
|
||||
# train
|
||||
result = self.main_network.train_and_sync_networks([current_states, actions_locations, quantile_midpoints], TD_targets)
|
||||
|
||||
Reference in New Issue
Block a user