mirror of
https://github.com/gryf/coach.git
synced 2025-12-18 03:30:19 +01:00
Adding worker logs and plumbed task_parameters to distributed coach (#130)
This commit is contained in:
committed by
Balaji Subramaniam
parent
2b4c9c6774
commit
4a6c404070
@@ -84,7 +84,7 @@ def start_graph(graph_manager: 'GraphManager', task_parameters: 'TaskParameters'
|
||||
graph_manager.close()
|
||||
|
||||
|
||||
def handle_distributed_coach_tasks(graph_manager, args):
|
||||
def handle_distributed_coach_tasks(graph_manager, args, task_parameters):
|
||||
ckpt_inside_container = "/checkpoint"
|
||||
|
||||
memory_backend_params = None
|
||||
@@ -100,22 +100,24 @@ def handle_distributed_coach_tasks(graph_manager, args):
|
||||
graph_manager.data_store_params = data_store_params
|
||||
|
||||
if args.distributed_coach_run_type == RunType.TRAINER:
|
||||
task_parameters.checkpoint_save_dir = ckpt_inside_container
|
||||
training_worker(
|
||||
graph_manager=graph_manager,
|
||||
checkpoint_dir=ckpt_inside_container
|
||||
task_parameters=task_parameters
|
||||
)
|
||||
|
||||
if args.distributed_coach_run_type == RunType.ROLLOUT_WORKER:
|
||||
task_parameters.checkpoint_restore_dir = ckpt_inside_container
|
||||
|
||||
data_store = None
|
||||
if args.data_store_params:
|
||||
data_store = get_data_store(data_store_params)
|
||||
wait_for_checkpoint(checkpoint_dir=ckpt_inside_container, data_store=data_store)
|
||||
|
||||
rollout_worker(
|
||||
graph_manager=graph_manager,
|
||||
checkpoint_dir=ckpt_inside_container,
|
||||
data_store=data_store,
|
||||
num_workers=args.num_workers
|
||||
num_workers=args.num_workers,
|
||||
task_parameters=task_parameters
|
||||
)
|
||||
|
||||
|
||||
@@ -124,8 +126,16 @@ def handle_distributed_coach_orchestrator(args):
|
||||
RunTypeParameters
|
||||
|
||||
ckpt_inside_container = "/checkpoint"
|
||||
rollout_command = ['python3', 'rl_coach/coach.py', '--distributed_coach_run_type', str(RunType.ROLLOUT_WORKER)] + sys.argv[1:]
|
||||
trainer_command = ['python3', 'rl_coach/coach.py', '--distributed_coach_run_type', str(RunType.TRAINER)] + sys.argv[1:]
|
||||
arg_list = sys.argv[1:]
|
||||
try:
|
||||
i = arg_list.index('--distributed_coach_run_type')
|
||||
arg_list.pop(i)
|
||||
arg_list.pop(i)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
trainer_command = ['python3', 'rl_coach/coach.py', '--distributed_coach_run_type', str(RunType.TRAINER)] + arg_list
|
||||
rollout_command = ['python3', 'rl_coach/coach.py', '--distributed_coach_run_type', str(RunType.ROLLOUT_WORKER)] + arg_list
|
||||
|
||||
if '--experiment_name' not in rollout_command:
|
||||
rollout_command = rollout_command + ['--experiment_name', args.experiment_name]
|
||||
@@ -170,6 +180,10 @@ def handle_distributed_coach_orchestrator(args):
|
||||
print("Could not deploy rollout worker(s).")
|
||||
return
|
||||
|
||||
if args.dump_worker_logs:
|
||||
screen.log_title("Dumping rollout worker logs in: {}".format(args.experiment_path))
|
||||
orchestrator.worker_logs(path=args.experiment_path)
|
||||
|
||||
try:
|
||||
orchestrator.trainer_logs()
|
||||
except KeyboardInterrupt:
|
||||
@@ -321,6 +335,9 @@ class CoachLauncher(object):
|
||||
if args.list:
|
||||
self.display_all_presets_and_exit()
|
||||
|
||||
if args.distributed_coach and not args.checkpoint_save_secs:
|
||||
screen.error("Distributed coach requires --checkpoint_save_secs or -s")
|
||||
|
||||
# Read args from config file for distributed Coach.
|
||||
if args.distributed_coach and args.distributed_coach_run_type == RunType.ORCHESTRATOR:
|
||||
coach_config = ConfigParser({
|
||||
@@ -546,10 +563,13 @@ class CoachLauncher(object):
|
||||
default=RunType.ORCHESTRATOR,
|
||||
choices=list(RunType))
|
||||
parser.add_argument('-asc', '--apply_stop_condition',
|
||||
help="(flag) If set, this will apply a stop condition on the run, defined by reaching a"
|
||||
help="(flag) If set, this will apply a stop condition on the run, defined by reaching a"
|
||||
"target success rate as set by the environment or a custom success rate as defined "
|
||||
"in the preset. ",
|
||||
action='store_true')
|
||||
parser.add_argument('--dump_worker_logs',
|
||||
help="(flag) Only used in distributed coach. If set, the worker logs are saved in the experiment dir",
|
||||
action='store_true')
|
||||
|
||||
return parser
|
||||
|
||||
@@ -570,26 +590,6 @@ class CoachLauncher(object):
|
||||
atexit.register(logger.summarize_experiment)
|
||||
screen.change_terminal_title(args.experiment_name)
|
||||
|
||||
# open dashboard
|
||||
if args.open_dashboard:
|
||||
open_dashboard(args.experiment_path)
|
||||
|
||||
if args.distributed_coach and args.distributed_coach_run_type != RunType.ORCHESTRATOR:
|
||||
handle_distributed_coach_tasks(graph_manager, args)
|
||||
return
|
||||
|
||||
if args.distributed_coach and args.distributed_coach_run_type == RunType.ORCHESTRATOR:
|
||||
handle_distributed_coach_orchestrator(args)
|
||||
return
|
||||
|
||||
# Single-threaded runs
|
||||
if args.num_workers == 1:
|
||||
self.start_single_threaded(graph_manager, args)
|
||||
else:
|
||||
self.start_multi_threaded(graph_manager, args)
|
||||
|
||||
def start_single_threaded(self, graph_manager: 'GraphManager', args: argparse.Namespace):
|
||||
# Start the training or evaluation
|
||||
task_parameters = TaskParameters(
|
||||
framework_type=args.framework,
|
||||
evaluate_only=args.evaluate,
|
||||
@@ -603,6 +603,26 @@ class CoachLauncher(object):
|
||||
apply_stop_condition=args.apply_stop_condition
|
||||
)
|
||||
|
||||
# open dashboard
|
||||
if args.open_dashboard:
|
||||
open_dashboard(args.experiment_path)
|
||||
|
||||
if args.distributed_coach and args.distributed_coach_run_type != RunType.ORCHESTRATOR:
|
||||
handle_distributed_coach_tasks(graph_manager, args, task_parameters)
|
||||
return
|
||||
|
||||
if args.distributed_coach and args.distributed_coach_run_type == RunType.ORCHESTRATOR:
|
||||
handle_distributed_coach_orchestrator(args)
|
||||
return
|
||||
|
||||
# Single-threaded runs
|
||||
if args.num_workers == 1:
|
||||
self.start_single_threaded(task_parameters, graph_manager, args)
|
||||
else:
|
||||
self.start_multi_threaded(graph_manager, args)
|
||||
|
||||
def start_single_threaded(self, task_parameters, graph_manager: 'GraphManager', args: argparse.Namespace):
|
||||
# Start the training or evaluation
|
||||
start_graph(graph_manager=graph_manager, task_parameters=task_parameters)
|
||||
|
||||
def start_multi_threaded(self, graph_manager: 'GraphManager', args: argparse.Namespace):
|
||||
|
||||
Reference in New Issue
Block a user