1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-18 03:30:19 +01:00

Adding worker logs and plumbed task_parameters to distributed coach (#130)

This commit is contained in:
Ajay Deshpande
2018-11-23 15:35:11 -08:00
committed by Balaji Subramaniam
parent 2b4c9c6774
commit 4a6c404070
5 changed files with 84 additions and 41 deletions

View File

@@ -84,7 +84,7 @@ def start_graph(graph_manager: 'GraphManager', task_parameters: 'TaskParameters'
graph_manager.close()
def handle_distributed_coach_tasks(graph_manager, args):
def handle_distributed_coach_tasks(graph_manager, args, task_parameters):
ckpt_inside_container = "/checkpoint"
memory_backend_params = None
@@ -100,22 +100,24 @@ def handle_distributed_coach_tasks(graph_manager, args):
graph_manager.data_store_params = data_store_params
if args.distributed_coach_run_type == RunType.TRAINER:
task_parameters.checkpoint_save_dir = ckpt_inside_container
training_worker(
graph_manager=graph_manager,
checkpoint_dir=ckpt_inside_container
task_parameters=task_parameters
)
if args.distributed_coach_run_type == RunType.ROLLOUT_WORKER:
task_parameters.checkpoint_restore_dir = ckpt_inside_container
data_store = None
if args.data_store_params:
data_store = get_data_store(data_store_params)
wait_for_checkpoint(checkpoint_dir=ckpt_inside_container, data_store=data_store)
rollout_worker(
graph_manager=graph_manager,
checkpoint_dir=ckpt_inside_container,
data_store=data_store,
num_workers=args.num_workers
num_workers=args.num_workers,
task_parameters=task_parameters
)
@@ -124,8 +126,16 @@ def handle_distributed_coach_orchestrator(args):
RunTypeParameters
ckpt_inside_container = "/checkpoint"
rollout_command = ['python3', 'rl_coach/coach.py', '--distributed_coach_run_type', str(RunType.ROLLOUT_WORKER)] + sys.argv[1:]
trainer_command = ['python3', 'rl_coach/coach.py', '--distributed_coach_run_type', str(RunType.TRAINER)] + sys.argv[1:]
arg_list = sys.argv[1:]
try:
i = arg_list.index('--distributed_coach_run_type')
arg_list.pop(i)
arg_list.pop(i)
except ValueError:
pass
trainer_command = ['python3', 'rl_coach/coach.py', '--distributed_coach_run_type', str(RunType.TRAINER)] + arg_list
rollout_command = ['python3', 'rl_coach/coach.py', '--distributed_coach_run_type', str(RunType.ROLLOUT_WORKER)] + arg_list
if '--experiment_name' not in rollout_command:
rollout_command = rollout_command + ['--experiment_name', args.experiment_name]
@@ -170,6 +180,10 @@ def handle_distributed_coach_orchestrator(args):
print("Could not deploy rollout worker(s).")
return
if args.dump_worker_logs:
screen.log_title("Dumping rollout worker logs in: {}".format(args.experiment_path))
orchestrator.worker_logs(path=args.experiment_path)
try:
orchestrator.trainer_logs()
except KeyboardInterrupt:
@@ -321,6 +335,9 @@ class CoachLauncher(object):
if args.list:
self.display_all_presets_and_exit()
if args.distributed_coach and not args.checkpoint_save_secs:
screen.error("Distributed coach requires --checkpoint_save_secs or -s")
# Read args from config file for distributed Coach.
if args.distributed_coach and args.distributed_coach_run_type == RunType.ORCHESTRATOR:
coach_config = ConfigParser({
@@ -546,10 +563,13 @@ class CoachLauncher(object):
default=RunType.ORCHESTRATOR,
choices=list(RunType))
parser.add_argument('-asc', '--apply_stop_condition',
help="(flag) If set, this will apply a stop condition on the run, defined by reaching a"
help="(flag) If set, this will apply a stop condition on the run, defined by reaching a"
"target success rate as set by the environment or a custom success rate as defined "
"in the preset. ",
action='store_true')
parser.add_argument('--dump_worker_logs',
help="(flag) Only used in distributed coach. If set, the worker logs are saved in the experiment dir",
action='store_true')
return parser
@@ -570,26 +590,6 @@ class CoachLauncher(object):
atexit.register(logger.summarize_experiment)
screen.change_terminal_title(args.experiment_name)
# open dashboard
if args.open_dashboard:
open_dashboard(args.experiment_path)
if args.distributed_coach and args.distributed_coach_run_type != RunType.ORCHESTRATOR:
handle_distributed_coach_tasks(graph_manager, args)
return
if args.distributed_coach and args.distributed_coach_run_type == RunType.ORCHESTRATOR:
handle_distributed_coach_orchestrator(args)
return
# Single-threaded runs
if args.num_workers == 1:
self.start_single_threaded(graph_manager, args)
else:
self.start_multi_threaded(graph_manager, args)
def start_single_threaded(self, graph_manager: 'GraphManager', args: argparse.Namespace):
# Start the training or evaluation
task_parameters = TaskParameters(
framework_type=args.framework,
evaluate_only=args.evaluate,
@@ -603,6 +603,26 @@ class CoachLauncher(object):
apply_stop_condition=args.apply_stop_condition
)
# open dashboard
if args.open_dashboard:
open_dashboard(args.experiment_path)
if args.distributed_coach and args.distributed_coach_run_type != RunType.ORCHESTRATOR:
handle_distributed_coach_tasks(graph_manager, args, task_parameters)
return
if args.distributed_coach and args.distributed_coach_run_type == RunType.ORCHESTRATOR:
handle_distributed_coach_orchestrator(args)
return
# Single-threaded runs
if args.num_workers == 1:
self.start_single_threaded(task_parameters, graph_manager, args)
else:
self.start_multi_threaded(graph_manager, args)
def start_single_threaded(self, task_parameters, graph_manager: 'GraphManager', args: argparse.Namespace):
# Start the training or evaluation
start_graph(graph_manager=graph_manager, task_parameters=task_parameters)
def start_multi_threaded(self, graph_manager: 'GraphManager', args: argparse.Namespace):

View File

@@ -111,6 +111,7 @@ class RedisPubSubBackend(MemoryBackend):
return False
def undeploy(self):
from kubernetes import client
if self.params.deployed:
return

View File

@@ -2,9 +2,11 @@ import os
import uuid
import json
import time
import sys
from enum import Enum
from typing import List
from configparser import ConfigParser, Error
from multiprocessing import Process
from rl_coach.base_parameters import RunType
from rl_coach.orchestrators.deploy import Deploy, DeployParameters
@@ -255,8 +257,35 @@ class Kubernetes(Deploy):
print("Got exception: %s\n while creating Job", e)
return False
def worker_logs(self):
pass
def worker_logs(self, path='./logs'):
worker_params = self.params.run_type_params.get(str(RunType.ROLLOUT_WORKER), None)
if not worker_params:
return
api_client = k8sclient.CoreV1Api()
pods = None
try:
pods = api_client.list_namespaced_pod(self.params.namespace, label_selector='app={}'.format(
worker_params.orchestration_params['job_name']
))
# pod = pods.items[0]
except k8sclient.rest.ApiException as e:
print("Got exception: %s\n while reading pods", e)
return
if not pods or len(pods.items) == 0:
return
for pod in pods.items:
Process(target=self._tail_log_file, args=(pod.metadata.name, api_client, self.params.namespace, path)).start()
def _tail_log_file(self, pod_name, api_client, namespace, path):
if not os.path.exists(path):
os.mkdir(path)
sys.stdout = open(os.path.join(path, pod_name), 'w')
self.tail_log(pod_name, api_client)
def trainer_logs(self):
trainer_params = self.params.run_type_params.get(str(RunType.TRAINER), None)

View File

@@ -68,21 +68,17 @@ def get_latest_checkpoint(checkpoint_dir):
rel_path = os.path.relpath(ckpt.model_checkpoint_path, checkpoint_dir)
return int(rel_path.split('_Step')[0])
return 0
def should_stop(checkpoint_dir):
return os.path.exists(os.path.join(checkpoint_dir, SyncFiles.FINISHED.value))
def rollout_worker(graph_manager, checkpoint_dir, data_store, num_workers):
def rollout_worker(graph_manager, data_store, num_workers, task_parameters):
"""
wait for first checkpoint then perform rollouts using the model
"""
wait_for_checkpoint(checkpoint_dir)
task_parameters = TaskParameters()
task_parameters.__dict__['checkpoint_restore_dir'] = checkpoint_dir
checkpoint_dir = task_parameters.checkpoint_restore_dir
wait_for_checkpoint(checkpoint_dir, data_store)
graph_manager.create_graph(task_parameters)
with graph_manager.phase_context(RunPhase.TRAIN):

View File

@@ -12,14 +12,11 @@ def data_store_ckpt_save(data_store):
time.sleep(10)
def training_worker(graph_manager, checkpoint_dir):
def training_worker(graph_manager, task_parameters):
"""
restore a checkpoint then perform rollouts using the restored model
"""
# initialize graph
task_parameters = TaskParameters()
task_parameters.__dict__['checkpoint_save_dir'] = checkpoint_dir
task_parameters.__dict__['checkpoint_save_secs'] = 20
graph_manager.create_graph(task_parameters)
# save randomly initialized graph