1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 19:20:19 +01:00

Enable distributed SharedRunningStats (#81)

- Use Redis pub/sub for updating SharedRunningStats.
This commit is contained in:
Balaji Subramaniam
2018-11-13 09:17:38 -08:00
committed by Gal Leibovich
parent 875d6ef017
commit a849c17e46
10 changed files with 76 additions and 27 deletions

View File

@@ -112,9 +112,14 @@ class Agent(AgentInterface):
self.output_filter = self.ap.output_filter
self.pre_network_filter = self.ap.pre_network_filter
device = self.replicated_device if self.replicated_device else self.worker_device
self.input_filter.set_device(device)
self.output_filter.set_device(device)
self.pre_network_filter.set_device(device)
if hasattr(self.ap.memory, 'memory_backend_params') and self.ap.algorithm.distributed_coach_synchronization_type:
self.input_filter.set_device(device, memory_backend_params=self.ap.memory.memory_backend_params)
self.output_filter.set_device(device, memory_backend_params=self.ap.memory.memory_backend_params)
self.pre_network_filter.set_device(device, memory_backend_params=self.ap.memory.memory_backend_params)
else:
self.input_filter.set_device(device)
self.output_filter.set_device(device)
self.pre_network_filter.set_device(device)
# initialize all internal variables
self._phase = RunPhase.HEATUP

View File

@@ -310,7 +310,7 @@ class PPOAgent(ActorCriticAgent):
# clean memory
self.call_memory('clean')
def _should_train_helper(self):
def _should_train_helper(self, wait_for_full_episode=True):
return super()._should_train_helper(True)
def train(self):

View File

@@ -15,11 +15,16 @@
#
import numpy as np
import pickle
import redis
import tensorflow as tf
import threading
from rl_coach.memories.backend.memory_impl import get_memory_backend
class SharedRunningStats(object):
def __init__(self, replicated_device=None, epsilon=1e-2, name="", create_ops=True):
def __init__(self, replicated_device=None, epsilon=1e-2, name="", create_ops=True, pubsub_params=None):
self.sess = None
self.name = name
self.replicated_device = replicated_device
@@ -28,6 +33,13 @@ class SharedRunningStats(object):
if create_ops:
with tf.device(replicated_device):
self.create_ops()
self.pubsub = None
if pubsub_params:
self.channel = "channel-srs-{}".format(self.name)
self.pubsub = get_memory_backend(pubsub_params)
subscribe_thread = SharedRunningStatsSubscribe(self)
subscribe_thread.daemon = True
subscribe_thread.start()
def create_ops(self, shape=[1], clip_values=None):
self.clip_values = clip_values
@@ -74,13 +86,20 @@ class SharedRunningStats(object):
self.sess = sess
def push(self, x):
if self.pubsub:
self.pubsub.redis_connection.publish(self.channel, pickle.dumps(x))
return
self.push_val(x)
def push_val(self, x):
x = x.astype('float64')
self.sess.run([self._inc_sum, self._inc_sum_squared, self._inc_count],
feed_dict={
self.new_sum: x.sum(axis=0).ravel(),
self.new_sum_squared: np.square(x).sum(axis=0).ravel(),
self.newcount: np.array(len(x), dtype='float64')
})
feed_dict={
self.new_sum: x.sum(axis=0).ravel(),
self.new_sum_squared: np.square(x).sum(axis=0).ravel(),
self.newcount: np.array(len(x), dtype='float64')
})
if self._shape is None:
self._shape = x.shape
@@ -119,3 +138,23 @@ class SharedRunningStats(object):
return self.sess.run(self.clipped_obs, feed_dict={self.raw_obs: batch})
else:
return self.sess.run(self.normalized_obs, feed_dict={self.raw_obs: batch})
class SharedRunningStatsSubscribe(threading.Thread):
def __init__(self, shared_running_stats):
super().__init__()
self.shared_running_stats = shared_running_stats
self.redis_address = self.shared_running_stats.pubsub.params.redis_address
self.redis_port = self.shared_running_stats.pubsub.params.redis_port
self.redis_connection = redis.Redis(self.redis_address, self.redis_port)
self.pubsub = self.redis_connection.pubsub()
self.channel = self.shared_running_stats.channel
self.pubsub.subscribe(self.channel)
def run(self):
for message in self.pubsub.listen():
try:
obj = pickle.loads(message['data'])
self.shared_running_stats.push_val(obj)
except Exception:
continue

View File

