1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 19:20:19 +01:00

fixed the LSTM middleware initialization

This commit is contained in:
Itai Caspi
2017-12-30 15:18:09 +02:00
committed by Itai Caspi
parent b435c6d2d7
commit eeb3ec5497
4 changed files with 8 additions and 0 deletions

View File

@@ -37,6 +37,8 @@ class DDPGAgent(ActorCriticAgent):
self.q_values = Signal("Q") self.q_values = Signal("Q")
self.signals.append(self.q_values) self.signals.append(self.q_values)
self.reset_game(do_not_reset_env=True)
def learn_from_batch(self, batch): def learn_from_batch(self, batch):
current_states, next_states, actions, rewards, game_overs, _ = self.extract_batch(batch) current_states, next_states, actions, rewards, game_overs, _ = self.extract_batch(batch)

View File

@@ -47,6 +47,8 @@ class PolicyOptimizationAgent(Agent):
self.entropy = Signal('Entropy') self.entropy = Signal('Entropy')
self.signals.append(self.entropy) self.signals.append(self.entropy)
self.reset_game(do_not_reset_env=True)
def log_to_screen(self, phase): def log_to_screen(self, phase):
# log to screen # log to screen
if self.current_episode > 0: if self.current_episode > 0:

View File

@@ -45,6 +45,8 @@ class PPOAgent(ActorCriticAgent):
self.unclipped_grads = Signal('Grads (unclipped)') self.unclipped_grads = Signal('Grads (unclipped)')
self.signals.append(self.unclipped_grads) self.signals.append(self.unclipped_grads)
self.reset_game(do_not_reset_env=True)
def fill_advantages(self, batch): def fill_advantages(self, batch):
current_states, next_states, actions, rewards, game_overs, total_return = self.extract_batch(batch) current_states, next_states, actions, rewards, game_overs, total_return = self.extract_batch(batch)

View File

@@ -28,6 +28,8 @@ class ValueOptimizationAgent(Agent):
self.q_values = Signal("Q") self.q_values = Signal("Q")
self.signals.append(self.q_values) self.signals.append(self.q_values)
self.reset_game(do_not_reset_env=True)
# Algorithms for which q_values are calculated from predictions will override this function # Algorithms for which q_values are calculated from predictions will override this function
def get_q_values(self, prediction): def get_q_values(self, prediction):
return prediction return prediction