1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 19:20:19 +01:00

Export graph to ONNX (#61)

Implements the ONNX graph exporting feature. 
Currently does not work for NAF, C51 and A3C_LSTM due to unsupported TF layers in the tf2onnx library.
This commit is contained in:
Itai Caspi
2018-11-06 10:55:21 +02:00
committed by GitHub
parent d75df17d97
commit 811152126c
4 changed files with 107 additions and 27 deletions

View File

@@ -241,6 +241,10 @@ def parse_arguments(parser: argparse.ArgumentParser) -> argparse.Namespace:
# checkpoints
args.checkpoint_save_dir = os.path.join(args.experiment_path, 'checkpoint') if args.checkpoint_save_secs is not None else None
if args.export_onnx_graph and not args.checkpoint_save_secs:
screen.warning("Exporting ONNX graphs requires setting the --checkpoint_save_secs flag. "
"The --export_onnx_graph will have no effect.")
return args
@@ -474,6 +478,14 @@ def main():
help="(int) A seed to use for running the experiment",
default=None,
type=int)
parser.add_argument('-onnx', '--export_onnx_graph',
help="(flag) Export the ONNX graph to the experiment directory. "
"This will have effect only if the --checkpoint_save_secs flag is used in order to store "
"checkpoints, since the weights checkpoint are needed for the ONNX graph. "
"Keep in mind that this can cause major overhead on the experiment. "
"Exporting ONNX graphs requires manually installing the tf2onnx package "
"(https://github.com/onnx/tensorflow-onnx).",
action='store_true')
parser.add_argument('-dc', '--distributed_coach',
help="(flag) Use distributed Coach.",
action='store_true')