1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-18 19:50:17 +01:00

Adding worker logs and plumbed task_parameters to distributed coach (#130)

This commit is contained in:
Ajay Deshpande
2018-11-23 15:35:11 -08:00
committed by Balaji Subramaniam
parent 2b4c9c6774
commit 4a6c404070
5 changed files with 84 additions and 41 deletions

View File

@@ -84,7 +84,7 @@ def start_graph(graph_manager: 'GraphManager', task_parameters: 'TaskParameters'
graph_manager.close() graph_manager.close()
def handle_distributed_coach_tasks(graph_manager, args): def handle_distributed_coach_tasks(graph_manager, args, task_parameters):
ckpt_inside_container = "/checkpoint" ckpt_inside_container = "/checkpoint"
memory_backend_params = None memory_backend_params = None
@@ -100,22 +100,24 @@ def handle_distributed_coach_tasks(graph_manager, args):
graph_manager.data_store_params = data_store_params graph_manager.data_store_params = data_store_params
if args.distributed_coach_run_type == RunType.TRAINER: if args.distributed_coach_run_type == RunType.TRAINER:
task_parameters.checkpoint_save_dir = ckpt_inside_container
training_worker( training_worker(
graph_manager=graph_manager, graph_manager=graph_manager,
checkpoint_dir=ckpt_inside_container task_parameters=task_parameters
) )
if args.distributed_coach_run_type == RunType.ROLLOUT_WORKER: if args.distributed_coach_run_type == RunType.ROLLOUT_WORKER:
task_parameters.checkpoint_restore_dir = ckpt_inside_container
data_store = None data_store = None
if args.data_store_params: if args.data_store_params:
data_store = get_data_store(data_store_params) data_store = get_data_store(data_store_params)
wait_for_checkpoint(checkpoint_dir=ckpt_inside_container, data_store=data_store)
rollout_worker( rollout_worker(
graph_manager=graph_manager, graph_manager=graph_manager,
checkpoint_dir=ckpt_inside_container,
data_store=data_store, data_store=data_store,
num_workers=args.num_workers num_workers=args.num_workers,
task_parameters=task_parameters
) )
@@ -124,8 +126,16 @@ def handle_distributed_coach_orchestrator(args):
RunTypeParameters RunTypeParameters
ckpt_inside_container = "/checkpoint" ckpt_inside_container = "/checkpoint"
rollout_command = ['python3', 'rl_coach/coach.py', '--distributed_coach_run_type', str(RunType.ROLLOUT_WORKER)] + sys.argv[1:] arg_list = sys.argv[1:]
trainer_command = ['python3', 'rl_coach/coach.py', '--distributed_coach_run_type', str(RunType.TRAINER)] + sys.argv[1:] try:
i = arg_list.index('--distributed_coach_run_type')
arg_list.pop(i)
arg_list.pop(i)
except ValueError:
pass
trainer_command = ['python3', 'rl_coach/coach.py', '--distributed_coach_run_type', str(RunType.TRAINER)] + arg_list
rollout_command = ['python3', 'rl_coach/coach.py', '--distributed_coach_run_type', str(RunType.ROLLOUT_WORKER)] + arg_list
if '--experiment_name' not in rollout_command: if '--experiment_name' not in rollout_command:
rollout_command = rollout_command + ['--experiment_name', args.experiment_name] rollout_command = rollout_command + ['--experiment_name', args.experiment_name]
@@ -170,6 +180,10 @@ def handle_distributed_coach_orchestrator(args):
print("Could not deploy rollout worker(s).") print("Could not deploy rollout worker(s).")
return return
if args.dump_worker_logs:
screen.log_title("Dumping rollout worker logs in: {}".format(args.experiment_path))
orchestrator.worker_logs(path=args.experiment_path)
try: try:
orchestrator.trainer_logs() orchestrator.trainer_logs()
except KeyboardInterrupt: except KeyboardInterrupt:
@@ -321,6 +335,9 @@ class CoachLauncher(object):
if args.list: if args.list:
self.display_all_presets_and_exit() self.display_all_presets_and_exit()
if args.distributed_coach and not args.checkpoint_save_secs:
screen.error("Distributed coach requires --checkpoint_save_secs or -s")
# Read args from config file for distributed Coach. # Read args from config file for distributed Coach.
if args.distributed_coach and args.distributed_coach_run_type == RunType.ORCHESTRATOR: if args.distributed_coach and args.distributed_coach_run_type == RunType.ORCHESTRATOR:
coach_config = ConfigParser({ coach_config = ConfigParser({
@@ -550,6 +567,9 @@ class CoachLauncher(object):
"target success rate as set by the environment or a custom success rate as defined " "target success rate as set by the environment or a custom success rate as defined "
"in the preset. ", "in the preset. ",
action='store_true') action='store_true')
parser.add_argument('--dump_worker_logs',
help="(flag) Only used in distributed coach. If set, the worker logs are saved in the experiment dir",
action='store_true')
return parser return parser
@@ -570,26 +590,6 @@ class CoachLauncher(object):
atexit.register(logger.summarize_experiment) atexit.register(logger.summarize_experiment)
screen.change_terminal_title(args.experiment_name) screen.change_terminal_title(args.experiment_name)
# open dashboard
if args.open_dashboard:
open_dashboard(args.experiment_path)
if args.distributed_coach and args.distributed_coach_run_type != RunType.ORCHESTRATOR:
handle_distributed_coach_tasks(graph_manager, args)
return
if args.distributed_coach and args.distributed_coach_run_type == RunType.ORCHESTRATOR:
handle_distributed_coach_orchestrator(args)
return
# Single-threaded runs
if args.num_workers == 1:
self.start_single_threaded(graph_manager, args)
else:
self.start_multi_threaded(graph_manager, args)
def start_single_threaded(self, graph_manager: 'GraphManager', args: argparse.Namespace):
# Start the training or evaluation
task_parameters = TaskParameters( task_parameters = TaskParameters(
framework_type=args.framework, framework_type=args.framework,
evaluate_only=args.evaluate, evaluate_only=args.evaluate,
@@ -603,6 +603,26 @@ class CoachLauncher(object):
apply_stop_condition=args.apply_stop_condition apply_stop_condition=args.apply_stop_condition
) )
# open dashboard
if args.open_dashboard:
open_dashboard(args.experiment_path)
if args.distributed_coach and args.distributed_coach_run_type != RunType.ORCHESTRATOR:
handle_distributed_coach_tasks(graph_manager, args, task_parameters)
return
if args.distributed_coach and args.distributed_coach_run_type == RunType.ORCHESTRATOR:
handle_distributed_coach_orchestrator(args)
return
# Single-threaded runs
if args.num_workers == 1:
self.start_single_threaded(task_parameters, graph_manager, args)
else:
self.start_multi_threaded(graph_manager, args)
def start_single_threaded(self, task_parameters, graph_manager: 'GraphManager', args: argparse.Namespace):
# Start the training or evaluation
start_graph(graph_manager=graph_manager, task_parameters=task_parameters) start_graph(graph_manager=graph_manager, task_parameters=task_parameters)
def start_multi_threaded(self, graph_manager: 'GraphManager', args: argparse.Namespace): def start_multi_threaded(self, graph_manager: 'GraphManager', args: argparse.Namespace):

View File

@@ -111,6 +111,7 @@ class RedisPubSubBackend(MemoryBackend):
return False return False
def undeploy(self): def undeploy(self):
from kubernetes import client
if self.params.deployed: if self.params.deployed:
return return

View File

@@ -2,9 +2,11 @@ import os
import uuid import uuid
import json import json
import time import time
import sys
from enum import Enum from enum import Enum
from typing import List from typing import List
from configparser import ConfigParser, Error from configparser import ConfigParser, Error
from multiprocessing import Process
from rl_coach.base_parameters import RunType from rl_coach.base_parameters import RunType
from rl_coach.orchestrators.deploy import Deploy, DeployParameters from rl_coach.orchestrators.deploy import Deploy, DeployParameters
@@ -255,8 +257,35 @@ class Kubernetes(Deploy):
print("Got exception: %s\n while creating Job", e) print("Got exception: %s\n while creating Job", e)
return False return False
def worker_logs(self): def worker_logs(self, path='./logs'):
pass worker_params = self.params.run_type_params.get(str(RunType.ROLLOUT_WORKER), None)
if not worker_params:
return
api_client = k8sclient.CoreV1Api()
pods = None
try:
pods = api_client.list_namespaced_pod(self.params.namespace, label_selector='app={}'.format(
worker_params.orchestration_params['job_name']
))
# pod = pods.items[0]
except k8sclient.rest.ApiException as e:
print("Got exception: %s\n while reading pods", e)
return
if not pods or len(pods.items) == 0:
return
for pod in pods.items:
Process(target=self._tail_log_file, args=(pod.metadata.name, api_client, self.params.namespace, path)).start()
def _tail_log_file(self, pod_name, api_client, namespace, path):
if not os.path.exists(path):
os.mkdir(path)
sys.stdout = open(os.path.join(path, pod_name), 'w')
self.tail_log(pod_name, api_client)
def trainer_logs(self): def trainer_logs(self):
trainer_params = self.params.run_type_params.get(str(RunType.TRAINER), None) trainer_params = self.params.run_type_params.get(str(RunType.TRAINER), None)

View File

@@ -68,21 +68,17 @@ def get_latest_checkpoint(checkpoint_dir):
rel_path = os.path.relpath(ckpt.model_checkpoint_path, checkpoint_dir) rel_path = os.path.relpath(ckpt.model_checkpoint_path, checkpoint_dir)
return int(rel_path.split('_Step')[0]) return int(rel_path.split('_Step')[0])
return 0
def should_stop(checkpoint_dir): def should_stop(checkpoint_dir):
return os.path.exists(os.path.join(checkpoint_dir, SyncFiles.FINISHED.value)) return os.path.exists(os.path.join(checkpoint_dir, SyncFiles.FINISHED.value))
def rollout_worker(graph_manager, checkpoint_dir, data_store, num_workers): def rollout_worker(graph_manager, data_store, num_workers, task_parameters):
""" """
wait for first checkpoint then perform rollouts using the model wait for first checkpoint then perform rollouts using the model
""" """
wait_for_checkpoint(checkpoint_dir) checkpoint_dir = task_parameters.checkpoint_restore_dir
wait_for_checkpoint(checkpoint_dir, data_store)
task_parameters = TaskParameters()
task_parameters.__dict__['checkpoint_restore_dir'] = checkpoint_dir
graph_manager.create_graph(task_parameters) graph_manager.create_graph(task_parameters)
with graph_manager.phase_context(RunPhase.TRAIN): with graph_manager.phase_context(RunPhase.TRAIN):

View File

@@ -12,14 +12,11 @@ def data_store_ckpt_save(data_store):
time.sleep(10) time.sleep(10)
def training_worker(graph_manager, checkpoint_dir): def training_worker(graph_manager, task_parameters):
""" """
restore a checkpoint then perform rollouts using the restored model restore a checkpoint then perform rollouts using the restored model
""" """
# initialize graph # initialize graph
task_parameters = TaskParameters()
task_parameters.__dict__['checkpoint_save_dir'] = checkpoint_dir
task_parameters.__dict__['checkpoint_save_secs'] = 20
graph_manager.create_graph(task_parameters) graph_manager.create_graph(task_parameters)
# save randomly initialized graph # save randomly initialized graph