mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 19:20:19 +01:00
Fix numpy shared running stats to support images (#411)
This commit is contained in:
committed by
Gal Leibovich
parent
79a4161eca
commit
0a712ecc94
@@ -130,8 +130,9 @@ class NumpySharedRunningStats(SharedRunningStats):
|
|||||||
def push_val(self, samples: np.ndarray):
|
def push_val(self, samples: np.ndarray):
|
||||||
assert len(samples.shape) >= 2 # we should always have a batch dimension
|
assert len(samples.shape) >= 2 # we should always have a batch dimension
|
||||||
assert samples.shape[1:] == self._mean.shape, 'RunningStats input shape mismatch'
|
assert samples.shape[1:] == self._mean.shape, 'RunningStats input shape mismatch'
|
||||||
self._sum += samples.sum(axis=0).ravel()
|
samples = samples.astype(np.float64)
|
||||||
self._sum_squares += np.square(samples).sum(axis=0).ravel()
|
self._sum += samples.sum(axis=0)
|
||||||
|
self._sum_squares += np.square(samples).sum(axis=0)
|
||||||
self._count += np.shape(samples)[0]
|
self._count += np.shape(samples)[0]
|
||||||
self._mean = self._sum / self._count
|
self._mean = self._sum / self._count
|
||||||
self._std = np.sqrt(np.maximum(
|
self._std = np.sqrt(np.maximum(
|
||||||
|
|||||||
Reference in New Issue
Block a user