1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 19:20:19 +01:00

Distiller's AMC induced changes (#359)

* override episode rewards with the last transition reward

* EWMA normalization filter

* allowing control over when the pre_network filter runs
This commit is contained in:
Gal Leibovich
2019-08-05 10:24:58 +03:00
committed by GitHub
parent 7df67dafa3
commit c1d1fae342
10 changed files with 137 additions and 30 deletions

View File

@@ -113,6 +113,8 @@ class ClippedPPOAlgorithmParameters(AlgorithmParameters):
self.normalization_stats = None
self.clipping_decay_schedule = ConstantSchedule(1)
self.act_for_full_episodes = True
self.update_pre_network_filters_state_on_train = True
self.update_pre_network_filters_state_on_inference = False
class ClippedPPOAgentParameters(AgentParameters):
@@ -303,7 +305,9 @@ class ClippedPPOAgent(ActorCriticAgent):
network.set_is_training(True)
dataset = self.memory.transitions
dataset = self.pre_network_filter.filter(dataset, deep_copy=False)
update_internal_state = self.ap.algorithm.update_pre_network_filters_state_on_train
dataset = self.pre_network_filter.filter(dataset, deep_copy=False,
update_internal_state=update_internal_state)
batch = Batch(dataset)
for training_step in range(self.ap.algorithm.num_consecutive_training_steps):
@@ -329,7 +333,9 @@ class ClippedPPOAgent(ActorCriticAgent):
def run_pre_network_filter_for_inference(self, state: StateType, update_internal_state: bool=False):
dummy_env_response = EnvResponse(next_state=state, reward=0, game_over=False)
return self.pre_network_filter.filter(dummy_env_response, update_internal_state=False)[0].next_state
update_internal_state = self.ap.algorithm.update_pre_network_filters_state_on_inference
return self.pre_network_filter.filter(
dummy_env_response, update_internal_state=update_internal_state)[0].next_state
def choose_action(self, curr_state):
self.ap.algorithm.clipping_decay_schedule.step()