diff --git a/rl_coach/utilities/shared_running_stats.py b/rl_coach/utilities/shared_running_stats.py index 263fae6..875d902 100644 --- a/rl_coach/utilities/shared_running_stats.py +++ b/rl_coach/utilities/shared_running_stats.py @@ -130,8 +130,9 @@ class NumpySharedRunningStats(SharedRunningStats): def push_val(self, samples: np.ndarray): assert len(samples.shape) >= 2 # we should always have a batch dimension assert samples.shape[1:] == self._mean.shape, 'RunningStats input shape mismatch' - self._sum += samples.sum(axis=0).ravel() - self._sum_squares += np.square(samples).sum(axis=0).ravel() + samples = samples.astype(np.float64) + self._sum += samples.sum(axis=0) + self._sum_squares += np.square(samples).sum(axis=0) self._count += np.shape(samples)[0] self._mean = self._sum / self._count self._std = np.sqrt(np.maximum(