mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 19:20:19 +01:00
Export graph to ONNX (#61)
Implements the ONNX graph exporting feature. Currently does not work for NAF, C51 and A3C_LSTM due to unsupported TF layers in the tf2onnx library.
This commit is contained in:
@@ -14,6 +14,7 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
import time
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
@@ -22,7 +23,7 @@ from rl_coach.architectures.architecture import Architecture
|
||||
from rl_coach.base_parameters import AgentParameters, DistributedTaskParameters
|
||||
from rl_coach.core_types import GradientClippingMethod
|
||||
from rl_coach.spaces import SpacesDefinition
|
||||
from rl_coach.utils import force_list, squeeze_list
|
||||
from rl_coach.utils import force_list, squeeze_list, start_shell_command_and_wait
|
||||
|
||||
|
||||
def variable_summaries(var):
|
||||
@@ -614,3 +615,41 @@ class TensorFlowArchitecture(Architecture):
|
||||
if self.middleware.__class__.__name__ == 'LSTMMiddleware':
|
||||
self.curr_rnn_c_in = self.middleware.c_init
|
||||
self.curr_rnn_h_in = self.middleware.h_init
|
||||
|
||||
|
||||
def save_onnx_graph(input_nodes, output_nodes, checkpoint_save_dir: str) -> None:
|
||||
"""
|
||||
Given the input nodes and output nodes of the TF graph, save it as an onnx graph
|
||||
This requires the TF graph and the weights checkpoint to be stored in the experiment directory.
|
||||
It then freezes the graph (merging the graph and weights checkpoint), and converts it to ONNX.
|
||||
|
||||
:param input_nodes: A list of input nodes for the TF graph
|
||||
:param output_nodes: A list of output nodes for the TF graph
|
||||
:param checkpoint_save_dir: The directory to save the ONNX graph to
|
||||
:return: None
|
||||
"""
|
||||
import tf2onnx # just to verify that tf2onnx is installed
|
||||
|
||||
# freeze graph
|
||||
frozen_graph_path = os.path.join(checkpoint_save_dir, "frozen_graph.pb")
|
||||
freeze_graph_command = [
|
||||
"python -m tensorflow.python.tools.freeze_graph",
|
||||
"--input_graph={}".format(os.path.join(checkpoint_save_dir, "graphdef.pb")),
|
||||
"--input_binary=true",
|
||||
"--output_node_names='{}'".format(','.join([o.split(":")[0] for o in output_nodes])),
|
||||
"--input_checkpoint={}".format(tf.train.latest_checkpoint(checkpoint_save_dir)),
|
||||
"--output_graph={}".format(frozen_graph_path)
|
||||
]
|
||||
start_shell_command_and_wait(" ".join(freeze_graph_command))
|
||||
|
||||
# convert graph to onnx
|
||||
onnx_graph_path = os.path.join(checkpoint_save_dir, "model.onnx")
|
||||
convert_to_onnx_command = [
|
||||
"python -m tf2onnx.convert",
|
||||
"--input {}".format(frozen_graph_path),
|
||||
"--inputs '{}'".format(','.join(input_nodes)),
|
||||
"--outputs '{}'".format(','.join(output_nodes)),
|
||||
"--output {}".format(onnx_graph_path),
|
||||
"--verbose"
|
||||
]
|
||||
start_shell_command_and_wait(" ".join(convert_to_onnx_command))
|
||||
|
||||
@@ -127,11 +127,15 @@ class PPOHead(Head):
|
||||
self.input = [self.actions, self.old_policy_mean, self.old_policy_std]
|
||||
self.policy_mean = self.dense_layer(num_actions)(input_layer, name='policy_mean',
|
||||
kernel_initializer=normalized_columns_initializer(0.01))
|
||||
if self.is_local:
|
||||
|
||||
# for local networks in distributed settings, we need to move variables we create manually to the
|
||||
# tf.GraphKeys.LOCAL_VARIABLES collection, since the variable scope custom getter which is set in
|
||||
# Architecture does not apply to them
|
||||
if self.is_local and isinstance(self.ap.task_parameters, DistributedTaskParameters):
|
||||
self.policy_logstd = tf.Variable(np.zeros((1, num_actions)), dtype='float32',
|
||||
collections=[tf.GraphKeys.LOCAL_VARIABLES])
|
||||
collections=[tf.GraphKeys.LOCAL_VARIABLES], name="policy_log_std")
|
||||
else:
|
||||
self.policy_logstd = tf.Variable(np.zeros((1, num_actions)), dtype='float32')
|
||||
self.policy_logstd = tf.Variable(np.zeros((1, num_actions)), dtype='float32', name="policy_log_std")
|
||||
|
||||
self.policy_std = tf.tile(tf.exp(self.policy_logstd), [tf.shape(input_layer)[0], 1], name='policy_std')
|
||||
|
||||
|
||||
@@ -241,6 +241,10 @@ def parse_arguments(parser: argparse.ArgumentParser) -> argparse.Namespace:
|
||||
# checkpoints
|
||||
args.checkpoint_save_dir = os.path.join(args.experiment_path, 'checkpoint') if args.checkpoint_save_secs is not None else None
|
||||
|
||||
if args.export_onnx_graph and not args.checkpoint_save_secs:
|
||||
screen.warning("Exporting ONNX graphs requires setting the --checkpoint_save_secs flag. "
|
||||
"The --export_onnx_graph will have no effect.")
|
||||
|
||||
return args
|
||||
|
||||
|
||||
@@ -474,6 +478,14 @@ def main():
|
||||
help="(int) A seed to use for running the experiment",
|
||||
default=None,
|
||||
type=int)
|
||||
parser.add_argument('-onnx', '--export_onnx_graph',
|
||||
help="(flag) Export the ONNX graph to the experiment directory. "
|
||||
"This will have effect only if the --checkpoint_save_secs flag is used in order to store "
|
||||
"checkpoints, since the weights checkpoint are needed for the ONNX graph. "
|
||||
"Keep in mind that this can cause major overhead on the experiment. "
|
||||
"Exporting ONNX graphs requires manually installing the tf2onnx package "
|
||||
"(https://github.com/onnx/tensorflow-onnx).",
|
||||
action='store_true')
|
||||
parser.add_argument('-dc', '--distributed_coach',
|
||||
help="(flag) Use distributed Coach.",
|
||||
action='store_true')
|
||||
|
||||
@@ -100,6 +100,8 @@ class GraphManager(object):
|
||||
self.preset_validation_params = PresetValidationParameters()
|
||||
self.reset_required = False
|
||||
|
||||
self.last_checkpoint_saving_time = time.time()
|
||||
|
||||
# counters
|
||||
self.total_steps_counters = {
|
||||
RunPhase.HEATUP: TotalStepsCounter(),
|
||||
@@ -227,28 +229,47 @@ class GraphManager(object):
|
||||
# restore from checkpoint if given
|
||||
self.restore_checkpoint()
|
||||
|
||||
# tf.train.write_graph(tf.get_default_graph(),
|
||||
# logdir=self.task_parameters.save_checkpoint_dir,
|
||||
# name='graphdef.pb',
|
||||
# as_text=False)
|
||||
# self.save_checkpoint()
|
||||
#
|
||||
# output_nodes = []
|
||||
# for level in self.level_managers:
|
||||
# for agent in level.agents.values():
|
||||
# for network in agent.networks.values():
|
||||
# for output in network.online_network.outputs:
|
||||
# output_nodes.append(output.name.split(":")[0])
|
||||
#
|
||||
# freeze_graph_command = [
|
||||
# "python -m tensorflow.python.tools.freeze_graph",
|
||||
# "--input_graph={}".format(os.path.join(self.task_parameters.save_checkpoint_dir, "graphdef.pb")),
|
||||
# "--input_binary=true",
|
||||
# "--output_node_names='{}'".format(','.join(output_nodes)),
|
||||
# "--input_checkpoint={}".format(os.path.join(self.task_parameters.save_checkpoint_dir, "0_Step-0.ckpt")),
|
||||
# "--output_graph={}".format(os.path.join(self.task_parameters.save_checkpoint_dir, "frozen_graph.pb"))
|
||||
# ]
|
||||
# start_shell_command_and_wait(" ".join(freeze_graph_command))
|
||||
# the TF graph is static, and therefore is saved once - in the beginning of the experiment
|
||||
if hasattr(self.task_parameters, 'checkpoint_save_dir') and self.task_parameters.checkpoint_save_dir:
|
||||
self.save_graph()
|
||||
|
||||
def save_graph(self) -> None:
|
||||
"""
|
||||
Save the TF graph to a protobuf description file in the experiment directory
|
||||
:return: None
|
||||
"""
|
||||
import tensorflow as tf
|
||||
|
||||
# write graph
|
||||
tf.train.write_graph(tf.get_default_graph(),
|
||||
logdir=self.task_parameters.checkpoint_save_dir,
|
||||
name='graphdef.pb',
|
||||
as_text=False)
|
||||
|
||||
def save_onnx_graph(self) -> None:
|
||||
"""
|
||||
Save the graph as an ONNX graph.
|
||||
This requires the graph and the weights checkpoint to be stored in the experiment directory.
|
||||
It then freezes the graph (merging the graph and weights checkpoint), and converts it to ONNX.
|
||||
:return: None
|
||||
"""
|
||||
|
||||
# collect input and output nodes
|
||||
input_nodes = []
|
||||
output_nodes = []
|
||||
for level in self.level_managers:
|
||||
for agent in level.agents.values():
|
||||
for network in agent.networks.values():
|
||||
for input_key, input in network.online_network.inputs.items():
|
||||
if not input_key.startswith("output_"):
|
||||
input_nodes.append(input.name)
|
||||
for output in network.online_network.outputs:
|
||||
output_nodes.append(output.name)
|
||||
|
||||
# TODO: make this framework agnostic
|
||||
from rl_coach.architectures.tensorflow_components.architecture import save_onnx_graph
|
||||
|
||||
save_onnx_graph(input_nodes, output_nodes, self.task_parameters.checkpoint_save_dir)
|
||||
|
||||
def setup_logger(self) -> None:
|
||||
# dump documentation
|
||||
@@ -496,7 +517,7 @@ class GraphManager(object):
|
||||
and (self.task_parameters.task_index == 0 # distributed
|
||||
or self.task_parameters.task_index is None # single-worker
|
||||
):
|
||||
self.save_checkpoint()
|
||||
self.save_checkpoint()
|
||||
|
||||
def save_checkpoint(self):
|
||||
checkpoint_path = os.path.join(self.task_parameters.checkpoint_save_dir,
|
||||
@@ -511,6 +532,10 @@ class GraphManager(object):
|
||||
# this is required in order for agents to save additional information like a DND for example
|
||||
[manager.save_checkpoint(self.checkpoint_id) for manager in self.level_managers]
|
||||
|
||||
# the ONNX graph will be stored only if checkpoints are stored and the -onnx flag is used
|
||||
if self.task_parameters.export_onnx_graph:
|
||||
self.save_onnx_graph()
|
||||
|
||||
screen.log_dict(
|
||||
OrderedDict([
|
||||
("Saving in path", saved_checkpoint_path),
|
||||
|
||||
Reference in New Issue
Block a user