1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-18 11:40:18 +01:00

Numpy shared running stats (#97)

This commit is contained in:
Gal Leibovich
2018-11-18 14:46:40 +02:00
committed by GitHub
parent e1fa6e9681
commit 6caf721d1c
10 changed files with 226 additions and 114 deletions

View File

@@ -46,11 +46,12 @@ class Filter(object):
"""
raise NotImplementedError("")
def set_device(self, device, memory_backend_params=None) -> None:
def set_device(self, device, memory_backend_params=None, mode='numpy') -> None:
"""
An optional function that allows the filter to get the device if it is required to use tensorflow ops
:param device: the device to use
:param memory_backend_params: parameters associated with the memory backend
:param mode: arithmetic backend to be used {numpy | tf}
:return: None
"""
pass
@@ -85,13 +86,13 @@ class OutputFilter(Filter):
duplicate.i_am_a_reference_filter = False
return duplicate
def set_device(self, device, memory_backend_params=None) -> None:
def set_device(self, device, memory_backend_params=None, mode='numpy') -> None:
"""
An optional function that allows the filter to get the device if it is required to use tensorflow ops
:param device: the device to use
:return: None
"""
[f.set_device(device, memory_backend_params) for f in self.action_filters.values()]
[f.set_device(device, memory_backend_params, mode='numpy') for f in self.action_filters.values()]
def set_session(self, sess) -> None:
"""
@@ -226,14 +227,14 @@ class InputFilter(Filter):
duplicate.i_am_a_reference_filter = False
return duplicate
def set_device(self, device, memory_backend_params=None) -> None:
def set_device(self, device, memory_backend_params=None, mode='numpy') -> None:
"""
An optional function that allows the filter to get the device if it is required to use tensorflow ops
:param device: the device to use
:return: None
"""
[f.set_device(device, memory_backend_params) for f in self.reward_filters.values()]
[[f.set_device(device, memory_backend_params) for f in filters.values()] for filters in self.observation_filters.values()]
[f.set_device(device, memory_backend_params, mode) for f in self.reward_filters.values()]
[[f.set_device(device, memory_backend_params, mode) for f in filters.values()] for filters in self.observation_filters.values()]
def set_session(self, sess) -> None:
"""

View File

@@ -17,10 +17,11 @@ from typing import List
import numpy as np
from rl_coach.architectures.tensorflow_components.shared_variables import SharedRunningStats
from rl_coach.architectures.tensorflow_components.shared_variables import SharedRunningStats, TFSharedRunningStats
from rl_coach.core_types import ObservationType
from rl_coach.filters.observation.observation_filter import ObservationFilter
from rl_coach.spaces import ObservationSpace
from rl_coach.utilities.shared_running_stats import NumpySharedRunningStats, NumpySharedRunningStats
class ObservationNormalizationFilter(ObservationFilter):
@@ -42,14 +43,20 @@ class ObservationNormalizationFilter(ObservationFilter):
self.supports_batching = True
self.observation_space = None
def set_device(self, device, memory_backend_params=None) -> None:
def set_device(self, device, memory_backend_params=None, mode='numpy') -> None:
"""
An optional function that allows the filter to get the device if it is required to use tensorflow ops
:param device: the device to use
:memory_backend_params: if not None, holds params for a memory backend for sharing data (e.g. Redis)
:param mode: the arithmetic module to use {'tf' | 'numpy'}
:return: None
"""
self.running_observation_stats = SharedRunningStats(device, name=self.name, create_ops=False,
if mode == 'tf':
self.running_observation_stats = TFSharedRunningStats(device, name=self.name, create_ops=False,
pubsub_params=memory_backend_params)
elif mode == 'numpy':
self.running_observation_stats = NumpySharedRunningStats(name=self.name,
pubsub_params=memory_backend_params)
def set_session(self, sess) -> None:
"""
@@ -66,11 +73,9 @@ class ObservationNormalizationFilter(ObservationFilter):
self.last_mean = self.running_observation_stats.mean
self.last_stdev = self.running_observation_stats.std
# TODO: make sure that a batch is given here
return self.running_observation_stats.normalize(observations)
def get_filtered_observation_space(self, input_observation_space: ObservationSpace) -> ObservationSpace:
self.running_observation_stats.create_ops(shape=input_observation_space.shape,
self.running_observation_stats.set_params(shape=input_observation_space.shape,
clip_values=(self.clip_min, self.clip_max))
return input_observation_space

View File

@@ -17,10 +17,11 @@
import numpy as np
from rl_coach.architectures.tensorflow_components.shared_variables import SharedRunningStats
from rl_coach.architectures.tensorflow_components.shared_variables import TFSharedRunningStats
from rl_coach.core_types import RewardType
from rl_coach.filters.reward.reward_filter import RewardFilter
from rl_coach.spaces import RewardSpace
from rl_coach.utilities.shared_running_stats import NumpySharedRunningStats
class RewardNormalizationFilter(RewardFilter):
@@ -39,14 +40,19 @@ class RewardNormalizationFilter(RewardFilter):
self.clip_max = clip_max
self.running_rewards_stats = None
def set_device(self, device, memory_backend_params=None) -> None:
def set_device(self, device, memory_backend_params=None, mode='numpy') -> None:
"""
An optional function that allows the filter to get the device if it is required to use tensorflow ops
:param device: the device to use
:return: None
"""
self.running_rewards_stats = SharedRunningStats(device, name='rewards_stats',
pubsub_params=memory_backend_params)
if mode == 'tf':
self.running_rewards_stats = TFSharedRunningStats(device, name='rewards_stats', create_ops=False,
pubsub_params=memory_backend_params)
elif mode == 'numpy':
self.running_rewards_stats = NumpySharedRunningStats(name='rewards_stats',
pubsub_params=memory_backend_params)
def set_session(self, sess) -> None:
"""