1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 19:20:19 +01:00

remove kubernetes dependency (#117)

This commit is contained in:
Gal Leibovich
2018-11-18 18:10:22 +02:00
committed by Gal Novik
parent 430e286c56
commit d4d06aaea6
8 changed files with 28 additions and 17 deletions

View File

@@ -52,6 +52,14 @@ class EmbeddingMergerType(Enum):
#ConcatDepthWise = 2 #ConcatDepthWise = 2
#Multiply = 3 #Multiply = 3
class RunType(Enum):
ORCHESTRATOR = "orchestrator"
TRAINER = "trainer"
ROLLOUT_WORKER = "rollout-worker"
def __str__(self):
return self.value
# DistributedCoachSynchronizationType provides the synchronization type for distributed Coach. # DistributedCoachSynchronizationType provides the synchronization type for distributed Coach.
# The default value is None, which means the algorithm or preset cannot be used with distributed Coach. # The default value is None, which means the algorithm or preset cannot be used with distributed Coach.

View File

@@ -28,7 +28,8 @@ import atexit
import time import time
import sys import sys
import json import json
from rl_coach.base_parameters import Frameworks, VisualizationParameters, TaskParameters, DistributedTaskParameters from rl_coach.base_parameters import Frameworks, VisualizationParameters, TaskParameters, DistributedTaskParameters, \
RunType
from multiprocessing import Process from multiprocessing import Process
from multiprocessing.managers import BaseManager from multiprocessing.managers import BaseManager
import subprocess import subprocess
@@ -37,7 +38,6 @@ from rl_coach.utils import list_all_presets, short_dynamic_import, get_open_port
from rl_coach.agents.human_agent import HumanAgentParameters from rl_coach.agents.human_agent import HumanAgentParameters
from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager
from rl_coach.environments.environment import SingleLevelSelection from rl_coach.environments.environment import SingleLevelSelection
from rl_coach.orchestrators.kubernetes_orchestrator import KubernetesParameters, Kubernetes, RunType, RunTypeParameters
from rl_coach.memories.backend.redis import RedisPubSubMemoryBackendParameters from rl_coach.memories.backend.redis import RedisPubSubMemoryBackendParameters
from rl_coach.memories.backend.memory_impl import construct_memory_params from rl_coach.memories.backend.memory_impl import construct_memory_params
from rl_coach.data_stores.data_store import DataStoreParameters from rl_coach.data_stores.data_store import DataStoreParameters
@@ -119,6 +119,9 @@ def handle_distributed_coach_tasks(graph_manager, args):
def handle_distributed_coach_orchestrator(graph_manager, args): def handle_distributed_coach_orchestrator(graph_manager, args):
from rl_coach.orchestrators.kubernetes_orchestrator import KubernetesParameters, Kubernetes, \
RunTypeParameters
ckpt_inside_container = "/checkpoint" ckpt_inside_container = "/checkpoint"
rollout_command = ['python3', 'rl_coach/coach.py', '--distributed_coach_run_type', str(RunType.ROLLOUT_WORKER)] + sys.argv[1:] rollout_command = ['python3', 'rl_coach/coach.py', '--distributed_coach_run_type', str(RunType.ROLLOUT_WORKER)] + sys.argv[1:]
trainer_command = ['python3', 'rl_coach/coach.py', '--distributed_coach_run_type', str(RunType.TRAINER)] + sys.argv[1:] trainer_command = ['python3', 'rl_coach/coach.py', '--distributed_coach_run_type', str(RunType.TRAINER)] + sys.argv[1:]

View File

@@ -1,7 +1,6 @@
import uuid import uuid
from rl_coach.data_stores.data_store import DataStore, DataStoreParameters from rl_coach.data_stores.data_store import DataStore, DataStoreParameters
from kubernetes import client as k8sclient
class NFSDataStoreParameters(DataStoreParameters): class NFSDataStoreParameters(DataStoreParameters):
@@ -37,6 +36,8 @@ class NFSDataStore(DataStore):
return True return True
def get_info(self): def get_info(self):
from kubernetes import client as k8sclient
return k8sclient.V1PersistentVolumeClaimVolumeSource( return k8sclient.V1PersistentVolumeClaimVolumeSource(
claim_name=self.params.pvc_name claim_name=self.params.pvc_name
) )
@@ -58,6 +59,8 @@ class NFSDataStore(DataStore):
pass pass
def deploy_k8s_nfs(self) -> bool: def deploy_k8s_nfs(self) -> bool:
from kubernetes import client as k8sclient
name = "nfs-server-{}".format(uuid.uuid4()) name = "nfs-server-{}".format(uuid.uuid4())
container = k8sclient.V1Container( container = k8sclient.V1Container(
name=name, name=name,
@@ -145,6 +148,8 @@ class NFSDataStore(DataStore):
return True return True
def create_k8s_nfs_resources(self) -> bool: def create_k8s_nfs_resources(self) -> bool:
from kubernetes import client as k8sclient
pv_name = "nfs-ckpt-pv-{}".format(uuid.uuid4()) pv_name = "nfs-ckpt-pv-{}".format(uuid.uuid4())
persistent_volume = k8sclient.V1PersistentVolume( persistent_volume = k8sclient.V1PersistentVolume(
api_version="v1", api_version="v1",
@@ -200,6 +205,8 @@ class NFSDataStore(DataStore):
return True return True
def undeploy_k8s_nfs(self) -> bool: def undeploy_k8s_nfs(self) -> bool:
from kubernetes import client as k8sclient
del_options = k8sclient.V1DeleteOptions() del_options = k8sclient.V1DeleteOptions()
k8s_apps_v1_api_client = k8sclient.AppsV1Api() k8s_apps_v1_api_client = k8sclient.AppsV1Api()
@@ -219,6 +226,8 @@ class NFSDataStore(DataStore):
return True return True
def delete_k8s_nfs_resources(self) -> bool: def delete_k8s_nfs_resources(self) -> bool:
from kubernetes import client as k8sclient
del_options = k8sclient.V1DeleteOptions() del_options = k8sclient.V1DeleteOptions()
k8s_api_client = k8sclient.CoreV1Api() k8s_api_client = k8sclient.CoreV1Api()

View File

@@ -1,5 +1,4 @@
from rl_coach.data_stores.data_store import DataStore, DataStoreParameters from rl_coach.data_stores.data_store import DataStore, DataStoreParameters
from kubernetes import client as k8sclient
from minio import Minio from minio import Minio
from minio.error import ResponseError from minio.error import ResponseError
from configparser import ConfigParser, Error from configparser import ConfigParser, Error

View File

@@ -24,7 +24,7 @@ import contextlib
from rl_coach.base_parameters import iterable_to_items, TaskParameters, DistributedTaskParameters, Frameworks, \ from rl_coach.base_parameters import iterable_to_items, TaskParameters, DistributedTaskParameters, Frameworks, \
VisualizationParameters, \ VisualizationParameters, \
Parameters, PresetValidationParameters Parameters, PresetValidationParameters, RunType
from rl_coach.core_types import TotalStepsCounter, RunPhase, PlayingStepsType, TrainingSteps, EnvironmentEpisodes, \ from rl_coach.core_types import TotalStepsCounter, RunPhase, PlayingStepsType, TrainingSteps, EnvironmentEpisodes, \
EnvironmentSteps, \ EnvironmentSteps, \
StepMethod, Transition StepMethod, Transition
@@ -33,7 +33,6 @@ from rl_coach.level_manager import LevelManager
from rl_coach.logger import screen, Logger from rl_coach.logger import screen, Logger
from rl_coach.utils import set_cpu, start_shell_command_and_wait from rl_coach.utils import set_cpu, start_shell_command_and_wait
from rl_coach.data_stores.data_store_impl import get_data_store from rl_coach.data_stores.data_store_impl import get_data_store
from rl_coach.orchestrators.kubernetes_orchestrator import RunType
from rl_coach.memories.backend.memory_impl import get_memory_backend from rl_coach.memories.backend.memory_impl import get_memory_backend
from rl_coach.data_stores.data_store import SyncFiles from rl_coach.data_stores.data_store import SyncFiles

View File

@@ -3,7 +3,6 @@ import redis
import pickle import pickle
import uuid import uuid
import time import time
from kubernetes import client
from rl_coach.memories.backend.memory import MemoryBackend, MemoryBackendParameters from rl_coach.memories.backend.memory import MemoryBackend, MemoryBackendParameters
from rl_coach.core_types import Transition, Episode, EnvironmentSteps, EnvironmentEpisodes from rl_coach.core_types import Transition, Episode, EnvironmentSteps, EnvironmentEpisodes
@@ -48,6 +47,7 @@ class RedisPubSubBackend(MemoryBackend):
if 'namespace' not in self.params.orchestrator_params: if 'namespace' not in self.params.orchestrator_params:
self.params.orchestrator_params['namespace'] = "default" self.params.orchestrator_params['namespace'] = "default"
from kubernetes import client
container = client.V1Container( container = client.V1Container(
name=self.redis_server_name, name=self.redis_server_name,

View File

@@ -5,6 +5,8 @@ import time
from enum import Enum from enum import Enum
from typing import List from typing import List
from configparser import ConfigParser, Error from configparser import ConfigParser, Error
from rl_coach.base_parameters import RunType
from rl_coach.orchestrators.deploy import Deploy, DeployParameters from rl_coach.orchestrators.deploy import Deploy, DeployParameters
from kubernetes import client as k8sclient, config as k8sconfig from kubernetes import client as k8sclient, config as k8sconfig
from rl_coach.memories.backend.memory import MemoryBackendParameters from rl_coach.memories.backend.memory import MemoryBackendParameters
@@ -13,15 +15,6 @@ from rl_coach.data_stores.data_store import DataStoreParameters
from rl_coach.data_stores.data_store_impl import get_data_store from rl_coach.data_stores.data_store_impl import get_data_store
class RunType(Enum):
ORCHESTRATOR = "orchestrator"
TRAINER = "trainer"
ROLLOUT_WORKER = "rollout-worker"
def __str__(self):
return self.value
class RunTypeParameters(): class RunTypeParameters():
def __init__(self, image: str, command: list(), arguments: list() = None, def __init__(self, image: str, command: list(), arguments: list() = None,

View File

@@ -20,7 +20,6 @@ import pickle
import redis import redis
import numpy as np import numpy as np
from rl_coach.memories.backend.memory_impl import get_memory_backend
class SharedRunningStatsSubscribe(threading.Thread): class SharedRunningStatsSubscribe(threading.Thread):
@@ -49,6 +48,7 @@ class SharedRunningStats(ABC):
self.pubsub = None self.pubsub = None
if pubsub_params: if pubsub_params:
self.channel = "channel-srs-{}".format(self.name) self.channel = "channel-srs-{}".format(self.name)
from rl_coach.memories.backend.memory_impl import get_memory_backend
self.pubsub = get_memory_backend(pubsub_params) self.pubsub = get_memory_backend(pubsub_params)
subscribe_thread = SharedRunningStatsSubscribe(self) subscribe_thread = SharedRunningStatsSubscribe(self)
subscribe_thread.daemon = True subscribe_thread.daemon = True