mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 19:20:19 +01:00
fix for finding the right filter checkpoint to restore + do not update internal filter state when evaluating + fix SharedRunningStats checkpoint filenames (#147)
This commit is contained in:
@@ -109,13 +109,13 @@ class SharedRunningStats(ABC):
|
||||
def restore_state_from_checkpoint(self, checkpoint_dir: str, checkpoint_prefix: str):
|
||||
pass
|
||||
|
||||
def get_latest_checkpoint(self, checkpoint_dir: str) -> str:
|
||||
def get_latest_checkpoint(self, checkpoint_dir: str, checkpoint_prefix: str) -> str:
|
||||
latest_checkpoint_id = -1
|
||||
latest_checkpoint = ''
|
||||
# get all checkpoint files
|
||||
for fname in os.listdir(checkpoint_dir):
|
||||
path = os.path.join(checkpoint_dir, fname)
|
||||
if os.path.isdir(path) or fname.split('.')[-1] != 'srs':
|
||||
if os.path.isdir(path) or fname.split('.')[-1] != 'srs' or checkpoint_prefix not in fname:
|
||||
continue
|
||||
checkpoint_id = int(fname.split('_')[0])
|
||||
if checkpoint_id > latest_checkpoint_id:
|
||||
@@ -189,7 +189,7 @@ class NumpySharedRunningStats(SharedRunningStats):
|
||||
pickle.dump(dict_to_save, f, pickle.HIGHEST_PROTOCOL)
|
||||
|
||||
def restore_state_from_checkpoint(self, checkpoint_dir: str, checkpoint_prefix: str):
|
||||
latest_checkpoint_filename = self.get_latest_checkpoint(checkpoint_dir)
|
||||
latest_checkpoint_filename = self.get_latest_checkpoint(checkpoint_dir, checkpoint_prefix)
|
||||
|
||||
if latest_checkpoint_filename == '':
|
||||
raise ValueError("Could not find NumpySharedRunningStats checkpoint file. ")
|
||||
|
||||
Reference in New Issue
Block a user