mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 11:10:20 +01:00
N-step returns for rainbow (#67)
* n_step returns for rainbow * Rename CartPole_PPO -> CartPole_ClippedPPO
This commit is contained in:
@@ -116,8 +116,10 @@ class ClippedPPOAgent(ActorCriticAgent):
|
||||
# calculate advantages
|
||||
advantages = []
|
||||
value_targets = []
|
||||
total_returns = batch.n_step_discounted_rewards()
|
||||
|
||||
if self.policy_gradient_rescaler == PolicyGradientRescaler.A_VALUE:
|
||||
advantages = batch.total_returns() - current_state_values
|
||||
advantages = total_returns - current_state_values
|
||||
elif self.policy_gradient_rescaler == PolicyGradientRescaler.GAE:
|
||||
# get bootstraps
|
||||
episode_start_idx = 0
|
||||
@@ -181,11 +183,13 @@ class ClippedPPOAgent(ActorCriticAgent):
|
||||
result = self.networks['main'].target_network.predict({k: v[start:end] for k, v in batch.states(network_keys).items()})
|
||||
old_policy_distribution = result[1:]
|
||||
|
||||
total_returns = batch.n_step_discounted_rewards(expand_dims=True)
|
||||
|
||||
# calculate gradients and apply on both the local policy network and on the global policy network
|
||||
if self.ap.algorithm.estimate_state_value_using_gae:
|
||||
value_targets = np.expand_dims(gae_based_value_targets, -1)
|
||||
else:
|
||||
value_targets = batch.total_returns(expand_dims=True)[start:end]
|
||||
value_targets = total_returns[start:end]
|
||||
|
||||
inputs = copy.copy({k: v[start:end] for k, v in batch.states(network_keys).items()})
|
||||
inputs['output_1_0'] = actions
|
||||
|
||||
Reference in New Issue
Block a user