mirror of
https://github.com/gryf/coach.git
synced 2025-12-18 11:40:18 +01:00
Export graph to ONNX (#61)
Implements the ONNX graph exporting feature. Currently does not work for NAF, C51 and A3C_LSTM due to unsupported TF layers in the tf2onnx library.
This commit is contained in:
@@ -14,6 +14,7 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
import time
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
@@ -22,7 +23,7 @@ from rl_coach.architectures.architecture import Architecture
|
||||
from rl_coach.base_parameters import AgentParameters, DistributedTaskParameters
|
||||
from rl_coach.core_types import GradientClippingMethod
|
||||
from rl_coach.spaces import SpacesDefinition
|
||||
from rl_coach.utils import force_list, squeeze_list
|
||||
from rl_coach.utils import force_list, squeeze_list, start_shell_command_and_wait
|
||||
|
||||
|
||||
def variable_summaries(var):
|
||||
@@ -614,3 +615,41 @@ class TensorFlowArchitecture(Architecture):
|
||||
if self.middleware.__class__.__name__ == 'LSTMMiddleware':
|
||||
self.curr_rnn_c_in = self.middleware.c_init
|
||||
self.curr_rnn_h_in = self.middleware.h_init
|
||||
|
||||
|
||||
def save_onnx_graph(input_nodes, output_nodes, checkpoint_save_dir: str) -> None:
|
||||
"""
|
||||
Given the input nodes and output nodes of the TF graph, save it as an onnx graph
|
||||
This requires the TF graph and the weights checkpoint to be stored in the experiment directory.
|
||||
It then freezes the graph (merging the graph and weights checkpoint), and converts it to ONNX.
|
||||
|
||||
:param input_nodes: A list of input nodes for the TF graph
|
||||
:param output_nodes: A list of output nodes for the TF graph
|
||||
:param checkpoint_save_dir: The directory to save the ONNX graph to
|
||||
:return: None
|
||||
"""
|
||||
import tf2onnx # just to verify that tf2onnx is installed
|
||||
|
||||
# freeze graph
|
||||
frozen_graph_path = os.path.join(checkpoint_save_dir, "frozen_graph.pb")
|
||||
freeze_graph_command = [
|
||||
"python -m tensorflow.python.tools.freeze_graph",
|
||||
"--input_graph={}".format(os.path.join(checkpoint_save_dir, "graphdef.pb")),
|
||||
"--input_binary=true",
|
||||
"--output_node_names='{}'".format(','.join([o.split(":")[0] for o in output_nodes])),
|
||||
"--input_checkpoint={}".format(tf.train.latest_checkpoint(checkpoint_save_dir)),
|
||||
"--output_graph={}".format(frozen_graph_path)
|
||||
]
|
||||
start_shell_command_and_wait(" ".join(freeze_graph_command))
|
||||
|
||||
# convert graph to onnx
|
||||
onnx_graph_path = os.path.join(checkpoint_save_dir, "model.onnx")
|
||||
convert_to_onnx_command = [
|
||||
"python -m tf2onnx.convert",
|
||||
"--input {}".format(frozen_graph_path),
|
||||
"--inputs '{}'".format(','.join(input_nodes)),
|
||||
"--outputs '{}'".format(','.join(output_nodes)),
|
||||
"--output {}".format(onnx_graph_path),
|
||||
"--verbose"
|
||||
]
|
||||
start_shell_command_and_wait(" ".join(convert_to_onnx_command))
|
||||
|
||||
@@ -127,11 +127,15 @@ class PPOHead(Head):
|
||||
self.input = [self.actions, self.old_policy_mean, self.old_policy_std]
|
||||
self.policy_mean = self.dense_layer(num_actions)(input_layer, name='policy_mean',
|
||||
kernel_initializer=normalized_columns_initializer(0.01))
|
||||
if self.is_local:
|
||||
|
||||
# for local networks in distributed settings, we need to move variables we create manually to the
|
||||
# tf.GraphKeys.LOCAL_VARIABLES collection, since the variable scope custom getter which is set in
|
||||
# Architecture does not apply to them
|
||||
if self.is_local and isinstance(self.ap.task_parameters, DistributedTaskParameters):
|
||||
self.policy_logstd = tf.Variable(np.zeros((1, num_actions)), dtype='float32',
|
||||
collections=[tf.GraphKeys.LOCAL_VARIABLES])
|
||||
collections=[tf.GraphKeys.LOCAL_VARIABLES], name="policy_log_std")
|
||||
else:
|
||||
self.policy_logstd = tf.Variable(np.zeros((1, num_actions)), dtype='float32')
|
||||
self.policy_logstd = tf.Variable(np.zeros((1, num_actions)), dtype='float32', name="policy_log_std")
|
||||
|
||||
self.policy_std = tf.tile(tf.exp(self.policy_logstd), [tf.shape(input_layer)[0], 1], name='policy_std')
|
||||
|
||||
|
||||
Reference in New Issue
Block a user