1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 19:20:19 +01:00

Adding initial interface for backend and redis pubsub (#19)

* Adding initial interface for backend and redis pubsub

* Addressing comments, adding super in all memories

* Removing distributed experience replay
This commit is contained in:
Ajay Deshpande
2018-10-03 15:07:48 -07:00
committed by zach dwiel
parent a54ef2757f
commit 6b2de6ba6d
21 changed files with 459 additions and 444 deletions

View File

@@ -33,6 +33,7 @@ from rl_coach.memories.episodic.episodic_experience_replay import EpisodicExperi
from rl_coach.spaces import SpacesDefinition, VectorObservationSpace, GoalsSpace, AttentionActionSpace from rl_coach.spaces import SpacesDefinition, VectorObservationSpace, GoalsSpace, AttentionActionSpace
from rl_coach.utils import Signal, force_list from rl_coach.utils import Signal, force_list
from rl_coach.utils import dynamic_import_and_instantiate_module_from_params from rl_coach.utils import dynamic_import_and_instantiate_module_from_params
from rl_coach.memories.backend.memory_impl import get_memory_backend
class Agent(AgentInterface): class Agent(AgentInterface):
@@ -76,6 +77,14 @@ class Agent(AgentInterface):
# modules # modules
self.memory = dynamic_import_and_instantiate_module_from_params(self.ap.memory) self.memory = dynamic_import_and_instantiate_module_from_params(self.ap.memory)
if hasattr(self.ap.memory, 'memory_backend_params'):
self.memory_backend = get_memory_backend(self.ap.memory.memory_backend_params)
if self.ap.memory.memory_backend_params.run_type == 'trainer':
self.memory_backend.subscribe(self.memory)
else:
self.memory.set_memory_backend(self.memory_backend)
if agent_parameters.memory.load_memory_from_file_path: if agent_parameters.memory.load_memory_from_file_path:
screen.log_title("Loading replay buffer from pickle. Pickle path: {}" screen.log_title("Loading replay buffer from pickle. Pickle path: {}"
.format(agent_parameters.memory.load_memory_from_file_path)) .format(agent_parameters.memory.load_memory_from_file_path))
@@ -534,6 +543,9 @@ class Agent(AgentInterface):
Determine if we should start a training phase according to the number of steps passed since the last training Determine if we should start a training phase according to the number of steps passed since the last training
:return: boolean: True if we should start a training phase :return: boolean: True if we should start a training phase
""" """
if hasattr(self.ap.memory, 'memory_backend_params'):
self.total_steps_counter = self.call_memory('num_transitions')
step_method = self.ap.algorithm.num_consecutive_playing_steps step_method = self.ap.algorithm.num_consecutive_playing_steps
if step_method.__class__ == EnvironmentEpisodes: if step_method.__class__ == EnvironmentEpisodes:
should_update = (self.current_episode - self.last_training_phase_step) >= step_method.num_steps should_update = (self.current_episode - self.last_training_phase_step) >= step_method.num_steps

View File

@@ -27,7 +27,6 @@ from rl_coach.architectures.tensorflow_components.embedders.embedder import Inpu
from rl_coach.core_types import EnvironmentSteps from rl_coach.core_types import EnvironmentSteps
from rl_coach.exploration_policies.e_greedy import EGreedyParameters from rl_coach.exploration_policies.e_greedy import EGreedyParameters
from rl_coach.memories.non_episodic.experience_replay import ExperienceReplayParameters from rl_coach.memories.non_episodic.experience_replay import ExperienceReplayParameters
from rl_coach.memories.non_episodic.distributed_experience_replay import DistributedExperienceReplayParameters
from rl_coach.schedules import LinearSchedule from rl_coach.schedules import LinearSchedule
@@ -51,20 +50,6 @@ class DQNNetworkParameters(NetworkParameters):
self.create_target_network = True self.create_target_network = True
class DQNAgentParametersDistributed(AgentParameters):
def __init__(self):
super().__init__(algorithm=DQNAlgorithmParameters(),
exploration=EGreedyParameters(),
memory=DistributedExperienceReplayParameters(),
networks={"main": DQNNetworkParameters()})
self.exploration.epsilon_schedule = LinearSchedule(1, 0.1, 1000000)
self.exploration.evaluation_epsilon = 0.05
@property
def path(self):
return 'rl_coach.agents.dqn_agent:DQNAgent'
class DQNAgentParameters(AgentParameters): class DQNAgentParameters(AgentParameters):
def __init__(self): def __init__(self):
super().__init__(algorithm=DQNAlgorithmParameters(), super().__init__(algorithm=DQNAlgorithmParameters(),

View File

View File

@@ -0,0 +1,36 @@
class MemoryBackendParameters(object):
def __init__(self, store_type, orchestrator_type, run_type, deployed: str = False):
self.store_type = store_type
self.orchestrator_type = orchestrator_type
self.run_type = run_type
self.deployed = deployed
class MemoryBackend(object):
def __init__(self, params: MemoryBackendParameters):
pass
def deploy(self):
raise NotImplemented("Not yet implemented")
def get_endpoint(self):
raise NotImplemented("Not yet implemented")
def undeploy(self):
raise NotImplemented("Not yet implemented")
def sample(self, size: int):
raise NotImplemented("Not yet implemented")
def store(self, obj):
raise NotImplemented("Not yet implemented")
def store_episode(self, obj):
raise NotImplemented("Not yet implemented")
def subscribe(self, memory):
raise NotImplemented("Not yet implemented")

View File

@@ -0,0 +1,21 @@
from rl_coach.memories.backend.memory import MemoryBackendParameters
from rl_coach.memories.backend.redis import RedisPubSubBackend, RedisPubSubMemoryBackendParameters
def get_memory_backend(params: MemoryBackendParameters):
backend = None
if type(params) == RedisPubSubMemoryBackendParameters:
backend = RedisPubSubBackend(params)
return backend
def construct_memory_params(json: dict):
if json['store_type'] == 'redispubsub':
memory_params = RedisPubSubMemoryBackendParameters(
json['redis_address'], json['redis_port'], channel=json.get('channel', ''), run_type=json['run_type']
)
return memory_params

View File

@@ -0,0 +1,160 @@
import redis
import pickle
import uuid
import threading
from kubernetes import client
from rl_coach.memories.backend.memory import MemoryBackend, MemoryBackendParameters
from rl_coach.memories.memory import Memory
from rl_coach.core_types import Transition, Episode
class RedisPubSubMemoryBackendParameters(MemoryBackendParameters):
def __init__(self, redis_address: str="", redis_port: int=6379, channel: str="channel-{}".format(uuid.uuid4()),
orchestrator_params: dict=None, run_type='trainer', orchestrator_type: str = "kubernetes", deployed: str = False):
self.redis_address = redis_address
self.redis_port = redis_port
self.channel = channel
if not orchestrator_params:
orchestrator_params = {}
self.orchestrator_params = orchestrator_params
self.run_type = run_type
self.store_type = "redispubsub"
self.orchestrator_type = orchestrator_type
self.deployed = deployed
class RedisPubSubBackend(MemoryBackend):
def __init__(self, params: RedisPubSubMemoryBackendParameters):
self.params = params
self.redis_connection = redis.Redis(self.params.redis_address, self.params.redis_port)
def store(self, obj):
self.redis_connection.publish(self.params.channel, pickle.dumps(obj))
def deploy(self):
if not self.params.deployed:
if self.params.orchestrator_type == 'kubernetes':
self.deploy_kubernetes()
self.params.deployed = True
def deploy_kubernetes(self):
if 'namespace' not in self.params.orchestrator_params:
self.params.orchestrator_params['namespace'] = "default"
container = client.V1Container(
name="redis-server",
image='redis:4-alpine',
)
template = client.V1PodTemplateSpec(
metadata=client.V1ObjectMeta(labels={'app': 'redis-server'}),
spec=client.V1PodSpec(
containers=[container]
)
)
deployment_spec = client.V1DeploymentSpec(
replicas=1,
template=template,
selector=client.V1LabelSelector(
match_labels={'app': 'redis-server'}
)
)
deployment = client.V1Deployment(
api_version='apps/v1',
kind='Deployment',
metadata=client.V1ObjectMeta(name='redis-server', labels={'app': 'redis-server'}),
spec=deployment_spec
)
api_client = client.AppsV1Api()
try:
api_client.create_namespaced_deployment(self.params.orchestrator_params['namespace'], deployment)
except client.rest.ApiException as e:
print("Got exception: %s\n while creating redis-server", e)
return False
core_v1_api = client.CoreV1Api()
service = client.V1Service(
api_version='v1',
kind='Service',
metadata=client.V1ObjectMeta(
name='redis-service'
),
spec=client.V1ServiceSpec(
selector={'app': 'redis-server'},
ports=[client.V1ServicePort(
protocol='TCP',
port=6379,
target_port=6379
)]
)
)
try:
core_v1_api.create_namespaced_service(self.params.orchestrator_params['namespace'], service)
self.params.redis_address = 'redis-service.{}.svc'.format(self.params.orchestrator_params['namespace'])
self.params.redis_port = 6379
return True
except client.rest.ApiException as e:
print("Got exception: %s\n while creating a service for redis-server", e)
return False
def undeploy(self):
if not self.params.deployed:
return
api_client = client.AppsV1Api()
delete_options = client.V1DeleteOptions()
try:
api_client.delete_namespaced_deployment('redis-server', self.params.orchestrator_params['namespace'], delete_options)
except client.rest.ApiException as e:
print("Got exception: %s\n while deleting redis-server", e)
api_client = client.CoreV1Api()
try:
api_client.delete_namespaced_service('redis-service', self.params.orchestrator_params['namespace'], delete_options)
except client.rest.ApiException as e:
print("Got exception: %s\n while deleting redis-server", e)
self.params.deployed = False
def sample(self, size):
pass
def subscribe(self, memory):
redis_sub = RedisSub(memory, redis_address=self.params.redis_address, redis_port=self.params.redis_port, channel=self.params.channel)
redis_sub.daemon = True
redis_sub.start()
def get_endpoint(self):
return {'redis_address': self.params.redis_address,
'redis_port': self.params.redis_port}
class RedisSub(threading.Thread):
def __init__(self, memory: Memory, redis_address: str = "localhost", redis_port: int=6379, channel: str = "PubsubChannel"):
super().__init__()
self.redis_connection = redis.Redis(redis_address, redis_port)
self.pubsub = self.redis_connection.pubsub()
self.subscriber = None
self.memory = memory
self.channel = channel
self.subscriber = self.pubsub.subscribe(self.channel)
def run(self):
for message in self.pubsub.listen():
if message and 'data' in message:
try:
obj = pickle.loads(message['data'])
if type(obj) == Transition:
self.memory.store(obj)
elif type(obj) == Episode:
self.memory.store_episode(obj)
except Exception:
continue

View File

@@ -160,6 +160,8 @@ class EpisodicExperienceReplay(Memory):
:param transition: a transition to store :param transition: a transition to store
:return: None :return: None
""" """
# Calling super.store() so that in case a memory backend is used, the memory backend can store this transition.
super().store(transition)
self.reader_writer_lock.lock_writing_and_reading() self.reader_writer_lock.lock_writing_and_reading()
if len(self._buffer) == 0: if len(self._buffer) == 0:
@@ -181,6 +183,9 @@ class EpisodicExperienceReplay(Memory):
:param episode: the new episode to store :param episode: the new episode to store
:return: None :return: None
""" """
# Calling super.store() so that in case a memory backend is used, the memory backend can store this episode.
super().store(episode)
if lock: if lock:
self.reader_writer_lock.lock_writing_and_reading() self.reader_writer_lock.lock_writing_and_reading()

View File

@@ -106,6 +106,10 @@ class EpisodicHindsightExperienceReplay(EpisodicExperienceReplay):
] ]
def store_episode(self, episode: Episode, lock: bool=True) -> None: def store_episode(self, episode: Episode, lock: bool=True) -> None:
# Calling super.store() so that in case a memory backend is used, the memory backend can store this episode.
super().store_episode(episode)
# generate hindsight transitions only when an episode is finished # generate hindsight transitions only when an episode is finished
last_episode_transitions = copy.copy(episode.transitions) last_episode_transitions = copy.copy(episode.transitions)

View File

@@ -59,6 +59,10 @@ class EpisodicHRLHindsightExperienceReplay(EpisodicHindsightExperienceReplay):
# for a layer producing sub-goals, we will replace in hindsight the action (sub-goal) given to the lower # for a layer producing sub-goals, we will replace in hindsight the action (sub-goal) given to the lower
# level with the actual achieved goal. the achieved goal (and observation) seen is assumed to be the same # level with the actual achieved goal. the achieved goal (and observation) seen is assumed to be the same
# for all levels - we can use this level's achieved goal instead of the lower level's one # for all levels - we can use this level's achieved goal instead of the lower level's one
# Calling super.store() so that in case a memory backend is used, the memory backend can store this episode.
super().store_episode(episode)
for transition in episode.transitions: for transition in episode.transitions:
new_achieved_goal = transition.next_state[self.goals_space.goal_name] new_achieved_goal = transition.next_state[self.goals_space.goal_name]
transition.action = new_achieved_goal transition.action = new_achieved_goal

View File

@@ -18,6 +18,7 @@ from enum import Enum
from typing import Tuple from typing import Tuple
from rl_coach.base_parameters import Parameters from rl_coach.base_parameters import Parameters
from rl_coach.memories.backend.memory import MemoryBackend
class MemoryGranularity(Enum): class MemoryGranularity(Enum):
@@ -32,7 +33,6 @@ class MemoryParameters(Parameters):
self.shared_memory = False self.shared_memory = False
self.load_memory_from_file_path = None self.load_memory_from_file_path = None
@property @property
def path(self): def path(self):
return 'rl_coach.memories.memory:Memory' return 'rl_coach.memories.memory:Memory'
@@ -45,9 +45,16 @@ class Memory(object):
""" """
self.max_size = max_size self.max_size = max_size
self._length = 0 self._length = 0
self.memory_backend = None
def store(self, obj): def store(self, obj):
raise NotImplementedError("") if self.memory_backend:
self.memory_backend.store(obj)
def store_episode(self, episode):
if self.memory_backend:
for transition in episode:
self.memory_backend.store(transition)
def get(self, index): def get(self, index):
raise NotImplementedError("") raise NotImplementedError("")
@@ -64,4 +71,5 @@ class Memory(object):
def clean(self): def clean(self):
raise NotImplementedError("") raise NotImplementedError("")
def set_memory_backend(self, memory_backend: MemoryBackend):
self.memory_backend = memory_backend

View File

@@ -72,6 +72,8 @@ class BalancedExperienceReplay(ExperienceReplay):
locks and then calls store with lock = True locks and then calls store with lock = True
:return: None :return: None
""" """
# Calling super.store() so that in case a memory backend is used, the memory backend can store this transition.
super().store(transition)
if lock: if lock:
self.reader_writer_lock.lock_writing_and_reading() self.reader_writer_lock.lock_writing_and_reading()

View File

@@ -1,182 +0,0 @@
#
# Copyright (c) 2017 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import List, Tuple, Union
import numpy as np
import redis
import uuid
import pickle
from rl_coach.core_types import Transition
from rl_coach.memories.memory import Memory, MemoryGranularity, MemoryParameters
class DistributedExperienceReplayParameters(MemoryParameters):
def __init__(self):
super().__init__()
self.max_size = (MemoryGranularity.Transitions, 1000000)
self.allow_duplicates_in_batch_sampling = True
self.redis_ip = 'localhost'
self.redis_port = 6379
self.redis_db = 0
@property
def path(self):
return 'rl_coach.memories.non_episodic.distributed_experience_replay:DistributedExperienceReplay'
class DistributedExperienceReplay(Memory):
"""
A regular replay buffer which stores transition without any additional structure
"""
def __init__(self, max_size: Tuple[MemoryGranularity, int], allow_duplicates_in_batch_sampling: bool=True,
redis_ip='localhost', redis_port=6379, redis_db=0):
"""
:param max_size: the maximum number of transitions or episodes to hold in the memory
:param allow_duplicates_in_batch_sampling: allow having the same transition multiple times in a batch
"""
super().__init__(max_size)
if max_size[0] != MemoryGranularity.Transitions:
raise ValueError("Experience replay size can only be configured in terms of transitions")
self.allow_duplicates_in_batch_sampling = allow_duplicates_in_batch_sampling
self.db = redis_db
self.redis_connection = redis.Redis(redis_ip, redis_port, self.db)
def length(self) -> int:
"""
Get the number of transitions in the ER
"""
return self.num_transitions()
def num_transitions(self) -> int:
"""
Get the number of transitions in the ER
"""
try:
return self.redis_connection.info(section='keyspace')['db{}'.format(self.db)]['keys']
except Exception as e:
return 0
def sample(self, size: int) -> List[Transition]:
"""
Sample a batch of transitions form the replay buffer. If the requested size is larger than the number
of samples available in the replay buffer then the batch will return empty.
:param size: the size of the batch to sample
:param beta: the beta parameter used for importance sampling
:return: a batch (list) of selected transitions from the replay buffer
"""
transition_idx = dict()
if self.allow_duplicates_in_batch_sampling:
while len(transition_idx) != size:
key = self.redis_connection.randomkey()
transition_idx[key] = pickle.loads(self.redis_connection.get(key))
else:
if self.num_transitions() >= size:
while len(transition_idx) != size:
key = self.redis_connection.randomkey()
if key in transition_idx:
continue
transition_idx[key] = pickle.loads(self.redis_connection.get(key))
else:
raise ValueError("The replay buffer cannot be sampled since there are not enough transitions yet. "
"There are currently {} transitions".format(self.num_transitions()))
batch = transition_idx.values()
return batch
def _enforce_max_length(self) -> None:
"""
Make sure that the size of the replay buffer does not pass the maximum size allowed.
If it passes the max size, the oldest transition in the replay buffer will be removed.
This function does not use locks since it is only called internally
:return: None
"""
granularity, size = self.max_size
if granularity == MemoryGranularity.Transitions:
while size != 0 and self.num_transitions() > size:
self.redis_connection.delete(self.redis_connection.randomkey())
else:
raise ValueError("The granularity of the replay buffer can only be set in terms of transitions")
def store(self, transition: Transition, lock: bool=True) -> None:
"""
Store a new transition in the memory.
:param transition: a transition to store
:param lock: if true, will lock the readers writers lock. this can cause a deadlock if an inheriting class
locks and then calls store with lock = True
:return: None
"""
self.redis_connection.set(uuid.uuid4(), pickle.dumps(transition))
self._enforce_max_length()
def get_transition(self, transition_index: int, lock: bool=True) -> Union[None, Transition]:
"""
Returns the transition in the given index. If the transition does not exist, returns None instead.
:param transition_index: the index of the transition to return
:param lock: use write locking if this is a shared memory
:return: the corresponding transition
"""
return pickle.loads(self.redis_connection.get(transition_index))
def remove_transition(self, transition_index: int, lock: bool=True) -> None:
"""
Remove the transition in the given index.
This does not remove the transition from the segment trees! it is just used to remove the transition
from the transitions list
:param transition_index: the index of the transition to remove
:return: None
"""
self.redis_connection.delete(transition_index)
# for API compatibility
def get(self, transition_index: int, lock: bool=True) -> Union[None, Transition]:
"""
Returns the transition in the given index. If the transition does not exist, returns None instead.
:param transition_index: the index of the transition to return
:return: the corresponding transition
"""
return self.get_transition(transition_index, lock)
# for API compatibility
def remove(self, transition_index: int, lock: bool=True):
"""
Remove the transition in the given index
:param transition_index: the index of the transition to remove
:return: None
"""
self.remove_transition(transition_index, lock)
def clean(self, lock: bool=True) -> None:
"""
Clean the memory by removing all the episodes
:return: None
"""
self.redis_connection.flushall()
# self.transitions = []
def mean_reward(self) -> np.ndarray:
"""
Get the mean reward in the replay buffer
:return: the mean reward
"""
mean = np.mean([pickle.loads(self.redis_connection.get(key)).reward for key in self.redis_connection.keys()])
return mean

View File

@@ -90,7 +90,6 @@ class ExperienceReplay(Memory):
batch = [self.transitions[i] for i in transitions_idx] batch = [self.transitions[i] for i in transitions_idx]
self.reader_writer_lock.release_writing() self.reader_writer_lock.release_writing()
return batch return batch
def _enforce_max_length(self) -> None: def _enforce_max_length(self) -> None:
@@ -115,6 +114,8 @@ class ExperienceReplay(Memory):
locks and then calls store with lock = True locks and then calls store with lock = True
:return: None :return: None
""" """
# Calling super.store() so that in case a memory backend is used, the memory backend can store this transition.
super().store(transition)
if lock: if lock:
self.reader_writer_lock.lock_writing_and_reading() self.reader_writer_lock.lock_writing_and_reading()

View File

@@ -267,6 +267,9 @@ class PrioritizedExperienceReplay(ExperienceReplay):
:param transition: a transition to store :param transition: a transition to store
:return: None :return: None
""" """
# Calling super.store() so that in case a memory backend is used, the memory backend can store this transition.
super().store(transition)
self.reader_writer_lock.lock_writing_and_reading() self.reader_writer_lock.lock_writing_and_reading()
transition_priority = self.maximal_priority transition_priority = self.maximal_priority

View File

@@ -1,29 +1,46 @@
import os import os
import uuid import uuid
import json
import time
from typing import List
from rl_coach.orchestrators.deploy import Deploy, DeployParameters from rl_coach.orchestrators.deploy import Deploy, DeployParameters
from kubernetes import client, config from kubernetes import client, config
from rl_coach.memories.backend.memory import MemoryBackendParameters
from rl_coach.memories.backend.memory_impl import get_memory_backend
class RunTypeParameters():
def __init__(self, image: str, command: list(), arguments: list() = None,
run_type: str = "trainer", checkpoint_dir: str = "/checkpoint",
num_replicas: int = 1, orchestration_params: dict=None):
self.image = image
self.command = command
if not arguments:
arguments = list()
self.arguments = arguments
self.run_type = run_type
self.checkpoint_dir = checkpoint_dir
self.num_replicas = num_replicas
if not orchestration_params:
orchestration_params = dict()
self.orchestration_params = orchestration_params
class KubernetesParameters(DeployParameters): class KubernetesParameters(DeployParameters):
def __init__(self, name: str, image: str, command: list(), arguments: list() = list(), synchronized: bool = False, def __init__(self, run_type_params: List[RunTypeParameters], kubeconfig: str = None, namespace: str = "", nfs_server: str = None,
num_workers: int = 1, kubeconfig: str = None, namespace: str = None, redis_ip: str = None, nfs_path: str = None, checkpoint_dir: str = '/checkpoint', memory_backend_parameters: MemoryBackendParameters = None):
redis_port: int = None, redis_db: int = 0, nfs_server: str = None, nfs_path: str = None,
checkpoint_dir: str = '/checkpoint'): self.run_type_params = {}
self.image = image for run_type_param in run_type_params:
self.synchronized = synchronized self.run_type_params[run_type_param.run_type] = run_type_param
self.command = command
self.arguments = arguments
self.kubeconfig = kubeconfig self.kubeconfig = kubeconfig
self.num_workers = num_workers
self.namespace = namespace self.namespace = namespace
self.redis_ip = redis_ip
self.redis_port = redis_port
self.redis_db = redis_db
self.nfs_server = nfs_server self.nfs_server = nfs_server
self.nfs_path = nfs_path self.nfs_path = nfs_path
self.checkpoint_dir = checkpoint_dir self.checkpoint_dir = checkpoint_dir
self.name = name self.memory_backend_parameters = memory_backend_parameters
class Kubernetes(Deploy): class Kubernetes(Deploy):
@@ -44,17 +61,14 @@ class Kubernetes(Deploy):
if os.environ.get('http_proxy'): if os.environ.get('http_proxy'):
client.Configuration._default.proxy = os.environ.get('http_proxy') client.Configuration._default.proxy = os.environ.get('http_proxy')
self.deploy_parameters.memory_backend_parameters.orchestrator_params = {'namespace': self.deploy_parameters.namespace}
self.memory_backend = get_memory_backend(self.deploy_parameters.memory_backend_parameters)
def setup(self) -> bool: def setup(self) -> bool:
if not self.deploy_parameters.redis_ip: self.memory_backend.deploy()
# Need to spin up a redis service and a deployment.
if not self.deploy_redis():
print("Failed to setup redis")
return False
if not self.create_nfs_resources(): if not self.create_nfs_resources():
return False return False
return True return True
def create_nfs_resources(self): def create_nfs_resources(self):
@@ -107,87 +121,24 @@ class Kubernetes(Deploy):
return False return False
return True return True
def deploy_redis(self) -> bool: def deploy_trainer(self) -> bool:
container = client.V1Container(
name="redis-server",
image='redis:4-alpine',
)
template = client.V1PodTemplateSpec(
metadata=client.V1ObjectMeta(labels={'app': 'redis-server'}),
spec=client.V1PodSpec(
containers=[container]
)
)
deployment_spec = client.V1DeploymentSpec(
replicas=1,
template=template,
selector=client.V1LabelSelector(
match_labels={'app': 'redis-server'}
)
)
deployment = client.V1Deployment( trainer_params = self.deploy_parameters.run_type_params.get('trainer', None)
api_version='apps/v1', if not trainer_params:
kind='Deployment',
metadata=client.V1ObjectMeta(name='redis-server', labels={'app': 'redis-server'}),
spec=deployment_spec
)
api_client = client.AppsV1Api()
try:
api_client.create_namespaced_deployment(self.deploy_parameters.namespace, deployment)
except client.rest.ApiException as e:
print("Got exception: %s\n while creating redis-server", e)
return False return False
core_v1_api = client.CoreV1Api() trainer_params.command += ['--memory_backend_params', json.dumps(self.deploy_parameters.memory_backend_parameters.__dict__)]
name = "{}-{}".format(trainer_params.run_type, uuid.uuid4())
service = client.V1Service(
api_version='v1',
kind='Service',
metadata=client.V1ObjectMeta(
name='redis-service'
),
spec=client.V1ServiceSpec(
selector={'app': 'redis-server'},
ports=[client.V1ServicePort(
protocol='TCP',
port=6379,
target_port=6379
)]
)
)
try:
core_v1_api.create_namespaced_service(self.deploy_parameters.namespace, service)
self.deploy_parameters.redis_ip = 'redis-service.{}.svc'.format(self.deploy_parameters.namespace)
self.deploy_parameters.redis_port = 6379
return True
except client.rest.ApiException as e:
print("Got exception: %s\n while creating a service for redis-server", e)
return False
def deploy(self) -> bool:
self.deploy_parameters.command += ['--redis_ip', self.deploy_parameters.redis_ip, '--redis_port', '{}'.format(self.deploy_parameters.redis_port)]
if self.deploy_parameters.synchronized:
return self.create_k8s_deployment()
else:
return self.create_k8s_job()
def create_k8s_deployment(self) -> bool:
name = "{}-{}".format(self.deploy_parameters.name, uuid.uuid4())
container = client.V1Container( container = client.V1Container(
name=name, name=name,
image=self.deploy_parameters.image, image=trainer_params.image,
command=self.deploy_parameters.command, command=trainer_params.command,
args=self.deploy_parameters.arguments, args=trainer_params.arguments,
image_pull_policy='Always', image_pull_policy='Always',
volume_mounts=[client.V1VolumeMount( volume_mounts=[client.V1VolumeMount(
name='nfs-pvc', name='nfs-pvc',
mount_path=self.deploy_parameters.checkpoint_dir mount_path=trainer_params.checkpoint_dir
)] )]
) )
template = client.V1PodTemplateSpec( template = client.V1PodTemplateSpec(
@@ -203,7 +154,7 @@ class Kubernetes(Deploy):
), ),
) )
deployment_spec = client.V1DeploymentSpec( deployment_spec = client.V1DeploymentSpec(
replicas=self.deploy_parameters.num_workers, replicas=trainer_params.num_replicas,
template=template, template=template,
selector=client.V1LabelSelector( selector=client.V1LabelSelector(
match_labels={'app': name} match_labels={'app': name}
@@ -220,23 +171,30 @@ class Kubernetes(Deploy):
api_client = client.AppsV1Api() api_client = client.AppsV1Api()
try: try:
api_client.create_namespaced_deployment(self.deploy_parameters.namespace, deployment) api_client.create_namespaced_deployment(self.deploy_parameters.namespace, deployment)
trainer_params.orchestration_params['deployment_name'] = name
return True return True
except client.rest.ApiException as e: except client.rest.ApiException as e:
print("Got exception: %s\n while creating deployment", e) print("Got exception: %s\n while creating deployment", e)
return False return False
def create_k8s_job(self): def deploy_worker(self):
name = "{}-{}".format(self.deploy_parameters.name, uuid.uuid4())
worker_params = self.deploy_parameters.run_type_params.get('worker', None)
if not worker_params:
return False
worker_params.command += ['--memory_backend_params', json.dumps(self.deploy_parameters.memory_backend_parameters.__dict__)]
name = "{}-{}".format(worker_params.run_type, uuid.uuid4())
container = client.V1Container( container = client.V1Container(
name=name, name=name,
image=self.deploy_parameters.image, image=worker_params.image,
command=self.deploy_parameters.command, command=worker_params.command,
args=self.deploy_parameters.arguments, args=worker_params.arguments,
image_pull_policy='Always', image_pull_policy='Always',
volume_mounts=[client.V1VolumeMount( volume_mounts=[client.V1VolumeMount(
name='nfs-pvc', name='nfs-pvc',
mount_path=self.deploy_parameters.checkpoint_dir mount_path=worker_params.checkpoint_dir
)] )]
) )
template = client.V1PodTemplateSpec( template = client.V1PodTemplateSpec(
@@ -249,27 +207,104 @@ class Kubernetes(Deploy):
claim_name=self.nfs_pvc_name claim_name=self.nfs_pvc_name
) )
)], )],
restart_policy='Never'
), ),
) )
job_spec = client.V1JobSpec( deployment_spec = client.V1DeploymentSpec(
parallelism=self.deploy_parameters.num_workers, replicas=worker_params.num_replicas,
template=template, template=template,
completions=2147483647 selector=client.V1LabelSelector(
match_labels={'app': name}
)
) )
deployment = client.V1Deployment(
job = client.V1Job( api_version='apps/v1',
api_version='batch/v1', kind="Deployment",
kind='Job',
metadata=client.V1ObjectMeta(name=name), metadata=client.V1ObjectMeta(name=name),
spec=job_spec spec=deployment_spec
) )
api_client = client.BatchV1Api() api_client = client.AppsV1Api()
try: try:
api_client.create_namespaced_job(self.deploy_parameters.namespace, job) api_client.create_namespaced_deployment(self.deploy_parameters.namespace, deployment)
worker_params.orchestration_params['deployment_name'] = name
return True return True
except client.rest.ApiException as e: except client.rest.ApiException as e:
print("Got exception: %s\n while creating deployment", e) print("Got exception: %s\n while creating deployment", e)
return False return False
def worker_logs(self):
pass
def trainer_logs(self):
trainer_params = self.deploy_parameters.run_type_params.get('trainer', None)
if not trainer_params:
return
api_client = client.CoreV1Api()
pod = None
try:
pods = api_client.list_namespaced_pod(self.deploy_parameters.namespace, label_selector='app={}'.format(
trainer_params.orchestration_params['deployment_name']
))
pod = pods.items[0]
except client.rest.ApiException as e:
print("Got exception: %s\n while reading pods", e)
return
if not pod:
return
self.tail_log(pod.metadata.name, api_client)
def tail_log(self, pod_name, corev1_api):
while True:
time.sleep(10)
# Try to tail the pod logs
try:
print(corev1_api.read_namespaced_pod_log(
pod_name, self.deploy_parameters.namespace, follow=True
), flush=True)
except client.rest.ApiException as e:
pass
# This part will get executed if the pod is one of the following phases: not ready, failed or terminated.
# Check if the pod has errored out, else just try again.
# Get the pod
try:
pod = corev1_api.read_namespaced_pod(pod_name, self.deploy_parameters.namespace)
except client.rest.ApiException as e:
continue
if not hasattr(pod, 'status') or not pod.status:
continue
if not hasattr(pod.status, 'container_statuses') or not pod.status.container_statuses:
continue
for container_status in pod.status.container_statuses:
if container_status.state.waiting is not None:
if container_status.state.waiting.reason == 'Error' or \
container_status.state.waiting.reason == 'CrashLoopBackOff' or \
container_status.state.waiting.reason == 'ImagePullBackOff' or \
container_status.state.waiting.reason == 'ErrImagePull':
return
if container_status.state.terminated is not None:
return
def undeploy(self):
trainer_params = self.deploy_parameters.run_type_params.get('trainer', None)
api_client = client.AppsV1Api()
delete_options = client.V1DeleteOptions()
if trainer_params:
try:
api_client.delete_namespaced_deployment(trainer_params.orchestration_params['deployment_name'], self.deploy_parameters.namespace, delete_options)
except client.rest.ApiException as e:
print("Got exception: %s\n while deleting trainer", e)
worker_params = self.deploy_parameters.run_type_params.get('worker', None)
if worker_params:
try:
api_client.delete_namespaced_deployment(worker_params.orchestration_params['deployment_name'], self.deploy_parameters.namespace, delete_options)
except client.rest.ApiException as e:
print("Got exception: %s\n while deleting workers", e)
self.memory_backend.undeploy()

