1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 19:20:19 +01:00

Export graph to ONNX (#61)

Implements the ONNX graph exporting feature. 
Currently does not work for NAF, C51 and A3C_LSTM due to unsupported TF layers in the tf2onnx library.
This commit is contained in:
Itai Caspi
2018-11-06 10:55:21 +02:00
committed by GitHub
parent d75df17d97
commit 811152126c
4 changed files with 107 additions and 27 deletions

View File

@@ -14,6 +14,7 @@
# limitations under the License.
#
import time
import os
import numpy as np
import tensorflow as tf
@@ -22,7 +23,7 @@ from rl_coach.architectures.architecture import Architecture
from rl_coach.base_parameters import AgentParameters, DistributedTaskParameters
from rl_coach.core_types import GradientClippingMethod
from rl_coach.spaces import SpacesDefinition
from rl_coach.utils import force_list, squeeze_list
from rl_coach.utils import force_list, squeeze_list, start_shell_command_and_wait
def variable_summaries(var):
@@ -614,3 +615,41 @@ class TensorFlowArchitecture(Architecture):
if self.middleware.__class__.__name__ == 'LSTMMiddleware':
self.curr_rnn_c_in = self.middleware.c_init
self.curr_rnn_h_in = self.middleware.h_init
def save_onnx_graph(input_nodes, output_nodes, checkpoint_save_dir: str) -> None:
"""
Given the input nodes and output nodes of the TF graph, save it as an onnx graph
This requires the TF graph and the weights checkpoint to be stored in the experiment directory.
It then freezes the graph (merging the graph and weights checkpoint), and converts it to ONNX.
:param input_nodes: A list of input nodes for the TF graph
:param output_nodes: A list of output nodes for the TF graph
:param checkpoint_save_dir: The directory to save the ONNX graph to
:return: None
"""
import tf2onnx # just to verify that tf2onnx is installed
# freeze graph
frozen_graph_path = os.path.join(checkpoint_save_dir, "frozen_graph.pb")
freeze_graph_command = [
"python -m tensorflow.python.tools.freeze_graph",
"--input_graph={}".format(os.path.join(checkpoint_save_dir, "graphdef.pb")),
"--input_binary=true",
"--output_node_names='{}'".format(','.join([o.split(":")[0] for o in output_nodes])),
"--input_checkpoint={}".format(tf.train.latest_checkpoint(checkpoint_save_dir)),
"--output_graph={}".format(frozen_graph_path)
]
start_shell_command_and_wait(" ".join(freeze_graph_command))
# convert graph to onnx
onnx_graph_path = os.path.join(checkpoint_save_dir, "model.onnx")
convert_to_onnx_command = [
"python -m tf2onnx.convert",
"--input {}".format(frozen_graph_path),
"--inputs '{}'".format(','.join(input_nodes)),
"--outputs '{}'".format(','.join(output_nodes)),
"--output {}".format(onnx_graph_path),
"--verbose"
]
start_shell_command_and_wait(" ".join(convert_to_onnx_command))

View File

@@ -127,11 +127,15 @@ class PPOHead(Head):
self.input = [self.actions, self.old_policy_mean, self.old_policy_std]
self.policy_mean = self.dense_layer(num_actions)(input_layer, name='policy_mean',
kernel_initializer=normalized_columns_initializer(0.01))
if self.is_local:
# for local networks in distributed settings, we need to move variables we create manually to the
# tf.GraphKeys.LOCAL_VARIABLES collection, since the variable scope custom getter which is set in
# Architecture does not apply to them
if self.is_local and isinstance(self.ap.task_parameters, DistributedTaskParameters):
self.policy_logstd = tf.Variable(np.zeros((1, num_actions)), dtype='float32',
collections=[tf.GraphKeys.LOCAL_VARIABLES])
collections=[tf.GraphKeys.LOCAL_VARIABLES], name="policy_log_std")
else:
self.policy_logstd = tf.Variable(np.zeros((1, num_actions)), dtype='float32')
self.policy_logstd = tf.Variable(np.zeros((1, num_actions)), dtype='float32', name="policy_log_std")
self.policy_std = tf.tile(tf.exp(self.policy_logstd), [tf.shape(input_layer)[0], 1], name='policy_std')

View File

