1
0
mirror of https://github.com/gryf/coach.git synced 2026-03-03 15:25:49 +01:00

Adding initial interface for backend and redis pubsub (#19)

* Adding initial interface for backend and redis pubsub

* Addressing comments, adding super in all memories

* Removing distributed experience replay
This commit is contained in:
Ajay Deshpande
2018-10-03 15:07:48 -07:00
committed by zach dwiel
parent a54ef2757f
commit 6b2de6ba6d
21 changed files with 459 additions and 444 deletions

View File

View File

@@ -0,0 +1,36 @@
class MemoryBackendParameters(object):
def __init__(self, store_type, orchestrator_type, run_type, deployed: str = False):
self.store_type = store_type
self.orchestrator_type = orchestrator_type
self.run_type = run_type
self.deployed = deployed
class MemoryBackend(object):
def __init__(self, params: MemoryBackendParameters):
pass
def deploy(self):
raise NotImplemented("Not yet implemented")
def get_endpoint(self):
raise NotImplemented("Not yet implemented")
def undeploy(self):
raise NotImplemented("Not yet implemented")
def sample(self, size: int):
raise NotImplemented("Not yet implemented")
def store(self, obj):
raise NotImplemented("Not yet implemented")
def store_episode(self, obj):
raise NotImplemented("Not yet implemented")
def subscribe(self, memory):
raise NotImplemented("Not yet implemented")

View File

@@ -0,0 +1,21 @@
from rl_coach.memories.backend.memory import MemoryBackendParameters
from rl_coach.memories.backend.redis import RedisPubSubBackend, RedisPubSubMemoryBackendParameters
def get_memory_backend(params: MemoryBackendParameters):
backend = None
if type(params) == RedisPubSubMemoryBackendParameters:
backend = RedisPubSubBackend(params)
return backend
def construct_memory_params(json: dict):
if json['store_type'] == 'redispubsub':
memory_params = RedisPubSubMemoryBackendParameters(
json['redis_address'], json['redis_port'], channel=json.get('channel', ''), run_type=json['run_type']
)
return memory_params

View File

@@ -0,0 +1,160 @@
import redis
import pickle
import uuid
import threading
from kubernetes import client
from rl_coach.memories.backend.memory import MemoryBackend, MemoryBackendParameters
from rl_coach.memories.memory import Memory
from rl_coach.core_types import Transition, Episode
class RedisPubSubMemoryBackendParameters(MemoryBackendParameters):
def __init__(self, redis_address: str="", redis_port: int=6379, channel: str="channel-{}".format(uuid.uuid4()),
orchestrator_params: dict=None, run_type='trainer', orchestrator_type: str = "kubernetes", deployed: str = False):
self.redis_address = redis_address
self.redis_port = redis_port
self.channel = channel
if not orchestrator_params:
orchestrator_params = {}
self.orchestrator_params = orchestrator_params
self.run_type = run_type
self.store_type = "redispubsub"
self.orchestrator_type = orchestrator_type
self.deployed = deployed
class RedisPubSubBackend(MemoryBackend):
def __init__(self, params: RedisPubSubMemoryBackendParameters):
self.params = params
self.redis_connection = redis.Redis(self.params.redis_address, self.params.redis_port)
def store(self, obj):
self.redis_connection.publish(self.params.channel, pickle.dumps(obj))
def deploy(self):
if not self.params.deployed:
if self.params.orchestrator_type == 'kubernetes':
self.deploy_kubernetes()
self.params.deployed = True
def deploy_kubernetes(self):
if 'namespace' not in self.params.orchestrator_params:
self.params.orchestrator_params['namespace'] = "default"
container = client.V1Container(
name="redis-server",
image='redis:4-alpine',
)
template = client.V1PodTemplateSpec(
metadata=client.V1ObjectMeta(labels={'app': 'redis-server'}),
spec=client.V1PodSpec(
containers=[container]
)
)
deployment_spec = client.V1DeploymentSpec(
replicas=1,
template=template,
selector=client.V1LabelSelector(
match_labels={'app': 'redis-server'}
)
)
deployment = client.V1Deployment(
api_version='apps/v1',
kind='Deployment',
metadata=client.V1ObjectMeta(name='redis-server', labels={'app': 'redis-server'}),
spec=deployment_spec
)
api_client = client.AppsV1Api()
try:
api_client.create_namespaced_deployment(self.params.orchestrator_params['namespace'], deployment)
except client.rest.ApiException as e:
print("Got exception: %s\n while creating redis-server", e)
return False
core_v1_api = client.CoreV1Api()
service = client.V1Service(
api_version='v1',
kind='Service',
metadata=client.V1ObjectMeta(
name='redis-service'
),
spec=client.V1ServiceSpec(
selector={'app': 'redis-server'},
ports=[client.V1ServicePort(
protocol='TCP',
port=6379,
target_port=6379
)]
)
)
try:
core_v1_api.create_namespaced_service(self.params.orchestrator_params['namespace'], service)
self.params.redis_address = 'redis-service.{}.svc'.format(self.params.orchestrator_params['namespace'])
self.params.redis_port = 6379
return True
except client.rest.ApiException as e:
print("Got exception: %s\n while creating a service for redis-server", e)
return False
def undeploy(self):
if not self.params.deployed:
return
api_client = client.AppsV1Api()
delete_options = client.V1DeleteOptions()
try:
api_client.delete_namespaced_deployment('redis-server', self.params.orchestrator_params['namespace'], delete_options)
except client.rest.ApiException as e:
print("Got exception: %s\n while deleting redis-server", e)
api_client = client.CoreV1Api()
try:
api_client.delete_namespaced_service('redis-service', self.params.orchestrator_params['namespace'], delete_options)
except client.rest.ApiException as e:
print("Got exception: %s\n while deleting redis-server", e)
self.params.deployed = False
def sample(self, size):
pass
def subscribe(self, memory):
redis_sub = RedisSub(memory, redis_address=self.params.redis_address, redis_port=self.params.redis_port, channel=self.params.channel)
redis_sub.daemon = True
redis_sub.start()
def get_endpoint(self):
return {'redis_address': self.params.redis_address,
'redis_port': self.params.redis_port}
class RedisSub(threading.Thread):
def __init__(self, memory: Memory, redis_address: str = "localhost", redis_port: int=6379, channel: str = "PubsubChannel"):
super().__init__()
self.redis_connection = redis.Redis(redis_address, redis_port)
self.pubsub = self.redis_connection.pubsub()
self.subscriber = None
self.memory = memory
self.channel = channel
self.subscriber = self.pubsub.subscribe(self.channel)
def run(self):
for message in self.pubsub.listen():
if message and 'data' in message:
try:
obj = pickle.loads(message['data'])
if type(obj) == Transition:
self.memory.store(obj)
elif type(obj) == Episode:
self.memory.store_episode(obj)
except Exception:
continue

View File

@@ -160,6 +160,8 @@ class EpisodicExperienceReplay(Memory):
:param transition: a transition to store
:return: None
"""
# Calling super.store() so that in case a memory backend is used, the memory backend can store this transition.
super().store(transition)
self.reader_writer_lock.lock_writing_and_reading()
if len(self._buffer) == 0:
@@ -181,6 +183,9 @@ class EpisodicExperienceReplay(Memory):
:param episode: the new episode to store
:return: None
"""
# Calling super.store() so that in case a memory backend is used, the memory backend can store this episode.
super().store(episode)
if lock:
self.reader_writer_lock.lock_writing_and_reading()

View File

@@ -106,6 +106,10 @@ class EpisodicHindsightExperienceReplay(EpisodicExperienceReplay):
]
def store_episode(self, episode: Episode, lock: bool=True) -> None:
# Calling super.store() so that in case a memory backend is used, the memory backend can store this episode.
super().store_episode(episode)
# generate hindsight transitions only when an episode is finished
last_episode_transitions = copy.copy(episode.transitions)

