mirror of
https://github.com/gryf/coach.git
synced 2026-03-14 13:45:46 +01:00
Distiller's AMC induced changes (#359)
* override episode rewards with the last transition reward * EWMA normalization filter * allowing control over when the pre_network filter runs
This commit is contained in:
@@ -338,11 +338,14 @@ class InputFilter(Filter):
|
||||
state_object[observation_name] = filtered_observations[i]
|
||||
|
||||
# filter reward
|
||||
for f in filtered_data:
|
||||
filtered_reward = f.reward
|
||||
for filter in self._reward_filters.values():
|
||||
filtered_reward = filter.filter(filtered_reward, update_internal_state)
|
||||
f.reward = filtered_reward
|
||||
for filter in self._reward_filters.values():
|
||||
if filter.supports_batching:
|
||||
filtered_rewards = filter.filter([f.reward for f in filtered_data], update_internal_state)
|
||||
for d, filtered_reward in zip(filtered_data, filtered_rewards):
|
||||
d.reward = filtered_reward
|
||||
else:
|
||||
for d in filtered_data:
|
||||
d.reward = filter.filter(d.reward, update_internal_state)
|
||||
|
||||
return filtered_data
|
||||
|
||||
|
||||
@@ -1,8 +1,11 @@
|
||||
from .reward_rescale_filter import RewardRescaleFilter
|
||||
from .reward_clipping_filter import RewardClippingFilter
|
||||
from .reward_normalization_filter import RewardNormalizationFilter
|
||||
from .reward_ewma_normalization_filter import RewardEwmaNormalizationFilter
|
||||
|
||||
__all__ = [
|
||||
'RewardRescaleFilter',
|
||||
'RewardClippingFilter',
|
||||
'RewardNormalizationFilter'
|
||||
'RewardNormalizationFilter',
|
||||
'RewardEwmaNormalizationFilter'
|
||||
]
|
||||
76
rl_coach/filters/reward/reward_ewma_normalization_filter.py
Normal file
76
rl_coach/filters/reward/reward_ewma_normalization_filter.py
Normal file
@@ -0,0 +1,76 @@
|
||||
#
|
||||
# Copyright (c) 2019 Intel Corporation
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import pickle
|
||||
|
||||
from rl_coach.core_types import RewardType
|
||||
from rl_coach.filters.reward.reward_filter import RewardFilter
|
||||
from rl_coach.spaces import RewardSpace
|
||||
from rl_coach.utils import get_latest_checkpoint
|
||||
|
||||
|
||||
class RewardEwmaNormalizationFilter(RewardFilter):
|
||||
"""
|
||||
Normalizes the reward values based on Exponential Weighted Moving Average.
|
||||
"""
|
||||
def __init__(self, alpha: float = 0.01):
|
||||
"""
|
||||
:param alpha: the degree of weighting decrease, a constant smoothing factor between 0 and 1.
|
||||
A higher alpha discounts older observations faster
|
||||
"""
|
||||
super().__init__()
|
||||
self.alpha = alpha
|
||||
self.moving_average = 0
|
||||
self.initialized = False
|
||||
self.checkpoint_file_extension = 'ewma'
|
||||
self.supports_batching = True
|
||||
|
||||
def filter(self, reward: RewardType, update_internal_state: bool=True) -> RewardType:
|
||||
if not isinstance(reward, np.ndarray):
|
||||
reward = np.array(reward)
|
||||
|
||||
if update_internal_state:
|
||||
mean_rewards = np.mean(reward)
|
||||
|
||||
if not self.initialized:
|
||||
self.moving_average = mean_rewards
|
||||
self.initialized = True
|
||||
else:
|
||||
self.moving_average += self.alpha * (mean_rewards - self.moving_average)
|
||||
|
||||
return reward - self.moving_average
|
||||
|
||||
def get_filtered_reward_space(self, input_reward_space: RewardSpace) -> RewardSpace:
|
||||
return input_reward_space
|
||||
|
||||
def save_state_to_checkpoint(self, checkpoint_dir: str, checkpoint_prefix: int):
|
||||
dict_to_save = {'moving_average': self.moving_average}
|
||||
|
||||
with open(os.path.join(checkpoint_dir, str(checkpoint_prefix) + '.' + self.checkpoint_file_extension), 'wb') as f:
|
||||
pickle.dump(dict_to_save, f, pickle.HIGHEST_PROTOCOL)
|
||||
|
||||
def restore_state_from_checkpoint(self, checkpoint_dir: str, checkpoint_prefix: str):
|
||||
latest_checkpoint_filename = get_latest_checkpoint(checkpoint_dir, checkpoint_prefix,
|
||||
self.checkpoint_file_extension)
|
||||
|
||||
if latest_checkpoint_filename == '':
|
||||
raise ValueError("Could not find RewardEwmaNormalizationFilter checkpoint file. ")
|
||||
|
||||
with open(os.path.join(checkpoint_dir, str(latest_checkpoint_filename)), 'rb') as f:
|
||||
saved_dict = pickle.load(f)
|
||||
self.__dict__.update(saved_dict)
|
||||
@@ -21,6 +21,7 @@ from rl_coach.spaces import RewardSpace
|
||||
class RewardFilter(Filter):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.supports_batching = False
|
||||
|
||||
def get_filtered_reward_space(self, input_reward_space: RewardSpace) -> RewardSpace:
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user