View File

@@ -1,51 +1,43 @@
import argparse import argparse
from rl_coach.orchestrators.kubernetes_orchestrator import KubernetesParameters, Kubernetes from rl_coach.orchestrators.kubernetes_orchestrator import KubernetesParameters, Kubernetes, RunTypeParameters
from rl_coach.memories.backend.redis import RedisPubSubMemoryBackendParameters
def main(preset: str, image: str='ajaysudh/testing:coach', redis_ip: str=None, redis_port:int=None, num_workers: int=1, nfs_server: str="", nfs_path: str=""): def main(preset: str, image: str='ajaysudh/testing:coach', num_workers: int=1, nfs_server: str="", nfs_path: str="", memory_backend: str=""):
rollout_command = ['python3', 'rl_coach/rollout_worker.py', '-p', preset] rollout_command = ['python3', 'rl_coach/rollout_worker.py', '-p', preset]
training_command = ['python3', 'rl_coach/training_worker.py', '-p', preset] training_command = ['python3', 'rl_coach/training_worker.py', '-p', preset]
""" memory_backend_params = RedisPubSubMemoryBackendParameters()
TODO:
1. Create a NFS backed PV for checkpointing.
a. Include that in both (worker, trainer) containers.
b. Change checkpoint writing logic to always write to a temporary file and then rename.
2. Test e2e 1 loop.
a. Trainer writes a checkpoint
b. Rollout worker picks it and gathers experience, writes back to redis.
c. 1 rollout worker, 1 trainer.
3. Trainer should be a job (not a deployment)
a. When all the epochs of training are done, workers should also be deleted.
4. Test e2e with multiple rollout workers.
5. Test e2e with multiple rollout workers and multiple loops.
"""
training_params = KubernetesParameters("train", image, training_command, kubeconfig='~/.kube/config', redis_ip=redis_ip, redis_port=redis_port, worker_run_type_params = RunTypeParameters(image, rollout_command, run_type="worker")
nfs_server=nfs_server, nfs_path=nfs_path) trainer_run_type_params = RunTypeParameters(image, training_command, run_type="trainer")
training_obj = Kubernetes(training_params)
if not training_obj.setup(): orchestration_params = KubernetesParameters([worker_run_type_params, trainer_run_type_params], kubeconfig='~/.kube/config', nfs_server=nfs_server,
nfs_path=nfs_path, memory_backend_parameters=memory_backend_params)
orchestrator = Kubernetes(orchestration_params)
if not orchestrator.setup():
print("Could not setup") print("Could not setup")
return return
rollout_params = KubernetesParameters("worker", image, rollout_command, kubeconfig='~/.kube/config', redis_ip=training_params.redis_ip, redis_port=training_params.redis_port, num_workers=num_workers) if orchestrator.deploy_trainer():
rollout_obj = Kubernetes(rollout_params)
# if not rollout_obj.setup():
# print("Could not setup")
if training_obj.deploy():
print("Successfully deployed") print("Successfully deployed")
else: else:
print("Could not deploy") print("Could not deploy")
return return
if rollout_obj.deploy(): if orchestrator.deploy_worker():
print("Successfully deployed") print("Successfully deployed")
else: else:
print("Could not deploy") print("Could not deploy")
return return
try:
orchestrator.trainer_logs()
except KeyboardInterrupt:
pass
orchestrator.undeploy()
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
@@ -65,6 +57,10 @@ if __name__ == '__main__':
help="(string) Exported path for the nfs server", help="(string) Exported path for the nfs server",
type=str, type=str,
required=True) required=True)
parser.add_argument('--memory_backend',
help="(string) Memory backend to use",
type=str,
default="redispubsub")
# parser.add_argument('--checkpoint_dir', # parser.add_argument('--checkpoint_dir',
# help='(string) Path to a folder containing a checkpoint to write the model to.', # help='(string) Path to a folder containing a checkpoint to write the model to.',
@@ -72,4 +68,4 @@ if __name__ == '__main__':
# default='/checkpoint') # default='/checkpoint')
args = parser.parse_args() args = parser.parse_args()
main(preset=args.preset, image=args.image, nfs_server=args.nfs_server, nfs_path=args.nfs_path) main(preset=args.preset, image=args.image, nfs_server=args.nfs_server, nfs_path=args.nfs_path, memory_backend=args.memory_backend)

