mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-04 05:10:06 +06:00
Merge branch 'main' into set-supports-static-cache-false-on-moes
This commit is contained in:
commit
3b20325201
@ -14,66 +14,111 @@ rendered properly in your Markdown viewer.
|
|||||||
|
|
||||||
-->
|
-->
|
||||||
|
|
||||||
# DeBERTa-v2
|
<div style="float: right;">
|
||||||
|
<div class="flex flex-wrap space-x-1">
|
||||||
<div class="flex flex-wrap space-x-1">
|
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white" >
|
||||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
<img alt="TensorFlow" src="https://img.shields.io/badge/TensorFlow-FF6F00?style=flat&logo=tensorflow&logoColor=white">
|
||||||
<img alt="TensorFlow" src="https://img.shields.io/badge/TensorFlow-FF6F00?style=flat&logo=tensorflow&logoColor=white">
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
## Overview
|
|
||||||
|
|
||||||
The DeBERTa model was proposed in [DeBERTa: Decoding-enhanced BERT with Disentangled Attention](https://huggingface.co/papers/2006.03654) by Pengcheng He, Xiaodong Liu, Jianfeng Gao, Weizhu Chen It is based on Google's
|
# DeBERTa-v2
|
||||||
BERT model released in 2018 and Facebook's RoBERTa model released in 2019.
|
|
||||||
|
|
||||||
It builds on RoBERTa with disentangled attention and enhanced mask decoder training with half of the data used in
|
[DeBERTa-v2](https://huggingface.co/papers/2006.03654) improves on the original [DeBERTa](./deberta) architecture by using a SentencePiece-based tokenizer and a new vocabulary size of 128K. It also adds an additional convolutional layer within the first transformer layer to better learn local dependencies of input tokens. Finally, the position projection and content projection matrices are shared in the attention layer to reduce the number of parameters.
|
||||||
RoBERTa.
|
|
||||||
|
|
||||||
The abstract from the paper is the following:
|
You can find all the original [DeBERTa-v2] checkpoints under the [Microsoft](https://huggingface.co/microsoft?search_models=deberta-v2) organization.
|
||||||
|
|
||||||
*Recent progress in pre-trained neural language models has significantly improved the performance of many natural
|
|
||||||
language processing (NLP) tasks. In this paper we propose a new model architecture DeBERTa (Decoding-enhanced BERT with
|
|
||||||
disentangled attention) that improves the BERT and RoBERTa models using two novel techniques. The first is the
|
|
||||||
disentangled attention mechanism, where each word is represented using two vectors that encode its content and
|
|
||||||
position, respectively, and the attention weights among words are computed using disentangled matrices on their
|
|
||||||
contents and relative positions. Second, an enhanced mask decoder is used to replace the output softmax layer to
|
|
||||||
predict the masked tokens for model pretraining. We show that these two techniques significantly improve the efficiency
|
|
||||||
of model pretraining and performance of downstream tasks. Compared to RoBERTa-Large, a DeBERTa model trained on half of
|
|
||||||
the training data performs consistently better on a wide range of NLP tasks, achieving improvements on MNLI by +0.9%
|
|
||||||
(90.2% vs. 91.1%), on SQuAD v2.0 by +2.3% (88.4% vs. 90.7%) and RACE by +3.6% (83.2% vs. 86.8%). The DeBERTa code and
|
|
||||||
pre-trained models will be made publicly available at https://github.com/microsoft/DeBERTa.*
|
|
||||||
|
|
||||||
|
|
||||||
The following information is visible directly on the [original implementation
|
> [!TIP]
|
||||||
repository](https://github.com/microsoft/DeBERTa). DeBERTa v2 is the second version of the DeBERTa model. It includes
|
> This model was contributed by [Pengcheng He](https://huggingface.co/DeBERTa).
|
||||||
the 1.5B model used for the SuperGLUE single-model submission and achieving 89.9, versus human baseline 89.8. You can
|
>
|
||||||
find more details about this submission in the authors'
|
> Click on the DeBERTa-v2 models in the right sidebar for more examples of how to apply DeBERTa-v2 to different language tasks.
|
||||||
[blog](https://www.microsoft.com/en-us/research/blog/microsoft-deberta-surpasses-human-performance-on-the-superglue-benchmark/)
|
|
||||||
|
|
||||||
New in v2:
|
The example below demonstrates how to classify text with [`Pipeline`] or the [`AutoModel`] class.
|
||||||
|
|
||||||
- **Vocabulary** In v2 the tokenizer is changed to use a new vocabulary of size 128K built from the training data.
|
<hfoptions id="usage">
|
||||||
Instead of a GPT2-based tokenizer, the tokenizer is now
|
<hfoption id="Pipeline">
|
||||||
[sentencepiece-based](https://github.com/google/sentencepiece) tokenizer.
|
|
||||||
- **nGiE(nGram Induced Input Encoding)** The DeBERTa-v2 model uses an additional convolution layer aside with the first
|
|
||||||
transformer layer to better learn the local dependency of input tokens.
|
|
||||||
- **Sharing position projection matrix with content projection matrix in attention layer** Based on previous
|
|
||||||
experiments, this can save parameters without affecting the performance.
|
|
||||||
- **Apply bucket to encode relative positions** The DeBERTa-v2 model uses log bucket to encode relative positions
|
|
||||||
similar to T5.
|
|
||||||
- **900M model & 1.5B model** Two additional model sizes are available: 900M and 1.5B, which significantly improves the
|
|
||||||
performance of downstream tasks.
|
|
||||||
|
|
||||||
This model was contributed by [DeBERTa](https://huggingface.co/DeBERTa). This model TF 2.0 implementation was
|
```py
|
||||||
contributed by [kamalkraj](https://huggingface.co/kamalkraj). The original code can be found [here](https://github.com/microsoft/DeBERTa).
|
import torch
|
||||||
|
from transformers import pipeline
|
||||||
|
|
||||||
## Resources
|
pipeline = pipeline(
|
||||||
|
task="text-classification",
|
||||||
|
model="microsoft/deberta-v2-xlarge-mnli",
|
||||||
|
device=0,
|
||||||
|
torch_dtype=torch.float16
|
||||||
|
)
|
||||||
|
result = pipeline("DeBERTa-v2 is great at understanding context!")
|
||||||
|
print(result)
|
||||||
|
```
|
||||||
|
|
||||||
|
</hfoption>
|
||||||
|
<hfoption id="AutoModel">
|
||||||
|
|
||||||
|
```py
|
||||||
|
import torch
|
||||||
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
|
"microsoft/deberta-v2-xlarge-mnli"
|
||||||
|
)
|
||||||
|
model = AutoModelForSequenceClassification.from_pretrained(
|
||||||
|
"microsoft/deberta-v2-xlarge-mnli",
|
||||||
|
torch_dtype=torch.float16,
|
||||||
|
device_map="auto"
|
||||||
|
)
|
||||||
|
|
||||||
|
inputs = tokenizer("DeBERTa-v2 is great at understanding context!", return_tensors="pt").to("cuda")
|
||||||
|
outputs = model(**inputs)
|
||||||
|
|
||||||
|
logits = outputs.logits
|
||||||
|
predicted_class_id = logits.argmax().item()
|
||||||
|
predicted_label = model.config.id2label[predicted_class_id]
|
||||||
|
print(f"Predicted label: {predicted_label}")
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
</hfoption>
|
||||||
|
|
||||||
|
<hfoption id="transformers CLI">
|
||||||
|
|
||||||
|
```bash
|
||||||
|
echo -e "DeBERTa-v2 is great at understanding context!" | transformers-cli run --task fill-mask --model microsoft/deberta-v2-xlarge-mnli --device 0
|
||||||
|
```
|
||||||
|
</hfoption>
|
||||||
|
</hfoptions>
|
||||||
|
|
||||||
|
Quantization reduces the memory burden of large models by representing the weights in a lower precision. Refer to the [Quantization](../quantization/overview) overview for more available quantization backends.
|
||||||
|
|
||||||
|
The example below uses [bitsandbytes quantization](../quantization/bitsandbytes) to only quantize the weights to 4-bit.
|
||||||
|
|
||||||
|
```py
|
||||||
|
from transformers import AutoModelForSequenceClassification, AutoTokenizer, BitsAndBytesConfig
|
||||||
|
|
||||||
|
model_id = "microsoft/deberta-v2-xlarge-mnli"
|
||||||
|
quantization_config = BitsAndBytesConfig(
|
||||||
|
load_in_4bit=True,
|
||||||
|
bnb_4bit_quant_type="nf4",
|
||||||
|
bnb_4bit_compute_dtype="float16",
|
||||||
|
bnb_4bit_use_double_quant=True,
|
||||||
|
)
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||||
|
model = AutoModelForSequenceClassification.from_pretrained(
|
||||||
|
model_id,
|
||||||
|
quantization_config=quantization_config,
|
||||||
|
torch_dtype="float16"
|
||||||
|
)
|
||||||
|
|
||||||
|
inputs = tokenizer("DeBERTa-v2 is great at understanding context!", return_tensors="pt").to("cuda")
|
||||||
|
outputs = model(**inputs)
|
||||||
|
logits = outputs.logits
|
||||||
|
predicted_class_id = logits.argmax().item()
|
||||||
|
predicted_label = model.config.id2label[predicted_class_id]
|
||||||
|
print(f"Predicted label: {predicted_label}")
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
- [Text classification task guide](../tasks/sequence_classification)
|
|
||||||
- [Token classification task guide](../tasks/token_classification)
|
|
||||||
- [Question answering task guide](../tasks/question_answering)
|
|
||||||
- [Masked language modeling task guide](../tasks/masked_language_modeling)
|
|
||||||
- [Multiple choice task guide](../tasks/multiple_choice)
|
|
||||||
|
|
||||||
## DebertaV2Config
|
## DebertaV2Config
|
||||||
|
|
||||||
|
@ -130,6 +130,7 @@ CONFIG_MAPPING_NAMES = OrderedDict[str, str](
|
|||||||
("falcon_h1", "FalconH1Config"),
|
("falcon_h1", "FalconH1Config"),
|
||||||
("falcon_mamba", "FalconMambaConfig"),
|
("falcon_mamba", "FalconMambaConfig"),
|
||||||
("fastspeech2_conformer", "FastSpeech2ConformerConfig"),
|
("fastspeech2_conformer", "FastSpeech2ConformerConfig"),
|
||||||
|
("fastspeech2_conformer_with_hifigan", "FastSpeech2ConformerWithHifiGanConfig"),
|
||||||
("flaubert", "FlaubertConfig"),
|
("flaubert", "FlaubertConfig"),
|
||||||
("flava", "FlavaConfig"),
|
("flava", "FlavaConfig"),
|
||||||
("fnet", "FNetConfig"),
|
("fnet", "FNetConfig"),
|
||||||
@ -511,6 +512,7 @@ MODEL_NAMES_MAPPING = OrderedDict[str, str](
|
|||||||
("falcon_h1", "FalconH1"),
|
("falcon_h1", "FalconH1"),
|
||||||
("falcon_mamba", "FalconMamba"),
|
("falcon_mamba", "FalconMamba"),
|
||||||
("fastspeech2_conformer", "FastSpeech2Conformer"),
|
("fastspeech2_conformer", "FastSpeech2Conformer"),
|
||||||
|
("fastspeech2_conformer_with_hifigan", "FastSpeech2ConformerWithHifiGan"),
|
||||||
("flan-t5", "FLAN-T5"),
|
("flan-t5", "FLAN-T5"),
|
||||||
("flan-ul2", "FLAN-UL2"),
|
("flan-ul2", "FLAN-UL2"),
|
||||||
("flaubert", "FlauBERT"),
|
("flaubert", "FlauBERT"),
|
||||||
@ -866,6 +868,7 @@ SPECIAL_MODEL_TYPE_TO_MODULE_NAME = OrderedDict[str, str](
|
|||||||
("sam_hq_vision_model", "sam_hq"),
|
("sam_hq_vision_model", "sam_hq"),
|
||||||
("llama4_text", "llama4"),
|
("llama4_text", "llama4"),
|
||||||
("blip_2_qformer", "blip_2"),
|
("blip_2_qformer", "blip_2"),
|
||||||
|
("fastspeech2_conformer_with_hifigan", "fastspeech2_conformer"),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -1178,7 +1181,8 @@ class AutoConfig:
|
|||||||
|
|
||||||
>>> unused_kwargs
|
>>> unused_kwargs
|
||||||
{'foo': False}
|
{'foo': False}
|
||||||
```"""
|
```
|
||||||
|
"""
|
||||||
use_auth_token = kwargs.pop("use_auth_token", None)
|
use_auth_token = kwargs.pop("use_auth_token", None)
|
||||||
if use_auth_token is not None:
|
if use_auth_token is not None:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
|
@ -121,6 +121,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
|
|||||||
("falcon_h1", "FalconH1Model"),
|
("falcon_h1", "FalconH1Model"),
|
||||||
("falcon_mamba", "FalconMambaModel"),
|
("falcon_mamba", "FalconMambaModel"),
|
||||||
("fastspeech2_conformer", "FastSpeech2ConformerModel"),
|
("fastspeech2_conformer", "FastSpeech2ConformerModel"),
|
||||||
|
("fastspeech2_conformer_with_hifigan", "FastSpeech2ConformerWithHifiGan"),
|
||||||
("flaubert", "FlaubertModel"),
|
("flaubert", "FlaubertModel"),
|
||||||
("flava", "FlavaModel"),
|
("flava", "FlavaModel"),
|
||||||
("fnet", "FNetModel"),
|
("fnet", "FNetModel"),
|
||||||
@ -1512,6 +1513,7 @@ MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING_NAMES = OrderedDict(
|
|||||||
("bark", "BarkModel"),
|
("bark", "BarkModel"),
|
||||||
("csm", "CsmForConditionalGeneration"),
|
("csm", "CsmForConditionalGeneration"),
|
||||||
("fastspeech2_conformer", "FastSpeech2ConformerWithHifiGan"),
|
("fastspeech2_conformer", "FastSpeech2ConformerWithHifiGan"),
|
||||||
|
("fastspeech2_conformer_with_hifigan", "FastSpeech2ConformerWithHifiGan"),
|
||||||
("musicgen", "MusicgenForConditionalGeneration"),
|
("musicgen", "MusicgenForConditionalGeneration"),
|
||||||
("musicgen_melody", "MusicgenMelodyForConditionalGeneration"),
|
("musicgen_melody", "MusicgenMelodyForConditionalGeneration"),
|
||||||
("qwen2_5_omni", "Qwen2_5OmniForConditionalGeneration"),
|
("qwen2_5_omni", "Qwen2_5OmniForConditionalGeneration"),
|
||||||
|
@ -25,6 +25,7 @@ from transformers import (
|
|||||||
AutoModelForCausalLM,
|
AutoModelForCausalLM,
|
||||||
AutoModelForSequenceClassification,
|
AutoModelForSequenceClassification,
|
||||||
AutoTokenizer,
|
AutoTokenizer,
|
||||||
|
BitsAndBytesConfig,
|
||||||
OPTForCausalLM,
|
OPTForCausalLM,
|
||||||
Trainer,
|
Trainer,
|
||||||
TrainingArguments,
|
TrainingArguments,
|
||||||
@ -76,6 +77,12 @@ class PeftIntegrationTester(unittest.TestCase, PeftTesterMixin):
|
|||||||
|
|
||||||
return is_peft_loaded
|
return is_peft_loaded
|
||||||
|
|
||||||
|
def _get_bnb_4bit_config(self):
|
||||||
|
return BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4")
|
||||||
|
|
||||||
|
def _get_bnb_8bit_config(self):
|
||||||
|
return BitsAndBytesConfig(load_in_8bit=True)
|
||||||
|
|
||||||
def test_peft_from_pretrained(self):
|
def test_peft_from_pretrained(self):
|
||||||
"""
|
"""
|
||||||
Simple test that tests the basic usage of PEFT model through `from_pretrained`.
|
Simple test that tests the basic usage of PEFT model through `from_pretrained`.
|
||||||
@ -431,7 +438,10 @@ class PeftIntegrationTester(unittest.TestCase, PeftTesterMixin):
|
|||||||
"""
|
"""
|
||||||
for model_id in self.peft_test_model_ids:
|
for model_id in self.peft_test_model_ids:
|
||||||
for transformers_class in self.transformers_test_model_classes:
|
for transformers_class in self.transformers_test_model_classes:
|
||||||
peft_model = transformers_class.from_pretrained(model_id, load_in_8bit=True, device_map="auto")
|
bnb_config = self._get_bnb_8bit_config()
|
||||||
|
peft_model = transformers_class.from_pretrained(
|
||||||
|
model_id, device_map="auto", quantization_config=bnb_config
|
||||||
|
)
|
||||||
|
|
||||||
module = peft_model.model.decoder.layers[0].self_attn.v_proj
|
module = peft_model.model.decoder.layers[0].self_attn.v_proj
|
||||||
self.assertTrue(module.__class__.__name__ == "Linear8bitLt")
|
self.assertTrue(module.__class__.__name__ == "Linear8bitLt")
|
||||||
@ -449,7 +459,10 @@ class PeftIntegrationTester(unittest.TestCase, PeftTesterMixin):
|
|||||||
# 4bit
|
# 4bit
|
||||||
for model_id in self.peft_test_model_ids:
|
for model_id in self.peft_test_model_ids:
|
||||||
for transformers_class in self.transformers_test_model_classes:
|
for transformers_class in self.transformers_test_model_classes:
|
||||||
peft_model = transformers_class.from_pretrained(model_id, load_in_4bit=True, device_map="auto")
|
bnb_config = self._get_bnb_4bit_config()
|
||||||
|
peft_model = transformers_class.from_pretrained(
|
||||||
|
model_id, device_map="auto", quantization_config=bnb_config
|
||||||
|
)
|
||||||
|
|
||||||
module = peft_model.model.decoder.layers[0].self_attn.v_proj
|
module = peft_model.model.decoder.layers[0].self_attn.v_proj
|
||||||
self.assertTrue(module.__class__.__name__ == "Linear4bit")
|
self.assertTrue(module.__class__.__name__ == "Linear4bit")
|
||||||
@ -465,7 +478,10 @@ class PeftIntegrationTester(unittest.TestCase, PeftTesterMixin):
|
|||||||
# 8-bit
|
# 8-bit
|
||||||
for model_id in self.peft_test_model_ids:
|
for model_id in self.peft_test_model_ids:
|
||||||
for transformers_class in self.transformers_test_model_classes:
|
for transformers_class in self.transformers_test_model_classes:
|
||||||
peft_model = transformers_class.from_pretrained(model_id, load_in_8bit=True, device_map="auto")
|
bnb_config = self._get_bnb_8bit_config()
|
||||||
|
peft_model = transformers_class.from_pretrained(
|
||||||
|
model_id, device_map="auto", quantization_config=bnb_config
|
||||||
|
)
|
||||||
|
|
||||||
module = peft_model.model.decoder.layers[0].self_attn.v_proj
|
module = peft_model.model.decoder.layers[0].self_attn.v_proj
|
||||||
self.assertTrue(module.__class__.__name__ == "Linear8bitLt")
|
self.assertTrue(module.__class__.__name__ == "Linear8bitLt")
|
||||||
@ -489,7 +505,10 @@ class PeftIntegrationTester(unittest.TestCase, PeftTesterMixin):
|
|||||||
# 4bit
|
# 4bit
|
||||||
for model_id in self.peft_test_model_ids:
|
for model_id in self.peft_test_model_ids:
|
||||||
for transformers_class in self.transformers_test_model_classes:
|
for transformers_class in self.transformers_test_model_classes:
|
||||||
peft_model = transformers_class.from_pretrained(model_id, load_in_4bit=True, device_map="auto")
|
bnb_config = self._get_bnb_4bit_config()
|
||||||
|
peft_model = transformers_class.from_pretrained(
|
||||||
|
model_id, device_map="auto", quantization_config=bnb_config
|
||||||
|
)
|
||||||
|
|
||||||
module = peft_model.model.decoder.layers[0].self_attn.v_proj
|
module = peft_model.model.decoder.layers[0].self_attn.v_proj
|
||||||
self.assertTrue(module.__class__.__name__ == "Linear4bit")
|
self.assertTrue(module.__class__.__name__ == "Linear4bit")
|
||||||
@ -505,7 +524,10 @@ class PeftIntegrationTester(unittest.TestCase, PeftTesterMixin):
|
|||||||
# 8-bit
|
# 8-bit
|
||||||
for model_id in self.peft_test_model_ids:
|
for model_id in self.peft_test_model_ids:
|
||||||
for transformers_class in self.transformers_test_model_classes:
|
for transformers_class in self.transformers_test_model_classes:
|
||||||
peft_model = transformers_class.from_pretrained(model_id, load_in_8bit=True, device_map="auto")
|
bnb_config = self._get_bnb_8bit_config()
|
||||||
|
peft_model = transformers_class.from_pretrained(
|
||||||
|
model_id, device_map="auto", quantization_config=bnb_config
|
||||||
|
)
|
||||||
|
|
||||||
module = peft_model.model.decoder.layers[0].self_attn.v_proj
|
module = peft_model.model.decoder.layers[0].self_attn.v_proj
|
||||||
self.assertTrue(module.__class__.__name__ == "Linear8bitLt")
|
self.assertTrue(module.__class__.__name__ == "Linear8bitLt")
|
||||||
|
@ -328,8 +328,10 @@ class TokenClassificationPipelineTests(unittest.TestCase):
|
|||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
nested_simplify(output),
|
nested_simplify(output),
|
||||||
[
|
[
|
||||||
{"entity_group": "PER", "score": ANY(float), "word": "Sarah", "start": 6, "end": 11},
|
[
|
||||||
{"entity_group": "LOC", "score": ANY(float), "word": "New York", "start": 21, "end": 29},
|
{"entity_group": "PER", "score": ANY(float), "word": "Sarah", "start": 6, "end": 11},
|
||||||
|
{"entity_group": "LOC", "score": ANY(float), "word": "New York", "start": 21, "end": 29},
|
||||||
|
]
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -349,8 +351,8 @@ class TokenClassificationPipelineTests(unittest.TestCase):
|
|||||||
{"entity_group": "LOC", "score": ANY(float), "word": "New York", "start": 21, "end": 29},
|
{"entity_group": "LOC", "score": ANY(float), "word": "New York", "start": 21, "end": 29},
|
||||||
],
|
],
|
||||||
[
|
[
|
||||||
{"entity_group": "PER", "score": ANY(float), "word": "Wolfgang", "start": 12, "end": 20},
|
{"entity_group": "PER", "score": ANY(float), "word": "Wolfgang", "start": 11, "end": 19},
|
||||||
{"entity_group": "LOC", "score": ANY(float), "word": "Berlin", "start": 36, "end": 42},
|
{"entity_group": "LOC", "score": ANY(float), "word": "Berlin", "start": 34, "end": 40},
|
||||||
],
|
],
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -3748,7 +3748,24 @@ class ModelTesterMixin:
|
|||||||
self.skipTest(
|
self.skipTest(
|
||||||
"PaliGemma-like models currently (transformers==4.41.0) requires an attention_mask input"
|
"PaliGemma-like models currently (transformers==4.41.0) requires an attention_mask input"
|
||||||
)
|
)
|
||||||
if config.model_type in ["modernbert", "gemma3", "t5gemma"]:
|
if config.model_type in [
|
||||||
|
"modernbert",
|
||||||
|
"gemma3",
|
||||||
|
"t5gemma",
|
||||||
|
"diffllama",
|
||||||
|
"dpr",
|
||||||
|
"eomt",
|
||||||
|
"gpt_bigcode",
|
||||||
|
"jamba",
|
||||||
|
"kosmos-2",
|
||||||
|
"mllama",
|
||||||
|
"pixtral",
|
||||||
|
"sam",
|
||||||
|
"sam_hq",
|
||||||
|
"zamba2",
|
||||||
|
"sam_vision_model",
|
||||||
|
"sam_hq_vision_model",
|
||||||
|
]:
|
||||||
self.skipTest(
|
self.skipTest(
|
||||||
reason=f"{config.model_type} currently (transformers==4.52.0) automatically adds an attention_mask input"
|
reason=f"{config.model_type} currently (transformers==4.52.0) automatically adds an attention_mask input"
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user