1
0
mirror of https://github.com/gryf/coach.git synced 2026-03-14 13:45:46 +01:00

Distiller's AMC induced changes (#359)

* override episode rewards with the last transition reward

* EWMA normalization filter

* allowing control over when the pre_network filter runs
This commit is contained in:
Gal Leibovich
2019-08-05 10:24:58 +03:00
committed by GitHub
parent 7df67dafa3
commit c1d1fae342
10 changed files with 137 additions and 30 deletions

View File

@@ -338,11 +338,14 @@ class InputFilter(Filter):
state_object[observation_name] = filtered_observations[i]
# filter reward
for f in filtered_data:
filtered_reward = f.reward
for filter in self._reward_filters.values():
filtered_reward = filter.filter(filtered_reward, update_internal_state)
f.reward = filtered_reward
for filter in self._reward_filters.values():
if filter.supports_batching:
filtered_rewards = filter.filter([f.reward for f in filtered_data], update_internal_state)
for d, filtered_reward in zip(filtered_data, filtered_rewards):
d.reward = filtered_reward
else:
for d in filtered_data:
d.reward = filter.filter(d.reward, update_internal_state)
return filtered_data

View File

@@ -1,8 +1,11 @@
from .reward_rescale_filter import RewardRescaleFilter
from .reward_clipping_filter import RewardClippingFilter
from .reward_normalization_filter import RewardNormalizationFilter
from .reward_ewma_normalization_filter import RewardEwmaNormalizationFilter
__all__ = [
'RewardRescaleFilter',
'RewardClippingFilter',
'RewardNormalizationFilter'
'RewardNormalizationFilter',
'RewardEwmaNormalizationFilter'
]

View File

@@ -0,0 +1,76 @@
#
# Copyright (c) 2019 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os
import numpy as np
import pickle
from rl_coach.core_types import RewardType
from rl_coach.filters.reward.reward_filter import RewardFilter
from rl_coach.spaces import RewardSpace
from rl_coach.utils import get_latest_checkpoint
class RewardEwmaNormalizationFilter(RewardFilter):
"""
Normalizes the reward values based on Exponential Weighted Moving Average.
"""
def __init__(self, alpha: float = 0.01):
"""
:param alpha: the degree of weighting decrease, a constant smoothing factor between 0 and 1.
A higher alpha discounts older observations faster
"""
super().__init__()
self.alpha = alpha
self.moving_average = 0
self.initialized = False
self.checkpoint_file_extension = 'ewma'
self.supports_batching = True
def filter(self, reward: RewardType, update_internal_state: bool=True) -> RewardType:
if not isinstance(reward, np.ndarray):
reward = np.array(reward)
if update_internal_state:
mean_rewards = np.mean(reward)
if not self.initialized:
self.moving_average = mean_rewards
self.initialized = True
else:
self.moving_average += self.alpha * (mean_rewards - self.moving_average)
return reward - self.moving_average
def get_filtered_reward_space(self, input_reward_space: RewardSpace) -> RewardSpace:
return input_reward_space
def save_state_to_checkpoint(self, checkpoint_dir: str, checkpoint_prefix: int):
dict_to_save = {'moving_average': self.moving_average}
with open(os.path.join(checkpoint_dir, str(checkpoint_prefix) + '.' + self.checkpoint_file_extension), 'wb') as f:
pickle.dump(dict_to_save, f, pickle.HIGHEST_PROTOCOL)
def restore_state_from_checkpoint(self, checkpoint_dir: str, checkpoint_prefix: str):
latest_checkpoint_filename = get_latest_checkpoint(checkpoint_dir, checkpoint_prefix,
self.checkpoint_file_extension)
if latest_checkpoint_filename == '':
raise ValueError("Could not find RewardEwmaNormalizationFilter checkpoint file. ")
with open(os.path.join(checkpoint_dir, str(latest_checkpoint_filename)), 'rb') as f:
saved_dict = pickle.load(f)
self.__dict__.update(saved_dict)

View File

@@ -21,6 +21,7 @@ from rl_coach.spaces import RewardSpace
class RewardFilter(Filter):
def __init__(self):
super().__init__()
self.supports_batching = False
def get_filtered_reward_space(self, input_reward_space: RewardSpace) -> RewardSpace:
"""