mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 11:11:05 +06:00
Add ONNX support for LayoutLMv3 (#17953)
* Add ONNX support for LayoutLMv3 * Update docstrings * Update empty description in docstring * Fix imports and type hints
This commit is contained in:
parent
fe14046421
commit
9cb7cef285
@ -70,6 +70,7 @@ Ready-made configurations include the following architectures:
|
|||||||
- GPT-J
|
- GPT-J
|
||||||
- I-BERT
|
- I-BERT
|
||||||
- LayoutLM
|
- LayoutLM
|
||||||
|
- LayoutLMv3
|
||||||
- LongT5
|
- LongT5
|
||||||
- M2M100
|
- M2M100
|
||||||
- Marian
|
- Marian
|
||||||
|
@ -28,7 +28,11 @@ from ...utils import (
|
|||||||
|
|
||||||
|
|
||||||
_import_structure = {
|
_import_structure = {
|
||||||
"configuration_layoutlmv3": ["LAYOUTLMV3_PRETRAINED_CONFIG_ARCHIVE_MAP", "LayoutLMv3Config"],
|
"configuration_layoutlmv3": [
|
||||||
|
"LAYOUTLMV3_PRETRAINED_CONFIG_ARCHIVE_MAP",
|
||||||
|
"LayoutLMv3Config",
|
||||||
|
"LayoutLMv3OnnxConfig",
|
||||||
|
],
|
||||||
"processing_layoutlmv3": ["LayoutLMv3Processor"],
|
"processing_layoutlmv3": ["LayoutLMv3Processor"],
|
||||||
"tokenization_layoutlmv3": ["LayoutLMv3Tokenizer"],
|
"tokenization_layoutlmv3": ["LayoutLMv3Tokenizer"],
|
||||||
}
|
}
|
||||||
@ -66,7 +70,11 @@ else:
|
|||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .configuration_layoutlmv3 import LAYOUTLMV3_PRETRAINED_CONFIG_ARCHIVE_MAP, LayoutLMv3Config
|
from .configuration_layoutlmv3 import (
|
||||||
|
LAYOUTLMV3_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||||
|
LayoutLMv3Config,
|
||||||
|
LayoutLMv3OnnxConfig,
|
||||||
|
)
|
||||||
from .processing_layoutlmv3 import LayoutLMv3Processor
|
from .processing_layoutlmv3 import LayoutLMv3Processor
|
||||||
from .tokenization_layoutlmv3 import LayoutLMv3Tokenizer
|
from .tokenization_layoutlmv3 import LayoutLMv3Tokenizer
|
||||||
|
|
||||||
|
@ -14,10 +14,22 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
""" LayoutLMv3 model configuration"""
|
""" LayoutLMv3 model configuration"""
|
||||||
|
|
||||||
|
from collections import OrderedDict
|
||||||
|
from typing import TYPE_CHECKING, Any, Mapping, Optional
|
||||||
|
|
||||||
|
from packaging import version
|
||||||
|
|
||||||
from ...configuration_utils import PretrainedConfig
|
from ...configuration_utils import PretrainedConfig
|
||||||
|
from ...onnx import OnnxConfig
|
||||||
|
from ...onnx.utils import compute_effective_axis_dimension
|
||||||
from ...utils import logging
|
from ...utils import logging
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from ...processing_utils import ProcessorMixin
|
||||||
|
from ...utils import TensorType
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
LAYOUTLMV3_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
LAYOUTLMV3_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
||||||
@ -176,3 +188,107 @@ class LayoutLMv3Config(PretrainedConfig):
|
|||||||
self.num_channels = num_channels
|
self.num_channels = num_channels
|
||||||
self.patch_size = patch_size
|
self.patch_size = patch_size
|
||||||
self.classifier_dropout = classifier_dropout
|
self.classifier_dropout = classifier_dropout
|
||||||
|
|
||||||
|
|
||||||
|
class LayoutLMv3OnnxConfig(OnnxConfig):
|
||||||
|
|
||||||
|
torch_onnx_minimum_version = version.parse("1.12")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def inputs(self) -> Mapping[str, Mapping[int, str]]:
|
||||||
|
# The order of inputs is different for question answering and sequence classification
|
||||||
|
if self.task in ["question-answering", "sequence-classification"]:
|
||||||
|
return OrderedDict(
|
||||||
|
[
|
||||||
|
("input_ids", {0: "batch", 1: "sequence"}),
|
||||||
|
("attention_mask", {0: "batch", 1: "sequence"}),
|
||||||
|
("bbox", {0: "batch", 1: "sequence"}),
|
||||||
|
("pixel_values", {0: "batch", 1: "sequence"}),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return OrderedDict(
|
||||||
|
[
|
||||||
|
("input_ids", {0: "batch", 1: "sequence"}),
|
||||||
|
("bbox", {0: "batch", 1: "sequence"}),
|
||||||
|
("attention_mask", {0: "batch", 1: "sequence"}),
|
||||||
|
("pixel_values", {0: "batch", 1: "sequence"}),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def atol_for_validation(self) -> float:
|
||||||
|
return 1e-5
|
||||||
|
|
||||||
|
@property
|
||||||
|
def default_onnx_opset(self) -> int:
|
||||||
|
return 12
|
||||||
|
|
||||||
|
def generate_dummy_inputs(
|
||||||
|
self,
|
||||||
|
processor: "ProcessorMixin",
|
||||||
|
batch_size: int = -1,
|
||||||
|
seq_length: int = -1,
|
||||||
|
is_pair: bool = False,
|
||||||
|
framework: Optional["TensorType"] = None,
|
||||||
|
num_channels: int = 3,
|
||||||
|
image_width: int = 40,
|
||||||
|
image_height: int = 40,
|
||||||
|
) -> Mapping[str, Any]:
|
||||||
|
"""
|
||||||
|
Generate inputs to provide to the ONNX exporter for the specific framework
|
||||||
|
|
||||||
|
Args:
|
||||||
|
processor ([`ProcessorMixin`]):
|
||||||
|
The processor associated with this model configuration.
|
||||||
|
batch_size (`int`, *optional*, defaults to -1):
|
||||||
|
The batch size to export the model for (-1 means dynamic axis).
|
||||||
|
seq_length (`int`, *optional*, defaults to -1):
|
||||||
|
The sequence length to export the model for (-1 means dynamic axis).
|
||||||
|
is_pair (`bool`, *optional*, defaults to `False`):
|
||||||
|
Indicate if the input is a pair (sentence 1, sentence 2).
|
||||||
|
framework (`TensorType`, *optional*, defaults to `None`):
|
||||||
|
The framework (PyTorch or TensorFlow) that the processor will generate tensors for.
|
||||||
|
num_channels (`int`, *optional*, defaults to 3):
|
||||||
|
The number of channels of the generated images.
|
||||||
|
image_width (`int`, *optional*, defaults to 40):
|
||||||
|
The width of the generated images.
|
||||||
|
image_height (`int`, *optional*, defaults to 40):
|
||||||
|
The height of the generated images.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Mapping[str, Any]: holding the kwargs to provide to the model's forward function
|
||||||
|
"""
|
||||||
|
|
||||||
|
# A dummy image is used so OCR should not be applied
|
||||||
|
setattr(processor.feature_extractor, "apply_ocr", False)
|
||||||
|
|
||||||
|
# If dynamic axis (-1) we forward with a fixed dimension of 2 samples to avoid optimizations made by ONNX
|
||||||
|
batch_size = compute_effective_axis_dimension(
|
||||||
|
batch_size, fixed_dimension=OnnxConfig.default_fixed_batch, num_token_to_add=0
|
||||||
|
)
|
||||||
|
# If dynamic axis (-1) we forward with a fixed dimension of 8 tokens to avoid optimizations made by ONNX
|
||||||
|
token_to_add = processor.tokenizer.num_special_tokens_to_add(is_pair)
|
||||||
|
seq_length = compute_effective_axis_dimension(
|
||||||
|
seq_length, fixed_dimension=OnnxConfig.default_fixed_sequence, num_token_to_add=token_to_add
|
||||||
|
)
|
||||||
|
# Generate dummy inputs according to compute batch and sequence
|
||||||
|
dummy_text = [[" ".join([processor.tokenizer.unk_token]) * seq_length]] * batch_size
|
||||||
|
|
||||||
|
# Generate dummy bounding boxes
|
||||||
|
dummy_bboxes = [[[48, 84, 73, 128]]] * batch_size
|
||||||
|
|
||||||
|
# If dynamic axis (-1) we forward with a fixed dimension of 2 samples to avoid optimizations made by ONNX
|
||||||
|
# batch_size = compute_effective_axis_dimension(batch_size, fixed_dimension=OnnxConfig.default_fixed_batch)
|
||||||
|
dummy_image = self._generate_dummy_images(batch_size, num_channels, image_height, image_width)
|
||||||
|
|
||||||
|
inputs = dict(
|
||||||
|
processor(
|
||||||
|
dummy_image,
|
||||||
|
text=dummy_text,
|
||||||
|
boxes=dummy_bboxes,
|
||||||
|
return_tensors=framework,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return inputs
|
||||||
|
@ -40,6 +40,7 @@ if is_tf_available():
|
|||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from ..feature_extraction_utils import FeatureExtractionMixin
|
from ..feature_extraction_utils import FeatureExtractionMixin
|
||||||
|
from ..processing_utils import ProcessorMixin
|
||||||
from ..tokenization_utils import PreTrainedTokenizer
|
from ..tokenization_utils import PreTrainedTokenizer
|
||||||
|
|
||||||
|
|
||||||
@ -80,7 +81,7 @@ def check_onnxruntime_requirements(minimum_version: Version):
|
|||||||
|
|
||||||
|
|
||||||
def export_pytorch(
|
def export_pytorch(
|
||||||
preprocessor: Union["PreTrainedTokenizer", "FeatureExtractionMixin"],
|
preprocessor: Union["PreTrainedTokenizer", "FeatureExtractionMixin", "ProcessorMixin"],
|
||||||
model: "PreTrainedModel",
|
model: "PreTrainedModel",
|
||||||
config: OnnxConfig,
|
config: OnnxConfig,
|
||||||
opset: int,
|
opset: int,
|
||||||
@ -92,7 +93,7 @@ def export_pytorch(
|
|||||||
Export a PyTorch model to an ONNX Intermediate Representation (IR)
|
Export a PyTorch model to an ONNX Intermediate Representation (IR)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
preprocessor: ([`PreTrainedTokenizer`] or [`FeatureExtractionMixin`]):
|
preprocessor: ([`PreTrainedTokenizer`], [`FeatureExtractionMixin`] or [`ProcessorMixin`]):
|
||||||
The preprocessor used for encoding the data.
|
The preprocessor used for encoding the data.
|
||||||
model ([`PreTrainedModel`]):
|
model ([`PreTrainedModel`]):
|
||||||
The model to export.
|
The model to export.
|
||||||
@ -269,7 +270,7 @@ def export_tensorflow(
|
|||||||
|
|
||||||
|
|
||||||
def export(
|
def export(
|
||||||
preprocessor: Union["PreTrainedTokenizer", "FeatureExtractionMixin"],
|
preprocessor: Union["PreTrainedTokenizer", "FeatureExtractionMixin", "ProcessorMixin"],
|
||||||
model: Union["PreTrainedModel", "TFPreTrainedModel"],
|
model: Union["PreTrainedModel", "TFPreTrainedModel"],
|
||||||
config: OnnxConfig,
|
config: OnnxConfig,
|
||||||
opset: int,
|
opset: int,
|
||||||
@ -281,7 +282,7 @@ def export(
|
|||||||
Export a Pytorch or TensorFlow model to an ONNX Intermediate Representation (IR)
|
Export a Pytorch or TensorFlow model to an ONNX Intermediate Representation (IR)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
preprocessor: ([`PreTrainedTokenizer`] or [`FeatureExtractionMixin`]):
|
preprocessor: ([`PreTrainedTokenizer`], [`FeatureExtractionMixin`] or [`ProcessorMixin`]):
|
||||||
The preprocessor used for encoding the data.
|
The preprocessor used for encoding the data.
|
||||||
model ([`PreTrainedModel`] or [`TFPreTrainedModel`]):
|
model ([`PreTrainedModel`] or [`TFPreTrainedModel`]):
|
||||||
The model to export.
|
The model to export.
|
||||||
@ -339,7 +340,7 @@ def export(
|
|||||||
|
|
||||||
def validate_model_outputs(
|
def validate_model_outputs(
|
||||||
config: OnnxConfig,
|
config: OnnxConfig,
|
||||||
preprocessor: Union["PreTrainedTokenizer", "FeatureExtractionMixin"],
|
preprocessor: Union["PreTrainedTokenizer", "FeatureExtractionMixin", "ProcessorMixin"],
|
||||||
reference_model: Union["PreTrainedModel", "TFPreTrainedModel"],
|
reference_model: Union["PreTrainedModel", "TFPreTrainedModel"],
|
||||||
onnx_model: Path,
|
onnx_model: Path,
|
||||||
onnx_named_outputs: List[str],
|
onnx_named_outputs: List[str],
|
||||||
|
@ -317,6 +317,13 @@ class FeaturesManager:
|
|||||||
"token-classification",
|
"token-classification",
|
||||||
onnx_config_cls="models.layoutlm.LayoutLMOnnxConfig",
|
onnx_config_cls="models.layoutlm.LayoutLMOnnxConfig",
|
||||||
),
|
),
|
||||||
|
"layoutlmv3": supported_features_mapping(
|
||||||
|
"default",
|
||||||
|
"question-answering",
|
||||||
|
"sequence-classification",
|
||||||
|
"token-classification",
|
||||||
|
onnx_config_cls="models.layoutlmv3.LayoutLMv3OnnxConfig",
|
||||||
|
),
|
||||||
"longt5": supported_features_mapping(
|
"longt5": supported_features_mapping(
|
||||||
"default",
|
"default",
|
||||||
"default-with-past",
|
"default-with-past",
|
||||||
|
@ -195,6 +195,7 @@ PYTORCH_EXPORT_MODELS = {
|
|||||||
("xlm", "xlm-clm-ende-1024"),
|
("xlm", "xlm-clm-ende-1024"),
|
||||||
("xlm-roberta", "xlm-roberta-base"),
|
("xlm-roberta", "xlm-roberta-base"),
|
||||||
("layoutlm", "microsoft/layoutlm-base-uncased"),
|
("layoutlm", "microsoft/layoutlm-base-uncased"),
|
||||||
|
("layoutlmv3", "microsoft/layoutlmv3-base"),
|
||||||
("vit", "google/vit-base-patch16-224"),
|
("vit", "google/vit-base-patch16-224"),
|
||||||
("deit", "facebook/deit-small-patch16-224"),
|
("deit", "facebook/deit-small-patch16-224"),
|
||||||
("beit", "microsoft/beit-base-patch16-224"),
|
("beit", "microsoft/beit-base-patch16-224"),
|
||||||
|
Loading…
Reference in New Issue
Block a user