mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 19:20:19 +01:00
Save filters' internal state (#127)
* save filters internal state * moving the restore to be made from within NumpyRunningStats
This commit is contained in:
@@ -13,7 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
import threading
|
||||
import pickle
|
||||
@@ -102,6 +102,14 @@ class SharedRunningStats(ABC):
|
||||
def set_session(self, sess):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def save_state_to_checkpoint(self, checkpoint_dir: str, checkpoint_id: int):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def restore_state_from_checkpoint(self, checkpoint_dir: str):
|
||||
pass
|
||||
|
||||
|
||||
class NumpySharedRunningStats(SharedRunningStats):
|
||||
def __init__(self, name, epsilon=1e-2, pubsub_params=None):
|
||||
@@ -156,4 +164,21 @@ class NumpySharedRunningStats(SharedRunningStats):
|
||||
# no session for the numpy implementation
|
||||
pass
|
||||
|
||||
def save_state_to_checkpoint(self, checkpoint_dir: str, checkpoint_id: int):
|
||||
with open(os.path.join(checkpoint_dir, str(checkpoint_id) + '.srs'), 'wb') as f:
|
||||
pickle.dump(self.__dict__, f, pickle.HIGHEST_PROTOCOL)
|
||||
|
||||
def restore_state_from_checkpoint(self, checkpoint_dir: str):
|
||||
latest_checkpoint = -1
|
||||
# get all checkpoint files
|
||||
for fname in os.listdir(checkpoint_dir):
|
||||
path = os.path.join(checkpoint_dir, fname)
|
||||
if os.path.isdir(path):
|
||||
continue
|
||||
checkpoint_id = int(fname.split('.')[0])
|
||||
if checkpoint_id > latest_checkpoint:
|
||||
latest_checkpoint = checkpoint_id
|
||||
|
||||
with open(os.path.join(checkpoint_dir, str(latest_checkpoint) + '.srs'), 'rb') as f:
|
||||
temp_running_observation_stats = pickle.load(f)
|
||||
self.__dict__.update(temp_running_observation_stats)
|
||||
|
||||
Reference in New Issue
Block a user