1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-18 11:40:18 +01:00

allow custom number of training steps

This commit is contained in:
Zach Dwiel
2018-09-21 15:49:06 -04:00
committed by zach dwiel
parent d69332efd4
commit 67faa80ea0
2 changed files with 21 additions and 1 deletions

View File

@@ -54,6 +54,8 @@ stop_kubernetes:
kubectl delete pv --ignore-not-found nfs-checkpoint-pv kubectl delete pv --ignore-not-found nfs-checkpoint-pv
kubectl delete pvc --ignore-not-found nfs-checkpoint-pvc kubectl delete pvc --ignore-not-found nfs-checkpoint-pvc
kubectl delete deployment --ignore-not-found redis-server kubectl delete deployment --ignore-not-found redis-server
kubectl get jobs | grep train | awk "{print $\1}" | xargs kubectl delete jobs
kubectl get jobs | grep worker | awk "{print $\1}" | xargs kubectl delete jobs
kubernetes: stop_kubernetes kubernetes: stop_kubernetes
python3 ${CONTEXT}/rl_coach/orchestrators/start_training.py --preset CartPole_DQN_distributed --image ${IMAGE} -ns 10.63.249.182 -np / python3 ${CONTEXT}/rl_coach/orchestrators/start_training.py --preset CartPole_DQN_distributed --image ${IMAGE} -ns 10.63.249.182 -np /

View File

@@ -22,6 +22,23 @@ def heatup(graph_manager):
time.sleep(1) time.sleep(1)
class StepsLoop(object):
"""StepsLoop facilitates a simple while loop"""
def __init__(self, steps_counters, phase, steps):
super(StepsLoop, self).__init__()
self.steps_counters = steps_counters
self.phase = phase
self.steps = steps
self.step_end = self._step_count() + steps.num_steps
def _step_count(self):
return self.steps_counters[self.phase][self.steps.__class__]
def continue(self):
return self._step_count() < count_end:
def training_worker(graph_manager, checkpoint_dir): def training_worker(graph_manager, checkpoint_dir):
""" """
restore a checkpoint then perform rollouts using the restored model restore a checkpoint then perform rollouts using the restored model
@@ -38,7 +55,8 @@ def training_worker(graph_manager, checkpoint_dir):
heatup(graph_manager) heatup(graph_manager)
# training loop # training loop
for _ in range(40): stepper = StepsLoop(graph_manager.total_steps_counters, RunPhase.TRAIN, graph_manager.improve_steps)
while stepper.continue():
graph_manager.phase = core_types.RunPhase.TRAIN graph_manager.phase = core_types.RunPhase.TRAIN
graph_manager.train(core_types.TrainingSteps(1)) graph_manager.train(core_types.TrainingSteps(1))
graph_manager.phase = core_types.RunPhase.UNDEFINED graph_manager.phase = core_types.RunPhase.UNDEFINED