Add support for Perceiver ONNX export (#17213)

* Start adding perceiver support for ONNX

* Fix pad token bug for fast tokenizers

* Fix formatting

* Make get_preprocesor more opinionated (processor priority, otherwise tokenizer/feature extractor)

* Clean docs format

* Minor cleanup following @sgugger's comments

* Fix typo in docs

* Fix another docs typo

* Fix one more typo in docs

* Update src/transformers/onnx/utils.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/onnx/utils.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/onnx/utils.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
Patrick Deutschmann 2022-06-03 13:40:22 +02:00 committed by GitHub
parent 5c17918fe4
commit babeff5524
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 178 additions and 30 deletions

View File

@ -70,6 +70,7 @@ Ready-made configurations include the following architectures:
- mBART
- MobileBERT
- OpenAI GPT-2
- Perceiver
- PLBart
- RoBERTa
- RoFormer

View File

@ -14,8 +14,15 @@
# limitations under the License.
""" Perceiver model configuration"""
from collections import OrderedDict
from typing import Any, Mapping, Optional, Union
from ...configuration_utils import PretrainedConfig
from ...utils import logging
from ...feature_extraction_utils import FeatureExtractionMixin
from ...onnx import OnnxConfig
from ...onnx.utils import compute_effective_axis_dimension
from ...tokenization_utils_base import PreTrainedTokenizerBase
from ...utils import TensorType, logging
logger = logging.get_logger(__name__)
@ -172,3 +179,63 @@ class PerceiverConfig(PretrainedConfig):
self.audio_samples_per_frame = audio_samples_per_frame
self.samples_per_patch = samples_per_patch
self.output_shape = output_shape
class PerceiverOnnxConfig(OnnxConfig):
@property
def inputs(self) -> Mapping[str, Mapping[int, str]]:
if self.task == "multiple-choice":
dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"}
else:
dynamic_axis = {0: "batch", 1: "sequence"}
return OrderedDict(
[
("inputs", dynamic_axis),
("attention_mask", dynamic_axis),
]
)
@property
def atol_for_validation(self) -> float:
return 1e-4
def generate_dummy_inputs(
self,
preprocessor: Union["PreTrainedTokenizerBase", "FeatureExtractionMixin"],
batch_size: int = -1,
seq_length: int = -1,
num_choices: int = -1,
is_pair: bool = False,
framework: Optional[TensorType] = None,
num_channels: int = 3,
image_width: int = 40,
image_height: int = 40,
) -> Mapping[str, Any]:
# copied from `transformers.onnx.config.OnnxConfig` and slightly altered/simplified
if isinstance(preprocessor, PreTrainedTokenizerBase):
# If dynamic axis (-1) we forward with a fixed dimension of 2 samples to avoid optimizations made by ONNX
batch_size = compute_effective_axis_dimension(
batch_size, fixed_dimension=OnnxConfig.default_fixed_batch, num_token_to_add=0
)
# If dynamic axis (-1) we forward with a fixed dimension of 8 tokens to avoid optimizations made by ONNX
token_to_add = preprocessor.num_special_tokens_to_add(is_pair)
seq_length = compute_effective_axis_dimension(
seq_length, fixed_dimension=OnnxConfig.default_fixed_sequence, num_token_to_add=token_to_add
)
# Generate dummy inputs according to compute batch and sequence
dummy_input = [" ".join(["a"]) * seq_length] * batch_size
inputs = dict(preprocessor(dummy_input, return_tensors=framework))
inputs["inputs"] = inputs.pop("input_ids")
return inputs
elif isinstance(preprocessor, FeatureExtractionMixin) and preprocessor.model_input_names[0] == "pixel_values":
# If dynamic axis (-1) we forward with a fixed dimension of 2 samples to avoid optimizations made by ONNX
batch_size = compute_effective_axis_dimension(batch_size, fixed_dimension=OnnxConfig.default_fixed_batch)
dummy_input = self._generate_dummy_images(batch_size, num_channels, image_height, image_width)
inputs = dict(preprocessor(images=dummy_input, return_tensors=framework))
inputs["inputs"] = inputs.pop("pixel_values")
return inputs
else:
raise ValueError(
"Unable to generate dummy inputs for the model. Please provide a tokenizer or a preprocessor."
)

View File

@ -2735,7 +2735,9 @@ def _check_or_build_spatial_positions(pos, index_dims, batch_size):
"""
if pos is None:
pos = build_linear_positions(index_dims)
pos = torch.broadcast_to(pos[None], (batch_size,) + pos.shape)
# equivalent to `torch.broadcast_to(pos[None], (batch_size,) + pos.shape)`
# but `torch.broadcast_to` cannot be converted to ONNX
pos = pos[None].expand((batch_size,) + pos.shape)
pos = torch.reshape(pos, [batch_size, np.prod(index_dims), -1])
else:
# Just a warning label: you probably don't want your spatial features to
@ -2840,7 +2842,8 @@ class PerceiverEmbeddingDecoder(nn.Module):
def forward(self, hidden_states: torch.Tensor, embedding_layer: torch.Tensor) -> torch.Tensor:
batch_size, seq_len, d_model = hidden_states.shape
output = torch.matmul(hidden_states.reshape([-1, d_model]), embedding_layer.weight.T) # Flatten batch dim
# Flatten batch dim
output = torch.matmul(hidden_states.reshape([-1, d_model]), embedding_layer.weight.transpose(0, 1))
output = output + self.bias
return output.reshape([batch_size, seq_len, self.vocab_size])
@ -3166,9 +3169,9 @@ class PerceiverImagePreprocessor(AbstractPreprocessor):
if self.prep_type != "patches":
# move channels to last dimension, as the _build_network_inputs method below expects this
if inputs.ndim == 4:
inputs = torch.moveaxis(inputs, 1, -1)
inputs = torch.permute(inputs, (0, 2, 3, 1))
elif inputs.ndim == 5:
inputs = torch.moveaxis(inputs, 2, -1)
inputs = torch.permute(inputs, (0, 1, 3, 4, 2))
else:
raise ValueError("Unsupported data format for conv1x1.")

View File

@ -15,9 +15,8 @@
from argparse import ArgumentParser
from pathlib import Path
from ..models.auto import AutoConfig, AutoFeatureExtractor, AutoTokenizer
from ..models.auto.feature_extraction_auto import FEATURE_EXTRACTOR_MAPPING_NAMES
from ..models.auto.tokenization_auto import TOKENIZER_MAPPING_NAMES
from ..models.auto import AutoFeatureExtractor, AutoProcessor, AutoTokenizer
from ..onnx.utils import get_preprocessor
from ..utils import logging
from .convert import export, validate_model_outputs
from .features import FeaturesManager
@ -43,6 +42,13 @@ def main():
)
parser.add_argument("output", type=Path, help="Path indicating where to store generated ONNX model.")
parser.add_argument("--cache_dir", type=str, default=None, help="Path indicating where to store cache.")
parser.add_argument(
"--preprocessor",
type=str,
choices=["auto", "tokenizer", "feature_extractor", "processor"],
default="auto",
help="Which type of preprocessor to use. 'auto' tries to automatically detect it.",
)
# Retrieve CLI arguments
args = parser.parse_args()
@ -51,15 +57,17 @@ def main():
if not args.output.parent.exists():
args.output.parent.mkdir(parents=True)
# Check the modality of the inputs and instantiate the appropriate preprocessor
# TODO(lewtun): Refactor this as a function if we need to check modalities elsewhere as well
config = AutoConfig.from_pretrained(args.model)
if config.model_type in TOKENIZER_MAPPING_NAMES:
# Instantiate the appropriate preprocessor
if args.preprocessor == "auto":
preprocessor = get_preprocessor(args.model)
elif args.preprocessor == "tokenizer":
preprocessor = AutoTokenizer.from_pretrained(args.model)
elif config.model_type in FEATURE_EXTRACTOR_MAPPING_NAMES:
elif args.preprocessor == "feature_extractor":
preprocessor = AutoFeatureExtractor.from_pretrained(args.model)
elif args.preprocessor == "processor":
preprocessor = AutoProcessor.from_pretrained(args.model)
else:
raise ValueError(f"Unsupported model type: {config.model_type}")
raise ValueError(f"Unknown preprocessor type '{args.preprocessor}'")
# Allocate the model
model = FeaturesManager.get_model_from_feature(

View File

@ -26,6 +26,7 @@ from ..models.m2m_100 import M2M100OnnxConfig
from ..models.marian import MarianOnnxConfig
from ..models.mbart import MBartOnnxConfig
from ..models.mobilebert import MobileBertOnnxConfig
from ..models.perceiver.configuration_perceiver import PerceiverOnnxConfig
from ..models.roberta import RobertaOnnxConfig
from ..models.roformer import RoFormerOnnxConfig
from ..models.squeezebert import SqueezeBertOnnxConfig
@ -332,6 +333,12 @@ class FeaturesManager:
"m2m-100": supported_features_mapping(
"default", "default-with-past", "seq2seq-lm", "seq2seq-lm-with-past", onnx_config_cls=M2M100OnnxConfig
),
"perceiver": supported_features_mapping(
"image-classification",
"masked-lm",
"sequence-classification",
onnx_config_cls=PerceiverOnnxConfig,
),
"roberta": supported_features_mapping(
"default",
"masked-lm",
@ -516,3 +523,18 @@ class FeaturesManager:
)
return model.config.model_type, FeaturesManager._SUPPORTED_MODEL_TYPE[model_type][feature]
def get_config(model_type: str, feature: str) -> OnnxConfig:
"""
Gets the OnnxConfig for a model_type and feature combination.
Args:
model_type (`str`):
The model type to retrieve the config for.
feature (`str`):
The feature to retrieve the config for.
Returns:
`OnnxConfig`: config for the combination
"""
return FeaturesManager._SUPPORTED_MODEL_TYPE[model_type][feature]

View File

@ -14,6 +14,9 @@
from ctypes import c_float, sizeof
from enum import Enum
from typing import Optional, Union
from .. import AutoFeatureExtractor, AutoProcessor, AutoTokenizer
class ParameterFormat(Enum):
@ -61,3 +64,41 @@ def compute_serialized_parameters_size(num_parameters: int, dtype: ParameterForm
Size (in byte) taken to save all the parameters
"""
return num_parameters * dtype.size
def get_preprocessor(model_name: str) -> Optional[Union[AutoTokenizer, AutoFeatureExtractor, AutoProcessor]]:
"""
Gets a preprocessor (tokenizer, feature extractor or processor) that is available for `model_name`.
Args:
model_name (`str`): Name of the model for which a preprocessor are loaded.
Returns:
`Optional[Union[AutoTokenizer, AutoFeatureExtractor, AutoProcessor]]`:
If a processor is found, it is returned. Otherwise, if a tokenizer or a feature extractor exists, it is
returned. If both a tokenizer and a feature extractor exist, an error is raised. The function returns
`None` if no preprocessor is found.
"""
try:
return AutoProcessor.from_pretrained(model_name)
except (ValueError, OSError, KeyError):
tokenizer, feature_extractor = None, None
try:
tokenizer = AutoTokenizer.from_pretrained(model_name)
except (OSError, KeyError):
pass
try:
feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)
except (OSError, KeyError):
pass
if tokenizer is not None and feature_extractor is not None:
raise ValueError(
f"Couldn't auto-detect preprocessor for {model_name}. Found both a tokenizer and a feature extractor."
)
elif tokenizer is None and feature_extractor is None:
return None
elif tokenizer is not None:
return tokenizer
else:
return feature_extractor

