mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
ONNX v2 raises an Exception when using PyTorch < 1.8.0 (#12933)
* Raise an issue if the pytorch version is < 1.8.0
* Attempt to add a test to ensure it correctly raises.
* Missing docstring.
* Second attempt, patch with string absolute import.
* Let's do the call before checking it was called ...
* use the correct function ... 🤦
* Raise ImportError and AssertionError respectively when unable to find torch and torch version is not sufficient.
* Correct path mock patching
* relax constraint for torch_onnx_dict_inputs to ge instead of eq.
* Style.
* Split each version requirements for torch.
* Let's compare version directly.
* Import torch_version after checking pytorch is installed.
* @require_torch
This commit is contained in:
parent
9160d81c98
commit
640421c0ec
@ -274,8 +274,9 @@ PRESET_MIRROR_DICT = {
|
||||
"bfsu": "https://mirrors.bfsu.edu.cn/hugging-face-models",
|
||||
}
|
||||
|
||||
# This is the version of torch required to run torch.fx features.
|
||||
# This is the version of torch required to run torch.fx features and torch.onnx with dictionary inputs.
|
||||
TORCH_FX_REQUIRED_VERSION = version.parse("1.8")
|
||||
TORCH_ONNX_DICT_INPUTS_MINIMUM_VERSION = version.parse("1.8")
|
||||
|
||||
_is_offline_mode = True if os.environ.get("TRANSFORMERS_OFFLINE", "0").upper() in ENV_VARS_TRUE_VALUES else False
|
||||
|
||||
@ -297,7 +298,7 @@ def is_torch_cuda_available():
|
||||
return False
|
||||
|
||||
|
||||
_torch_fx_available = False
|
||||
_torch_fx_available = _torch_onnx_dict_inputs_support_available = False
|
||||
if _torch_available:
|
||||
torch_version = version.parse(importlib_metadata.version("torch"))
|
||||
_torch_fx_available = (torch_version.major, torch_version.minor) == (
|
||||
@ -305,11 +306,17 @@ if _torch_available:
|
||||
TORCH_FX_REQUIRED_VERSION.minor,
|
||||
)
|
||||
|
||||
_torch_onnx_dict_inputs_support_available = torch_version >= TORCH_ONNX_DICT_INPUTS_MINIMUM_VERSION
|
||||
|
||||
|
||||
def is_torch_fx_available():
|
||||
return _torch_fx_available
|
||||
|
||||
|
||||
def is_torch_onnx_dict_inputs_support_available():
|
||||
return _torch_onnx_dict_inputs_support_available
|
||||
|
||||
|
||||
def is_tf_available():
|
||||
return _tf_available
|
||||
|
||||
|
@ -21,6 +21,7 @@ import numpy as np
|
||||
from packaging.version import Version, parse
|
||||
|
||||
from .. import PreTrainedModel, PreTrainedTokenizer, TensorType, TFPreTrainedModel, is_torch_available
|
||||
from ..file_utils import is_torch_onnx_dict_inputs_support_available
|
||||
from ..utils import logging
|
||||
from .config import OnnxConfig
|
||||
from .utils import flatten_output_collection_property
|
||||
@ -79,11 +80,16 @@ def export(
|
||||
|
||||
"""
|
||||
if not is_torch_available():
|
||||
raise Exception("Cannot convert because PyTorch is not installed. Please install torch first.")
|
||||
raise ImportError("Cannot convert because PyTorch is not installed. Please install torch first.")
|
||||
|
||||
import torch
|
||||
from torch.onnx import export
|
||||
|
||||
from ..file_utils import torch_version
|
||||
|
||||
if not is_torch_onnx_dict_inputs_support_available():
|
||||
raise AssertionError(f"Unsupported PyTorch version, minimum required is 1.8.0, got: {torch_version}")
|
||||
|
||||
logger.info(f"Using framework PyTorch: {torch.__version__}")
|
||||
torch.set_grad_enabled(False)
|
||||
model.config.return_dict = True
|
||||
|
@ -24,7 +24,13 @@ from transformers.models.roberta import RobertaOnnxConfig
|
||||
|
||||
# from transformers.models.t5 import T5OnnxConfig
|
||||
from transformers.models.xlm_roberta import XLMRobertaOnnxConfig
|
||||
from transformers.onnx import EXTERNAL_DATA_FORMAT_SIZE_LIMIT, OnnxConfig, ParameterFormat, validate_model_outputs
|
||||
from transformers.onnx import (
|
||||
EXTERNAL_DATA_FORMAT_SIZE_LIMIT,
|
||||
OnnxConfig,
|
||||
ParameterFormat,
|
||||
export,
|
||||
validate_model_outputs,
|
||||
)
|
||||
from transformers.onnx.config import DEFAULT_ONNX_OPSET, OnnxConfigWithPast
|
||||
from transformers.onnx.utils import (
|
||||
compute_effective_axis_dimension,
|
||||
@ -40,6 +46,15 @@ class OnnxUtilsTestCaseV2(TestCase):
|
||||
Cover all the utilities involved to export ONNX models
|
||||
"""
|
||||
|
||||
@require_torch
|
||||
@patch("transformers.onnx.convert.is_torch_onnx_dict_inputs_support_available", return_value=False)
|
||||
def test_ensure_pytorch_version_ge_1_8_0(self, mock_is_torch_onnx_dict_inputs_support_available):
|
||||
"""
|
||||
Ensure we raise an Exception if the pytorch version is unsupported (< 1.8.0)
|
||||
"""
|
||||
self.assertRaises(AssertionError, export, None, None, None, None, None)
|
||||
mock_is_torch_onnx_dict_inputs_support_available.assert_called()
|
||||
|
||||
def test_compute_effective_axis_dimension(self):
|
||||
"""
|
||||
When exporting ONNX model with dynamic axis (batch or sequence) we set batch_size and/or sequence_length = -1.
|
||||
|
Loading…
Reference in New Issue
Block a user