View File

@@ -1,63 +0,0 @@
from rl_coach.agents.dqn_agent import DQNAgentParametersDistributed
from rl_coach.base_parameters import VisualizationParameters, PresetValidationParameters
from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps, RunPhase
from rl_coach.environments.environment import SelectedPhaseOnlyDumpMethod, MaxDumpMethod
from rl_coach.environments.gym_environment import Mujoco
from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager
from rl_coach.graph_managers.graph_manager import ScheduleParameters
from rl_coach.memories.memory import MemoryGranularity
from rl_coach.schedules import LinearSchedule
####################
# Graph Scheduling #
####################
schedule_params = ScheduleParameters()
schedule_params.improve_steps = TrainingSteps(10000000000)
schedule_params.steps_between_evaluation_periods = EnvironmentEpisodes(10)
schedule_params.evaluation_steps = EnvironmentEpisodes(1)
schedule_params.heatup_steps = EnvironmentSteps(1000)
#########
# Agent #
#########
agent_params = DQNAgentParametersDistributed()
# DQN params
agent_params.algorithm.num_steps_between_copying_online_weights_to_target = EnvironmentSteps(100)
agent_params.algorithm.discount = 0.99
agent_params.algorithm.num_consecutive_playing_steps = EnvironmentSteps(1)
# NN configuration
agent_params.network_wrappers['main'].learning_rate = 0.00025
agent_params.network_wrappers['main'].replace_mse_with_huber_loss = False
# ER size
agent_params.memory.max_size = (MemoryGranularity.Transitions, 40000)
# E-Greedy schedule
agent_params.exploration.epsilon_schedule = LinearSchedule(1.0, 0.01, 10000)
################
# Environment #
################
env_params = Mujoco()
env_params.level = 'CartPole-v0'
vis_params = VisualizationParameters()
vis_params.video_dump_methods = [SelectedPhaseOnlyDumpMethod(RunPhase.TEST), MaxDumpMethod()]
vis_params.dump_mp4 = False
########
# Test #
########
preset_validation_params = PresetValidationParameters()
preset_validation_params.test = True
preset_validation_params.min_reward_threshold = 150
preset_validation_params.max_episodes_to_achieve_reward = 250
graph_manager = BasicRLGraphManager(agent_params=agent_params, env_params=env_params,
schedule_params=schedule_params, vis_params=vis_params,
preset_validation_params=preset_validation_params)

