1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-18 03:30:19 +01:00

Add RedisDataStore (#295)

* GraphManager.set_session also sets self.sess

* make sure that GraphManager.fetch_from_worker uses training phase

* remove unnecessary phase setting in training worker

* reorganize rollout worker

* provide default name to GlobalVariableSaver.__init__ since it isn't really used anyway

* allow dividing TrainingSteps and EnvironmentSteps

* add timestamps to the log

* added redis data store

* conflict merge fix
This commit is contained in:
Zach Dwiel
2019-08-28 14:15:58 -04:00
committed by shadiendrawis
parent 34e1c04f29
commit 7b0fccb041
18 changed files with 528 additions and 120 deletions

View File

@@ -124,6 +124,11 @@ class Kubernetes(Deploy):
"""
self.memory_backend.deploy()
if self.params.data_store_params.store_type == "redis":
self.data_store.params.redis_address = self.memory_backend.params.redis_address
self.data_store.params.redis_port = self.memory_backend.params.redis_port
if not self.data_store.deploy():
return False
if self.params.data_store_params.store_type == "nfs":
@@ -146,6 +151,8 @@ class Kubernetes(Deploy):
trainer_params.command += ['--data_store_params', json.dumps(self.params.data_store_params.__dict__)]
name = "{}-{}".format(trainer_params.run_type, uuid.uuid4())
# TODO: instead of defining each container and template spec from scratch, loaded default
# configuration and modify them as necessary depending on the store type
if self.params.data_store_params.store_type == "nfs":
container = k8sclient.V1Container(
name=name,
@@ -171,7 +178,7 @@ class Kubernetes(Deploy):
restart_policy='Never'
),
)
else:
elif self.params.data_store_params.store_type == "s3":
container = k8sclient.V1Container(
name=name,
image=trainer_params.image,
@@ -190,6 +197,34 @@ class Kubernetes(Deploy):
restart_policy='Never'
),
)
elif self.params.data_store_params.store_type == "redis":
container = k8sclient.V1Container(
name=name,
image=trainer_params.image,
command=trainer_params.command,
args=trainer_params.arguments,
image_pull_policy='Always',
stdin=True,
tty=True,
resources=k8sclient.V1ResourceRequirements(
limits={
"cpu": "40",
"memory": "4Gi",
"nvidia.com/gpu": "1",
}
),
)
template = k8sclient.V1PodTemplateSpec(
metadata=k8sclient.V1ObjectMeta(labels={'app': name}),
spec=k8sclient.V1PodSpec(
containers=[container],
restart_policy='Never'
),
)
else:
raise ValueError("unexpected store_type {}. expected 's3', 'nfs', 'redis'".format(
self.params.data_store_params.store_type
))
job_spec = k8sclient.V1JobSpec(
completions=1,
@@ -221,12 +256,17 @@ class Kubernetes(Deploy):
if not worker_params:
return False
# At this point, the memory backend and data store have been deployed and in the process,
# these parameters have been updated to include things like the hostname and port the
# service can be found at.
worker_params.command += ['--memory_backend_params', json.dumps(self.params.memory_backend_parameters.__dict__)]
worker_params.command += ['--data_store_params', json.dumps(self.params.data_store_params.__dict__)]
worker_params.command += ['--num_workers', '{}'.format(worker_params.num_replicas)]
name = "{}-{}".format(worker_params.run_type, uuid.uuid4())
# TODO: instead of defining each container and template spec from scratch, loaded default
# configuration and modify them as necessary depending on the store type
if self.params.data_store_params.store_type == "nfs":
container = k8sclient.V1Container(
name=name,
@@ -252,7 +292,7 @@ class Kubernetes(Deploy):
restart_policy='Never'
),
)
else:
elif self.params.data_store_params.store_type == "s3":
container = k8sclient.V1Container(
name=name,
image=worker_params.image,
@@ -271,6 +311,32 @@ class Kubernetes(Deploy):
restart_policy='Never'
)
)
elif self.params.data_store_params.store_type == "redis":
container = k8sclient.V1Container(
name=name,
image=worker_params.image,
command=worker_params.command,
args=worker_params.arguments,
image_pull_policy='Always',
stdin=True,
tty=True,
resources=k8sclient.V1ResourceRequirements(
limits={
"cpu": "8",
"memory": "4Gi",
# "nvidia.com/gpu": "0",
}
),
)
template = k8sclient.V1PodTemplateSpec(
metadata=k8sclient.V1ObjectMeta(labels={'app': name}),
spec=k8sclient.V1PodSpec(
containers=[container],
restart_policy='Never'
)
)
else:
raise ValueError('unexpected store type {}'.format(self.params.data_store_params.store_type))
job_spec = k8sclient.V1JobSpec(
completions=worker_params.num_replicas,