diff --git a/rl_coach/memories/non_episodic/distributed_experience_replay.py b/rl_coach/memories/non_episodic/distributed_experience_replay.py index c5c928b..d6bdf9c 100644 --- a/rl_coach/memories/non_episodic/distributed_experience_replay.py +++ b/rl_coach/memories/non_episodic/distributed_experience_replay.py @@ -51,7 +51,6 @@ class DistributedExperienceReplay(Memory): super().__init__(max_size) if max_size[0] != MemoryGranularity.Transitions: raise ValueError("Experience replay size can only be configured in terms of transitions") - # self.transitions = [] self.allow_duplicates_in_batch_sampling = allow_duplicates_in_batch_sampling self.db = db @@ -67,10 +66,8 @@ class DistributedExperienceReplay(Memory): """ Get the number of transitions in the ER """ - # Replace with distributed store len return self.redis_connection.info(section='keyspace')['db{}'.format(self.db)]['keys'] - # return len(self.transitions) - + def sample(self, size: int) -> List[Transition]: """ Sample a batch of transitions form the replay buffer. If the requested size is larger than the number @@ -84,8 +81,6 @@ class DistributedExperienceReplay(Memory): while len(transition_idx) != size: key = self.redis_connection.randomkey() transition_idx[key] = pickle.loads(self.redis_connection.get(key)) - # transition_idx = np.random.randint(self.num_transitions(), size=size) - else: if self.num_transitions() >= size: while len(transition_idx) != size: @@ -93,12 +88,10 @@ class DistributedExperienceReplay(Memory): if key in transition_idx: continue transition_idx[key] = pickle.loads(self.redis_connection.get(key)) - # transition_idx = np.random.choice(self.num_transitions(), size=size, replace=False) else: raise ValueError("The replay buffer cannot be sampled since there are not enough transitions yet. " "There are currently {} transitions".format(self.num_transitions())) - # Replace with distributed store batch = transition_idx.values() return batch @@ -125,10 +118,7 @@ class DistributedExperienceReplay(Memory): locks and then calls store with lock = True :return: None """ - # Replace with distributed store - self.redis_connection.set(uuid.uuid4(), pickle.dumps(transition)) - # self.transitions.append(transition) self._enforce_max_length() def get_transition(self, transition_index: int, lock: bool=True) -> Union[None, Transition]: @@ -138,8 +128,6 @@ class DistributedExperienceReplay(Memory): :param lock: use write locking if this is a shared memory :return: the corresponding transition """ - # Replace with distributed store - import pytest; pytest.set_trace() return pickle.loads(self.redis_connection.get(transition_index)) def remove_transition(self, transition_index: int, lock: bool=True) -> None: @@ -151,8 +139,6 @@ class DistributedExperienceReplay(Memory): :param transition_index: the index of the transition to remove :return: None """ - # Replace with distributed store - import pytest; pytest.set_trace() self.redis_connection.delete(transition_index) # for API compatibility @@ -162,8 +148,6 @@ class DistributedExperienceReplay(Memory): :param transition_index: the index of the transition to return :return: the corresponding transition """ - # Replace with distributed store - import pytest; pytest.set_trace() return self.get_transition(transition_index, lock) # for API compatibility @@ -173,8 +157,6 @@ class DistributedExperienceReplay(Memory): :param transition_index: the index of the transition to remove :return: None """ - # Replace with distributed store - import pytest; pytest.set_trace() self.remove_transition(transition_index, lock) def clean(self, lock: bool=True) -> None: @@ -182,7 +164,6 @@ class DistributedExperienceReplay(Memory): Clean the memory by removing all the episodes :return: None """ - import pytest; pytest.set_trace() self.redis_connection.flushall() # self.transitions = [] @@ -191,9 +172,6 @@ class DistributedExperienceReplay(Memory): Get the mean reward in the replay buffer :return: the mean reward """ - # Replace with distributed store - import pytest; pytest.set_trace() mean = np.mean([pickle.loads(self.redis_connection.get(key)).reward for key in self.redis_connection.keys()]) - # mean = np.mean([transition.reward for transition in self.transitions]) - + return mean