mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 11:10:20 +01:00
1. Having the standard checkpoint prefix in order for the data store to grab it, and sync it to S3. 2. Removing the reference to Redis so that it won't try to pickle that in. 3. Enable restoring a checkpoint into a single-worker run, which was saved by a single-node-multiple-worker run.
200 lines
6.2 KiB
Python
200 lines
6.2 KiB
Python
#
|
|
# Copyright (c) 2017 Intel Corporation
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
#
|
|
import os
|
|
from abc import ABC, abstractmethod
|
|
import threading
|
|
import pickle
|
|
import redis
|
|
import numpy as np
|
|
|
|
|
|
class SharedRunningStatsSubscribe(threading.Thread):
|
|
def __init__(self, shared_running_stats):
|
|
super().__init__()
|
|
self.shared_running_stats = shared_running_stats
|
|
self.redis_address = self.shared_running_stats.pubsub.params.redis_address
|
|
self.redis_port = self.shared_running_stats.pubsub.params.redis_port
|
|
self.redis_connection = redis.Redis(self.redis_address, self.redis_port)
|
|
self.pubsub = self.redis_connection.pubsub()
|
|
self.channel = self.shared_running_stats.channel
|
|
self.pubsub.subscribe(self.channel)
|
|
|
|
def run(self):
|
|
for message in self.pubsub.listen():
|
|
try:
|
|
obj = pickle.loads(message['data'])
|
|
self.shared_running_stats.push_val(obj)
|
|
except Exception:
|
|
continue
|
|
|
|
|
|
class SharedRunningStats(ABC):
|
|
def __init__(self, name="", pubsub_params=None):
|
|
self.name = name
|
|
self.pubsub = None
|
|
if pubsub_params:
|
|
self.channel = "channel-srs-{}".format(self.name)
|
|
from rl_coach.memories.backend.memory_impl import get_memory_backend
|
|
self.pubsub = get_memory_backend(pubsub_params)
|
|
subscribe_thread = SharedRunningStatsSubscribe(self)
|
|
subscribe_thread.daemon = True
|
|
subscribe_thread.start()
|
|
|
|
@abstractmethod
|
|
def set_params(self, shape=[1], clip_values=None):
|
|
pass
|
|
|
|
def push(self, x):
|
|
if self.pubsub:
|
|
self.pubsub.redis_connection.publish(self.channel, pickle.dumps(x))
|
|
return
|
|
|
|
self.push_val(x)
|
|
|
|
@abstractmethod
|
|
def push_val(self, x):
|
|
pass
|
|
|
|
@property
|
|
@abstractmethod
|
|
def n(self):
|
|
pass
|
|
|
|
@property
|
|
@abstractmethod
|
|
def mean(self):
|
|
pass
|
|
|
|
@property
|
|
@abstractmethod
|
|
def var(self):
|
|
pass
|
|
|
|
@property
|
|
@abstractmethod
|
|
def std(self):
|
|
pass
|
|
|
|
@property
|
|
@abstractmethod
|
|
def shape(self):
|
|
pass
|
|
|
|
@abstractmethod
|
|
def normalize(self, batch):
|
|
pass
|
|
|
|
@abstractmethod
|
|
def set_session(self, sess):
|
|
pass
|
|
|
|
@abstractmethod
|
|
def save_state_to_checkpoint(self, checkpoint_dir: str, checkpoint_prefix: int):
|
|
pass
|
|
|
|
@abstractmethod
|
|
def restore_state_from_checkpoint(self, checkpoint_dir: str, checkpoint_prefix: str):
|
|
pass
|
|
|
|
def get_latest_checkpoint(self, checkpoint_dir: str) -> str:
|
|
latest_checkpoint_id = -1
|
|
latest_checkpoint = ''
|
|
# get all checkpoint files
|
|
for fname in os.listdir(checkpoint_dir):
|
|
path = os.path.join(checkpoint_dir, fname)
|
|
if os.path.isdir(path) or fname.split('.')[-1] != 'srs':
|
|
continue
|
|
checkpoint_id = int(fname.split('_')[0])
|
|
if checkpoint_id > latest_checkpoint_id:
|
|
latest_checkpoint = fname
|
|
latest_checkpoint_id = checkpoint_id
|
|
|
|
return latest_checkpoint
|
|
|
|
|
|
class NumpySharedRunningStats(SharedRunningStats):
|
|
def __init__(self, name, epsilon=1e-2, pubsub_params=None):
|
|
super().__init__(name=name, pubsub_params=pubsub_params)
|
|
self._count = epsilon
|
|
self.epsilon = epsilon
|
|
|
|
def set_params(self, shape=[1], clip_values=None):
|
|
self._shape = shape
|
|
self._mean = np.zeros(shape)
|
|
self._std = np.sqrt(self.epsilon) * np.ones(shape)
|
|
self._sum = np.zeros(shape)
|
|
self._sum_squares = self.epsilon * np.ones(shape)
|
|
self.clip_values = clip_values
|
|
|
|
def push_val(self, samples: np.ndarray):
|
|
assert len(samples.shape) >= 2 # we should always have a batch dimension
|
|
assert samples.shape[1:] == self._mean.shape, 'RunningStats input shape mismatch'
|
|
self._sum += samples.sum(axis=0).ravel()
|
|
self._sum_squares += np.square(samples).sum(axis=0).ravel()
|
|
self._count += np.shape(samples)[0]
|
|
self._mean = self._sum / self._count
|
|
self._std = np.sqrt(np.maximum(
|
|
(self._sum_squares - self._count * np.square(self._mean)) / np.maximum(self._count - 1, 1),
|
|
self.epsilon))
|
|
|
|
@property
|
|
def n(self):
|
|
return self._count
|
|
|
|
@property
|
|
def mean(self):
|
|
return self._mean
|
|
|
|
@property
|
|
def var(self):
|
|
return self._std ** 2
|
|
|
|
@property
|
|
def std(self):
|
|
return self._std
|
|
|
|
@property
|
|
def shape(self):
|
|
return self._mean.shape
|
|
|
|
def normalize(self, batch):
|
|
batch = (batch - self.mean) / (self.std + 1e-15)
|
|
return np.clip(batch, *self.clip_values)
|
|
|
|
def set_session(self, sess):
|
|
# no session for the numpy implementation
|
|
pass
|
|
|
|
def save_state_to_checkpoint(self, checkpoint_dir: str, checkpoint_prefix: int):
|
|
dict_to_save = {'_mean': self._mean,
|
|
'_std': self._std,
|
|
'_count': self._count,
|
|
'_sum': self._sum,
|
|
'_sum_squares': self._sum_squares}
|
|
|
|
with open(os.path.join(checkpoint_dir, str(checkpoint_prefix) + '.srs'), 'wb') as f:
|
|
pickle.dump(dict_to_save, f, pickle.HIGHEST_PROTOCOL)
|
|
|
|
def restore_state_from_checkpoint(self, checkpoint_dir: str, checkpoint_prefix: str):
|
|
latest_checkpoint_filename = self.get_latest_checkpoint(checkpoint_dir)
|
|
|
|
if latest_checkpoint_filename == '':
|
|
raise ValueError("Could not find NumpySharedRunningStats checkpoint file. ")
|
|
|
|
with open(os.path.join(checkpoint_dir, str(latest_checkpoint_filename)), 'rb') as f:
|
|
saved_dict = pickle.load(f)
|
|
self.__dict__.update(saved_dict)
|