mirror of
https://github.com/gryf/coach.git
synced 2026-04-13 01:13:32 +02:00
OPE: Weighted Importance Sampling (#299)
This commit is contained in:
@@ -54,6 +54,7 @@ class ExperienceReplay(Memory):
|
||||
self.allow_duplicates_in_batch_sampling = allow_duplicates_in_batch_sampling
|
||||
|
||||
self.reader_writer_lock = ReaderWriterLock()
|
||||
self.frozen = False
|
||||
|
||||
def length(self) -> int:
|
||||
"""
|
||||
@@ -135,6 +136,8 @@ class ExperienceReplay(Memory):
|
||||
locks and then calls store with lock = True
|
||||
:return: None
|
||||
"""
|
||||
self.assert_not_frozen()
|
||||
|
||||
# Calling super.store() so that in case a memory backend is used, the memory backend can store this transition.
|
||||
super().store(transition)
|
||||
if lock:
|
||||
@@ -175,6 +178,8 @@ class ExperienceReplay(Memory):
|
||||
:param transition_index: the index of the transition to remove
|
||||
:return: None
|
||||
"""
|
||||
self.assert_not_frozen()
|
||||
|
||||
if lock:
|
||||
self.reader_writer_lock.lock_writing_and_reading()
|
||||
|
||||
@@ -207,6 +212,8 @@ class ExperienceReplay(Memory):
|
||||
Clean the memory by removing all the episodes
|
||||
:return: None
|
||||
"""
|
||||
self.assert_not_frozen()
|
||||
|
||||
if lock:
|
||||
self.reader_writer_lock.lock_writing_and_reading()
|
||||
|
||||
@@ -242,6 +249,8 @@ class ExperienceReplay(Memory):
|
||||
The pickle file is assumed to include a list of transitions.
|
||||
:param file_path: The path to a pickle file to restore
|
||||
"""
|
||||
self.assert_not_frozen()
|
||||
|
||||
with open(file_path, 'rb') as file:
|
||||
transitions = pickle.load(file)
|
||||
num_transitions = len(transitions)
|
||||
@@ -260,3 +269,17 @@ class ExperienceReplay(Memory):
|
||||
|
||||
progress_bar.close()
|
||||
|
||||
def freeze(self):
|
||||
"""
|
||||
Freezing the replay buffer does not allow any new transitions to be added to the memory.
|
||||
Useful when working with a dataset (e.g. batch-rl or imitation learning).
|
||||
:return: None
|
||||
"""
|
||||
self.frozen = True
|
||||
|
||||
def assert_not_frozen(self):
|
||||
"""
|
||||
Check that the memory is not frozen, and can be changed.
|
||||
:return:
|
||||
"""
|
||||
assert self.frozen is False, "Memory is frozen, and cannot be changed."
|
||||
|
||||
Reference in New Issue
Block a user