View File

@ -6,7 +6,7 @@ from unittest.mock import patch
import pytest
from parameterized import parameterized
from transformers import AutoConfig, AutoFeatureExtractor, AutoTokenizer, is_tf_available, is_torch_available
from transformers import AutoConfig, PreTrainedTokenizerBase, is_tf_available, is_torch_available
from transformers.onnx import (
EXTERNAL_DATA_FORMAT_SIZE_LIMIT,
OnnxConfig,
@ -15,7 +15,11 @@ from transformers.onnx import (
export,
validate_model_outputs,
)
from transformers.onnx.utils import compute_effective_axis_dimension, compute_serialized_parameters_size
from transformers.onnx.utils import (
compute_effective_axis_dimension,
compute_serialized_parameters_size,
get_preprocessor,
)
from transformers.testing_utils import require_onnx, require_rjieba, require_tf, require_torch, require_vision, slow
@ -189,6 +193,8 @@ PYTORCH_EXPORT_MODELS = {
("deit", "facebook/deit-small-patch16-224"),
("beit", "microsoft/beit-base-patch16-224"),
("data2vec-text", "facebook/data2vec-text-base"),
("perceiver", "deepmind/language-perceiver", ("masked-lm", "sequence-classification")),
("perceiver", "deepmind/vision-perceiver-conv", ("image-classification",)),
}
PYTORCH_EXPORT_WITH_PAST_MODELS = {
@ -226,10 +232,15 @@ TENSORFLOW_EXPORT_SEQ2SEQ_WITH_PAST_MODELS = {}
def _get_models_to_test(export_models_list):
models_to_test = []
if is_torch_available() or is_tf_available():
for name, model in export_models_list:
for feature, onnx_config_class_constructor in FeaturesManager.get_supported_features_for_model_type(
name
).items():
for name, model, *features in export_models_list:
if features:
feature_config_mapping = {
feature: FeaturesManager.get_config(name, feature) for _ in features for feature in _
}
else:
feature_config_mapping = FeaturesManager.get_supported_features_for_model_type(name)
for feature, onnx_config_class_constructor in feature_config_mapping.items():
models_to_test.append((f"{name}_{feature}", name, model, feature, onnx_config_class_constructor))
return sorted(models_to_test)
else:
@ -261,16 +272,11 @@ class OnnxExportTestCaseV2(TestCase):
f" {onnx_config.torch_onnx_minimum_version}, got: {torch_version}"
)
# Check the modality of the inputs and instantiate the appropriate preprocessor
if model.main_input_name == "input_ids":
preprocessor = AutoTokenizer.from_pretrained(model_name)
# Useful for causal lm models that do not use pad tokens.
if not getattr(config, "pad_token_id", None):
config.pad_token_id = preprocessor.eos_token_id
elif model.main_input_name == "pixel_values":
preprocessor = AutoFeatureExtractor.from_pretrained(model_name)
else:
raise ValueError(f"Unsupported model input name: {model.main_input_name}")
preprocessor = get_preprocessor(model_name)
# Useful for causal lm models that do not use pad tokens.
if isinstance(preprocessor, PreTrainedTokenizerBase) and not getattr(config, "pad_token_id", None):
config.pad_token_id = preprocessor.eos_token_id
with NamedTemporaryFile("w") as output:
try: