1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 11:10:20 +01:00

Distiller's AMC induced changes (#359)

* override episode rewards with the last transition reward

* EWMA normalization filter

* allowing control over when the pre_network filter runs
This commit is contained in:
Gal Leibovich
2019-08-05 10:24:58 +03:00
committed by GitHub
parent 7df67dafa3
commit c1d1fae342
10 changed files with 137 additions and 30 deletions

View File

@@ -573,6 +573,9 @@ class Agent(AgentInterface):
if self.phase != RunPhase.TEST: if self.phase != RunPhase.TEST:
if isinstance(self.memory, EpisodicExperienceReplay): if isinstance(self.memory, EpisodicExperienceReplay):
if self.ap.algorithm.override_episode_rewards_with_the_last_transition_reward:
for t in self.current_episode_buffer.transitions:
t.reward = self.current_episode_buffer.transitions[-1].reward
self.call_memory('store_episode', self.current_episode_buffer) self.call_memory('store_episode', self.current_episode_buffer)
elif self.ap.algorithm.store_transitions_only_when_episodes_are_terminated: elif self.ap.algorithm.store_transitions_only_when_episodes_are_terminated:
for transition in self.current_episode_buffer.transitions: for transition in self.current_episode_buffer.transitions:
@@ -727,7 +730,8 @@ class Agent(AgentInterface):
# update counters # update counters
self.training_iteration += 1 self.training_iteration += 1
if self.pre_network_filter is not None: if self.pre_network_filter is not None:
batch = self.pre_network_filter.filter(batch, update_internal_state=False, deep_copy=False) update_internal_state = self.ap.algorithm.update_pre_network_filters_state_on_train
batch = self.pre_network_filter.filter(batch, update_internal_state=update_internal_state, deep_copy=False)
# if the batch returned empty then there are not enough samples in the replay buffer -> skip # if the batch returned empty then there are not enough samples in the replay buffer -> skip
# training step # training step
@@ -837,7 +841,8 @@ class Agent(AgentInterface):
# informed action # informed action
if self.pre_network_filter is not None: if self.pre_network_filter is not None:
# before choosing an action, first use the pre_network_filter to filter out the current state # before choosing an action, first use the pre_network_filter to filter out the current state
update_filter_internal_state = self.phase is not RunPhase.TEST update_filter_internal_state = self.ap.algorithm.update_pre_network_filters_state_on_inference and \
self.phase is not RunPhase.TEST
curr_state = self.run_pre_network_filter_for_inference(self.curr_state, update_filter_internal_state) curr_state = self.run_pre_network_filter_for_inference(self.curr_state, update_filter_internal_state)
else: else:
@@ -865,6 +870,10 @@ class Agent(AgentInterface):
:return: The filtered state :return: The filtered state
""" """
dummy_env_response = EnvResponse(next_state=state, reward=0, game_over=False) dummy_env_response = EnvResponse(next_state=state, reward=0, game_over=False)
# TODO actually we only want to run the observation filters. No point in running the reward filters as the
# filtered reward is being ignored anyway (and it might unncecessarily affect the reward filters' internal
# state).
return self.pre_network_filter.filter(dummy_env_response, return self.pre_network_filter.filter(dummy_env_response,
update_internal_state=update_filter_internal_state)[0].next_state update_internal_state=update_filter_internal_state)[0].next_state

View File

@@ -113,6 +113,8 @@ class ClippedPPOAlgorithmParameters(AlgorithmParameters):
self.normalization_stats = None self.normalization_stats = None
self.clipping_decay_schedule = ConstantSchedule(1) self.clipping_decay_schedule = ConstantSchedule(1)
self.act_for_full_episodes = True self.act_for_full_episodes = True
self.update_pre_network_filters_state_on_train = True
self.update_pre_network_filters_state_on_inference = False
class ClippedPPOAgentParameters(AgentParameters): class ClippedPPOAgentParameters(AgentParameters):
@@ -303,7 +305,9 @@ class ClippedPPOAgent(ActorCriticAgent):
network.set_is_training(True) network.set_is_training(True)
dataset = self.memory.transitions dataset = self.memory.transitions
dataset = self.pre_network_filter.filter(dataset, deep_copy=False) update_internal_state = self.ap.algorithm.update_pre_network_filters_state_on_train
dataset = self.pre_network_filter.filter(dataset, deep_copy=False,
update_internal_state=update_internal_state)
batch = Batch(dataset) batch = Batch(dataset)
for training_step in range(self.ap.algorithm.num_consecutive_training_steps): for training_step in range(self.ap.algorithm.num_consecutive_training_steps):
@@ -329,7 +333,9 @@ class ClippedPPOAgent(ActorCriticAgent):
def run_pre_network_filter_for_inference(self, state: StateType, update_internal_state: bool=False): def run_pre_network_filter_for_inference(self, state: StateType, update_internal_state: bool=False):
dummy_env_response = EnvResponse(next_state=state, reward=0, game_over=False) dummy_env_response = EnvResponse(next_state=state, reward=0, game_over=False)
return self.pre_network_filter.filter(dummy_env_response, update_internal_state=False)[0].next_state update_internal_state = self.ap.algorithm.update_pre_network_filters_state_on_inference
return self.pre_network_filter.filter(
dummy_env_response, update_internal_state=update_internal_state)[0].next_state
def choose_action(self, curr_state): def choose_action(self, curr_state):
self.ap.algorithm.clipping_decay_schedule.step() self.ap.algorithm.clipping_decay_schedule.step()

