1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 11:10:20 +01:00

N-step returns for rainbow (#67)

* n_step returns for rainbow
* Rename CartPole_PPO -> CartPole_ClippedPPO
This commit is contained in:
Gal Leibovich
2018-11-07 18:33:08 +02:00
committed by GitHub
parent 35c477c922
commit 49dea39d34
18 changed files with 173 additions and 117 deletions

View File

@@ -116,7 +116,6 @@ class Agent(AgentInterface):
self.output_filter.set_device(device)
self.pre_network_filter.set_device(device)
# initialize all internal variables
self._phase = RunPhase.HEATUP
self.total_shaped_reward_in_current_episode = 0
@@ -143,7 +142,7 @@ class Agent(AgentInterface):
self.accumulated_shaped_rewards_across_evaluation_episodes = 0
self.num_successes_across_evaluation_episodes = 0
self.num_evaluation_episodes_completed = 0
self.current_episode_buffer = Episode(discount=self.ap.algorithm.discount)
self.current_episode_buffer = Episode(discount=self.ap.algorithm.discount, n_step=self.ap.algorithm.n_step)
# TODO: add agents observation rendering for debugging purposes (not the same as the environment rendering)
# environment parameters
@@ -452,10 +451,10 @@ class Agent(AgentInterface):
:return: None
"""
self.current_episode_buffer.is_complete = True
self.current_episode_buffer.update_returns()
self.current_episode_buffer.update_transitions_rewards_and_bootstrap_data()
for transition in self.current_episode_buffer.transitions:
self.discounted_return.add_sample(transition.total_return)
self.discounted_return.add_sample(transition.n_step_discounted_rewards)
if self.phase != RunPhase.TEST or self.ap.task_parameters.evaluate_only:
self.current_episode += 1
@@ -497,7 +496,7 @@ class Agent(AgentInterface):
self.curr_state = {}
self.current_episode_steps_counter = 0
self.episode_running_info = {}
self.current_episode_buffer = Episode(discount=self.ap.algorithm.discount)
self.current_episode_buffer = Episode(discount=self.ap.algorithm.discount, n_step=self.ap.algorithm.n_step)
if self.exploration_policy:
self.exploration_policy.reset()
self.input_filter.reset()