mirror of
https://github.com/gryf/coach.git
synced 2025-12-18 11:40:18 +01:00
allow custom number of training steps
This commit is contained in:
@@ -54,6 +54,8 @@ stop_kubernetes:
|
||||
kubectl delete pv --ignore-not-found nfs-checkpoint-pv
|
||||
kubectl delete pvc --ignore-not-found nfs-checkpoint-pvc
|
||||
kubectl delete deployment --ignore-not-found redis-server
|
||||
kubectl get jobs | grep train | awk "{print $\1}" | xargs kubectl delete jobs
|
||||
kubectl get jobs | grep worker | awk "{print $\1}" | xargs kubectl delete jobs
|
||||
|
||||
kubernetes: stop_kubernetes
|
||||
python3 ${CONTEXT}/rl_coach/orchestrators/start_training.py --preset CartPole_DQN_distributed --image ${IMAGE} -ns 10.63.249.182 -np /
|
||||
|
||||
@@ -22,6 +22,23 @@ def heatup(graph_manager):
|
||||
time.sleep(1)
|
||||
|
||||
|
||||
class StepsLoop(object):
|
||||
"""StepsLoop facilitates a simple while loop"""
|
||||
def __init__(self, steps_counters, phase, steps):
|
||||
super(StepsLoop, self).__init__()
|
||||
self.steps_counters = steps_counters
|
||||
self.phase = phase
|
||||
self.steps = steps
|
||||
|
||||
self.step_end = self._step_count() + steps.num_steps
|
||||
|
||||
def _step_count(self):
|
||||
return self.steps_counters[self.phase][self.steps.__class__]
|
||||
|
||||
def continue(self):
|
||||
return self._step_count() < count_end:
|
||||
|
||||
|
||||
def training_worker(graph_manager, checkpoint_dir):
|
||||
"""
|
||||
restore a checkpoint then perform rollouts using the restored model
|
||||
@@ -38,7 +55,8 @@ def training_worker(graph_manager, checkpoint_dir):
|
||||
heatup(graph_manager)
|
||||
|
||||
# training loop
|
||||
for _ in range(40):
|
||||
stepper = StepsLoop(graph_manager.total_steps_counters, RunPhase.TRAIN, graph_manager.improve_steps)
|
||||
while stepper.continue():
|
||||
graph_manager.phase = core_types.RunPhase.TRAIN
|
||||
graph_manager.train(core_types.TrainingSteps(1))
|
||||
graph_manager.phase = core_types.RunPhase.UNDEFINED
|
||||
|
||||
Reference in New Issue
Block a user