mirror of
https://github.com/gryf/coach.git
synced 2025-12-18 03:30:19 +01:00
replace ExperienceReplay._num_transitions with len(ExperienceReplay.transitions)
This commit is contained in:
@@ -51,7 +51,6 @@ class ExperienceReplay(Memory):
|
|||||||
if max_size[0] != MemoryGranularity.Transitions:
|
if max_size[0] != MemoryGranularity.Transitions:
|
||||||
raise ValueError("Experience replay size can only be configured in terms of transitions")
|
raise ValueError("Experience replay size can only be configured in terms of transitions")
|
||||||
self.transitions = []
|
self.transitions = []
|
||||||
self._num_transitions = 0
|
|
||||||
self.allow_duplicates_in_batch_sampling = allow_duplicates_in_batch_sampling
|
self.allow_duplicates_in_batch_sampling = allow_duplicates_in_batch_sampling
|
||||||
|
|
||||||
self.reader_writer_lock = ReaderWriterLock()
|
self.reader_writer_lock = ReaderWriterLock()
|
||||||
@@ -66,7 +65,7 @@ class ExperienceReplay(Memory):
|
|||||||
"""
|
"""
|
||||||
Get the number of transitions in the ER
|
Get the number of transitions in the ER
|
||||||
"""
|
"""
|
||||||
return self._num_transitions
|
return len(self.transitions)
|
||||||
|
|
||||||
def sample(self, size: int) -> List[Transition]:
|
def sample(self, size: int) -> List[Transition]:
|
||||||
"""
|
"""
|
||||||
@@ -119,7 +118,6 @@ class ExperienceReplay(Memory):
|
|||||||
if lock:
|
if lock:
|
||||||
self.reader_writer_lock.lock_writing_and_reading()
|
self.reader_writer_lock.lock_writing_and_reading()
|
||||||
|
|
||||||
self._num_transitions += 1
|
|
||||||
self.transitions.append(transition)
|
self.transitions.append(transition)
|
||||||
self._enforce_max_length()
|
self._enforce_max_length()
|
||||||
|
|
||||||
@@ -149,6 +147,7 @@ class ExperienceReplay(Memory):
|
|||||||
def remove_transition(self, transition_index: int, lock: bool=True) -> None:
|
def remove_transition(self, transition_index: int, lock: bool=True) -> None:
|
||||||
"""
|
"""
|
||||||
Remove the transition in the given index.
|
Remove the transition in the given index.
|
||||||
|
|
||||||
This does not remove the transition from the segment trees! it is just used to remove the transition
|
This does not remove the transition from the segment trees! it is just used to remove the transition
|
||||||
from the transitions list
|
from the transitions list
|
||||||
:param transition_index: the index of the transition to remove
|
:param transition_index: the index of the transition to remove
|
||||||
@@ -158,7 +157,6 @@ class ExperienceReplay(Memory):
|
|||||||
self.reader_writer_lock.lock_writing_and_reading()
|
self.reader_writer_lock.lock_writing_and_reading()
|
||||||
|
|
||||||
if self.num_transitions() > transition_index:
|
if self.num_transitions() > transition_index:
|
||||||
self._num_transitions -= 1
|
|
||||||
del self.transitions[transition_index]
|
del self.transitions[transition_index]
|
||||||
|
|
||||||
if lock:
|
if lock:
|
||||||
@@ -191,7 +189,6 @@ class ExperienceReplay(Memory):
|
|||||||
self.reader_writer_lock.lock_writing_and_reading()
|
self.reader_writer_lock.lock_writing_and_reading()
|
||||||
|
|
||||||
self.transitions = []
|
self.transitions = []
|
||||||
self._num_transitions = 0
|
|
||||||
|
|
||||||
if lock:
|
if lock:
|
||||||
self.reader_writer_lock.release_writing_and_reading()
|
self.reader_writer_lock.release_writing_and_reading()
|
||||||
|
|||||||
Reference in New Issue
Block a user