1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 19:20:19 +01:00

Numpy shared running stats (#97)

This commit is contained in:
Gal Leibovich
2018-11-18 14:46:40 +02:00
committed by GitHub
parent e1fa6e9681
commit 6caf721d1c
10 changed files with 226 additions and 114 deletions

View File

@@ -15,33 +15,30 @@
#
import numpy as np
import pickle
import redis
import tensorflow as tf
import threading
from rl_coach.memories.backend.memory_impl import get_memory_backend
from rl_coach.utilities.shared_running_stats import SharedRunningStats
class SharedRunningStats(object):
class TFSharedRunningStats(SharedRunningStats):
def __init__(self, replicated_device=None, epsilon=1e-2, name="", create_ops=True, pubsub_params=None):
super().__init__(name=name, pubsub_params=pubsub_params)
self.sess = None
self.name = name
self.replicated_device = replicated_device
self.epsilon = epsilon
self.ops_were_created = False
if create_ops:
with tf.device(replicated_device):
self.create_ops()
self.pubsub = None
if pubsub_params:
self.channel = "channel-srs-{}".format(self.name)
self.pubsub = get_memory_backend(pubsub_params)
subscribe_thread = SharedRunningStatsSubscribe(self)
subscribe_thread.daemon = True
subscribe_thread.start()
self.set_params()
def set_params(self, shape=[1], clip_values=None):
"""
set params and create ops
:param shape: shape of the stats to track
:param clip_values: if not None, sets clip min/max thresholds
"""
def create_ops(self, shape=[1], clip_values=None):
self.clip_values = clip_values
with tf.variable_scope(self.name):
self._sum = tf.get_variable(
@@ -85,13 +82,6 @@ class SharedRunningStats(object):
def set_session(self, sess):
self.sess = sess
def push(self, x):
if self.pubsub:
self.pubsub.redis_connection.publish(self.channel, pickle.dumps(x))
return
self.push_val(x)
def push_val(self, x):
x = x.astype('float64')
self.sess.run([self._inc_sum, self._inc_sum_squared, self._inc_count],
@@ -138,23 +128,3 @@ class SharedRunningStats(object):
return self.sess.run(self.clipped_obs, feed_dict={self.raw_obs: batch})
else:
return self.sess.run(self.normalized_obs, feed_dict={self.raw_obs: batch})
class SharedRunningStatsSubscribe(threading.Thread):
def __init__(self, shared_running_stats):
super().__init__()
self.shared_running_stats = shared_running_stats
self.redis_address = self.shared_running_stats.pubsub.params.redis_address
self.redis_port = self.shared_running_stats.pubsub.params.redis_port
self.redis_connection = redis.Redis(self.redis_address, self.redis_port)
self.pubsub = self.redis_connection.pubsub()
self.channel = self.shared_running_stats.channel
self.pubsub.subscribe(self.channel)
def run(self):
for message in self.pubsub.listen():
try:
obj = pickle.loads(message['data'])
self.shared_running_stats.push_val(obj)
except Exception:
continue