1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-18 19:50:17 +01:00

reorder methods in GraphManager

This commit is contained in:
Zach Dwiel
2018-10-04 16:47:41 -04:00
committed by zach dwiel
parent 496a516de1
commit 776c94d551

View File

@@ -379,7 +379,7 @@ class GraphManager(object):
if result.game_over: if result.game_over:
self.handle_episode_ended() self.handle_episode_ended()
self.reset_required = True self.reset_required = True
def train_and_act(self, steps: StepMethod) -> None: def train_and_act(self, steps: StepMethod) -> None:
@@ -431,6 +431,37 @@ class GraphManager(object):
self.act(EnvironmentEpisodes(1)) self.act(EnvironmentEpisodes(1))
self.sync() self.sync()
def improve(self):
"""
The main loop of the run.
Defined in the following steps:
1. Heatup
2. Repeat:
2.1. Repeat:
2.1.1. Act
2.1.2. Train
2.1.3. Possibly save checkpoint
2.2. Evaluate
:return: None
"""
# initialize the network parameters from the global network
self.sync()
# heatup
self.heatup(self.heatup_steps)
# improve
if self.task_parameters.task_index is not None:
screen.log_title("Starting to improve {} task index {}".format(self.name, self.task_parameters.task_index))
else:
screen.log_title("Starting to improve {}".format(self.name))
count_end = self.improve_steps.num_steps
while self.total_steps_counters[RunPhase.TRAIN][self.improve_steps.__class__] < count_end:
self.train_and_act(self.steps_between_evaluation_periods)
self.evaluate(self.evaluation_steps)
def restore_checkpoint(self): def restore_checkpoint(self):
self.verify_graph_was_created() self.verify_graph_was_created()
@@ -489,39 +520,6 @@ class GraphManager(object):
data_store = get_data_store(self.data_store_params) data_store = get_data_store(self.data_store_params)
data_store.save_to_store() data_store.save_to_store()
def improve(self):
"""
The main loop of the run.
Defined in the following steps:
1. Heatup
2. Repeat:
2.1. Repeat:
2.1.1. Act
2.1.2. Train
2.1.3. Possibly save checkpoint
2.2. Evaluate
:return: None
"""
self.verify_graph_was_created()
# initialize the network parameters from the global network
self.sync()
# heatup
self.heatup(self.heatup_steps)
# improve
if self.task_parameters.task_index is not None:
screen.log_title("Starting to improve {} task index {}".format(self.name, self.task_parameters.task_index))
else:
screen.log_title("Starting to improve {}".format(self.name))
count_end = self.improve_steps.num_steps
while self.total_steps_counters[RunPhase.TRAIN][self.improve_steps.__class__] < count_end:
self.train_and_act(self.steps_between_evaluation_periods)
self.evaluate(self.evaluation_steps)
def verify_graph_was_created(self): def verify_graph_was_created(self):
""" """
Verifies that the graph was already created, and if not, it creates it with the default task parameters Verifies that the graph was already created, and if not, it creates it with the default task parameters