mirror of
https://github.com/gryf/coach.git
synced 2026-02-16 14:05:46 +01:00
Add documentation on distributed Coach. (#158)
* Added documentation on distributed Coach.
This commit is contained in:
committed by
Gal Novik
parent
e3ecf445e2
commit
d06197f663
@@ -22,10 +22,21 @@ class NFSDataStoreParameters(DataStoreParameters):
|
||||
|
||||
|
||||
class NFSDataStore(DataStore):
|
||||
"""
|
||||
An implementation of data store which uses NFS for storing policy checkpoints when using Coach in distributed mode.
|
||||
The policy checkpoints are written by the trainer and read by the rollout worker.
|
||||
"""
|
||||
|
||||
def __init__(self, params: NFSDataStoreParameters):
|
||||
"""
|
||||
:param params: The parameters required to use the NFS data store.
|
||||
"""
|
||||
self.params = params
|
||||
|
||||
def deploy(self) -> bool:
|
||||
"""
|
||||
Deploy the NFS server in an orchestrator if/when required.
|
||||
"""
|
||||
if self.params.orchestrator_type == "kubernetes":
|
||||
if not self.params.deployed:
|
||||
if not self.deploy_k8s_nfs():
|
||||
@@ -43,6 +54,9 @@ class NFSDataStore(DataStore):
|
||||
)
|
||||
|
||||
def undeploy(self) -> bool:
|
||||
"""
|
||||
Undeploy the NFS server and resources from an orchestrator.
|
||||
"""
|
||||
if self.params.orchestrator_type == "kubernetes":
|
||||
if not self.params.deployed:
|
||||
if not self.undeploy_k8s_nfs():
|
||||
@@ -59,6 +73,9 @@ class NFSDataStore(DataStore):
|
||||
pass
|
||||
|
||||
def deploy_k8s_nfs(self) -> bool:
|
||||
"""
|
||||
Deploy the NFS server in the Kubernetes orchestrator.
|
||||
"""
|
||||
from kubernetes import client as k8sclient
|
||||
|
||||
name = "nfs-server-{}".format(uuid.uuid4())
|
||||
@@ -148,6 +165,9 @@ class NFSDataStore(DataStore):
|
||||
return True
|
||||
|
||||
def create_k8s_nfs_resources(self) -> bool:
|
||||
"""
|
||||
Create NFS resources such as PV and PVC in Kubernetes.
|
||||
"""
|
||||
from kubernetes import client as k8sclient
|
||||
|
||||
pv_name = "nfs-ckpt-pv-{}".format(uuid.uuid4())
|
||||
@@ -226,6 +246,9 @@ class NFSDataStore(DataStore):
|
||||
return True
|
||||
|
||||
def delete_k8s_nfs_resources(self) -> bool:
|
||||
"""
|
||||
Delete NFS resources such as PV and PVC from the Kubernetes orchestrator.
|
||||
"""
|
||||
from kubernetes import client as k8sclient
|
||||
|
||||
del_options = k8sclient.V1DeleteOptions()
|
||||
|
||||
@@ -23,7 +23,16 @@ class S3DataStoreParameters(DataStoreParameters):
|
||||
|
||||
|
||||
class S3DataStore(DataStore):
|
||||
"""
|
||||
An implementation of the data store using S3 for storing policy checkpoints when using Coach in distributed mode.
|
||||
The policy checkpoints are written by the trainer and read by the rollout worker.
|
||||
"""
|
||||
|
||||
def __init__(self, params: S3DataStoreParameters):
|
||||
"""
|
||||
:param params: The parameters required to use the S3 data store.
|
||||
"""
|
||||
|
||||
super(S3DataStore, self).__init__(params)
|
||||
self.params = params
|
||||
access_key = None
|
||||
@@ -51,6 +60,10 @@ class S3DataStore(DataStore):
|
||||
return True
|
||||
|
||||
def save_to_store(self):
|
||||
"""
|
||||
save_to_store() uploads the policy checkpoint, gifs and videos to the S3 data store. It reads the checkpoint state files and
|
||||
uploads only the latest checkpoint files to S3. It is used by the trainer in Coach when used in the distributed mode.
|
||||
"""
|
||||
try:
|
||||
# remove lock file if it exists
|
||||
self.mc.remove_object(self.params.bucket_name, SyncFiles.LOCKFILE.value)
|
||||
@@ -95,6 +108,10 @@ class S3DataStore(DataStore):
|
||||
print("Got exception: %s\n while saving to S3", e)
|
||||
|
||||
def load_from_store(self):
|
||||
"""
|
||||
load_from_store() downloads a new checkpoint from the S3 data store when it is not available locally. It is used
|
||||
by the rollout workers when using Coach in distributed mode.
|
||||
"""
|
||||
try:
|
||||
state_file = CheckpointStateFile(os.path.abspath(self.params.checkpoint_dir))
|
||||
|
||||
|
||||
@@ -25,17 +25,30 @@ class RedisPubSubMemoryBackendParameters(MemoryBackendParameters):
|
||||
|
||||
|
||||
class RedisPubSubBackend(MemoryBackend):
|
||||
"""
|
||||
A memory backend which transfers the experiences from the rollout to the training worker using Redis Pub/Sub in
|
||||
Coach when distributed mode is used.
|
||||
"""
|
||||
|
||||
def __init__(self, params: RedisPubSubMemoryBackendParameters):
|
||||
"""
|
||||
:param params: The Redis parameters to be used with this Redis Pub/Sub instance.
|
||||
"""
|
||||
self.params = params
|
||||
self.redis_connection = redis.Redis(self.params.redis_address, self.params.redis_port)
|
||||
self.redis_server_name = 'redis-server-{}'.format(uuid.uuid4())
|
||||
self.redis_service_name = 'redis-service-{}'.format(uuid.uuid4())
|
||||
|
||||
def store(self, obj):
|
||||
"""
|
||||
:param obj: The object to store in memory. The object is either a Tranisition or Episode type.
|
||||
"""
|
||||
self.redis_connection.publish(self.params.channel, pickle.dumps(obj))
|
||||
|
||||
def deploy(self):
|
||||
"""
|
||||
Deploy the Redis Pub/Sub service in an orchestrator.
|
||||
"""
|
||||
if not self.params.deployed:
|
||||
if self.params.orchestrator_type == 'kubernetes':
|
||||
self.deploy_kubernetes()
|
||||
@@ -44,7 +57,9 @@ class RedisPubSubBackend(MemoryBackend):
|
||||
time.sleep(10)
|
||||
|
||||
def deploy_kubernetes(self):
|
||||
|
||||
"""
|
||||
Deploy the Redis Pub/Sub service in Kubernetes orchestrator.
|
||||
"""
|
||||
if 'namespace' not in self.params.orchestrator_params:
|
||||
self.params.orchestrator_params['namespace'] = "default"
|
||||
from kubernetes import client
|
||||
@@ -111,6 +126,9 @@ class RedisPubSubBackend(MemoryBackend):
|
||||
return False
|
||||
|
||||
def undeploy(self):
|
||||
"""
|
||||
Undeploy the Redis Pub/Sub service in an orchestrator.
|
||||
"""
|
||||
from kubernetes import client
|
||||
if self.params.deployed:
|
||||
return
|
||||
@@ -133,9 +151,15 @@ class RedisPubSubBackend(MemoryBackend):
|
||||
pass
|
||||
|
||||
def fetch(self, num_consecutive_playing_steps=None):
|
||||
"""
|
||||
:param num_consecutive_playing_steps: The number steps to fetch.
|
||||
"""
|
||||
return RedisSub(redis_address=self.params.redis_address, redis_port=self.params.redis_port, channel=self.params.channel).run(num_consecutive_playing_steps)
|
||||
|
||||
def subscribe(self, agent):
|
||||
"""
|
||||
:param agent: The agent in use.
|
||||
"""
|
||||
redis_sub = RedisSub(redis_address=self.params.redis_address, redis_port=self.params.redis_port, channel=self.params.channel)
|
||||
return redis_sub
|
||||
|
||||
@@ -154,6 +178,9 @@ class RedisSub(object):
|
||||
self.subscriber = self.pubsub.subscribe(self.channel)
|
||||
|
||||
def run(self, num_consecutive_playing_steps):
|
||||
"""
|
||||
:param num_consecutive_playing_steps: The number steps to fetch.
|
||||
"""
|
||||
transitions = 0
|
||||
episodes = 0
|
||||
steps = 0
|
||||
|
||||
@@ -54,8 +54,17 @@ class KubernetesParameters(DeployParameters):
|
||||
|
||||
|
||||
class Kubernetes(Deploy):
|
||||
"""
|
||||
An orchestrator implmentation which uses Kubernetes to deploy the components such as training and rollout workers
|
||||
and Redis Pub/Sub in Coach when used in the distributed mode.
|
||||
"""
|
||||
|
||||
def __init__(self, params: KubernetesParameters):
|
||||
"""
|
||||
:param params: The Kubernetes parameters which are used for deploying the components in Coach. These parameters
|
||||
include namespace and kubeconfig.
|
||||
"""
|
||||
|
||||
super().__init__(params)
|
||||
self.params = params
|
||||
if self.params.kubeconfig:
|
||||
@@ -93,6 +102,9 @@ class Kubernetes(Deploy):
|
||||
self.s3_secret_key = os.environ.get('SECRET_ACCESS_KEY')
|
||||
|
||||
def setup(self) -> bool:
|
||||
"""
|
||||
Deploys the memory backend and data stores if required.
|
||||
"""
|
||||
|
||||
self.memory_backend.deploy()
|
||||
if not self.data_store.deploy():
|
||||
@@ -102,6 +114,9 @@ class Kubernetes(Deploy):
|
||||
return True
|
||||
|
||||
def deploy_trainer(self) -> bool:
|
||||
"""
|
||||
Deploys the training worker in Kubernetes.
|
||||
"""
|
||||
|
||||
trainer_params = self.params.run_type_params.get(str(RunType.TRAINER), None)
|
||||
if not trainer_params:
|
||||
@@ -179,6 +194,9 @@ class Kubernetes(Deploy):
|
||||
return False
|
||||
|
||||
def deploy_worker(self):
|
||||
"""
|
||||
Deploys the rollout worker(s) in Kubernetes.
|
||||
"""
|
||||
|
||||
worker_params = self.params.run_type_params.get(str(RunType.ROLLOUT_WORKER), None)
|
||||
if not worker_params:
|
||||
@@ -258,6 +276,9 @@ class Kubernetes(Deploy):
|
||||
return False
|
||||
|
||||
def worker_logs(self, path='./logs'):
|
||||
"""
|
||||
:param path: Path to store the worker logs.
|
||||
"""
|
||||
worker_params = self.params.run_type_params.get(str(RunType.ROLLOUT_WORKER), None)
|
||||
if not worker_params:
|
||||
return
|
||||
@@ -288,6 +309,9 @@ class Kubernetes(Deploy):
|
||||
self.tail_log(pod_name, api_client)
|
||||
|
||||
def trainer_logs(self):
|
||||
"""
|
||||
Get the logs from trainer.
|
||||
"""
|
||||
trainer_params = self.params.run_type_params.get(str(RunType.TRAINER), None)
|
||||
if not trainer_params:
|
||||
return
|
||||
@@ -346,6 +370,10 @@ class Kubernetes(Deploy):
|
||||
return
|
||||
|
||||
def undeploy(self):
|
||||
"""
|
||||
Undeploy all the components, such as trainer and rollout worker(s), Redis pub/sub and data store, when required.
|
||||
"""
|
||||
|
||||
trainer_params = self.params.run_type_params.get(str(RunType.TRAINER), None)
|
||||
api_client = k8sclient.BatchV1Api()
|
||||
delete_options = k8sclient.V1DeleteOptions(
|
||||
|
||||
Reference in New Issue
Block a user