mirror of
https://github.com/gryf/coach.git
synced 2026-03-13 13:15:50 +01:00
OPE: Weighted Importance Sampling (#299)
This commit is contained in:
@@ -15,6 +15,8 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
import ast
|
||||
from copy import deepcopy
|
||||
|
||||
import math
|
||||
|
||||
import pandas as pd
|
||||
@@ -64,6 +66,10 @@ class EpisodicExperienceReplay(Memory):
|
||||
self.last_training_set_episode_id = None # used in batch-rl
|
||||
self.last_training_set_transition_id = None # used in batch-rl
|
||||
self.train_to_eval_ratio = train_to_eval_ratio # used in batch-rl
|
||||
self.evaluation_dataset_as_episodes = None
|
||||
self.evaluation_dataset_as_transitions = None
|
||||
|
||||
self.frozen = False
|
||||
|
||||
def length(self, lock: bool = False) -> int:
|
||||
"""
|
||||
@@ -137,6 +143,8 @@ class EpisodicExperienceReplay(Memory):
|
||||
Shuffle all the episodes in the replay buffer
|
||||
:return:
|
||||
"""
|
||||
self.assert_not_frozen()
|
||||
|
||||
random.shuffle(self._buffer)
|
||||
self.transitions = [t for e in self._buffer for t in e.transitions]
|
||||
|
||||
@@ -256,6 +264,7 @@ class EpisodicExperienceReplay(Memory):
|
||||
:param transition: a transition to store
|
||||
:return: None
|
||||
"""
|
||||
self.assert_not_frozen()
|
||||
|
||||
# Calling super.store() so that in case a memory backend is used, the memory backend can store this transition.
|
||||
super().store(transition)
|
||||
@@ -281,6 +290,8 @@ class EpisodicExperienceReplay(Memory):
|
||||
:param episode: the new episode to store
|
||||
:return: None
|
||||
"""
|
||||
self.assert_not_frozen()
|
||||
|
||||
# Calling super.store() so that in case a memory backend is used, the memory backend can store this episode.
|
||||
super().store_episode(episode)
|
||||
|
||||
@@ -322,6 +333,8 @@ class EpisodicExperienceReplay(Memory):
|
||||
:param episode_index: the index of the episode to remove
|
||||
:return: None
|
||||
"""
|
||||
self.assert_not_frozen()
|
||||
|
||||
if len(self._buffer) > episode_index:
|
||||
episode_length = self._buffer[episode_index].length()
|
||||
self._length -= 1
|
||||
@@ -381,6 +394,7 @@ class EpisodicExperienceReplay(Memory):
|
||||
Clean the memory by removing all the episodes
|
||||
:return: None
|
||||
"""
|
||||
self.assert_not_frozen()
|
||||
self.reader_writer_lock.lock_writing_and_reading()
|
||||
|
||||
self.transitions = []
|
||||
@@ -409,6 +423,8 @@ class EpisodicExperienceReplay(Memory):
|
||||
The csv file is assumed to include a list of transitions.
|
||||
:param csv_dataset: A construct which holds the dataset parameters
|
||||
"""
|
||||
self.assert_not_frozen()
|
||||
|
||||
df = pd.read_csv(csv_dataset.filepath)
|
||||
if len(df) > self.max_size[1]:
|
||||
screen.warning("Warning! The number of transitions to load into the replay buffer ({}) is "
|
||||
@@ -446,3 +462,34 @@ class EpisodicExperienceReplay(Memory):
|
||||
progress_bar.close()
|
||||
|
||||
self.shuffle_episodes()
|
||||
|
||||
def freeze(self):
|
||||
"""
|
||||
Freezing the replay buffer does not allow any new transitions to be added to the memory.
|
||||
Useful when working with a dataset (e.g. batch-rl or imitation learning).
|
||||
:return: None
|
||||
"""
|
||||
self.frozen = True
|
||||
|
||||
def assert_not_frozen(self):
|
||||
"""
|
||||
Check that the memory is not frozen, and can be changed.
|
||||
:return:
|
||||
"""
|
||||
assert self.frozen is False, "Memory is frozen, and cannot be changed."
|
||||
|
||||
def prepare_evaluation_dataset(self):
|
||||
"""
|
||||
Gather the memory content that will be used for off-policy evaluation in episodes and transitions format
|
||||
:return:
|
||||
"""
|
||||
self.evaluation_dataset_as_episodes = deepcopy(
|
||||
self.get_all_complete_episodes_from_to(self.get_last_training_set_episode_id() + 1,
|
||||
self.num_complete_episodes()))
|
||||
|
||||
if len(self.evaluation_dataset_as_episodes) == 0:
|
||||
raise ValueError('train_to_eval_ratio is too high causing the evaluation set to be empty. '
|
||||
'Consider decreasing its value.')
|
||||
|
||||
self.evaluation_dataset_as_transitions = [t for e in self.evaluation_dataset_as_episodes
|
||||
for t in e.transitions]
|
||||
|
||||
@@ -54,6 +54,7 @@ class ExperienceReplay(Memory):
|
||||
self.allow_duplicates_in_batch_sampling = allow_duplicates_in_batch_sampling
|
||||
|
||||
self.reader_writer_lock = ReaderWriterLock()
|
||||
self.frozen = False
|
||||
|
||||
def length(self) -> int:
|
||||
"""
|
||||
@@ -135,6 +136,8 @@ class ExperienceReplay(Memory):
|
||||
locks and then calls store with lock = True
|
||||
:return: None
|
||||
"""
|
||||
self.assert_not_frozen()
|
||||
|
||||
# Calling super.store() so that in case a memory backend is used, the memory backend can store this transition.
|
||||
super().store(transition)
|
||||
if lock:
|
||||
@@ -175,6 +178,8 @@ class ExperienceReplay(Memory):
|
||||
:param transition_index: the index of the transition to remove
|
||||
:return: None
|
||||
"""
|
||||
self.assert_not_frozen()
|
||||
|
||||
if lock:
|
||||
self.reader_writer_lock.lock_writing_and_reading()
|
||||
|
||||
@@ -207,6 +212,8 @@ class ExperienceReplay(Memory):
|
||||
Clean the memory by removing all the episodes
|
||||
:return: None
|
||||
"""
|
||||
self.assert_not_frozen()
|
||||
|
||||
if lock:
|
||||
self.reader_writer_lock.lock_writing_and_reading()
|
||||
|
||||
@@ -242,6 +249,8 @@ class ExperienceReplay(Memory):
|
||||
The pickle file is assumed to include a list of transitions.
|
||||
:param file_path: The path to a pickle file to restore
|
||||
"""
|
||||
self.assert_not_frozen()
|
||||
|
||||
with open(file_path, 'rb') as file:
|
||||
transitions = pickle.load(file)
|
||||
num_transitions = len(transitions)
|
||||
@@ -260,3 +269,17 @@ class ExperienceReplay(Memory):
|
||||
|
||||
progress_bar.close()
|
||||
|
||||
def freeze(self):
|
||||
"""
|
||||
Freezing the replay buffer does not allow any new transitions to be added to the memory.
|
||||
Useful when working with a dataset (e.g. batch-rl or imitation learning).
|
||||
:return: None
|
||||
"""
|
||||
self.frozen = True
|
||||
|
||||
def assert_not_frozen(self):
|
||||
"""
|
||||
Check that the memory is not frozen, and can be changed.
|
||||
:return:
|
||||
"""
|
||||
assert self.frozen is False, "Memory is frozen, and cannot be changed."
|
||||
|
||||
Reference in New Issue
Block a user