diff --git a/rl_coach/memories/non_episodic/experience_replay.py b/rl_coach/memories/non_episodic/experience_replay.py index 798854c..4887e49 100644 --- a/rl_coach/memories/non_episodic/experience_replay.py +++ b/rl_coach/memories/non_episodic/experience_replay.py @@ -51,7 +51,6 @@ class ExperienceReplay(Memory): if max_size[0] != MemoryGranularity.Transitions: raise ValueError("Experience replay size can only be configured in terms of transitions") self.transitions = [] - self._num_transitions = 0 self.allow_duplicates_in_batch_sampling = allow_duplicates_in_batch_sampling self.reader_writer_lock = ReaderWriterLock() @@ -66,7 +65,7 @@ class ExperienceReplay(Memory): """ Get the number of transitions in the ER """ - return self._num_transitions + return len(self.transitions) def sample(self, size: int) -> List[Transition]: """ @@ -119,7 +118,6 @@ class ExperienceReplay(Memory): if lock: self.reader_writer_lock.lock_writing_and_reading() - self._num_transitions += 1 self.transitions.append(transition) self._enforce_max_length() @@ -149,6 +147,7 @@ class ExperienceReplay(Memory): def remove_transition(self, transition_index: int, lock: bool=True) -> None: """ Remove the transition in the given index. + This does not remove the transition from the segment trees! it is just used to remove the transition from the transitions list :param transition_index: the index of the transition to remove @@ -158,7 +157,6 @@ class ExperienceReplay(Memory): self.reader_writer_lock.lock_writing_and_reading() if self.num_transitions() > transition_index: - self._num_transitions -= 1 del self.transitions[transition_index] if lock: @@ -191,7 +189,6 @@ class ExperienceReplay(Memory): self.reader_writer_lock.lock_writing_and_reading() self.transitions = [] - self._num_transitions = 0 if lock: self.reader_writer_lock.release_writing_and_reading()