mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 19:20:19 +01:00
Enable distributed SharedRunningStats (#81)
- Use Redis pub/sub for updating SharedRunningStats.
This commit is contained in:
committed by
Gal Leibovich
parent
875d6ef017
commit
a849c17e46
@@ -112,6 +112,11 @@ class Agent(AgentInterface):
|
|||||||
self.output_filter = self.ap.output_filter
|
self.output_filter = self.ap.output_filter
|
||||||
self.pre_network_filter = self.ap.pre_network_filter
|
self.pre_network_filter = self.ap.pre_network_filter
|
||||||
device = self.replicated_device if self.replicated_device else self.worker_device
|
device = self.replicated_device if self.replicated_device else self.worker_device
|
||||||
|
if hasattr(self.ap.memory, 'memory_backend_params') and self.ap.algorithm.distributed_coach_synchronization_type:
|
||||||
|
self.input_filter.set_device(device, memory_backend_params=self.ap.memory.memory_backend_params)
|
||||||
|
self.output_filter.set_device(device, memory_backend_params=self.ap.memory.memory_backend_params)
|
||||||
|
self.pre_network_filter.set_device(device, memory_backend_params=self.ap.memory.memory_backend_params)
|
||||||
|
else:
|
||||||
self.input_filter.set_device(device)
|
self.input_filter.set_device(device)
|
||||||
self.output_filter.set_device(device)
|
self.output_filter.set_device(device)
|
||||||
self.pre_network_filter.set_device(device)
|
self.pre_network_filter.set_device(device)
|
||||||
|
|||||||
@@ -310,7 +310,7 @@ class PPOAgent(ActorCriticAgent):
|
|||||||
# clean memory
|
# clean memory
|
||||||
self.call_memory('clean')
|
self.call_memory('clean')
|
||||||
|
|
||||||
def _should_train_helper(self):
|
def _should_train_helper(self, wait_for_full_episode=True):
|
||||||
return super()._should_train_helper(True)
|
return super()._should_train_helper(True)
|
||||||
|
|
||||||
def train(self):
|
def train(self):
|
||||||
|
|||||||
@@ -15,11 +15,16 @@
|
|||||||
#
|
#
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import pickle
|
||||||
|
import redis
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
import threading
|
||||||
|
|
||||||
|
from rl_coach.memories.backend.memory_impl import get_memory_backend
|
||||||
|
|
||||||
|
|
||||||
class SharedRunningStats(object):
|
class SharedRunningStats(object):
|
||||||
def __init__(self, replicated_device=None, epsilon=1e-2, name="", create_ops=True):
|
def __init__(self, replicated_device=None, epsilon=1e-2, name="", create_ops=True, pubsub_params=None):
|
||||||
self.sess = None
|
self.sess = None
|
||||||
self.name = name
|
self.name = name
|
||||||
self.replicated_device = replicated_device
|
self.replicated_device = replicated_device
|
||||||
@@ -28,6 +33,13 @@ class SharedRunningStats(object):
|
|||||||
if create_ops:
|
if create_ops:
|
||||||
with tf.device(replicated_device):
|
with tf.device(replicated_device):
|
||||||
self.create_ops()
|
self.create_ops()
|
||||||
|
self.pubsub = None
|
||||||
|
if pubsub_params:
|
||||||
|
self.channel = "channel-srs-{}".format(self.name)
|
||||||
|
self.pubsub = get_memory_backend(pubsub_params)
|
||||||
|
subscribe_thread = SharedRunningStatsSubscribe(self)
|
||||||
|
subscribe_thread.daemon = True
|
||||||
|
subscribe_thread.start()
|
||||||
|
|
||||||
def create_ops(self, shape=[1], clip_values=None):
|
def create_ops(self, shape=[1], clip_values=None):
|
||||||
self.clip_values = clip_values
|
self.clip_values = clip_values
|
||||||
@@ -74,6 +86,13 @@ class SharedRunningStats(object):
|
|||||||
self.sess = sess
|
self.sess = sess
|
||||||
|
|
||||||
def push(self, x):
|
def push(self, x):
|
||||||
|
if self.pubsub:
|
||||||
|
self.pubsub.redis_connection.publish(self.channel, pickle.dumps(x))
|
||||||
|
return
|
||||||
|
|
||||||
|
self.push_val(x)
|
||||||
|
|
||||||
|
def push_val(self, x):
|
||||||
x = x.astype('float64')
|
x = x.astype('float64')
|
||||||
self.sess.run([self._inc_sum, self._inc_sum_squared, self._inc_count],
|
self.sess.run([self._inc_sum, self._inc_sum_squared, self._inc_count],
|
||||||
feed_dict={
|
feed_dict={
|
||||||
@@ -119,3 +138,23 @@ class SharedRunningStats(object):
|
|||||||
return self.sess.run(self.clipped_obs, feed_dict={self.raw_obs: batch})
|
return self.sess.run(self.clipped_obs, feed_dict={self.raw_obs: batch})
|
||||||
else:
|
else:
|
||||||
return self.sess.run(self.normalized_obs, feed_dict={self.raw_obs: batch})
|
return self.sess.run(self.normalized_obs, feed_dict={self.raw_obs: batch})
|
||||||
|
|
||||||
|
|
||||||
|
class SharedRunningStatsSubscribe(threading.Thread):
|
||||||
|
def __init__(self, shared_running_stats):
|
||||||
|
super().__init__()
|
||||||
|
self.shared_running_stats = shared_running_stats
|
||||||
|
self.redis_address = self.shared_running_stats.pubsub.params.redis_address
|
||||||
|
self.redis_port = self.shared_running_stats.pubsub.params.redis_port
|
||||||
|
self.redis_connection = redis.Redis(self.redis_address, self.redis_port)
|
||||||
|
self.pubsub = self.redis_connection.pubsub()
|
||||||
|
self.channel = self.shared_running_stats.channel
|
||||||
|
self.pubsub.subscribe(self.channel)
|
||||||
|
|
||||||
|
def run(self):
|
||||||
|
for message in self.pubsub.listen():
|
||||||
|
try:
|
||||||
|
obj = pickle.loads(message['data'])
|
||||||
|
self.shared_running_stats.push_val(obj)
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
|||||||
@@ -46,10 +46,11 @@ class Filter(object):
|
|||||||
"""
|
"""
|
||||||
raise NotImplementedError("")
|
raise NotImplementedError("")
|
||||||
|
|
||||||
def set_device(self, device) -> None:
|
def set_device(self, device, memory_backend_params=None) -> None:
|
||||||
"""
|
"""
|
||||||
An optional function that allows the filter to get the device if it is required to use tensorflow ops
|
An optional function that allows the filter to get the device if it is required to use tensorflow ops
|
||||||
:param device: the device to use
|
:param device: the device to use
|
||||||
|
:param memory_backend_params: parameters associated with the memory backend
|
||||||
:return: None
|
:return: None
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
@@ -84,13 +85,13 @@ class OutputFilter(Filter):
|
|||||||
duplicate.i_am_a_reference_filter = False
|
duplicate.i_am_a_reference_filter = False
|
||||||
return duplicate
|
return duplicate
|
||||||
|
|
||||||
def set_device(self, device) -> None:
|
def set_device(self, device, memory_backend_params=None) -> None:
|
||||||
"""
|
"""
|
||||||
An optional function that allows the filter to get the device if it is required to use tensorflow ops
|
An optional function that allows the filter to get the device if it is required to use tensorflow ops
|
||||||
:param device: the device to use
|
:param device: the device to use
|
||||||
:return: None
|
:return: None
|
||||||
"""
|
"""
|
||||||
[f.set_device(device) for f in self.action_filters.values()]
|
[f.set_device(device, memory_backend_params) for f in self.action_filters.values()]
|
||||||
|
|
||||||
def set_session(self, sess) -> None:
|
def set_session(self, sess) -> None:
|
||||||
"""
|
"""
|
||||||
@@ -225,14 +226,14 @@ class InputFilter(Filter):
|
|||||||
duplicate.i_am_a_reference_filter = False
|
duplicate.i_am_a_reference_filter = False
|
||||||
return duplicate
|
return duplicate
|
||||||
|
|
||||||
def set_device(self, device) -> None:
|
def set_device(self, device, memory_backend_params=None) -> None:
|
||||||
"""
|
"""
|
||||||
An optional function that allows the filter to get the device if it is required to use tensorflow ops
|
An optional function that allows the filter to get the device if it is required to use tensorflow ops
|
||||||
:param device: the device to use
|
:param device: the device to use
|
||||||
:return: None
|
:return: None
|
||||||
"""
|
"""
|
||||||
[f.set_device(device) for f in self.reward_filters.values()]
|
[f.set_device(device, memory_backend_params) for f in self.reward_filters.values()]
|
||||||
[[f.set_device(device) for f in filters.values()] for filters in self.observation_filters.values()]
|
[[f.set_device(device, memory_backend_params) for f in filters.values()] for filters in self.observation_filters.values()]
|
||||||
|
|
||||||
def set_session(self, sess) -> None:
|
def set_session(self, sess) -> None:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -41,13 +41,14 @@ class ObservationNormalizationFilter(ObservationFilter):
|
|||||||
self.supports_batching = True
|
self.supports_batching = True
|
||||||
self.observation_space = None
|
self.observation_space = None
|
||||||
|
|
||||||
def set_device(self, device) -> None:
|
def set_device(self, device, memory_backend_params=None) -> None:
|
||||||
"""
|
"""
|
||||||
An optional function that allows the filter to get the device if it is required to use tensorflow ops
|
An optional function that allows the filter to get the device if it is required to use tensorflow ops
|
||||||
:param device: the device to use
|
:param device: the device to use
|
||||||
:return: None
|
:return: None
|
||||||
"""
|
"""
|
||||||
self.running_observation_stats = SharedRunningStats(device, name=self.name, create_ops=False)
|
self.running_observation_stats = SharedRunningStats(device, name=self.name, create_ops=False,
|
||||||
|
pubsub_params=memory_backend_params)
|
||||||
|
|
||||||
def set_session(self, sess) -> None:
|
def set_session(self, sess) -> None:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -38,13 +38,14 @@ class RewardNormalizationFilter(RewardFilter):
|
|||||||
self.clip_max = clip_max
|
self.clip_max = clip_max
|
||||||
self.running_rewards_stats = None
|
self.running_rewards_stats = None
|
||||||
|
|
||||||
def set_device(self, device) -> None:
|
def set_device(self, device, memory_backend_params=None) -> None:
|
||||||
"""
|
"""
|
||||||
An optional function that allows the filter to get the device if it is required to use tensorflow ops
|
An optional function that allows the filter to get the device if it is required to use tensorflow ops
|
||||||
:param device: the device to use
|
:param device: the device to use
|
||||||
:return: None
|
:return: None
|
||||||
"""
|
"""
|
||||||
self.running_rewards_stats = SharedRunningStats(device, name='rewards_stats')
|
self.running_rewards_stats = SharedRunningStats(device, name='rewards_stats',
|
||||||
|
pubsub_params=memory_backend_params)
|
||||||
|
|
||||||
def set_session(self, sess) -> None:
|
def set_session(self, sess) -> None:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -5,7 +5,6 @@ from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentS
|
|||||||
from rl_coach.environments.gym_environment import GymVectorEnvironment, mujoco_v2
|
from rl_coach.environments.gym_environment import GymVectorEnvironment, mujoco_v2
|
||||||
from rl_coach.exploration_policies.additive_noise import AdditiveNoiseParameters
|
from rl_coach.exploration_policies.additive_noise import AdditiveNoiseParameters
|
||||||
from rl_coach.exploration_policies.e_greedy import EGreedyParameters
|
from rl_coach.exploration_policies.e_greedy import EGreedyParameters
|
||||||
from rl_coach.filters.observation.observation_normalization_filter import ObservationNormalizationFilter
|
|
||||||
from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager
|
from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager
|
||||||
from rl_coach.graph_managers.graph_manager import ScheduleParameters
|
from rl_coach.graph_managers.graph_manager import ScheduleParameters
|
||||||
from rl_coach.schedules import LinearSchedule
|
from rl_coach.schedules import LinearSchedule
|
||||||
@@ -49,8 +48,6 @@ agent_params.algorithm.distributed_coach_synchronization_type = DistributedCoach
|
|||||||
|
|
||||||
agent_params.exploration = EGreedyParameters()
|
agent_params.exploration = EGreedyParameters()
|
||||||
agent_params.exploration.epsilon_schedule = LinearSchedule(1.0, 0.01, 10000)
|
agent_params.exploration.epsilon_schedule = LinearSchedule(1.0, 0.01, 10000)
|
||||||
# agent_params.pre_network_filter.add_observation_filter('observation', 'normalize_observation',
|
|
||||||
# ObservationNormalizationFilter(name='normalize_observation'))
|
|
||||||
|
|
||||||
###############
|
###############
|
||||||
# Environment #
|
# Environment #
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
from rl_coach.agents.dqn_agent import DQNAgentParameters
|
from rl_coach.agents.dqn_agent import DQNAgentParameters
|
||||||
from rl_coach.base_parameters import VisualizationParameters, PresetValidationParameters, DistributedCoachSynchronizationType
|
from rl_coach.base_parameters import VisualizationParameters, PresetValidationParameters
|
||||||
from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps
|
from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps
|
||||||
from rl_coach.environments.gym_environment import GymVectorEnvironment
|
from rl_coach.environments.gym_environment import GymVectorEnvironment
|
||||||
from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager
|
from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
from rl_coach.agents.clipped_ppo_agent import ClippedPPOAgentParameters
|
from rl_coach.agents.clipped_ppo_agent import ClippedPPOAgentParameters
|
||||||
from rl_coach.architectures.layers import Dense
|
from rl_coach.architectures.layers import Dense
|
||||||
from rl_coach.base_parameters import VisualizationParameters, PresetValidationParameters
|
from rl_coach.base_parameters import VisualizationParameters, PresetValidationParameters, DistributedCoachSynchronizationType
|
||||||
from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps
|
from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps
|
||||||
from rl_coach.environments.environment import SingleLevelSelection
|
from rl_coach.environments.environment import SingleLevelSelection
|
||||||
from rl_coach.environments.gym_environment import GymVectorEnvironment, mujoco_v2
|
from rl_coach.environments.gym_environment import GymVectorEnvironment, mujoco_v2
|
||||||
@@ -43,6 +43,8 @@ agent_params.algorithm.gae_lambda = 0.95
|
|||||||
agent_params.algorithm.discount = 0.99
|
agent_params.algorithm.discount = 0.99
|
||||||
agent_params.algorithm.optimization_epochs = 10
|
agent_params.algorithm.optimization_epochs = 10
|
||||||
agent_params.algorithm.estimate_state_value_using_gae = True
|
agent_params.algorithm.estimate_state_value_using_gae = True
|
||||||
|
# Distributed Coach synchronization type.
|
||||||
|
agent_params.algorithm.distributed_coach_synchronization_type = DistributedCoachSynchronizationType.SYNC
|
||||||
|
|
||||||
agent_params.input_filter = InputFilter()
|
agent_params.input_filter = InputFilter()
|
||||||
agent_params.exploration = AdditiveNoiseParameters()
|
agent_params.exploration = AdditiveNoiseParameters()
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
from rl_coach.agents.ppo_agent import PPOAgentParameters
|
from rl_coach.agents.ppo_agent import PPOAgentParameters
|
||||||
from rl_coach.architectures.layers import Dense
|
from rl_coach.architectures.layers import Dense
|
||||||
from rl_coach.base_parameters import VisualizationParameters, PresetValidationParameters
|
from rl_coach.base_parameters import VisualizationParameters, PresetValidationParameters, DistributedCoachSynchronizationType
|
||||||
from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps
|
from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps
|
||||||
from rl_coach.environments.environment import SingleLevelSelection
|
from rl_coach.environments.environment import SingleLevelSelection
|
||||||
from rl_coach.environments.gym_environment import GymVectorEnvironment, mujoco_v2
|
from rl_coach.environments.gym_environment import GymVectorEnvironment, mujoco_v2
|
||||||
@@ -33,6 +33,9 @@ agent_params.network_wrappers['critic'].middleware_parameters.scheme = [Dense(64
|
|||||||
agent_params.input_filter = InputFilter()
|
agent_params.input_filter = InputFilter()
|
||||||
agent_params.input_filter.add_observation_filter('observation', 'normalize', ObservationNormalizationFilter())
|
agent_params.input_filter.add_observation_filter('observation', 'normalize', ObservationNormalizationFilter())
|
||||||
|
|
||||||
|
# Distributed Coach synchronization type.
|
||||||
|
agent_params.algorithm.distributed_coach_synchronization_type = DistributedCoachSynchronizationType.SYNC
|
||||||
|
|
||||||
###############
|
###############
|
||||||
# Environment #
|
# Environment #
|
||||||
###############
|
###############
|
||||||
|
|||||||
Reference in New Issue
Block a user