diff --git a/rl_coach/graph_managers/graph_manager.py b/rl_coach/graph_managers/graph_manager.py index 03da59d..6325d28 100644 --- a/rl_coach/graph_managers/graph_manager.py +++ b/rl_coach/graph_managers/graph_manager.py @@ -379,7 +379,7 @@ class GraphManager(object): if result.game_over: self.handle_episode_ended() - + self.reset_required = True def train_and_act(self, steps: StepMethod) -> None: @@ -431,6 +431,37 @@ class GraphManager(object): self.act(EnvironmentEpisodes(1)) self.sync() + def improve(self): + """ + The main loop of the run. + Defined in the following steps: + 1. Heatup + 2. Repeat: + 2.1. Repeat: + 2.1.1. Act + 2.1.2. Train + 2.1.3. Possibly save checkpoint + 2.2. Evaluate + :return: None + """ + + # initialize the network parameters from the global network + self.sync() + + # heatup + self.heatup(self.heatup_steps) + + # improve + if self.task_parameters.task_index is not None: + screen.log_title("Starting to improve {} task index {}".format(self.name, self.task_parameters.task_index)) + else: + screen.log_title("Starting to improve {}".format(self.name)) + + count_end = self.improve_steps.num_steps + while self.total_steps_counters[RunPhase.TRAIN][self.improve_steps.__class__] < count_end: + self.train_and_act(self.steps_between_evaluation_periods) + self.evaluate(self.evaluation_steps) + def restore_checkpoint(self): self.verify_graph_was_created() @@ -489,39 +520,6 @@ class GraphManager(object): data_store = get_data_store(self.data_store_params) data_store.save_to_store() - - def improve(self): - """ - The main loop of the run. - Defined in the following steps: - 1. Heatup - 2. Repeat: - 2.1. Repeat: - 2.1.1. Act - 2.1.2. Train - 2.1.3. Possibly save checkpoint - 2.2. Evaluate - :return: None - """ - self.verify_graph_was_created() - - # initialize the network parameters from the global network - self.sync() - - # heatup - self.heatup(self.heatup_steps) - - # improve - if self.task_parameters.task_index is not None: - screen.log_title("Starting to improve {} task index {}".format(self.name, self.task_parameters.task_index)) - else: - screen.log_title("Starting to improve {}".format(self.name)) - - count_end = self.improve_steps.num_steps - while self.total_steps_counters[RunPhase.TRAIN][self.improve_steps.__class__] < count_end: - self.train_and_act(self.steps_between_evaluation_periods) - self.evaluate(self.evaluation_steps) - def verify_graph_was_created(self): """ Verifies that the graph was already created, and if not, it creates it with the default task parameters