mirror of
https://github.com/gryf/coach.git
synced 2026-03-17 15:23:37 +01:00
bug fixes for OPE (#311)
This commit is contained in:
@@ -148,7 +148,7 @@ class EpisodicExperienceReplay(Memory):
|
||||
random.shuffle(self._buffer)
|
||||
self.transitions = [t for e in self._buffer for t in e.transitions]
|
||||
|
||||
def get_shuffled_data_generator(self, size: int) -> List[Transition]:
|
||||
def get_shuffled_training_data_generator(self, size: int) -> List[Transition]:
|
||||
"""
|
||||
Get an generator for iterating through the shuffled replay buffer, for processing the data in epochs.
|
||||
If the requested size is larger than the number of samples available in the replay buffer then the batch will
|
||||
@@ -159,15 +159,6 @@ class EpisodicExperienceReplay(Memory):
|
||||
:return: a batch (list) of selected transitions from the replay buffer
|
||||
"""
|
||||
self.reader_writer_lock.lock_writing()
|
||||
if self.last_training_set_transition_id is None:
|
||||
if self.train_to_eval_ratio < 0 or self.train_to_eval_ratio >= 1:
|
||||
raise ValueError('train_to_eval_ratio should be in the (0, 1] range.')
|
||||
|
||||
transition = self.transitions[round(self.train_to_eval_ratio * self.num_transitions_in_complete_episodes())]
|
||||
episode_num, episode = self.get_episode_for_transition(transition)
|
||||
self.last_training_set_episode_id = episode_num
|
||||
self.last_training_set_transition_id = \
|
||||
len([t for e in self.get_all_complete_episodes_from_to(0, self.last_training_set_episode_id + 1) for t in e])
|
||||
|
||||
shuffled_transition_indices = list(range(self.last_training_set_transition_id))
|
||||
random.shuffle(shuffled_transition_indices)
|
||||
@@ -483,6 +474,9 @@ class EpisodicExperienceReplay(Memory):
|
||||
Gather the memory content that will be used for off-policy evaluation in episodes and transitions format
|
||||
:return:
|
||||
"""
|
||||
self.reader_writer_lock.lock_writing_and_reading()
|
||||
|
||||
self._split_training_and_evaluation_datasets()
|
||||
self.evaluation_dataset_as_episodes = deepcopy(
|
||||
self.get_all_complete_episodes_from_to(self.get_last_training_set_episode_id() + 1,
|
||||
self.num_complete_episodes()))
|
||||
@@ -493,3 +487,20 @@ class EpisodicExperienceReplay(Memory):
|
||||
|
||||
self.evaluation_dataset_as_transitions = [t for e in self.evaluation_dataset_as_episodes
|
||||
for t in e.transitions]
|
||||
self.reader_writer_lock.release_writing_and_reading()
|
||||
|
||||
def _split_training_and_evaluation_datasets(self):
|
||||
"""
|
||||
If the data in the buffer was not split to training and evaluation yet, split it accordingly.
|
||||
:return: None
|
||||
"""
|
||||
|
||||
if self.last_training_set_transition_id is None:
|
||||
if self.train_to_eval_ratio < 0 or self.train_to_eval_ratio >= 1:
|
||||
raise ValueError('train_to_eval_ratio should be in the (0, 1] range.')
|
||||
|
||||
transition = self.transitions[round(self.train_to_eval_ratio * self.num_transitions_in_complete_episodes())]
|
||||
episode_num, episode = self.get_episode_for_transition(transition)
|
||||
self.last_training_set_episode_id = episode_num
|
||||
self.last_training_set_transition_id = \
|
||||
len([t for e in self.get_all_complete_episodes_from_to(0, self.last_training_set_episode_id + 1) for t in e])
|
||||
|
||||
@@ -92,7 +92,7 @@ class ExperienceReplay(Memory):
|
||||
self.reader_writer_lock.release_writing()
|
||||
return batch
|
||||
|
||||
def get_shuffled_data_generator(self, size: int) -> List[Transition]:
|
||||
def get_shuffled_training_data_generator(self, size: int) -> List[Transition]:
|
||||
"""
|
||||
Get an generator for iterating through the shuffled replay buffer, for processing the data in epochs.
|
||||
If the requested size is larger than the number of samples available in the replay buffer then the batch will
|
||||
|
||||
Reference in New Issue
Block a user