mirror of
https://github.com/gryf/coach.git
synced 2025-12-18 19:50:17 +01:00
allow custom number of training steps
This commit is contained in:
@@ -54,6 +54,8 @@ stop_kubernetes:
|
|||||||
kubectl delete pv --ignore-not-found nfs-checkpoint-pv
|
kubectl delete pv --ignore-not-found nfs-checkpoint-pv
|
||||||
kubectl delete pvc --ignore-not-found nfs-checkpoint-pvc
|
kubectl delete pvc --ignore-not-found nfs-checkpoint-pvc
|
||||||
kubectl delete deployment --ignore-not-found redis-server
|
kubectl delete deployment --ignore-not-found redis-server
|
||||||
|
kubectl get jobs | grep train | awk "{print $\1}" | xargs kubectl delete jobs
|
||||||
|
kubectl get jobs | grep worker | awk "{print $\1}" | xargs kubectl delete jobs
|
||||||
|
|
||||||
kubernetes: stop_kubernetes
|
kubernetes: stop_kubernetes
|
||||||
python3 ${CONTEXT}/rl_coach/orchestrators/start_training.py --preset CartPole_DQN_distributed --image ${IMAGE} -ns 10.63.249.182 -np /
|
python3 ${CONTEXT}/rl_coach/orchestrators/start_training.py --preset CartPole_DQN_distributed --image ${IMAGE} -ns 10.63.249.182 -np /
|
||||||
|
|||||||
@@ -22,6 +22,23 @@ def heatup(graph_manager):
|
|||||||
time.sleep(1)
|
time.sleep(1)
|
||||||
|
|
||||||
|
|
||||||
|
class StepsLoop(object):
|
||||||
|
"""StepsLoop facilitates a simple while loop"""
|
||||||
|
def __init__(self, steps_counters, phase, steps):
|
||||||
|
super(StepsLoop, self).__init__()
|
||||||
|
self.steps_counters = steps_counters
|
||||||
|
self.phase = phase
|
||||||
|
self.steps = steps
|
||||||
|
|
||||||
|
self.step_end = self._step_count() + steps.num_steps
|
||||||
|
|
||||||
|
def _step_count(self):
|
||||||
|
return self.steps_counters[self.phase][self.steps.__class__]
|
||||||
|
|
||||||
|
def continue(self):
|
||||||
|
return self._step_count() < count_end:
|
||||||
|
|
||||||
|
|
||||||
def training_worker(graph_manager, checkpoint_dir):
|
def training_worker(graph_manager, checkpoint_dir):
|
||||||
"""
|
"""
|
||||||
restore a checkpoint then perform rollouts using the restored model
|
restore a checkpoint then perform rollouts using the restored model
|
||||||
@@ -38,7 +55,8 @@ def training_worker(graph_manager, checkpoint_dir):
|
|||||||
heatup(graph_manager)
|
heatup(graph_manager)
|
||||||
|
|
||||||
# training loop
|
# training loop
|
||||||
for _ in range(40):
|
stepper = StepsLoop(graph_manager.total_steps_counters, RunPhase.TRAIN, graph_manager.improve_steps)
|
||||||
|
while stepper.continue():
|
||||||
graph_manager.phase = core_types.RunPhase.TRAIN
|
graph_manager.phase = core_types.RunPhase.TRAIN
|
||||||
graph_manager.train(core_types.TrainingSteps(1))
|
graph_manager.train(core_types.TrainingSteps(1))
|
||||||
graph_manager.phase = core_types.RunPhase.UNDEFINED
|
graph_manager.phase = core_types.RunPhase.UNDEFINED
|
||||||
|
|||||||
Reference in New Issue
Block a user