1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-18 19:50:17 +01:00

Adding should_train helper and should_train in graph_manager

This commit is contained in:
Ajay Deshpande
2018-10-05 14:22:15 -07:00
committed by zach dwiel
parent a2e57a44f1
commit a7f5442015
7 changed files with 126 additions and 20 deletions

View File

@@ -20,12 +20,12 @@ def main(preset: str, image: str='ajaysudh/testing:coach', num_workers: int=1, n
if data_store == "s3":
ds_params = DataStoreParameters("s3", "", "")
ds_params_instance = S3DataStoreParameters(ds_params=ds_params, end_point=s3_end_point, bucket_name=s3_bucket_name,
checkpoint_dir="/checkpoint")
checkpoint_dir="/checkpoint")
elif data_store == "nfs":
ds_params = DataStoreParameters("nfs", "kubernetes", {"namespace": "default"})
ds_params_instance = NFSDataStoreParameters(ds_params)
worker_run_type_params = RunTypeParameters(image, rollout_command, run_type="worker")
worker_run_type_params = RunTypeParameters(image, rollout_command, run_type="worker", num_replicas=num_workers)
trainer_run_type_params = RunTypeParameters(image, training_command, run_type="trainer")
orchestration_params = KubernetesParameters([worker_run_type_params, trainer_run_type_params],
@@ -53,7 +53,7 @@ def main(preset: str, image: str='ajaysudh/testing:coach', num_workers: int=1, n
orchestrator.trainer_logs()
except KeyboardInterrupt:
pass
# orchestrator.undeploy()
orchestrator.undeploy()
if __name__ == '__main__':
@@ -90,6 +90,11 @@ if __name__ == '__main__':
help="(string) S3 bucket name to use when S3 data store is used.",
type=str,
required=True)
parser.add_argument('--num-workers',
help="(string) Number of rollout workers",
type=int,
required=False,
default=1)
# parser.add_argument('--checkpoint_dir',
# help='(string) Path to a folder containing a checkpoint to write the model to.',
@@ -99,4 +104,4 @@ if __name__ == '__main__':
main(preset=args.preset, image=args.image, nfs_server=args.nfs_server, nfs_path=args.nfs_path,
memory_backend=args.memory_backend, data_store=args.data_store, s3_end_point=args.s3_end_point,
s3_bucket_name=args.s3_bucket_name)
s3_bucket_name=args.s3_bucket_name, num_workers=args.num_workers)