From 811152126c841c1e100efb28ae5e6a5e5e329b6c Mon Sep 17 00:00:00 2001 From: Itai Caspi <30383381+itaicaspi-intel@users.noreply.github.com> Date: Tue, 6 Nov 2018 10:55:21 +0200 Subject: [PATCH] Export graph to ONNX (#61) Implements the ONNX graph exporting feature. Currently does not work for NAF, C51 and A3C_LSTM due to unsupported TF layers in the tf2onnx library. --- .../tensorflow_components/architecture.py | 41 ++++++++++- .../tensorflow_components/heads/ppo_head.py | 10 ++- rl_coach/coach.py | 12 ++++ rl_coach/graph_managers/graph_manager.py | 71 +++++++++++++------ 4 files changed, 107 insertions(+), 27 deletions(-) diff --git a/rl_coach/architectures/tensorflow_components/architecture.py b/rl_coach/architectures/tensorflow_components/architecture.py index 7c5c248..3353b52 100644 --- a/rl_coach/architectures/tensorflow_components/architecture.py +++ b/rl_coach/architectures/tensorflow_components/architecture.py @@ -14,6 +14,7 @@ # limitations under the License. # import time +import os import numpy as np import tensorflow as tf @@ -22,7 +23,7 @@ from rl_coach.architectures.architecture import Architecture from rl_coach.base_parameters import AgentParameters, DistributedTaskParameters from rl_coach.core_types import GradientClippingMethod from rl_coach.spaces import SpacesDefinition -from rl_coach.utils import force_list, squeeze_list +from rl_coach.utils import force_list, squeeze_list, start_shell_command_and_wait def variable_summaries(var): @@ -614,3 +615,41 @@ class TensorFlowArchitecture(Architecture): if self.middleware.__class__.__name__ == 'LSTMMiddleware': self.curr_rnn_c_in = self.middleware.c_init self.curr_rnn_h_in = self.middleware.h_init + + +def save_onnx_graph(input_nodes, output_nodes, checkpoint_save_dir: str) -> None: + """ + Given the input nodes and output nodes of the TF graph, save it as an onnx graph + This requires the TF graph and the weights checkpoint to be stored in the experiment directory. + It then freezes the graph (merging the graph and weights checkpoint), and converts it to ONNX. + + :param input_nodes: A list of input nodes for the TF graph + :param output_nodes: A list of output nodes for the TF graph + :param checkpoint_save_dir: The directory to save the ONNX graph to + :return: None + """ + import tf2onnx # just to verify that tf2onnx is installed + + # freeze graph + frozen_graph_path = os.path.join(checkpoint_save_dir, "frozen_graph.pb") + freeze_graph_command = [ + "python -m tensorflow.python.tools.freeze_graph", + "--input_graph={}".format(os.path.join(checkpoint_save_dir, "graphdef.pb")), + "--input_binary=true", + "--output_node_names='{}'".format(','.join([o.split(":")[0] for o in output_nodes])), + "--input_checkpoint={}".format(tf.train.latest_checkpoint(checkpoint_save_dir)), + "--output_graph={}".format(frozen_graph_path) + ] + start_shell_command_and_wait(" ".join(freeze_graph_command)) + + # convert graph to onnx + onnx_graph_path = os.path.join(checkpoint_save_dir, "model.onnx") + convert_to_onnx_command = [ + "python -m tf2onnx.convert", + "--input {}".format(frozen_graph_path), + "--inputs '{}'".format(','.join(input_nodes)), + "--outputs '{}'".format(','.join(output_nodes)), + "--output {}".format(onnx_graph_path), + "--verbose" + ] + start_shell_command_and_wait(" ".join(convert_to_onnx_command)) diff --git a/rl_coach/architectures/tensorflow_components/heads/ppo_head.py b/rl_coach/architectures/tensorflow_components/heads/ppo_head.py index 6ce7898..f15e6f2 100644 --- a/rl_coach/architectures/tensorflow_components/heads/ppo_head.py +++ b/rl_coach/architectures/tensorflow_components/heads/ppo_head.py @@ -127,11 +127,15 @@ class PPOHead(Head): self.input = [self.actions, self.old_policy_mean, self.old_policy_std] self.policy_mean = self.dense_layer(num_actions)(input_layer, name='policy_mean', kernel_initializer=normalized_columns_initializer(0.01)) - if self.is_local: + + # for local networks in distributed settings, we need to move variables we create manually to the + # tf.GraphKeys.LOCAL_VARIABLES collection, since the variable scope custom getter which is set in + # Architecture does not apply to them + if self.is_local and isinstance(self.ap.task_parameters, DistributedTaskParameters): self.policy_logstd = tf.Variable(np.zeros((1, num_actions)), dtype='float32', - collections=[tf.GraphKeys.LOCAL_VARIABLES]) + collections=[tf.GraphKeys.LOCAL_VARIABLES], name="policy_log_std") else: - self.policy_logstd = tf.Variable(np.zeros((1, num_actions)), dtype='float32') + self.policy_logstd = tf.Variable(np.zeros((1, num_actions)), dtype='float32', name="policy_log_std") self.policy_std = tf.tile(tf.exp(self.policy_logstd), [tf.shape(input_layer)[0], 1], name='policy_std') diff --git a/rl_coach/coach.py b/rl_coach/coach.py index 48d8076..c9cd304 100644 --- a/rl_coach/coach.py +++ b/rl_coach/coach.py @@ -241,6 +241,10 @@ def parse_arguments(parser: argparse.ArgumentParser) -> argparse.Namespace: # checkpoints args.checkpoint_save_dir = os.path.join(args.experiment_path, 'checkpoint') if args.checkpoint_save_secs is not None else None + if args.export_onnx_graph and not args.checkpoint_save_secs: + screen.warning("Exporting ONNX graphs requires setting the --checkpoint_save_secs flag. " + "The --export_onnx_graph will have no effect.") + return args @@ -474,6 +478,14 @@ def main(): help="(int) A seed to use for running the experiment", default=None, type=int) + parser.add_argument('-onnx', '--export_onnx_graph', + help="(flag) Export the ONNX graph to the experiment directory. " + "This will have effect only if the --checkpoint_save_secs flag is used in order to store " + "checkpoints, since the weights checkpoint are needed for the ONNX graph. " + "Keep in mind that this can cause major overhead on the experiment. " + "Exporting ONNX graphs requires manually installing the tf2onnx package " + "(https://github.com/onnx/tensorflow-onnx).", + action='store_true') parser.add_argument('-dc', '--distributed_coach', help="(flag) Use distributed Coach.", action='store_true') diff --git a/rl_coach/graph_managers/graph_manager.py b/rl_coach/graph_managers/graph_manager.py index b373c1c..7743885 100644 --- a/rl_coach/graph_managers/graph_manager.py +++ b/rl_coach/graph_managers/graph_manager.py @@ -100,6 +100,8 @@ class GraphManager(object): self.preset_validation_params = PresetValidationParameters() self.reset_required = False + self.last_checkpoint_saving_time = time.time() + # counters self.total_steps_counters = { RunPhase.HEATUP: TotalStepsCounter(), @@ -227,28 +229,47 @@ class GraphManager(object): # restore from checkpoint if given self.restore_checkpoint() - # tf.train.write_graph(tf.get_default_graph(), - # logdir=self.task_parameters.save_checkpoint_dir, - # name='graphdef.pb', - # as_text=False) - # self.save_checkpoint() - # - # output_nodes = [] - # for level in self.level_managers: - # for agent in level.agents.values(): - # for network in agent.networks.values(): - # for output in network.online_network.outputs: - # output_nodes.append(output.name.split(":")[0]) - # - # freeze_graph_command = [ - # "python -m tensorflow.python.tools.freeze_graph", - # "--input_graph={}".format(os.path.join(self.task_parameters.save_checkpoint_dir, "graphdef.pb")), - # "--input_binary=true", - # "--output_node_names='{}'".format(','.join(output_nodes)), - # "--input_checkpoint={}".format(os.path.join(self.task_parameters.save_checkpoint_dir, "0_Step-0.ckpt")), - # "--output_graph={}".format(os.path.join(self.task_parameters.save_checkpoint_dir, "frozen_graph.pb")) - # ] - # start_shell_command_and_wait(" ".join(freeze_graph_command)) + # the TF graph is static, and therefore is saved once - in the beginning of the experiment + if hasattr(self.task_parameters, 'checkpoint_save_dir') and self.task_parameters.checkpoint_save_dir: + self.save_graph() + + def save_graph(self) -> None: + """ + Save the TF graph to a protobuf description file in the experiment directory + :return: None + """ + import tensorflow as tf + + # write graph + tf.train.write_graph(tf.get_default_graph(), + logdir=self.task_parameters.checkpoint_save_dir, + name='graphdef.pb', + as_text=False) + + def save_onnx_graph(self) -> None: + """ + Save the graph as an ONNX graph. + This requires the graph and the weights checkpoint to be stored in the experiment directory. + It then freezes the graph (merging the graph and weights checkpoint), and converts it to ONNX. + :return: None + """ + + # collect input and output nodes + input_nodes = [] + output_nodes = [] + for level in self.level_managers: + for agent in level.agents.values(): + for network in agent.networks.values(): + for input_key, input in network.online_network.inputs.items(): + if not input_key.startswith("output_"): + input_nodes.append(input.name) + for output in network.online_network.outputs: + output_nodes.append(output.name) + + # TODO: make this framework agnostic + from rl_coach.architectures.tensorflow_components.architecture import save_onnx_graph + + save_onnx_graph(input_nodes, output_nodes, self.task_parameters.checkpoint_save_dir) def setup_logger(self) -> None: # dump documentation @@ -496,7 +517,7 @@ class GraphManager(object): and (self.task_parameters.task_index == 0 # distributed or self.task_parameters.task_index is None # single-worker ): - self.save_checkpoint() + self.save_checkpoint() def save_checkpoint(self): checkpoint_path = os.path.join(self.task_parameters.checkpoint_save_dir, @@ -511,6 +532,10 @@ class GraphManager(object): # this is required in order for agents to save additional information like a DND for example [manager.save_checkpoint(self.checkpoint_id) for manager in self.level_managers] + # the ONNX graph will be stored only if checkpoints are stored and the -onnx flag is used + if self.task_parameters.export_onnx_graph: + self.save_onnx_graph() + screen.log_dict( OrderedDict([ ("Saving in path", saved_checkpoint_path),