From 3fa45dbd91dcaa8cd8e4278da9ca3b4fced677a4 Mon Sep 17 00:00:00 2001 From: Yih-Dar <2521628+ydshieh@users.noreply.github.com> Date: Mon, 22 Aug 2022 11:28:23 +0200 Subject: [PATCH] Fix Data2VecVision ONNX test (#18587) Co-authored-by: lewtun Co-authored-by: ydshieh --- src/transformers/onnx/config.py | 1 + src/transformers/onnx/features.py | 7 ++++++- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/src/transformers/onnx/config.py b/src/transformers/onnx/config.py index fdcc12bdcd1..3b789051a22 100644 --- a/src/transformers/onnx/config.py +++ b/src/transformers/onnx/config.py @@ -99,6 +99,7 @@ class OnnxConfig(ABC): "end_logits": {0: "batch", 1: "sequence"}, } ), + "semantic-segmentation": OrderedDict({"logits": {0: "batch", 1: "num_labels", 2: "height", 3: "width"}}), "seq2seq-lm": OrderedDict({"logits": {0: "batch", 1: "decoder_sequence"}}), "sequence-classification": OrderedDict({"logits": {0: "batch"}}), "token-classification": OrderedDict({"logits": {0: "batch", 1: "sequence"}}), diff --git a/src/transformers/onnx/features.py b/src/transformers/onnx/features.py index fbfeb47250e..3596fe18400 100644 --- a/src/transformers/onnx/features.py +++ b/src/transformers/onnx/features.py @@ -25,6 +25,7 @@ if is_torch_available(): AutoModelForMultipleChoice, AutoModelForObjectDetection, AutoModelForQuestionAnswering, + AutoModelForSemanticSegmentation, AutoModelForSeq2SeqLM, AutoModelForSequenceClassification, AutoModelForTokenClassification, @@ -36,6 +37,7 @@ if is_tf_available(): TFAutoModelForMaskedLM, TFAutoModelForMultipleChoice, TFAutoModelForQuestionAnswering, + TFAutoModelForSemanticSegmentation, TFAutoModelForSeq2SeqLM, TFAutoModelForSequenceClassification, TFAutoModelForTokenClassification, @@ -94,6 +96,7 @@ class FeaturesManager: "image-classification": AutoModelForImageClassification, "image-segmentation": AutoModelForImageSegmentation, "masked-im": AutoModelForMaskedImageModeling, + "semantic-segmentation": AutoModelForSemanticSegmentation, } if is_tf_available(): _TASKS_TO_TF_AUTOMODELS = { @@ -105,6 +108,7 @@ class FeaturesManager: "token-classification": TFAutoModelForTokenClassification, "multiple-choice": TFAutoModelForMultipleChoice, "question-answering": TFAutoModelForQuestionAnswering, + "semantic-segmentation": TFAutoModelForSemanticSegmentation, } # Set of model topologies we support associated to the features supported by each topology and the factory @@ -236,7 +240,8 @@ class FeaturesManager: "data2vec-vision": supported_features_mapping( "default", "image-classification", - "image-segmentation", + # ONNX doesn't support `adaptive_avg_pool2d` yet + # "semantic-segmentation", onnx_config_cls="models.data2vec.Data2VecVisionOnnxConfig", ), "deberta": supported_features_mapping(