1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 19:20:19 +01:00

remove kubernetes dependency (#117)

This commit is contained in:
Gal Leibovich
2018-11-18 18:10:22 +02:00
committed by Gal Novik
parent 430e286c56
commit d4d06aaea6
8 changed files with 28 additions and 17 deletions

View File

@@ -52,6 +52,14 @@ class EmbeddingMergerType(Enum):
#ConcatDepthWise = 2
#Multiply = 3
class RunType(Enum):
ORCHESTRATOR = "orchestrator"
TRAINER = "trainer"
ROLLOUT_WORKER = "rollout-worker"
def __str__(self):
return self.value
# DistributedCoachSynchronizationType provides the synchronization type for distributed Coach.
# The default value is None, which means the algorithm or preset cannot be used with distributed Coach.

View File

@@ -28,7 +28,8 @@ import atexit
import time
import sys
import json
from rl_coach.base_parameters import Frameworks, VisualizationParameters, TaskParameters, DistributedTaskParameters
from rl_coach.base_parameters import Frameworks, VisualizationParameters, TaskParameters, DistributedTaskParameters, \
RunType
from multiprocessing import Process
from multiprocessing.managers import BaseManager
import subprocess
@@ -37,7 +38,6 @@ from rl_coach.utils import list_all_presets, short_dynamic_import, get_open_port
from rl_coach.agents.human_agent import HumanAgentParameters
from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager
from rl_coach.environments.environment import SingleLevelSelection
from rl_coach.orchestrators.kubernetes_orchestrator import KubernetesParameters, Kubernetes, RunType, RunTypeParameters
from rl_coach.memories.backend.redis import RedisPubSubMemoryBackendParameters
from rl_coach.memories.backend.memory_impl import construct_memory_params
from rl_coach.data_stores.data_store import DataStoreParameters
@@ -119,6 +119,9 @@ def handle_distributed_coach_tasks(graph_manager, args):
def handle_distributed_coach_orchestrator(graph_manager, args):
from rl_coach.orchestrators.kubernetes_orchestrator import KubernetesParameters, Kubernetes, \
RunTypeParameters
ckpt_inside_container = "/checkpoint"
rollout_command = ['python3', 'rl_coach/coach.py', '--distributed_coach_run_type', str(RunType.ROLLOUT_WORKER)] + sys.argv[1:]
trainer_command = ['python3', 'rl_coach/coach.py', '--distributed_coach_run_type', str(RunType.TRAINER)] + sys.argv[1:]

View File

@@ -1,7 +1,6 @@
import uuid
from rl_coach.data_stores.data_store import DataStore, DataStoreParameters
from kubernetes import client as k8sclient
class NFSDataStoreParameters(DataStoreParameters):
@@ -37,6 +36,8 @@ class NFSDataStore(DataStore):
return True
def get_info(self):
from kubernetes import client as k8sclient
return k8sclient.V1PersistentVolumeClaimVolumeSource(
claim_name=self.params.pvc_name
)
@@ -58,6 +59,8 @@ class NFSDataStore(DataStore):
pass
def deploy_k8s_nfs(self) -> bool:
from kubernetes import client as k8sclient
name = "nfs-server-{}".format(uuid.uuid4())
container = k8sclient.V1Container(
name=name,
@@ -145,6 +148,8 @@ class NFSDataStore(DataStore):
return True
def create_k8s_nfs_resources(self) -> bool:
from kubernetes import client as k8sclient
pv_name = "nfs-ckpt-pv-{}".format(uuid.uuid4())
persistent_volume = k8sclient.V1PersistentVolume(
api_version="v1",
@@ -200,6 +205,8 @@ class NFSDataStore(DataStore):
return True
def undeploy_k8s_nfs(self) -> bool:
from kubernetes import client as k8sclient
del_options = k8sclient.V1DeleteOptions()
k8s_apps_v1_api_client = k8sclient.AppsV1Api()
@@ -219,6 +226,8 @@ class NFSDataStore(DataStore):
return True
def delete_k8s_nfs_resources(self) -> bool:
from kubernetes import client as k8sclient
del_options = k8sclient.V1DeleteOptions()
k8s_api_client = k8sclient.CoreV1Api()

View File

@@ -1,5 +1,4 @@
from rl_coach.data_stores.data_store import DataStore, DataStoreParameters
from kubernetes import client as k8sclient
from minio import Minio
from minio.error import ResponseError
from configparser import ConfigParser, Error

View File

@@ -24,7 +24,7 @@ import contextlib
from rl_coach.base_parameters import iterable_to_items, TaskParameters, DistributedTaskParameters, Frameworks, \
VisualizationParameters, \
Parameters, PresetValidationParameters
Parameters, PresetValidationParameters, RunType
from rl_coach.core_types import TotalStepsCounter, RunPhase, PlayingStepsType, TrainingSteps, EnvironmentEpisodes, \
EnvironmentSteps, \
StepMethod, Transition
@@ -33,7 +33,6 @@ from rl_coach.level_manager import LevelManager
from rl_coach.logger import screen, Logger
from rl_coach.utils import set_cpu, start_shell_command_and_wait
from rl_coach.data_stores.data_store_impl import get_data_store
from rl_coach.orchestrators.kubernetes_orchestrator import RunType
from rl_coach.memories.backend.memory_impl import get_memory_backend
from rl_coach.data_stores.data_store import SyncFiles

View File

@@ -3,7 +3,6 @@ import redis
import pickle
import uuid
import time
from kubernetes import client
from rl_coach.memories.backend.memory import MemoryBackend, MemoryBackendParameters
from rl_coach.core_types import Transition, Episode, EnvironmentSteps, EnvironmentEpisodes
@@ -48,6 +47,7 @@ class RedisPubSubBackend(MemoryBackend):
if 'namespace' not in self.params.orchestrator_params:
self.params.orchestrator_params['namespace'] = "default"
from kubernetes import client
container = client.V1Container(
name=self.redis_server_name,

View File

@@ -5,6 +5,8 @@ import time
from enum import Enum
from typing import List
from configparser import ConfigParser, Error
from rl_coach.base_parameters import RunType
from rl_coach.orchestrators.deploy import Deploy, DeployParameters
from kubernetes import client as k8sclient, config as k8sconfig
from rl_coach.memories.backend.memory import MemoryBackendParameters
@@ -13,15 +15,6 @@ from rl_coach.data_stores.data_store import DataStoreParameters
from rl_coach.data_stores.data_store_impl import get_data_store
class RunType(Enum):
ORCHESTRATOR = "orchestrator"
TRAINER = "trainer"
ROLLOUT_WORKER = "rollout-worker"
def __str__(self):
return self.value
class RunTypeParameters():
def __init__(self, image: str, command: list(), arguments: list() = None,

View File

@@ -20,7 +20,6 @@ import pickle
import redis
import numpy as np
from rl_coach.memories.backend.memory_impl import get_memory_backend
class SharedRunningStatsSubscribe(threading.Thread):
@@ -49,6 +48,7 @@ class SharedRunningStats(ABC):
self.pubsub = None
if pubsub_params:
self.channel = "channel-srs-{}".format(self.name)
from rl_coach.memories.backend.memory_impl import get_memory_backend
self.pubsub = get_memory_backend(pubsub_params)
subscribe_thread = SharedRunningStatsSubscribe(self)
subscribe_thread.daemon = True