ONNX v2 raises an Exception when using PyTorch < 1.8.0 (#12933)

* Raise an issue if the pytorch version is < 1.8.0

* Attempt to add a test to ensure it correctly raises.

* Missing docstring.

* Second attempt, patch with string absolute import.

* Let's do the call before checking it was called ...

* use the correct function ... 🤦

* Raise ImportError and AssertionError respectively when unable to find torch and torch version is not sufficient.

* Correct path mock patching

* relax constraint for torch_onnx_dict_inputs to ge instead of eq.

* Style.

* Split each version requirements for torch.

* Let's compare version directly.

* Import torch_version after checking pytorch is installed.

* @require_torch
This commit is contained in:
Funtowicz Morgan 2021-07-29 18:02:29 +02:00 committed by GitHub
parent 9160d81c98
commit 640421c0ec
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 32 additions and 4 deletions

View File

@ -274,8 +274,9 @@ PRESET_MIRROR_DICT = {
"bfsu": "https://mirrors.bfsu.edu.cn/hugging-face-models",
}
# This is the version of torch required to run torch.fx features.
# This is the version of torch required to run torch.fx features and torch.onnx with dictionary inputs.
TORCH_FX_REQUIRED_VERSION = version.parse("1.8")
TORCH_ONNX_DICT_INPUTS_MINIMUM_VERSION = version.parse("1.8")
_is_offline_mode = True if os.environ.get("TRANSFORMERS_OFFLINE", "0").upper() in ENV_VARS_TRUE_VALUES else False
@ -297,7 +298,7 @@ def is_torch_cuda_available():
return False
_torch_fx_available = False
_torch_fx_available = _torch_onnx_dict_inputs_support_available = False
if _torch_available:
torch_version = version.parse(importlib_metadata.version("torch"))
_torch_fx_available = (torch_version.major, torch_version.minor) == (
@ -305,11 +306,17 @@ if _torch_available:
TORCH_FX_REQUIRED_VERSION.minor,
)
_torch_onnx_dict_inputs_support_available = torch_version >= TORCH_ONNX_DICT_INPUTS_MINIMUM_VERSION
def is_torch_fx_available():
return _torch_fx_available
def is_torch_onnx_dict_inputs_support_available():
return _torch_onnx_dict_inputs_support_available
def is_tf_available():
return _tf_available

View File

@ -21,6 +21,7 @@ import numpy as np
from packaging.version import Version, parse
from .. import PreTrainedModel, PreTrainedTokenizer, TensorType, TFPreTrainedModel, is_torch_available
from ..file_utils import is_torch_onnx_dict_inputs_support_available
from ..utils import logging
from .config import OnnxConfig
from .utils import flatten_output_collection_property
@ -79,11 +80,16 @@ def export(
"""
if not is_torch_available():
raise Exception("Cannot convert because PyTorch is not installed. Please install torch first.")
raise ImportError("Cannot convert because PyTorch is not installed. Please install torch first.")
import torch
from torch.onnx import export
from ..file_utils import torch_version
if not is_torch_onnx_dict_inputs_support_available():
raise AssertionError(f"Unsupported PyTorch version, minimum required is 1.8.0, got: {torch_version}")
logger.info(f"Using framework PyTorch: {torch.__version__}")
torch.set_grad_enabled(False)
model.config.return_dict = True

View File

@ -24,7 +24,13 @@ from transformers.models.roberta import RobertaOnnxConfig
# from transformers.models.t5 import T5OnnxConfig
from transformers.models.xlm_roberta import XLMRobertaOnnxConfig
from transformers.onnx import EXTERNAL_DATA_FORMAT_SIZE_LIMIT, OnnxConfig, ParameterFormat, validate_model_outputs
from transformers.onnx import (
EXTERNAL_DATA_FORMAT_SIZE_LIMIT,
OnnxConfig,
ParameterFormat,
export,
validate_model_outputs,
)
from transformers.onnx.config import DEFAULT_ONNX_OPSET, OnnxConfigWithPast
from transformers.onnx.utils import (
compute_effective_axis_dimension,
@ -40,6 +46,15 @@ class OnnxUtilsTestCaseV2(TestCase):
Cover all the utilities involved to export ONNX models
"""
@require_torch
@patch("transformers.onnx.convert.is_torch_onnx_dict_inputs_support_available", return_value=False)
def test_ensure_pytorch_version_ge_1_8_0(self, mock_is_torch_onnx_dict_inputs_support_available):
"""
Ensure we raise an Exception if the pytorch version is unsupported (< 1.8.0)
"""
self.assertRaises(AssertionError, export, None, None, None, None, None)
mock_is_torch_onnx_dict_inputs_support_available.assert_called()
def test_compute_effective_axis_dimension(self):
"""
When exporting ONNX model with dynamic axis (batch or sequence) we set batch_size and/or sequence_length = -1.