1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-18 11:40:18 +01:00

allow custom number of training steps

This commit is contained in:
Zach Dwiel
2018-09-21 15:49:06 -04:00
committed by zach dwiel
parent d69332efd4
commit 67faa80ea0
2 changed files with 21 additions and 1 deletions

View File

@@ -22,6 +22,23 @@ def heatup(graph_manager):
time.sleep(1)
class StepsLoop(object):
"""StepsLoop facilitates a simple while loop"""
def __init__(self, steps_counters, phase, steps):
super(StepsLoop, self).__init__()
self.steps_counters = steps_counters
self.phase = phase
self.steps = steps
self.step_end = self._step_count() + steps.num_steps
def _step_count(self):
return self.steps_counters[self.phase][self.steps.__class__]
def continue(self):
return self._step_count() < count_end:
def training_worker(graph_manager, checkpoint_dir):
"""
restore a checkpoint then perform rollouts using the restored model
@@ -38,7 +55,8 @@ def training_worker(graph_manager, checkpoint_dir):
heatup(graph_manager)
# training loop
for _ in range(40):
stepper = StepsLoop(graph_manager.total_steps_counters, RunPhase.TRAIN, graph_manager.improve_steps)
while stepper.continue():
graph_manager.phase = core_types.RunPhase.TRAIN
graph_manager.train(core_types.TrainingSteps(1))
graph_manager.phase = core_types.RunPhase.UNDEFINED