mirror of
https://github.com/gryf/coach.git
synced 2025-12-18 03:30:19 +01:00
Integrate coach.py params with distributed Coach. (#42)
* Integrate coach.py params with distributed Coach. * Minor improvements - Use enums instead of constants. - Reduce code duplication. - Ask experiment name with timeout.
This commit is contained in:
committed by
GitHub
parent
95b4fc6888
commit
7e7006305a
@@ -2,6 +2,7 @@ import os
|
||||
import uuid
|
||||
import json
|
||||
import time
|
||||
from enum import Enum
|
||||
from typing import List
|
||||
from configparser import ConfigParser, Error
|
||||
from rl_coach.orchestrators.deploy import Deploy, DeployParameters
|
||||
@@ -12,10 +13,19 @@ from rl_coach.data_stores.data_store import DataStoreParameters
|
||||
from rl_coach.data_stores.data_store_impl import get_data_store
|
||||
|
||||
|
||||
class RunType(Enum):
|
||||
ORCHESTRATOR = "orchestrator"
|
||||
TRAINER = "trainer"
|
||||
ROLLOUT_WORKER = "rollout-worker"
|
||||
|
||||
def __str__(self):
|
||||
return self.value
|
||||
|
||||
|
||||
class RunTypeParameters():
|
||||
|
||||
def __init__(self, image: str, command: list(), arguments: list() = None,
|
||||
run_type: str = "trainer", checkpoint_dir: str = "/checkpoint",
|
||||
run_type: str = str(RunType.TRAINER), checkpoint_dir: str = "/checkpoint",
|
||||
num_replicas: int = 1, orchestration_params: dict=None):
|
||||
self.image = image
|
||||
self.command = command
|
||||
@@ -97,12 +107,12 @@ class Kubernetes(Deploy):
|
||||
|
||||
def deploy_trainer(self) -> bool:
|
||||
|
||||
trainer_params = self.params.run_type_params.get('trainer', None)
|
||||
trainer_params = self.params.run_type_params.get(str(RunType.TRAINER), None)
|
||||
if not trainer_params:
|
||||
return False
|
||||
|
||||
trainer_params.command += ['--memory-backend-params', json.dumps(self.params.memory_backend_parameters.__dict__)]
|
||||
trainer_params.command += ['--data-store-params', json.dumps(self.params.data_store_params.__dict__)]
|
||||
trainer_params.command += ['--memory_backend_params', json.dumps(self.params.memory_backend_parameters.__dict__)]
|
||||
trainer_params.command += ['--data_store_params', json.dumps(self.params.data_store_params.__dict__)]
|
||||
|
||||
name = "{}-{}".format(trainer_params.run_type, uuid.uuid4())
|
||||
|
||||
@@ -175,13 +185,13 @@ class Kubernetes(Deploy):
|
||||
|
||||
def deploy_worker(self):
|
||||
|
||||
worker_params = self.params.run_type_params.get('worker', None)
|
||||
worker_params = self.params.run_type_params.get(str(RunType.ROLLOUT_WORKER), None)
|
||||
if not worker_params:
|
||||
return False
|
||||
|
||||
worker_params.command += ['--memory-backend-params', json.dumps(self.params.memory_backend_parameters.__dict__)]
|
||||
worker_params.command += ['--data-store-params', json.dumps(self.params.data_store_params.__dict__)]
|
||||
worker_params.command += ['--num-workers', '{}'.format(worker_params.num_replicas)]
|
||||
worker_params.command += ['--memory_backend_params', json.dumps(self.params.memory_backend_parameters.__dict__)]
|
||||
worker_params.command += ['--data_store_params', json.dumps(self.params.data_store_params.__dict__)]
|
||||
worker_params.command += ['--num_workers', '{}'.format(worker_params.num_replicas)]
|
||||
|
||||
name = "{}-{}".format(worker_params.run_type, uuid.uuid4())
|
||||
|
||||
@@ -255,7 +265,7 @@ class Kubernetes(Deploy):
|
||||
pass
|
||||
|
||||
def trainer_logs(self):
|
||||
trainer_params = self.params.run_type_params.get('trainer', None)
|
||||
trainer_params = self.params.run_type_params.get(str(RunType.TRAINER), None)
|
||||
if not trainer_params:
|
||||
return
|
||||
|
||||
@@ -313,7 +323,7 @@ class Kubernetes(Deploy):
|
||||
return
|
||||
|
||||
def undeploy(self):
|
||||
trainer_params = self.params.run_type_params.get('trainer', None)
|
||||
trainer_params = self.params.run_type_params.get(str(RunType.TRAINER), None)
|
||||
api_client = k8sclient.AppsV1Api()
|
||||
delete_options = k8sclient.V1DeleteOptions()
|
||||
if trainer_params:
|
||||
@@ -321,7 +331,7 @@ class Kubernetes(Deploy):
|
||||
api_client.delete_namespaced_deployment(trainer_params.orchestration_params['deployment_name'], self.params.namespace, delete_options)
|
||||
except k8sclient.rest.ApiException as e:
|
||||
print("Got exception: %s\n while deleting trainer", e)
|
||||
worker_params = self.params.run_type_params.get('worker', None)
|
||||
worker_params = self.params.run_type_params.get(str(RunType.ROLLOUT_WORKER), None)
|
||||
if worker_params:
|
||||
try:
|
||||
api_client.delete_namespaced_deployment(worker_params.orchestration_params['deployment_name'], self.params.namespace, delete_options)
|
||||
|
||||
Reference in New Issue
Block a user