1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-18 03:30:19 +01:00

replace ExperienceReplay._num_transitions with len(ExperienceReplay.transitions)

This commit is contained in:
Zach Dwiel
2018-09-07 14:56:43 -04:00
committed by zach dwiel
parent cccfe88f9b
commit 9f1f9e5ab4

View File

@@ -51,7 +51,6 @@ class ExperienceReplay(Memory):
if max_size[0] != MemoryGranularity.Transitions: if max_size[0] != MemoryGranularity.Transitions:
raise ValueError("Experience replay size can only be configured in terms of transitions") raise ValueError("Experience replay size can only be configured in terms of transitions")
self.transitions = [] self.transitions = []
self._num_transitions = 0
self.allow_duplicates_in_batch_sampling = allow_duplicates_in_batch_sampling self.allow_duplicates_in_batch_sampling = allow_duplicates_in_batch_sampling
self.reader_writer_lock = ReaderWriterLock() self.reader_writer_lock = ReaderWriterLock()
@@ -66,7 +65,7 @@ class ExperienceReplay(Memory):
""" """
Get the number of transitions in the ER Get the number of transitions in the ER
""" """
return self._num_transitions return len(self.transitions)
def sample(self, size: int) -> List[Transition]: def sample(self, size: int) -> List[Transition]:
""" """
@@ -119,7 +118,6 @@ class ExperienceReplay(Memory):
if lock: if lock:
self.reader_writer_lock.lock_writing_and_reading() self.reader_writer_lock.lock_writing_and_reading()
self._num_transitions += 1
self.transitions.append(transition) self.transitions.append(transition)
self._enforce_max_length() self._enforce_max_length()
@@ -149,6 +147,7 @@ class ExperienceReplay(Memory):
def remove_transition(self, transition_index: int, lock: bool=True) -> None: def remove_transition(self, transition_index: int, lock: bool=True) -> None:
""" """
Remove the transition in the given index. Remove the transition in the given index.
This does not remove the transition from the segment trees! it is just used to remove the transition This does not remove the transition from the segment trees! it is just used to remove the transition
from the transitions list from the transitions list
:param transition_index: the index of the transition to remove :param transition_index: the index of the transition to remove
@@ -158,7 +157,6 @@ class ExperienceReplay(Memory):
self.reader_writer_lock.lock_writing_and_reading() self.reader_writer_lock.lock_writing_and_reading()
if self.num_transitions() > transition_index: if self.num_transitions() > transition_index:
self._num_transitions -= 1
del self.transitions[transition_index] del self.transitions[transition_index]
if lock: if lock:
@@ -191,7 +189,6 @@ class ExperienceReplay(Memory):
self.reader_writer_lock.lock_writing_and_reading() self.reader_writer_lock.lock_writing_and_reading()
self.transitions = [] self.transitions = []
self._num_transitions = 0
if lock: if lock:
self.reader_writer_lock.release_writing_and_reading() self.reader_writer_lock.release_writing_and_reading()