mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Adds GroupViT to models exportable with ONNX (#18628)
* groupvit to onnx * dynamic shape for pixel values dim
This commit is contained in:
parent
46d0e26a27
commit
220da3b8a1
@ -70,6 +70,7 @@ Ready-made configurations include the following architectures:
|
||||
- FlauBERT
|
||||
- GPT Neo
|
||||
- GPT-J
|
||||
- GroupViT
|
||||
- I-BERT
|
||||
- LayoutLM
|
||||
- LayoutLMv3
|
||||
|
@ -24,6 +24,7 @@ _import_structure = {
|
||||
"configuration_groupvit": [
|
||||
"GROUPVIT_PRETRAINED_CONFIG_ARCHIVE_MAP",
|
||||
"GroupViTConfig",
|
||||
"GroupViTOnnxConfig",
|
||||
"GroupViTTextConfig",
|
||||
"GroupViTVisionConfig",
|
||||
],
|
||||
@ -47,6 +48,7 @@ if TYPE_CHECKING:
|
||||
from .configuration_groupvit import (
|
||||
GROUPVIT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
GroupViTConfig,
|
||||
GroupViTOnnxConfig,
|
||||
GroupViTTextConfig,
|
||||
GroupViTVisionConfig,
|
||||
)
|
||||
|
@ -16,12 +16,19 @@
|
||||
|
||||
import copy
|
||||
import os
|
||||
from typing import Union
|
||||
from collections import OrderedDict
|
||||
from typing import TYPE_CHECKING, Any, Mapping, Optional, Union
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...onnx import OnnxConfig
|
||||
from ...utils import logging
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ...processing_utils import ProcessorMixin
|
||||
from ...utils import TensorType
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
GROUPVIT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
||||
@ -343,3 +350,44 @@ class GroupViTConfig(PretrainedConfig):
|
||||
output["vision_config"] = self.vision_config.to_dict()
|
||||
output["model_type"] = self.__class__.model_type
|
||||
return output
|
||||
|
||||
|
||||
class GroupViTOnnxConfig(OnnxConfig):
|
||||
@property
|
||||
def inputs(self) -> Mapping[str, Mapping[int, str]]:
|
||||
return OrderedDict(
|
||||
[
|
||||
("input_ids", {0: "batch", 1: "sequence"}),
|
||||
("pixel_values", {0: "batch", 1: "num_channels", 2: "height", 3: "width"}),
|
||||
("attention_mask", {0: "batch", 1: "sequence"}),
|
||||
]
|
||||
)
|
||||
|
||||
@property
|
||||
def outputs(self) -> Mapping[str, Mapping[int, str]]:
|
||||
return OrderedDict(
|
||||
[
|
||||
("logits_per_image", {0: "batch"}),
|
||||
("logits_per_text", {0: "batch"}),
|
||||
("text_embeds", {0: "batch"}),
|
||||
("image_embeds", {0: "batch"}),
|
||||
]
|
||||
)
|
||||
|
||||
@property
|
||||
def atol_for_validation(self) -> float:
|
||||
return 1e-4
|
||||
|
||||
def generate_dummy_inputs(
|
||||
self,
|
||||
processor: "ProcessorMixin",
|
||||
framework: Optional["TensorType"] = None,
|
||||
) -> Mapping[str, Any]:
|
||||
|
||||
text_input_dict = super().generate_dummy_inputs(processor.tokenizer, framework=framework)
|
||||
image_input_dict = super().generate_dummy_inputs(processor.feature_extractor, framework=framework)
|
||||
return {**text_input_dict, **image_input_dict}
|
||||
|
||||
@property
|
||||
def default_onnx_opset(self) -> int:
|
||||
return 14
|
||||
|
@ -1542,7 +1542,7 @@ class GroupViTModel(GroupViTPreTrainedModel):
|
||||
# cosine similarity as logits
|
||||
logit_scale = self.logit_scale.exp()
|
||||
logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * logit_scale
|
||||
logits_per_image = logits_per_text.T
|
||||
logits_per_image = logits_per_text.t()
|
||||
|
||||
seg_logits = None
|
||||
if output_segmentation:
|
||||
|
@ -326,6 +326,10 @@ class FeaturesManager:
|
||||
"sequence-classification",
|
||||
onnx_config_cls="models.gpt_neo.GPTNeoOnnxConfig",
|
||||
),
|
||||
"groupvit": supported_features_mapping(
|
||||
"default",
|
||||
onnx_config_cls="models.groupvit.GroupViTOnnxConfig",
|
||||
),
|
||||
"ibert": supported_features_mapping(
|
||||
"default",
|
||||
"masked-lm",
|
||||
|
@ -204,6 +204,7 @@ PYTORCH_EXPORT_MODELS = {
|
||||
("xlm-roberta", "xlm-roberta-base"),
|
||||
("layoutlm", "microsoft/layoutlm-base-uncased"),
|
||||
("layoutlmv3", "microsoft/layoutlmv3-base"),
|
||||
("groupvit", "nvidia/groupvit-gcc-yfcc"),
|
||||
("levit", "facebook/levit-128S"),
|
||||
("owlvit", "google/owlvit-base-patch32"),
|
||||
("vit", "google/vit-base-patch16-224"),
|
||||
|
Loading…
Reference in New Issue
Block a user