1
0
mirror of https://github.com/gryf/coach.git synced 2026-02-16 14:05:46 +01:00

Add documentation on distributed Coach. (#158)

* Added documentation on distributed Coach.
This commit is contained in:
Balaji Subramaniam
2018-11-27 02:26:15 -08:00
committed by Gal Novik
parent e3ecf445e2
commit d06197f663
151 changed files with 5302 additions and 643 deletions

View File

@@ -22,10 +22,21 @@ class NFSDataStoreParameters(DataStoreParameters):
class NFSDataStore(DataStore):
"""
An implementation of data store which uses NFS for storing policy checkpoints when using Coach in distributed mode.
The policy checkpoints are written by the trainer and read by the rollout worker.
"""
def __init__(self, params: NFSDataStoreParameters):
"""
:param params: The parameters required to use the NFS data store.
"""
self.params = params
def deploy(self) -> bool:
"""
Deploy the NFS server in an orchestrator if/when required.
"""
if self.params.orchestrator_type == "kubernetes":
if not self.params.deployed:
if not self.deploy_k8s_nfs():
@@ -43,6 +54,9 @@ class NFSDataStore(DataStore):
)
def undeploy(self) -> bool:
"""
Undeploy the NFS server and resources from an orchestrator.
"""
if self.params.orchestrator_type == "kubernetes":
if not self.params.deployed:
if not self.undeploy_k8s_nfs():
@@ -59,6 +73,9 @@ class NFSDataStore(DataStore):
pass
def deploy_k8s_nfs(self) -> bool:
"""
Deploy the NFS server in the Kubernetes orchestrator.
"""
from kubernetes import client as k8sclient
name = "nfs-server-{}".format(uuid.uuid4())
@@ -148,6 +165,9 @@ class NFSDataStore(DataStore):
return True
def create_k8s_nfs_resources(self) -> bool:
"""
Create NFS resources such as PV and PVC in Kubernetes.
"""
from kubernetes import client as k8sclient
pv_name = "nfs-ckpt-pv-{}".format(uuid.uuid4())
@@ -226,6 +246,9 @@ class NFSDataStore(DataStore):
return True
def delete_k8s_nfs_resources(self) -> bool:
"""
Delete NFS resources such as PV and PVC from the Kubernetes orchestrator.
"""
from kubernetes import client as k8sclient
del_options = k8sclient.V1DeleteOptions()

View File

@@ -23,7 +23,16 @@ class S3DataStoreParameters(DataStoreParameters):
class S3DataStore(DataStore):
"""
An implementation of the data store using S3 for storing policy checkpoints when using Coach in distributed mode.
The policy checkpoints are written by the trainer and read by the rollout worker.
"""
def __init__(self, params: S3DataStoreParameters):
"""
:param params: The parameters required to use the S3 data store.
"""
super(S3DataStore, self).__init__(params)
self.params = params
access_key = None
@@ -51,6 +60,10 @@ class S3DataStore(DataStore):
return True
def save_to_store(self):
"""
save_to_store() uploads the policy checkpoint, gifs and videos to the S3 data store. It reads the checkpoint state files and
uploads only the latest checkpoint files to S3. It is used by the trainer in Coach when used in the distributed mode.
"""
try:
# remove lock file if it exists
self.mc.remove_object(self.params.bucket_name, SyncFiles.LOCKFILE.value)
@@ -95,6 +108,10 @@ class S3DataStore(DataStore):
print("Got exception: %s\n while saving to S3", e)
def load_from_store(self):
"""
load_from_store() downloads a new checkpoint from the S3 data store when it is not available locally. It is used
by the rollout workers when using Coach in distributed mode.
"""
try:
state_file = CheckpointStateFile(os.path.abspath(self.params.checkpoint_dir))

View File

