mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 19:20:19 +01:00
Numpy shared running stats (#97)
This commit is contained in:
159
rl_coach/utilities/shared_running_stats.py
Normal file
159
rl_coach/utilities/shared_running_stats.py
Normal file
@@ -0,0 +1,159 @@
|
||||
#
|
||||
# Copyright (c) 2017 Intel Corporation
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
import threading
|
||||
import pickle
|
||||
import redis
|
||||
import numpy as np
|
||||
|
||||
from rl_coach.memories.backend.memory_impl import get_memory_backend
|
||||
|
||||
|
||||
class SharedRunningStatsSubscribe(threading.Thread):
|
||||
def __init__(self, shared_running_stats):
|
||||
super().__init__()
|
||||
self.shared_running_stats = shared_running_stats
|
||||
self.redis_address = self.shared_running_stats.pubsub.params.redis_address
|
||||
self.redis_port = self.shared_running_stats.pubsub.params.redis_port
|
||||
self.redis_connection = redis.Redis(self.redis_address, self.redis_port)
|
||||
self.pubsub = self.redis_connection.pubsub()
|
||||
self.channel = self.shared_running_stats.channel
|
||||
self.pubsub.subscribe(self.channel)
|
||||
|
||||
def run(self):
|
||||
for message in self.pubsub.listen():
|
||||
try:
|
||||
obj = pickle.loads(message['data'])
|
||||
self.shared_running_stats.push_val(obj)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
|
||||
class SharedRunningStats(ABC):
|
||||
def __init__(self, name="", pubsub_params=None):
|
||||
self.name = name
|
||||
self.pubsub = None
|
||||
if pubsub_params:
|
||||
self.channel = "channel-srs-{}".format(self.name)
|
||||
self.pubsub = get_memory_backend(pubsub_params)
|
||||
subscribe_thread = SharedRunningStatsSubscribe(self)
|
||||
subscribe_thread.daemon = True
|
||||
subscribe_thread.start()
|
||||
|
||||
@abstractmethod
|
||||
def set_params(self, shape=[1], clip_values=None):
|
||||
pass
|
||||
|
||||
def push(self, x):
|
||||
if self.pubsub:
|
||||
self.pubsub.redis_connection.publish(self.channel, pickle.dumps(x))
|
||||
return
|
||||
|
||||
self.push_val(x)
|
||||
|
||||
@abstractmethod
|
||||
def push_val(self, x):
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def n(self):
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def mean(self):
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def var(self):
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def std(self):
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def shape(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def normalize(self, batch):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def set_session(self, sess):
|
||||
pass
|
||||
|
||||
|
||||
class NumpySharedRunningStats(SharedRunningStats):
|
||||
def __init__(self, name, epsilon=1e-2, pubsub_params=None):
|
||||
super().__init__(name=name, pubsub_params=pubsub_params)
|
||||
self._count = epsilon
|
||||
self.epsilon = epsilon
|
||||
|
||||
def set_params(self, shape=[1], clip_values=None):
|
||||
self._shape = shape
|
||||
self._mean = np.zeros(shape)
|
||||
self._std = np.sqrt(self.epsilon) * np.ones(shape)
|
||||
self._sum = np.zeros(shape)
|
||||
self._sum_squares = self.epsilon * np.ones(shape)
|
||||
self.clip_values = clip_values
|
||||
|
||||
def push_val(self, samples: np.ndarray):
|
||||
assert len(samples.shape) >= 2 # we should always have a batch dimension
|
||||
assert samples.shape[1:] == self._mean.shape, 'RunningStats input shape mismatch'
|
||||
self._sum += samples.sum(axis=0).ravel()
|
||||
self._sum_squares += np.square(samples).sum(axis=0).ravel()
|
||||
self._count += np.shape(samples)[0]
|
||||
self._mean = self._sum / self._count
|
||||
self._std = np.sqrt(np.maximum(
|
||||
(self._sum_squares - self._count * np.square(self._mean)) / np.maximum(self._count - 1, 1),
|
||||
self.epsilon))
|
||||
|
||||
@property
|
||||
def n(self):
|
||||
return self._count
|
||||
|
||||
@property
|
||||
def mean(self):
|
||||
return self._mean
|
||||
|
||||
@property
|
||||
def var(self):
|
||||
return self._std ** 2
|
||||
|
||||
@property
|
||||
def std(self):
|
||||
return self._std
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
return self._mean.shape
|
||||
|
||||
def normalize(self, batch):
|
||||
batch = (batch - self.mean) / (self.std + 1e-15)
|
||||
return np.clip(batch, *self.clip_values)
|
||||
|
||||
def set_session(self, sess):
|
||||
# no session for the numpy implementation
|
||||
pass
|
||||
|
||||
|
||||
Reference in New Issue
Block a user