mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Prepare ONNX export for torch v1.11 (#15270)
* Prepare ONNX export for torch v1.11
This commit is contained in:
parent
126bddd1ba
commit
b4ce313e6c
@ -112,19 +112,34 @@ def export(
|
||||
|
||||
config.patch_ops()
|
||||
|
||||
# export can works with named args but the dict containing named args as to be last element of the args tuple
|
||||
export(
|
||||
model,
|
||||
(model_inputs,),
|
||||
f=output.as_posix(),
|
||||
input_names=list(config.inputs.keys()),
|
||||
output_names=onnx_outputs,
|
||||
dynamic_axes={name: axes for name, axes in chain(config.inputs.items(), config.outputs.items())},
|
||||
do_constant_folding=True,
|
||||
use_external_data_format=config.use_external_data_format(model.num_parameters()),
|
||||
enable_onnx_checker=True,
|
||||
opset_version=opset,
|
||||
)
|
||||
# PyTorch deprecated the `enable_onnx_checker` and `use_external_data_format` arguments in v1.11,
|
||||
# so we check the torch version for backwards compatibility
|
||||
if parse(torch.__version__) <= parse("1.10.99"):
|
||||
# export can work with named args but the dict containing named args
|
||||
# has to be the last element of the args tuple.
|
||||
export(
|
||||
model,
|
||||
(model_inputs,),
|
||||
f=output.as_posix(),
|
||||
input_names=list(config.inputs.keys()),
|
||||
output_names=onnx_outputs,
|
||||
dynamic_axes={name: axes for name, axes in chain(config.inputs.items(), config.outputs.items())},
|
||||
do_constant_folding=True,
|
||||
use_external_data_format=config.use_external_data_format(model.num_parameters()),
|
||||
enable_onnx_checker=True,
|
||||
opset_version=opset,
|
||||
)
|
||||
else:
|
||||
export(
|
||||
model,
|
||||
(model_inputs,),
|
||||
f=output.as_posix(),
|
||||
input_names=list(config.inputs.keys()),
|
||||
output_names=onnx_outputs,
|
||||
dynamic_axes={name: axes for name, axes in chain(config.inputs.items(), config.outputs.items())},
|
||||
do_constant_folding=True,
|
||||
opset_version=opset,
|
||||
)
|
||||
|
||||
config.restore_ops()
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user