1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 19:20:19 +01:00

Distiller's AMC induced changes (#359)

* override episode rewards with the last transition reward

* EWMA normalization filter

* allowing control over when the pre_network filter runs
This commit is contained in:
Gal Leibovich
2019-08-05 10:24:58 +03:00
committed by GitHub
parent 7df67dafa3
commit c1d1fae342
10 changed files with 137 additions and 30 deletions

View File

@@ -573,6 +573,9 @@ class Agent(AgentInterface):
if self.phase != RunPhase.TEST:
if isinstance(self.memory, EpisodicExperienceReplay):
if self.ap.algorithm.override_episode_rewards_with_the_last_transition_reward:
for t in self.current_episode_buffer.transitions:
t.reward = self.current_episode_buffer.transitions[-1].reward
self.call_memory('store_episode', self.current_episode_buffer)
elif self.ap.algorithm.store_transitions_only_when_episodes_are_terminated:
for transition in self.current_episode_buffer.transitions:
@@ -727,7 +730,8 @@ class Agent(AgentInterface):
# update counters
self.training_iteration += 1
if self.pre_network_filter is not None:
batch = self.pre_network_filter.filter(batch, update_internal_state=False, deep_copy=False)
update_internal_state = self.ap.algorithm.update_pre_network_filters_state_on_train
batch = self.pre_network_filter.filter(batch, update_internal_state=update_internal_state, deep_copy=False)
# if the batch returned empty then there are not enough samples in the replay buffer -> skip
# training step
@@ -837,7 +841,8 @@ class Agent(AgentInterface):
# informed action
if self.pre_network_filter is not None:
# before choosing an action, first use the pre_network_filter to filter out the current state
update_filter_internal_state = self.phase is not RunPhase.TEST
update_filter_internal_state = self.ap.algorithm.update_pre_network_filters_state_on_inference and \
self.phase is not RunPhase.TEST
curr_state = self.run_pre_network_filter_for_inference(self.curr_state, update_filter_internal_state)
else:
@@ -865,6 +870,10 @@ class Agent(AgentInterface):
:return: The filtered state
"""
dummy_env_response = EnvResponse(next_state=state, reward=0, game_over=False)
# TODO actually we only want to run the observation filters. No point in running the reward filters as the
# filtered reward is being ignored anyway (and it might unncecessarily affect the reward filters' internal
# state).
return self.pre_network_filter.filter(dummy_env_response,
update_internal_state=update_filter_internal_state)[0].next_state