mirror of
https://github.com/gryf/coach.git
synced 2026-03-11 03:55:52 +01:00
pre-release 0.10.0
This commit is contained in:
15
rl_coach/memories/__init__.py
Normal file
15
rl_coach/memories/__init__.py
Normal file
@@ -0,0 +1,15 @@
|
||||
#
|
||||
# Copyright (c) 2017 Intel Corporation
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
0
rl_coach/memories/episodic/__init__.py
Normal file
0
rl_coach/memories/episodic/__init__.py
Normal file
318
rl_coach/memories/episodic/episodic_experience_replay.py
Normal file
318
rl_coach/memories/episodic/episodic_experience_replay.py
Normal file
@@ -0,0 +1,318 @@
|
||||
#
|
||||
# Copyright (c) 2017 Intel Corporation
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from typing import List, Tuple, Union, Dict, Any
|
||||
|
||||
import numpy as np
|
||||
from rl_coach.utils import ReaderWriterLock
|
||||
|
||||
from rl_coach.core_types import Transition, Episode
|
||||
from rl_coach.memories.memory import Memory, MemoryGranularity, MemoryParameters
|
||||
|
||||
|
||||
class EpisodicExperienceReplayParameters(MemoryParameters):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.max_size = (MemoryGranularity.Transitions, 1000000)
|
||||
|
||||
@property
|
||||
def path(self):
|
||||
return 'rl_coach.memories.episodic.episodic_experience_replay:EpisodicExperienceReplay'
|
||||
|
||||
|
||||
class EpisodicExperienceReplay(Memory):
|
||||
"""
|
||||
A replay buffer that stores episodes of transitions. The additional structure allows performing various
|
||||
calculations of total return and other values that depend on the sequential behavior of the transitions
|
||||
in the episode.
|
||||
"""
|
||||
def __init__(self, max_size: Tuple[MemoryGranularity, int]):
|
||||
"""
|
||||
:param max_size: the maximum number of transitions or episodes to hold in the memory
|
||||
"""
|
||||
super().__init__(max_size)
|
||||
|
||||
self._buffer = [Episode()] # list of episodes
|
||||
self.transitions = []
|
||||
self._length = 1 # the episodic replay buffer starts with a single empty episode
|
||||
self._num_transitions = 0
|
||||
self._num_transitions_in_complete_episodes = 0
|
||||
|
||||
self.reader_writer_lock = ReaderWriterLock()
|
||||
|
||||
def length(self, lock: bool=False) -> int:
|
||||
"""
|
||||
Get the number of episodes in the ER (even if they are not complete)
|
||||
"""
|
||||
length = self._length
|
||||
if self._length is not 0 and self._buffer[-1].is_empty():
|
||||
length = self._length - 1
|
||||
|
||||
return length
|
||||
|
||||
def num_complete_episodes(self):
|
||||
""" Get the number of complete episodes in ER """
|
||||
length = self._length - 1
|
||||
|
||||
return length
|
||||
|
||||
def num_transitions(self):
|
||||
return self._num_transitions
|
||||
|
||||
def num_transitions_in_complete_episodes(self):
|
||||
return self._num_transitions_in_complete_episodes
|
||||
|
||||
def sample(self, size: int) -> List[Transition]:
|
||||
"""
|
||||
Sample a batch of transitions form the replay buffer. If the requested size is larger than the number
|
||||
of samples available in the replay buffer then the batch will return empty.
|
||||
:param size: the size of the batch to sample
|
||||
:return: a batch (list) of selected transitions from the replay buffer
|
||||
"""
|
||||
self.reader_writer_lock.lock_writing()
|
||||
|
||||
if self.num_complete_episodes() >= 1:
|
||||
transitions_idx = np.random.randint(self.num_transitions_in_complete_episodes(), size=size)
|
||||
batch = [self.transitions[i] for i in transitions_idx]
|
||||
|
||||
else:
|
||||
raise ValueError("The episodic replay buffer cannot be sampled since there are no complete episodes yet. "
|
||||
"There is currently 1 episodes with {} transitions".format(self._buffer[0].length()))
|
||||
|
||||
self.reader_writer_lock.release_writing()
|
||||
|
||||
return batch
|
||||
|
||||
def _enforce_max_length(self) -> None:
|
||||
"""
|
||||
Make sure that the size of the replay buffer does not pass the maximum size allowed.
|
||||
If it passes the max size, the oldest episode in the replay buffer will be removed.
|
||||
:return: None
|
||||
"""
|
||||
granularity, size = self.max_size
|
||||
if granularity == MemoryGranularity.Transitions:
|
||||
while size != 0 and self.num_transitions() > size:
|
||||
self._remove_episode(0)
|
||||
elif granularity == MemoryGranularity.Episodes:
|
||||
while self.length() > size:
|
||||
self._remove_episode(0)
|
||||
|
||||
def _update_episode(self, episode: Episode) -> None:
|
||||
episode.update_returns()
|
||||
|
||||
def verify_last_episode_is_closed(self) -> None:
|
||||
"""
|
||||
Verify that there is no open episodes in the replay buffer
|
||||
:return: None
|
||||
"""
|
||||
self.reader_writer_lock.lock_writing_and_reading()
|
||||
|
||||
last_episode = self.get(-1, False)
|
||||
if last_episode and last_episode.length() > 0:
|
||||
self.close_last_episode(lock=False)
|
||||
|
||||
self.reader_writer_lock.release_writing_and_reading()
|
||||
|
||||
def close_last_episode(self, lock=True) -> None:
|
||||
"""
|
||||
Close the last episode in the replay buffer and open a new one
|
||||
:return: None
|
||||
"""
|
||||
if lock:
|
||||
self.reader_writer_lock.lock_writing_and_reading()
|
||||
|
||||
last_episode = self._buffer[-1]
|
||||
|
||||
self._num_transitions_in_complete_episodes += last_episode.length()
|
||||
self._length += 1
|
||||
|
||||
# create a new Episode for the next transitions to be placed into
|
||||
self._buffer.append(Episode())
|
||||
|
||||
# if update episode adds to the buffer, a new Episode needs to be ready first
|
||||
# it would be better if this were less state full
|
||||
self._update_episode(last_episode)
|
||||
|
||||
self._enforce_max_length()
|
||||
|
||||
if lock:
|
||||
self.reader_writer_lock.release_writing_and_reading()
|
||||
|
||||
def store(self, transition: Transition) -> None:
|
||||
"""
|
||||
Store a new transition in the memory. If the transition game_over flag is on, this closes the episode and
|
||||
creates a new empty episode.
|
||||
Warning! using the episodic memory by storing individual transitions instead of episodes will use the default
|
||||
Episode class parameters in order to create new episodes.
|
||||
:param transition: a transition to store
|
||||
:return: None
|
||||
"""
|
||||
self.reader_writer_lock.lock_writing_and_reading()
|
||||
|
||||
if len(self._buffer) == 0:
|
||||
self._buffer.append(Episode())
|
||||
last_episode = self._buffer[-1]
|
||||
last_episode.insert(transition)
|
||||
self.transitions.append(transition)
|
||||
self._num_transitions += 1
|
||||
if transition.game_over:
|
||||
self.close_last_episode(False)
|
||||
|
||||
self._enforce_max_length()
|
||||
|
||||
self.reader_writer_lock.release_writing_and_reading()
|
||||
|
||||
def store_episode(self, episode: Episode, lock: bool=True) -> None:
|
||||
"""
|
||||
Store a new episode in the memory.
|
||||
:param episode: the new episode to store
|
||||
:return: None
|
||||
"""
|
||||
if lock:
|
||||
self.reader_writer_lock.lock_writing_and_reading()
|
||||
|
||||
if self._buffer[-1].length() == 0:
|
||||
self._buffer[-1] = episode
|
||||
else:
|
||||
self._buffer.append(episode)
|
||||
self.transitions.extend(episode.transitions)
|
||||
self._num_transitions += episode.length()
|
||||
self.close_last_episode(False)
|
||||
|
||||
if lock:
|
||||
self.reader_writer_lock.release_writing_and_reading()
|
||||
|
||||
def get_episode(self, episode_index: int, lock: bool=True) -> Union[None, Episode]:
|
||||
"""
|
||||
Returns the episode in the given index. If the episode does not exist, returns None instead.
|
||||
:param episode_index: the index of the episode to return
|
||||
:return: the corresponding episode
|
||||
"""
|
||||
if lock:
|
||||
self.reader_writer_lock.lock_writing()
|
||||
|
||||
if self.length() == 0 or episode_index >= self.length():
|
||||
episode = None
|
||||
else:
|
||||
episode = self._buffer[episode_index]
|
||||
|
||||
if lock:
|
||||
self.reader_writer_lock.release_writing()
|
||||
return episode
|
||||
|
||||
def _remove_episode(self, episode_index: int) -> None:
|
||||
"""
|
||||
Remove the episode in the given index (even if it is not complete yet)
|
||||
:param episode_index: the index of the episode to remove
|
||||
:return: None
|
||||
"""
|
||||
if len(self._buffer) > episode_index:
|
||||
episode_length = self._buffer[episode_index].length()
|
||||
self._length -= 1
|
||||
self._num_transitions -= episode_length
|
||||
self._num_transitions_in_complete_episodes -= episode_length
|
||||
del self.transitions[:episode_length]
|
||||
del self._buffer[episode_index]
|
||||
|
||||
def remove_episode(self, episode_index: int) -> None:
|
||||
"""
|
||||
Remove the episode in the given index (even if it is not complete yet)
|
||||
:param episode_index: the index of the episode to remove
|
||||
:return: None
|
||||
"""
|
||||
self.reader_writer_lock.lock_writing_and_reading()
|
||||
|
||||
self._remove_episode(episode_index)
|
||||
|
||||
self.reader_writer_lock.release_writing_and_reading()
|
||||
|
||||
# for API compatibility
|
||||
def get(self, episode_index: int, lock: bool=True) -> Union[None, Episode]:
|
||||
"""
|
||||
Returns the episode in the given index. If the episode does not exist, returns None instead.
|
||||
:param episode_index: the index of the episode to return
|
||||
:return: the corresponding episode
|
||||
"""
|
||||
return self.get_episode(episode_index, lock)
|
||||
|
||||
def get_last_complete_episode(self) -> Union[None, Episode]:
|
||||
"""
|
||||
Returns the last complete episode in the memory or None if there are no complete episodes
|
||||
:return: None or the last complete episode
|
||||
"""
|
||||
self.reader_writer_lock.lock_writing()
|
||||
|
||||
last_complete_episode_index = self.num_complete_episodes() - 1
|
||||
episode = None
|
||||
if last_complete_episode_index >= 0:
|
||||
episode = self.get(last_complete_episode_index)
|
||||
|
||||
self.reader_writer_lock.release_writing()
|
||||
|
||||
return episode
|
||||
|
||||
# for API compatibility
|
||||
def remove(self, episode_index: int):
|
||||
"""
|
||||
Remove the episode in the given index (even if it is not complete yet)
|
||||
:param episode_index: the index of the episode to remove
|
||||
:return: None
|
||||
"""
|
||||
self.remove_episode(episode_index)
|
||||
|
||||
def update_last_transition_info(self, info: Dict[str, Any]) -> None:
|
||||
"""
|
||||
Update the info of the last transition stored in the memory
|
||||
:param info: the new info to append to the existing info
|
||||
:return: None
|
||||
"""
|
||||
self.reader_writer_lock.lock_writing_and_reading()
|
||||
|
||||
episode = self._buffer[-1]
|
||||
if episode.length() == 0:
|
||||
if len(self._buffer) < 2:
|
||||
return
|
||||
episode = self._buffer[-2]
|
||||
episode.transitions[-1].info.update(info)
|
||||
|
||||
self.reader_writer_lock.release_writing_and_reading()
|
||||
|
||||
def clean(self) -> None:
|
||||
"""
|
||||
Clean the memory by removing all the episodes
|
||||
:return: None
|
||||
"""
|
||||
self.reader_writer_lock.lock_writing_and_reading()
|
||||
|
||||
self.transitions = []
|
||||
self._buffer = [Episode()]
|
||||
self._length = 1
|
||||
self._num_transitions = 0
|
||||
self._num_transitions_in_complete_episodes = 0
|
||||
|
||||
self.reader_writer_lock.release_writing_and_reading()
|
||||
|
||||
def mean_reward(self) -> np.ndarray:
|
||||
"""
|
||||
Get the mean reward in the replay buffer
|
||||
:return: the mean reward
|
||||
"""
|
||||
self.reader_writer_lock.lock_writing()
|
||||
|
||||
mean = np.mean([transition.reward for transition in self.transitions])
|
||||
|
||||
self.reader_writer_lock.release_writing()
|
||||
return mean
|
||||
@@ -0,0 +1,147 @@
|
||||
#
|
||||
# Copyright (c) 2017 Intel Corporation
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import copy
|
||||
from enum import Enum
|
||||
from typing import Tuple, List
|
||||
|
||||
import numpy as np
|
||||
|
||||
from rl_coach.core_types import Episode, Transition
|
||||
from rl_coach.memories.episodic.episodic_experience_replay import EpisodicExperienceReplayParameters, EpisodicExperienceReplay
|
||||
from rl_coach.memories.non_episodic.experience_replay import MemoryGranularity
|
||||
from rl_coach.spaces import GoalsSpace
|
||||
|
||||
|
||||
class HindsightGoalSelectionMethod(Enum):
|
||||
Future = 0
|
||||
Final = 1
|
||||
Episode = 2
|
||||
Random = 3
|
||||
|
||||
|
||||
class EpisodicHindsightExperienceReplayParameters(EpisodicExperienceReplayParameters):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.hindsight_transitions_per_regular_transition = None
|
||||
self.hindsight_goal_selection_method = None
|
||||
self.goals_space = None
|
||||
|
||||
@property
|
||||
def path(self):
|
||||
return 'rl_coach.memories.episodic.episodic_hindsight_experience_replay:EpisodicHindsightExperienceReplay'
|
||||
|
||||
|
||||
class EpisodicHindsightExperienceReplay(EpisodicExperienceReplay):
|
||||
"""
|
||||
Implements Hindsight Experience Replay as described in the following paper: https://arxiv.org/pdf/1707.01495.pdf
|
||||
|
||||
"""
|
||||
def __init__(self, max_size: Tuple[MemoryGranularity, int],
|
||||
hindsight_transitions_per_regular_transition: int,
|
||||
hindsight_goal_selection_method: HindsightGoalSelectionMethod,
|
||||
goals_space: GoalsSpace):
|
||||
"""
|
||||
:param max_size: The maximum size of the memory. should be defined in a granularity of Transitions
|
||||
:param hindsight_transitions_per_regular_transition: The number of hindsight artificial transitions to generate
|
||||
for each actual transition
|
||||
:param hindsight_goal_selection_method: The method that will be used for generating the goals for the
|
||||
hindsight transitions. Should be one of HindsightGoalSelectionMethod
|
||||
:param goals_space: A GoalsSpace which defines the base properties of the goals space
|
||||
"""
|
||||
super().__init__(max_size)
|
||||
|
||||
self.hindsight_transitions_per_regular_transition = hindsight_transitions_per_regular_transition
|
||||
self.hindsight_goal_selection_method = hindsight_goal_selection_method
|
||||
self.goals_space = goals_space
|
||||
self.last_episode_start_idx = 0
|
||||
|
||||
def _sample_goal(self, episode_transitions: List, transition_index: int):
|
||||
"""
|
||||
Sample a single goal state according to the sampling method
|
||||
:param episode_transitions: a list of all the transitions in the current episode
|
||||
:param transition_index: the transition to start sampling from
|
||||
:return: a goal corresponding to the sampled state
|
||||
"""
|
||||
if self.hindsight_goal_selection_method == HindsightGoalSelectionMethod.Future:
|
||||
# states that were observed in the same episode after the transition that is being replayed
|
||||
selected_transition = np.random.choice(episode_transitions[transition_index+1:])
|
||||
elif self.hindsight_goal_selection_method == HindsightGoalSelectionMethod.Final:
|
||||
# the final state in the episode
|
||||
selected_transition = episode_transitions[-1]
|
||||
elif self.hindsight_goal_selection_method == HindsightGoalSelectionMethod.Episode:
|
||||
# a random state from the episode
|
||||
selected_transition = np.random.choice(episode_transitions)
|
||||
elif self.hindsight_goal_selection_method == HindsightGoalSelectionMethod.Random:
|
||||
# a random state from the entire replay buffer
|
||||
selected_transition = np.random.choice(self.transitions)
|
||||
else:
|
||||
raise ValueError("Invalid goal selection method was used for the hindsight goal selection")
|
||||
return self.goals_space.goal_from_state(selected_transition.state)
|
||||
|
||||
def _sample_goals(self, episode_transitions: List, transition_index: int):
|
||||
"""
|
||||
Sample a batch of goal states according to the sampling method
|
||||
:param episode_transitions: a list of all the transitions in the current episode
|
||||
:param transition_index: the transition to start sampling from
|
||||
:return: a goal corresponding to the sampled state
|
||||
"""
|
||||
return [
|
||||
self._sample_goal(episode_transitions, transition_index)
|
||||
for _ in range(self.hindsight_transitions_per_regular_transition)
|
||||
]
|
||||
|
||||
def store_episode(self, episode: Episode, lock: bool=True) -> None:
|
||||
# generate hindsight transitions only when an episode is finished
|
||||
last_episode_transitions = copy.copy(episode.transitions)
|
||||
|
||||
# cannot create a future hindsight goal in the last transition of an episode
|
||||
if self.hindsight_goal_selection_method == HindsightGoalSelectionMethod.Future:
|
||||
relevant_base_transitions = last_episode_transitions[:-1]
|
||||
else:
|
||||
relevant_base_transitions = last_episode_transitions
|
||||
|
||||
# for each transition in the last episode, create a set of hindsight transitions
|
||||
for transition_index, transition in enumerate(relevant_base_transitions):
|
||||
sampled_goals = self._sample_goals(last_episode_transitions, transition_index)
|
||||
for goal in sampled_goals:
|
||||
hindsight_transition = copy.copy(transition)
|
||||
|
||||
if hindsight_transition.state['desired_goal'].shape != goal.shape:
|
||||
raise ValueError((
|
||||
'goal shape {goal_shape} already in transition is '
|
||||
'different than the one sampled as a hindsight goal '
|
||||
'{hindsight_goal_shape}.'
|
||||
).format(
|
||||
goal_shape=hindsight_transition.state['desired_goal'].shape,
|
||||
hindsight_goal_shape=goal.shape,
|
||||
))
|
||||
|
||||
# update the goal in the transition
|
||||
hindsight_transition.state['desired_goal'] = goal
|
||||
hindsight_transition.next_state['desired_goal'] = goal
|
||||
|
||||
# update the reward and terminal signal according to the goal
|
||||
hindsight_transition.reward, hindsight_transition.game_over = \
|
||||
self.goals_space.get_reward_for_goal_and_state(goal, hindsight_transition.next_state)
|
||||
|
||||
hindsight_transition.total_return = None
|
||||
episode.insert(hindsight_transition)
|
||||
|
||||
super().store_episode(episode)
|
||||
|
||||
def store(self, transition: Transition):
|
||||
raise ValueError("An episodic HER cannot store a single transition. Only full episodes are to be stored.")
|
||||
@@ -0,0 +1,69 @@
|
||||
#
|
||||
# Copyright (c) 2017 Intel Corporation
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
from rl_coach.core_types import Episode, Transition
|
||||
from rl_coach.memories.episodic.episodic_hindsight_experience_replay import HindsightGoalSelectionMethod, \
|
||||
EpisodicHindsightExperienceReplay, EpisodicHindsightExperienceReplayParameters
|
||||
from rl_coach.memories.non_episodic.experience_replay import MemoryGranularity
|
||||
from rl_coach.spaces import GoalsSpace
|
||||
|
||||
|
||||
class EpisodicHRLHindsightExperienceReplayParameters(EpisodicHindsightExperienceReplayParameters):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@property
|
||||
def path(self):
|
||||
return 'memories.episodic.episodic_hrl_hindsight_experience_replay:EpisodicHRLHindsightExperienceReplay'
|
||||
|
||||
|
||||
class EpisodicHRLHindsightExperienceReplay(EpisodicHindsightExperienceReplay):
|
||||
"""
|
||||
Implements HRL Hindsight Experience Replay as described in the following paper: https://arxiv.org/abs/1805.08180
|
||||
|
||||
This is the memory you should use if you want a shared hindsight experience replay buffer between multiple workers
|
||||
"""
|
||||
def __init__(self, max_size: Tuple[MemoryGranularity, int],
|
||||
hindsight_transitions_per_regular_transition: int,
|
||||
hindsight_goal_selection_method: HindsightGoalSelectionMethod,
|
||||
goals_space: GoalsSpace,
|
||||
):
|
||||
"""
|
||||
:param max_size: The maximum size of the memory. should be defined in a granularity of Transitions
|
||||
:param hindsight_transitions_per_regular_transition: The number of hindsight artificial transitions to generate
|
||||
for each actual transition
|
||||
:param hindsight_goal_selection_method: The method that will be used for generating the goals for the
|
||||
hindsight transitions. Should be one of HindsightGoalSelectionMethod
|
||||
:param goals_space: A GoalsSpace which defines the properties of the goals
|
||||
:param do_action_hindsight: Replace the action (sub-goal) given to a lower layer, with the actual achieved goal
|
||||
"""
|
||||
super().__init__(max_size, hindsight_transitions_per_regular_transition, hindsight_goal_selection_method,
|
||||
goals_space)
|
||||
|
||||
def store_episode(self, episode: Episode, lock: bool=True) -> None:
|
||||
# for a layer producing sub-goals, we will replace in hindsight the action (sub-goal) given to the lower
|
||||
# level with the actual achieved goal. the achieved goal (and observation) seen is assumed to be the same
|
||||
# for all levels - we can use this level's achieved goal instead of the lower level's one
|
||||
for transition in episode.transitions:
|
||||
new_achieved_goal = transition.next_state[self.goals_space.goal_name]
|
||||
transition.action = new_achieved_goal
|
||||
|
||||
super().store_episode(episode)
|
||||
|
||||
def store(self, transition: Transition):
|
||||
raise ValueError("An episodic HER cannot store a single transition. Only full episodes are to be stored.")
|
||||
34
rl_coach/memories/episodic/single_episode_buffer.py
Normal file
34
rl_coach/memories/episodic/single_episode_buffer.py
Normal file
@@ -0,0 +1,34 @@
|
||||
#
|
||||
# Copyright (c) 2017 Intel Corporation
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from rl_coach.memories.memory import MemoryGranularity, MemoryParameters
|
||||
|
||||
from rl_coach.memories.episodic.episodic_experience_replay import EpisodicExperienceReplay
|
||||
|
||||
|
||||
class SingleEpisodeBufferParameters(MemoryParameters):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
del self.max_size
|
||||
|
||||
@property
|
||||
def path(self):
|
||||
return 'rl_coach.memories.episodic.single_episode_buffer:SingleEpisodeBuffer'
|
||||
|
||||
|
||||
class SingleEpisodeBuffer(EpisodicExperienceReplay):
|
||||
def __init__(self):
|
||||
super().__init__((MemoryGranularity.Episodes, 1))
|
||||
67
rl_coach/memories/memory.py
Normal file
67
rl_coach/memories/memory.py
Normal file
@@ -0,0 +1,67 @@
|
||||
#
|
||||
# Copyright (c) 2017 Intel Corporation
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from enum import Enum
|
||||
from typing import Tuple
|
||||
|
||||
from rl_coach.base_parameters import Parameters
|
||||
|
||||
|
||||
class MemoryGranularity(Enum):
|
||||
Transitions = 0
|
||||
Episodes = 1
|
||||
|
||||
|
||||
class MemoryParameters(Parameters):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.max_size = None
|
||||
self.shared_memory = False
|
||||
self.load_memory_from_file_path = None
|
||||
|
||||
|
||||
@property
|
||||
def path(self):
|
||||
return 'rl_coach.memories.memory:Memory'
|
||||
|
||||
|
||||
class Memory(object):
|
||||
def __init__(self, max_size: Tuple[MemoryGranularity, int]):
|
||||
"""
|
||||
:param max_size: the maximum number of objects to hold in the memory
|
||||
"""
|
||||
self.max_size = max_size
|
||||
self._length = 0
|
||||
|
||||
def store(self, obj):
|
||||
raise NotImplementedError("")
|
||||
|
||||
def get(self, index):
|
||||
raise NotImplementedError("")
|
||||
|
||||
def remove(self, index):
|
||||
raise NotImplementedError("")
|
||||
|
||||
def length(self):
|
||||
raise NotImplementedError("")
|
||||
|
||||
def sample(self, size):
|
||||
raise NotImplementedError("")
|
||||
|
||||
def clean(self):
|
||||
raise NotImplementedError("")
|
||||
|
||||
|
||||
0
rl_coach/memories/non_episodic/__init__.py
Normal file
0
rl_coach/memories/non_episodic/__init__.py
Normal file
@@ -0,0 +1,286 @@
|
||||
#
|
||||
# Copyright (c) 2017 Intel Corporation
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import os
|
||||
import pickle
|
||||
|
||||
import numpy as np
|
||||
from annoy import AnnoyIndex
|
||||
|
||||
|
||||
class AnnoyDictionary(object):
|
||||
def __init__(self, dict_size, key_width, new_value_shift_coefficient=0.1, batch_size=100, key_error_threshold=0.01,
|
||||
num_neighbors=50, override_existing_keys=True, rebuild_on_every_update=False):
|
||||
self.rebuild_on_every_update = rebuild_on_every_update
|
||||
self.max_size = dict_size
|
||||
self.curr_size = 0
|
||||
self.new_value_shift_coefficient = new_value_shift_coefficient
|
||||
self.num_neighbors = num_neighbors
|
||||
self.override_existing_keys = override_existing_keys
|
||||
|
||||
self.index = AnnoyIndex(key_width, metric='euclidean')
|
||||
self.index.set_seed(1)
|
||||
|
||||
self.embeddings = np.zeros((dict_size, key_width))
|
||||
self.values = np.zeros(dict_size)
|
||||
self.additional_data = [None] * dict_size
|
||||
|
||||
self.lru_timestamps = np.zeros(dict_size)
|
||||
self.current_timestamp = 0.0
|
||||
|
||||
# keys that are in this distance will be considered as the same key
|
||||
self.key_error_threshold = key_error_threshold
|
||||
|
||||
self.initial_update_size = batch_size
|
||||
self.min_update_size = self.initial_update_size
|
||||
self.key_dimension = key_width
|
||||
self.value_dimension = 1
|
||||
self._reset_buffer()
|
||||
|
||||
self.built_capacity = 0
|
||||
|
||||
def add(self, keys, values, additional_data=None):
|
||||
if not additional_data:
|
||||
additional_data = [None] * len(keys)
|
||||
|
||||
# Adds new embeddings and values to the dictionary
|
||||
indices = []
|
||||
indices_to_remove = []
|
||||
for i in range(keys.shape[0]):
|
||||
index = self._lookup_key_index(keys[i])
|
||||
if index and self.override_existing_keys:
|
||||
# update existing value
|
||||
self.values[index] += self.new_value_shift_coefficient * (values[i] - self.values[index])
|
||||
self.additional_data[index[0][0]] = additional_data[i]
|
||||
self.lru_timestamps[index] = self.current_timestamp
|
||||
indices_to_remove.append(i)
|
||||
else:
|
||||
# add new
|
||||
if self.curr_size >= self.max_size:
|
||||
# find the LRU entry
|
||||
index = np.argmin(self.lru_timestamps)
|
||||
else:
|
||||
index = self.curr_size
|
||||
self.curr_size += 1
|
||||
self.lru_timestamps[index] = self.current_timestamp
|
||||
indices.append(index)
|
||||
|
||||
for i in reversed(indices_to_remove):
|
||||
keys = np.delete(keys, i, 0)
|
||||
values = np.delete(values, i, 0)
|
||||
del additional_data[i]
|
||||
|
||||
self.buffered_keys = np.vstack((self.buffered_keys, keys))
|
||||
self.buffered_values = np.vstack((self.buffered_values, values))
|
||||
self.buffered_indices = self.buffered_indices + indices
|
||||
self.buffered_additional_data = self.buffered_additional_data + additional_data
|
||||
|
||||
if len(self.buffered_indices) >= self.min_update_size:
|
||||
self.min_update_size = max(self.initial_update_size, int(self.curr_size * 0.02))
|
||||
self._rebuild_index()
|
||||
elif self.rebuild_on_every_update:
|
||||
self._rebuild_index()
|
||||
|
||||
self.current_timestamp += 1
|
||||
|
||||
# Returns the stored embeddings and values of the closest embeddings
|
||||
def query(self, keys, k):
|
||||
if not self.has_enough_entries(k):
|
||||
# this will only happen when the DND is not yet populated with enough entries, which is only during heatup
|
||||
# these values won't be used and therefore they are meaningless
|
||||
return [0.0], [0.0], [0], [None]
|
||||
|
||||
_, indices = self._get_k_nearest_neighbors_indices(keys, k)
|
||||
|
||||
embeddings = []
|
||||
values = []
|
||||
additional_data = []
|
||||
for ind in indices:
|
||||
self.lru_timestamps[ind] = self.current_timestamp
|
||||
embeddings.append(self.embeddings[ind])
|
||||
values.append(self.values[ind])
|
||||
curr_additional_data = []
|
||||
for sub_ind in ind:
|
||||
curr_additional_data.append(self.additional_data[sub_ind])
|
||||
additional_data.append(curr_additional_data)
|
||||
|
||||
self.current_timestamp += 1
|
||||
|
||||
return embeddings, values, indices, additional_data
|
||||
|
||||
def has_enough_entries(self, k):
|
||||
return self.curr_size > k and (self.built_capacity > k)
|
||||
|
||||
def sample_embeddings(self, num_embeddings):
|
||||
return self.embeddings[np.random.choice(self.curr_size, num_embeddings)]
|
||||
|
||||
def _get_k_nearest_neighbors_indices(self, keys, k):
|
||||
distances = []
|
||||
indices = []
|
||||
for key in keys:
|
||||
index, distance = self.index.get_nns_by_vector(key, k, include_distances=True)
|
||||
distances.append(distance)
|
||||
indices.append(index)
|
||||
return distances, indices
|
||||
|
||||
def _rebuild_index(self):
|
||||
self.index.unbuild()
|
||||
self.embeddings[self.buffered_indices] = self.buffered_keys
|
||||
self.values[self.buffered_indices] = np.squeeze(self.buffered_values)
|
||||
for i, data in zip(self.buffered_indices, self.buffered_additional_data):
|
||||
self.additional_data[i] = data
|
||||
for idx, key in zip(self.buffered_indices, self.buffered_keys):
|
||||
self.index.add_item(idx, key)
|
||||
|
||||
self._reset_buffer()
|
||||
|
||||
self.index.build(self.num_neighbors)
|
||||
self.built_capacity = self.curr_size
|
||||
|
||||
def _reset_buffer(self):
|
||||
self.buffered_keys = np.zeros((0, self.key_dimension))
|
||||
self.buffered_values = np.zeros((0, self.value_dimension))
|
||||
self.buffered_indices = []
|
||||
self.buffered_additional_data = []
|
||||
|
||||
def _lookup_key_index(self, key):
|
||||
distance, index = self._get_k_nearest_neighbors_indices([key], 1)
|
||||
if distance != [[]] and distance[0][0] <= self.key_error_threshold:
|
||||
return index
|
||||
return None
|
||||
|
||||
|
||||
class QDND(object):
|
||||
def __init__(self, dict_size, key_width, num_actions, new_value_shift_coefficient=0.1, key_error_threshold=0.01,
|
||||
learning_rate=0.01, num_neighbors=50, return_additional_data=False, override_existing_keys=False,
|
||||
rebuild_on_every_update=False):
|
||||
self.dict_size = dict_size
|
||||
self.key_width = key_width
|
||||
self.num_actions = num_actions
|
||||
self.new_value_shift_coefficient = new_value_shift_coefficient
|
||||
self.key_error_threshold = key_error_threshold
|
||||
self.learning_rate = learning_rate
|
||||
self.num_neighbors = num_neighbors
|
||||
self.return_additional_data = return_additional_data
|
||||
self.override_existing_keys = override_existing_keys
|
||||
self.dicts = []
|
||||
|
||||
# create a dict for each action
|
||||
for a in range(num_actions):
|
||||
new_dict = AnnoyDictionary(dict_size, key_width, new_value_shift_coefficient,
|
||||
key_error_threshold=key_error_threshold, num_neighbors=num_neighbors,
|
||||
override_existing_keys=override_existing_keys,
|
||||
rebuild_on_every_update=rebuild_on_every_update)
|
||||
self.dicts.append(new_dict)
|
||||
|
||||
def add(self, embeddings, actions, values, additional_data=None):
|
||||
# add a new set of embeddings and values to each of the underlining dictionaries
|
||||
embeddings = np.array(embeddings)
|
||||
actions = np.array(actions)
|
||||
values = np.array(values)
|
||||
for a in range(self.num_actions):
|
||||
idx = np.where(actions == a)
|
||||
curr_action_embeddings = embeddings[idx]
|
||||
curr_action_values = np.expand_dims(values[idx], -1)
|
||||
if additional_data:
|
||||
curr_additional_data = []
|
||||
for i in idx[0]:
|
||||
curr_additional_data.append(additional_data[i])
|
||||
else:
|
||||
curr_additional_data = None
|
||||
|
||||
self.dicts[a].add(curr_action_embeddings, curr_action_values, curr_additional_data)
|
||||
return True
|
||||
|
||||
def query(self, embeddings, action, k):
|
||||
# query for nearest neighbors to the given embeddings
|
||||
dnd_embeddings = []
|
||||
dnd_values = []
|
||||
dnd_indices = []
|
||||
dnd_additional_data = []
|
||||
for i in range(len(embeddings)):
|
||||
embedding, value, indices, additional_data = self.dicts[action].query([embeddings[i]], k)
|
||||
dnd_embeddings.append(embedding[0])
|
||||
dnd_values.append(value[0])
|
||||
dnd_indices.append(indices[0])
|
||||
dnd_additional_data.append(additional_data[0])
|
||||
|
||||
if self.return_additional_data:
|
||||
return dnd_embeddings, dnd_values, dnd_indices, dnd_additional_data
|
||||
else:
|
||||
return dnd_embeddings, dnd_values, dnd_indices
|
||||
|
||||
def has_enough_entries(self, k):
|
||||
# check if each of the action dictionaries has at least k entries
|
||||
for a in range(self.num_actions):
|
||||
if not self.dicts[a].has_enough_entries(k):
|
||||
return False
|
||||
return True
|
||||
|
||||
def update_keys_and_values(self, actions, key_gradients, value_gradients, indices):
|
||||
# Update DND keys and values
|
||||
for batch_action, batch_keys, batch_values, batch_indices in zip(actions, key_gradients, value_gradients, indices):
|
||||
# Update keys (embeddings) and values in DND
|
||||
for i, index in enumerate(batch_indices):
|
||||
self.dicts[batch_action].embeddings[index, :] -= self.learning_rate * batch_keys[i, :]
|
||||
self.dicts[batch_action].values[index] -= self.learning_rate * batch_values[i]
|
||||
|
||||
def sample_embeddings(self, num_embeddings):
|
||||
num_actions = len(self.dicts)
|
||||
embeddings = []
|
||||
num_embeddings_per_action = int(num_embeddings/num_actions)
|
||||
for action in range(num_actions):
|
||||
embeddings.append(self.dicts[action].sample_embeddings(num_embeddings_per_action))
|
||||
embeddings = np.vstack(embeddings)
|
||||
|
||||
# the numbers did not divide nicely, let's just randomly sample some more embeddings
|
||||
if num_embeddings_per_action * num_actions < num_embeddings:
|
||||
action = np.random.randint(0, num_actions)
|
||||
extra_embeddings = self.dicts[action].sample_embeddings(num_embeddings -
|
||||
num_embeddings_per_action * num_actions)
|
||||
embeddings = np.vstack([embeddings, extra_embeddings])
|
||||
return embeddings
|
||||
|
||||
def clean(self):
|
||||
# create a new dict for each action
|
||||
self.dicts = []
|
||||
for a in range(self.num_actions):
|
||||
new_dict = AnnoyDictionary(self.dict_size, self.key_width, self.new_value_shift_coefficient,
|
||||
key_error_threshold=self.key_error_threshold, num_neighbors=self.num_neighbors)
|
||||
self.dicts.append(new_dict)
|
||||
|
||||
|
||||
def load_dnd(model_dir):
|
||||
max_id = 0
|
||||
|
||||
for f in [s for s in os.listdir(model_dir) if s.endswith('.dnd')]:
|
||||
if int(f.split('.')[0]) > max_id:
|
||||
max_id = int(f.split('.')[0])
|
||||
|
||||
model_path = str(max_id) + '.dnd'
|
||||
with open(os.path.join(model_dir, model_path), 'rb') as f:
|
||||
DND = pickle.load(f)
|
||||
|
||||
for a in range(DND.num_actions):
|
||||
DND.dicts[a].index = AnnoyIndex(512, metric='euclidean')
|
||||
DND.dicts[a].index.set_seed(1)
|
||||
|
||||
for idx, key in zip(range(DND.dicts[a].curr_size), DND.dicts[a].embeddings[:DND.dicts[a].curr_size]):
|
||||
DND.dicts[a].index.add_item(idx, key)
|
||||
|
||||
DND.dicts[a].index.build(50)
|
||||
|
||||
return DND
|
||||
220
rl_coach/memories/non_episodic/experience_replay.py
Normal file
220
rl_coach/memories/non_episodic/experience_replay.py
Normal file
@@ -0,0 +1,220 @@
|
||||
#
|
||||
# Copyright (c) 2017 Intel Corporation
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from typing import List, Tuple, Union, Dict, Any
|
||||
|
||||
import numpy as np
|
||||
from rl_coach.utils import ReaderWriterLock
|
||||
|
||||
from rl_coach.core_types import Transition
|
||||
from rl_coach.memories.memory import Memory, MemoryGranularity, MemoryParameters
|
||||
|
||||
|
||||
class ExperienceReplayParameters(MemoryParameters):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.max_size = (MemoryGranularity.Transitions, 1000000)
|
||||
self.allow_duplicates_in_batch_sampling = True
|
||||
|
||||
@property
|
||||
def path(self):
|
||||
return 'rl_coach.memories.non_episodic.experience_replay:ExperienceReplay'
|
||||
|
||||
|
||||
class ExperienceReplay(Memory):
|
||||
"""
|
||||
A regular replay buffer which stores transition without any additional structure
|
||||
"""
|
||||
def __init__(self, max_size: Tuple[MemoryGranularity, int], allow_duplicates_in_batch_sampling: bool=True):
|
||||
"""
|
||||
:param max_size: the maximum number of transitions or episodes to hold in the memory
|
||||
:param allow_duplicates_in_batch_sampling: allow having the same transition multiple times in a batch
|
||||
"""
|
||||
super().__init__(max_size)
|
||||
if max_size[0] != MemoryGranularity.Transitions:
|
||||
raise ValueError("Experience replay size can only be configured in terms of transitions")
|
||||
self.transitions = []
|
||||
self._num_transitions = 0
|
||||
self.allow_duplicates_in_batch_sampling = allow_duplicates_in_batch_sampling
|
||||
|
||||
self.reader_writer_lock = ReaderWriterLock()
|
||||
|
||||
def length(self) -> int:
|
||||
"""
|
||||
Get the number of transitions in the ER
|
||||
"""
|
||||
return self.num_transitions()
|
||||
|
||||
def num_transitions(self) -> int:
|
||||
"""
|
||||
Get the number of transitions in the ER
|
||||
"""
|
||||
return self._num_transitions
|
||||
|
||||
def sample(self, size: int) -> List[Transition]:
|
||||
"""
|
||||
Sample a batch of transitions form the replay buffer. If the requested size is larger than the number
|
||||
of samples available in the replay buffer then the batch will return empty.
|
||||
:param size: the size of the batch to sample
|
||||
:param beta: the beta parameter used for importance sampling
|
||||
:return: a batch (list) of selected transitions from the replay buffer
|
||||
"""
|
||||
self.reader_writer_lock.lock_writing()
|
||||
|
||||
if self.allow_duplicates_in_batch_sampling:
|
||||
transitions_idx = np.random.randint(self.num_transitions(), size=size)
|
||||
|
||||
else:
|
||||
if self.num_transitions() >= size:
|
||||
transitions_idx = np.random.choice(self.num_transitions(), size=size, replace=False)
|
||||
else:
|
||||
raise ValueError("The replay buffer cannot be sampled since there are not enough transitions yet. "
|
||||
"There are currently {} transitions".format(self.num_transitions()))
|
||||
|
||||
batch = [self.transitions[i] for i in transitions_idx]
|
||||
|
||||
self.reader_writer_lock.release_writing()
|
||||
|
||||
return batch
|
||||
|
||||
def _enforce_max_length(self) -> None:
|
||||
"""
|
||||
Make sure that the size of the replay buffer does not pass the maximum size allowed.
|
||||
If it passes the max size, the oldest transition in the replay buffer will be removed.
|
||||
This function does not use locks since it is only called internally
|
||||
:return: None
|
||||
"""
|
||||
granularity, size = self.max_size
|
||||
if granularity == MemoryGranularity.Transitions:
|
||||
while size != 0 and self.num_transitions() > size:
|
||||
self.remove_transition(0, False)
|
||||
else:
|
||||
raise ValueError("The granularity of the replay buffer can only be set in terms of transitions")
|
||||
|
||||
def store(self, transition: Transition, lock: bool=True) -> None:
|
||||
"""
|
||||
Store a new transition in the memory.
|
||||
:param transition: a transition to store
|
||||
:param lock: if true, will lock the readers writers lock. this can cause a deadlock if an inheriting class
|
||||
locks and then calls store with lock = True
|
||||
:return: None
|
||||
"""
|
||||
if lock:
|
||||
self.reader_writer_lock.lock_writing_and_reading()
|
||||
|
||||
self._num_transitions += 1
|
||||
self.transitions.append(transition)
|
||||
self._enforce_max_length()
|
||||
|
||||
if lock:
|
||||
self.reader_writer_lock.release_writing_and_reading()
|
||||
|
||||
def get_transition(self, transition_index: int, lock: bool=True) -> Union[None, Transition]:
|
||||
"""
|
||||
Returns the transition in the given index. If the transition does not exist, returns None instead.
|
||||
:param transition_index: the index of the transition to return
|
||||
:param lock: use write locking if this is a shared memory
|
||||
:return: the corresponding transition
|
||||
"""
|
||||
if lock:
|
||||
self.reader_writer_lock.lock_writing()
|
||||
|
||||
if self.length() == 0 or transition_index >= self.length():
|
||||
transition = None
|
||||
else:
|
||||
transition = self.transitions[transition_index]
|
||||
|
||||
if lock:
|
||||
self.reader_writer_lock.release_writing()
|
||||
|
||||
return transition
|
||||
|
||||
def remove_transition(self, transition_index: int, lock: bool=True) -> None:
|
||||
"""
|
||||
Remove the transition in the given index.
|
||||
This does not remove the transition from the segment trees! it is just used to remove the transition
|
||||
from the transitions list
|
||||
:param transition_index: the index of the transition to remove
|
||||
:return: None
|
||||
"""
|
||||
if lock:
|
||||
self.reader_writer_lock.lock_writing_and_reading()
|
||||
|
||||
if self.num_transitions() > transition_index:
|
||||
self._num_transitions -= 1
|
||||
del self.transitions[transition_index]
|
||||
|
||||
if lock:
|
||||
self.reader_writer_lock.release_writing_and_reading()
|
||||
|
||||
# for API compatibility
|
||||
def get(self, transition_index: int, lock: bool=True) -> Union[None, Transition]:
|
||||
"""
|
||||
Returns the transition in the given index. If the transition does not exist, returns None instead.
|
||||
:param transition_index: the index of the transition to return
|
||||
:return: the corresponding transition
|
||||
"""
|
||||
return self.get_transition(transition_index, lock)
|
||||
|
||||
# for API compatibility
|
||||
def remove(self, transition_index: int, lock: bool=True):
|
||||
"""
|
||||
Remove the transition in the given index
|
||||
:param transition_index: the index of the transition to remove
|
||||
:return: None
|
||||
"""
|
||||
self.remove_transition(transition_index, lock)
|
||||
|
||||
def update_last_transition_info(self, info: Dict[str, Any]) -> None:
|
||||
"""
|
||||
Update the info of the last transition stored in the memory
|
||||
:param info: the new info to append to the existing info
|
||||
:return: None
|
||||
"""
|
||||
self.reader_writer_lock.lock_writing_and_reading()
|
||||
|
||||
if self.length() == 0:
|
||||
raise ValueError("There are no transition in the replay buffer")
|
||||
self.transitions[-1].info.update(info)
|
||||
|
||||
self.reader_writer_lock.release_writing_and_reading()
|
||||
|
||||
def clean(self, lock: bool=True) -> None:
|
||||
"""
|
||||
Clean the memory by removing all the episodes
|
||||
:return: None
|
||||
"""
|
||||
if lock:
|
||||
self.reader_writer_lock.lock_writing_and_reading()
|
||||
|
||||
self.transitions = []
|
||||
self._num_transitions = 0
|
||||
|
||||
if lock:
|
||||
self.reader_writer_lock.release_writing_and_reading()
|
||||
|
||||
def mean_reward(self) -> np.ndarray:
|
||||
"""
|
||||
Get the mean reward in the replay buffer
|
||||
:return: the mean reward
|
||||
"""
|
||||
self.reader_writer_lock.lock_writing()
|
||||
|
||||
mean = np.mean([transition.reward for transition in self.transitions])
|
||||
|
||||
self.reader_writer_lock.release_writing()
|
||||
|
||||
return mean
|
||||
292
rl_coach/memories/non_episodic/prioritized_experience_replay.py
Normal file
292
rl_coach/memories/non_episodic/prioritized_experience_replay.py
Normal file
@@ -0,0 +1,292 @@
|
||||
#
|
||||
# Copyright (c) 2017 Intel Corporation
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import operator
|
||||
import random
|
||||
from enum import Enum
|
||||
from typing import List, Tuple, Any
|
||||
|
||||
import numpy as np
|
||||
from rl_coach.memories.memory import MemoryGranularity
|
||||
from rl_coach.schedules import Schedule, ConstantSchedule
|
||||
|
||||
from rl_coach.core_types import Transition
|
||||
from rl_coach.memories.non_episodic.experience_replay import ExperienceReplayParameters, ExperienceReplay
|
||||
|
||||
|
||||
class PrioritizedExperienceReplayParameters(ExperienceReplayParameters):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.max_size = (MemoryGranularity.Transitions, 1000000)
|
||||
self.alpha = 0.6
|
||||
self.beta = ConstantSchedule(0.4)
|
||||
self.epsilon = 1e-6
|
||||
|
||||
@property
|
||||
def path(self):
|
||||
return 'rl_coach.memories.non_episodic.prioritized_experience_replay:PrioritizedExperienceReplay'
|
||||
|
||||
|
||||
class SegmentTree(object):
|
||||
"""
|
||||
A tree which can be used as a min/max heap or a sum tree
|
||||
Add or update item value - O(log N)
|
||||
Sampling an item - O(log N)
|
||||
"""
|
||||
class Operation(Enum):
|
||||
MAX = {"operator": max, "initial_value": -float("inf")}
|
||||
MIN = {"operator": min, "initial_value": float("inf")}
|
||||
SUM = {"operator": operator.add, "initial_value": 0}
|
||||
|
||||
def __init__(self, size: int, operation: Operation):
|
||||
self.next_leaf_idx_to_write = 0
|
||||
self.size = size
|
||||
if not (size > 0 and size & (size - 1) == 0):
|
||||
raise ValueError("A segment tree size must be a positive power of 2. The given size is {}".format(self.size))
|
||||
self.operation = operation
|
||||
self.tree = np.ones(2 * size - 1) * self.operation.value['initial_value']
|
||||
self.data = [None] * size
|
||||
|
||||
def _propagate(self, node_idx: int) -> None:
|
||||
"""
|
||||
Propagate an update of a node's value to its parent node
|
||||
:param node_idx: the index of the node that was updated
|
||||
:return: None
|
||||
"""
|
||||
parent = (node_idx - 1) // 2
|
||||
|
||||
self.tree[parent] = self.operation.value['operator'](self.tree[parent * 2 + 1], self.tree[parent * 2 + 2])
|
||||
|
||||
if parent != 0:
|
||||
self._propagate(parent)
|
||||
|
||||
def _retrieve(self, root_node_idx: int, val: float)-> int:
|
||||
"""
|
||||
Retrieve the first node that has a value larger than val and is a child of the node at index idx
|
||||
:param root_node_idx: the index of the root node to search from
|
||||
:param val: the value to query for
|
||||
:return: the index of the resulting node
|
||||
"""
|
||||
left = 2 * root_node_idx + 1
|
||||
right = left + 1
|
||||
|
||||
if left >= len(self.tree):
|
||||
return root_node_idx
|
||||
|
||||
if val <= self.tree[left]:
|
||||
return self._retrieve(left, val)
|
||||
else:
|
||||
return self._retrieve(right, val-self.tree[left])
|
||||
|
||||
def total_value(self) -> float:
|
||||
"""
|
||||
Return the total value of the tree according to the tree operation. For SUM for example, this will return
|
||||
the total sum of the tree. for MIN, this will return the minimal value
|
||||
:return: the total value of the tree
|
||||
"""
|
||||
return self.tree[0]
|
||||
|
||||
def add(self, val: float, data: Any) -> None:
|
||||
"""
|
||||
Add a new value to the tree with data assigned to it
|
||||
:param val: the new value to add to the tree
|
||||
:param data: the data that should be assigned to this value
|
||||
:return: None
|
||||
"""
|
||||
self.data[self.next_leaf_idx_to_write] = data
|
||||
self.update(self.next_leaf_idx_to_write, val)
|
||||
|
||||
self.next_leaf_idx_to_write += 1
|
||||
if self.next_leaf_idx_to_write >= self.size:
|
||||
self.next_leaf_idx_to_write = 0
|
||||
|
||||
def update(self, leaf_idx: int, new_val: float) -> None:
|
||||
"""
|
||||
Update the value of the node at index idx
|
||||
:param leaf_idx: the index of the node to update
|
||||
:param new_val: the new value of the node
|
||||
:return: None
|
||||
"""
|
||||
node_idx = leaf_idx + self.size - 1
|
||||
if not 0 <= node_idx < len(self.tree):
|
||||
raise ValueError("The given left index ({}) can not be found in the tree. The available leaves are: 0-{}"
|
||||
.format(leaf_idx, self.size - 1))
|
||||
|
||||
self.tree[node_idx] = new_val
|
||||
self._propagate(node_idx)
|
||||
|
||||
def get(self, val: float) -> Tuple[int, float, Any]:
|
||||
"""
|
||||
Given a value between 0 and the tree sum, return the object which this value is in it's range.
|
||||
For example, if we have 3 leaves: 10, 20, 30, and val=35, this will return the 3rd leaf, by accumulating
|
||||
leaves by their order until getting to 35. This allows sampling leaves according to their proportional
|
||||
probability.
|
||||
:param val: a value within the range 0 and the tree sum
|
||||
:return: the index of the resulting leaf in the tree, it's probability and
|
||||
the object itself
|
||||
"""
|
||||
node_idx = self._retrieve(0, val)
|
||||
leaf_idx = node_idx - self.size + 1
|
||||
data_value = self.tree[node_idx]
|
||||
data = self.data[leaf_idx]
|
||||
|
||||
return leaf_idx, data_value, data
|
||||
|
||||
def __str__(self):
|
||||
result = ""
|
||||
start = 0
|
||||
size = 1
|
||||
while size <= self.size:
|
||||
result += "{}\n".format(self.tree[start:(start + size)])
|
||||
start += size
|
||||
size *= 2
|
||||
return result
|
||||
|
||||
|
||||
class PrioritizedExperienceReplay(ExperienceReplay):
|
||||
"""
|
||||
This is the proportional sampling variant of the prioritized experience replay as described
|
||||
in https://arxiv.org/pdf/1511.05952.pdf.
|
||||
"""
|
||||
def __init__(self, max_size: Tuple[MemoryGranularity, int], alpha: float=0.6, beta: Schedule=ConstantSchedule(0.4),
|
||||
epsilon: float=1e-6, allow_duplicates_in_batch_sampling: bool=True):
|
||||
"""
|
||||
:param max_size: the maximum number of transitions or episodes to hold in the memory
|
||||
:param alpha: the alpha prioritization coefficient
|
||||
:param beta: the beta parameter used for importance sampling
|
||||
:param epsilon: a small value added to the priority of each transition
|
||||
:param allow_duplicates_in_batch_sampling: allow having the same transition multiple times in a batch
|
||||
"""
|
||||
if max_size[0] != MemoryGranularity.Transitions:
|
||||
raise ValueError("Prioritized Experience Replay currently only support setting the memory size in "
|
||||
"transitions granularity.")
|
||||
self.power_of_2_size = 1
|
||||
while self.power_of_2_size < max_size[1]:
|
||||
self.power_of_2_size *= 2
|
||||
super().__init__((MemoryGranularity.Transitions, self.power_of_2_size), allow_duplicates_in_batch_sampling)
|
||||
self.sum_tree = SegmentTree(self.power_of_2_size, SegmentTree.Operation.SUM)
|
||||
self.min_tree = SegmentTree(self.power_of_2_size, SegmentTree.Operation.MIN)
|
||||
self.max_tree = SegmentTree(self.power_of_2_size, SegmentTree.Operation.MAX)
|
||||
self.alpha = alpha
|
||||
self.beta = beta
|
||||
self.epsilon = epsilon
|
||||
self.maximal_priority = 1.0
|
||||
|
||||
def _update_priority(self, leaf_idx: int, error: float) -> None:
|
||||
"""
|
||||
Update the priority of a given transition, using its index in the tree and its error
|
||||
:param leaf_idx: the index of the transition leaf in the tree
|
||||
:param error: the new error value
|
||||
:return: None
|
||||
"""
|
||||
if error < 0:
|
||||
raise ValueError("The priorities must be non-negative values")
|
||||
priority = (error + self.epsilon)
|
||||
self.sum_tree.update(leaf_idx, priority ** self.alpha)
|
||||
self.min_tree.update(leaf_idx, priority ** self.alpha)
|
||||
self.max_tree.update(leaf_idx, priority)
|
||||
self.maximal_priority = self.max_tree.total_value()
|
||||
|
||||
def update_priorities(self, indices: List[int], error_values: List[float]) -> None:
|
||||
"""
|
||||
Update the priorities of a batch of transitions using their indices and their new TD error terms
|
||||
:param indices: the indices of the transitions to update
|
||||
:param error_values: the new error values
|
||||
:return: None
|
||||
"""
|
||||
self.reader_writer_lock.lock_writing_and_reading()
|
||||
|
||||
if len(indices) != len(error_values):
|
||||
raise ValueError("The number of indexes requested for update don't match the number of error values given")
|
||||
for transition_idx, error in zip(indices, error_values):
|
||||
self._update_priority(transition_idx, error)
|
||||
|
||||
self.reader_writer_lock.release_writing_and_reading()
|
||||
|
||||
def sample(self, size: int) -> List[Transition]:
|
||||
"""
|
||||
Sample a batch of transitions form the replay buffer. If the requested size is larger than the number
|
||||
of samples available in the replay buffer then the batch will return empty.
|
||||
:param size: the size of the batch to sample
|
||||
:return: a batch (list) of selected transitions from the replay buffer
|
||||
"""
|
||||
|
||||
self.reader_writer_lock.lock_writing()
|
||||
|
||||
if self.num_transitions() >= size:
|
||||
# split the tree leaves to equal segments and sample one transition from each segment
|
||||
batch = []
|
||||
segment_size = self.sum_tree.total_value() / size
|
||||
|
||||
# get the maximum weight in the memory
|
||||
min_probability = self.min_tree.total_value() / self.sum_tree.total_value() # min P(j) = min p^a / sum(p^a)
|
||||
max_weight = (min_probability * self.num_transitions()) ** -self.beta.current_value # max wi
|
||||
|
||||
# sample a batch
|
||||
for i in range(size):
|
||||
start_probability = segment_size * i
|
||||
end_probability = segment_size * (i + 1)
|
||||
|
||||
# sample leaf and calculate its weight
|
||||
val = random.uniform(start_probability, end_probability)
|
||||
leaf_idx, priority, transition = self.sum_tree.get(val)
|
||||
priority /= self.sum_tree.total_value() # P(j) = p^a / sum(p^a)
|
||||
weight = (self.num_transitions() * priority) ** -self.beta.current_value # (N * P(j)) ^ -beta
|
||||
normalized_weight = weight / max_weight # wj = ((N * P(j)) ^ -beta) / max wi
|
||||
|
||||
transition.info['idx'] = leaf_idx
|
||||
transition.info['weight'] = normalized_weight
|
||||
|
||||
batch.append(transition)
|
||||
|
||||
self.beta.step()
|
||||
|
||||
else:
|
||||
raise ValueError("The replay buffer cannot be sampled since there are not enough transitions yet. "
|
||||
"There are currently {} transitions".format(self.num_transitions()))
|
||||
|
||||
self.reader_writer_lock.release_writing()
|
||||
return batch
|
||||
|
||||
def store(self, transition: Transition) -> None:
|
||||
"""
|
||||
Store a new transition in the memory.
|
||||
:param transition: a transition to store
|
||||
:return: None
|
||||
"""
|
||||
self.reader_writer_lock.lock_writing_and_reading()
|
||||
|
||||
transition_priority = self.maximal_priority
|
||||
self.sum_tree.add(transition_priority ** self.alpha, transition)
|
||||
self.min_tree.add(transition_priority ** self.alpha, transition)
|
||||
self.max_tree.add(transition_priority, transition)
|
||||
super().store(transition, False)
|
||||
|
||||
self.reader_writer_lock.release_writing_and_reading()
|
||||
|
||||
def clean(self) -> None:
|
||||
"""
|
||||
Clean the memory by removing all the episodes
|
||||
:return: None
|
||||
"""
|
||||
self.reader_writer_lock.lock_writing_and_reading()
|
||||
|
||||
super().clean(lock=False)
|
||||
self.sum_tree = SegmentTree(self.power_of_2_size, SegmentTree.Operation.SUM)
|
||||
self.min_tree = SegmentTree(self.power_of_2_size, SegmentTree.Operation.MIN)
|
||||
self.max_tree = SegmentTree(self.power_of_2_size, SegmentTree.Operation.MAX)
|
||||
|
||||
self.reader_writer_lock.release_writing_and_reading()
|
||||
Reference in New Issue
Block a user