@@ -241,6 +241,10 @@ def parse_arguments(parser: argparse.ArgumentParser) -> argparse.Namespace:
# checkpoints
args.checkpoint_save_dir = os.path.join(args.experiment_path, 'checkpoint') if args.checkpoint_save_secs is not None else None
if args.export_onnx_graph and not args.checkpoint_save_secs:
screen.warning("Exporting ONNX graphs requires setting the --checkpoint_save_secs flag. "
"The --export_onnx_graph will have no effect.")
return args
@@ -474,6 +478,14 @@ def main():
help="(int) A seed to use for running the experiment",
default=None,
type=int)
parser.add_argument('-onnx', '--export_onnx_graph',
help="(flag) Export the ONNX graph to the experiment directory. "
"This will have effect only if the --checkpoint_save_secs flag is used in order to store "
"checkpoints, since the weights checkpoint are needed for the ONNX graph. "
"Keep in mind that this can cause major overhead on the experiment. "
"Exporting ONNX graphs requires manually installing the tf2onnx package "
"(https://github.com/onnx/tensorflow-onnx).",
action='store_true')
parser.add_argument('-dc', '--distributed_coach',
help="(flag) Use distributed Coach.",
action='store_true')

View File

@@ -100,6 +100,8 @@ class GraphManager(object):
self.preset_validation_params = PresetValidationParameters()
self.reset_required = False
self.last_checkpoint_saving_time = time.time()
# counters
self.total_steps_counters = {
RunPhase.HEATUP: TotalStepsCounter(),
@@ -227,28 +229,47 @@ class GraphManager(object):
# restore from checkpoint if given
self.restore_checkpoint()
# tf.train.write_graph(tf.get_default_graph(),
# logdir=self.task_parameters.save_checkpoint_dir,
# name='graphdef.pb',
# as_text=False)
# self.save_checkpoint()
#
# output_nodes = []
# for level in self.level_managers:
# for agent in level.agents.values():
# for network in agent.networks.values():
# for output in network.online_network.outputs:
# output_nodes.append(output.name.split(":")[0])
#
# freeze_graph_command = [
# "python -m tensorflow.python.tools.freeze_graph",
# "--input_graph={}".format(os.path.join(self.task_parameters.save_checkpoint_dir, "graphdef.pb")),
# "--input_binary=true",
# "--output_node_names='{}'".format(','.join(output_nodes)),
# "--input_checkpoint={}".format(os.path.join(self.task_parameters.save_checkpoint_dir, "0_Step-0.ckpt")),
# "--output_graph={}".format(os.path.join(self.task_parameters.save_checkpoint_dir, "frozen_graph.pb"))
# ]
# start_shell_command_and_wait(" ".join(freeze_graph_command))
# the TF graph is static, and therefore is saved once - in the beginning of the experiment
if hasattr(self.task_parameters, 'checkpoint_save_dir') and self.task_parameters.checkpoint_save_dir:
self.save_graph()
def save_graph(self) -> None:
"""
Save the TF graph to a protobuf description file in the experiment directory
:return: None
"""
import tensorflow as tf
# write graph
tf.train.write_graph(tf.get_default_graph(),
logdir=self.task_parameters.checkpoint_save_dir,
name='graphdef.pb',
as_text=False)
def save_onnx_graph(self) -> None:
"""
Save the graph as an ONNX graph.
This requires the graph and the weights checkpoint to be stored in the experiment directory.
It then freezes the graph (merging the graph and weights checkpoint), and converts it to ONNX.
:return: None
"""
# collect input and output nodes
input_nodes = []
output_nodes = []
for level in self.level_managers:
for agent in level.agents.values():
for network in agent.networks.values():
for input_key, input in network.online_network.inputs.items():
if not input_key.startswith("output_"):
input_nodes.append(input.name)
for output in network.online_network.outputs:
output_nodes.append(output.name)
# TODO: make this framework agnostic
from rl_coach.architectures.tensorflow_components.architecture import save_onnx_graph
save_onnx_graph(input_nodes, output_nodes, self.task_parameters.checkpoint_save_dir)
def setup_logger(self) -> None:
# dump documentation
@@ -496,7 +517,7 @@ class GraphManager(object):
and (self.task_parameters.task_index == 0 # distributed
or self.task_parameters.task_index is None # single-worker
):
self.save_checkpoint()
self.save_checkpoint()
def save_checkpoint(self):
checkpoint_path = os.path.join(self.task_parameters.checkpoint_save_dir,
@@ -511,6 +532,10 @@ class GraphManager(object):
# this is required in order for agents to save additional information like a DND for example
[manager.save_checkpoint(self.checkpoint_id) for manager in self.level_managers]
# the ONNX graph will be stored only if checkpoints are stored and the -onnx flag is used
if self.task_parameters.export_onnx_graph:
self.save_onnx_graph()
screen.log_dict(
OrderedDict([
("Saving in path", saved_checkpoint_path),