mirror of
https://github.com/gryf/coach.git
synced 2026-03-18 15:53:35 +01:00
N-step returns for rainbow (#67)
* n_step returns for rainbow * Rename CartPole_PPO -> CartPole_ClippedPPO
This commit is contained in:
@@ -27,6 +27,7 @@ class EpisodicExperienceReplayParameters(MemoryParameters):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.max_size = (MemoryGranularity.Transitions, 1000000)
|
||||
self.n_step = -1
|
||||
|
||||
@property
|
||||
def path(self):
|
||||
@@ -39,12 +40,13 @@ class EpisodicExperienceReplay(Memory):
|
||||
calculations of total return and other values that depend on the sequential behavior of the transitions
|
||||
in the episode.
|
||||
"""
|
||||
def __init__(self, max_size: Tuple[MemoryGranularity, int]):
|
||||
def __init__(self, max_size: Tuple[MemoryGranularity, int]=(MemoryGranularity.Transitions, 1000000), n_step=-1):
|
||||
"""
|
||||
:param max_size: the maximum number of transitions or episodes to hold in the memory
|
||||
"""
|
||||
super().__init__(max_size)
|
||||
self._buffer = [Episode()] # list of episodes
|
||||
self.n_step = n_step
|
||||
self._buffer = [Episode(n_step=self.n_step)] # list of episodes
|
||||
self.transitions = []
|
||||
self._length = 1 # the episodic replay buffer starts with a single empty episode
|
||||
self._num_transitions = 0
|
||||
@@ -109,7 +111,7 @@ class EpisodicExperienceReplay(Memory):
|
||||
self._remove_episode(0)
|
||||
|
||||
def _update_episode(self, episode: Episode) -> None:
|
||||
episode.update_returns()
|
||||
episode.update_transitions_rewards_and_bootstrap_data()
|
||||
|
||||
def verify_last_episode_is_closed(self) -> None:
|
||||
"""
|
||||
@@ -138,7 +140,7 @@ class EpisodicExperienceReplay(Memory):
|
||||
self._length += 1
|
||||
|
||||
# create a new Episode for the next transitions to be placed into
|
||||
self._buffer.append(Episode())
|
||||
self._buffer.append(Episode(n_step=self.n_step))
|
||||
|
||||
# if update episode adds to the buffer, a new Episode needs to be ready first
|
||||
# it would be better if this were less state full
|
||||
@@ -158,12 +160,14 @@ class EpisodicExperienceReplay(Memory):
|
||||
:param transition: a transition to store
|
||||
:return: None
|
||||
"""
|
||||
|
||||
# Calling super.store() so that in case a memory backend is used, the memory backend can store this transition.
|
||||
super().store(transition)
|
||||
|
||||
self.reader_writer_lock.lock_writing_and_reading()
|
||||
|
||||
if len(self._buffer) == 0:
|
||||
self._buffer.append(Episode())
|
||||
self._buffer.append(Episode(n_step=self.n_step))
|
||||
last_episode = self._buffer[-1]
|
||||
last_episode.insert(transition)
|
||||
self.transitions.append(transition)
|
||||
@@ -284,7 +288,7 @@ class EpisodicExperienceReplay(Memory):
|
||||
self.reader_writer_lock.lock_writing_and_reading()
|
||||
|
||||
self.transitions = []
|
||||
self._buffer = [Episode()]
|
||||
self._buffer = [Episode(n_step=self.n_step)]
|
||||
self._length = 1
|
||||
self._num_transitions = 0
|
||||
self._num_transitions_in_complete_episodes = 0
|
||||
|
||||
@@ -139,7 +139,7 @@ class EpisodicHindsightExperienceReplay(EpisodicExperienceReplay):
|
||||
hindsight_transition.reward, hindsight_transition.game_over = \
|
||||
self.goals_space.get_reward_for_goal_and_state(goal, hindsight_transition.next_state)
|
||||
|
||||
hindsight_transition.total_return = None
|
||||
hindsight_transition.n_step_discounted_rewards = None
|
||||
episode.insert(hindsight_transition)
|
||||
|
||||
super().store_episode(episode)
|
||||
|
||||
@@ -128,14 +128,14 @@ class SegmentTree(object):
|
||||
self.tree[node_idx] = new_val
|
||||
self._propagate(node_idx)
|
||||
|
||||
def get(self, val: float) -> Tuple[int, float, Any]:
|
||||
def get_element_by_partial_sum(self, val: float) -> Tuple[int, float, Any]:
|
||||
"""
|
||||
Given a value between 0 and the tree sum, return the object which this value is in it's range.
|
||||
For example, if we have 3 leaves: 10, 20, 30, and val=35, this will return the 3rd leaf, by accumulating
|
||||
leaves by their order until getting to 35. This allows sampling leaves according to their proportional
|
||||
probability.
|
||||
:param val: a value within the range 0 and the tree sum
|
||||
:return: the index of the resulting leaf in the tree, it's probability and
|
||||
:return: the index of the resulting leaf in the tree, its probability and
|
||||
the object itself
|
||||
"""
|
||||
node_idx = self._retrieve(0, val)
|
||||
@@ -237,12 +237,12 @@ class PrioritizedExperienceReplay(ExperienceReplay):
|
||||
|
||||
# sample a batch
|
||||
for i in range(size):
|
||||
start_probability = segment_size * i
|
||||
end_probability = segment_size * (i + 1)
|
||||
segment_start = segment_size * i
|
||||
segment_end = segment_size * (i + 1)
|
||||
|
||||
# sample leaf and calculate its weight
|
||||
val = random.uniform(start_probability, end_probability)
|
||||
leaf_idx, priority, transition = self.sum_tree.get(val)
|
||||
val = random.uniform(segment_start, segment_end)
|
||||
leaf_idx, priority, transition = self.sum_tree.get_element_by_partial_sum(val)
|
||||
priority /= self.sum_tree.total_value() # P(j) = p^a / sum(p^a)
|
||||
weight = (self.num_transitions() * priority) ** -self.beta.current_value # (N * P(j)) ^ -beta
|
||||
normalized_weight = weight / max_weight # wj = ((N * P(j)) ^ -beta) / max wi
|
||||
@@ -261,7 +261,7 @@ class PrioritizedExperienceReplay(ExperienceReplay):
|
||||
self.reader_writer_lock.release_writing()
|
||||
return batch
|
||||
|
||||
def store(self, transition: Transition) -> None:
|
||||
def store(self, transition: Transition, lock=True) -> None:
|
||||
"""
|
||||
Store a new transition in the memory.
|
||||
:param transition: a transition to store
|
||||
@@ -270,7 +270,8 @@ class PrioritizedExperienceReplay(ExperienceReplay):
|
||||
# Calling super.store() so that in case a memory backend is used, the memory backend can store this transition.
|
||||
super().store(transition)
|
||||
|
||||
self.reader_writer_lock.lock_writing_and_reading()
|
||||
if lock:
|
||||
self.reader_writer_lock.lock_writing_and_reading()
|
||||
|
||||
transition_priority = self.maximal_priority
|
||||
self.sum_tree.add(transition_priority ** self.alpha, transition)
|
||||
@@ -278,18 +279,21 @@ class PrioritizedExperienceReplay(ExperienceReplay):
|
||||
self.max_tree.add(transition_priority, transition)
|
||||
super().store(transition, False)
|
||||
|
||||
self.reader_writer_lock.release_writing_and_reading()
|
||||
if lock:
|
||||
self.reader_writer_lock.release_writing_and_reading()
|
||||
|
||||
def clean(self) -> None:
|
||||
def clean(self, lock=True) -> None:
|
||||
"""
|
||||
Clean the memory by removing all the episodes
|
||||
:return: None
|
||||
"""
|
||||
self.reader_writer_lock.lock_writing_and_reading()
|
||||
if lock:
|
||||
self.reader_writer_lock.lock_writing_and_reading()
|
||||
|
||||
super().clean(lock=False)
|
||||
self.sum_tree = SegmentTree(self.power_of_2_size, SegmentTree.Operation.SUM)
|
||||
self.min_tree = SegmentTree(self.power_of_2_size, SegmentTree.Operation.MIN)
|
||||
self.max_tree = SegmentTree(self.power_of_2_size, SegmentTree.Operation.MAX)
|
||||
|
||||
self.reader_writer_lock.release_writing_and_reading()
|
||||
if lock:
|
||||
self.reader_writer_lock.release_writing_and_reading()
|
||||
|
||||
Reference in New Issue
Block a user