mirror of
https://github.com/gryf/coach.git
synced 2025-12-18 19:50:17 +01:00
Adding should_train helper and should_train in graph_manager
This commit is contained in:
committed by
zach dwiel
parent
a2e57a44f1
commit
a7f5442015
@@ -20,12 +20,12 @@ def main(preset: str, image: str='ajaysudh/testing:coach', num_workers: int=1, n
|
||||
if data_store == "s3":
|
||||
ds_params = DataStoreParameters("s3", "", "")
|
||||
ds_params_instance = S3DataStoreParameters(ds_params=ds_params, end_point=s3_end_point, bucket_name=s3_bucket_name,
|
||||
checkpoint_dir="/checkpoint")
|
||||
checkpoint_dir="/checkpoint")
|
||||
elif data_store == "nfs":
|
||||
ds_params = DataStoreParameters("nfs", "kubernetes", {"namespace": "default"})
|
||||
ds_params_instance = NFSDataStoreParameters(ds_params)
|
||||
|
||||
worker_run_type_params = RunTypeParameters(image, rollout_command, run_type="worker")
|
||||
worker_run_type_params = RunTypeParameters(image, rollout_command, run_type="worker", num_replicas=num_workers)
|
||||
trainer_run_type_params = RunTypeParameters(image, training_command, run_type="trainer")
|
||||
|
||||
orchestration_params = KubernetesParameters([worker_run_type_params, trainer_run_type_params],
|
||||
@@ -53,7 +53,7 @@ def main(preset: str, image: str='ajaysudh/testing:coach', num_workers: int=1, n
|
||||
orchestrator.trainer_logs()
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
# orchestrator.undeploy()
|
||||
orchestrator.undeploy()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
@@ -90,6 +90,11 @@ if __name__ == '__main__':
|
||||
help="(string) S3 bucket name to use when S3 data store is used.",
|
||||
type=str,
|
||||
required=True)
|
||||
parser.add_argument('--num-workers',
|
||||
help="(string) Number of rollout workers",
|
||||
type=int,
|
||||
required=False,
|
||||
default=1)
|
||||
|
||||
# parser.add_argument('--checkpoint_dir',
|
||||
# help='(string) Path to a folder containing a checkpoint to write the model to.',
|
||||
@@ -99,4 +104,4 @@ if __name__ == '__main__':
|
||||
|
||||
main(preset=args.preset, image=args.image, nfs_server=args.nfs_server, nfs_path=args.nfs_path,
|
||||
memory_backend=args.memory_backend, data_store=args.data_store, s3_end_point=args.s3_end_point,
|
||||
s3_bucket_name=args.s3_bucket_name)
|
||||
s3_bucket_name=args.s3_bucket_name, num_workers=args.num_workers)
|
||||
|
||||
Reference in New Issue
Block a user