Add OnnxConfig for ConvBERT (#16859)

* add OnnxConfig for ConvBert

Co-authored-by: ChainYo <t.chaigneau.tc@gmail.com>
This commit is contained in:
Thomas Chaigneau 2022-04-22 18:19:15 +02:00 committed by GitHub
parent 0d1cff1195
commit ec81c11a18
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 35 additions and 2 deletions

View File

@ -53,6 +53,7 @@ Ready-made configurations include the following architectures:
- Blenderbot
- BlenderbotSmall
- CamemBERT
- ConvBERT
- Data2VecText
- Data2VecVision
- DistilBERT

View File

@ -21,7 +21,7 @@ from ...utils import _LazyModule, is_tf_available, is_tokenizers_available, is_t
_import_structure = {
"configuration_convbert": ["CONVBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "ConvBertConfig"],
"configuration_convbert": ["CONVBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "ConvBertConfig", "ConvBertOnnxConfig"],
"tokenization_convbert": ["ConvBertTokenizer"],
}
@ -58,7 +58,7 @@ if is_tf_available():
if TYPE_CHECKING:
from .configuration_convbert import CONVBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, ConvBertConfig
from .configuration_convbert import CONVBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, ConvBertConfig, ConvBertOnnxConfig
from .tokenization_convbert import ConvBertTokenizer
if is_tokenizers_available():

View File

@ -14,7 +14,11 @@
# limitations under the License.
""" ConvBERT model configuration"""
from collections import OrderedDict
from typing import Mapping
from ...configuration_utils import PretrainedConfig
from ...onnx import OnnxConfig
from ...utils import logging
@ -138,3 +142,20 @@ class ConvBertConfig(PretrainedConfig):
self.conv_kernel_size = conv_kernel_size
self.num_groups = num_groups
self.classifier_dropout = classifier_dropout
# Copied from transformers.models.bert.configuration_bert.BertOnnxConfig
class ConvBertOnnxConfig(OnnxConfig):
@property
def inputs(self) -> Mapping[str, Mapping[int, str]]:
if self.task == "multiple-choice":
dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"}
else:
dynamic_axis = {0: "batch", 1: "sequence"}
return OrderedDict(
[
("input_ids", dynamic_axis),
("attention_mask", dynamic_axis),
("token_type_ids", dynamic_axis),
]
)

View File

@ -10,6 +10,7 @@ from ..models.big_bird import BigBirdOnnxConfig
from ..models.blenderbot import BlenderbotOnnxConfig
from ..models.blenderbot_small import BlenderbotSmallOnnxConfig
from ..models.camembert import CamembertOnnxConfig
from ..models.convbert import ConvBertOnnxConfig
from ..models.data2vec import Data2VecTextOnnxConfig
from ..models.distilbert import DistilBertOnnxConfig
from ..models.electra import ElectraOnnxConfig
@ -187,6 +188,15 @@ class FeaturesManager:
"question-answering",
onnx_config_cls=CamembertOnnxConfig,
),
"convbert": supported_features_mapping(
"default",
"masked-lm",
"sequence-classification",
"multiple-choice",
"token-classification",
"question-answering",
onnx_config_cls=ConvBertOnnxConfig,
),
"distilbert": supported_features_mapping(
"default",
"masked-lm",

View File

@ -175,6 +175,7 @@ PYTORCH_EXPORT_MODELS = {
("bigbird", "google/bigbird-roberta-base"),
("ibert", "kssteven/ibert-roberta-base"),
("camembert", "camembert-base"),
("convbert", "YituTech/conv-bert-base"),
("distilbert", "distilbert-base-cased"),
("electra", "google/electra-base-generator"),
("roberta", "roberta-base"),