@@ -46,10 +46,11 @@ class Filter(object):
"""
raise NotImplementedError("")
def set_device(self, device) -> None:
def set_device(self, device, memory_backend_params=None) -> None:
"""
An optional function that allows the filter to get the device if it is required to use tensorflow ops
:param device: the device to use
:param memory_backend_params: parameters associated with the memory backend
:return: None
"""
pass
@@ -84,13 +85,13 @@ class OutputFilter(Filter):
duplicate.i_am_a_reference_filter = False
return duplicate
def set_device(self, device) -> None:
def set_device(self, device, memory_backend_params=None) -> None:
"""
An optional function that allows the filter to get the device if it is required to use tensorflow ops
:param device: the device to use
:return: None
"""
[f.set_device(device) for f in self.action_filters.values()]
[f.set_device(device, memory_backend_params) for f in self.action_filters.values()]
def set_session(self, sess) -> None:
"""
@@ -225,14 +226,14 @@ class InputFilter(Filter):
duplicate.i_am_a_reference_filter = False
return duplicate
def set_device(self, device) -> None:
def set_device(self, device, memory_backend_params=None) -> None:
"""
An optional function that allows the filter to get the device if it is required to use tensorflow ops
:param device: the device to use
:return: None
"""
[f.set_device(device) for f in self.reward_filters.values()]
[[f.set_device(device) for f in filters.values()] for filters in self.observation_filters.values()]
[f.set_device(device, memory_backend_params) for f in self.reward_filters.values()]
[[f.set_device(device, memory_backend_params) for f in filters.values()] for filters in self.observation_filters.values()]
def set_session(self, sess) -> None:
"""

View File

@@ -41,13 +41,14 @@ class ObservationNormalizationFilter(ObservationFilter):
self.supports_batching = True
self.observation_space = None
def set_device(self, device) -> None:
def set_device(self, device, memory_backend_params=None) -> None:
"""
An optional function that allows the filter to get the device if it is required to use tensorflow ops
:param device: the device to use
:return: None
"""
self.running_observation_stats = SharedRunningStats(device, name=self.name, create_ops=False)
self.running_observation_stats = SharedRunningStats(device, name=self.name, create_ops=False,
pubsub_params=memory_backend_params)
def set_session(self, sess) -> None:
"""

View File

@@ -38,13 +38,14 @@ class RewardNormalizationFilter(RewardFilter):
self.clip_max = clip_max
self.running_rewards_stats = None
def set_device(self, device) -> None:
def set_device(self, device, memory_backend_params=None) -> None:
"""
An optional function that allows the filter to get the device if it is required to use tensorflow ops
:param device: the device to use
:return: None
"""
self.running_rewards_stats = SharedRunningStats(device, name='rewards_stats')
self.running_rewards_stats = SharedRunningStats(device, name='rewards_stats',
pubsub_params=memory_backend_params)
def set_session(self, sess) -> None:
"""

View File

@@ -5,7 +5,6 @@ from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentS
from rl_coach.environments.gym_environment import GymVectorEnvironment, mujoco_v2
from rl_coach.exploration_policies.additive_noise import AdditiveNoiseParameters
from rl_coach.exploration_policies.e_greedy import EGreedyParameters
from rl_coach.filters.observation.observation_normalization_filter import ObservationNormalizationFilter
from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager
from rl_coach.graph_managers.graph_manager import ScheduleParameters
from rl_coach.schedules import LinearSchedule
@@ -49,8 +48,6 @@ agent_params.algorithm.distributed_coach_synchronization_type = DistributedCoach
agent_params.exploration = EGreedyParameters()
agent_params.exploration.epsilon_schedule = LinearSchedule(1.0, 0.01, 10000)
# agent_params.pre_network_filter.add_observation_filter('observation', 'normalize_observation',
# ObservationNormalizationFilter(name='normalize_observation'))
###############
# Environment #

View File

@@ -1,5 +1,5 @@
from rl_coach.agents.dqn_agent import DQNAgentParameters
from rl_coach.base_parameters import VisualizationParameters, PresetValidationParameters, DistributedCoachSynchronizationType
from rl_coach.base_parameters import VisualizationParameters, PresetValidationParameters
from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps
from rl_coach.environments.gym_environment import GymVectorEnvironment
from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager

View File

@@ -1,6 +1,6 @@
from rl_coach.agents.clipped_ppo_agent import ClippedPPOAgentParameters
from rl_coach.architectures.layers import Dense
from rl_coach.base_parameters import VisualizationParameters, PresetValidationParameters
from rl_coach.base_parameters import VisualizationParameters, PresetValidationParameters, DistributedCoachSynchronizationType
from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps
from rl_coach.environments.environment import SingleLevelSelection
from rl_coach.environments.gym_environment import GymVectorEnvironment, mujoco_v2
@@ -43,6 +43,8 @@ agent_params.algorithm.gae_lambda = 0.95
agent_params.algorithm.discount = 0.99
agent_params.algorithm.optimization_epochs = 10
agent_params.algorithm.estimate_state_value_using_gae = True
# Distributed Coach synchronization type.
agent_params.algorithm.distributed_coach_synchronization_type = DistributedCoachSynchronizationType.SYNC
agent_params.input_filter = InputFilter()
agent_params.exploration = AdditiveNoiseParameters()

View File

@@ -1,6 +1,6 @@
from rl_coach.agents.ppo_agent import PPOAgentParameters
from rl_coach.architectures.layers import Dense
from rl_coach.base_parameters import VisualizationParameters, PresetValidationParameters
from rl_coach.base_parameters import VisualizationParameters, PresetValidationParameters, DistributedCoachSynchronizationType
from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps
from rl_coach.environments.environment import SingleLevelSelection
from rl_coach.environments.gym_environment import GymVectorEnvironment, mujoco_v2
@@ -33,6 +33,9 @@ agent_params.network_wrappers['critic'].middleware_parameters.scheme = [Dense(64
agent_params.input_filter = InputFilter()
agent_params.input_filter.add_observation_filter('observation', 'normalize', ObservationNormalizationFilter())
# Distributed Coach synchronization type.
agent_params.algorithm.distributed_coach_synchronization_type = DistributedCoachSynchronizationType.SYNC
###############
# Environment #
###############