diff --git a/agents/ddpg_agent.py b/agents/ddpg_agent.py index f5d0275..c973c06 100644 --- a/agents/ddpg_agent.py +++ b/agents/ddpg_agent.py @@ -37,6 +37,8 @@ class DDPGAgent(ActorCriticAgent): self.q_values = Signal("Q") self.signals.append(self.q_values) + self.reset_game(do_not_reset_env=True) + def learn_from_batch(self, batch): current_states, next_states, actions, rewards, game_overs, _ = self.extract_batch(batch) diff --git a/agents/policy_optimization_agent.py b/agents/policy_optimization_agent.py index 07aac6a..be23760 100644 --- a/agents/policy_optimization_agent.py +++ b/agents/policy_optimization_agent.py @@ -47,6 +47,8 @@ class PolicyOptimizationAgent(Agent): self.entropy = Signal('Entropy') self.signals.append(self.entropy) + self.reset_game(do_not_reset_env=True) + def log_to_screen(self, phase): # log to screen if self.current_episode > 0: diff --git a/agents/ppo_agent.py b/agents/ppo_agent.py index 990abd8..92e7cd6 100644 --- a/agents/ppo_agent.py +++ b/agents/ppo_agent.py @@ -45,6 +45,8 @@ class PPOAgent(ActorCriticAgent): self.unclipped_grads = Signal('Grads (unclipped)') self.signals.append(self.unclipped_grads) + self.reset_game(do_not_reset_env=True) + def fill_advantages(self, batch): current_states, next_states, actions, rewards, game_overs, total_return = self.extract_batch(batch) diff --git a/agents/value_optimization_agent.py b/agents/value_optimization_agent.py index 0684e34..f318577 100644 --- a/agents/value_optimization_agent.py +++ b/agents/value_optimization_agent.py @@ -28,6 +28,8 @@ class ValueOptimizationAgent(Agent): self.q_values = Signal("Q") self.signals.append(self.q_values) + self.reset_game(do_not_reset_env=True) + # Algorithms for which q_values are calculated from predictions will override this function def get_q_values(self, prediction): return prediction