1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 19:20:19 +01:00

Adding support for evaluation only mode with predefined number of steps (#225)

This commit is contained in:
Gal Novik
2019-03-03 10:03:45 +02:00
committed by Gal Leibovich
parent 2c1a9dbf20
commit 10220be9be
3 changed files with 24 additions and 16 deletions

View File

@@ -550,13 +550,14 @@ class AgentParameters(Parameters):
class TaskParameters(Parameters):
def __init__(self, framework_type: Frameworks=Frameworks.tensorflow, evaluate_only: bool=False, use_cpu: bool=False,
def __init__(self, framework_type: Frameworks=Frameworks.tensorflow, evaluate_only: int=None, use_cpu: bool=False,
experiment_path='/tmp', seed=None, checkpoint_save_secs=None, checkpoint_restore_dir=None,
checkpoint_save_dir=None, export_onnx_graph: bool=False, apply_stop_condition: bool=False,
num_gpu: int=1):
"""
:param framework_type: deep learning framework type. currently only tensorflow is supported
:param evaluate_only: the task will be used only for evaluating the model
:param evaluate_only: if not None, the task will be used only for evaluating the model for the given number of steps.
A value of 0 means that task will be evaluated for an infinite number of steps.
:param use_cpu: use the cpu for this task
:param experiment_path: the path to the directory which will store all the experiment outputs
:param seed: a seed to use for the random numbers generator
@@ -583,13 +584,14 @@ class TaskParameters(Parameters):
class DistributedTaskParameters(TaskParameters):
def __init__(self, framework_type: Frameworks, parameters_server_hosts: str, worker_hosts: str, job_type: str,
task_index: int, evaluate_only: bool=False, num_tasks: int=None,
task_index: int, evaluate_only: int=None, num_tasks: int=None,
num_training_tasks: int=None, use_cpu: bool=False, experiment_path=None, dnd=None,
shared_memory_scratchpad=None, seed=None, checkpoint_save_secs=None, checkpoint_restore_dir=None,
checkpoint_save_dir=None, export_onnx_graph: bool=False, apply_stop_condition: bool=False):
"""
:param framework_type: deep learning framework type. currently only tensorflow is supported
:param evaluate_only: the task will be used only for evaluating the model
:param evaluate_only: if not None, the task will be used only for evaluating the model for the given number of steps.
A value of 0 means that task will be evaluated for an infinite number of steps.
:param parameters_server_hosts: comma-separated list of hostname:port pairs to which the parameter servers are
assigned
:param worker_hosts: comma-separated list of hostname:port pairs to which the workers are assigned