diff --git a/rl_coach/graph_managers/graph_manager.py b/rl_coach/graph_managers/graph_manager.py index a8a1f5b..80e5149 100644 --- a/rl_coach/graph_managers/graph_manager.py +++ b/rl_coach/graph_managers/graph_manager.py @@ -319,7 +319,7 @@ class GraphManager(object): # act for at least steps, though don't interrupt an episode count_end = self.total_steps_counters[self.phase][EnvironmentSteps] + steps.num_steps while self.total_steps_counters[self.phase][steps.__class__] < count_end: - self.act(steps, continue_until_game_over=True, return_on_game_over=True) + self.act(EnvironmentEpisodes(1)) def handle_episode_ended(self) -> None: """ @@ -356,8 +356,7 @@ class GraphManager(object): [environment.reset_internal_state(force_environment_reset) for environment in self.environments] [manager.reset_internal_state() for manager in self.level_managers] - def act(self, steps: PlayingStepsType, return_on_game_over: bool=False, continue_until_game_over=False, - keep_networks_in_sync=False) -> (int, bool): + def act(self, steps: PlayingStepsType, keep_networks_in_sync=False) -> (int, bool): """ Do several steps of acting on the environment :param steps: the number of steps as a tuple of steps time and steps count @@ -380,7 +379,7 @@ class GraphManager(object): # The assumption here is that the total_steps_counters are each updated when an event # takes place (i.e. an episode ends) # TODO - The counter of frames is not updated correctly. need to fix that. - while self.total_steps_counters[self.phase][steps.__class__] < count_end or continue_until_game_over: + while self.total_steps_counters[self.phase][steps.__class__] < count_end: # reset the environment if the previous episode was terminated if self.reset_required: self.reset_internal_state() @@ -402,14 +401,11 @@ class GraphManager(object): self.total_steps_counters[self.phase][EnvironmentSteps] += max(1, steps_end - steps_begin) if result.game_over: - continue_until_game_over = False self.handle_episode_ended() # TODO: why not just reset right now? self.reset_required = True if keep_networks_in_sync: self.sync_graph() - if return_on_game_over: - return def train_and_act(self, steps: StepMethod) -> None: """ @@ -460,8 +456,7 @@ class GraphManager(object): # act for at least `steps`, though don't interrupt an episode count_end = self.total_steps_counters[self.phase][steps.__class__] + steps.num_steps while self.total_steps_counters[self.phase][steps.__class__] < count_end: - self.act(steps, continue_until_game_over=True, return_on_game_over=True, - keep_networks_in_sync=keep_networks_in_sync) + self.act(EnvironmentEpisodes(1), keep_networks_in_sync=keep_networks_in_sync) self.phase = RunPhase.UNDEFINED