1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 11:10:20 +01:00

Numpy shared running stats (#97)

This commit is contained in:
Gal Leibovich
2018-11-18 14:46:40 +02:00
committed by GitHub
parent e1fa6e9681
commit 6caf721d1c
10 changed files with 226 additions and 114 deletions

View File

@@ -239,51 +239,6 @@ def squeeze_list(var):
return var
# http://www.johndcook.com/blog/standard_deviation/
class RunningStat(object):
def __init__(self, shape):
self._shape = shape
self._num_samples = 0
self._mean = np.zeros(shape)
self._std = np.zeros(shape)
def reset(self):
self._num_samples = 0
self._mean = np.zeros(self._shape)
self._std = np.zeros(self._shape)
def push(self, sample):
sample = np.asarray(sample)
assert sample.shape == self._mean.shape, 'RunningStat input shape mismatch'
self._num_samples += 1
if self._num_samples == 1:
self._mean[...] = sample
else:
old_mean = self._mean.copy()
self._mean[...] = old_mean + (sample - old_mean) / self._num_samples
self._std[...] = self._std + (sample - old_mean) * (sample - self._mean)
@property
def n(self):
return self._num_samples
@property
def mean(self):
return self._mean
@property
def var(self):
return self._std / (self._num_samples - 1) if self._num_samples > 1 else np.square(self._mean)
@property
def std(self):
return np.sqrt(self.var)
@property
def shape(self):
return self._mean.shape
def get_open_port():
import socket
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
@@ -590,3 +545,5 @@ def start_shell_command_and_wait(command):
def indent_string(string):
return '\t' + string.replace('\n', '\n\t')