mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 19:20:19 +01:00
Making stop condition optional by using a flag (#113)
* apply stop condition flag (default: ignore the stop condition)
This commit is contained in:
@@ -512,7 +512,7 @@ class AgentParameters(Parameters):
|
|||||||
class TaskParameters(Parameters):
|
class TaskParameters(Parameters):
|
||||||
def __init__(self, framework_type: Frameworks=Frameworks.tensorflow, evaluate_only: bool=False, use_cpu: bool=False,
|
def __init__(self, framework_type: Frameworks=Frameworks.tensorflow, evaluate_only: bool=False, use_cpu: bool=False,
|
||||||
experiment_path='/tmp', seed=None, checkpoint_save_secs=None, checkpoint_restore_dir=None,
|
experiment_path='/tmp', seed=None, checkpoint_save_secs=None, checkpoint_restore_dir=None,
|
||||||
checkpoint_save_dir=None, export_onnx_graph: bool=False):
|
checkpoint_save_dir=None, export_onnx_graph: bool=False, apply_stop_condition: bool=False):
|
||||||
"""
|
"""
|
||||||
:param framework_type: deep learning framework type. currently only tensorflow is supported
|
:param framework_type: deep learning framework type. currently only tensorflow is supported
|
||||||
:param evaluate_only: the task will be used only for evaluating the model
|
:param evaluate_only: the task will be used only for evaluating the model
|
||||||
@@ -523,6 +523,8 @@ class TaskParameters(Parameters):
|
|||||||
:param checkpoint_restore_dir: the directory to restore the checkpoints from
|
:param checkpoint_restore_dir: the directory to restore the checkpoints from
|
||||||
:param checkpoint_save_dir: the directory to store the checkpoints in
|
:param checkpoint_save_dir: the directory to store the checkpoints in
|
||||||
:param export_onnx_graph: If set to True, this will export an onnx graph each time a checkpoint is saved
|
:param export_onnx_graph: If set to True, this will export an onnx graph each time a checkpoint is saved
|
||||||
|
:param apply_stop_condition: If set to True, this will apply the stop condition defined by reaching a target success rate
|
||||||
|
|
||||||
"""
|
"""
|
||||||
self.framework_type = framework_type
|
self.framework_type = framework_type
|
||||||
self.task_index = 0 # TODO: not really needed
|
self.task_index = 0 # TODO: not really needed
|
||||||
@@ -534,6 +536,7 @@ class TaskParameters(Parameters):
|
|||||||
self.checkpoint_save_dir = checkpoint_save_dir
|
self.checkpoint_save_dir = checkpoint_save_dir
|
||||||
self.seed = seed
|
self.seed = seed
|
||||||
self.export_onnx_graph = export_onnx_graph
|
self.export_onnx_graph = export_onnx_graph
|
||||||
|
self.apply_stop_condition = apply_stop_condition
|
||||||
|
|
||||||
|
|
||||||
class DistributedTaskParameters(TaskParameters):
|
class DistributedTaskParameters(TaskParameters):
|
||||||
@@ -541,7 +544,7 @@ class DistributedTaskParameters(TaskParameters):
|
|||||||
task_index: int, evaluate_only: bool=False, num_tasks: int=None,
|
task_index: int, evaluate_only: bool=False, num_tasks: int=None,
|
||||||
num_training_tasks: int=None, use_cpu: bool=False, experiment_path=None, dnd=None,
|
num_training_tasks: int=None, use_cpu: bool=False, experiment_path=None, dnd=None,
|
||||||
shared_memory_scratchpad=None, seed=None, checkpoint_save_secs=None, checkpoint_restore_dir=None,
|
shared_memory_scratchpad=None, seed=None, checkpoint_save_secs=None, checkpoint_restore_dir=None,
|
||||||
checkpoint_save_dir=None, export_onnx_graph: bool=False):
|
checkpoint_save_dir=None, export_onnx_graph: bool=False, apply_stop_condition: bool=False):
|
||||||
"""
|
"""
|
||||||
:param framework_type: deep learning framework type. currently only tensorflow is supported
|
:param framework_type: deep learning framework type. currently only tensorflow is supported
|
||||||
:param evaluate_only: the task will be used only for evaluating the model
|
:param evaluate_only: the task will be used only for evaluating the model
|
||||||
@@ -560,11 +563,13 @@ class DistributedTaskParameters(TaskParameters):
|
|||||||
:param checkpoint_restore_dir: the directory to restore the checkpoints from
|
:param checkpoint_restore_dir: the directory to restore the checkpoints from
|
||||||
:param checkpoint_save_dir: the directory to store the checkpoints in
|
:param checkpoint_save_dir: the directory to store the checkpoints in
|
||||||
:param export_onnx_graph: If set to True, this will export an onnx graph each time a checkpoint is saved
|
:param export_onnx_graph: If set to True, this will export an onnx graph each time a checkpoint is saved
|
||||||
|
:param apply_stop_condition: If set to True, this will apply the stop condition defined by reaching a target success rate
|
||||||
|
|
||||||
"""
|
"""
|
||||||
super().__init__(framework_type=framework_type, evaluate_only=evaluate_only, use_cpu=use_cpu,
|
super().__init__(framework_type=framework_type, evaluate_only=evaluate_only, use_cpu=use_cpu,
|
||||||
experiment_path=experiment_path, seed=seed, checkpoint_save_secs=checkpoint_save_secs,
|
experiment_path=experiment_path, seed=seed, checkpoint_save_secs=checkpoint_save_secs,
|
||||||
checkpoint_restore_dir=checkpoint_restore_dir, checkpoint_save_dir=checkpoint_save_dir,
|
checkpoint_restore_dir=checkpoint_restore_dir, checkpoint_save_dir=checkpoint_save_dir,
|
||||||
export_onnx_graph=export_onnx_graph)
|
export_onnx_graph=export_onnx_graph, apply_stop_condition=apply_stop_condition)
|
||||||
self.parameters_server_hosts = parameters_server_hosts
|
self.parameters_server_hosts = parameters_server_hosts
|
||||||
self.worker_hosts = worker_hosts
|
self.worker_hosts = worker_hosts
|
||||||
self.job_type = job_type
|
self.job_type = job_type
|
||||||
|
|||||||
@@ -541,6 +541,11 @@ class CoachLauncher(object):
|
|||||||
type=RunType,
|
type=RunType,
|
||||||
default=RunType.ORCHESTRATOR,
|
default=RunType.ORCHESTRATOR,
|
||||||
choices=list(RunType))
|
choices=list(RunType))
|
||||||
|
parser.add_argument('-asc', '--apply_stop_condition',
|
||||||
|
help="(flag) If set, this will apply a stop condition on the run, defined by reaching a"
|
||||||
|
"target success rate as set by the environment or a custom success rate as defined "
|
||||||
|
"in the preset. ",
|
||||||
|
action='store_true')
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
@@ -590,7 +595,8 @@ class CoachLauncher(object):
|
|||||||
checkpoint_save_secs=args.checkpoint_save_secs,
|
checkpoint_save_secs=args.checkpoint_save_secs,
|
||||||
checkpoint_restore_dir=args.checkpoint_restore_dir,
|
checkpoint_restore_dir=args.checkpoint_restore_dir,
|
||||||
checkpoint_save_dir=args.checkpoint_save_dir,
|
checkpoint_save_dir=args.checkpoint_save_dir,
|
||||||
export_onnx_graph=args.export_onnx_graph
|
export_onnx_graph=args.export_onnx_graph,
|
||||||
|
apply_stop_condition=args.apply_stop_condition
|
||||||
)
|
)
|
||||||
|
|
||||||
start_graph(graph_manager=graph_manager, task_parameters=task_parameters)
|
start_graph(graph_manager=graph_manager, task_parameters=task_parameters)
|
||||||
@@ -629,7 +635,8 @@ class CoachLauncher(object):
|
|||||||
checkpoint_save_secs=args.checkpoint_save_secs,
|
checkpoint_save_secs=args.checkpoint_save_secs,
|
||||||
checkpoint_restore_dir=args.checkpoint_restore_dir,
|
checkpoint_restore_dir=args.checkpoint_restore_dir,
|
||||||
checkpoint_save_dir=args.checkpoint_save_dir,
|
checkpoint_save_dir=args.checkpoint_save_dir,
|
||||||
export_onnx_graph=args.export_onnx_graph
|
export_onnx_graph=args.export_onnx_graph,
|
||||||
|
apply_stop_condition=args.apply_stop_condition
|
||||||
)
|
)
|
||||||
# we assume that only the evaluation workers are rendering
|
# we assume that only the evaluation workers are rendering
|
||||||
graph_manager.visualization_parameters.render = args.render and evaluation_worker
|
graph_manager.visualization_parameters.render = args.render and evaluation_worker
|
||||||
|
|||||||
@@ -669,4 +669,4 @@ class GraphManager(object):
|
|||||||
self.memory_backend = get_memory_backend(self.agent_params.memory.memory_backend_params)
|
self.memory_backend = get_memory_backend(self.agent_params.memory.memory_backend_params)
|
||||||
|
|
||||||
def should_stop(self) -> bool:
|
def should_stop(self) -> bool:
|
||||||
return all([manager.should_stop() for manager in self.level_managers])
|
return self.task_parameters.apply_stop_condition and all([manager.should_stop() for manager in self.level_managers])
|
||||||
|
|||||||
Reference in New Issue
Block a user