diff --git a/architectures/tensorflow_components/shared_variables.py b/architectures/tensorflow_components/shared_variables.py index 8ee623e..2775251 100644 --- a/architectures/tensorflow_components/shared_variables.py +++ b/architectures/tensorflow_components/shared_variables.py @@ -40,8 +40,9 @@ class SharedRunningStats(object): name="count", trainable=False) self._shape = shape - self._mean = tf.to_float(self._sum / self._count) - self._std = tf.sqrt(tf.maximum(tf.to_float(self._sum_squared / self._count) - tf.square(self._mean), 1e-2)) + self._mean = self._sum / self._count + self._std = tf.sqrt(tf.maximum((self._sum_squared - self._count*tf.square(self._mean)) + / tf.maximum(self._count-1, 1), epsilon)) self.new_sum = tf.placeholder(shape=self.shape, dtype=tf.float64, name='sum') self.new_sum_squared = tf.placeholder(shape=self.shape, dtype=tf.float64, name='var')