1
0
mirror of https://github.com/gryf/coach.git synced 2026-03-13 13:15:50 +01:00

OPE: Weighted Importance Sampling (#299)

This commit is contained in:
Gal Leibovich
2019-05-02 19:25:42 +03:00
committed by GitHub
parent 74db141d5e
commit 582921ffe3
8 changed files with 222 additions and 51 deletions

View File

@@ -15,6 +15,8 @@
# limitations under the License.
#
import ast
from copy import deepcopy
import math
import pandas as pd
@@ -64,6 +66,10 @@ class EpisodicExperienceReplay(Memory):
self.last_training_set_episode_id = None # used in batch-rl
self.last_training_set_transition_id = None # used in batch-rl
self.train_to_eval_ratio = train_to_eval_ratio # used in batch-rl
self.evaluation_dataset_as_episodes = None
self.evaluation_dataset_as_transitions = None
self.frozen = False
def length(self, lock: bool = False) -> int:
"""
@@ -137,6 +143,8 @@ class EpisodicExperienceReplay(Memory):
Shuffle all the episodes in the replay buffer
:return:
"""
self.assert_not_frozen()
random.shuffle(self._buffer)
self.transitions = [t for e in self._buffer for t in e.transitions]
@@ -256,6 +264,7 @@ class EpisodicExperienceReplay(Memory):
:param transition: a transition to store
:return: None
"""
self.assert_not_frozen()
# Calling super.store() so that in case a memory backend is used, the memory backend can store this transition.
super().store(transition)
@@ -281,6 +290,8 @@ class EpisodicExperienceReplay(Memory):
:param episode: the new episode to store
:return: None
"""
self.assert_not_frozen()
# Calling super.store() so that in case a memory backend is used, the memory backend can store this episode.
super().store_episode(episode)
@@ -322,6 +333,8 @@ class EpisodicExperienceReplay(Memory):
:param episode_index: the index of the episode to remove
:return: None
"""
self.assert_not_frozen()
if len(self._buffer) > episode_index:
episode_length = self._buffer[episode_index].length()
self._length -= 1
@@ -381,6 +394,7 @@ class EpisodicExperienceReplay(Memory):
Clean the memory by removing all the episodes
:return: None
"""
self.assert_not_frozen()
self.reader_writer_lock.lock_writing_and_reading()
self.transitions = []
@@ -409,6 +423,8 @@ class EpisodicExperienceReplay(Memory):
The csv file is assumed to include a list of transitions.
:param csv_dataset: A construct which holds the dataset parameters
"""
self.assert_not_frozen()
df = pd.read_csv(csv_dataset.filepath)
if len(df) > self.max_size[1]:
screen.warning("Warning! The number of transitions to load into the replay buffer ({}) is "
@@ -446,3 +462,34 @@ class EpisodicExperienceReplay(Memory):
progress_bar.close()
self.shuffle_episodes()
def freeze(self):
"""
Freezing the replay buffer does not allow any new transitions to be added to the memory.
Useful when working with a dataset (e.g. batch-rl or imitation learning).
:return: None
"""
self.frozen = True
def assert_not_frozen(self):
"""
Check that the memory is not frozen, and can be changed.
:return:
"""
assert self.frozen is False, "Memory is frozen, and cannot be changed."
def prepare_evaluation_dataset(self):
"""
Gather the memory content that will be used for off-policy evaluation in episodes and transitions format
:return:
"""
self.evaluation_dataset_as_episodes = deepcopy(
self.get_all_complete_episodes_from_to(self.get_last_training_set_episode_id() + 1,
self.num_complete_episodes()))
if len(self.evaluation_dataset_as_episodes) == 0:
raise ValueError('train_to_eval_ratio is too high causing the evaluation set to be empty. '
'Consider decreasing its value.')
self.evaluation_dataset_as_transitions = [t for e in self.evaluation_dataset_as_episodes
for t in e.transitions]

View File

@@ -54,6 +54,7 @@ class ExperienceReplay(Memory):
self.allow_duplicates_in_batch_sampling = allow_duplicates_in_batch_sampling
self.reader_writer_lock = ReaderWriterLock()
self.frozen = False
def length(self) -> int:
"""
@@ -135,6 +136,8 @@ class ExperienceReplay(Memory):
locks and then calls store with lock = True
:return: None
"""
self.assert_not_frozen()
# Calling super.store() so that in case a memory backend is used, the memory backend can store this transition.
super().store(transition)
if lock:
@@ -175,6 +178,8 @@ class ExperienceReplay(Memory):
:param transition_index: the index of the transition to remove
:return: None
"""
self.assert_not_frozen()
if lock:
self.reader_writer_lock.lock_writing_and_reading()
@@ -207,6 +212,8 @@ class ExperienceReplay(Memory):
Clean the memory by removing all the episodes
:return: None
"""
self.assert_not_frozen()
if lock:
self.reader_writer_lock.lock_writing_and_reading()
@@ -242,6 +249,8 @@ class ExperienceReplay(Memory):
The pickle file is assumed to include a list of transitions.
:param file_path: The path to a pickle file to restore
"""
self.assert_not_frozen()
with open(file_path, 'rb') as file:
transitions = pickle.load(file)
num_transitions = len(transitions)
@@ -260,3 +269,17 @@ class ExperienceReplay(Memory):
progress_bar.close()
def freeze(self):
"""
Freezing the replay buffer does not allow any new transitions to be added to the memory.
Useful when working with a dataset (e.g. batch-rl or imitation learning).
:return: None
"""
self.frozen = True
def assert_not_frozen(self):
"""
Check that the memory is not frozen, and can be changed.
:return:
"""
assert self.frozen is False, "Memory is frozen, and cannot be changed."