mirror of
https://github.com/gryf/coach.git
synced 2026-02-23 10:35:46 +01:00
Batch RL Tutorial (#372)
This commit is contained in:
@@ -15,6 +15,8 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
import ast
|
||||
|
||||
import pickle
|
||||
from copy import deepcopy
|
||||
|
||||
import math
|
||||
@@ -141,14 +143,27 @@ class EpisodicExperienceReplay(Memory):
|
||||
|
||||
def shuffle_episodes(self):
|
||||
"""
|
||||
Shuffle all the episodes in the replay buffer
|
||||
Shuffle all the complete episodes in the replay buffer, while deleting the last non-complete episode
|
||||
:return:
|
||||
"""
|
||||
self.reader_writer_lock.lock_writing()
|
||||
|
||||
self.assert_not_frozen()
|
||||
|
||||
# unlike the standard usage of the EpisodicExperienceReplay, where we always leave an empty episode after
|
||||
# the last full one, so that new transitions will have where to be added, in this case we delibrately remove
|
||||
# that empty last episode, as we are about to shuffle the memory, and we don't want it to be shuffled in
|
||||
self.remove_last_episode(lock=False)
|
||||
|
||||
random.shuffle(self._buffer)
|
||||
self.transitions = [t for e in self._buffer for t in e.transitions]
|
||||
|
||||
# create a new Episode for the next transitions to be placed into
|
||||
self._buffer.append(Episode(n_step=self.n_step))
|
||||
self._length += 1
|
||||
|
||||
self.reader_writer_lock.release_writing()
|
||||
|
||||
def get_shuffled_training_data_generator(self, size: int) -> List[Transition]:
|
||||
"""
|
||||
Get an generator for iterating through the shuffled replay buffer, for processing the data in epochs.
|
||||
@@ -201,10 +216,10 @@ class EpisodicExperienceReplay(Memory):
|
||||
granularity, size = self.max_size
|
||||
if granularity == MemoryGranularity.Transitions:
|
||||
while size != 0 and self.num_transitions() > size:
|
||||
self._remove_episode(0)
|
||||
self.remove_first_episode(lock=False)
|
||||
elif granularity == MemoryGranularity.Episodes:
|
||||
while self.length() > size:
|
||||
self._remove_episode(0)
|
||||
self.remove_first_episode(lock=False)
|
||||
|
||||
def _update_episode(self, episode: Episode) -> None:
|
||||
episode.update_transitions_rewards_and_bootstrap_data()
|
||||
@@ -321,31 +336,53 @@ class EpisodicExperienceReplay(Memory):
|
||||
|
||||
def _remove_episode(self, episode_index: int) -> None:
|
||||
"""
|
||||
Remove the episode in the given index (even if it is not complete yet)
|
||||
:param episode_index: the index of the episode to remove
|
||||
Remove either the first or the last index
|
||||
:param episode_index: the index of the episode to remove (either 0 or -1)
|
||||
:return: None
|
||||
"""
|
||||
self.assert_not_frozen()
|
||||
assert episode_index == 0 or episode_index == -1, "_remove_episode only supports removing the first or the last " \
|
||||
"episode"
|
||||
|
||||
if len(self._buffer) > episode_index:
|
||||
if len(self._buffer) > 0:
|
||||
episode_length = self._buffer[episode_index].length()
|
||||
self._length -= 1
|
||||
self._num_transitions -= episode_length
|
||||
self._num_transitions_in_complete_episodes -= episode_length
|
||||
del self.transitions[:episode_length]
|
||||
if episode_index == 0:
|
||||
del self.transitions[:episode_length]
|
||||
else: # episode_index = -1
|
||||
del self.transitions[-episode_length:]
|
||||
del self._buffer[episode_index]
|
||||
|
||||
def remove_episode(self, episode_index: int) -> None:
|
||||
def remove_first_episode(self, lock: bool = True) -> None:
|
||||
"""
|
||||
Remove the episode in the given index (even if it is not complete yet)
|
||||
:param episode_index: the index of the episode to remove
|
||||
Remove the first episode (even if it is not complete yet)
|
||||
:param lock: if true, will lock the readers writers lock. this can cause a deadlock if an inheriting class
|
||||
locks and then calls store with lock = True
|
||||
:return: None
|
||||
"""
|
||||
self.reader_writer_lock.lock_writing_and_reading()
|
||||
if lock:
|
||||
self.reader_writer_lock.lock_writing_and_reading()
|
||||
|
||||
self._remove_episode(episode_index)
|
||||
self._remove_episode(0)
|
||||
if lock:
|
||||
self.reader_writer_lock.release_writing_and_reading()
|
||||
|
||||
self.reader_writer_lock.release_writing_and_reading()
|
||||
def remove_last_episode(self, lock: bool = True) -> None:
|
||||
"""
|
||||
Remove the last episode (even if it is not complete yet)
|
||||
:param lock: if true, will lock the readers writers lock. this can cause a deadlock if an inheriting class
|
||||
locks and then calls store with lock = True
|
||||
:return: None
|
||||
"""
|
||||
if lock:
|
||||
self.reader_writer_lock.lock_writing_and_reading()
|
||||
|
||||
self._remove_episode(-1)
|
||||
|
||||
if lock:
|
||||
self.reader_writer_lock.release_writing_and_reading()
|
||||
|
||||
# for API compatibility
|
||||
def get(self, episode_index: int, lock: bool = True) -> Union[None, Episode]:
|
||||
@@ -372,15 +409,6 @@ class EpisodicExperienceReplay(Memory):
|
||||
|
||||
return episode
|
||||
|
||||
# for API compatibility
|
||||
def remove(self, episode_index: int):
|
||||
"""
|
||||
Remove the episode in the given index (even if it is not complete yet)
|
||||
:param episode_index: the index of the episode to remove
|
||||
:return: None
|
||||
"""
|
||||
self.remove_episode(episode_index)
|
||||
|
||||
def clean(self) -> None:
|
||||
"""
|
||||
Clean the memory by removing all the episodes
|
||||
@@ -446,7 +474,7 @@ class EpisodicExperienceReplay(Memory):
|
||||
|
||||
transitions.append(
|
||||
Transition(state={'observation': state},
|
||||
action=current_transition['action'], reward=current_transition['reward'],
|
||||
action=int(current_transition['action']), reward=current_transition['reward'],
|
||||
next_state={'observation': next_state}, game_over=False,
|
||||
info={'all_action_probabilities':
|
||||
ast.literal_eval(current_transition['all_action_probabilities'])}),
|
||||
@@ -516,3 +544,36 @@ class EpisodicExperienceReplay(Memory):
|
||||
self.last_training_set_episode_id = episode_num
|
||||
self.last_training_set_transition_id = \
|
||||
len([t for e in self.get_all_complete_episodes_from_to(0, self.last_training_set_episode_id + 1) for t in e])
|
||||
|
||||
def save(self, file_path: str) -> None:
|
||||
"""
|
||||
Save the replay buffer contents to a pickle file
|
||||
:param file_path: the path to the file that will be used to store the pickled transitions
|
||||
"""
|
||||
with open(file_path, 'wb') as file:
|
||||
pickle.dump(self.get_all_complete_episodes(), file)
|
||||
|
||||
def load_pickled(self, file_path: str) -> None:
|
||||
"""
|
||||
Restore the replay buffer contents from a pickle file.
|
||||
The pickle file is assumed to include a list of transitions.
|
||||
:param file_path: The path to a pickle file to restore
|
||||
"""
|
||||
self.assert_not_frozen()
|
||||
|
||||
with open(file_path, 'rb') as file:
|
||||
episodes = pickle.load(file)
|
||||
num_transitions = sum([len(e.transitions) for e in episodes])
|
||||
if num_transitions > self.max_size[1]:
|
||||
screen.warning("Warning! The number of transition to load into the replay buffer ({}) is "
|
||||
"bigger than the max size of the replay buffer ({}). The excessive transitions will "
|
||||
"not be stored.".format(num_transitions, self.max_size[1]))
|
||||
|
||||
progress_bar = ProgressBar(len(episodes))
|
||||
for episode_idx, episode in enumerate(episodes):
|
||||
self.store_episode(episode)
|
||||
|
||||
# print progress
|
||||
progress_bar.update(episode_idx)
|
||||
|
||||
progress_bar.close()
|
||||
|
||||
@@ -58,9 +58,6 @@ class Memory(object):
|
||||
def get(self, index):
|
||||
raise NotImplementedError("")
|
||||
|
||||
def remove(self, index):
|
||||
raise NotImplementedError("")
|
||||
|
||||
def length(self):
|
||||
raise NotImplementedError("")
|
||||
|
||||
|
||||
@@ -198,15 +198,6 @@ class ExperienceReplay(Memory):
|
||||
"""
|
||||
return self.get_transition(transition_index, lock)
|
||||
|
||||
# for API compatibility
|
||||
def remove(self, transition_index: int, lock: bool=True):
|
||||
"""
|
||||
Remove the transition in the given index
|
||||
:param transition_index: the index of the transition to remove
|
||||
:return: None
|
||||
"""
|
||||
self.remove_transition(transition_index, lock)
|
||||
|
||||
def clean(self, lock: bool=True) -> None:
|
||||
"""
|
||||
Clean the memory by removing all the episodes
|
||||
|
||||
Reference in New Issue
Block a user