1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 19:20:19 +01:00

N-step returns for rainbow (#67)

* n_step returns for rainbow
* Rename CartPole_PPO -> CartPole_ClippedPPO
This commit is contained in:
Gal Leibovich
2018-11-07 18:33:08 +02:00
committed by GitHub
parent 35c477c922
commit 49dea39d34
18 changed files with 173 additions and 117 deletions

View File

@@ -73,11 +73,11 @@ class PolicyOptimizationAgent(Agent):
episode_discounted_returns = []
for i in range(episode.length()):
transition = episode.get_transition(i)
episode_discounted_returns.append(transition.total_return)
episode_discounted_returns.append(transition.n_step_discounted_rewards)
self.num_episodes_where_step_has_been_seen[i] += 1
self.mean_return_over_multiple_episodes[i] -= self.mean_return_over_multiple_episodes[i] / \
self.num_episodes_where_step_has_been_seen[i]
self.mean_return_over_multiple_episodes[i] += transition.total_return / \
self.mean_return_over_multiple_episodes[i] += transition.n_step_discounted_rewards / \
self.num_episodes_where_step_has_been_seen[i]
self.mean_discounted_return = np.mean(episode_discounted_returns)
self.std_discounted_return = np.std(episode_discounted_returns)
@@ -97,7 +97,7 @@ class PolicyOptimizationAgent(Agent):
network.set_is_training(True)
# we need to update the returns of the episode until now
episode.update_returns()
episode.update_transitions_rewards_and_bootstrap_data()
# get t_max transitions or less if the we got to a terminal state
# will be used for both actor-critic and vanilla PG.