mirror of
https://github.com/gryf/coach.git
synced 2026-03-15 06:03:33 +01:00
load and save function for non-episodic replay buffers + carla improvements + network bug fixes
This commit is contained in:
@@ -15,6 +15,7 @@
|
||||
#
|
||||
|
||||
from typing import List, Tuple, Union, Dict, Any
|
||||
import pickle
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -218,3 +219,21 @@ class ExperienceReplay(Memory):
|
||||
self.reader_writer_lock.release_writing()
|
||||
|
||||
return mean
|
||||
|
||||
def save(self, file_path: str) -> None:
|
||||
"""
|
||||
Save the replay buffer contents to a pickle file
|
||||
:param file_path: the path to the file that will be used to store the pickled transitions
|
||||
"""
|
||||
with open(file_path, 'wb') as file:
|
||||
pickle.dump(self.transitions, file)
|
||||
|
||||
def load(self, file_path: str) -> None:
|
||||
"""
|
||||
Restore the replay buffer contents from a pickle file.
|
||||
The pickle file is assumed to include a list of transitions.
|
||||
:param file_path: The path to a pickle file to restore
|
||||
"""
|
||||
with open(file_path, 'rb') as file:
|
||||
self.transitions = pickle.load(file)
|
||||
self._num_transitions = len(self.transitions)
|
||||
|
||||
Reference in New Issue
Block a user