mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 11:10:20 +01:00
Integrate coach.py params with distributed Coach. (#42)
* Integrate coach.py params with distributed Coach. * Minor improvements - Use enums instead of constants. - Reduce code duplication. - Ask experiment name with timeout.
This commit is contained in:
committed by
GitHub
parent
95b4fc6888
commit
7e7006305a
7
dist-coach-config.template
Normal file
7
dist-coach-config.template
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
[coach]
|
||||||
|
image = <insert-image-name>
|
||||||
|
memory_backend = redispubsub
|
||||||
|
data_store = s3
|
||||||
|
s3_end_point = s3.amazonaws.com
|
||||||
|
s3_bucket_name = <insert-s3-bucket-name>
|
||||||
|
s3_creds_file = <insert-path-for-s3-creds-file>
|
||||||
@@ -53,6 +53,18 @@ class EmbeddingMergerType(Enum):
|
|||||||
#Multiply = 3
|
#Multiply = 3
|
||||||
|
|
||||||
|
|
||||||
|
# DistributedCoachSynchronizationType provides the synchronization type for distributed Coach.
|
||||||
|
# The default value is None, which means the algorithm or preset cannot be used with distributed Coach.
|
||||||
|
class DistributedCoachSynchronizationType(Enum):
|
||||||
|
# In SYNC mode, the trainer waits for all the experiences to be gathered from distributed rollout workers before
|
||||||
|
# training a new policy and the rollout workers wait for a new policy before gathering experiences.
|
||||||
|
SYNC = "sync"
|
||||||
|
|
||||||
|
# In ASYNC mode, the trainer doesn't wait for any set of experiences to be gathered from distributed rollout workers
|
||||||
|
# and the rollout workers continously gather experiences loading new policies, whenever they become available.
|
||||||
|
ASYNC = "async"
|
||||||
|
|
||||||
|
|
||||||
def iterable_to_items(obj):
|
def iterable_to_items(obj):
|
||||||
if isinstance(obj, dict) or isinstance(obj, OrderedDict) or isinstance(obj, types.MappingProxyType):
|
if isinstance(obj, dict) or isinstance(obj, OrderedDict) or isinstance(obj, types.MappingProxyType):
|
||||||
items = obj.items()
|
items = obj.items()
|
||||||
@@ -154,6 +166,9 @@ class AlgorithmParameters(Parameters):
|
|||||||
# intrinsic reward
|
# intrinsic reward
|
||||||
self.scale_external_reward_by_intrinsic_reward_value = False
|
self.scale_external_reward_by_intrinsic_reward_value = False
|
||||||
|
|
||||||
|
# Distributed Coach params
|
||||||
|
self.distributed_coach_synchronization_type = None
|
||||||
|
|
||||||
|
|
||||||
class PresetValidationParameters(Parameters):
|
class PresetValidationParameters(Parameters):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ import sys
|
|||||||
sys.path.append('.')
|
sys.path.append('.')
|
||||||
|
|
||||||
import copy
|
import copy
|
||||||
|
from configparser import ConfigParser, Error
|
||||||
from rl_coach.core_types import EnvironmentSteps
|
from rl_coach.core_types import EnvironmentSteps
|
||||||
import os
|
import os
|
||||||
from rl_coach import logger
|
from rl_coach import logger
|
||||||
@@ -26,6 +27,7 @@ import argparse
|
|||||||
import atexit
|
import atexit
|
||||||
import time
|
import time
|
||||||
import sys
|
import sys
|
||||||
|
import json
|
||||||
from rl_coach.base_parameters import Frameworks, VisualizationParameters, TaskParameters, DistributedTaskParameters
|
from rl_coach.base_parameters import Frameworks, VisualizationParameters, TaskParameters, DistributedTaskParameters
|
||||||
from multiprocessing import Process
|
from multiprocessing import Process
|
||||||
from multiprocessing.managers import BaseManager
|
from multiprocessing.managers import BaseManager
|
||||||
@@ -35,6 +37,14 @@ from rl_coach.utils import list_all_presets, short_dynamic_import, get_open_port
|
|||||||
from rl_coach.agents.human_agent import HumanAgentParameters
|
from rl_coach.agents.human_agent import HumanAgentParameters
|
||||||
from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager
|
from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager
|
||||||
from rl_coach.environments.environment import SingleLevelSelection
|
from rl_coach.environments.environment import SingleLevelSelection
|
||||||
|
from rl_coach.orchestrators.kubernetes_orchestrator import KubernetesParameters, Kubernetes, RunType, RunTypeParameters
|
||||||
|
from rl_coach.memories.backend.redis import RedisPubSubMemoryBackendParameters
|
||||||
|
from rl_coach.memories.backend.memory_impl import construct_memory_params
|
||||||
|
from rl_coach.data_stores.data_store import DataStoreParameters
|
||||||
|
from rl_coach.data_stores.s3_data_store import S3DataStoreParameters
|
||||||
|
from rl_coach.data_stores.data_store_impl import get_data_store, construct_data_store_params
|
||||||
|
from rl_coach.training_worker import training_worker
|
||||||
|
from rl_coach.rollout_worker import rollout_worker, wait_for_checkpoint
|
||||||
|
|
||||||
|
|
||||||
if len(set(failed_imports)) > 0:
|
if len(set(failed_imports)) > 0:
|
||||||
@@ -108,6 +118,7 @@ def display_all_presets_and_exit():
|
|||||||
print(preset)
|
print(preset)
|
||||||
sys.exit(0)
|
sys.exit(0)
|
||||||
|
|
||||||
|
|
||||||
def expand_preset(preset):
|
def expand_preset(preset):
|
||||||
if preset.lower() in [p.lower() for p in list_all_presets()]:
|
if preset.lower() in [p.lower() for p in list_all_presets()]:
|
||||||
preset = "{}.py:graph_manager".format(os.path.join(get_base_dir(), 'presets', preset))
|
preset = "{}.py:graph_manager".format(os.path.join(get_base_dir(), 'presets', preset))
|
||||||
@@ -150,6 +161,49 @@ def parse_arguments(parser: argparse.ArgumentParser) -> argparse.Namespace:
|
|||||||
if args.list:
|
if args.list:
|
||||||
display_all_presets_and_exit()
|
display_all_presets_and_exit()
|
||||||
|
|
||||||
|
# Read args from config file for distributed Coach.
|
||||||
|
if args.distributed_coach and args.distributed_coach_run_type == RunType.ORCHESTRATOR:
|
||||||
|
coach_config = ConfigParser({
|
||||||
|
'image': '',
|
||||||
|
'memory_backend': 'redispubsub',
|
||||||
|
'data_store': 's3',
|
||||||
|
's3_end_point': 's3.amazonaws.com',
|
||||||
|
's3_bucket_name': '',
|
||||||
|
's3_creds_file': ''
|
||||||
|
})
|
||||||
|
try:
|
||||||
|
coach_config.read(args.distributed_coach_config_path)
|
||||||
|
args.image = coach_config.get('coach', 'image')
|
||||||
|
args.memory_backend = coach_config.get('coach', 'memory_backend')
|
||||||
|
args.data_store = coach_config.get('coach', 'data_store')
|
||||||
|
args.s3_end_point = coach_config.get('coach', 's3_end_point')
|
||||||
|
args.s3_bucket_name = coach_config.get('coach', 's3_bucket_name')
|
||||||
|
args.s3_creds_file = coach_config.get('coach', 's3_creds_file')
|
||||||
|
except Error as e:
|
||||||
|
screen.error("Error when reading distributed Coach config file: {}".format(e))
|
||||||
|
|
||||||
|
if args.image == '':
|
||||||
|
screen.error("Image cannot be empty.")
|
||||||
|
|
||||||
|
data_store_choices = ['s3']
|
||||||
|
if args.data_store not in data_store_choices:
|
||||||
|
screen.warning("{} data store is unsupported.".format(args.data_store))
|
||||||
|
screen.error("Supported data stores are {}.".format(data_store_choices))
|
||||||
|
|
||||||
|
memory_backend_choices = ['redispubsub']
|
||||||
|
if args.memory_backend not in memory_backend_choices:
|
||||||
|
screen.warning("{} memory backend is not supported.".format(args.memory_backend))
|
||||||
|
screen.error("Supported memory backends are {}.".format(memory_backend_choices))
|
||||||
|
|
||||||
|
if args.s3_bucket_name == '':
|
||||||
|
screen.error("S3 bucket name cannot be empty.")
|
||||||
|
|
||||||
|
if args.s3_creds_file == '':
|
||||||
|
args.s3_creds_file = None
|
||||||
|
|
||||||
|
if args.play and args.distributed_coach:
|
||||||
|
screen.error("Playing is not supported in distributed Coach.")
|
||||||
|
|
||||||
# replace a short preset name with the full path
|
# replace a short preset name with the full path
|
||||||
if args.preset is not None:
|
if args.preset is not None:
|
||||||
args.preset = expand_preset(args.preset)
|
args.preset = expand_preset(args.preset)
|
||||||
@@ -217,6 +271,94 @@ def start_graph(graph_manager: 'GraphManager', task_parameters: 'TaskParameters'
|
|||||||
graph_manager.improve()
|
graph_manager.improve()
|
||||||
|
|
||||||
|
|
||||||
|
def handle_distributed_coach_tasks(graph_manager, args):
|
||||||
|
ckpt_inside_container = "/checkpoint"
|
||||||
|
|
||||||
|
memory_backend_params = None
|
||||||
|
if args.memory_backend_params:
|
||||||
|
memory_backend_params = json.loads(args.memory_backend_params)
|
||||||
|
memory_backend_params['run_type'] = str(args.distributed_coach_run_type)
|
||||||
|
graph_manager.agent_params.memory.register_var('memory_backend_params', construct_memory_params(memory_backend_params))
|
||||||
|
|
||||||
|
data_store_params = None
|
||||||
|
if args.data_store_params:
|
||||||
|
data_store_params = construct_data_store_params(json.loads(args.data_store_params))
|
||||||
|
data_store_params.checkpoint_dir = ckpt_inside_container
|
||||||
|
graph_manager.data_store_params = data_store_params
|
||||||
|
|
||||||
|
if args.distributed_coach_run_type == RunType.TRAINER:
|
||||||
|
training_worker(
|
||||||
|
graph_manager=graph_manager,
|
||||||
|
checkpoint_dir=ckpt_inside_container
|
||||||
|
)
|
||||||
|
|
||||||
|
if args.distributed_coach_run_type == RunType.ROLLOUT_WORKER:
|
||||||
|
data_store = None
|
||||||
|
if args.data_store_params:
|
||||||
|
data_store = get_data_store(data_store_params)
|
||||||
|
wait_for_checkpoint(checkpoint_dir=ckpt_inside_container, data_store=data_store)
|
||||||
|
|
||||||
|
rollout_worker(
|
||||||
|
graph_manager=graph_manager,
|
||||||
|
checkpoint_dir=ckpt_inside_container,
|
||||||
|
data_store=data_store,
|
||||||
|
num_workers=args.num_workers
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def handle_distributed_coach_orchestrator(graph_manager, args):
|
||||||
|
ckpt_inside_container = "/checkpoint"
|
||||||
|
rollout_command = ['python3', 'rl_coach/coach.py', '--distributed_coach_run_type', str(RunType.ROLLOUT_WORKER)] + sys.argv[1:]
|
||||||
|
trainer_command = ['python3', 'rl_coach/coach.py', '--distributed_coach_run_type', str(RunType.TRAINER)] + sys.argv[1:]
|
||||||
|
|
||||||
|
if '--experiment_name' not in rollout_command:
|
||||||
|
rollout_command = rollout_command + ['--experiment_name', args.experiment_name]
|
||||||
|
|
||||||
|
if '--experiment_name' not in trainer_command:
|
||||||
|
trainer_command = trainer_command + ['--experiment_name', args.experiment_name]
|
||||||
|
|
||||||
|
memory_backend_params = None
|
||||||
|
if args.memory_backend == "redispubsub":
|
||||||
|
memory_backend_params = RedisPubSubMemoryBackendParameters()
|
||||||
|
|
||||||
|
ds_params_instance = None
|
||||||
|
if args.data_store == "s3":
|
||||||
|
ds_params = DataStoreParameters("s3", "", "")
|
||||||
|
ds_params_instance = S3DataStoreParameters(ds_params=ds_params, end_point=args.s3_end_point, bucket_name=args.s3_bucket_name,
|
||||||
|
creds_file=args.s3_creds_file, checkpoint_dir=ckpt_inside_container)
|
||||||
|
|
||||||
|
worker_run_type_params = RunTypeParameters(args.image, rollout_command, run_type=str(RunType.ROLLOUT_WORKER), num_replicas=args.num_workers)
|
||||||
|
trainer_run_type_params = RunTypeParameters(args.image, trainer_command, run_type=str(RunType.TRAINER))
|
||||||
|
|
||||||
|
orchestration_params = KubernetesParameters([worker_run_type_params, trainer_run_type_params],
|
||||||
|
kubeconfig='~/.kube/config',
|
||||||
|
memory_backend_parameters=memory_backend_params,
|
||||||
|
data_store_params=ds_params_instance)
|
||||||
|
orchestrator = Kubernetes(orchestration_params)
|
||||||
|
if not orchestrator.setup():
|
||||||
|
print("Could not setup.")
|
||||||
|
return
|
||||||
|
|
||||||
|
if orchestrator.deploy_trainer():
|
||||||
|
print("Successfully deployed trainer.")
|
||||||
|
else:
|
||||||
|
print("Could not deploy trainer.")
|
||||||
|
return
|
||||||
|
|
||||||
|
if orchestrator.deploy_worker():
|
||||||
|
print("Successfully deployed rollout worker(s).")
|
||||||
|
else:
|
||||||
|
print("Could not deploy rollout worker(s).")
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
orchestrator.trainer_logs()
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
pass
|
||||||
|
|
||||||
|
orchestrator.undeploy()
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('-p', '--preset',
|
parser.add_argument('-p', '--preset',
|
||||||
@@ -329,11 +471,35 @@ def main():
|
|||||||
help="(int) A seed to use for running the experiment",
|
help="(int) A seed to use for running the experiment",
|
||||||
default=None,
|
default=None,
|
||||||
type=int)
|
type=int)
|
||||||
|
parser.add_argument('-dc', '--distributed_coach',
|
||||||
|
help="(flag) Use distributed Coach.",
|
||||||
|
action='store_true')
|
||||||
|
parser.add_argument('-dcp', '--distributed_coach_config_path',
|
||||||
|
help="(string) Path to config file when using distributed rollout workers."
|
||||||
|
"Only distributed Coach parameters should be provided through this config file."
|
||||||
|
"Rest of the parameters are provided using Coach command line options."
|
||||||
|
"Used only with --distributed_coach flag."
|
||||||
|
"Ignored if --distributed_coach flag is not used.",
|
||||||
|
type=str)
|
||||||
|
parser.add_argument('--memory_backend_params',
|
||||||
|
help=argparse.SUPPRESS,
|
||||||
|
type=str)
|
||||||
|
parser.add_argument('--data_store_params',
|
||||||
|
help=argparse.SUPPRESS,
|
||||||
|
type=str)
|
||||||
|
parser.add_argument('--distributed_coach_run_type',
|
||||||
|
help=argparse.SUPPRESS,
|
||||||
|
type=RunType,
|
||||||
|
default=RunType.ORCHESTRATOR,
|
||||||
|
choices=list(RunType))
|
||||||
|
|
||||||
args = parse_arguments(parser)
|
args = parse_arguments(parser)
|
||||||
|
|
||||||
graph_manager = get_graph_manager_from_args(args)
|
graph_manager = get_graph_manager_from_args(args)
|
||||||
|
|
||||||
|
if args.distributed_coach and not graph_manager.agent_params.algorithm.distributed_coach_synchronization_type:
|
||||||
|
screen.error("{} preset is not supported using distributed Coach.".format(args.preset))
|
||||||
|
|
||||||
# Intel optimized TF seems to run significantly faster when limiting to a single OMP thread.
|
# Intel optimized TF seems to run significantly faster when limiting to a single OMP thread.
|
||||||
# This will not affect GPU runs.
|
# This will not affect GPU runs.
|
||||||
os.environ["OMP_NUM_THREADS"] = "1"
|
os.environ["OMP_NUM_THREADS"] = "1"
|
||||||
@@ -343,7 +509,7 @@ def main():
|
|||||||
os.environ['TF_CPP_MIN_LOG_LEVEL'] = str(args.tf_verbosity)
|
os.environ['TF_CPP_MIN_LOG_LEVEL'] = str(args.tf_verbosity)
|
||||||
|
|
||||||
# turn off the summary at the end of the run if necessary
|
# turn off the summary at the end of the run if necessary
|
||||||
if not args.no_summary:
|
if not args.no_summary and not args.distributed_coach:
|
||||||
atexit.register(logger.summarize_experiment)
|
atexit.register(logger.summarize_experiment)
|
||||||
screen.change_terminal_title(args.experiment_name)
|
screen.change_terminal_title(args.experiment_name)
|
||||||
|
|
||||||
@@ -351,6 +517,14 @@ def main():
|
|||||||
if args.open_dashboard:
|
if args.open_dashboard:
|
||||||
open_dashboard(args.experiment_path)
|
open_dashboard(args.experiment_path)
|
||||||
|
|
||||||
|
if args.distributed_coach and args.distributed_coach_run_type != RunType.ORCHESTRATOR:
|
||||||
|
handle_distributed_coach_tasks(graph_manager, args)
|
||||||
|
return
|
||||||
|
|
||||||
|
if args.distributed_coach and args.distributed_coach_run_type == RunType.ORCHESTRATOR:
|
||||||
|
handle_distributed_coach_orchestrator(graph_manager, args)
|
||||||
|
return
|
||||||
|
|
||||||
# Single-threaded runs
|
# Single-threaded runs
|
||||||
if args.num_workers == 1:
|
if args.num_workers == 1:
|
||||||
# Start the training or evaluation
|
# Start the training or evaluation
|
||||||
|
|||||||
@@ -33,6 +33,7 @@ from rl_coach.level_manager import LevelManager
|
|||||||
from rl_coach.logger import screen, Logger
|
from rl_coach.logger import screen, Logger
|
||||||
from rl_coach.utils import set_cpu, start_shell_command_and_wait
|
from rl_coach.utils import set_cpu, start_shell_command_and_wait
|
||||||
from rl_coach.data_stores.data_store_impl import get_data_store
|
from rl_coach.data_stores.data_store_impl import get_data_store
|
||||||
|
from rl_coach.orchestrators.kubernetes_orchestrator import RunType
|
||||||
|
|
||||||
|
|
||||||
class ScheduleParameters(Parameters):
|
class ScheduleParameters(Parameters):
|
||||||
@@ -361,9 +362,10 @@ class GraphManager(object):
|
|||||||
self.verify_graph_was_created()
|
self.verify_graph_was_created()
|
||||||
|
|
||||||
if hasattr(self, 'data_store_params') and hasattr(self.agent_params.memory, 'memory_backend_params'):
|
if hasattr(self, 'data_store_params') and hasattr(self.agent_params.memory, 'memory_backend_params'):
|
||||||
if self.agent_params.memory.memory_backend_params.run_type == "worker":
|
if self.agent_params.memory.memory_backend_params.run_type == str(RunType.ROLLOUT_WORKER):
|
||||||
data_store = get_data_store(self.data_store_params)
|
data_store = get_data_store(self.data_store_params)
|
||||||
data_store.load_from_store()
|
data_store.load_from_store()
|
||||||
|
|
||||||
# perform several steps of playing
|
# perform several steps of playing
|
||||||
count_end = self.current_step_counter + steps
|
count_end = self.current_step_counter + steps
|
||||||
while self.current_step_counter < count_end:
|
while self.current_step_counter < count_end:
|
||||||
|
|||||||
@@ -18,7 +18,9 @@ import datetime
|
|||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import shutil
|
import shutil
|
||||||
|
import signal
|
||||||
import time
|
import time
|
||||||
|
import uuid
|
||||||
from subprocess import Popen, PIPE
|
from subprocess import Popen, PIPE
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
@@ -90,6 +92,23 @@ class ScreenLogger(object):
|
|||||||
def ask_input(self, title):
|
def ask_input(self, title):
|
||||||
return input("{}{}{}".format(Colors.BG_CYAN, title, Colors.END))
|
return input("{}{}{}".format(Colors.BG_CYAN, title, Colors.END))
|
||||||
|
|
||||||
|
def ask_input_with_timeout(self, title, timeout, msg_if_timeout='Timeout expired.'):
|
||||||
|
class TimeoutExpired(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def timeout_alarm_handler(signum, frame):
|
||||||
|
raise TimeoutExpired
|
||||||
|
|
||||||
|
signal.signal(signal.SIGALRM, timeout_alarm_handler)
|
||||||
|
signal.alarm(timeout)
|
||||||
|
|
||||||
|
try:
|
||||||
|
return input("{}{}{}".format(Colors.BG_CYAN, title, Colors.END))
|
||||||
|
except TimeoutExpired:
|
||||||
|
self.warning(msg_if_timeout)
|
||||||
|
finally:
|
||||||
|
signal.alarm(0)
|
||||||
|
|
||||||
def ask_yes_no(self, title: str, default: Union[None, bool] = None):
|
def ask_yes_no(self, title: str, default: Union[None, bool] = None):
|
||||||
"""
|
"""
|
||||||
Ask the user for a yes / no question and return True if the answer is yes and False otherwise.
|
Ask the user for a yes / no question and return True if the answer is yes and False otherwise.
|
||||||
@@ -333,10 +352,14 @@ def get_experiment_name(initial_experiment_name=''):
|
|||||||
match = None
|
match = None
|
||||||
while match is None:
|
while match is None:
|
||||||
if initial_experiment_name == '':
|
if initial_experiment_name == '':
|
||||||
experiment_name = screen.ask_input("Please enter an experiment name: ")
|
msg_if_timeout = "Timeout waiting for experiement name."
|
||||||
|
experiment_name = screen.ask_input_with_timeout("Please enter an experiment name: ", 60, msg_if_timeout)
|
||||||
else:
|
else:
|
||||||
experiment_name = initial_experiment_name
|
experiment_name = initial_experiment_name
|
||||||
|
|
||||||
|
if not experiment_name:
|
||||||
|
experiment_name = ''
|
||||||
|
|
||||||
experiment_name = experiment_name.replace(" ", "_")
|
experiment_name = experiment_name.replace(" ", "_")
|
||||||
match = re.match("^$|^[\w -/]{1,1000}$", experiment_name)
|
match = re.match("^$|^[\w -/]{1,1000}$", experiment_name)
|
||||||
|
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ import os
|
|||||||
import uuid
|
import uuid
|
||||||
import json
|
import json
|
||||||
import time
|
import time
|
||||||
|
from enum import Enum
|
||||||
from typing import List
|
from typing import List
|
||||||
from configparser import ConfigParser, Error
|
from configparser import ConfigParser, Error
|
||||||
from rl_coach.orchestrators.deploy import Deploy, DeployParameters
|
from rl_coach.orchestrators.deploy import Deploy, DeployParameters
|
||||||
@@ -12,10 +13,19 @@ from rl_coach.data_stores.data_store import DataStoreParameters
|
|||||||
from rl_coach.data_stores.data_store_impl import get_data_store
|
from rl_coach.data_stores.data_store_impl import get_data_store
|
||||||
|
|
||||||
|
|
||||||
|
class RunType(Enum):
|
||||||
|
ORCHESTRATOR = "orchestrator"
|
||||||
|
TRAINER = "trainer"
|
||||||
|
ROLLOUT_WORKER = "rollout-worker"
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return self.value
|
||||||
|
|
||||||
|
|
||||||
class RunTypeParameters():
|
class RunTypeParameters():
|
||||||
|
|
||||||
def __init__(self, image: str, command: list(), arguments: list() = None,
|
def __init__(self, image: str, command: list(), arguments: list() = None,
|
||||||
run_type: str = "trainer", checkpoint_dir: str = "/checkpoint",
|
run_type: str = str(RunType.TRAINER), checkpoint_dir: str = "/checkpoint",
|
||||||
num_replicas: int = 1, orchestration_params: dict=None):
|
num_replicas: int = 1, orchestration_params: dict=None):
|
||||||
self.image = image
|
self.image = image
|
||||||
self.command = command
|
self.command = command
|
||||||
@@ -97,12 +107,12 @@ class Kubernetes(Deploy):
|
|||||||
|
|
||||||
def deploy_trainer(self) -> bool:
|
def deploy_trainer(self) -> bool:
|
||||||
|
|
||||||
trainer_params = self.params.run_type_params.get('trainer', None)
|
trainer_params = self.params.run_type_params.get(str(RunType.TRAINER), None)
|
||||||
if not trainer_params:
|
if not trainer_params:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
trainer_params.command += ['--memory-backend-params', json.dumps(self.params.memory_backend_parameters.__dict__)]
|
trainer_params.command += ['--memory_backend_params', json.dumps(self.params.memory_backend_parameters.__dict__)]
|
||||||
trainer_params.command += ['--data-store-params', json.dumps(self.params.data_store_params.__dict__)]
|
trainer_params.command += ['--data_store_params', json.dumps(self.params.data_store_params.__dict__)]
|
||||||
|
|
||||||
name = "{}-{}".format(trainer_params.run_type, uuid.uuid4())
|
name = "{}-{}".format(trainer_params.run_type, uuid.uuid4())
|
||||||
|
|
||||||
@@ -175,13 +185,13 @@ class Kubernetes(Deploy):
|
|||||||
|
|
||||||
def deploy_worker(self):
|
def deploy_worker(self):
|
||||||
|
|
||||||
worker_params = self.params.run_type_params.get('worker', None)
|
worker_params = self.params.run_type_params.get(str(RunType.ROLLOUT_WORKER), None)
|
||||||
if not worker_params:
|
if not worker_params:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
worker_params.command += ['--memory-backend-params', json.dumps(self.params.memory_backend_parameters.__dict__)]
|
worker_params.command += ['--memory_backend_params', json.dumps(self.params.memory_backend_parameters.__dict__)]
|
||||||
worker_params.command += ['--data-store-params', json.dumps(self.params.data_store_params.__dict__)]
|
worker_params.command += ['--data_store_params', json.dumps(self.params.data_store_params.__dict__)]
|
||||||
worker_params.command += ['--num-workers', '{}'.format(worker_params.num_replicas)]
|
worker_params.command += ['--num_workers', '{}'.format(worker_params.num_replicas)]
|
||||||
|
|
||||||
name = "{}-{}".format(worker_params.run_type, uuid.uuid4())
|
name = "{}-{}".format(worker_params.run_type, uuid.uuid4())
|
||||||
|
|
||||||
@@ -255,7 +265,7 @@ class Kubernetes(Deploy):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
def trainer_logs(self):
|
def trainer_logs(self):
|
||||||
trainer_params = self.params.run_type_params.get('trainer', None)
|
trainer_params = self.params.run_type_params.get(str(RunType.TRAINER), None)
|
||||||
if not trainer_params:
|
if not trainer_params:
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -313,7 +323,7 @@ class Kubernetes(Deploy):
|
|||||||
return
|
return
|
||||||
|
|
||||||
def undeploy(self):
|
def undeploy(self):
|
||||||
trainer_params = self.params.run_type_params.get('trainer', None)
|
trainer_params = self.params.run_type_params.get(str(RunType.TRAINER), None)
|
||||||
api_client = k8sclient.AppsV1Api()
|
api_client = k8sclient.AppsV1Api()
|
||||||
delete_options = k8sclient.V1DeleteOptions()
|
delete_options = k8sclient.V1DeleteOptions()
|
||||||
if trainer_params:
|
if trainer_params:
|
||||||
@@ -321,7 +331,7 @@ class Kubernetes(Deploy):
|
|||||||
api_client.delete_namespaced_deployment(trainer_params.orchestration_params['deployment_name'], self.params.namespace, delete_options)
|
api_client.delete_namespaced_deployment(trainer_params.orchestration_params['deployment_name'], self.params.namespace, delete_options)
|
||||||
except k8sclient.rest.ApiException as e:
|
except k8sclient.rest.ApiException as e:
|
||||||
print("Got exception: %s\n while deleting trainer", e)
|
print("Got exception: %s\n while deleting trainer", e)
|
||||||
worker_params = self.params.run_type_params.get('worker', None)
|
worker_params = self.params.run_type_params.get(str(RunType.ROLLOUT_WORKER), None)
|
||||||
if worker_params:
|
if worker_params:
|
||||||
try:
|
try:
|
||||||
api_client.delete_namespaced_deployment(worker_params.orchestration_params['deployment_name'], self.params.namespace, delete_options)
|
api_client.delete_namespaced_deployment(worker_params.orchestration_params['deployment_name'], self.params.namespace, delete_options)
|
||||||
|
|||||||
@@ -1,120 +0,0 @@
|
|||||||
import argparse
|
|
||||||
|
|
||||||
from rl_coach.orchestrators.kubernetes_orchestrator import KubernetesParameters, Kubernetes, RunTypeParameters
|
|
||||||
from rl_coach.memories.backend.redis import RedisPubSubMemoryBackendParameters
|
|
||||||
from rl_coach.data_stores.data_store import DataStoreParameters
|
|
||||||
from rl_coach.data_stores.s3_data_store import S3DataStoreParameters
|
|
||||||
from rl_coach.data_stores.nfs_data_store import NFSDataStoreParameters
|
|
||||||
|
|
||||||
|
|
||||||
def main(preset: str, image: str='ajaysudh/testing:coach', num_workers: int=1, nfs_server: str=None, nfs_path: str=None,
|
|
||||||
memory_backend: str=None, data_store: str=None, s3_end_point: str=None, s3_bucket_name: str=None,
|
|
||||||
s3_creds_file: str=None, policy_type: str="OFF"):
|
|
||||||
rollout_command = ['python3', 'rl_coach/rollout_worker.py', '-p', preset, '--policy-type', policy_type]
|
|
||||||
training_command = ['python3', 'rl_coach/training_worker.py', '-p', preset, '--policy-type', policy_type]
|
|
||||||
|
|
||||||
memory_backend_params = None
|
|
||||||
if memory_backend == "redispubsub":
|
|
||||||
memory_backend_params = RedisPubSubMemoryBackendParameters()
|
|
||||||
|
|
||||||
ds_params_instance = None
|
|
||||||
if data_store == "s3":
|
|
||||||
ds_params = DataStoreParameters("s3", "", "")
|
|
||||||
ds_params_instance = S3DataStoreParameters(ds_params=ds_params, end_point=s3_end_point, bucket_name=s3_bucket_name,
|
|
||||||
creds_file=s3_creds_file, checkpoint_dir="/checkpoint")
|
|
||||||
elif data_store == "nfs":
|
|
||||||
ds_params = DataStoreParameters("nfs", "kubernetes", {"namespace": "default"})
|
|
||||||
ds_params_instance = NFSDataStoreParameters(ds_params)
|
|
||||||
|
|
||||||
worker_run_type_params = RunTypeParameters(image, rollout_command, run_type="worker", num_replicas=num_workers)
|
|
||||||
trainer_run_type_params = RunTypeParameters(image, training_command, run_type="trainer")
|
|
||||||
|
|
||||||
orchestration_params = KubernetesParameters([worker_run_type_params, trainer_run_type_params],
|
|
||||||
kubeconfig='~/.kube/config', nfs_server=nfs_server, nfs_path=nfs_path,
|
|
||||||
memory_backend_parameters=memory_backend_params,
|
|
||||||
data_store_params=ds_params_instance)
|
|
||||||
orchestrator = Kubernetes(orchestration_params)
|
|
||||||
if not orchestrator.setup():
|
|
||||||
print("Could not setup")
|
|
||||||
return
|
|
||||||
|
|
||||||
if orchestrator.deploy_trainer():
|
|
||||||
print("Successfully deployed")
|
|
||||||
else:
|
|
||||||
print("Could not deploy")
|
|
||||||
return
|
|
||||||
|
|
||||||
if orchestrator.deploy_worker():
|
|
||||||
print("Successfully deployed")
|
|
||||||
else:
|
|
||||||
print("Could not deploy")
|
|
||||||
return
|
|
||||||
|
|
||||||
try:
|
|
||||||
orchestrator.trainer_logs()
|
|
||||||
except KeyboardInterrupt:
|
|
||||||
pass
|
|
||||||
orchestrator.undeploy()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
parser.add_argument('--image',
|
|
||||||
help="(string) Name of a docker image.",
|
|
||||||
type=str,
|
|
||||||
required=True)
|
|
||||||
parser.add_argument('-p', '--preset',
|
|
||||||
help="(string) Name of a preset to run (class name from the 'presets' directory).",
|
|
||||||
type=str,
|
|
||||||
required=True)
|
|
||||||
parser.add_argument('--memory-backend',
|
|
||||||
help="(string) Memory backend to use.",
|
|
||||||
type=str,
|
|
||||||
choices=['redispubsub'],
|
|
||||||
default="redispubsub")
|
|
||||||
parser.add_argument('--data-store',
|
|
||||||
help="(string) Data store to use.",
|
|
||||||
type=str,
|
|
||||||
choices=['s3', 'nfs'],
|
|
||||||
default="s3")
|
|
||||||
parser.add_argument('--nfs-server',
|
|
||||||
help="(string) Addresss of the nfs server.",
|
|
||||||
type=str,
|
|
||||||
required=False)
|
|
||||||
parser.add_argument('--nfs-path',
|
|
||||||
help="(string) Exported path for the nfs server.",
|
|
||||||
type=str,
|
|
||||||
required=False)
|
|
||||||
parser.add_argument('--s3-end-point',
|
|
||||||
help="(string) S3 endpoint to use when S3 data store is used.",
|
|
||||||
type=str,
|
|
||||||
required=False)
|
|
||||||
parser.add_argument('--s3-bucket-name',
|
|
||||||
help="(string) S3 bucket name to use when S3 data store is used.",
|
|
||||||
type=str,
|
|
||||||
required=False)
|
|
||||||
parser.add_argument('--s3-creds-file',
|
|
||||||
help="(string) S3 credentials file to use when S3 data store is used.",
|
|
||||||
type=str,
|
|
||||||
required=False)
|
|
||||||
parser.add_argument('--num-workers',
|
|
||||||
help="(string) Number of rollout workers.",
|
|
||||||
type=int,
|
|
||||||
required=False,
|
|
||||||
default=1)
|
|
||||||
parser.add_argument('--policy-type',
|
|
||||||
help="(string) The type of policy: OFF/ON.",
|
|
||||||
type=str,
|
|
||||||
choices=['ON', 'OFF'],
|
|
||||||
default='OFF')
|
|
||||||
|
|
||||||
# parser.add_argument('--checkpoint_dir',
|
|
||||||
# help='(string) Path to a folder containing a checkpoint to write the model to.',
|
|
||||||
# type=str,
|
|
||||||
# default='/checkpoint')
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
main(preset=args.preset, image=args.image, nfs_server=args.nfs_server, nfs_path=args.nfs_path,
|
|
||||||
memory_backend=args.memory_backend, data_store=args.data_store, s3_end_point=args.s3_end_point,
|
|
||||||
s3_bucket_name=args.s3_bucket_name, s3_creds_file=args.s3_creds_file, num_workers=args.num_workers,
|
|
||||||
policy_type=args.policy_type)
|
|
||||||
@@ -1,18 +0,0 @@
|
|||||||
from rl_coach.orchestrators.kubernetes_orchestrator import KubernetesParameters, Kubernetes
|
|
||||||
|
|
||||||
# image = 'gcr.io/constant-cubist-173123/coach:latest'
|
|
||||||
image = 'ajaysudh/testing:coach'
|
|
||||||
command = ['python3', 'rl_coach/rollout_worker.py', '-p', 'CartPole_DQN_distributed']
|
|
||||||
# command = ['sleep', '10h']
|
|
||||||
|
|
||||||
params = KubernetesParameters(image, command, kubeconfig='~/.kube/config', redis_ip='redis-service.ajay.svc', redis_port=6379, num_workers=1)
|
|
||||||
# params = KubernetesParameters(image, command, kubeconfig='~/.kube/config')
|
|
||||||
|
|
||||||
obj = Kubernetes(params)
|
|
||||||
if not obj.setup():
|
|
||||||
print("Could not setup")
|
|
||||||
|
|
||||||
if obj.deploy():
|
|
||||||
print("Successfully deployed")
|
|
||||||
else:
|
|
||||||
print("Could not deploy")
|
|
||||||
@@ -1,5 +1,5 @@
|
|||||||
from rl_coach.agents.dqn_agent import DQNAgentParameters
|
from rl_coach.agents.dqn_agent import DQNAgentParameters
|
||||||
from rl_coach.base_parameters import VisualizationParameters, PresetValidationParameters
|
from rl_coach.base_parameters import VisualizationParameters, PresetValidationParameters, DistributedCoachSynchronizationType
|
||||||
from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps
|
from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps
|
||||||
from rl_coach.environments.gym_environment import GymVectorEnvironment
|
from rl_coach.environments.gym_environment import GymVectorEnvironment
|
||||||
from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager
|
from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
from rl_coach.agents.clipped_ppo_agent import ClippedPPOAgentParameters
|
from rl_coach.agents.clipped_ppo_agent import ClippedPPOAgentParameters
|
||||||
from rl_coach.architectures.tensorflow_components.layers import Dense
|
from rl_coach.architectures.tensorflow_components.layers import Dense
|
||||||
from rl_coach.base_parameters import VisualizationParameters, PresetValidationParameters
|
from rl_coach.base_parameters import VisualizationParameters, PresetValidationParameters, DistributedCoachSynchronizationType
|
||||||
from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps, RunPhase
|
from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps, RunPhase
|
||||||
from rl_coach.environments.gym_environment import GymVectorEnvironment, mujoco_v2
|
from rl_coach.environments.gym_environment import GymVectorEnvironment, mujoco_v2
|
||||||
from rl_coach.exploration_policies.additive_noise import AdditiveNoiseParameters
|
from rl_coach.exploration_policies.additive_noise import AdditiveNoiseParameters
|
||||||
@@ -44,6 +44,9 @@ agent_params.algorithm.optimization_epochs = 10
|
|||||||
agent_params.algorithm.estimate_state_value_using_gae = True
|
agent_params.algorithm.estimate_state_value_using_gae = True
|
||||||
agent_params.algorithm.num_steps_between_copying_online_weights_to_target = EnvironmentSteps(2048)
|
agent_params.algorithm.num_steps_between_copying_online_weights_to_target = EnvironmentSteps(2048)
|
||||||
|
|
||||||
|
# Distributed Coach synchronization type.
|
||||||
|
agent_params.algorithm.distributed_coach_synchronization_type = DistributedCoachSynchronizationType.SYNC
|
||||||
|
|
||||||
agent_params.exploration = EGreedyParameters()
|
agent_params.exploration = EGreedyParameters()
|
||||||
agent_params.exploration.epsilon_schedule = LinearSchedule(1.0, 0.01, 10000)
|
agent_params.exploration.epsilon_schedule = LinearSchedule(1.0, 0.01, 10000)
|
||||||
agent_params.pre_network_filter.add_observation_filter('observation', 'normalize_observation',
|
agent_params.pre_network_filter.add_observation_filter('observation', 'normalize_observation',
|
||||||
|
|||||||
@@ -7,28 +7,16 @@ this rollout worker:
|
|||||||
- exits
|
- exits
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
|
||||||
import time
|
import time
|
||||||
import os
|
import os
|
||||||
import json
|
|
||||||
import math
|
import math
|
||||||
|
|
||||||
from threading import Thread
|
from rl_coach.base_parameters import TaskParameters, DistributedCoachSynchronizationType
|
||||||
|
|
||||||
from rl_coach.base_parameters import TaskParameters
|
|
||||||
from rl_coach.coach import expand_preset
|
|
||||||
from rl_coach.core_types import EnvironmentSteps, RunPhase
|
from rl_coach.core_types import EnvironmentSteps, RunPhase
|
||||||
from rl_coach.utils import short_dynamic_import
|
|
||||||
from rl_coach.memories.backend.memory_impl import construct_memory_params
|
|
||||||
from rl_coach.data_stores.data_store_impl import get_data_store, construct_data_store_params
|
|
||||||
from google.protobuf import text_format
|
from google.protobuf import text_format
|
||||||
from tensorflow.python.training.checkpoint_state_pb2 import CheckpointState
|
from tensorflow.python.training.checkpoint_state_pb2 import CheckpointState
|
||||||
|
|
||||||
|
|
||||||
# Q: specify alternative distributed memory, or should this go in the preset?
|
|
||||||
# A: preset must define distributed memory to be used. we aren't going to take
|
|
||||||
# a non-distributed preset and automatically distribute it.
|
|
||||||
|
|
||||||
def has_checkpoint(checkpoint_dir):
|
def has_checkpoint(checkpoint_dir):
|
||||||
"""
|
"""
|
||||||
True if a checkpoint is present in checkpoint_dir
|
True if a checkpoint is present in checkpoint_dir
|
||||||
@@ -39,6 +27,7 @@ def has_checkpoint(checkpoint_dir):
|
|||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def wait_for_checkpoint(checkpoint_dir, data_store=None, timeout=10):
|
def wait_for_checkpoint(checkpoint_dir, data_store=None, timeout=10):
|
||||||
"""
|
"""
|
||||||
block until there is a checkpoint in checkpoint_dir
|
block until there is a checkpoint in checkpoint_dir
|
||||||
@@ -79,7 +68,7 @@ def get_latest_checkpoint(checkpoint_dir):
|
|||||||
return int(rel_path.split('_Step')[0])
|
return int(rel_path.split('_Step')[0])
|
||||||
|
|
||||||
|
|
||||||
def rollout_worker(graph_manager, checkpoint_dir, data_store, num_workers, policy_type):
|
def rollout_worker(graph_manager, checkpoint_dir, data_store, num_workers):
|
||||||
"""
|
"""
|
||||||
wait for first checkpoint then perform rollouts using the model
|
wait for first checkpoint then perform rollouts using the model
|
||||||
"""
|
"""
|
||||||
@@ -102,7 +91,7 @@ def rollout_worker(graph_manager, checkpoint_dir, data_store, num_workers, polic
|
|||||||
|
|
||||||
new_checkpoint = get_latest_checkpoint(checkpoint_dir)
|
new_checkpoint = get_latest_checkpoint(checkpoint_dir)
|
||||||
|
|
||||||
if policy_type == 'ON':
|
if graph_manager.agent_params.algorithm.distributed_coach_synchronization_type == DistributedCoachSynchronizationType.SYNC:
|
||||||
while new_checkpoint < last_checkpoint + 1:
|
while new_checkpoint < last_checkpoint + 1:
|
||||||
if data_store:
|
if data_store:
|
||||||
data_store.load_from_store()
|
data_store.load_from_store()
|
||||||
@@ -110,64 +99,8 @@ def rollout_worker(graph_manager, checkpoint_dir, data_store, num_workers, polic
|
|||||||
|
|
||||||
graph_manager.restore_checkpoint()
|
graph_manager.restore_checkpoint()
|
||||||
|
|
||||||
if policy_type == "OFF":
|
if graph_manager.agent_params.algorithm.distributed_coach_synchronization_type == DistributedCoachSynchronizationType.ASYNC:
|
||||||
|
|
||||||
if new_checkpoint > last_checkpoint:
|
if new_checkpoint > last_checkpoint:
|
||||||
graph_manager.restore_checkpoint()
|
graph_manager.restore_checkpoint()
|
||||||
|
|
||||||
last_checkpoint = new_checkpoint
|
last_checkpoint = new_checkpoint
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
parser.add_argument('-p', '--preset',
|
|
||||||
help="(string) Name of a preset to run (class name from the 'presets' directory.)",
|
|
||||||
type=str,
|
|
||||||
required=True)
|
|
||||||
parser.add_argument('--checkpoint-dir',
|
|
||||||
help='(string) Path to a folder containing a checkpoint to restore the model from.',
|
|
||||||
type=str,
|
|
||||||
default='/checkpoint')
|
|
||||||
parser.add_argument('--memory-backend-params',
|
|
||||||
help="(string) JSON string of the memory backend params",
|
|
||||||
type=str)
|
|
||||||
parser.add_argument('--data-store-params',
|
|
||||||
help="(string) JSON string of the data store params",
|
|
||||||
type=str)
|
|
||||||
parser.add_argument('--num-workers',
|
|
||||||
help="(int) The number of workers started in this pool",
|
|
||||||
type=int,
|
|
||||||
default=1)
|
|
||||||
parser.add_argument('--policy-type',
|
|
||||||
help="(string) The type of policy: OFF/ON",
|
|
||||||
type=str,
|
|
||||||
default='OFF')
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
|
||||||
graph_manager = short_dynamic_import(expand_preset(args.preset), ignore_module_case=True)
|
|
||||||
|
|
||||||
data_store = None
|
|
||||||
if args.memory_backend_params:
|
|
||||||
args.memory_backend_params = json.loads(args.memory_backend_params)
|
|
||||||
args.memory_backend_params['run_type'] = 'worker'
|
|
||||||
graph_manager.agent_params.memory.register_var('memory_backend_params', construct_memory_params(args.memory_backend_params))
|
|
||||||
|
|
||||||
if args.data_store_params:
|
|
||||||
data_store_params = construct_data_store_params(json.loads(args.data_store_params))
|
|
||||||
data_store_params.checkpoint_dir = args.checkpoint_dir
|
|
||||||
graph_manager.data_store_params = data_store_params
|
|
||||||
data_store = get_data_store(data_store_params)
|
|
||||||
wait_for_checkpoint(checkpoint_dir=args.checkpoint_dir, data_store=data_store)
|
|
||||||
# thread = Thread(target = data_store_ckpt_load, args = [data_store])
|
|
||||||
# thread.start()
|
|
||||||
|
|
||||||
rollout_worker(
|
|
||||||
graph_manager=graph_manager,
|
|
||||||
checkpoint_dir=args.checkpoint_dir,
|
|
||||||
data_store=data_store,
|
|
||||||
num_workers=args.num_workers,
|
|
||||||
policy_type=args.policy_type
|
|
||||||
)
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
main()
|
|
||||||
|
|||||||
@@ -1,24 +1,18 @@
|
|||||||
"""
|
"""
|
||||||
"""
|
"""
|
||||||
import argparse
|
|
||||||
import time
|
import time
|
||||||
import json
|
|
||||||
|
|
||||||
from threading import Thread
|
from rl_coach.base_parameters import TaskParameters, DistributedCoachSynchronizationType
|
||||||
|
|
||||||
from rl_coach.base_parameters import TaskParameters
|
|
||||||
from rl_coach.coach import expand_preset
|
|
||||||
from rl_coach import core_types
|
from rl_coach import core_types
|
||||||
from rl_coach.utils import short_dynamic_import
|
|
||||||
from rl_coach.memories.backend.memory_impl import construct_memory_params
|
|
||||||
from rl_coach.data_stores.data_store_impl import get_data_store, construct_data_store_params
|
|
||||||
|
|
||||||
def data_store_ckpt_save(data_store):
|
def data_store_ckpt_save(data_store):
|
||||||
while True:
|
while True:
|
||||||
data_store.save_to_store()
|
data_store.save_to_store()
|
||||||
time.sleep(10)
|
time.sleep(10)
|
||||||
|
|
||||||
def training_worker(graph_manager, checkpoint_dir, policy_type):
|
|
||||||
|
def training_worker(graph_manager, checkpoint_dir):
|
||||||
"""
|
"""
|
||||||
restore a checkpoint then perform rollouts using the restored model
|
restore a checkpoint then perform rollouts using the restored model
|
||||||
"""
|
"""
|
||||||
@@ -49,55 +43,7 @@ def training_worker(graph_manager, checkpoint_dir, policy_type):
|
|||||||
graph_manager.evaluate(graph_manager.evaluation_steps)
|
graph_manager.evaluate(graph_manager.evaluation_steps)
|
||||||
eval_offset += 1
|
eval_offset += 1
|
||||||
|
|
||||||
if policy_type == 'ON':
|
if graph_manager.agent_params.algorithm.distributed_coach_synchronization_type == DistributedCoachSynchronizationType.SYNC:
|
||||||
graph_manager.save_checkpoint()
|
graph_manager.save_checkpoint()
|
||||||
else:
|
else:
|
||||||
graph_manager.occasionally_save_checkpoint()
|
graph_manager.occasionally_save_checkpoint()
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
parser.add_argument('-p', '--preset',
|
|
||||||
help="(string) Name of a preset to run (class name from the 'presets' directory.)",
|
|
||||||
type=str,
|
|
||||||
required=True)
|
|
||||||
parser.add_argument('--checkpoint-dir',
|
|
||||||
help='(string) Path to a folder containing a checkpoint to write the model to.',
|
|
||||||
type=str,
|
|
||||||
default='/checkpoint')
|
|
||||||
parser.add_argument('--memory-backend-params',
|
|
||||||
help="(string) JSON string of the memory backend params",
|
|
||||||
type=str)
|
|
||||||
parser.add_argument('--data-store-params',
|
|
||||||
help="(string) JSON string of the data store params",
|
|
||||||
type=str)
|
|
||||||
parser.add_argument('--policy-type',
|
|
||||||
help="(string) The type of policy: OFF/ON",
|
|
||||||
type=str,
|
|
||||||
default='OFF')
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
graph_manager = short_dynamic_import(expand_preset(args.preset), ignore_module_case=True)
|
|
||||||
|
|
||||||
if args.memory_backend_params:
|
|
||||||
args.memory_backend_params = json.loads(args.memory_backend_params)
|
|
||||||
args.memory_backend_params['run_type'] = 'trainer'
|
|
||||||
graph_manager.agent_params.memory.register_var('memory_backend_params', construct_memory_params(args.memory_backend_params))
|
|
||||||
|
|
||||||
if args.data_store_params:
|
|
||||||
data_store_params = construct_data_store_params(json.loads(args.data_store_params))
|
|
||||||
data_store_params.checkpoint_dir = args.checkpoint_dir
|
|
||||||
graph_manager.data_store_params = data_store_params
|
|
||||||
# data_store = get_data_store(data_store_params)
|
|
||||||
# thread = Thread(target = data_store_ckpt_save, args = [data_store])
|
|
||||||
# thread.start()
|
|
||||||
|
|
||||||
training_worker(
|
|
||||||
graph_manager=graph_manager,
|
|
||||||
checkpoint_dir=args.checkpoint_dir,
|
|
||||||
policy_type=args.policy_type
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
main()
|
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ import signal
|
|||||||
import sys
|
import sys
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
|
import traceback
|
||||||
from multiprocessing import Manager
|
from multiprocessing import Manager
|
||||||
from subprocess import Popen
|
from subprocess import Popen
|
||||||
from typing import List, Tuple
|
from typing import List, Tuple
|
||||||
@@ -30,6 +31,8 @@ from typing import List, Tuple
|
|||||||
import atexit
|
import atexit
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
from rl_coach.logger import screen
|
||||||
|
|
||||||
killed_processes = []
|
killed_processes = []
|
||||||
|
|
||||||
eps = np.finfo(np.float32).eps
|
eps = np.finfo(np.float32).eps
|
||||||
|
|||||||
Reference in New Issue
Block a user