mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 19:20:19 +01:00
Fixes for having NumpySharedRunningStats syncing on multi-node (#139)
1. Having the standard checkpoint prefix in order for the data store to grab it, and sync it to S3. 2. Removing the reference to Redis so that it won't try to pickle that in. 3. Enable restoring a checkpoint into a single-worker run, which was saved by a single-node-multiple-worker run.
This commit is contained in:
@@ -13,6 +13,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import os
|
||||
import pickle
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
@@ -46,10 +48,10 @@ class TFSharedRunningStats(SharedRunningStats):
|
||||
initializer=tf.constant_initializer(0.0),
|
||||
name="running_sum", trainable=False, shape=shape, validate_shape=False,
|
||||
collections=[tf.GraphKeys.GLOBAL_VARIABLES])
|
||||
self._sum_squared = tf.get_variable(
|
||||
self._sum_squares = tf.get_variable(
|
||||
dtype=tf.float64,
|
||||
initializer=tf.constant_initializer(self.epsilon),
|
||||
name="running_sum_squared", trainable=False, shape=shape, validate_shape=False,
|
||||
name="running_sum_squares", trainable=False, shape=shape, validate_shape=False,
|
||||
collections=[tf.GraphKeys.GLOBAL_VARIABLES])
|
||||
self._count = tf.get_variable(
|
||||
dtype=tf.float64,
|
||||
@@ -59,17 +61,17 @@ class TFSharedRunningStats(SharedRunningStats):
|
||||
|
||||
self._shape = None
|
||||
self._mean = tf.div(self._sum, self._count, name="mean")
|
||||
self._std = tf.sqrt(tf.maximum((self._sum_squared - self._count*tf.square(self._mean))
|
||||
self._std = tf.sqrt(tf.maximum((self._sum_squares - self._count * tf.square(self._mean))
|
||||
/ tf.maximum(self._count-1, 1), self.epsilon), name="stdev")
|
||||
self.tf_mean = tf.cast(self._mean, 'float32')
|
||||
self.tf_std = tf.cast(self._std, 'float32')
|
||||
|
||||
self.new_sum = tf.placeholder(dtype=tf.float64, name='sum')
|
||||
self.new_sum_squared = tf.placeholder(dtype=tf.float64, name='var')
|
||||
self.new_sum_squares = tf.placeholder(dtype=tf.float64, name='var')
|
||||
self.newcount = tf.placeholder(shape=[], dtype=tf.float64, name='count')
|
||||
|
||||
self._inc_sum = tf.assign_add(self._sum, self.new_sum, use_locking=True)
|
||||
self._inc_sum_squared = tf.assign_add(self._sum_squared, self.new_sum_squared, use_locking=True)
|
||||
self._inc_sum_squares = tf.assign_add(self._sum_squares, self.new_sum_squares, use_locking=True)
|
||||
self._inc_count = tf.assign_add(self._count, self.newcount, use_locking=True)
|
||||
|
||||
self.raw_obs = tf.placeholder(dtype=tf.float64, name='raw_obs')
|
||||
@@ -84,10 +86,10 @@ class TFSharedRunningStats(SharedRunningStats):
|
||||
|
||||
def push_val(self, x):
|
||||
x = x.astype('float64')
|
||||
self.sess.run([self._inc_sum, self._inc_sum_squared, self._inc_count],
|
||||
self.sess.run([self._inc_sum, self._inc_sum_squares, self._inc_count],
|
||||
feed_dict={
|
||||
self.new_sum: x.sum(axis=0).ravel(),
|
||||
self.new_sum_squared: np.square(x).sum(axis=0).ravel(),
|
||||
self.new_sum_squares: np.square(x).sum(axis=0).ravel(),
|
||||
self.newcount: np.array(len(x), dtype='float64')
|
||||
})
|
||||
if self._shape is None:
|
||||
@@ -117,11 +119,11 @@ class TFSharedRunningStats(SharedRunningStats):
|
||||
def shape(self, val):
|
||||
self._shape = val
|
||||
self.new_sum.set_shape(val)
|
||||
self.new_sum_squared.set_shape(val)
|
||||
self.new_sum_squares.set_shape(val)
|
||||
self.tf_mean.set_shape(val)
|
||||
self.tf_std.set_shape(val)
|
||||
self._sum.set_shape(val)
|
||||
self._sum_squared.set_shape(val)
|
||||
self._sum_squares.set_shape(val)
|
||||
|
||||
def normalize(self, batch):
|
||||
if self.clip_values is not None:
|
||||
@@ -129,10 +131,25 @@ class TFSharedRunningStats(SharedRunningStats):
|
||||
else:
|
||||
return self.sess.run(self.normalized_obs, feed_dict={self.raw_obs: batch})
|
||||
|
||||
def save_state_to_checkpoint(self, checkpoint_dir: str, checkpoint_id: int):
|
||||
# the stats are part of the TF graph - no need to explicitly save anything
|
||||
pass
|
||||
def save_state_to_checkpoint(self, checkpoint_dir: str, checkpoint_prefix: str):
|
||||
# Since the internal state is maintained as part of the TF graph, no need to do anything special for
|
||||
# save/restore, when going from single-node-multi-thread run back to a single-node-multi-worker run.
|
||||
# Nevertheless, if we'll want to restore a checkpoint back to either a * single-worker, or a
|
||||
# multi-node-multi-worker * run, we have to save the internal state, so that it can be restored to the
|
||||
# NumpySharedRunningStats implementation.
|
||||
|
||||
def restore_state_from_checkpoint(self, checkpoint_dir: str):
|
||||
# the stats are part of the TF graph - no need to explicitly restore anything
|
||||
dict_to_save = {'_mean': self.mean,
|
||||
'_std': self.std,
|
||||
'_count': self.n,
|
||||
'_sum': self.sess.run(self._sum),
|
||||
'_sum_squares': self.sess.run(self._sum_squares)}
|
||||
|
||||
with open(os.path.join(checkpoint_dir, str(checkpoint_prefix) + '.srs'), 'wb') as f:
|
||||
pickle.dump(dict_to_save, f, pickle.HIGHEST_PROTOCOL)
|
||||
|
||||
def restore_state_from_checkpoint(self, checkpoint_dir: str, checkpoint_prefix: str):
|
||||
# Since the internal state is maintained as part of the TF graph, no need to do anything special for
|
||||
# save/restore, when going from single-node-multi-thread run back to a single-node-multi-worker run.
|
||||
# Restoring from either a * single-worker, or a multi-node-multi-worker * run, to a single-node-multi-thread run
|
||||
# is not supported.
|
||||
pass
|
||||
|
||||
Reference in New Issue
Block a user