1
0
mirror of https://github.com/gryf/coach.git synced 2026-03-26 21:23:31 +01:00

OPE: Weighted Importance Sampling (#299)

This commit is contained in:
Gal Leibovich
2019-05-02 19:25:42 +03:00
committed by GitHub
parent 74db141d5e
commit 582921ffe3
8 changed files with 222 additions and 51 deletions

View File

@@ -15,6 +15,8 @@
# limitations under the License.
#
import ast
from copy import deepcopy
import math
import pandas as pd
@@ -64,6 +66,10 @@ class EpisodicExperienceReplay(Memory):
self.last_training_set_episode_id = None # used in batch-rl
self.last_training_set_transition_id = None # used in batch-rl
self.train_to_eval_ratio = train_to_eval_ratio # used in batch-rl
self.evaluation_dataset_as_episodes = None
self.evaluation_dataset_as_transitions = None
self.frozen = False
def length(self, lock: bool = False) -> int:
"""
@@ -137,6 +143,8 @@ class EpisodicExperienceReplay(Memory):
Shuffle all the episodes in the replay buffer
:return:
"""
self.assert_not_frozen()
random.shuffle(self._buffer)
self.transitions = [t for e in self._buffer for t in e.transitions]
@@ -256,6 +264,7 @@ class EpisodicExperienceReplay(Memory):
:param transition: a transition to store
:return: None
"""
self.assert_not_frozen()
# Calling super.store() so that in case a memory backend is used, the memory backend can store this transition.
super().store(transition)
@@ -281,6 +290,8 @@ class EpisodicExperienceReplay(Memory):
:param episode: the new episode to store
:return: None
"""
self.assert_not_frozen()
# Calling super.store() so that in case a memory backend is used, the memory backend can store this episode.
super().store_episode(episode)
@@ -322,6 +333,8 @@ class EpisodicExperienceReplay(Memory):
:param episode_index: the index of the episode to remove
:return: None
"""
self.assert_not_frozen()
if len(self._buffer) > episode_index:
episode_length = self._buffer[episode_index].length()
self._length -= 1
@@ -381,6 +394,7 @@ class EpisodicExperienceReplay(Memory):
Clean the memory by removing all the episodes
:return: None
"""
self.assert_not_frozen()
self.reader_writer_lock.lock_writing_and_reading()
self.transitions = []
@@ -409,6 +423,8 @@ class EpisodicExperienceReplay(Memory):
The csv file is assumed to include a list of transitions.
:param csv_dataset: A construct which holds the dataset parameters
"""
self.assert_not_frozen()
df = pd.read_csv(csv_dataset.filepath)
if len(df) > self.max_size[1]:
screen.warning("Warning! The number of transitions to load into the replay buffer ({}) is "
@@ -446,3 +462,34 @@ class EpisodicExperienceReplay(Memory):
progress_bar.close()
self.shuffle_episodes()
def freeze(self):
"""
Freezing the replay buffer does not allow any new transitions to be added to the memory.
Useful when working with a dataset (e.g. batch-rl or imitation learning).
:return: None
"""
self.frozen = True
def assert_not_frozen(self):
"""
Check that the memory is not frozen, and can be changed.
:return:
"""
assert self.frozen is False, "Memory is frozen, and cannot be changed."
def prepare_evaluation_dataset(self):
"""
Gather the memory content that will be used for off-policy evaluation in episodes and transitions format
:return:
"""
self.evaluation_dataset_as_episodes = deepcopy(
self.get_all_complete_episodes_from_to(self.get_last_training_set_episode_id() + 1,
self.num_complete_episodes()))
if len(self.evaluation_dataset_as_episodes) == 0:
raise ValueError('train_to_eval_ratio is too high causing the evaluation set to be empty. '
'Consider decreasing its value.')
self.evaluation_dataset_as_transitions = [t for e in self.evaluation_dataset_as_episodes
for t in e.transitions]