Prepare ONNX export for torch v1.11 (#15270)

* Prepare ONNX export for torch v1.11
This commit is contained in:
lewtun 2022-01-21 14:28:19 +01:00 committed by GitHub
parent 126bddd1ba
commit b4ce313e6c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -112,19 +112,34 @@ def export(
config.patch_ops()
# export can works with named args but the dict containing named args as to be last element of the args tuple
export(
model,
(model_inputs,),
f=output.as_posix(),
input_names=list(config.inputs.keys()),
output_names=onnx_outputs,
dynamic_axes={name: axes for name, axes in chain(config.inputs.items(), config.outputs.items())},
do_constant_folding=True,
use_external_data_format=config.use_external_data_format(model.num_parameters()),
enable_onnx_checker=True,
opset_version=opset,
)
# PyTorch deprecated the `enable_onnx_checker` and `use_external_data_format` arguments in v1.11,
# so we check the torch version for backwards compatibility
if parse(torch.__version__) <= parse("1.10.99"):
# export can work with named args but the dict containing named args
# has to be the last element of the args tuple.
export(
model,
(model_inputs,),
f=output.as_posix(),
input_names=list(config.inputs.keys()),
output_names=onnx_outputs,
dynamic_axes={name: axes for name, axes in chain(config.inputs.items(), config.outputs.items())},
do_constant_folding=True,
use_external_data_format=config.use_external_data_format(model.num_parameters()),
enable_onnx_checker=True,
opset_version=opset,
)
else:
export(
model,
(model_inputs,),
f=output.as_posix(),
input_names=list(config.inputs.keys()),
output_names=onnx_outputs,
dynamic_axes={name: axes for name, axes in chain(config.inputs.items(), config.outputs.items())},
do_constant_folding=True,
opset_version=opset,
)
config.restore_ops()