1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 11:10:20 +01:00
Files
coach/rl_coach/utilities/shared_running_stats.py
Gal Leibovich c1d1fae342 Distiller's AMC induced changes (#359)
* override episode rewards with the last transition reward

* EWMA normalization filter

* allowing control over when the pre_network filter runs
2019-08-05 10:24:58 +03:00

189 lines
5.8 KiB
Python

#
# Copyright (c) 2017 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os
from abc import ABC, abstractmethod
import threading
import pickle
import redis
import numpy as np
from rl_coach.utils import get_latest_checkpoint
class SharedRunningStatsSubscribe(threading.Thread):
def __init__(self, shared_running_stats):
super().__init__()
self.shared_running_stats = shared_running_stats
self.redis_address = self.shared_running_stats.pubsub.params.redis_address
self.redis_port = self.shared_running_stats.pubsub.params.redis_port
self.redis_connection = redis.Redis(self.redis_address, self.redis_port)
self.pubsub = self.redis_connection.pubsub()
self.channel = self.shared_running_stats.channel
self.pubsub.subscribe(self.channel)
def run(self):
for message in self.pubsub.listen():
try:
obj = pickle.loads(message['data'])
self.shared_running_stats.push_val(obj)
except Exception:
continue
class SharedRunningStats(ABC):
def __init__(self, name="", pubsub_params=None):
self.name = name
self.pubsub = None
if pubsub_params:
self.channel = "channel-srs-{}".format(self.name)
from rl_coach.memories.backend.memory_impl import get_memory_backend
self.pubsub = get_memory_backend(pubsub_params)
subscribe_thread = SharedRunningStatsSubscribe(self)
subscribe_thread.daemon = True
subscribe_thread.start()
@abstractmethod
def set_params(self, shape=[1], clip_values=None):
pass
def push(self, x):
if self.pubsub:
self.pubsub.redis_connection.publish(self.channel, pickle.dumps(x))
return
self.push_val(x)
@abstractmethod
def push_val(self, x):
pass
@property
@abstractmethod
def n(self):
pass
@property
@abstractmethod
def mean(self):
pass
@property
@abstractmethod
def var(self):
pass
@property
@abstractmethod
def std(self):
pass
@property
@abstractmethod
def shape(self):
pass
@abstractmethod
def normalize(self, batch):
pass
@abstractmethod
def set_session(self, sess):
pass
@abstractmethod
def save_state_to_checkpoint(self, checkpoint_dir: str, checkpoint_prefix: int):
pass
@abstractmethod
def restore_state_from_checkpoint(self, checkpoint_dir: str, checkpoint_prefix: str):
pass
class NumpySharedRunningStats(SharedRunningStats):
def __init__(self, name, epsilon=1e-2, pubsub_params=None):
super().__init__(name=name, pubsub_params=pubsub_params)
self._count = epsilon
self.epsilon = epsilon
self.checkpoint_file_extension = 'srs'
def set_params(self, shape=[1], clip_values=None):
self._shape = shape
self._mean = np.zeros(shape)
self._std = np.sqrt(self.epsilon) * np.ones(shape)
self._sum = np.zeros(shape)
self._sum_squares = self.epsilon * np.ones(shape)
self.clip_values = clip_values
def push_val(self, samples: np.ndarray):
assert len(samples.shape) >= 2 # we should always have a batch dimension
assert samples.shape[1:] == self._mean.shape, 'RunningStats input shape mismatch'
self._sum += samples.sum(axis=0).ravel()
self._sum_squares += np.square(samples).sum(axis=0).ravel()
self._count += np.shape(samples)[0]
self._mean = self._sum / self._count
self._std = np.sqrt(np.maximum(
(self._sum_squares - self._count * np.square(self._mean)) / np.maximum(self._count - 1, 1),
self.epsilon))
@property
def n(self):
return self._count
@property
def mean(self):
return self._mean
@property
def var(self):
return self._std ** 2
@property
def std(self):
return self._std
@property
def shape(self):
return self._mean.shape
def normalize(self, batch):
batch = (batch - self.mean) / (self.std + 1e-15)
return np.clip(batch, *self.clip_values)
def set_session(self, sess):
# no session for the numpy implementation
pass
def save_state_to_checkpoint(self, checkpoint_dir: str, checkpoint_prefix: int):
dict_to_save = {'_mean': self._mean,
'_std': self._std,
'_count': self._count,
'_sum': self._sum,
'_sum_squares': self._sum_squares}
with open(os.path.join(checkpoint_dir, str(checkpoint_prefix) + '.' + self.checkpoint_file_extension), 'wb') as f:
pickle.dump(dict_to_save, f, pickle.HIGHEST_PROTOCOL)
def restore_state_from_checkpoint(self, checkpoint_dir: str, checkpoint_prefix: str):
latest_checkpoint_filename = get_latest_checkpoint(checkpoint_dir, checkpoint_prefix,
self.checkpoint_file_extension)
if latest_checkpoint_filename == '':
raise ValueError("Could not find NumpySharedRunningStats checkpoint file. ")
with open(os.path.join(checkpoint_dir, str(latest_checkpoint_filename)), 'rb') as f:
saved_dict = pickle.load(f)
self.__dict__.update(saved_dict)