From bf68dd9e6e25f63a17aa793b557316c5765521fd Mon Sep 17 00:00:00 2001 From: Raushan Turganbay Date: Tue, 3 Jun 2025 09:40:44 +0200 Subject: [PATCH] [tests] expand flex-attn test for vision models (#38434) * expand the test for VLMs * typo * mark models `supports_flex` + expand test for additional kwargs * flex attn for refactored vision models * fix copies * fix * unskip * style * address comments --- .../modeling_new_task_model.py | 444 ++++++++++++------ .../modular_new_task_model.py | 3 +- .../modeling_audio_spectrogram_transformer.py | 2 + .../models/aya_vision/modeling_aya_vision.py | 1 + .../models/chameleon/modeling_chameleon.py | 1 + src/transformers/models/clip/modeling_clip.py | 2 + .../models/colpali/modeling_colpali.py | 3 +- src/transformers/models/deit/modeling_deit.py | 2 + .../models/dinov2/modeling_dinov2.py | 2 + .../modeling_dinov2_with_registers.py | 2 + src/transformers/models/dpt/modeling_dpt.py | 2 + src/transformers/models/emu3/modeling_emu3.py | 2 +- .../models/got_ocr2/modeling_got_ocr2.py | 1 + .../models/ijepa/modeling_ijepa.py | 2 + .../models/ijepa/modular_ijepa.py | 2 + .../models/internvl/modeling_internvl.py | 3 + .../models/internvl/modular_internvl.py | 2 + .../models/llava/modeling_llava.py | 1 + .../models/llava_next/modeling_llava_next.py | 1 + .../modeling_llava_next_video.py | 1 + .../modeling_llava_onevision.py | 1 + .../models/mistral3/modeling_mistral3.py | 1 + .../models/mllama/modeling_mllama.py | 1 + .../models/paligemma/modeling_paligemma.py | 1 + .../modeling_phi4_multimodal.py | 1 + .../models/siglip/modeling_siglip.py | 2 + .../models/siglip2/modeling_siglip2.py | 2 + .../models/videomae/modeling_videomae.py | 2 + .../models/vipllava/modeling_vipllava.py | 1 + src/transformers/models/vit/modeling_vit.py | 2 + .../models/vit_mae/modeling_vit_mae.py | 2 + .../models/vit_msn/modeling_vit_msn.py | 2 + .../models/vivit/modeling_vivit.py | 2 + .../models/yolos/modeling_yolos.py | 2 + tests/models/clip/test_modeling_clip.py | 1 + tests/models/emu3/test_modeling_emu3.py | 4 - tests/models/gemma3/test_modeling_gemma3.py | 6 - .../models/got_ocr2/test_modeling_got_ocr2.py | 4 - .../models/musicgen/test_modeling_musicgen.py | 2 + .../test_modeling_musicgen_melody.py | 2 + tests/models/siglip/test_modeling_siglip.py | 4 +- tests/models/siglip2/test_modeling_siglip2.py | 4 +- .../models/videomae/test_modeling_videomae.py | 3 +- .../models/vipllava/test_modeling_vipllava.py | 4 - tests/test_modeling_common.py | 89 +++- 45 files changed, 429 insertions(+), 195 deletions(-) diff --git a/examples/modular-transformers/modeling_new_task_model.py b/examples/modular-transformers/modeling_new_task_model.py index 8da71ab1709..e618c9a1c62 100644 --- a/examples/modular-transformers/modeling_new_task_model.py +++ b/examples/modular-transformers/modeling_new_task_model.py @@ -12,24 +12,52 @@ from torch import nn from ...cache_utils import Cache, HybridCache, StaticCache from ...generation import GenerationMixin +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_outputs import BaseModelOutputWithPast from ...modeling_utils import PreTrainedModel -from ...utils import ( - ModelOutput, - add_start_docstrings, - add_start_docstrings_to_model_forward, - replace_return_docstrings, -) -from ..auto import AutoModel, AutoModelForCausalLM +from ...processing_utils import Unpack +from ...utils import ModelOutput, auto_docstring, can_return_tuple, is_torchdynamo_compiling +from ..auto import AutoModel from .configuration_new_task_model import NewTaskModelConfig -_CONFIG_FOR_DOC = "NewTaskModelConfig" +@dataclass +class NewTaskModelModelOutputWithPast(BaseModelOutputWithPast): + """ + Base class for NewTaskModel outputs, with hidden states and attentions. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + image_hidden_states (`torch.FloatTensor`, *optional*): + A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`. + image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state. + """ + + image_hidden_states: Optional[torch.FloatTensor] = None @dataclass class NewTaskModelCausalLMOutputWithPast(ModelOutput): """ - Base class for NewTaskModelcausal language model (or autoregressive) outputs. + Base class for NewTaskModel causal language model (or autoregressive) outputs. Args: loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): @@ -77,30 +105,10 @@ class NewTaskModelMultiModalProjector(nn.Module): return hidden_states -NEW_TASK_MODEL_START_DOCSTRING = r""" - This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads - etc.) - - This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. - Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage - and behavior. - - Parameters: - config ([`NewTaskModelConfig`] or [`NewTaskModelVisionConfig`]): - Model configuration class with all the parameters of the model. Initializing with a config file does not - load the weights associated with the model, only the configuration. Check out the - [`~PreTrainedModel.from_pretrained`] method to load the model weights. -""" - - -@add_start_docstrings( - "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", - NEW_TASK_MODEL_START_DOCSTRING, -) +@auto_docstring class NewTaskModelPreTrainedModel(PreTrainedModel): config_class = NewTaskModelConfig - base_model_prefix = "model" + base_model_prefix = "" supports_gradient_checkpointing = True _no_split_modules = ["NewTaskModelMultiModalProjector"] _skip_keys_device_placement = "past_key_values" @@ -109,6 +117,8 @@ class NewTaskModelPreTrainedModel(PreTrainedModel): _supports_static_cache = True _supports_flash_attn_2 = True _supports_sdpa = True + _supports_flex_attn = True + _supports_attention_backend = True def _init_weights(self, module): # important: this ported version of NewTaskModelisn't meant for training from scratch - only @@ -121,102 +131,24 @@ class NewTaskModelPreTrainedModel(PreTrainedModel): module.bias.data.zero_() -NEW_TASK_MODEL_INPUTS_DOCSTRING = r""" - Args: - input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide - it. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are input IDs?](../glossary#input-ids) - pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)): - The tensors corresponding to the input images. Pixel values can be obtained using - [`AutoImageProcessor`]. See [`SiglipImageProcessor.__call__`] for details ([]`NewTaskModelProcessor`] uses - [`SiglipImageProcessor`] for processing images). - attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see - `past_key_values`). - - If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] - and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more - information on the default strategy. - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, - config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids) - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape - `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape - `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. - - Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention - blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. - - If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that - don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all - `decoder_input_ids` of shape `(batch_size, sequence_length)`. - inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This - is useful if you want more control over how to convert `input_ids` indices into associated vectors than the - model's internal embedding lookup matrix. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see - `past_key_values`). - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. - cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): - Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, - this tensor is not affected by padding. It is used to update the cache in the correct position and to infer - the complete sequence length. -""" - - -@add_start_docstrings( - """The NEW_TASK_MODEL model which consists of a vision backbone and a language model.""", - NEW_TASK_MODEL_START_DOCSTRING, +@auto_docstring( + custom_intro=""" + The Base NewTaskModel model which consists of a vision backbone and a language model withou language modeling head., + """ ) -class NewTaskModelForNewTask(NewTaskModelPreTrainedModel, GenerationMixin): - main_input_name: ClassVar[str] = "doc_input_ids" # transformers-related +class NewTaskModelModel(NewTaskModelPreTrainedModel): + _checkpoint_conversion_mapping = {"language_model.model": "language_model"} - def __init__(self, config): + def __init__(self, config: NewTaskModelConfig): super().__init__(config) self.vision_tower = AutoModel.from_config(config=config.vision_config) self.multi_modal_projector = NewTaskModelMultiModalProjector(config) self.vocab_size = config.text_config.vocab_size - language_model = AutoModelForCausalLM.from_config(config=config.text_config) - - if language_model._tied_weights_keys is not None: - self._tied_weights_keys = [f"language_model.{k}" for k in language_model._tied_weights_keys] + language_model = AutoModel.from_config(config=config.text_config) self.language_model = language_model self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1 - - self.embedding_dim = self.config.embedding_dim - self.custom_text_proj = nn.Linear(self.config.text_config.hidden_size, self.embedding_dim) - - if self.language_model._tied_weights_keys is not None: - self._tied_weights_keys = [f"model.language_model.{k}" for k in self.language_model._tied_weights_keys] self.post_init() def get_input_embeddings(self): @@ -225,18 +157,6 @@ class NewTaskModelForNewTask(NewTaskModelPreTrainedModel, GenerationMixin): def set_input_embeddings(self, value): self.language_model.set_input_embeddings(value) - def get_output_embeddings(self): - return self.language_model.get_output_embeddings() - - def set_output_embeddings(self, new_embeddings): - self.language_model.set_output_embeddings(new_embeddings) - - def set_decoder(self, decoder): - self.language_model.set_decoder(decoder) - - def get_decoder(self): - return self.language_model.get_decoder() - def _update_causal_mask( self, attention_mask, @@ -321,8 +241,191 @@ class NewTaskModelForNewTask(NewTaskModelPreTrainedModel, GenerationMixin): image_features = image_features / (self.config.text_config.hidden_size**0.5) return image_features - @add_start_docstrings_to_model_forward(NEW_TASK_MODEL_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=NewTaskModelCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: torch.LongTensor = None, + pixel_values: torch.FloatTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None, + token_type_ids: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> Union[Tuple, NewTaskModelModelOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`. + + Example: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, NewTaskModelForConditionalGeneration + + >>> model = NewTaskModelForConditionalGeneration.from_pretrained("google/new_task_model2-3b-mix-224") + >>> processor = AutoProcessor.from_pretrained("google/new_task_model2-3b-mix-224") + + >>> prompt = "Where is the cat standing?" + >>> url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> inputs = processor(images=image, text=prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(**inputs,) + >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Where is the cat standing?\nsnow" + ```""" + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + is_training = token_type_ids is not None and labels is not None + + # Replace image id woth PAD if the image token if OOV, to avoid index-errors + if input_ids is not None and self.config.image_token_id >= self.vocab_size: + special_image_mask = input_ids == self.config.image_token_id + llm_input_ids = input_ids.clone() + llm_input_ids[special_image_mask] = 0 + else: + llm_input_ids = input_ids + + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(llm_input_ids) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + 1 # NewTaskModel positions are 1-indexed + + # Merge text and images + if pixel_values is not None: + image_features = self.get_image_features(pixel_values) + + if input_ids is None: + special_image_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + else: + special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1) + special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) + + if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel(): + image_tokens_in_text = (special_image_mask).sum(dim=1).sum(dim=0)[0] + raise ValueError( + f"Number of images does not match number of special image tokens in the input text. " + f"Got {image_tokens_in_text} image tokens in the text but {image_features.shape[0] * image_features.shape[1]} " + "tokens from image embeddings." + ) + image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) + + causal_mask = self._update_causal_mask( + attention_mask, token_type_ids, past_key_values, cache_position, inputs_embeds, is_training + ) + outputs = self.language_model( + attention_mask=causal_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=True, + cache_position=cache_position, + **kwargs, + ) + + return NewTaskModelModelOutputWithPast( + last_hidden_state=outputs.last_hidden_state, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + image_hidden_states=image_features if pixel_values is not None else None, + ) + + +@auto_docstring( + custom_intro=""" + The Base NewTaskModel model which consists of a vision backbone and a language model without language modeling head., + """ +) +class NewTaskModelForNewTask(NewTaskModelPreTrainedModel, GenerationMixin): + _checkpoint_conversion_mapping = { + "^language_model.model": "model.language_model", + "^vision_tower": "model.vision_tower", + "^multi_modal_projector": "model.multi_modal_projector", + "^language_model.lm_head": "lm_head", + } + _tied_weights_keys = ["lm_head.weight"] + main_input_name: ClassVar[str] = "doc_input_ids" # transformers-related + + def __init__(self, config): + super().__init__(config) + self.model = NewTaskModelModel(config) + self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) + + self.embedding_dim = self.config.embedding_dim + self.custom_text_proj = nn.Linear(self.config.text_config.hidden_size, self.embedding_dim) + + if self.language_model._tied_weights_keys is not None: + self._tied_weights_keys = [f"model.language_model.{k}" for k in self.language_model._tied_weights_keys] + self.post_init() + + def get_input_embeddings(self): + return self.model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.model.set_input_embeddings(value) + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + # Make modules available throught conditional class for BC + @property + def language_model(self): + return self.model.language_model + + @property + def vision_tower(self): + return self.model.vision_tower + + @property + def multi_modal_projector(self): + return self.model.multi_modal_projector + + @can_return_tuple + @auto_docstring def forward( self, input_ids: torch.LongTensor = None, @@ -341,19 +444,10 @@ class NewTaskModelForNewTask(NewTaskModelPreTrainedModel, GenerationMixin): num_logits_to_keep: int = 0, ) -> Union[Tuple, NewTaskModelCausalLMOutputWithPast]: r""" - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`. - - logits_to_keep (`int` or `torch.Tensor`, *optional*): - If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all - `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that - token can save memory, which becomes pretty significant for long sequences or large vocabulary size. - If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. - This is useful when using packed tensor format (single dimension for batch and sequence length). - - Returns: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`. Example: @@ -400,7 +494,8 @@ class NewTaskModelForNewTask(NewTaskModelPreTrainedModel, GenerationMixin): # L2 normalization embeddings = proj / proj.norm(dim=-1, keepdim=True) # (batch_size, sequence_length, dim) - embeddings = embeddings * attention_mask.unsqueeze(-1) # (batch_size, sequence_length, dim) + if attention_mask is not None: + embeddings = embeddings * attention_mask.unsqueeze(-1) # (batch_size, sequence_length, dim) return (embeddings,) + vlm_outputs @@ -420,7 +515,7 @@ class NewTaskModelForNewTask(NewTaskModelPreTrainedModel, GenerationMixin): **kwargs, ): # Overwritten -- custom `position_ids` and `pixel_values` handling - model_inputs = self.language_model.prepare_inputs_for_generation( + model_inputs = super().prepare_inputs_for_generation( input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, @@ -443,13 +538,68 @@ class NewTaskModelForNewTask(NewTaskModelPreTrainedModel, GenerationMixin): is_training = token_type_ids is not None and labels is not None if cache_position[0] == 0 and isinstance(past_key_values, HybridCache): input_tensor = inputs_embeds if inputs_embeds is not None else input_ids - causal_mask = self._update_causal_mask( + causal_mask = self.model._update_causal_mask( attention_mask, token_type_ids, past_key_values, cache_position, input_tensor, is_training ) model_inputs["attention_mask"] = causal_mask return model_inputs + @staticmethod + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + cache_position: torch.Tensor, + batch_size: int, + **kwargs, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape + `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, + to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( + causal_mask.device + ) + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + def resize_token_embeddings( self, new_num_tokens: Optional[int] = None, pad_to_multiple_of=None, mean_resizing=True ) -> nn.Embedding: diff --git a/examples/modular-transformers/modular_new_task_model.py b/examples/modular-transformers/modular_new_task_model.py index f1943e37e1f..53fa5a3f09c 100644 --- a/examples/modular-transformers/modular_new_task_model.py +++ b/examples/modular-transformers/modular_new_task_model.py @@ -65,7 +65,8 @@ class NewTaskModelForNewTask(PaliGemmaForConditionalGeneration): # L2 normalization embeddings = proj / proj.norm(dim=-1, keepdim=True) # (batch_size, sequence_length, dim) - embeddings = embeddings * attention_mask.unsqueeze(-1) # (batch_size, sequence_length, dim) + if attention_mask is not None: + embeddings = embeddings * attention_mask.unsqueeze(-1) # (batch_size, sequence_length, dim) return (embeddings,) + vlm_outputs diff --git a/src/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py b/src/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py index 90fc3362159..ef7bf6ff665 100644 --- a/src/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +++ b/src/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py @@ -384,6 +384,8 @@ class ASTPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _supports_sdpa = True _supports_flash_attn_2 = True + _supports_flex_attn = True + _supports_attention_backend = True def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None: """Initialize the weights""" diff --git a/src/transformers/models/aya_vision/modeling_aya_vision.py b/src/transformers/models/aya_vision/modeling_aya_vision.py index 519800f9285..1d723f1b5ae 100644 --- a/src/transformers/models/aya_vision/modeling_aya_vision.py +++ b/src/transformers/models/aya_vision/modeling_aya_vision.py @@ -97,6 +97,7 @@ class AyaVisionPreTrainedModel(PreTrainedModel): _supports_sdpa = True _supports_quantized_cache = False _supports_static_cache = False + _supports_flex_attn = True _supports_attention_backend = True def _init_weights(self, module): diff --git a/src/transformers/models/chameleon/modeling_chameleon.py b/src/transformers/models/chameleon/modeling_chameleon.py index c816fe61053..5e39d2515e0 100644 --- a/src/transformers/models/chameleon/modeling_chameleon.py +++ b/src/transformers/models/chameleon/modeling_chameleon.py @@ -830,6 +830,7 @@ class ChameleonPreTrainedModel(PreTrainedModel): _supports_cache_class = True _supports_static_cache = True _supports_param_buffer_assignment = False + _supports_flex_attn = True _supports_attention_backend = True def _init_weights(self, module): diff --git a/src/transformers/models/clip/modeling_clip.py b/src/transformers/models/clip/modeling_clip.py index 1ddf6bf8f79..d50efcb7556 100644 --- a/src/transformers/models/clip/modeling_clip.py +++ b/src/transformers/models/clip/modeling_clip.py @@ -450,6 +450,8 @@ class CLIPPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _supports_sdpa = True _supports_flash_attn_2 = True + _supports_flex_attn = True + _supports_attention_backend = True def _init_weights(self, module): """Initialize the weights""" diff --git a/src/transformers/models/colpali/modeling_colpali.py b/src/transformers/models/colpali/modeling_colpali.py index 69ba6e6643d..79cb24a8a50 100644 --- a/src/transformers/models/colpali/modeling_colpali.py +++ b/src/transformers/models/colpali/modeling_colpali.py @@ -162,7 +162,8 @@ class ColPaliForRetrieval(ColPaliPreTrainedModel): # L2 normalization embeddings = embeddings / embeddings.norm(dim=-1, keepdim=True) # (batch_size, sequence_length, dim) - embeddings = embeddings * attention_mask.unsqueeze(-1) # (batch_size, sequence_length, dim) + if attention_mask is not None: + embeddings = embeddings * attention_mask.unsqueeze(-1) # (batch_size, sequence_length, dim) return ColPaliForRetrievalOutput( embeddings=embeddings, diff --git a/src/transformers/models/deit/modeling_deit.py b/src/transformers/models/deit/modeling_deit.py index 1b2ba103657..05bacf81552 100644 --- a/src/transformers/models/deit/modeling_deit.py +++ b/src/transformers/models/deit/modeling_deit.py @@ -450,6 +450,8 @@ class DeiTPreTrainedModel(PreTrainedModel): _no_split_modules = ["DeiTLayer"] _supports_sdpa = True _supports_flash_attn_2 = True + _supports_flex_attn = True + _supports_attention_backend = True def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None: """Initialize the weights""" diff --git a/src/transformers/models/dinov2/modeling_dinov2.py b/src/transformers/models/dinov2/modeling_dinov2.py index 2fb1be118cb..30bf56d3524 100644 --- a/src/transformers/models/dinov2/modeling_dinov2.py +++ b/src/transformers/models/dinov2/modeling_dinov2.py @@ -494,6 +494,8 @@ class Dinov2PreTrainedModel(PreTrainedModel): _no_split_modules = ["Dinov2SwiGLUFFN"] _supports_sdpa = True _supports_flash_attn_2 = True + _supports_flex_attn = True + _supports_attention_backend = True def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None: """Initialize the weights""" diff --git a/src/transformers/models/dinov2_with_registers/modeling_dinov2_with_registers.py b/src/transformers/models/dinov2_with_registers/modeling_dinov2_with_registers.py index 35e5e02e86f..82d8169c38d 100644 --- a/src/transformers/models/dinov2_with_registers/modeling_dinov2_with_registers.py +++ b/src/transformers/models/dinov2_with_registers/modeling_dinov2_with_registers.py @@ -512,6 +512,8 @@ class Dinov2WithRegistersPreTrainedModel(PreTrainedModel): _no_split_modules = ["Dinov2WithRegistersSwiGLUFFN"] _supports_sdpa = True _supports_flash_attn_2 = True + _supports_flex_attn = True + _supports_attention_backend = True def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None: """Initialize the weights""" diff --git a/src/transformers/models/dpt/modeling_dpt.py b/src/transformers/models/dpt/modeling_dpt.py index d772ddb85b3..18cc873b552 100755 --- a/src/transformers/models/dpt/modeling_dpt.py +++ b/src/transformers/models/dpt/modeling_dpt.py @@ -826,6 +826,8 @@ class DPTPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _supports_sdpa = True _supports_flash_attn_2 = True + _supports_flex_attn = True + _supports_attention_backend = True def _init_weights(self, module): """Initialize the weights""" diff --git a/src/transformers/models/emu3/modeling_emu3.py b/src/transformers/models/emu3/modeling_emu3.py index 6540e9fc714..995e9cac7d6 100644 --- a/src/transformers/models/emu3/modeling_emu3.py +++ b/src/transformers/models/emu3/modeling_emu3.py @@ -1142,8 +1142,8 @@ class Emu3PreTrainedModel(PreTrainedModel): _supports_cache_class = True _supports_static_cache = True _supports_param_buffer_assignment = False - _supports_attention_backend = True _supports_flex_attn = True + _supports_attention_backend = True def _init_weights(self, module): std = self.config.get_text_config().initializer_range diff --git a/src/transformers/models/got_ocr2/modeling_got_ocr2.py b/src/transformers/models/got_ocr2/modeling_got_ocr2.py index 0d6b44214ba..fc3ab807df8 100644 --- a/src/transformers/models/got_ocr2/modeling_got_ocr2.py +++ b/src/transformers/models/got_ocr2/modeling_got_ocr2.py @@ -593,6 +593,7 @@ class GotOcr2PreTrainedModel(PreTrainedModel): _supports_sdpa = True _supports_quantized_cache = True _supports_static_cache = True + _supports_flex_attn = True _supports_attention_backend = True def _init_weights(self, module): diff --git a/src/transformers/models/ijepa/modeling_ijepa.py b/src/transformers/models/ijepa/modeling_ijepa.py index 956987290af..dafef6d5666 100644 --- a/src/transformers/models/ijepa/modeling_ijepa.py +++ b/src/transformers/models/ijepa/modeling_ijepa.py @@ -151,6 +151,8 @@ class IJepaPreTrainedModel(PreTrainedModel): _no_split_modules = ["IJepaEmbeddings", "IJepaLayer"] _supports_sdpa = True _supports_flash_attn_2 = True + _supports_flex_attn = True + _supports_attention_backend = True def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None: """Initialize the weights""" diff --git a/src/transformers/models/ijepa/modular_ijepa.py b/src/transformers/models/ijepa/modular_ijepa.py index cf31e6ec859..2f6064acbb1 100644 --- a/src/transformers/models/ijepa/modular_ijepa.py +++ b/src/transformers/models/ijepa/modular_ijepa.py @@ -95,6 +95,8 @@ class IJepaPreTrainedModel(PreTrainedModel): _no_split_modules = ["IJepaEmbeddings", "IJepaLayer"] _supports_sdpa = True _supports_flash_attn_2 = True + _supports_flex_attn = True + _supports_attention_backend = True def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None: """Initialize the weights""" diff --git a/src/transformers/models/internvl/modeling_internvl.py b/src/transformers/models/internvl/modeling_internvl.py index b2e539e7bce..4c747a2394e 100644 --- a/src/transformers/models/internvl/modeling_internvl.py +++ b/src/transformers/models/internvl/modeling_internvl.py @@ -178,6 +178,8 @@ class InternVLVisionPreTrainedModel(PreTrainedModel): _no_split_modules = ["InternVLVisionLayer"] _supports_sdpa = True _supports_flash_attn_2 = True + _supports_flex_attn = True + _supports_attention_backend = True def _init_weights(self, module): """Initialize the weights""" @@ -537,6 +539,7 @@ class InternVLPreTrainedModel(PreTrainedModel): _supports_sdpa = True _supports_quantized_cache = True _supports_static_cache = True + _supports_flex_attn = True _supports_attention_backend = True def _init_weights(self, module): diff --git a/src/transformers/models/internvl/modular_internvl.py b/src/transformers/models/internvl/modular_internvl.py index fcf4958f623..248f58f3b78 100644 --- a/src/transformers/models/internvl/modular_internvl.py +++ b/src/transformers/models/internvl/modular_internvl.py @@ -140,6 +140,8 @@ class InternVLVisionPreTrainedModel(PreTrainedModel): _no_split_modules = ["InternVLVisionLayer"] _supports_sdpa = True _supports_flash_attn_2 = True + _supports_flex_attn = True + _supports_attention_backend = True def _init_weights(self, module): """Initialize the weights""" diff --git a/src/transformers/models/llava/modeling_llava.py b/src/transformers/models/llava/modeling_llava.py index 0ab2333d1ba..4321dea59d9 100644 --- a/src/transformers/models/llava/modeling_llava.py +++ b/src/transformers/models/llava/modeling_llava.py @@ -141,6 +141,7 @@ class LlavaPreTrainedModel(PreTrainedModel): _supports_sdpa = True _supports_quantized_cache = True _supports_static_cache = True + _supports_flex_attn = True _supports_attention_backend = True def _init_weights(self, module): diff --git a/src/transformers/models/llava_next/modeling_llava_next.py b/src/transformers/models/llava_next/modeling_llava_next.py index bd1e1aae253..d0ea1bb233c 100644 --- a/src/transformers/models/llava_next/modeling_llava_next.py +++ b/src/transformers/models/llava_next/modeling_llava_next.py @@ -252,6 +252,7 @@ class LlavaNextPreTrainedModel(PreTrainedModel): _supports_sdpa = True _supports_quantized_cache = True _supports_static_cache = True + _supports_flex_attn = True _supports_attention_backend = True def _init_weights(self, module): diff --git a/src/transformers/models/llava_next_video/modeling_llava_next_video.py b/src/transformers/models/llava_next_video/modeling_llava_next_video.py index 20b1647c0bd..368fd09ef32 100644 --- a/src/transformers/models/llava_next_video/modeling_llava_next_video.py +++ b/src/transformers/models/llava_next_video/modeling_llava_next_video.py @@ -195,6 +195,7 @@ class LlavaNextVideoPreTrainedModel(PreTrainedModel): _supports_sdpa = True _supports_quantized_cache = True _supports_static_cache = True + _supports_flex_attn = True _supports_attention_backend = True def _init_weights(self, module): diff --git a/src/transformers/models/llava_onevision/modeling_llava_onevision.py b/src/transformers/models/llava_onevision/modeling_llava_onevision.py index a43674d2942..5a1157edebf 100644 --- a/src/transformers/models/llava_onevision/modeling_llava_onevision.py +++ b/src/transformers/models/llava_onevision/modeling_llava_onevision.py @@ -308,6 +308,7 @@ class LlavaOnevisionPreTrainedModel(PreTrainedModel): _supports_sdpa = True _supports_quantized_cache = True _supports_static_cache = True + _supports_flex_attn = True _supports_attention_backend = True def _init_weights(self, module): diff --git a/src/transformers/models/mistral3/modeling_mistral3.py b/src/transformers/models/mistral3/modeling_mistral3.py index c5a5205947f..082020b3afd 100644 --- a/src/transformers/models/mistral3/modeling_mistral3.py +++ b/src/transformers/models/mistral3/modeling_mistral3.py @@ -206,6 +206,7 @@ class Mistral3PreTrainedModel(PreTrainedModel): _supports_sdpa = True _supports_quantized_cache = True _supports_static_cache = True + _supports_flex_attn = True _supports_attention_backend = True def _init_weights(self, module): diff --git a/src/transformers/models/mllama/modeling_mllama.py b/src/transformers/models/mllama/modeling_mllama.py index c28671c81f8..4c73295cab6 100644 --- a/src/transformers/models/mllama/modeling_mllama.py +++ b/src/transformers/models/mllama/modeling_mllama.py @@ -860,6 +860,7 @@ class MllamaPreTrainedModel(PreTrainedModel): _supports_sdpa = True _supports_flash_attn_2 = True _supports_quantized_cache = True + _supports_flex_attn = True _supports_attention_backend = True def _init_weights(self, module): diff --git a/src/transformers/models/paligemma/modeling_paligemma.py b/src/transformers/models/paligemma/modeling_paligemma.py index 4e508ef3331..addba2b30fe 100644 --- a/src/transformers/models/paligemma/modeling_paligemma.py +++ b/src/transformers/models/paligemma/modeling_paligemma.py @@ -131,6 +131,7 @@ class PaliGemmaPreTrainedModel(PreTrainedModel): _supports_static_cache = True _supports_flash_attn_2 = True _supports_sdpa = True + _supports_flex_attn = True _supports_attention_backend = True def _init_weights(self, module): diff --git a/src/transformers/models/phi4_multimodal/modeling_phi4_multimodal.py b/src/transformers/models/phi4_multimodal/modeling_phi4_multimodal.py index 858666b9f5a..d484f9255a8 100644 --- a/src/transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +++ b/src/transformers/models/phi4_multimodal/modeling_phi4_multimodal.py @@ -375,6 +375,7 @@ class Phi4MultimodalVisionPreTrainedModel(PreTrainedModel): _supports_flash_attn_2 = True _supports_sdpa = True _supports_flex_attn = True + _supports_attention_backend = True def _init_weights(self, module): """Initialize the weights""" diff --git a/src/transformers/models/siglip/modeling_siglip.py b/src/transformers/models/siglip/modeling_siglip.py index d6966c31f66..fa5e3f6dc84 100644 --- a/src/transformers/models/siglip/modeling_siglip.py +++ b/src/transformers/models/siglip/modeling_siglip.py @@ -517,6 +517,8 @@ class SiglipPreTrainedModel(PreTrainedModel): ] _supports_flash_attn_2 = True _supports_sdpa = True + _supports_flex_attn = True + _supports_attention_backend = True def _init_weights(self, module): """Initialize the weights""" diff --git a/src/transformers/models/siglip2/modeling_siglip2.py b/src/transformers/models/siglip2/modeling_siglip2.py index a198cbc347f..eb3bf5d4a34 100644 --- a/src/transformers/models/siglip2/modeling_siglip2.py +++ b/src/transformers/models/siglip2/modeling_siglip2.py @@ -749,6 +749,8 @@ class Siglip2PreTrainedModel(PreTrainedModel): ] _supports_flash_attn_2 = True _supports_sdpa = True + _supports_flex_attn = True + _supports_attention_backend = True def _init_weights(self, module): """Initialize the weights""" diff --git a/src/transformers/models/videomae/modeling_videomae.py b/src/transformers/models/videomae/modeling_videomae.py index f8af7b57550..e6a114a522b 100755 --- a/src/transformers/models/videomae/modeling_videomae.py +++ b/src/transformers/models/videomae/modeling_videomae.py @@ -491,6 +491,8 @@ class VideoMAEPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _supports_sdpa = True _supports_flash_attn_2 = True + _supports_flex_attn = True + _supports_attention_backend = True def _init_weights(self, module): """Initialize the weights""" diff --git a/src/transformers/models/vipllava/modeling_vipllava.py b/src/transformers/models/vipllava/modeling_vipllava.py index 70abf26ee20..c4a20aef914 100644 --- a/src/transformers/models/vipllava/modeling_vipllava.py +++ b/src/transformers/models/vipllava/modeling_vipllava.py @@ -142,6 +142,7 @@ class VipLlavaPreTrainedModel(PreTrainedModel): _supports_sdpa = True _supports_quantized_cache = True _supports_static_cache = True + _supports_flex_attn = True _supports_attention_backend = True def _init_weights(self, module): diff --git a/src/transformers/models/vit/modeling_vit.py b/src/transformers/models/vit/modeling_vit.py index 6cae97942fc..e8a9958fc5b 100644 --- a/src/transformers/models/vit/modeling_vit.py +++ b/src/transformers/models/vit/modeling_vit.py @@ -448,6 +448,8 @@ class ViTPreTrainedModel(PreTrainedModel): _no_split_modules = ["ViTEmbeddings", "ViTLayer"] _supports_sdpa = True _supports_flash_attn_2 = True + _supports_flex_attn = True + _supports_attention_backend = True def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None: """Initialize the weights""" diff --git a/src/transformers/models/vit_mae/modeling_vit_mae.py b/src/transformers/models/vit_mae/modeling_vit_mae.py index 541015c7da5..326f193713f 100755 --- a/src/transformers/models/vit_mae/modeling_vit_mae.py +++ b/src/transformers/models/vit_mae/modeling_vit_mae.py @@ -633,6 +633,8 @@ class ViTMAEPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _supports_sdpa = True _supports_flash_attn_2 = True + _supports_flex_attn = True + _supports_attention_backend = True def _init_weights(self, module): """Initialize the weights""" diff --git a/src/transformers/models/vit_msn/modeling_vit_msn.py b/src/transformers/models/vit_msn/modeling_vit_msn.py index 16d61b95324..7684eb2b277 100644 --- a/src/transformers/models/vit_msn/modeling_vit_msn.py +++ b/src/transformers/models/vit_msn/modeling_vit_msn.py @@ -452,6 +452,8 @@ class ViTMSNPreTrainedModel(PreTrainedModel): _no_split_modules = ["ViTMSNAttention", "ViTMSNSdpaAttention"] _supports_sdpa = True _supports_flash_attn_2 = True + _supports_flex_attn = True + _supports_attention_backend = True # todo: Resort to https://github.com/facebookresearch/msn/blob/main/src/deit.py#L200-#L211 # when creating pre-training scripts. diff --git a/src/transformers/models/vivit/modeling_vivit.py b/src/transformers/models/vivit/modeling_vivit.py index 15097bd69a7..2ca76996ad1 100755 --- a/src/transformers/models/vivit/modeling_vivit.py +++ b/src/transformers/models/vivit/modeling_vivit.py @@ -456,6 +456,8 @@ class VivitPreTrainedModel(PreTrainedModel): _no_split_modules = [] _supports_sdpa = True _supports_flash_attn_2 = True + _supports_flex_attn = True + _supports_attention_backend = True def _init_weights(self, module): """Initialize the weights""" diff --git a/src/transformers/models/yolos/modeling_yolos.py b/src/transformers/models/yolos/modeling_yolos.py index 2443591f633..a42982fb932 100755 --- a/src/transformers/models/yolos/modeling_yolos.py +++ b/src/transformers/models/yolos/modeling_yolos.py @@ -532,6 +532,8 @@ class YolosPreTrainedModel(PreTrainedModel): _no_split_modules = [] _supports_sdpa = True _supports_flash_attn_2 = True + _supports_flex_attn = True + _supports_attention_backend = True def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None: """Initialize the weights""" diff --git a/tests/models/clip/test_modeling_clip.py b/tests/models/clip/test_modeling_clip.py index 739ae1f5903..82e04fc4541 100644 --- a/tests/models/clip/test_modeling_clip.py +++ b/tests/models/clip/test_modeling_clip.py @@ -535,6 +535,7 @@ class CLIPModelTest(CLIPModelTesterMixin, PipelineTesterMixin, unittest.TestCase pipeline_model_mapping = ( {"feature-extraction": CLIPModel, "image-feature-extraction": CLIPVisionModel} if is_torch_available() else {} ) + additional_model_inputs = ["pixel_values"] fx_compatible = True test_head_masking = False test_pruning = False diff --git a/tests/models/emu3/test_modeling_emu3.py b/tests/models/emu3/test_modeling_emu3.py index ddbb8d02d9a..bec6f0fc1fb 100644 --- a/tests/models/emu3/test_modeling_emu3.py +++ b/tests/models/emu3/test_modeling_emu3.py @@ -401,10 +401,6 @@ class Emu3Vision2TextModelTest(ModelTesterMixin, GenerationTesterMixin, Pipeline def test_generate_with_static_cache(self): pass - @unittest.skip("Emu3 doesn't support Flex attn yet!") - def test_flex_attention_with_grads(self): - pass - @require_torch class Emu3IntegrationTest(unittest.TestCase): diff --git a/tests/models/gemma3/test_modeling_gemma3.py b/tests/models/gemma3/test_modeling_gemma3.py index 7c3b59e40a9..7fed1157756 100644 --- a/tests/models/gemma3/test_modeling_gemma3.py +++ b/tests/models/gemma3/test_modeling_gemma3.py @@ -351,12 +351,6 @@ class Gemma3Vision2TextModelTest(ModelTesterMixin, GenerationTesterMixin, unitte def test_initialization(self): pass - @unittest.skip( - reason="Siglip has no FLEX attention, and we don't have a proper way to set/test attn in VLMs. TODO @raushan" - ) - def test_flex_attention_with_grads(self): - pass - def test_automodelforcausallm(self): """ Regression test for #36741/#36917 -- make sure `AutoModelForCausalLM` works with a Gemma3 config, i.e. that diff --git a/tests/models/got_ocr2/test_modeling_got_ocr2.py b/tests/models/got_ocr2/test_modeling_got_ocr2.py index e4706b6db0c..047d4a0da9b 100644 --- a/tests/models/got_ocr2/test_modeling_got_ocr2.py +++ b/tests/models/got_ocr2/test_modeling_got_ocr2.py @@ -236,10 +236,6 @@ class GotOcr2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi def test_past_key_values_format(self): pass - @unittest.skip(reason="Vision backbone doesn't support FLEX yet!") - def test_flex_attention_with_grads(self): - pass - @require_torch class GotOcr2IntegrationTest(unittest.TestCase): diff --git a/tests/models/musicgen/test_modeling_musicgen.py b/tests/models/musicgen/test_modeling_musicgen.py index 1a27192506f..b051fa9657b 100644 --- a/tests/models/musicgen/test_modeling_musicgen.py +++ b/tests/models/musicgen/test_modeling_musicgen.py @@ -569,6 +569,8 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, all_generative_model_classes = () greedy_sample_model_classes = (MusicgenForConditionalGeneration,) if is_torch_available() else () pipeline_model_mapping = {"text-to-audio": MusicgenForConditionalGeneration} if is_torch_available() else {} + # Addition keys that are required for forward. MusicGen isn't encoder-decoder in config so we have to pass decoder ids as additional + additional_model_inputs = ["decoder_input_ids"] test_pruning = False # training is not supported yet for MusicGen test_headmasking = False test_resize_embeddings = False diff --git a/tests/models/musicgen_melody/test_modeling_musicgen_melody.py b/tests/models/musicgen_melody/test_modeling_musicgen_melody.py index abf2edd1ce3..86fe12b324a 100644 --- a/tests/models/musicgen_melody/test_modeling_musicgen_melody.py +++ b/tests/models/musicgen_melody/test_modeling_musicgen_melody.py @@ -589,6 +589,8 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester all_generative_model_classes = () greedy_sample_model_classes = (MusicgenMelodyForConditionalGeneration,) if is_torch_available() else () pipeline_model_mapping = {"text-to-audio": MusicgenMelodyForConditionalGeneration} if is_torch_available() else {} + # Addition keys that are required for forward. MusicGen isn't encoder-decoder in config so we have to pass decoder ids as additional + additional_model_inputs = ["decoder_input_ids"] test_pruning = False # training is not supported yet for MusicGen test_headmasking = False test_resize_embeddings = False diff --git a/tests/models/siglip/test_modeling_siglip.py b/tests/models/siglip/test_modeling_siglip.py index 419bb3efa51..8a5b1037eb6 100644 --- a/tests/models/siglip/test_modeling_siglip.py +++ b/tests/models/siglip/test_modeling_siglip.py @@ -103,7 +103,7 @@ class SiglipVisionModelTester: patch_size=2, num_channels=3, is_training=True, - hidden_size=32, + hidden_size=64, num_hidden_layers=2, num_attention_heads=4, intermediate_size=37, @@ -274,7 +274,7 @@ class SiglipTextModelTester: use_input_mask=True, use_labels=True, vocab_size=99, - hidden_size=32, + hidden_size=64, num_hidden_layers=2, num_attention_heads=4, intermediate_size=37, diff --git a/tests/models/siglip2/test_modeling_siglip2.py b/tests/models/siglip2/test_modeling_siglip2.py index 5c714205f72..b1c005a3cc1 100644 --- a/tests/models/siglip2/test_modeling_siglip2.py +++ b/tests/models/siglip2/test_modeling_siglip2.py @@ -180,7 +180,7 @@ class Siglip2VisionModelTester: patch_size=2, num_channels=3, is_training=True, - hidden_size=32, + hidden_size=64, num_hidden_layers=2, num_attention_heads=4, intermediate_size=37, @@ -363,7 +363,7 @@ class Siglip2TextModelTester: use_input_mask=True, use_labels=True, vocab_size=99, - hidden_size=32, + hidden_size=64, num_hidden_layers=2, num_attention_heads=4, intermediate_size=37, diff --git a/tests/models/videomae/test_modeling_videomae.py b/tests/models/videomae/test_modeling_videomae.py index 1b8b08d14b6..0b85c31f8a9 100644 --- a/tests/models/videomae/test_modeling_videomae.py +++ b/tests/models/videomae/test_modeling_videomae.py @@ -190,7 +190,8 @@ class VideoMAEModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase if is_torch_available() else {} ) - + # Addition keys that are required for forward, used in tests where we manipulate and create new input dict from scratch + additional_model_inputs = ["bool_masked_pos"] test_pruning = False test_torchscript = False test_resize_embeddings = False diff --git a/tests/models/vipllava/test_modeling_vipllava.py b/tests/models/vipllava/test_modeling_vipllava.py index 79b57e12ffa..07d9ab3c53e 100644 --- a/tests/models/vipllava/test_modeling_vipllava.py +++ b/tests/models/vipllava/test_modeling_vipllava.py @@ -322,10 +322,6 @@ class VipLlavaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTest def test_flash_attention_2_padding_matches_padding_free_with_position_ids(self): pass - @unittest.skip("LLaVA vision backbones doesn't support flex attention yet") - def test_flex_attention_with_grads(self): - pass - @require_torch class VipLlavaForConditionalGenerationIntegrationTest(unittest.TestCase): diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 9f56c79956e..441c99267b6 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -3637,7 +3637,10 @@ class ModelTesterMixin: processed_inputs[model.main_input_name] = inputs_dict[model.main_input_name] for key in getattr(self, "additional_model_inputs", []): - processed_inputs[key] = inputs_dict[key] + # Some models don't have all `additional_model_inputs`, especially when we + # craft cases to test model in different settings + if key in inputs_dict: + processed_inputs[key] = inputs_dict[key] for key, value in processed_inputs.items(): if torch.is_floating_point(value): @@ -4012,19 +4015,21 @@ class ModelTesterMixin: model = model_class.from_pretrained(tmpdirname, torch_dtype=torch_dtype) sub_models_supporting_fa2 = [ - (module._supports_flash_attn_2 or module._supports_attention_backend) + module._supports_flash_attn_2 for name, module in model.named_modules() if isinstance(module, PreTrainedModel) and name != "" ] supports_fa2_all_modules = ( all(sub_models_supporting_fa2) if len(sub_models_supporting_fa2) > 0 - else (model._supports_flash_attn_2 or model._supports_attention_backend) + else model._supports_flash_attn_2 ) if not supports_fa2_all_modules: with self.assertRaises(ValueError): model_fa2 = model_class.from_pretrained( - tmpdirname, torch_dtype=torch_dtype, attn_implementation="flash_attention_2" + tmpdirname, + torch_dtype=torch_dtype, + attn_implementation="flash_attention_2", ) else: model_fa2 = model_class.from_pretrained( @@ -4572,33 +4577,73 @@ class ModelTesterMixin: @require_torch_gpu def test_flex_attention_with_grads(self): for model_class in self.all_model_classes: - # TODO: raushan, fix for composite models after making VLMs support new attn API - if not model_class._supports_flex_attn or self._is_composite: - self.skipTest(reason="This model does not support flex attention") - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - config._attn_implementation = "flex_attention" - # Flex Attention cannot use dropout - if hasattr(config, "attention_dropout"): - config.attention_dropout = 0 - if hasattr(config, "attention_probs_dropout_prob"): - config.attention_probs_dropout_prob = 0 + inputs_dict = self._prepare_for_class(inputs_dict, model_class) + model = model_class(config).to(device=torch_device) - # Flex attention relies on triton on compilation - # However, triton cannot handle hidden dimensions of less than 16 - # --> forcing at least a hidden dim of 16 - config.hidden_size *= max( - 16 // getattr(config, "head_dim", config.hidden_size // config.num_attention_heads), 1 + # If not all sub-models support flex, skip the test + sub_models_supporting_flex = [ + module._supports_flex_attn + for name, module in model.named_modules() + if isinstance(module, PreTrainedModel) and name != "" + ] + supports_flex_all_modules = (all(sub_models_supporting_flex) and len(sub_models_supporting_flex) > 0) or ( + model._supports_flex_attn and len(sub_models_supporting_flex) == 0 ) - if hasattr(config, "head_dim"): - config.head_dim = max(16, config.head_dim) + if not supports_flex_all_modules: + self.skipTest(reason="This model's submodels does not support flex attention") + def update_config_for_flex(config): + # Flex Attention cannot use dropout + if hasattr(config, "attention_dropout"): + config.attention_dropout = 0 + if hasattr(config, "attention_probs_dropout_prob"): + config.attention_probs_dropout_prob = 0 + + # Flex attention relies on triton on compilation + # However, triton cannot handle hidden dimensions of less than 16 + # --> forcing at least a hidden dim of 16 + + # Update the head dim and try to update hidden size as well if present in config + # NOTE: some models may have none if the values in sub-config, thus we check for `Noneness` + head_dim = None + if hasattr(config, "head_dim") and config.head_dim is not None: + head_dim = config.head_dim + config.head_dim = max(16, config.head_dim) + + if ( + getattr(config, "hidden_size", None) is not None + and getattr(config, "num_attention_heads", None) is not None + ): + head_dim = head_dim if head_dim is not None else config.hidden_size // config.num_attention_heads + config.hidden_size *= max(16 // head_dim, 1) + + if ( + getattr(config, "decoder_hidden_size", None) is not None + and getattr(config, "decoder_num_attention_heads", None) is not None + ): + decoder_head_dim = config.decoder_hidden_size // config.decoder_num_attention_heads + config.decoder_hidden_size *= max(16 // decoder_head_dim, 1) + + # Set default attention to flex and update config values + update_config_for_flex(config) + for key in config.sub_configs: + sub_config = getattr(config, key) + update_config_for_flex(sub_config) + + config._attn_implementation = "flex_attention" model = model_class(config).to(device=torch_device) self.assertTrue(model.config._attn_implementation == "flex_attention") # Elaborate workaround for encoder-decoder models as some do not specify their main input dummy_inputs = {model.main_input_name: inputs_dict[model.main_input_name].to(torch_device)} - if config.is_encoder_decoder: + for key in getattr(self, "additional_model_inputs", []): + # Some models don't have all `additional_model_inputs`, especially when we + # craft cases to test model in different settings + if key in inputs_dict: + dummy_inputs[key] = inputs_dict[key].to(torch_device) + + if config.get_text_config(decoder=True).is_encoder_decoder: dummy_inputs["decoder_input_ids"] = inputs_dict["decoder_input_ids"].to(torch_device) dummy_inputs["decoder_attention_mask"] = inputs_dict["decoder_attention_mask"].to(torch_device)