mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 19:20:19 +01:00
Fix cmd line arguments handling (#68)
* refactoring the merging of the task parameters and the command line parameters * removing some unused command line arguments * fix for saving checkpoints when not passing through coach.py
This commit is contained in:
@@ -94,7 +94,7 @@ def create_monitored_session(target: tf.train.Server, task_index: int,
|
|||||||
is_chief=is_chief,
|
is_chief=is_chief,
|
||||||
hooks=[],
|
hooks=[],
|
||||||
checkpoint_dir=checkpoint_dir,
|
checkpoint_dir=checkpoint_dir,
|
||||||
checkpoint_save_secs=checkpoint_save_secs,
|
save_checkpoint_secs=checkpoint_save_secs,
|
||||||
config=config
|
config=config
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -432,14 +432,17 @@ class AgentParameters(Parameters):
|
|||||||
|
|
||||||
class TaskParameters(Parameters):
|
class TaskParameters(Parameters):
|
||||||
def __init__(self, framework_type: Frameworks=Frameworks.tensorflow, evaluate_only: bool=False, use_cpu: bool=False,
|
def __init__(self, framework_type: Frameworks=Frameworks.tensorflow, evaluate_only: bool=False, use_cpu: bool=False,
|
||||||
experiment_path='/tmp', seed=None, checkpoint_save_secs=None):
|
experiment_path='/tmp', seed=None, checkpoint_save_secs=None, checkpoint_restore_dir=None,
|
||||||
|
checkpoint_save_dir=None):
|
||||||
"""
|
"""
|
||||||
:param framework_type: deep learning framework type. currently only tensorflow is supported
|
:param framework_type: deep learning framework type. currently only tensorflow is supported
|
||||||
:param evaluate_only: the task will be used only for evaluating the model
|
:param evaluate_only: the task will be used only for evaluating the model
|
||||||
:param use_cpu: use the cpu for this task
|
:param use_cpu: use the cpu for this task
|
||||||
:param experiment_path: the path to the directory which will store all the experiment outputs
|
:param experiment_path: the path to the directory which will store all the experiment outputs
|
||||||
:param checkpoint_save_secs: the number of seconds between each checkpoint saving
|
|
||||||
:param seed: a seed to use for the random numbers generator
|
:param seed: a seed to use for the random numbers generator
|
||||||
|
:param checkpoint_save_secs: the number of seconds between each checkpoint saving
|
||||||
|
:param checkpoint_restore_dir: the directory to restore the checkpoints from
|
||||||
|
:param checkpoint_save_dir: the directory to store the checkpoints in
|
||||||
"""
|
"""
|
||||||
self.framework_type = framework_type
|
self.framework_type = framework_type
|
||||||
self.task_index = 0 # TODO: not really needed
|
self.task_index = 0 # TODO: not really needed
|
||||||
@@ -447,6 +450,8 @@ class TaskParameters(Parameters):
|
|||||||
self.use_cpu = use_cpu
|
self.use_cpu = use_cpu
|
||||||
self.experiment_path = experiment_path
|
self.experiment_path = experiment_path
|
||||||
self.checkpoint_save_secs = checkpoint_save_secs
|
self.checkpoint_save_secs = checkpoint_save_secs
|
||||||
|
self.checkpoint_restore_dir = checkpoint_restore_dir
|
||||||
|
self.checkpoint_save_dir = checkpoint_save_dir
|
||||||
self.seed = seed
|
self.seed = seed
|
||||||
|
|
||||||
|
|
||||||
@@ -454,7 +459,8 @@ class DistributedTaskParameters(TaskParameters):
|
|||||||
def __init__(self, framework_type: Frameworks, parameters_server_hosts: str, worker_hosts: str, job_type: str,
|
def __init__(self, framework_type: Frameworks, parameters_server_hosts: str, worker_hosts: str, job_type: str,
|
||||||
task_index: int, evaluate_only: bool=False, num_tasks: int=None,
|
task_index: int, evaluate_only: bool=False, num_tasks: int=None,
|
||||||
num_training_tasks: int=None, use_cpu: bool=False, experiment_path=None, dnd=None,
|
num_training_tasks: int=None, use_cpu: bool=False, experiment_path=None, dnd=None,
|
||||||
shared_memory_scratchpad=None, seed=None):
|
shared_memory_scratchpad=None, seed=None, checkpoint_save_secs=None, checkpoint_restore_dir=None,
|
||||||
|
checkpoint_save_dir=None):
|
||||||
"""
|
"""
|
||||||
:param framework_type: deep learning framework type. currently only tensorflow is supported
|
:param framework_type: deep learning framework type. currently only tensorflow is supported
|
||||||
:param evaluate_only: the task will be used only for evaluating the model
|
:param evaluate_only: the task will be used only for evaluating the model
|
||||||
@@ -469,9 +475,13 @@ class DistributedTaskParameters(TaskParameters):
|
|||||||
:param experiment_path: the path to the directory which will store all the experiment outputs
|
:param experiment_path: the path to the directory which will store all the experiment outputs
|
||||||
:param dnd: an external DND to use for NEC. This is a workaround needed for a shared DND not using the scratchpad.
|
:param dnd: an external DND to use for NEC. This is a workaround needed for a shared DND not using the scratchpad.
|
||||||
:param seed: a seed to use for the random numbers generator
|
:param seed: a seed to use for the random numbers generator
|
||||||
|
:param checkpoint_save_secs: the number of seconds between each checkpoint saving
|
||||||
|
:param checkpoint_restore_dir: the directory to restore the checkpoints from
|
||||||
|
:param checkpoint_save_dir: the directory to store the checkpoints in
|
||||||
"""
|
"""
|
||||||
super().__init__(framework_type=framework_type, evaluate_only=evaluate_only, use_cpu=use_cpu,
|
super().__init__(framework_type=framework_type, evaluate_only=evaluate_only, use_cpu=use_cpu,
|
||||||
experiment_path=experiment_path, seed=seed)
|
experiment_path=experiment_path, seed=seed, checkpoint_save_secs=checkpoint_save_secs,
|
||||||
|
checkpoint_restore_dir=checkpoint_restore_dir, checkpoint_save_dir=checkpoint_save_dir)
|
||||||
self.parameters_server_hosts = parameters_server_hosts
|
self.parameters_server_hosts = parameters_server_hosts
|
||||||
self.worker_hosts = worker_hosts
|
self.worker_hosts = worker_hosts
|
||||||
self.job_type = job_type
|
self.job_type = job_type
|
||||||
|
|||||||
@@ -213,10 +213,7 @@ def parse_arguments(parser: argparse.ArgumentParser) -> argparse.Namespace:
|
|||||||
screen.error("The requested checkpoint folder to load from does not exist.")
|
screen.error("The requested checkpoint folder to load from does not exist.")
|
||||||
|
|
||||||
# no preset was given. check if the user requested to play some environment on its own
|
# no preset was given. check if the user requested to play some environment on its own
|
||||||
if args.preset is None and args.play:
|
if args.preset is None and args.play and not args.environment_type:
|
||||||
if args.environment_type:
|
|
||||||
args.agent_type = 'Human'
|
|
||||||
else:
|
|
||||||
screen.error('When no preset is given for Coach to run, and the user requests human control over '
|
screen.error('When no preset is given for Coach to run, and the user requests human control over '
|
||||||
'the environment, the user is expected to input the desired environment_type and level.'
|
'the environment, the user is expected to input the desired environment_type and level.'
|
||||||
'\nAt least one of these parameters was not given.')
|
'\nAt least one of these parameters was not given.')
|
||||||
@@ -428,24 +425,8 @@ def main():
|
|||||||
parser.add_argument('-dm', '--dump_mp4',
|
parser.add_argument('-dm', '--dump_mp4',
|
||||||
help="(flag) Enable the mp4 saving functionality.",
|
help="(flag) Enable the mp4 saving functionality.",
|
||||||
action='store_true')
|
action='store_true')
|
||||||
parser.add_argument('-at', '--agent_type',
|
|
||||||
help="(string) Choose an agent type class to override on top of the selected preset. "
|
|
||||||
"If no preset is defined, a preset can be set from the command-line by combining settings "
|
|
||||||
"which are set by using --agent_type, --experiment_type, --environemnt_type",
|
|
||||||
default=None,
|
|
||||||
type=str)
|
|
||||||
parser.add_argument('-et', '--environment_type',
|
parser.add_argument('-et', '--environment_type',
|
||||||
help="(string) Choose an environment type class to override on top of the selected preset."
|
help="(string) Choose an environment type class to override on top of the selected preset.",
|
||||||
"If no preset is defined, a preset can be set from the command-line by combining settings "
|
|
||||||
"which are set by using --agent_type, --experiment_type, --environemnt_type",
|
|
||||||
default=None,
|
|
||||||
type=str)
|
|
||||||
parser.add_argument('-ept', '--exploration_policy_type',
|
|
||||||
help="(string) Choose an exploration policy type class to override on top of the selected "
|
|
||||||
"preset."
|
|
||||||
"If no preset is defined, a preset can be set from the command-line by combining settings "
|
|
||||||
"which are set by using --agent_type, --experiment_type, --environemnt_type"
|
|
||||||
,
|
|
||||||
default=None,
|
default=None,
|
||||||
type=str)
|
type=str)
|
||||||
parser.add_argument('-lvl', '--level',
|
parser.add_argument('-lvl', '--level',
|
||||||
@@ -546,13 +527,16 @@ def main():
|
|||||||
# Single-threaded runs
|
# Single-threaded runs
|
||||||
if args.num_workers == 1:
|
if args.num_workers == 1:
|
||||||
# Start the training or evaluation
|
# Start the training or evaluation
|
||||||
task_parameters = TaskParameters(framework_type=args.framework,
|
task_parameters = TaskParameters(
|
||||||
|
framework_type=args.framework,
|
||||||
evaluate_only=args.evaluate,
|
evaluate_only=args.evaluate,
|
||||||
experiment_path=args.experiment_path,
|
experiment_path=args.experiment_path,
|
||||||
seed=args.seed,
|
seed=args.seed,
|
||||||
use_cpu=args.use_cpu,
|
use_cpu=args.use_cpu,
|
||||||
checkpoint_save_secs=args.checkpoint_save_secs)
|
checkpoint_save_secs=args.checkpoint_save_secs,
|
||||||
task_parameters.__dict__ = add_items_to_dict(task_parameters.__dict__, args.__dict__)
|
checkpoint_restore_dir=args.checkpoint_restore_dir,
|
||||||
|
checkpoint_save_dir=args.checkpoint_save_dir
|
||||||
|
)
|
||||||
|
|
||||||
start_graph(graph_manager=graph_manager, task_parameters=task_parameters)
|
start_graph(graph_manager=graph_manager, task_parameters=task_parameters)
|
||||||
|
|
||||||
@@ -575,7 +559,8 @@ def main():
|
|||||||
|
|
||||||
def start_distributed_task(job_type, task_index, evaluation_worker=False,
|
def start_distributed_task(job_type, task_index, evaluation_worker=False,
|
||||||
shared_memory_scratchpad=shared_memory_scratchpad):
|
shared_memory_scratchpad=shared_memory_scratchpad):
|
||||||
task_parameters = DistributedTaskParameters(framework_type=args.framework,
|
task_parameters = DistributedTaskParameters(
|
||||||
|
framework_type=args.framework,
|
||||||
parameters_server_hosts=ps_hosts,
|
parameters_server_hosts=ps_hosts,
|
||||||
worker_hosts=worker_hosts,
|
worker_hosts=worker_hosts,
|
||||||
job_type=job_type,
|
job_type=job_type,
|
||||||
@@ -586,8 +571,12 @@ def main():
|
|||||||
num_training_tasks=args.num_workers,
|
num_training_tasks=args.num_workers,
|
||||||
experiment_path=args.experiment_path,
|
experiment_path=args.experiment_path,
|
||||||
shared_memory_scratchpad=shared_memory_scratchpad,
|
shared_memory_scratchpad=shared_memory_scratchpad,
|
||||||
seed=args.seed+task_index if args.seed is not None else None) # each worker gets a different seed
|
seed=args.seed+task_index if args.seed is not None else None, # each worker gets a different seed
|
||||||
task_parameters.__dict__ = add_items_to_dict(task_parameters.__dict__, args.__dict__)
|
checkpoint_save_secs=args.checkpoint_save_secs,
|
||||||
|
checkpoint_restore_dir=args.checkpoint_restore_dir,
|
||||||
|
checkpoint_save_dir=args.checkpoint_save_dir
|
||||||
|
)
|
||||||
|
|
||||||
# we assume that only the evaluation workers are rendering
|
# we assume that only the evaluation workers are rendering
|
||||||
graph_manager.visualization_parameters.render = args.render and evaluation_worker
|
graph_manager.visualization_parameters.render = args.render and evaluation_worker
|
||||||
p = Process(target=start_graph, args=(graph_manager, task_parameters))
|
p = Process(target=start_graph, args=(graph_manager, task_parameters))
|
||||||
@@ -607,7 +596,7 @@ def main():
|
|||||||
workers.append(start_distributed_task("worker", task_index))
|
workers.append(start_distributed_task("worker", task_index))
|
||||||
|
|
||||||
# evaluation worker
|
# evaluation worker
|
||||||
if args.evaluation_worker:
|
if args.evaluation_worker or args.render:
|
||||||
evaluation_worker = start_distributed_task("worker", args.num_workers, evaluation_worker=True)
|
evaluation_worker = start_distributed_task("worker", args.num_workers, evaluation_worker=True)
|
||||||
|
|
||||||
# wait for all workers
|
# wait for all workers
|
||||||
|
|||||||
@@ -100,6 +100,8 @@ class GraphManager(object):
|
|||||||
self.preset_validation_params = PresetValidationParameters()
|
self.preset_validation_params = PresetValidationParameters()
|
||||||
self.reset_required = False
|
self.reset_required = False
|
||||||
|
|
||||||
|
# timers
|
||||||
|
self.graph_creation_time = None
|
||||||
self.last_checkpoint_saving_time = time.time()
|
self.last_checkpoint_saving_time = time.time()
|
||||||
|
|
||||||
# counters
|
# counters
|
||||||
@@ -520,6 +522,8 @@ class GraphManager(object):
|
|||||||
self.save_checkpoint()
|
self.save_checkpoint()
|
||||||
|
|
||||||
def save_checkpoint(self):
|
def save_checkpoint(self):
|
||||||
|
if self.task_parameters.checkpoint_save_dir is None:
|
||||||
|
self.task_parameters.checkpoint_save_dir = os.path.join(self.task_parameters.experiment_path, 'checkpoint')
|
||||||
checkpoint_path = os.path.join(self.task_parameters.checkpoint_save_dir,
|
checkpoint_path = os.path.join(self.task_parameters.checkpoint_save_dir,
|
||||||
"{}_Step-{}.ckpt".format(
|
"{}_Step-{}.ckpt".format(
|
||||||
self.checkpoint_id,
|
self.checkpoint_id,
|
||||||
|
|||||||
Reference in New Issue
Block a user