1
0
mirror of https://github.com/gryf/coach.git synced 2026-04-13 01:13:32 +02:00

OPE: Weighted Importance Sampling (#299)

This commit is contained in:
Gal Leibovich
2019-05-02 19:25:42 +03:00
committed by GitHub
parent 74db141d5e
commit 582921ffe3
8 changed files with 222 additions and 51 deletions

View File

@@ -54,6 +54,7 @@ class ExperienceReplay(Memory):
self.allow_duplicates_in_batch_sampling = allow_duplicates_in_batch_sampling
self.reader_writer_lock = ReaderWriterLock()
self.frozen = False
def length(self) -> int:
"""
@@ -135,6 +136,8 @@ class ExperienceReplay(Memory):
locks and then calls store with lock = True
:return: None
"""
self.assert_not_frozen()
# Calling super.store() so that in case a memory backend is used, the memory backend can store this transition.
super().store(transition)
if lock:
@@ -175,6 +178,8 @@ class ExperienceReplay(Memory):
:param transition_index: the index of the transition to remove
:return: None
"""
self.assert_not_frozen()
if lock:
self.reader_writer_lock.lock_writing_and_reading()
@@ -207,6 +212,8 @@ class ExperienceReplay(Memory):
Clean the memory by removing all the episodes
:return: None
"""
self.assert_not_frozen()
if lock:
self.reader_writer_lock.lock_writing_and_reading()
@@ -242,6 +249,8 @@ class ExperienceReplay(Memory):
The pickle file is assumed to include a list of transitions.
:param file_path: The path to a pickle file to restore
"""
self.assert_not_frozen()
with open(file_path, 'rb') as file:
transitions = pickle.load(file)
num_transitions = len(transitions)
@@ -260,3 +269,17 @@ class ExperienceReplay(Memory):
progress_bar.close()
def freeze(self):
"""
Freezing the replay buffer does not allow any new transitions to be added to the memory.
Useful when working with a dataset (e.g. batch-rl or imitation learning).
:return: None
"""
self.frozen = True
def assert_not_frozen(self):
"""
Check that the memory is not frozen, and can be changed.
:return:
"""
assert self.frozen is False, "Memory is frozen, and cannot be changed."