diff --git a/rl_coach/base_parameters.py b/rl_coach/base_parameters.py index 03b27e9..d813ed6 100644 --- a/rl_coach/base_parameters.py +++ b/rl_coach/base_parameters.py @@ -512,7 +512,7 @@ class AgentParameters(Parameters): class TaskParameters(Parameters): def __init__(self, framework_type: Frameworks=Frameworks.tensorflow, evaluate_only: bool=False, use_cpu: bool=False, experiment_path='/tmp', seed=None, checkpoint_save_secs=None, checkpoint_restore_dir=None, - checkpoint_save_dir=None, export_onnx_graph: bool=False): + checkpoint_save_dir=None, export_onnx_graph: bool=False, apply_stop_condition: bool=False): """ :param framework_type: deep learning framework type. currently only tensorflow is supported :param evaluate_only: the task will be used only for evaluating the model @@ -523,6 +523,8 @@ class TaskParameters(Parameters): :param checkpoint_restore_dir: the directory to restore the checkpoints from :param checkpoint_save_dir: the directory to store the checkpoints in :param export_onnx_graph: If set to True, this will export an onnx graph each time a checkpoint is saved + :param apply_stop_condition: If set to True, this will apply the stop condition defined by reaching a target success rate + """ self.framework_type = framework_type self.task_index = 0 # TODO: not really needed @@ -534,6 +536,7 @@ class TaskParameters(Parameters): self.checkpoint_save_dir = checkpoint_save_dir self.seed = seed self.export_onnx_graph = export_onnx_graph + self.apply_stop_condition = apply_stop_condition class DistributedTaskParameters(TaskParameters): @@ -541,7 +544,7 @@ class DistributedTaskParameters(TaskParameters): task_index: int, evaluate_only: bool=False, num_tasks: int=None, num_training_tasks: int=None, use_cpu: bool=False, experiment_path=None, dnd=None, shared_memory_scratchpad=None, seed=None, checkpoint_save_secs=None, checkpoint_restore_dir=None, - checkpoint_save_dir=None, export_onnx_graph: bool=False): + checkpoint_save_dir=None, export_onnx_graph: bool=False, apply_stop_condition: bool=False): """ :param framework_type: deep learning framework type. currently only tensorflow is supported :param evaluate_only: the task will be used only for evaluating the model @@ -560,11 +563,13 @@ class DistributedTaskParameters(TaskParameters): :param checkpoint_restore_dir: the directory to restore the checkpoints from :param checkpoint_save_dir: the directory to store the checkpoints in :param export_onnx_graph: If set to True, this will export an onnx graph each time a checkpoint is saved + :param apply_stop_condition: If set to True, this will apply the stop condition defined by reaching a target success rate + """ super().__init__(framework_type=framework_type, evaluate_only=evaluate_only, use_cpu=use_cpu, experiment_path=experiment_path, seed=seed, checkpoint_save_secs=checkpoint_save_secs, checkpoint_restore_dir=checkpoint_restore_dir, checkpoint_save_dir=checkpoint_save_dir, - export_onnx_graph=export_onnx_graph) + export_onnx_graph=export_onnx_graph, apply_stop_condition=apply_stop_condition) self.parameters_server_hosts = parameters_server_hosts self.worker_hosts = worker_hosts self.job_type = job_type diff --git a/rl_coach/coach.py b/rl_coach/coach.py index 8404a38..0cc982f 100644 --- a/rl_coach/coach.py +++ b/rl_coach/coach.py @@ -541,6 +541,11 @@ class CoachLauncher(object): type=RunType, default=RunType.ORCHESTRATOR, choices=list(RunType)) + parser.add_argument('-asc', '--apply_stop_condition', + help="(flag) If set, this will apply a stop condition on the run, defined by reaching a" + "target success rate as set by the environment or a custom success rate as defined " + "in the preset. ", + action='store_true') return parser @@ -590,7 +595,8 @@ class CoachLauncher(object): checkpoint_save_secs=args.checkpoint_save_secs, checkpoint_restore_dir=args.checkpoint_restore_dir, checkpoint_save_dir=args.checkpoint_save_dir, - export_onnx_graph=args.export_onnx_graph + export_onnx_graph=args.export_onnx_graph, + apply_stop_condition=args.apply_stop_condition ) start_graph(graph_manager=graph_manager, task_parameters=task_parameters) @@ -629,7 +635,8 @@ class CoachLauncher(object): checkpoint_save_secs=args.checkpoint_save_secs, checkpoint_restore_dir=args.checkpoint_restore_dir, checkpoint_save_dir=args.checkpoint_save_dir, - export_onnx_graph=args.export_onnx_graph + export_onnx_graph=args.export_onnx_graph, + apply_stop_condition=args.apply_stop_condition ) # we assume that only the evaluation workers are rendering graph_manager.visualization_parameters.render = args.render and evaluation_worker diff --git a/rl_coach/graph_managers/graph_manager.py b/rl_coach/graph_managers/graph_manager.py index 10ee60a..ef67776 100644 --- a/rl_coach/graph_managers/graph_manager.py +++ b/rl_coach/graph_managers/graph_manager.py @@ -669,4 +669,4 @@ class GraphManager(object): self.memory_backend = get_memory_backend(self.agent_params.memory.memory_backend_params) def should_stop(self) -> bool: - return all([manager.should_stop() for manager in self.level_managers]) + return self.task_parameters.apply_stop_condition and all([manager.should_stop() for manager in self.level_managers])