1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-18 11:40:18 +01:00

reorder methods in GraphManager

This commit is contained in:
Zach Dwiel
2018-10-04 16:47:41 -04:00
committed by zach dwiel
parent 496a516de1
commit 776c94d551

View File

@@ -431,6 +431,37 @@ class GraphManager(object):
self.act(EnvironmentEpisodes(1))
self.sync()
def improve(self):
"""
The main loop of the run.
Defined in the following steps:
1. Heatup
2. Repeat:
2.1. Repeat:
2.1.1. Act
2.1.2. Train
2.1.3. Possibly save checkpoint
2.2. Evaluate
:return: None
"""
# initialize the network parameters from the global network
self.sync()
# heatup
self.heatup(self.heatup_steps)
# improve
if self.task_parameters.task_index is not None:
screen.log_title("Starting to improve {} task index {}".format(self.name, self.task_parameters.task_index))
else:
screen.log_title("Starting to improve {}".format(self.name))
count_end = self.improve_steps.num_steps
while self.total_steps_counters[RunPhase.TRAIN][self.improve_steps.__class__] < count_end:
self.train_and_act(self.steps_between_evaluation_periods)
self.evaluate(self.evaluation_steps)
def restore_checkpoint(self):
self.verify_graph_was_created()
@@ -489,39 +520,6 @@ class GraphManager(object):
data_store = get_data_store(self.data_store_params)
data_store.save_to_store()
def improve(self):
"""
The main loop of the run.
Defined in the following steps:
1. Heatup
2. Repeat:
2.1. Repeat:
2.1.1. Act
2.1.2. Train
2.1.3. Possibly save checkpoint
2.2. Evaluate
:return: None
"""
self.verify_graph_was_created()
# initialize the network parameters from the global network
self.sync()
# heatup
self.heatup(self.heatup_steps)
# improve
if self.task_parameters.task_index is not None:
screen.log_title("Starting to improve {} task index {}".format(self.name, self.task_parameters.task_index))
else:
screen.log_title("Starting to improve {}".format(self.name))
count_end = self.improve_steps.num_steps
while self.total_steps_counters[RunPhase.TRAIN][self.improve_steps.__class__] < count_end:
self.train_and_act(self.steps_between_evaluation_periods)
self.evaluate(self.evaluation_steps)
def verify_graph_was_created(self):
"""
Verifies that the graph was already created, and if not, it creates it with the default task parameters