# # Copyright (c) 2017 Intel Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from abc import ABC, abstractmethod import threading import pickle import redis import numpy as np from rl_coach.memories.backend.memory_impl import get_memory_backend class SharedRunningStatsSubscribe(threading.Thread): def __init__(self, shared_running_stats): super().__init__() self.shared_running_stats = shared_running_stats self.redis_address = self.shared_running_stats.pubsub.params.redis_address self.redis_port = self.shared_running_stats.pubsub.params.redis_port self.redis_connection = redis.Redis(self.redis_address, self.redis_port) self.pubsub = self.redis_connection.pubsub() self.channel = self.shared_running_stats.channel self.pubsub.subscribe(self.channel) def run(self): for message in self.pubsub.listen(): try: obj = pickle.loads(message['data']) self.shared_running_stats.push_val(obj) except Exception: continue class SharedRunningStats(ABC): def __init__(self, name="", pubsub_params=None): self.name = name self.pubsub = None if pubsub_params: self.channel = "channel-srs-{}".format(self.name) self.pubsub = get_memory_backend(pubsub_params) subscribe_thread = SharedRunningStatsSubscribe(self) subscribe_thread.daemon = True subscribe_thread.start() @abstractmethod def set_params(self, shape=[1], clip_values=None): pass def push(self, x): if self.pubsub: self.pubsub.redis_connection.publish(self.channel, pickle.dumps(x)) return self.push_val(x) @abstractmethod def push_val(self, x): pass @property @abstractmethod def n(self): pass @property @abstractmethod def mean(self): pass @property @abstractmethod def var(self): pass @property @abstractmethod def std(self): pass @property @abstractmethod def shape(self): pass @abstractmethod def normalize(self, batch): pass @abstractmethod def set_session(self, sess): pass class NumpySharedRunningStats(SharedRunningStats): def __init__(self, name, epsilon=1e-2, pubsub_params=None): super().__init__(name=name, pubsub_params=pubsub_params) self._count = epsilon self.epsilon = epsilon def set_params(self, shape=[1], clip_values=None): self._shape = shape self._mean = np.zeros(shape) self._std = np.sqrt(self.epsilon) * np.ones(shape) self._sum = np.zeros(shape) self._sum_squares = self.epsilon * np.ones(shape) self.clip_values = clip_values def push_val(self, samples: np.ndarray): assert len(samples.shape) >= 2 # we should always have a batch dimension assert samples.shape[1:] == self._mean.shape, 'RunningStats input shape mismatch' self._sum += samples.sum(axis=0).ravel() self._sum_squares += np.square(samples).sum(axis=0).ravel() self._count += np.shape(samples)[0] self._mean = self._sum / self._count self._std = np.sqrt(np.maximum( (self._sum_squares - self._count * np.square(self._mean)) / np.maximum(self._count - 1, 1), self.epsilon)) @property def n(self): return self._count @property def mean(self): return self._mean @property def var(self): return self._std ** 2 @property def std(self): return self._std @property def shape(self): return self._mean.shape def normalize(self, batch): batch = (batch - self.mean) / (self.std + 1e-15) return np.clip(batch, *self.clip_values) def set_session(self, sess): # no session for the numpy implementation pass