1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 19:20:19 +01:00

Making stop condition optional by using a flag (#113)

* apply stop condition flag (default: ignore the stop condition)
This commit is contained in:
Gal Leibovich
2018-11-18 13:37:39 +02:00
committed by Gal Novik
parent 449bcfb4e1
commit 9fd4d55623
3 changed files with 18 additions and 6 deletions

View File

@@ -512,7 +512,7 @@ class AgentParameters(Parameters):
class TaskParameters(Parameters):
def __init__(self, framework_type: Frameworks=Frameworks.tensorflow, evaluate_only: bool=False, use_cpu: bool=False,
experiment_path='/tmp', seed=None, checkpoint_save_secs=None, checkpoint_restore_dir=None,
checkpoint_save_dir=None, export_onnx_graph: bool=False):
checkpoint_save_dir=None, export_onnx_graph: bool=False, apply_stop_condition: bool=False):
"""
:param framework_type: deep learning framework type. currently only tensorflow is supported
:param evaluate_only: the task will be used only for evaluating the model
@@ -523,6 +523,8 @@ class TaskParameters(Parameters):
:param checkpoint_restore_dir: the directory to restore the checkpoints from
:param checkpoint_save_dir: the directory to store the checkpoints in
:param export_onnx_graph: If set to True, this will export an onnx graph each time a checkpoint is saved
:param apply_stop_condition: If set to True, this will apply the stop condition defined by reaching a target success rate
"""
self.framework_type = framework_type
self.task_index = 0 # TODO: not really needed
@@ -534,6 +536,7 @@ class TaskParameters(Parameters):
self.checkpoint_save_dir = checkpoint_save_dir
self.seed = seed
self.export_onnx_graph = export_onnx_graph
self.apply_stop_condition = apply_stop_condition
class DistributedTaskParameters(TaskParameters):
@@ -541,7 +544,7 @@ class DistributedTaskParameters(TaskParameters):
task_index: int, evaluate_only: bool=False, num_tasks: int=None,
num_training_tasks: int=None, use_cpu: bool=False, experiment_path=None, dnd=None,
shared_memory_scratchpad=None, seed=None, checkpoint_save_secs=None, checkpoint_restore_dir=None,
checkpoint_save_dir=None, export_onnx_graph: bool=False):
checkpoint_save_dir=None, export_onnx_graph: bool=False, apply_stop_condition: bool=False):
"""
:param framework_type: deep learning framework type. currently only tensorflow is supported
:param evaluate_only: the task will be used only for evaluating the model
@@ -560,11 +563,13 @@ class DistributedTaskParameters(TaskParameters):
:param checkpoint_restore_dir: the directory to restore the checkpoints from
:param checkpoint_save_dir: the directory to store the checkpoints in
:param export_onnx_graph: If set to True, this will export an onnx graph each time a checkpoint is saved
:param apply_stop_condition: If set to True, this will apply the stop condition defined by reaching a target success rate
"""
super().__init__(framework_type=framework_type, evaluate_only=evaluate_only, use_cpu=use_cpu,
experiment_path=experiment_path, seed=seed, checkpoint_save_secs=checkpoint_save_secs,
checkpoint_restore_dir=checkpoint_restore_dir, checkpoint_save_dir=checkpoint_save_dir,
export_onnx_graph=export_onnx_graph)
export_onnx_graph=export_onnx_graph, apply_stop_condition=apply_stop_condition)
self.parameters_server_hosts = parameters_server_hosts
self.worker_hosts = worker_hosts
self.job_type = job_type