Adds IBERT to models exportable with ONNX (#14868)

* Add IBertOnnxConfig and tests

* add all the supported features for IBERT and remove outputs in IbertOnnxConfig

* use OnnxConfig

* fix codestyle

* remove serialization.rst

* codestyle
This commit is contained in:
Virus 2022-01-11 14:17:08 +03:00 committed by GitHub
parent efb35a4107
commit c4fa908fa9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 29 additions and 2 deletions

View File

@ -40,6 +40,7 @@ Ready-made configurations include the following models:
- CamemBERT
- DistilBERT
- GPT Neo
- I-BERT
- LayoutLM
- Longformer
- Marian

View File

@ -22,7 +22,7 @@ from ...file_utils import _LazyModule, is_torch_available
_import_structure = {
"configuration_ibert": ["IBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "IBertConfig"],
"configuration_ibert": ["IBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "IBertConfig", "IBertOnnxConfig"],
}
if is_torch_available():
@ -38,7 +38,7 @@ if is_torch_available():
]
if TYPE_CHECKING:
from .configuration_ibert import IBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, IBertConfig
from .configuration_ibert import IBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, IBertConfig, IBertOnnxConfig
if is_torch_available():
from .modeling_ibert import (

View File

@ -15,6 +15,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
""" I-BERT configuration"""
from collections import OrderedDict
from typing import Mapping
from transformers.onnx import OnnxConfig
from ...configuration_utils import PretrainedConfig
from ...utils import logging
@ -122,3 +126,14 @@ class IBertConfig(PretrainedConfig):
self.position_embedding_type = position_embedding_type
self.quant_mode = quant_mode
self.force_dequant = force_dequant
class IBertOnnxConfig(OnnxConfig):
@property
def inputs(self) -> Mapping[str, Mapping[int, str]]:
return OrderedDict(
[
("input_ids", {0: "batch", 1: "sequence"}),
("attention_mask", {0: "batch", 1: "sequence"}),
]
)

View File

@ -9,6 +9,7 @@ from ..models.camembert import CamembertOnnxConfig
from ..models.distilbert import DistilBertOnnxConfig
from ..models.gpt2 import GPT2OnnxConfig
from ..models.gpt_neo import GPTNeoOnnxConfig
from ..models.ibert import IBertOnnxConfig
from ..models.layoutlm import LayoutLMOnnxConfig
from ..models.longformer import LongformerOnnxConfig
from ..models.marian import MarianOnnxConfig
@ -125,6 +126,15 @@ class FeaturesManager:
"question-answering",
onnx_config_cls=BertOnnxConfig,
),
"ibert": supported_features_mapping(
"default",
"masked-lm",
"sequence-classification",
# "multiple-choice",
"token-classification",
"question-answering",
onnx_config_cls=IBertOnnxConfig,
),
"camembert": supported_features_mapping(
"default",
"masked-lm",

View File

@ -171,6 +171,7 @@ class OnnxConfigWithPastTestCaseV2(TestCase):
PYTORCH_EXPORT_MODELS = {
("albert", "hf-internal-testing/tiny-albert"),
("bert", "bert-base-cased"),
("ibert", "kssteven/ibert-roberta-base"),
("camembert", "camembert-base"),
("distilbert", "distilbert-base-cased"),
# ("longFormer", "longformer-base-4096"),