diff --git a/rl_coach/graph_managers/graph_manager.py b/rl_coach/graph_managers/graph_manager.py index 201c67b..f0788ff 100644 --- a/rl_coach/graph_managers/graph_manager.py +++ b/rl_coach/graph_managers/graph_manager.py @@ -348,11 +348,10 @@ class GraphManager(object): [environment.reset_internal_state(force_environment_reset) for environment in self.environments] [manager.reset_internal_state() for manager in self.level_managers] - def act(self, steps: PlayingStepsType, keep_networks_in_sync=False) -> (int, bool): + def act(self, steps: PlayingStepsType) -> (int, bool): """ Do several steps of acting on the environment :param steps: the number of steps as a tuple of steps time and steps count - :param keep_networks_in_sync: sync the network parameters with the global network before each episode """ self.verify_graph_was_created() @@ -386,8 +385,6 @@ class GraphManager(object): self.handle_episode_ended() self.reset_required = True - if keep_networks_in_sync: - self.sync_graph() def train_and_act(self, steps: StepMethod) -> None: """ @@ -436,7 +433,8 @@ class GraphManager(object): # act for at least `steps`, though don't interrupt an episode count_end = self.total_steps_counters[self.phase][steps.__class__] + steps.num_steps while self.total_steps_counters[self.phase][steps.__class__] < count_end: - self.act(EnvironmentEpisodes(1), keep_networks_in_sync=keep_networks_in_sync) + self.act(EnvironmentEpisodes(1)) + self.sync_graph() def restore_checkpoint(self): self.verify_graph_was_created()