From ce9838a7d68f2c5a5a7a4b287e09f1eaac1b8747 Mon Sep 17 00:00:00 2001 From: Ajay Deshpande Date: Fri, 14 Sep 2018 15:58:57 -0700 Subject: [PATCH] Adding kubernetes orchestrator for rollouts, adding requirements for incremental docker builds --- requirements.txt | 16 ++ rl_coach/agents/dqn_agent.py | 15 ++ .../distributed_experience_replay.py | 24 +-- rl_coach/orchestrators/__init__.py | 15 ++ rl_coach/orchestrators/deploy.py | 19 +++ .../orchestrators/kubernetes_orchestrator.py | 153 ++++++++++++++++++ rl_coach/orchestrators/test.py | 18 +++ rl_coach/presets/CartPole_DQN_distributed.py | 71 ++++++++ rl_coach/rollout_worker.py | 2 + setup.py | 9 +- 10 files changed, 327 insertions(+), 15 deletions(-) create mode 100644 requirements.txt create mode 100644 rl_coach/orchestrators/__init__.py create mode 100644 rl_coach/orchestrators/deploy.py create mode 100644 rl_coach/orchestrators/kubernetes_orchestrator.py create mode 100644 rl_coach/orchestrators/test.py create mode 100644 rl_coach/presets/CartPole_DQN_distributed.py diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..463dd95 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,16 @@ +annoy==1.8.3 +Pillow==4.3.0 +matplotlib==2.0.2 +numpy==1.14.5 +pandas==0.22.0 +pygame==1.9.3 +PyOpenGL==3.1.0 +scipy==0.19.0 +scikit-image==0.13.0 +box2d==2.3.2 +gym==0.10.5 +bokeh==0.13.0 +futures==3.1.1 +wxPython==4.0.1 +kubernetes==7.0.0 +redis==2.10.6 \ No newline at end of file diff --git a/rl_coach/agents/dqn_agent.py b/rl_coach/agents/dqn_agent.py index f261a08..b300b79 100644 --- a/rl_coach/agents/dqn_agent.py +++ b/rl_coach/agents/dqn_agent.py @@ -27,6 +27,7 @@ from rl_coach.architectures.tensorflow_components.embedders.embedder import Inpu from rl_coach.core_types import EnvironmentSteps from rl_coach.exploration_policies.e_greedy import EGreedyParameters from rl_coach.memories.non_episodic.experience_replay import ExperienceReplayParameters +from rl_coach.memories.non_episodic.distributed_experience_replay import DistributedExperienceReplayParameters from rl_coach.schedules import LinearSchedule @@ -50,6 +51,20 @@ class DQNNetworkParameters(NetworkParameters): self.create_target_network = True +class DQNAgentParametersDistributed(AgentParameters): + def __init__(self): + super().__init__(algorithm=DQNAlgorithmParameters(), + exploration=EGreedyParameters(), + memory=DistributedExperienceReplayParameters(), + networks={"main": DQNNetworkParameters()}) + self.exploration.epsilon_schedule = LinearSchedule(1, 0.1, 1000000) + self.exploration.evaluation_epsilon = 0.05 + + @property + def path(self): + return 'rl_coach.agents.dqn_agent:DQNAgent' + + class DQNAgentParameters(AgentParameters): def __init__(self): super().__init__(algorithm=DQNAlgorithmParameters(), diff --git a/rl_coach/memories/non_episodic/distributed_experience_replay.py b/rl_coach/memories/non_episodic/distributed_experience_replay.py index d6bdf9c..6b14e81 100644 --- a/rl_coach/memories/non_episodic/distributed_experience_replay.py +++ b/rl_coach/memories/non_episodic/distributed_experience_replay.py @@ -14,7 +14,7 @@ # limitations under the License. # -from typing import List, Tuple, Union, Dict, Any +from typing import List, Tuple, Union import numpy as np import redis @@ -23,7 +23,6 @@ import pickle from rl_coach.core_types import Transition from rl_coach.memories.memory import Memory, MemoryGranularity, MemoryParameters -from rl_coach.utils import ReaderWriterLock class DistributedExperienceReplayParameters(MemoryParameters): @@ -31,6 +30,9 @@ class DistributedExperienceReplayParameters(MemoryParameters): super().__init__() self.max_size = (MemoryGranularity.Transitions, 1000000) self.allow_duplicates_in_batch_sampling = True + self.redis_ip = 'localhost' + self.redis_port = 6379 + self.redis_db = 0 @property def path(self): @@ -41,19 +43,19 @@ class DistributedExperienceReplay(Memory): """ A regular replay buffer which stores transition without any additional structure """ - def __init__(self, max_size: Tuple[MemoryGranularity, int], allow_duplicates_in_batch_sampling: bool=True, - redis_ip = 'localhost', redis_port = 6379, db = 0): + def __init__(self, max_size: Tuple[MemoryGranularity, int], allow_duplicates_in_batch_sampling: bool=True, + redis_ip='localhost', redis_port=6379, redis_db=0): """ :param max_size: the maximum number of transitions or episodes to hold in the memory :param allow_duplicates_in_batch_sampling: allow having the same transition multiple times in a batch """ - + super().__init__(max_size) if max_size[0] != MemoryGranularity.Transitions: raise ValueError("Experience replay size can only be configured in terms of transitions") self.allow_duplicates_in_batch_sampling = allow_duplicates_in_batch_sampling - self.db = db + self.db = redis_db self.redis_connection = redis.Redis(redis_ip, redis_port, self.db) def length(self) -> int: @@ -67,7 +69,7 @@ class DistributedExperienceReplay(Memory): Get the number of transitions in the ER """ return self.redis_connection.info(section='keyspace')['db{}'.format(self.db)]['keys'] - + def sample(self, size: int) -> List[Transition]: """ Sample a batch of transitions form the replay buffer. If the requested size is larger than the number @@ -75,7 +77,7 @@ class DistributedExperienceReplay(Memory): :param size: the size of the batch to sample :param beta: the beta parameter used for importance sampling :return: a batch (list) of selected transitions from the replay buffer - """ + """ transition_idx = dict() if self.allow_duplicates_in_batch_sampling: while len(transition_idx) != size: @@ -129,7 +131,7 @@ class DistributedExperienceReplay(Memory): :return: the corresponding transition """ return pickle.loads(self.redis_connection.get(transition_index)) - + def remove_transition(self, transition_index: int, lock: bool=True) -> None: """ Remove the transition in the given index. @@ -140,7 +142,7 @@ class DistributedExperienceReplay(Memory): :return: None """ self.redis_connection.delete(transition_index) - + # for API compatibility def get(self, transition_index: int, lock: bool=True) -> Union[None, Transition]: """ @@ -173,5 +175,5 @@ class DistributedExperienceReplay(Memory): :return: the mean reward """ mean = np.mean([pickle.loads(self.redis_connection.get(key)).reward for key in self.redis_connection.keys()]) - + return mean diff --git a/rl_coach/orchestrators/__init__.py b/rl_coach/orchestrators/__init__.py new file mode 100644 index 0000000..cf26739 --- /dev/null +++ b/rl_coach/orchestrators/__init__.py @@ -0,0 +1,15 @@ +# +# Copyright (c) 2017 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/rl_coach/orchestrators/deploy.py b/rl_coach/orchestrators/deploy.py new file mode 100644 index 0000000..d99e1d1 --- /dev/null +++ b/rl_coach/orchestrators/deploy.py @@ -0,0 +1,19 @@ + + + +class DeployParameters(object): + + def __init__(self): + pass + + +class Deploy(object): + + def __init__(self, deploy_parameters): + self.deploy_parameters = deploy_parameters + + def setup(self) -> bool: + pass + + def deploy(self) -> bool: + pass \ No newline at end of file diff --git a/rl_coach/orchestrators/kubernetes_orchestrator.py b/rl_coach/orchestrators/kubernetes_orchestrator.py new file mode 100644 index 0000000..9b21ee4 --- /dev/null +++ b/rl_coach/orchestrators/kubernetes_orchestrator.py @@ -0,0 +1,153 @@ + +from rl_coach.orchestrators.deploy import Deploy, DeployParameters +from kubernetes import client, config + + +class KubernetesParameters(DeployParameters): + + def __init__(self, image: str, command: list(), arguments: list() = list(), synchronized: bool = False, + num_workers: int = 1, kubeconfig: str = None, namespace: str = None, redis_ip: str = None, + redis_port: int = None, redis_db: int = 0): + self.image = image + self.synchronized = synchronized + self.command = command + self.arguments = arguments + self.kubeconfig = kubeconfig + self.num_workers = num_workers + self.namespace = namespace + self.redis_ip = redis_ip + self.redis_port = redis_port + self.redis_db = redis_db + + +class Kubernetes(Deploy): + + def __init__(self, deploy_parameters: KubernetesParameters): + super().__init__(deploy_parameters) + self.deploy_parameters = deploy_parameters + + def setup(self) -> bool: + if self.deploy_parameters.kubeconfig: + config.load_kube_config() + else: + config.load_incluster_config() + + if not self.deploy_parameters.namespace: + _, current_context = config.list_kube_config_contexts() + self.deploy_parameters.namespace = current_context['context']['namespace'] + + if not self.deploy_parameters.redis_ip: + # Need to spin up a redis service and a deployment. + if not self.deploy_redis(): + print("Failed to setup redis") + return False + + self.deploy_parameters.command += ['-r', self.deploy_parameters.redis_ip, '-p', '{}'.format(self.deploy_parameters.redis_port)] + + return True + + def deploy_redis(self) -> bool: + container = client.V1Container( + name="redis-server", + image='redis:4-alpine', + ) + template = client.V1PodTemplateSpec( + metadata=client.V1ObjectMeta(labels={'app': 'redis-server'}), + spec=client.V1PodSpec( + containers=[container] + ) + ) + deployment_spec = client.V1DeploymentSpec( + replicas=1, + template=template, + selector=client.V1LabelSelector( + match_labels={'app': 'redis-server'} + ) + ) + + deployment = client.V1Deployment( + api_version='apps/v1', + kind='Deployment', + metadata=client.V1ObjectMeta(name='redis-server', labels={'app': 'redis-server'}), + spec=deployment_spec + ) + + api_client = client.AppsV1Api() + try: + api_client.create_namespaced_deployment(self.deploy_parameters.namespace, deployment) + except client.rest.ApiException as e: + print("Got exception: %s\n while creating redis-server", e) + return False + + core_v1_api = client.CoreV1Api() + + service = client.V1Service( + api_version='v1', + kind='Service', + metadata=client.V1ObjectMeta( + name='redis-service' + ), + spec=client.V1ServiceSpec( + selector={'app': 'redis-server'}, + ports=[client.V1ServicePort( + protocol='TCP', + port=6379, + target_port=6379 + )] + ) + ) + + try: + core_v1_api.create_namespaced_service(self.deploy_parameters.namespace, service) + self.deploy_parameters.redis_ip = 'redis-service.{}.svc'.format(self.deploy_parameters.namespace) + self.deploy_parameters.redis_port = 6379 + return True + except client.rest.ApiException as e: + print("Got exception: %s\n while creating a service for redis-server", e) + return False + + def deploy(self) -> bool: + if self.deploy_parameters.synchronized: + return self.create_k8s_job() + else: + return self.create_k8s_deployment() + + def create_k8s_deployment(self) -> bool: + container = client.V1Container( + name="worker", + image=self.deploy_parameters.image, + command=self.deploy_parameters.command, + args=self.deploy_parameters.arguments, + image_pull_policy='Always' + ) + template = client.V1PodTemplateSpec( + metadata=client.V1ObjectMeta(labels={'app': 'worker'}), + spec=client.V1PodSpec( + containers=[container] + ) + ) + deployment_spec = client.V1DeploymentSpec( + replicas=self.deploy_parameters.num_workers, + template=template, + selector=client.V1LabelSelector( + match_labels={'app': 'worker'} + ) + ) + + deployment = client.V1Deployment( + api_version='apps/v1', + kind='Deployment', + metadata=client.V1ObjectMeta(name='rollout-worker'), + spec=deployment_spec + ) + + api_client = client.AppsV1Api() + try: + api_client.create_namespaced_deployment(self.deploy_parameters.namespace, deployment) + return True + except client.rest.ApiException as e: + print("Got exception: %s\n while creating deployment", e) + return False + + def create_k8s_job(self): + pass diff --git a/rl_coach/orchestrators/test.py b/rl_coach/orchestrators/test.py new file mode 100644 index 0000000..56428e0 --- /dev/null +++ b/rl_coach/orchestrators/test.py @@ -0,0 +1,18 @@ +from rl_coach.orchestrators.kubernetes_orchestrator import KubernetesParameters, Kubernetes + +# image = 'gcr.io/constant-cubist-173123/coach:latest' +image = 'ajaysudh/testing:coach' +command = ['python3', 'rl_coach/rollout_worker.py'] +# command = ['sleep', '10h'] + +params = KubernetesParameters(image, command, kubeconfig='~/.kube/config', redis_ip='redis-service.ajay.svc', redis_port=6379, num_workers=10) +# params = KubernetesParameters(image, command, kubeconfig='~/.kube/config') + +obj = Kubernetes(params) +if not obj.setup(): + print("Could not setup") + +if obj.deploy(): + print("Successfully deployed") +else: + print("Could not deploy") diff --git a/rl_coach/presets/CartPole_DQN_distributed.py b/rl_coach/presets/CartPole_DQN_distributed.py new file mode 100644 index 0000000..d3e8513 --- /dev/null +++ b/rl_coach/presets/CartPole_DQN_distributed.py @@ -0,0 +1,71 @@ +from rl_coach.agents.dqn_agent import DQNAgentParametersDistributed +from rl_coach.base_parameters import VisualizationParameters, PresetValidationParameters +from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps, RunPhase +from rl_coach.environments.environment import SelectedPhaseOnlyDumpMethod, MaxDumpMethod +from rl_coach.environments.gym_environment import Mujoco +from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager +from rl_coach.graph_managers.graph_manager import ScheduleParameters +from rl_coach.memories.memory import MemoryGranularity +from rl_coach.schedules import LinearSchedule + + +def construct_graph(redis_ip='localhost', redis_port=6379): + #################### + # Graph Scheduling # + #################### + + schedule_params = ScheduleParameters() + schedule_params.improve_steps = TrainingSteps(10000000000) + schedule_params.steps_between_evaluation_periods = EnvironmentEpisodes(10) + schedule_params.evaluation_steps = EnvironmentEpisodes(1) + schedule_params.heatup_steps = EnvironmentSteps(1000) + + ######### + # Agent # + ######### + agent_params = DQNAgentParametersDistributed() + + # DQN params + agent_params.algorithm.num_steps_between_copying_online_weights_to_target = EnvironmentSteps(100) + agent_params.algorithm.discount = 0.99 + agent_params.algorithm.num_consecutive_playing_steps = EnvironmentSteps(1) + + # NN configuration + agent_params.network_wrappers['main'].learning_rate = 0.00025 + agent_params.network_wrappers['main'].replace_mse_with_huber_loss = False + + # ER size + agent_params.memory.max_size = (MemoryGranularity.Transitions, 40000) + + # E-Greedy schedule + agent_params.exploration.epsilon_schedule = LinearSchedule(1.0, 0.01, 10000) + + # Redis parameters + agent_params.memory.redis_ip = redis_ip + agent_params.memory.redis_port = redis_port + + ################ + # Environment # + ################ + env_params = Mujoco() + env_params.level = 'CartPole-v0' + + vis_params = VisualizationParameters() + vis_params.video_dump_methods = [SelectedPhaseOnlyDumpMethod(RunPhase.TEST), MaxDumpMethod()] + vis_params.dump_mp4 = False + + ######## + # Test # + ######## + preset_validation_params = PresetValidationParameters() + preset_validation_params.test = True + preset_validation_params.min_reward_threshold = 150 + preset_validation_params.max_episodes_to_achieve_reward = 250 + + graph_manager = BasicRLGraphManager(agent_params=agent_params, env_params=env_params, + schedule_params=schedule_params, vis_params=vis_params, + preset_validation_params=preset_validation_params) + return graph_manager + + +graph_manager = construct_graph() diff --git a/rl_coach/rollout_worker.py b/rl_coach/rollout_worker.py index 90035cd..f7a29f0 100644 --- a/rl_coach/rollout_worker.py +++ b/rl_coach/rollout_worker.py @@ -20,6 +20,7 @@ def rollout_worker(graph_manager, checkpoint_dir): task_parameters = TaskParameters() task_parameters.__dict__['checkpoint_restore_dir'] = checkpoint_dir graph_manager.create_graph(task_parameters) + graph_manager.phase = RunPhase.TRAIN graph_manager.act(EnvironmentEpisodes(num_steps=10)) graph_manager.phase = RunPhase.UNDEFINED @@ -44,5 +45,6 @@ def main(): checkpoint_dir=args.checkpoint_dir, ) + if __name__ == '__main__': main() diff --git a/setup.py b/setup.py index 76d8285..e77cd70 100644 --- a/setup.py +++ b/setup.py @@ -47,10 +47,11 @@ here = path.abspath(path.dirname(__file__)) with open(path.join(here, 'README.md'), encoding='utf-8') as f: long_description = f.read() -install_requires=[ - 'annoy==1.8.3', 'Pillow==4.3.0', 'matplotlib==2.0.2', 'numpy==1.14.5', 'pandas==0.22.0', - 'pygame==1.9.3', 'PyOpenGL==3.1.0', 'scipy==0.19.0', 'scikit-image==0.13.0', - 'box2d==2.3.2', 'gym==0.10.5', 'bokeh==0.13.0', 'futures==3.1.1', 'wxPython==4.0.1'] +install_requires = list() + +with open(path.join(here, 'requirements.txt'), 'r') as f: + for line in f: + install_requires.append(line.strip()) # check if system has CUDA enabled GPU p = subprocess.Popen(['command -v nvidia-smi'], stdout=subprocess.PIPE, shell=True)