From a849c17e46ea7188a7b6229c9f68417803e730ad Mon Sep 17 00:00:00 2001 From: Balaji Subramaniam Date: Tue, 13 Nov 2018 09:17:38 -0800 Subject: [PATCH] Enable distributed SharedRunningStats (#81) - Use Redis pub/sub for updating SharedRunningStats. --- rl_coach/agents/agent.py | 11 ++-- rl_coach/agents/ppo_agent.py | 2 +- .../tensorflow_components/shared_variables.py | 53 ++++++++++++++++--- rl_coach/filters/filter.py | 13 ++--- .../observation_normalization_filter.py | 5 +- .../reward/reward_normalization_filter.py | 5 +- rl_coach/presets/CartPole_ClippedPPO.py | 3 -- rl_coach/presets/CartPole_DQN.py | 2 +- rl_coach/presets/Mujoco_ClippedPPO.py | 4 +- rl_coach/presets/Mujoco_PPO.py | 5 +- 10 files changed, 76 insertions(+), 27 deletions(-) diff --git a/rl_coach/agents/agent.py b/rl_coach/agents/agent.py index 9b4fef6..c26282a 100644 --- a/rl_coach/agents/agent.py +++ b/rl_coach/agents/agent.py @@ -112,9 +112,14 @@ class Agent(AgentInterface): self.output_filter = self.ap.output_filter self.pre_network_filter = self.ap.pre_network_filter device = self.replicated_device if self.replicated_device else self.worker_device - self.input_filter.set_device(device) - self.output_filter.set_device(device) - self.pre_network_filter.set_device(device) + if hasattr(self.ap.memory, 'memory_backend_params') and self.ap.algorithm.distributed_coach_synchronization_type: + self.input_filter.set_device(device, memory_backend_params=self.ap.memory.memory_backend_params) + self.output_filter.set_device(device, memory_backend_params=self.ap.memory.memory_backend_params) + self.pre_network_filter.set_device(device, memory_backend_params=self.ap.memory.memory_backend_params) + else: + self.input_filter.set_device(device) + self.output_filter.set_device(device) + self.pre_network_filter.set_device(device) # initialize all internal variables self._phase = RunPhase.HEATUP diff --git a/rl_coach/agents/ppo_agent.py b/rl_coach/agents/ppo_agent.py index d455caa..64539e9 100644 --- a/rl_coach/agents/ppo_agent.py +++ b/rl_coach/agents/ppo_agent.py @@ -310,7 +310,7 @@ class PPOAgent(ActorCriticAgent): # clean memory self.call_memory('clean') - def _should_train_helper(self): + def _should_train_helper(self, wait_for_full_episode=True): return super()._should_train_helper(True) def train(self): diff --git a/rl_coach/architectures/tensorflow_components/shared_variables.py b/rl_coach/architectures/tensorflow_components/shared_variables.py index 1a3289c..8748885 100644 --- a/rl_coach/architectures/tensorflow_components/shared_variables.py +++ b/rl_coach/architectures/tensorflow_components/shared_variables.py @@ -1,5 +1,5 @@ # -# Copyright (c) 2017 Intel Corporation +# Copyright (c) 2017 Intel Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -15,11 +15,16 @@ # import numpy as np +import pickle +import redis import tensorflow as tf +import threading + +from rl_coach.memories.backend.memory_impl import get_memory_backend class SharedRunningStats(object): - def __init__(self, replicated_device=None, epsilon=1e-2, name="", create_ops=True): + def __init__(self, replicated_device=None, epsilon=1e-2, name="", create_ops=True, pubsub_params=None): self.sess = None self.name = name self.replicated_device = replicated_device @@ -28,6 +33,13 @@ class SharedRunningStats(object): if create_ops: with tf.device(replicated_device): self.create_ops() + self.pubsub = None + if pubsub_params: + self.channel = "channel-srs-{}".format(self.name) + self.pubsub = get_memory_backend(pubsub_params) + subscribe_thread = SharedRunningStatsSubscribe(self) + subscribe_thread.daemon = True + subscribe_thread.start() def create_ops(self, shape=[1], clip_values=None): self.clip_values = clip_values @@ -74,13 +86,20 @@ class SharedRunningStats(object): self.sess = sess def push(self, x): + if self.pubsub: + self.pubsub.redis_connection.publish(self.channel, pickle.dumps(x)) + return + + self.push_val(x) + + def push_val(self, x): x = x.astype('float64') self.sess.run([self._inc_sum, self._inc_sum_squared, self._inc_count], - feed_dict={ - self.new_sum: x.sum(axis=0).ravel(), - self.new_sum_squared: np.square(x).sum(axis=0).ravel(), - self.newcount: np.array(len(x), dtype='float64') - }) + feed_dict={ + self.new_sum: x.sum(axis=0).ravel(), + self.new_sum_squared: np.square(x).sum(axis=0).ravel(), + self.newcount: np.array(len(x), dtype='float64') + }) if self._shape is None: self._shape = x.shape @@ -119,3 +138,23 @@ class SharedRunningStats(object): return self.sess.run(self.clipped_obs, feed_dict={self.raw_obs: batch}) else: return self.sess.run(self.normalized_obs, feed_dict={self.raw_obs: batch}) + + +class SharedRunningStatsSubscribe(threading.Thread): + def __init__(self, shared_running_stats): + super().__init__() + self.shared_running_stats = shared_running_stats + self.redis_address = self.shared_running_stats.pubsub.params.redis_address + self.redis_port = self.shared_running_stats.pubsub.params.redis_port + self.redis_connection = redis.Redis(self.redis_address, self.redis_port) + self.pubsub = self.redis_connection.pubsub() + self.channel = self.shared_running_stats.channel + self.pubsub.subscribe(self.channel) + + def run(self): + for message in self.pubsub.listen(): + try: + obj = pickle.loads(message['data']) + self.shared_running_stats.push_val(obj) + except Exception: + continue diff --git a/rl_coach/filters/filter.py b/rl_coach/filters/filter.py index 35d7e7c..705aa64 100644 --- a/rl_coach/filters/filter.py +++ b/rl_coach/filters/filter.py @@ -46,10 +46,11 @@ class Filter(object): """ raise NotImplementedError("") - def set_device(self, device) -> None: + def set_device(self, device, memory_backend_params=None) -> None: """ An optional function that allows the filter to get the device if it is required to use tensorflow ops :param device: the device to use + :param memory_backend_params: parameters associated with the memory backend :return: None """ pass @@ -84,13 +85,13 @@ class OutputFilter(Filter): duplicate.i_am_a_reference_filter = False return duplicate - def set_device(self, device) -> None: + def set_device(self, device, memory_backend_params=None) -> None: """ An optional function that allows the filter to get the device if it is required to use tensorflow ops :param device: the device to use :return: None """ - [f.set_device(device) for f in self.action_filters.values()] + [f.set_device(device, memory_backend_params) for f in self.action_filters.values()] def set_session(self, sess) -> None: """ @@ -225,14 +226,14 @@ class InputFilter(Filter): duplicate.i_am_a_reference_filter = False return duplicate - def set_device(self, device) -> None: + def set_device(self, device, memory_backend_params=None) -> None: """ An optional function that allows the filter to get the device if it is required to use tensorflow ops :param device: the device to use :return: None """ - [f.set_device(device) for f in self.reward_filters.values()] - [[f.set_device(device) for f in filters.values()] for filters in self.observation_filters.values()] + [f.set_device(device, memory_backend_params) for f in self.reward_filters.values()] + [[f.set_device(device, memory_backend_params) for f in filters.values()] for filters in self.observation_filters.values()] def set_session(self, sess) -> None: """ diff --git a/rl_coach/filters/observation/observation_normalization_filter.py b/rl_coach/filters/observation/observation_normalization_filter.py index 178036d..21be759 100644 --- a/rl_coach/filters/observation/observation_normalization_filter.py +++ b/rl_coach/filters/observation/observation_normalization_filter.py @@ -41,13 +41,14 @@ class ObservationNormalizationFilter(ObservationFilter): self.supports_batching = True self.observation_space = None - def set_device(self, device) -> None: + def set_device(self, device, memory_backend_params=None) -> None: """ An optional function that allows the filter to get the device if it is required to use tensorflow ops :param device: the device to use :return: None """ - self.running_observation_stats = SharedRunningStats(device, name=self.name, create_ops=False) + self.running_observation_stats = SharedRunningStats(device, name=self.name, create_ops=False, + pubsub_params=memory_backend_params) def set_session(self, sess) -> None: """ diff --git a/rl_coach/filters/reward/reward_normalization_filter.py b/rl_coach/filters/reward/reward_normalization_filter.py index fa33a4e..cd46995 100644 --- a/rl_coach/filters/reward/reward_normalization_filter.py +++ b/rl_coach/filters/reward/reward_normalization_filter.py @@ -38,13 +38,14 @@ class RewardNormalizationFilter(RewardFilter): self.clip_max = clip_max self.running_rewards_stats = None - def set_device(self, device) -> None: + def set_device(self, device, memory_backend_params=None) -> None: """ An optional function that allows the filter to get the device if it is required to use tensorflow ops :param device: the device to use :return: None """ - self.running_rewards_stats = SharedRunningStats(device, name='rewards_stats') + self.running_rewards_stats = SharedRunningStats(device, name='rewards_stats', + pubsub_params=memory_backend_params) def set_session(self, sess) -> None: """ diff --git a/rl_coach/presets/CartPole_ClippedPPO.py b/rl_coach/presets/CartPole_ClippedPPO.py index f400478..35dfbda 100644 --- a/rl_coach/presets/CartPole_ClippedPPO.py +++ b/rl_coach/presets/CartPole_ClippedPPO.py @@ -5,7 +5,6 @@ from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentS from rl_coach.environments.gym_environment import GymVectorEnvironment, mujoco_v2 from rl_coach.exploration_policies.additive_noise import AdditiveNoiseParameters from rl_coach.exploration_policies.e_greedy import EGreedyParameters -from rl_coach.filters.observation.observation_normalization_filter import ObservationNormalizationFilter from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager from rl_coach.graph_managers.graph_manager import ScheduleParameters from rl_coach.schedules import LinearSchedule @@ -49,8 +48,6 @@ agent_params.algorithm.distributed_coach_synchronization_type = DistributedCoach agent_params.exploration = EGreedyParameters() agent_params.exploration.epsilon_schedule = LinearSchedule(1.0, 0.01, 10000) -# agent_params.pre_network_filter.add_observation_filter('observation', 'normalize_observation', -# ObservationNormalizationFilter(name='normalize_observation')) ############### # Environment # diff --git a/rl_coach/presets/CartPole_DQN.py b/rl_coach/presets/CartPole_DQN.py index ba4472f..02a38d7 100644 --- a/rl_coach/presets/CartPole_DQN.py +++ b/rl_coach/presets/CartPole_DQN.py @@ -1,5 +1,5 @@ from rl_coach.agents.dqn_agent import DQNAgentParameters -from rl_coach.base_parameters import VisualizationParameters, PresetValidationParameters, DistributedCoachSynchronizationType +from rl_coach.base_parameters import VisualizationParameters, PresetValidationParameters from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps from rl_coach.environments.gym_environment import GymVectorEnvironment from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager diff --git a/rl_coach/presets/Mujoco_ClippedPPO.py b/rl_coach/presets/Mujoco_ClippedPPO.py index ca2d662..d7ec89c 100644 --- a/rl_coach/presets/Mujoco_ClippedPPO.py +++ b/rl_coach/presets/Mujoco_ClippedPPO.py @@ -1,6 +1,6 @@ from rl_coach.agents.clipped_ppo_agent import ClippedPPOAgentParameters from rl_coach.architectures.layers import Dense -from rl_coach.base_parameters import VisualizationParameters, PresetValidationParameters +from rl_coach.base_parameters import VisualizationParameters, PresetValidationParameters, DistributedCoachSynchronizationType from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps from rl_coach.environments.environment import SingleLevelSelection from rl_coach.environments.gym_environment import GymVectorEnvironment, mujoco_v2 @@ -43,6 +43,8 @@ agent_params.algorithm.gae_lambda = 0.95 agent_params.algorithm.discount = 0.99 agent_params.algorithm.optimization_epochs = 10 agent_params.algorithm.estimate_state_value_using_gae = True +# Distributed Coach synchronization type. +agent_params.algorithm.distributed_coach_synchronization_type = DistributedCoachSynchronizationType.SYNC agent_params.input_filter = InputFilter() agent_params.exploration = AdditiveNoiseParameters() diff --git a/rl_coach/presets/Mujoco_PPO.py b/rl_coach/presets/Mujoco_PPO.py index 4eb2f72..d5deaa2 100644 --- a/rl_coach/presets/Mujoco_PPO.py +++ b/rl_coach/presets/Mujoco_PPO.py @@ -1,6 +1,6 @@ from rl_coach.agents.ppo_agent import PPOAgentParameters from rl_coach.architectures.layers import Dense -from rl_coach.base_parameters import VisualizationParameters, PresetValidationParameters +from rl_coach.base_parameters import VisualizationParameters, PresetValidationParameters, DistributedCoachSynchronizationType from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps from rl_coach.environments.environment import SingleLevelSelection from rl_coach.environments.gym_environment import GymVectorEnvironment, mujoco_v2 @@ -33,6 +33,9 @@ agent_params.network_wrappers['critic'].middleware_parameters.scheme = [Dense(64 agent_params.input_filter = InputFilter() agent_params.input_filter.add_observation_filter('observation', 'normalize', ObservationNormalizationFilter()) +# Distributed Coach synchronization type. +agent_params.algorithm.distributed_coach_synchronization_type = DistributedCoachSynchronizationType.SYNC + ############### # Environment # ###############