View File

@@ -59,6 +59,10 @@ class EpisodicHRLHindsightExperienceReplay(EpisodicHindsightExperienceReplay):
# for a layer producing sub-goals, we will replace in hindsight the action (sub-goal) given to the lower
# level with the actual achieved goal. the achieved goal (and observation) seen is assumed to be the same
# for all levels - we can use this level's achieved goal instead of the lower level's one
# Calling super.store() so that in case a memory backend is used, the memory backend can store this episode.
super().store_episode(episode)
for transition in episode.transitions:
new_achieved_goal = transition.next_state[self.goals_space.goal_name]
transition.action = new_achieved_goal

View File

@@ -18,6 +18,7 @@ from enum import Enum
from typing import Tuple
from rl_coach.base_parameters import Parameters
from rl_coach.memories.backend.memory import MemoryBackend
class MemoryGranularity(Enum):
@@ -32,7 +33,6 @@ class MemoryParameters(Parameters):
self.shared_memory = False
self.load_memory_from_file_path = None
@property
def path(self):
return 'rl_coach.memories.memory:Memory'
@@ -45,9 +45,16 @@ class Memory(object):
"""
self.max_size = max_size
self._length = 0
self.memory_backend = None
def store(self, obj):
raise NotImplementedError("")
if self.memory_backend:
self.memory_backend.store(obj)
def store_episode(self, episode):
if self.memory_backend:
for transition in episode:
self.memory_backend.store(transition)
def get(self, index):
raise NotImplementedError("")
@@ -64,4 +71,5 @@ class Memory(object):
def clean(self):
raise NotImplementedError("")
def set_memory_backend(self, memory_backend: MemoryBackend):
self.memory_backend = memory_backend

View File

@@ -72,6 +72,8 @@ class BalancedExperienceReplay(ExperienceReplay):
locks and then calls store with lock = True
:return: None
"""
# Calling super.store() so that in case a memory backend is used, the memory backend can store this transition.
super().store(transition)
if lock:
self.reader_writer_lock.lock_writing_and_reading()

