1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 19:20:19 +01:00

Adding framework for multinode tests (#149)

* Currently runs CartPole_ClippedPPO and Mujoco_ClippedPPO with inverted_pendulum level.
This commit is contained in:
Ajay Deshpande
2019-02-26 13:53:12 -08:00
committed by Balaji Subramaniam
parent b461a1b8ab
commit 2c1a9dbf20
8 changed files with 210 additions and 24 deletions

View File

@@ -103,7 +103,8 @@ def handle_distributed_coach_tasks(graph_manager, args, task_parameters):
task_parameters.checkpoint_save_dir = ckpt_inside_container
training_worker(
graph_manager=graph_manager,
task_parameters=task_parameters
task_parameters=task_parameters,
is_multi_node_test=args.is_multi_node_test
)
if args.distributed_coach_run_type == RunType.ROLLOUT_WORKER:
@@ -166,30 +167,32 @@ def handle_distributed_coach_orchestrator(args):
orchestrator = Kubernetes(orchestration_params)
if not orchestrator.setup():
print("Could not setup.")
return
return 1
if orchestrator.deploy_trainer():
print("Successfully deployed trainer.")
else:
print("Could not deploy trainer.")
return
return 1
if orchestrator.deploy_worker():
print("Successfully deployed rollout worker(s).")
else:
print("Could not deploy rollout worker(s).")
return
return 1
if args.dump_worker_logs:
screen.log_title("Dumping rollout worker logs in: {}".format(args.experiment_path))
orchestrator.worker_logs(path=args.experiment_path)
exit_code = 1
try:
orchestrator.trainer_logs()
exit_code = orchestrator.trainer_logs()
except KeyboardInterrupt:
pass
orchestrator.undeploy()
return exit_code
class CoachLauncher(object):
@@ -331,7 +334,7 @@ class CoachLauncher(object):
# if no arg is given
if len(sys.argv) == 1:
parser.print_help()
sys.exit(0)
sys.exit(1)
# list available presets
if args.list:
@@ -569,6 +572,9 @@ class CoachLauncher(object):
parser.add_argument('--dump_worker_logs',
help="(flag) Only used in distributed coach. If set, the worker logs are saved in the experiment dir",
action='store_true')
parser.add_argument('--is_multi_node_test',
help=argparse.SUPPRESS,
action='store_true')
return parser
@@ -617,8 +623,7 @@ class CoachLauncher(object):
return
if args.distributed_coach and args.distributed_coach_run_type == RunType.ORCHESTRATOR:
handle_distributed_coach_orchestrator(args)
return
exit(handle_distributed_coach_orchestrator(args))
# Single-threaded runs
if args.num_workers == 1: