Add ONNX export for BeiT (#16498)

* Add beit onnx conversion support

* Updated docs

* Added cross reference to ViT ONNX config
This commit is contained in:
Jim Rohrer 2022-04-01 03:52:42 -05:00 committed by GitHub
parent bfeff6cc6a
commit 9de70f213e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 31 additions and 5 deletions

View File

@ -47,6 +47,7 @@ Ready-made configurations include the following architectures:
- ALBERT
- BART
- BEiT
- BERT
- Blenderbot
- BlenderbotSmall

View File

@ -22,7 +22,7 @@ from ...utils import _LazyModule, is_flax_available, is_torch_available, is_visi
_import_structure = {
"configuration_beit": ["BEIT_PRETRAINED_CONFIG_ARCHIVE_MAP", "BeitConfig"],
"configuration_beit": ["BEIT_PRETRAINED_CONFIG_ARCHIVE_MAP", "BeitConfig", "BeitOnnxConfig"],
}
if is_vision_available():
@ -48,7 +48,7 @@ if is_flax_available():
]
if TYPE_CHECKING:
from .configuration_beit import BEIT_PRETRAINED_CONFIG_ARCHIVE_MAP, BeitConfig
from .configuration_beit import BEIT_PRETRAINED_CONFIG_ARCHIVE_MAP, BeitConfig, BeitOnnxConfig
if is_vision_available():
from .feature_extraction_beit import BeitFeatureExtractor

View File

@ -13,8 +13,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
""" BEiT model configuration"""
from collections import OrderedDict
from typing import Mapping
from packaging import version
from ...configuration_utils import PretrainedConfig
from ...onnx import OnnxConfig
from ...utils import logging
@ -176,3 +181,21 @@ class BeitConfig(PretrainedConfig):
self.auxiliary_num_convs = auxiliary_num_convs
self.auxiliary_concat_input = auxiliary_concat_input
self.semantic_loss_ignore_index = semantic_loss_ignore_index
# Copied from transformers.models.vit.configuration_vit.ViTOnnxConfig
class BeitOnnxConfig(OnnxConfig):
torch_onnx_minimum_version = version.parse("1.11")
@property
def inputs(self) -> Mapping[str, Mapping[int, str]]:
return OrderedDict(
[
("pixel_values", {0: "batch", 1: "sequence"}),
]
)
@property
def atol_for_validation(self) -> float:
return 1e-4

View File

@ -4,6 +4,7 @@ from typing import Callable, Dict, Optional, Tuple, Type, Union
from .. import PretrainedConfig, PreTrainedModel, TFPreTrainedModel, is_tf_available, is_torch_available
from ..models.albert import AlbertOnnxConfig
from ..models.bart import BartOnnxConfig
from ..models.beit import BeitOnnxConfig
from ..models.bert import BertOnnxConfig
from ..models.blenderbot import BlenderbotOnnxConfig
from ..models.blenderbot_small import BlenderbotSmallOnnxConfig
@ -270,6 +271,7 @@ class FeaturesManager:
onnx_config_cls=ElectraOnnxConfig,
),
"vit": supported_features_mapping("default", "image-classification", onnx_config_cls=ViTOnnxConfig),
"beit": supported_features_mapping("default", "image-classification", onnx_config_cls=BeitOnnxConfig),
"blenderbot": supported_features_mapping(
"default",
"default-with-past",

View File

@ -15,14 +15,13 @@ from transformers.onnx import (
export,
validate_model_outputs,
)
from transformers.onnx.utils import compute_effective_axis_dimension, compute_serialized_parameters_size
from transformers.testing_utils import require_onnx, require_tf, require_torch, require_vision, slow
if is_torch_available() or is_tf_available():
from transformers.onnx.features import FeaturesManager
from transformers.onnx.utils import compute_effective_axis_dimension, compute_serialized_parameters_size
from transformers.testing_utils import require_onnx, require_tf, require_torch, require_vision, slow
@require_onnx
class OnnxUtilsTestCaseV2(TestCase):
@ -181,6 +180,7 @@ PYTORCH_EXPORT_MODELS = {
("xlm-roberta", "xlm-roberta-base"),
("layoutlm", "microsoft/layoutlm-base-uncased"),
("vit", "google/vit-base-patch16-224"),
("beit", "microsoft/beit-base-patch16-224"),
}
PYTORCH_EXPORT_WITH_PAST_MODELS = {