1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 19:20:19 +01:00

N-step returns for rainbow (#67)

* n_step returns for rainbow
* Rename CartPole_PPO -> CartPole_ClippedPPO
This commit is contained in:
Gal Leibovich
2018-11-07 18:33:08 +02:00
committed by GitHub
parent 35c477c922
commit 49dea39d34
18 changed files with 173 additions and 117 deletions

View File

@@ -70,6 +70,7 @@ class PALAgent(ValueOptimizationAgent):
# calculate TD error
TD_targets = np.copy(q_st_online)
total_returns = batch.n_step_discounted_rewards()
for i in range(self.ap.network_wrappers['main'].batch_size):
TD_targets[i, batch.actions()[i]] = batch.rewards()[i] + \
(1.0 - batch.game_overs()[i]) * self.ap.algorithm.discount * \
@@ -83,7 +84,7 @@ class PALAgent(ValueOptimizationAgent):
TD_targets[i, batch.actions()[i]] -= self.alpha * advantage_learning_update
# mixing monte carlo updates
monte_carlo_target = batch.total_returns()[i]
monte_carlo_target = total_returns[i]
TD_targets[i, batch.actions()[i]] = (1 - self.monte_carlo_mixing_rate) * TD_targets[i, batch.actions()[i]] \
+ self.monte_carlo_mixing_rate * monte_carlo_target