mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 03:00:20 +01:00
Distiller's AMC induced changes (#359)
* override episode rewards with the last transition reward * EWMA normalization filter * allowing control over when the pre_network filter runs
This commit is contained in:
@@ -573,6 +573,9 @@ class Agent(AgentInterface):
|
||||
|
||||
if self.phase != RunPhase.TEST:
|
||||
if isinstance(self.memory, EpisodicExperienceReplay):
|
||||
if self.ap.algorithm.override_episode_rewards_with_the_last_transition_reward:
|
||||
for t in self.current_episode_buffer.transitions:
|
||||
t.reward = self.current_episode_buffer.transitions[-1].reward
|
||||
self.call_memory('store_episode', self.current_episode_buffer)
|
||||
elif self.ap.algorithm.store_transitions_only_when_episodes_are_terminated:
|
||||
for transition in self.current_episode_buffer.transitions:
|
||||
@@ -727,7 +730,8 @@ class Agent(AgentInterface):
|
||||
# update counters
|
||||
self.training_iteration += 1
|
||||
if self.pre_network_filter is not None:
|
||||
batch = self.pre_network_filter.filter(batch, update_internal_state=False, deep_copy=False)
|
||||
update_internal_state = self.ap.algorithm.update_pre_network_filters_state_on_train
|
||||
batch = self.pre_network_filter.filter(batch, update_internal_state=update_internal_state, deep_copy=False)
|
||||
|
||||
# if the batch returned empty then there are not enough samples in the replay buffer -> skip
|
||||
# training step
|
||||
@@ -837,7 +841,8 @@ class Agent(AgentInterface):
|
||||
# informed action
|
||||
if self.pre_network_filter is not None:
|
||||
# before choosing an action, first use the pre_network_filter to filter out the current state
|
||||
update_filter_internal_state = self.phase is not RunPhase.TEST
|
||||
update_filter_internal_state = self.ap.algorithm.update_pre_network_filters_state_on_inference and \
|
||||
self.phase is not RunPhase.TEST
|
||||
curr_state = self.run_pre_network_filter_for_inference(self.curr_state, update_filter_internal_state)
|
||||
|
||||
else:
|
||||
@@ -865,6 +870,10 @@ class Agent(AgentInterface):
|
||||
:return: The filtered state
|
||||
"""
|
||||
dummy_env_response = EnvResponse(next_state=state, reward=0, game_over=False)
|
||||
|
||||
# TODO actually we only want to run the observation filters. No point in running the reward filters as the
|
||||
# filtered reward is being ignored anyway (and it might unncecessarily affect the reward filters' internal
|
||||
# state).
|
||||
return self.pre_network_filter.filter(dummy_env_response,
|
||||
update_internal_state=update_filter_internal_state)[0].next_state
|
||||
|
||||
|
||||
@@ -113,6 +113,8 @@ class ClippedPPOAlgorithmParameters(AlgorithmParameters):
|
||||
self.normalization_stats = None
|
||||
self.clipping_decay_schedule = ConstantSchedule(1)
|
||||
self.act_for_full_episodes = True
|
||||
self.update_pre_network_filters_state_on_train = True
|
||||
self.update_pre_network_filters_state_on_inference = False
|
||||
|
||||
|
||||
class ClippedPPOAgentParameters(AgentParameters):
|
||||
@@ -303,7 +305,9 @@ class ClippedPPOAgent(ActorCriticAgent):
|
||||
network.set_is_training(True)
|
||||
|
||||
dataset = self.memory.transitions
|
||||
dataset = self.pre_network_filter.filter(dataset, deep_copy=False)
|
||||
update_internal_state = self.ap.algorithm.update_pre_network_filters_state_on_train
|
||||
dataset = self.pre_network_filter.filter(dataset, deep_copy=False,
|
||||
update_internal_state=update_internal_state)
|
||||
batch = Batch(dataset)
|
||||
|
||||
for training_step in range(self.ap.algorithm.num_consecutive_training_steps):
|
||||
@@ -329,7 +333,9 @@ class ClippedPPOAgent(ActorCriticAgent):
|
||||
|
||||
def run_pre_network_filter_for_inference(self, state: StateType, update_internal_state: bool=False):
|
||||
dummy_env_response = EnvResponse(next_state=state, reward=0, game_over=False)
|
||||
return self.pre_network_filter.filter(dummy_env_response, update_internal_state=False)[0].next_state
|
||||
update_internal_state = self.ap.algorithm.update_pre_network_filters_state_on_inference
|
||||
return self.pre_network_filter.filter(
|
||||
dummy_env_response, update_internal_state=update_internal_state)[0].next_state
|
||||
|
||||
def choose_action(self, curr_state):
|
||||
self.ap.algorithm.clipping_decay_schedule.step()
|
||||
|
||||
@@ -213,6 +213,14 @@ class AlgorithmParameters(Parameters):
|
||||
# Support for parameter noise
|
||||
self.supports_parameter_noise = False
|
||||
|
||||
# Override, in retrospective, all the episode rewards with the last reward in the episode
|
||||
# (sometimes useful for sparse, end of the episode, rewards problems)
|
||||
self.override_episode_rewards_with_the_last_transition_reward = False
|
||||
|
||||
# Filters - TODO consider creating a FilterParameters class and initialize the filters with it
|
||||
self.update_pre_network_filters_state_on_train = False
|
||||
self.update_pre_network_filters_state_on_inference = True
|
||||
|
||||
|
||||
class PresetValidationParameters(Parameters):
|
||||
def __init__(self,
|
||||
|
||||
@@ -88,9 +88,6 @@ class TruncatedNormal(ContinuousActionExplorationPolicy):
|
||||
else:
|
||||
action_values_std = current_noise
|
||||
|
||||
# scale the noise to the action space range
|
||||
action_values_std = current_noise * (self.action_space.high - self.action_space.low)
|
||||
|
||||
# extract the mean values
|
||||
if isinstance(action_values, list):
|
||||
# the action values are expected to be a list with the action mean and optionally the action stdev
|
||||
|
||||
@@ -338,11 +338,14 @@ class InputFilter(Filter):
|
||||
state_object[observation_name] = filtered_observations[i]
|
||||
|
||||
# filter reward
|
||||
for f in filtered_data:
|
||||
filtered_reward = f.reward
|
||||
for filter in self._reward_filters.values():
|
||||
filtered_reward = filter.filter(filtered_reward, update_internal_state)
|
||||
f.reward = filtered_reward
|
||||
for filter in self._reward_filters.values():
|
||||
if filter.supports_batching:
|
||||
filtered_rewards = filter.filter([f.reward for f in filtered_data], update_internal_state)
|
||||
for d, filtered_reward in zip(filtered_data, filtered_rewards):
|
||||
d.reward = filtered_reward
|
||||
else:
|
||||
for d in filtered_data:
|
||||
d.reward = filter.filter(d.reward, update_internal_state)
|
||||
|
||||
return filtered_data
|
||||
|
||||
|
||||
@@ -1,8 +1,11 @@
|
||||
from .reward_rescale_filter import RewardRescaleFilter
|
||||
from .reward_clipping_filter import RewardClippingFilter
|
||||
from .reward_normalization_filter import RewardNormalizationFilter
|
||||
from .reward_ewma_normalization_filter import RewardEwmaNormalizationFilter
|
||||
|
||||
__all__ = [
|
||||
'RewardRescaleFilter',
|
||||
'RewardClippingFilter',
|
||||
'RewardNormalizationFilter'
|
||||
'RewardNormalizationFilter',
|
||||
'RewardEwmaNormalizationFilter'
|
||||
]
|
||||
76
rl_coach/filters/reward/reward_ewma_normalization_filter.py
Normal file
76
rl_coach/filters/reward/reward_ewma_normalization_filter.py
Normal file
@@ -0,0 +1,76 @@
|
||||
#
|
||||
# Copyright (c) 2019 Intel Corporation
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import pickle
|
||||
|
||||
from rl_coach.core_types import RewardType
|
||||
from rl_coach.filters.reward.reward_filter import RewardFilter
|
||||
from rl_coach.spaces import RewardSpace
|
||||
from rl_coach.utils import get_latest_checkpoint
|
||||
|
||||
|
||||
class RewardEwmaNormalizationFilter(RewardFilter):
|
||||
"""
|
||||
Normalizes the reward values based on Exponential Weighted Moving Average.
|
||||
"""
|
||||
def __init__(self, alpha: float = 0.01):
|
||||
"""
|
||||
:param alpha: the degree of weighting decrease, a constant smoothing factor between 0 and 1.
|
||||
A higher alpha discounts older observations faster
|
||||
"""
|
||||
super().__init__()
|
||||
self.alpha = alpha
|
||||
self.moving_average = 0
|
||||
self.initialized = False
|
||||
self.checkpoint_file_extension = 'ewma'
|
||||
self.supports_batching = True
|
||||
|
||||
def filter(self, reward: RewardType, update_internal_state: bool=True) -> RewardType:
|
||||
if not isinstance(reward, np.ndarray):
|
||||
reward = np.array(reward)
|
||||
|
||||
if update_internal_state:
|
||||
mean_rewards = np.mean(reward)
|
||||
|
||||
if not self.initialized:
|
||||
self.moving_average = mean_rewards
|
||||
self.initialized = True
|
||||
else:
|
||||
self.moving_average += self.alpha * (mean_rewards - self.moving_average)
|
||||
|
||||
return reward - self.moving_average
|
||||
|
||||
def get_filtered_reward_space(self, input_reward_space: RewardSpace) -> RewardSpace:
|
||||
return input_reward_space
|
||||
|
||||
def save_state_to_checkpoint(self, checkpoint_dir: str, checkpoint_prefix: int):
|
||||
dict_to_save = {'moving_average': self.moving_average}
|
||||
|
||||
with open(os.path.join(checkpoint_dir, str(checkpoint_prefix) + '.' + self.checkpoint_file_extension), 'wb') as f:
|
||||
pickle.dump(dict_to_save, f, pickle.HIGHEST_PROTOCOL)
|
||||
|
||||
def restore_state_from_checkpoint(self, checkpoint_dir: str, checkpoint_prefix: str):
|
||||
latest_checkpoint_filename = get_latest_checkpoint(checkpoint_dir, checkpoint_prefix,
|
||||
self.checkpoint_file_extension)
|
||||
|
||||
if latest_checkpoint_filename == '':
|
||||
raise ValueError("Could not find RewardEwmaNormalizationFilter checkpoint file. ")
|
||||
|
||||
with open(os.path.join(checkpoint_dir, str(latest_checkpoint_filename)), 'rb') as f:
|
||||
saved_dict = pickle.load(f)
|
||||
self.__dict__.update(saved_dict)
|
||||
@@ -21,6 +21,7 @@ from rl_coach.spaces import RewardSpace
|
||||
class RewardFilter(Filter):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.supports_batching = False
|
||||
|
||||
def get_filtered_reward_space(self, input_reward_space: RewardSpace) -> RewardSpace:
|
||||
"""
|
||||
|
||||
@@ -20,6 +20,8 @@ import pickle
|
||||
import redis
|
||||
import numpy as np
|
||||
|
||||
from rl_coach.utils import get_latest_checkpoint
|
||||
|
||||
|
||||
class SharedRunningStatsSubscribe(threading.Thread):
|
||||
def __init__(self, shared_running_stats):
|
||||
@@ -109,27 +111,13 @@ class SharedRunningStats(ABC):
|
||||
def restore_state_from_checkpoint(self, checkpoint_dir: str, checkpoint_prefix: str):
|
||||
pass
|
||||
|
||||
def get_latest_checkpoint(self, checkpoint_dir: str, checkpoint_prefix: str) -> str:
|
||||
latest_checkpoint_id = -1
|
||||
latest_checkpoint = ''
|
||||
# get all checkpoint files
|
||||
for fname in os.listdir(checkpoint_dir):
|
||||
path = os.path.join(checkpoint_dir, fname)
|
||||
if os.path.isdir(path) or fname.split('.')[-1] != 'srs' or checkpoint_prefix not in fname:
|
||||
continue
|
||||
checkpoint_id = int(fname.split('_')[0])
|
||||
if checkpoint_id > latest_checkpoint_id:
|
||||
latest_checkpoint = fname
|
||||
latest_checkpoint_id = checkpoint_id
|
||||
|
||||
return latest_checkpoint
|
||||
|
||||
|
||||
class NumpySharedRunningStats(SharedRunningStats):
|
||||
def __init__(self, name, epsilon=1e-2, pubsub_params=None):
|
||||
super().__init__(name=name, pubsub_params=pubsub_params)
|
||||
self._count = epsilon
|
||||
self.epsilon = epsilon
|
||||
self.checkpoint_file_extension = 'srs'
|
||||
|
||||
def set_params(self, shape=[1], clip_values=None):
|
||||
self._shape = shape
|
||||
@@ -185,11 +173,12 @@ class NumpySharedRunningStats(SharedRunningStats):
|
||||
'_sum': self._sum,
|
||||
'_sum_squares': self._sum_squares}
|
||||
|
||||
with open(os.path.join(checkpoint_dir, str(checkpoint_prefix) + '.srs'), 'wb') as f:
|
||||
with open(os.path.join(checkpoint_dir, str(checkpoint_prefix) + '.' + self.checkpoint_file_extension), 'wb') as f:
|
||||
pickle.dump(dict_to_save, f, pickle.HIGHEST_PROTOCOL)
|
||||
|
||||
def restore_state_from_checkpoint(self, checkpoint_dir: str, checkpoint_prefix: str):
|
||||
latest_checkpoint_filename = self.get_latest_checkpoint(checkpoint_dir, checkpoint_prefix)
|
||||
latest_checkpoint_filename = get_latest_checkpoint(checkpoint_dir, checkpoint_prefix,
|
||||
self.checkpoint_file_extension)
|
||||
|
||||
if latest_checkpoint_filename == '':
|
||||
raise ValueError("Could not find NumpySharedRunningStats checkpoint file. ")
|
||||
|
||||
@@ -532,3 +532,18 @@ def start_shell_command_and_wait(command):
|
||||
|
||||
def indent_string(string):
|
||||
return '\t' + string.replace('\n', '\n\t')
|
||||
|
||||
def get_latest_checkpoint(checkpoint_dir: str, checkpoint_prefix: str, checkpoint_file_extension: str) -> str:
|
||||
latest_checkpoint_id = -1
|
||||
latest_checkpoint = ''
|
||||
# get all checkpoint files
|
||||
for fname in os.listdir(checkpoint_dir):
|
||||
path = os.path.join(checkpoint_dir, fname)
|
||||
if os.path.isdir(path) or fname.split('.')[-1] != checkpoint_file_extension or checkpoint_prefix not in fname:
|
||||
continue
|
||||
checkpoint_id = int(fname.split('_')[0])
|
||||
if checkpoint_id > latest_checkpoint_id:
|
||||
latest_checkpoint = fname
|
||||
latest_checkpoint_id = checkpoint_id
|
||||
|
||||
return latest_checkpoint
|
||||
|
||||
Reference in New Issue
Block a user