mirror of
https://github.com/gryf/coach.git
synced 2025-12-18 11:40:18 +01:00
Adding worker logs and plumbed task_parameters to distributed coach (#130)
This commit is contained in:
committed by
Balaji Subramaniam
parent
2b4c9c6774
commit
4a6c404070
@@ -84,7 +84,7 @@ def start_graph(graph_manager: 'GraphManager', task_parameters: 'TaskParameters'
|
||||
graph_manager.close()
|
||||
|
||||
|
||||
def handle_distributed_coach_tasks(graph_manager, args):
|
||||
def handle_distributed_coach_tasks(graph_manager, args, task_parameters):
|
||||
ckpt_inside_container = "/checkpoint"
|
||||
|
||||
memory_backend_params = None
|
||||
@@ -100,22 +100,24 @@ def handle_distributed_coach_tasks(graph_manager, args):
|
||||
graph_manager.data_store_params = data_store_params
|
||||
|
||||
if args.distributed_coach_run_type == RunType.TRAINER:
|
||||
task_parameters.checkpoint_save_dir = ckpt_inside_container
|
||||
training_worker(
|
||||
graph_manager=graph_manager,
|
||||
checkpoint_dir=ckpt_inside_container
|
||||
task_parameters=task_parameters
|
||||
)
|
||||
|
||||
if args.distributed_coach_run_type == RunType.ROLLOUT_WORKER:
|
||||
task_parameters.checkpoint_restore_dir = ckpt_inside_container
|
||||
|
||||
data_store = None
|
||||
if args.data_store_params:
|
||||
data_store = get_data_store(data_store_params)
|
||||
wait_for_checkpoint(checkpoint_dir=ckpt_inside_container, data_store=data_store)
|
||||
|
||||
rollout_worker(
|
||||
graph_manager=graph_manager,
|
||||
checkpoint_dir=ckpt_inside_container,
|
||||
data_store=data_store,
|
||||
num_workers=args.num_workers
|
||||
num_workers=args.num_workers,
|
||||
task_parameters=task_parameters
|
||||
)
|
||||
|
||||
|
||||
@@ -124,8 +126,16 @@ def handle_distributed_coach_orchestrator(args):
|
||||
RunTypeParameters
|
||||
|
||||
ckpt_inside_container = "/checkpoint"
|
||||
rollout_command = ['python3', 'rl_coach/coach.py', '--distributed_coach_run_type', str(RunType.ROLLOUT_WORKER)] + sys.argv[1:]
|
||||
trainer_command = ['python3', 'rl_coach/coach.py', '--distributed_coach_run_type', str(RunType.TRAINER)] + sys.argv[1:]
|
||||
arg_list = sys.argv[1:]
|
||||
try:
|
||||
i = arg_list.index('--distributed_coach_run_type')
|
||||
arg_list.pop(i)
|
||||
arg_list.pop(i)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
trainer_command = ['python3', 'rl_coach/coach.py', '--distributed_coach_run_type', str(RunType.TRAINER)] + arg_list
|
||||
rollout_command = ['python3', 'rl_coach/coach.py', '--distributed_coach_run_type', str(RunType.ROLLOUT_WORKER)] + arg_list
|
||||
|
||||
if '--experiment_name' not in rollout_command:
|
||||
rollout_command = rollout_command + ['--experiment_name', args.experiment_name]
|
||||
@@ -170,6 +180,10 @@ def handle_distributed_coach_orchestrator(args):
|
||||
print("Could not deploy rollout worker(s).")
|
||||
return
|
||||
|
||||
if args.dump_worker_logs:
|
||||
screen.log_title("Dumping rollout worker logs in: {}".format(args.experiment_path))
|
||||
orchestrator.worker_logs(path=args.experiment_path)
|
||||
|
||||
try:
|
||||
orchestrator.trainer_logs()
|
||||
except KeyboardInterrupt:
|
||||
@@ -321,6 +335,9 @@ class CoachLauncher(object):
|
||||
if args.list:
|
||||
self.display_all_presets_and_exit()
|
||||
|
||||
if args.distributed_coach and not args.checkpoint_save_secs:
|
||||
screen.error("Distributed coach requires --checkpoint_save_secs or -s")
|
||||
|
||||
# Read args from config file for distributed Coach.
|
||||
if args.distributed_coach and args.distributed_coach_run_type == RunType.ORCHESTRATOR:
|
||||
coach_config = ConfigParser({
|
||||
@@ -550,6 +567,9 @@ class CoachLauncher(object):
|
||||
"target success rate as set by the environment or a custom success rate as defined "
|
||||
"in the preset. ",
|
||||
action='store_true')
|
||||
parser.add_argument('--dump_worker_logs',
|
||||
help="(flag) Only used in distributed coach. If set, the worker logs are saved in the experiment dir",
|
||||
action='store_true')
|
||||
|
||||
return parser
|
||||
|
||||
@@ -570,26 +590,6 @@ class CoachLauncher(object):
|
||||
atexit.register(logger.summarize_experiment)
|
||||
screen.change_terminal_title(args.experiment_name)
|
||||
|
||||
# open dashboard
|
||||
if args.open_dashboard:
|
||||
open_dashboard(args.experiment_path)
|
||||
|
||||
if args.distributed_coach and args.distributed_coach_run_type != RunType.ORCHESTRATOR:
|
||||
handle_distributed_coach_tasks(graph_manager, args)
|
||||
return
|
||||
|
||||
if args.distributed_coach and args.distributed_coach_run_type == RunType.ORCHESTRATOR:
|
||||
handle_distributed_coach_orchestrator(args)
|
||||
return
|
||||
|
||||
# Single-threaded runs
|
||||
if args.num_workers == 1:
|
||||
self.start_single_threaded(graph_manager, args)
|
||||
else:
|
||||
self.start_multi_threaded(graph_manager, args)
|
||||
|
||||
def start_single_threaded(self, graph_manager: 'GraphManager', args: argparse.Namespace):
|
||||
# Start the training or evaluation
|
||||
task_parameters = TaskParameters(
|
||||
framework_type=args.framework,
|
||||
evaluate_only=args.evaluate,
|
||||
@@ -603,6 +603,26 @@ class CoachLauncher(object):
|
||||
apply_stop_condition=args.apply_stop_condition
|
||||
)
|
||||
|
||||
# open dashboard
|
||||
if args.open_dashboard:
|
||||
open_dashboard(args.experiment_path)
|
||||
|
||||
if args.distributed_coach and args.distributed_coach_run_type != RunType.ORCHESTRATOR:
|
||||
handle_distributed_coach_tasks(graph_manager, args, task_parameters)
|
||||
return
|
||||
|
||||
if args.distributed_coach and args.distributed_coach_run_type == RunType.ORCHESTRATOR:
|
||||
handle_distributed_coach_orchestrator(args)
|
||||
return
|
||||
|
||||
# Single-threaded runs
|
||||
if args.num_workers == 1:
|
||||
self.start_single_threaded(task_parameters, graph_manager, args)
|
||||
else:
|
||||
self.start_multi_threaded(graph_manager, args)
|
||||
|
||||
def start_single_threaded(self, task_parameters, graph_manager: 'GraphManager', args: argparse.Namespace):
|
||||
# Start the training or evaluation
|
||||
start_graph(graph_manager=graph_manager, task_parameters=task_parameters)
|
||||
|
||||
def start_multi_threaded(self, graph_manager: 'GraphManager', args: argparse.Namespace):
|
||||
|
||||
@@ -111,6 +111,7 @@ class RedisPubSubBackend(MemoryBackend):
|
||||
return False
|
||||
|
||||
def undeploy(self):
|
||||
from kubernetes import client
|
||||
if self.params.deployed:
|
||||
return
|
||||
|
||||
|
||||
@@ -2,9 +2,11 @@ import os
|
||||
import uuid
|
||||
import json
|
||||
import time
|
||||
import sys
|
||||
from enum import Enum
|
||||
from typing import List
|
||||
from configparser import ConfigParser, Error
|
||||
from multiprocessing import Process
|
||||
|
||||
from rl_coach.base_parameters import RunType
|
||||
from rl_coach.orchestrators.deploy import Deploy, DeployParameters
|
||||
@@ -255,8 +257,35 @@ class Kubernetes(Deploy):
|
||||
print("Got exception: %s\n while creating Job", e)
|
||||
return False
|
||||
|
||||
def worker_logs(self):
|
||||
pass
|
||||
def worker_logs(self, path='./logs'):
|
||||
worker_params = self.params.run_type_params.get(str(RunType.ROLLOUT_WORKER), None)
|
||||
if not worker_params:
|
||||
return
|
||||
|
||||
api_client = k8sclient.CoreV1Api()
|
||||
pods = None
|
||||
try:
|
||||
pods = api_client.list_namespaced_pod(self.params.namespace, label_selector='app={}'.format(
|
||||
worker_params.orchestration_params['job_name']
|
||||
))
|
||||
|
||||
# pod = pods.items[0]
|
||||
except k8sclient.rest.ApiException as e:
|
||||
print("Got exception: %s\n while reading pods", e)
|
||||
return
|
||||
|
||||
if not pods or len(pods.items) == 0:
|
||||
return
|
||||
|
||||
for pod in pods.items:
|
||||
Process(target=self._tail_log_file, args=(pod.metadata.name, api_client, self.params.namespace, path)).start()
|
||||
|
||||
def _tail_log_file(self, pod_name, api_client, namespace, path):
|
||||
if not os.path.exists(path):
|
||||
os.mkdir(path)
|
||||
|
||||
sys.stdout = open(os.path.join(path, pod_name), 'w')
|
||||
self.tail_log(pod_name, api_client)
|
||||
|
||||
def trainer_logs(self):
|
||||
trainer_params = self.params.run_type_params.get(str(RunType.TRAINER), None)
|
||||
|
||||
@@ -68,21 +68,17 @@ def get_latest_checkpoint(checkpoint_dir):
|
||||
rel_path = os.path.relpath(ckpt.model_checkpoint_path, checkpoint_dir)
|
||||
return int(rel_path.split('_Step')[0])
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
def should_stop(checkpoint_dir):
|
||||
return os.path.exists(os.path.join(checkpoint_dir, SyncFiles.FINISHED.value))
|
||||
|
||||
|
||||
def rollout_worker(graph_manager, checkpoint_dir, data_store, num_workers):
|
||||
def rollout_worker(graph_manager, data_store, num_workers, task_parameters):
|
||||
"""
|
||||
wait for first checkpoint then perform rollouts using the model
|
||||
"""
|
||||
wait_for_checkpoint(checkpoint_dir)
|
||||
|
||||
task_parameters = TaskParameters()
|
||||
task_parameters.__dict__['checkpoint_restore_dir'] = checkpoint_dir
|
||||
checkpoint_dir = task_parameters.checkpoint_restore_dir
|
||||
wait_for_checkpoint(checkpoint_dir, data_store)
|
||||
|
||||
graph_manager.create_graph(task_parameters)
|
||||
with graph_manager.phase_context(RunPhase.TRAIN):
|
||||
|
||||
@@ -12,14 +12,11 @@ def data_store_ckpt_save(data_store):
|
||||
time.sleep(10)
|
||||
|
||||
|
||||
def training_worker(graph_manager, checkpoint_dir):
|
||||
def training_worker(graph_manager, task_parameters):
|
||||
"""
|
||||
restore a checkpoint then perform rollouts using the restored model
|
||||
"""
|
||||
# initialize graph
|
||||
task_parameters = TaskParameters()
|
||||
task_parameters.__dict__['checkpoint_save_dir'] = checkpoint_dir
|
||||
task_parameters.__dict__['checkpoint_save_secs'] = 20
|
||||
graph_manager.create_graph(task_parameters)
|
||||
|
||||
# save randomly initialized graph
|
||||
|
||||
Reference in New Issue
Block a user