diff --git a/rl_coach/orchestrators/kubernetes_orchestrator.py b/rl_coach/orchestrators/kubernetes_orchestrator.py index 9b21ee4..6bdcd8c 100644 --- a/rl_coach/orchestrators/kubernetes_orchestrator.py +++ b/rl_coach/orchestrators/kubernetes_orchestrator.py @@ -42,7 +42,7 @@ class Kubernetes(Deploy): print("Failed to setup redis") return False - self.deploy_parameters.command += ['-r', self.deploy_parameters.redis_ip, '-p', '{}'.format(self.deploy_parameters.redis_port)] + self.deploy_parameters.command += ['--redis_ip', self.deploy_parameters.redis_ip, '--redis_port', '{}'.format(self.deploy_parameters.redis_port)] return True diff --git a/rl_coach/orchestrators/test.py b/rl_coach/orchestrators/test.py index 56428e0..3142b6a 100644 --- a/rl_coach/orchestrators/test.py +++ b/rl_coach/orchestrators/test.py @@ -2,10 +2,10 @@ from rl_coach.orchestrators.kubernetes_orchestrator import KubernetesParameters, # image = 'gcr.io/constant-cubist-173123/coach:latest' image = 'ajaysudh/testing:coach' -command = ['python3', 'rl_coach/rollout_worker.py'] +command = ['python3', 'rl_coach/rollout_worker.py', '-p', 'CartPole_DQN_distributed'] # command = ['sleep', '10h'] -params = KubernetesParameters(image, command, kubeconfig='~/.kube/config', redis_ip='redis-service.ajay.svc', redis_port=6379, num_workers=10) +params = KubernetesParameters(image, command, kubeconfig='~/.kube/config', redis_ip='redis-service.ajay.svc', redis_port=6379, num_workers=1) # params = KubernetesParameters(image, command, kubeconfig='~/.kube/config') obj = Kubernetes(params) diff --git a/rl_coach/rollout_worker.py b/rl_coach/rollout_worker.py index f1c5363..ef0ff7d 100644 --- a/rl_coach/rollout_worker.py +++ b/rl_coach/rollout_worker.py @@ -20,7 +20,7 @@ def rollout_worker(graph_manager, checkpoint_dir): task_parameters = TaskParameters() task_parameters.__dict__['checkpoint_restore_dir'] = checkpoint_dir graph_manager.create_graph(task_parameters) - + graph_manager.phase = RunPhase.TRAIN graph_manager.act(EnvironmentEpisodes(num_steps=10)) graph_manager.phase = RunPhase.UNDEFINED @@ -36,11 +36,19 @@ def main(): help='(string) Path to a folder containing a checkpoint to restore the model from.', type=str, default='/checkpoint') + parser.add_argument('-r', '--redis_ip', + help="(string) IP or host for the redis server", + default='localhost', + type=str) + parser.add_argument('-rp', '--redis_port', + help="(int) Port of the redis server", + default=6379, + type=int) args = parser.parse_args() graph_manager = short_dynamic_import(expand_preset(args.preset), ignore_module_case=True) - graph_manager.agent_parameters.memory.redis_ip = args.redis_ip + graph_manager.agent_params.memory.redis_ip = args.redis_ip graph_manager.agent_params.memory.redis_port = args.redis_port rollout_worker(