Prepare deprecated ONNX exporter for torch v1.11 (#15388)

* Prepare deprecated ONNX exporter for PyTorch v1.11

* Add deprecation warning
This commit is contained in:
lewtun 2022-01-28 16:32:47 +01:00 committed by GitHub
parent 4996922b6d
commit 507601a5cf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
from argparse import ArgumentParser
from os import listdir, makedirs
from pathlib import Path
@ -278,18 +279,32 @@ def convert_pytorch(nlp: Pipeline, opset: int, output: Path, use_external_format
input_names, output_names, dynamic_axes, tokens = infer_shapes(nlp, "pt")
ordered_input_names, model_args = ensure_valid_input(nlp.model, tokens, input_names)
export(
nlp.model,
model_args,
f=output.as_posix(),
input_names=ordered_input_names,
output_names=output_names,
dynamic_axes=dynamic_axes,
do_constant_folding=True,
use_external_data_format=use_external_format,
enable_onnx_checker=True,
opset_version=opset,
)
# PyTorch deprecated the `enable_onnx_checker` and `use_external_data_format` arguments in v1.11,
# so we check the torch version for backwards compatibility
if parse(torch.__version__) <= parse("1.10.99"):
export(
nlp.model,
model_args,
f=output.as_posix(),
input_names=ordered_input_names,
output_names=output_names,
dynamic_axes=dynamic_axes,
do_constant_folding=True,
use_external_data_format=use_external_format,
enable_onnx_checker=True,
opset_version=opset,
)
else:
export(
nlp.model,
model_args,
f=output.as_posix(),
input_names=ordered_input_names,
output_names=output_names,
dynamic_axes=dynamic_axes,
do_constant_folding=True,
opset_version=opset,
)
def convert_tensorflow(nlp: Pipeline, opset: int, output: Path):
@ -356,6 +371,10 @@ def convert(
Returns:
"""
warnings.warn(
"The `transformers.convert_graph_to_onnx` package is deprecated and will be removed in version 5 of Transformers",
FutureWarning,
)
print(f"ONNX opset version set to: {opset}")
# Load the pipeline