mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 19:20:19 +01:00
Enable distributed SharedRunningStats (#81)
- Use Redis pub/sub for updating SharedRunningStats.
This commit is contained in:
committed by
Gal Leibovich
parent
875d6ef017
commit
a849c17e46
@@ -112,9 +112,14 @@ class Agent(AgentInterface):
|
||||
self.output_filter = self.ap.output_filter
|
||||
self.pre_network_filter = self.ap.pre_network_filter
|
||||
device = self.replicated_device if self.replicated_device else self.worker_device
|
||||
self.input_filter.set_device(device)
|
||||
self.output_filter.set_device(device)
|
||||
self.pre_network_filter.set_device(device)
|
||||
if hasattr(self.ap.memory, 'memory_backend_params') and self.ap.algorithm.distributed_coach_synchronization_type:
|
||||
self.input_filter.set_device(device, memory_backend_params=self.ap.memory.memory_backend_params)
|
||||
self.output_filter.set_device(device, memory_backend_params=self.ap.memory.memory_backend_params)
|
||||
self.pre_network_filter.set_device(device, memory_backend_params=self.ap.memory.memory_backend_params)
|
||||
else:
|
||||
self.input_filter.set_device(device)
|
||||
self.output_filter.set_device(device)
|
||||
self.pre_network_filter.set_device(device)
|
||||
|
||||
# initialize all internal variables
|
||||
self._phase = RunPhase.HEATUP
|
||||
|
||||
@@ -310,7 +310,7 @@ class PPOAgent(ActorCriticAgent):
|
||||
# clean memory
|
||||
self.call_memory('clean')
|
||||
|
||||
def _should_train_helper(self):
|
||||
def _should_train_helper(self, wait_for_full_episode=True):
|
||||
return super()._should_train_helper(True)
|
||||
|
||||
def train(self):
|
||||
|
||||
@@ -15,11 +15,16 @@
|
||||
#
|
||||
|
||||
import numpy as np
|
||||
import pickle
|
||||
import redis
|
||||
import tensorflow as tf
|
||||
import threading
|
||||
|
||||
from rl_coach.memories.backend.memory_impl import get_memory_backend
|
||||
|
||||
|
||||
class SharedRunningStats(object):
|
||||
def __init__(self, replicated_device=None, epsilon=1e-2, name="", create_ops=True):
|
||||
def __init__(self, replicated_device=None, epsilon=1e-2, name="", create_ops=True, pubsub_params=None):
|
||||
self.sess = None
|
||||
self.name = name
|
||||
self.replicated_device = replicated_device
|
||||
@@ -28,6 +33,13 @@ class SharedRunningStats(object):
|
||||
if create_ops:
|
||||
with tf.device(replicated_device):
|
||||
self.create_ops()
|
||||
self.pubsub = None
|
||||
if pubsub_params:
|
||||
self.channel = "channel-srs-{}".format(self.name)
|
||||
self.pubsub = get_memory_backend(pubsub_params)
|
||||
subscribe_thread = SharedRunningStatsSubscribe(self)
|
||||
subscribe_thread.daemon = True
|
||||
subscribe_thread.start()
|
||||
|
||||
def create_ops(self, shape=[1], clip_values=None):
|
||||
self.clip_values = clip_values
|
||||
@@ -74,13 +86,20 @@ class SharedRunningStats(object):
|
||||
self.sess = sess
|
||||
|
||||
def push(self, x):
|
||||
if self.pubsub:
|
||||
self.pubsub.redis_connection.publish(self.channel, pickle.dumps(x))
|
||||
return
|
||||
|
||||
self.push_val(x)
|
||||
|
||||
def push_val(self, x):
|
||||
x = x.astype('float64')
|
||||
self.sess.run([self._inc_sum, self._inc_sum_squared, self._inc_count],
|
||||
feed_dict={
|
||||
self.new_sum: x.sum(axis=0).ravel(),
|
||||
self.new_sum_squared: np.square(x).sum(axis=0).ravel(),
|
||||
self.newcount: np.array(len(x), dtype='float64')
|
||||
})
|
||||
feed_dict={
|
||||
self.new_sum: x.sum(axis=0).ravel(),
|
||||
self.new_sum_squared: np.square(x).sum(axis=0).ravel(),
|
||||
self.newcount: np.array(len(x), dtype='float64')
|
||||
})
|
||||
if self._shape is None:
|
||||
self._shape = x.shape
|
||||
|
||||
@@ -119,3 +138,23 @@ class SharedRunningStats(object):
|
||||
return self.sess.run(self.clipped_obs, feed_dict={self.raw_obs: batch})
|
||||
else:
|
||||
return self.sess.run(self.normalized_obs, feed_dict={self.raw_obs: batch})
|
||||
|
||||
|
||||
class SharedRunningStatsSubscribe(threading.Thread):
|
||||
def __init__(self, shared_running_stats):
|
||||
super().__init__()
|
||||
self.shared_running_stats = shared_running_stats
|
||||
self.redis_address = self.shared_running_stats.pubsub.params.redis_address
|
||||
self.redis_port = self.shared_running_stats.pubsub.params.redis_port
|
||||
self.redis_connection = redis.Redis(self.redis_address, self.redis_port)
|
||||
self.pubsub = self.redis_connection.pubsub()
|
||||
self.channel = self.shared_running_stats.channel
|
||||
self.pubsub.subscribe(self.channel)
|
||||
|
||||
def run(self):
|
||||
for message in self.pubsub.listen():
|
||||
try:
|
||||
obj = pickle.loads(message['data'])
|
||||
self.shared_running_stats.push_val(obj)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
@@ -46,10 +46,11 @@ class Filter(object):
|
||||
"""
|
||||
raise NotImplementedError("")
|
||||
|
||||
def set_device(self, device) -> None:
|
||||
def set_device(self, device, memory_backend_params=None) -> None:
|
||||
"""
|
||||
An optional function that allows the filter to get the device if it is required to use tensorflow ops
|
||||
:param device: the device to use
|
||||
:param memory_backend_params: parameters associated with the memory backend
|
||||
:return: None
|
||||
"""
|
||||
pass
|
||||
@@ -84,13 +85,13 @@ class OutputFilter(Filter):
|
||||
duplicate.i_am_a_reference_filter = False
|
||||
return duplicate
|
||||
|
||||
def set_device(self, device) -> None:
|
||||
def set_device(self, device, memory_backend_params=None) -> None:
|
||||
"""
|
||||
An optional function that allows the filter to get the device if it is required to use tensorflow ops
|
||||
:param device: the device to use
|
||||
:return: None
|
||||
"""
|
||||
[f.set_device(device) for f in self.action_filters.values()]
|
||||
[f.set_device(device, memory_backend_params) for f in self.action_filters.values()]
|
||||
|
||||
def set_session(self, sess) -> None:
|
||||
"""
|
||||
@@ -225,14 +226,14 @@ class InputFilter(Filter):
|
||||
duplicate.i_am_a_reference_filter = False
|
||||
return duplicate
|
||||
|
||||
def set_device(self, device) -> None:
|
||||
def set_device(self, device, memory_backend_params=None) -> None:
|
||||
"""
|
||||
An optional function that allows the filter to get the device if it is required to use tensorflow ops
|
||||
:param device: the device to use
|
||||
:return: None
|
||||
"""
|
||||
[f.set_device(device) for f in self.reward_filters.values()]
|
||||
[[f.set_device(device) for f in filters.values()] for filters in self.observation_filters.values()]
|
||||
[f.set_device(device, memory_backend_params) for f in self.reward_filters.values()]
|
||||
[[f.set_device(device, memory_backend_params) for f in filters.values()] for filters in self.observation_filters.values()]
|
||||
|
||||
def set_session(self, sess) -> None:
|
||||
"""
|
||||
|
||||
@@ -41,13 +41,14 @@ class ObservationNormalizationFilter(ObservationFilter):
|
||||
self.supports_batching = True
|
||||
self.observation_space = None
|
||||
|
||||
def set_device(self, device) -> None:
|
||||
def set_device(self, device, memory_backend_params=None) -> None:
|
||||
"""
|
||||
An optional function that allows the filter to get the device if it is required to use tensorflow ops
|
||||
:param device: the device to use
|
||||
:return: None
|
||||
"""
|
||||
self.running_observation_stats = SharedRunningStats(device, name=self.name, create_ops=False)
|
||||
self.running_observation_stats = SharedRunningStats(device, name=self.name, create_ops=False,
|
||||
pubsub_params=memory_backend_params)
|
||||
|
||||
def set_session(self, sess) -> None:
|
||||
"""
|
||||
|
||||
@@ -38,13 +38,14 @@ class RewardNormalizationFilter(RewardFilter):
|
||||
self.clip_max = clip_max
|
||||
self.running_rewards_stats = None
|
||||
|
||||
def set_device(self, device) -> None:
|
||||
def set_device(self, device, memory_backend_params=None) -> None:
|
||||
"""
|
||||
An optional function that allows the filter to get the device if it is required to use tensorflow ops
|
||||
:param device: the device to use
|
||||
:return: None
|
||||
"""
|
||||
self.running_rewards_stats = SharedRunningStats(device, name='rewards_stats')
|
||||
self.running_rewards_stats = SharedRunningStats(device, name='rewards_stats',
|
||||
pubsub_params=memory_backend_params)
|
||||
|
||||
def set_session(self, sess) -> None:
|
||||
"""
|
||||
|
||||
@@ -5,7 +5,6 @@ from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentS
|
||||
from rl_coach.environments.gym_environment import GymVectorEnvironment, mujoco_v2
|
||||
from rl_coach.exploration_policies.additive_noise import AdditiveNoiseParameters
|
||||
from rl_coach.exploration_policies.e_greedy import EGreedyParameters
|
||||
from rl_coach.filters.observation.observation_normalization_filter import ObservationNormalizationFilter
|
||||
from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager
|
||||
from rl_coach.graph_managers.graph_manager import ScheduleParameters
|
||||
from rl_coach.schedules import LinearSchedule
|
||||
@@ -49,8 +48,6 @@ agent_params.algorithm.distributed_coach_synchronization_type = DistributedCoach
|
||||
|
||||
agent_params.exploration = EGreedyParameters()
|
||||
agent_params.exploration.epsilon_schedule = LinearSchedule(1.0, 0.01, 10000)
|
||||
# agent_params.pre_network_filter.add_observation_filter('observation', 'normalize_observation',
|
||||
# ObservationNormalizationFilter(name='normalize_observation'))
|
||||
|
||||
###############
|
||||
# Environment #
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from rl_coach.agents.dqn_agent import DQNAgentParameters
|
||||
from rl_coach.base_parameters import VisualizationParameters, PresetValidationParameters, DistributedCoachSynchronizationType
|
||||
from rl_coach.base_parameters import VisualizationParameters, PresetValidationParameters
|
||||
from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps
|
||||
from rl_coach.environments.gym_environment import GymVectorEnvironment
|
||||
from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from rl_coach.agents.clipped_ppo_agent import ClippedPPOAgentParameters
|
||||
from rl_coach.architectures.layers import Dense
|
||||
from rl_coach.base_parameters import VisualizationParameters, PresetValidationParameters
|
||||
from rl_coach.base_parameters import VisualizationParameters, PresetValidationParameters, DistributedCoachSynchronizationType
|
||||
from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps
|
||||
from rl_coach.environments.environment import SingleLevelSelection
|
||||
from rl_coach.environments.gym_environment import GymVectorEnvironment, mujoco_v2
|
||||
@@ -43,6 +43,8 @@ agent_params.algorithm.gae_lambda = 0.95
|
||||
agent_params.algorithm.discount = 0.99
|
||||
agent_params.algorithm.optimization_epochs = 10
|
||||
agent_params.algorithm.estimate_state_value_using_gae = True
|
||||
# Distributed Coach synchronization type.
|
||||
agent_params.algorithm.distributed_coach_synchronization_type = DistributedCoachSynchronizationType.SYNC
|
||||
|
||||
agent_params.input_filter = InputFilter()
|
||||
agent_params.exploration = AdditiveNoiseParameters()
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from rl_coach.agents.ppo_agent import PPOAgentParameters
|
||||
from rl_coach.architectures.layers import Dense
|
||||
from rl_coach.base_parameters import VisualizationParameters, PresetValidationParameters
|
||||
from rl_coach.base_parameters import VisualizationParameters, PresetValidationParameters, DistributedCoachSynchronizationType
|
||||
from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps
|
||||
from rl_coach.environments.environment import SingleLevelSelection
|
||||
from rl_coach.environments.gym_environment import GymVectorEnvironment, mujoco_v2
|
||||
@@ -33,6 +33,9 @@ agent_params.network_wrappers['critic'].middleware_parameters.scheme = [Dense(64
|
||||
agent_params.input_filter = InputFilter()
|
||||
agent_params.input_filter.add_observation_filter('observation', 'normalize', ObservationNormalizationFilter())
|
||||
|
||||
# Distributed Coach synchronization type.
|
||||
agent_params.algorithm.distributed_coach_synchronization_type = DistributedCoachSynchronizationType.SYNC
|
||||
|
||||
###############
|
||||
# Environment #
|
||||
###############
|
||||
|
||||
Reference in New Issue
Block a user