1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 19:20:19 +01:00

Fix numpy shared running stats to support images (#411)

This commit is contained in:
shadiendrawis
2019-10-06 12:16:38 +03:00
committed by Gal Leibovich
parent 79a4161eca
commit 0a712ecc94

View File

@@ -130,8 +130,9 @@ class NumpySharedRunningStats(SharedRunningStats):
def push_val(self, samples: np.ndarray):
assert len(samples.shape) >= 2 # we should always have a batch dimension
assert samples.shape[1:] == self._mean.shape, 'RunningStats input shape mismatch'
self._sum += samples.sum(axis=0).ravel()
self._sum_squares += np.square(samples).sum(axis=0).ravel()
samples = samples.astype(np.float64)
self._sum += samples.sum(axis=0)
self._sum_squares += np.square(samples).sum(axis=0)
self._count += np.shape(samples)[0]
self._mean = self._sum / self._count
self._std = np.sqrt(np.maximum(