mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
Prepare deprecated ONNX exporter for torch v1.11 (#15388)
* Prepare deprecated ONNX exporter for PyTorch v1.11 * Add deprecation warning
This commit is contained in:
parent
4996922b6d
commit
507601a5cf
@ -12,6 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import warnings
|
||||
from argparse import ArgumentParser
|
||||
from os import listdir, makedirs
|
||||
from pathlib import Path
|
||||
@ -278,18 +279,32 @@ def convert_pytorch(nlp: Pipeline, opset: int, output: Path, use_external_format
|
||||
input_names, output_names, dynamic_axes, tokens = infer_shapes(nlp, "pt")
|
||||
ordered_input_names, model_args = ensure_valid_input(nlp.model, tokens, input_names)
|
||||
|
||||
export(
|
||||
nlp.model,
|
||||
model_args,
|
||||
f=output.as_posix(),
|
||||
input_names=ordered_input_names,
|
||||
output_names=output_names,
|
||||
dynamic_axes=dynamic_axes,
|
||||
do_constant_folding=True,
|
||||
use_external_data_format=use_external_format,
|
||||
enable_onnx_checker=True,
|
||||
opset_version=opset,
|
||||
)
|
||||
# PyTorch deprecated the `enable_onnx_checker` and `use_external_data_format` arguments in v1.11,
|
||||
# so we check the torch version for backwards compatibility
|
||||
if parse(torch.__version__) <= parse("1.10.99"):
|
||||
export(
|
||||
nlp.model,
|
||||
model_args,
|
||||
f=output.as_posix(),
|
||||
input_names=ordered_input_names,
|
||||
output_names=output_names,
|
||||
dynamic_axes=dynamic_axes,
|
||||
do_constant_folding=True,
|
||||
use_external_data_format=use_external_format,
|
||||
enable_onnx_checker=True,
|
||||
opset_version=opset,
|
||||
)
|
||||
else:
|
||||
export(
|
||||
nlp.model,
|
||||
model_args,
|
||||
f=output.as_posix(),
|
||||
input_names=ordered_input_names,
|
||||
output_names=output_names,
|
||||
dynamic_axes=dynamic_axes,
|
||||
do_constant_folding=True,
|
||||
opset_version=opset,
|
||||
)
|
||||
|
||||
|
||||
def convert_tensorflow(nlp: Pipeline, opset: int, output: Path):
|
||||
@ -356,6 +371,10 @@ def convert(
|
||||
Returns:
|
||||
|
||||
"""
|
||||
warnings.warn(
|
||||
"The `transformers.convert_graph_to_onnx` package is deprecated and will be removed in version 5 of Transformers",
|
||||
FutureWarning,
|
||||
)
|
||||
print(f"ONNX opset version set to: {opset}")
|
||||
|
||||
# Load the pipeline
|
||||
|
Loading…
Reference in New Issue
Block a user