1
0
mirror of https://github.com/gryf/coach.git synced 2026-04-15 20:13:33 +02:00

N-step returns for rainbow (#67)

* n_step returns for rainbow
* Rename CartPole_PPO -> CartPole_ClippedPPO
This commit is contained in:
Gal Leibovich
2018-11-07 18:33:08 +02:00
committed by GitHub
parent 35c477c922
commit 49dea39d34
18 changed files with 173 additions and 117 deletions

View File

@@ -27,6 +27,7 @@ class EpisodicExperienceReplayParameters(MemoryParameters):
def __init__(self):
super().__init__()
self.max_size = (MemoryGranularity.Transitions, 1000000)
self.n_step = -1
@property
def path(self):
@@ -39,12 +40,13 @@ class EpisodicExperienceReplay(Memory):
calculations of total return and other values that depend on the sequential behavior of the transitions
in the episode.
"""
def __init__(self, max_size: Tuple[MemoryGranularity, int]):
def __init__(self, max_size: Tuple[MemoryGranularity, int]=(MemoryGranularity.Transitions, 1000000), n_step=-1):
"""
:param max_size: the maximum number of transitions or episodes to hold in the memory
"""
super().__init__(max_size)
self._buffer = [Episode()] # list of episodes
self.n_step = n_step
self._buffer = [Episode(n_step=self.n_step)] # list of episodes
self.transitions = []
self._length = 1 # the episodic replay buffer starts with a single empty episode
self._num_transitions = 0
@@ -109,7 +111,7 @@ class EpisodicExperienceReplay(Memory):
self._remove_episode(0)
def _update_episode(self, episode: Episode) -> None:
episode.update_returns()
episode.update_transitions_rewards_and_bootstrap_data()
def verify_last_episode_is_closed(self) -> None:
"""
@@ -138,7 +140,7 @@ class EpisodicExperienceReplay(Memory):
self._length += 1
# create a new Episode for the next transitions to be placed into
self._buffer.append(Episode())
self._buffer.append(Episode(n_step=self.n_step))
# if update episode adds to the buffer, a new Episode needs to be ready first
# it would be better if this were less state full
@@ -158,12 +160,14 @@ class EpisodicExperienceReplay(Memory):
:param transition: a transition to store
:return: None
"""
# Calling super.store() so that in case a memory backend is used, the memory backend can store this transition.
super().store(transition)
self.reader_writer_lock.lock_writing_and_reading()
if len(self._buffer) == 0:
self._buffer.append(Episode())
self._buffer.append(Episode(n_step=self.n_step))
last_episode = self._buffer[-1]
last_episode.insert(transition)
self.transitions.append(transition)
@@ -284,7 +288,7 @@ class EpisodicExperienceReplay(Memory):
self.reader_writer_lock.lock_writing_and_reading()
self.transitions = []
self._buffer = [Episode()]
self._buffer = [Episode(n_step=self.n_step)]
self._length = 1
self._num_transitions = 0
self._num_transitions_in_complete_episodes = 0

View File

@@ -139,7 +139,7 @@ class EpisodicHindsightExperienceReplay(EpisodicExperienceReplay):
hindsight_transition.reward, hindsight_transition.game_over = \
self.goals_space.get_reward_for_goal_and_state(goal, hindsight_transition.next_state)
hindsight_transition.total_return = None
hindsight_transition.n_step_discounted_rewards = None
episode.insert(hindsight_transition)
super().store_episode(episode)