diff --git a/docker/Makefile b/docker/Makefile index 9bd680f..b3f74a0 100644 --- a/docker/Makefile +++ b/docker/Makefile @@ -1,3 +1,4 @@ +REGISTRY=nervana-dockrepo01.fm.intel.com:5001/ IMAGE=zdwiel/coach # IMAGE=gcr.io/deep-greens/inference:v5 @@ -43,5 +44,9 @@ run_training_worker: build run_rollout_worker: build ${DOCKER} run ${RUN_ARGUMENTS} -it ${IMAGE} python3 rl_coach/rollout_worker.py --preset CartPole_DQN_distributed -push: - docker push ${IMAGE} +kubernetes: build push + kubectl run -i --tty --attach --image=${IMAGE} --restart=Never date -- python3 rl_coach/orchestrators/start_training.py --preset CartPole_DQN_distributed --image ${IMAGE} + +push: build + ${DOCKER} tag ${IMAGE} ${REGISTRY}${IMAGE} + ${DOCKER} push ${REGISTRY}${IMAGE} diff --git a/rl_coach/orchestrators/start_training.py b/rl_coach/orchestrators/start_training.py new file mode 100644 index 0000000..9950b61 --- /dev/null +++ b/rl_coach/orchestrators/start_training.py @@ -0,0 +1,48 @@ +import argparse + +from rl_coach.orchestrators.kubernetes_orchestrator import KubernetesParameters, Kubernetes + + +def main(preset, image='ajaysudh/testing:coach', redis_ip='redis-service.ajay.svc'): + rollout_command = ['python3', 'rl_coach/rollout_worker.py', '-p', preset] + training_command = ['python3', 'rl_coach/training_worker.py', '-p', preset] + + rollout_params = KubernetesParameters(image, rollout_command, redis_ip=redis_ip, redis_port=6379, num_workers=1) + training_params = KubernetesParameters(image, training_command, redis_ip=redis_ip, redis_port=6379, num_workers=1) + + training_obj = Kubernetes(training_params) + if not training_obj.setup(): + print("Could not setup") + + rollout_obj = Kubernetes(training_params) + if not rollout_obj.setup(): + print("Could not setup") + + if training_obj.deploy(): + print("Successfully deployed") + else: + print("Could not deploy") + + if rollout_obj.deploy(): + print("Successfully deployed") + else: + print("Could not deploy") + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--image', + help="(string) Name of a docker image.", + type=str, + required=True) + parser.add_argument('-p', '--preset', + help="(string) Name of a preset to run (class name from the 'presets' directory.)", + type=str, + required=True) + # parser.add_argument('--checkpoint_dir', + # help='(string) Path to a folder containing a checkpoint to write the model to.', + # type=str, + # default='/checkpoint') + args = parser.parse_args() + + main(preset=args.preset, image=args.image) diff --git a/rl_coach/rollout_worker.py b/rl_coach/rollout_worker.py index 8afde08..27845a6 100644 --- a/rl_coach/rollout_worker.py +++ b/rl_coach/rollout_worker.py @@ -41,8 +41,8 @@ def wait_for_checkpoint(checkpoint_dir, timeout=10): return raise ValueError(( - 'Waited {timeout} seconds, but checkpoint never found in' - ' {checkpoint_dir}' + 'Waited {timeout} seconds, but checkpoint never found in ' + '{checkpoint_dir}' ).format( timeout=timeout, checkpoint_dir=checkpoint_dir,