1
0
mirror of https://github.com/gryf/coach.git synced 2026-02-23 10:35:46 +01:00

Batch RL Tutorial (#372)

This commit is contained in:
Gal Leibovich
2019-07-14 18:43:48 +03:00
committed by GitHub
parent b82414138d
commit 19ad2d60a7
40 changed files with 1155 additions and 182 deletions

View File

@@ -15,6 +15,8 @@
# limitations under the License.
#
import ast
import pickle
from copy import deepcopy
import math
@@ -141,14 +143,27 @@ class EpisodicExperienceReplay(Memory):
def shuffle_episodes(self):
"""
Shuffle all the episodes in the replay buffer
Shuffle all the complete episodes in the replay buffer, while deleting the last non-complete episode
:return:
"""
self.reader_writer_lock.lock_writing()
self.assert_not_frozen()
# unlike the standard usage of the EpisodicExperienceReplay, where we always leave an empty episode after
# the last full one, so that new transitions will have where to be added, in this case we delibrately remove
# that empty last episode, as we are about to shuffle the memory, and we don't want it to be shuffled in
self.remove_last_episode(lock=False)
random.shuffle(self._buffer)
self.transitions = [t for e in self._buffer for t in e.transitions]
# create a new Episode for the next transitions to be placed into
self._buffer.append(Episode(n_step=self.n_step))
self._length += 1
self.reader_writer_lock.release_writing()
def get_shuffled_training_data_generator(self, size: int) -> List[Transition]:
"""
Get an generator for iterating through the shuffled replay buffer, for processing the data in epochs.
@@ -201,10 +216,10 @@ class EpisodicExperienceReplay(Memory):
granularity, size = self.max_size
if granularity == MemoryGranularity.Transitions:
while size != 0 and self.num_transitions() > size:
self._remove_episode(0)
self.remove_first_episode(lock=False)
elif granularity == MemoryGranularity.Episodes:
while self.length() > size:
self._remove_episode(0)
self.remove_first_episode(lock=False)
def _update_episode(self, episode: Episode) -> None:
episode.update_transitions_rewards_and_bootstrap_data()
@@ -321,31 +336,53 @@ class EpisodicExperienceReplay(Memory):
def _remove_episode(self, episode_index: int) -> None:
"""
Remove the episode in the given index (even if it is not complete yet)
:param episode_index: the index of the episode to remove
Remove either the first or the last index
:param episode_index: the index of the episode to remove (either 0 or -1)
:return: None
"""
self.assert_not_frozen()
assert episode_index == 0 or episode_index == -1, "_remove_episode only supports removing the first or the last " \
"episode"
if len(self._buffer) > episode_index:
if len(self._buffer) > 0:
episode_length = self._buffer[episode_index].length()
self._length -= 1
self._num_transitions -= episode_length
self._num_transitions_in_complete_episodes -= episode_length
del self.transitions[:episode_length]
if episode_index == 0:
del self.transitions[:episode_length]
else: # episode_index = -1
del self.transitions[-episode_length:]
del self._buffer[episode_index]
def remove_episode(self, episode_index: int) -> None:
def remove_first_episode(self, lock: bool = True) -> None:
"""
Remove the episode in the given index (even if it is not complete yet)
:param episode_index: the index of the episode to remove
Remove the first episode (even if it is not complete yet)
:param lock: if true, will lock the readers writers lock. this can cause a deadlock if an inheriting class
locks and then calls store with lock = True
:return: None
"""
self.reader_writer_lock.lock_writing_and_reading()
if lock:
self.reader_writer_lock.lock_writing_and_reading()
self._remove_episode(episode_index)
self._remove_episode(0)
if lock:
self.reader_writer_lock.release_writing_and_reading()
self.reader_writer_lock.release_writing_and_reading()
def remove_last_episode(self, lock: bool = True) -> None:
"""
Remove the last episode (even if it is not complete yet)
:param lock: if true, will lock the readers writers lock. this can cause a deadlock if an inheriting class
locks and then calls store with lock = True
:return: None
"""
if lock:
self.reader_writer_lock.lock_writing_and_reading()
self._remove_episode(-1)
if lock:
self.reader_writer_lock.release_writing_and_reading()
# for API compatibility
def get(self, episode_index: int, lock: bool = True) -> Union[None, Episode]:
@@ -372,15 +409,6 @@ class EpisodicExperienceReplay(Memory):
return episode
# for API compatibility
def remove(self, episode_index: int):
"""
Remove the episode in the given index (even if it is not complete yet)
:param episode_index: the index of the episode to remove
:return: None
"""
self.remove_episode(episode_index)
def clean(self) -> None:
"""
Clean the memory by removing all the episodes
@@ -446,7 +474,7 @@ class EpisodicExperienceReplay(Memory):
transitions.append(
Transition(state={'observation': state},
action=current_transition['action'], reward=current_transition['reward'],
action=int(current_transition['action']), reward=current_transition['reward'],
next_state={'observation': next_state}, game_over=False,
info={'all_action_probabilities':
ast.literal_eval(current_transition['all_action_probabilities'])}),
@@ -516,3 +544,36 @@ class EpisodicExperienceReplay(Memory):
self.last_training_set_episode_id = episode_num
self.last_training_set_transition_id = \
len([t for e in self.get_all_complete_episodes_from_to(0, self.last_training_set_episode_id + 1) for t in e])
def save(self, file_path: str) -> None:
"""
Save the replay buffer contents to a pickle file
:param file_path: the path to the file that will be used to store the pickled transitions
"""
with open(file_path, 'wb') as file:
pickle.dump(self.get_all_complete_episodes(), file)
def load_pickled(self, file_path: str) -> None:
"""
Restore the replay buffer contents from a pickle file.
The pickle file is assumed to include a list of transitions.
:param file_path: The path to a pickle file to restore
"""
self.assert_not_frozen()
with open(file_path, 'rb') as file:
episodes = pickle.load(file)
num_transitions = sum([len(e.transitions) for e in episodes])
if num_transitions > self.max_size[1]:
screen.warning("Warning! The number of transition to load into the replay buffer ({}) is "
"bigger than the max size of the replay buffer ({}). The excessive transitions will "
"not be stored.".format(num_transitions, self.max_size[1]))
progress_bar = ProgressBar(len(episodes))
for episode_idx, episode in enumerate(episodes):
self.store_episode(episode)
# print progress
progress_bar.update(episode_idx)
progress_bar.close()

View File

@@ -58,9 +58,6 @@ class Memory(object):
def get(self, index):
raise NotImplementedError("")
def remove(self, index):
raise NotImplementedError("")
def length(self):
raise NotImplementedError("")

View File

@@ -198,15 +198,6 @@ class ExperienceReplay(Memory):
"""
return self.get_transition(transition_index, lock)
# for API compatibility
def remove(self, transition_index: int, lock: bool=True):
"""
Remove the transition in the given index
:param transition_index: the index of the transition to remove
:return: None
"""
self.remove_transition(transition_index, lock)
def clean(self, lock: bool=True) -> None:
"""
Clean the memory by removing all the episodes