1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 11:10:20 +01:00

N-step returns for rainbow (#67)

* n_step returns for rainbow
* Rename CartPole_PPO -> CartPole_ClippedPPO
This commit is contained in:
Gal Leibovich
2018-11-07 18:33:08 +02:00
committed by GitHub
parent 35c477c922
commit 49dea39d34
18 changed files with 173 additions and 117 deletions

View File

@@ -112,11 +112,11 @@ class PPOAgent(ActorCriticAgent):
# current_states_with_timestep = self.concat_state_and_timestep(batch)
current_state_values = self.networks['critic'].online_network.predict(batch.states(network_keys)).squeeze()
total_returns = batch.n_step_discounted_rewards()
# calculate advantages
advantages = []
if self.policy_gradient_rescaler == PolicyGradientRescaler.A_VALUE:
advantages = batch.total_returns() - current_state_values
advantages = total_returns - current_state_values
elif self.policy_gradient_rescaler == PolicyGradientRescaler.GAE:
# get bootstraps
episode_start_idx = 0
@@ -155,6 +155,7 @@ class PPOAgent(ActorCriticAgent):
# current_states_with_timestep = self.concat_state_and_timestep(dataset)
mix_fraction = self.ap.algorithm.value_targets_mix_fraction
total_returns = batch.n_step_discounted_rewards(True)
for j in range(epochs):
curr_batch_size = batch.size
if self.networks['critic'].online_network.optimizer_type != 'LBFGS':
@@ -165,7 +166,7 @@ class PPOAgent(ActorCriticAgent):
k: v[i * curr_batch_size:(i + 1) * curr_batch_size]
for k, v in batch.states(network_keys).items()
}
total_return_batch = batch.total_returns(True)[i * curr_batch_size:(i + 1) * curr_batch_size]
total_return_batch = total_returns[i * curr_batch_size:(i + 1) * curr_batch_size]
old_policy_values = force_list(self.networks['critic'].target_network.predict(
current_states_batch).squeeze())
if self.networks['critic'].online_network.optimizer_type != 'LBFGS':