From eeb3ec5497cc4f7ced35ef8754755706f96bf133 Mon Sep 17 00:00:00 2001 From: Itai Caspi Date: Sat, 30 Dec 2017 15:18:09 +0200 Subject: [PATCH] fixed the LSTM middleware initialization --- agents/ddpg_agent.py | 2 ++ agents/policy_optimization_agent.py | 2 ++ agents/ppo_agent.py | 2 ++ agents/value_optimization_agent.py | 2 ++ 4 files changed, 8 insertions(+) diff --git a/agents/ddpg_agent.py b/agents/ddpg_agent.py index f5d0275..c973c06 100644 --- a/agents/ddpg_agent.py +++ b/agents/ddpg_agent.py @@ -37,6 +37,8 @@ class DDPGAgent(ActorCriticAgent): self.q_values = Signal("Q") self.signals.append(self.q_values) + self.reset_game(do_not_reset_env=True) + def learn_from_batch(self, batch): current_states, next_states, actions, rewards, game_overs, _ = self.extract_batch(batch) diff --git a/agents/policy_optimization_agent.py b/agents/policy_optimization_agent.py index 07aac6a..be23760 100644 --- a/agents/policy_optimization_agent.py +++ b/agents/policy_optimization_agent.py @@ -47,6 +47,8 @@ class PolicyOptimizationAgent(Agent): self.entropy = Signal('Entropy') self.signals.append(self.entropy) + self.reset_game(do_not_reset_env=True) + def log_to_screen(self, phase): # log to screen if self.current_episode > 0: diff --git a/agents/ppo_agent.py b/agents/ppo_agent.py index 990abd8..92e7cd6 100644 --- a/agents/ppo_agent.py +++ b/agents/ppo_agent.py @@ -45,6 +45,8 @@ class PPOAgent(ActorCriticAgent): self.unclipped_grads = Signal('Grads (unclipped)') self.signals.append(self.unclipped_grads) + self.reset_game(do_not_reset_env=True) + def fill_advantages(self, batch): current_states, next_states, actions, rewards, game_overs, total_return = self.extract_batch(batch) diff --git a/agents/value_optimization_agent.py b/agents/value_optimization_agent.py index 0684e34..f318577 100644 --- a/agents/value_optimization_agent.py +++ b/agents/value_optimization_agent.py @@ -28,6 +28,8 @@ class ValueOptimizationAgent(Agent): self.q_values = Signal("Q") self.signals.append(self.q_values) + self.reset_game(do_not_reset_env=True) + # Algorithms for which q_values are calculated from predictions will override this function def get_q_values(self, prediction): return prediction