mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 11:10:20 +01:00
remove kubernetes dependency (#117)
This commit is contained in:
@@ -52,6 +52,14 @@ class EmbeddingMergerType(Enum):
|
||||
#ConcatDepthWise = 2
|
||||
#Multiply = 3
|
||||
|
||||
class RunType(Enum):
|
||||
ORCHESTRATOR = "orchestrator"
|
||||
TRAINER = "trainer"
|
||||
ROLLOUT_WORKER = "rollout-worker"
|
||||
|
||||
def __str__(self):
|
||||
return self.value
|
||||
|
||||
|
||||
# DistributedCoachSynchronizationType provides the synchronization type for distributed Coach.
|
||||
# The default value is None, which means the algorithm or preset cannot be used with distributed Coach.
|
||||
|
||||
@@ -28,7 +28,8 @@ import atexit
|
||||
import time
|
||||
import sys
|
||||
import json
|
||||
from rl_coach.base_parameters import Frameworks, VisualizationParameters, TaskParameters, DistributedTaskParameters
|
||||
from rl_coach.base_parameters import Frameworks, VisualizationParameters, TaskParameters, DistributedTaskParameters, \
|
||||
RunType
|
||||
from multiprocessing import Process
|
||||
from multiprocessing.managers import BaseManager
|
||||
import subprocess
|
||||
@@ -37,7 +38,6 @@ from rl_coach.utils import list_all_presets, short_dynamic_import, get_open_port
|
||||
from rl_coach.agents.human_agent import HumanAgentParameters
|
||||
from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager
|
||||
from rl_coach.environments.environment import SingleLevelSelection
|
||||
from rl_coach.orchestrators.kubernetes_orchestrator import KubernetesParameters, Kubernetes, RunType, RunTypeParameters
|
||||
from rl_coach.memories.backend.redis import RedisPubSubMemoryBackendParameters
|
||||
from rl_coach.memories.backend.memory_impl import construct_memory_params
|
||||
from rl_coach.data_stores.data_store import DataStoreParameters
|
||||
@@ -119,6 +119,9 @@ def handle_distributed_coach_tasks(graph_manager, args):
|
||||
|
||||
|
||||
def handle_distributed_coach_orchestrator(graph_manager, args):
|
||||
from rl_coach.orchestrators.kubernetes_orchestrator import KubernetesParameters, Kubernetes, \
|
||||
RunTypeParameters
|
||||
|
||||
ckpt_inside_container = "/checkpoint"
|
||||
rollout_command = ['python3', 'rl_coach/coach.py', '--distributed_coach_run_type', str(RunType.ROLLOUT_WORKER)] + sys.argv[1:]
|
||||
trainer_command = ['python3', 'rl_coach/coach.py', '--distributed_coach_run_type', str(RunType.TRAINER)] + sys.argv[1:]
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import uuid
|
||||
|
||||
from rl_coach.data_stores.data_store import DataStore, DataStoreParameters
|
||||
from kubernetes import client as k8sclient
|
||||
|
||||
|
||||
class NFSDataStoreParameters(DataStoreParameters):
|
||||
@@ -37,6 +36,8 @@ class NFSDataStore(DataStore):
|
||||
return True
|
||||
|
||||
def get_info(self):
|
||||
from kubernetes import client as k8sclient
|
||||
|
||||
return k8sclient.V1PersistentVolumeClaimVolumeSource(
|
||||
claim_name=self.params.pvc_name
|
||||
)
|
||||
@@ -58,6 +59,8 @@ class NFSDataStore(DataStore):
|
||||
pass
|
||||
|
||||
def deploy_k8s_nfs(self) -> bool:
|
||||
from kubernetes import client as k8sclient
|
||||
|
||||
name = "nfs-server-{}".format(uuid.uuid4())
|
||||
container = k8sclient.V1Container(
|
||||
name=name,
|
||||
@@ -145,6 +148,8 @@ class NFSDataStore(DataStore):
|
||||
return True
|
||||
|
||||
def create_k8s_nfs_resources(self) -> bool:
|
||||
from kubernetes import client as k8sclient
|
||||
|
||||
pv_name = "nfs-ckpt-pv-{}".format(uuid.uuid4())
|
||||
persistent_volume = k8sclient.V1PersistentVolume(
|
||||
api_version="v1",
|
||||
@@ -200,6 +205,8 @@ class NFSDataStore(DataStore):
|
||||
return True
|
||||
|
||||
def undeploy_k8s_nfs(self) -> bool:
|
||||
from kubernetes import client as k8sclient
|
||||
|
||||
del_options = k8sclient.V1DeleteOptions()
|
||||
|
||||
k8s_apps_v1_api_client = k8sclient.AppsV1Api()
|
||||
@@ -219,6 +226,8 @@ class NFSDataStore(DataStore):
|
||||
return True
|
||||
|
||||
def delete_k8s_nfs_resources(self) -> bool:
|
||||
from kubernetes import client as k8sclient
|
||||
|
||||
del_options = k8sclient.V1DeleteOptions()
|
||||
k8s_api_client = k8sclient.CoreV1Api()
|
||||
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
from rl_coach.data_stores.data_store import DataStore, DataStoreParameters
|
||||
from kubernetes import client as k8sclient
|
||||
from minio import Minio
|
||||
from minio.error import ResponseError
|
||||
from configparser import ConfigParser, Error
|
||||
|
||||
@@ -24,7 +24,7 @@ import contextlib
|
||||
|
||||
from rl_coach.base_parameters import iterable_to_items, TaskParameters, DistributedTaskParameters, Frameworks, \
|
||||
VisualizationParameters, \
|
||||
Parameters, PresetValidationParameters
|
||||
Parameters, PresetValidationParameters, RunType
|
||||
from rl_coach.core_types import TotalStepsCounter, RunPhase, PlayingStepsType, TrainingSteps, EnvironmentEpisodes, \
|
||||
EnvironmentSteps, \
|
||||
StepMethod, Transition
|
||||
@@ -33,7 +33,6 @@ from rl_coach.level_manager import LevelManager
|
||||
from rl_coach.logger import screen, Logger
|
||||
from rl_coach.utils import set_cpu, start_shell_command_and_wait
|
||||
from rl_coach.data_stores.data_store_impl import get_data_store
|
||||
from rl_coach.orchestrators.kubernetes_orchestrator import RunType
|
||||
from rl_coach.memories.backend.memory_impl import get_memory_backend
|
||||
from rl_coach.data_stores.data_store import SyncFiles
|
||||
|
||||
|
||||
@@ -3,7 +3,6 @@ import redis
|
||||
import pickle
|
||||
import uuid
|
||||
import time
|
||||
from kubernetes import client
|
||||
|
||||
from rl_coach.memories.backend.memory import MemoryBackend, MemoryBackendParameters
|
||||
from rl_coach.core_types import Transition, Episode, EnvironmentSteps, EnvironmentEpisodes
|
||||
@@ -48,6 +47,7 @@ class RedisPubSubBackend(MemoryBackend):
|
||||
|
||||
if 'namespace' not in self.params.orchestrator_params:
|
||||
self.params.orchestrator_params['namespace'] = "default"
|
||||
from kubernetes import client
|
||||
|
||||
container = client.V1Container(
|
||||
name=self.redis_server_name,
|
||||
|
||||
@@ -5,6 +5,8 @@ import time
|
||||
from enum import Enum
|
||||
from typing import List
|
||||
from configparser import ConfigParser, Error
|
||||
|
||||
from rl_coach.base_parameters import RunType
|
||||
from rl_coach.orchestrators.deploy import Deploy, DeployParameters
|
||||
from kubernetes import client as k8sclient, config as k8sconfig
|
||||
from rl_coach.memories.backend.memory import MemoryBackendParameters
|
||||
@@ -13,15 +15,6 @@ from rl_coach.data_stores.data_store import DataStoreParameters
|
||||
from rl_coach.data_stores.data_store_impl import get_data_store
|
||||
|
||||
|
||||
class RunType(Enum):
|
||||
ORCHESTRATOR = "orchestrator"
|
||||
TRAINER = "trainer"
|
||||
ROLLOUT_WORKER = "rollout-worker"
|
||||
|
||||
def __str__(self):
|
||||
return self.value
|
||||
|
||||
|
||||
class RunTypeParameters():
|
||||
|
||||
def __init__(self, image: str, command: list(), arguments: list() = None,
|
||||
|
||||
@@ -20,7 +20,6 @@ import pickle
|
||||
import redis
|
||||
import numpy as np
|
||||
|
||||
from rl_coach.memories.backend.memory_impl import get_memory_backend
|
||||
|
||||
|
||||
class SharedRunningStatsSubscribe(threading.Thread):
|
||||
@@ -49,6 +48,7 @@ class SharedRunningStats(ABC):
|
||||
self.pubsub = None
|
||||
if pubsub_params:
|
||||
self.channel = "channel-srs-{}".format(self.name)
|
||||
from rl_coach.memories.backend.memory_impl import get_memory_backend
|
||||
self.pubsub = get_memory_backend(pubsub_params)
|
||||
subscribe_thread = SharedRunningStatsSubscribe(self)
|
||||
subscribe_thread.daemon = True
|
||||
|
||||
Reference in New Issue
Block a user