mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 03:00:20 +01:00
Distiller's AMC induced changes (#359)
* override episode rewards with the last transition reward * EWMA normalization filter * allowing control over when the pre_network filter runs
This commit is contained in:
@@ -573,6 +573,9 @@ class Agent(AgentInterface):
|
|||||||
|
|
||||||
if self.phase != RunPhase.TEST:
|
if self.phase != RunPhase.TEST:
|
||||||
if isinstance(self.memory, EpisodicExperienceReplay):
|
if isinstance(self.memory, EpisodicExperienceReplay):
|
||||||
|
if self.ap.algorithm.override_episode_rewards_with_the_last_transition_reward:
|
||||||
|
for t in self.current_episode_buffer.transitions:
|
||||||
|
t.reward = self.current_episode_buffer.transitions[-1].reward
|
||||||
self.call_memory('store_episode', self.current_episode_buffer)
|
self.call_memory('store_episode', self.current_episode_buffer)
|
||||||
elif self.ap.algorithm.store_transitions_only_when_episodes_are_terminated:
|
elif self.ap.algorithm.store_transitions_only_when_episodes_are_terminated:
|
||||||
for transition in self.current_episode_buffer.transitions:
|
for transition in self.current_episode_buffer.transitions:
|
||||||
@@ -727,7 +730,8 @@ class Agent(AgentInterface):
|
|||||||
# update counters
|
# update counters
|
||||||
self.training_iteration += 1
|
self.training_iteration += 1
|
||||||
if self.pre_network_filter is not None:
|
if self.pre_network_filter is not None:
|
||||||
batch = self.pre_network_filter.filter(batch, update_internal_state=False, deep_copy=False)
|
update_internal_state = self.ap.algorithm.update_pre_network_filters_state_on_train
|
||||||
|
batch = self.pre_network_filter.filter(batch, update_internal_state=update_internal_state, deep_copy=False)
|
||||||
|
|
||||||
# if the batch returned empty then there are not enough samples in the replay buffer -> skip
|
# if the batch returned empty then there are not enough samples in the replay buffer -> skip
|
||||||
# training step
|
# training step
|
||||||
@@ -837,7 +841,8 @@ class Agent(AgentInterface):
|
|||||||
# informed action
|
# informed action
|
||||||
if self.pre_network_filter is not None:
|
if self.pre_network_filter is not None:
|
||||||
# before choosing an action, first use the pre_network_filter to filter out the current state
|
# before choosing an action, first use the pre_network_filter to filter out the current state
|
||||||
update_filter_internal_state = self.phase is not RunPhase.TEST
|
update_filter_internal_state = self.ap.algorithm.update_pre_network_filters_state_on_inference and \
|
||||||
|
self.phase is not RunPhase.TEST
|
||||||
curr_state = self.run_pre_network_filter_for_inference(self.curr_state, update_filter_internal_state)
|
curr_state = self.run_pre_network_filter_for_inference(self.curr_state, update_filter_internal_state)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
@@ -865,6 +870,10 @@ class Agent(AgentInterface):
|
|||||||
:return: The filtered state
|
:return: The filtered state
|
||||||
"""
|
"""
|
||||||
dummy_env_response = EnvResponse(next_state=state, reward=0, game_over=False)
|
dummy_env_response = EnvResponse(next_state=state, reward=0, game_over=False)
|
||||||
|
|
||||||
|
# TODO actually we only want to run the observation filters. No point in running the reward filters as the
|
||||||
|
# filtered reward is being ignored anyway (and it might unncecessarily affect the reward filters' internal
|
||||||
|
# state).
|
||||||
return self.pre_network_filter.filter(dummy_env_response,
|
return self.pre_network_filter.filter(dummy_env_response,
|
||||||
update_internal_state=update_filter_internal_state)[0].next_state
|
update_internal_state=update_filter_internal_state)[0].next_state
|
||||||
|
|
||||||
|
|||||||
@@ -113,6 +113,8 @@ class ClippedPPOAlgorithmParameters(AlgorithmParameters):
|
|||||||
self.normalization_stats = None
|
self.normalization_stats = None
|
||||||
self.clipping_decay_schedule = ConstantSchedule(1)
|
self.clipping_decay_schedule = ConstantSchedule(1)
|
||||||
self.act_for_full_episodes = True
|
self.act_for_full_episodes = True
|
||||||
|
self.update_pre_network_filters_state_on_train = True
|
||||||
|
self.update_pre_network_filters_state_on_inference = False
|
||||||
|
|
||||||
|
|
||||||
class ClippedPPOAgentParameters(AgentParameters):
|
class ClippedPPOAgentParameters(AgentParameters):
|
||||||
@@ -303,7 +305,9 @@ class ClippedPPOAgent(ActorCriticAgent):
|
|||||||
network.set_is_training(True)
|
network.set_is_training(True)
|
||||||
|
|
||||||
dataset = self.memory.transitions
|
dataset = self.memory.transitions
|
||||||
dataset = self.pre_network_filter.filter(dataset, deep_copy=False)
|
update_internal_state = self.ap.algorithm.update_pre_network_filters_state_on_train
|
||||||
|
dataset = self.pre_network_filter.filter(dataset, deep_copy=False,
|
||||||
|
update_internal_state=update_internal_state)
|
||||||
batch = Batch(dataset)
|
batch = Batch(dataset)
|
||||||
|
|
||||||
for training_step in range(self.ap.algorithm.num_consecutive_training_steps):
|
for training_step in range(self.ap.algorithm.num_consecutive_training_steps):
|
||||||
@@ -329,7 +333,9 @@ class ClippedPPOAgent(ActorCriticAgent):
|
|||||||
|
|
||||||
def run_pre_network_filter_for_inference(self, state: StateType, update_internal_state: bool=False):
|
def run_pre_network_filter_for_inference(self, state: StateType, update_internal_state: bool=False):
|
||||||
dummy_env_response = EnvResponse(next_state=state, reward=0, game_over=False)
|
dummy_env_response = EnvResponse(next_state=state, reward=0, game_over=False)
|
||||||
return self.pre_network_filter.filter(dummy_env_response, update_internal_state=False)[0].next_state
|
update_internal_state = self.ap.algorithm.update_pre_network_filters_state_on_inference
|
||||||
|
return self.pre_network_filter.filter(
|
||||||
|
dummy_env_response, update_internal_state=update_internal_state)[0].next_state
|
||||||
|
|
||||||
def choose_action(self, curr_state):
|
def choose_action(self, curr_state):
|
||||||
self.ap.algorithm.clipping_decay_schedule.step()
|
self.ap.algorithm.clipping_decay_schedule.step()
|
||||||
|
|||||||
@@ -213,6 +213,14 @@ class AlgorithmParameters(Parameters):
|
|||||||
# Support for parameter noise
|
# Support for parameter noise
|
||||||
self.supports_parameter_noise = False
|
self.supports_parameter_noise = False
|
||||||
|
|
||||||
|
# Override, in retrospective, all the episode rewards with the last reward in the episode
|
||||||
|
# (sometimes useful for sparse, end of the episode, rewards problems)
|
||||||
|
self.override_episode_rewards_with_the_last_transition_reward = False
|
||||||
|
|
||||||
|
# Filters - TODO consider creating a FilterParameters class and initialize the filters with it
|
||||||
|
self.update_pre_network_filters_state_on_train = False
|
||||||
|
self.update_pre_network_filters_state_on_inference = True
|
||||||
|
|
||||||
|
|
||||||
class PresetValidationParameters(Parameters):
|
class PresetValidationParameters(Parameters):
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
|
|||||||
@@ -88,9 +88,6 @@ class TruncatedNormal(ContinuousActionExplorationPolicy):
|
|||||||
else:
|
else:
|
||||||
action_values_std = current_noise
|
action_values_std = current_noise
|
||||||
|
|
||||||
# scale the noise to the action space range
|
|
||||||
action_values_std = current_noise * (self.action_space.high - self.action_space.low)
|
|
||||||
|
|
||||||
# extract the mean values
|
# extract the mean values
|
||||||
if isinstance(action_values, list):
|
if isinstance(action_values, list):
|
||||||
# the action values are expected to be a list with the action mean and optionally the action stdev
|
# the action values are expected to be a list with the action mean and optionally the action stdev
|
||||||
|
|||||||
@@ -338,11 +338,14 @@ class InputFilter(Filter):
|
|||||||
state_object[observation_name] = filtered_observations[i]
|
state_object[observation_name] = filtered_observations[i]
|
||||||
|
|
||||||
# filter reward
|
# filter reward
|
||||||
for f in filtered_data:
|
for filter in self._reward_filters.values():
|
||||||
filtered_reward = f.reward
|
if filter.supports_batching:
|
||||||
for filter in self._reward_filters.values():
|
filtered_rewards = filter.filter([f.reward for f in filtered_data], update_internal_state)
|
||||||
filtered_reward = filter.filter(filtered_reward, update_internal_state)
|
for d, filtered_reward in zip(filtered_data, filtered_rewards):
|
||||||
f.reward = filtered_reward
|
d.reward = filtered_reward
|
||||||
|
else:
|
||||||
|
for d in filtered_data:
|
||||||
|
d.reward = filter.filter(d.reward, update_internal_state)
|
||||||
|
|
||||||
return filtered_data
|
return filtered_data
|
||||||
|
|
||||||
|
|||||||
@@ -1,8 +1,11 @@
|
|||||||
from .reward_rescale_filter import RewardRescaleFilter
|
from .reward_rescale_filter import RewardRescaleFilter
|
||||||
from .reward_clipping_filter import RewardClippingFilter
|
from .reward_clipping_filter import RewardClippingFilter
|
||||||
from .reward_normalization_filter import RewardNormalizationFilter
|
from .reward_normalization_filter import RewardNormalizationFilter
|
||||||
|
from .reward_ewma_normalization_filter import RewardEwmaNormalizationFilter
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'RewardRescaleFilter',
|
'RewardRescaleFilter',
|
||||||
'RewardClippingFilter',
|
'RewardClippingFilter',
|
||||||
'RewardNormalizationFilter'
|
'RewardNormalizationFilter',
|
||||||
|
'RewardEwmaNormalizationFilter'
|
||||||
]
|
]
|
||||||
76
rl_coach/filters/reward/reward_ewma_normalization_filter.py
Normal file
76
rl_coach/filters/reward/reward_ewma_normalization_filter.py
Normal file
@@ -0,0 +1,76 @@
|
|||||||
|
#
|
||||||
|
# Copyright (c) 2019 Intel Corporation
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
import os
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pickle
|
||||||
|
|
||||||
|
from rl_coach.core_types import RewardType
|
||||||
|
from rl_coach.filters.reward.reward_filter import RewardFilter
|
||||||
|
from rl_coach.spaces import RewardSpace
|
||||||
|
from rl_coach.utils import get_latest_checkpoint
|
||||||
|
|
||||||
|
|
||||||
|
class RewardEwmaNormalizationFilter(RewardFilter):
|
||||||
|
"""
|
||||||
|
Normalizes the reward values based on Exponential Weighted Moving Average.
|
||||||
|
"""
|
||||||
|
def __init__(self, alpha: float = 0.01):
|
||||||
|
"""
|
||||||
|
:param alpha: the degree of weighting decrease, a constant smoothing factor between 0 and 1.
|
||||||
|
A higher alpha discounts older observations faster
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.alpha = alpha
|
||||||
|
self.moving_average = 0
|
||||||
|
self.initialized = False
|
||||||
|
self.checkpoint_file_extension = 'ewma'
|
||||||
|
self.supports_batching = True
|
||||||
|
|
||||||
|
def filter(self, reward: RewardType, update_internal_state: bool=True) -> RewardType:
|
||||||
|
if not isinstance(reward, np.ndarray):
|
||||||
|
reward = np.array(reward)
|
||||||
|
|
||||||
|
if update_internal_state:
|
||||||
|
mean_rewards = np.mean(reward)
|
||||||
|
|
||||||
|
if not self.initialized:
|
||||||
|
self.moving_average = mean_rewards
|
||||||
|
self.initialized = True
|
||||||
|
else:
|
||||||
|
self.moving_average += self.alpha * (mean_rewards - self.moving_average)
|
||||||
|
|
||||||
|
return reward - self.moving_average
|
||||||
|
|
||||||
|
def get_filtered_reward_space(self, input_reward_space: RewardSpace) -> RewardSpace:
|
||||||
|
return input_reward_space
|
||||||
|
|
||||||
|
def save_state_to_checkpoint(self, checkpoint_dir: str, checkpoint_prefix: int):
|
||||||
|
dict_to_save = {'moving_average': self.moving_average}
|
||||||
|
|
||||||
|
with open(os.path.join(checkpoint_dir, str(checkpoint_prefix) + '.' + self.checkpoint_file_extension), 'wb') as f:
|
||||||
|
pickle.dump(dict_to_save, f, pickle.HIGHEST_PROTOCOL)
|
||||||
|
|
||||||
|
def restore_state_from_checkpoint(self, checkpoint_dir: str, checkpoint_prefix: str):
|
||||||
|
latest_checkpoint_filename = get_latest_checkpoint(checkpoint_dir, checkpoint_prefix,
|
||||||
|
self.checkpoint_file_extension)
|
||||||
|
|
||||||
|
if latest_checkpoint_filename == '':
|
||||||
|
raise ValueError("Could not find RewardEwmaNormalizationFilter checkpoint file. ")
|
||||||
|
|
||||||
|
with open(os.path.join(checkpoint_dir, str(latest_checkpoint_filename)), 'rb') as f:
|
||||||
|
saved_dict = pickle.load(f)
|
||||||
|
self.__dict__.update(saved_dict)
|
||||||
@@ -21,6 +21,7 @@ from rl_coach.spaces import RewardSpace
|
|||||||
class RewardFilter(Filter):
|
class RewardFilter(Filter):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
self.supports_batching = False
|
||||||
|
|
||||||
def get_filtered_reward_space(self, input_reward_space: RewardSpace) -> RewardSpace:
|
def get_filtered_reward_space(self, input_reward_space: RewardSpace) -> RewardSpace:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -20,6 +20,8 @@ import pickle
|
|||||||
import redis
|
import redis
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
from rl_coach.utils import get_latest_checkpoint
|
||||||
|
|
||||||
|
|
||||||
class SharedRunningStatsSubscribe(threading.Thread):
|
class SharedRunningStatsSubscribe(threading.Thread):
|
||||||
def __init__(self, shared_running_stats):
|
def __init__(self, shared_running_stats):
|
||||||
@@ -109,27 +111,13 @@ class SharedRunningStats(ABC):
|
|||||||
def restore_state_from_checkpoint(self, checkpoint_dir: str, checkpoint_prefix: str):
|
def restore_state_from_checkpoint(self, checkpoint_dir: str, checkpoint_prefix: str):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def get_latest_checkpoint(self, checkpoint_dir: str, checkpoint_prefix: str) -> str:
|
|
||||||
latest_checkpoint_id = -1
|
|
||||||
latest_checkpoint = ''
|
|
||||||
# get all checkpoint files
|
|
||||||
for fname in os.listdir(checkpoint_dir):
|
|
||||||
path = os.path.join(checkpoint_dir, fname)
|
|
||||||
if os.path.isdir(path) or fname.split('.')[-1] != 'srs' or checkpoint_prefix not in fname:
|
|
||||||
continue
|
|
||||||
checkpoint_id = int(fname.split('_')[0])
|
|
||||||
if checkpoint_id > latest_checkpoint_id:
|
|
||||||
latest_checkpoint = fname
|
|
||||||
latest_checkpoint_id = checkpoint_id
|
|
||||||
|
|
||||||
return latest_checkpoint
|
|
||||||
|
|
||||||
|
|
||||||
class NumpySharedRunningStats(SharedRunningStats):
|
class NumpySharedRunningStats(SharedRunningStats):
|
||||||
def __init__(self, name, epsilon=1e-2, pubsub_params=None):
|
def __init__(self, name, epsilon=1e-2, pubsub_params=None):
|
||||||
super().__init__(name=name, pubsub_params=pubsub_params)
|
super().__init__(name=name, pubsub_params=pubsub_params)
|
||||||
self._count = epsilon
|
self._count = epsilon
|
||||||
self.epsilon = epsilon
|
self.epsilon = epsilon
|
||||||
|
self.checkpoint_file_extension = 'srs'
|
||||||
|
|
||||||
def set_params(self, shape=[1], clip_values=None):
|
def set_params(self, shape=[1], clip_values=None):
|
||||||
self._shape = shape
|
self._shape = shape
|
||||||
@@ -185,11 +173,12 @@ class NumpySharedRunningStats(SharedRunningStats):
|
|||||||
'_sum': self._sum,
|
'_sum': self._sum,
|
||||||
'_sum_squares': self._sum_squares}
|
'_sum_squares': self._sum_squares}
|
||||||
|
|
||||||
with open(os.path.join(checkpoint_dir, str(checkpoint_prefix) + '.srs'), 'wb') as f:
|
with open(os.path.join(checkpoint_dir, str(checkpoint_prefix) + '.' + self.checkpoint_file_extension), 'wb') as f:
|
||||||
pickle.dump(dict_to_save, f, pickle.HIGHEST_PROTOCOL)
|
pickle.dump(dict_to_save, f, pickle.HIGHEST_PROTOCOL)
|
||||||
|
|
||||||
def restore_state_from_checkpoint(self, checkpoint_dir: str, checkpoint_prefix: str):
|
def restore_state_from_checkpoint(self, checkpoint_dir: str, checkpoint_prefix: str):
|
||||||
latest_checkpoint_filename = self.get_latest_checkpoint(checkpoint_dir, checkpoint_prefix)
|
latest_checkpoint_filename = get_latest_checkpoint(checkpoint_dir, checkpoint_prefix,
|
||||||
|
self.checkpoint_file_extension)
|
||||||
|
|
||||||
if latest_checkpoint_filename == '':
|
if latest_checkpoint_filename == '':
|
||||||
raise ValueError("Could not find NumpySharedRunningStats checkpoint file. ")
|
raise ValueError("Could not find NumpySharedRunningStats checkpoint file. ")
|
||||||
|
|||||||
@@ -532,3 +532,18 @@ def start_shell_command_and_wait(command):
|
|||||||
|
|
||||||
def indent_string(string):
|
def indent_string(string):
|
||||||
return '\t' + string.replace('\n', '\n\t')
|
return '\t' + string.replace('\n', '\n\t')
|
||||||
|
|
||||||
|
def get_latest_checkpoint(checkpoint_dir: str, checkpoint_prefix: str, checkpoint_file_extension: str) -> str:
|
||||||
|
latest_checkpoint_id = -1
|
||||||
|
latest_checkpoint = ''
|
||||||
|
# get all checkpoint files
|
||||||
|
for fname in os.listdir(checkpoint_dir):
|
||||||
|
path = os.path.join(checkpoint_dir, fname)
|
||||||
|
if os.path.isdir(path) or fname.split('.')[-1] != checkpoint_file_extension or checkpoint_prefix not in fname:
|
||||||
|
continue
|
||||||
|
checkpoint_id = int(fname.split('_')[0])
|
||||||
|
if checkpoint_id > latest_checkpoint_id:
|
||||||
|
latest_checkpoint = fname
|
||||||
|
latest_checkpoint_id = checkpoint_id
|
||||||
|
|
||||||
|
return latest_checkpoint
|
||||||
|
|||||||
Reference in New Issue
Block a user