1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 19:20:19 +01:00

Save filters' internal state (#127)

* save filters internal state

* moving the restore to be made from within NumpyRunningStats
This commit is contained in:
Gal Leibovich
2018-11-20 17:21:48 +02:00
committed by GitHub
parent 67eb9e4c28
commit a112ee69f6
13 changed files with 173 additions and 14 deletions

View File

@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os
from abc import ABC, abstractmethod
import threading
import pickle
@@ -102,6 +102,14 @@ class SharedRunningStats(ABC):
def set_session(self, sess):
pass
@abstractmethod
def save_state_to_checkpoint(self, checkpoint_dir: str, checkpoint_id: int):
pass
@abstractmethod
def restore_state_from_checkpoint(self, checkpoint_dir: str):
pass
class NumpySharedRunningStats(SharedRunningStats):
def __init__(self, name, epsilon=1e-2, pubsub_params=None):
@@ -156,4 +164,21 @@ class NumpySharedRunningStats(SharedRunningStats):
# no session for the numpy implementation
pass
def save_state_to_checkpoint(self, checkpoint_dir: str, checkpoint_id: int):
with open(os.path.join(checkpoint_dir, str(checkpoint_id) + '.srs'), 'wb') as f:
pickle.dump(self.__dict__, f, pickle.HIGHEST_PROTOCOL)
def restore_state_from_checkpoint(self, checkpoint_dir: str):
latest_checkpoint = -1
# get all checkpoint files
for fname in os.listdir(checkpoint_dir):
path = os.path.join(checkpoint_dir, fname)
if os.path.isdir(path):
continue
checkpoint_id = int(fname.split('.')[0])
if checkpoint_id > latest_checkpoint:
latest_checkpoint = checkpoint_id
with open(os.path.join(checkpoint_dir, str(latest_checkpoint) + '.srs'), 'rb') as f:
temp_running_observation_stats = pickle.load(f)
self.__dict__.update(temp_running_observation_stats)