From 67faa80ea0dfe62da3215933c2c8da0aa814d4c4 Mon Sep 17 00:00:00 2001 From: Zach Dwiel Date: Fri, 21 Sep 2018 15:49:06 -0400 Subject: [PATCH] allow custom number of training steps --- docker/Makefile | 2 ++ rl_coach/training_worker.py | 20 +++++++++++++++++++- 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/docker/Makefile b/docker/Makefile index ea35972..9881378 100644 --- a/docker/Makefile +++ b/docker/Makefile @@ -54,6 +54,8 @@ stop_kubernetes: kubectl delete pv --ignore-not-found nfs-checkpoint-pv kubectl delete pvc --ignore-not-found nfs-checkpoint-pvc kubectl delete deployment --ignore-not-found redis-server + kubectl get jobs | grep train | awk "{print $\1}" | xargs kubectl delete jobs + kubectl get jobs | grep worker | awk "{print $\1}" | xargs kubectl delete jobs kubernetes: stop_kubernetes python3 ${CONTEXT}/rl_coach/orchestrators/start_training.py --preset CartPole_DQN_distributed --image ${IMAGE} -ns 10.63.249.182 -np / diff --git a/rl_coach/training_worker.py b/rl_coach/training_worker.py index f2c569b..d81f54c 100644 --- a/rl_coach/training_worker.py +++ b/rl_coach/training_worker.py @@ -22,6 +22,23 @@ def heatup(graph_manager): time.sleep(1) +class StepsLoop(object): + """StepsLoop facilitates a simple while loop""" + def __init__(self, steps_counters, phase, steps): + super(StepsLoop, self).__init__() + self.steps_counters = steps_counters + self.phase = phase + self.steps = steps + + self.step_end = self._step_count() + steps.num_steps + + def _step_count(self): + return self.steps_counters[self.phase][self.steps.__class__] + + def continue(self): + return self._step_count() < count_end: + + def training_worker(graph_manager, checkpoint_dir): """ restore a checkpoint then perform rollouts using the restored model @@ -38,7 +55,8 @@ def training_worker(graph_manager, checkpoint_dir): heatup(graph_manager) # training loop - for _ in range(40): + stepper = StepsLoop(graph_manager.total_steps_counters, RunPhase.TRAIN, graph_manager.improve_steps) + while stepper.continue(): graph_manager.phase = core_types.RunPhase.TRAIN graph_manager.train(core_types.TrainingSteps(1)) graph_manager.phase = core_types.RunPhase.UNDEFINED