Add ONNX support for LayoutLMv3 (#17953)

* Add ONNX support for LayoutLMv3

* Update docstrings

* Update empty description in docstring

* Fix imports and type hints
This commit is contained in:
regisss 2022-06-30 18:09:52 +02:00 committed by GitHub
parent fe14046421
commit 9cb7cef285
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 141 additions and 7 deletions

View File

@ -70,6 +70,7 @@ Ready-made configurations include the following architectures:
- GPT-J
- I-BERT
- LayoutLM
- LayoutLMv3
- LongT5
- M2M100
- Marian

View File

@ -28,7 +28,11 @@ from ...utils import (
_import_structure = {
"configuration_layoutlmv3": ["LAYOUTLMV3_PRETRAINED_CONFIG_ARCHIVE_MAP", "LayoutLMv3Config"],
"configuration_layoutlmv3": [
"LAYOUTLMV3_PRETRAINED_CONFIG_ARCHIVE_MAP",
"LayoutLMv3Config",
"LayoutLMv3OnnxConfig",
],
"processing_layoutlmv3": ["LayoutLMv3Processor"],
"tokenization_layoutlmv3": ["LayoutLMv3Tokenizer"],
}
@ -66,7 +70,11 @@ else:
if TYPE_CHECKING:
from .configuration_layoutlmv3 import LAYOUTLMV3_PRETRAINED_CONFIG_ARCHIVE_MAP, LayoutLMv3Config
from .configuration_layoutlmv3 import (
LAYOUTLMV3_PRETRAINED_CONFIG_ARCHIVE_MAP,
LayoutLMv3Config,
LayoutLMv3OnnxConfig,
)
from .processing_layoutlmv3 import LayoutLMv3Processor
from .tokenization_layoutlmv3 import LayoutLMv3Tokenizer

View File

@ -14,10 +14,22 @@
# limitations under the License.
""" LayoutLMv3 model configuration"""
from collections import OrderedDict
from typing import TYPE_CHECKING, Any, Mapping, Optional
from packaging import version
from ...configuration_utils import PretrainedConfig
from ...onnx import OnnxConfig
from ...onnx.utils import compute_effective_axis_dimension
from ...utils import logging
if TYPE_CHECKING:
from ...processing_utils import ProcessorMixin
from ...utils import TensorType
logger = logging.get_logger(__name__)
LAYOUTLMV3_PRETRAINED_CONFIG_ARCHIVE_MAP = {
@ -176,3 +188,107 @@ class LayoutLMv3Config(PretrainedConfig):
self.num_channels = num_channels
self.patch_size = patch_size
self.classifier_dropout = classifier_dropout
class LayoutLMv3OnnxConfig(OnnxConfig):
torch_onnx_minimum_version = version.parse("1.12")
@property
def inputs(self) -> Mapping[str, Mapping[int, str]]:
# The order of inputs is different for question answering and sequence classification
if self.task in ["question-answering", "sequence-classification"]:
return OrderedDict(
[
("input_ids", {0: "batch", 1: "sequence"}),
("attention_mask", {0: "batch", 1: "sequence"}),
("bbox", {0: "batch", 1: "sequence"}),
("pixel_values", {0: "batch", 1: "sequence"}),
]
)
else:
return OrderedDict(
[
("input_ids", {0: "batch", 1: "sequence"}),
("bbox", {0: "batch", 1: "sequence"}),
("attention_mask", {0: "batch", 1: "sequence"}),
("pixel_values", {0: "batch", 1: "sequence"}),
]
)
@property
def atol_for_validation(self) -> float:
return 1e-5
@property
def default_onnx_opset(self) -> int:
return 12
def generate_dummy_inputs(
self,
processor: "ProcessorMixin",
batch_size: int = -1,
seq_length: int = -1,
is_pair: bool = False,
framework: Optional["TensorType"] = None,
num_channels: int = 3,
image_width: int = 40,
image_height: int = 40,
) -> Mapping[str, Any]:
"""
Generate inputs to provide to the ONNX exporter for the specific framework
Args:
processor ([`ProcessorMixin`]):
The processor associated with this model configuration.
batch_size (`int`, *optional*, defaults to -1):
The batch size to export the model for (-1 means dynamic axis).
seq_length (`int`, *optional*, defaults to -1):
The sequence length to export the model for (-1 means dynamic axis).
is_pair (`bool`, *optional*, defaults to `False`):
Indicate if the input is a pair (sentence 1, sentence 2).
framework (`TensorType`, *optional*, defaults to `None`):
The framework (PyTorch or TensorFlow) that the processor will generate tensors for.
num_channels (`int`, *optional*, defaults to 3):
The number of channels of the generated images.
image_width (`int`, *optional*, defaults to 40):
The width of the generated images.
image_height (`int`, *optional*, defaults to 40):
The height of the generated images.
Returns:
Mapping[str, Any]: holding the kwargs to provide to the model's forward function
"""
# A dummy image is used so OCR should not be applied
setattr(processor.feature_extractor, "apply_ocr", False)
# If dynamic axis (-1) we forward with a fixed dimension of 2 samples to avoid optimizations made by ONNX
batch_size = compute_effective_axis_dimension(
batch_size, fixed_dimension=OnnxConfig.default_fixed_batch, num_token_to_add=0
)
# If dynamic axis (-1) we forward with a fixed dimension of 8 tokens to avoid optimizations made by ONNX
token_to_add = processor.tokenizer.num_special_tokens_to_add(is_pair)
seq_length = compute_effective_axis_dimension(
seq_length, fixed_dimension=OnnxConfig.default_fixed_sequence, num_token_to_add=token_to_add
)
# Generate dummy inputs according to compute batch and sequence
dummy_text = [[" ".join([processor.tokenizer.unk_token]) * seq_length]] * batch_size
# Generate dummy bounding boxes
dummy_bboxes = [[[48, 84, 73, 128]]] * batch_size
# If dynamic axis (-1) we forward with a fixed dimension of 2 samples to avoid optimizations made by ONNX
# batch_size = compute_effective_axis_dimension(batch_size, fixed_dimension=OnnxConfig.default_fixed_batch)
dummy_image = self._generate_dummy_images(batch_size, num_channels, image_height, image_width)
inputs = dict(
processor(
dummy_image,
text=dummy_text,
boxes=dummy_bboxes,
return_tensors=framework,
)
)
return inputs

View File

@ -40,6 +40,7 @@ if is_tf_available():
if TYPE_CHECKING:
from ..feature_extraction_utils import FeatureExtractionMixin
from ..processing_utils import ProcessorMixin
from ..tokenization_utils import PreTrainedTokenizer
@ -80,7 +81,7 @@ def check_onnxruntime_requirements(minimum_version: Version):
def export_pytorch(
preprocessor: Union["PreTrainedTokenizer", "FeatureExtractionMixin"],
preprocessor: Union["PreTrainedTokenizer", "FeatureExtractionMixin", "ProcessorMixin"],
model: "PreTrainedModel",
config: OnnxConfig,
opset: int,
@ -92,7 +93,7 @@ def export_pytorch(
Export a PyTorch model to an ONNX Intermediate Representation (IR)
Args:
preprocessor: ([`PreTrainedTokenizer`] or [`FeatureExtractionMixin`]):
preprocessor: ([`PreTrainedTokenizer`], [`FeatureExtractionMixin`] or [`ProcessorMixin`]):
The preprocessor used for encoding the data.
model ([`PreTrainedModel`]):
The model to export.
@ -269,7 +270,7 @@ def export_tensorflow(
def export(
preprocessor: Union["PreTrainedTokenizer", "FeatureExtractionMixin"],
preprocessor: Union["PreTrainedTokenizer", "FeatureExtractionMixin", "ProcessorMixin"],
model: Union["PreTrainedModel", "TFPreTrainedModel"],
config: OnnxConfig,
opset: int,
@ -281,7 +282,7 @@ def export(
Export a Pytorch or TensorFlow model to an ONNX Intermediate Representation (IR)
Args:
preprocessor: ([`PreTrainedTokenizer`] or [`FeatureExtractionMixin`]):
preprocessor: ([`PreTrainedTokenizer`], [`FeatureExtractionMixin`] or [`ProcessorMixin`]):
The preprocessor used for encoding the data.
model ([`PreTrainedModel`] or [`TFPreTrainedModel`]):
The model to export.
@ -339,7 +340,7 @@ def export(
def validate_model_outputs(
config: OnnxConfig,
preprocessor: Union["PreTrainedTokenizer", "FeatureExtractionMixin"],
preprocessor: Union["PreTrainedTokenizer", "FeatureExtractionMixin", "ProcessorMixin"],
reference_model: Union["PreTrainedModel", "TFPreTrainedModel"],
onnx_model: Path,
onnx_named_outputs: List[str],

View File

@ -317,6 +317,13 @@ class FeaturesManager:
"token-classification",
onnx_config_cls="models.layoutlm.LayoutLMOnnxConfig",
),
"layoutlmv3": supported_features_mapping(
"default",
"question-answering",
"sequence-classification",
"token-classification",
onnx_config_cls="models.layoutlmv3.LayoutLMv3OnnxConfig",
),
"longt5": supported_features_mapping(
"default",
"default-with-past",

View File

@ -195,6 +195,7 @@ PYTORCH_EXPORT_MODELS = {
("xlm", "xlm-clm-ende-1024"),
("xlm-roberta", "xlm-roberta-base"),
("layoutlm", "microsoft/layoutlm-base-uncased"),
("layoutlmv3", "microsoft/layoutlmv3-base"),
("vit", "google/vit-base-patch16-224"),
("deit", "facebook/deit-small-patch16-224"),
("beit", "microsoft/beit-base-patch16-224"),