@@ -25,17 +25,30 @@ class RedisPubSubMemoryBackendParameters(MemoryBackendParameters):
class RedisPubSubBackend(MemoryBackend):
"""
A memory backend which transfers the experiences from the rollout to the training worker using Redis Pub/Sub in
Coach when distributed mode is used.
"""
def __init__(self, params: RedisPubSubMemoryBackendParameters):
"""
:param params: The Redis parameters to be used with this Redis Pub/Sub instance.
"""
self.params = params
self.redis_connection = redis.Redis(self.params.redis_address, self.params.redis_port)
self.redis_server_name = 'redis-server-{}'.format(uuid.uuid4())
self.redis_service_name = 'redis-service-{}'.format(uuid.uuid4())
def store(self, obj):
"""
:param obj: The object to store in memory. The object is either a Tranisition or Episode type.
"""
self.redis_connection.publish(self.params.channel, pickle.dumps(obj))
def deploy(self):
"""
Deploy the Redis Pub/Sub service in an orchestrator.
"""
if not self.params.deployed:
if self.params.orchestrator_type == 'kubernetes':
self.deploy_kubernetes()
@@ -44,7 +57,9 @@ class RedisPubSubBackend(MemoryBackend):
time.sleep(10)
def deploy_kubernetes(self):
"""
Deploy the Redis Pub/Sub service in Kubernetes orchestrator.
"""
if 'namespace' not in self.params.orchestrator_params:
self.params.orchestrator_params['namespace'] = "default"
from kubernetes import client
@@ -111,6 +126,9 @@ class RedisPubSubBackend(MemoryBackend):
return False
def undeploy(self):
"""
Undeploy the Redis Pub/Sub service in an orchestrator.
"""
from kubernetes import client
if self.params.deployed:
return
@@ -133,9 +151,15 @@ class RedisPubSubBackend(MemoryBackend):
pass
def fetch(self, num_consecutive_playing_steps=None):
"""
:param num_consecutive_playing_steps: The number steps to fetch.
"""
return RedisSub(redis_address=self.params.redis_address, redis_port=self.params.redis_port, channel=self.params.channel).run(num_consecutive_playing_steps)
def subscribe(self, agent):
"""
:param agent: The agent in use.
"""
redis_sub = RedisSub(redis_address=self.params.redis_address, redis_port=self.params.redis_port, channel=self.params.channel)
return redis_sub
@@ -154,6 +178,9 @@ class RedisSub(object):
self.subscriber = self.pubsub.subscribe(self.channel)
def run(self, num_consecutive_playing_steps):
"""
:param num_consecutive_playing_steps: The number steps to fetch.
"""
transitions = 0
episodes = 0
steps = 0

View File

@@ -54,8 +54,17 @@ class KubernetesParameters(DeployParameters):
class Kubernetes(Deploy):
"""
An orchestrator implmentation which uses Kubernetes to deploy the components such as training and rollout workers
and Redis Pub/Sub in Coach when used in the distributed mode.
"""
def __init__(self, params: KubernetesParameters):
"""
:param params: The Kubernetes parameters which are used for deploying the components in Coach. These parameters
include namespace and kubeconfig.
"""
super().__init__(params)
self.params = params
if self.params.kubeconfig:
@@ -93,6 +102,9 @@ class Kubernetes(Deploy):
self.s3_secret_key = os.environ.get('SECRET_ACCESS_KEY')
def setup(self) -> bool:
"""
Deploys the memory backend and data stores if required.
"""
self.memory_backend.deploy()
if not self.data_store.deploy():
@@ -102,6 +114,9 @@ class Kubernetes(Deploy):
return True
def deploy_trainer(self) -> bool:
"""
Deploys the training worker in Kubernetes.
"""
trainer_params = self.params.run_type_params.get(str(RunType.TRAINER), None)
if not trainer_params:
@@ -179,6 +194,9 @@ class Kubernetes(Deploy):
return False
def deploy_worker(self):
"""
Deploys the rollout worker(s) in Kubernetes.
"""
worker_params = self.params.run_type_params.get(str(RunType.ROLLOUT_WORKER), None)
if not worker_params:
@@ -258,6 +276,9 @@ class Kubernetes(Deploy):
return False
def worker_logs(self, path='./logs'):
"""
:param path: Path to store the worker logs.
"""
worker_params = self.params.run_type_params.get(str(RunType.ROLLOUT_WORKER), None)
if not worker_params:
return
@@ -288,6 +309,9 @@ class Kubernetes(Deploy):
self.tail_log(pod_name, api_client)
def trainer_logs(self):
"""
Get the logs from trainer.
"""
trainer_params = self.params.run_type_params.get(str(RunType.TRAINER), None)
if not trainer_params:
return
@@ -346,6 +370,10 @@ class Kubernetes(Deploy):
return
def undeploy(self):
"""
Undeploy all the components, such as trainer and rollout worker(s), Redis pub/sub and data store, when required.
"""
trainer_params = self.params.run_type_params.get(str(RunType.TRAINER), None)
api_client = k8sclient.BatchV1Api()
delete_options = k8sclient.V1DeleteOptions(