1
0
mirror of https://github.com/gryf/coach.git synced 2026-03-23 11:03:32 +01:00

Enable distributed SharedRunningStats (#81)

- Use Redis pub/sub for updating SharedRunningStats.
This commit is contained in:
Balaji Subramaniam
2018-11-13 09:17:38 -08:00
committed by Gal Leibovich
parent 875d6ef017
commit a849c17e46
10 changed files with 76 additions and 27 deletions

View File

@@ -46,10 +46,11 @@ class Filter(object):
"""
raise NotImplementedError("")
def set_device(self, device) -> None:
def set_device(self, device, memory_backend_params=None) -> None:
"""
An optional function that allows the filter to get the device if it is required to use tensorflow ops
:param device: the device to use
:param memory_backend_params: parameters associated with the memory backend
:return: None
"""
pass
@@ -84,13 +85,13 @@ class OutputFilter(Filter):
duplicate.i_am_a_reference_filter = False
return duplicate
def set_device(self, device) -> None:
def set_device(self, device, memory_backend_params=None) -> None:
"""
An optional function that allows the filter to get the device if it is required to use tensorflow ops
:param device: the device to use
:return: None
"""
[f.set_device(device) for f in self.action_filters.values()]
[f.set_device(device, memory_backend_params) for f in self.action_filters.values()]
def set_session(self, sess) -> None:
"""
@@ -225,14 +226,14 @@ class InputFilter(Filter):
duplicate.i_am_a_reference_filter = False
return duplicate
def set_device(self, device) -> None:
def set_device(self, device, memory_backend_params=None) -> None:
"""
An optional function that allows the filter to get the device if it is required to use tensorflow ops
:param device: the device to use
:return: None
"""
[f.set_device(device) for f in self.reward_filters.values()]
[[f.set_device(device) for f in filters.values()] for filters in self.observation_filters.values()]
[f.set_device(device, memory_backend_params) for f in self.reward_filters.values()]
[[f.set_device(device, memory_backend_params) for f in filters.values()] for filters in self.observation_filters.values()]
def set_session(self, sess) -> None:
"""

View File

@@ -41,13 +41,14 @@ class ObservationNormalizationFilter(ObservationFilter):
self.supports_batching = True
self.observation_space = None
def set_device(self, device) -> None:
def set_device(self, device, memory_backend_params=None) -> None:
"""
An optional function that allows the filter to get the device if it is required to use tensorflow ops
:param device: the device to use
:return: None
"""
self.running_observation_stats = SharedRunningStats(device, name=self.name, create_ops=False)
self.running_observation_stats = SharedRunningStats(device, name=self.name, create_ops=False,
pubsub_params=memory_backend_params)
def set_session(self, sess) -> None:
"""

View File

@@ -38,13 +38,14 @@ class RewardNormalizationFilter(RewardFilter):
self.clip_max = clip_max
self.running_rewards_stats = None
def set_device(self, device) -> None:
def set_device(self, device, memory_backend_params=None) -> None:
"""
An optional function that allows the filter to get the device if it is required to use tensorflow ops
:param device: the device to use
:return: None
"""
self.running_rewards_stats = SharedRunningStats(device, name='rewards_stats')
self.running_rewards_stats = SharedRunningStats(device, name='rewards_stats',
pubsub_params=memory_backend_params)
def set_session(self, sess) -> None:
"""