mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 11:10:20 +01:00
Adding improvements
This commit is contained in:
committed by
zach dwiel
parent
3ba0df7d07
commit
9a30c26469
@@ -3,6 +3,7 @@ import redis
|
||||
import pickle
|
||||
import uuid
|
||||
import threading
|
||||
import time
|
||||
from kubernetes import client
|
||||
|
||||
from rl_coach.memories.backend.memory import MemoryBackend, MemoryBackendParameters
|
||||
@@ -31,6 +32,8 @@ class RedisPubSubBackend(MemoryBackend):
|
||||
def __init__(self, params: RedisPubSubMemoryBackendParameters):
|
||||
self.params = params
|
||||
self.redis_connection = redis.Redis(self.params.redis_address, self.params.redis_port)
|
||||
self.redis_server_name = 'redis-server-{}'.format(uuid.uuid4())
|
||||
self.redis_service_name = 'redis-service-{}'.format(uuid.uuid4())
|
||||
|
||||
def store(self, obj):
|
||||
self.redis_connection.publish(self.params.channel, pickle.dumps(obj))
|
||||
@@ -39,7 +42,9 @@ class RedisPubSubBackend(MemoryBackend):
|
||||
if not self.params.deployed:
|
||||
if self.params.orchestrator_type == 'kubernetes':
|
||||
self.deploy_kubernetes()
|
||||
self.params.deployed = True
|
||||
|
||||
# Wait till subscribe to the channel is possible or else it will cause delays in the trainer.
|
||||
time.sleep(10)
|
||||
|
||||
def deploy_kubernetes(self):
|
||||
|
||||
@@ -47,11 +52,11 @@ class RedisPubSubBackend(MemoryBackend):
|
||||
self.params.orchestrator_params['namespace'] = "default"
|
||||
|
||||
container = client.V1Container(
|
||||
name="redis-server",
|
||||
name=self.redis_server_name,
|
||||
image='redis:4-alpine',
|
||||
)
|
||||
template = client.V1PodTemplateSpec(
|
||||
metadata=client.V1ObjectMeta(labels={'app': 'redis-server'}),
|
||||
metadata=client.V1ObjectMeta(labels={'app': self.redis_server_name}),
|
||||
spec=client.V1PodSpec(
|
||||
containers=[container]
|
||||
)
|
||||
@@ -60,14 +65,14 @@ class RedisPubSubBackend(MemoryBackend):
|
||||
replicas=1,
|
||||
template=template,
|
||||
selector=client.V1LabelSelector(
|
||||
match_labels={'app': 'redis-server'}
|
||||
match_labels={'app': self.redis_server_name}
|
||||
)
|
||||
)
|
||||
|
||||
deployment = client.V1Deployment(
|
||||
api_version='apps/v1',
|
||||
kind='Deployment',
|
||||
metadata=client.V1ObjectMeta(name='redis-server', labels={'app': 'redis-server'}),
|
||||
metadata=client.V1ObjectMeta(name=self.redis_server_name, labels={'app': self.redis_server_name}),
|
||||
spec=deployment_spec
|
||||
)
|
||||
|
||||
@@ -84,10 +89,10 @@ class RedisPubSubBackend(MemoryBackend):
|
||||
api_version='v1',
|
||||
kind='Service',
|
||||
metadata=client.V1ObjectMeta(
|
||||
name='redis-service'
|
||||
name=self.redis_service_name
|
||||
),
|
||||
spec=client.V1ServiceSpec(
|
||||
selector={'app': 'redis-server'},
|
||||
selector={'app': self.redis_server_name},
|
||||
ports=[client.V1ServicePort(
|
||||
protocol='TCP',
|
||||
port=6379,
|
||||
@@ -98,7 +103,9 @@ class RedisPubSubBackend(MemoryBackend):
|
||||
|
||||
try:
|
||||
core_v1_api.create_namespaced_service(self.params.orchestrator_params['namespace'], service)
|
||||
self.params.redis_address = 'redis-service.{}.svc'.format(self.params.orchestrator_params['namespace'])
|
||||
self.params.redis_address = '{}.{}.svc'.format(
|
||||
self.redis_service_name, self.params.orchestrator_params['namespace']
|
||||
)
|
||||
self.params.redis_port = 6379
|
||||
return True
|
||||
except client.rest.ApiException as e:
|
||||
@@ -106,23 +113,21 @@ class RedisPubSubBackend(MemoryBackend):
|
||||
return False
|
||||
|
||||
def undeploy(self):
|
||||
if not self.params.deployed:
|
||||
if self.params.deployed:
|
||||
return
|
||||
api_client = client.AppsV1Api()
|
||||
delete_options = client.V1DeleteOptions()
|
||||
try:
|
||||
api_client.delete_namespaced_deployment('redis-server', self.params.orchestrator_params['namespace'], delete_options)
|
||||
api_client.delete_namespaced_deployment(self.redis_server_name, self.params.orchestrator_params['namespace'], delete_options)
|
||||
except client.rest.ApiException as e:
|
||||
print("Got exception: %s\n while deleting redis-server", e)
|
||||
|
||||
api_client = client.CoreV1Api()
|
||||
try:
|
||||
api_client.delete_namespaced_service('redis-service', self.params.orchestrator_params['namespace'], delete_options)
|
||||
api_client.delete_namespaced_service(self.redis_service_name, self.params.orchestrator_params['namespace'], delete_options)
|
||||
except client.rest.ApiException as e:
|
||||
print("Got exception: %s\n while deleting redis-server", e)
|
||||
|
||||
self.params.deployed = False
|
||||
|
||||
def sample(self, size):
|
||||
pass
|
||||
|
||||
@@ -145,7 +150,9 @@ class RedisSub(threading.Thread):
|
||||
self.subscriber = None
|
||||
self.agent = agent
|
||||
self.channel = channel
|
||||
print('Before subscribe')
|
||||
self.subscriber = self.pubsub.subscribe(self.channel)
|
||||
print('After subscribe')
|
||||
|
||||
def run(self):
|
||||
for message in self.pubsub.listen():
|
||||
|
||||
Reference in New Issue
Block a user