mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 19:20:19 +01:00
Uploading checkpoint if crd provided (#191)
* Uploading checkpoint if crd provided * Changing the calculation of total steps because of a recent change in core_types Fixes #195
This commit is contained in:
committed by
Scott Leishman
parent
b3db9ce77d
commit
33dc29ee99
@@ -103,21 +103,22 @@ def handle_distributed_coach_tasks(graph_manager, args, task_parameters):
|
||||
data_store_params.checkpoint_dir = ckpt_inside_container
|
||||
graph_manager.data_store_params = data_store_params
|
||||
|
||||
data_store = None
|
||||
if args.data_store_params:
|
||||
data_store = get_data_store(data_store_params)
|
||||
|
||||
if args.distributed_coach_run_type == RunType.TRAINER:
|
||||
task_parameters.checkpoint_save_dir = ckpt_inside_container
|
||||
training_worker(
|
||||
graph_manager=graph_manager,
|
||||
task_parameters=task_parameters,
|
||||
data_store=data_store,
|
||||
is_multi_node_test=args.is_multi_node_test
|
||||
)
|
||||
|
||||
if args.distributed_coach_run_type == RunType.ROLLOUT_WORKER:
|
||||
task_parameters.checkpoint_restore_path = ckpt_inside_container
|
||||
|
||||
data_store = None
|
||||
if args.data_store_params:
|
||||
data_store = get_data_store(data_store_params)
|
||||
|
||||
rollout_worker(
|
||||
graph_manager=graph_manager,
|
||||
data_store=data_store,
|
||||
@@ -169,7 +170,7 @@ def handle_distributed_coach_orchestrator(args):
|
||||
memory_backend_parameters=memory_backend_params,
|
||||
data_store_params=ds_params_instance)
|
||||
orchestrator = Kubernetes(orchestration_params)
|
||||
if not orchestrator.setup():
|
||||
if not orchestrator.setup(args.checkpoint_restore_dir):
|
||||
print("Could not setup.")
|
||||
return 1
|
||||
|
||||
@@ -394,7 +395,9 @@ class CoachLauncher(object):
|
||||
|
||||
# validate the checkpoints args
|
||||
if args.checkpoint_restore_dir is not None and not os.path.exists(args.checkpoint_restore_dir):
|
||||
screen.error("The requested checkpoint folder to load from does not exist.")
|
||||
# If distributed trainer, the checkpoint dir is not yet available so skipping the check in that case.
|
||||
if not (args.distributed_coach and args.distributed_coach_run_type in [RunType.TRAINER, RunType.ROLLOUT_WORKER]):
|
||||
screen.error("The requested checkpoint folder to load from does not exist.")
|
||||
|
||||
# validate the checkpoints args
|
||||
if args.checkpoint_restore_file is not None and not glob(args.checkpoint_restore_file + '*'):
|
||||
|
||||
Reference in New Issue
Block a user