View File

@@ -10,11 +10,14 @@ this rollout worker:
import argparse import argparse
import time import time
import os import os
import json
from rl_coach.base_parameters import TaskParameters from rl_coach.base_parameters import TaskParameters
from rl_coach.coach import expand_preset from rl_coach.coach import expand_preset
from rl_coach.core_types import EnvironmentEpisodes, RunPhase from rl_coach.core_types import EnvironmentEpisodes, RunPhase
from rl_coach.utils import short_dynamic_import from rl_coach.utils import short_dynamic_import
from rl_coach.memories.backend.memory_impl import construct_memory_params
# Q: specify alternative distributed memory, or should this go in the preset? # Q: specify alternative distributed memory, or should this go in the preset?
# A: preset must define distributed memory to be used. we aren't going to take # A: preset must define distributed memory to be used. we aren't going to take
@@ -58,9 +61,12 @@ def rollout_worker(graph_manager, checkpoint_dir):
task_parameters = TaskParameters() task_parameters = TaskParameters()
task_parameters.__dict__['checkpoint_restore_dir'] = checkpoint_dir task_parameters.__dict__['checkpoint_restore_dir'] = checkpoint_dir
graph_manager.create_graph(task_parameters) graph_manager.create_graph(task_parameters)
graph_manager.phase = RunPhase.TRAIN graph_manager.phase = RunPhase.TRAIN
graph_manager.act(EnvironmentEpisodes(num_steps=10))
for i in range(10000000):
graph_manager.act(EnvironmentEpisodes(num_steps=10))
graph_manager.restore_checkpoint()
graph_manager.phase = RunPhase.UNDEFINED graph_manager.phase = RunPhase.UNDEFINED
@@ -82,13 +88,19 @@ def main():
help="(int) Port of the redis server", help="(int) Port of the redis server",
default=6379, default=6379,
type=int) type=int)
parser.add_argument('--memory_backend_params',
help="(string) JSON string of the memory backend params",
type=str)
args = parser.parse_args() args = parser.parse_args()
graph_manager = short_dynamic_import(expand_preset(args.preset), ignore_module_case=True) graph_manager = short_dynamic_import(expand_preset(args.preset), ignore_module_case=True)
graph_manager.agent_params.memory.redis_ip = args.redis_ip if args.memory_backend_params:
graph_manager.agent_params.memory.redis_port = args.redis_port args.memory_backend_params = json.loads(args.memory_backend_params)
if 'run_type' not in args.memory_backend_params:
args.memory_backend_params['run_type'] = 'worker'
graph_manager.agent_params.memory.register_var('memory_backend_params', construct_memory_params(args.memory_backend_params))
rollout_worker( rollout_worker(
graph_manager=graph_manager, graph_manager=graph_manager,
checkpoint_dir=args.checkpoint_dir, checkpoint_dir=args.checkpoint_dir,

View File

@@ -2,42 +2,17 @@
""" """
import argparse import argparse
import time import time
import json
from rl_coach.base_parameters import TaskParameters from rl_coach.base_parameters import TaskParameters
from rl_coach.coach import expand_preset from rl_coach.coach import expand_preset
from rl_coach import core_types from rl_coach import core_types
from rl_coach.utils import short_dynamic_import from rl_coach.utils import short_dynamic_import
from rl_coach.memories.non_episodic.distributed_experience_replay import DistributedExperienceReplay from rl_coach.memories.backend.memory_impl import construct_memory_params
from rl_coach.memories.memory import MemoryGranularity
# Q: specify alternative distributed memory, or should this go in the preset? # Q: specify alternative distributed memory, or should this go in the preset?
# A: preset must define distributed memory to be used. we aren't going to take a non-distributed preset and automatically distribute it. # A: preset must define distributed memory to be used. we aren't going to take a non-distributed preset and automatically distribute it.
def heatup(graph_manager):
memory = DistributedExperienceReplay(max_size=(MemoryGranularity.Transitions, 1000000),
redis_ip=graph_manager.agent_params.memory.redis_ip,
redis_port=graph_manager.agent_params.memory.redis_port)
while(memory.num_transitions() < graph_manager.heatup_steps.num_steps):
time.sleep(1)
class StepsLoop(object):
"""StepsLoop facilitates a simple while loop"""
def __init__(self, steps_counters, phase, steps):
super(StepsLoop, self).__init__()
self.steps_counters = steps_counters
self.phase = phase
self.steps = steps
self.step_end = self._step_count() + steps.num_steps
def _step_count(self):
return self.steps_counters[self.phase][self.steps.__class__]
def continue(self):
return self._step_count() < count_end:
def training_worker(graph_manager, checkpoint_dir): def training_worker(graph_manager, checkpoint_dir):
""" """
@@ -51,12 +26,8 @@ def training_worker(graph_manager, checkpoint_dir):
# save randomly initialized graph # save randomly initialized graph
graph_manager.save_checkpoint() graph_manager.save_checkpoint()
# optionally wait for a specific number of transitions to be in memory before training
heatup(graph_manager)
# training loop # training loop
stepper = StepsLoop(graph_manager.total_steps_counters, RunPhase.TRAIN, graph_manager.improve_steps) while True:
while stepper.continue():
graph_manager.phase = core_types.RunPhase.TRAIN graph_manager.phase = core_types.RunPhase.TRAIN
graph_manager.train(core_types.TrainingSteps(1)) graph_manager.train(core_types.TrainingSteps(1))
graph_manager.phase = core_types.RunPhase.UNDEFINED graph_manager.phase = core_types.RunPhase.UNDEFINED
@@ -65,7 +36,6 @@ def training_worker(graph_manager, checkpoint_dir):
graph_manager.save_checkpoint() graph_manager.save_checkpoint()
# TODO: signal to workers that training is done
def main(): def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
@@ -85,12 +55,18 @@ def main():
help="(int) Port of the redis server", help="(int) Port of the redis server",
default=6379, default=6379,
type=int) type=int)
parser.add_argument('--memory_backend_params',
help="(string) JSON string of the memory backend params",
type=str)
args = parser.parse_args() args = parser.parse_args()
graph_manager = short_dynamic_import(expand_preset(args.preset), ignore_module_case=True) graph_manager = short_dynamic_import(expand_preset(args.preset), ignore_module_case=True)
graph_manager.agent_params.memory.redis_ip = args.redis_ip if args.memory_backend_params:
graph_manager.agent_params.memory.redis_port = args.redis_port args.memory_backend_params = json.loads(args.memory_backend_params)
if 'run_type' not in args.memory_backend_params:
args.memory_backend_params['run_type'] = 'trainer'
graph_manager.agent_params.memory.register_var('memory_backend_params', construct_memory_params(args.memory_backend_params))
training_worker( training_worker(
graph_manager=graph_manager, graph_manager=graph_manager,