1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 11:10:20 +01:00

Adding improvements

This commit is contained in:
Ajay Deshpande
2018-10-15 15:57:10 -07:00
committed by zach dwiel
parent 3ba0df7d07
commit 9a30c26469
7 changed files with 38 additions and 27 deletions

View File

@@ -3,6 +3,7 @@ import redis
import pickle
import uuid
import threading
import time
from kubernetes import client
from rl_coach.memories.backend.memory import MemoryBackend, MemoryBackendParameters
@@ -31,6 +32,8 @@ class RedisPubSubBackend(MemoryBackend):
def __init__(self, params: RedisPubSubMemoryBackendParameters):
self.params = params
self.redis_connection = redis.Redis(self.params.redis_address, self.params.redis_port)
self.redis_server_name = 'redis-server-{}'.format(uuid.uuid4())
self.redis_service_name = 'redis-service-{}'.format(uuid.uuid4())
def store(self, obj):
self.redis_connection.publish(self.params.channel, pickle.dumps(obj))
@@ -39,7 +42,9 @@ class RedisPubSubBackend(MemoryBackend):
if not self.params.deployed:
if self.params.orchestrator_type == 'kubernetes':
self.deploy_kubernetes()
self.params.deployed = True
# Wait till subscribe to the channel is possible or else it will cause delays in the trainer.
time.sleep(10)
def deploy_kubernetes(self):
@@ -47,11 +52,11 @@ class RedisPubSubBackend(MemoryBackend):
self.params.orchestrator_params['namespace'] = "default"
container = client.V1Container(
name="redis-server",
name=self.redis_server_name,
image='redis:4-alpine',
)
template = client.V1PodTemplateSpec(
metadata=client.V1ObjectMeta(labels={'app': 'redis-server'}),
metadata=client.V1ObjectMeta(labels={'app': self.redis_server_name}),
spec=client.V1PodSpec(
containers=[container]
)
@@ -60,14 +65,14 @@ class RedisPubSubBackend(MemoryBackend):
replicas=1,
template=template,
selector=client.V1LabelSelector(
match_labels={'app': 'redis-server'}
match_labels={'app': self.redis_server_name}
)
)
deployment = client.V1Deployment(
api_version='apps/v1',
kind='Deployment',
metadata=client.V1ObjectMeta(name='redis-server', labels={'app': 'redis-server'}),
metadata=client.V1ObjectMeta(name=self.redis_server_name, labels={'app': self.redis_server_name}),
spec=deployment_spec
)
@@ -84,10 +89,10 @@ class RedisPubSubBackend(MemoryBackend):
api_version='v1',
kind='Service',
metadata=client.V1ObjectMeta(
name='redis-service'
name=self.redis_service_name
),
spec=client.V1ServiceSpec(
selector={'app': 'redis-server'},
selector={'app': self.redis_server_name},
ports=[client.V1ServicePort(
protocol='TCP',
port=6379,
@@ -98,7 +103,9 @@ class RedisPubSubBackend(MemoryBackend):
try:
core_v1_api.create_namespaced_service(self.params.orchestrator_params['namespace'], service)
self.params.redis_address = 'redis-service.{}.svc'.format(self.params.orchestrator_params['namespace'])
self.params.redis_address = '{}.{}.svc'.format(
self.redis_service_name, self.params.orchestrator_params['namespace']
)
self.params.redis_port = 6379
return True
except client.rest.ApiException as e:
@@ -106,23 +113,21 @@ class RedisPubSubBackend(MemoryBackend):
return False
def undeploy(self):
if not self.params.deployed:
if self.params.deployed:
return
api_client = client.AppsV1Api()
delete_options = client.V1DeleteOptions()
try:
api_client.delete_namespaced_deployment('redis-server', self.params.orchestrator_params['namespace'], delete_options)
api_client.delete_namespaced_deployment(self.redis_server_name, self.params.orchestrator_params['namespace'], delete_options)
except client.rest.ApiException as e:
print("Got exception: %s\n while deleting redis-server", e)
api_client = client.CoreV1Api()
try:
api_client.delete_namespaced_service('redis-service', self.params.orchestrator_params['namespace'], delete_options)
api_client.delete_namespaced_service(self.redis_service_name, self.params.orchestrator_params['namespace'], delete_options)
except client.rest.ApiException as e:
print("Got exception: %s\n while deleting redis-server", e)
self.params.deployed = False
def sample(self, size):
pass
@@ -145,7 +150,9 @@ class RedisSub(threading.Thread):
self.subscriber = None
self.agent = agent
self.channel = channel
print('Before subscribe')
self.subscriber = self.pubsub.subscribe(self.channel)
print('After subscribe')
def run(self):
for message in self.pubsub.listen():