1
0
mirror of https://github.com/gryf/coach.git synced 2026-05-01 21:40:56 +02:00

N-step returns for rainbow (#67)

* n_step returns for rainbow
* Rename CartPole_PPO -> CartPole_ClippedPPO
This commit is contained in:
Gal Leibovich
2018-11-07 18:33:08 +02:00
committed by GitHub
parent 35c477c922
commit 49dea39d34
18 changed files with 173 additions and 117 deletions
@@ -128,14 +128,14 @@ class SegmentTree(object):
self.tree[node_idx] = new_val
self._propagate(node_idx)
def get(self, val: float) -> Tuple[int, float, Any]:
def get_element_by_partial_sum(self, val: float) -> Tuple[int, float, Any]:
"""
Given a value between 0 and the tree sum, return the object which this value is in it's range.
For example, if we have 3 leaves: 10, 20, 30, and val=35, this will return the 3rd leaf, by accumulating
leaves by their order until getting to 35. This allows sampling leaves according to their proportional
probability.
:param val: a value within the range 0 and the tree sum
:return: the index of the resulting leaf in the tree, it's probability and
:return: the index of the resulting leaf in the tree, its probability and
the object itself
"""
node_idx = self._retrieve(0, val)
@@ -237,12 +237,12 @@ class PrioritizedExperienceReplay(ExperienceReplay):
# sample a batch
for i in range(size):
start_probability = segment_size * i
end_probability = segment_size * (i + 1)
segment_start = segment_size * i
segment_end = segment_size * (i + 1)
# sample leaf and calculate its weight
val = random.uniform(start_probability, end_probability)
leaf_idx, priority, transition = self.sum_tree.get(val)
val = random.uniform(segment_start, segment_end)
leaf_idx, priority, transition = self.sum_tree.get_element_by_partial_sum(val)
priority /= self.sum_tree.total_value() # P(j) = p^a / sum(p^a)
weight = (self.num_transitions() * priority) ** -self.beta.current_value # (N * P(j)) ^ -beta
normalized_weight = weight / max_weight # wj = ((N * P(j)) ^ -beta) / max wi
@@ -261,7 +261,7 @@ class PrioritizedExperienceReplay(ExperienceReplay):
self.reader_writer_lock.release_writing()
return batch
def store(self, transition: Transition) -> None:
def store(self, transition: Transition, lock=True) -> None:
"""
Store a new transition in the memory.
:param transition: a transition to store
@@ -270,7 +270,8 @@ class PrioritizedExperienceReplay(ExperienceReplay):
# Calling super.store() so that in case a memory backend is used, the memory backend can store this transition.
super().store(transition)
self.reader_writer_lock.lock_writing_and_reading()
if lock:
self.reader_writer_lock.lock_writing_and_reading()
transition_priority = self.maximal_priority
self.sum_tree.add(transition_priority ** self.alpha, transition)
@@ -278,18 +279,21 @@ class PrioritizedExperienceReplay(ExperienceReplay):
self.max_tree.add(transition_priority, transition)
super().store(transition, False)
self.reader_writer_lock.release_writing_and_reading()
if lock:
self.reader_writer_lock.release_writing_and_reading()
def clean(self) -> None:
def clean(self, lock=True) -> None:
"""
Clean the memory by removing all the episodes
:return: None
"""
self.reader_writer_lock.lock_writing_and_reading()
if lock:
self.reader_writer_lock.lock_writing_and_reading()
super().clean(lock=False)
self.sum_tree = SegmentTree(self.power_of_2_size, SegmentTree.Operation.SUM)
self.min_tree = SegmentTree(self.power_of_2_size, SegmentTree.Operation.MIN)
self.max_tree = SegmentTree(self.power_of_2_size, SegmentTree.Operation.MAX)
self.reader_writer_lock.release_writing_and_reading()
if lock:
self.reader_writer_lock.release_writing_and_reading()