mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 11:10:20 +01:00
Fix cmd line arguments handling (#68)
* refactoring the merging of the task parameters and the command line parameters * removing some unused command line arguments * fix for saving checkpoints when not passing through coach.py
This commit is contained in:
@@ -432,14 +432,17 @@ class AgentParameters(Parameters):
|
||||
|
||||
class TaskParameters(Parameters):
|
||||
def __init__(self, framework_type: Frameworks=Frameworks.tensorflow, evaluate_only: bool=False, use_cpu: bool=False,
|
||||
experiment_path='/tmp', seed=None, checkpoint_save_secs=None):
|
||||
experiment_path='/tmp', seed=None, checkpoint_save_secs=None, checkpoint_restore_dir=None,
|
||||
checkpoint_save_dir=None):
|
||||
"""
|
||||
:param framework_type: deep learning framework type. currently only tensorflow is supported
|
||||
:param evaluate_only: the task will be used only for evaluating the model
|
||||
:param use_cpu: use the cpu for this task
|
||||
:param experiment_path: the path to the directory which will store all the experiment outputs
|
||||
:param checkpoint_save_secs: the number of seconds between each checkpoint saving
|
||||
:param seed: a seed to use for the random numbers generator
|
||||
:param checkpoint_save_secs: the number of seconds between each checkpoint saving
|
||||
:param checkpoint_restore_dir: the directory to restore the checkpoints from
|
||||
:param checkpoint_save_dir: the directory to store the checkpoints in
|
||||
"""
|
||||
self.framework_type = framework_type
|
||||
self.task_index = 0 # TODO: not really needed
|
||||
@@ -447,6 +450,8 @@ class TaskParameters(Parameters):
|
||||
self.use_cpu = use_cpu
|
||||
self.experiment_path = experiment_path
|
||||
self.checkpoint_save_secs = checkpoint_save_secs
|
||||
self.checkpoint_restore_dir = checkpoint_restore_dir
|
||||
self.checkpoint_save_dir = checkpoint_save_dir
|
||||
self.seed = seed
|
||||
|
||||
|
||||
@@ -454,7 +459,8 @@ class DistributedTaskParameters(TaskParameters):
|
||||
def __init__(self, framework_type: Frameworks, parameters_server_hosts: str, worker_hosts: str, job_type: str,
|
||||
task_index: int, evaluate_only: bool=False, num_tasks: int=None,
|
||||
num_training_tasks: int=None, use_cpu: bool=False, experiment_path=None, dnd=None,
|
||||
shared_memory_scratchpad=None, seed=None):
|
||||
shared_memory_scratchpad=None, seed=None, checkpoint_save_secs=None, checkpoint_restore_dir=None,
|
||||
checkpoint_save_dir=None):
|
||||
"""
|
||||
:param framework_type: deep learning framework type. currently only tensorflow is supported
|
||||
:param evaluate_only: the task will be used only for evaluating the model
|
||||
@@ -469,9 +475,13 @@ class DistributedTaskParameters(TaskParameters):
|
||||
:param experiment_path: the path to the directory which will store all the experiment outputs
|
||||
:param dnd: an external DND to use for NEC. This is a workaround needed for a shared DND not using the scratchpad.
|
||||
:param seed: a seed to use for the random numbers generator
|
||||
:param checkpoint_save_secs: the number of seconds between each checkpoint saving
|
||||
:param checkpoint_restore_dir: the directory to restore the checkpoints from
|
||||
:param checkpoint_save_dir: the directory to store the checkpoints in
|
||||
"""
|
||||
super().__init__(framework_type=framework_type, evaluate_only=evaluate_only, use_cpu=use_cpu,
|
||||
experiment_path=experiment_path, seed=seed)
|
||||
experiment_path=experiment_path, seed=seed, checkpoint_save_secs=checkpoint_save_secs,
|
||||
checkpoint_restore_dir=checkpoint_restore_dir, checkpoint_save_dir=checkpoint_save_dir)
|
||||
self.parameters_server_hosts = parameters_server_hosts
|
||||
self.worker_hosts = worker_hosts
|
||||
self.job_type = job_type
|
||||
|
||||
Reference in New Issue
Block a user