1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 19:20:19 +01:00

fix for finding the right filter checkpoint to restore + do not update internal filter state when evaluating + fix SharedRunningStats checkpoint filenames (#147)

This commit is contained in:
Gal Leibovich
2018-12-17 21:36:27 +02:00
committed by GitHub
parent b4bc8a476c
commit 4c914c057c
4 changed files with 19 additions and 15 deletions

View File

@@ -109,13 +109,13 @@ class SharedRunningStats(ABC):
def restore_state_from_checkpoint(self, checkpoint_dir: str, checkpoint_prefix: str):
pass
def get_latest_checkpoint(self, checkpoint_dir: str) -> str:
def get_latest_checkpoint(self, checkpoint_dir: str, checkpoint_prefix: str) -> str:
latest_checkpoint_id = -1
latest_checkpoint = ''
# get all checkpoint files
for fname in os.listdir(checkpoint_dir):
path = os.path.join(checkpoint_dir, fname)
if os.path.isdir(path) or fname.split('.')[-1] != 'srs':
if os.path.isdir(path) or fname.split('.')[-1] != 'srs' or checkpoint_prefix not in fname:
continue
checkpoint_id = int(fname.split('_')[0])
if checkpoint_id > latest_checkpoint_id:
@@ -189,7 +189,7 @@ class NumpySharedRunningStats(SharedRunningStats):
pickle.dump(dict_to_save, f, pickle.HIGHEST_PROTOCOL)
def restore_state_from_checkpoint(self, checkpoint_dir: str, checkpoint_prefix: str):
latest_checkpoint_filename = self.get_latest_checkpoint(checkpoint_dir)
latest_checkpoint_filename = self.get_latest_checkpoint(checkpoint_dir, checkpoint_prefix)
if latest_checkpoint_filename == '':
raise ValueError("Could not find NumpySharedRunningStats checkpoint file. ")