mirror of
https://github.com/gryf/coach.git
synced 2026-03-23 11:03:32 +01:00
Enable distributed SharedRunningStats (#81)
- Use Redis pub/sub for updating SharedRunningStats.
This commit is contained in:
committed by
Gal Leibovich
parent
875d6ef017
commit
a849c17e46
@@ -46,10 +46,11 @@ class Filter(object):
|
||||
"""
|
||||
raise NotImplementedError("")
|
||||
|
||||
def set_device(self, device) -> None:
|
||||
def set_device(self, device, memory_backend_params=None) -> None:
|
||||
"""
|
||||
An optional function that allows the filter to get the device if it is required to use tensorflow ops
|
||||
:param device: the device to use
|
||||
:param memory_backend_params: parameters associated with the memory backend
|
||||
:return: None
|
||||
"""
|
||||
pass
|
||||
@@ -84,13 +85,13 @@ class OutputFilter(Filter):
|
||||
duplicate.i_am_a_reference_filter = False
|
||||
return duplicate
|
||||
|
||||
def set_device(self, device) -> None:
|
||||
def set_device(self, device, memory_backend_params=None) -> None:
|
||||
"""
|
||||
An optional function that allows the filter to get the device if it is required to use tensorflow ops
|
||||
:param device: the device to use
|
||||
:return: None
|
||||
"""
|
||||
[f.set_device(device) for f in self.action_filters.values()]
|
||||
[f.set_device(device, memory_backend_params) for f in self.action_filters.values()]
|
||||
|
||||
def set_session(self, sess) -> None:
|
||||
"""
|
||||
@@ -225,14 +226,14 @@ class InputFilter(Filter):
|
||||
duplicate.i_am_a_reference_filter = False
|
||||
return duplicate
|
||||
|
||||
def set_device(self, device) -> None:
|
||||
def set_device(self, device, memory_backend_params=None) -> None:
|
||||
"""
|
||||
An optional function that allows the filter to get the device if it is required to use tensorflow ops
|
||||
:param device: the device to use
|
||||
:return: None
|
||||
"""
|
||||
[f.set_device(device) for f in self.reward_filters.values()]
|
||||
[[f.set_device(device) for f in filters.values()] for filters in self.observation_filters.values()]
|
||||
[f.set_device(device, memory_backend_params) for f in self.reward_filters.values()]
|
||||
[[f.set_device(device, memory_backend_params) for f in filters.values()] for filters in self.observation_filters.values()]
|
||||
|
||||
def set_session(self, sess) -> None:
|
||||
"""
|
||||
|
||||
@@ -41,13 +41,14 @@ class ObservationNormalizationFilter(ObservationFilter):
|
||||
self.supports_batching = True
|
||||
self.observation_space = None
|
||||
|
||||
def set_device(self, device) -> None:
|
||||
def set_device(self, device, memory_backend_params=None) -> None:
|
||||
"""
|
||||
An optional function that allows the filter to get the device if it is required to use tensorflow ops
|
||||
:param device: the device to use
|
||||
:return: None
|
||||
"""
|
||||
self.running_observation_stats = SharedRunningStats(device, name=self.name, create_ops=False)
|
||||
self.running_observation_stats = SharedRunningStats(device, name=self.name, create_ops=False,
|
||||
pubsub_params=memory_backend_params)
|
||||
|
||||
def set_session(self, sess) -> None:
|
||||
"""
|
||||
|
||||
@@ -38,13 +38,14 @@ class RewardNormalizationFilter(RewardFilter):
|
||||
self.clip_max = clip_max
|
||||
self.running_rewards_stats = None
|
||||
|
||||
def set_device(self, device) -> None:
|
||||
def set_device(self, device, memory_backend_params=None) -> None:
|
||||
"""
|
||||
An optional function that allows the filter to get the device if it is required to use tensorflow ops
|
||||
:param device: the device to use
|
||||
:return: None
|
||||
"""
|
||||
self.running_rewards_stats = SharedRunningStats(device, name='rewards_stats')
|
||||
self.running_rewards_stats = SharedRunningStats(device, name='rewards_stats',
|
||||
pubsub_params=memory_backend_params)
|
||||
|
||||
def set_session(self, sess) -> None:
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user