1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 19:20:19 +01:00

Export graph to ONNX (#61)

Implements the ONNX graph exporting feature. 
Currently does not work for NAF, C51 and A3C_LSTM due to unsupported TF layers in the tf2onnx library.
This commit is contained in:
Itai Caspi
2018-11-06 10:55:21 +02:00
committed by GitHub
parent d75df17d97
commit 811152126c
4 changed files with 107 additions and 27 deletions

View File

@@ -100,6 +100,8 @@ class GraphManager(object):
self.preset_validation_params = PresetValidationParameters()
self.reset_required = False
self.last_checkpoint_saving_time = time.time()
# counters
self.total_steps_counters = {
RunPhase.HEATUP: TotalStepsCounter(),
@@ -227,28 +229,47 @@ class GraphManager(object):
# restore from checkpoint if given
self.restore_checkpoint()
# tf.train.write_graph(tf.get_default_graph(),
# logdir=self.task_parameters.save_checkpoint_dir,
# name='graphdef.pb',
# as_text=False)
# self.save_checkpoint()
#
# output_nodes = []
# for level in self.level_managers:
# for agent in level.agents.values():
# for network in agent.networks.values():
# for output in network.online_network.outputs:
# output_nodes.append(output.name.split(":")[0])
#
# freeze_graph_command = [
# "python -m tensorflow.python.tools.freeze_graph",
# "--input_graph={}".format(os.path.join(self.task_parameters.save_checkpoint_dir, "graphdef.pb")),
# "--input_binary=true",
# "--output_node_names='{}'".format(','.join(output_nodes)),
# "--input_checkpoint={}".format(os.path.join(self.task_parameters.save_checkpoint_dir, "0_Step-0.ckpt")),
# "--output_graph={}".format(os.path.join(self.task_parameters.save_checkpoint_dir, "frozen_graph.pb"))
# ]
# start_shell_command_and_wait(" ".join(freeze_graph_command))
# the TF graph is static, and therefore is saved once - in the beginning of the experiment
if hasattr(self.task_parameters, 'checkpoint_save_dir') and self.task_parameters.checkpoint_save_dir:
self.save_graph()
def save_graph(self) -> None:
"""
Save the TF graph to a protobuf description file in the experiment directory
:return: None
"""
import tensorflow as tf
# write graph
tf.train.write_graph(tf.get_default_graph(),
logdir=self.task_parameters.checkpoint_save_dir,
name='graphdef.pb',
as_text=False)
def save_onnx_graph(self) -> None:
"""
Save the graph as an ONNX graph.
This requires the graph and the weights checkpoint to be stored in the experiment directory.
It then freezes the graph (merging the graph and weights checkpoint), and converts it to ONNX.
:return: None
"""
# collect input and output nodes
input_nodes = []
output_nodes = []
for level in self.level_managers:
for agent in level.agents.values():
for network in agent.networks.values():
for input_key, input in network.online_network.inputs.items():
if not input_key.startswith("output_"):
input_nodes.append(input.name)
for output in network.online_network.outputs:
output_nodes.append(output.name)
# TODO: make this framework agnostic
from rl_coach.architectures.tensorflow_components.architecture import save_onnx_graph
save_onnx_graph(input_nodes, output_nodes, self.task_parameters.checkpoint_save_dir)
def setup_logger(self) -> None:
# dump documentation
@@ -496,7 +517,7 @@ class GraphManager(object):
and (self.task_parameters.task_index == 0 # distributed
or self.task_parameters.task_index is None # single-worker
):
self.save_checkpoint()
self.save_checkpoint()
def save_checkpoint(self):
checkpoint_path = os.path.join(self.task_parameters.checkpoint_save_dir,
@@ -511,6 +532,10 @@ class GraphManager(object):
# this is required in order for agents to save additional information like a DND for example
[manager.save_checkpoint(self.checkpoint_id) for manager in self.level_managers]
# the ONNX graph will be stored only if checkpoints are stored and the -onnx flag is used
if self.task_parameters.export_onnx_graph:
self.save_onnx_graph()
screen.log_dict(
OrderedDict([
("Saving in path", saved_checkpoint_path),