1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 19:20:19 +01:00

Enable distributed SharedRunningStats (#81)

- Use Redis pub/sub for updating SharedRunningStats.
This commit is contained in:
Balaji Subramaniam
2018-11-13 09:17:38 -08:00
committed by Gal Leibovich
parent 875d6ef017
commit a849c17e46
10 changed files with 76 additions and 27 deletions

View File

@@ -112,6 +112,11 @@ class Agent(AgentInterface):
self.output_filter = self.ap.output_filter self.output_filter = self.ap.output_filter
self.pre_network_filter = self.ap.pre_network_filter self.pre_network_filter = self.ap.pre_network_filter
device = self.replicated_device if self.replicated_device else self.worker_device device = self.replicated_device if self.replicated_device else self.worker_device
if hasattr(self.ap.memory, 'memory_backend_params') and self.ap.algorithm.distributed_coach_synchronization_type:
self.input_filter.set_device(device, memory_backend_params=self.ap.memory.memory_backend_params)
self.output_filter.set_device(device, memory_backend_params=self.ap.memory.memory_backend_params)
self.pre_network_filter.set_device(device, memory_backend_params=self.ap.memory.memory_backend_params)
else:
self.input_filter.set_device(device) self.input_filter.set_device(device)
self.output_filter.set_device(device) self.output_filter.set_device(device)
self.pre_network_filter.set_device(device) self.pre_network_filter.set_device(device)

View File

@@ -310,7 +310,7 @@ class PPOAgent(ActorCriticAgent):
# clean memory # clean memory
self.call_memory('clean') self.call_memory('clean')
def _should_train_helper(self): def _should_train_helper(self, wait_for_full_episode=True):
return super()._should_train_helper(True) return super()._should_train_helper(True)
def train(self): def train(self):

View File

@@ -15,11 +15,16 @@
# #
import numpy as np import numpy as np
import pickle
import redis
import tensorflow as tf import tensorflow as tf
import threading
from rl_coach.memories.backend.memory_impl import get_memory_backend
class SharedRunningStats(object): class SharedRunningStats(object):
def __init__(self, replicated_device=None, epsilon=1e-2, name="", create_ops=True): def __init__(self, replicated_device=None, epsilon=1e-2, name="", create_ops=True, pubsub_params=None):
self.sess = None self.sess = None
self.name = name self.name = name
self.replicated_device = replicated_device self.replicated_device = replicated_device
@@ -28,6 +33,13 @@ class SharedRunningStats(object):
if create_ops: if create_ops:
with tf.device(replicated_device): with tf.device(replicated_device):
self.create_ops() self.create_ops()
self.pubsub = None
if pubsub_params:
self.channel = "channel-srs-{}".format(self.name)
self.pubsub = get_memory_backend(pubsub_params)
subscribe_thread = SharedRunningStatsSubscribe(self)
subscribe_thread.daemon = True
subscribe_thread.start()
def create_ops(self, shape=[1], clip_values=None): def create_ops(self, shape=[1], clip_values=None):
self.clip_values = clip_values self.clip_values = clip_values
@@ -74,6 +86,13 @@ class SharedRunningStats(object):
self.sess = sess self.sess = sess
def push(self, x): def push(self, x):
if self.pubsub:
self.pubsub.redis_connection.publish(self.channel, pickle.dumps(x))
return
self.push_val(x)
def push_val(self, x):
x = x.astype('float64') x = x.astype('float64')
self.sess.run([self._inc_sum, self._inc_sum_squared, self._inc_count], self.sess.run([self._inc_sum, self._inc_sum_squared, self._inc_count],
feed_dict={ feed_dict={
@@ -119,3 +138,23 @@ class SharedRunningStats(object):
return self.sess.run(self.clipped_obs, feed_dict={self.raw_obs: batch}) return self.sess.run(self.clipped_obs, feed_dict={self.raw_obs: batch})
else: else:
return self.sess.run(self.normalized_obs, feed_dict={self.raw_obs: batch}) return self.sess.run(self.normalized_obs, feed_dict={self.raw_obs: batch})
class SharedRunningStatsSubscribe(threading.Thread):
def __init__(self, shared_running_stats):
super().__init__()
self.shared_running_stats = shared_running_stats
self.redis_address = self.shared_running_stats.pubsub.params.redis_address
self.redis_port = self.shared_running_stats.pubsub.params.redis_port
self.redis_connection = redis.Redis(self.redis_address, self.redis_port)
self.pubsub = self.redis_connection.pubsub()
self.channel = self.shared_running_stats.channel
self.pubsub.subscribe(self.channel)
def run(self):
for message in self.pubsub.listen():
try:
obj = pickle.loads(message['data'])
self.shared_running_stats.push_val(obj)
except Exception:
continue

View File

@@ -46,10 +46,11 @@ class Filter(object):
""" """
raise NotImplementedError("") raise NotImplementedError("")
def set_device(self, device) -> None: def set_device(self, device, memory_backend_params=None) -> None:
""" """
An optional function that allows the filter to get the device if it is required to use tensorflow ops An optional function that allows the filter to get the device if it is required to use tensorflow ops
:param device: the device to use :param device: the device to use
:param memory_backend_params: parameters associated with the memory backend
:return: None :return: None
""" """
pass pass
@@ -84,13 +85,13 @@ class OutputFilter(Filter):
duplicate.i_am_a_reference_filter = False duplicate.i_am_a_reference_filter = False
return duplicate return duplicate
def set_device(self, device) -> None: def set_device(self, device, memory_backend_params=None) -> None:
""" """
An optional function that allows the filter to get the device if it is required to use tensorflow ops An optional function that allows the filter to get the device if it is required to use tensorflow ops
:param device: the device to use :param device: the device to use
:return: None :return: None
""" """
[f.set_device(device) for f in self.action_filters.values()] [f.set_device(device, memory_backend_params) for f in self.action_filters.values()]
def set_session(self, sess) -> None: def set_session(self, sess) -> None:
""" """
@@ -225,14 +226,14 @@ class InputFilter(Filter):
duplicate.i_am_a_reference_filter = False duplicate.i_am_a_reference_filter = False
return duplicate return duplicate
def set_device(self, device) -> None: def set_device(self, device, memory_backend_params=None) -> None:
""" """
An optional function that allows the filter to get the device if it is required to use tensorflow ops An optional function that allows the filter to get the device if it is required to use tensorflow ops
:param device: the device to use :param device: the device to use
:return: None :return: None
""" """
[f.set_device(device) for f in self.reward_filters.values()] [f.set_device(device, memory_backend_params) for f in self.reward_filters.values()]
[[f.set_device(device) for f in filters.values()] for filters in self.observation_filters.values()] [[f.set_device(device, memory_backend_params) for f in filters.values()] for filters in self.observation_filters.values()]
def set_session(self, sess) -> None: def set_session(self, sess) -> None:
""" """

View File

@@ -41,13 +41,14 @@ class ObservationNormalizationFilter(ObservationFilter):
self.supports_batching = True self.supports_batching = True
self.observation_space = None self.observation_space = None
def set_device(self, device) -> None: def set_device(self, device, memory_backend_params=None) -> None:
""" """
An optional function that allows the filter to get the device if it is required to use tensorflow ops An optional function that allows the filter to get the device if it is required to use tensorflow ops
:param device: the device to use :param device: the device to use
:return: None :return: None
""" """
self.running_observation_stats = SharedRunningStats(device, name=self.name, create_ops=False) self.running_observation_stats = SharedRunningStats(device, name=self.name, create_ops=False,
pubsub_params=memory_backend_params)
def set_session(self, sess) -> None: def set_session(self, sess) -> None:
""" """

View File

@@ -38,13 +38,14 @@ class RewardNormalizationFilter(RewardFilter):
self.clip_max = clip_max self.clip_max = clip_max
self.running_rewards_stats = None self.running_rewards_stats = None
def set_device(self, device) -> None: def set_device(self, device, memory_backend_params=None) -> None:
""" """
An optional function that allows the filter to get the device if it is required to use tensorflow ops An optional function that allows the filter to get the device if it is required to use tensorflow ops
:param device: the device to use :param device: the device to use
:return: None :return: None
""" """
self.running_rewards_stats = SharedRunningStats(device, name='rewards_stats') self.running_rewards_stats = SharedRunningStats(device, name='rewards_stats',
pubsub_params=memory_backend_params)
def set_session(self, sess) -> None: def set_session(self, sess) -> None:
""" """

View File

@@ -5,7 +5,6 @@ from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentS
from rl_coach.environments.gym_environment import GymVectorEnvironment, mujoco_v2 from rl_coach.environments.gym_environment import GymVectorEnvironment, mujoco_v2
from rl_coach.exploration_policies.additive_noise import AdditiveNoiseParameters from rl_coach.exploration_policies.additive_noise import AdditiveNoiseParameters
from rl_coach.exploration_policies.e_greedy import EGreedyParameters from rl_coach.exploration_policies.e_greedy import EGreedyParameters
from rl_coach.filters.observation.observation_normalization_filter import ObservationNormalizationFilter
from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager
from rl_coach.graph_managers.graph_manager import ScheduleParameters from rl_coach.graph_managers.graph_manager import ScheduleParameters
from rl_coach.schedules import LinearSchedule from rl_coach.schedules import LinearSchedule
@@ -49,8 +48,6 @@ agent_params.algorithm.distributed_coach_synchronization_type = DistributedCoach
agent_params.exploration = EGreedyParameters() agent_params.exploration = EGreedyParameters()
agent_params.exploration.epsilon_schedule = LinearSchedule(1.0, 0.01, 10000) agent_params.exploration.epsilon_schedule = LinearSchedule(1.0, 0.01, 10000)
# agent_params.pre_network_filter.add_observation_filter('observation', 'normalize_observation',
# ObservationNormalizationFilter(name='normalize_observation'))
############### ###############
# Environment # # Environment #

View File

@@ -1,5 +1,5 @@
from rl_coach.agents.dqn_agent import DQNAgentParameters from rl_coach.agents.dqn_agent import DQNAgentParameters
from rl_coach.base_parameters import VisualizationParameters, PresetValidationParameters, DistributedCoachSynchronizationType from rl_coach.base_parameters import VisualizationParameters, PresetValidationParameters
from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps
from rl_coach.environments.gym_environment import GymVectorEnvironment from rl_coach.environments.gym_environment import GymVectorEnvironment
from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager

View File

@@ -1,6 +1,6 @@
from rl_coach.agents.clipped_ppo_agent import ClippedPPOAgentParameters from rl_coach.agents.clipped_ppo_agent import ClippedPPOAgentParameters
from rl_coach.architectures.layers import Dense from rl_coach.architectures.layers import Dense
from rl_coach.base_parameters import VisualizationParameters, PresetValidationParameters from rl_coach.base_parameters import VisualizationParameters, PresetValidationParameters, DistributedCoachSynchronizationType
from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps
from rl_coach.environments.environment import SingleLevelSelection from rl_coach.environments.environment import SingleLevelSelection
from rl_coach.environments.gym_environment import GymVectorEnvironment, mujoco_v2 from rl_coach.environments.gym_environment import GymVectorEnvironment, mujoco_v2
@@ -43,6 +43,8 @@ agent_params.algorithm.gae_lambda = 0.95
agent_params.algorithm.discount = 0.99 agent_params.algorithm.discount = 0.99
agent_params.algorithm.optimization_epochs = 10 agent_params.algorithm.optimization_epochs = 10
agent_params.algorithm.estimate_state_value_using_gae = True agent_params.algorithm.estimate_state_value_using_gae = True
# Distributed Coach synchronization type.
agent_params.algorithm.distributed_coach_synchronization_type = DistributedCoachSynchronizationType.SYNC
agent_params.input_filter = InputFilter() agent_params.input_filter = InputFilter()
agent_params.exploration = AdditiveNoiseParameters() agent_params.exploration = AdditiveNoiseParameters()

View File

@@ -1,6 +1,6 @@
from rl_coach.agents.ppo_agent import PPOAgentParameters from rl_coach.agents.ppo_agent import PPOAgentParameters
from rl_coach.architectures.layers import Dense from rl_coach.architectures.layers import Dense
from rl_coach.base_parameters import VisualizationParameters, PresetValidationParameters from rl_coach.base_parameters import VisualizationParameters, PresetValidationParameters, DistributedCoachSynchronizationType
from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps
from rl_coach.environments.environment import SingleLevelSelection from rl_coach.environments.environment import SingleLevelSelection
from rl_coach.environments.gym_environment import GymVectorEnvironment, mujoco_v2 from rl_coach.environments.gym_environment import GymVectorEnvironment, mujoco_v2
@@ -33,6 +33,9 @@ agent_params.network_wrappers['critic'].middleware_parameters.scheme = [Dense(64
agent_params.input_filter = InputFilter() agent_params.input_filter = InputFilter()
agent_params.input_filter.add_observation_filter('observation', 'normalize', ObservationNormalizationFilter()) agent_params.input_filter.add_observation_filter('observation', 'normalize', ObservationNormalizationFilter())
# Distributed Coach synchronization type.
agent_params.algorithm.distributed_coach_synchronization_type = DistributedCoachSynchronizationType.SYNC
############### ###############
# Environment # # Environment #
############### ###############