mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 19:20:19 +01:00
Adding support for evaluation only mode with predefined number of steps (#225)
This commit is contained in:
@@ -550,13 +550,14 @@ class AgentParameters(Parameters):
|
||||
|
||||
|
||||
class TaskParameters(Parameters):
|
||||
def __init__(self, framework_type: Frameworks=Frameworks.tensorflow, evaluate_only: bool=False, use_cpu: bool=False,
|
||||
def __init__(self, framework_type: Frameworks=Frameworks.tensorflow, evaluate_only: int=None, use_cpu: bool=False,
|
||||
experiment_path='/tmp', seed=None, checkpoint_save_secs=None, checkpoint_restore_dir=None,
|
||||
checkpoint_save_dir=None, export_onnx_graph: bool=False, apply_stop_condition: bool=False,
|
||||
num_gpu: int=1):
|
||||
"""
|
||||
:param framework_type: deep learning framework type. currently only tensorflow is supported
|
||||
:param evaluate_only: the task will be used only for evaluating the model
|
||||
:param evaluate_only: if not None, the task will be used only for evaluating the model for the given number of steps.
|
||||
A value of 0 means that task will be evaluated for an infinite number of steps.
|
||||
:param use_cpu: use the cpu for this task
|
||||
:param experiment_path: the path to the directory which will store all the experiment outputs
|
||||
:param seed: a seed to use for the random numbers generator
|
||||
@@ -583,13 +584,14 @@ class TaskParameters(Parameters):
|
||||
|
||||
class DistributedTaskParameters(TaskParameters):
|
||||
def __init__(self, framework_type: Frameworks, parameters_server_hosts: str, worker_hosts: str, job_type: str,
|
||||
task_index: int, evaluate_only: bool=False, num_tasks: int=None,
|
||||
task_index: int, evaluate_only: int=None, num_tasks: int=None,
|
||||
num_training_tasks: int=None, use_cpu: bool=False, experiment_path=None, dnd=None,
|
||||
shared_memory_scratchpad=None, seed=None, checkpoint_save_secs=None, checkpoint_restore_dir=None,
|
||||
checkpoint_save_dir=None, export_onnx_graph: bool=False, apply_stop_condition: bool=False):
|
||||
"""
|
||||
:param framework_type: deep learning framework type. currently only tensorflow is supported
|
||||
:param evaluate_only: the task will be used only for evaluating the model
|
||||
:param evaluate_only: if not None, the task will be used only for evaluating the model for the given number of steps.
|
||||
A value of 0 means that task will be evaluated for an infinite number of steps.
|
||||
:param parameters_server_hosts: comma-separated list of hostname:port pairs to which the parameter servers are
|
||||
assigned
|
||||
:param worker_hosts: comma-separated list of hostname:port pairs to which the workers are assigned
|
||||
|
||||
Reference in New Issue
Block a user