mirror of
https://github.com/gryf/coach.git
synced 2026-04-15 20:13:33 +02:00
N-step returns for rainbow (#67)
* n_step returns for rainbow * Rename CartPole_PPO -> CartPole_ClippedPPO
This commit is contained in:
@@ -27,6 +27,7 @@ class EpisodicExperienceReplayParameters(MemoryParameters):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.max_size = (MemoryGranularity.Transitions, 1000000)
|
||||
self.n_step = -1
|
||||
|
||||
@property
|
||||
def path(self):
|
||||
@@ -39,12 +40,13 @@ class EpisodicExperienceReplay(Memory):
|
||||
calculations of total return and other values that depend on the sequential behavior of the transitions
|
||||
in the episode.
|
||||
"""
|
||||
def __init__(self, max_size: Tuple[MemoryGranularity, int]):
|
||||
def __init__(self, max_size: Tuple[MemoryGranularity, int]=(MemoryGranularity.Transitions, 1000000), n_step=-1):
|
||||
"""
|
||||
:param max_size: the maximum number of transitions or episodes to hold in the memory
|
||||
"""
|
||||
super().__init__(max_size)
|
||||
self._buffer = [Episode()] # list of episodes
|
||||
self.n_step = n_step
|
||||
self._buffer = [Episode(n_step=self.n_step)] # list of episodes
|
||||
self.transitions = []
|
||||
self._length = 1 # the episodic replay buffer starts with a single empty episode
|
||||
self._num_transitions = 0
|
||||
@@ -109,7 +111,7 @@ class EpisodicExperienceReplay(Memory):
|
||||
self._remove_episode(0)
|
||||
|
||||
def _update_episode(self, episode: Episode) -> None:
|
||||
episode.update_returns()
|
||||
episode.update_transitions_rewards_and_bootstrap_data()
|
||||
|
||||
def verify_last_episode_is_closed(self) -> None:
|
||||
"""
|
||||
@@ -138,7 +140,7 @@ class EpisodicExperienceReplay(Memory):
|
||||
self._length += 1
|
||||
|
||||
# create a new Episode for the next transitions to be placed into
|
||||
self._buffer.append(Episode())
|
||||
self._buffer.append(Episode(n_step=self.n_step))
|
||||
|
||||
# if update episode adds to the buffer, a new Episode needs to be ready first
|
||||
# it would be better if this were less state full
|
||||
@@ -158,12 +160,14 @@ class EpisodicExperienceReplay(Memory):
|
||||
:param transition: a transition to store
|
||||
:return: None
|
||||
"""
|
||||
|
||||
# Calling super.store() so that in case a memory backend is used, the memory backend can store this transition.
|
||||
super().store(transition)
|
||||
|
||||
self.reader_writer_lock.lock_writing_and_reading()
|
||||
|
||||
if len(self._buffer) == 0:
|
||||
self._buffer.append(Episode())
|
||||
self._buffer.append(Episode(n_step=self.n_step))
|
||||
last_episode = self._buffer[-1]
|
||||
last_episode.insert(transition)
|
||||
self.transitions.append(transition)
|
||||
@@ -284,7 +288,7 @@ class EpisodicExperienceReplay(Memory):
|
||||
self.reader_writer_lock.lock_writing_and_reading()
|
||||
|
||||
self.transitions = []
|
||||
self._buffer = [Episode()]
|
||||
self._buffer = [Episode(n_step=self.n_step)]
|
||||
self._length = 1
|
||||
self._num_transitions = 0
|
||||
self._num_transitions_in_complete_episodes = 0
|
||||
|
||||
@@ -139,7 +139,7 @@ class EpisodicHindsightExperienceReplay(EpisodicExperienceReplay):
|
||||
hindsight_transition.reward, hindsight_transition.game_over = \
|
||||
self.goals_space.get_reward_for_goal_and_state(goal, hindsight_transition.next_state)
|
||||
|
||||
hindsight_transition.total_return = None
|
||||
hindsight_transition.n_step_discounted_rewards = None
|
||||
episode.insert(hindsight_transition)
|
||||
|
||||
super().store_episode(episode)
|
||||
|
||||
Reference in New Issue
Block a user