mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 19:20:19 +01:00
Export graph to ONNX (#61)
Implements the ONNX graph exporting feature. Currently does not work for NAF, C51 and A3C_LSTM due to unsupported TF layers in the tf2onnx library.
This commit is contained in:
@@ -241,6 +241,10 @@ def parse_arguments(parser: argparse.ArgumentParser) -> argparse.Namespace:
|
||||
# checkpoints
|
||||
args.checkpoint_save_dir = os.path.join(args.experiment_path, 'checkpoint') if args.checkpoint_save_secs is not None else None
|
||||
|
||||
if args.export_onnx_graph and not args.checkpoint_save_secs:
|
||||
screen.warning("Exporting ONNX graphs requires setting the --checkpoint_save_secs flag. "
|
||||
"The --export_onnx_graph will have no effect.")
|
||||
|
||||
return args
|
||||
|
||||
|
||||
@@ -474,6 +478,14 @@ def main():
|
||||
help="(int) A seed to use for running the experiment",
|
||||
default=None,
|
||||
type=int)
|
||||
parser.add_argument('-onnx', '--export_onnx_graph',
|
||||
help="(flag) Export the ONNX graph to the experiment directory. "
|
||||
"This will have effect only if the --checkpoint_save_secs flag is used in order to store "
|
||||
"checkpoints, since the weights checkpoint are needed for the ONNX graph. "
|
||||
"Keep in mind that this can cause major overhead on the experiment. "
|
||||
"Exporting ONNX graphs requires manually installing the tf2onnx package "
|
||||
"(https://github.com/onnx/tensorflow-onnx).",
|
||||
action='store_true')
|
||||
parser.add_argument('-dc', '--distributed_coach',
|
||||
help="(flag) Use distributed Coach.",
|
||||
action='store_true')
|
||||
|
||||
Reference in New Issue
Block a user