From 6caf721d1c33cf2fe2a18bdc26f25c8be599ff50 Mon Sep 17 00:00:00 2001 From: Gal Leibovich Date: Sun, 18 Nov 2018 14:46:40 +0200 Subject: [PATCH] Numpy shared running stats (#97) --- rl_coach/agents/agent.py | 27 ++- .../tensorflow_components/shared_variables.py | 54 ++---- rl_coach/coach.py | 2 +- rl_coach/filters/filter.py | 13 +- .../observation_normalization_filter.py | 17 +- .../reward/reward_normalization_filter.py | 14 +- rl_coach/graph_managers/graph_manager.py | 3 +- rl_coach/presets/CartPole_ClippedPPO.py | 4 +- rl_coach/utilities/shared_running_stats.py | 159 ++++++++++++++++++ rl_coach/utils.py | 47 +----- 10 files changed, 226 insertions(+), 114 deletions(-) create mode 100644 rl_coach/utilities/shared_running_stats.py diff --git a/rl_coach/agents/agent.py b/rl_coach/agents/agent.py index bd1fc71..c6e992e 100644 --- a/rl_coach/agents/agent.py +++ b/rl_coach/agents/agent.py @@ -25,7 +25,7 @@ from six.moves import range from rl_coach.agents.agent_interface import AgentInterface from rl_coach.architectures.network_wrapper import NetworkWrapper -from rl_coach.base_parameters import AgentParameters, DistributedTaskParameters +from rl_coach.base_parameters import AgentParameters, DistributedTaskParameters, Frameworks from rl_coach.core_types import RunPhase, PredictionType, EnvironmentEpisodes, ActionType, Batch, Episode, StateType from rl_coach.core_types import Transition, ActionInfo, TrainingSteps, EnvironmentSteps, EnvResponse from rl_coach.logger import screen, Logger, EpisodeLogger @@ -110,14 +110,27 @@ class Agent(AgentInterface): self.output_filter = self.ap.output_filter self.pre_network_filter = self.ap.pre_network_filter device = self.replicated_device if self.replicated_device else self.worker_device + + # TODO-REMOVE This is a temporary flow dividing to 3 modes. To be converged to a single flow once distributed tf + # is removed, and Redis is used for sharing data between local workers. + # Filters MoW will be split between different configurations + # 1. Distributed coach synchrnization type (=distributed across multiple nodes) - Redis based data sharing + numpy arithmetic backend + # 2. Distributed TF (=distributed on a single node, using distributed TF) - TF for both data sharing and arithmetic backend + # 3. Single worker (=both TF and Mxnet) - no data sharing needed + numpy arithmetic backend + if hasattr(self.ap.memory, 'memory_backend_params') and self.ap.algorithm.distributed_coach_synchronization_type: - self.input_filter.set_device(device, memory_backend_params=self.ap.memory.memory_backend_params) - self.output_filter.set_device(device, memory_backend_params=self.ap.memory.memory_backend_params) - self.pre_network_filter.set_device(device, memory_backend_params=self.ap.memory.memory_backend_params) + self.input_filter.set_device(device, memory_backend_params=self.ap.memory.memory_backend_params, mode='numpy') + self.output_filter.set_device(device, memory_backend_params=self.ap.memory.memory_backend_params, mode='numpy') + self.pre_network_filter.set_device(device, memory_backend_params=self.ap.memory.memory_backend_params, mode='numpy') + elif (type(agent_parameters.task_parameters) == DistributedTaskParameters and + agent_parameters.task_parameters.framework_type == Frameworks.tensorflow): + self.input_filter.set_device(device, mode='tf') + self.output_filter.set_device(device, mode='tf') + self.pre_network_filter.set_device(device, mode='tf') else: - self.input_filter.set_device(device) - self.output_filter.set_device(device) - self.pre_network_filter.set_device(device) + self.input_filter.set_device(device, mode='numpy') + self.output_filter.set_device(device, mode='numpy') + self.pre_network_filter.set_device(device, mode='numpy') # initialize all internal variables self._phase = RunPhase.HEATUP diff --git a/rl_coach/architectures/tensorflow_components/shared_variables.py b/rl_coach/architectures/tensorflow_components/shared_variables.py index 8748885..33a2c05 100644 --- a/rl_coach/architectures/tensorflow_components/shared_variables.py +++ b/rl_coach/architectures/tensorflow_components/shared_variables.py @@ -15,33 +15,30 @@ # import numpy as np -import pickle -import redis import tensorflow as tf -import threading -from rl_coach.memories.backend.memory_impl import get_memory_backend +from rl_coach.utilities.shared_running_stats import SharedRunningStats -class SharedRunningStats(object): +class TFSharedRunningStats(SharedRunningStats): def __init__(self, replicated_device=None, epsilon=1e-2, name="", create_ops=True, pubsub_params=None): + super().__init__(name=name, pubsub_params=pubsub_params) self.sess = None - self.name = name self.replicated_device = replicated_device self.epsilon = epsilon self.ops_were_created = False if create_ops: with tf.device(replicated_device): - self.create_ops() - self.pubsub = None - if pubsub_params: - self.channel = "channel-srs-{}".format(self.name) - self.pubsub = get_memory_backend(pubsub_params) - subscribe_thread = SharedRunningStatsSubscribe(self) - subscribe_thread.daemon = True - subscribe_thread.start() + self.set_params() + + def set_params(self, shape=[1], clip_values=None): + """ + set params and create ops + + :param shape: shape of the stats to track + :param clip_values: if not None, sets clip min/max thresholds + """ - def create_ops(self, shape=[1], clip_values=None): self.clip_values = clip_values with tf.variable_scope(self.name): self._sum = tf.get_variable( @@ -85,13 +82,6 @@ class SharedRunningStats(object): def set_session(self, sess): self.sess = sess - def push(self, x): - if self.pubsub: - self.pubsub.redis_connection.publish(self.channel, pickle.dumps(x)) - return - - self.push_val(x) - def push_val(self, x): x = x.astype('float64') self.sess.run([self._inc_sum, self._inc_sum_squared, self._inc_count], @@ -138,23 +128,3 @@ class SharedRunningStats(object): return self.sess.run(self.clipped_obs, feed_dict={self.raw_obs: batch}) else: return self.sess.run(self.normalized_obs, feed_dict={self.raw_obs: batch}) - - -class SharedRunningStatsSubscribe(threading.Thread): - def __init__(self, shared_running_stats): - super().__init__() - self.shared_running_stats = shared_running_stats - self.redis_address = self.shared_running_stats.pubsub.params.redis_address - self.redis_port = self.shared_running_stats.pubsub.params.redis_port - self.redis_connection = redis.Redis(self.redis_address, self.redis_port) - self.pubsub = self.redis_connection.pubsub() - self.channel = self.shared_running_stats.channel - self.pubsub.subscribe(self.channel) - - def run(self): - for message in self.pubsub.listen(): - try: - obj = pickle.loads(message['data']) - self.shared_running_stats.push_val(obj) - except Exception: - continue diff --git a/rl_coach/coach.py b/rl_coach/coach.py index 0cc982f..6ae8ba3 100644 --- a/rl_coach/coach.py +++ b/rl_coach/coach.py @@ -78,7 +78,7 @@ def start_graph(graph_manager: 'GraphManager', task_parameters: 'TaskParameters' # let the adventure begin if task_parameters.evaluate_only: - graph_manager.evaluate(EnvironmentSteps(sys.maxsize), keep_networks_in_sync=True) + graph_manager.evaluate(EnvironmentSteps(sys.maxsize)) else: graph_manager.improve() diff --git a/rl_coach/filters/filter.py b/rl_coach/filters/filter.py index 705aa64..dbf59f2 100644 --- a/rl_coach/filters/filter.py +++ b/rl_coach/filters/filter.py @@ -46,11 +46,12 @@ class Filter(object): """ raise NotImplementedError("") - def set_device(self, device, memory_backend_params=None) -> None: + def set_device(self, device, memory_backend_params=None, mode='numpy') -> None: """ An optional function that allows the filter to get the device if it is required to use tensorflow ops :param device: the device to use :param memory_backend_params: parameters associated with the memory backend + :param mode: arithmetic backend to be used {numpy | tf} :return: None """ pass @@ -85,13 +86,13 @@ class OutputFilter(Filter): duplicate.i_am_a_reference_filter = False return duplicate - def set_device(self, device, memory_backend_params=None) -> None: + def set_device(self, device, memory_backend_params=None, mode='numpy') -> None: """ An optional function that allows the filter to get the device if it is required to use tensorflow ops :param device: the device to use :return: None """ - [f.set_device(device, memory_backend_params) for f in self.action_filters.values()] + [f.set_device(device, memory_backend_params, mode='numpy') for f in self.action_filters.values()] def set_session(self, sess) -> None: """ @@ -226,14 +227,14 @@ class InputFilter(Filter): duplicate.i_am_a_reference_filter = False return duplicate - def set_device(self, device, memory_backend_params=None) -> None: + def set_device(self, device, memory_backend_params=None, mode='numpy') -> None: """ An optional function that allows the filter to get the device if it is required to use tensorflow ops :param device: the device to use :return: None """ - [f.set_device(device, memory_backend_params) for f in self.reward_filters.values()] - [[f.set_device(device, memory_backend_params) for f in filters.values()] for filters in self.observation_filters.values()] + [f.set_device(device, memory_backend_params, mode) for f in self.reward_filters.values()] + [[f.set_device(device, memory_backend_params, mode) for f in filters.values()] for filters in self.observation_filters.values()] def set_session(self, sess) -> None: """ diff --git a/rl_coach/filters/observation/observation_normalization_filter.py b/rl_coach/filters/observation/observation_normalization_filter.py index 6ecc057..796ef31 100644 --- a/rl_coach/filters/observation/observation_normalization_filter.py +++ b/rl_coach/filters/observation/observation_normalization_filter.py @@ -17,10 +17,11 @@ from typing import List import numpy as np -from rl_coach.architectures.tensorflow_components.shared_variables import SharedRunningStats +from rl_coach.architectures.tensorflow_components.shared_variables import SharedRunningStats, TFSharedRunningStats from rl_coach.core_types import ObservationType from rl_coach.filters.observation.observation_filter import ObservationFilter from rl_coach.spaces import ObservationSpace +from rl_coach.utilities.shared_running_stats import NumpySharedRunningStats, NumpySharedRunningStats class ObservationNormalizationFilter(ObservationFilter): @@ -42,14 +43,20 @@ class ObservationNormalizationFilter(ObservationFilter): self.supports_batching = True self.observation_space = None - def set_device(self, device, memory_backend_params=None) -> None: + def set_device(self, device, memory_backend_params=None, mode='numpy') -> None: """ An optional function that allows the filter to get the device if it is required to use tensorflow ops :param device: the device to use + :memory_backend_params: if not None, holds params for a memory backend for sharing data (e.g. Redis) + :param mode: the arithmetic module to use {'tf' | 'numpy'} :return: None """ - self.running_observation_stats = SharedRunningStats(device, name=self.name, create_ops=False, + if mode == 'tf': + self.running_observation_stats = TFSharedRunningStats(device, name=self.name, create_ops=False, pubsub_params=memory_backend_params) + elif mode == 'numpy': + self.running_observation_stats = NumpySharedRunningStats(name=self.name, + pubsub_params=memory_backend_params) def set_session(self, sess) -> None: """ @@ -66,11 +73,9 @@ class ObservationNormalizationFilter(ObservationFilter): self.last_mean = self.running_observation_stats.mean self.last_stdev = self.running_observation_stats.std - # TODO: make sure that a batch is given here return self.running_observation_stats.normalize(observations) def get_filtered_observation_space(self, input_observation_space: ObservationSpace) -> ObservationSpace: - - self.running_observation_stats.create_ops(shape=input_observation_space.shape, + self.running_observation_stats.set_params(shape=input_observation_space.shape, clip_values=(self.clip_min, self.clip_max)) return input_observation_space diff --git a/rl_coach/filters/reward/reward_normalization_filter.py b/rl_coach/filters/reward/reward_normalization_filter.py index fd7dfed..daf5562 100644 --- a/rl_coach/filters/reward/reward_normalization_filter.py +++ b/rl_coach/filters/reward/reward_normalization_filter.py @@ -17,10 +17,11 @@ import numpy as np -from rl_coach.architectures.tensorflow_components.shared_variables import SharedRunningStats +from rl_coach.architectures.tensorflow_components.shared_variables import TFSharedRunningStats from rl_coach.core_types import RewardType from rl_coach.filters.reward.reward_filter import RewardFilter from rl_coach.spaces import RewardSpace +from rl_coach.utilities.shared_running_stats import NumpySharedRunningStats class RewardNormalizationFilter(RewardFilter): @@ -39,14 +40,19 @@ class RewardNormalizationFilter(RewardFilter): self.clip_max = clip_max self.running_rewards_stats = None - def set_device(self, device, memory_backend_params=None) -> None: + def set_device(self, device, memory_backend_params=None, mode='numpy') -> None: """ An optional function that allows the filter to get the device if it is required to use tensorflow ops :param device: the device to use :return: None """ - self.running_rewards_stats = SharedRunningStats(device, name='rewards_stats', - pubsub_params=memory_backend_params) + + if mode == 'tf': + self.running_rewards_stats = TFSharedRunningStats(device, name='rewards_stats', create_ops=False, + pubsub_params=memory_backend_params) + elif mode == 'numpy': + self.running_rewards_stats = NumpySharedRunningStats(name='rewards_stats', + pubsub_params=memory_backend_params) def set_session(self, sess) -> None: """ diff --git a/rl_coach/graph_managers/graph_manager.py b/rl_coach/graph_managers/graph_manager.py index ef67776..45ae4f6 100644 --- a/rl_coach/graph_managers/graph_manager.py +++ b/rl_coach/graph_managers/graph_manager.py @@ -462,11 +462,10 @@ class GraphManager(object): """ [manager.sync() for manager in self.level_managers] - def evaluate(self, steps: PlayingStepsType, keep_networks_in_sync: bool=False) -> bool: + def evaluate(self, steps: PlayingStepsType) -> bool: """ Perform evaluation for several steps :param steps: the number of steps as a tuple of steps time and steps count - :param keep_networks_in_sync: sync the network parameters with the global network before each episode :return: bool, True if the target reward and target success has been reached """ self.verify_graph_was_created() diff --git a/rl_coach/presets/CartPole_ClippedPPO.py b/rl_coach/presets/CartPole_ClippedPPO.py index 35dfbda..6aa9d4c 100644 --- a/rl_coach/presets/CartPole_ClippedPPO.py +++ b/rl_coach/presets/CartPole_ClippedPPO.py @@ -5,6 +5,7 @@ from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentS from rl_coach.environments.gym_environment import GymVectorEnvironment, mujoco_v2 from rl_coach.exploration_policies.additive_noise import AdditiveNoiseParameters from rl_coach.exploration_policies.e_greedy import EGreedyParameters +from rl_coach.filters.observation.observation_normalization_filter import ObservationNormalizationFilter from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager from rl_coach.graph_managers.graph_manager import ScheduleParameters from rl_coach.schedules import LinearSchedule @@ -48,7 +49,8 @@ agent_params.algorithm.distributed_coach_synchronization_type = DistributedCoach agent_params.exploration = EGreedyParameters() agent_params.exploration.epsilon_schedule = LinearSchedule(1.0, 0.01, 10000) - +agent_params.pre_network_filter.add_observation_filter('observation', 'normalize_observation', + ObservationNormalizationFilter(name='normalize_observation')) ############### # Environment # ############### diff --git a/rl_coach/utilities/shared_running_stats.py b/rl_coach/utilities/shared_running_stats.py new file mode 100644 index 0000000..6d865ad --- /dev/null +++ b/rl_coach/utilities/shared_running_stats.py @@ -0,0 +1,159 @@ +# +# Copyright (c) 2017 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from abc import ABC, abstractmethod +import threading +import pickle +import redis +import numpy as np + +from rl_coach.memories.backend.memory_impl import get_memory_backend + + +class SharedRunningStatsSubscribe(threading.Thread): + def __init__(self, shared_running_stats): + super().__init__() + self.shared_running_stats = shared_running_stats + self.redis_address = self.shared_running_stats.pubsub.params.redis_address + self.redis_port = self.shared_running_stats.pubsub.params.redis_port + self.redis_connection = redis.Redis(self.redis_address, self.redis_port) + self.pubsub = self.redis_connection.pubsub() + self.channel = self.shared_running_stats.channel + self.pubsub.subscribe(self.channel) + + def run(self): + for message in self.pubsub.listen(): + try: + obj = pickle.loads(message['data']) + self.shared_running_stats.push_val(obj) + except Exception: + continue + + +class SharedRunningStats(ABC): + def __init__(self, name="", pubsub_params=None): + self.name = name + self.pubsub = None + if pubsub_params: + self.channel = "channel-srs-{}".format(self.name) + self.pubsub = get_memory_backend(pubsub_params) + subscribe_thread = SharedRunningStatsSubscribe(self) + subscribe_thread.daemon = True + subscribe_thread.start() + + @abstractmethod + def set_params(self, shape=[1], clip_values=None): + pass + + def push(self, x): + if self.pubsub: + self.pubsub.redis_connection.publish(self.channel, pickle.dumps(x)) + return + + self.push_val(x) + + @abstractmethod + def push_val(self, x): + pass + + @property + @abstractmethod + def n(self): + pass + + @property + @abstractmethod + def mean(self): + pass + + @property + @abstractmethod + def var(self): + pass + + @property + @abstractmethod + def std(self): + pass + + @property + @abstractmethod + def shape(self): + pass + + @abstractmethod + def normalize(self, batch): + pass + + @abstractmethod + def set_session(self, sess): + pass + + +class NumpySharedRunningStats(SharedRunningStats): + def __init__(self, name, epsilon=1e-2, pubsub_params=None): + super().__init__(name=name, pubsub_params=pubsub_params) + self._count = epsilon + self.epsilon = epsilon + + def set_params(self, shape=[1], clip_values=None): + self._shape = shape + self._mean = np.zeros(shape) + self._std = np.sqrt(self.epsilon) * np.ones(shape) + self._sum = np.zeros(shape) + self._sum_squares = self.epsilon * np.ones(shape) + self.clip_values = clip_values + + def push_val(self, samples: np.ndarray): + assert len(samples.shape) >= 2 # we should always have a batch dimension + assert samples.shape[1:] == self._mean.shape, 'RunningStats input shape mismatch' + self._sum += samples.sum(axis=0).ravel() + self._sum_squares += np.square(samples).sum(axis=0).ravel() + self._count += np.shape(samples)[0] + self._mean = self._sum / self._count + self._std = np.sqrt(np.maximum( + (self._sum_squares - self._count * np.square(self._mean)) / np.maximum(self._count - 1, 1), + self.epsilon)) + + @property + def n(self): + return self._count + + @property + def mean(self): + return self._mean + + @property + def var(self): + return self._std ** 2 + + @property + def std(self): + return self._std + + @property + def shape(self): + return self._mean.shape + + def normalize(self, batch): + batch = (batch - self.mean) / (self.std + 1e-15) + return np.clip(batch, *self.clip_values) + + def set_session(self, sess): + # no session for the numpy implementation + pass + + diff --git a/rl_coach/utils.py b/rl_coach/utils.py index 76729b3..61808c8 100644 --- a/rl_coach/utils.py +++ b/rl_coach/utils.py @@ -239,51 +239,6 @@ def squeeze_list(var): return var -# http://www.johndcook.com/blog/standard_deviation/ -class RunningStat(object): - def __init__(self, shape): - self._shape = shape - self._num_samples = 0 - self._mean = np.zeros(shape) - self._std = np.zeros(shape) - - def reset(self): - self._num_samples = 0 - self._mean = np.zeros(self._shape) - self._std = np.zeros(self._shape) - - def push(self, sample): - sample = np.asarray(sample) - assert sample.shape == self._mean.shape, 'RunningStat input shape mismatch' - self._num_samples += 1 - if self._num_samples == 1: - self._mean[...] = sample - else: - old_mean = self._mean.copy() - self._mean[...] = old_mean + (sample - old_mean) / self._num_samples - self._std[...] = self._std + (sample - old_mean) * (sample - self._mean) - - @property - def n(self): - return self._num_samples - - @property - def mean(self): - return self._mean - - @property - def var(self): - return self._std / (self._num_samples - 1) if self._num_samples > 1 else np.square(self._mean) - - @property - def std(self): - return np.sqrt(self.var) - - @property - def shape(self): - return self._mean.shape - - def get_open_port(): import socket s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) @@ -590,3 +545,5 @@ def start_shell_command_and_wait(command): def indent_string(string): return '\t' + string.replace('\n', '\n\t') + +