mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 19:20:19 +01:00
Fix cmd line arguments handling (#68)
* refactoring the merging of the task parameters and the command line parameters * removing some unused command line arguments * fix for saving checkpoints when not passing through coach.py
This commit is contained in:
@@ -213,13 +213,10 @@ def parse_arguments(parser: argparse.ArgumentParser) -> argparse.Namespace:
|
||||
screen.error("The requested checkpoint folder to load from does not exist.")
|
||||
|
||||
# no preset was given. check if the user requested to play some environment on its own
|
||||
if args.preset is None and args.play:
|
||||
if args.environment_type:
|
||||
args.agent_type = 'Human'
|
||||
else:
|
||||
screen.error('When no preset is given for Coach to run, and the user requests human control over '
|
||||
'the environment, the user is expected to input the desired environment_type and level.'
|
||||
'\nAt least one of these parameters was not given.')
|
||||
if args.preset is None and args.play and not args.environment_type:
|
||||
screen.error('When no preset is given for Coach to run, and the user requests human control over '
|
||||
'the environment, the user is expected to input the desired environment_type and level.'
|
||||
'\nAt least one of these parameters was not given.')
|
||||
elif args.preset and args.play:
|
||||
screen.error("Both the --preset and the --play flags were set. These flags can not be used together. "
|
||||
"For human control, please use the --play flag together with the environment type flag (-et)")
|
||||
@@ -428,24 +425,8 @@ def main():
|
||||
parser.add_argument('-dm', '--dump_mp4',
|
||||
help="(flag) Enable the mp4 saving functionality.",
|
||||
action='store_true')
|
||||
parser.add_argument('-at', '--agent_type',
|
||||
help="(string) Choose an agent type class to override on top of the selected preset. "
|
||||
"If no preset is defined, a preset can be set from the command-line by combining settings "
|
||||
"which are set by using --agent_type, --experiment_type, --environemnt_type",
|
||||
default=None,
|
||||
type=str)
|
||||
parser.add_argument('-et', '--environment_type',
|
||||
help="(string) Choose an environment type class to override on top of the selected preset."
|
||||
"If no preset is defined, a preset can be set from the command-line by combining settings "
|
||||
"which are set by using --agent_type, --experiment_type, --environemnt_type",
|
||||
default=None,
|
||||
type=str)
|
||||
parser.add_argument('-ept', '--exploration_policy_type',
|
||||
help="(string) Choose an exploration policy type class to override on top of the selected "
|
||||
"preset."
|
||||
"If no preset is defined, a preset can be set from the command-line by combining settings "
|
||||
"which are set by using --agent_type, --experiment_type, --environemnt_type"
|
||||
,
|
||||
help="(string) Choose an environment type class to override on top of the selected preset.",
|
||||
default=None,
|
||||
type=str)
|
||||
parser.add_argument('-lvl', '--level',
|
||||
@@ -546,13 +527,16 @@ def main():
|
||||
# Single-threaded runs
|
||||
if args.num_workers == 1:
|
||||
# Start the training or evaluation
|
||||
task_parameters = TaskParameters(framework_type=args.framework,
|
||||
evaluate_only=args.evaluate,
|
||||
experiment_path=args.experiment_path,
|
||||
seed=args.seed,
|
||||
use_cpu=args.use_cpu,
|
||||
checkpoint_save_secs=args.checkpoint_save_secs)
|
||||
task_parameters.__dict__ = add_items_to_dict(task_parameters.__dict__, args.__dict__)
|
||||
task_parameters = TaskParameters(
|
||||
framework_type=args.framework,
|
||||
evaluate_only=args.evaluate,
|
||||
experiment_path=args.experiment_path,
|
||||
seed=args.seed,
|
||||
use_cpu=args.use_cpu,
|
||||
checkpoint_save_secs=args.checkpoint_save_secs,
|
||||
checkpoint_restore_dir=args.checkpoint_restore_dir,
|
||||
checkpoint_save_dir=args.checkpoint_save_dir
|
||||
)
|
||||
|
||||
start_graph(graph_manager=graph_manager, task_parameters=task_parameters)
|
||||
|
||||
@@ -575,19 +559,24 @@ def main():
|
||||
|
||||
def start_distributed_task(job_type, task_index, evaluation_worker=False,
|
||||
shared_memory_scratchpad=shared_memory_scratchpad):
|
||||
task_parameters = DistributedTaskParameters(framework_type=args.framework,
|
||||
parameters_server_hosts=ps_hosts,
|
||||
worker_hosts=worker_hosts,
|
||||
job_type=job_type,
|
||||
task_index=task_index,
|
||||
evaluate_only=evaluation_worker,
|
||||
use_cpu=args.use_cpu,
|
||||
num_tasks=total_tasks, # training tasks + 1 evaluation task
|
||||
num_training_tasks=args.num_workers,
|
||||
experiment_path=args.experiment_path,
|
||||
shared_memory_scratchpad=shared_memory_scratchpad,
|
||||
seed=args.seed+task_index if args.seed is not None else None) # each worker gets a different seed
|
||||
task_parameters.__dict__ = add_items_to_dict(task_parameters.__dict__, args.__dict__)
|
||||
task_parameters = DistributedTaskParameters(
|
||||
framework_type=args.framework,
|
||||
parameters_server_hosts=ps_hosts,
|
||||
worker_hosts=worker_hosts,
|
||||
job_type=job_type,
|
||||
task_index=task_index,
|
||||
evaluate_only=evaluation_worker,
|
||||
use_cpu=args.use_cpu,
|
||||
num_tasks=total_tasks, # training tasks + 1 evaluation task
|
||||
num_training_tasks=args.num_workers,
|
||||
experiment_path=args.experiment_path,
|
||||
shared_memory_scratchpad=shared_memory_scratchpad,
|
||||
seed=args.seed+task_index if args.seed is not None else None, # each worker gets a different seed
|
||||
checkpoint_save_secs=args.checkpoint_save_secs,
|
||||
checkpoint_restore_dir=args.checkpoint_restore_dir,
|
||||
checkpoint_save_dir=args.checkpoint_save_dir
|
||||
)
|
||||
|
||||
# we assume that only the evaluation workers are rendering
|
||||
graph_manager.visualization_parameters.render = args.render and evaluation_worker
|
||||
p = Process(target=start_graph, args=(graph_manager, task_parameters))
|
||||
@@ -607,7 +596,7 @@ def main():
|
||||
workers.append(start_distributed_task("worker", task_index))
|
||||
|
||||
# evaluation worker
|
||||
if args.evaluation_worker:
|
||||
if args.evaluation_worker or args.render:
|
||||
evaluation_worker = start_distributed_task("worker", args.num_workers, evaluation_worker=True)
|
||||
|
||||
# wait for all workers
|
||||
|
||||
Reference in New Issue
Block a user