mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 19:20:19 +01:00
Adding framework for multinode tests (#149)
* Currently runs CartPole_ClippedPPO and Mujoco_ClippedPPO with inverted_pendulum level.
This commit is contained in:
committed by
Balaji Subramaniam
parent
b461a1b8ab
commit
2c1a9dbf20
@@ -103,7 +103,8 @@ def handle_distributed_coach_tasks(graph_manager, args, task_parameters):
|
||||
task_parameters.checkpoint_save_dir = ckpt_inside_container
|
||||
training_worker(
|
||||
graph_manager=graph_manager,
|
||||
task_parameters=task_parameters
|
||||
task_parameters=task_parameters,
|
||||
is_multi_node_test=args.is_multi_node_test
|
||||
)
|
||||
|
||||
if args.distributed_coach_run_type == RunType.ROLLOUT_WORKER:
|
||||
@@ -166,30 +167,32 @@ def handle_distributed_coach_orchestrator(args):
|
||||
orchestrator = Kubernetes(orchestration_params)
|
||||
if not orchestrator.setup():
|
||||
print("Could not setup.")
|
||||
return
|
||||
return 1
|
||||
|
||||
if orchestrator.deploy_trainer():
|
||||
print("Successfully deployed trainer.")
|
||||
else:
|
||||
print("Could not deploy trainer.")
|
||||
return
|
||||
return 1
|
||||
|
||||
if orchestrator.deploy_worker():
|
||||
print("Successfully deployed rollout worker(s).")
|
||||
else:
|
||||
print("Could not deploy rollout worker(s).")
|
||||
return
|
||||
return 1
|
||||
|
||||
if args.dump_worker_logs:
|
||||
screen.log_title("Dumping rollout worker logs in: {}".format(args.experiment_path))
|
||||
orchestrator.worker_logs(path=args.experiment_path)
|
||||
|
||||
exit_code = 1
|
||||
try:
|
||||
orchestrator.trainer_logs()
|
||||
exit_code = orchestrator.trainer_logs()
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
|
||||
orchestrator.undeploy()
|
||||
return exit_code
|
||||
|
||||
|
||||
class CoachLauncher(object):
|
||||
@@ -331,7 +334,7 @@ class CoachLauncher(object):
|
||||
# if no arg is given
|
||||
if len(sys.argv) == 1:
|
||||
parser.print_help()
|
||||
sys.exit(0)
|
||||
sys.exit(1)
|
||||
|
||||
# list available presets
|
||||
if args.list:
|
||||
@@ -569,6 +572,9 @@ class CoachLauncher(object):
|
||||
parser.add_argument('--dump_worker_logs',
|
||||
help="(flag) Only used in distributed coach. If set, the worker logs are saved in the experiment dir",
|
||||
action='store_true')
|
||||
parser.add_argument('--is_multi_node_test',
|
||||
help=argparse.SUPPRESS,
|
||||
action='store_true')
|
||||
|
||||
return parser
|
||||
|
||||
@@ -617,8 +623,7 @@ class CoachLauncher(object):
|
||||
return
|
||||
|
||||
if args.distributed_coach and args.distributed_coach_run_type == RunType.ORCHESTRATOR:
|
||||
handle_distributed_coach_orchestrator(args)
|
||||
return
|
||||
exit(handle_distributed_coach_orchestrator(args))
|
||||
|
||||
# Single-threaded runs
|
||||
if args.num_workers == 1:
|
||||
|
||||
Reference in New Issue
Block a user