Fix Data2VecVision ONNX test (#18587)

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yih-Dar 2022-08-22 11:28:23 +02:00 committed by GitHub
parent 30992ef0d9
commit 3fa45dbd91
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 7 additions and 1 deletions

View File

@ -99,6 +99,7 @@ class OnnxConfig(ABC):
"end_logits": {0: "batch", 1: "sequence"},
}
),
"semantic-segmentation": OrderedDict({"logits": {0: "batch", 1: "num_labels", 2: "height", 3: "width"}}),
"seq2seq-lm": OrderedDict({"logits": {0: "batch", 1: "decoder_sequence"}}),
"sequence-classification": OrderedDict({"logits": {0: "batch"}}),
"token-classification": OrderedDict({"logits": {0: "batch", 1: "sequence"}}),

View File

@ -25,6 +25,7 @@ if is_torch_available():
AutoModelForMultipleChoice,
AutoModelForObjectDetection,
AutoModelForQuestionAnswering,
AutoModelForSemanticSegmentation,
AutoModelForSeq2SeqLM,
AutoModelForSequenceClassification,
AutoModelForTokenClassification,
@ -36,6 +37,7 @@ if is_tf_available():
TFAutoModelForMaskedLM,
TFAutoModelForMultipleChoice,
TFAutoModelForQuestionAnswering,
TFAutoModelForSemanticSegmentation,
TFAutoModelForSeq2SeqLM,
TFAutoModelForSequenceClassification,
TFAutoModelForTokenClassification,
@ -94,6 +96,7 @@ class FeaturesManager:
"image-classification": AutoModelForImageClassification,
"image-segmentation": AutoModelForImageSegmentation,
"masked-im": AutoModelForMaskedImageModeling,
"semantic-segmentation": AutoModelForSemanticSegmentation,
}
if is_tf_available():
_TASKS_TO_TF_AUTOMODELS = {
@ -105,6 +108,7 @@ class FeaturesManager:
"token-classification": TFAutoModelForTokenClassification,
"multiple-choice": TFAutoModelForMultipleChoice,
"question-answering": TFAutoModelForQuestionAnswering,
"semantic-segmentation": TFAutoModelForSemanticSegmentation,
}
# Set of model topologies we support associated to the features supported by each topology and the factory
@ -236,7 +240,8 @@ class FeaturesManager:
"data2vec-vision": supported_features_mapping(
"default",
"image-classification",
"image-segmentation",
# ONNX doesn't support `adaptive_avg_pool2d` yet
# "semantic-segmentation",
onnx_config_cls="models.data2vec.Data2VecVisionOnnxConfig",
),
"deberta": supported_features_mapping(