mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 18:22:34 +06:00
Fix Data2VecVision ONNX test (#18587)
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com> Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
parent
30992ef0d9
commit
3fa45dbd91
@ -99,6 +99,7 @@ class OnnxConfig(ABC):
|
||||
"end_logits": {0: "batch", 1: "sequence"},
|
||||
}
|
||||
),
|
||||
"semantic-segmentation": OrderedDict({"logits": {0: "batch", 1: "num_labels", 2: "height", 3: "width"}}),
|
||||
"seq2seq-lm": OrderedDict({"logits": {0: "batch", 1: "decoder_sequence"}}),
|
||||
"sequence-classification": OrderedDict({"logits": {0: "batch"}}),
|
||||
"token-classification": OrderedDict({"logits": {0: "batch", 1: "sequence"}}),
|
||||
|
@ -25,6 +25,7 @@ if is_torch_available():
|
||||
AutoModelForMultipleChoice,
|
||||
AutoModelForObjectDetection,
|
||||
AutoModelForQuestionAnswering,
|
||||
AutoModelForSemanticSegmentation,
|
||||
AutoModelForSeq2SeqLM,
|
||||
AutoModelForSequenceClassification,
|
||||
AutoModelForTokenClassification,
|
||||
@ -36,6 +37,7 @@ if is_tf_available():
|
||||
TFAutoModelForMaskedLM,
|
||||
TFAutoModelForMultipleChoice,
|
||||
TFAutoModelForQuestionAnswering,
|
||||
TFAutoModelForSemanticSegmentation,
|
||||
TFAutoModelForSeq2SeqLM,
|
||||
TFAutoModelForSequenceClassification,
|
||||
TFAutoModelForTokenClassification,
|
||||
@ -94,6 +96,7 @@ class FeaturesManager:
|
||||
"image-classification": AutoModelForImageClassification,
|
||||
"image-segmentation": AutoModelForImageSegmentation,
|
||||
"masked-im": AutoModelForMaskedImageModeling,
|
||||
"semantic-segmentation": AutoModelForSemanticSegmentation,
|
||||
}
|
||||
if is_tf_available():
|
||||
_TASKS_TO_TF_AUTOMODELS = {
|
||||
@ -105,6 +108,7 @@ class FeaturesManager:
|
||||
"token-classification": TFAutoModelForTokenClassification,
|
||||
"multiple-choice": TFAutoModelForMultipleChoice,
|
||||
"question-answering": TFAutoModelForQuestionAnswering,
|
||||
"semantic-segmentation": TFAutoModelForSemanticSegmentation,
|
||||
}
|
||||
|
||||
# Set of model topologies we support associated to the features supported by each topology and the factory
|
||||
@ -236,7 +240,8 @@ class FeaturesManager:
|
||||
"data2vec-vision": supported_features_mapping(
|
||||
"default",
|
||||
"image-classification",
|
||||
"image-segmentation",
|
||||
# ONNX doesn't support `adaptive_avg_pool2d` yet
|
||||
# "semantic-segmentation",
|
||||
onnx_config_cls="models.data2vec.Data2VecVisionOnnxConfig",
|
||||
),
|
||||
"deberta": supported_features_mapping(
|
||||
|
Loading…
Reference in New Issue
Block a user