From 83e0b09a6a6d08b6ac6896418bded853355c55ee Mon Sep 17 00:00:00 2001 From: Itai Caspi <30383381+itaicaspi-intel@users.noreply.github.com> Date: Thu, 8 Nov 2018 12:52:42 +0200 Subject: [PATCH] adding the missing export_onnx_graph parameter to task parameters (#73) --- .../tensorflow_components/heads/ppo_head.py | 2 +- rl_coach/base_parameters.py | 10 +++++++--- rl_coach/coach.py | 6 ++++-- 3 files changed, 12 insertions(+), 6 deletions(-) diff --git a/rl_coach/architectures/tensorflow_components/heads/ppo_head.py b/rl_coach/architectures/tensorflow_components/heads/ppo_head.py index f15e6f2..2dacaea 100644 --- a/rl_coach/architectures/tensorflow_components/heads/ppo_head.py +++ b/rl_coach/architectures/tensorflow_components/heads/ppo_head.py @@ -19,7 +19,7 @@ import tensorflow as tf from rl_coach.architectures.tensorflow_components.layers import Dense from rl_coach.architectures.tensorflow_components.heads.head import Head, normalized_columns_initializer -from rl_coach.base_parameters import AgentParameters +from rl_coach.base_parameters import AgentParameters, DistributedTaskParameters from rl_coach.core_types import ActionProbabilities from rl_coach.spaces import BoxActionSpace, DiscreteActionSpace from rl_coach.spaces import SpacesDefinition diff --git a/rl_coach/base_parameters.py b/rl_coach/base_parameters.py index 03ed774..2aa09c9 100644 --- a/rl_coach/base_parameters.py +++ b/rl_coach/base_parameters.py @@ -436,7 +436,7 @@ class AgentParameters(Parameters): class TaskParameters(Parameters): def __init__(self, framework_type: Frameworks=Frameworks.tensorflow, evaluate_only: bool=False, use_cpu: bool=False, experiment_path='/tmp', seed=None, checkpoint_save_secs=None, checkpoint_restore_dir=None, - checkpoint_save_dir=None): + checkpoint_save_dir=None, export_onnx_graph: bool=False): """ :param framework_type: deep learning framework type. currently only tensorflow is supported :param evaluate_only: the task will be used only for evaluating the model @@ -446,6 +446,7 @@ class TaskParameters(Parameters): :param checkpoint_save_secs: the number of seconds between each checkpoint saving :param checkpoint_restore_dir: the directory to restore the checkpoints from :param checkpoint_save_dir: the directory to store the checkpoints in + :param export_onnx_graph: If set to True, this will export an onnx graph each time a checkpoint is saved """ self.framework_type = framework_type self.task_index = 0 # TODO: not really needed @@ -456,6 +457,7 @@ class TaskParameters(Parameters): self.checkpoint_restore_dir = checkpoint_restore_dir self.checkpoint_save_dir = checkpoint_save_dir self.seed = seed + self.export_onnx_graph = export_onnx_graph class DistributedTaskParameters(TaskParameters): @@ -463,7 +465,7 @@ class DistributedTaskParameters(TaskParameters): task_index: int, evaluate_only: bool=False, num_tasks: int=None, num_training_tasks: int=None, use_cpu: bool=False, experiment_path=None, dnd=None, shared_memory_scratchpad=None, seed=None, checkpoint_save_secs=None, checkpoint_restore_dir=None, - checkpoint_save_dir=None): + checkpoint_save_dir=None, export_onnx_graph: bool=False): """ :param framework_type: deep learning framework type. currently only tensorflow is supported :param evaluate_only: the task will be used only for evaluating the model @@ -481,10 +483,12 @@ class DistributedTaskParameters(TaskParameters): :param checkpoint_save_secs: the number of seconds between each checkpoint saving :param checkpoint_restore_dir: the directory to restore the checkpoints from :param checkpoint_save_dir: the directory to store the checkpoints in + :param export_onnx_graph: If set to True, this will export an onnx graph each time a checkpoint is saved """ super().__init__(framework_type=framework_type, evaluate_only=evaluate_only, use_cpu=use_cpu, experiment_path=experiment_path, seed=seed, checkpoint_save_secs=checkpoint_save_secs, - checkpoint_restore_dir=checkpoint_restore_dir, checkpoint_save_dir=checkpoint_save_dir) + checkpoint_restore_dir=checkpoint_restore_dir, checkpoint_save_dir=checkpoint_save_dir, + export_onnx_graph=export_onnx_graph) self.parameters_server_hosts = parameters_server_hosts self.worker_hosts = worker_hosts self.job_type = job_type diff --git a/rl_coach/coach.py b/rl_coach/coach.py index 19daf40..f750d2d 100644 --- a/rl_coach/coach.py +++ b/rl_coach/coach.py @@ -535,7 +535,8 @@ def main(): use_cpu=args.use_cpu, checkpoint_save_secs=args.checkpoint_save_secs, checkpoint_restore_dir=args.checkpoint_restore_dir, - checkpoint_save_dir=args.checkpoint_save_dir + checkpoint_save_dir=args.checkpoint_save_dir, + export_onnx_graph=args.export_onnx_graph ) start_graph(graph_manager=graph_manager, task_parameters=task_parameters) @@ -574,7 +575,8 @@ def main(): seed=args.seed+task_index if args.seed is not None else None, # each worker gets a different seed checkpoint_save_secs=args.checkpoint_save_secs, checkpoint_restore_dir=args.checkpoint_restore_dir, - checkpoint_save_dir=args.checkpoint_save_dir + checkpoint_save_dir=args.checkpoint_save_dir, + export_onnx_graph=args.export_onnx_graph ) # we assume that only the evaluation workers are rendering