1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 11:10:20 +01:00

replace ExperienceReplay._num_transitions with len(ExperienceReplay.transitions)

This commit is contained in:
Zach Dwiel
2018-09-07 14:56:43 -04:00
committed by zach dwiel
parent cccfe88f9b
commit 9f1f9e5ab4

View File

@@ -51,7 +51,6 @@ class ExperienceReplay(Memory):
if max_size[0] != MemoryGranularity.Transitions:
raise ValueError("Experience replay size can only be configured in terms of transitions")
self.transitions = []
self._num_transitions = 0
self.allow_duplicates_in_batch_sampling = allow_duplicates_in_batch_sampling
self.reader_writer_lock = ReaderWriterLock()
@@ -66,7 +65,7 @@ class ExperienceReplay(Memory):
"""
Get the number of transitions in the ER
"""
return self._num_transitions
return len(self.transitions)
def sample(self, size: int) -> List[Transition]:
"""
@@ -119,7 +118,6 @@ class ExperienceReplay(Memory):
if lock:
self.reader_writer_lock.lock_writing_and_reading()
self._num_transitions += 1
self.transitions.append(transition)
self._enforce_max_length()
@@ -149,6 +147,7 @@ class ExperienceReplay(Memory):
def remove_transition(self, transition_index: int, lock: bool=True) -> None:
"""
Remove the transition in the given index.
This does not remove the transition from the segment trees! it is just used to remove the transition
from the transitions list
:param transition_index: the index of the transition to remove
@@ -158,7 +157,6 @@ class ExperienceReplay(Memory):
self.reader_writer_lock.lock_writing_and_reading()
if self.num_transitions() > transition_index:
self._num_transitions -= 1
del self.transitions[transition_index]
if lock:
@@ -191,7 +189,6 @@ class ExperienceReplay(Memory):
self.reader_writer_lock.lock_writing_and_reading()
self.transitions = []
self._num_transitions = 0
if lock:
self.reader_writer_lock.release_writing_and_reading()