mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Add OnnxConfig for ConvBERT (#16859)
* add OnnxConfig for ConvBert Co-authored-by: ChainYo <t.chaigneau.tc@gmail.com>
This commit is contained in:
parent
0d1cff1195
commit
ec81c11a18
@ -53,6 +53,7 @@ Ready-made configurations include the following architectures:
|
||||
- Blenderbot
|
||||
- BlenderbotSmall
|
||||
- CamemBERT
|
||||
- ConvBERT
|
||||
- Data2VecText
|
||||
- Data2VecVision
|
||||
- DistilBERT
|
||||
|
@ -21,7 +21,7 @@ from ...utils import _LazyModule, is_tf_available, is_tokenizers_available, is_t
|
||||
|
||||
|
||||
_import_structure = {
|
||||
"configuration_convbert": ["CONVBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "ConvBertConfig"],
|
||||
"configuration_convbert": ["CONVBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "ConvBertConfig", "ConvBertOnnxConfig"],
|
||||
"tokenization_convbert": ["ConvBertTokenizer"],
|
||||
}
|
||||
|
||||
@ -58,7 +58,7 @@ if is_tf_available():
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .configuration_convbert import CONVBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, ConvBertConfig
|
||||
from .configuration_convbert import CONVBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, ConvBertConfig, ConvBertOnnxConfig
|
||||
from .tokenization_convbert import ConvBertTokenizer
|
||||
|
||||
if is_tokenizers_available():
|
||||
|
@ -14,7 +14,11 @@
|
||||
# limitations under the License.
|
||||
""" ConvBERT model configuration"""
|
||||
|
||||
from collections import OrderedDict
|
||||
from typing import Mapping
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...onnx import OnnxConfig
|
||||
from ...utils import logging
|
||||
|
||||
|
||||
@ -138,3 +142,20 @@ class ConvBertConfig(PretrainedConfig):
|
||||
self.conv_kernel_size = conv_kernel_size
|
||||
self.num_groups = num_groups
|
||||
self.classifier_dropout = classifier_dropout
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.configuration_bert.BertOnnxConfig
|
||||
class ConvBertOnnxConfig(OnnxConfig):
|
||||
@property
|
||||
def inputs(self) -> Mapping[str, Mapping[int, str]]:
|
||||
if self.task == "multiple-choice":
|
||||
dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"}
|
||||
else:
|
||||
dynamic_axis = {0: "batch", 1: "sequence"}
|
||||
return OrderedDict(
|
||||
[
|
||||
("input_ids", dynamic_axis),
|
||||
("attention_mask", dynamic_axis),
|
||||
("token_type_ids", dynamic_axis),
|
||||
]
|
||||
)
|
||||
|
@ -10,6 +10,7 @@ from ..models.big_bird import BigBirdOnnxConfig
|
||||
from ..models.blenderbot import BlenderbotOnnxConfig
|
||||
from ..models.blenderbot_small import BlenderbotSmallOnnxConfig
|
||||
from ..models.camembert import CamembertOnnxConfig
|
||||
from ..models.convbert import ConvBertOnnxConfig
|
||||
from ..models.data2vec import Data2VecTextOnnxConfig
|
||||
from ..models.distilbert import DistilBertOnnxConfig
|
||||
from ..models.electra import ElectraOnnxConfig
|
||||
@ -187,6 +188,15 @@ class FeaturesManager:
|
||||
"question-answering",
|
||||
onnx_config_cls=CamembertOnnxConfig,
|
||||
),
|
||||
"convbert": supported_features_mapping(
|
||||
"default",
|
||||
"masked-lm",
|
||||
"sequence-classification",
|
||||
"multiple-choice",
|
||||
"token-classification",
|
||||
"question-answering",
|
||||
onnx_config_cls=ConvBertOnnxConfig,
|
||||
),
|
||||
"distilbert": supported_features_mapping(
|
||||
"default",
|
||||
"masked-lm",
|
||||
|
@ -175,6 +175,7 @@ PYTORCH_EXPORT_MODELS = {
|
||||
("bigbird", "google/bigbird-roberta-base"),
|
||||
("ibert", "kssteven/ibert-roberta-base"),
|
||||
("camembert", "camembert-base"),
|
||||
("convbert", "YituTech/conv-bert-base"),
|
||||
("distilbert", "distilbert-base-cased"),
|
||||
("electra", "google/electra-base-generator"),
|
||||
("roberta", "roberta-base"),
|
||||
|
Loading…
Reference in New Issue
Block a user