mirror of
https://github.com/gryf/coach.git
synced 2026-02-16 22:25:47 +01:00
Distiller's AMC induced changes (#359)
* override episode rewards with the last transition reward * EWMA normalization filter * allowing control over when the pre_network filter runs
This commit is contained in:
@@ -573,6 +573,9 @@ class Agent(AgentInterface):
|
||||
|
||||
if self.phase != RunPhase.TEST:
|
||||
if isinstance(self.memory, EpisodicExperienceReplay):
|
||||
if self.ap.algorithm.override_episode_rewards_with_the_last_transition_reward:
|
||||
for t in self.current_episode_buffer.transitions:
|
||||
t.reward = self.current_episode_buffer.transitions[-1].reward
|
||||
self.call_memory('store_episode', self.current_episode_buffer)
|
||||
elif self.ap.algorithm.store_transitions_only_when_episodes_are_terminated:
|
||||
for transition in self.current_episode_buffer.transitions:
|
||||
@@ -727,7 +730,8 @@ class Agent(AgentInterface):
|
||||
# update counters
|
||||
self.training_iteration += 1
|
||||
if self.pre_network_filter is not None:
|
||||
batch = self.pre_network_filter.filter(batch, update_internal_state=False, deep_copy=False)
|
||||
update_internal_state = self.ap.algorithm.update_pre_network_filters_state_on_train
|
||||
batch = self.pre_network_filter.filter(batch, update_internal_state=update_internal_state, deep_copy=False)
|
||||
|
||||
# if the batch returned empty then there are not enough samples in the replay buffer -> skip
|
||||
# training step
|
||||
@@ -837,7 +841,8 @@ class Agent(AgentInterface):
|
||||
# informed action
|
||||
if self.pre_network_filter is not None:
|
||||
# before choosing an action, first use the pre_network_filter to filter out the current state
|
||||
update_filter_internal_state = self.phase is not RunPhase.TEST
|
||||
update_filter_internal_state = self.ap.algorithm.update_pre_network_filters_state_on_inference and \
|
||||
self.phase is not RunPhase.TEST
|
||||
curr_state = self.run_pre_network_filter_for_inference(self.curr_state, update_filter_internal_state)
|
||||
|
||||
else:
|
||||
@@ -865,6 +870,10 @@ class Agent(AgentInterface):
|
||||
:return: The filtered state
|
||||
"""
|
||||
dummy_env_response = EnvResponse(next_state=state, reward=0, game_over=False)
|
||||
|
||||
# TODO actually we only want to run the observation filters. No point in running the reward filters as the
|
||||
# filtered reward is being ignored anyway (and it might unncecessarily affect the reward filters' internal
|
||||
# state).
|
||||
return self.pre_network_filter.filter(dummy_env_response,
|
||||
update_internal_state=update_filter_internal_state)[0].next_state
|
||||
|
||||
|
||||
@@ -113,6 +113,8 @@ class ClippedPPOAlgorithmParameters(AlgorithmParameters):
|
||||
self.normalization_stats = None
|
||||
self.clipping_decay_schedule = ConstantSchedule(1)
|
||||
self.act_for_full_episodes = True
|
||||
self.update_pre_network_filters_state_on_train = True
|
||||
self.update_pre_network_filters_state_on_inference = False
|
||||
|
||||
|
||||
class ClippedPPOAgentParameters(AgentParameters):
|
||||
@@ -303,7 +305,9 @@ class ClippedPPOAgent(ActorCriticAgent):
|
||||
network.set_is_training(True)
|
||||
|
||||
dataset = self.memory.transitions
|
||||
dataset = self.pre_network_filter.filter(dataset, deep_copy=False)
|
||||
update_internal_state = self.ap.algorithm.update_pre_network_filters_state_on_train
|
||||
dataset = self.pre_network_filter.filter(dataset, deep_copy=False,
|
||||
update_internal_state=update_internal_state)
|
||||
batch = Batch(dataset)
|
||||
|
||||
for training_step in range(self.ap.algorithm.num_consecutive_training_steps):
|
||||
@@ -329,7 +333,9 @@ class ClippedPPOAgent(ActorCriticAgent):
|
||||
|
||||
def run_pre_network_filter_for_inference(self, state: StateType, update_internal_state: bool=False):
|
||||
dummy_env_response = EnvResponse(next_state=state, reward=0, game_over=False)
|
||||
return self.pre_network_filter.filter(dummy_env_response, update_internal_state=False)[0].next_state
|
||||
update_internal_state = self.ap.algorithm.update_pre_network_filters_state_on_inference
|
||||
return self.pre_network_filter.filter(
|
||||
dummy_env_response, update_internal_state=update_internal_state)[0].next_state
|
||||
|
||||
def choose_action(self, curr_state):
|
||||
self.ap.algorithm.clipping_decay_schedule.step()
|
||||
|
||||
Reference in New Issue
Block a user