mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 11:10:20 +01:00
N-step returns for rainbow (#67)
* n_step returns for rainbow * Rename CartPole_PPO -> CartPole_ClippedPPO
This commit is contained in:
@@ -112,11 +112,11 @@ class PPOAgent(ActorCriticAgent):
|
||||
# current_states_with_timestep = self.concat_state_and_timestep(batch)
|
||||
|
||||
current_state_values = self.networks['critic'].online_network.predict(batch.states(network_keys)).squeeze()
|
||||
|
||||
total_returns = batch.n_step_discounted_rewards()
|
||||
# calculate advantages
|
||||
advantages = []
|
||||
if self.policy_gradient_rescaler == PolicyGradientRescaler.A_VALUE:
|
||||
advantages = batch.total_returns() - current_state_values
|
||||
advantages = total_returns - current_state_values
|
||||
elif self.policy_gradient_rescaler == PolicyGradientRescaler.GAE:
|
||||
# get bootstraps
|
||||
episode_start_idx = 0
|
||||
@@ -155,6 +155,7 @@ class PPOAgent(ActorCriticAgent):
|
||||
# current_states_with_timestep = self.concat_state_and_timestep(dataset)
|
||||
|
||||
mix_fraction = self.ap.algorithm.value_targets_mix_fraction
|
||||
total_returns = batch.n_step_discounted_rewards(True)
|
||||
for j in range(epochs):
|
||||
curr_batch_size = batch.size
|
||||
if self.networks['critic'].online_network.optimizer_type != 'LBFGS':
|
||||
@@ -165,7 +166,7 @@ class PPOAgent(ActorCriticAgent):
|
||||
k: v[i * curr_batch_size:(i + 1) * curr_batch_size]
|
||||
for k, v in batch.states(network_keys).items()
|
||||
}
|
||||
total_return_batch = batch.total_returns(True)[i * curr_batch_size:(i + 1) * curr_batch_size]
|
||||
total_return_batch = total_returns[i * curr_batch_size:(i + 1) * curr_batch_size]
|
||||
old_policy_values = force_list(self.networks['critic'].target_network.predict(
|
||||
current_states_batch).squeeze())
|
||||
if self.networks['critic'].online_network.optimizer_type != 'LBFGS':
|
||||
|
||||
Reference in New Issue
Block a user