Attn implementation for composite models (#32238)

* first try

* codestyle

* idefics2 is happy

* [run-slow] llava, llava_next, video_llava, vipllava, llava_next_video, idefics, idefics2, kosmos2, fuyu, blip, blip_2, instructblip, instructblipvideo, paligemma

* fix-copies

* [run-slow] llava, llava_next, video_llava, vipllava, llava_next_video, idefics, idefics2, kosmos2, fuyu, blip, blip_2, instructblip, instructblipvideo

* blip-2 needs to init vision from config

* when was this removed O_o

* minor fix

* tests

* this way?

* tests

* model-agnostic code

* codestyle

* add tests for idefics

* modify general test for VLMs

* no generation test for vlm yet!

* no generation test here also

* wanr in VIT-SDPA if output attn

* add more tests

* user can pass dict as attn impl

* repo consistency

* update

* muicgen

* no prints

* forgot speech enc-dec and clip

* how many composite models we have?

* musicgen meelody is same as mudicgen

* +siglip

* fix tests + add some more

* remove idefics custom overriden code

* make idefics2 automappable

* nits

* skip tests

* doctests

* Update src/transformers/models/idefics2/configuration_idefics2.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update tests/models/clip/test_modeling_clip.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update tests/models/idefics2/test_modeling_idefics2.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update tests/models/idefics2/test_modeling_idefics2.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/configuration_utils.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* major update, no need for automap

* clean up

* add FA2 test

* more tests

* style

* skip tests

* why did these started failing now?

* no attributes for FA2 needed

* one tiny test

* address comment about FA2 false warning

* style

* add new models and resolve conflicts

* fix copies

* let it be this way for now, come back tomorrow to review

* some more fixes

* update

* more updates

* update

* fix copies

* style and tests

* another big update

* fix tests

* fix tests

* update

* another update

* fix tests

* fix copies

* fix tests

---------

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
Raushan Turganbay 2024-10-22 06:54:44 +02:00 committed by GitHub
parent 32590b5ecb
commit 21d5025826
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
64 changed files with 1925 additions and 713 deletions

View File

@ -79,6 +79,7 @@ FlashAttention-2 is currently supported for the following architectures:
* [OLMo](https://huggingface.co/docs/transformers/model_doc/olmo#transformers.OlmoModel)
* [OLMoE](https://huggingface.co/docs/transformers/model_doc/olmoe#transformers.OlmoeModel)
* [OPT](https://huggingface.co/docs/transformers/model_doc/opt#transformers.OPTModel)
* [PaliGemma](https://huggingface.co/docs/transformers/model_doc/paligemma#transformers.PaliGemmaForConditionalGeneration)
* [Phi](https://huggingface.co/docs/transformers/model_doc/phi#transformers.PhiModel)
* [Phi3](https://huggingface.co/docs/transformers/model_doc/phi3#transformers.Phi3Model)
* [PhiMoE](https://huggingface.co/docs/transformers/model_doc/phimoe#transformers.PhimoeModel)
@ -88,6 +89,10 @@ FlashAttention-2 is currently supported for the following architectures:
* [Qwen2Audio](https://huggingface.co/docs/transformers/model_doc/qwen2_audio#transformers.Qwen2AudioEncoder)
* [Qwen2MoE](https://huggingface.co/docs/transformers/model_doc/qwen2_moe#transformers.Qwen2MoeModel)
* [Qwen2VL](https://huggingface.co/docs/transformers/model_doc/qwen2_vl#transformers.Qwen2VLModel)
* [RAG](https://huggingface.co/docs/transformers/model_doc/rag#transformers.RagModel)
* [SpeechEncoderDecoder](https://huggingface.co/docs/transformers/model_doc/speech_encoder_decoder#transformers.SpeechEncoderDecoderModel)
* [VisionEncoderDecoder](https://huggingface.co/docs/transformers/model_doc/vision_encoder_decoder#transformers.VisionEncoderDecoderModel)
* [VisionTextDualEncoder](https://huggingface.co/docs/transformers/model_doc/vision_text_dual_encoder#transformers.VisionTextDualEncoderModel)
* [Whisper](https://huggingface.co/docs/transformers/model_doc/whisper#transformers.WhisperModel)
* [Wav2Vec2](https://huggingface.co/docs/transformers/model_doc/wav2vec2#transformers.Wav2Vec2Model)
* [Hubert](https://huggingface.co/docs/transformers/model_doc/hubert#transformers.HubertModel)
@ -225,6 +230,7 @@ For now, Transformers supports SDPA inference and training for the following arc
* [Dinov2](https://huggingface.co/docs/transformers/en/model_doc/dinov2)
* [DistilBert](https://huggingface.co/docs/transformers/model_doc/distilbert#transformers.DistilBertModel)
* [Dpr](https://huggingface.co/docs/transformers/model_doc/dpr#transformers.DprReader)
* [EncoderDecoder](https://huggingface.co/docs/transformers/model_doc/encoder_decoder#transformers.EncoderDecoderModel)
* [Falcon](https://huggingface.co/docs/transformers/model_doc/falcon#transformers.FalconModel)
* [Gemma](https://huggingface.co/docs/transformers/model_doc/gemma#transformers.GemmaModel)
* [Gemma2](https://huggingface.co/docs/transformers/model_doc/gemma2#transformers.Gemma2Model)
@ -233,11 +239,16 @@ For now, Transformers supports SDPA inference and training for the following arc
* [GPTNeoX](https://huggingface.co/docs/transformers/model_doc/gpt_neox#transformers.GPTNeoXModel)
* [Hubert](https://huggingface.co/docs/transformers/model_doc/hubert#transformers.HubertModel)
* [Idefics](https://huggingface.co/docs/transformers/model_doc/idefics#transformers.IdeficsModel)
* [Idefics2](https://huggingface.co/docs/transformers/model_doc/idefics2#transformers.Idefics2Model)
* [Idefics3](https://huggingface.co/docs/transformers/model_doc/idefics3#transformers.Idefics3Model)
* [Granite](https://huggingface.co/docs/transformers/model_doc/granite#transformers.GraniteModel)
* [GraniteMoe](https://huggingface.co/docs/transformers/model_doc/granitemoe#transformers.GraniteMoeModel)
* [JetMoe](https://huggingface.co/docs/transformers/model_doc/jetmoe#transformers.JetMoeModel)
* [Jamba](https://huggingface.co/docs/transformers/model_doc/jamba#transformers.JambaModel)
* [Llama](https://huggingface.co/docs/transformers/model_doc/llama#transformers.LlamaModel)
* [Llava](https://huggingface.co/docs/transformers/model_doc/llava)
* [Llava-NeXT](https://huggingface.co/docs/transformers/model_doc/llava_next)
* [Llava-NeXT-Video](https://huggingface.co/docs/transformers/model_doc/llava_next_video)
* [LLaVA-Onevision](https://huggingface.co/docs/transformers/model_doc/llava_onevision)
* [M2M100](https://huggingface.co/docs/transformers/model_doc/m2m_100#transformers.M2M100Model)
* [Mimi](https://huggingface.co/docs/transformers/model_doc/mimi)
@ -277,10 +288,15 @@ For now, Transformers supports SDPA inference and training for the following arc
* [Musicgen](https://huggingface.co/docs/transformers/model_doc/musicgen#transformers.MusicgenModel)
* [MusicGen Melody](https://huggingface.co/docs/transformers/model_doc/musicgen_melody#transformers.MusicgenMelodyModel)
* [Nemotron](https://huggingface.co/docs/transformers/model_doc/nemotron)
* [SpeechEncoderDecoder](https://huggingface.co/docs/transformers/model_doc/speech_encoder_decoder#transformers.SpeechEncoderDecoderModel)
* [VideoLlava](https://huggingface.co/docs/transformers/model_doc/video_llava)
* [VipLlava](https://huggingface.co/docs/transformers/model_doc/vipllava)
* [VisionEncoderDecoder](https://huggingface.co/docs/transformers/model_doc/vision_encoder_decoder#transformers.VisionEncoderDecoderModel)
* [ViT](https://huggingface.co/docs/transformers/model_doc/vit#transformers.ViTModel)
* [ViTHybrid](https://huggingface.co/docs/transformers/model_doc/vit_hybrid#transformers.ViTHybridModel)
* [ViTMAE](https://huggingface.co/docs/transformers/model_doc/vit_mae#transformers.ViTMAEModel)
* [ViTMSN](https://huggingface.co/docs/transformers/model_doc/vit_msn#transformers.ViTMSNModel)
* [VisionTextDualEncoder](https://huggingface.co/docs/transformers/model_doc/vision_text_dual_encoder#transformers.VisionTextDualEncoderModel)
* [VideoMAE](https://huggingface.co/docs/transformers/model_doc/videomae#transformers.VideoMAEModell)
* [ViViT](https://huggingface.co/docs/transformers/model_doc/vivit#transformers.VivitModel)
* [wav2vec2](https://huggingface.co/docs/transformers/model_doc/wav2vec2#transformers.Wav2Vec2Model)

View File

@ -296,6 +296,7 @@ class PretrainedConfig(PushToHubMixin):
# Attention implementation to use, if relevant.
self._attn_implementation_internal = kwargs.pop("attn_implementation", None)
self._attn_implementation_autoset = False
# Drop the transformers version info
self.transformers_version = kwargs.pop("transformers_version", None)
@ -776,6 +777,10 @@ class PretrainedConfig(PushToHubMixin):
def __repr__(self):
return f"{self.__class__.__name__} {self.to_json_string()}"
def __iter__(self):
for attr in self.__dict__:
yield attr
def to_diff_dict(self) -> Dict[str, Any]:
"""
Removes all attributes from config which correspond to the default config attributes for better readability and

View File

@ -1420,6 +1420,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
f"`model = {self.__class__.__name__}.from_pretrained(PRETRAINED_MODEL_NAME)`"
)
# Save config and origin of the pretrained weights if given in model
if not getattr(config, "_attn_implementation_autoset", False):
config = self._autoset_attn_implementation(
config, torch_dtype=torch.get_default_dtype(), check_device_map=False
)
@ -1500,6 +1501,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
torch_dtype (`torch.dtype`, *optional*):
Override the default `torch.dtype` and load the model under this dtype.
"""
# when we init a model from within another model (e.g. VLMs) and dispatch on FA2
# a warning is raised that dtype should be fp16. Since we never pass dtype from within
# modeling code, we can try to infer it here same way as done in `from_pretrained`
torch_dtype = kwargs.pop("torch_dtype", torch.get_default_dtype())
use_flash_attention_2 = kwargs.pop("use_flash_attention_2", False)
@ -1518,6 +1522,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
attn_implementation = None
config._attn_implementation = kwargs.pop("attn_implementation", attn_implementation)
if not getattr(config, "_attn_implementation_autoset", False):
config = cls._autoset_attn_implementation(
config,
use_flash_attention_2=use_flash_attention_2,
@ -1570,7 +1575,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
' We recommend to just use `attn_implementation="flash_attention_2"` when loading the model.'
)
if config._attn_implementation not in ["eager", "sdpa", "flash_attention_2"]:
if not isinstance(config._attn_implementation, dict) and config._attn_implementation not in [
"eager",
"sdpa",
"flash_attention_2",
]:
message = f'Specified `attn_implementation="{config._attn_implementation}"` is not supported. The only possible arguments are `attn_implementation="eager"` (manual attention implementation)'
if cls._supports_flash_attn_2:
message += ', `"attn_implementation=flash_attention_2"` (implementation using flash attention 2)'
@ -1581,6 +1590,22 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
# If a config is passed with a preset attn_implementation, we skip the automatic dispatch and use the user-provided config, with hard checks that the requested attention implementation is available.
requested_attn_implementation = config._attn_implementation_internal
# Composite models consisting of several PretrainedModels have to specify attention impl as a dict
# where keys are sub-config names. But most people will specify one `str` which means that should dispatch it
# for all sub-models.
# Below we check if a config is composite and manually prepare a dict of attn impl if not already passed as a dict.
# Later each sub-module will dispatch with its own attn impl, by calling `XXXModel._from_config(config.text_config)`
# If any of sub-modules doesn't support requested attn, an error will be raised. See https://github.com/huggingface/transformers/pull/32238
for key in config:
if isinstance(getattr(config, key), PretrainedConfig):
sub_config = getattr(config, key)
curr_attn_implementation = (
requested_attn_implementation
if not isinstance(requested_attn_implementation, dict)
else requested_attn_implementation.get(key, None)
)
sub_config._attn_implementation_internal = curr_attn_implementation
if use_flash_attention_2:
logger.warning_once(
'The model was loaded with use_flash_attention_2=True, which is deprecated and may be removed in a future release. Please use `attn_implementation="flash_attention_2"` instead.'
@ -1611,9 +1636,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
"Using the `SDPA` attention implementation on multi-gpu setup with ROCM may lead to performance issues due to the FA backend. Disabling it to use alternative backends."
)
torch.backends.cuda.enable_flash_sdp(False)
elif isinstance(requested_attn_implementation, dict):
config._attn_implementation = None
else:
config._attn_implementation = "eager"
config._attn_implementation_autoset = True
return config
@classmethod
@ -2771,6 +2799,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
# Attach architecture to the config
model_to_save.config.architectures = [model_to_save.__class__.__name__]
# Unset attn implementation so it can be set to another one when loading back
model_to_save.config._attn_implementation_autoset = False
# If we have a custom model, we copy the file defining it in the folder and set the attributes so it can be
# loaded from the Hub.
if self._auto_class is not None:
@ -4055,6 +4086,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
init_contexts.append(init_empty_weights())
config = copy.deepcopy(config) # We do not want to modify the config inplace in from_pretrained.
if not getattr(config, "_attn_implementation_autoset", False):
config = cls._autoset_attn_implementation(
config, use_flash_attention_2=use_flash_attention_2, torch_dtype=torch_dtype, device_map=device_map
)

View File

@ -176,8 +176,24 @@ class ASTSdpaSelfAttention(ASTSelfAttention):
self.attention_probs_dropout_prob = config.attention_probs_dropout_prob
def forward(
self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False
self,
hidden_states: torch.FloatTensor,
head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
if output_attentions or head_mask is not None:
logger.warning_once(
"`ASTSdpaAttention` is used but `torch.nn.functional.scaled_dot_product_attention` does not support "
"`output_attentions=True` or `head_mask`. Falling back to the manual attention implementation, but "
"specifying the manual implementation will be required from Transformers version v5.0.0 onwards. "
'This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
return super().forward(
hidden_states=hidden_states,
head_mask=head_mask,
output_attentions=output_attentions,
)
mixed_query_layer = self.query(hidden_states)
key_layer = self.transpose_for_scores(self.key(hidden_states))

View File

@ -410,6 +410,7 @@ class Blip2PreTrainedModel(PreTrainedModel):
config_class = Blip2Config
base_model_prefix = "blip"
supports_gradient_checkpointing = True
_no_split_modules = [
"Blip2Attention",
"Blip2QFormerMultiHeadAttention",
@ -1455,13 +1456,9 @@ class Blip2Model(Blip2PreTrainedModel):
self.language_projection = nn.Linear(config.qformer_config.hidden_size, config.text_config.hidden_size)
if config.use_decoder_only_language_model:
language_model = AutoModelForCausalLM.from_config(
config.text_config, attn_implementation=config._attn_implementation
)
language_model = AutoModelForCausalLM.from_config(config.text_config)
else:
language_model = AutoModelForSeq2SeqLM.from_config(
config.text_config, attn_implementation=config._attn_implementation
)
language_model = AutoModelForSeq2SeqLM.from_config(config.text_config)
# Update _tied_weights_keys using the base model used.
if language_model._tied_weights_keys is not None:
@ -2020,13 +2017,9 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel, GenerationMixin):
self.language_projection = nn.Linear(config.qformer_config.hidden_size, config.text_config.hidden_size)
if config.use_decoder_only_language_model:
language_model = AutoModelForCausalLM.from_config(
config.text_config, attn_implementation=config._attn_implementation
)
language_model = AutoModelForCausalLM.from_config(config.text_config)
else:
language_model = AutoModelForSeq2SeqLM.from_config(
config.text_config, attn_implementation=config._attn_implementation
)
language_model = AutoModelForSeq2SeqLM.from_config(config.text_config)
# Update _tied_weights_keys using the base model used.
if language_model._tied_weights_keys is not None:

View File

@ -1204,10 +1204,10 @@ class CLIPModel(CLIPPreTrainedModel):
self.text_embed_dim = text_config.hidden_size
self.vision_embed_dim = vision_config.hidden_size
text_model = CLIPTextModel._from_config(text_config, attn_implementation=config._attn_implementation)
text_model = CLIPTextModel._from_config(text_config)
self.text_model = text_model.text_model
vision_model = CLIPVisionModel._from_config(vision_config, attn_implementation=config._attn_implementation)
vision_model = CLIPVisionModel._from_config(vision_config)
self.vision_model = vision_model.vision_model
self.visual_projection = nn.Linear(self.vision_embed_dim, self.projection_dim, bias=False)
@ -1590,9 +1590,7 @@ class CLIPForImageClassification(CLIPPreTrainedModel):
super().__init__(config)
self.num_labels = config.num_labels
vision_model = CLIPVisionModel._from_config(
config.vision_config, attn_implementation=config._attn_implementation
)
vision_model = CLIPVisionModel._from_config(config.vision_config)
self.vision_model = vision_model.vision_model
# Classifier head

View File

@ -248,8 +248,24 @@ class DeiTSdpaSelfAttention(DeiTSelfAttention):
self.attention_probs_dropout_prob = config.attention_probs_dropout_prob
def forward(
self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False
self,
hidden_states: torch.FloatTensor,
head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
if output_attentions or head_mask is not None:
logger.warning_once(
"`DeiTSdpaAttention` is used but `torch.nn.functional.scaled_dot_product_attention` does not support "
"`output_attentions=True` or `head_mask`. Falling back to the manual attention implementation, but "
"specifying the manual implementation will be required from Transformers version v5.0.0 onwards. "
'This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
return super().forward(
hidden_states=hidden_states,
head_mask=head_mask,
output_attentions=output_attentions,
)
mixed_query_layer = self.query(hidden_states)
key_layer = self.transpose_for_scores(self.key(hidden_states))

View File

@ -180,6 +180,8 @@ class EncoderDecoderModel(PreTrainedModel, GenerationMixin):
main_input_name = "input_ids"
supports_gradient_checkpointing = True
_supports_param_buffer_assignment = False
_supports_flash_attn_2 = True
_supports_sdpa = True
def __init__(
self,
@ -210,12 +212,12 @@ class EncoderDecoderModel(PreTrainedModel, GenerationMixin):
if encoder is None:
from ..auto.modeling_auto import AutoModel
encoder = AutoModel.from_config(config.encoder, attn_implementation=config._attn_implementation)
encoder = AutoModel.from_config(config.encoder)
if decoder is None:
from ..auto.modeling_auto import AutoModelForCausalLM
decoder = AutoModelForCausalLM.from_config(config.decoder, attn_implementation=config._attn_implementation)
decoder = AutoModelForCausalLM.from_config(config.decoder)
self.encoder = encoder
self.decoder = decoder
@ -233,6 +235,9 @@ class EncoderDecoderModel(PreTrainedModel, GenerationMixin):
# make sure that the individual model's config refers to the shared config
# so that the updates to the config will be synced
# update `_attn_implementation` because the attn is set in a deepcopied config within PreTrainedModel
self.config.encoder._attn_implementation = self.encoder.config._attn_implementation
self.config.decoder._attn_implementation = self.decoder.config._attn_implementation
self.encoder.config = self.config.encoder
self.decoder.config = self.config.decoder

View File

@ -933,18 +933,6 @@ class IdeficsPreTrainedModel(PreTrainedModel):
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
# Adapted from transformers.modeling_utils.PreTrainedModel._check_and_enable_sdpa
@classmethod
def _check_and_enable_sdpa(cls, config, hard_check_only: bool = False) -> PretrainedConfig:
# We remove the checks on `is_torch_sdpa_available()` and `cls._supports_sdpa` as Falcon supports SDPA from torch==2.0.0 (no requirement on 2.1).
_is_bettertransformer = getattr(cls, "use_bettertransformer", False)
if _is_bettertransformer:
return config
if not hard_check_only:
config._attn_implementation = "sdpa"
return config
LLAMA_INPUTS_DOCSTRING = r"""
Args:

View File

@ -57,7 +57,7 @@ class Idefics2VisionConfig(PretrainedConfig):
The epsilon used by the layer normalization layers.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
intializer_range (`float`, *optional*, defaults to 0.02):
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation for initializing all weight matrices in the model.
Example:
@ -134,6 +134,10 @@ class Idefics2PerceiverConfig(PretrainedConfig):
Args:
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
The non-linear activation function (function or string) in the perceiver block.
hidden_size (`int`, *optional*, defaults to 4096):
Dimension of the hidden representations.
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
The epsilon used by the rms normalization layers.
resampler_n_latents (`int`, *optional*, defaults to 64):
Number of latent embeddings to resample ("compress") the input sequence to (usually < 128).
resampler_depth (`int`, *optional*, defaults to 3):
@ -153,6 +157,8 @@ class Idefics2PerceiverConfig(PretrainedConfig):
def __init__(
self,
hidden_act="silu",
hidden_size=4096,
rms_norm_eps=1e-06,
resampler_n_latents=64,
resampler_depth=3,
resampler_n_heads=16,
@ -162,6 +168,8 @@ class Idefics2PerceiverConfig(PretrainedConfig):
**kwargs,
):
self.hidden_act = hidden_act
self.hidden_size = hidden_size
self.rms_norm_eps = rms_norm_eps
self.resampler_n_latents = resampler_n_latents
self.resampler_depth = resampler_depth
self.resampler_n_heads = resampler_n_heads
@ -258,5 +266,12 @@ class Idefics2Config(PretrainedConfig):
)
self.text_config = text_config
if self.text_config.hidden_size != self.perceiver_config.hidden_size:
self.perceiver_config.hidden_size = self.text_config.hidden_size
self.perceiver_config.rms_norm_eps = self.text_config.rms_norm_eps
logger.warning_once(
"Perceiver config has a different `hidden_size` than text config, which means default values were used. "
"In your model's config on the hub, add `hidden_size` and `rms_norm_eps` keys under the `perceiver_config` dict. "
)
super().__init__(**kwargs, tie_word_embeddings=tie_word_embeddings)

View File

@ -16,7 +16,7 @@
import math
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple, Union
from typing import List, Optional, Tuple, Union
import torch
import torch.utils.checkpoint
@ -38,7 +38,7 @@ from ...utils import (
replace_return_docstrings,
)
from ..auto import AutoModel
from .configuration_idefics2 import Idefics2Config, Idefics2VisionConfig
from .configuration_idefics2 import Idefics2Config, Idefics2PerceiverConfig, Idefics2VisionConfig
if is_flash_attn_2_available():
@ -572,9 +572,86 @@ class Idefics2Encoder(nn.Module):
)
class Idefics2VisionTransformer(nn.Module):
IDEFICS2_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
and behavior.
Parameters:
config ([`Idefics2Config`] or [`Idefics2VisionConfig`]):
Model configuration class with all the parameters of the model. Initializing with a config file does not
load the weights associated with the model, only the configuration. Check out the
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
@add_start_docstrings(
"The bare Idefics2 Model outputting raw hidden-states without any specific head on top.",
IDEFICS2_START_DOCSTRING,
)
class Idefics2PreTrainedModel(PreTrainedModel):
config_class = Idefics2Config
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["Idefics2VisionAttention", "Idefics2MLP", "Idefics2PerceiverLayer", "Idefics2DecoderLayer"]
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_cache_class = True
def _init_weights(self, module):
std = (
self.config.text_config.initializer_range
if hasattr(self.config, "initializer_range")
else self.config.text_config.initializer_range
)
if hasattr(module, "class_embedding"):
module.class_embedding.data.normal_(mean=0.0, std=std)
if isinstance(module, (nn.Linear, nn.Conv2d)):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
IDEFICS2_INPUTS_DOCSTRING = r"""
Args:
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)):
The tensors corresponding to the input images. Pixel values can be obtained using
[`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details ([]`LlavaProcessor`] uses
[`CLIPImageProcessor`] for processing images).
pixel_attention_mask (`torch.Tensor` of shape `(batch_size, image_size, image_size)`, *optional*):
Mask to avoid performing attention on padding pixel indices.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
@add_start_docstrings(
"""Idefics2 vision encoder model that returnss raw image embeddings.""",
IDEFICS2_START_DOCSTRING,
)
class Idefics2VisionTransformer(Idefics2PreTrainedModel):
_supports_sdpa = False
config_class = Idefics2VisionConfig
def __init__(self, config: Idefics2VisionConfig):
super().__init__()
super().__init__(config)
embed_dim = config.hidden_size
self.config = config
@ -687,12 +764,12 @@ class Idefics2PerceiverAttention(nn.Module):
super().__init__()
self.layer_idx = None
self.hidden_size = config.text_config.hidden_size
self.num_heads = config.perceiver_config.resampler_n_heads
self.head_dim = config.perceiver_config.resampler_head_dim
self.num_key_value_heads = config.perceiver_config.num_key_value_heads
self.hidden_size = config.hidden_size
self.num_heads = config.resampler_n_heads
self.head_dim = config.resampler_head_dim
self.num_key_value_heads = config.num_key_value_heads
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.attention_dropout = config.perceiver_config.attention_dropout
self.attention_dropout = config.attention_dropout
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
@ -918,20 +995,20 @@ IDEFICS2_PERCEIVER_ATTENTION_CLASSES = {
class Idefics2PerceiverLayer(nn.Module):
def __init__(self, config, layer_idx: int):
super().__init__()
self.hidden_size = config.text_config.hidden_size
self.n_latents = config.perceiver_config.resampler_n_latents
self.depth = config.perceiver_config.resampler_depth
self.rms_norm_eps = config.text_config.rms_norm_eps
self.hidden_size = config.hidden_size
self.n_latents = config.resampler_n_latents
self.depth = config.resampler_depth
self.rms_norm_eps = config.rms_norm_eps
self.input_latents_norm = Idefics2RMSNorm(self.hidden_size, eps=self.rms_norm_eps)
self.input_context_norm = Idefics2RMSNorm(self.hidden_size, eps=self.rms_norm_eps)
self.self_attn = IDEFICS2_PERCEIVER_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx=layer_idx)
self.post_attention_layernorm = Idefics2RMSNorm(self.hidden_size, eps=self.rms_norm_eps)
self.mlp = Idefics2MLP(
hidden_size=config.text_config.hidden_size,
intermediate_size=config.text_config.hidden_size * 4,
output_size=config.text_config.hidden_size,
hidden_act=config.perceiver_config.hidden_act,
hidden_size=config.hidden_size,
intermediate_size=config.hidden_size * 4,
output_size=config.hidden_size,
hidden_act=config.hidden_act,
)
def forward(
@ -987,20 +1064,37 @@ class Idefics2PerceiverLayer(nn.Module):
return outputs
class Idefics2PerceiverResampler(nn.Module):
IDEFICS2_INPUTS_DOCSTRING = r"""
Args:
context (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_dim)`):
The hidden states of the image after vision encoder and modality projection.
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
"""
@add_start_docstrings(
"Idefics2 perceiver resampler model that performs `depth` blocks of cross-attention with a fixed ",
"`n_latents` inputs to decrease embedding sequence length. The Resampler acts as a form of learned pooling and ",
"is derived from [Perceiver: General Perception with Iterative Attention](https://arxiv.org/abs/2103.03206)",
IDEFICS2_START_DOCSTRING,
)
class Idefics2PerceiverResampler(Idefics2PreTrainedModel):
_supports_sdpa = False
config_class = Idefics2PerceiverConfig
def __init__(self, config) -> None:
"""
Instantiates a Perceiver Resampler that operates over a sequence of embeddings (say from a ResNet or ViT or
MAE) of a given dimension, performs `depth` blocks of cross-attention with a fixed `n_latents` inputs, then
returns a Tensor of shape [bsz, n_latents, embed_dim]. The Resampler acts as a form of learned pooling and
is derived from [Perceiver: General Perception with Iterative Attention](https://arxiv.org/abs/2103.03206).
"""
super().__init__()
self.hidden_size = config.text_config.hidden_size
self.hidden_act = config.perceiver_config.hidden_act
self.n_latents = config.perceiver_config.resampler_n_latents
self.depth = config.perceiver_config.resampler_depth
self.rms_norm_eps = config.text_config.rms_norm_eps
super().__init__(config)
self.hidden_size = config.hidden_size
self.hidden_act = config.hidden_act
self.n_latents = config.resampler_n_latents
self.depth = config.resampler_depth
self.rms_norm_eps = config.rms_norm_eps
# Create Latents for Perceiver
self.latents = nn.Parameter(torch.ones(self.n_latents, self.hidden_size))
@ -1014,7 +1108,7 @@ class Idefics2PerceiverResampler(nn.Module):
def forward(
self,
context: torch.Tensor,
attention_mask,
attention_mask: torch.Tensor,
) -> torch.Tensor:
# seq embed -> bsz seq embed
latents = self.latents.unsqueeze(0).expand((context.shape[0], *self.latents.size()))
@ -1057,7 +1151,7 @@ class Idefics2Connector(nn.Module):
output_size=config.text_config.hidden_size,
hidden_act=config.text_config.hidden_act,
)
self.perceiver_resampler = Idefics2PerceiverResampler(config)
self.perceiver_resampler = Idefics2PerceiverResampler._from_config(config.perceiver_config)
def forward(self, image_hidden_states, attention_mask):
image_hidden_states = self.modality_projection(image_hidden_states)
@ -1065,80 +1159,6 @@ class Idefics2Connector(nn.Module):
return image_hidden_states
IDEFICS2_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
and behavior.
Parameters:
config ([`Idefics2Config`] or [`Idefics2VisionConfig`]):
Model configuration class with all the parameters of the model. Initializing with a config file does not
load the weights associated with the model, only the configuration. Check out the
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
@add_start_docstrings(
"The bare Idefics2 Model outputting raw hidden-states without any specific head on top.",
IDEFICS2_START_DOCSTRING,
)
class Idefics2PreTrainedModel(PreTrainedModel):
config_class = Idefics2Config
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["Idefics2VisionAttention", "Idefics2MLP", "Idefics2PerceiverLayer", "Idefics2DecoderLayer"]
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
_supports_cache_class = True
def _init_weights(self, module):
std = (
self.config.initializer_range
if hasattr(self.config, "initializer_range")
else self.config.text_config.initializer_range
)
if hasattr(module, "class_embedding"):
module.class_embedding.data.normal_(mean=0.0, std=std)
if isinstance(module, (nn.Linear, nn.Conv2d)):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
@classmethod
def _autoset_attn_implementation(
cls,
config,
use_flash_attention_2: bool = False,
torch_dtype: Optional[torch.dtype] = None,
device_map: Optional[Union[str, Dict[str, int]]] = None,
check_device_map: bool = True,
**kwargs,
):
"""
Overrides the method in `PreTrainedModel` to update the vision config with the correct attention implementation
"""
config = super()._autoset_attn_implementation(
config=config,
use_flash_attention_2=use_flash_attention_2,
torch_dtype=torch_dtype,
device_map=device_map,
check_device_map=check_device_map,
**kwargs,
)
config.vision_config._attn_implementation = config._attn_implementation
return config
IDEFICS2_INPUTS_DOCSTRING = r"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
@ -1219,14 +1239,14 @@ class Idefics2Model(Idefics2PreTrainedModel):
self.padding_idx = self.config.text_config.pad_token_id
self.vocab_size = self.config.text_config.vocab_size
self.vision_model = Idefics2VisionTransformer(config.vision_config)
self.vision_model = Idefics2VisionTransformer._from_config(config.vision_config)
self.connector = Idefics2Connector(config)
self.text_model = AutoModel.from_config(config.text_config, attn_implementation=config._attn_implementation)
self.text_model = AutoModel.from_config(config.text_config)
self.image_seq_len = config.perceiver_config.resampler_n_latents
self.image_token_id = self.config.image_token_id
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
self._use_flash_attention_2 = config.text_config._attn_implementation == "flash_attention_2"
self.post_init()

View File

@ -621,12 +621,13 @@ class Idefics3PreTrainedModel(PreTrainedModel):
_no_split_modules = ["Idefics3VisionAttention", "Idefics3DecoderLayer"]
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_cache_class = True
# Copied from transformers.models.idefics2.modeling_idefics2.Idefics2PreTrainedModel._init_weights
def _init_weights(self, module):
std = (
self.config.initializer_range
self.config.text_config.initializer_range
if hasattr(self.config, "initializer_range")
else self.config.text_config.initializer_range
)
@ -667,6 +668,7 @@ IDEFICS3_VISION_START_DOCSTRING = r"""
)
class Idefics3VisionTransformer(Idefics3PreTrainedModel):
config_class = Idefics3VisionConfig
_supports_sdpa = False
def __init__(self, config: Idefics3VisionConfig):
super().__init__(config)
@ -824,18 +826,16 @@ class Idefics3Model(Idefics3PreTrainedModel):
self.padding_idx = self.config.text_config.pad_token_id
self.vocab_size = self.config.text_config.vocab_size
self.vision_model = Idefics3VisionTransformer._from_config(
config.vision_config, attn_implementation=config._attn_implementation
)
self.vision_model = Idefics3VisionTransformer._from_config(config.vision_config)
self.connector = Idefics3Connector(config)
self.text_model = AutoModel.from_config(config.text_config, attn_implementation=config._attn_implementation)
self.text_model = AutoModel.from_config(config.text_config)
self.image_seq_len = int(
((config.vision_config.image_size // config.vision_config.patch_size) ** 2) / (config.scale_factor**2)
)
self.image_token_id = self.config.image_token_id
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
self._use_flash_attention_2 = config.text_config._attn_implementation == "flash_attention_2"
self.post_init()

View File

@ -315,6 +315,7 @@ class InstructBlipPreTrainedModel(PreTrainedModel):
config_class = InstructBlipConfig
base_model_prefix = "blip"
supports_gradient_checkpointing = True
_no_split_modules = [
"InstructBlipQFormerEmbeddings",
"InstructBlipAttention",
@ -1298,13 +1299,9 @@ class InstructBlipForConditionalGeneration(InstructBlipPreTrainedModel, Generati
self.language_projection = nn.Linear(config.qformer_config.hidden_size, config.text_config.hidden_size)
if config.use_decoder_only_language_model:
language_model = AutoModelForCausalLM.from_config(
config.text_config, attn_implementation=config._attn_implementation
)
language_model = AutoModelForCausalLM.from_config(config.text_config)
else:
language_model = AutoModelForSeq2SeqLM.from_config(
config.text_config, attn_implementation=config._attn_implementation
)
language_model = AutoModelForSeq2SeqLM.from_config(config.text_config)
if language_model._no_split_modules is not None:
self._no_split_modules.extend(language_model._no_split_modules)

View File

@ -317,6 +317,7 @@ class InstructBlipVideoPreTrainedModel(PreTrainedModel):
config_class = InstructBlipVideoConfig
base_model_prefix = "blip"
supports_gradient_checkpointing = True
_no_split_modules = [
"InstructBlipVideoQFormerEmbeddings",
"InstructBlipVideoAttention",
@ -1292,13 +1293,9 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipVideoPreTrainedModel
self.language_projection = nn.Linear(config.qformer_config.hidden_size, config.text_config.hidden_size)
if config.use_decoder_only_language_model:
language_model = AutoModelForCausalLM.from_config(
config.text_config, attn_implementation=config._attn_implementation
)
language_model = AutoModelForCausalLM.from_config(config.text_config)
else:
language_model = AutoModelForSeq2SeqLM.from_config(
config.text_config, attn_implementation=config._attn_implementation
)
language_model = AutoModelForSeq2SeqLM.from_config(config.text_config)
if language_model._no_split_modules is not None:
self._no_split_modules.extend(language_model._no_split_modules)

View File

@ -125,8 +125,9 @@ class LlavaPreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True
_no_split_modules = ["LlavaVisionAttention"]
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
_supports_cache_class = True
_supports_flash_attn_2 = True
_supports_sdpa = True
def _init_weights(self, module):
# important: this ported version of Llava isn't meant for training from scratch - only
@ -150,14 +151,6 @@ class LlavaPreTrainedModel(PreTrainedModel):
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
@property
def _supports_sdpa(self):
"""
Retrieve language_model's attribute to check whether the model supports
SDPA or not.
"""
return self.language_model._supports_sdpa
LLAVA_INPUTS_DOCSTRING = r"""
Args:
@ -245,9 +238,7 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel, GenerationMixin):
self.multi_modal_projector = LlavaMultiModalProjector(config)
self.vocab_size = config.text_config.vocab_size
self.language_model = AutoModelForCausalLM.from_config(
config.text_config, attn_implementation=config._attn_implementation
)
self.language_model = AutoModelForCausalLM.from_config(config.text_config)
self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
self.post_init()

View File

@ -234,8 +234,9 @@ class LlavaNextPreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True
_no_split_modules = ["LlavaNextVisionAttention"]
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
_supports_cache_class = True
_supports_flash_attn_2 = True
_supports_sdpa = True
def _init_weights(self, module):
# important: this ported version of LlavaNext isn't meant for training from scratch - only
@ -259,14 +260,6 @@ class LlavaNextPreTrainedModel(PreTrainedModel):
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
@property
def _supports_sdpa(self):
"""
Retrieve language_model's attribute to check whether the model supports
SDPA or not.
"""
return self.language_model._supports_sdpa
LLAVA_NEXT_INPUTS_DOCSTRING = r"""
Args:
@ -360,9 +353,7 @@ class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel, GenerationMixi
self.image_newline = nn.Parameter(torch.randn(config.text_config.hidden_size, dtype=self.dtype) * embed_std)
self.vocab_size = config.text_config.vocab_size
self.language_model = AutoModelForCausalLM.from_config(
config.text_config, attn_implementation=config._attn_implementation
)
self.language_model = AutoModelForCausalLM.from_config(config.text_config)
self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
self._padding_side = "left" # set it to left by default, user can use setter to change padding_sides
self.post_init()

View File

@ -277,8 +277,9 @@ class LlavaNextVideoPreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True
_no_split_modules = ["LlavaNextVideoVisionAttention"]
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
_supports_cache_class = True
_supports_flash_attn_2 = True
_supports_sdpa = True
def _init_weights(self, module):
# important: this ported version of LlavaNextVideo isn't meant for training from scratch - only
@ -302,14 +303,6 @@ class LlavaNextVideoPreTrainedModel(PreTrainedModel):
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
@property
def _supports_sdpa(self):
"""
Retrieve language_model's attribute to check whether the model supports
SDPA or not.
"""
return self.language_model._supports_sdpa
LLAVA_NEXT_VIDEO_INPUTS_DOCSTRING = r"""
Args:
@ -406,9 +399,7 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextVideoPreTrainedModel, Gene
self.image_newline = nn.Parameter(torch.randn(config.text_config.hidden_size, dtype=self.dtype) * embed_std)
self.vocab_size = config.text_config.vocab_size
self.language_model = AutoModelForCausalLM.from_config(
config.text_config, attn_implementation=config._attn_implementation
)
self.language_model = AutoModelForCausalLM.from_config(config.text_config)
self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
self._padding_side = "left" # set it to left by default, user can use setter to change padding_sides
self.vision_resampler = LlavaNextVideoPooler(config)

View File

@ -363,18 +363,14 @@ LLAVA_ONEVISION_INPUTS_DOCSTRING = r"""
class LlavaOnevisionForConditionalGeneration(LlavaOnevisionPreTrainedModel, GenerationMixin):
def __init__(self, config: LlavaOnevisionConfig):
super().__init__(config)
self.vision_tower = AutoModel.from_config(
config.vision_config, attn_implementation=config._attn_implementation
)
self.vision_tower = AutoModel.from_config(config.vision_config)
self.multi_modal_projector = LlavaOnevisionMultiModalProjector(config)
embed_std = 1 / math.sqrt(config.text_config.hidden_size)
self.image_newline = nn.Parameter(torch.randn(config.text_config.hidden_size, dtype=self.dtype) * embed_std)
self.vocab_size = config.text_config.vocab_size
self.language_model = AutoModelForCausalLM.from_config(
config.text_config, attn_implementation=config._attn_implementation
)
self.language_model = AutoModelForCausalLM.from_config(config.text_config)
self.post_init()
# Copied from transformers.models.llava_next.modeling_llava_next.LlavaNextForConditionalGeneration.get_input_embeddings

View File

@ -1979,12 +1979,8 @@ class MllamaForConditionalGeneration(MllamaPreTrainedModel, GenerationMixin):
self.vision_output_dim = config.vision_config.vision_output_dim
self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
self.vision_model = MllamaVisionModel._from_config(
config.vision_config, attn_implementation=config._attn_implementation
)
self.language_model = MllamaForCausalLM._from_config(
config.text_config, attn_implementation=config._attn_implementation
)
self.vision_model = MllamaVisionModel._from_config(config.vision_config)
self.language_model = MllamaForCausalLM._from_config(config.text_config)
self.multi_modal_projector = nn.Linear(
config.vision_config.vision_output_dim,
config.text_config.hidden_size,

View File

@ -236,20 +236,3 @@ class MusicgenConfig(PretrainedConfig):
# This is a property because you might want to change the codec model on the fly
def sampling_rate(self):
return self.audio_encoder.sampling_rate
@property
def _attn_implementation(self):
# This property is made private for now (as it cannot be changed and a PreTrainedModel.use_attn_implementation method needs to be implemented.)
if hasattr(self, "_attn_implementation_internal"):
if self._attn_implementation_internal is None:
# `config.attn_implementation` should never be None, for backward compatibility.
return "eager"
else:
return self._attn_implementation_internal
else:
return "eager"
@_attn_implementation.setter
def _attn_implementation(self, value):
self._attn_implementation_internal = value
self.decoder._attn_implementation = value

View File

@ -1713,7 +1713,7 @@ class MusicgenForConditionalGeneration(PreTrainedModel, GenerationMixin):
audio_encoder = AutoModel.from_config(config.audio_encoder)
if decoder is None:
decoder = MusicgenForCausalLM(config.decoder)
decoder = MusicgenForCausalLM._from_config(config.decoder)
self.text_encoder = text_encoder
self.audio_encoder = audio_encoder
@ -1737,6 +1737,9 @@ class MusicgenForConditionalGeneration(PreTrainedModel, GenerationMixin):
# make sure that the individual model's config refers to the shared config
# so that the updates to the config will be synced
self.config.text_encoder._attn_implementation = self.text_encoder.config._attn_implementation
self.config.audio_encoder._attn_implementation = self.audio_encoder.config._attn_implementation
self.config.decoder._attn_implementation = self.decoder.config._attn_implementation
self.text_encoder.config = self.config.text_encoder
self.audio_encoder.config = self.config.audio_encoder
self.decoder.config = self.config.decoder

View File

@ -250,20 +250,3 @@ class MusicgenMelodyConfig(PretrainedConfig):
# This is a property because you might want to change the codec model on the fly
def sampling_rate(self):
return self.audio_encoder.sampling_rate
@property
def _attn_implementation(self):
# This property is made private for now (as it cannot be changed and a PreTrainedModel.use_attn_implementation method needs to be implemented.)
if hasattr(self, "_attn_implementation_internal"):
if self._attn_implementation_internal is None:
# `config.attn_implementation` should never be None, for backward compatibility.
return "eager"
else:
return self._attn_implementation_internal
else:
return "eager"
@_attn_implementation.setter
def _attn_implementation(self, value):
self._attn_implementation_internal = value
self.decoder._attn_implementation = value

View File

@ -1628,7 +1628,7 @@ class MusicgenMelodyForConditionalGeneration(PreTrainedModel, GenerationMixin):
audio_encoder = AutoModel.from_config(config.audio_encoder)
if decoder is None:
decoder = MusicgenMelodyForCausalLM(config.decoder)
decoder = MusicgenMelodyForCausalLM._from_config(config.decoder)
self.text_encoder = text_encoder
self.audio_encoder = audio_encoder
@ -1636,6 +1636,9 @@ class MusicgenMelodyForConditionalGeneration(PreTrainedModel, GenerationMixin):
# make sure that the individual model's config refers to the shared config
# so that the updates to the config will be synced
self.config.text_encoder._attn_implementation = self.text_encoder.config._attn_implementation
self.config.audio_encoder._attn_implementation = self.audio_encoder.config._attn_implementation
self.config.decoder._attn_implementation = self.decoder.config._attn_implementation
self.text_encoder.config = self.config.text_encoder
self.audio_encoder.config = self.config.audio_encoder
self.decoder.config = self.config.decoder

View File

@ -288,7 +288,7 @@ class OmDetTurboLRUCache:
class OmDetTurboLanguageBackbone(nn.Module):
def __init__(self, config: OmDetTurboConfig):
super().__init__()
self.model = AutoModel.from_config(config.text_config, attn_implementation=config._attn_implementation)
self.model = AutoModel.from_config(config.text_config)
self.text_projection = nn.Parameter(torch.zeros(config.text_projection_in_dim, config.text_projection_out_dim))
def forward(self, hidden_states, mask=None, encode_type="task"):

View File

@ -193,12 +193,12 @@ class PaliGemmaPreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True
_no_split_modules = ["PaliGemmaMultiModalProjector"]
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = False
_supports_cache_class = True
_supports_quantized_cache = True
_supports_static_cache = True
_supports_sdpa = True
_supports_cache_class = True
_supports_flash_attn_2 = True
_supports_sdpa = True
def _init_weights(self, module):
# important: this ported version of PaliGemmaisn't meant for training from scratch - only
@ -221,14 +221,6 @@ class PaliGemmaPreTrainedModel(PreTrainedModel):
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
@property
def _supports_sdpa(self):
"""
Retrieve language_model's attribute to check whether the model supports
SDPA or not.
"""
return self.language_model._supports_sdpa
PALIGEMMA_INPUTS_DOCSTRING = r"""
Args:
@ -310,11 +302,8 @@ class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel, GenerationMixi
self.vision_tower = AutoModel.from_config(config=config.vision_config)
self.multi_modal_projector = PaliGemmaMultiModalProjector(config)
self.vocab_size = config.text_config.vocab_size
self._attn_implementation = config._attn_implementation
language_model = AutoModelForCausalLM.from_config(
config=config.text_config, attn_implementation=self._attn_implementation
)
language_model = AutoModelForCausalLM.from_config(config=config.text_config)
if language_model._tied_weights_keys is not None:
self._tied_weights_keys = [f"language_model.{k}" for k in language_model._tied_weights_keys]
@ -354,6 +343,11 @@ class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel, GenerationMixi
def _update_causal_mask(
self, attention_mask, token_type_ids, inputs_embeds, past_key_values, cache_position, is_training: bool = False
):
if self.config.text_config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
return None
using_static_cache = isinstance(past_key_values, StaticCache)
dtype = inputs_embeds.dtype
min_dtype = torch.finfo(dtype).min

View File

@ -544,6 +544,7 @@ class Qwen2AudioPreTrainedModel(PreTrainedModel):
_no_split_modules = ["Qwen2AudioAttention"]
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
_supports_sdpa = True
def _init_weights(self, module):
# important: this ported version of Qwen2Audio isn't meant for training from scratch - only
@ -559,14 +560,6 @@ class Qwen2AudioPreTrainedModel(PreTrainedModel):
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
@property
def _supports_sdpa(self):
"""
Retrieve language_model's attribute to check whether the model supports
SDPA or not.
"""
return self.language_model._supports_sdpa
QWEN2AUDIOENCODER_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
@ -859,13 +852,11 @@ QWEN2AUDIO_INPUTS_DOCSTRING = r"""
class Qwen2AudioForConditionalGeneration(Qwen2AudioPreTrainedModel, GenerationMixin):
def __init__(self, config: Qwen2AudioConfig):
super().__init__(config)
self.audio_tower = AutoModel.from_config(config.audio_config, attn_implementation=config._attn_implementation)
self.audio_tower = AutoModel.from_config(config.audio_config)
self.multi_modal_projector = Qwen2AudioMultiModalProjector(config)
self.vocab_size = config.text_config.vocab_size
self.language_model = AutoModelForCausalLM.from_config(
config.text_config, attn_implementation=config._attn_implementation
)
self.language_model = AutoModelForCausalLM.from_config(config.text_config)
self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
self._padding_side = "left" # set it to left by default, user can use setter to change padding_sides
self.post_init()

View File

@ -1443,9 +1443,7 @@ class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel, GenerationMixin):
def __init__(self, config):
super().__init__(config)
self.visual = Qwen2VisionTransformerPretrainedModel._from_config(
config.vision_config, attn_implementation=config._attn_implementation
)
self.visual = Qwen2VisionTransformerPretrainedModel._from_config(config.vision_config)
self.model = Qwen2VLModel(config)
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)

View File

@ -232,6 +232,8 @@ class RagPreTrainedModel(PreTrainedModel):
config_class = RagConfig
base_model_prefix = "rag"
_supports_flash_attn_2 = True
_supports_sdpa = True
@classmethod
def from_pretrained(cls, *args, **kwargs):
@ -506,16 +508,12 @@ class RagModel(RagPreTrainedModel):
if question_encoder is None:
from ..auto.modeling_auto import AutoModel
question_encoder = AutoModel.from_config(
config.question_encoder, attn_implementation=config._attn_implementation
)
question_encoder = AutoModel.from_config(config.question_encoder)
if generator is None:
from ..auto.modeling_auto import AutoModelForSeq2SeqLM
generator = AutoModelForSeq2SeqLM.from_config(
config.generator, attn_implementation=config._attn_implementation
)
generator = AutoModelForSeq2SeqLM.from_config(config.generator)
self.retriever = retriever
if self.retriever is not None:

View File

@ -669,6 +669,7 @@ class SiglipPreTrainedModel(PreTrainedModel):
config_class = SiglipConfig
base_model_prefix = "siglip"
supports_gradient_checkpointing = True
_no_split_modules = [
"SiglipTextEmbeddings",
"SiglipEncoderLayer",
@ -1218,8 +1219,8 @@ class SiglipModel(SiglipPreTrainedModel):
vision_config = config.vision_config
# First, initialize the text and vision models with proper attention implementation
text_model = SiglipTextModel._from_config(text_config, attn_implementation=config._attn_implementation)
vision_model = SiglipVisionModel._from_config(vision_config, attn_implementation=config._attn_implementation)
text_model = SiglipTextModel._from_config(text_config)
vision_model = SiglipVisionModel._from_config(vision_config)
# Second, get the text and vision submodules (for backward compatibility)
self.text_model = text_model.text_model
@ -1454,9 +1455,7 @@ class SiglipForImageClassification(SiglipPreTrainedModel):
# Create the vision model with proper attention
# and take only vision_model submodule (for backward compatibility)
vision_model = SiglipVisionModel._from_config(
config.vision_config, attn_implementation=config._attn_implementation
)
vision_model = SiglipVisionModel._from_config(config.vision_config)
self.vision_model = vision_model.vision_model
# Classifier head

View File

@ -183,6 +183,8 @@ class SpeechEncoderDecoderModel(PreTrainedModel, GenerationMixin):
main_input_name = "inputs"
supports_gradient_checkpointing = True
_supports_param_buffer_assignment = False
_supports_flash_attn_2 = True
_supports_sdpa = True
def __init__(
self,
@ -213,10 +215,10 @@ class SpeechEncoderDecoderModel(PreTrainedModel, GenerationMixin):
super().__init__(config)
if encoder is None:
encoder = AutoModel.from_config(config.encoder, attn_implementation=config._attn_implementation)
encoder = AutoModel.from_config(config.encoder)
if decoder is None:
decoder = AutoModelForCausalLM.from_config(config.decoder, attn_implementation=config._attn_implementation)
decoder = AutoModelForCausalLM.from_config(config.decoder)
self.encoder = encoder
self.decoder = decoder
@ -234,6 +236,8 @@ class SpeechEncoderDecoderModel(PreTrainedModel, GenerationMixin):
# make sure that the individual model's config refers to the shared config
# so that the updates to the config will be synced
self.config.encoder._attn_implementation = self.encoder.config._attn_implementation
self.config.decoder._attn_implementation = self.decoder.config._attn_implementation
self.encoder.config = self.config.encoder
self.decoder.config = self.config.decoder

View File

@ -126,8 +126,9 @@ class VideoLlavaPreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True
_no_split_modules = ["VideoLlavaVisionAttention"]
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
_supports_cache_class = True
_supports_flash_attn_2 = True
_supports_sdpa = True
def _init_weights(self, module):
std = (
@ -148,14 +149,6 @@ class VideoLlavaPreTrainedModel(PreTrainedModel):
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
@property
def _supports_sdpa(self):
"""
Retrieve language_model's attribute to check whether the model supports
SDPA or not.
"""
return self.language_model._supports_sdpa
VIDEO_LLAVA_INPUTS_DOCSTRING = r"""
Args:
@ -248,9 +241,7 @@ class VideoLlavaForConditionalGeneration(VideoLlavaPreTrainedModel, GenerationMi
self.multi_modal_projector = VideoLlavaMultiModalProjector(config)
self.vocab_size = config.text_config.vocab_size
self.language_model = AutoModelForCausalLM.from_config(
config.text_config, attn_implementation=config._attn_implementation
)
self.language_model = AutoModelForCausalLM.from_config(config.text_config)
self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
self.post_init()

View File

@ -132,8 +132,9 @@ class VipLlavaPreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True
_no_split_modules = ["VipLlavaVisionAttention"]
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
_supports_cache_class = True
_supports_flash_attn_2 = True
_supports_sdpa = True
def _init_weights(self, module):
# important: this ported version of VipLlava isn't meant for training from scratch - only
@ -157,14 +158,6 @@ class VipLlavaPreTrainedModel(PreTrainedModel):
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
@property
def _supports_sdpa(self):
"""
Retrieve language_model's attribute to check whether the model supports
SDPA or not.
"""
return self.language_model._supports_sdpa
VIPLLAVA_INPUTS_DOCSTRING = r"""
Args:
@ -248,9 +241,7 @@ class VipLlavaForConditionalGeneration(VipLlavaPreTrainedModel, GenerationMixin)
self.multi_modal_projector = VipLlavaMultiModalProjector(config)
self.vocab_size = config.text_config.vocab_size
self.language_model = AutoModelForCausalLM.from_config(
config.text_config, attn_implementation=config._attn_implementation
)
self.language_model = AutoModelForCausalLM.from_config(config.text_config)
self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
self.post_init()

View File

@ -161,6 +161,8 @@ class VisionEncoderDecoderModel(PreTrainedModel, GenerationMixin):
main_input_name = "pixel_values"
supports_gradient_checkpointing = True
_supports_param_buffer_assignment = False
_supports_flash_attn_2 = True
_supports_sdpa = True
def __init__(
self,
@ -191,10 +193,10 @@ class VisionEncoderDecoderModel(PreTrainedModel, GenerationMixin):
super().__init__(config)
if encoder is None:
encoder = AutoModel.from_config(config.encoder, attn_implementation=config._attn_implementation)
encoder = AutoModel.from_config(config.encoder)
if decoder is None:
decoder = AutoModelForCausalLM.from_config(config.decoder, attn_implementation=config._attn_implementation)
decoder = AutoModelForCausalLM.from_config(config.decoder)
self.encoder = encoder
self.decoder = decoder
@ -212,6 +214,8 @@ class VisionEncoderDecoderModel(PreTrainedModel, GenerationMixin):
# make sure that the individual model's config refers to the shared config
# so that the updates to the config will be synced
self.config.encoder._attn_implementation = self.encoder.config._attn_implementation
self.config.decoder._attn_implementation = self.decoder.config._attn_implementation
self.encoder.config = self.config.encoder
self.decoder.config = self.config.decoder

View File

@ -161,6 +161,8 @@ def clip_loss(similarity: torch.Tensor) -> torch.Tensor:
class VisionTextDualEncoderModel(PreTrainedModel):
config_class = VisionTextDualEncoderConfig
base_model_prefix = "vision_text_dual_encoder"
_supports_flash_attn_2 = True
_supports_sdpa = True
def __init__(
self,
@ -184,18 +186,18 @@ class VisionTextDualEncoderModel(PreTrainedModel):
if isinstance(config.vision_config, CLIPVisionConfig):
vision_model = CLIPVisionModel(config.vision_config)
else:
vision_model = AutoModel.from_config(
config.vision_config, attn_implementation=config._attn_implementation
)
vision_model = AutoModel.from_config(config.vision_config)
if text_model is None:
text_model = AutoModel.from_config(config.text_config, attn_implementation=config._attn_implementation)
text_model = AutoModel.from_config(config.text_config)
self.vision_model = vision_model
self.text_model = text_model
# make sure that the individual model's config refers to the shared config
# so that the updates to the config will be synced
self.config.vision_config._attn_implementation = self.vision_model.config._attn_implementation
self.config.text_config._attn_implementation = self.text_model.config._attn_implementation
self.vision_model.config = self.config.vision_config
self.text_model.config = self.config.text_config

View File

@ -250,8 +250,24 @@ class ViTSdpaSelfAttention(ViTSelfAttention):
self.attention_probs_dropout_prob = config.attention_probs_dropout_prob
def forward(
self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False
self,
hidden_states: torch.FloatTensor,
head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
if output_attentions or head_mask is not None:
logger.warning_once(
"`ViTSdpaAttention` is used but `torch.nn.functional.scaled_dot_product_attention` does not support "
"`output_attentions=True` or `head_mask`. Falling back to the manual attention implementation, but "
"specifying the manual implementation will be required from Transformers version v5.0.0 onwards. "
'This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
return super().forward(
hidden_states=hidden_states,
head_mask=head_mask,
output_attentions=output_attentions,
)
mixed_query_layer = self.query(hidden_states)
key_layer = self.transpose_for_scores(self.key(hidden_states))

View File

@ -424,8 +424,24 @@ class ViTMAESdpaSelfAttention(ViTMAESelfAttention):
self.attention_probs_dropout_prob = config.attention_probs_dropout_prob
def forward(
self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False
self,
hidden_states: torch.FloatTensor,
head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
if output_attentions or head_mask is not None:
logger.warning_once(
"`ViTMAESdpaAttention` is used but `torch.nn.functional.scaled_dot_product_attention` does not support "
"`output_attentions=True` or `head_mask`. Falling back to the manual attention implementation, but "
"specifying the manual implementation will be required from Transformers version v5.0.0 onwards. "
'This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
return super().forward(
hidden_states=hidden_states,
head_mask=head_mask,
output_attentions=output_attentions,
)
mixed_query_layer = self.query(hidden_states)
key_layer = self.transpose_for_scores(self.key(hidden_states))

View File

@ -241,8 +241,24 @@ class ViTMSNSdpaSelfAttention(ViTMSNSelfAttention):
self.attention_probs_dropout_prob = config.attention_probs_dropout_prob
def forward(
self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False
self,
hidden_states: torch.FloatTensor,
head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
if output_attentions or head_mask is not None:
logger.warning_once(
"`ViTMSNSdpaAttention` is used but `torch.nn.functional.scaled_dot_product_attention` does not support "
"`output_attentions=True` or `head_mask`. Falling back to the manual attention implementation, but "
"specifying the manual implementation will be required from Transformers version v5.0.0 onwards. "
'This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
return super().forward(
hidden_states=hidden_states,
head_mask=head_mask,
output_attentions=output_attentions,
)
mixed_query_layer = self.query(hidden_states)
key_layer = self.transpose_for_scores(self.key(hidden_states))

View File

@ -299,8 +299,24 @@ class YolosSdpaSelfAttention(YolosSelfAttention):
self.attention_probs_dropout_prob = config.attention_probs_dropout_prob
def forward(
self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False
self,
hidden_states: torch.FloatTensor,
head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
if output_attentions or head_mask is not None:
logger.warning_once(
"`YolosSdpaAttention` is used but `torch.nn.functional.scaled_dot_product_attention` does not support "
"`output_attentions=True` or `head_mask`. Falling back to the manual attention implementation, but "
"specifying the manual implementation will be required from Transformers version v5.0.0 onwards. "
'This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
return super().forward(
hidden_states=hidden_states,
head_mask=head_mask,
output_attentions=output_attentions,
)
mixed_query_layer = self.query(hidden_states)
key_layer = self.transpose_for_scores(self.key(hidden_states))

View File

@ -27,6 +27,7 @@ from transformers.testing_utils import (
require_torch_fp16,
require_torch_gpu,
require_torch_multi_accelerator,
require_torch_sdpa,
require_vision,
slow,
torch_device,
@ -456,6 +457,7 @@ class Blip2ForConditionalGenerationDecoderOnlyTest(ModelTesterMixin, GenerationT
test_resize_embeddings = False
test_attention_outputs = False
test_torchscript = False
_is_composite = True
def setUp(self):
self.model_tester = Blip2ForConditionalGenerationDecoderOnlyModelTester(self)
@ -488,6 +490,66 @@ class Blip2ForConditionalGenerationDecoderOnlyTest(ModelTesterMixin, GenerationT
def test_save_load_fast_init_to_base(self):
pass
@require_torch_sdpa
def test_sdpa_can_dispatch_composite_models(self):
"""
Tests if composite models dispatch correctly on SDPA/eager when requested so when loading the model.
This tests only by looking at layer names, as usually SDPA layers are calles "SDPAAttention".
In contrast to the above test, this one checks if the "config._attn_implamentation" is a dict after the model
is loaded, because we manually replicate requested attn implementation on each sub-config when loading.
See https://github.com/huggingface/transformers/pull/32238 for more info
The test tries to cover most general cases of composite models, VLMs with vision and text configs. Any model
that has a different set of sub-configs has to overwrite this test.
"""
if not self.has_attentions:
self.skipTest(reason="Model architecture does not support attentions")
if not self._is_composite:
self.skipTest(f"{self.all_model_classes[0].__name__} does not support SDPA")
for model_class in self.all_model_classes:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model_sdpa = model_class.from_pretrained(tmpdirname)
model_sdpa = model_sdpa.eval().to(torch_device)
text_attn = "sdpa" if model.language_model._supports_sdpa else "eager"
vision_attn = "sdpa" if model.vision_model._supports_sdpa else "eager"
qformer_attn = "sdpa" if model.qformer._supports_sdpa else "eager"
# `None` as it is the requested one which will be assigned to each sub-config
# Sub-model will dispatch to SDPA if it can (checked below that `SDPA` layers are present)
self.assertTrue(model.language_model.config._attn_implementation == text_attn)
self.assertTrue(model.vision_model.config._attn_implementation == vision_attn)
self.assertTrue(model.qformer.config._attn_implementation == qformer_attn)
model_eager = model_class.from_pretrained(tmpdirname, attn_implementation="eager")
model_eager = model_eager.eval().to(torch_device)
self.assertTrue(model_eager.config._attn_implementation == "eager")
self.assertTrue(model_eager.language_model.config._attn_implementation == "eager")
self.assertTrue(model_eager.vision_model.config._attn_implementation == "eager")
self.assertTrue(model_eager.qformer.config._attn_implementation == "eager")
for name, submodule in model_eager.named_modules():
class_name = submodule.__class__.__name__
if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
raise ValueError("The eager model should not have SDPA attention layers")
has_sdpa = False
for name, submodule in model_sdpa.named_modules():
class_name = submodule.__class__.__name__
if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
has_sdpa = True
break
if not has_sdpa and any(
module_attn == "sdpa" for module_attn in [text_attn, vision_attn, qformer_attn]
):
raise ValueError("The SDPA model should have SDPA attention layers")
def test_forward_signature(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
@ -715,6 +777,7 @@ class Blip2ModelTest(ModelTesterMixin, PipelineTesterMixin, GenerationTesterMixi
test_resize_embeddings = False
test_attention_outputs = False
test_torchscript = False
_is_composite = True
# TODO: Fix the failed tests
def is_pipeline_test_to_skip(
@ -768,6 +831,66 @@ class Blip2ModelTest(ModelTesterMixin, PipelineTesterMixin, GenerationTesterMixi
def test_cpu_offload(self):
pass
@require_torch_sdpa
def test_sdpa_can_dispatch_composite_models(self):
"""
Tests if composite models dispatch correctly on SDPA/eager when requested so when loading the model.
This tests only by looking at layer names, as usually SDPA layers are calles "SDPAAttention".
In contrast to the above test, this one checks if the "config._attn_implamentation" is a dict after the model
is loaded, because we manually replicate requested attn implementation on each sub-config when loading.
See https://github.com/huggingface/transformers/pull/32238 for more info
The test tries to cover most general cases of composite models, VLMs with vision and text configs. Any model
that has a different set of sub-configs has to overwrite this test.
"""
if not self.has_attentions:
self.skipTest(reason="Model architecture does not support attentions")
if not self._is_composite:
self.skipTest(f"{self.all_model_classes[0].__name__} does not support SDPA")
for model_class in self.all_model_classes:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model_sdpa = model_class.from_pretrained(tmpdirname)
model_sdpa = model_sdpa.eval().to(torch_device)
text_attn = "sdpa" if model.language_model._supports_sdpa else "eager"
vision_attn = "sdpa" if model.vision_model._supports_sdpa else "eager"
qformer_attn = "sdpa" if model.qformer._supports_sdpa else "eager"
# `None` as it is the requested one which will be assigned to each sub-config
# Sub-model will dispatch to SDPA if it can (checked below that `SDPA` layers are present)
self.assertTrue(model.language_model.config._attn_implementation == text_attn)
self.assertTrue(model.vision_model.config._attn_implementation == vision_attn)
self.assertTrue(model.qformer.config._attn_implementation == qformer_attn)
model_eager = model_class.from_pretrained(tmpdirname, attn_implementation="eager")
model_eager = model_eager.eval().to(torch_device)
self.assertTrue(model_eager.config._attn_implementation == "eager")
self.assertTrue(model_eager.language_model.config._attn_implementation == "eager")
self.assertTrue(model_eager.vision_model.config._attn_implementation == "eager")
self.assertTrue(model_eager.qformer.config._attn_implementation == "eager")
for name, submodule in model_eager.named_modules():
class_name = submodule.__class__.__name__
if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
raise ValueError("The eager model should not have SDPA attention layers")
has_sdpa = False
for name, submodule in model_sdpa.named_modules():
class_name = submodule.__class__.__name__
if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
has_sdpa = True
break
if not has_sdpa and any(
module_attn == "sdpa" for module_attn in [text_attn, vision_attn, qformer_attn]
):
raise ValueError("The SDPA model should have SDPA attention layers")
def test_forward_signature(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()

View File

@ -191,6 +191,53 @@ class CLIPModelTesterMixin(ModelTesterMixin):
different output logits, and are not supposed to be used or tested with padding_side="left".
"""
def test_sdpa_can_dispatch_composite_models(self):
for model_class in self.all_model_classes:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
# Load the model with SDPA
model_sdpa = model_class.from_pretrained(tmpdirname)
model_sdpa = model_sdpa.eval().to(torch_device)
# Load model with eager attention
model_eager = model_class.from_pretrained(
tmpdirname,
attn_implementation="eager",
)
model_eager = model_eager.eval().to(torch_device)
# SigLip has one shared cls attr for all models, so we assign both submodels heer
vision_attn = text_attn = "sdpa" if model._supports_sdpa else "eager"
# `None` as it is the requested one which will be assigned to each sub-config
# Sub-model will dispatch to SDPA if it can (checked below that `SDPA` layers are present)
if hasattr(model_sdpa, "vision_model") and hasattr(model_sdpa, "text_model"):
self.assertTrue(model_sdpa.vision_model.config._attn_implementation == vision_attn)
self.assertTrue(model_sdpa.text_model.config._attn_implementation == text_attn)
self.assertTrue(model_eager.vision_model.config._attn_implementation == "eager")
self.assertTrue(model_eager.text_model.config._attn_implementation == "eager")
self.assertTrue(model_sdpa.config._attn_implementation == "sdpa")
self.assertTrue(model_eager.config._attn_implementation == "eager")
for name, submodule in model_eager.named_modules():
class_name = submodule.__class__.__name__
if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
raise ValueError("The eager model should not have SDPA attention layers")
has_sdpa = False
for name, submodule in model_sdpa.named_modules():
class_name = submodule.__class__.__name__
if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
has_sdpa = True
break
if not has_sdpa and model_sdpa.config.model_type != "falcon":
raise ValueError("The SDPA model should have SDPA attention layers")
def test_eager_matches_sdpa_inference(
self,
torch_dtype: str,
@ -252,24 +299,6 @@ class CLIPModelTesterMixin(ModelTesterMixin):
)
model_eager = model_eager.eval().to(torch_device)
self.assertTrue(model_sdpa.config._attn_implementation == "sdpa")
self.assertTrue(model_eager.config._attn_implementation == "eager")
for name, submodule in model_eager.named_modules():
class_name = submodule.__class__.__name__
if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
raise ValueError("The eager model should not have SDPA attention layers")
has_sdpa = False
for name, submodule in model_sdpa.named_modules():
class_name = submodule.__class__.__name__
if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
has_sdpa = True
break
if not has_sdpa:
raise ValueError("The SDPA model should have SDPA attention layers")
# We use these for loops instead of parameterized.expand just for the interest of avoiding loading/saving the model each time,
# but it would be nicer to have an efficient way to use parameterized.expand
cases = [
@ -461,6 +490,10 @@ class CLIPVisionModelTest(CLIPModelTesterMixin, unittest.TestCase):
use_attention_mask_options=(None,),
)
@require_torch_sdpa
def test_sdpa_can_dispatch_composite_models(self):
super().test_sdpa_can_dispatch_composite_models()
class CLIPTextModelTester:
def __init__(
@ -639,6 +672,10 @@ class CLIPTextModelTest(CLIPModelTesterMixin, unittest.TestCase):
use_attention_mask_options=(None, "right"), # "left" is not supported for text model
)
@require_torch_sdpa
def test_sdpa_can_dispatch_composite_models(self):
super().test_sdpa_can_dispatch_composite_models()
@require_torch_sdpa
def test_sdpa_can_dispatch_on_flash(self):
self.skipTest(reason="CLIPTextModel has two attention masks: `causal_attention_mask` and `attention_mask`")
@ -704,6 +741,7 @@ class CLIPModelTest(CLIPModelTesterMixin, PipelineTesterMixin, unittest.TestCase
test_pruning = False
test_resize_embeddings = False
test_attention_outputs = False
_is_composite = True
def setUp(self):
self.model_tester = CLIPModelTester(self)
@ -975,6 +1013,10 @@ class CLIPModelTest(CLIPModelTesterMixin, PipelineTesterMixin, unittest.TestCase
use_attention_mask_options=(None, "right"), # "left" is not supported for text model
)
@require_torch_sdpa
def test_sdpa_can_dispatch_composite_models(self):
super().test_sdpa_can_dispatch_composite_models()
@require_torch_sdpa
def test_sdpa_can_dispatch_on_flash(self):
self.skipTest(reason="CLIP text tower has two attention masks: `causal_attention_mask` and `attention_mask`")
@ -1104,6 +1146,7 @@ class CLIPForImageClassificationModelTest(CLIPModelTesterMixin, PipelineTesterMi
test_pruning = False
test_resize_embeddings = False
test_attention_outputs = False
_is_composite = True
def setUp(self):
self.model_tester = CLIPForImageClassificationModelTester(self)
@ -1143,6 +1186,10 @@ class CLIPForImageClassificationModelTest(CLIPModelTesterMixin, PipelineTesterMi
use_attention_mask_options=(None,),
)
@require_torch_sdpa
def test_sdpa_can_dispatch_composite_models(self):
super().test_sdpa_can_dispatch_composite_models()
# We will verify our results on an image of cute cats
def prepare_img():

View File

@ -18,7 +18,14 @@ import tempfile
import unittest
from transformers import is_torch_available, logging
from transformers.testing_utils import CaptureLogger, require_deterministic_for_xpu, require_torch, slow, torch_device
from transformers.testing_utils import (
CaptureLogger,
require_deterministic_for_xpu,
require_torch,
require_torch_sdpa,
slow,
torch_device,
)
from ...test_modeling_common import ids_tensor
from ..bart.test_modeling_bart import BartStandaloneDecoderModelTester
@ -54,6 +61,8 @@ if is_torch_available():
@require_torch
class EncoderDecoderMixin:
supports_sdpa = False
def get_encoder_decoder_model(self, config, decoder_config):
raise NotImplementedError
@ -670,6 +679,67 @@ class EncoderDecoderMixin:
max_diff = np.amax(np.abs(out_1 - out_2))
self.assertLessEqual(max_diff, 1e-5)
@require_torch_sdpa
def test_sdpa_can_dispatch_composite_models(self):
if not self.supports_sdpa:
self.skipTest("SDPA is not supported")
inputs_dict = self.prepare_config_and_inputs()
encoder_config, decoder_config = inputs_dict["config"], inputs_dict["decoder_config"]
config = EncoderDecoderConfig.from_encoder_decoder_configs(
encoder_config=encoder_config, decoder_config=decoder_config
)
model = EncoderDecoderModel(config=config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model_sdpa = EncoderDecoderModel.from_pretrained(tmpdirname)
model_sdpa = model_sdpa.eval().to(torch_device)
# see https://github.com/huggingface/transformers/pull/32238
# Sub-model will dispatch to SDPA if it can (checked below that `SDPA` layers are present)
encoder_attn = "sdpa" if model.encoder._supports_sdpa else "eager"
decoder_attn = "sdpa" if model.decoder._supports_sdpa else "eager"
self.assertTrue(model_sdpa.config._attn_implementation == "sdpa")
self.assertTrue(model_sdpa.encoder.config._attn_implementation == encoder_attn)
self.assertTrue(model_sdpa.decoder.config._attn_implementation == decoder_attn)
# Also test that nothing break if we request SDPA explicitly, when both sub-parts support it.
# If the model supports sdpa (i.e. all of sub-models supports it) we'll dispatch safely
# Otherwise we should raise error that SDPA is not supported, as some of the sub-models doesn't support
if encoder_attn == "sdpa" and decoder_attn == "sdpa":
model_sdpa_explicit = EncoderDecoderModel.from_pretrained(tmpdirname, attn_implementation="sdpa")
model_sdpa_explicit = model_sdpa_explicit.eval().to(torch_device)
self.assertTrue(model_sdpa_explicit.config._attn_implementation == "sdpa")
else:
with self.assertRaises(ValueError):
model_sdpa_explicit = EncoderDecoderModel.from_pretrained(tmpdirname, attn_implementation="sdpa")
model_eager = EncoderDecoderModel.from_pretrained(
tmpdirname,
attn_implementation="eager",
)
model_eager = model_eager.eval().to(torch_device)
self.assertTrue(model_eager.config._attn_implementation == "eager")
self.assertTrue(model_eager.encoder.config._attn_implementation == "eager")
self.assertTrue(model_eager.decoder.config._attn_implementation == "eager")
for name, submodule in model_eager.named_modules():
class_name = submodule.__class__.__name__
if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
raise ValueError("The eager model should not have SDPA attention layers")
has_sdpa = False
for name, submodule in model_sdpa.named_modules():
class_name = submodule.__class__.__name__
if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
has_sdpa = True
break
if not has_sdpa:
raise ValueError("The SDPA model should have SDPA attention layers")
@require_torch
class BertEncoderDecoderModelTest(EncoderDecoderMixin, unittest.TestCase):
@ -949,6 +1019,8 @@ class RoBertaEncoderDecoderModelTest(EncoderDecoderMixin, unittest.TestCase):
@require_torch
class GPT2EncoderDecoderModelTest(EncoderDecoderMixin, unittest.TestCase):
supports_sdpa = True
def get_encoder_decoder_model(self, config, decoder_config):
encoder_model = BertModel(config)
decoder_model = GPT2LMHeadModel(decoder_config)

View File

@ -88,6 +88,10 @@ class Gemma2ModelTest(GemmaModelTest, unittest.TestCase):
def test_model_outputs_equivalence(self, **kwargs):
pass
@unittest.skip("Gemma2's forcefully disables sdpa due to softcapping")
def test_sdpa_can_dispatch_non_composite_models(self):
pass
@parameterized.expand([("float16",), ("bfloat16",), ("float32",)])
@unittest.skip("Gemma2's eager attn/sdpa attn outputs are expected to be different")
def test_eager_matches_sdpa_inference(self):

View File

@ -580,11 +580,9 @@ class IdeficsModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase)
model = IdeficsModel.from_pretrained(model_name)
self.assertIsNotNone(model)
@require_torch_sdpa
@slow
@parameterized.expand([("float16",), ("bfloat16",), ("float32",)])
def test_eager_matches_sdpa_inference(self, torch_dtype: str):
self.skipTest(reason="Idefics has a hard requirement on SDPA, skipping this test")
@unittest.skip("Idefics has a hard requirement on SDPA")
def test_sdpa_can_dispatch_non_composite_models(self):
pass
@unittest.skipIf(not is_torch_greater_or_equal_than_2_0, reason="pytorch 2.0 or higher is required")
@ -806,6 +804,10 @@ class IdeficsForVisionText2TextTest(IdeficsModelTest, GenerationTesterMixin, uni
def test_training_gradient_checkpointing_use_reentrant_false(self):
pass
@unittest.skip("Idefics has a hard requirement on SDPA")
def test_sdpa_can_dispatch_non_composite_models(self):
pass
@unittest.skipIf(not is_torch_greater_or_equal_than_2_0, reason="pytorch 2.0 or higher is required")
@require_torch

View File

@ -16,6 +16,7 @@
import copy
import gc
import tempfile
import unittest
from io import BytesIO
@ -36,6 +37,7 @@ from transformers.testing_utils import (
require_torch,
require_torch_gpu,
require_torch_multi_gpu,
require_torch_sdpa,
slow,
torch_device,
)
@ -180,6 +182,7 @@ class Idefics2ModelTest(ModelTesterMixin, unittest.TestCase):
test_pruning = False
test_resize_embeddings = True
test_head_masking = False
_is_composite = True
def setUp(self):
self.model_tester = Idefics2VisionText2TextModelTester(self)
@ -327,6 +330,43 @@ class Idefics2ModelTest(ModelTesterMixin, unittest.TestCase):
# Check that the model can still do a forward pass successfully (every parameter should be resized)
model(**self._prepare_for_class(inputs_dict, model_class))
@require_torch_sdpa
def test_sdpa_can_dispatch_composite_models(self):
for model_class in self.all_model_classes:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model_sdpa = model_class.from_pretrained(tmpdirname)
model_sdpa = model_sdpa.eval().to(torch_device)
vision_attn = None if model.vision_model._supports_sdpa else "eager"
perceiver_attn = None if model.connector.perceiver_resampler._supports_sdpa else "eager"
self.assertTrue(model_sdpa.config._attn_implementation == "sdpa")
self.assertTrue(model_sdpa.vision_model.config._attn_implementation == vision_attn)
self.assertTrue(model_sdpa.connector.perceiver_resampler.config._attn_implementation == perceiver_attn)
model_eager = model_class.from_pretrained(tmpdirname, attn_implementation="eager")
model_eager = model_eager.eval().to(torch_device)
self.assertTrue(model_eager.config._attn_implementation == "eager")
self.assertTrue(model_eager.vision_model.config._attn_implementation == "eager")
self.assertTrue(model_sdpa.connector.perceiver_resampler.config._attn_implementation == "eager")
for name, submodule in model_eager.named_modules():
class_name = submodule.__class__.__name__
if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
raise ValueError("The eager model should not have SDPA attention layers")
has_sdpa = False
for name, submodule in model_sdpa.named_modules():
class_name = submodule.__class__.__name__
if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
has_sdpa = True
break
if not has_sdpa and model_sdpa.config.model_type != "falcon":
raise ValueError("The SDPA model should have SDPA attention layers")
@require_torch
class Idefics2ForConditionalGenerationModelTest(GenerationTesterMixin, ModelTesterMixin, unittest.TestCase):

View File

@ -32,6 +32,7 @@ from transformers.testing_utils import (
require_accelerate,
require_bitsandbytes,
require_torch,
require_torch_sdpa,
require_vision,
slow,
torch_device,
@ -460,6 +461,7 @@ class InstructBlipForConditionalGenerationDecoderOnlyTest(ModelTesterMixin, Gene
test_resize_embeddings = False
test_attention_outputs = False
test_torchscript = False
_is_composite = True
def setUp(self):
self.model_tester = InstructBlipForConditionalGenerationDecoderOnlyModelTester(self)
@ -529,6 +531,66 @@ class InstructBlipForConditionalGenerationDecoderOnlyTest(ModelTesterMixin, Gene
model = InstructBlipForConditionalGeneration.from_pretrained(model_name)
self.assertIsNotNone(model)
@require_torch_sdpa
def test_sdpa_can_dispatch_composite_models(self):
"""
Tests if composite models dispatch correctly on SDPA/eager when requested so when loading the model.
This tests only by looking at layer names, as usually SDPA layers are calles "SDPAAttention".
In contrast to the above test, this one checks if the "config._attn_implamentation" is a dict after the model
is loaded, because we manually replicate requested attn implementation on each sub-config when loading.
See https://github.com/huggingface/transformers/pull/32238 for more info
The test tries to cover most general cases of composite models, VLMs with vision and text configs. Any model
that has a different set of sub-configs has to overwrite this test.
"""
if not self.has_attentions:
self.skipTest(reason="Model architecture does not support attentions")
if not self._is_composite:
self.skipTest(f"{self.all_model_classes[0].__name__} does not support SDPA")
for model_class in self.all_model_classes:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model_sdpa = model_class.from_pretrained(tmpdirname)
model_sdpa = model_sdpa.eval().to(torch_device)
text_attn = "sdpa" if model.language_model._supports_sdpa else "eager"
vision_attn = "sdpa" if model.vision_model._supports_sdpa else "eager"
qformer_attn = "sdpa" if model.qformer._supports_sdpa else "eager"
# `None` as it is the requested one which will be assigned to each sub-config
# Sub-model will dispatch to SDPA if it can (checked below that `SDPA` layers are present)
self.assertTrue(model.language_model.config._attn_implementation == text_attn)
self.assertTrue(model.vision_model.config._attn_implementation == vision_attn)
self.assertTrue(model.qformer.config._attn_implementation == qformer_attn)
model_eager = model_class.from_pretrained(tmpdirname, attn_implementation="eager")
model_eager = model_eager.eval().to(torch_device)
self.assertTrue(model_eager.config._attn_implementation == "eager")
self.assertTrue(model_eager.language_model.config._attn_implementation == "eager")
self.assertTrue(model_eager.vision_model.config._attn_implementation == "eager")
self.assertTrue(model_eager.qformer.config._attn_implementation == "eager")
for name, submodule in model_eager.named_modules():
class_name = submodule.__class__.__name__
if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
raise ValueError("The eager model should not have SDPA attention layers")
has_sdpa = False
for name, submodule in model_sdpa.named_modules():
class_name = submodule.__class__.__name__
if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
has_sdpa = True
break
if not has_sdpa and any(
module_attn == "sdpa" for module_attn in [text_attn, vision_attn, qformer_attn]
):
raise ValueError("The SDPA model should have SDPA attention layers")
# We will verify our results on an image of cute cats
def prepare_img():

View File

@ -32,6 +32,7 @@ from transformers.testing_utils import (
require_accelerate,
require_bitsandbytes,
require_torch,
require_torch_sdpa,
require_vision,
slow,
torch_device,
@ -481,6 +482,7 @@ class InstructBlipVideoForConditionalGenerationDecoderOnlyTest(
test_resize_embeddings = False
test_attention_outputs = False
test_torchscript = False
_is_composite = True
def setUp(self):
self.model_tester = InstructBlipVideoForConditionalGenerationDecoderOnlyModelTester(self)
@ -550,6 +552,66 @@ class InstructBlipVideoForConditionalGenerationDecoderOnlyTest(
model = InstructBlipVideoForConditionalGeneration.from_pretrained(model_name)
self.assertIsNotNone(model)
@require_torch_sdpa
def test_sdpa_can_dispatch_composite_models(self):
"""
Tests if composite models dispatch correctly on SDPA/eager when requested so when loading the model.
This tests only by looking at layer names, as usually SDPA layers are calles "SDPAAttention".
In contrast to the above test, this one checks if the "config._attn_implamentation" is a dict after the model
is loaded, because we manually replicate requested attn implementation on each sub-config when loading.
See https://github.com/huggingface/transformers/pull/32238 for more info
The test tries to cover most general cases of composite models, VLMs with vision and text configs. Any model
that has a different set of sub-configs has to overwrite this test.
"""
if not self.has_attentions:
self.skipTest(reason="Model architecture does not support attentions")
if not self._is_composite:
self.skipTest(f"{self.all_model_classes[0].__name__} does not support SDPA")
for model_class in self.all_model_classes:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model_sdpa = model_class.from_pretrained(tmpdirname)
model_sdpa = model_sdpa.eval().to(torch_device)
text_attn = "sdpa" if model.language_model._supports_sdpa else "eager"
vision_attn = "sdpa" if model.vision_model._supports_sdpa else "eager"
qformer_attn = "sdpa" if model.qformer._supports_sdpa else "eager"
# `None` as it is the requested one which will be assigned to each sub-config
# Sub-model will dispatch to SDPA if it can (checked below that `SDPA` layers are present)
self.assertTrue(model.language_model.config._attn_implementation == text_attn)
self.assertTrue(model.vision_model.config._attn_implementation == vision_attn)
self.assertTrue(model.qformer.config._attn_implementation == qformer_attn)
model_eager = model_class.from_pretrained(tmpdirname, attn_implementation="eager")
model_eager = model_eager.eval().to(torch_device)
self.assertTrue(model_eager.config._attn_implementation == "eager")
self.assertTrue(model_eager.language_model.config._attn_implementation == "eager")
self.assertTrue(model_eager.vision_model.config._attn_implementation == "eager")
self.assertTrue(model_eager.qformer.config._attn_implementation == "eager")
for name, submodule in model_eager.named_modules():
class_name = submodule.__class__.__name__
if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
raise ValueError("The eager model should not have SDPA attention layers")
has_sdpa = False
for name, submodule in model_sdpa.named_modules():
class_name = submodule.__class__.__name__
if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
has_sdpa = True
break
if not has_sdpa and any(
module_attn == "sdpa" for module_attn in [text_attn, vision_attn, qformer_attn]
):
raise ValueError("The SDPA model should have SDPA attention layers")
# We will verify our results on an image of cute cats
def prepare_video():

View File

@ -25,8 +25,17 @@ import requests
from transformers import AutoModelForImageTextToText, AutoProcessor, Kosmos2Config
from transformers.models.kosmos2.configuration_kosmos2 import Kosmos2TextConfig, Kosmos2VisionConfig
from transformers.testing_utils import IS_ROCM_SYSTEM, require_torch, require_vision, slow, torch_device
from transformers.utils import is_torch_available, is_vision_available
from transformers.testing_utils import (
IS_ROCM_SYSTEM,
require_torch,
require_vision,
slow,
torch_device,
)
from transformers.utils import (
is_torch_available,
is_vision_available,
)
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import (
@ -257,6 +266,7 @@ class Kosmos2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase)
test_pruning = False
test_resize_embeddings = False
test_attention_outputs = False
_is_composite = True
# TODO: `image-to-text` pipeline for this model needs Processor.
def is_pipeline_test_to_skip(

View File

@ -186,6 +186,7 @@ class LlavaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTesterM
pipeline_model_mapping = {"image-to-text": LlavaForConditionalGeneration} if is_torch_available() else {}
test_pruning = False
test_head_masking = False
_is_composite = True
def setUp(self):
self.model_tester = LlavaVisionText2TextModelTester(self)
@ -260,6 +261,16 @@ class LlavaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTesterM
def test_sdpa_can_dispatch_on_flash(self):
pass
@unittest.skip("FlashAttention only support fp16 and bf16 data type")
def test_flash_attn_2_fp32_ln(self):
pass
@unittest.skip(
"VLMs need lots of steps to prepare images/mask correctly to get pad-free inputs. Can be tested as part of LLM test"
)
def test_flash_attention_2_padding_matches_padding_free_with_position_ids(self):
pass
@require_torch
class LlavaForConditionalGenerationIntegrationTest(unittest.TestCase):

View File

@ -218,6 +218,7 @@ class LlavaNextForConditionalGenerationModelTest(ModelTesterMixin, GenerationTes
all_generative_model_classes = (LlavaNextForConditionalGeneration,) if is_torch_available() else ()
test_pruning = False
test_head_masking = False
_is_composite = True
def setUp(self):
self.model_tester = LlavaNextVisionText2TextModelTester(self)
@ -316,6 +317,16 @@ class LlavaNextForConditionalGenerationModelTest(ModelTesterMixin, GenerationTes
def test_sdpa_can_dispatch_on_flash(self):
pass
@unittest.skip("FlashAttention only support fp16 and bf16 data type")
def test_flash_attn_2_fp32_ln(self):
pass
@unittest.skip(
"VLMs need lots of steps to prepare images/mask correctly to get pad-free inputs. Can be tested as part of LLM test"
)
def test_flash_attention_2_padding_matches_padding_free_with_position_ids(self):
pass
@require_torch
class LlavaNextForConditionalGenerationIntegrationTest(unittest.TestCase):

View File

@ -236,6 +236,7 @@ class LlavaNextVideoForConditionalGenerationModelTest(ModelTesterMixin, Generati
all_generative_model_classes = (LlavaNextVideoForConditionalGeneration,) if is_torch_available() else ()
test_pruning = False
test_head_masking = False
_is_composite = True
def setUp(self):
self.model_tester = LlavaNextVideoVisionText2TextModelTester(self)
@ -340,6 +341,16 @@ class LlavaNextVideoForConditionalGenerationModelTest(ModelTesterMixin, Generati
def test_sdpa_can_dispatch_on_flash(self):
pass
@unittest.skip("FlashAttention only support fp16 and bf16 data type")
def test_flash_attn_2_fp32_ln(self):
pass
@unittest.skip(
"VLMs need lots of steps to prepare images/mask correctly to get pad-free inputs. Can be tested as part of LLM test"
)
def test_flash_attention_2_padding_matches_padding_free_with_position_ids(self):
pass
@require_torch
class LlavaNextVideoForConditionalGenerationIntegrationTest(unittest.TestCase):

View File

@ -219,6 +219,7 @@ class LlavaOnevisionForConditionalGenerationModelTest(ModelTesterMixin, Generati
all_generative_model_classes = (LlavaOnevisionForConditionalGeneration,) if is_torch_available() else ()
test_pruning = False
test_head_masking = False
_is_composite = True
def setUp(self):
self.model_tester = LlavaOnevisionVisionText2TextModelTester(self)
@ -306,6 +307,16 @@ class LlavaOnevisionForConditionalGenerationModelTest(ModelTesterMixin, Generati
def test_assisted_decoding_with_num_logits_to_keep(self):
pass
@unittest.skip("FlashAttention only support fp16 and bf16 data type")
def test_flash_attn_2_fp32_ln(self):
pass
@unittest.skip(
"VLMs need lots of steps to prepare images/mask correctly to get pad-free inputs. Can be tested as part of LLM test"
)
def test_flash_attention_2_padding_matches_padding_free_with_position_ids(self):
pass
@require_torch
class LlavaOnevisionForConditionalGenerationIntegrationTest(unittest.TestCase):

View File

@ -274,6 +274,7 @@ class MllamaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTester
test_pruning = False
test_head_masking = False
test_torchscript = False
_is_composite = True
def setUp(self):
self.model_tester = MllamaVisionText2TextModelTester(self)

View File

@ -654,8 +654,6 @@ class MusicgenDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
model_sdpa = model_class.from_pretrained(tmpdirname, torch_dtype=torch_dtype)
model_sdpa = model_sdpa.eval().to(torch_device)
self.assertTrue(model_sdpa.config._attn_implementation == "sdpa")
model_eager = model_class.from_pretrained(
tmpdirname,
torch_dtype=torch_dtype,
@ -663,20 +661,6 @@ class MusicgenDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
)
model_eager = model_eager.eval().to(torch_device)
self.assertTrue(model_eager.config._attn_implementation == "eager")
for name, submodule in model_eager.named_modules():
if "SdpaAttention" in submodule.__class__.__name__:
raise ValueError("The eager model should not have SDPA attention layers")
has_sdpa = False
for name, submodule in model_sdpa.named_modules():
if "SdpaAttention" in submodule.__class__.__name__:
has_sdpa = True
break
if not has_sdpa and model_sdpa.config.model_type != "falcon":
raise ValueError("The SDPA model should have SDPA attention layers")
# We use these for loops instead of parameterized.expand just for the interest of avoiding loading/saving 8 times the model,
# but it would be nicer to have an efficient way to use parameterized.expand
fail_cases = []
@ -1042,6 +1026,7 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
# not to test torchscript as the model tester doesn't prepare `input_values` and `padding_mask`
# (and `torchscript` hates `None` values).
test_torchscript = False
_is_composite = True
def setUp(self):
self.model_tester = MusicgenTester(self)
@ -1420,7 +1405,7 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
@require_torch_gpu
@mark.flash_attn_test
@slow
# Copied from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_2_inference_equivalence
# Adapted from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_2_inference_equivalence
def test_flash_attn_2_inference_equivalence(self):
for model_class in self.all_model_classes:
if not model_class._supports_flash_attn_2:
@ -1432,7 +1417,9 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model_fa = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2"
tmpdirname,
torch_dtype=torch.bfloat16,
attn_implementation={"decoder": "flash_attention_2", "audio_encoder": None, "text_encoder": None},
)
model_fa.to(torch_device)
@ -1505,7 +1492,88 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
@require_torch_gpu
@mark.flash_attn_test
@slow
# Copied from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_2_inference_equivalence_right_padding
def test_flash_attn_2_conversion(self):
if not self.has_attentions:
self.skipTest(reason="Model architecture does not support attentions")
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
if not model_class._supports_flash_attn_2:
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model = model_class.from_pretrained(
tmpdirname,
torch_dtype=torch.float16,
attn_implementation={"decoder": "flash_attention_2", "audio_encoder": None, "text_encoder": None},
).to(torch_device)
for _, module in model.named_modules():
if "FlashAttention" in module.__class__.__name__:
return
self.assertTrue(False, "FlashAttention2 modules not found in model")
@require_torch_sdpa
@require_torch_gpu
@slow
def test_sdpa_can_dispatch_on_flash(self):
if not self.has_attentions:
self.skipTest(reason="Model architecture does not support attentions")
torch.compiler.reset()
compute_capability = torch.cuda.get_device_capability()
major, _ = compute_capability
if not torch.version.cuda or major < 8:
self.skipTest(reason="This test requires an NVIDIA GPU with compute capability >= 8.0")
for model_class in self.all_model_classes:
if not model_class._supports_sdpa:
self.skipTest(f"{model_class.__name__} does not support SDPA")
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
inputs_dict = self._prepare_for_class(inputs_dict, model_class)
if config.model_type in ["llava", "llava_next", "vipllava", "video_llava"]:
self.skipTest(
reason="Llava-like models currently (transformers==4.39.1) requires an attention_mask input"
)
if config.model_type in ["paligemma"]:
self.skipTest(
"PaliGemma-like models currently (transformers==4.41.0) requires an attention_mask input"
)
if config.model_type in ["idefics", "idefics2", "idefics3"]:
self.skipTest(reason="Idefics currently (transformers==4.39.1) requires an image_attention_mask input")
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model = model_class.from_pretrained(
tmpdirname,
torch_dtype=torch.float16,
attn_implementation={"decoder": "sdpa", "audio_encoder": None, "text_encoder": None},
)
model.to(torch_device)
inputs_dict.pop("attention_mask", None)
inputs_dict.pop("decoder_attention_mask", None)
for name, inp in inputs_dict.items():
if isinstance(inp, torch.Tensor) and inp.dtype in [torch.float32, torch.float16]:
inputs_dict[name] = inp.to(torch.float16)
with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
_ = model(**inputs_dict)
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
@slow
# Adapted from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_2_inference_equivalence_right_padding
def test_flash_attn_2_inference_equivalence_right_padding(self):
for model_class in self.all_model_classes:
if not model_class._supports_flash_attn_2:
@ -1517,7 +1585,9 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model_fa = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2"
tmpdirname,
torch_dtype=torch.bfloat16,
attn_implementation={"decoder": "flash_attention_2", "audio_encoder": None, "text_encoder": None},
)
model_fa.to(torch_device)
@ -1587,7 +1657,7 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
@require_torch_gpu
@mark.flash_attn_test
@slow
# Copied from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_2_generate_left_padding
# Adapted from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_2_generate_left_padding
def test_flash_attn_2_generate_left_padding(self):
# Ignore copy
for model_class in self.greedy_sample_model_classes:
@ -1622,7 +1692,7 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
model = model_class.from_pretrained(
tmpdirname,
torch_dtype=torch.float16,
attn_implementation="flash_attention_2",
attn_implementation={"decoder": "flash_attention_2", "audio_encoder": None, "text_encoder": None},
low_cpu_mem_usage=True,
).to(torch_device)
@ -1636,7 +1706,7 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
@require_torch_gpu
@mark.flash_attn_test
@slow
# Copied from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_2_generate_padding_right
# Adapted from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_2_generate_padding_right
def test_flash_attn_2_generate_padding_right(self):
# Ignore copy
for model_class in self.greedy_sample_model_classes:
@ -1670,7 +1740,7 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
model = model_class.from_pretrained(
tmpdirname,
torch_dtype=torch.float16,
attn_implementation="flash_attention_2",
attn_implementation={"decoder": "flash_attention_2", "audio_encoder": None, "text_encoder": None},
low_cpu_mem_usage=True,
).to(torch_device)
@ -1684,7 +1754,7 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
@require_torch_gpu
@mark.flash_attn_test
@slow
# Copied from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_2_generate_use_cache
# Adapted from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_2_generate_use_cache
def test_flash_attn_2_generate_use_cache(self):
max_new_tokens = 30
@ -1713,7 +1783,7 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
model = model_class.from_pretrained(
tmpdirname,
torch_dtype=torch.float16,
attn_implementation="flash_attention_2",
attn_implementation={"decoder": "flash_attention_2", "audio_encoder": None, "text_encoder": None},
low_cpu_mem_usage=True,
).to(torch_device)
@ -1726,6 +1796,53 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
use_cache=True,
)
@require_torch_sdpa
def test_sdpa_can_dispatch_composite_models(self):
if not self.has_attentions:
self.skipTest(reason="Model architecture does not support attentions")
if not self._is_composite:
self.skipTest(f"{self.all_model_classes[0].__name__} does not support SDPA")
for model_class in self.all_model_classes:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model_sdpa = model_class.from_pretrained(tmpdirname)
model_sdpa = model_sdpa.eval().to(torch_device)
audio_encoder_attn = "sdpa" if model.audio_encoder._supports_sdpa else "eager"
text_encoder_attn = "sdpa" if model.text_encoder._supports_sdpa else "eager"
decoder_attn = "sdpa" if model.decoder._supports_sdpa else "eager"
# `None` as it is the requested one which will be assigned to each sub-config
# Sub-model will dispatch to SDPA if it can (checked below that `SDPA` layers are present)
self.assertTrue(model_sdpa.audio_encoder.config._attn_implementation == audio_encoder_attn)
self.assertTrue(model_sdpa.text_encoder.config._attn_implementation == text_encoder_attn)
self.assertTrue(model_sdpa.decoder.config._attn_implementation == decoder_attn)
self.assertTrue(model_sdpa.config._attn_implementation == "sdpa")
model_eager = model_class.from_pretrained(tmpdirname, attn_implementation="eager")
model_eager = model_eager.eval().to(torch_device)
self.assertTrue(model_eager.audio_encoder.config._attn_implementation == "eager")
self.assertTrue(model_eager.text_encoder.config._attn_implementation == "eager")
self.assertTrue(model_eager.decoder.config._attn_implementation == "eager")
self.assertTrue(model_eager.config._attn_implementation == "eager")
for name, submodule in model_eager.named_modules():
if "SdpaAttention" in submodule.__class__.__name__:
raise ValueError("The eager model should not have SDPA attention layers")
has_sdpa = False
for name, submodule in model_sdpa.named_modules():
if "SdpaAttention" in submodule.__class__.__name__:
has_sdpa = True
break
if not has_sdpa and model_sdpa.config.model_type != "falcon":
raise ValueError("The SDPA model should have SDPA attention layers")
@parameterized.expand([("float16",), ("bfloat16",), ("float32",)])
@require_torch_sdpa
@slow
@ -1792,8 +1909,6 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
model_sdpa = model_class.from_pretrained(tmpdirname, torch_dtype=torch_dtype)
model_sdpa = model_sdpa.eval().to(torch_device)
self.assertTrue(model_sdpa.config._attn_implementation == "sdpa")
model_eager = model_class.from_pretrained(
tmpdirname,
torch_dtype=torch_dtype,
@ -1801,20 +1916,6 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
)
model_eager = model_eager.eval().to(torch_device)
self.assertTrue(model_eager.config._attn_implementation == "eager")
for name, submodule in model_eager.named_modules():
if "SdpaAttention" in submodule.__class__.__name__:
raise ValueError("The eager model should not have SDPA attention layers")
has_sdpa = False
for name, submodule in model_sdpa.named_modules():
if "SdpaAttention" in submodule.__class__.__name__:
has_sdpa = True
break
if not has_sdpa and model_sdpa.config.model_type != "falcon":
raise ValueError("The SDPA model should have SDPA attention layers")
# We use these for loops instead of parameterized.expand just for the interest of avoiding loading/saving 8 times the model,
# but it would be nicer to have an efficient way to use parameterized.expand
fail_cases = []

View File

@ -311,7 +311,9 @@ class MusicgenMelodyDecoderTest(ModelTesterMixin, GenerationTesterMixin, unittes
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model_fa = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2"
tmpdirname,
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
)
model_fa.to(torch_device)
@ -391,7 +393,9 @@ class MusicgenMelodyDecoderTest(ModelTesterMixin, GenerationTesterMixin, unittes
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model_fa = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2"
tmpdirname,
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
)
model_fa.to(torch_device)
@ -454,148 +458,10 @@ class MusicgenMelodyDecoderTest(ModelTesterMixin, GenerationTesterMixin, unittes
assert torch.allclose(logits_fa[:-1], logits[:-1], atol=4e-2, rtol=4e-2)
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
@slow
# Copied from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_2_generate_left_padding
def test_flash_attn_2_generate_left_padding(self):
# Ignore copy
for model_class in self.greedy_sample_model_classes:
if not model_class._supports_flash_attn_2:
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, low_cpu_mem_usage=True).to(
torch_device
)
dummy_input = inputs_dict[model.main_input_name]
if dummy_input.dtype in [torch.float32, torch.bfloat16]:
dummy_input = dummy_input.to(torch.float16)
dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input))
# make sure we do left padding
dummy_attention_mask[:, :-1] = 0
dummy_attention_mask[:, -1:] = 1
out = model.generate(
dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=8, do_sample=False
)
model = model_class.from_pretrained(
tmpdirname,
torch_dtype=torch.float16,
attn_implementation="flash_attention_2",
low_cpu_mem_usage=True,
).to(torch_device)
out_fa = model.generate(
dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=8, do_sample=False
)
self.assertTrue(torch.allclose(out, out_fa))
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
@slow
# Copied from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_2_generate_padding_right
def test_flash_attn_2_generate_padding_right(self):
# Ignore copy
for model_class in self.greedy_sample_model_classes:
if not model_class._supports_flash_attn_2:
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, low_cpu_mem_usage=True).to(
torch_device
)
dummy_input = inputs_dict[model.main_input_name]
if dummy_input.dtype in [torch.float32, torch.bfloat16]:
dummy_input = dummy_input.to(torch.float16)
dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input))
# make sure we do right padding
dummy_attention_mask[:, :-1] = 1
dummy_attention_mask[:, -1:] = 0
out = model.generate(
dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=8, do_sample=False
)
model = model_class.from_pretrained(
tmpdirname,
torch_dtype=torch.float16,
attn_implementation="flash_attention_2",
low_cpu_mem_usage=True,
).to(torch_device)
out_fa = model.generate(
dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=8, do_sample=False
)
self.assertTrue(torch.allclose(out, out_fa))
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
@slow
# Copied from tests.models.musicgen.test_modeling_musicgen.MusicgenDecoderTest.test_flash_attn_2_generate_use_cache
def test_flash_attn_2_generate_use_cache(self):
max_new_tokens = 30
# Ignore copy
for model_class in self.greedy_sample_model_classes:
if not model_class._supports_flash_attn_2:
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
dummy_input = inputs_dict[model_class.main_input_name]
if dummy_input.dtype in [torch.float32, torch.bfloat16]:
dummy_input = dummy_input.to(torch.float16)
# make sure that all models have enough positions for generation
if hasattr(config, "max_position_embeddings"):
config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input))
model = model_class.from_pretrained(
tmpdirname,
torch_dtype=torch.float16,
attn_implementation="flash_attention_2",
low_cpu_mem_usage=True,
).to(torch_device)
# Just test that a large cache works as expected
_ = model.generate(
dummy_input,
attention_mask=dummy_attention_mask,
max_new_tokens=max_new_tokens,
do_sample=False,
use_cache=True,
)
@parameterized.expand([("float16",), ("bfloat16",), ("float32",)])
@require_torch_sdpa
@slow
# Copied from tests.models.musicgen.test_modeling_musicgen.MusicgenDecoderTest.test_eager_matches_sdpa_inference
# Copied from tests.test_modeling_common.ModelTesterMixin.test_eager_matches_sdpa_inference
def test_eager_matches_sdpa_inference(self, torch_dtype: str):
if not self.has_attentions:
self.skipTest(reason="Model architecture does not support attentions")
@ -658,8 +524,6 @@ class MusicgenMelodyDecoderTest(ModelTesterMixin, GenerationTesterMixin, unittes
model_sdpa = model_class.from_pretrained(tmpdirname, torch_dtype=torch_dtype)
model_sdpa = model_sdpa.eval().to(torch_device)
self.assertTrue(model_sdpa.config._attn_implementation == "sdpa")
model_eager = model_class.from_pretrained(
tmpdirname,
torch_dtype=torch_dtype,
@ -667,20 +531,6 @@ class MusicgenMelodyDecoderTest(ModelTesterMixin, GenerationTesterMixin, unittes
)
model_eager = model_eager.eval().to(torch_device)
self.assertTrue(model_eager.config._attn_implementation == "eager")
for name, submodule in model_eager.named_modules():
if "SdpaAttention" in submodule.__class__.__name__:
raise ValueError("The eager model should not have SDPA attention layers")
has_sdpa = False
for name, submodule in model_sdpa.named_modules():
if "SdpaAttention" in submodule.__class__.__name__:
has_sdpa = True
break
if not has_sdpa and model_sdpa.config.model_type != "falcon":
raise ValueError("The SDPA model should have SDPA attention layers")
# We use these for loops instead of parameterized.expand just for the interest of avoiding loading/saving 8 times the model,
# but it would be nicer to have an efficient way to use parameterized.expand
fail_cases = []
@ -839,74 +689,6 @@ class MusicgenMelodyDecoderTest(ModelTesterMixin, GenerationTesterMixin, unittes
self.assertTrue(len(fail_cases) == 0, "\n".join(fail_cases))
@require_torch_sdpa
@slow
# Copied from tests.models.musicgen.test_modeling_musicgen.MusicgenDecoderTest.test_eager_matches_sdpa_generate
def test_eager_matches_sdpa_generate(self):
max_new_tokens = 30
# Ignore copy
for model_class in self.greedy_sample_model_classes:
if not model_class._supports_sdpa:
self.skipTest(f"{model_class.__name__} does not support SDPA")
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
dummy_input = inputs_dict[model_class.main_input_name]
if dummy_input.dtype in [torch.float32, torch.bfloat16]:
dummy_input = dummy_input.to(torch.float16)
# make sure that all models have enough positions for generation
if hasattr(config, "max_position_embeddings"):
config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input))
model_sdpa = model_class.from_pretrained(
tmpdirname,
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
).to(torch_device)
self.assertTrue(model_sdpa.config._attn_implementation == "sdpa")
model_eager = model_class.from_pretrained(
tmpdirname,
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
attn_implementation="eager",
).to(torch_device)
self.assertTrue(model_eager.config._attn_implementation == "eager")
for name, submodule in model_eager.named_modules():
if "SdpaAttention" in submodule.__class__.__name__:
raise ValueError("The eager model should not have SDPA attention layers")
has_sdpa = False
for name, submodule in model_sdpa.named_modules():
if "SdpaAttention" in submodule.__class__.__name__:
has_sdpa = True
break
if not has_sdpa:
raise ValueError("The SDPA model should have SDPA attention layers")
# Just test that a large cache works as expected
res_eager = model_eager.generate(
dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=max_new_tokens, do_sample=False
)
res_sdpa = model_sdpa.generate(
dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=max_new_tokens, do_sample=False
)
self.assertTrue(torch.allclose(res_eager, res_sdpa))
def prepare_musicgen_melody_inputs_dict(
config,
@ -1048,6 +830,7 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
# not to test torchscript as the model tester doesn't prepare `input_features` and `padding_mask`
# (and `torchscript` hates `None` values).
test_torchscript = False
_is_composite = True
def setUp(self):
self.model_tester = MusicgenMelodyTester(self)
@ -1406,7 +1189,7 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
@require_torch_gpu
@mark.flash_attn_test
@slow
# Copied from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_2_inference_equivalence
# Adapted from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_2_inference_equivalence
def test_flash_attn_2_inference_equivalence(self):
for model_class in self.all_model_classes:
if not model_class._supports_flash_attn_2:
@ -1418,7 +1201,9 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model_fa = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2"
tmpdirname,
torch_dtype=torch.bfloat16,
attn_implementation={"decoder": "flash_attention_2", "audio_encoder": None, "text_encoder": None},
)
model_fa.to(torch_device)
@ -1491,7 +1276,88 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
@require_torch_gpu
@mark.flash_attn_test
@slow
# Copied from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_2_inference_equivalence_right_padding
def test_flash_attn_2_conversion(self):
if not self.has_attentions:
self.skipTest(reason="Model architecture does not support attentions")
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
if not model_class._supports_flash_attn_2:
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model = model_class.from_pretrained(
tmpdirname,
torch_dtype=torch.float16,
attn_implementation={"decoder": "flash_attention_2", "audio_encoder": None, "text_encoder": None},
).to(torch_device)
for _, module in model.named_modules():
if "FlashAttention" in module.__class__.__name__:
return
self.assertTrue(False, "FlashAttention2 modules not found in model")
@require_torch_sdpa
@require_torch_gpu
@slow
def test_sdpa_can_dispatch_on_flash(self):
if not self.has_attentions:
self.skipTest(reason="Model architecture does not support attentions")
torch.compiler.reset()
compute_capability = torch.cuda.get_device_capability()
major, _ = compute_capability
if not torch.version.cuda or major < 8:
self.skipTest(reason="This test requires an NVIDIA GPU with compute capability >= 8.0")
for model_class in self.all_model_classes:
if not model_class._supports_sdpa:
self.skipTest(f"{model_class.__name__} does not support SDPA")
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
inputs_dict = self._prepare_for_class(inputs_dict, model_class)
if config.model_type in ["llava", "llava_next", "vipllava", "video_llava"]:
self.skipTest(
reason="Llava-like models currently (transformers==4.39.1) requires an attention_mask input"
)
if config.model_type in ["paligemma"]:
self.skipTest(
"PaliGemma-like models currently (transformers==4.41.0) requires an attention_mask input"
)
if config.model_type in ["idefics", "idefics2", "idefics3"]:
self.skipTest(reason="Idefics currently (transformers==4.39.1) requires an image_attention_mask input")
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model = model_class.from_pretrained(
tmpdirname,
torch_dtype=torch.float16,
attn_implementation={"decoder": "sdpa", "audio_encoder": None, "text_encoder": None},
)
model.to(torch_device)
inputs_dict.pop("attention_mask", None)
inputs_dict.pop("decoder_attention_mask", None)
for name, inp in inputs_dict.items():
if isinstance(inp, torch.Tensor) and inp.dtype in [torch.float32, torch.float16]:
inputs_dict[name] = inp.to(torch.float16)
with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
_ = model(**inputs_dict)
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
@slow
# Adapted from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_2_inference_equivalence_right_padding
def test_flash_attn_2_inference_equivalence_right_padding(self):
for model_class in self.all_model_classes:
if not model_class._supports_flash_attn_2:
@ -1503,7 +1369,9 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model_fa = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2"
tmpdirname,
torch_dtype=torch.bfloat16,
attn_implementation={"decoder": "flash_attention_2", "audio_encoder": None, "text_encoder": None},
)
model_fa.to(torch_device)
@ -1573,7 +1441,7 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
@require_torch_gpu
@mark.flash_attn_test
@slow
# Copied from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_2_generate_left_padding
# Adapted from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_2_generate_left_padding
def test_flash_attn_2_generate_left_padding(self):
# Ignore copy
for model_class in self.greedy_sample_model_classes:
@ -1608,7 +1476,7 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
model = model_class.from_pretrained(
tmpdirname,
torch_dtype=torch.float16,
attn_implementation="flash_attention_2",
attn_implementation={"decoder": "flash_attention_2", "audio_encoder": None, "text_encoder": None},
low_cpu_mem_usage=True,
).to(torch_device)
@ -1622,7 +1490,7 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
@require_torch_gpu
@mark.flash_attn_test
@slow
# Copied from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_2_generate_padding_right
# Adapted from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_2_generate_padding_right
def test_flash_attn_2_generate_padding_right(self):
# Ignore copy
for model_class in self.greedy_sample_model_classes:
@ -1656,7 +1524,7 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
model = model_class.from_pretrained(
tmpdirname,
torch_dtype=torch.float16,
attn_implementation="flash_attention_2",
attn_implementation={"decoder": "flash_attention_2", "audio_encoder": None, "text_encoder": None},
low_cpu_mem_usage=True,
).to(torch_device)
@ -1670,7 +1538,7 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
@require_torch_gpu
@mark.flash_attn_test
@slow
# Copied from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_2_generate_use_cache
# Adapted from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_2_generate_use_cache
def test_flash_attn_2_generate_use_cache(self):
max_new_tokens = 30
@ -1699,7 +1567,7 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
model = model_class.from_pretrained(
tmpdirname,
torch_dtype=torch.float16,
attn_implementation="flash_attention_2",
attn_implementation={"decoder": "flash_attention_2", "audio_encoder": None, "text_encoder": None},
low_cpu_mem_usage=True,
).to(torch_device)
@ -1712,6 +1580,53 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
use_cache=True,
)
@require_torch_sdpa
def test_sdpa_can_dispatch_composite_models(self):
if not self.has_attentions:
self.skipTest(reason="Model architecture does not support attentions")
if not self._is_composite:
self.skipTest(f"{self.all_model_classes[0].__name__} does not support SDPA")
for model_class in self.all_model_classes:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model_sdpa = model_class.from_pretrained(tmpdirname)
model_sdpa = model_sdpa.eval().to(torch_device)
audio_encoder_attn = "sdpa" if model.audio_encoder._supports_sdpa else "eager"
text_encoder_attn = "sdpa" if model.text_encoder._supports_sdpa else "eager"
decoder_attn = "sdpa" if model.decoder._supports_sdpa else "eager"
# `None` as it is the requested one which will be assigned to each sub-config
# Sub-model will dispatch to SDPA if it can (checked below that `SDPA` layers are present)
self.assertTrue(model_sdpa.audio_encoder.config._attn_implementation == audio_encoder_attn)
self.assertTrue(model_sdpa.text_encoder.config._attn_implementation == text_encoder_attn)
self.assertTrue(model_sdpa.decoder.config._attn_implementation == decoder_attn)
self.assertTrue(model_sdpa.config._attn_implementation == "sdpa")
model_eager = model_class.from_pretrained(tmpdirname, attn_implementation="eager")
model_eager = model_eager.eval().to(torch_device)
self.assertTrue(model_eager.audio_encoder.config._attn_implementation == "eager")
self.assertTrue(model_eager.text_encoder.config._attn_implementation == "eager")
self.assertTrue(model_eager.decoder.config._attn_implementation == "eager")
self.assertTrue(model_eager.config._attn_implementation == "eager")
for name, submodule in model_eager.named_modules():
if "SdpaAttention" in submodule.__class__.__name__:
raise ValueError("The eager model should not have SDPA attention layers")
has_sdpa = False
for name, submodule in model_sdpa.named_modules():
if "SdpaAttention" in submodule.__class__.__name__:
has_sdpa = True
break
if not has_sdpa and model_sdpa.config.model_type != "falcon":
raise ValueError("The SDPA model should have SDPA attention layers")
@parameterized.expand([("float16",), ("bfloat16",), ("float32",)])
@require_torch_sdpa
@slow
@ -1775,8 +1690,6 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
model_sdpa = model_class.from_pretrained(tmpdirname, torch_dtype=torch_dtype)
model_sdpa = model_sdpa.eval().to(torch_device)
self.assertTrue(model_sdpa.config._attn_implementation == "sdpa")
model_eager = model_class.from_pretrained(
tmpdirname,
torch_dtype=torch_dtype,
@ -1784,20 +1697,6 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
)
model_eager = model_eager.eval().to(torch_device)
self.assertTrue(model_eager.config._attn_implementation == "eager")
for name, submodule in model_eager.named_modules():
if "SdpaAttention" in submodule.__class__.__name__:
raise ValueError("The eager model should not have SDPA attention layers")
has_sdpa = False
for name, submodule in model_sdpa.named_modules():
if "SdpaAttention" in submodule.__class__.__name__:
has_sdpa = True
break
if not has_sdpa and model_sdpa.config.model_type != "falcon":
raise ValueError("The SDPA model should have SDPA attention layers")
# We use these for loops instead of parameterized.expand just for the interest of avoiding loading/saving 8 times the model,
# but it would be nicer to have an efficient way to use parameterized.expand
fail_cases = []

View File

@ -187,6 +187,7 @@ class PaliGemmaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTes
test_pruning = False
test_torchscript = False
test_head_masking = False
_is_composite = True
def setUp(self):
self.model_tester = PaliGemmaVisionText2TextModelTester(self)
@ -319,6 +320,16 @@ class PaliGemmaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTes
def test_static_cache_matches_dynamic(self):
pass
@unittest.skip("FlashAttention only support fp16 and bf16 data type")
def test_flash_attn_2_fp32_ln(self):
pass
@unittest.skip(
"VLMs need lots of steps to prepare images/mask correctly to get pad-free inputs. Can be tested as part of LLM test"
)
def test_flash_attention_2_padding_matches_padding_free_with_position_ids(self):
pass
@slow
@require_torch

View File

@ -15,6 +15,7 @@
"""Testing suite for the PyTorch Qwen2Audio model."""
import gc
import tempfile
import unittest
from io import BytesIO
from urllib.request import urlopen
@ -29,6 +30,7 @@ from transformers import (
)
from transformers.testing_utils import (
require_torch,
require_torch_sdpa,
slow,
torch_device,
)
@ -152,6 +154,7 @@ class Qwen2AudioForConditionalGenerationModelTest(ModelTesterMixin, unittest.Tes
all_model_classes = (Qwen2AudioForConditionalGeneration,) if is_torch_available() else ()
test_pruning = False
test_head_masking = False
_is_composite = True
def setUp(self):
self.model_tester = Qwen2AudioModelTester(self)
@ -165,6 +168,53 @@ class Qwen2AudioForConditionalGenerationModelTest(ModelTesterMixin, unittest.Tes
def test_sdpa_can_dispatch_on_flash(self):
pass
@require_torch_sdpa
def test_sdpa_can_dispatch_composite_models(self):
# overwrite because Qwen2 is audio+text model (not vision+text)
if not self.has_attentions:
self.skipTest(reason="Model architecture does not support attentions")
if not self._is_composite:
self.skipTest(f"{self.all_model_classes[0].__name__} does not support SDPA")
for model_class in self.all_model_classes:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model_sdpa = model_class.from_pretrained(tmpdirname)
model_sdpa = model_sdpa.eval().to(torch_device)
text_attn = "sdpa" if model.language_model._supports_sdpa else "eager"
vision_attn = "sdpa" if model.audio_tower._supports_sdpa else "eager"
# `None` as it is the requested one which will be assigned to each sub-config
# Sub-model will dispatch to SDPA if it can (checked below that `SDPA` layers are present)
self.assertTrue(model_sdpa.config._attn_implementation == "sdpa")
self.assertTrue(model.language_model.config._attn_implementation == text_attn)
self.assertTrue(model.audio_tower.config._attn_implementation == vision_attn)
model_eager = model_class.from_pretrained(tmpdirname, attn_implementation="eager")
model_eager = model_eager.eval().to(torch_device)
self.assertTrue(model_eager.config._attn_implementation == "eager")
self.assertTrue(model_eager.language_model.config._attn_implementation == "eager")
self.assertTrue(model_eager.audio_tower.config._attn_implementation == "eager")
for name, submodule in model_eager.named_modules():
class_name = submodule.__class__.__name__
if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
raise ValueError("The eager model should not have SDPA attention layers")
has_sdpa = False
for name, submodule in model_sdpa.named_modules():
class_name = submodule.__class__.__name__
if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
has_sdpa = True
break
if not has_sdpa and model_sdpa.config.model_type != "falcon":
raise ValueError("The SDPA model should have SDPA attention layers")
@require_torch
class Qwen2AudioForConditionalGenerationIntegrationTest(unittest.TestCase):

View File

@ -71,6 +71,51 @@ if is_vision_available():
class SiglipModelTesterMixin(ModelTesterMixin):
def test_sdpa_can_dispatch_composite_models(self):
for model_class in self.all_model_classes:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
# Load the model with SDPA
model_sdpa = model_class.from_pretrained(tmpdirname)
model_sdpa = model_sdpa.eval().to(torch_device)
# Load model with eager attention
model_eager = model_class.from_pretrained(
tmpdirname,
attn_implementation="eager",
)
model_eager = model_eager.eval().to(torch_device)
# SigLip has one shared cls attr for all models, so we assign both submodels heer
vision_attn = text_attn = "sdpa" if model._supports_sdpa else "eager"
if hasattr(model_sdpa, "vision_model") and hasattr(model_sdpa, "text_model"):
self.assertTrue(model_sdpa.vision_model.config._attn_implementation == vision_attn)
self.assertTrue(model_sdpa.text_model.config._attn_implementation == text_attn)
self.assertTrue(model_eager.vision_model.config._attn_implementation == "eager")
self.assertTrue(model_eager.text_model.config._attn_implementation == "eager")
self.assertTrue(model_sdpa.config._attn_implementation == "sdpa")
self.assertTrue(model_eager.config._attn_implementation == "eager")
for name, submodule in model_eager.named_modules():
class_name = submodule.__class__.__name__
if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
raise ValueError("The eager model should not have SDPA attention layers")
has_sdpa = False
for name, submodule in model_sdpa.named_modules():
class_name = submodule.__class__.__name__
if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
has_sdpa = True
break
if not has_sdpa and model_sdpa.config.model_type != "falcon":
raise ValueError("The SDPA model should have SDPA attention layers")
def test_eager_matches_sdpa_inference(
self,
torch_dtype: str,
@ -132,23 +177,6 @@ class SiglipModelTesterMixin(ModelTesterMixin):
)
model_eager = model_eager.eval().to(torch_device)
self.assertTrue(model_sdpa.config._attn_implementation == "sdpa")
self.assertTrue(model_eager.config._attn_implementation == "eager")
for name, submodule in model_eager.named_modules():
class_name = submodule.__class__.__name__
if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
raise ValueError("The eager model should not have SDPA attention layers")
has_sdpa = False
for name, submodule in model_sdpa.named_modules():
class_name = submodule.__class__.__name__
if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
has_sdpa = True
break
if not has_sdpa and model_sdpa.config.model_type != "falcon":
raise ValueError("The SDPA model should have SDPA attention layers")
# We use these for loops instead of parameterized.expand just for the interest of avoiding loading/saving the model each time,
# but it would be nicer to have an efficient way to use parameterized.expand
cases = [
@ -400,6 +428,10 @@ class SiglipVisionModelTest(SiglipModelTesterMixin, unittest.TestCase):
use_attention_mask_options=(False,),
)
@require_torch_sdpa
def test_sdpa_can_dispatch_composite_models(self):
super().test_sdpa_can_dispatch_composite_models()
class SiglipTextModelTester:
def __init__(
@ -562,6 +594,10 @@ class SiglipTextModelTest(SiglipModelTesterMixin, unittest.TestCase):
use_attention_mask_options=(False, True),
)
@require_torch_sdpa
def test_sdpa_can_dispatch_composite_models(self):
super().test_sdpa_can_dispatch_composite_models()
class SiglipModelTester:
def __init__(self, parent, text_kwargs=None, vision_kwargs=None, is_training=True):
@ -629,6 +665,7 @@ class SiglipModelTest(SiglipModelTesterMixin, PipelineTesterMixin, unittest.Test
test_cpu_offload = False
test_disk_offload_safetensors = False
test_disk_offload_bin = False
_is_composite = True
# Copied from tests.models.clip.test_modeling_clip.CLIPModelTest.setUp with CLIP->Siglip
def setUp(self):
@ -851,6 +888,10 @@ class SiglipModelTest(SiglipModelTesterMixin, PipelineTesterMixin, unittest.Test
use_attention_mask_options=(False, True),
)
@require_torch_sdpa
def test_sdpa_can_dispatch_composite_models(self):
super().test_sdpa_can_dispatch_composite_models()
class SiglipForImageClassificationModelTester(SiglipModelTester):
def __init__(self, parent):
@ -888,6 +929,7 @@ class SiglipForImageClassificationModelTest(SiglipModelTesterMixin, PipelineTest
test_cpu_offload = False
test_disk_offload_safetensors = False
test_disk_offload_bin = False
_is_composite = True
def setUp(self):
self.model_tester = SiglipForImageClassificationModelTester(self)
@ -925,6 +967,10 @@ class SiglipForImageClassificationModelTest(SiglipModelTesterMixin, PipelineTest
torch_dtype=torch_dtype, logit_keys=("logits",), use_attention_mask_options=(False,)
)
@require_torch_sdpa
def test_sdpa_can_dispatch_composite_models(self):
super().test_sdpa_can_dispatch_composite_models()
# We will verify our results on an image of cute cats
def prepare_img():

View File

@ -18,7 +18,13 @@ import tempfile
import unittest
from transformers import is_torch_available
from transformers.testing_utils import require_deterministic_for_xpu, require_torch, slow, torch_device
from transformers.testing_utils import (
require_deterministic_for_xpu,
require_torch,
require_torch_sdpa,
slow,
torch_device,
)
from ...test_modeling_common import floats_tensor, ids_tensor, random_attention_mask
from ..bert.test_modeling_bert import BertModelTester
@ -441,6 +447,66 @@ class EncoderDecoderMixin:
max_diff = np.amax(np.abs(out_1 - out_2))
self.assertLessEqual(max_diff, 1e-5)
@require_torch_sdpa
def test_sdpa_can_dispatch_composite_models(self):
inputs_dict = self.prepare_config_and_inputs()
encoder_config, decoder_config = inputs_dict["config"], inputs_dict["decoder_config"]
config = SpeechEncoderDecoderConfig.from_encoder_decoder_configs(
encoder_config=encoder_config, decoder_config=decoder_config
)
model = SpeechEncoderDecoderModel(config=config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model_sdpa = SpeechEncoderDecoderModel.from_pretrained(tmpdirname)
model_sdpa = model_sdpa.eval().to(torch_device)
# see https://github.com/huggingface/transformers/pull/32238
# Sub-model will dispatch to SDPA if it can (checked below that `SDPA` layers are present)
encoder_attn = "sdpa" if model.encoder._supports_sdpa else "eager"
decoder_attn = "sdpa" if model.decoder._supports_sdpa else "eager"
self.assertTrue(model_sdpa.config._attn_implementation == "sdpa")
self.assertTrue(model_sdpa.encoder.config._attn_implementation == encoder_attn)
self.assertTrue(model_sdpa.decoder.config._attn_implementation == decoder_attn)
# Also test that nothing break if we request SDPA explicitly, when both sub-parts support it.
# If the model supports sdpa (i.e. all of sub-models supports it) we'll dispatch safely
# Otherwise we should raise error that SDPA is not supported, as some of the sub-models doesn't support
if encoder_attn == "sdpa" and decoder_attn == "sdpa":
model_sdpa_explicit = SpeechEncoderDecoderModel.from_pretrained(tmpdirname, attn_implementation="sdpa")
model_sdpa_explicit = model_sdpa_explicit.eval().to(torch_device)
self.assertTrue(model_sdpa_explicit.config._attn_implementation == "sdpa")
else:
with self.assertRaises(ValueError):
model_sdpa_explicit = SpeechEncoderDecoderModel.from_pretrained(
tmpdirname, attn_implementation="sdpa"
)
model_eager = SpeechEncoderDecoderModel.from_pretrained(
tmpdirname,
attn_implementation="eager",
)
model_eager = model_eager.eval().to(torch_device)
self.assertTrue(model_eager.config._attn_implementation == "eager")
self.assertTrue(model_eager.encoder.config._attn_implementation == "eager")
self.assertTrue(model_eager.decoder.config._attn_implementation == "eager")
for name, submodule in model_eager.named_modules():
class_name = submodule.__class__.__name__
if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
raise ValueError("The eager model should not have SDPA attention layers")
has_sdpa = False
for name, submodule in model_sdpa.named_modules():
class_name = submodule.__class__.__name__
if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
has_sdpa = True
break
if not has_sdpa:
raise ValueError("The SDPA model should have SDPA attention layers")
@require_torch
class Wav2Vec2BertModelTest(EncoderDecoderMixin, unittest.TestCase):

View File

@ -206,6 +206,7 @@ class VideoLlavaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTe
test_pruning = False
test_resize_embeddings = True
test_head_masking = False
_is_composite = True
def setUp(self):
self.model_tester = VideoLlavaVisionText2TextModelTester(self)
@ -237,6 +238,16 @@ class VideoLlavaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTe
def test_sdpa_can_dispatch_on_flash(self):
pass
@unittest.skip("FlashAttention only support fp16 and bf16 data type")
def test_flash_attn_2_fp32_ln(self):
pass
@unittest.skip(
"VLMs need lots of steps to prepare images/mask correctly to get pad-free inputs. Can be tested as part of LLM test"
)
def test_flash_attention_2_padding_matches_padding_free_with_position_ids(self):
pass
@unittest.skip(
reason="After #33533, this still passes, but many subsequential tests fail with `device-side assert triggered`"
)

View File

@ -168,6 +168,7 @@ class VipLlavaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTest
test_pruning = False
test_resize_embeddings = True
test_head_masking = False
_is_composite = True
def setUp(self):
self.model_tester = VipLlavaVisionText2TextModelTester(self)
@ -242,6 +243,16 @@ class VipLlavaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTest
def test_sdpa_can_dispatch_on_flash(self):
pass
@unittest.skip("FlashAttention only support fp16 and bf16 data type")
def test_flash_attn_2_fp32_ln(self):
pass
@unittest.skip(
"VLMs need lots of steps to prepare images/mask correctly to get pad-free inputs. Can be tested as part of LLM test"
)
def test_flash_attention_2_padding_matches_padding_free_with_position_ids(self):
pass
@require_torch
class VipLlavaForConditionalGenerationIntegrationTest(unittest.TestCase):

View File

@ -27,17 +27,24 @@ from transformers.testing_utils import (
require_nltk,
require_sentencepiece,
require_torch,
require_torch_sdpa,
require_vision,
slow,
to_2tuple,
torch_device,
)
from transformers.utils import cached_property, is_torch_available, is_vision_available
from transformers.utils import (
cached_property,
is_torch_available,
is_vision_available,
)
from ...test_modeling_common import floats_tensor, ids_tensor, random_attention_mask
from ..bart.test_modeling_bart import BartModelTester
from ..bert.test_modeling_bert import BertModelTester
from ..deit.test_modeling_deit import DeiTModelTester
from ..donut.test_modeling_donut_swin import DonutSwinModelTester
from ..gpt2.test_modeling_gpt2 import GPT2ModelTester
from ..layoutlmv3.test_modeling_layoutlmv3 import LayoutLMv3ModelTester
from ..swin.test_modeling_swin import SwinModelTester
from ..trocr.test_modeling_trocr import TrOCRStandaloneDecoderModelTester
@ -53,6 +60,8 @@ if is_torch_available():
BartForCausalLM,
BertLMHeadModel,
DeiTModel,
DonutSwinModel,
GPT2LMHeadModel,
LayoutLMv3Model,
SwinModel,
TrOCRForCausalLM,
@ -72,6 +81,8 @@ if is_vision_available():
@require_torch
class EncoderDecoderMixin:
supports_sdpa = False
def get_encoder_decoder_model(self, config, decoder_config):
pass
@ -374,6 +385,69 @@ class EncoderDecoderMixin:
max_diff = np.amax(np.abs(out_1 - out_2))
self.assertLessEqual(max_diff, 1e-5)
@require_torch_sdpa
def test_sdpa_can_dispatch_composite_models(self):
if not self.supports_sdpa:
self.skipTest("SDPA is not supported")
inputs_dict = self.prepare_config_and_inputs()
encoder_config, decoder_config = inputs_dict["config"], inputs_dict["decoder_config"]
config = VisionEncoderDecoderConfig.from_encoder_decoder_configs(
encoder_config=encoder_config, decoder_config=decoder_config
)
model = VisionEncoderDecoderModel(config=config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model_sdpa = VisionEncoderDecoderModel.from_pretrained(tmpdirname)
model_sdpa = model_sdpa.eval().to(torch_device)
# see https://github.com/huggingface/transformers/pull/32238
# Sub-model will dispatch to SDPA if it can (checked below that `SDPA` layers are present)
encoder_attn = "sdpa" if model.encoder._supports_sdpa else "eager"
decoder_attn = "sdpa" if model.decoder._supports_sdpa else "eager"
self.assertTrue(model_sdpa.config._attn_implementation == "sdpa")
self.assertTrue(model_sdpa.encoder.config._attn_implementation == encoder_attn)
self.assertTrue(model_sdpa.decoder.config._attn_implementation == decoder_attn)
# Also test that nothing break if we request SDPA explicitly, when both sub-parts support it.
# If the model supports sdpa (i.e. all of sub-models supports it) we'll dispatch safely
# Otherwise we should raise error that SDPA is not supported, as some of the sub-models doesn't support
if encoder_attn == "sdpa" and decoder_attn == "sdpa":
model_sdpa_explicit = VisionEncoderDecoderModel.from_pretrained(tmpdirname, attn_implementation="sdpa")
model_sdpa_explicit = model_sdpa_explicit.eval().to(torch_device)
self.assertTrue(model_sdpa_explicit.config._attn_implementation == "sdpa")
else:
with self.assertRaises(ValueError):
model_sdpa_explicit = VisionEncoderDecoderModel.from_pretrained(
tmpdirname, attn_implementation="sdpa"
)
model_eager = VisionEncoderDecoderModel.from_pretrained(
tmpdirname,
attn_implementation="eager",
)
model_eager = model_eager.eval().to(torch_device)
self.assertTrue(model_eager.config._attn_implementation == "eager")
self.assertTrue(model_eager.encoder.config._attn_implementation == "eager")
self.assertTrue(model_eager.decoder.config._attn_implementation == "eager")
for name, submodule in model_eager.named_modules():
class_name = submodule.__class__.__name__
if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
raise ValueError("The eager model should not have SDPA attention layers")
has_sdpa = False
for name, submodule in model_sdpa.named_modules():
class_name = submodule.__class__.__name__
if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
has_sdpa = True
break
if not has_sdpa:
raise ValueError("The SDPA model should have SDPA attention layers")
@require_torch
class DeiT2RobertaModelTest(EncoderDecoderMixin, unittest.TestCase):
@ -497,6 +571,8 @@ class DeiT2RobertaModelTest(EncoderDecoderMixin, unittest.TestCase):
@require_torch
class ViT2BertModelTest(EncoderDecoderMixin, unittest.TestCase):
supports_sdpa = True # one submodel support SDPA
def get_pretrained_model_and_inputs(self):
model = VisionEncoderDecoderModel.from_encoder_decoder_pretrained(
"hf-internal-testing/tiny-random-vit", "hf-internal-testing/tiny-bert"
@ -649,6 +725,8 @@ class Swin2BartModelTest(EncoderDecoderMixin, unittest.TestCase):
@require_torch
class ViT2TrOCR(EncoderDecoderMixin, unittest.TestCase):
supports_sdpa = True # one submodel support SDPA
def get_encoder_decoder_model(self, config, decoder_config):
encoder_model = ViTModel(config).eval()
decoder_model = TrOCRForCausalLM(decoder_config).eval()
@ -804,6 +882,240 @@ class LayoutLMv32TrOCR(EncoderDecoderMixin, unittest.TestCase):
pass
@require_torch
class VIT2GPT2Test(EncoderDecoderMixin, unittest.TestCase):
supports_sdpa = True # both submodels support SDPA
def get_encoder_decoder_model(self, config, decoder_config):
encoder_model = ViTModel(config).eval()
decoder_model = GPT2LMHeadModel(decoder_config).eval()
return encoder_model, decoder_model
def prepare_config_and_inputs(self):
model_tester_encoder = ViTModelTester(self, batch_size=13)
model_tester_decoder = GPT2ModelTester(self, batch_size=13, hidden_size=32, max_position_embeddings=512)
encoder_config_and_inputs = model_tester_encoder.prepare_config_and_inputs()
decoder_config_and_inputs = model_tester_decoder.prepare_config_and_inputs()
config, pixel_values, labels = encoder_config_and_inputs
(
decoder_config,
decoder_input_ids,
decoder_attention_mask,
decoder_head_mask,
decoder_token_type_ids,
mc_token_ids,
sequence_labels,
token_labels,
choice_labels,
) = decoder_config_and_inputs
# make sure that cross attention layers are added
decoder_config.add_cross_attention = True
# disable cache for now
decoder_config.use_cache = False
return {
"config": config,
"pixel_values": pixel_values,
"decoder_config": decoder_config,
"decoder_input_ids": decoder_input_ids,
"decoder_attention_mask": decoder_attention_mask,
"decoder_head_mask": decoder_head_mask,
"labels": decoder_input_ids,
}
def check_encoder_decoder_model_output_attentions(
self,
config,
decoder_config,
decoder_input_ids,
decoder_attention_mask,
pixel_values,
labels=None,
**kwargs,
):
# make the decoder inputs a different shape from the encoder inputs to harden the test
decoder_input_ids = decoder_input_ids[:, :-1]
decoder_attention_mask = decoder_attention_mask[:, :-1]
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
enc_dec_model = VisionEncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
enc_dec_model.to(torch_device)
outputs_encoder_decoder = enc_dec_model(
pixel_values=pixel_values,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
output_attentions=True,
**kwargs,
)
encoder_attentions = outputs_encoder_decoder["encoder_attentions"]
self.assertEqual(len(encoder_attentions), config.num_hidden_layers)
seq_len = (encoder_model.config.image_size // encoder_model.config.patch_size) ** 2 + 1
decoder_attentions = outputs_encoder_decoder["decoder_attentions"]
num_decoder_layers = (
decoder_config.num_decoder_layers
if hasattr(decoder_config, "num_decoder_layers")
else decoder_config.num_hidden_layers
)
self.assertEqual(len(decoder_attentions), num_decoder_layers)
self.assertEqual(
decoder_attentions[0].shape[-3:],
(decoder_config.num_attention_heads, decoder_input_ids.shape[-1], decoder_input_ids.shape[-1]),
)
cross_attentions = outputs_encoder_decoder["cross_attentions"]
self.assertEqual(len(cross_attentions), num_decoder_layers)
cross_attention_input_seq_len = decoder_input_ids.shape[-1]
self.assertEqual(
cross_attentions[0].shape[-3:],
(decoder_config.num_attention_heads, cross_attention_input_seq_len, seq_len), # 4 6 16
)
def check_encoder_decoder_model_generate(self, config, decoder_config, pixel_values=None, **kwargs):
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
enc_dec_model = VisionEncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
# Generate until max length
if hasattr(enc_dec_model.config, "eos_token_id"):
enc_dec_model.config.eos_token_id = None
if hasattr(enc_dec_model.config, "decoder") and hasattr(enc_dec_model.config.decoder, "eos_token_id"):
enc_dec_model.config.decoder.eos_token_id = None
if hasattr(enc_dec_model.generation_config, "eos_token_id"):
enc_dec_model.generation_config.eos_token_id = None
enc_dec_model.to(torch_device)
generated_output = enc_dec_model.generate(
pixel_values=pixel_values,
decoder_start_token_id=enc_dec_model.config.decoder.bos_token_id,
**kwargs,
)
self.assertEqual(generated_output.shape, (pixel_values.shape[0],) + (decoder_config.max_length,))
@unittest.skip(reason="VIT2GPT2 also has an integration test for testinf save-load")
def test_real_model_save_load_from_pretrained(self):
pass
@require_torch
class Donut2GPT2Test(EncoderDecoderMixin, unittest.TestCase):
supports_sdpa = True # one submodel (GPT2) support SDPA
def get_encoder_decoder_model(self, config, decoder_config):
encoder_model = DonutSwinModel(config).eval()
decoder_model = GPT2LMHeadModel(decoder_config).eval()
return encoder_model, decoder_model
def prepare_config_and_inputs(self):
model_tester_encoder = DonutSwinModelTester(self, batch_size=13)
model_tester_decoder = GPT2ModelTester(self, batch_size=13, hidden_size=32, max_position_embeddings=512)
encoder_config_and_inputs = model_tester_encoder.prepare_config_and_inputs()
decoder_config_and_inputs = model_tester_decoder.prepare_config_and_inputs()
config, pixel_values, labels = encoder_config_and_inputs
(
decoder_config,
decoder_input_ids,
decoder_attention_mask,
decoder_head_mask,
decoder_token_type_ids,
mc_token_ids,
sequence_labels,
token_labels,
choice_labels,
) = decoder_config_and_inputs
# make sure that cross attention layers are added
decoder_config.add_cross_attention = True
# disable cache for now
decoder_config.use_cache = False
return {
"config": config,
"pixel_values": pixel_values,
"decoder_config": decoder_config,
"decoder_input_ids": decoder_input_ids,
"decoder_attention_mask": decoder_attention_mask,
"decoder_head_mask": decoder_head_mask,
"labels": decoder_input_ids,
}
def check_encoder_decoder_model_output_attentions(
self,
config,
decoder_config,
decoder_input_ids,
decoder_attention_mask,
pixel_values,
labels=None,
**kwargs,
):
# make the decoder inputs a different shape from the encoder inputs to harden the test
decoder_input_ids = decoder_input_ids[:, :-1]
decoder_attention_mask = decoder_attention_mask[:, :-1]
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
enc_dec_model = VisionEncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
enc_dec_model.to(torch_device)
outputs_encoder_decoder = enc_dec_model(
pixel_values=pixel_values,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
output_attentions=True,
**kwargs,
)
encoder_attentions = outputs_encoder_decoder["encoder_attentions"]
self.assertEqual(len(encoder_attentions), config.num_hidden_layers)
seq_len = encoder_model.config.image_size // encoder_model.config.patch_size
decoder_attentions = outputs_encoder_decoder["decoder_attentions"]
num_decoder_layers = (
decoder_config.num_decoder_layers
if hasattr(decoder_config, "num_decoder_layers")
else decoder_config.num_hidden_layers
)
self.assertEqual(len(decoder_attentions), num_decoder_layers)
self.assertEqual(
decoder_attentions[0].shape[-3:],
(decoder_config.num_attention_heads, decoder_input_ids.shape[-1], decoder_input_ids.shape[-1]),
)
cross_attentions = outputs_encoder_decoder["cross_attentions"]
self.assertEqual(len(cross_attentions), num_decoder_layers)
cross_attention_input_seq_len = decoder_input_ids.shape[-1]
self.assertEqual(
cross_attentions[0].shape[-3:],
(decoder_config.num_attention_heads, cross_attention_input_seq_len, seq_len), # 4 6 16
)
def check_encoder_decoder_model_generate(self, config, decoder_config, pixel_values=None, **kwargs):
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
enc_dec_model = VisionEncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
# Generate until max length
if hasattr(enc_dec_model.config, "eos_token_id"):
enc_dec_model.config.eos_token_id = None
if hasattr(enc_dec_model.config, "decoder") and hasattr(enc_dec_model.config.decoder, "eos_token_id"):
enc_dec_model.config.decoder.eos_token_id = None
if hasattr(enc_dec_model.generation_config, "eos_token_id"):
enc_dec_model.generation_config.eos_token_id = None
enc_dec_model.to(torch_device)
generated_output = enc_dec_model.generate(
pixel_values=pixel_values,
decoder_start_token_id=enc_dec_model.config.decoder.bos_token_id,
**kwargs,
)
self.assertEqual(generated_output.shape, (pixel_values.shape[0],) + (decoder_config.max_length,))
@unittest.skip(reason="Donut has an Integration test for that")
def test_real_model_save_load_from_pretrained(self):
pass
@require_vision
@require_torch
class TrOCRModelIntegrationTest(unittest.TestCase):

View File

@ -207,6 +207,7 @@ class ModelTesterMixin:
test_model_parallel = False
is_encoder_decoder = False
has_attentions = True
_is_composite = False
model_split_percents = [0.5, 0.7, 0.9]
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
@ -3006,6 +3007,7 @@ class ModelTesterMixin:
*get_values(MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES),
]:
continue
model = model_class(config)
model.to(torch_device)
model.eval()
@ -3950,6 +3952,147 @@ class ModelTesterMixin:
self.assertTrue(torch.allclose(out, out_fa))
def test_attn_implementation_composite_models(self):
"""
Tests if composite models can receive a dict object as attn_implementation, where each key should be
one of the sub-configs from the model's config.
"""
if not self.has_attentions:
self.skipTest(reason="Model architecture does not support attentions")
for model_class in self.all_model_classes:
if not self._is_composite:
self.skipTest("Model is not a composite model.")
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
sub_configs = {
key: getattr(config, key) for key in config if isinstance(getattr(config, key), PretrainedConfig)
}
# set eager as it will be the one supported in all models
# we just need to test if passing 'attn_implementation' as a dict fails or not
attn_implementation_per_subconfig = {}
for key, sub_config in sub_configs.items():
attn_implementation_per_subconfig[key] = "eager"
config._attn_implementation = attn_implementation_per_subconfig
model = model_class(config)
for key in model.config:
if isinstance(getattr(model.config, key), PretrainedConfig):
sub_config = getattr(model.config, key)
self.assertTrue(sub_config._attn_implementation == "eager")
for name, submodule in model.named_modules():
class_name = submodule.__class__.__name__
if (
"SdpaAttention" in class_name
or "SdpaSelfAttention" in class_name
or "FlashAttention" in class_name
):
raise ValueError("The eager model should not have SDPA/FA2 attention layers")
@require_torch_sdpa
def test_sdpa_can_dispatch_non_composite_models(self):
"""
Tests if non-composite models dispatch correctly on SDPA/eager when requested so when loading the model.
This tests only by looking at layer names, as usually SDPA layers are calles "SDPAAttention".
"""
if not self.has_attentions:
self.skipTest(reason="Model architecture does not support attentions")
if not self.all_model_classes[0]._supports_sdpa or self._is_composite:
self.skipTest(f"{self.all_model_classes[0].__name__} does not support SDPA")
for model_class in self.all_model_classes:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model_sdpa = model_class.from_pretrained(tmpdirname)
model_sdpa = model_sdpa.eval().to(torch_device)
self.assertTrue(model_sdpa.config._attn_implementation == "sdpa")
model_eager = model_class.from_pretrained(tmpdirname, attn_implementation="eager")
model_eager = model_eager.eval().to(torch_device)
self.assertTrue(model_eager.config._attn_implementation == "eager")
for name, submodule in model_eager.named_modules():
class_name = submodule.__class__.__name__
if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
raise ValueError("The eager model should not have SDPA attention layers")
has_sdpa = False
for name, submodule in model_sdpa.named_modules():
class_name = submodule.__class__.__name__
if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
has_sdpa = True
break
if not has_sdpa and model_sdpa.config.model_type != "falcon":
raise ValueError("The SDPA model should have SDPA attention layers")
@require_torch_sdpa
def test_sdpa_can_dispatch_composite_models(self):
"""
Tests if composite models dispatch correctly on SDPA/eager when requested so when loading the model.
This tests only by looking at layer names, as usually SDPA layers are calles "SDPAAttention".
In contrast to the above test, this one checks if the "config._attn_implamentation" is a dict after the model
is loaded, because we manually replicate requested attn implementation on each sub-config when loading.
See https://github.com/huggingface/transformers/pull/32238 for more info
The test tries to cover most general cases of composite models, VLMs with vision and text configs. Any model
that has a different set of sub-configs has to overwrite this test.
"""
if not self.has_attentions:
self.skipTest(reason="Model architecture does not support attentions")
if not self._is_composite:
self.skipTest(f"{self.all_model_classes[0].__name__} does not support SDPA")
for model_class in self.all_model_classes:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model_sdpa = model_class.from_pretrained(tmpdirname)
model_sdpa = model_sdpa.eval().to(torch_device)
vision_model_names = {"visual", "image_tower", "vision_tower", "vision_model"}
language_model_names = {"language_model", "model", "text_model"}
vision_model_name = [name for name in vision_model_names if hasattr(model_sdpa, name)][0]
language_model_name = [name for name in language_model_names if hasattr(model_sdpa, name)][0]
vision_model_sdpa = getattr(model, vision_model_name)
language_model_sdpa = getattr(model, language_model_name)
text_attn = "sdpa" if language_model_sdpa._supports_sdpa else "eager"
vision_attn = "sdpa" if vision_model_sdpa._supports_sdpa else "eager"
# `None` as it is the requested one which will be assigned to each sub-config
# Sub-model will dispatch to SDPA if it can (checked below that `SDPA` layers are present)
self.assertTrue(language_model_sdpa.config._attn_implementation == text_attn)
self.assertTrue(vision_model_sdpa.config._attn_implementation == vision_attn)
model_eager = model_class.from_pretrained(tmpdirname, attn_implementation="eager")
model_eager = model_eager.eval().to(torch_device)
self.assertTrue(getattr(model_eager, language_model_name).config._attn_implementation == "eager")
self.assertTrue(getattr(model_eager, vision_model_name).config._attn_implementation == "eager")
for name, submodule in model_eager.named_modules():
class_name = submodule.__class__.__name__
if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
raise ValueError("The eager model should not have SDPA attention layers")
has_sdpa = False
for name, submodule in model_sdpa.named_modules():
class_name = submodule.__class__.__name__
if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
has_sdpa = True
break
if not has_sdpa and any(module_attn == "sdpa" for module_attn in [text_attn, vision_attn]):
raise ValueError("The SDPA model should have SDPA attention layers")
@parameterized.expand([("float16",), ("bfloat16",), ("float32",)])
@require_torch_sdpa
@slow
@ -4012,7 +4155,6 @@ class ModelTesterMixin:
# This means that the class needs to be instantiated much later, after `use_mask` is set, which means a significant refactor of the code.
# However masking there is not done at any layers that matters (i.e self-attention), therefore we can safely deactivate it.
deactivate_mask = "use_mask_token" in inspect.signature(model_class).parameters
is_encoder_decoder = model.config.is_encoder_decoder
with tempfile.TemporaryDirectory() as tmpdirname:
@ -4020,8 +4162,6 @@ class ModelTesterMixin:
model_sdpa = model_class.from_pretrained(tmpdirname, torch_dtype=torch_dtype)
model_sdpa = model_sdpa.eval().to(torch_device)
self.assertTrue(model_sdpa.config._attn_implementation == "sdpa")
model_eager = model_class.from_pretrained(
tmpdirname,
torch_dtype=torch_dtype,
@ -4029,22 +4169,6 @@ class ModelTesterMixin:
)
model_eager = model_eager.eval().to(torch_device)
self.assertTrue(model_eager.config._attn_implementation == "eager")
for name, submodule in model_eager.named_modules():
class_name = submodule.__class__.__name__
if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
raise ValueError("The eager model should not have SDPA attention layers")
has_sdpa = False
for name, submodule in model_sdpa.named_modules():
class_name = submodule.__class__.__name__
if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
has_sdpa = True
break
if not has_sdpa and model_sdpa.config.model_type != "falcon":
raise ValueError("The SDPA model should have SDPA attention layers")
# We use these for loops instead of parameterized.expand just for the interest of avoiding loading/saving 16 times the model,
# but it would be nicer to have an efficient way to use parameterized.expand
fail_cases = []
@ -4279,7 +4403,7 @@ class ModelTesterMixin:
self.skipTest(
"PaliGemma-like models currently (transformers==4.41.0) requires an attention_mask input"
)
if config.model_type in ["idefics"]:
if config.model_type in ["idefics", "idefics2", "idefics3"]:
self.skipTest(reason="Idefics currently (transformers==4.39.1) requires an image_attention_mask input")
model = model_class(config)
@ -4382,8 +4506,6 @@ class ModelTesterMixin:
low_cpu_mem_usage=True,
).to(torch_device)
self.assertTrue(model_sdpa.config._attn_implementation == "sdpa")
model_eager = model_class.from_pretrained(
tmpdirname,
torch_dtype=torch.float16,
@ -4391,22 +4513,6 @@ class ModelTesterMixin:
attn_implementation="eager",
).to(torch_device)
self.assertTrue(model_eager.config._attn_implementation == "eager")
for name, submodule in model_eager.named_modules():
class_name = submodule.__class__.__name__
if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
raise ValueError("The eager model should not have SDPA attention layers")
has_sdpa = False
for name, submodule in model_sdpa.named_modules():
class_name = submodule.__class__.__name__
if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
has_sdpa = True
break
if not has_sdpa:
raise ValueError("The SDPA model should have SDPA attention layers")
# Just test that a large cache works as expected
res_eager = model_eager.generate(
dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=max_new_tokens, do_sample=False
@ -4429,6 +4535,8 @@ class ModelTesterMixin:
self.skipTest(f"No generative model classes for {self.__class__.__name__}")
for model_class in self.all_generative_model_classes:
if model_class._supports_sdpa:
self.skipTest(reason="Model architecture does not support attentions")
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
if config.model_type not in WINDOW_ATTENTION_MODELS:
@ -4531,6 +4639,62 @@ class ModelTesterMixin:
use_cache=True,
)
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
def test_flash_attn_2_can_dispatch_composite_models(self):
"""
Tests if composite models can dispatch on FA2 if the sub-models support FA2.
The tests is needed as we handle differently composite models and we cannot check them
with above tests. If any of the sub-models does not support FA2, we'll raise an error when dispatching
that particular sub-model. Otherwise we dispatch safely in all sub-models, where "sub-models" are specific
backbone models (LM/vision/audio/etc)
"""
if not self.has_attentions:
self.skipTest(reason="Model architecture does not support attentions")
if not is_torch_fp16_available_on_device(torch_device):
self.skipTest(f"float16 not supported on {torch_device} (on the specific device currently used)")
torch_dtype = torch.float16
for model_class in self.all_model_classes:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config)
if not self._is_composite:
self.skipTest("This model is not a composte model!")
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch_dtype)
supports_fa2_all_modules = all(
module._supports_flash_attn_2
for name, module in model.named_modules()
if isinstance(module, PreTrainedModel) and name != ""
)
if not supports_fa2_all_modules:
with self.assertRaises(ValueError):
model_fa2 = model_class.from_pretrained(
tmpdirname, torch_dtype=torch_dtype, attn_implementation="flash_attention_2"
)
else:
model_fa2 = model_class.from_pretrained(
tmpdirname, torch_dtype=torch_dtype, attn_implementation="flash_attention_2"
)
for key in model_fa2.config:
if isinstance(getattr(model_fa2.config, key), PretrainedConfig):
sub_config = getattr(model_fa2.config, key)
self.assertTrue(sub_config._attn_implementation == "flash_attention_2")
has_fa2 = False
for name, submodule in model_fa2.named_modules():
class_name = submodule.__class__.__name__
if "FlashAttention" in class_name:
has_fa2 = True
break
if not has_fa2:
raise ValueError("The FA2 model should have FA2 layers")
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
@ -4679,7 +4843,7 @@ class ModelTesterMixin:
if 0 in inputs_dict["attention_mask"][:, -1]:
inputs_dict["attention_mask"] = inputs_dict["attention_mask"].flip(1)
dummy_attention_mask = inputs_dict["attention_mask"]
inputs_dict["input_ids"][~dummy_attention_mask.bool()] = config.pad_token_id
inputs_dict["input_ids"][~dummy_attention_mask.bool()] = config.get_text_config().pad_token_id
model = (
model_class.from_pretrained(

View File

@ -228,6 +228,7 @@ class ConfigTestUtils(unittest.TestCase):
"_name_or_path",
"_commit_hash",
"_attn_implementation_internal",
"_attn_implementation_autoset",
"transformers_version",
],
)

View File

@ -82,6 +82,8 @@ PRIVATE_MODELS = [
"SeamlessM4Tv2TextToUnitModel",
"SeamlessM4Tv2CodeHifiGan",
"SeamlessM4Tv2TextToUnitForConditionalGeneration",
"Idefics2PerceiverResampler",
"Idefics2VisionTransformer",
"Idefics3VisionTransformer",
]
@ -225,7 +227,6 @@ IGNORE_NON_AUTO_CONFIGURED = PRIVATE_MODELS.copy() + [
"BeitForMaskedImageModeling",
"ChineseCLIPTextModel",
"ChineseCLIPVisionModel",
"CLIPTextModel",
"CLIPTextModelWithProjection",
"CLIPVisionModelWithProjection",
"ClvpForCausalLM",
@ -327,6 +328,7 @@ IGNORE_NON_AUTO_CONFIGURED = PRIVATE_MODELS.copy() + [
"SiglipVisionModel",
"SiglipTextModel",
"ChameleonVQVAE", # no autoclass for VQ-VAE models
"CLIPTextModel",
"MoshiForConditionalGeneration", # no auto class for speech-to-speech
]