mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 19:20:19 +01:00
Export graph to ONNX (#61)
Implements the ONNX graph exporting feature. Currently does not work for NAF, C51 and A3C_LSTM due to unsupported TF layers in the tf2onnx library.
This commit is contained in:
@@ -100,6 +100,8 @@ class GraphManager(object):
|
||||
self.preset_validation_params = PresetValidationParameters()
|
||||
self.reset_required = False
|
||||
|
||||
self.last_checkpoint_saving_time = time.time()
|
||||
|
||||
# counters
|
||||
self.total_steps_counters = {
|
||||
RunPhase.HEATUP: TotalStepsCounter(),
|
||||
@@ -227,28 +229,47 @@ class GraphManager(object):
|
||||
# restore from checkpoint if given
|
||||
self.restore_checkpoint()
|
||||
|
||||
# tf.train.write_graph(tf.get_default_graph(),
|
||||
# logdir=self.task_parameters.save_checkpoint_dir,
|
||||
# name='graphdef.pb',
|
||||
# as_text=False)
|
||||
# self.save_checkpoint()
|
||||
#
|
||||
# output_nodes = []
|
||||
# for level in self.level_managers:
|
||||
# for agent in level.agents.values():
|
||||
# for network in agent.networks.values():
|
||||
# for output in network.online_network.outputs:
|
||||
# output_nodes.append(output.name.split(":")[0])
|
||||
#
|
||||
# freeze_graph_command = [
|
||||
# "python -m tensorflow.python.tools.freeze_graph",
|
||||
# "--input_graph={}".format(os.path.join(self.task_parameters.save_checkpoint_dir, "graphdef.pb")),
|
||||
# "--input_binary=true",
|
||||
# "--output_node_names='{}'".format(','.join(output_nodes)),
|
||||
# "--input_checkpoint={}".format(os.path.join(self.task_parameters.save_checkpoint_dir, "0_Step-0.ckpt")),
|
||||
# "--output_graph={}".format(os.path.join(self.task_parameters.save_checkpoint_dir, "frozen_graph.pb"))
|
||||
# ]
|
||||
# start_shell_command_and_wait(" ".join(freeze_graph_command))
|
||||
# the TF graph is static, and therefore is saved once - in the beginning of the experiment
|
||||
if hasattr(self.task_parameters, 'checkpoint_save_dir') and self.task_parameters.checkpoint_save_dir:
|
||||
self.save_graph()
|
||||
|
||||
def save_graph(self) -> None:
|
||||
"""
|
||||
Save the TF graph to a protobuf description file in the experiment directory
|
||||
:return: None
|
||||
"""
|
||||
import tensorflow as tf
|
||||
|
||||
# write graph
|
||||
tf.train.write_graph(tf.get_default_graph(),
|
||||
logdir=self.task_parameters.checkpoint_save_dir,
|
||||
name='graphdef.pb',
|
||||
as_text=False)
|
||||
|
||||
def save_onnx_graph(self) -> None:
|
||||
"""
|
||||
Save the graph as an ONNX graph.
|
||||
This requires the graph and the weights checkpoint to be stored in the experiment directory.
|
||||
It then freezes the graph (merging the graph and weights checkpoint), and converts it to ONNX.
|
||||
:return: None
|
||||
"""
|
||||
|
||||
# collect input and output nodes
|
||||
input_nodes = []
|
||||
output_nodes = []
|
||||
for level in self.level_managers:
|
||||
for agent in level.agents.values():
|
||||
for network in agent.networks.values():
|
||||
for input_key, input in network.online_network.inputs.items():
|
||||
if not input_key.startswith("output_"):
|
||||
input_nodes.append(input.name)
|
||||
for output in network.online_network.outputs:
|
||||
output_nodes.append(output.name)
|
||||
|
||||
# TODO: make this framework agnostic
|
||||
from rl_coach.architectures.tensorflow_components.architecture import save_onnx_graph
|
||||
|
||||
save_onnx_graph(input_nodes, output_nodes, self.task_parameters.checkpoint_save_dir)
|
||||
|
||||
def setup_logger(self) -> None:
|
||||
# dump documentation
|
||||
@@ -496,7 +517,7 @@ class GraphManager(object):
|
||||
and (self.task_parameters.task_index == 0 # distributed
|
||||
or self.task_parameters.task_index is None # single-worker
|
||||
):
|
||||
self.save_checkpoint()
|
||||
self.save_checkpoint()
|
||||
|
||||
def save_checkpoint(self):
|
||||
checkpoint_path = os.path.join(self.task_parameters.checkpoint_save_dir,
|
||||
@@ -511,6 +532,10 @@ class GraphManager(object):
|
||||
# this is required in order for agents to save additional information like a DND for example
|
||||
[manager.save_checkpoint(self.checkpoint_id) for manager in self.level_managers]
|
||||
|
||||
# the ONNX graph will be stored only if checkpoints are stored and the -onnx flag is used
|
||||
if self.task_parameters.export_onnx_graph:
|
||||
self.save_onnx_graph()
|
||||
|
||||
screen.log_dict(
|
||||
OrderedDict([
|
||||
("Saving in path", saved_checkpoint_path),
|
||||
|
||||
Reference in New Issue
Block a user