mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 19:20:19 +01:00
Fixes for having NumpySharedRunningStats syncing on multi-node (#139)
1. Having the standard checkpoint prefix in order for the data store to grab it, and sync it to S3. 2. Removing the reference to Redis so that it won't try to pickle that in. 3. Enable restoring a checkpoint into a single-worker run, which was saved by a single-node-multiple-worker run.
This commit is contained in:
@@ -21,7 +21,6 @@ import redis
|
||||
import numpy as np
|
||||
|
||||
|
||||
|
||||
class SharedRunningStatsSubscribe(threading.Thread):
|
||||
def __init__(self, shared_running_stats):
|
||||
super().__init__()
|
||||
@@ -103,13 +102,28 @@ class SharedRunningStats(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def save_state_to_checkpoint(self, checkpoint_dir: str, checkpoint_id: int):
|
||||
def save_state_to_checkpoint(self, checkpoint_dir: str, checkpoint_prefix: int):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def restore_state_from_checkpoint(self, checkpoint_dir: str):
|
||||
def restore_state_from_checkpoint(self, checkpoint_dir: str, checkpoint_prefix: str):
|
||||
pass
|
||||
|
||||
def get_latest_checkpoint(self, checkpoint_dir: str) -> str:
|
||||
latest_checkpoint_id = -1
|
||||
latest_checkpoint = ''
|
||||
# get all checkpoint files
|
||||
for fname in os.listdir(checkpoint_dir):
|
||||
path = os.path.join(checkpoint_dir, fname)
|
||||
if os.path.isdir(path) or fname.split('.')[-1] != 'srs':
|
||||
continue
|
||||
checkpoint_id = int(fname.split('_')[0])
|
||||
if checkpoint_id > latest_checkpoint_id:
|
||||
latest_checkpoint = fname
|
||||
latest_checkpoint_id = checkpoint_id
|
||||
|
||||
return latest_checkpoint
|
||||
|
||||
|
||||
class NumpySharedRunningStats(SharedRunningStats):
|
||||
def __init__(self, name, epsilon=1e-2, pubsub_params=None):
|
||||
@@ -164,21 +178,22 @@ class NumpySharedRunningStats(SharedRunningStats):
|
||||
# no session for the numpy implementation
|
||||
pass
|
||||
|
||||
def save_state_to_checkpoint(self, checkpoint_dir: str, checkpoint_id: int):
|
||||
with open(os.path.join(checkpoint_dir, str(checkpoint_id) + '.srs'), 'wb') as f:
|
||||
pickle.dump(self.__dict__, f, pickle.HIGHEST_PROTOCOL)
|
||||
def save_state_to_checkpoint(self, checkpoint_dir: str, checkpoint_prefix: int):
|
||||
dict_to_save = {'_mean': self._mean,
|
||||
'_std': self._std,
|
||||
'_count': self._count,
|
||||
'_sum': self._sum,
|
||||
'_sum_squares': self._sum_squares}
|
||||
|
||||
def restore_state_from_checkpoint(self, checkpoint_dir: str):
|
||||
latest_checkpoint = -1
|
||||
# get all checkpoint files
|
||||
for fname in os.listdir(checkpoint_dir):
|
||||
path = os.path.join(checkpoint_dir, fname)
|
||||
if os.path.isdir(path):
|
||||
continue
|
||||
checkpoint_id = int(fname.split('.')[0])
|
||||
if checkpoint_id > latest_checkpoint:
|
||||
latest_checkpoint = checkpoint_id
|
||||
with open(os.path.join(checkpoint_dir, str(checkpoint_prefix) + '.srs'), 'wb') as f:
|
||||
pickle.dump(dict_to_save, f, pickle.HIGHEST_PROTOCOL)
|
||||
|
||||
with open(os.path.join(checkpoint_dir, str(latest_checkpoint) + '.srs'), 'rb') as f:
|
||||
temp_running_observation_stats = pickle.load(f)
|
||||
self.__dict__.update(temp_running_observation_stats)
|
||||
def restore_state_from_checkpoint(self, checkpoint_dir: str, checkpoint_prefix: str):
|
||||
latest_checkpoint_filename = self.get_latest_checkpoint(checkpoint_dir)
|
||||
|
||||
if latest_checkpoint_filename == '':
|
||||
raise ValueError("Could not find NumpySharedRunningStats checkpoint file. ")
|
||||
|
||||
with open(os.path.join(checkpoint_dir, str(latest_checkpoint_filename)), 'rb') as f:
|
||||
saved_dict = pickle.load(f)
|
||||
self.__dict__.update(saved_dict)
|
||||
|
||||
Reference in New Issue
Block a user