diff --git a/rl_coach/graph_managers/graph_manager.py b/rl_coach/graph_managers/graph_manager.py index 80e5149..6c72e54 100644 --- a/rl_coach/graph_managers/graph_manager.py +++ b/rl_coach/graph_managers/graph_manager.py @@ -446,19 +446,17 @@ class GraphManager(object): self.verify_graph_was_created() if steps.num_steps > 0: - self.phase = RunPhase.TEST - self.last_evaluation_start_time = time.time() + with self.phase_context(RunPhase.TEST): + self.last_evaluation_start_time = time.time() - # reset all the levels before starting to evaluate - self.reset_internal_state(force_environment_reset=True) - self.sync_graph() + # reset all the levels before starting to evaluate + self.reset_internal_state(force_environment_reset=True) + self.sync_graph() - # act for at least `steps`, though don't interrupt an episode - count_end = self.total_steps_counters[self.phase][steps.__class__] + steps.num_steps - while self.total_steps_counters[self.phase][steps.__class__] < count_end: - self.act(EnvironmentEpisodes(1), keep_networks_in_sync=keep_networks_in_sync) - - self.phase = RunPhase.UNDEFINED + # act for at least `steps`, though don't interrupt an episode + count_end = self.total_steps_counters[self.phase][steps.__class__] + steps.num_steps + while self.total_steps_counters[self.phase][steps.__class__] < count_end: + self.act(EnvironmentEpisodes(1), keep_networks_in_sync=keep_networks_in_sync) def restore_checkpoint(self): self.verify_graph_was_created()