diff --git a/rl_coach/agents/agent.py b/rl_coach/agents/agent.py index 866fe8a..3db0aaf 100644 --- a/rl_coach/agents/agent.py +++ b/rl_coach/agents/agent.py @@ -573,6 +573,9 @@ class Agent(AgentInterface): if self.phase != RunPhase.TEST: if isinstance(self.memory, EpisodicExperienceReplay): + if self.ap.algorithm.override_episode_rewards_with_the_last_transition_reward: + for t in self.current_episode_buffer.transitions: + t.reward = self.current_episode_buffer.transitions[-1].reward self.call_memory('store_episode', self.current_episode_buffer) elif self.ap.algorithm.store_transitions_only_when_episodes_are_terminated: for transition in self.current_episode_buffer.transitions: @@ -727,7 +730,8 @@ class Agent(AgentInterface): # update counters self.training_iteration += 1 if self.pre_network_filter is not None: - batch = self.pre_network_filter.filter(batch, update_internal_state=False, deep_copy=False) + update_internal_state = self.ap.algorithm.update_pre_network_filters_state_on_train + batch = self.pre_network_filter.filter(batch, update_internal_state=update_internal_state, deep_copy=False) # if the batch returned empty then there are not enough samples in the replay buffer -> skip # training step @@ -837,7 +841,8 @@ class Agent(AgentInterface): # informed action if self.pre_network_filter is not None: # before choosing an action, first use the pre_network_filter to filter out the current state - update_filter_internal_state = self.phase is not RunPhase.TEST + update_filter_internal_state = self.ap.algorithm.update_pre_network_filters_state_on_inference and \ + self.phase is not RunPhase.TEST curr_state = self.run_pre_network_filter_for_inference(self.curr_state, update_filter_internal_state) else: @@ -865,6 +870,10 @@ class Agent(AgentInterface): :return: The filtered state """ dummy_env_response = EnvResponse(next_state=state, reward=0, game_over=False) + + # TODO actually we only want to run the observation filters. No point in running the reward filters as the + # filtered reward is being ignored anyway (and it might unncecessarily affect the reward filters' internal + # state). return self.pre_network_filter.filter(dummy_env_response, update_internal_state=update_filter_internal_state)[0].next_state diff --git a/rl_coach/agents/clipped_ppo_agent.py b/rl_coach/agents/clipped_ppo_agent.py index cc29f33..1a9d202 100644 --- a/rl_coach/agents/clipped_ppo_agent.py +++ b/rl_coach/agents/clipped_ppo_agent.py @@ -113,6 +113,8 @@ class ClippedPPOAlgorithmParameters(AlgorithmParameters): self.normalization_stats = None self.clipping_decay_schedule = ConstantSchedule(1) self.act_for_full_episodes = True + self.update_pre_network_filters_state_on_train = True + self.update_pre_network_filters_state_on_inference = False class ClippedPPOAgentParameters(AgentParameters): @@ -303,7 +305,9 @@ class ClippedPPOAgent(ActorCriticAgent): network.set_is_training(True) dataset = self.memory.transitions - dataset = self.pre_network_filter.filter(dataset, deep_copy=False) + update_internal_state = self.ap.algorithm.update_pre_network_filters_state_on_train + dataset = self.pre_network_filter.filter(dataset, deep_copy=False, + update_internal_state=update_internal_state) batch = Batch(dataset) for training_step in range(self.ap.algorithm.num_consecutive_training_steps): @@ -329,7 +333,9 @@ class ClippedPPOAgent(ActorCriticAgent): def run_pre_network_filter_for_inference(self, state: StateType, update_internal_state: bool=False): dummy_env_response = EnvResponse(next_state=state, reward=0, game_over=False) - return self.pre_network_filter.filter(dummy_env_response, update_internal_state=False)[0].next_state + update_internal_state = self.ap.algorithm.update_pre_network_filters_state_on_inference + return self.pre_network_filter.filter( + dummy_env_response, update_internal_state=update_internal_state)[0].next_state def choose_action(self, curr_state): self.ap.algorithm.clipping_decay_schedule.step() diff --git a/rl_coach/base_parameters.py b/rl_coach/base_parameters.py index a84eb5b..65bc940 100644 --- a/rl_coach/base_parameters.py +++ b/rl_coach/base_parameters.py @@ -213,6 +213,14 @@ class AlgorithmParameters(Parameters): # Support for parameter noise self.supports_parameter_noise = False + # Override, in retrospective, all the episode rewards with the last reward in the episode + # (sometimes useful for sparse, end of the episode, rewards problems) + self.override_episode_rewards_with_the_last_transition_reward = False + + # Filters - TODO consider creating a FilterParameters class and initialize the filters with it + self.update_pre_network_filters_state_on_train = False + self.update_pre_network_filters_state_on_inference = True + class PresetValidationParameters(Parameters): def __init__(self, diff --git a/rl_coach/exploration_policies/truncated_normal.py b/rl_coach/exploration_policies/truncated_normal.py index 7d859ee..23f98d4 100644 --- a/rl_coach/exploration_policies/truncated_normal.py +++ b/rl_coach/exploration_policies/truncated_normal.py @@ -88,9 +88,6 @@ class TruncatedNormal(ContinuousActionExplorationPolicy): else: action_values_std = current_noise - # scale the noise to the action space range - action_values_std = current_noise * (self.action_space.high - self.action_space.low) - # extract the mean values if isinstance(action_values, list): # the action values are expected to be a list with the action mean and optionally the action stdev diff --git a/rl_coach/filters/filter.py b/rl_coach/filters/filter.py index 6881f8d..3035e81 100644 --- a/rl_coach/filters/filter.py +++ b/rl_coach/filters/filter.py @@ -338,11 +338,14 @@ class InputFilter(Filter): state_object[observation_name] = filtered_observations[i] # filter reward - for f in filtered_data: - filtered_reward = f.reward - for filter in self._reward_filters.values(): - filtered_reward = filter.filter(filtered_reward, update_internal_state) - f.reward = filtered_reward + for filter in self._reward_filters.values(): + if filter.supports_batching: + filtered_rewards = filter.filter([f.reward for f in filtered_data], update_internal_state) + for d, filtered_reward in zip(filtered_data, filtered_rewards): + d.reward = filtered_reward + else: + for d in filtered_data: + d.reward = filter.filter(d.reward, update_internal_state) return filtered_data diff --git a/rl_coach/filters/reward/__init__.py b/rl_coach/filters/reward/__init__.py index 7be10cc..40d0bdc 100644 --- a/rl_coach/filters/reward/__init__.py +++ b/rl_coach/filters/reward/__init__.py @@ -1,8 +1,11 @@ from .reward_rescale_filter import RewardRescaleFilter from .reward_clipping_filter import RewardClippingFilter from .reward_normalization_filter import RewardNormalizationFilter +from .reward_ewma_normalization_filter import RewardEwmaNormalizationFilter + __all__ = [ 'RewardRescaleFilter', 'RewardClippingFilter', - 'RewardNormalizationFilter' + 'RewardNormalizationFilter', + 'RewardEwmaNormalizationFilter' ] \ No newline at end of file diff --git a/rl_coach/filters/reward/reward_ewma_normalization_filter.py b/rl_coach/filters/reward/reward_ewma_normalization_filter.py new file mode 100644 index 0000000..968be03 --- /dev/null +++ b/rl_coach/filters/reward/reward_ewma_normalization_filter.py @@ -0,0 +1,76 @@ +# +# Copyright (c) 2019 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import os + +import numpy as np +import pickle + +from rl_coach.core_types import RewardType +from rl_coach.filters.reward.reward_filter import RewardFilter +from rl_coach.spaces import RewardSpace +from rl_coach.utils import get_latest_checkpoint + + +class RewardEwmaNormalizationFilter(RewardFilter): + """ + Normalizes the reward values based on Exponential Weighted Moving Average. + """ + def __init__(self, alpha: float = 0.01): + """ + :param alpha: the degree of weighting decrease, a constant smoothing factor between 0 and 1. + A higher alpha discounts older observations faster + """ + super().__init__() + self.alpha = alpha + self.moving_average = 0 + self.initialized = False + self.checkpoint_file_extension = 'ewma' + self.supports_batching = True + + def filter(self, reward: RewardType, update_internal_state: bool=True) -> RewardType: + if not isinstance(reward, np.ndarray): + reward = np.array(reward) + + if update_internal_state: + mean_rewards = np.mean(reward) + + if not self.initialized: + self.moving_average = mean_rewards + self.initialized = True + else: + self.moving_average += self.alpha * (mean_rewards - self.moving_average) + + return reward - self.moving_average + + def get_filtered_reward_space(self, input_reward_space: RewardSpace) -> RewardSpace: + return input_reward_space + + def save_state_to_checkpoint(self, checkpoint_dir: str, checkpoint_prefix: int): + dict_to_save = {'moving_average': self.moving_average} + + with open(os.path.join(checkpoint_dir, str(checkpoint_prefix) + '.' + self.checkpoint_file_extension), 'wb') as f: + pickle.dump(dict_to_save, f, pickle.HIGHEST_PROTOCOL) + + def restore_state_from_checkpoint(self, checkpoint_dir: str, checkpoint_prefix: str): + latest_checkpoint_filename = get_latest_checkpoint(checkpoint_dir, checkpoint_prefix, + self.checkpoint_file_extension) + + if latest_checkpoint_filename == '': + raise ValueError("Could not find RewardEwmaNormalizationFilter checkpoint file. ") + + with open(os.path.join(checkpoint_dir, str(latest_checkpoint_filename)), 'rb') as f: + saved_dict = pickle.load(f) + self.__dict__.update(saved_dict) diff --git a/rl_coach/filters/reward/reward_filter.py b/rl_coach/filters/reward/reward_filter.py index d105b8b..58d467c 100644 --- a/rl_coach/filters/reward/reward_filter.py +++ b/rl_coach/filters/reward/reward_filter.py @@ -21,6 +21,7 @@ from rl_coach.spaces import RewardSpace class RewardFilter(Filter): def __init__(self): super().__init__() + self.supports_batching = False def get_filtered_reward_space(self, input_reward_space: RewardSpace) -> RewardSpace: """ diff --git a/rl_coach/utilities/shared_running_stats.py b/rl_coach/utilities/shared_running_stats.py index b78b66b..263fae6 100644 --- a/rl_coach/utilities/shared_running_stats.py +++ b/rl_coach/utilities/shared_running_stats.py @@ -20,6 +20,8 @@ import pickle import redis import numpy as np +from rl_coach.utils import get_latest_checkpoint + class SharedRunningStatsSubscribe(threading.Thread): def __init__(self, shared_running_stats): @@ -109,27 +111,13 @@ class SharedRunningStats(ABC): def restore_state_from_checkpoint(self, checkpoint_dir: str, checkpoint_prefix: str): pass - def get_latest_checkpoint(self, checkpoint_dir: str, checkpoint_prefix: str) -> str: - latest_checkpoint_id = -1 - latest_checkpoint = '' - # get all checkpoint files - for fname in os.listdir(checkpoint_dir): - path = os.path.join(checkpoint_dir, fname) - if os.path.isdir(path) or fname.split('.')[-1] != 'srs' or checkpoint_prefix not in fname: - continue - checkpoint_id = int(fname.split('_')[0]) - if checkpoint_id > latest_checkpoint_id: - latest_checkpoint = fname - latest_checkpoint_id = checkpoint_id - - return latest_checkpoint - class NumpySharedRunningStats(SharedRunningStats): def __init__(self, name, epsilon=1e-2, pubsub_params=None): super().__init__(name=name, pubsub_params=pubsub_params) self._count = epsilon self.epsilon = epsilon + self.checkpoint_file_extension = 'srs' def set_params(self, shape=[1], clip_values=None): self._shape = shape @@ -185,11 +173,12 @@ class NumpySharedRunningStats(SharedRunningStats): '_sum': self._sum, '_sum_squares': self._sum_squares} - with open(os.path.join(checkpoint_dir, str(checkpoint_prefix) + '.srs'), 'wb') as f: + with open(os.path.join(checkpoint_dir, str(checkpoint_prefix) + '.' + self.checkpoint_file_extension), 'wb') as f: pickle.dump(dict_to_save, f, pickle.HIGHEST_PROTOCOL) def restore_state_from_checkpoint(self, checkpoint_dir: str, checkpoint_prefix: str): - latest_checkpoint_filename = self.get_latest_checkpoint(checkpoint_dir, checkpoint_prefix) + latest_checkpoint_filename = get_latest_checkpoint(checkpoint_dir, checkpoint_prefix, + self.checkpoint_file_extension) if latest_checkpoint_filename == '': raise ValueError("Could not find NumpySharedRunningStats checkpoint file. ") diff --git a/rl_coach/utils.py b/rl_coach/utils.py index f51b02b..72714fa 100644 --- a/rl_coach/utils.py +++ b/rl_coach/utils.py @@ -532,3 +532,18 @@ def start_shell_command_and_wait(command): def indent_string(string): return '\t' + string.replace('\n', '\n\t') + +def get_latest_checkpoint(checkpoint_dir: str, checkpoint_prefix: str, checkpoint_file_extension: str) -> str: + latest_checkpoint_id = -1 + latest_checkpoint = '' + # get all checkpoint files + for fname in os.listdir(checkpoint_dir): + path = os.path.join(checkpoint_dir, fname) + if os.path.isdir(path) or fname.split('.')[-1] != checkpoint_file_extension or checkpoint_prefix not in fname: + continue + checkpoint_id = int(fname.split('_')[0]) + if checkpoint_id > latest_checkpoint_id: + latest_checkpoint = fname + latest_checkpoint_id = checkpoint_id + + return latest_checkpoint