1
0
mirror of https://github.com/gryf/coach.git synced 2026-03-18 15:53:35 +01:00

N-step returns for rainbow (#67)

* n_step returns for rainbow
* Rename CartPole_PPO -> CartPole_ClippedPPO
This commit is contained in:
Gal Leibovich
2018-11-07 18:33:08 +02:00
committed by GitHub
parent 35c477c922
commit 49dea39d34
18 changed files with 173 additions and 117 deletions

View File

@@ -27,6 +27,7 @@ class EpisodicExperienceReplayParameters(MemoryParameters):
def __init__(self):
super().__init__()
self.max_size = (MemoryGranularity.Transitions, 1000000)
self.n_step = -1
@property
def path(self):
@@ -39,12 +40,13 @@ class EpisodicExperienceReplay(Memory):
calculations of total return and other values that depend on the sequential behavior of the transitions
in the episode.
"""
def __init__(self, max_size: Tuple[MemoryGranularity, int]):
def __init__(self, max_size: Tuple[MemoryGranularity, int]=(MemoryGranularity.Transitions, 1000000), n_step=-1):
"""
:param max_size: the maximum number of transitions or episodes to hold in the memory
"""
super().__init__(max_size)
self._buffer = [Episode()] # list of episodes
self.n_step = n_step
self._buffer = [Episode(n_step=self.n_step)] # list of episodes
self.transitions = []
self._length = 1 # the episodic replay buffer starts with a single empty episode
self._num_transitions = 0
@@ -109,7 +111,7 @@ class EpisodicExperienceReplay(Memory):
self._remove_episode(0)
def _update_episode(self, episode: Episode) -> None:
episode.update_returns()
episode.update_transitions_rewards_and_bootstrap_data()
def verify_last_episode_is_closed(self) -> None:
"""
@@ -138,7 +140,7 @@ class EpisodicExperienceReplay(Memory):
self._length += 1
# create a new Episode for the next transitions to be placed into
self._buffer.append(Episode())
self._buffer.append(Episode(n_step=self.n_step))
# if update episode adds to the buffer, a new Episode needs to be ready first
# it would be better if this were less state full
@@ -158,12 +160,14 @@ class EpisodicExperienceReplay(Memory):
:param transition: a transition to store
:return: None
"""
# Calling super.store() so that in case a memory backend is used, the memory backend can store this transition.
super().store(transition)
self.reader_writer_lock.lock_writing_and_reading()
if len(self._buffer) == 0:
self._buffer.append(Episode())
self._buffer.append(Episode(n_step=self.n_step))
last_episode = self._buffer[-1]
last_episode.insert(transition)
self.transitions.append(transition)
@@ -284,7 +288,7 @@ class EpisodicExperienceReplay(Memory):
self.reader_writer_lock.lock_writing_and_reading()
self.transitions = []
self._buffer = [Episode()]
self._buffer = [Episode(n_step=self.n_step)]
self._length = 1
self._num_transitions = 0
self._num_transitions_in_complete_episodes = 0

View File

@@ -139,7 +139,7 @@ class EpisodicHindsightExperienceReplay(EpisodicExperienceReplay):
hindsight_transition.reward, hindsight_transition.game_over = \
self.goals_space.get_reward_for_goal_and_state(goal, hindsight_transition.next_state)
hindsight_transition.total_return = None
hindsight_transition.n_step_discounted_rewards = None
episode.insert(hindsight_transition)
super().store_episode(episode)

View File

@@ -128,14 +128,14 @@ class SegmentTree(object):
self.tree[node_idx] = new_val
self._propagate(node_idx)
def get(self, val: float) -> Tuple[int, float, Any]:
def get_element_by_partial_sum(self, val: float) -> Tuple[int, float, Any]:
"""
Given a value between 0 and the tree sum, return the object which this value is in it's range.
For example, if we have 3 leaves: 10, 20, 30, and val=35, this will return the 3rd leaf, by accumulating
leaves by their order until getting to 35. This allows sampling leaves according to their proportional
probability.
:param val: a value within the range 0 and the tree sum
:return: the index of the resulting leaf in the tree, it's probability and
:return: the index of the resulting leaf in the tree, its probability and
the object itself
"""
node_idx = self._retrieve(0, val)
@@ -237,12 +237,12 @@ class PrioritizedExperienceReplay(ExperienceReplay):
# sample a batch
for i in range(size):
start_probability = segment_size * i
end_probability = segment_size * (i + 1)
segment_start = segment_size * i
segment_end = segment_size * (i + 1)
# sample leaf and calculate its weight
val = random.uniform(start_probability, end_probability)
leaf_idx, priority, transition = self.sum_tree.get(val)
val = random.uniform(segment_start, segment_end)
leaf_idx, priority, transition = self.sum_tree.get_element_by_partial_sum(val)
priority /= self.sum_tree.total_value() # P(j) = p^a / sum(p^a)
weight = (self.num_transitions() * priority) ** -self.beta.current_value # (N * P(j)) ^ -beta
normalized_weight = weight / max_weight # wj = ((N * P(j)) ^ -beta) / max wi
@@ -261,7 +261,7 @@ class PrioritizedExperienceReplay(ExperienceReplay):
self.reader_writer_lock.release_writing()
return batch
def store(self, transition: Transition) -> None:
def store(self, transition: Transition, lock=True) -> None:
"""
Store a new transition in the memory.
:param transition: a transition to store
@@ -270,7 +270,8 @@ class PrioritizedExperienceReplay(ExperienceReplay):
# Calling super.store() so that in case a memory backend is used, the memory backend can store this transition.
super().store(transition)
self.reader_writer_lock.lock_writing_and_reading()
if lock:
self.reader_writer_lock.lock_writing_and_reading()
transition_priority = self.maximal_priority
self.sum_tree.add(transition_priority ** self.alpha, transition)
@@ -278,18 +279,21 @@ class PrioritizedExperienceReplay(ExperienceReplay):
self.max_tree.add(transition_priority, transition)
super().store(transition, False)
self.reader_writer_lock.release_writing_and_reading()
if lock:
self.reader_writer_lock.release_writing_and_reading()
def clean(self) -> None:
def clean(self, lock=True) -> None:
"""
Clean the memory by removing all the episodes
:return: None
"""
self.reader_writer_lock.lock_writing_and_reading()
if lock:
self.reader_writer_lock.lock_writing_and_reading()
super().clean(lock=False)
self.sum_tree = SegmentTree(self.power_of_2_size, SegmentTree.Operation.SUM)
self.min_tree = SegmentTree(self.power_of_2_size, SegmentTree.Operation.MIN)
self.max_tree = SegmentTree(self.power_of_2_size, SegmentTree.Operation.MAX)
self.reader_writer_lock.release_writing_and_reading()
if lock:
self.reader_writer_lock.release_writing_and_reading()