mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 19:20:19 +01:00
N-step returns for rainbow (#67)
* n_step returns for rainbow * Rename CartPole_PPO -> CartPole_ClippedPPO
This commit is contained in:
@@ -73,11 +73,11 @@ class PolicyOptimizationAgent(Agent):
|
||||
episode_discounted_returns = []
|
||||
for i in range(episode.length()):
|
||||
transition = episode.get_transition(i)
|
||||
episode_discounted_returns.append(transition.total_return)
|
||||
episode_discounted_returns.append(transition.n_step_discounted_rewards)
|
||||
self.num_episodes_where_step_has_been_seen[i] += 1
|
||||
self.mean_return_over_multiple_episodes[i] -= self.mean_return_over_multiple_episodes[i] / \
|
||||
self.num_episodes_where_step_has_been_seen[i]
|
||||
self.mean_return_over_multiple_episodes[i] += transition.total_return / \
|
||||
self.mean_return_over_multiple_episodes[i] += transition.n_step_discounted_rewards / \
|
||||
self.num_episodes_where_step_has_been_seen[i]
|
||||
self.mean_discounted_return = np.mean(episode_discounted_returns)
|
||||
self.std_discounted_return = np.std(episode_discounted_returns)
|
||||
@@ -97,7 +97,7 @@ class PolicyOptimizationAgent(Agent):
|
||||
network.set_is_training(True)
|
||||
|
||||
# we need to update the returns of the episode until now
|
||||
episode.update_returns()
|
||||
episode.update_transitions_rewards_and_bootstrap_data()
|
||||
|
||||
# get t_max transitions or less if the we got to a terminal state
|
||||
# will be used for both actor-critic and vanilla PG.
|
||||
|
||||
Reference in New Issue
Block a user