mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 11:10:20 +01:00
fixed the LSTM middleware initialization
This commit is contained in:
@@ -37,6 +37,8 @@ class DDPGAgent(ActorCriticAgent):
|
||||
self.q_values = Signal("Q")
|
||||
self.signals.append(self.q_values)
|
||||
|
||||
self.reset_game(do_not_reset_env=True)
|
||||
|
||||
def learn_from_batch(self, batch):
|
||||
current_states, next_states, actions, rewards, game_overs, _ = self.extract_batch(batch)
|
||||
|
||||
|
||||
@@ -47,6 +47,8 @@ class PolicyOptimizationAgent(Agent):
|
||||
self.entropy = Signal('Entropy')
|
||||
self.signals.append(self.entropy)
|
||||
|
||||
self.reset_game(do_not_reset_env=True)
|
||||
|
||||
def log_to_screen(self, phase):
|
||||
# log to screen
|
||||
if self.current_episode > 0:
|
||||
|
||||
@@ -45,6 +45,8 @@ class PPOAgent(ActorCriticAgent):
|
||||
self.unclipped_grads = Signal('Grads (unclipped)')
|
||||
self.signals.append(self.unclipped_grads)
|
||||
|
||||
self.reset_game(do_not_reset_env=True)
|
||||
|
||||
def fill_advantages(self, batch):
|
||||
current_states, next_states, actions, rewards, game_overs, total_return = self.extract_batch(batch)
|
||||
|
||||
|
||||
@@ -28,6 +28,8 @@ class ValueOptimizationAgent(Agent):
|
||||
self.q_values = Signal("Q")
|
||||
self.signals.append(self.q_values)
|
||||
|
||||
self.reset_game(do_not_reset_env=True)
|
||||
|
||||
# Algorithms for which q_values are calculated from predictions will override this function
|
||||
def get_q_values(self, prediction):
|
||||
return prediction
|
||||
|
||||
Reference in New Issue
Block a user