View File

@@ -1,182 +0,0 @@
#
# Copyright (c) 2017 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import List, Tuple, Union
import numpy as np
import redis
import uuid
import pickle
from rl_coach.core_types import Transition
from rl_coach.memories.memory import Memory, MemoryGranularity, MemoryParameters
class DistributedExperienceReplayParameters(MemoryParameters):
def __init__(self):
super().__init__()
self.max_size = (MemoryGranularity.Transitions, 1000000)
self.allow_duplicates_in_batch_sampling = True
self.redis_ip = 'localhost'
self.redis_port = 6379
self.redis_db = 0
@property
def path(self):
return 'rl_coach.memories.non_episodic.distributed_experience_replay:DistributedExperienceReplay'
class DistributedExperienceReplay(Memory):
"""
A regular replay buffer which stores transition without any additional structure
"""
def __init__(self, max_size: Tuple[MemoryGranularity, int], allow_duplicates_in_batch_sampling: bool=True,
redis_ip='localhost', redis_port=6379, redis_db=0):
"""
:param max_size: the maximum number of transitions or episodes to hold in the memory
:param allow_duplicates_in_batch_sampling: allow having the same transition multiple times in a batch
"""
super().__init__(max_size)
if max_size[0] != MemoryGranularity.Transitions:
raise ValueError("Experience replay size can only be configured in terms of transitions")
self.allow_duplicates_in_batch_sampling = allow_duplicates_in_batch_sampling
self.db = redis_db
self.redis_connection = redis.Redis(redis_ip, redis_port, self.db)
def length(self) -> int:
"""
Get the number of transitions in the ER
"""
return self.num_transitions()
def num_transitions(self) -> int:
"""
Get the number of transitions in the ER
"""
try:
return self.redis_connection.info(section='keyspace')['db{}'.format(self.db)]['keys']
except Exception as e:
return 0
def sample(self, size: int) -> List[Transition]:
"""
Sample a batch of transitions form the replay buffer. If the requested size is larger than the number
of samples available in the replay buffer then the batch will return empty.
:param size: the size of the batch to sample
:param beta: the beta parameter used for importance sampling
:return: a batch (list) of selected transitions from the replay buffer
"""
transition_idx = dict()
if self.allow_duplicates_in_batch_sampling:
while len(transition_idx) != size:
key = self.redis_connection.randomkey()
transition_idx[key] = pickle.loads(self.redis_connection.get(key))
else:
if self.num_transitions() >= size:
while len(transition_idx) != size:
key = self.redis_connection.randomkey()
if key in transition_idx:
continue
transition_idx[key] = pickle.loads(self.redis_connection.get(key))
else:
raise ValueError("The replay buffer cannot be sampled since there are not enough transitions yet. "
"There are currently {} transitions".format(self.num_transitions()))
batch = transition_idx.values()
return batch
def _enforce_max_length(self) -> None:
"""
Make sure that the size of the replay buffer does not pass the maximum size allowed.
If it passes the max size, the oldest transition in the replay buffer will be removed.
This function does not use locks since it is only called internally
:return: None
"""
granularity, size = self.max_size
if granularity == MemoryGranularity.Transitions:
while size != 0 and self.num_transitions() > size:
self.redis_connection.delete(self.redis_connection.randomkey())
else:
raise ValueError("The granularity of the replay buffer can only be set in terms of transitions")
def store(self, transition: Transition, lock: bool=True) -> None:
"""
Store a new transition in the memory.
:param transition: a transition to store
:param lock: if true, will lock the readers writers lock. this can cause a deadlock if an inheriting class
locks and then calls store with lock = True
:return: None
"""
self.redis_connection.set(uuid.uuid4(), pickle.dumps(transition))
self._enforce_max_length()
def get_transition(self, transition_index: int, lock: bool=True) -> Union[None, Transition]:
"""
Returns the transition in the given index. If the transition does not exist, returns None instead.
:param transition_index: the index of the transition to return
:param lock: use write locking if this is a shared memory
:return: the corresponding transition
"""
return pickle.loads(self.redis_connection.get(transition_index))
def remove_transition(self, transition_index: int, lock: bool=True) -> None:
"""
Remove the transition in the given index.
This does not remove the transition from the segment trees! it is just used to remove the transition
from the transitions list
:param transition_index: the index of the transition to remove
:return: None
"""
self.redis_connection.delete(transition_index)
# for API compatibility
def get(self, transition_index: int, lock: bool=True) -> Union[None, Transition]:
"""
Returns the transition in the given index. If the transition does not exist, returns None instead.
:param transition_index: the index of the transition to return
:return: the corresponding transition
"""
return self.get_transition(transition_index, lock)
# for API compatibility
def remove(self, transition_index: int, lock: bool=True):
"""
Remove the transition in the given index
:param transition_index: the index of the transition to remove
:return: None
"""
self.remove_transition(transition_index, lock)
def clean(self, lock: bool=True) -> None:
"""
Clean the memory by removing all the episodes
:return: None
"""
self.redis_connection.flushall()
# self.transitions = []
def mean_reward(self) -> np.ndarray:
"""
Get the mean reward in the replay buffer
:return: the mean reward
"""
mean = np.mean([pickle.loads(self.redis_connection.get(key)).reward for key in self.redis_connection.keys()])
return mean

View File

@@ -90,7 +90,6 @@ class ExperienceReplay(Memory):
batch = [self.transitions[i] for i in transitions_idx]
self.reader_writer_lock.release_writing()
return batch
def _enforce_max_length(self) -> None:
@@ -115,6 +114,8 @@ class ExperienceReplay(Memory):
locks and then calls store with lock = True
:return: None
"""
# Calling super.store() so that in case a memory backend is used, the memory backend can store this transition.
super().store(transition)
if lock:
self.reader_writer_lock.lock_writing_and_reading()

View File

@@ -267,6 +267,9 @@ class PrioritizedExperienceReplay(ExperienceReplay):
:param transition: a transition to store
:return: None
"""
# Calling super.store() so that in case a memory backend is used, the memory backend can store this transition.
super().store(transition)
self.reader_writer_lock.lock_writing_and_reading()
transition_priority = self.maximal_priority