mirror of
https://github.com/gryf/coach.git
synced 2026-03-29 16:13:31 +02:00
Batch RL (#238)
This commit is contained in:
@@ -14,10 +14,10 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from typing import List, Tuple, Union, Dict, Any
|
||||
from typing import List, Tuple, Union
|
||||
import pickle
|
||||
import sys
|
||||
import time
|
||||
import random
|
||||
import math
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -72,7 +72,6 @@ class ExperienceReplay(Memory):
|
||||
Sample a batch of transitions form the replay buffer. If the requested size is larger than the number
|
||||
of samples available in the replay buffer then the batch will return empty.
|
||||
:param size: the size of the batch to sample
|
||||
:param beta: the beta parameter used for importance sampling
|
||||
:return: a batch (list) of selected transitions from the replay buffer
|
||||
"""
|
||||
self.reader_writer_lock.lock_writing()
|
||||
@@ -92,6 +91,32 @@ class ExperienceReplay(Memory):
|
||||
self.reader_writer_lock.release_writing()
|
||||
return batch
|
||||
|
||||
def get_shuffled_data_generator(self, size: int) -> List[Transition]:
|
||||
"""
|
||||
Get an generator for iterating through the shuffled replay buffer, for processing the data in epochs.
|
||||
If the requested size is larger than the number of samples available in the replay buffer then the batch will
|
||||
return empty. The last returned batch may be smaller than the size requested, to accommodate for all the
|
||||
transitions in the replay buffer.
|
||||
|
||||
:param size: the size of the batch to return
|
||||
:return: a batch (list) of selected transitions from the replay buffer
|
||||
"""
|
||||
self.reader_writer_lock.lock_writing()
|
||||
shuffled_transition_indices = list(range(len(self.transitions)))
|
||||
random.shuffle(shuffled_transition_indices)
|
||||
|
||||
# we deliberately drop some of the ending data which is left after dividing to batches of size `size`
|
||||
# for i in range(math.ceil(len(shuffled_transition_indices) / size)):
|
||||
for i in range(int(len(shuffled_transition_indices) / size)):
|
||||
sample_data = [self.transitions[j] for j in shuffled_transition_indices[i * size: (i + 1) * size]]
|
||||
self.reader_writer_lock.release_writing()
|
||||
|
||||
yield sample_data
|
||||
|
||||
## usage example
|
||||
# for o in random_seq_generator(list(range(10)), 4):
|
||||
# print(o)
|
||||
|
||||
def _enforce_max_length(self) -> None:
|
||||
"""
|
||||
Make sure that the size of the replay buffer does not pass the maximum size allowed.
|
||||
@@ -215,7 +240,7 @@ class ExperienceReplay(Memory):
|
||||
with open(file_path, 'wb') as file:
|
||||
pickle.dump(self.transitions, file)
|
||||
|
||||
def load(self, file_path: str) -> None:
|
||||
def load_pickled(self, file_path: str) -> None:
|
||||
"""
|
||||
Restore the replay buffer contents from a pickle file.
|
||||
The pickle file is assumed to include a list of transitions.
|
||||
@@ -238,3 +263,4 @@ class ExperienceReplay(Memory):
|
||||
progress_bar.update(transition_idx)
|
||||
|
||||
progress_bar.close()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user