1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 19:20:19 +01:00

Uploading checkpoint if crd provided (#191)

* Uploading checkpoint if crd provided
* Changing the calculation of total steps because of a recent change in core_types

Fixes #195
This commit is contained in:
Ajay Deshpande
2019-04-26 12:27:33 -07:00
committed by Scott Leishman
parent b3db9ce77d
commit 33dc29ee99
8 changed files with 122 additions and 40 deletions

View File

@@ -103,21 +103,22 @@ def handle_distributed_coach_tasks(graph_manager, args, task_parameters):
data_store_params.checkpoint_dir = ckpt_inside_container
graph_manager.data_store_params = data_store_params
data_store = None
if args.data_store_params:
data_store = get_data_store(data_store_params)
if args.distributed_coach_run_type == RunType.TRAINER:
task_parameters.checkpoint_save_dir = ckpt_inside_container
training_worker(
graph_manager=graph_manager,
task_parameters=task_parameters,
data_store=data_store,
is_multi_node_test=args.is_multi_node_test
)
if args.distributed_coach_run_type == RunType.ROLLOUT_WORKER:
task_parameters.checkpoint_restore_path = ckpt_inside_container
data_store = None
if args.data_store_params:
data_store = get_data_store(data_store_params)
rollout_worker(
graph_manager=graph_manager,
data_store=data_store,
@@ -169,7 +170,7 @@ def handle_distributed_coach_orchestrator(args):
memory_backend_parameters=memory_backend_params,
data_store_params=ds_params_instance)
orchestrator = Kubernetes(orchestration_params)
if not orchestrator.setup():
if not orchestrator.setup(args.checkpoint_restore_dir):
print("Could not setup.")
return 1
@@ -394,7 +395,9 @@ class CoachLauncher(object):
# validate the checkpoints args
if args.checkpoint_restore_dir is not None and not os.path.exists(args.checkpoint_restore_dir):
screen.error("The requested checkpoint folder to load from does not exist.")
# If distributed trainer, the checkpoint dir is not yet available so skipping the check in that case.
if not (args.distributed_coach and args.distributed_coach_run_type in [RunType.TRAINER, RunType.ROLLOUT_WORKER]):
screen.error("The requested checkpoint folder to load from does not exist.")
# validate the checkpoints args
if args.checkpoint_restore_file is not None and not glob(args.checkpoint_restore_file + '*'):