mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 19:20:19 +01:00
Enable distributed SharedRunningStats (#81)
- Use Redis pub/sub for updating SharedRunningStats.
This commit is contained in:
committed by
Gal Leibovich
parent
875d6ef017
commit
a849c17e46
@@ -1,5 +1,5 @@
|
||||
#
|
||||
# Copyright (c) 2017 Intel Corporation
|
||||
# Copyright (c) 2017 Intel Corporation
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -15,11 +15,16 @@
|
||||
#
|
||||
|
||||
import numpy as np
|
||||
import pickle
|
||||
import redis
|
||||
import tensorflow as tf
|
||||
import threading
|
||||
|
||||
from rl_coach.memories.backend.memory_impl import get_memory_backend
|
||||
|
||||
|
||||
class SharedRunningStats(object):
|
||||
def __init__(self, replicated_device=None, epsilon=1e-2, name="", create_ops=True):
|
||||
def __init__(self, replicated_device=None, epsilon=1e-2, name="", create_ops=True, pubsub_params=None):
|
||||
self.sess = None
|
||||
self.name = name
|
||||
self.replicated_device = replicated_device
|
||||
@@ -28,6 +33,13 @@ class SharedRunningStats(object):
|
||||
if create_ops:
|
||||
with tf.device(replicated_device):
|
||||
self.create_ops()
|
||||
self.pubsub = None
|
||||
if pubsub_params:
|
||||
self.channel = "channel-srs-{}".format(self.name)
|
||||
self.pubsub = get_memory_backend(pubsub_params)
|
||||
subscribe_thread = SharedRunningStatsSubscribe(self)
|
||||
subscribe_thread.daemon = True
|
||||
subscribe_thread.start()
|
||||
|
||||
def create_ops(self, shape=[1], clip_values=None):
|
||||
self.clip_values = clip_values
|
||||
@@ -74,13 +86,20 @@ class SharedRunningStats(object):
|
||||
self.sess = sess
|
||||
|
||||
def push(self, x):
|
||||
if self.pubsub:
|
||||
self.pubsub.redis_connection.publish(self.channel, pickle.dumps(x))
|
||||
return
|
||||
|
||||
self.push_val(x)
|
||||
|
||||
def push_val(self, x):
|
||||
x = x.astype('float64')
|
||||
self.sess.run([self._inc_sum, self._inc_sum_squared, self._inc_count],
|
||||
feed_dict={
|
||||
self.new_sum: x.sum(axis=0).ravel(),
|
||||
self.new_sum_squared: np.square(x).sum(axis=0).ravel(),
|
||||
self.newcount: np.array(len(x), dtype='float64')
|
||||
})
|
||||
feed_dict={
|
||||
self.new_sum: x.sum(axis=0).ravel(),
|
||||
self.new_sum_squared: np.square(x).sum(axis=0).ravel(),
|
||||
self.newcount: np.array(len(x), dtype='float64')
|
||||
})
|
||||
if self._shape is None:
|
||||
self._shape = x.shape
|
||||
|
||||
@@ -119,3 +138,23 @@ class SharedRunningStats(object):
|
||||
return self.sess.run(self.clipped_obs, feed_dict={self.raw_obs: batch})
|
||||
else:
|
||||
return self.sess.run(self.normalized_obs, feed_dict={self.raw_obs: batch})
|
||||
|
||||
|
||||
class SharedRunningStatsSubscribe(threading.Thread):
|
||||
def __init__(self, shared_running_stats):
|
||||
super().__init__()
|
||||
self.shared_running_stats = shared_running_stats
|
||||
self.redis_address = self.shared_running_stats.pubsub.params.redis_address
|
||||
self.redis_port = self.shared_running_stats.pubsub.params.redis_port
|
||||
self.redis_connection = redis.Redis(self.redis_address, self.redis_port)
|
||||
self.pubsub = self.redis_connection.pubsub()
|
||||
self.channel = self.shared_running_stats.channel
|
||||
self.pubsub.subscribe(self.channel)
|
||||
|
||||
def run(self):
|
||||
for message in self.pubsub.listen():
|
||||
try:
|
||||
obj = pickle.loads(message['data'])
|
||||
self.shared_running_stats.push_val(obj)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
Reference in New Issue
Block a user