View File

@@ -213,6 +213,14 @@ class AlgorithmParameters(Parameters):
# Support for parameter noise # Support for parameter noise
self.supports_parameter_noise = False self.supports_parameter_noise = False
# Override, in retrospective, all the episode rewards with the last reward in the episode
# (sometimes useful for sparse, end of the episode, rewards problems)
self.override_episode_rewards_with_the_last_transition_reward = False
# Filters - TODO consider creating a FilterParameters class and initialize the filters with it
self.update_pre_network_filters_state_on_train = False
self.update_pre_network_filters_state_on_inference = True
class PresetValidationParameters(Parameters): class PresetValidationParameters(Parameters):
def __init__(self, def __init__(self,

View File

@@ -88,9 +88,6 @@ class TruncatedNormal(ContinuousActionExplorationPolicy):
else: else:
action_values_std = current_noise action_values_std = current_noise
# scale the noise to the action space range
action_values_std = current_noise * (self.action_space.high - self.action_space.low)
# extract the mean values # extract the mean values
if isinstance(action_values, list): if isinstance(action_values, list):
# the action values are expected to be a list with the action mean and optionally the action stdev # the action values are expected to be a list with the action mean and optionally the action stdev

View File

@@ -338,11 +338,14 @@ class InputFilter(Filter):
state_object[observation_name] = filtered_observations[i] state_object[observation_name] = filtered_observations[i]
# filter reward # filter reward
for f in filtered_data: for filter in self._reward_filters.values():
filtered_reward = f.reward if filter.supports_batching:
for filter in self._reward_filters.values(): filtered_rewards = filter.filter([f.reward for f in filtered_data], update_internal_state)
filtered_reward = filter.filter(filtered_reward, update_internal_state) for d, filtered_reward in zip(filtered_data, filtered_rewards):
f.reward = filtered_reward d.reward = filtered_reward
else:
for d in filtered_data:
d.reward = filter.filter(d.reward, update_internal_state)
return filtered_data return filtered_data

View File

@@ -1,8 +1,11 @@
from .reward_rescale_filter import RewardRescaleFilter from .reward_rescale_filter import RewardRescaleFilter
from .reward_clipping_filter import RewardClippingFilter from .reward_clipping_filter import RewardClippingFilter
from .reward_normalization_filter import RewardNormalizationFilter from .reward_normalization_filter import RewardNormalizationFilter
from .reward_ewma_normalization_filter import RewardEwmaNormalizationFilter
__all__ = [ __all__ = [
'RewardRescaleFilter', 'RewardRescaleFilter',
'RewardClippingFilter', 'RewardClippingFilter',
'RewardNormalizationFilter' 'RewardNormalizationFilter',
'RewardEwmaNormalizationFilter'
] ]

View File

@@ -0,0 +1,76 @@
#
# Copyright (c) 2019 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os
import numpy as np
import pickle
from rl_coach.core_types import RewardType
from rl_coach.filters.reward.reward_filter import RewardFilter
from rl_coach.spaces import RewardSpace
from rl_coach.utils import get_latest_checkpoint
class RewardEwmaNormalizationFilter(RewardFilter):
"""
Normalizes the reward values based on Exponential Weighted Moving Average.
"""
def __init__(self, alpha: float = 0.01):
"""
:param alpha: the degree of weighting decrease, a constant smoothing factor between 0 and 1.
A higher alpha discounts older observations faster
"""
super().__init__()
self.alpha = alpha
self.moving_average = 0
self.initialized = False
self.checkpoint_file_extension = 'ewma'
self.supports_batching = True
def filter(self, reward: RewardType, update_internal_state: bool=True) -> RewardType:
if not isinstance(reward, np.ndarray):
reward = np.array(reward)
if update_internal_state:
mean_rewards = np.mean(reward)
if not self.initialized:
self.moving_average = mean_rewards
self.initialized = True
else:
self.moving_average += self.alpha * (mean_rewards - self.moving_average)
return reward - self.moving_average
def get_filtered_reward_space(self, input_reward_space: RewardSpace) -> RewardSpace:
return input_reward_space
def save_state_to_checkpoint(self, checkpoint_dir: str, checkpoint_prefix: int):
dict_to_save = {'moving_average': self.moving_average}
with open(os.path.join(checkpoint_dir, str(checkpoint_prefix) + '.' + self.checkpoint_file_extension), 'wb') as f:
pickle.dump(dict_to_save, f, pickle.HIGHEST_PROTOCOL)
def restore_state_from_checkpoint(self, checkpoint_dir: str, checkpoint_prefix: str):
latest_checkpoint_filename = get_latest_checkpoint(checkpoint_dir, checkpoint_prefix,
self.checkpoint_file_extension)
if latest_checkpoint_filename == '':
raise ValueError("Could not find RewardEwmaNormalizationFilter checkpoint file. ")
with open(os.path.join(checkpoint_dir, str(latest_checkpoint_filename)), 'rb') as f:
saved_dict = pickle.load(f)
self.__dict__.update(saved_dict)

View File

@@ -21,6 +21,7 @@ from rl_coach.spaces import RewardSpace
class RewardFilter(Filter): class RewardFilter(Filter):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.supports_batching = False
def get_filtered_reward_space(self, input_reward_space: RewardSpace) -> RewardSpace: def get_filtered_reward_space(self, input_reward_space: RewardSpace) -> RewardSpace:
""" """

View File

@@ -20,6 +20,8 @@ import pickle
import redis import redis
import numpy as np import numpy as np
from rl_coach.utils import get_latest_checkpoint
class SharedRunningStatsSubscribe(threading.Thread): class SharedRunningStatsSubscribe(threading.Thread):
def __init__(self, shared_running_stats): def __init__(self, shared_running_stats):
@@ -109,27 +111,13 @@ class SharedRunningStats(ABC):
def restore_state_from_checkpoint(self, checkpoint_dir: str, checkpoint_prefix: str): def restore_state_from_checkpoint(self, checkpoint_dir: str, checkpoint_prefix: str):
pass pass
def get_latest_checkpoint(self, checkpoint_dir: str, checkpoint_prefix: str) -> str:
latest_checkpoint_id = -1
latest_checkpoint = ''
# get all checkpoint files
for fname in os.listdir(checkpoint_dir):
path = os.path.join(checkpoint_dir, fname)
if os.path.isdir(path) or fname.split('.')[-1] != 'srs' or checkpoint_prefix not in fname:
continue
checkpoint_id = int(fname.split('_')[0])
if checkpoint_id > latest_checkpoint_id:
latest_checkpoint = fname
latest_checkpoint_id = checkpoint_id
return latest_checkpoint
class NumpySharedRunningStats(SharedRunningStats): class NumpySharedRunningStats(SharedRunningStats):
def __init__(self, name, epsilon=1e-2, pubsub_params=None): def __init__(self, name, epsilon=1e-2, pubsub_params=None):
super().__init__(name=name, pubsub_params=pubsub_params) super().__init__(name=name, pubsub_params=pubsub_params)
self._count = epsilon self._count = epsilon
self.epsilon = epsilon self.epsilon = epsilon
self.checkpoint_file_extension = 'srs'
def set_params(self, shape=[1], clip_values=None): def set_params(self, shape=[1], clip_values=None):
self._shape = shape self._shape = shape
@@ -185,11 +173,12 @@ class NumpySharedRunningStats(SharedRunningStats):
'_sum': self._sum, '_sum': self._sum,
'_sum_squares': self._sum_squares} '_sum_squares': self._sum_squares}
with open(os.path.join(checkpoint_dir, str(checkpoint_prefix) + '.srs'), 'wb') as f: with open(os.path.join(checkpoint_dir, str(checkpoint_prefix) + '.' + self.checkpoint_file_extension), 'wb') as f:
pickle.dump(dict_to_save, f, pickle.HIGHEST_PROTOCOL) pickle.dump(dict_to_save, f, pickle.HIGHEST_PROTOCOL)
def restore_state_from_checkpoint(self, checkpoint_dir: str, checkpoint_prefix: str): def restore_state_from_checkpoint(self, checkpoint_dir: str, checkpoint_prefix: str):
latest_checkpoint_filename = self.get_latest_checkpoint(checkpoint_dir, checkpoint_prefix) latest_checkpoint_filename = get_latest_checkpoint(checkpoint_dir, checkpoint_prefix,
self.checkpoint_file_extension)
if latest_checkpoint_filename == '': if latest_checkpoint_filename == '':
raise ValueError("Could not find NumpySharedRunningStats checkpoint file. ") raise ValueError("Could not find NumpySharedRunningStats checkpoint file. ")

View File

@@ -532,3 +532,18 @@ def start_shell_command_and_wait(command):
def indent_string(string): def indent_string(string):
return '\t' + string.replace('\n', '\n\t') return '\t' + string.replace('\n', '\n\t')
def get_latest_checkpoint(checkpoint_dir: str, checkpoint_prefix: str, checkpoint_file_extension: str) -> str:
latest_checkpoint_id = -1
latest_checkpoint = ''
# get all checkpoint files
for fname in os.listdir(checkpoint_dir):
path = os.path.join(checkpoint_dir, fname)
if os.path.isdir(path) or fname.split('.')[-1] != checkpoint_file_extension or checkpoint_prefix not in fname:
continue
checkpoint_id = int(fname.split('_')[0])
if checkpoint_id > latest_checkpoint_id:
latest_checkpoint = fname
latest_checkpoint_id = checkpoint_id
return latest_checkpoint