mirror of
https://github.com/gryf/coach.git
synced 2025-12-18 19:50:17 +01:00
reorder methods in GraphManager
This commit is contained in:
@@ -379,7 +379,7 @@ class GraphManager(object):
|
|||||||
|
|
||||||
if result.game_over:
|
if result.game_over:
|
||||||
self.handle_episode_ended()
|
self.handle_episode_ended()
|
||||||
|
|
||||||
self.reset_required = True
|
self.reset_required = True
|
||||||
|
|
||||||
def train_and_act(self, steps: StepMethod) -> None:
|
def train_and_act(self, steps: StepMethod) -> None:
|
||||||
@@ -431,6 +431,37 @@ class GraphManager(object):
|
|||||||
self.act(EnvironmentEpisodes(1))
|
self.act(EnvironmentEpisodes(1))
|
||||||
self.sync()
|
self.sync()
|
||||||
|
|
||||||
|
def improve(self):
|
||||||
|
"""
|
||||||
|
The main loop of the run.
|
||||||
|
Defined in the following steps:
|
||||||
|
1. Heatup
|
||||||
|
2. Repeat:
|
||||||
|
2.1. Repeat:
|
||||||
|
2.1.1. Act
|
||||||
|
2.1.2. Train
|
||||||
|
2.1.3. Possibly save checkpoint
|
||||||
|
2.2. Evaluate
|
||||||
|
:return: None
|
||||||
|
"""
|
||||||
|
|
||||||
|
# initialize the network parameters from the global network
|
||||||
|
self.sync()
|
||||||
|
|
||||||
|
# heatup
|
||||||
|
self.heatup(self.heatup_steps)
|
||||||
|
|
||||||
|
# improve
|
||||||
|
if self.task_parameters.task_index is not None:
|
||||||
|
screen.log_title("Starting to improve {} task index {}".format(self.name, self.task_parameters.task_index))
|
||||||
|
else:
|
||||||
|
screen.log_title("Starting to improve {}".format(self.name))
|
||||||
|
|
||||||
|
count_end = self.improve_steps.num_steps
|
||||||
|
while self.total_steps_counters[RunPhase.TRAIN][self.improve_steps.__class__] < count_end:
|
||||||
|
self.train_and_act(self.steps_between_evaluation_periods)
|
||||||
|
self.evaluate(self.evaluation_steps)
|
||||||
|
|
||||||
def restore_checkpoint(self):
|
def restore_checkpoint(self):
|
||||||
self.verify_graph_was_created()
|
self.verify_graph_was_created()
|
||||||
|
|
||||||
@@ -489,39 +520,6 @@ class GraphManager(object):
|
|||||||
data_store = get_data_store(self.data_store_params)
|
data_store = get_data_store(self.data_store_params)
|
||||||
data_store.save_to_store()
|
data_store.save_to_store()
|
||||||
|
|
||||||
|
|
||||||
def improve(self):
|
|
||||||
"""
|
|
||||||
The main loop of the run.
|
|
||||||
Defined in the following steps:
|
|
||||||
1. Heatup
|
|
||||||
2. Repeat:
|
|
||||||
2.1. Repeat:
|
|
||||||
2.1.1. Act
|
|
||||||
2.1.2. Train
|
|
||||||
2.1.3. Possibly save checkpoint
|
|
||||||
2.2. Evaluate
|
|
||||||
:return: None
|
|
||||||
"""
|
|
||||||
self.verify_graph_was_created()
|
|
||||||
|
|
||||||
# initialize the network parameters from the global network
|
|
||||||
self.sync()
|
|
||||||
|
|
||||||
# heatup
|
|
||||||
self.heatup(self.heatup_steps)
|
|
||||||
|
|
||||||
# improve
|
|
||||||
if self.task_parameters.task_index is not None:
|
|
||||||
screen.log_title("Starting to improve {} task index {}".format(self.name, self.task_parameters.task_index))
|
|
||||||
else:
|
|
||||||
screen.log_title("Starting to improve {}".format(self.name))
|
|
||||||
|
|
||||||
count_end = self.improve_steps.num_steps
|
|
||||||
while self.total_steps_counters[RunPhase.TRAIN][self.improve_steps.__class__] < count_end:
|
|
||||||
self.train_and_act(self.steps_between_evaluation_periods)
|
|
||||||
self.evaluate(self.evaluation_steps)
|
|
||||||
|
|
||||||
def verify_graph_was_created(self):
|
def verify_graph_was_created(self):
|
||||||
"""
|
"""
|
||||||
Verifies that the graph was already created, and if not, it creates it with the default task parameters
|
Verifies that the graph was already created, and if not, it creates it with the default task parameters
|
||||||
|
|||||||
Reference in New Issue
Block a user