mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 11:10:20 +01:00
Fix numpy shared running stats to support images (#411)
This commit is contained in:
committed by
Gal Leibovich
parent
79a4161eca
commit
0a712ecc94
@@ -130,8 +130,9 @@ class NumpySharedRunningStats(SharedRunningStats):
|
||||
def push_val(self, samples: np.ndarray):
|
||||
assert len(samples.shape) >= 2 # we should always have a batch dimension
|
||||
assert samples.shape[1:] == self._mean.shape, 'RunningStats input shape mismatch'
|
||||
self._sum += samples.sum(axis=0).ravel()
|
||||
self._sum_squares += np.square(samples).sum(axis=0).ravel()
|
||||
samples = samples.astype(np.float64)
|
||||
self._sum += samples.sum(axis=0)
|
||||
self._sum_squares += np.square(samples).sum(axis=0)
|
||||
self._count += np.shape(samples)[0]
|
||||
self._mean = self._sum / self._count
|
||||
self._std = np.sqrt(np.maximum(
|
||||
|
||||
Reference in New Issue
Block a user