mirror of
https://github.com/gryf/coach.git
synced 2025-12-18 11:40:18 +01:00
Numpy shared running stats (#97)
This commit is contained in:
@@ -46,11 +46,12 @@ class Filter(object):
|
||||
"""
|
||||
raise NotImplementedError("")
|
||||
|
||||
def set_device(self, device, memory_backend_params=None) -> None:
|
||||
def set_device(self, device, memory_backend_params=None, mode='numpy') -> None:
|
||||
"""
|
||||
An optional function that allows the filter to get the device if it is required to use tensorflow ops
|
||||
:param device: the device to use
|
||||
:param memory_backend_params: parameters associated with the memory backend
|
||||
:param mode: arithmetic backend to be used {numpy | tf}
|
||||
:return: None
|
||||
"""
|
||||
pass
|
||||
@@ -85,13 +86,13 @@ class OutputFilter(Filter):
|
||||
duplicate.i_am_a_reference_filter = False
|
||||
return duplicate
|
||||
|
||||
def set_device(self, device, memory_backend_params=None) -> None:
|
||||
def set_device(self, device, memory_backend_params=None, mode='numpy') -> None:
|
||||
"""
|
||||
An optional function that allows the filter to get the device if it is required to use tensorflow ops
|
||||
:param device: the device to use
|
||||
:return: None
|
||||
"""
|
||||
[f.set_device(device, memory_backend_params) for f in self.action_filters.values()]
|
||||
[f.set_device(device, memory_backend_params, mode='numpy') for f in self.action_filters.values()]
|
||||
|
||||
def set_session(self, sess) -> None:
|
||||
"""
|
||||
@@ -226,14 +227,14 @@ class InputFilter(Filter):
|
||||
duplicate.i_am_a_reference_filter = False
|
||||
return duplicate
|
||||
|
||||
def set_device(self, device, memory_backend_params=None) -> None:
|
||||
def set_device(self, device, memory_backend_params=None, mode='numpy') -> None:
|
||||
"""
|
||||
An optional function that allows the filter to get the device if it is required to use tensorflow ops
|
||||
:param device: the device to use
|
||||
:return: None
|
||||
"""
|
||||
[f.set_device(device, memory_backend_params) for f in self.reward_filters.values()]
|
||||
[[f.set_device(device, memory_backend_params) for f in filters.values()] for filters in self.observation_filters.values()]
|
||||
[f.set_device(device, memory_backend_params, mode) for f in self.reward_filters.values()]
|
||||
[[f.set_device(device, memory_backend_params, mode) for f in filters.values()] for filters in self.observation_filters.values()]
|
||||
|
||||
def set_session(self, sess) -> None:
|
||||
"""
|
||||
|
||||
@@ -17,10 +17,11 @@ from typing import List
|
||||
|
||||
import numpy as np
|
||||
|
||||
from rl_coach.architectures.tensorflow_components.shared_variables import SharedRunningStats
|
||||
from rl_coach.architectures.tensorflow_components.shared_variables import SharedRunningStats, TFSharedRunningStats
|
||||
from rl_coach.core_types import ObservationType
|
||||
from rl_coach.filters.observation.observation_filter import ObservationFilter
|
||||
from rl_coach.spaces import ObservationSpace
|
||||
from rl_coach.utilities.shared_running_stats import NumpySharedRunningStats, NumpySharedRunningStats
|
||||
|
||||
|
||||
class ObservationNormalizationFilter(ObservationFilter):
|
||||
@@ -42,14 +43,20 @@ class ObservationNormalizationFilter(ObservationFilter):
|
||||
self.supports_batching = True
|
||||
self.observation_space = None
|
||||
|
||||
def set_device(self, device, memory_backend_params=None) -> None:
|
||||
def set_device(self, device, memory_backend_params=None, mode='numpy') -> None:
|
||||
"""
|
||||
An optional function that allows the filter to get the device if it is required to use tensorflow ops
|
||||
:param device: the device to use
|
||||
:memory_backend_params: if not None, holds params for a memory backend for sharing data (e.g. Redis)
|
||||
:param mode: the arithmetic module to use {'tf' | 'numpy'}
|
||||
:return: None
|
||||
"""
|
||||
self.running_observation_stats = SharedRunningStats(device, name=self.name, create_ops=False,
|
||||
if mode == 'tf':
|
||||
self.running_observation_stats = TFSharedRunningStats(device, name=self.name, create_ops=False,
|
||||
pubsub_params=memory_backend_params)
|
||||
elif mode == 'numpy':
|
||||
self.running_observation_stats = NumpySharedRunningStats(name=self.name,
|
||||
pubsub_params=memory_backend_params)
|
||||
|
||||
def set_session(self, sess) -> None:
|
||||
"""
|
||||
@@ -66,11 +73,9 @@ class ObservationNormalizationFilter(ObservationFilter):
|
||||
self.last_mean = self.running_observation_stats.mean
|
||||
self.last_stdev = self.running_observation_stats.std
|
||||
|
||||
# TODO: make sure that a batch is given here
|
||||
return self.running_observation_stats.normalize(observations)
|
||||
|
||||
def get_filtered_observation_space(self, input_observation_space: ObservationSpace) -> ObservationSpace:
|
||||
|
||||
self.running_observation_stats.create_ops(shape=input_observation_space.shape,
|
||||
self.running_observation_stats.set_params(shape=input_observation_space.shape,
|
||||
clip_values=(self.clip_min, self.clip_max))
|
||||
return input_observation_space
|
||||
|
||||
@@ -17,10 +17,11 @@
|
||||
|
||||
import numpy as np
|
||||
|
||||
from rl_coach.architectures.tensorflow_components.shared_variables import SharedRunningStats
|
||||
from rl_coach.architectures.tensorflow_components.shared_variables import TFSharedRunningStats
|
||||
from rl_coach.core_types import RewardType
|
||||
from rl_coach.filters.reward.reward_filter import RewardFilter
|
||||
from rl_coach.spaces import RewardSpace
|
||||
from rl_coach.utilities.shared_running_stats import NumpySharedRunningStats
|
||||
|
||||
|
||||
class RewardNormalizationFilter(RewardFilter):
|
||||
@@ -39,14 +40,19 @@ class RewardNormalizationFilter(RewardFilter):
|
||||
self.clip_max = clip_max
|
||||
self.running_rewards_stats = None
|
||||
|
||||
def set_device(self, device, memory_backend_params=None) -> None:
|
||||
def set_device(self, device, memory_backend_params=None, mode='numpy') -> None:
|
||||
"""
|
||||
An optional function that allows the filter to get the device if it is required to use tensorflow ops
|
||||
:param device: the device to use
|
||||
:return: None
|
||||
"""
|
||||
self.running_rewards_stats = SharedRunningStats(device, name='rewards_stats',
|
||||
pubsub_params=memory_backend_params)
|
||||
|
||||
if mode == 'tf':
|
||||
self.running_rewards_stats = TFSharedRunningStats(device, name='rewards_stats', create_ops=False,
|
||||
pubsub_params=memory_backend_params)
|
||||
elif mode == 'numpy':
|
||||
self.running_rewards_stats = NumpySharedRunningStats(name='rewards_stats',
|
||||
pubsub_params=memory_backend_params)
|
||||
|
||||
def set_session(self, sess) -> None:
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user