mirror of
https://github.com/gryf/coach.git
synced 2025-12-18 11:40:18 +01:00
Adding target reward and target sucess (#58)
* Adding target reward * Adding target successs * Addressing comments * Using custom_reward_threshold and target_success_rate * Adding exit message * Moving success rate to environment * Making target_success_rate optional
This commit is contained in:
committed by
Balaji Subramaniam
parent
0fe583186e
commit
875d6ef017
@@ -79,6 +79,7 @@ class Kubernetes(Deploy):
|
||||
self.memory_backend = get_memory_backend(self.params.memory_backend_parameters)
|
||||
|
||||
self.params.data_store_params.orchestrator_params = {'namespace': self.params.namespace}
|
||||
self.params.data_store_params.namespace = self.params.namespace
|
||||
self.data_store = get_data_store(self.params.data_store_params)
|
||||
|
||||
if self.params.data_store_params.store_type == "s3":
|
||||
@@ -137,7 +138,8 @@ class Kubernetes(Deploy):
|
||||
volumes=[k8sclient.V1Volume(
|
||||
name="nfs-pvc",
|
||||
persistent_volume_claim=self.nfs_pvc
|
||||
)]
|
||||
)],
|
||||
restart_policy='OnFailure'
|
||||
),
|
||||
)
|
||||
else:
|
||||
@@ -155,32 +157,30 @@ class Kubernetes(Deploy):
|
||||
template = k8sclient.V1PodTemplateSpec(
|
||||
metadata=k8sclient.V1ObjectMeta(labels={'app': name}),
|
||||
spec=k8sclient.V1PodSpec(
|
||||
containers=[container]
|
||||
containers=[container],
|
||||
restart_policy='OnFailure'
|
||||
),
|
||||
)
|
||||
|
||||
deployment_spec = k8sclient.V1DeploymentSpec(
|
||||
replicas=trainer_params.num_replicas,
|
||||
template=template,
|
||||
selector=k8sclient.V1LabelSelector(
|
||||
match_labels={'app': name}
|
||||
)
|
||||
job_spec = k8sclient.V1JobSpec(
|
||||
completions=1,
|
||||
template=template
|
||||
)
|
||||
|
||||
deployment = k8sclient.V1Deployment(
|
||||
api_version='apps/v1',
|
||||
kind='Deployment',
|
||||
job = k8sclient.V1Job(
|
||||
api_version="batch/v1",
|
||||
kind="Job",
|
||||
metadata=k8sclient.V1ObjectMeta(name=name),
|
||||
spec=deployment_spec
|
||||
spec=job_spec
|
||||
)
|
||||
|
||||
api_client = k8sclient.AppsV1Api()
|
||||
api_client = k8sclient.BatchV1Api()
|
||||
try:
|
||||
api_client.create_namespaced_deployment(self.params.namespace, deployment)
|
||||
trainer_params.orchestration_params['deployment_name'] = name
|
||||
api_client.create_namespaced_job(self.params.namespace, job)
|
||||
trainer_params.orchestration_params['job_name'] = name
|
||||
return True
|
||||
except k8sclient.rest.ApiException as e:
|
||||
print("Got exception: %s\n while creating deployment", e)
|
||||
print("Got exception: %s\n while creating job", e)
|
||||
return False
|
||||
|
||||
def deploy_worker(self):
|
||||
@@ -217,6 +217,7 @@ class Kubernetes(Deploy):
|
||||
name="nfs-pvc",
|
||||
persistent_volume_claim=self.nfs_pvc
|
||||
)],
|
||||
restart_policy='OnFailure'
|
||||
),
|
||||
)
|
||||
else:
|
||||
@@ -234,31 +235,31 @@ class Kubernetes(Deploy):
|
||||
template = k8sclient.V1PodTemplateSpec(
|
||||
metadata=k8sclient.V1ObjectMeta(labels={'app': name}),
|
||||
spec=k8sclient.V1PodSpec(
|
||||
containers=[container]
|
||||
containers=[container],
|
||||
restart_policy='OnFailure'
|
||||
)
|
||||
)
|
||||
|
||||
deployment_spec = k8sclient.V1DeploymentSpec(
|
||||
replicas=worker_params.num_replicas,
|
||||
template=template,
|
||||
selector=k8sclient.V1LabelSelector(
|
||||
match_labels={'app': name}
|
||||
)
|
||||
)
|
||||
deployment = k8sclient.V1Deployment(
|
||||
api_version='apps/v1',
|
||||
kind="Deployment",
|
||||
metadata=k8sclient.V1ObjectMeta(name=name),
|
||||
spec=deployment_spec
|
||||
job_spec = k8sclient.V1JobSpec(
|
||||
completions=worker_params.num_replicas,
|
||||
parallelism=worker_params.num_replicas,
|
||||
template=template
|
||||
)
|
||||
|
||||
api_client = k8sclient.AppsV1Api()
|
||||
job = k8sclient.V1Job(
|
||||
api_version="batch/v1",
|
||||
kind="Job",
|
||||
metadata=k8sclient.V1ObjectMeta(name=name),
|
||||
spec=job_spec
|
||||
)
|
||||
|
||||
api_client = k8sclient.BatchV1Api()
|
||||
try:
|
||||
api_client.create_namespaced_deployment(self.params.namespace, deployment)
|
||||
worker_params.orchestration_params['deployment_name'] = name
|
||||
api_client.create_namespaced_job(self.params.namespace, job)
|
||||
worker_params.orchestration_params['job_name'] = name
|
||||
return True
|
||||
except k8sclient.rest.ApiException as e:
|
||||
print("Got exception: %s\n while creating deployment", e)
|
||||
print("Got exception: %s\n while creating Job", e)
|
||||
return False
|
||||
|
||||
def worker_logs(self):
|
||||
@@ -273,7 +274,7 @@ class Kubernetes(Deploy):
|
||||
pod = None
|
||||
try:
|
||||
pods = api_client.list_namespaced_pod(self.params.namespace, label_selector='app={}'.format(
|
||||
trainer_params.orchestration_params['deployment_name']
|
||||
trainer_params.orchestration_params['job_name']
|
||||
))
|
||||
|
||||
pod = pods.items[0]
|
||||
@@ -324,17 +325,20 @@ class Kubernetes(Deploy):
|
||||
|
||||
def undeploy(self):
|
||||
trainer_params = self.params.run_type_params.get(str(RunType.TRAINER), None)
|
||||
api_client = k8sclient.AppsV1Api()
|
||||
delete_options = k8sclient.V1DeleteOptions()
|
||||
api_client = k8sclient.BatchV1Api()
|
||||
delete_options = k8sclient.V1DeleteOptions(
|
||||
propagation_policy="Foreground"
|
||||
)
|
||||
|
||||
if trainer_params:
|
||||
try:
|
||||
api_client.delete_namespaced_deployment(trainer_params.orchestration_params['deployment_name'], self.params.namespace, delete_options)
|
||||
api_client.delete_namespaced_job(trainer_params.orchestration_params['job_name'], self.params.namespace, delete_options)
|
||||
except k8sclient.rest.ApiException as e:
|
||||
print("Got exception: %s\n while deleting trainer", e)
|
||||
worker_params = self.params.run_type_params.get(str(RunType.ROLLOUT_WORKER), None)
|
||||
if worker_params:
|
||||
try:
|
||||
api_client.delete_namespaced_deployment(worker_params.orchestration_params['deployment_name'], self.params.namespace, delete_options)
|
||||
api_client.delete_namespaced_job(worker_params.orchestration_params['job_name'], self.params.namespace, delete_options)
|
||||
except k8sclient.rest.ApiException as e:
|
||||
print("Got exception: %s\n while deleting workers", e)
|
||||
self.memory_backend.undeploy()
|
||||
|
||||
Reference in New Issue
Block a user