mirror of
https://github.com/gryf/coach.git
synced 2026-03-14 05:35:55 +01:00
removing some of the presets from the trace tests + more robust replay buffer loading
This commit is contained in:
@@ -16,6 +16,8 @@
|
||||
|
||||
from typing import List, Tuple, Union, Dict, Any
|
||||
import pickle
|
||||
import sys
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -235,5 +237,17 @@ class ExperienceReplay(Memory):
|
||||
:param file_path: The path to a pickle file to restore
|
||||
"""
|
||||
with open(file_path, 'rb') as file:
|
||||
self.transitions = pickle.load(file)
|
||||
self._num_transitions = len(self.transitions)
|
||||
transitions = pickle.load(file)
|
||||
num_transitions = len(transitions)
|
||||
start_time = time.time()
|
||||
for transition_idx, transition in enumerate(transitions):
|
||||
self.store(transition)
|
||||
|
||||
# print progress
|
||||
if transition_idx % 100 == 0:
|
||||
percentage = int((100 * transition_idx) / num_transitions)
|
||||
sys.stdout.write("\rProgress: ({}/{})".format(transition_idx, num_transitions))
|
||||
sys.stdout.write(' Time (sec): {}'.format(round(time.time() - start_time, 2)))
|
||||
sys.stdout.write(' {}%|{}{}| '.format(percentage, '#' * int(percentage / 10),
|
||||
' ' * (10 - int(percentage / 10))))
|
||||
sys.stdout.flush()
|
||||
|
||||
Reference in New Issue
Block a user