diff --git a/rl_coach/agents/agent.py b/rl_coach/agents/agent.py index 13d1b8d..2dd1e9a 100644 --- a/rl_coach/agents/agent.py +++ b/rl_coach/agents/agent.py @@ -33,6 +33,7 @@ from rl_coach.memories.episodic.episodic_experience_replay import EpisodicExperi from rl_coach.spaces import SpacesDefinition, VectorObservationSpace, GoalsSpace, AttentionActionSpace from rl_coach.utils import Signal, force_list from rl_coach.utils import dynamic_import_and_instantiate_module_from_params +from rl_coach.memories.backend.memory_impl import get_memory_backend class Agent(AgentInterface): @@ -76,6 +77,14 @@ class Agent(AgentInterface): # modules self.memory = dynamic_import_and_instantiate_module_from_params(self.ap.memory) + if hasattr(self.ap.memory, 'memory_backend_params'): + self.memory_backend = get_memory_backend(self.ap.memory.memory_backend_params) + + if self.ap.memory.memory_backend_params.run_type == 'trainer': + self.memory_backend.subscribe(self.memory) + else: + self.memory.set_memory_backend(self.memory_backend) + if agent_parameters.memory.load_memory_from_file_path: screen.log_title("Loading replay buffer from pickle. Pickle path: {}" .format(agent_parameters.memory.load_memory_from_file_path)) @@ -534,6 +543,9 @@ class Agent(AgentInterface): Determine if we should start a training phase according to the number of steps passed since the last training :return: boolean: True if we should start a training phase """ + + if hasattr(self.ap.memory, 'memory_backend_params'): + self.total_steps_counter = self.call_memory('num_transitions') step_method = self.ap.algorithm.num_consecutive_playing_steps if step_method.__class__ == EnvironmentEpisodes: should_update = (self.current_episode - self.last_training_phase_step) >= step_method.num_steps diff --git a/rl_coach/agents/clipped_ppo_agent.py b/rl_coach/agents/clipped_ppo_agent.py index ba6851b..0091841 100644 --- a/rl_coach/agents/clipped_ppo_agent.py +++ b/rl_coach/agents/clipped_ppo_agent.py @@ -1,5 +1,5 @@ # -# Copyright (c) 2017 Intel Corporation +# Copyright (c) 2017 Intel Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/rl_coach/agents/dqn_agent.py b/rl_coach/agents/dqn_agent.py index b300b79..f261a08 100644 --- a/rl_coach/agents/dqn_agent.py +++ b/rl_coach/agents/dqn_agent.py @@ -27,7 +27,6 @@ from rl_coach.architectures.tensorflow_components.embedders.embedder import Inpu from rl_coach.core_types import EnvironmentSteps from rl_coach.exploration_policies.e_greedy import EGreedyParameters from rl_coach.memories.non_episodic.experience_replay import ExperienceReplayParameters -from rl_coach.memories.non_episodic.distributed_experience_replay import DistributedExperienceReplayParameters from rl_coach.schedules import LinearSchedule @@ -51,20 +50,6 @@ class DQNNetworkParameters(NetworkParameters): self.create_target_network = True -class DQNAgentParametersDistributed(AgentParameters): - def __init__(self): - super().__init__(algorithm=DQNAlgorithmParameters(), - exploration=EGreedyParameters(), - memory=DistributedExperienceReplayParameters(), - networks={"main": DQNNetworkParameters()}) - self.exploration.epsilon_schedule = LinearSchedule(1, 0.1, 1000000) - self.exploration.evaluation_epsilon = 0.05 - - @property - def path(self): - return 'rl_coach.agents.dqn_agent:DQNAgent' - - class DQNAgentParameters(AgentParameters): def __init__(self): super().__init__(algorithm=DQNAlgorithmParameters(), diff --git a/rl_coach/memories/backend/__init__.py b/rl_coach/memories/backend/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/rl_coach/memories/backend/memory.py b/rl_coach/memories/backend/memory.py new file mode 100644 index 0000000..4199cf6 --- /dev/null +++ b/rl_coach/memories/backend/memory.py @@ -0,0 +1,36 @@ + + +class MemoryBackendParameters(object): + + def __init__(self, store_type, orchestrator_type, run_type, deployed: str = False): + self.store_type = store_type + self.orchestrator_type = orchestrator_type + self.run_type = run_type + self.deployed = deployed + + +class MemoryBackend(object): + + def __init__(self, params: MemoryBackendParameters): + pass + + def deploy(self): + raise NotImplemented("Not yet implemented") + + def get_endpoint(self): + raise NotImplemented("Not yet implemented") + + def undeploy(self): + raise NotImplemented("Not yet implemented") + + def sample(self, size: int): + raise NotImplemented("Not yet implemented") + + def store(self, obj): + raise NotImplemented("Not yet implemented") + + def store_episode(self, obj): + raise NotImplemented("Not yet implemented") + + def subscribe(self, memory): + raise NotImplemented("Not yet implemented") diff --git a/rl_coach/memories/backend/memory_impl.py b/rl_coach/memories/backend/memory_impl.py new file mode 100644 index 0000000..7470ef5 --- /dev/null +++ b/rl_coach/memories/backend/memory_impl.py @@ -0,0 +1,21 @@ + +from rl_coach.memories.backend.memory import MemoryBackendParameters +from rl_coach.memories.backend.redis import RedisPubSubBackend, RedisPubSubMemoryBackendParameters + + +def get_memory_backend(params: MemoryBackendParameters): + + backend = None + if type(params) == RedisPubSubMemoryBackendParameters: + backend = RedisPubSubBackend(params) + + return backend + + +def construct_memory_params(json: dict): + + if json['store_type'] == 'redispubsub': + memory_params = RedisPubSubMemoryBackendParameters( + json['redis_address'], json['redis_port'], channel=json.get('channel', ''), run_type=json['run_type'] + ) + return memory_params diff --git a/rl_coach/memories/backend/redis.py b/rl_coach/memories/backend/redis.py new file mode 100644 index 0000000..f060d45 --- /dev/null +++ b/rl_coach/memories/backend/redis.py @@ -0,0 +1,160 @@ + +import redis +import pickle +import uuid +import threading +from kubernetes import client + +from rl_coach.memories.backend.memory import MemoryBackend, MemoryBackendParameters +from rl_coach.memories.memory import Memory +from rl_coach.core_types import Transition, Episode + + +class RedisPubSubMemoryBackendParameters(MemoryBackendParameters): + + def __init__(self, redis_address: str="", redis_port: int=6379, channel: str="channel-{}".format(uuid.uuid4()), + orchestrator_params: dict=None, run_type='trainer', orchestrator_type: str = "kubernetes", deployed: str = False): + self.redis_address = redis_address + self.redis_port = redis_port + self.channel = channel + if not orchestrator_params: + orchestrator_params = {} + self.orchestrator_params = orchestrator_params + self.run_type = run_type + self.store_type = "redispubsub" + self.orchestrator_type = orchestrator_type + self.deployed = deployed + + +class RedisPubSubBackend(MemoryBackend): + + def __init__(self, params: RedisPubSubMemoryBackendParameters): + self.params = params + self.redis_connection = redis.Redis(self.params.redis_address, self.params.redis_port) + + def store(self, obj): + self.redis_connection.publish(self.params.channel, pickle.dumps(obj)) + + def deploy(self): + if not self.params.deployed: + if self.params.orchestrator_type == 'kubernetes': + self.deploy_kubernetes() + self.params.deployed = True + + def deploy_kubernetes(self): + + if 'namespace' not in self.params.orchestrator_params: + self.params.orchestrator_params['namespace'] = "default" + + container = client.V1Container( + name="redis-server", + image='redis:4-alpine', + ) + template = client.V1PodTemplateSpec( + metadata=client.V1ObjectMeta(labels={'app': 'redis-server'}), + spec=client.V1PodSpec( + containers=[container] + ) + ) + deployment_spec = client.V1DeploymentSpec( + replicas=1, + template=template, + selector=client.V1LabelSelector( + match_labels={'app': 'redis-server'} + ) + ) + + deployment = client.V1Deployment( + api_version='apps/v1', + kind='Deployment', + metadata=client.V1ObjectMeta(name='redis-server', labels={'app': 'redis-server'}), + spec=deployment_spec + ) + + api_client = client.AppsV1Api() + try: + api_client.create_namespaced_deployment(self.params.orchestrator_params['namespace'], deployment) + except client.rest.ApiException as e: + print("Got exception: %s\n while creating redis-server", e) + return False + + core_v1_api = client.CoreV1Api() + + service = client.V1Service( + api_version='v1', + kind='Service', + metadata=client.V1ObjectMeta( + name='redis-service' + ), + spec=client.V1ServiceSpec( + selector={'app': 'redis-server'}, + ports=[client.V1ServicePort( + protocol='TCP', + port=6379, + target_port=6379 + )] + ) + ) + + try: + core_v1_api.create_namespaced_service(self.params.orchestrator_params['namespace'], service) + self.params.redis_address = 'redis-service.{}.svc'.format(self.params.orchestrator_params['namespace']) + self.params.redis_port = 6379 + return True + except client.rest.ApiException as e: + print("Got exception: %s\n while creating a service for redis-server", e) + return False + + def undeploy(self): + if not self.params.deployed: + return + api_client = client.AppsV1Api() + delete_options = client.V1DeleteOptions() + try: + api_client.delete_namespaced_deployment('redis-server', self.params.orchestrator_params['namespace'], delete_options) + except client.rest.ApiException as e: + print("Got exception: %s\n while deleting redis-server", e) + + api_client = client.CoreV1Api() + try: + api_client.delete_namespaced_service('redis-service', self.params.orchestrator_params['namespace'], delete_options) + except client.rest.ApiException as e: + print("Got exception: %s\n while deleting redis-server", e) + + self.params.deployed = False + + def sample(self, size): + pass + + def subscribe(self, memory): + redis_sub = RedisSub(memory, redis_address=self.params.redis_address, redis_port=self.params.redis_port, channel=self.params.channel) + redis_sub.daemon = True + redis_sub.start() + + def get_endpoint(self): + return {'redis_address': self.params.redis_address, + 'redis_port': self.params.redis_port} + + +class RedisSub(threading.Thread): + + def __init__(self, memory: Memory, redis_address: str = "localhost", redis_port: int=6379, channel: str = "PubsubChannel"): + super().__init__() + self.redis_connection = redis.Redis(redis_address, redis_port) + self.pubsub = self.redis_connection.pubsub() + self.subscriber = None + self.memory = memory + self.channel = channel + self.subscriber = self.pubsub.subscribe(self.channel) + + def run(self): + for message in self.pubsub.listen(): + if message and 'data' in message: + try: + obj = pickle.loads(message['data']) + if type(obj) == Transition: + self.memory.store(obj) + elif type(obj) == Episode: + self.memory.store_episode(obj) + except Exception: + continue diff --git a/rl_coach/memories/episodic/episodic_experience_replay.py b/rl_coach/memories/episodic/episodic_experience_replay.py index 79f9aaf..f174a74 100644 --- a/rl_coach/memories/episodic/episodic_experience_replay.py +++ b/rl_coach/memories/episodic/episodic_experience_replay.py @@ -160,6 +160,8 @@ class EpisodicExperienceReplay(Memory): :param transition: a transition to store :return: None """ + # Calling super.store() so that in case a memory backend is used, the memory backend can store this transition. + super().store(transition) self.reader_writer_lock.lock_writing_and_reading() if len(self._buffer) == 0: @@ -181,6 +183,9 @@ class EpisodicExperienceReplay(Memory): :param episode: the new episode to store :return: None """ + # Calling super.store() so that in case a memory backend is used, the memory backend can store this episode. + super().store(episode) + if lock: self.reader_writer_lock.lock_writing_and_reading() diff --git a/rl_coach/memories/episodic/episodic_hindsight_experience_replay.py b/rl_coach/memories/episodic/episodic_hindsight_experience_replay.py index c30f451..ad006f1 100644 --- a/rl_coach/memories/episodic/episodic_hindsight_experience_replay.py +++ b/rl_coach/memories/episodic/episodic_hindsight_experience_replay.py @@ -106,6 +106,10 @@ class EpisodicHindsightExperienceReplay(EpisodicExperienceReplay): ] def store_episode(self, episode: Episode, lock: bool=True) -> None: + + # Calling super.store() so that in case a memory backend is used, the memory backend can store this episode. + super().store_episode(episode) + # generate hindsight transitions only when an episode is finished last_episode_transitions = copy.copy(episode.transitions) diff --git a/rl_coach/memories/episodic/episodic_hrl_hindsight_experience_replay.py b/rl_coach/memories/episodic/episodic_hrl_hindsight_experience_replay.py index c433c5e..9ec155d 100644 --- a/rl_coach/memories/episodic/episodic_hrl_hindsight_experience_replay.py +++ b/rl_coach/memories/episodic/episodic_hrl_hindsight_experience_replay.py @@ -59,6 +59,10 @@ class EpisodicHRLHindsightExperienceReplay(EpisodicHindsightExperienceReplay): # for a layer producing sub-goals, we will replace in hindsight the action (sub-goal) given to the lower # level with the actual achieved goal. the achieved goal (and observation) seen is assumed to be the same # for all levels - we can use this level's achieved goal instead of the lower level's one + + # Calling super.store() so that in case a memory backend is used, the memory backend can store this episode. + super().store_episode(episode) + for transition in episode.transitions: new_achieved_goal = transition.next_state[self.goals_space.goal_name] transition.action = new_achieved_goal diff --git a/rl_coach/memories/memory.py b/rl_coach/memories/memory.py index e5285e0..5c56cd0 100644 --- a/rl_coach/memories/memory.py +++ b/rl_coach/memories/memory.py @@ -18,6 +18,7 @@ from enum import Enum from typing import Tuple from rl_coach.base_parameters import Parameters +from rl_coach.memories.backend.memory import MemoryBackend class MemoryGranularity(Enum): @@ -32,7 +33,6 @@ class MemoryParameters(Parameters): self.shared_memory = False self.load_memory_from_file_path = None - @property def path(self): return 'rl_coach.memories.memory:Memory' @@ -45,9 +45,16 @@ class Memory(object): """ self.max_size = max_size self._length = 0 + self.memory_backend = None def store(self, obj): - raise NotImplementedError("") + if self.memory_backend: + self.memory_backend.store(obj) + + def store_episode(self, episode): + if self.memory_backend: + for transition in episode: + self.memory_backend.store(transition) def get(self, index): raise NotImplementedError("") @@ -64,4 +71,5 @@ class Memory(object): def clean(self): raise NotImplementedError("") - + def set_memory_backend(self, memory_backend: MemoryBackend): + self.memory_backend = memory_backend diff --git a/rl_coach/memories/non_episodic/balanced_experience_replay.py b/rl_coach/memories/non_episodic/balanced_experience_replay.py index 24f2c19..7bfb48b 100644 --- a/rl_coach/memories/non_episodic/balanced_experience_replay.py +++ b/rl_coach/memories/non_episodic/balanced_experience_replay.py @@ -72,6 +72,8 @@ class BalancedExperienceReplay(ExperienceReplay): locks and then calls store with lock = True :return: None """ + # Calling super.store() so that in case a memory backend is used, the memory backend can store this transition. + super().store(transition) if lock: self.reader_writer_lock.lock_writing_and_reading() diff --git a/rl_coach/memories/non_episodic/distributed_experience_replay.py b/rl_coach/memories/non_episodic/distributed_experience_replay.py deleted file mode 100644 index 61b24b0..0000000 --- a/rl_coach/memories/non_episodic/distributed_experience_replay.py +++ /dev/null @@ -1,182 +0,0 @@ -# -# Copyright (c) 2017 Intel Corporation -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -from typing import List, Tuple, Union - -import numpy as np -import redis -import uuid -import pickle - -from rl_coach.core_types import Transition -from rl_coach.memories.memory import Memory, MemoryGranularity, MemoryParameters - - -class DistributedExperienceReplayParameters(MemoryParameters): - def __init__(self): - super().__init__() - self.max_size = (MemoryGranularity.Transitions, 1000000) - self.allow_duplicates_in_batch_sampling = True - self.redis_ip = 'localhost' - self.redis_port = 6379 - self.redis_db = 0 - - @property - def path(self): - return 'rl_coach.memories.non_episodic.distributed_experience_replay:DistributedExperienceReplay' - - -class DistributedExperienceReplay(Memory): - """ - A regular replay buffer which stores transition without any additional structure - """ - def __init__(self, max_size: Tuple[MemoryGranularity, int], allow_duplicates_in_batch_sampling: bool=True, - redis_ip='localhost', redis_port=6379, redis_db=0): - """ - :param max_size: the maximum number of transitions or episodes to hold in the memory - :param allow_duplicates_in_batch_sampling: allow having the same transition multiple times in a batch - """ - - super().__init__(max_size) - if max_size[0] != MemoryGranularity.Transitions: - raise ValueError("Experience replay size can only be configured in terms of transitions") - self.allow_duplicates_in_batch_sampling = allow_duplicates_in_batch_sampling - - self.db = redis_db - self.redis_connection = redis.Redis(redis_ip, redis_port, self.db) - - def length(self) -> int: - """ - Get the number of transitions in the ER - """ - return self.num_transitions() - - def num_transitions(self) -> int: - """ - Get the number of transitions in the ER - """ - try: - return self.redis_connection.info(section='keyspace')['db{}'.format(self.db)]['keys'] - except Exception as e: - return 0 - - def sample(self, size: int) -> List[Transition]: - """ - Sample a batch of transitions form the replay buffer. If the requested size is larger than the number - of samples available in the replay buffer then the batch will return empty. - :param size: the size of the batch to sample - :param beta: the beta parameter used for importance sampling - :return: a batch (list) of selected transitions from the replay buffer - """ - transition_idx = dict() - if self.allow_duplicates_in_batch_sampling: - while len(transition_idx) != size: - key = self.redis_connection.randomkey() - transition_idx[key] = pickle.loads(self.redis_connection.get(key)) - else: - if self.num_transitions() >= size: - while len(transition_idx) != size: - key = self.redis_connection.randomkey() - if key in transition_idx: - continue - transition_idx[key] = pickle.loads(self.redis_connection.get(key)) - else: - raise ValueError("The replay buffer cannot be sampled since there are not enough transitions yet. " - "There are currently {} transitions".format(self.num_transitions())) - - batch = transition_idx.values() - - return batch - - def _enforce_max_length(self) -> None: - """ - Make sure that the size of the replay buffer does not pass the maximum size allowed. - If it passes the max size, the oldest transition in the replay buffer will be removed. - This function does not use locks since it is only called internally - :return: None - """ - granularity, size = self.max_size - if granularity == MemoryGranularity.Transitions: - while size != 0 and self.num_transitions() > size: - self.redis_connection.delete(self.redis_connection.randomkey()) - else: - raise ValueError("The granularity of the replay buffer can only be set in terms of transitions") - - def store(self, transition: Transition, lock: bool=True) -> None: - """ - Store a new transition in the memory. - :param transition: a transition to store - :param lock: if true, will lock the readers writers lock. this can cause a deadlock if an inheriting class - locks and then calls store with lock = True - :return: None - """ - self.redis_connection.set(uuid.uuid4(), pickle.dumps(transition)) - self._enforce_max_length() - - def get_transition(self, transition_index: int, lock: bool=True) -> Union[None, Transition]: - """ - Returns the transition in the given index. If the transition does not exist, returns None instead. - :param transition_index: the index of the transition to return - :param lock: use write locking if this is a shared memory - :return: the corresponding transition - """ - return pickle.loads(self.redis_connection.get(transition_index)) - - def remove_transition(self, transition_index: int, lock: bool=True) -> None: - """ - Remove the transition in the given index. - - This does not remove the transition from the segment trees! it is just used to remove the transition - from the transitions list - :param transition_index: the index of the transition to remove - :return: None - """ - self.redis_connection.delete(transition_index) - - # for API compatibility - def get(self, transition_index: int, lock: bool=True) -> Union[None, Transition]: - """ - Returns the transition in the given index. If the transition does not exist, returns None instead. - :param transition_index: the index of the transition to return - :return: the corresponding transition - """ - return self.get_transition(transition_index, lock) - - # for API compatibility - def remove(self, transition_index: int, lock: bool=True): - """ - Remove the transition in the given index - :param transition_index: the index of the transition to remove - :return: None - """ - self.remove_transition(transition_index, lock) - - def clean(self, lock: bool=True) -> None: - """ - Clean the memory by removing all the episodes - :return: None - """ - self.redis_connection.flushall() - # self.transitions = [] - - def mean_reward(self) -> np.ndarray: - """ - Get the mean reward in the replay buffer - :return: the mean reward - """ - mean = np.mean([pickle.loads(self.redis_connection.get(key)).reward for key in self.redis_connection.keys()]) - - return mean diff --git a/rl_coach/memories/non_episodic/experience_replay.py b/rl_coach/memories/non_episodic/experience_replay.py index 4887e49..b3a2043 100644 --- a/rl_coach/memories/non_episodic/experience_replay.py +++ b/rl_coach/memories/non_episodic/experience_replay.py @@ -90,7 +90,6 @@ class ExperienceReplay(Memory): batch = [self.transitions[i] for i in transitions_idx] self.reader_writer_lock.release_writing() - return batch def _enforce_max_length(self) -> None: @@ -115,6 +114,8 @@ class ExperienceReplay(Memory): locks and then calls store with lock = True :return: None """ + # Calling super.store() so that in case a memory backend is used, the memory backend can store this transition. + super().store(transition) if lock: self.reader_writer_lock.lock_writing_and_reading() diff --git a/rl_coach/memories/non_episodic/prioritized_experience_replay.py b/rl_coach/memories/non_episodic/prioritized_experience_replay.py index a8fdcfc..544df48 100644 --- a/rl_coach/memories/non_episodic/prioritized_experience_replay.py +++ b/rl_coach/memories/non_episodic/prioritized_experience_replay.py @@ -267,6 +267,9 @@ class PrioritizedExperienceReplay(ExperienceReplay): :param transition: a transition to store :return: None """ + # Calling super.store() so that in case a memory backend is used, the memory backend can store this transition. + super().store(transition) + self.reader_writer_lock.lock_writing_and_reading() transition_priority = self.maximal_priority diff --git a/rl_coach/orchestrators/deploy.py b/rl_coach/orchestrators/deploy.py index d99e1d1..36b8b34 100644 --- a/rl_coach/orchestrators/deploy.py +++ b/rl_coach/orchestrators/deploy.py @@ -16,4 +16,4 @@ class Deploy(object): pass def deploy(self) -> bool: - pass \ No newline at end of file + pass diff --git a/rl_coach/orchestrators/kubernetes_orchestrator.py b/rl_coach/orchestrators/kubernetes_orchestrator.py index 71336c7..b22b681 100644 --- a/rl_coach/orchestrators/kubernetes_orchestrator.py +++ b/rl_coach/orchestrators/kubernetes_orchestrator.py @@ -1,29 +1,46 @@ import os import uuid +import json +import time +from typing import List from rl_coach.orchestrators.deploy import Deploy, DeployParameters from kubernetes import client, config +from rl_coach.memories.backend.memory import MemoryBackendParameters +from rl_coach.memories.backend.memory_impl import get_memory_backend + + +class RunTypeParameters(): + + def __init__(self, image: str, command: list(), arguments: list() = None, + run_type: str = "trainer", checkpoint_dir: str = "/checkpoint", + num_replicas: int = 1, orchestration_params: dict=None): + self.image = image + self.command = command + if not arguments: + arguments = list() + self.arguments = arguments + self.run_type = run_type + self.checkpoint_dir = checkpoint_dir + self.num_replicas = num_replicas + if not orchestration_params: + orchestration_params = dict() + self.orchestration_params = orchestration_params class KubernetesParameters(DeployParameters): - def __init__(self, name: str, image: str, command: list(), arguments: list() = list(), synchronized: bool = False, - num_workers: int = 1, kubeconfig: str = None, namespace: str = None, redis_ip: str = None, - redis_port: int = None, redis_db: int = 0, nfs_server: str = None, nfs_path: str = None, - checkpoint_dir: str = '/checkpoint'): - self.image = image - self.synchronized = synchronized - self.command = command - self.arguments = arguments + def __init__(self, run_type_params: List[RunTypeParameters], kubeconfig: str = None, namespace: str = "", nfs_server: str = None, + nfs_path: str = None, checkpoint_dir: str = '/checkpoint', memory_backend_parameters: MemoryBackendParameters = None): + + self.run_type_params = {} + for run_type_param in run_type_params: + self.run_type_params[run_type_param.run_type] = run_type_param self.kubeconfig = kubeconfig - self.num_workers = num_workers self.namespace = namespace - self.redis_ip = redis_ip - self.redis_port = redis_port - self.redis_db = redis_db self.nfs_server = nfs_server self.nfs_path = nfs_path self.checkpoint_dir = checkpoint_dir - self.name = name + self.memory_backend_parameters = memory_backend_parameters class Kubernetes(Deploy): @@ -44,17 +61,14 @@ class Kubernetes(Deploy): if os.environ.get('http_proxy'): client.Configuration._default.proxy = os.environ.get('http_proxy') + self.deploy_parameters.memory_backend_parameters.orchestrator_params = {'namespace': self.deploy_parameters.namespace} + self.memory_backend = get_memory_backend(self.deploy_parameters.memory_backend_parameters) + def setup(self) -> bool: - if not self.deploy_parameters.redis_ip: - # Need to spin up a redis service and a deployment. - if not self.deploy_redis(): - print("Failed to setup redis") - return False - + self.memory_backend.deploy() if not self.create_nfs_resources(): return False - return True def create_nfs_resources(self): @@ -107,87 +121,24 @@ class Kubernetes(Deploy): return False return True - def deploy_redis(self) -> bool: - container = client.V1Container( - name="redis-server", - image='redis:4-alpine', - ) - template = client.V1PodTemplateSpec( - metadata=client.V1ObjectMeta(labels={'app': 'redis-server'}), - spec=client.V1PodSpec( - containers=[container] - ) - ) - deployment_spec = client.V1DeploymentSpec( - replicas=1, - template=template, - selector=client.V1LabelSelector( - match_labels={'app': 'redis-server'} - ) - ) + def deploy_trainer(self) -> bool: - deployment = client.V1Deployment( - api_version='apps/v1', - kind='Deployment', - metadata=client.V1ObjectMeta(name='redis-server', labels={'app': 'redis-server'}), - spec=deployment_spec - ) - - api_client = client.AppsV1Api() - try: - api_client.create_namespaced_deployment(self.deploy_parameters.namespace, deployment) - except client.rest.ApiException as e: - print("Got exception: %s\n while creating redis-server", e) + trainer_params = self.deploy_parameters.run_type_params.get('trainer', None) + if not trainer_params: return False - core_v1_api = client.CoreV1Api() - - service = client.V1Service( - api_version='v1', - kind='Service', - metadata=client.V1ObjectMeta( - name='redis-service' - ), - spec=client.V1ServiceSpec( - selector={'app': 'redis-server'}, - ports=[client.V1ServicePort( - protocol='TCP', - port=6379, - target_port=6379 - )] - ) - ) - - try: - core_v1_api.create_namespaced_service(self.deploy_parameters.namespace, service) - self.deploy_parameters.redis_ip = 'redis-service.{}.svc'.format(self.deploy_parameters.namespace) - self.deploy_parameters.redis_port = 6379 - return True - except client.rest.ApiException as e: - print("Got exception: %s\n while creating a service for redis-server", e) - return False - - def deploy(self) -> bool: - - self.deploy_parameters.command += ['--redis_ip', self.deploy_parameters.redis_ip, '--redis_port', '{}'.format(self.deploy_parameters.redis_port)] - - if self.deploy_parameters.synchronized: - return self.create_k8s_deployment() - else: - return self.create_k8s_job() - - def create_k8s_deployment(self) -> bool: - name = "{}-{}".format(self.deploy_parameters.name, uuid.uuid4()) + trainer_params.command += ['--memory_backend_params', json.dumps(self.deploy_parameters.memory_backend_parameters.__dict__)] + name = "{}-{}".format(trainer_params.run_type, uuid.uuid4()) container = client.V1Container( name=name, - image=self.deploy_parameters.image, - command=self.deploy_parameters.command, - args=self.deploy_parameters.arguments, + image=trainer_params.image, + command=trainer_params.command, + args=trainer_params.arguments, image_pull_policy='Always', volume_mounts=[client.V1VolumeMount( name='nfs-pvc', - mount_path=self.deploy_parameters.checkpoint_dir + mount_path=trainer_params.checkpoint_dir )] ) template = client.V1PodTemplateSpec( @@ -203,7 +154,7 @@ class Kubernetes(Deploy): ), ) deployment_spec = client.V1DeploymentSpec( - replicas=self.deploy_parameters.num_workers, + replicas=trainer_params.num_replicas, template=template, selector=client.V1LabelSelector( match_labels={'app': name} @@ -220,23 +171,30 @@ class Kubernetes(Deploy): api_client = client.AppsV1Api() try: api_client.create_namespaced_deployment(self.deploy_parameters.namespace, deployment) + trainer_params.orchestration_params['deployment_name'] = name return True except client.rest.ApiException as e: print("Got exception: %s\n while creating deployment", e) return False - def create_k8s_job(self): - name = "{}-{}".format(self.deploy_parameters.name, uuid.uuid4()) + def deploy_worker(self): + + worker_params = self.deploy_parameters.run_type_params.get('worker', None) + if not worker_params: + return False + + worker_params.command += ['--memory_backend_params', json.dumps(self.deploy_parameters.memory_backend_parameters.__dict__)] + name = "{}-{}".format(worker_params.run_type, uuid.uuid4()) container = client.V1Container( name=name, - image=self.deploy_parameters.image, - command=self.deploy_parameters.command, - args=self.deploy_parameters.arguments, + image=worker_params.image, + command=worker_params.command, + args=worker_params.arguments, image_pull_policy='Always', volume_mounts=[client.V1VolumeMount( name='nfs-pvc', - mount_path=self.deploy_parameters.checkpoint_dir + mount_path=worker_params.checkpoint_dir )] ) template = client.V1PodTemplateSpec( @@ -249,27 +207,104 @@ class Kubernetes(Deploy): claim_name=self.nfs_pvc_name ) )], - restart_policy='Never' ), ) - job_spec = client.V1JobSpec( - parallelism=self.deploy_parameters.num_workers, + deployment_spec = client.V1DeploymentSpec( + replicas=worker_params.num_replicas, template=template, - completions=2147483647 + selector=client.V1LabelSelector( + match_labels={'app': name} + ) ) - - job = client.V1Job( - api_version='batch/v1', - kind='Job', + deployment = client.V1Deployment( + api_version='apps/v1', + kind="Deployment", metadata=client.V1ObjectMeta(name=name), - spec=job_spec + spec=deployment_spec ) - api_client = client.BatchV1Api() + api_client = client.AppsV1Api() try: - api_client.create_namespaced_job(self.deploy_parameters.namespace, job) + api_client.create_namespaced_deployment(self.deploy_parameters.namespace, deployment) + worker_params.orchestration_params['deployment_name'] = name return True except client.rest.ApiException as e: print("Got exception: %s\n while creating deployment", e) return False + + def worker_logs(self): + pass + + def trainer_logs(self): + trainer_params = self.deploy_parameters.run_type_params.get('trainer', None) + if not trainer_params: + return + + api_client = client.CoreV1Api() + pod = None + try: + pods = api_client.list_namespaced_pod(self.deploy_parameters.namespace, label_selector='app={}'.format( + trainer_params.orchestration_params['deployment_name'] + )) + + pod = pods.items[0] + except client.rest.ApiException as e: + print("Got exception: %s\n while reading pods", e) + return + + if not pod: + return + + self.tail_log(pod.metadata.name, api_client) + + def tail_log(self, pod_name, corev1_api): + while True: + time.sleep(10) + # Try to tail the pod logs + try: + print(corev1_api.read_namespaced_pod_log( + pod_name, self.deploy_parameters.namespace, follow=True + ), flush=True) + except client.rest.ApiException as e: + pass + + # This part will get executed if the pod is one of the following phases: not ready, failed or terminated. + # Check if the pod has errored out, else just try again. + # Get the pod + try: + pod = corev1_api.read_namespaced_pod(pod_name, self.deploy_parameters.namespace) + except client.rest.ApiException as e: + continue + + if not hasattr(pod, 'status') or not pod.status: + continue + if not hasattr(pod.status, 'container_statuses') or not pod.status.container_statuses: + continue + + for container_status in pod.status.container_statuses: + if container_status.state.waiting is not None: + if container_status.state.waiting.reason == 'Error' or \ + container_status.state.waiting.reason == 'CrashLoopBackOff' or \ + container_status.state.waiting.reason == 'ImagePullBackOff' or \ + container_status.state.waiting.reason == 'ErrImagePull': + return + if container_status.state.terminated is not None: + return + + def undeploy(self): + trainer_params = self.deploy_parameters.run_type_params.get('trainer', None) + api_client = client.AppsV1Api() + delete_options = client.V1DeleteOptions() + if trainer_params: + try: + api_client.delete_namespaced_deployment(trainer_params.orchestration_params['deployment_name'], self.deploy_parameters.namespace, delete_options) + except client.rest.ApiException as e: + print("Got exception: %s\n while deleting trainer", e) + worker_params = self.deploy_parameters.run_type_params.get('worker', None) + if worker_params: + try: + api_client.delete_namespaced_deployment(worker_params.orchestration_params['deployment_name'], self.deploy_parameters.namespace, delete_options) + except client.rest.ApiException as e: + print("Got exception: %s\n while deleting workers", e) + self.memory_backend.undeploy() diff --git a/rl_coach/orchestrators/start_training.py b/rl_coach/orchestrators/start_training.py index ba28a25..547bb5c 100644 --- a/rl_coach/orchestrators/start_training.py +++ b/rl_coach/orchestrators/start_training.py @@ -1,51 +1,43 @@ import argparse -from rl_coach.orchestrators.kubernetes_orchestrator import KubernetesParameters, Kubernetes +from rl_coach.orchestrators.kubernetes_orchestrator import KubernetesParameters, Kubernetes, RunTypeParameters +from rl_coach.memories.backend.redis import RedisPubSubMemoryBackendParameters -def main(preset: str, image: str='ajaysudh/testing:coach', redis_ip: str=None, redis_port:int=None, num_workers: int=1, nfs_server: str="", nfs_path: str=""): +def main(preset: str, image: str='ajaysudh/testing:coach', num_workers: int=1, nfs_server: str="", nfs_path: str="", memory_backend: str=""): rollout_command = ['python3', 'rl_coach/rollout_worker.py', '-p', preset] training_command = ['python3', 'rl_coach/training_worker.py', '-p', preset] - """ - TODO: - 1. Create a NFS backed PV for checkpointing. - a. Include that in both (worker, trainer) containers. - b. Change checkpoint writing logic to always write to a temporary file and then rename. - 2. Test e2e 1 loop. - a. Trainer writes a checkpoint - b. Rollout worker picks it and gathers experience, writes back to redis. - c. 1 rollout worker, 1 trainer. - 3. Trainer should be a job (not a deployment) - a. When all the epochs of training are done, workers should also be deleted. - 4. Test e2e with multiple rollout workers. - 5. Test e2e with multiple rollout workers and multiple loops. - """ + memory_backend_params = RedisPubSubMemoryBackendParameters() - training_params = KubernetesParameters("train", image, training_command, kubeconfig='~/.kube/config', redis_ip=redis_ip, redis_port=redis_port, - nfs_server=nfs_server, nfs_path=nfs_path) - training_obj = Kubernetes(training_params) - if not training_obj.setup(): + worker_run_type_params = RunTypeParameters(image, rollout_command, run_type="worker") + trainer_run_type_params = RunTypeParameters(image, training_command, run_type="trainer") + + orchestration_params = KubernetesParameters([worker_run_type_params, trainer_run_type_params], kubeconfig='~/.kube/config', nfs_server=nfs_server, + nfs_path=nfs_path, memory_backend_parameters=memory_backend_params) + orchestrator = Kubernetes(orchestration_params) + if not orchestrator.setup(): print("Could not setup") return - rollout_params = KubernetesParameters("worker", image, rollout_command, kubeconfig='~/.kube/config', redis_ip=training_params.redis_ip, redis_port=training_params.redis_port, num_workers=num_workers) - rollout_obj = Kubernetes(rollout_params) - # if not rollout_obj.setup(): - # print("Could not setup") - - if training_obj.deploy(): + if orchestrator.deploy_trainer(): print("Successfully deployed") else: print("Could not deploy") return - if rollout_obj.deploy(): + if orchestrator.deploy_worker(): print("Successfully deployed") else: print("Could not deploy") return + try: + orchestrator.trainer_logs() + except KeyboardInterrupt: + pass + orchestrator.undeploy() + if __name__ == '__main__': parser = argparse.ArgumentParser() @@ -65,6 +57,10 @@ if __name__ == '__main__': help="(string) Exported path for the nfs server", type=str, required=True) + parser.add_argument('--memory_backend', + help="(string) Memory backend to use", + type=str, + default="redispubsub") # parser.add_argument('--checkpoint_dir', # help='(string) Path to a folder containing a checkpoint to write the model to.', @@ -72,4 +68,4 @@ if __name__ == '__main__': # default='/checkpoint') args = parser.parse_args() - main(preset=args.preset, image=args.image, nfs_server=args.nfs_server, nfs_path=args.nfs_path) + main(preset=args.preset, image=args.image, nfs_server=args.nfs_server, nfs_path=args.nfs_path, memory_backend=args.memory_backend) diff --git a/rl_coach/presets/CartPole_DQN_distributed.py b/rl_coach/presets/CartPole_DQN_distributed.py deleted file mode 100644 index f4259c4..0000000 --- a/rl_coach/presets/CartPole_DQN_distributed.py +++ /dev/null @@ -1,63 +0,0 @@ -from rl_coach.agents.dqn_agent import DQNAgentParametersDistributed -from rl_coach.base_parameters import VisualizationParameters, PresetValidationParameters -from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps, RunPhase -from rl_coach.environments.environment import SelectedPhaseOnlyDumpMethod, MaxDumpMethod -from rl_coach.environments.gym_environment import Mujoco -from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager -from rl_coach.graph_managers.graph_manager import ScheduleParameters -from rl_coach.memories.memory import MemoryGranularity -from rl_coach.schedules import LinearSchedule - - - -#################### -# Graph Scheduling # -#################### - -schedule_params = ScheduleParameters() -schedule_params.improve_steps = TrainingSteps(10000000000) -schedule_params.steps_between_evaluation_periods = EnvironmentEpisodes(10) -schedule_params.evaluation_steps = EnvironmentEpisodes(1) -schedule_params.heatup_steps = EnvironmentSteps(1000) - -######### -# Agent # -######### -agent_params = DQNAgentParametersDistributed() - -# DQN params -agent_params.algorithm.num_steps_between_copying_online_weights_to_target = EnvironmentSteps(100) -agent_params.algorithm.discount = 0.99 -agent_params.algorithm.num_consecutive_playing_steps = EnvironmentSteps(1) - -# NN configuration -agent_params.network_wrappers['main'].learning_rate = 0.00025 -agent_params.network_wrappers['main'].replace_mse_with_huber_loss = False - -# ER size -agent_params.memory.max_size = (MemoryGranularity.Transitions, 40000) - -# E-Greedy schedule -agent_params.exploration.epsilon_schedule = LinearSchedule(1.0, 0.01, 10000) - -################ -# Environment # -################ -env_params = Mujoco() -env_params.level = 'CartPole-v0' - -vis_params = VisualizationParameters() -vis_params.video_dump_methods = [SelectedPhaseOnlyDumpMethod(RunPhase.TEST), MaxDumpMethod()] -vis_params.dump_mp4 = False - -######## -# Test # -######## -preset_validation_params = PresetValidationParameters() -preset_validation_params.test = True -preset_validation_params.min_reward_threshold = 150 -preset_validation_params.max_episodes_to_achieve_reward = 250 - -graph_manager = BasicRLGraphManager(agent_params=agent_params, env_params=env_params, - schedule_params=schedule_params, vis_params=vis_params, - preset_validation_params=preset_validation_params) diff --git a/rl_coach/rollout_worker.py b/rl_coach/rollout_worker.py index 27845a6..7d8a371 100644 --- a/rl_coach/rollout_worker.py +++ b/rl_coach/rollout_worker.py @@ -10,11 +10,14 @@ this rollout worker: import argparse import time import os +import json from rl_coach.base_parameters import TaskParameters from rl_coach.coach import expand_preset from rl_coach.core_types import EnvironmentEpisodes, RunPhase from rl_coach.utils import short_dynamic_import +from rl_coach.memories.backend.memory_impl import construct_memory_params + # Q: specify alternative distributed memory, or should this go in the preset? # A: preset must define distributed memory to be used. we aren't going to take @@ -58,9 +61,12 @@ def rollout_worker(graph_manager, checkpoint_dir): task_parameters = TaskParameters() task_parameters.__dict__['checkpoint_restore_dir'] = checkpoint_dir graph_manager.create_graph(task_parameters) - graph_manager.phase = RunPhase.TRAIN - graph_manager.act(EnvironmentEpisodes(num_steps=10)) + + for i in range(10000000): + graph_manager.act(EnvironmentEpisodes(num_steps=10)) + graph_manager.restore_checkpoint() + graph_manager.phase = RunPhase.UNDEFINED @@ -82,13 +88,19 @@ def main(): help="(int) Port of the redis server", default=6379, type=int) + parser.add_argument('--memory_backend_params', + help="(string) JSON string of the memory backend params", + type=str) + args = parser.parse_args() graph_manager = short_dynamic_import(expand_preset(args.preset), ignore_module_case=True) - graph_manager.agent_params.memory.redis_ip = args.redis_ip - graph_manager.agent_params.memory.redis_port = args.redis_port - + if args.memory_backend_params: + args.memory_backend_params = json.loads(args.memory_backend_params) + if 'run_type' not in args.memory_backend_params: + args.memory_backend_params['run_type'] = 'worker' + graph_manager.agent_params.memory.register_var('memory_backend_params', construct_memory_params(args.memory_backend_params)) rollout_worker( graph_manager=graph_manager, checkpoint_dir=args.checkpoint_dir, diff --git a/rl_coach/training_worker.py b/rl_coach/training_worker.py index d81f54c..85f2052 100644 --- a/rl_coach/training_worker.py +++ b/rl_coach/training_worker.py @@ -2,42 +2,17 @@ """ import argparse import time +import json from rl_coach.base_parameters import TaskParameters from rl_coach.coach import expand_preset from rl_coach import core_types from rl_coach.utils import short_dynamic_import -from rl_coach.memories.non_episodic.distributed_experience_replay import DistributedExperienceReplay -from rl_coach.memories.memory import MemoryGranularity +from rl_coach.memories.backend.memory_impl import construct_memory_params # Q: specify alternative distributed memory, or should this go in the preset? # A: preset must define distributed memory to be used. we aren't going to take a non-distributed preset and automatically distribute it. -def heatup(graph_manager): - memory = DistributedExperienceReplay(max_size=(MemoryGranularity.Transitions, 1000000), - redis_ip=graph_manager.agent_params.memory.redis_ip, - redis_port=graph_manager.agent_params.memory.redis_port) - - while(memory.num_transitions() < graph_manager.heatup_steps.num_steps): - time.sleep(1) - - -class StepsLoop(object): - """StepsLoop facilitates a simple while loop""" - def __init__(self, steps_counters, phase, steps): - super(StepsLoop, self).__init__() - self.steps_counters = steps_counters - self.phase = phase - self.steps = steps - - self.step_end = self._step_count() + steps.num_steps - - def _step_count(self): - return self.steps_counters[self.phase][self.steps.__class__] - - def continue(self): - return self._step_count() < count_end: - def training_worker(graph_manager, checkpoint_dir): """ @@ -51,12 +26,8 @@ def training_worker(graph_manager, checkpoint_dir): # save randomly initialized graph graph_manager.save_checkpoint() - # optionally wait for a specific number of transitions to be in memory before training - heatup(graph_manager) - # training loop - stepper = StepsLoop(graph_manager.total_steps_counters, RunPhase.TRAIN, graph_manager.improve_steps) - while stepper.continue(): + while True: graph_manager.phase = core_types.RunPhase.TRAIN graph_manager.train(core_types.TrainingSteps(1)) graph_manager.phase = core_types.RunPhase.UNDEFINED @@ -65,7 +36,6 @@ def training_worker(graph_manager, checkpoint_dir): graph_manager.save_checkpoint() - # TODO: signal to workers that training is done def main(): parser = argparse.ArgumentParser() @@ -85,12 +55,18 @@ def main(): help="(int) Port of the redis server", default=6379, type=int) + parser.add_argument('--memory_backend_params', + help="(string) JSON string of the memory backend params", + type=str) args = parser.parse_args() graph_manager = short_dynamic_import(expand_preset(args.preset), ignore_module_case=True) - graph_manager.agent_params.memory.redis_ip = args.redis_ip - graph_manager.agent_params.memory.redis_port = args.redis_port + if args.memory_backend_params: + args.memory_backend_params = json.loads(args.memory_backend_params) + if 'run_type' not in args.memory_backend_params: + args.memory_backend_params['run_type'] = 'trainer' + graph_manager.agent_params.memory.register_var('memory_backend_params', construct_memory_params(args.memory_backend_params)) training_worker( graph_manager=graph_manager,