Add RemBERT ONNX config (#20520)

* rembert onnx config

* formatting

Co-authored-by: Ho <erincho@bcd0745f972b.ant.amazon.com>
This commit is contained in:
Erin 2022-12-05 08:39:09 -08:00 committed by GitHub
parent afe2a466bb
commit 87282cb73c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 39 additions and 2 deletions

View File

@ -93,6 +93,7 @@ Ready-made configurations include the following architectures:
- OWL-ViT
- Perceiver
- PLBart
- RemBERT
- ResNet
- RoBERTa
- RoFormer

View File

@ -28,7 +28,9 @@ from ...utils import (
)
_import_structure = {"configuration_rembert": ["REMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "RemBertConfig"]}
_import_structure = {
"configuration_rembert": ["REMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "RemBertConfig", "RemBertOnnxConfig"]
}
try:
if not is_sentencepiece_available():
@ -88,7 +90,7 @@ else:
if TYPE_CHECKING:
from .configuration_rembert import REMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, RemBertConfig
from .configuration_rembert import REMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, RemBertConfig, RemBertOnnxConfig
try:
if not is_sentencepiece_available():

View File

@ -13,8 +13,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
""" RemBERT model configuration"""
from collections import OrderedDict
from typing import Mapping
from ...configuration_utils import PretrainedConfig
from ...onnx import OnnxConfig
from ...utils import logging
@ -135,3 +138,23 @@ class RemBertConfig(PretrainedConfig):
self.layer_norm_eps = layer_norm_eps
self.use_cache = use_cache
self.tie_word_embeddings = False
class RemBertOnnxConfig(OnnxConfig):
@property
def inputs(self) -> Mapping[str, Mapping[int, str]]:
if self.task == "multiple-choice":
dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"}
else:
dynamic_axis = {0: "batch", 1: "sequence"}
return OrderedDict(
[
("input_ids", dynamic_axis),
("attention_mask", dynamic_axis),
("token_type_ids", dynamic_axis),
]
)
@property
def atol_for_validation(self) -> float:
return 1e-4

View File

@ -447,6 +447,16 @@ class FeaturesManager:
"sequence-classification",
onnx_config_cls="models.perceiver.PerceiverOnnxConfig",
),
"rembert": supported_features_mapping(
"default",
"masked-lm",
"causal-lm",
"sequence-classification",
"multiple-choice",
"token-classification",
"question-answering",
onnx_config_cls="models.rembert.RemBertOnnxConfig",
),
"resnet": supported_features_mapping(
"default",
"image-classification",

View File

@ -210,6 +210,7 @@ PYTORCH_EXPORT_MODELS = {
("owlvit", "google/owlvit-base-patch32"),
("perceiver", "hf-internal-testing/tiny-random-PerceiverModel", ("masked-lm", "sequence-classification")),
("perceiver", "hf-internal-testing/tiny-random-PerceiverModel", ("image-classification",)),
("rembert", "google/rembert"),
("resnet", "microsoft/resnet-50"),
("roberta", "hf-internal-testing/tiny-random-RobertaModel"),
("roformer", "hf-internal-testing/tiny-random-RoFormerModel"),