Added XLM onnx config (#17030)

* Add onnx configuration for xlm

* Add supported features for xlm

* Add xlm to models exportable with onnx

* Add xlm architecture to test file

* Modify docs

* Make code quality fixes
This commit is contained in:
Ritik Nandwal 2022-05-31 18:56:06 +05:30 committed by GitHub
parent 567d9c061d
commit 5af38953bb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 35 additions and 2 deletions

View File

@ -75,6 +75,7 @@ Ready-made configurations include the following architectures:
- RoFormer
- T5
- ViT
- XLM
- XLM-RoBERTa
- XLM-RoBERTa-XL

View File

@ -22,7 +22,7 @@ from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_tf_availabl
_import_structure = {
"configuration_xlm": ["XLM_PRETRAINED_CONFIG_ARCHIVE_MAP", "XLMConfig"],
"configuration_xlm": ["XLM_PRETRAINED_CONFIG_ARCHIVE_MAP", "XLMConfig", "XLMOnnxConfig"],
"tokenization_xlm": ["XLMTokenizer"],
}
@ -64,7 +64,7 @@ else:
if TYPE_CHECKING:
from .configuration_xlm import XLM_PRETRAINED_CONFIG_ARCHIVE_MAP, XLMConfig
from .configuration_xlm import XLM_PRETRAINED_CONFIG_ARCHIVE_MAP, XLMConfig, XLMOnnxConfig
from .tokenization_xlm import XLMTokenizer
try:

View File

@ -13,8 +13,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
""" XLM configuration"""
from collections import OrderedDict
from typing import Mapping
from ...configuration_utils import PretrainedConfig
from ...onnx import OnnxConfig
from ...utils import logging
@ -228,3 +231,20 @@ class XLMConfig(PretrainedConfig):
self.n_words = kwargs["n_words"]
super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, **kwargs)
# Copied from transformers.models.bert.configuration_bert.BertOnnxConfig
class XLMOnnxConfig(OnnxConfig):
@property
def inputs(self) -> Mapping[str, Mapping[int, str]]:
if self.task == "multiple-choice":
dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"}
else:
dynamic_axis = {0: "batch", 1: "sequence"}
return OrderedDict(
[
("input_ids", dynamic_axis),
("attention_mask", dynamic_axis),
("token_type_ids", dynamic_axis),
]
)

View File

@ -30,6 +30,7 @@ from ..models.roberta import RobertaOnnxConfig
from ..models.roformer import RoFormerOnnxConfig
from ..models.t5 import T5OnnxConfig
from ..models.vit import ViTOnnxConfig
from ..models.xlm import XLMOnnxConfig
from ..models.xlm_roberta import XLMRobertaOnnxConfig
from ..utils import logging
from .config import OnnxConfig
@ -357,6 +358,16 @@ class FeaturesManager:
"vit": supported_features_mapping(
"default", "image-classification", "masked-im", onnx_config_cls=ViTOnnxConfig
),
"xlm": supported_features_mapping(
"default",
"masked-lm",
"causal-lm",
"sequence-classification",
"multiple-choice",
"token-classification",
"question-answering",
onnx_config_cls=XLMOnnxConfig,
),
"xlm-roberta": supported_features_mapping(
"default",
"masked-lm",

View File

@ -181,6 +181,7 @@ PYTORCH_EXPORT_MODELS = {
("roberta", "roberta-base"),
("roformer", "junnyu/roformer_chinese_base"),
("mobilebert", "google/mobilebert-uncased"),
("xlm", "xlm-clm-ende-1024"),
("xlm-roberta", "xlm-roberta-base"),
("layoutlm", "microsoft/layoutlm-base-uncased"),
("vit", "google/vit-base-patch16-224"),