1
0
mirror of https://github.com/gryf/coach.git synced 2026-03-17 15:23:37 +01:00

bug fixes for OPE (#311)

This commit is contained in:
Gal Leibovich
2019-05-21 16:39:11 +03:00
committed by GitHub
parent 85d70dd7d5
commit acceb03ac0
8 changed files with 38 additions and 21 deletions

View File

@@ -148,7 +148,7 @@ class EpisodicExperienceReplay(Memory):
random.shuffle(self._buffer)
self.transitions = [t for e in self._buffer for t in e.transitions]
def get_shuffled_data_generator(self, size: int) -> List[Transition]:
def get_shuffled_training_data_generator(self, size: int) -> List[Transition]:
"""
Get an generator for iterating through the shuffled replay buffer, for processing the data in epochs.
If the requested size is larger than the number of samples available in the replay buffer then the batch will
@@ -159,15 +159,6 @@ class EpisodicExperienceReplay(Memory):
:return: a batch (list) of selected transitions from the replay buffer
"""
self.reader_writer_lock.lock_writing()
if self.last_training_set_transition_id is None:
if self.train_to_eval_ratio < 0 or self.train_to_eval_ratio >= 1:
raise ValueError('train_to_eval_ratio should be in the (0, 1] range.')
transition = self.transitions[round(self.train_to_eval_ratio * self.num_transitions_in_complete_episodes())]
episode_num, episode = self.get_episode_for_transition(transition)
self.last_training_set_episode_id = episode_num
self.last_training_set_transition_id = \
len([t for e in self.get_all_complete_episodes_from_to(0, self.last_training_set_episode_id + 1) for t in e])
shuffled_transition_indices = list(range(self.last_training_set_transition_id))
random.shuffle(shuffled_transition_indices)
@@ -483,6 +474,9 @@ class EpisodicExperienceReplay(Memory):
Gather the memory content that will be used for off-policy evaluation in episodes and transitions format
:return:
"""
self.reader_writer_lock.lock_writing_and_reading()
self._split_training_and_evaluation_datasets()
self.evaluation_dataset_as_episodes = deepcopy(
self.get_all_complete_episodes_from_to(self.get_last_training_set_episode_id() + 1,
self.num_complete_episodes()))
@@ -493,3 +487,20 @@ class EpisodicExperienceReplay(Memory):
self.evaluation_dataset_as_transitions = [t for e in self.evaluation_dataset_as_episodes
for t in e.transitions]
self.reader_writer_lock.release_writing_and_reading()
def _split_training_and_evaluation_datasets(self):
"""
If the data in the buffer was not split to training and evaluation yet, split it accordingly.
:return: None
"""
if self.last_training_set_transition_id is None:
if self.train_to_eval_ratio < 0 or self.train_to_eval_ratio >= 1:
raise ValueError('train_to_eval_ratio should be in the (0, 1] range.')
transition = self.transitions[round(self.train_to_eval_ratio * self.num_transitions_in_complete_episodes())]
episode_num, episode = self.get_episode_for_transition(transition)
self.last_training_set_episode_id = episode_num
self.last_training_set_transition_id = \
len([t for e in self.get_all_complete_episodes_from_to(0, self.last_training_set_episode_id + 1) for t in e])

View File

@@ -92,7 +92,7 @@ class ExperienceReplay(Memory):
self.reader_writer_lock.release_writing()
return batch
def get_shuffled_data_generator(self, size: int) -> List[Transition]:
def get_shuffled_training_data_generator(self, size: int) -> List[Transition]:
"""
Get an generator for iterating through the shuffled replay buffer, for processing the data in epochs.
If the requested size is larger than the number of samples available in the replay buffer then the batch will