1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 11:10:20 +01:00

N-step returns for rainbow (#67)

* n_step returns for rainbow
* Rename CartPole_PPO -> CartPole_ClippedPPO
This commit is contained in:
Gal Leibovich
2018-11-07 18:33:08 +02:00
committed by GitHub
parent 35c477c922
commit 49dea39d34
18 changed files with 173 additions and 117 deletions

View File

@@ -116,8 +116,10 @@ class ClippedPPOAgent(ActorCriticAgent):
# calculate advantages
advantages = []
value_targets = []
total_returns = batch.n_step_discounted_rewards()
if self.policy_gradient_rescaler == PolicyGradientRescaler.A_VALUE:
advantages = batch.total_returns() - current_state_values
advantages = total_returns - current_state_values
elif self.policy_gradient_rescaler == PolicyGradientRescaler.GAE:
# get bootstraps
episode_start_idx = 0
@@ -181,11 +183,13 @@ class ClippedPPOAgent(ActorCriticAgent):
result = self.networks['main'].target_network.predict({k: v[start:end] for k, v in batch.states(network_keys).items()})
old_policy_distribution = result[1:]
total_returns = batch.n_step_discounted_rewards(expand_dims=True)
# calculate gradients and apply on both the local policy network and on the global policy network
if self.ap.algorithm.estimate_state_value_using_gae:
value_targets = np.expand_dims(gae_based_value_targets, -1)
else:
value_targets = batch.total_returns(expand_dims=True)[start:end]
value_targets = total_returns[start:end]
inputs = copy.copy({k: v[start:end] for k, v in batch.states(network_keys).items()})
inputs['output_1_0'] = actions