mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 19:20:19 +01:00
Export graph to ONNX (#61)
Implements the ONNX graph exporting feature. Currently does not work for NAF, C51 and A3C_LSTM due to unsupported TF layers in the tf2onnx library.
This commit is contained in:
@@ -14,6 +14,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
import time
|
import time
|
||||||
|
import os
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
@@ -22,7 +23,7 @@ from rl_coach.architectures.architecture import Architecture
|
|||||||
from rl_coach.base_parameters import AgentParameters, DistributedTaskParameters
|
from rl_coach.base_parameters import AgentParameters, DistributedTaskParameters
|
||||||
from rl_coach.core_types import GradientClippingMethod
|
from rl_coach.core_types import GradientClippingMethod
|
||||||
from rl_coach.spaces import SpacesDefinition
|
from rl_coach.spaces import SpacesDefinition
|
||||||
from rl_coach.utils import force_list, squeeze_list
|
from rl_coach.utils import force_list, squeeze_list, start_shell_command_and_wait
|
||||||
|
|
||||||
|
|
||||||
def variable_summaries(var):
|
def variable_summaries(var):
|
||||||
@@ -614,3 +615,41 @@ class TensorFlowArchitecture(Architecture):
|
|||||||
if self.middleware.__class__.__name__ == 'LSTMMiddleware':
|
if self.middleware.__class__.__name__ == 'LSTMMiddleware':
|
||||||
self.curr_rnn_c_in = self.middleware.c_init
|
self.curr_rnn_c_in = self.middleware.c_init
|
||||||
self.curr_rnn_h_in = self.middleware.h_init
|
self.curr_rnn_h_in = self.middleware.h_init
|
||||||
|
|
||||||
|
|
||||||
|
def save_onnx_graph(input_nodes, output_nodes, checkpoint_save_dir: str) -> None:
|
||||||
|
"""
|
||||||
|
Given the input nodes and output nodes of the TF graph, save it as an onnx graph
|
||||||
|
This requires the TF graph and the weights checkpoint to be stored in the experiment directory.
|
||||||
|
It then freezes the graph (merging the graph and weights checkpoint), and converts it to ONNX.
|
||||||
|
|
||||||
|
:param input_nodes: A list of input nodes for the TF graph
|
||||||
|
:param output_nodes: A list of output nodes for the TF graph
|
||||||
|
:param checkpoint_save_dir: The directory to save the ONNX graph to
|
||||||
|
:return: None
|
||||||
|
"""
|
||||||
|
import tf2onnx # just to verify that tf2onnx is installed
|
||||||
|
|
||||||
|
# freeze graph
|
||||||
|
frozen_graph_path = os.path.join(checkpoint_save_dir, "frozen_graph.pb")
|
||||||
|
freeze_graph_command = [
|
||||||
|
"python -m tensorflow.python.tools.freeze_graph",
|
||||||
|
"--input_graph={}".format(os.path.join(checkpoint_save_dir, "graphdef.pb")),
|
||||||
|
"--input_binary=true",
|
||||||
|
"--output_node_names='{}'".format(','.join([o.split(":")[0] for o in output_nodes])),
|
||||||
|
"--input_checkpoint={}".format(tf.train.latest_checkpoint(checkpoint_save_dir)),
|
||||||
|
"--output_graph={}".format(frozen_graph_path)
|
||||||
|
]
|
||||||
|
start_shell_command_and_wait(" ".join(freeze_graph_command))
|
||||||
|
|
||||||
|
# convert graph to onnx
|
||||||
|
onnx_graph_path = os.path.join(checkpoint_save_dir, "model.onnx")
|
||||||
|
convert_to_onnx_command = [
|
||||||
|
"python -m tf2onnx.convert",
|
||||||
|
"--input {}".format(frozen_graph_path),
|
||||||
|
"--inputs '{}'".format(','.join(input_nodes)),
|
||||||
|
"--outputs '{}'".format(','.join(output_nodes)),
|
||||||
|
"--output {}".format(onnx_graph_path),
|
||||||
|
"--verbose"
|
||||||
|
]
|
||||||
|
start_shell_command_and_wait(" ".join(convert_to_onnx_command))
|
||||||
|
|||||||
@@ -127,11 +127,15 @@ class PPOHead(Head):
|
|||||||
self.input = [self.actions, self.old_policy_mean, self.old_policy_std]
|
self.input = [self.actions, self.old_policy_mean, self.old_policy_std]
|
||||||
self.policy_mean = self.dense_layer(num_actions)(input_layer, name='policy_mean',
|
self.policy_mean = self.dense_layer(num_actions)(input_layer, name='policy_mean',
|
||||||
kernel_initializer=normalized_columns_initializer(0.01))
|
kernel_initializer=normalized_columns_initializer(0.01))
|
||||||
if self.is_local:
|
|
||||||
|
# for local networks in distributed settings, we need to move variables we create manually to the
|
||||||
|
# tf.GraphKeys.LOCAL_VARIABLES collection, since the variable scope custom getter which is set in
|
||||||
|
# Architecture does not apply to them
|
||||||
|
if self.is_local and isinstance(self.ap.task_parameters, DistributedTaskParameters):
|
||||||
self.policy_logstd = tf.Variable(np.zeros((1, num_actions)), dtype='float32',
|
self.policy_logstd = tf.Variable(np.zeros((1, num_actions)), dtype='float32',
|
||||||
collections=[tf.GraphKeys.LOCAL_VARIABLES])
|
collections=[tf.GraphKeys.LOCAL_VARIABLES], name="policy_log_std")
|
||||||
else:
|
else:
|
||||||
self.policy_logstd = tf.Variable(np.zeros((1, num_actions)), dtype='float32')
|
self.policy_logstd = tf.Variable(np.zeros((1, num_actions)), dtype='float32', name="policy_log_std")
|
||||||
|
|
||||||
self.policy_std = tf.tile(tf.exp(self.policy_logstd), [tf.shape(input_layer)[0], 1], name='policy_std')
|
self.policy_std = tf.tile(tf.exp(self.policy_logstd), [tf.shape(input_layer)[0], 1], name='policy_std')
|
||||||
|
|
||||||
|
|||||||
@@ -241,6 +241,10 @@ def parse_arguments(parser: argparse.ArgumentParser) -> argparse.Namespace:
|
|||||||
# checkpoints
|
# checkpoints
|
||||||
args.checkpoint_save_dir = os.path.join(args.experiment_path, 'checkpoint') if args.checkpoint_save_secs is not None else None
|
args.checkpoint_save_dir = os.path.join(args.experiment_path, 'checkpoint') if args.checkpoint_save_secs is not None else None
|
||||||
|
|
||||||
|
if args.export_onnx_graph and not args.checkpoint_save_secs:
|
||||||
|
screen.warning("Exporting ONNX graphs requires setting the --checkpoint_save_secs flag. "
|
||||||
|
"The --export_onnx_graph will have no effect.")
|
||||||
|
|
||||||
return args
|
return args
|
||||||
|
|
||||||
|
|
||||||
@@ -474,6 +478,14 @@ def main():
|
|||||||
help="(int) A seed to use for running the experiment",
|
help="(int) A seed to use for running the experiment",
|
||||||
default=None,
|
default=None,
|
||||||
type=int)
|
type=int)
|
||||||
|
parser.add_argument('-onnx', '--export_onnx_graph',
|
||||||
|
help="(flag) Export the ONNX graph to the experiment directory. "
|
||||||
|
"This will have effect only if the --checkpoint_save_secs flag is used in order to store "
|
||||||
|
"checkpoints, since the weights checkpoint are needed for the ONNX graph. "
|
||||||
|
"Keep in mind that this can cause major overhead on the experiment. "
|
||||||
|
"Exporting ONNX graphs requires manually installing the tf2onnx package "
|
||||||
|
"(https://github.com/onnx/tensorflow-onnx).",
|
||||||
|
action='store_true')
|
||||||
parser.add_argument('-dc', '--distributed_coach',
|
parser.add_argument('-dc', '--distributed_coach',
|
||||||
help="(flag) Use distributed Coach.",
|
help="(flag) Use distributed Coach.",
|
||||||
action='store_true')
|
action='store_true')
|
||||||
|
|||||||
@@ -100,6 +100,8 @@ class GraphManager(object):
|
|||||||
self.preset_validation_params = PresetValidationParameters()
|
self.preset_validation_params = PresetValidationParameters()
|
||||||
self.reset_required = False
|
self.reset_required = False
|
||||||
|
|
||||||
|
self.last_checkpoint_saving_time = time.time()
|
||||||
|
|
||||||
# counters
|
# counters
|
||||||
self.total_steps_counters = {
|
self.total_steps_counters = {
|
||||||
RunPhase.HEATUP: TotalStepsCounter(),
|
RunPhase.HEATUP: TotalStepsCounter(),
|
||||||
@@ -227,28 +229,47 @@ class GraphManager(object):
|
|||||||
# restore from checkpoint if given
|
# restore from checkpoint if given
|
||||||
self.restore_checkpoint()
|
self.restore_checkpoint()
|
||||||
|
|
||||||
# tf.train.write_graph(tf.get_default_graph(),
|
# the TF graph is static, and therefore is saved once - in the beginning of the experiment
|
||||||
# logdir=self.task_parameters.save_checkpoint_dir,
|
if hasattr(self.task_parameters, 'checkpoint_save_dir') and self.task_parameters.checkpoint_save_dir:
|
||||||
# name='graphdef.pb',
|
self.save_graph()
|
||||||
# as_text=False)
|
|
||||||
# self.save_checkpoint()
|
def save_graph(self) -> None:
|
||||||
#
|
"""
|
||||||
# output_nodes = []
|
Save the TF graph to a protobuf description file in the experiment directory
|
||||||
# for level in self.level_managers:
|
:return: None
|
||||||
# for agent in level.agents.values():
|
"""
|
||||||
# for network in agent.networks.values():
|
import tensorflow as tf
|
||||||
# for output in network.online_network.outputs:
|
|
||||||
# output_nodes.append(output.name.split(":")[0])
|
# write graph
|
||||||
#
|
tf.train.write_graph(tf.get_default_graph(),
|
||||||
# freeze_graph_command = [
|
logdir=self.task_parameters.checkpoint_save_dir,
|
||||||
# "python -m tensorflow.python.tools.freeze_graph",
|
name='graphdef.pb',
|
||||||
# "--input_graph={}".format(os.path.join(self.task_parameters.save_checkpoint_dir, "graphdef.pb")),
|
as_text=False)
|
||||||
# "--input_binary=true",
|
|
||||||
# "--output_node_names='{}'".format(','.join(output_nodes)),
|
def save_onnx_graph(self) -> None:
|
||||||
# "--input_checkpoint={}".format(os.path.join(self.task_parameters.save_checkpoint_dir, "0_Step-0.ckpt")),
|
"""
|
||||||
# "--output_graph={}".format(os.path.join(self.task_parameters.save_checkpoint_dir, "frozen_graph.pb"))
|
Save the graph as an ONNX graph.
|
||||||
# ]
|
This requires the graph and the weights checkpoint to be stored in the experiment directory.
|
||||||
# start_shell_command_and_wait(" ".join(freeze_graph_command))
|
It then freezes the graph (merging the graph and weights checkpoint), and converts it to ONNX.
|
||||||
|
:return: None
|
||||||
|
"""
|
||||||
|
|
||||||
|
# collect input and output nodes
|
||||||
|
input_nodes = []
|
||||||
|
output_nodes = []
|
||||||
|
for level in self.level_managers:
|
||||||
|
for agent in level.agents.values():
|
||||||
|
for network in agent.networks.values():
|
||||||
|
for input_key, input in network.online_network.inputs.items():
|
||||||
|
if not input_key.startswith("output_"):
|
||||||
|
input_nodes.append(input.name)
|
||||||
|
for output in network.online_network.outputs:
|
||||||
|
output_nodes.append(output.name)
|
||||||
|
|
||||||
|
# TODO: make this framework agnostic
|
||||||
|
from rl_coach.architectures.tensorflow_components.architecture import save_onnx_graph
|
||||||
|
|
||||||
|
save_onnx_graph(input_nodes, output_nodes, self.task_parameters.checkpoint_save_dir)
|
||||||
|
|
||||||
def setup_logger(self) -> None:
|
def setup_logger(self) -> None:
|
||||||
# dump documentation
|
# dump documentation
|
||||||
@@ -496,7 +517,7 @@ class GraphManager(object):
|
|||||||
and (self.task_parameters.task_index == 0 # distributed
|
and (self.task_parameters.task_index == 0 # distributed
|
||||||
or self.task_parameters.task_index is None # single-worker
|
or self.task_parameters.task_index is None # single-worker
|
||||||
):
|
):
|
||||||
self.save_checkpoint()
|
self.save_checkpoint()
|
||||||
|
|
||||||
def save_checkpoint(self):
|
def save_checkpoint(self):
|
||||||
checkpoint_path = os.path.join(self.task_parameters.checkpoint_save_dir,
|
checkpoint_path = os.path.join(self.task_parameters.checkpoint_save_dir,
|
||||||
@@ -511,6 +532,10 @@ class GraphManager(object):
|
|||||||
# this is required in order for agents to save additional information like a DND for example
|
# this is required in order for agents to save additional information like a DND for example
|
||||||
[manager.save_checkpoint(self.checkpoint_id) for manager in self.level_managers]
|
[manager.save_checkpoint(self.checkpoint_id) for manager in self.level_managers]
|
||||||
|
|
||||||
|
# the ONNX graph will be stored only if checkpoints are stored and the -onnx flag is used
|
||||||
|
if self.task_parameters.export_onnx_graph:
|
||||||
|
self.save_onnx_graph()
|
||||||
|
|
||||||
screen.log_dict(
|
screen.log_dict(
|
||||||
OrderedDict([
|
OrderedDict([
|
||||||
("Saving in path", saved_checkpoint_path),
|
("Saving in path", saved_checkpoint_path),
|
||||||
|
|||||||
Reference in New Issue
Block a user