From d23aae2b8c8738a12ab1b6710e60ae5866beaf9d Mon Sep 17 00:00:00 2001 From: Raushan Turganbay Date: Thu, 8 May 2025 18:18:54 +0200 Subject: [PATCH] [VLMs] support attention backends (#37576) * update models * why rename * return attn weights when sdpa * fixes * fix attn implementation composite * fix moshi * add message * add typings * use explicitly all flags for each attn type * fix some tests * import what is needed * kosmos on main has ew attention already, yay * new models in main, run fixup * won't fix kosmos yet * fix-copies * clean up after rebasing * fix tests * style * dont cast attns to fp32 * did we update ruff? oke, let's just do what it asks * fix pixtral after rebase --- src/transformers/configuration_utils.py | 3 + src/transformers/modeling_utils.py | 3 - src/transformers/models/aria/modeling_aria.py | 12 +- src/transformers/models/aria/modular_aria.py | 17 +- .../models/aya_vision/modeling_aya_vision.py | 23 +- .../models/aya_vision/modular_aya_vision.py | 6 +- .../models/blip_2/modeling_blip_2.py | 115 +++-- .../models/chameleon/modeling_chameleon.py | 316 +++---------- src/transformers/models/emu3/modeling_emu3.py | 15 +- src/transformers/models/emu3/modular_emu3.py | 18 +- src/transformers/models/fuyu/modeling_fuyu.py | 31 +- .../models/got_ocr2/modeling_got_ocr2.py | 20 +- .../models/got_ocr2/modular_got_ocr2.py | 16 +- .../models/idefics/modeling_idefics.py | 129 +++--- .../models/idefics2/modeling_idefics2.py | 45 +- .../models/idefics3/modeling_idefics3.py | 45 +- .../instructblip/modeling_instructblip.py | 124 +++-- .../modeling_instructblipvideo.py | 126 ++++-- .../modular_instructblipvideo.py | 38 +- .../models/internvl/modeling_internvl.py | 23 +- .../models/kosmos2/modeling_kosmos2.py | 121 +++-- .../models/llama4/modeling_llama4.py | 30 +- .../models/llava/modeling_llava.py | 25 +- .../models/llava_next/modeling_llava_next.py | 23 +- .../modeling_llava_next_video.py | 25 +- .../modular_llava_next_video.py | 20 +- .../modeling_llava_onevision.py | 24 +- .../modular_llava_onevision.py | 20 +- .../models/mistral3/modeling_mistral3.py | 288 ++++++------ .../models/mistral3/modular_mistral3.py | 21 +- .../models/mllama/modeling_mllama.py | 426 ++++++------------ .../models/moshi/configuration_moshi.py | 2 +- .../models/moshi/modeling_moshi.py | 2 +- src/transformers/models/opt/modeling_opt.py | 273 +++-------- .../models/paligemma/modeling_paligemma.py | 23 +- .../models/pixtral/modeling_pixtral.py | 49 +- .../models/smolvlm/modeling_smolvlm.py | 205 ++++----- .../models/smolvlm/modular_smolvlm.py | 11 +- .../video_llava/modeling_video_llava.py | 23 +- .../models/vipllava/modeling_vipllava.py | 1 + tests/models/blip_2/test_modeling_blip_2.py | 57 +-- .../test_modeling_instructblip.py | 28 +- .../test_modeling_instructblipvideo.py | 28 +- tests/models/kosmos2/test_modeling_kosmos2.py | 11 + tests/models/opt/test_modeling_opt.py | 3 +- tests/models/pixtral/test_modeling_pixtral.py | 1 + tests/test_modeling_common.py | 8 +- 47 files changed, 1318 insertions(+), 1555 deletions(-) diff --git a/src/transformers/configuration_utils.py b/src/transformers/configuration_utils.py index 80c100563a6..1c0bff6cf39 100755 --- a/src/transformers/configuration_utils.py +++ b/src/transformers/configuration_utils.py @@ -833,6 +833,7 @@ class PretrainedConfig(PushToHubMixin): if "model_type" in value: # Needs to be set even if it's not in the diff diff["model_type"] = value["model_type"] + serializable_config_dict[key] = diff elif ( key not in default_config_dict @@ -1003,6 +1004,8 @@ class PretrainedConfig(PushToHubMixin): del d["_commit_hash"] if "_attn_implementation_internal" in d: del d["_attn_implementation_internal"] + if "_attn_implementation_autoset" in d: + del d["_attn_implementation_autoset"] # Do not serialize `base_model_tp_plan` for now if "base_model_tp_plan" in d: del d["base_model_tp_plan"] diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 344990c2c9f..cdaf68d7616 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -3430,9 +3430,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi # Attach architecture to the config model_to_save.config.architectures = [model_to_save.__class__.__name__] - # Unset attn implementation so it can be set to another one when loading back - model_to_save.config._attn_implementation_autoset = False - # If we have a custom model, we copy the file defining it in the folder and set the attributes so it can be # loaded from the Hub. if self._auto_class is not None: diff --git a/src/transformers/models/aria/modeling_aria.py b/src/transformers/models/aria/modeling_aria.py index 31e0980fdeb..bef0e2cec15 100644 --- a/src/transformers/models/aria/modeling_aria.py +++ b/src/transformers/models/aria/modeling_aria.py @@ -669,6 +669,7 @@ class AriaTextPreTrainedModel(PreTrainedModel): _supports_flash_attn_2 = False _supports_sdpa = True _supports_cache_class = True + _supports_attention_backend = True def _init_weights(self, module): std = self.config.initializer_range @@ -1409,6 +1410,7 @@ class AriaModel(AriaPreTrainedModel): image_features = self.multi_modal_projector(selected_image_feature, attn_mask=image_attn_mask) return image_features + @can_return_tuple @add_start_docstrings_to_model_forward(ARIA_INPUTS_DOCSTRING) def forward( self, @@ -1424,6 +1426,7 @@ class AriaModel(AriaPreTrainedModel): output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], ) -> Union[Tuple, AriaModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -1470,16 +1473,16 @@ class AriaModel(AriaPreTrainedModel): output_hidden_states=output_hidden_states, return_dict=True, cache_position=cache_position, + **kwargs, ) - output = AriaModelOutputWithPast( + return AriaModelOutputWithPast( last_hidden_state=outputs.last_hidden_state, past_key_values=outputs.past_key_values if use_cache else None, hidden_states=outputs.hidden_states, attentions=outputs.attentions, image_hidden_states=image_features if pixel_values is not None else None, ) - return output if return_dict else output.to_tuple() def _create_patch_attention_mask(self, pixel_mask): if pixel_mask is None: @@ -1563,7 +1566,7 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin): return_dict: Optional[bool] = None, logits_to_keep: Union[int, torch.Tensor] = 0, cache_position: Optional[torch.LongTensor] = None, - **loss_kwargs, + **kwargs: Unpack[KwargsForCausalLM], ) -> Union[Tuple, AriaCausalLMOutputWithPast]: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): @@ -1645,6 +1648,7 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin): output_hidden_states=output_hidden_states, return_dict=return_dict, cache_position=cache_position, + **kwargs, ) hidden_states = outputs[0] @@ -1655,7 +1659,7 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin): loss = None if labels is not None: loss = self.loss_function( - logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **loss_kwargs + logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs ) return AriaCausalLMOutputWithPast( diff --git a/src/transformers/models/aria/modular_aria.py b/src/transformers/models/aria/modular_aria.py index a42c7227772..e6beaf89197 100644 --- a/src/transformers/models/aria/modular_aria.py +++ b/src/transformers/models/aria/modular_aria.py @@ -32,6 +32,7 @@ from ...image_utils import ( valid_images, validate_preprocess_arguments, ) +from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import CausalLMOutputWithPast from ...modeling_utils import PreTrainedModel from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack @@ -40,6 +41,7 @@ from ...tokenization_utils import ( TextInput, ) from ...utils import ( + LossKwargs, TensorType, add_start_docstrings, add_start_docstrings_to_model_forward, @@ -1240,6 +1242,7 @@ class AriaTextPreTrainedModel(PreTrainedModel): _supports_flash_attn_2 = False _supports_sdpa = True _supports_cache_class = True + _supports_attention_backend = True def _init_weights(self, module): std = self.config.initializer_range @@ -1290,6 +1293,9 @@ class AriaTextModel(LlamaModel): self.post_init() +class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... + + class AriaTextForCausalLM(AriaTextPreTrainedModel, LlamaForCausalLM): """ Aria model for causal language modeling tasks. @@ -1434,6 +1440,7 @@ class AriaModel(LlavaModel): image_features = self.multi_modal_projector(selected_image_feature, attn_mask=image_attn_mask) return image_features + @can_return_tuple @add_start_docstrings_to_model_forward(ARIA_INPUTS_DOCSTRING) def forward( self, @@ -1449,6 +1456,7 @@ class AriaModel(LlavaModel): output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], ) -> Union[Tuple, AriaModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -1495,16 +1503,16 @@ class AriaModel(LlavaModel): output_hidden_states=output_hidden_states, return_dict=True, cache_position=cache_position, + **kwargs, ) - output = AriaModelOutputWithPast( + return AriaModelOutputWithPast( last_hidden_state=outputs.last_hidden_state, past_key_values=outputs.past_key_values if use_cache else None, hidden_states=outputs.hidden_states, attentions=outputs.attentions, image_hidden_states=image_features if pixel_values is not None else None, ) - return output if return_dict else output.to_tuple() @add_start_docstrings( @@ -1533,7 +1541,7 @@ class AriaForConditionalGeneration(LlavaForConditionalGeneration): return_dict: Optional[bool] = None, logits_to_keep: Union[int, torch.Tensor] = 0, cache_position: Optional[torch.LongTensor] = None, - **loss_kwargs, + **kwargs: Unpack[KwargsForCausalLM], ) -> Union[Tuple, AriaCausalLMOutputWithPast]: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): @@ -1615,6 +1623,7 @@ class AriaForConditionalGeneration(LlavaForConditionalGeneration): output_hidden_states=output_hidden_states, return_dict=return_dict, cache_position=cache_position, + **kwargs, ) hidden_states = outputs[0] @@ -1625,7 +1634,7 @@ class AriaForConditionalGeneration(LlavaForConditionalGeneration): loss = None if labels is not None: loss = self.loss_function( - logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **loss_kwargs + logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs ) return AriaCausalLMOutputWithPast( diff --git a/src/transformers/models/aya_vision/modeling_aya_vision.py b/src/transformers/models/aya_vision/modeling_aya_vision.py index 042ff0c05a7..c700a971294 100644 --- a/src/transformers/models/aya_vision/modeling_aya_vision.py +++ b/src/transformers/models/aya_vision/modeling_aya_vision.py @@ -27,9 +27,12 @@ from torch import nn from ...activations import ACT2FN from ...generation import GenerationMixin +from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import BaseModelOutputWithPast, ModelOutput from ...modeling_utils import PreTrainedModel +from ...processing_utils import Unpack from ...utils import ( + LossKwargs, add_start_docstrings, add_start_docstrings_to_model_forward, can_return_tuple, @@ -124,6 +127,7 @@ class AyaVisionPreTrainedModel(PreTrainedModel): _supports_sdpa = True _supports_quantized_cache = False _supports_static_cache = False + _supports_attention_backend = True def _init_weights(self, module): std = ( @@ -358,6 +362,7 @@ class AyaVisionModel(AyaVisionPreTrainedModel): image_features = self.multi_modal_projector(selected_image_feature) return image_features + @can_return_tuple @add_start_docstrings_to_model_forward(AYA_VISION_INPUTS_DOCSTRING) def forward( self, @@ -375,7 +380,7 @@ class AyaVisionModel(AyaVisionPreTrainedModel): return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, image_sizes: torch.Tensor = None, - **lm_kwargs, + **kwargs: Unpack[FlashAttentionKwargs], ) -> Union[Tuple, AyaVisionModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -434,17 +439,19 @@ class AyaVisionModel(AyaVisionPreTrainedModel): output_hidden_states=output_hidden_states, return_dict=True, cache_position=cache_position, - **lm_kwargs, + **kwargs, ) - output = AyaVisionModelOutputWithPast( + return AyaVisionModelOutputWithPast( last_hidden_state=outputs.last_hidden_state, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, image_hidden_states=image_features if pixel_values is not None else None, ) - return output if return_dict else output.to_tuple() + + +class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... @add_start_docstrings( @@ -512,7 +519,7 @@ class AyaVisionForConditionalGeneration(AyaVisionPreTrainedModel, GenerationMixi cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, image_sizes: Optional[torch.Tensor] = None, - **lm_kwargs, + **kwargs: Unpack[KwargsForCausalLM], ) -> Union[Tuple, AyaVisionCausalLMOutputWithPast]: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): @@ -589,7 +596,7 @@ class AyaVisionForConditionalGeneration(AyaVisionPreTrainedModel, GenerationMixi return_dict=True, cache_position=cache_position, image_sizes=image_sizes, - **lm_kwargs, + **kwargs, ) hidden_states = outputs[0] @@ -599,7 +606,9 @@ class AyaVisionForConditionalGeneration(AyaVisionPreTrainedModel, GenerationMixi loss = None if labels is not None: - loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size) + loss = self.loss_function( + logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs + ) return AyaVisionCausalLMOutputWithPast( loss=loss, diff --git a/src/transformers/models/aya_vision/modular_aya_vision.py b/src/transformers/models/aya_vision/modular_aya_vision.py index 5d7acecff88..977170708f0 100644 --- a/src/transformers/models/aya_vision/modular_aya_vision.py +++ b/src/transformers/models/aya_vision/modular_aya_vision.py @@ -20,6 +20,7 @@ import torch from torch import nn from transformers.models.llava.modeling_llava import ( + KwargsForCausalLM, LlavaCausalLMOutputWithPast, LlavaForConditionalGeneration, LlavaModel, @@ -27,6 +28,7 @@ from transformers.models.llava.modeling_llava import ( ) from ...activations import ACT2FN +from ...processing_utils import Unpack from ...utils import ( add_start_docstrings, logging, @@ -148,7 +150,7 @@ class AyaVisionForConditionalGeneration(LlavaForConditionalGeneration): cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, image_sizes: Optional[torch.Tensor] = None, - **lm_kwargs, + **kwargs: Unpack[KwargsForCausalLM], ) -> Union[Tuple, AyaVisionCausalLMOutputWithPast]: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): @@ -213,7 +215,7 @@ class AyaVisionForConditionalGeneration(LlavaForConditionalGeneration): cache_position=cache_position, logits_to_keep=logits_to_keep, image_sizes=image_sizes, - **lm_kwargs, + **kwargs, ) diff --git a/src/transformers/models/blip_2/modeling_blip_2.py b/src/transformers/models/blip_2/modeling_blip_2.py index d6ec49505bc..fcc7b2b4979 100644 --- a/src/transformers/models/blip_2/modeling_blip_2.py +++ b/src/transformers/models/blip_2/modeling_blip_2.py @@ -16,7 +16,7 @@ import math from dataclasses import dataclass -from typing import Any, Optional, Tuple, Union +from typing import Any, Callable, Optional, Tuple, Union import torch import torch.utils.checkpoint @@ -25,15 +25,18 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN from ...generation import GenerationMixin +from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPooling, BaseModelOutputWithPoolingAndCrossAttentions, ) -from ...modeling_utils import PreTrainedModel +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import ( + LossKwargs, ModelOutput, add_start_docstrings, add_start_docstrings_to_model_forward, @@ -255,6 +258,30 @@ class Blip2VisionEmbeddings(nn.Module): return embeddings +# Adapted from transformers.models.siglip.modeling_siglip.eager_attention_forward -> BLIP doesn't cast attn weights to fp32 +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + class Blip2Attention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -270,7 +297,8 @@ class Blip2Attention(nn.Module): f" {self.num_heads})." ) self.scale = self.head_dim**-0.5 - self.dropout = nn.Dropout(config.attention_dropout) + self.is_causal = False + self.attention_dropout = config.attention_dropout # small tweak here compared to CLIP, no bias here self.qkv = nn.Linear(self.embed_dim, 3 * self.embed_dim, bias=False) @@ -296,6 +324,7 @@ class Blip2Attention(nn.Module): hidden_states: torch.Tensor, head_mask: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = False, + **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" @@ -308,31 +337,32 @@ class Blip2Attention(nn.Module): ) query_states, key_states, value_states = mixed_qkv[0], mixed_qkv[1], mixed_qkv[2] - # Take the dot product between "query" and "key" to get the raw attention scores. - attention_scores = torch.matmul(query_states, key_states.transpose(-1, -2)) + attention_interface: Callable = eager_attention_forward - attention_scores = attention_scores * self.scale + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and output_attentions: + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - # Normalize the attention scores to probabilities. - attention_probs = nn.functional.softmax(attention_scores, dim=-1) + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask=None, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scale, + **kwargs, + ) - # This is actually dropping out entire tokens to attend to, which might - # seem a bit unusual, but is taken from the original Transformer paper. - attention_probs = self.dropout(attention_probs) - - # Mask heads if we want to - if head_mask is not None: - attention_probs = attention_probs * head_mask - - context_layer = torch.matmul(attention_probs, value_states).permute(0, 2, 1, 3) - - new_context_layer_shape = context_layer.size()[:-2] + (self.embed_dim,) - context_layer = context_layer.reshape(new_context_layer_shape) - - output = self.projection(context_layer) - - outputs = (output, attention_probs) if output_attentions else (output, None) + attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous() + attn_output = self.projection(attn_output) + outputs = (attn_output, attn_weights) if output_attentions else (attn_output, None) return outputs @@ -410,6 +440,10 @@ class Blip2PreTrainedModel(PreTrainedModel): config_class = Blip2Config base_model_prefix = "blip" supports_gradient_checkpointing = True + _supports_attention_backend = True + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_flex_attn = True _no_split_modules = [ "Blip2Attention", @@ -1332,6 +1366,11 @@ class Blip2TextEmbeddings(nn.Module): BLIP_2_QFORMER_START_DOCSTRING, ) class Blip2QFormerModel(Blip2PreTrainedModel): + _supports_attention_backend = False # adds position on attn weights before last matmul + _supports_flash_attn_2 = False + _supports_sdpa = False + _supports_flex_attn = False + def __init__(self, config: Blip2QFormerConfig): super().__init__(config) self.config = config @@ -1511,6 +1550,9 @@ class Blip2QFormerModel(Blip2PreTrainedModel): ) +class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... + + @add_start_docstrings( """ BLIP-2 Model for generating text and image features. The model consists of a vision encoder, Querying Transformer @@ -1526,10 +1568,10 @@ class Blip2Model(Blip2PreTrainedModel): def __init__(self, config: Blip2Config): super().__init__(config) - self.vision_model = Blip2VisionModel(config.vision_config) + self.vision_model = Blip2VisionModel._from_config(config.vision_config) self.query_tokens = nn.Parameter(torch.zeros(1, config.num_query_tokens, config.qformer_config.hidden_size)) - self.qformer = Blip2QFormerModel(config.qformer_config) + self.qformer = Blip2QFormerModel._from_config(config.qformer_config) self.language_projection = nn.Linear(config.qformer_config.hidden_size, config.text_config.hidden_size) if config.use_decoder_only_language_model: @@ -1580,6 +1622,7 @@ class Blip2Model(Blip2PreTrainedModel): output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + **kwargs: Unpack[KwargsForCausalLM], ): r""" Returns: @@ -1611,6 +1654,7 @@ class Blip2Model(Blip2PreTrainedModel): output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + **kwargs, ) else: inputs_embeds = self.language_model.get_input_embeddings()(input_ids) @@ -1624,6 +1668,7 @@ class Blip2Model(Blip2PreTrainedModel): output_hidden_states=output_hidden_states, return_dict=return_dict, labels=labels, + **kwargs, ) return text_outputs @@ -1749,6 +1794,7 @@ class Blip2Model(Blip2PreTrainedModel): labels: Optional[torch.LongTensor] = None, return_dict: Optional[bool] = None, interpolate_pos_encoding: bool = False, + **kwargs: Unpack[KwargsForCausalLM], ) -> Union[Tuple, Blip2ForConditionalGenerationModelOutput]: r""" Returns: @@ -1826,6 +1872,7 @@ class Blip2Model(Blip2PreTrainedModel): output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + **kwargs, ) logits = outputs.logits if return_dict else outputs[0] loss = None @@ -1851,6 +1898,7 @@ class Blip2Model(Blip2PreTrainedModel): output_hidden_states=output_hidden_states, return_dict=True, # toggle for easier access to loss/logits below labels=labels, + **kwargs, ) loss = outputs.loss logits = outputs.logits @@ -1981,10 +2029,10 @@ class Blip2VisionModelWithProjection(Blip2PreTrainedModel): def __init__(self, config: Blip2Config): super().__init__(config) - self.vision_model = Blip2VisionModel(config.vision_config) + self.vision_model = Blip2VisionModel._from_config(config.vision_config) self.query_tokens = nn.Parameter(torch.zeros(1, config.num_query_tokens, config.qformer_config.hidden_size)) - self.qformer = Blip2QFormerModel(config.qformer_config) + self.qformer = Blip2QFormerModel._from_config(config.qformer_config) # vision projection layer self.vision_projection = nn.Linear(config.qformer_config.hidden_size, config.image_text_hidden_size) @@ -2102,10 +2150,10 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel, GenerationMixin): def __init__(self, config: Blip2Config): super().__init__(config) - self.vision_model = Blip2VisionModel(config.vision_config) + self.vision_model = Blip2VisionModel._from_config(config.vision_config) self.query_tokens = nn.Parameter(torch.zeros(1, config.num_query_tokens, config.qformer_config.hidden_size)) - self.qformer = Blip2QFormerModel(config.qformer_config) + self.qformer = Blip2QFormerModel._from_config(config.qformer_config) self.language_projection = nn.Linear(config.qformer_config.hidden_size, config.text_config.hidden_size) if config.use_decoder_only_language_model: @@ -2180,6 +2228,7 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel, GenerationMixin): return_dict: Optional[bool] = None, interpolate_pos_encoding: bool = False, use_cache: Optional[bool] = None, + **kwargs: Unpack[KwargsForCausalLM], ) -> Union[Tuple, Blip2ForConditionalGenerationModelOutput]: r""" Returns: @@ -2308,6 +2357,7 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel, GenerationMixin): output_hidden_states=output_hidden_states, return_dict=return_dict, use_cache=use_cache, + **kwargs, ) logits = outputs.logits if return_dict else outputs[0] loss = None @@ -2334,6 +2384,7 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel, GenerationMixin): return_dict=True, # toggle for easier access to loss/logits below labels=labels, use_cache=use_cache, + **kwargs, ) loss = outputs.loss logits = outputs.logits @@ -2463,12 +2514,12 @@ class Blip2ForImageTextRetrieval(Blip2PreTrainedModel): def __init__(self, config: Blip2Config): super().__init__(config) - self.vision_model = Blip2VisionModel(config.vision_config) + self.vision_model = Blip2VisionModel._from_config(config.vision_config) self.query_tokens = nn.Parameter(torch.zeros(1, config.num_query_tokens, config.qformer_config.hidden_size)) self.embeddings = Blip2TextEmbeddings(config.qformer_config) - self.qformer = Blip2QFormerModel(config.qformer_config) + self.qformer = Blip2QFormerModel._from_config(config.qformer_config) # vision projection layer self.vision_projection = nn.Linear(config.qformer_config.hidden_size, config.image_text_hidden_size) diff --git a/src/transformers/models/chameleon/modeling_chameleon.py b/src/transformers/models/chameleon/modeling_chameleon.py index e9eca929c3b..3928e88ae01 100644 --- a/src/transformers/models/chameleon/modeling_chameleon.py +++ b/src/transformers/models/chameleon/modeling_chameleon.py @@ -14,31 +14,32 @@ # limitations under the License. """PyTorch Chameleon model.""" -import math from functools import cached_property -from typing import Optional, Tuple, Union +from typing import Callable, Optional, Tuple, Union import torch import torch.nn.functional as F import torch.utils.checkpoint from torch import nn -from torch.nn import CrossEntropyLoss from ...activations import ACT2FN -from ...cache_utils import Cache, DynamicCache, StaticCache +from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter -from ...modeling_flash_attention_utils import _flash_attention_forward, flash_attn_supports_top_left_mask +from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, ) -from ...modeling_utils import PreTrainedModel +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack from ...pytorch_utils import ALL_LAYERNORM_LAYERS from ...utils import ( + LossKwargs, add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, + can_return_tuple, is_torch_flex_attn_available, is_torchdynamo_compiling, logging, @@ -235,6 +236,33 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) +# Copied from transformers.models.llama.modeling_llama.eager_attention_forward +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + class ChameleonAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -259,6 +287,7 @@ class ChameleonAttention(nn.Module): self.rope_theta = config.rope_theta self.is_causal = True self.model_parallel_size = config.model_parallel_size + self.scaling = self.head_dim**-0.5 if (self.head_dim * self.num_heads) != self.hidden_size: raise ValueError( @@ -338,144 +367,26 @@ class ChameleonAttention(nn.Module): cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) + attention_interface: Callable = eager_attention_forward - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) - - if attention_mask is not None: # no matter the length, we just slice it - causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] - attn_weights = attn_weights + causal_mask - - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) - attn_output = torch.matmul(attn_weights, value_states) - - if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -# NO LONGER EXIST copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2 with Llama->Chameleon -# TODO(joao): add me back asap :) -class ChameleonFlashAttention2(ChameleonAttention): - """ - Chameleon flash attention module. This module inherits from `ChameleonAttention` as the weights of the module stays - untouched. The only required change would be on the forward pass where it needs to correctly call the public API of - flash attention and deal with padding tokens in case the input contains any of them. - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. - # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. - # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask() - - # Ignore copy - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.LongTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - **kwargs, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - if isinstance(past_key_value, StaticCache): - raise ValueError( - "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` " - "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers" - ) - - output_attentions = False - - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.reshape(-1, self.num_heads, self.head_dim) - query_states = self.q_norm(query_states) - - key_states = key_states.reshape(-1, self.num_key_value_heads, self.head_dim) - key_states = self.k_norm(key_states) - - # Flash attention requires the input to have the shape - # batch_size x seq_length x head_dim x hidden_dim - # therefore we just need to keep the original shape - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - - cos, sin = self.rotary_emb(value_states, position_ids) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - - if past_key_value is not None: - # sin and cos are specific to RoPE models; position_ids needed for the static cache - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. - # We would need to refactor the KV cache to be able to avoid many of these transpose/reshape/view. - query_states = query_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) - - dropout_rate = self.attention_dropout if self.training else 0.0 - - # In PEFT, usually we cast the layer norms in float32 for training stability reasons - # therefore the input hidden states gets silently casted in float32. Hence, we need - # cast them back in the correct dtype just to be sure everything works as expected. - # This might slowdown training & inference so it is recommended to not cast the LayerNorms - # in fp32. (ChameleonRMSNorm handles it correctly) - - input_dtype = query_states.dtype - if input_dtype == torch.float32: - if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() - # Handle the case where the model is quantized - elif hasattr(self.config, "_pre_quantization_dtype"): - target_dtype = self.config._pre_quantization_dtype + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and output_attentions: + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) else: - target_dtype = self.q_proj.weight.dtype + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - logger.warning_once( - f"The input hidden states seems to be silently casted in float32, this might be related to" - f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" - f" {target_dtype}." - ) - - query_states = query_states.to(target_dtype) - key_states = key_states.to(target_dtype) - value_states = value_states.to(target_dtype) - - attn_output = _flash_attention_forward( + attn_output, attn_weights = attention_interface( + self, query_states, key_states, value_states, attention_mask, - q_len, - dropout=dropout_rate, - sliding_window=getattr(self, "sliding_window", None), - use_top_left_mask=self._flash_attn_uses_top_left_mask, - is_causal=self.is_causal, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + **kwargs, ) attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() @@ -487,114 +398,13 @@ class ChameleonFlashAttention2(ChameleonAttention): return attn_output, attn_weights, past_key_value -class ChameleonSdpaAttention(ChameleonAttention): - """ - Chameleon attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from - `ChameleonAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to - SDPA API. - """ - - # Adapted from ChameleonAttention.forward - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - if output_attentions: - # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. - logger.warning_once( - "ChameleonModel is using ChameleonSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " - 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - return super().forward( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - ) - - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.reshape(-1, self.num_heads, self.head_dim) - query_states = self.q_norm(query_states) - - key_states = key_states.reshape(-1, self.num_key_value_heads, self.head_dim) - key_states = self.k_norm(key_states) - - query_states = query_states.reshape(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.reshape(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - - cos, sin = self.rotary_emb(value_states, position_ids) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, None) - - if past_key_value is not None: - # sin and cos are specific to RoPE models; position_ids needed for the static cache - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - causal_mask = attention_mask - if attention_mask is not None and cache_position is not None: - causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] - - # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, - # Reference: https://github.com/pytorch/pytorch/issues/112577. - if query_states.device.type == "cuda" and causal_mask is not None: - query_states = query_states.contiguous() - key_states = key_states.contiguous() - value_states = value_states.contiguous() - - # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment - # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. - is_causal = True if causal_mask is None and q_len > 1 else False - - attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, - key_states, - value_states, - attn_mask=causal_mask, - dropout_p=self.attention_dropout if self.training else 0.0, - is_causal=is_causal, - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.view(bsz, q_len, self.hidden_size) - - attn_output = self.o_proj(attn_output) - - return attn_output, None, past_key_value - - -CHAMELEON_ATTENTION_CLASSES = { - "eager": ChameleonAttention, - "flash_attention_2": ChameleonFlashAttention2, - "sdpa": ChameleonSdpaAttention, -} - - # copied from transformers.models.llama.modeling_llama.LlamaDecoderLayer with Llama->Chameleon, LLAMA->CHAMELEON -# TODO(joao): add me back asap :) class ChameleonDecoderLayer(nn.Module): def __init__(self, config: ChameleonConfig, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size - self.self_attn = CHAMELEON_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx) + self.self_attn = ChameleonAttention(config=config, layer_idx=layer_idx) self.mlp = ChameleonMLP(config) self.input_layernorm = ChameleonRMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -669,7 +479,7 @@ class ChameleonSwinDecoderLayer(nn.Module): super().__init__() self.hidden_size = config.hidden_size - self.self_attn = CHAMELEON_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx) + self.self_attn = ChameleonAttention(config=config, layer_idx=layer_idx) self.mlp = ChameleonMLP(config) self.input_layernorm = ChameleonRMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -1052,6 +862,7 @@ class ChameleonPreTrainedModel(PreTrainedModel): _supports_cache_class = True _supports_static_cache = True _supports_param_buffer_assignment = False + _supports_attention_backend = True def _init_weights(self, module): std = self.config.initializer_range @@ -1256,6 +1067,7 @@ class ChameleonModel(ChameleonPreTrainedModel): output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], ) -> Union[Tuple, BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -1342,6 +1154,7 @@ class ChameleonModel(ChameleonPreTrainedModel): output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, + **kwargs, ) hidden_states = layer_outputs[0] @@ -1498,6 +1311,9 @@ class ChameleonModel(ChameleonPreTrainedModel): return causal_mask +class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... + + @add_start_docstrings( "Chameleon Model with a head on top used for outputting logits for next token prediction.", CHAMELEON_START_DOCSTRING, @@ -1532,6 +1348,7 @@ class ChameleonForConditionalGeneration(ChameleonPreTrainedModel, GenerationMixi def get_decoder(self): return self.model + @can_return_tuple @add_start_docstrings_to_model_forward(CHAMELEON_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( @@ -1548,6 +1365,7 @@ class ChameleonForConditionalGeneration(ChameleonPreTrainedModel, GenerationMixi output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[KwargsForCausalLM], ) -> Union[Tuple, CausalLMOutputWithPast]: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): @@ -1596,6 +1414,7 @@ class ChameleonForConditionalGeneration(ChameleonPreTrainedModel, GenerationMixi output_hidden_states=output_hidden_states, return_dict=return_dict, cache_position=cache_position, + **kwargs, ) hidden_states = outputs[0] @@ -1607,22 +1426,7 @@ class ChameleonForConditionalGeneration(ChameleonPreTrainedModel, GenerationMixi loss = None if labels is not None: - # Upcast to float if we need to compute the loss to avoid potential precision issues - logits = logits.float() - # Shift so that tokens < n predict n - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = CrossEntropyLoss() - shift_logits = shift_logits.view(-1, self.config.vocab_size) - shift_labels = shift_labels.view(-1) - # Enable model parallelism - shift_labels = shift_labels.to(shift_logits.device) - loss = loss_fct(shift_logits, shift_labels) - - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) return CausalLMOutputWithPast( loss=loss, diff --git a/src/transformers/models/emu3/modeling_emu3.py b/src/transformers/models/emu3/modeling_emu3.py index 45031a1a647..761e40a2fe1 100644 --- a/src/transformers/models/emu3/modeling_emu3.py +++ b/src/transformers/models/emu3/modeling_emu3.py @@ -1010,6 +1010,7 @@ class Emu3VQVAE(PreTrainedModel): _supports_sdpa = True _supports_flash_attn_2 = True _supports_flex_attn = True + _supports_attention_backend = True _no_split_modules = [ "Emu3VQVAETemporalResnetBlock", "Emu3VQVAEAttentionBlock", @@ -1202,6 +1203,7 @@ class Emu3PreTrainedModel(PreTrainedModel): _supports_cache_class = True _supports_static_cache = True _supports_param_buffer_assignment = False + _supports_attention_backend = True _supports_flex_attn = True def _init_weights(self, module): @@ -1836,6 +1838,7 @@ class Emu3Model(Emu3PreTrainedModel): image = self.vqmodel.decode(image_tokens) return image + @can_return_tuple @add_start_docstrings_to_model_forward(EMU3_INPUTS_DOCSTRING) def forward( self, @@ -1851,6 +1854,7 @@ class Emu3Model(Emu3PreTrainedModel): output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], ) -> Union[Tuple, CausalLMOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -1884,8 +1888,9 @@ class Emu3Model(Emu3PreTrainedModel): use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, + return_dict=True, cache_position=cache_position, + **kwargs, ) return outputs @@ -1941,6 +1946,7 @@ class Emu3ForConditionalGeneration(Emu3PreTrainedModel, GenerationMixin): cache_position: Optional[torch.LongTensor] = None, labels: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs: Unpack[KwargsForCausalLM], ) -> Union[Tuple, CausalLMOutputWithPast]: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): @@ -2007,8 +2013,9 @@ class Emu3ForConditionalGeneration(Emu3PreTrainedModel, GenerationMixin): use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, + return_dict=True, cache_position=cache_position, + **kwargs, ) hidden_states = outputs[0] @@ -2018,7 +2025,9 @@ class Emu3ForConditionalGeneration(Emu3PreTrainedModel, GenerationMixin): loss = None if labels is not None: - loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size) + loss = self.loss_function( + logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs + ) return CausalLMOutputWithPast( loss=loss, diff --git a/src/transformers/models/emu3/modular_emu3.py b/src/transformers/models/emu3/modular_emu3.py index 6075b8d7370..971d94e29d7 100644 --- a/src/transformers/models/emu3/modular_emu3.py +++ b/src/transformers/models/emu3/modular_emu3.py @@ -25,10 +25,12 @@ import torch.utils.checkpoint from ...cache_utils import Cache from ...generation import GenerationMixin +from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import ( CausalLMOutputWithPast, ) from ...modeling_utils import PreTrainedModel +from ...processing_utils import Unpack from ...utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, @@ -41,6 +43,7 @@ from ..chameleon.modeling_chameleon import ( ChameleonVQVAEEncoderConvDownsample, ) from ..llama.modeling_llama import ( + KwargsForCausalLM, LlamaDecoderLayer, LlamaForCausalLM, LlamaModel, @@ -736,6 +739,7 @@ class Emu3VQVAE(PreTrainedModel): _supports_sdpa = True _supports_flash_attn_2 = True _supports_flex_attn = True + _supports_attention_backend = True _no_split_modules = [ "Emu3VQVAETemporalResnetBlock", "Emu3VQVAEAttentionBlock", @@ -898,6 +902,7 @@ class Emu3PreTrainedModel(ChameleonPreTrainedModel, Emu3VQVAE): "Emu3DecoderLayer", ] _supports_flex_attn = True + _supports_attention_backend = True def _init_weights(self, module): std = self.config.get_text_config().initializer_range @@ -1179,6 +1184,7 @@ class Emu3Model(Emu3PreTrainedModel): image = self.vqmodel.decode(image_tokens) return image + @can_return_tuple @add_start_docstrings_to_model_forward(EMU3_INPUTS_DOCSTRING) def forward( self, @@ -1194,6 +1200,7 @@ class Emu3Model(Emu3PreTrainedModel): output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], ) -> Union[Tuple, CausalLMOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -1227,8 +1234,9 @@ class Emu3Model(Emu3PreTrainedModel): use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, + return_dict=True, cache_position=cache_position, + **kwargs, ) return outputs @@ -1284,6 +1292,7 @@ class Emu3ForConditionalGeneration(Emu3PreTrainedModel, GenerationMixin): cache_position: Optional[torch.LongTensor] = None, labels: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs: Unpack[KwargsForCausalLM], ) -> Union[Tuple, CausalLMOutputWithPast]: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): @@ -1350,8 +1359,9 @@ class Emu3ForConditionalGeneration(Emu3PreTrainedModel, GenerationMixin): use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, + return_dict=True, cache_position=cache_position, + **kwargs, ) hidden_states = outputs[0] @@ -1361,7 +1371,9 @@ class Emu3ForConditionalGeneration(Emu3PreTrainedModel, GenerationMixin): loss = None if labels is not None: - loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size) + loss = self.loss_function( + logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs + ) return CausalLMOutputWithPast( loss=loss, diff --git a/src/transformers/models/fuyu/modeling_fuyu.py b/src/transformers/models/fuyu/modeling_fuyu.py index ec74b9be3d5..a580d47ef1c 100644 --- a/src/transformers/models/fuyu/modeling_fuyu.py +++ b/src/transformers/models/fuyu/modeling_fuyu.py @@ -21,10 +21,19 @@ import torch.utils.checkpoint from torch import nn from ...generation import GenerationMixin -from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_outputs import CausalLMOutputWithPast from ...modeling_utils import PreTrainedModel from ...models.auto.modeling_auto import AutoModel -from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings +from ...processing_utils import Unpack +from ...utils import ( + LossKwargs, + add_start_docstrings, + add_start_docstrings_to_model_forward, + can_return_tuple, + logging, + replace_return_docstrings, +) from .configuration_fuyu import FuyuConfig @@ -58,6 +67,10 @@ class FuyuPreTrainedModel(PreTrainedModel): config_class = FuyuConfig base_model_prefix = "fuyu" supports_gradient_checkpointing = True + _supports_attention_backend = True + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_flex_attn = True _no_split_modules = [] _skip_keys_device_placement = "past_key_values" @@ -142,6 +155,9 @@ FUYU_INPUTS_DOCSTRING = r""" """ +class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... + + @add_start_docstrings( """The Fuyu model which consists of a vision backbone and a language model, without a language modeling head.""", FUYU_START_DOCSTRING, @@ -224,8 +240,8 @@ class FuyuModel(FuyuPreTrainedModel): output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, - **kwargs, - ) -> Union[Tuple, BaseModelOutputWithPast]: + **kwargs: Unpack[KwargsForCausalLM], + ) -> Union[Tuple, CausalLMOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states @@ -323,6 +339,7 @@ class FuyuForCausalLM(FuyuPreTrainedModel, GenerationMixin): def get_decoder(self): return self.model.get_decoder() + @can_return_tuple @add_start_docstrings_to_model_forward(FUYU_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( @@ -392,7 +409,7 @@ class FuyuForCausalLM(FuyuPreTrainedModel, GenerationMixin): output_attentions=output_attentions, output_hidden_states=output_hidden_states, use_cache=use_cache, - return_dict=return_dict, + return_dict=True, # don't pass kwargs because Persimmon-backbone doesn't accept FA2 kwargs yet, TODO: raushan ) @@ -407,10 +424,6 @@ class FuyuForCausalLM(FuyuPreTrainedModel, GenerationMixin): logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs ) - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output - return CausalLMOutputWithPast( loss=loss, logits=logits, diff --git a/src/transformers/models/got_ocr2/modeling_got_ocr2.py b/src/transformers/models/got_ocr2/modeling_got_ocr2.py index 15013677c23..145a9ce752f 100644 --- a/src/transformers/models/got_ocr2/modeling_got_ocr2.py +++ b/src/transformers/models/got_ocr2/modeling_got_ocr2.py @@ -30,9 +30,12 @@ import torch.nn.functional as F from ...activations import ACT2FN from ...generation import GenerationMixin +from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import BaseModelOutputWithPast, ModelOutput from ...modeling_utils import PreTrainedModel +from ...processing_utils import Unpack from ...utils import ( + LossKwargs, add_start_docstrings, add_start_docstrings_to_model_forward, can_return_tuple, @@ -619,6 +622,7 @@ class GotOcr2PreTrainedModel(PreTrainedModel): _supports_sdpa = True _supports_quantized_cache = True _supports_static_cache = True + _supports_attention_backend = True def _init_weights(self, module): std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range) @@ -745,6 +749,7 @@ class GotOcr2Model(GotOcr2PreTrainedModel): image_outputs = self.vision_tower(pixel_values).last_hidden_state return self.multi_modal_projector(image_outputs) + @can_return_tuple @add_start_docstrings_to_model_forward(GOT_OCR2_INPUTS_DOCSTRING) def forward( self, @@ -759,6 +764,7 @@ class GotOcr2Model(GotOcr2PreTrainedModel): output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], ) -> Union[Tuple, GotOcr2ModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -800,16 +806,19 @@ class GotOcr2Model(GotOcr2PreTrainedModel): output_hidden_states=output_hidden_states, return_dict=True, cache_position=cache_position, + **kwargs, ) - output = GotOcr2ModelOutputWithPast( + return GotOcr2ModelOutputWithPast( last_hidden_state=outputs.last_hidden_state, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, image_hidden_states=image_features if pixel_values is not None else None, ) - return output if return_dict else output.to_tuple() + + +class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... @add_start_docstrings( @@ -874,6 +883,7 @@ class GotOcr2ForConditionalGeneration(GotOcr2PreTrainedModel, GenerationMixin): return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs: Unpack[KwargsForCausalLM], ) -> Union[Tuple, GotOcr2CausalLMOutputWithPast]: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): @@ -937,6 +947,8 @@ class GotOcr2ForConditionalGeneration(GotOcr2PreTrainedModel, GenerationMixin): output_hidden_states=output_hidden_states, return_dict=True, cache_position=cache_position, + logits_to_keep=logits_to_keep, + **kwargs, ) hidden_states = outputs[0] @@ -946,7 +958,9 @@ class GotOcr2ForConditionalGeneration(GotOcr2PreTrainedModel, GenerationMixin): loss = None if labels is not None: - loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size) + loss = self.loss_function( + logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs + ) return GotOcr2CausalLMOutputWithPast( loss=loss, diff --git a/src/transformers/models/got_ocr2/modular_got_ocr2.py b/src/transformers/models/got_ocr2/modular_got_ocr2.py index f485146f2e0..fe02b614777 100644 --- a/src/transformers/models/got_ocr2/modular_got_ocr2.py +++ b/src/transformers/models/got_ocr2/modular_got_ocr2.py @@ -21,6 +21,7 @@ import torch.nn as nn import torch.utils.checkpoint from transformers.models.llava.modeling_llava import ( + KwargsForCausalLM, LlavaCausalLMOutputWithPast, LlavaForConditionalGeneration, LlavaModel, @@ -30,6 +31,8 @@ from transformers.models.llava.modeling_llava import ( from transformers.models.sam.modeling_sam import SamMLPBlock, SamVisionAttention, SamVisionEncoder, SamVisionLayer from ...configuration_utils import PretrainedConfig +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...processing_utils import Unpack from ...utils import ( add_start_docstrings_to_model_forward, can_return_tuple, @@ -393,6 +396,7 @@ class GotOcr2Model(LlavaModel): image_outputs = self.vision_tower(pixel_values).last_hidden_state return self.multi_modal_projector(image_outputs) + @can_return_tuple @add_start_docstrings_to_model_forward(GOT_OCR2_INPUTS_DOCSTRING) def forward( self, @@ -407,6 +411,7 @@ class GotOcr2Model(LlavaModel): output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], ) -> Union[Tuple, GotOcr2ModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -448,16 +453,16 @@ class GotOcr2Model(LlavaModel): output_hidden_states=output_hidden_states, return_dict=True, cache_position=cache_position, + **kwargs, ) - output = GotOcr2ModelOutputWithPast( + return GotOcr2ModelOutputWithPast( last_hidden_state=outputs.last_hidden_state, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, image_hidden_states=image_features if pixel_values is not None else None, ) - return output if return_dict else output.to_tuple() class GotOcr2ForConditionalGeneration(LlavaForConditionalGeneration): @@ -479,6 +484,7 @@ class GotOcr2ForConditionalGeneration(LlavaForConditionalGeneration): return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs: Unpack[KwargsForCausalLM], ) -> Union[Tuple, GotOcr2CausalLMOutputWithPast]: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): @@ -542,6 +548,8 @@ class GotOcr2ForConditionalGeneration(LlavaForConditionalGeneration): output_hidden_states=output_hidden_states, return_dict=True, cache_position=cache_position, + logits_to_keep=logits_to_keep, + **kwargs, ) hidden_states = outputs[0] @@ -551,7 +559,9 @@ class GotOcr2ForConditionalGeneration(LlavaForConditionalGeneration): loss = None if labels is not None: - loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size) + loss = self.loss_function( + logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs + ) return GotOcr2CausalLMOutputWithPast( loss=loss, diff --git a/src/transformers/models/idefics/modeling_idefics.py b/src/transformers/models/idefics/modeling_idefics.py index eae379be381..ebe43db4598 100644 --- a/src/transformers/models/idefics/modeling_idefics.py +++ b/src/transformers/models/idefics/modeling_idefics.py @@ -20,24 +20,27 @@ """PyTorch Idefics model.""" from dataclasses import dataclass -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch import torch.nn.functional as F import torch.utils.checkpoint from torch import nn -from torch.nn import CrossEntropyLoss from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import ModelOutput -from ...modeling_utils import PretrainedConfig, PreTrainedModel +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PretrainedConfig, PreTrainedModel +from ...processing_utils import Unpack from ...pytorch_utils import ALL_LAYERNORM_LAYERS from ...utils import ( + LossKwargs, add_start_docstrings, add_start_docstrings_to_model_forward, + can_return_tuple, is_torch_flex_attn_available, logging, replace_return_docstrings, @@ -500,6 +503,30 @@ class IdeficsMLP(nn.Module): return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) +# Copied from transformers.models.siglip.modeling_siglip.eager_attention_forward +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + # this was adapted from LlamaAttention class IdeficsAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -515,11 +542,13 @@ class IdeficsAttention(nn.Module): layer_idx: Optional[int] = None, ): super().__init__() + self.config = config self.hidden_size = hidden_size self.num_heads = num_heads self.head_dim = hidden_size // num_heads self.dropout = dropout self.is_causal = True + self.scaling = self.head_dim**-0.5 self.layer_idx = layer_idx if layer_idx is None: @@ -596,6 +625,7 @@ class IdeficsAttention(nn.Module): output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, + **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: # if key_value_states are provided this layer is used as a cross-attention layer is_cross_attention = self.is_cross_attention or key_value_states is not None @@ -631,47 +661,33 @@ class IdeficsAttention(nn.Module): query_states = self.q_layer_norm(query_states) key_states = self.k_layer_norm(key_states) - causal_mask = attention_mask - if attention_mask is not None: - causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] + attention_interface: Callable = eager_attention_forward - # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, - # Reference: https://github.com/pytorch/pytorch/issues/112577. - if query_states.device.type == "cuda" and attention_mask is not None: - query_states = query_states.contiguous() - key_states = key_states.contiguous() - value_states = value_states.contiguous() + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and output_attentions: + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment - # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. - # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. - is_causal = True if self.is_causal and causal_mask is None and q_len > 1 else False - - attn_output = torch.nn.functional.scaled_dot_product_attention( + attn_output, attn_weights = attention_interface( + self, query_states, key_states, value_states, - attn_mask=causal_mask, - dropout_p=self.dropout if self.training else 0.0, - is_causal=is_causal, + attention_mask, + dropout=0.0 if not self.training else self.dropout, + scaling=self.scaling, + **kwargs, ) - if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.transpose(1, 2) - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) - + attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() attn_output = self.o_proj(attn_output) - attn_weights = None if output_attentions: - logger.warning_once( - "attn_weights are not extracted in scaled_dot_product_attention. The model returns None instead" - ) + attn_weights = None return attn_output, attn_weights, past_key_value @@ -706,6 +722,7 @@ class IdeficsDecoderLayer(nn.Module): output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, + **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ Args: @@ -734,6 +751,7 @@ class IdeficsDecoderLayer(nn.Module): output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, + **kwargs, ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states @@ -833,6 +851,7 @@ class IdeficsGatedCrossAttentionLayer(nn.Module): output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, past_key_value: Optional[Tuple[torch.Tensor]] = None, + **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ Args: @@ -875,6 +894,7 @@ class IdeficsGatedCrossAttentionLayer(nn.Module): key_value_states=image_hidden_states, attention_mask=image_attention_mask, output_attentions=output_attentions, + **kwargs, ) hidden_states = nn.functional.dropout(hidden_states, p=self.config, training=self.training) # Fill in zeros for cross_attention hidden_states of tokens attending to no images @@ -927,7 +947,9 @@ class IdeficsPreTrainedModel(PreTrainedModel): _no_split_modules = ["IdeficsDecoderLayer", "IdeficsGatedCrossAttentionLayer"] _supports_sdpa = True _supports_cache_class = True + _supports_flash_attn_2 = True _supports_static_cache = False # IDEFICS cannot compile due to dynamic control flow when checking inputs + _supports_attention_backend = True def _init_weights(self, module): # important: this ported version of Idefics isn't meant for training from scratch - only @@ -1029,6 +1051,9 @@ LLAMA_INPUTS_DOCSTRING = r""" """ +class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... + + @add_start_docstrings( "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", LLAMA_START_DOCSTRING, @@ -1112,6 +1137,7 @@ class IdeficsModel(IdeficsPreTrainedModel): def set_input_embeddings(self, value): self.embed_tokens = value + @can_return_tuple @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) def forward( self, @@ -1130,6 +1156,7 @@ class IdeficsModel(IdeficsPreTrainedModel): interpolate_pos_encoding: Optional[bool] = False, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], ) -> Union[Tuple, IdeficsBaseModelOutputWithPast]: device = input_ids.device if input_ids is not None else inputs_embeds.device @@ -1292,6 +1319,7 @@ class IdeficsModel(IdeficsPreTrainedModel): output_attentions=output_attentions, use_cache=use_cache, past_key_value=None, # not implemented + **kwargs, ) hidden_states = outputs[0] @@ -1303,6 +1331,7 @@ class IdeficsModel(IdeficsPreTrainedModel): output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, + **kwargs, ) return layer_outputs @@ -1348,6 +1377,7 @@ class IdeficsModel(IdeficsPreTrainedModel): cross_layer_interval=self.cross_layer_interval, gated_cross_attn_layers=self.gated_cross_attn_layers, cache_position=cache_position, + **kwargs, ) hidden_states = layer_outputs[0] @@ -1368,12 +1398,7 @@ class IdeficsModel(IdeficsPreTrainedModel): if return_legacy_cache: next_cache = next_cache.to_legacy_cache() image_hidden_states = image_hidden_states.view(batch_size, num_images, image_seq_len, image_hidden_size) - if not return_dict: - return tuple( - v - for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, image_hidden_states] - if v is not None - ) + return IdeficsBaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=next_cache, @@ -1565,6 +1590,7 @@ class IdeficsForVisionText2Text(IdeficsPreTrainedModel, GenerationMixin): ): output_embeddings.out_additional_features = input_embeddings.num_additional_embeddings + @can_return_tuple @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=IdeficsCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( @@ -1585,6 +1611,7 @@ class IdeficsForVisionText2Text(IdeficsPreTrainedModel, GenerationMixin): interpolate_pos_encoding: Optional[bool] = False, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[KwargsForCausalLM], ) -> Union[Tuple, IdeficsCausalLMOutputWithPast]: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): @@ -1641,8 +1668,9 @@ class IdeficsForVisionText2Text(IdeficsPreTrainedModel, GenerationMixin): output_attentions=output_attentions, output_hidden_states=output_hidden_states, interpolate_pos_encoding=interpolate_pos_encoding, - return_dict=return_dict, + return_dict=True, cache_position=cache_position, + **kwargs, ) hidden_states = outputs[0] @@ -1650,24 +1678,7 @@ class IdeficsForVisionText2Text(IdeficsPreTrainedModel, GenerationMixin): loss = None if labels is not None: - labels = labels.to(logits.device) - # Shift so that tokens < n predict n - if attention_mask is not None: - # we use the input attention mask to shift the logits and labels, because it is 2D. - # we also crop attn mask in case it is longer, which happens in PrefixTuning with peft - shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to(logits.device) - shift_logits = logits[..., :-1, :][shift_attention_mask != 0].contiguous() - shift_labels = labels[..., 1:][shift_attention_mask != 0].contiguous() - else: - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = CrossEntropyLoss() - loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) - - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) return IdeficsCausalLMOutputWithPast( loss=loss, diff --git a/src/transformers/models/idefics2/modeling_idefics2.py b/src/transformers/models/idefics2/modeling_idefics2.py index 186c7be6bbc..e5e3b9f3a53 100644 --- a/src/transformers/models/idefics2/modeling_idefics2.py +++ b/src/transformers/models/idefics2/modeling_idefics2.py @@ -20,17 +20,20 @@ from typing import Callable, List, Optional, Tuple, Union import torch import torch.utils.checkpoint from torch import nn -from torch.nn import CrossEntropyLoss from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import _prepare_4d_attention_mask +from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import BaseModelOutput, ModelOutput from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack from ...utils import ( + LossKwargs, add_start_docstrings, add_start_docstrings_to_model_forward, + can_return_tuple, logging, replace_return_docstrings, ) @@ -514,6 +517,7 @@ class Idefics2PreTrainedModel(PreTrainedModel): _supports_sdpa = True _supports_flex_attn = True _supports_cache_class = True + _supports_attention_backend = True def _init_weights(self, module): std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range) @@ -1089,6 +1093,7 @@ class Idefics2Model(Idefics2PreTrainedModel): new_inputs_embeds[special_image_token_mask] = reshaped_image_hidden_states.to(new_inputs_embeds.device) return new_inputs_embeds + @can_return_tuple @add_start_docstrings_to_model_forward( """ Inputs fed to the model can have an arbitrary number of images. To account for this, pixel_values fed to @@ -1117,6 +1122,7 @@ class Idefics2Model(Idefics2PreTrainedModel): output_hidden_states: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, return_dict: Optional[bool] = None, + **kwargs: Unpack[FlashAttentionKwargs], ) -> Union[Tuple, Idefics2BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -1226,15 +1232,13 @@ class Idefics2Model(Idefics2PreTrainedModel): output_attentions=output_attentions, output_hidden_states=output_hidden_states, cache_position=cache_position, - return_dict=return_dict, + return_dict=True, + **kwargs, ) if return_legacy_cache and use_cache: outputs.past_key_values = outputs.past_key_values.to_legacy_cache() - if not return_dict: - return tuple(v for v in [*outputs, image_hidden_states] if v is not None) - return Idefics2BaseModelOutputWithPast( last_hidden_state=outputs.last_hidden_state, past_key_values=outputs.past_key_values, @@ -1244,6 +1248,9 @@ class Idefics2Model(Idefics2PreTrainedModel): ) +class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... + + @add_start_docstrings( """The Idefics2 Model with a language modeling head. It is made up a SigLIP vision encoder, with a language modeling head on top. """, IDEFICS2_START_DOCSTRING, @@ -1292,6 +1299,7 @@ class Idefics2ForConditionalGeneration(Idefics2PreTrainedModel, GenerationMixin) def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings + @can_return_tuple @add_start_docstrings_to_model_forward(IDEFICS2_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=Idefics2CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( @@ -1311,6 +1319,7 @@ class Idefics2ForConditionalGeneration(Idefics2PreTrainedModel, GenerationMixin) return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs: Unpack[KwargsForCausalLM], ) -> Union[Tuple, Idefics2CausalLMOutputWithPast]: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): @@ -1386,7 +1395,8 @@ class Idefics2ForConditionalGeneration(Idefics2PreTrainedModel, GenerationMixin) output_attentions=output_attentions, output_hidden_states=output_hidden_states, cache_position=cache_position, - return_dict=return_dict, + return_dict=True, + **kwargs, ) hidden_states = outputs[0] @@ -1396,26 +1406,9 @@ class Idefics2ForConditionalGeneration(Idefics2PreTrainedModel, GenerationMixin) loss = None if labels is not None: - # Upcast to float if we need to compute the loss to avoid potential precision issues - logits = logits.float() - labels = labels.to(logits.device) - # Shift so that tokens < n predict n - if attention_mask is not None: - # we use the input attention mask to shift the logits and labels, because it is 2D. - # we also crop attn mask in case it is longer, which happens in PrefixTuning with peft - shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to(logits.device) - shift_logits = logits[..., :-1, :][shift_attention_mask != 0].contiguous() - shift_labels = labels[..., 1:][shift_attention_mask != 0].contiguous() - else: - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = CrossEntropyLoss() - loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) - - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output + loss = self.loss_function( + logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs + ) return Idefics2CausalLMOutputWithPast( loss=loss, diff --git a/src/transformers/models/idefics3/modeling_idefics3.py b/src/transformers/models/idefics3/modeling_idefics3.py index 5945fd71c5f..4e1a00fa98c 100644 --- a/src/transformers/models/idefics3/modeling_idefics3.py +++ b/src/transformers/models/idefics3/modeling_idefics3.py @@ -20,17 +20,20 @@ from typing import Callable, List, Optional, Tuple, Union import torch import torch.utils.checkpoint from torch import nn -from torch.nn import CrossEntropyLoss from ...activations import ACT2FN from ...cache_utils import DynamicCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import _prepare_4d_attention_mask +from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import BaseModelOutput, ModelOutput from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack from ...utils import ( + LossKwargs, add_start_docstrings, add_start_docstrings_to_model_forward, + can_return_tuple, logging, replace_return_docstrings, ) @@ -532,6 +535,7 @@ class Idefics3PreTrainedModel(PreTrainedModel): _supports_sdpa = True _supports_flex_attn = True _supports_cache_class = True + _supports_attention_backend = True def _init_weights(self, module): std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range) @@ -816,6 +820,7 @@ class Idefics3Model(Idefics3PreTrainedModel): new_inputs_embeds[special_image_token_mask] = reshaped_image_hidden_states return new_inputs_embeds + @can_return_tuple @add_start_docstrings_to_model_forward( """ Inputs fed to the model can have an arbitrary number of images. To account for this, pixel_values fed to @@ -843,6 +848,7 @@ class Idefics3Model(Idefics3PreTrainedModel): output_hidden_states: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, return_dict: Optional[bool] = None, + **kwargs: Unpack[FlashAttentionKwargs], ) -> Union[Tuple, Idefics3BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -939,12 +945,10 @@ class Idefics3Model(Idefics3PreTrainedModel): output_attentions=output_attentions, output_hidden_states=output_hidden_states, cache_position=cache_position, - return_dict=return_dict, + return_dict=True, + **kwargs, ) - if not return_dict: - return tuple(v for v in [*outputs, image_hidden_states] if v is not None) - return Idefics3BaseModelOutputWithPast( last_hidden_state=outputs.last_hidden_state, past_key_values=outputs.past_key_values, @@ -954,6 +958,9 @@ class Idefics3Model(Idefics3PreTrainedModel): ) +class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... + + @add_start_docstrings( """The Idefics3 Model with a language modeling head. It is made up a SigLIP vision encoder, with a language modeling head on top. """, IDEFICS3_START_DOCSTRING, @@ -1009,6 +1016,7 @@ class Idefics3ForConditionalGeneration(Idefics3PreTrainedModel, GenerationMixin) def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings + @can_return_tuple @add_start_docstrings_to_model_forward(IDEFICS3_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=Idefics3CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( @@ -1028,6 +1036,7 @@ class Idefics3ForConditionalGeneration(Idefics3PreTrainedModel, GenerationMixin) cache_position: Optional[torch.LongTensor] = None, return_dict: Optional[bool] = None, logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs: Unpack[KwargsForCausalLM], ) -> Union[Tuple, Idefics3CausalLMOutputWithPast]: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): @@ -1117,7 +1126,8 @@ class Idefics3ForConditionalGeneration(Idefics3PreTrainedModel, GenerationMixin) output_attentions=output_attentions, output_hidden_states=output_hidden_states, cache_position=cache_position, - return_dict=return_dict, + return_dict=True, + **kwargs, ) hidden_states = outputs[0] @@ -1127,26 +1137,9 @@ class Idefics3ForConditionalGeneration(Idefics3PreTrainedModel, GenerationMixin) loss = None if labels is not None: - # Upcast to float if we need to compute the loss to avoid potential precision issues - logits = logits.float() - labels = labels.to(logits.device) - # Shift so that tokens < n predict n - if attention_mask is not None: - # we use the input attention mask to shift the logits and labels, because it is 2D. - # we also crop attn mask in case it is longer, which happens in PrefixTuning with peft - shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to(logits.device) - shift_logits = logits[..., :-1, :][shift_attention_mask != 0].contiguous() - shift_labels = labels[..., 1:][shift_attention_mask != 0].contiguous() - else: - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = CrossEntropyLoss() - loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) - - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output + loss = self.loss_function( + logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs + ) return Idefics3CausalLMOutputWithPast( loss=loss, diff --git a/src/transformers/models/instructblip/modeling_instructblip.py b/src/transformers/models/instructblip/modeling_instructblip.py index 6abed48c86a..ec051387d65 100644 --- a/src/transformers/models/instructblip/modeling_instructblip.py +++ b/src/transformers/models/instructblip/modeling_instructblip.py @@ -16,27 +16,30 @@ import math from dataclasses import dataclass -from typing import Any, Optional, Tuple, Union +from typing import Any, Callable, Optional, Tuple, Union import torch import torch.utils.checkpoint from torch import nn -from torch.nn import CrossEntropyLoss from ...activations import ACT2FN from ...generation import GenerationMixin +from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPooling, BaseModelOutputWithPoolingAndCrossAttentions, ) -from ...modeling_utils import PreTrainedModel +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import ( + LossKwargs, ModelOutput, add_start_docstrings, add_start_docstrings_to_model_forward, + can_return_tuple, logging, replace_return_docstrings, torch_int, @@ -159,6 +162,30 @@ class InstructBlipVisionEmbeddings(nn.Module): return embeddings +# Adapted from transformers.models.siglip.modeling_siglip.eager_attention_forward -> InstructBLIP doesn't cast attn weights to fp32 +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + # Copied from transformers.models.blip_2.modeling_blip_2.Blip2Attention with Blip2->InstructBlip class InstructBlipAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -175,7 +202,8 @@ class InstructBlipAttention(nn.Module): f" {self.num_heads})." ) self.scale = self.head_dim**-0.5 - self.dropout = nn.Dropout(config.attention_dropout) + self.is_causal = False + self.attention_dropout = config.attention_dropout # small tweak here compared to CLIP, no bias here self.qkv = nn.Linear(self.embed_dim, 3 * self.embed_dim, bias=False) @@ -201,6 +229,7 @@ class InstructBlipAttention(nn.Module): hidden_states: torch.Tensor, head_mask: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = False, + **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" @@ -213,31 +242,32 @@ class InstructBlipAttention(nn.Module): ) query_states, key_states, value_states = mixed_qkv[0], mixed_qkv[1], mixed_qkv[2] - # Take the dot product between "query" and "key" to get the raw attention scores. - attention_scores = torch.matmul(query_states, key_states.transpose(-1, -2)) + attention_interface: Callable = eager_attention_forward - attention_scores = attention_scores * self.scale + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and output_attentions: + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - # Normalize the attention scores to probabilities. - attention_probs = nn.functional.softmax(attention_scores, dim=-1) + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask=None, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scale, + **kwargs, + ) - # This is actually dropping out entire tokens to attend to, which might - # seem a bit unusual, but is taken from the original Transformer paper. - attention_probs = self.dropout(attention_probs) - - # Mask heads if we want to - if head_mask is not None: - attention_probs = attention_probs * head_mask - - context_layer = torch.matmul(attention_probs, value_states).permute(0, 2, 1, 3) - - new_context_layer_shape = context_layer.size()[:-2] + (self.embed_dim,) - context_layer = context_layer.reshape(new_context_layer_shape) - - output = self.projection(context_layer) - - outputs = (output, attention_probs) if output_attentions else (output, None) + attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous() + attn_output = self.projection(attn_output) + outputs = (attn_output, attn_weights) if output_attentions else (attn_output, None) return outputs @@ -315,6 +345,10 @@ class InstructBlipPreTrainedModel(PreTrainedModel): config_class = InstructBlipConfig base_model_prefix = "blip" supports_gradient_checkpointing = True + _supports_attention_backend = True + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_flex_attn = True _supports_cache_class = True _supports_static_cache = True _supports_quantized_cache = False # not all LM bacbones support (e.g. T5) @@ -1087,6 +1121,11 @@ class InstructBlipQFormerModel(InstructBlipPreTrainedModel): instruction as input. """ + _supports_attention_backend = False # adds position on attn weights before last matmul + _supports_flash_attn_2 = False + _supports_sdpa = False + _supports_flex_attn = False + def __init__(self, config: InstructBlipQFormerConfig): super().__init__(config) self.config = config @@ -1277,6 +1316,9 @@ class InstructBlipQFormerModel(InstructBlipPreTrainedModel): ) +class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... + + @add_start_docstrings( """ InstructBLIP base Model consisting of language model, qformer and vision encoder. @@ -1337,6 +1379,7 @@ class InstructBlipModel(InstructBlipPreTrainedModel): if hasattr(self.language_model, "_hf_hook"): self.language_model._hf_hook.io_same_device = True # For `generate` compatibility + @can_return_tuple @add_start_docstrings_to_model_forward(INSTRUCTBLIP_INPUTS_DOCSTRING) def forward( self, @@ -1352,6 +1395,7 @@ class InstructBlipModel(InstructBlipPreTrainedModel): return_dict: Optional[bool] = None, interpolate_pos_encoding: bool = False, use_cache: Optional[bool] = None, + **kwargs: Unpack[FlashAttentionKwargs], ) -> Union[Tuple, InstructBlipForConditionalGenerationModelOutput]: return_dict = return_dict if return_dict is not None else self.config.use_return_dict @@ -1404,6 +1448,7 @@ class InstructBlipModel(InstructBlipPreTrainedModel): output_hidden_states=output_hidden_states, return_dict=return_dict, use_cache=use_cache, + **kwargs, ) else: outputs = self.language_model( @@ -1415,11 +1460,9 @@ class InstructBlipModel(InstructBlipPreTrainedModel): output_hidden_states=output_hidden_states, return_dict=return_dict, use_cache=use_cache, + **kwargs, ) - if not return_dict: - return (vision_outputs, query_outputs, outputs) - return InstructBlipForConditionalGenerationModelOutput( vision_outputs=vision_outputs, qformer_outputs=query_outputs, @@ -1448,10 +1491,10 @@ class InstructBlipForConditionalGeneration(InstructBlipPreTrainedModel, Generati def __init__(self, config: InstructBlipConfig): super().__init__(config) - self.vision_model = InstructBlipVisionModel(config.vision_config) + self.vision_model = InstructBlipVisionModel._from_config(config.vision_config) self.query_tokens = nn.Parameter(torch.zeros(1, config.num_query_tokens, config.qformer_config.hidden_size)) - self.qformer = InstructBlipQFormerModel(config.qformer_config) + self.qformer = InstructBlipQFormerModel._from_config(config.qformer_config) self.language_projection = nn.Linear(config.qformer_config.hidden_size, config.text_config.hidden_size) @@ -1516,6 +1559,7 @@ class InstructBlipForConditionalGeneration(InstructBlipPreTrainedModel, Generati if hasattr(self.language_model, "_hf_hook"): self.language_model._hf_hook.io_same_device = True # For `generate` compatibility + @can_return_tuple @add_start_docstrings_to_model_forward(INSTRUCTBLIP_INPUTS_DOCSTRING) @replace_return_docstrings( output_type=InstructBlipForConditionalGenerationModelOutput, config_class=InstructBlipVisionConfig @@ -1535,6 +1579,7 @@ class InstructBlipForConditionalGeneration(InstructBlipPreTrainedModel, Generati return_dict: Optional[bool] = None, interpolate_pos_encoding: bool = False, use_cache: Optional[bool] = None, + **kwargs: Unpack[KwargsForCausalLM], ) -> Union[Tuple, InstructBlipForConditionalGenerationModelOutput]: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): @@ -1646,21 +1691,15 @@ class InstructBlipForConditionalGeneration(InstructBlipPreTrainedModel, Generati output_hidden_states=output_hidden_states, return_dict=return_dict, use_cache=use_cache, + **kwargs, ) logits = outputs.logits if return_dict else outputs[0] loss = None - # we compute the loss here since we need to take into account the sequence length of the query embeds if labels is not None: - labels = labels.to(logits.device) - logits = logits[:, -labels.size(1) :, :] - # Shift so that tokens < n predict n - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous().to(logits.device) + loss = self.loss_function( + logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs + ) - # Flatten the tokens - loss_fct = CrossEntropyLoss(reduction="mean") - - loss = loss_fct(shift_logits.view(-1, self.config.text_config.vocab_size), shift_labels.view(-1)) else: outputs = self.language_model( inputs_embeds=inputs_embeds, @@ -1672,14 +1711,11 @@ class InstructBlipForConditionalGeneration(InstructBlipPreTrainedModel, Generati return_dict=return_dict, labels=labels, use_cache=use_cache, + **kwargs, ) loss = outputs.loss if return_dict else outputs[0] logits = outputs.logits if return_dict else outputs[1] - if not return_dict: - output = (logits, vision_outputs, query_outputs, outputs) - return ((loss,) + output) if loss is not None else output - return InstructBlipForConditionalGenerationModelOutput( loss=loss, logits=logits, diff --git a/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py b/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py index 0ce752bed8b..4e78afaa882 100644 --- a/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py +++ b/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py @@ -21,26 +21,29 @@ import math from dataclasses import dataclass -from typing import Any, Optional, Tuple, Union +from typing import Any, Callable, Optional, Tuple, Union import torch from torch import nn -from torch.nn import CrossEntropyLoss from ...activations import ACT2FN from ...generation import GenerationMixin +from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPooling, BaseModelOutputWithPoolingAndCrossAttentions, ) -from ...modeling_utils import PreTrainedModel +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import ( + LossKwargs, ModelOutput, add_start_docstrings, add_start_docstrings_to_model_forward, + can_return_tuple, logging, replace_return_docstrings, torch_int, @@ -130,6 +133,30 @@ class InstructBlipVideoVisionEmbeddings(nn.Module): return embeddings +# Adapted from transformers.models.siglip.modeling_siglip.eager_attention_forward -> InstructBlipVideo doesn't cast attn weights to fp32 +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + class InstructBlipVideoAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -145,7 +172,8 @@ class InstructBlipVideoAttention(nn.Module): f" {self.num_heads})." ) self.scale = self.head_dim**-0.5 - self.dropout = nn.Dropout(config.attention_dropout) + self.is_causal = False + self.attention_dropout = config.attention_dropout # small tweak here compared to CLIP, no bias here self.qkv = nn.Linear(self.embed_dim, 3 * self.embed_dim, bias=False) @@ -171,6 +199,7 @@ class InstructBlipVideoAttention(nn.Module): hidden_states: torch.Tensor, head_mask: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = False, + **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" @@ -183,31 +212,32 @@ class InstructBlipVideoAttention(nn.Module): ) query_states, key_states, value_states = mixed_qkv[0], mixed_qkv[1], mixed_qkv[2] - # Take the dot product between "query" and "key" to get the raw attention scores. - attention_scores = torch.matmul(query_states, key_states.transpose(-1, -2)) + attention_interface: Callable = eager_attention_forward - attention_scores = attention_scores * self.scale + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and output_attentions: + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - # Normalize the attention scores to probabilities. - attention_probs = nn.functional.softmax(attention_scores, dim=-1) + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask=None, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scale, + **kwargs, + ) - # This is actually dropping out entire tokens to attend to, which might - # seem a bit unusual, but is taken from the original Transformer paper. - attention_probs = self.dropout(attention_probs) - - # Mask heads if we want to - if head_mask is not None: - attention_probs = attention_probs * head_mask - - context_layer = torch.matmul(attention_probs, value_states).permute(0, 2, 1, 3) - - new_context_layer_shape = context_layer.size()[:-2] + (self.embed_dim,) - context_layer = context_layer.reshape(new_context_layer_shape) - - output = self.projection(context_layer) - - outputs = (output, attention_probs) if output_attentions else (output, None) + attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous() + attn_output = self.projection(attn_output) + outputs = (attn_output, attn_weights) if output_attentions else (attn_output, None) return outputs @@ -852,6 +882,9 @@ class InstructBlipVideoQFormerEmbeddings(nn.Module): return embeddings +class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... + + INSTRUCTBLIPVIDEO_START_DOCSTRING = r""" This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads @@ -945,6 +978,10 @@ class InstructBlipVideoPreTrainedModel(PreTrainedModel): config_class = InstructBlipVideoConfig base_model_prefix = "blip" supports_gradient_checkpointing = True + _supports_attention_backend = True + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_flex_attn = True _supports_cache_class = True _supports_static_cache = True _supports_quantized_cache = False # not all LM bacbones support (e.g. T5) @@ -1049,6 +1086,11 @@ class InstructBlipVideoQFormerModel(InstructBlipVideoPreTrainedModel): instruction as input. """ + _supports_attention_backend = False # adds position on attn weights before last matmul + _supports_flash_attn_2 = False + _supports_sdpa = False + _supports_flex_attn = False + def __init__(self, config: InstructBlipVideoQFormerConfig): super().__init__(config) self.config = config @@ -1332,6 +1374,7 @@ class InstructBlipVideoModel(InstructBlipVideoPreTrainedModel): if hasattr(self.language_model, "_hf_hook"): self.language_model._hf_hook.io_same_device = True # For `generate` compatibility + @can_return_tuple @add_start_docstrings_to_model_forward(INSTRUCTBLIPVIDEO_INPUTS_DOCSTRING) def forward( self, @@ -1347,6 +1390,7 @@ class InstructBlipVideoModel(InstructBlipVideoPreTrainedModel): return_dict: Optional[bool] = None, interpolate_pos_encoding: bool = False, use_cache: Optional[bool] = None, + **kwargs: Unpack[FlashAttentionKwargs], ) -> Union[Tuple, InstructBlipVideoForConditionalGenerationModelOutput]: return_dict = return_dict if return_dict is not None else self.config.use_return_dict @@ -1409,6 +1453,7 @@ class InstructBlipVideoModel(InstructBlipVideoPreTrainedModel): output_hidden_states=output_hidden_states, return_dict=return_dict, use_cache=use_cache, + **kwargs, ) else: outputs = self.language_model( @@ -1420,11 +1465,9 @@ class InstructBlipVideoModel(InstructBlipVideoPreTrainedModel): output_hidden_states=output_hidden_states, return_dict=return_dict, use_cache=use_cache, + **kwargs, ) - if not return_dict: - return (vision_outputs, query_outputs, outputs) - return InstructBlipVideoForConditionalGenerationModelOutput( vision_outputs=vision_outputs, qformer_outputs=query_outputs, @@ -1453,10 +1496,10 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipVideoPreTrainedModel def __init__(self, config: InstructBlipVideoConfig): super().__init__(config) - self.vision_model = InstructBlipVideoVisionModel(config.vision_config) + self.vision_model = InstructBlipVideoVisionModel._from_config(config.vision_config) self.query_tokens = nn.Parameter(torch.zeros(1, config.num_query_tokens, config.qformer_config.hidden_size)) - self.qformer = InstructBlipVideoQFormerModel(config.qformer_config) + self.qformer = InstructBlipVideoQFormerModel._from_config(config.qformer_config) self.language_projection = nn.Linear(config.qformer_config.hidden_size, config.text_config.hidden_size) @@ -1519,9 +1562,10 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipVideoPreTrainedModel if hasattr(self.language_model, "_hf_hook"): self.language_model._hf_hook.io_same_device = True # For `generate` compatibility + @can_return_tuple @add_start_docstrings_to_model_forward(INSTRUCTBLIPVIDEO_INPUTS_DOCSTRING) @replace_return_docstrings( - output_type=InstructBlipVideoForConditionalGenerationModelOutput, config_class=InstructBlipVideoVisionConfig + output_type=InstructBlipVideoForConditionalGenerationModelOutput, config_class=InstructBlipVideoConfig ) def forward( self, @@ -1538,6 +1582,7 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipVideoPreTrainedModel return_dict: Optional[bool] = None, interpolate_pos_encoding: bool = False, use_cache: Optional[bool] = None, + **kwargs: Unpack[KwargsForCausalLM], ) -> Union[Tuple, InstructBlipVideoForConditionalGenerationModelOutput]: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): @@ -1682,21 +1727,15 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipVideoPreTrainedModel output_hidden_states=output_hidden_states, return_dict=return_dict, use_cache=use_cache, + **kwargs, ) logits = outputs.logits if return_dict else outputs[0] loss = None - # we compute the loss here since we need to take into account the sequence length of the query embeds if labels is not None: - labels = labels.to(logits.device) - logits = logits[:, -labels.size(1) :, :] - # Shift so that tokens < n predict n - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous().to(logits.device) + loss = self.loss_function( + logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs + ) - # Flatten the tokens - loss_fct = CrossEntropyLoss(reduction="mean") - - loss = loss_fct(shift_logits.view(-1, self.config.text_config.vocab_size), shift_labels.view(-1)) else: outputs = self.language_model( inputs_embeds=inputs_embeds, @@ -1708,14 +1747,11 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipVideoPreTrainedModel return_dict=return_dict, labels=labels, use_cache=use_cache, + **kwargs, ) loss = outputs.loss if return_dict else outputs[0] logits = outputs.logits if return_dict else outputs[1] - if not return_dict: - output = (logits, vision_outputs, query_outputs, outputs) - return ((loss,) + output) if loss is not None else output - return InstructBlipVideoForConditionalGenerationModelOutput( loss=loss, logits=logits, diff --git a/src/transformers/models/instructblipvideo/modular_instructblipvideo.py b/src/transformers/models/instructblipvideo/modular_instructblipvideo.py index d28485545f7..566b00d983c 100644 --- a/src/transformers/models/instructblipvideo/modular_instructblipvideo.py +++ b/src/transformers/models/instructblipvideo/modular_instructblipvideo.py @@ -18,7 +18,6 @@ from typing import Optional, Tuple, Union import torch import torch.utils.checkpoint -from torch.nn import CrossEntropyLoss from transformers.models.instructblip.configuration_instructblip import ( InstructBlipQFormerConfig, @@ -31,11 +30,14 @@ from transformers.models.instructblip.modeling_instructblip import ( InstructBlipPreTrainedModel, InstructBlipQFormerModel, InstructBlipVisionModel, + KwargsForCausalLM, ) from ...configuration_utils import PretrainedConfig +from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES -from ...utils import add_start_docstrings_to_model_forward, logging +from ...processing_utils import Unpack +from ...utils import add_start_docstrings_to_model_forward, can_return_tuple, logging, replace_return_docstrings from ..auto import CONFIG_MAPPING, AutoConfig @@ -196,6 +198,7 @@ INSTRUCTBLIPVIDEO_INPUTS_DOCSTRING = None class InstructBlipVideoModel(InstructBlipModel): + @can_return_tuple @add_start_docstrings_to_model_forward(INSTRUCTBLIPVIDEO_INPUTS_DOCSTRING) def forward( self, @@ -211,6 +214,7 @@ class InstructBlipVideoModel(InstructBlipModel): return_dict: Optional[bool] = None, interpolate_pos_encoding: bool = False, use_cache: Optional[bool] = None, + **kwargs: Unpack[FlashAttentionKwargs], ) -> Union[Tuple, InstructBlipVideoForConditionalGenerationModelOutput]: return_dict = return_dict if return_dict is not None else self.config.use_return_dict @@ -273,6 +277,7 @@ class InstructBlipVideoModel(InstructBlipModel): output_hidden_states=output_hidden_states, return_dict=return_dict, use_cache=use_cache, + **kwargs, ) else: outputs = self.language_model( @@ -284,11 +289,9 @@ class InstructBlipVideoModel(InstructBlipModel): output_hidden_states=output_hidden_states, return_dict=return_dict, use_cache=use_cache, + **kwargs, ) - if not return_dict: - return (vision_outputs, query_outputs, outputs) - return InstructBlipVideoForConditionalGenerationModelOutput( vision_outputs=vision_outputs, qformer_outputs=query_outputs, @@ -297,6 +300,11 @@ class InstructBlipVideoModel(InstructBlipModel): class InstructBlipVideoForConditionalGeneration(InstructBlipForConditionalGeneration): + @can_return_tuple + @add_start_docstrings_to_model_forward(INSTRUCTBLIPVIDEO_INPUTS_DOCSTRING) + @replace_return_docstrings( + output_type=InstructBlipVideoForConditionalGenerationModelOutput, config_class=InstructBlipVideoConfig + ) def forward( self, pixel_values: torch.FloatTensor, @@ -312,6 +320,7 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipForConditionalGenera return_dict: Optional[bool] = None, interpolate_pos_encoding: bool = False, use_cache: Optional[bool] = None, + **kwargs: Unpack[KwargsForCausalLM], ) -> Union[Tuple, InstructBlipVideoForConditionalGenerationModelOutput]: r""" ```python @@ -447,21 +456,15 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipForConditionalGenera output_hidden_states=output_hidden_states, return_dict=return_dict, use_cache=use_cache, + **kwargs, ) logits = outputs.logits if return_dict else outputs[0] loss = None - # we compute the loss here since we need to take into account the sequence length of the query embeds if labels is not None: - labels = labels.to(logits.device) - logits = logits[:, -labels.size(1) :, :] - # Shift so that tokens < n predict n - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous().to(logits.device) + loss = self.loss_function( + logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs + ) - # Flatten the tokens - loss_fct = CrossEntropyLoss(reduction="mean") - - loss = loss_fct(shift_logits.view(-1, self.config.text_config.vocab_size), shift_labels.view(-1)) else: outputs = self.language_model( inputs_embeds=inputs_embeds, @@ -473,14 +476,11 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipForConditionalGenera return_dict=return_dict, labels=labels, use_cache=use_cache, + **kwargs, ) loss = outputs.loss if return_dict else outputs[0] logits = outputs.logits if return_dict else outputs[1] - if not return_dict: - output = (logits, vision_outputs, query_outputs, outputs) - return ((loss,) + output) if loss is not None else output - return InstructBlipVideoForConditionalGenerationModelOutput( loss=loss, logits=logits, diff --git a/src/transformers/models/internvl/modeling_internvl.py b/src/transformers/models/internvl/modeling_internvl.py index 6c59d06d2ee..2b1683e497c 100644 --- a/src/transformers/models/internvl/modeling_internvl.py +++ b/src/transformers/models/internvl/modeling_internvl.py @@ -35,6 +35,7 @@ from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPast, BaseMo from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import ( + LossKwargs, ModelOutput, add_code_sample_docstrings, add_start_docstrings, @@ -621,6 +622,7 @@ class InternVLPreTrainedModel(PreTrainedModel): _supports_sdpa = True _supports_quantized_cache = True _supports_static_cache = True + _supports_attention_backend = True def _init_weights(self, module): std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range) @@ -828,6 +830,7 @@ class InternVLModel(InternVLPreTrainedModel): return vision_features + @can_return_tuple @add_start_docstrings_to_model_forward(INTERNVL_INPUTS_DOCSTRING) def forward( self, @@ -845,7 +848,7 @@ class InternVLModel(InternVLPreTrainedModel): return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, image_sizes: torch.Tensor = None, - **lm_kwargs, + **kwargs: Unpack[FlashAttentionKwargs], ) -> Union[Tuple, InternVLModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -904,17 +907,16 @@ class InternVLModel(InternVLPreTrainedModel): output_hidden_states=output_hidden_states, return_dict=True, cache_position=cache_position, - **lm_kwargs, + **kwargs, ) - output = InternVLModelOutputWithPast( + return InternVLModelOutputWithPast( last_hidden_state=outputs.last_hidden_state, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, image_hidden_states=image_features if pixel_values is not None else None, ) - return output if return_dict else output.to_tuple() def pixel_shuffle(self, vision_features: torch.Tensor, scale_factor: float = 0.5): """Perform pixel shuffle downsampling on vision features. @@ -992,6 +994,9 @@ class InternVLCausalLMOutputWithPast(ModelOutput): image_hidden_states: Optional[torch.FloatTensor] = None +class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... + + @add_start_docstrings( """The INTERNVL model which consists of a vision backbone and a language model.""", INTERNVL_START_DOCSTRING, @@ -1056,8 +1061,8 @@ class InternVLForConditionalGeneration(InternVLPreTrainedModel, GenerationMixin) return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, - image_sizes: torch.Tensor = None, - **lm_kwargs, + image_sizes: Optional[torch.Tensor] = None, + **kwargs: Unpack[KwargsForCausalLM], ) -> Union[Tuple, InternVLCausalLMOutputWithPast]: r""" Args: @@ -1138,7 +1143,7 @@ class InternVLForConditionalGeneration(InternVLPreTrainedModel, GenerationMixin) return_dict=True, cache_position=cache_position, image_sizes=image_sizes, - **lm_kwargs, + **kwargs, ) hidden_states = outputs[0] @@ -1148,7 +1153,9 @@ class InternVLForConditionalGeneration(InternVLPreTrainedModel, GenerationMixin) loss = None if labels is not None: - loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size) + loss = self.loss_function( + logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs + ) return InternVLCausalLMOutputWithPast( loss=loss, diff --git a/src/transformers/models/kosmos2/modeling_kosmos2.py b/src/transformers/models/kosmos2/modeling_kosmos2.py index e557293ee37..b5eed095499 100644 --- a/src/transformers/models/kosmos2/modeling_kosmos2.py +++ b/src/transformers/models/kosmos2/modeling_kosmos2.py @@ -21,10 +21,10 @@ from typing import Any, Callable, List, Optional, Tuple, Union import torch import torch.utils.checkpoint from torch import nn -from torch.nn import CrossEntropyLoss from ...activations import ACT2FN from ...generation import GenerationMixin +from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -32,10 +32,13 @@ from ...modeling_outputs import ( CausalLMOutputWithCrossAttentions, ) from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack from ...utils import ( + LossKwargs, ModelOutput, add_start_docstrings, add_start_docstrings_to_model_forward, + can_return_tuple, logging, replace_return_docstrings, torch_int, @@ -466,7 +469,7 @@ class Kosmos2VisionEmbeddings(nn.Module): return embeddings -# Copied from transformers.models.siglip.modeling_siglip.eager_attention_forward +# Adapted from transformers.models.siglip.modeling_siglip.eager_attention_forward -> Kosmos2 doesn't cast attn weights to fp32 def eager_attention_forward( module: nn.Module, query: torch.Tensor, @@ -481,7 +484,7 @@ def eager_attention_forward( if attention_mask is not None: attn_weights = attn_weights + attention_mask - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.softmax(attn_weights, dim=-1) attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) attn_output = torch.matmul(attn_weights, value) @@ -892,6 +895,7 @@ class KosmosTextAttention(nn.Module): bias: bool = True, ): super().__init__() + self.config = config self.embed_dim = embed_dim self.num_heads = num_heads self.dropout = dropout @@ -929,6 +933,7 @@ class KosmosTextAttention(nn.Module): attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, + **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" @@ -953,8 +958,7 @@ class KosmosTextAttention(nn.Module): key_states = torch.cat([past_key_value[0], key_states], dim=2) value_states = torch.cat([past_key_value[1], value_states], dim=2) - query_states = self._shape(self.q_proj(hidden_states) * self.scaling) - attn_weights = torch.matmul(query_states, key_states.transpose(-1, -2)) + query_states = self._shape(self.q_proj(hidden_states)) if self.is_decoder: # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. @@ -966,32 +970,33 @@ class KosmosTextAttention(nn.Module): # if encoder bi-directional self-attention `past_key_value` is always `None` past_key_value = (key_states, value_states) - src_len = key_states.size(2) + attention_interface: Callable = eager_attention_forward - if attention_mask is not None: - if attention_mask.size() != (batch_size, 1, seq_length, src_len): - raise ValueError( - f"Attention mask should be of size {(batch_size, 1, seq_length, src_len)}, but is {attention_mask.size()}" + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and output_attentions: + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' ) - attn_weights = attn_weights + attention_mask + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - attn_weights = nn.functional.softmax(attn_weights, dim=-1) - - # Mask heads if we want to - if layer_head_mask is not None: - attn_weights = attn_weights * layer_head_mask - - attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) - - # attn_output = torch.bmm(attn_probs, value_states) ? - context_states = torch.matmul(attn_weights, value_states) - # attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) ? - context_states = context_states.permute(0, 2, 1, 3).contiguous().view(batch_size, seq_length, -1) + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.dropout, + scaling=self.scaling, + **kwargs, + ) + attn_output = attn_output.reshape(batch_size, seq_length, -1).contiguous() if self.inner_attn_ln is not None: - context_states = self.inner_attn_ln(context_states) + attn_output = self.inner_attn_ln(attn_output) - attn_output = self.out_proj(context_states) + attn_output = self.out_proj(attn_output) return attn_output, attn_weights, past_key_value @@ -1060,6 +1065,7 @@ class Kosmos2TextBlock(nn.Module): past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = True, + **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: residual = hidden_states @@ -1076,6 +1082,7 @@ class Kosmos2TextBlock(nn.Module): attention_mask=attention_mask, layer_head_mask=layer_head_mask, output_attentions=output_attentions, + **kwargs, ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states @@ -1103,6 +1110,7 @@ class Kosmos2TextBlock(nn.Module): layer_head_mask=cross_attn_layer_head_mask, past_key_value=cross_attn_past_key_value, output_attentions=output_attentions, + **kwargs, ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states @@ -1216,6 +1224,7 @@ class Kosmos2TextTransformer(nn.Module): return hidden_states + @can_return_tuple def forward( self, input_ids: Optional[torch.Tensor] = None, @@ -1233,6 +1242,7 @@ class Kosmos2TextTransformer(nn.Module): output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + **kwargs: Unpack[FlashAttentionKwargs], ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -1338,6 +1348,7 @@ class Kosmos2TextTransformer(nn.Module): past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, + **kwargs, ) hidden_states = layer_outputs[0] @@ -1357,18 +1368,6 @@ class Kosmos2TextTransformer(nn.Module): if output_hidden_states: all_hidden_states += (hidden_states,) - if not return_dict: - return tuple( - v - for v in [ - hidden_states, - present_key_value_states, - all_hidden_states, - all_self_attns, - all_cross_attentions, - ] - if v is not None - ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, past_key_values=present_key_value_states, @@ -1387,6 +1386,9 @@ class Kosmos2PreTrainedModel(PreTrainedModel): config_class = Kosmos2Config supports_gradient_checkpointing = True _no_split_modules = ["Kosmos2VisionEncoderLayer", "Kosmos2TextBlock"] + _supports_attention_backend = True + _supports_flash_attn_2 = True + _supports_sdpa = True def _init_weights(self, module): """Initialize the weights""" @@ -1525,6 +1527,7 @@ class Kosmos2TextModel(Kosmos2PreTrainedModel): def set_input_embeddings(self, value): self.model.embed_tokens = value + @can_return_tuple @add_start_docstrings_to_model_forward(KOSMOS2_TEXT_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=BaseModelOutputWithPastAndCrossAttentions, config_class=Kosmos2TextConfig) def forward( @@ -1544,6 +1547,7 @@ class Kosmos2TextModel(Kosmos2PreTrainedModel): output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + **kwargs: Unpack[FlashAttentionKwargs], ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: r""" Returns: @@ -1565,9 +1569,13 @@ class Kosmos2TextModel(Kosmos2PreTrainedModel): output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + **kwargs, ) +class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... + + @add_start_docstrings( """ The text model from KOSMOS-2 with a language modeling head on top (linear layer with weights tied to the input @@ -1600,6 +1608,7 @@ class Kosmos2TextForCausalLM(Kosmos2PreTrainedModel, GenerationMixin): def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings + @can_return_tuple @add_start_docstrings_to_model_forward(KOSMOS2_TEXT_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=Kosmos2TextConfig) def forward( @@ -1620,6 +1629,7 @@ class Kosmos2TextForCausalLM(Kosmos2PreTrainedModel, GenerationMixin): output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + **kwargs: Unpack[KwargsForCausalLM], ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): @@ -1652,27 +1662,14 @@ class Kosmos2TextForCausalLM(Kosmos2PreTrainedModel, GenerationMixin): use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, + return_dict=True, + **kwargs, ) lm_logits = self.lm_head(outputs[0]) loss = None if labels is not None: - # move labels to correct device to enable model parallelism - labels = labels.to(lm_logits.device) - # Shift so that tokens < n predict n - shift_logits = lm_logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - batch_size, seq_length, vocab_size = shift_logits.shape - # Flatten the tokens - loss_fct = CrossEntropyLoss() - loss = loss_fct( - shift_logits.view(batch_size * seq_length, vocab_size), shift_labels.view(batch_size * seq_length) - ) - - if not return_dict: - output = (lm_logits,) + outputs[1:] - return (loss,) + output if loss is not None else output + loss = self.loss_function(logits=lm_logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) return CausalLMOutputWithCrossAttentions( loss=loss, @@ -1804,6 +1801,7 @@ class Kosmos2Model(Kosmos2PreTrainedModel): def set_input_embeddings(self, value): self.text_model.model.embed_tokens = value + @can_return_tuple @add_start_docstrings_to_model_forward(KOSMOS2_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=Kosmos2ModelOutput, config_class=_CONFIG_FOR_DOC) def forward( @@ -1822,6 +1820,7 @@ class Kosmos2Model(Kosmos2PreTrainedModel): output_hidden_states: Optional[bool] = None, interpolate_pos_encoding: bool = False, return_dict: Optional[bool] = None, + **kwargs: Unpack[FlashAttentionKwargs], ) -> Union[Tuple, Kosmos2ModelOutput]: r""" Returns: @@ -1893,13 +1892,10 @@ class Kosmos2Model(Kosmos2PreTrainedModel): use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, + return_dict=True, + **kwargs, ) - if not return_dict: - outputs = outputs + (image_embeds, projection_attentions, vision_model_output) - return tuple(output for output in outputs if output is not None) - return Kosmos2ModelOutput( last_hidden_state=outputs.last_hidden_state, past_key_values=outputs.past_key_values, @@ -1946,6 +1942,7 @@ class Kosmos2ForConditionalGeneration(Kosmos2PreTrainedModel, GenerationMixin): def set_output_embeddings(self, new_embeddings): self.text_model.set_output_embeddings(new_embeddings) + @can_return_tuple @add_start_docstrings_to_model_forward(KOSMOS2_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=Kosmos2ForConditionalGenerationModelOutput, config_class=_CONFIG_FOR_DOC) def forward( @@ -1964,6 +1961,7 @@ class Kosmos2ForConditionalGeneration(Kosmos2PreTrainedModel, GenerationMixin): output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + **kwargs: Unpack[KwargsForCausalLM], ) -> Union[Tuple, Kosmos2ForConditionalGenerationModelOutput]: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): @@ -2048,13 +2046,10 @@ class Kosmos2ForConditionalGeneration(Kosmos2PreTrainedModel, GenerationMixin): use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, + return_dict=True, + **kwargs, ) - if not return_dict: - outputs = lm_outputs + (image_embeds, projection_attentions, vision_model_output) - return tuple(output for output in outputs if output is not None) - return Kosmos2ForConditionalGenerationModelOutput( loss=lm_outputs.loss, logits=lm_outputs.logits, diff --git a/src/transformers/models/llama4/modeling_llama4.py b/src/transformers/models/llama4/modeling_llama4.py index 4d61a011483..00f924b4d74 100644 --- a/src/transformers/models/llama4/modeling_llama4.py +++ b/src/transformers/models/llama4/modeling_llama4.py @@ -39,8 +39,10 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import ( + LossKwargs, add_start_docstrings, add_start_docstrings_to_model_forward, + can_return_tuple, is_torch_flex_attn_available, logging, replace_return_docstrings, @@ -244,6 +246,7 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) +# Adapted from transformers.models.llama.modeling_llama.eager_attention_forward -> llama4 doesn't cast attn weights to fp32 def eager_attention_forward( module: nn.Module, query: torch.Tensor, @@ -256,12 +259,13 @@ def eager_attention_forward( ): key_states = repeat_kv(key, module.num_key_value_groups) value_states = repeat_kv(value, module.num_key_value_groups) - attn_weights = torch.matmul(query, key_states.transpose(2, 3)) / math.sqrt(module.head_dim) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling if attention_mask is not None: causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] attn_weights = attn_weights + causal_mask - attn_weights = nn.functional.softmax(attn_weights.float(), dim=-1).to(query.dtype) + attn_weights = nn.functional.softmax(attn_weights, dim=-1) attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) attn_output = torch.matmul(attn_weights, value_states) attn_output = attn_output.transpose(1, 2).contiguous() @@ -607,6 +611,7 @@ class Llama4TextModel(Llama4PreTrainedModel): def set_input_embeddings(self, value): self.embed_tokens = value + @can_return_tuple @add_start_docstrings_to_model_forward(LLAMA4_INPUTS_DOCSTRING) def forward( self, @@ -712,13 +717,12 @@ class Llama4TextModel(Llama4PreTrainedModel): if output_hidden_states: all_hidden_states += (hidden_states,) - output = BaseModelOutputWithPast( + return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=past_key_values if use_cache else None, hidden_states=all_hidden_states, attentions=all_self_attns, ) - return output if return_dict else output.to_tuple() @torch.compiler.disable(recursive=False) # the operations in this method are not compilable def _update_causal_mask( @@ -931,6 +935,9 @@ class Llama4TextModel(Llama4PreTrainedModel): return causal_mask +class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... + + class Llama4ForCausalLM(Llama4PreTrainedModel, GenerationMixin): _no_split_modules = ["Llama4TextDecoderLayer"] base_model_prefix = "language_model" @@ -965,6 +972,7 @@ class Llama4ForCausalLM(Llama4PreTrainedModel, GenerationMixin): def get_decoder(self): return self.model + @can_return_tuple @add_start_docstrings_to_model_forward(LLAMA4_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( @@ -981,7 +989,7 @@ class Llama4ForCausalLM(Llama4PreTrainedModel, GenerationMixin): return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, - **kwargs, + **kwargs: Unpack[KwargsForCausalLM], ) -> Union[Tuple, CausalLMOutputWithPast]: r""" Args: @@ -1031,7 +1039,7 @@ class Llama4ForCausalLM(Llama4PreTrainedModel, GenerationMixin): use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, + return_dict=True, cache_position=cache_position, **kwargs, ) @@ -1044,10 +1052,6 @@ class Llama4ForCausalLM(Llama4PreTrainedModel, GenerationMixin): if labels is not None: loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output - return CausalLMOutputWithPast( loss=loss, logits=logits, @@ -1208,6 +1212,7 @@ class Llama4VisionAttention(nn.Module): self.head_dim = config.hidden_size // config.num_attention_heads self.num_key_value_groups = 1 self.attention_dropout = config.attention_dropout + self.scaling = self.head_dim**-0.5 self.q_proj = nn.Linear(self.embed_dim, self.num_heads * self.head_dim, bias=True) self.k_proj = nn.Linear(self.embed_dim, self.num_heads * self.head_dim, bias=True) @@ -1593,7 +1598,6 @@ class Llama4ForConditionalGeneration(Llama4PreTrainedModel, GenerationMixin): _tp_plan = {} base_model_prefix = "" config_class = Llama4Config - _supports_flex_attn = True def __init__(self, config: Llama4Config): super().__init__(config) @@ -1673,7 +1677,7 @@ class Llama4ForConditionalGeneration(Llama4PreTrainedModel, GenerationMixin): cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, image_sizes: torch.Tensor = None, - **lm_kwargs, + **kwargs: Unpack[KwargsForCausalLM], ) -> Union[Tuple, Llama4CausalLMOutputWithPast]: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): @@ -1780,7 +1784,7 @@ class Llama4ForConditionalGeneration(Llama4PreTrainedModel, GenerationMixin): return_dict=return_dict, cache_position=cache_position, logits_to_keep=logits_to_keep, - **lm_kwargs, + **kwargs, ) logits = outputs[0] diff --git a/src/transformers/models/llava/modeling_llava.py b/src/transformers/models/llava/modeling_llava.py index 3273d595a5a..42fc86c77cc 100644 --- a/src/transformers/models/llava/modeling_llava.py +++ b/src/transformers/models/llava/modeling_llava.py @@ -23,9 +23,12 @@ from torch import nn from ...activations import ACT2FN from ...generation import GenerationMixin +from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import BaseModelOutputWithPast, ModelOutput from ...modeling_utils import PreTrainedModel +from ...processing_utils import Unpack from ...utils import ( + LossKwargs, add_start_docstrings, add_start_docstrings_to_model_forward, can_return_tuple, @@ -171,6 +174,7 @@ class LlavaPreTrainedModel(PreTrainedModel): _supports_sdpa = True _supports_quantized_cache = True _supports_static_cache = True + _supports_attention_backend = True def _init_weights(self, module): # important: this ported version of Llava isn't meant for training from scratch - only @@ -331,6 +335,7 @@ class LlavaModel(LlavaPreTrainedModel): image_features = self.multi_modal_projector(selected_image_feature) return image_features + @can_return_tuple @add_start_docstrings_to_model_forward(LLAVA_INPUTS_DOCSTRING) def forward( self, @@ -348,7 +353,7 @@ class LlavaModel(LlavaPreTrainedModel): return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, image_sizes: torch.Tensor = None, - **lm_kwargs, + **kwargs: Unpack[FlashAttentionKwargs], ) -> Union[Tuple, LlavaModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -407,17 +412,19 @@ class LlavaModel(LlavaPreTrainedModel): output_hidden_states=output_hidden_states, return_dict=True, cache_position=cache_position, - **lm_kwargs, + **kwargs, ) - output = LlavaModelOutputWithPast( + return LlavaModelOutputWithPast( last_hidden_state=outputs.last_hidden_state, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, image_hidden_states=image_features if pixel_values is not None else None, ) - return output if return_dict else output.to_tuple() + + +class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... @add_start_docstrings( @@ -484,8 +491,8 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel, GenerationMixin): return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, - image_sizes: torch.Tensor = None, - **lm_kwargs, + image_sizes: Optional[torch.Tensor] = None, + **kwargs: Unpack[KwargsForCausalLM], ) -> Union[Tuple, LlavaCausalLMOutputWithPast]: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): @@ -553,7 +560,7 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel, GenerationMixin): return_dict=True, cache_position=cache_position, image_sizes=image_sizes, - **lm_kwargs, + **kwargs, ) hidden_states = outputs[0] @@ -563,7 +570,9 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel, GenerationMixin): loss = None if labels is not None: - loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size) + loss = self.loss_function( + logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs + ) return LlavaCausalLMOutputWithPast( loss=loss, diff --git a/src/transformers/models/llava_next/modeling_llava_next.py b/src/transformers/models/llava_next/modeling_llava_next.py index c17eb9622c8..ff1cd57d9dc 100644 --- a/src/transformers/models/llava_next/modeling_llava_next.py +++ b/src/transformers/models/llava_next/modeling_llava_next.py @@ -26,9 +26,12 @@ from torch import nn from ...activations import ACT2FN from ...generation import GenerationMixin from ...image_processing_utils import select_best_resolution +from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import BaseModelOutputWithPast, ModelOutput from ...modeling_utils import PreTrainedModel +from ...processing_utils import Unpack from ...utils import ( + LossKwargs, add_start_docstrings, add_start_docstrings_to_model_forward, can_return_tuple, @@ -280,6 +283,7 @@ class LlavaNextPreTrainedModel(PreTrainedModel): _supports_sdpa = True _supports_quantized_cache = True _supports_static_cache = True + _supports_attention_backend = True def _init_weights(self, module): std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range) @@ -528,6 +532,7 @@ class LlavaNextModel(LlavaNextPreTrainedModel): image_features = torch.split(image_features, image_num_patches, dim=0) return image_features + @can_return_tuple @add_start_docstrings_to_model_forward(LLAVA_NEXT_INPUTS_DOCSTRING) def forward( self, @@ -545,7 +550,7 @@ class LlavaNextModel(LlavaNextPreTrainedModel): output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - **lm_kwargs, + **kwargs: Unpack[FlashAttentionKwargs], ) -> Union[Tuple, LlavaNextModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -609,17 +614,19 @@ class LlavaNextModel(LlavaNextPreTrainedModel): output_hidden_states=output_hidden_states, return_dict=True, cache_position=cache_position, - **lm_kwargs, + **kwargs, ) - output = LlavaNextModelOutputWithPast( + return LlavaNextModelOutputWithPast( last_hidden_state=outputs.last_hidden_state, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, image_hidden_states=image_features if pixel_values is not None else None, ) - return output if return_dict else output.to_tuple() + + +class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... @add_start_docstrings( @@ -688,7 +695,7 @@ class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel, GenerationMixi return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, - **lm_kwargs, + **kwargs: Unpack[KwargsForCausalLM], ) -> Union[Tuple, LlavaNextCausalLMOutputWithPast]: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): @@ -756,7 +763,7 @@ class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel, GenerationMixi output_hidden_states=output_hidden_states, return_dict=return_dict, cache_position=cache_position, - **lm_kwargs, + **kwargs, ) hidden_states = outputs[0] @@ -766,7 +773,9 @@ class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel, GenerationMixi loss = None if labels is not None: - loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size) + loss = self.loss_function( + logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs + ) return LlavaNextCausalLMOutputWithPast( loss=loss, diff --git a/src/transformers/models/llava_next_video/modeling_llava_next_video.py b/src/transformers/models/llava_next_video/modeling_llava_next_video.py index 6b6f14bc342..f124e3b4c1f 100644 --- a/src/transformers/models/llava_next_video/modeling_llava_next_video.py +++ b/src/transformers/models/llava_next_video/modeling_llava_next_video.py @@ -30,9 +30,12 @@ from torch import nn from ...activations import ACT2FN from ...generation import GenerationMixin from ...image_processing_utils import select_best_resolution +from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import BaseModelOutputWithPast, ModelOutput from ...modeling_utils import PreTrainedModel +from ...processing_utils import Unpack from ...utils import ( + LossKwargs, add_start_docstrings, add_start_docstrings_to_model_forward, can_return_tuple, @@ -223,6 +226,7 @@ class LlavaNextVideoPreTrainedModel(PreTrainedModel): _supports_sdpa = True _supports_quantized_cache = True _supports_static_cache = True + _supports_attention_backend = True def _init_weights(self, module): std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range) @@ -581,6 +585,7 @@ class LlavaNextVideoModel(LlavaNextVideoPreTrainedModel): image_features = torch.split(image_features, image_num_patches, dim=0) return image_features + @can_return_tuple @add_start_docstrings_to_model_forward(LLAVA_NEXT_VIDEO_INPUTS_DOCSTRING) def forward( self, @@ -599,7 +604,7 @@ class LlavaNextVideoModel(LlavaNextVideoPreTrainedModel): output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - **lm_kwargs, + **kwargs: Unpack[FlashAttentionKwargs], ) -> Union[Tuple, LlavaNextVideoModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -684,10 +689,10 @@ class LlavaNextVideoModel(LlavaNextVideoPreTrainedModel): output_hidden_states=output_hidden_states, return_dict=True, cache_position=cache_position, - **lm_kwargs, + **kwargs, ) - output = LlavaNextVideoModelOutputWithPast( + return LlavaNextVideoModelOutputWithPast( last_hidden_state=outputs.last_hidden_state, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, @@ -695,7 +700,6 @@ class LlavaNextVideoModel(LlavaNextVideoPreTrainedModel): image_hidden_states=image_features if pixel_values is not None else None, video_hidden_states=video_features if pixel_values_videos is not None else None, ) - return output if return_dict else output.to_tuple() def get_video_features( self, @@ -744,6 +748,9 @@ class LlavaNextVideoModel(LlavaNextVideoPreTrainedModel): return video_features +class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... + + @add_start_docstrings( """The LLAVA-NeXT model which consists of a vision backbone and a language model.""", LLAVA_NEXT_VIDEO_START_DOCSTRING, @@ -811,7 +818,7 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextVideoPreTrainedModel, Gene return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, - **lm_kwargs, + **kwargs: Unpack[KwargsForCausalLM], ) -> Union[Tuple, LlavaNextVideoCausalLMOutputWithPast]: r""" pixel_values_videos (`torch.FloatTensor` of shape `(batch_size, num_frames, num_channels, image_size, image_size)): @@ -915,10 +922,10 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextVideoPreTrainedModel, Gene use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, + return_dict=True, cache_position=cache_position, image_sizes=image_sizes, - **lm_kwargs, + **kwargs, ) hidden_states = outputs[0] @@ -928,7 +935,9 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextVideoPreTrainedModel, Gene loss = None if labels is not None: - loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size) + loss = self.loss_function( + logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs + ) return LlavaNextVideoCausalLMOutputWithPast( loss=loss, diff --git a/src/transformers/models/llava_next_video/modular_llava_next_video.py b/src/transformers/models/llava_next_video/modular_llava_next_video.py index 985a69a68ec..cacf167a86e 100644 --- a/src/transformers/models/llava_next_video/modular_llava_next_video.py +++ b/src/transformers/models/llava_next_video/modular_llava_next_video.py @@ -22,6 +22,7 @@ import torch.utils.checkpoint from torch import nn from transformers.models.llava_next.modeling_llava_next import ( + KwargsForCausalLM, LlavaNextCausalLMOutputWithPast, LlavaNextForConditionalGeneration, LlavaNextModel, @@ -31,6 +32,8 @@ from transformers.models.llava_next.modeling_llava_next import ( ) from ...configuration_utils import PretrainedConfig +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...processing_utils import Unpack from ...utils import ( add_start_docstrings_to_model_forward, can_return_tuple, @@ -378,7 +381,7 @@ class LlavaNextVideoModel(LlavaNextModel): output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - **lm_kwargs, + **kwargs: Unpack[FlashAttentionKwargs], ) -> Union[Tuple, LlavaNextVideoModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -463,10 +466,10 @@ class LlavaNextVideoModel(LlavaNextModel): output_hidden_states=output_hidden_states, return_dict=True, cache_position=cache_position, - **lm_kwargs, + **kwargs, ) - output = LlavaNextVideoModelOutputWithPast( + return LlavaNextVideoModelOutputWithPast( last_hidden_state=outputs.last_hidden_state, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, @@ -474,7 +477,6 @@ class LlavaNextVideoModel(LlavaNextModel): image_hidden_states=image_features if pixel_values is not None else None, video_hidden_states=video_features if pixel_values_videos is not None else None, ) - return output if return_dict else output.to_tuple() LLAVA_NEXT_VIDEO_INPUTS_DOCSTRING = r""" @@ -580,7 +582,7 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextForConditionalGeneration): return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, - **lm_kwargs, + **kwargs: Unpack[KwargsForCausalLM], ) -> Union[Tuple, LlavaNextVideoCausalLMOutputWithPast]: r""" pixel_values_videos (`torch.FloatTensor` of shape `(batch_size, num_frames, num_channels, image_size, image_size)): @@ -684,10 +686,10 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextForConditionalGeneration): use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, + return_dict=True, cache_position=cache_position, image_sizes=image_sizes, - **lm_kwargs, + **kwargs, ) hidden_states = outputs[0] @@ -697,7 +699,9 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextForConditionalGeneration): loss = None if labels is not None: - loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size) + loss = self.loss_function( + logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs + ) return LlavaNextVideoCausalLMOutputWithPast( loss=loss, diff --git a/src/transformers/models/llava_onevision/modeling_llava_onevision.py b/src/transformers/models/llava_onevision/modeling_llava_onevision.py index 5ef23387e9f..73700e2448f 100644 --- a/src/transformers/models/llava_onevision/modeling_llava_onevision.py +++ b/src/transformers/models/llava_onevision/modeling_llava_onevision.py @@ -30,9 +30,12 @@ from torch import nn from ...activations import ACT2FN from ...generation import GenerationMixin from ...image_processing_utils import select_best_resolution +from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import BaseModelOutputWithPast, ModelOutput from ...modeling_utils import PreTrainedModel +from ...processing_utils import Unpack from ...utils import ( + LossKwargs, add_start_docstrings, can_return_tuple, is_torchdynamo_compiling, @@ -405,6 +408,7 @@ class LlavaOnevisionPreTrainedModel(PreTrainedModel): _supports_sdpa = True _supports_quantized_cache = True _supports_static_cache = True + _supports_attention_backend = True def _init_weights(self, module): std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range) @@ -570,6 +574,7 @@ class LlavaOnevisionModel(LlavaOnevisionPreTrainedModel): image_features = torch.split(image_features, image_num_patches, dim=0) return image_features + @can_return_tuple @add_start_docstrings(LLAVA_ONEVISION_INPUTS_DOCSTRING) def forward( self, @@ -590,7 +595,7 @@ class LlavaOnevisionModel(LlavaOnevisionPreTrainedModel): output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - **lm_kwargs, + **kwargs: Unpack[FlashAttentionKwargs], ) -> Union[Tuple, LlavaOnevisionModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -681,10 +686,10 @@ class LlavaOnevisionModel(LlavaOnevisionPreTrainedModel): output_hidden_states=output_hidden_states, return_dict=True, cache_position=cache_position, - **lm_kwargs, + **kwargs, ) - output = LlavaOnevisionModelOutputWithPast( + return LlavaOnevisionModelOutputWithPast( last_hidden_state=outputs.last_hidden_state, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, @@ -693,8 +698,6 @@ class LlavaOnevisionModel(LlavaOnevisionPreTrainedModel): video_hidden_states=video_features if pixel_values_videos is not None else None, ) - return output if return_dict else output.to_tuple() - def get_video_features( self, pixel_values: torch.FloatTensor, @@ -756,6 +759,9 @@ class LlavaOnevisionModel(LlavaOnevisionPreTrainedModel): return image_features +class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... + + @add_start_docstrings( """The LLAVA-NeXT model which consists of a vision backbone and a language model.""", LLAVA_ONEVISION_START_DOCSTRING, @@ -824,7 +830,7 @@ class LlavaOnevisionForConditionalGeneration(LlavaOnevisionPreTrainedModel, Gene return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, - **lm_kwargs, + **kwargs: Unpack[KwargsForCausalLM], ) -> Union[Tuple, LlavaOnevisionCausalLMOutputWithPast]: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): @@ -909,7 +915,7 @@ class LlavaOnevisionForConditionalGeneration(LlavaOnevisionPreTrainedModel, Gene return_dict=True, cache_position=cache_position, logits_to_keep=logits_to_keep, - **lm_kwargs, + **kwargs, ) hidden_states = outputs[0] @@ -919,7 +925,9 @@ class LlavaOnevisionForConditionalGeneration(LlavaOnevisionPreTrainedModel, Gene loss = None if labels is not None: - loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size) + loss = self.loss_function( + logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs + ) return LlavaOnevisionCausalLMOutputWithPast( loss=loss, diff --git a/src/transformers/models/llava_onevision/modular_llava_onevision.py b/src/transformers/models/llava_onevision/modular_llava_onevision.py index bc692c10a64..c96d2e36195 100644 --- a/src/transformers/models/llava_onevision/modular_llava_onevision.py +++ b/src/transformers/models/llava_onevision/modular_llava_onevision.py @@ -21,6 +21,7 @@ from torch import nn from transformers.models.llava_next.image_processing_llava_next_fast import LlavaNextImageProcessorFast from transformers.models.llava_next_video.modeling_llava_next_video import ( + KwargsForCausalLM, LlavaNextVideoCausalLMOutputWithPast, LlavaNextVideoForConditionalGeneration, LlavaNextVideoModel, @@ -36,6 +37,8 @@ from ...image_utils import ( OPENAI_CLIP_STD, PILImageResampling, ) +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...processing_utils import Unpack from ...utils import add_start_docstrings, can_return_tuple, is_torchdynamo_compiling, logging @@ -217,6 +220,7 @@ class LlavaOnevisionModel(LlavaNextVideoModel): return video_features + @can_return_tuple @add_start_docstrings(LLAVA_ONEVISION_INPUTS_DOCSTRING) def forward( self, @@ -237,7 +241,7 @@ class LlavaOnevisionModel(LlavaNextVideoModel): output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - **lm_kwargs, + **kwargs: Unpack[FlashAttentionKwargs], ) -> Union[Tuple, LlavaOnevisionModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -328,10 +332,10 @@ class LlavaOnevisionModel(LlavaNextVideoModel): output_hidden_states=output_hidden_states, return_dict=True, cache_position=cache_position, - **lm_kwargs, + **kwargs, ) - output = LlavaOnevisionModelOutputWithPast( + return LlavaOnevisionModelOutputWithPast( last_hidden_state=outputs.last_hidden_state, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, @@ -340,8 +344,6 @@ class LlavaOnevisionModel(LlavaNextVideoModel): video_hidden_states=video_features if pixel_values_videos is not None else None, ) - return output if return_dict else output.to_tuple() - class LlavaOnevisionForConditionalGeneration(LlavaNextVideoForConditionalGeneration): @can_return_tuple @@ -367,7 +369,7 @@ class LlavaOnevisionForConditionalGeneration(LlavaNextVideoForConditionalGenerat return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, - **lm_kwargs, + **kwargs: Unpack[KwargsForCausalLM], ) -> Union[Tuple, LlavaOnevisionCausalLMOutputWithPast]: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): @@ -452,7 +454,7 @@ class LlavaOnevisionForConditionalGeneration(LlavaNextVideoForConditionalGenerat return_dict=True, cache_position=cache_position, logits_to_keep=logits_to_keep, - **lm_kwargs, + **kwargs, ) hidden_states = outputs[0] @@ -462,7 +464,9 @@ class LlavaOnevisionForConditionalGeneration(LlavaNextVideoForConditionalGenerat loss = None if labels is not None: - loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size) + loss = self.loss_function( + logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs + ) return LlavaOnevisionCausalLMOutputWithPast( loss=loss, diff --git a/src/transformers/models/mistral3/modeling_mistral3.py b/src/transformers/models/mistral3/modeling_mistral3.py index 7078631552f..24cc21af609 100644 --- a/src/transformers/models/mistral3/modeling_mistral3.py +++ b/src/transformers/models/mistral3/modeling_mistral3.py @@ -28,9 +28,12 @@ from torch import nn from ...activations import ACT2FN from ...generation import GenerationMixin from ...integrations import use_kernel_forward_from_hub +from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import BaseModelOutputWithPast, ModelOutput from ...modeling_utils import PreTrainedModel +from ...processing_utils import Unpack from ...utils import ( + LossKwargs, add_start_docstrings, add_start_docstrings_to_model_forward, can_return_tuple, @@ -233,6 +236,7 @@ class Mistral3PreTrainedModel(PreTrainedModel): _supports_sdpa = True _supports_quantized_cache = True _supports_static_cache = True + _supports_attention_backend = True def _init_weights(self, module): # important: this ported version of Mistral3 isn't meant for training from scratch - only @@ -251,6 +255,144 @@ class Mistral3PreTrainedModel(PreTrainedModel): module.weight.data.fill_(1.0) +@add_start_docstrings( + """The Mistral3 model which consists of a vision backbone and a language model, without a language modeling head.""", + MISTRAL3_START_DOCSTRING, +) +class Mistral3Model(Mistral3PreTrainedModel): + _checkpoint_conversion_mapping = {"language_model.model": "language_model"} + + def __init__(self, config: Mistral3Config): + super().__init__(config) + self.vision_tower = AutoModel.from_config(config.vision_config) + + self.multi_modal_projector = Mistral3MultiModalProjector(config) + self.language_model = AutoModel.from_config(config.text_config) + self.post_init() + + def get_input_embeddings(self): + return self.language_model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.language_model.set_input_embeddings(value) + + def get_image_features( + self, + pixel_values: torch.FloatTensor, + vision_feature_layer: Union[int, List[int]], + image_sizes: torch.Tensor, + **kwargs, + ): + """ + Obtains image last hidden states from the vision tower and apply multimodal projection. + + Args: + pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`): + The tensors corresponding to the input images. + vision_feature_layer (`Union[int, List[int]]`): + The index of the layer to select the vision feature. If multiple indices are provided, + the vision feature of the corresponding indices will be concatenated to form the + vision features. + image_sizes (`torch.Tensor`): + Tensor containing the image sizes as returned by the processor. + Returns: + image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`). + """ + kwargs = {k: v for k, v in kwargs.items() if v is not None} + # this is not memory efficient at all (output_hidden_states=True) will save all the hidden states. + image_outputs = self.vision_tower(pixel_values, image_sizes=image_sizes, output_hidden_states=True, **kwargs) + # If we have one vision feature layer, return the corresponding hidden states, + # otherwise, select the hidden states of each feature layer and concatenate them + if isinstance(vision_feature_layer, int): + selected_image_feature = image_outputs.hidden_states[vision_feature_layer] + else: + hs_pool = [image_outputs.hidden_states[layer_idx] for layer_idx in vision_feature_layer] + selected_image_feature = torch.cat(hs_pool, dim=-1) + + image_features = self.multi_modal_projector(selected_image_feature.squeeze(0), image_sizes) + return image_features + + @can_return_tuple + def forward( + self, + input_ids: torch.LongTensor = None, + pixel_values: torch.FloatTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + vision_feature_layer: Optional[Union[int, List[int]]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + image_sizes: torch.Tensor = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> Union[Tuple, Mistral3ModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + vision_feature_layer = ( + vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer + ) + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if pixel_values is not None and inputs_embeds is not None: + raise ValueError( + "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one" + ) + + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + + if pixel_values is not None: + image_features = self.get_image_features( + pixel_values=pixel_values, + vision_feature_layer=vision_feature_layer, + image_sizes=image_sizes, + ) + + special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1) + special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) + if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel(): + n_image_tokens = (input_ids == self.config.image_token_id).sum() + n_image_features = image_features.shape[0] * image_features.shape[1] + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" + ) + image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) + + outputs = self.language_model( + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=True, + cache_position=cache_position, + **kwargs, + ) + + return Mistral3ModelOutputWithPast( + last_hidden_state=outputs.last_hidden_state, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + image_hidden_states=image_features if pixel_values is not None else None, + ) + + +class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... + + MISTRAL3_INPUTS_DOCSTRING = r""" Args: input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): @@ -328,142 +470,6 @@ MISTRAL3_INPUTS_DOCSTRING = r""" """ -@add_start_docstrings( - """The Mistral3 model which consists of a vision backbone and a language model, without a language modeling head.""", - MISTRAL3_START_DOCSTRING, -) -class Mistral3Model(Mistral3PreTrainedModel): - _checkpoint_conversion_mapping = {"language_model.model": "language_model"} - - def __init__(self, config: Mistral3Config): - super().__init__(config) - self.vision_tower = AutoModel.from_config(config.vision_config) - - self.multi_modal_projector = Mistral3MultiModalProjector(config) - self.language_model = AutoModel.from_config(config.text_config) - self.post_init() - - def get_input_embeddings(self): - return self.language_model.get_input_embeddings() - - def set_input_embeddings(self, value): - self.language_model.set_input_embeddings(value) - - def get_image_features( - self, - pixel_values: torch.FloatTensor, - vision_feature_layer: Union[int, List[int]], - image_sizes: torch.Tensor, - **kwargs, - ): - """ - Obtains image last hidden states from the vision tower and apply multimodal projection. - - Args: - pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`): - The tensors corresponding to the input images. - vision_feature_layer (`Union[int, List[int]]`): - The index of the layer to select the vision feature. If multiple indices are provided, - the vision feature of the corresponding indices will be concatenated to form the - vision features. - image_sizes (`torch.Tensor`): - Tensor containing the image sizes as returned by the processor. - Returns: - image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`). - """ - kwargs = {k: v for k, v in kwargs.items() if v is not None} - # this is not memory efficient at all (output_hidden_states=True) will save all the hidden states. - image_outputs = self.vision_tower(pixel_values, image_sizes=image_sizes, output_hidden_states=True, **kwargs) - # If we have one vision feature layer, return the corresponding hidden states, - # otherwise, select the hidden states of each feature layer and concatenate them - if isinstance(vision_feature_layer, int): - selected_image_feature = image_outputs.hidden_states[vision_feature_layer] - else: - hs_pool = [image_outputs.hidden_states[layer_idx] for layer_idx in vision_feature_layer] - selected_image_feature = torch.cat(hs_pool, dim=-1) - - image_features = self.multi_modal_projector(selected_image_feature.squeeze(0), image_sizes) - return image_features - - @add_start_docstrings_to_model_forward(MISTRAL3_INPUTS_DOCSTRING) - def forward( - self, - input_ids: torch.LongTensor = None, - pixel_values: torch.FloatTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - vision_feature_layer: Optional[Union[int, List[int]]] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, - image_sizes: torch.Tensor = None, - **lm_kwargs, - ) -> Union[Tuple, Mistral3ModelOutputWithPast]: - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - vision_feature_layer = ( - vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer - ) - - if (input_ids is None) ^ (inputs_embeds is not None): - raise ValueError("You must specify exactly one of input_ids or inputs_embeds") - - if pixel_values is not None and inputs_embeds is not None: - raise ValueError( - "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one" - ) - - if inputs_embeds is None: - inputs_embeds = self.get_input_embeddings()(input_ids) - - if pixel_values is not None: - image_features = self.get_image_features( - pixel_values=pixel_values, - vision_feature_layer=vision_feature_layer, - image_sizes=image_sizes, - ) - - special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1) - special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) - if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel(): - n_image_tokens = (input_ids == self.config.image_token_id).sum() - n_image_features = image_features.shape[0] * image_features.shape[1] - raise ValueError( - f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" - ) - image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) - inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) - - outputs = self.language_model( - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=True, - cache_position=cache_position, - **lm_kwargs, - ) - - output = Mistral3ModelOutputWithPast( - last_hidden_state=outputs.last_hidden_state, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - image_hidden_states=image_features if pixel_values is not None else None, - ) - return output if return_dict else output.to_tuple() - - @add_start_docstrings( """The MISTRAL3 model which consists of a vision backbone and a language model.""", MISTRAL3_START_DOCSTRING, @@ -526,8 +532,8 @@ class Mistral3ForConditionalGeneration(Mistral3PreTrainedModel, GenerationMixin) return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, - image_sizes: torch.Tensor = None, - **lm_kwargs, + image_sizes: Optional[torch.Tensor] = None, + **kwargs: Unpack[KwargsForCausalLM], ) -> Union[Tuple, Mistral3CausalLMOutputWithPast]: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): @@ -585,7 +591,7 @@ class Mistral3ForConditionalGeneration(Mistral3PreTrainedModel, GenerationMixin) return_dict=True, cache_position=cache_position, image_sizes=image_sizes, - **lm_kwargs, + **kwargs, ) hidden_states = outputs[0] @@ -595,7 +601,9 @@ class Mistral3ForConditionalGeneration(Mistral3PreTrainedModel, GenerationMixin) loss = None if labels is not None: - loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size) + loss = self.loss_function( + logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs + ) return Mistral3CausalLMOutputWithPast( loss=loss, diff --git a/src/transformers/models/mistral3/modular_mistral3.py b/src/transformers/models/mistral3/modular_mistral3.py index 5ef6663bde0..11dcfa2c3a4 100644 --- a/src/transformers/models/mistral3/modular_mistral3.py +++ b/src/transformers/models/mistral3/modular_mistral3.py @@ -19,6 +19,8 @@ import torch from torch import nn from ...activations import ACT2FN +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...processing_utils import Unpack from ...utils import ( add_start_docstrings_to_model_forward, can_return_tuple, @@ -27,6 +29,7 @@ from ...utils import ( replace_return_docstrings, ) from ..llava.modeling_llava import ( + KwargsForCausalLM, LlavaCausalLMOutputWithPast, LlavaForConditionalGeneration, LlavaModel, @@ -174,6 +177,7 @@ class Mistral3Model(LlavaModel): image_features = self.multi_modal_projector(selected_image_feature.squeeze(0), image_sizes) return image_features + @can_return_tuple def forward( self, input_ids: torch.LongTensor = None, @@ -189,7 +193,7 @@ class Mistral3Model(LlavaModel): return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, image_sizes: torch.Tensor = None, - **lm_kwargs, + **kwargs: Unpack[FlashAttentionKwargs], ) -> Union[Tuple, Mistral3ModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -239,17 +243,16 @@ class Mistral3Model(LlavaModel): output_hidden_states=output_hidden_states, return_dict=True, cache_position=cache_position, - **lm_kwargs, + **kwargs, ) - output = Mistral3ModelOutputWithPast( + return Mistral3ModelOutputWithPast( last_hidden_state=outputs.last_hidden_state, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, image_hidden_states=image_features if pixel_values is not None else None, ) - return output if return_dict else output.to_tuple() class Mistral3ForConditionalGeneration(LlavaForConditionalGeneration): @@ -271,8 +274,8 @@ class Mistral3ForConditionalGeneration(LlavaForConditionalGeneration): return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, - image_sizes: torch.Tensor = None, - **lm_kwargs, + image_sizes: Optional[torch.Tensor] = None, + **kwargs: Unpack[KwargsForCausalLM], ) -> Union[Tuple, Mistral3CausalLMOutputWithPast]: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): @@ -330,7 +333,7 @@ class Mistral3ForConditionalGeneration(LlavaForConditionalGeneration): return_dict=True, cache_position=cache_position, image_sizes=image_sizes, - **lm_kwargs, + **kwargs, ) hidden_states = outputs[0] @@ -340,7 +343,9 @@ class Mistral3ForConditionalGeneration(LlavaForConditionalGeneration): loss = None if labels is not None: - loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size) + loss = self.loss_function( + logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs + ) return Mistral3CausalLMOutputWithPast( loss=loss, diff --git a/src/transformers/models/mllama/modeling_mllama.py b/src/transformers/models/mllama/modeling_mllama.py index 79278c9892e..300b172c2d8 100644 --- a/src/transformers/models/mllama/modeling_mllama.py +++ b/src/transformers/models/mllama/modeling_mllama.py @@ -15,21 +15,24 @@ """PyTorch Mllama model.""" import math -from typing import List, Optional, Tuple, Union +from typing import Callable, List, Optional, Tuple, Union import torch import torch.nn.functional as F import torch.utils.checkpoint from torch import nn -from ... import PreTrainedModel from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPast, CausalLMOutputWithPast from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack from ...utils import ( + LossKwargs, add_start_docstrings, add_start_docstrings_to_model_forward, can_return_tuple, @@ -180,13 +183,56 @@ class MllamaVisionMLP(nn.Module): return hidden_states +# Copied from transformers.models.llama.modeling_llama.repeat_kv +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +# Copied from transformers.models.llama.modeling_llama.eager_attention_forward +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + class MllamaVisionAttention(nn.Module): def __init__(self, config: MllamaVisionConfig): super().__init__() + self.config = config self.embed_dim = config.hidden_size self.num_heads = config.attention_heads self.head_dim = config.hidden_size // config.attention_heads + self.scaling = self.head_dim**-0.5 + self.num_key_value_groups = 1 self.q_proj = nn.Linear(self.embed_dim, self.num_heads * self.head_dim, bias=False) self.k_proj = nn.Linear(self.embed_dim, self.num_heads * self.head_dim, bias=False) @@ -198,6 +244,7 @@ class MllamaVisionAttention(nn.Module): hidden_state: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, + **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: query = self.q_proj(hidden_state) key = self.k_proj(hidden_state) @@ -210,73 +257,35 @@ class MllamaVisionAttention(nn.Module): key = key.view(batch_size, kv_seq_len, self.num_heads, self.head_dim).transpose(1, 2) value = value.view(batch_size, kv_seq_len, self.num_heads, self.head_dim).transpose(1, 2) - attn_weights = torch.matmul(query, key.transpose(2, 3)) / math.sqrt(self.head_dim) + attention_interface: Callable = eager_attention_forward - if attention_mask is not None: # no matter the length, we just slice it - causal_mask = attention_mask[:, :, :, : key.shape[-2]] - attn_weights = attn_weights + causal_mask + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and output_attentions: + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) - attn_output = torch.matmul(attn_weights, value) + attn_output, attn_weights = attention_interface( + self, + query, + key, + value, + attention_mask, + dropout=0.0, + scaling=self.scaling, + **kwargs, + ) - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(batch_size, q_seq_len, -1) - - output = self.o_proj(attn_output) + attn_output = attn_output.reshape(batch_size, q_seq_len, -1).contiguous() + attn_output = self.o_proj(attn_output) if not output_attentions: attn_weights = None - return output, attn_weights - - -class MllamaVisionSdpaAttention(MllamaVisionAttention): - # Adapted from MllamaVisionAttention - def forward( - self, - hidden_state: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - output_attentions: Optional[bool] = None, - ) -> torch.Tensor: - # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. - if output_attentions: - logger.warning_once( - "MllamaModel is using MllamaVisionSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " - 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - return super().forward( - hidden_state=hidden_state, - attention_mask=attention_mask, - output_attentions=output_attentions, - ) - - query = self.q_proj(hidden_state) - key = self.k_proj(hidden_state) - value = self.v_proj(hidden_state) - - batch_size, q_seq_len, _ = query.shape - _, kv_seq_len, _ = key.shape - - query = query.view(batch_size, q_seq_len, self.num_heads, self.head_dim) - key = key.view(batch_size, kv_seq_len, self.num_heads, self.head_dim) - value = value.view(batch_size, kv_seq_len, self.num_heads, self.head_dim) - - query = query.transpose(1, 2) - key = key.transpose(1, 2) - value = value.transpose(1, 2) - - attn_output = F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask) - - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(batch_size, q_seq_len, -1) - - output = self.o_proj(attn_output) - - return output, None - - -MLLAMA_VISION_ATTENTION_CLASSES = {"eager": MllamaVisionAttention, "sdpa": MllamaVisionSdpaAttention} + return attn_output, attn_weights class MllamaVisionEncoderLayer(nn.Module): @@ -288,7 +297,7 @@ class MllamaVisionEncoderLayer(nn.Module): self.is_gated = is_gated self.intermediate_size = config.intermediate_size - self.self_attn = MLLAMA_VISION_ATTENTION_CLASSES[config._attn_implementation](config) + self.self_attn = MllamaVisionAttention(config) self.mlp = MllamaVisionMLP(config) self.input_layernorm = nn.LayerNorm(self.hidden_size, eps=config.norm_eps) @@ -453,6 +462,7 @@ class MllamaTextCrossAttention(nn.Module): self.head_dim = config.hidden_size // self.num_heads self.layer_idx = layer_idx self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.scaling = self.head_dim**-0.5 self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) @@ -471,6 +481,7 @@ class MllamaTextCrossAttention(nn.Module): output_attentions: bool = False, use_cache: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, + **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" bsz, q_len, _ = hidden_states.size() @@ -503,17 +514,29 @@ class MllamaTextCrossAttention(nn.Module): "Cross attention layer can't find neither `cross_attn_states` nor cached values for key/values!" ) - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + attention_interface: Callable = eager_attention_forward - if attention_mask is not None: # no matter the length, we just slice it - causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] - attn_weights = attn_weights + causal_mask + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and output_attentions: + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) - attn_output = torch.matmul(attn_weights, value_states) - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(bsz, q_len, -1) + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.dropout, + scaling=self.scaling, + **kwargs, + ) + + attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() attn_output = self.o_proj(attn_output) if not output_attentions: @@ -522,100 +545,6 @@ class MllamaTextCrossAttention(nn.Module): return attn_output, attn_weights, past_key_value -class MllamaTextCrossSdpaAttention(MllamaTextCrossAttention): - """ - Mllama attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from - `MllamaTextCrossAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to - SDPA API. - """ - - # Adapted from MllamaTextCrossAttention.forward - def forward( - self, - hidden_states: torch.Tensor, - cross_attention_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Cache] = None, - attention_mask: Optional[torch.Tensor] = None, - output_attentions: bool = False, - use_cache: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - """Input shape: Batch x Time x Channel""" - if output_attentions: - # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. - logger.warning_once( - "MllamaModel is using MllamaTextCrossSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " - 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - return super().forward( - hidden_states=hidden_states, - cross_attention_states=cross_attention_states, - attention_mask=attention_mask, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - ) - - bsz, q_len, _ = hidden_states.size() - query_states = self.q_proj(hidden_states) - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - query_states = self.q_norm(query_states) - - if cross_attention_states is not None: - key_states = self.k_proj(cross_attention_states) - value_states = self.v_proj(cross_attention_states) - key_states = key_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2) - - if past_key_value is not None: - # if we have a new image + new tokens, we only computed key_states on that new image - # we still update the cross key states, past_image, new_image. And use it! - key_states, value_states = past_key_value.update( - key_states, value_states, self.layer_idx, {"cache_position": cache_position} - ) - elif cache_position[0] != 0: - key_states, value_states = ( - past_key_value.key_cache[self.layer_idx], - past_key_value.value_cache[self.layer_idx], - ) - else: - raise ValueError( - "Cross attention layer can't find neither `cross_attn_states` nor cached values for key/values!" - ) - - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - key_states = self.k_norm(key_states) - - # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, - # Reference: https://github.com/pytorch/pytorch/issues/112577. - if query_states.device.type == "cuda" and attention_mask is not None: - query_states = query_states.contiguous() - key_states = key_states.contiguous() - value_states = value_states.contiguous() - - # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment - # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. - is_causal = True if attention_mask is None and q_len > 1 else False - - attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, - key_states, - value_states, - attn_mask=attention_mask, - dropout_p=self.dropout if self.training else 0.0, - is_causal=is_causal, - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(bsz, q_len, -1) - attn_output = self.o_proj(attn_output) - - return attn_output, None, past_key_value - - # Copied from transformers.models.llama.modeling_llama.rotate_half def rotate_half(x): """Rotates half the hidden dims of the input.""" @@ -652,19 +581,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): return q_embed, k_embed -# Copied from transformers.models.llama.modeling_llama.repeat_kv -def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: - """ - This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, - num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) - """ - batch, num_key_value_heads, slen, head_dim = hidden_states.shape - if n_rep == 1: - return hidden_states - hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) - return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) - - class MllamaTextSelfAttention(nn.Module): def __init__(self, config: MllamaTextConfig, layer_idx: int): super().__init__() @@ -675,6 +591,7 @@ class MllamaTextSelfAttention(nn.Module): self.num_key_value_heads = config.num_key_value_heads self.head_dim = config.hidden_size // self.num_heads self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.scaling = self.head_dim**-0.5 self.rope_theta = config.rope_theta self.layer_idx = layer_idx @@ -712,23 +629,29 @@ class MllamaTextSelfAttention(nn.Module): cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) + attention_interface: Callable = eager_attention_forward - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and output_attentions: + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - if attention_mask is not None: # no matter the length, we just slice it - causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] - attn_weights = attn_weights + causal_mask - - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) - attn_output = torch.matmul(attn_weights, value_states) - - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.view(bsz, q_len, -1) + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.dropout, + scaling=self.scaling, + **kwargs, + ) + attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() attn_output = self.o_proj(attn_output) if not output_attentions: @@ -737,92 +660,6 @@ class MllamaTextSelfAttention(nn.Module): return attn_output, attn_weights, past_key_value -class MllamaTextSelfSdpaAttention(MllamaTextSelfAttention): - # Adapted from MllamaTextSelfAttention - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: torch.Tensor, - position_embeddings: torch.Tensor, - output_attentions: bool = False, - use_cache: bool = False, - past_key_value=None, - cache_position=None, - **kwargs, - ): - if output_attentions: - # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. - logger.warning_once( - "MllamaModel is using MllamaTextSelfSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " - 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - return super().forward( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_embeddings=position_embeddings, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - **kwargs, - ) - - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - - cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - - if past_key_value is not None: - # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - causal_mask = attention_mask - if attention_mask is not None: - causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] - - # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, - # Reference: https://github.com/pytorch/pytorch/issues/112577. - if query_states.device.type == "cuda" and causal_mask is not None: - query_states = query_states.contiguous() - key_states = key_states.contiguous() - value_states = value_states.contiguous() - - # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment - # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. - is_causal = True if causal_mask is None and q_len > 1 else False - - attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, - key_states, - value_states, - attn_mask=causal_mask, - dropout_p=self.dropout if self.training else 0.0, - is_causal=is_causal, - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.view(bsz, q_len, -1) - - attn_output = self.o_proj(attn_output) - return attn_output, None, past_key_value - - -MLLAMA_TEXT_CROSS_ATTENTION_CLASSES = {"eager": MllamaTextCrossAttention, "sdpa": MllamaTextCrossSdpaAttention} -MLLAMA_TEXT_ATTENTION_CLASSES = {"eager": MllamaTextSelfAttention, "sdpa": MllamaTextSelfSdpaAttention} - - # Copied from transformers.models.gemma2.modeling_gemma2.Gemma2MLP with Gemma2->MllamaText class MllamaTextMLP(nn.Module): def __init__(self, config): @@ -847,7 +684,7 @@ class MllamaSelfAttentionDecoderLayer(nn.Module): super().__init__() self.hidden_size = config.hidden_size - self.self_attn = MLLAMA_TEXT_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx) + self.self_attn = MllamaTextSelfAttention(config=config, layer_idx=layer_idx) self.mlp = MllamaTextMLP(config) self.input_layernorm = MllamaTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -868,6 +705,7 @@ class MllamaSelfAttentionDecoderLayer(nn.Module): use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ Args: @@ -905,6 +743,7 @@ class MllamaSelfAttentionDecoderLayer(nn.Module): use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, + **kwargs, ) hidden_states = residual + hidden_states @@ -931,7 +770,7 @@ class MllamaCrossAttentionDecoderLayer(torch.nn.Module): def __init__(self, config: MllamaTextConfig, layer_idx: int) -> None: super().__init__() self.layer_idx = layer_idx - self.cross_attn = MLLAMA_TEXT_CROSS_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx=layer_idx) + self.cross_attn = MllamaTextCrossAttention(config, layer_idx=layer_idx) self.input_layernorm = MllamaTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.cross_attn_attn_gate = torch.nn.Parameter(torch.zeros(1)) @@ -953,6 +792,7 @@ class MllamaCrossAttentionDecoderLayer(torch.nn.Module): use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[torch.Tensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor]: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) @@ -964,6 +804,7 @@ class MllamaCrossAttentionDecoderLayer(torch.nn.Module): past_key_value=past_key_value, output_attentions=output_attentions, cache_position=cache_position, + **kwargs, ) hidden_states = residual + self.cross_attn_attn_gate.tanh() * hidden_states @@ -1026,7 +867,9 @@ class MllamaPreTrainedModel(PreTrainedModel): _supports_cache_class = True _supports_static_cache = False # static cache cannot have different shapes for each layer _supports_sdpa = True + _supports_flash_attn_2 = True _supports_quantized_cache = True + _supports_attention_backend = True def _init_weights(self, module): std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range) @@ -1694,6 +1537,7 @@ class MllamaTextModel(MllamaPreTrainedModel): output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], ) -> Union[Tuple, BaseModelOutputWithPast]: """ @@ -1804,6 +1648,7 @@ class MllamaTextModel(MllamaPreTrainedModel): use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, + **kwargs, ) hidden_states = layer_outputs[0] @@ -1832,6 +1677,9 @@ class MllamaTextModel(MllamaPreTrainedModel): ) +class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... + + @add_start_docstrings( """The Mllama Text Model with a language modeling head on top.""", MLLAMA_START_DOCSTRING, @@ -1888,7 +1736,7 @@ class MllamaForCausalLM(MllamaPreTrainedModel, GenerationMixin): return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, - **loss_kwargs, + **kwargs: Unpack[KwargsForCausalLM], ) -> Union[Tuple, CausalLMOutputWithPast]: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): @@ -1945,6 +1793,7 @@ class MllamaForCausalLM(MllamaPreTrainedModel, GenerationMixin): output_hidden_states=output_hidden_states, return_dict=return_dict, cache_position=cache_position, + **kwargs, ) hidden_states = outputs[0] @@ -1953,7 +1802,7 @@ class MllamaForCausalLM(MllamaPreTrainedModel, GenerationMixin): loss = None if labels is not None: - loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs) + loss = self.loss_function(logits, labels, self.vocab_size, **kwargs) if not return_dict: output = (logits,) + outputs[1:] @@ -1999,6 +1848,7 @@ class MllamaModel(MllamaPreTrainedModel): def set_input_embeddings(self, value): self.language_model.set_input_embeddings(value) + @can_return_tuple @add_start_docstrings_to_model_forward(MLLAMA_INPUTS_DOCSTRING) def forward( self, @@ -2017,6 +1867,7 @@ class MllamaModel(MllamaPreTrainedModel): output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], ) -> Union[Tuple, CausalLMOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -2079,15 +1930,15 @@ class MllamaModel(MllamaPreTrainedModel): output_attentions=output_attentions, return_dict=True, cache_position=cache_position, + **kwargs, ) - output = BaseModelOutputWithPast( + return BaseModelOutputWithPast( last_hidden_state=outputs.last_hidden_state, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) - return output if return_dict else output.to_tuple() @add_start_docstrings( @@ -2153,7 +2004,7 @@ class MllamaForConditionalGeneration(MllamaPreTrainedModel, GenerationMixin): return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, - **loss_kwargs, + **kwargs: Unpack[KwargsForCausalLM], ) -> Union[Tuple, CausalLMOutputWithPast]: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): @@ -2220,6 +2071,7 @@ class MllamaForConditionalGeneration(MllamaPreTrainedModel, GenerationMixin): output_hidden_states=output_hidden_states, return_dict=True, cache_position=cache_position, + **kwargs, ) hidden_states = outputs[0] @@ -2229,7 +2081,7 @@ class MllamaForConditionalGeneration(MllamaPreTrainedModel, GenerationMixin): loss = None if labels is not None: - loss = self.loss_function(logits, labels, self.config.text_config.vocab_size, **loss_kwargs) + loss = self.loss_function(logits, labels, self.config.text_config.vocab_size, **kwargs) return CausalLMOutputWithPast( loss=loss, diff --git a/src/transformers/models/moshi/configuration_moshi.py b/src/transformers/models/moshi/configuration_moshi.py index 1b31141f020..1da69740a7e 100644 --- a/src/transformers/models/moshi/configuration_moshi.py +++ b/src/transformers/models/moshi/configuration_moshi.py @@ -236,7 +236,7 @@ class MoshiConfig(PretrainedConfig): model_type = "moshi" keys_to_ignore_at_inference = ["past_key_values"] - sub_configs = {"audio_encoder_config": AutoConfig} + sub_configs = {"audio_encoder_config": AutoConfig, "depth_decoder_config": MoshiDepthConfig} def __init__( self, diff --git a/src/transformers/models/moshi/modeling_moshi.py b/src/transformers/models/moshi/modeling_moshi.py index ea436706348..4bc31452a3a 100644 --- a/src/transformers/models/moshi/modeling_moshi.py +++ b/src/transformers/models/moshi/modeling_moshi.py @@ -1907,7 +1907,7 @@ class MoshiForConditionalGeneration(MoshiPreTrainedModel, GenerationMixin): self.audio_encoder = AutoModel.from_config(config.audio_encoder_config) self.decoder = MoshiForCausalLM(config) - self.depth_decoder = MoshiDepthDecoder(config.depth_decoder_config) + self.depth_decoder = MoshiDepthDecoder._from_config(config.depth_decoder_config) self.num_codebooks = config.num_codebooks self.post_init() diff --git a/src/transformers/models/opt/modeling_opt.py b/src/transformers/models/opt/modeling_opt.py index 90d8418db37..c26e46ba5a1 100644 --- a/src/transformers/models/opt/modeling_opt.py +++ b/src/transformers/models/opt/modeling_opt.py @@ -14,7 +14,7 @@ # limitations under the License. """PyTorch OPT model.""" -from typing import List, Optional, Tuple, Union +from typing import Callable, List, Optional, Tuple, Union import torch import torch.utils.checkpoint @@ -27,18 +27,21 @@ from ...generation import GenerationMixin from ...modeling_attn_mask_utils import ( AttentionMaskConverter, ) -from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available +from ...modeling_flash_attention_utils import FlashAttentionKwargs, is_flash_attn_available from ...modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, QuestionAnsweringModelOutput, SequenceClassifierOutputWithPast, ) -from ...modeling_utils import PreTrainedModel +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack from ...utils import ( + LossKwargs, add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, + can_return_tuple, is_torch_flex_attn_available, logging, replace_return_docstrings, @@ -53,7 +56,7 @@ if is_torch_flex_attn_available(): if is_flash_attn_available(): - from ...modeling_flash_attention_utils import _flash_attention_forward + pass logger = logging.get_logger(__name__) @@ -98,6 +101,30 @@ class OPTLearnedPositionalEmbedding(nn.Embedding): return super().forward(position_ids + self.offset) +# Copied from transformers.models.siglip.modeling_siglip.eager_attention_forward +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + class OPTAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -143,9 +170,8 @@ class OPTAttention(nn.Module): attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, - # isn't needed in normal attention, but needed in flash attention so to keep the signature same - position_ids: Optional[torch.Tensor] = None, cache_position: Optional[torch.Tensor] = None, + **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]: """Input shape: Batch x Time x Channel""" bsz, tgt_len, _ = hidden_states.size() @@ -165,206 +191,35 @@ class OPTAttention(nn.Module): key_states, value_states, self.layer_idx, {"cache_position": cache_position} ) - attn_weights = torch.matmul(query_states, key_states.transpose(3, 2)) - if attention_mask is not None: - causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] - attn_weights = attn_weights + causal_mask + attention_interface: Callable = eager_attention_forward - # upcast to fp32 if the weights are in fp16. Please see https://github.com/huggingface/transformers/pull/17437 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - - if layer_head_mask is not None: - if layer_head_mask.size() != (self.num_heads,): - raise ValueError( - f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" - f" {layer_head_mask.size()}" + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and output_attentions: + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' ) - attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights - - attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) - attn_output = torch.matmul(attn_probs, value_states) - - attn_output = attn_output.transpose(1, 2).contiguous() - - # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be - # partitioned aross GPUs when using tensor-parallelism. - attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) - attn_output = self.out_proj(attn_output) - - return attn_output, attn_probs, past_key_value - - -class OptFlashAttention2(OPTAttention): - """ - OPT flash attention module. This module inherits from `OPTAttention` as the weights of the module stays untouched. - The only required change would be on the forward pass where it needs to correctly call the public API of flash - attention and deal with padding tokens in case the input contains any of them. - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. - # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. - # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask() - - def forward( - self, - hidden_states: torch.Tensor, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - attention_mask: Optional[torch.Tensor] = None, - layer_head_mask: Optional[torch.Tensor] = None, - output_attentions: bool = False, - position_ids: Optional[torch.Tensor] = None, - cache_position: Optional[torch.Tensor] = None, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - """Input shape: Batch x Time x Channel""" - - bsz, query_length, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - query_states = query_states.view(bsz, -1, self.num_heads, self.head_dim) - - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - key_states = key_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) - - if past_key_value is not None: - # save all key/value_states to cache to be re-used for fast auto-regressive generation - key_states, value_states = past_key_value.update( - key_states, value_states, self.layer_idx, {"cache_position": cache_position} - ) - - attn_dropout = self.dropout if self.training else 0.0 - - # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache - # to be able to avoid many of these transpose/reshape/view. - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) - - # In PEFT, usually we cast the layer norms in float32 for training stability reasons - # therefore the input hidden states gets silently casted in float32. Hence, we need - # cast them back in float16 just to be sure everything works as expected. - input_dtype = query_states.dtype - if input_dtype == torch.float32: - if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() - # Handle the case where the model is quantized - elif hasattr(self.config, "_pre_quantization_dtype"): - target_dtype = self.config._pre_quantization_dtype else: - target_dtype = self.q_proj.weight.dtype + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - logger.warning_once( - f"The input hidden states seems to be silently casted in float32, this might be related to" - f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" - f" {target_dtype}." - ) - - query_states = query_states.to(target_dtype) - key_states = key_states.to(target_dtype) - value_states = value_states.to(target_dtype) - - attn_output = _flash_attention_forward( + attn_output, attn_weights = attention_interface( + self, query_states, key_states, value_states, attention_mask, - query_length, - position_ids=position_ids, - dropout=attn_dropout, - is_causal=self.is_causal, - use_top_left_mask=self._flash_attn_uses_top_left_mask, + dropout=0.0 if not self.training else self.dropout, + scaling=self.scaling, + **kwargs, ) - attn_weights_reshaped = attn_output.reshape(bsz, query_length, self.num_heads * self.head_dim) - attn_output = self.out_proj(attn_weights_reshaped) - - if not output_attentions: - attn_weights_reshaped = None - - return attn_output, attn_weights_reshaped, past_key_value - - -class OPTSdpaAttention(OPTAttention): - """ - OPT sdpa attention module. This module inherits from `OPTAttention` as the weights of the module stays untouched. - The only required change would be on the forward pass where it needs to correctly call the public API of sdpa - attention and deal with padding tokens in case the input contains any of them. - """ - - def forward( - self, - hidden_states: torch.Tensor, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - attention_mask: Optional[torch.Tensor] = None, - layer_head_mask: Optional[torch.Tensor] = None, - output_attentions: bool = False, - position_ids: Optional[torch.Tensor] = None, - cache_position: Optional[torch.Tensor] = None, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - if output_attentions or layer_head_mask is not None: - logger.warning_once( - "OPTModel is using SDPA attention, which currently does not support output_attentions=True." - 'failing back to eager attention. remove warning using attn_implementation="eager".' - ) - - return super().forward( - hidden_states=hidden_states, - attention_mask=attention_mask, - layer_head_mask=layer_head_mask, - past_key_value=past_key_value, - output_attentions=output_attentions, - cache_position=cache_position, - ) - - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - query_states = query_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) - - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - key_states = key_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) - - if past_key_value is not None: - # save all key/value_states to cache to be re-used for fast auto-regressive generation - key_states, value_states = past_key_value.update( - key_states, value_states, self.layer_idx, {"cache_position": cache_position} - ) - - causal_mask = attention_mask - if attention_mask is not None: - causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] - - # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment - # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. - is_causal = True if causal_mask is None and q_len > 1 else False - - attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, - key_states, - value_states, - attn_mask=causal_mask, - dropout_p=self.dropout if self.training else 0.0, - is_causal=is_causal, - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.view(bsz, q_len, -1) + attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous() attn_output = self.out_proj(attn_output) - return attn_output, None, past_key_value + if not output_attentions: + attn_weights = None - -OPT_ATTENTION_CLASSES = { - "eager": OPTAttention, - "flash_attention_2": OptFlashAttention2, - "sdpa": OPTSdpaAttention, -} + return attn_output, attn_weights, past_key_value class OPTDecoderLayer(nn.Module): @@ -372,7 +227,7 @@ class OPTDecoderLayer(nn.Module): super().__init__() self.embed_dim = config.hidden_size - self.self_attn = OPT_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx) + self.self_attn = OPTAttention(config=config, layer_idx=layer_idx) self.do_layer_norm_before = config.do_layer_norm_before self.dropout = config.dropout @@ -395,6 +250,7 @@ class OPTDecoderLayer(nn.Module): use_cache: Optional[bool] = False, position_ids: Optional[torch.LongTensor] = None, cache_position: Optional[torch.Tensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ Args: @@ -429,6 +285,7 @@ class OPTDecoderLayer(nn.Module): layer_head_mask=layer_head_mask, output_attentions=output_attentions, cache_position=cache_position, + **kwargs, ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states @@ -495,8 +352,10 @@ class OPTPreTrainedModel(PreTrainedModel): base_model_prefix = "model" supports_gradient_checkpointing = True _no_split_modules = ["OPTDecoderLayer"] + _supports_attention_backend = True _supports_flash_attn_2 = True _supports_sdpa = True + _supports_flex_attn = True _supports_cache_class = True _supports_quantized_cache = True _supports_static_cache = True @@ -763,6 +622,7 @@ class OPTDecoder(OPTPreTrainedModel): return causal_mask + @can_return_tuple def forward( self, input_ids: Optional[torch.LongTensor] = None, @@ -776,6 +636,7 @@ class OPTDecoder(OPTPreTrainedModel): return_dict: Optional[bool] = None, position_ids: Optional[torch.LongTensor] = None, cache_position: Optional[torch.Tensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], ) -> Union[Tuple, BaseModelOutputWithPast]: r""" Args: @@ -942,6 +803,7 @@ class OPTDecoder(OPTPreTrainedModel): output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, + **kwargs, ) hidden_states = layer_outputs[0] @@ -966,8 +828,6 @@ class OPTDecoder(OPTPreTrainedModel): if return_legacy_cache: next_cache = next_cache.to_legacy_cache() - if not return_dict: - return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=next_cache, @@ -996,6 +856,7 @@ class OPTModel(OPTPreTrainedModel): def get_decoder(self): return self.decoder + @can_return_tuple @add_start_docstrings_to_model_forward(OPT_INPUTS_DOCSTRING) @add_code_sample_docstrings( checkpoint=_CHECKPOINT_FOR_DOC, @@ -1016,6 +877,7 @@ class OPTModel(OPTPreTrainedModel): return_dict: Optional[bool] = None, position_ids: Optional[torch.LongTensor] = None, cache_position: Optional[torch.Tensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], ) -> Union[Tuple, BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -1035,13 +897,11 @@ class OPTModel(OPTPreTrainedModel): use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, + return_dict=True, cache_position=cache_position, + **kwargs, ) - if not return_dict: - return decoder_outputs - return BaseModelOutputWithPast( last_hidden_state=decoder_outputs.last_hidden_state, past_key_values=decoder_outputs.past_key_values, @@ -1050,6 +910,9 @@ class OPTModel(OPTPreTrainedModel): ) +class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... + + class OPTForCausalLM(OPTPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] @@ -1081,6 +944,7 @@ class OPTForCausalLM(OPTPreTrainedModel, GenerationMixin): def get_decoder(self): return self.model.decoder + @can_return_tuple @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( self, @@ -1096,7 +960,7 @@ class OPTForCausalLM(OPTPreTrainedModel, GenerationMixin): return_dict: Optional[bool] = None, position_ids: Optional[torch.LongTensor] = None, cache_position: Optional[torch.Tensor] = None, - **kwargs, + **kwargs: Unpack[KwargsForCausalLM], ) -> Union[Tuple, CausalLMOutputWithPast]: r""" Args: @@ -1198,8 +1062,9 @@ class OPTForCausalLM(OPTPreTrainedModel, GenerationMixin): use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, + return_dict=True, cache_position=cache_position, + **kwargs, ) logits = self.lm_head(outputs[0]).contiguous() @@ -1215,10 +1080,6 @@ class OPTForCausalLM(OPTPreTrainedModel, GenerationMixin): **kwargs, ) - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output - return CausalLMOutputWithPast( loss=loss, logits=logits, diff --git a/src/transformers/models/paligemma/modeling_paligemma.py b/src/transformers/models/paligemma/modeling_paligemma.py index 2a4dc3ef62d..d8096d8113f 100644 --- a/src/transformers/models/paligemma/modeling_paligemma.py +++ b/src/transformers/models/paligemma/modeling_paligemma.py @@ -23,9 +23,12 @@ from torch import nn from ...cache_utils import Cache, HybridCache, StaticCache from ...generation import GenerationMixin +from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import BaseModelOutputWithPast from ...modeling_utils import PreTrainedModel +from ...processing_utils import Unpack from ...utils import ( + LossKwargs, ModelOutput, add_start_docstrings, add_start_docstrings_to_model_forward, @@ -159,6 +162,7 @@ class PaliGemmaPreTrainedModel(PreTrainedModel): _supports_static_cache = True _supports_flash_attn_2 = True _supports_sdpa = True + _supports_attention_backend = True def _init_weights(self, module): # important: this ported version of PaliGemmaisn't meant for training from scratch - only @@ -352,6 +356,7 @@ class PaliGemmaModel(PaliGemmaPreTrainedModel): image_features = image_features / (self.config.text_config.hidden_size**0.5) return image_features + @can_return_tuple @add_start_docstrings_to_model_forward(PALIGEMMA_INPUTS_DOCSTRING) def forward( self, @@ -368,7 +373,7 @@ class PaliGemmaModel(PaliGemmaPreTrainedModel): output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, - **lm_kwargs, + **kwargs: Unpack[FlashAttentionKwargs], ) -> Union[Tuple, PaligemmaModelOutputWithPast]: if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") @@ -436,17 +441,19 @@ class PaliGemmaModel(PaliGemmaPreTrainedModel): output_hidden_states=output_hidden_states, return_dict=True, cache_position=cache_position, - **lm_kwargs, + **kwargs, ) - output = PaligemmaModelOutputWithPast( + return PaligemmaModelOutputWithPast( last_hidden_state=outputs.last_hidden_state, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, image_hidden_states=image_features if pixel_values is not None else None, ) - return output if return_dict else output.to_tuple() + + +class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... @add_start_docstrings( @@ -512,7 +519,7 @@ class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel, GenerationMixi output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, logits_to_keep: Union[int, torch.Tensor] = 0, - **lm_kwargs, + **kwargs: Unpack[KwargsForCausalLM], ) -> Union[Tuple, PaliGemmaCausalLMOutputWithPast]: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): @@ -570,7 +577,7 @@ class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel, GenerationMixi output_hidden_states=output_hidden_states, return_dict=True, cache_position=cache_position, - **lm_kwargs, + **kwargs, ) hidden_states = outputs[0] @@ -580,7 +587,9 @@ class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel, GenerationMixi loss = None if labels is not None: - loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size) + loss = self.loss_function( + logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs + ) return PaliGemmaCausalLMOutputWithPast( loss=loss, diff --git a/src/transformers/models/pixtral/modeling_pixtral.py b/src/transformers/models/pixtral/modeling_pixtral.py index 57cb4f6591c..171bf63986e 100644 --- a/src/transformers/models/pixtral/modeling_pixtral.py +++ b/src/transformers/models/pixtral/modeling_pixtral.py @@ -21,14 +21,18 @@ import torch import torch.utils.checkpoint from torch import nn -from ... import PreTrainedModel from ...activations import ACT2FN from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import BaseModelOutput from ...modeling_rope_utils import dynamic_rope_update -from ...modeling_utils import ALL_ATTENTION_FUNCTIONS +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack -from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging +from ...utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + can_return_tuple, + logging, +) from .configuration_pixtral import PixtralVisionConfig @@ -132,7 +136,7 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): return q_embed, k_embed -# Copied from transformers.models.smolvlm.modeling_smolvlm.eager_attention_forward +# Copied from transformers.models.siglip.modeling_siglip.eager_attention_forward def eager_attention_forward( module: nn.Module, query: torch.Tensor, @@ -167,10 +171,11 @@ class PixtralAttention(nn.Module): self.embed_dim = config.hidden_size self.num_heads = config.num_attention_heads self.head_dim = self.embed_dim // self.num_heads - self.is_causal = False - self.scale = self.head_dim**-0.5 + self.scaling = self.head_dim**-0.5 + self.is_causal = False + self.dropout = config.attention_dropout self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False) @@ -211,28 +216,22 @@ class PixtralAttention(nn.Module): else: attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - # Since we use packing, if Flash-Attn 2 is selected we rely on position_ids - if self.config._attn_implementation == "flash_attention_2": - kwargs["position_ids"] = kwargs["position_ids"].to(hidden_states.device, non_blocking=True) - attention_mask = None - attn_output, attn_weights = attention_interface( self, query_states, key_states, value_states, attention_mask, - scaling=self.scale, dropout=0.0 if not self.training else self.dropout, - is_causal=self.is_causal, - output_attentions=output_attentions, + scaling=self.scaling, **kwargs, ) - attn_output = attn_output.reshape(batch_size, patches, -1) - + attn_output = attn_output.reshape(batch_size, patches, -1).contiguous() attn_output = self.o_proj(attn_output) + if not output_attentions: + attn_weights = None return attn_output, attn_weights @@ -288,7 +287,7 @@ class PixtralAttentionLayer(nn.Module): attention_mask: torch.Tensor, position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, output_attentions: Optional[bool] = None, - **kwargs, + **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.FloatTensor]: """ Args: @@ -341,7 +340,7 @@ class PixtralTransformer(nn.Module): output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, - **kwargs, + **kwargs: Unpack[FlashAttentionKwargs], ) -> Union[Tuple, BaseModelOutput]: r""" Args: @@ -383,7 +382,6 @@ class PixtralTransformer(nn.Module): attention_mask, position_embeddings, output_attentions, - **kwargs, ) else: layer_outputs = encoder_layer( @@ -431,6 +429,10 @@ class PixtralPreTrainedModel(PreTrainedModel): base_model_prefix = "model" main_input_name = "pixel_values" supports_gradient_checkpointing = True + _supports_attention_backend = True + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_flex_attn = True _no_split_modules = ["PixtralAttentionLayer"] _supports_flash_attn_2 = True _supports_sdpa = True @@ -508,6 +510,7 @@ class PixtralVisionModel(PixtralPreTrainedModel): def get_input_embeddings(self): return self.patch_conv + @can_return_tuple @add_start_docstrings_to_model_forward(PIXTRAL_INPUTS_DOCSTRING) def forward( self, @@ -517,7 +520,7 @@ class PixtralVisionModel(PixtralPreTrainedModel): output_attentions: Optional[bool] = None, return_dict: Optional[bool] = None, *args, - **kwargs, + **kwargs: Unpack[FlashAttentionKwargs], ) -> Union[Tuple, BaseModelOutput]: """ Returns: @@ -551,17 +554,15 @@ class PixtralVisionModel(PixtralPreTrainedModel): [p.shape[-2] * p.shape[-1] for p in patch_embeds_list], patch_embeds ) - out = self.transformer( + return self.transformer( patch_embeds, attention_mask=attention_mask, position_embeddings=position_embeddings, output_hidden_states=output_hidden_states, output_attentions=output_attentions, - return_dict=return_dict, + return_dict=True, **kwargs, ) - return out - __all__ = ["PixtralVisionModel", "PixtralPreTrainedModel"] diff --git a/src/transformers/models/smolvlm/modeling_smolvlm.py b/src/transformers/models/smolvlm/modeling_smolvlm.py index 6c33479e913..e5abf36a3e4 100644 --- a/src/transformers/models/smolvlm/modeling_smolvlm.py +++ b/src/transformers/models/smolvlm/modeling_smolvlm.py @@ -24,17 +24,20 @@ from typing import Callable, List, Optional, Tuple, Union import torch from torch import nn -from torch.nn import CrossEntropyLoss from ...activations import ACT2FN from ...cache_utils import DynamicCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import _prepare_4d_attention_mask +from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import BaseModelOutput, ModelOutput from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack from ...utils import ( + LossKwargs, add_start_docstrings, add_start_docstrings_to_model_forward, + can_return_tuple, logging, replace_return_docstrings, ) @@ -78,6 +81,7 @@ class SmolVLMPreTrainedModel(PreTrainedModel): _supports_sdpa = True _supports_flex_attn = True _supports_cache_class = True + _supports_attention_backend = True def _init_weights(self, module): std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range) @@ -574,80 +578,6 @@ class SmolVLMConnector(nn.Module): return image_hidden_states -SMOLVLM_INPUTS_DOCSTRING = r""" - Args: - input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide - it. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are input IDs?](../glossary#input-ids) - attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see - `past_key_values`). - - If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] - and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more - information on the default strategy. - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, - config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids) - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape - `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape - `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. - - Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention - blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. - - If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that - don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all - `decoder_input_ids` of shape `(batch_size, sequence_length)`. - inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This - is useful if you want more control over how to convert `input_ids` indices into associated vectors than the - model's internal embedding lookup matrix. - pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)): - The tensors corresponding to the input images. Pixel values can be obtained using - [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details ([]`LlavaProcessor`] uses - [`CLIPImageProcessor`] for processing images). - pixel_attention_mask (`torch.Tensor` of shape `(batch_size, image_size, image_size)`, *optional*): - Mask to avoid performing attention on padding pixel indices. - image_hidden_states (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`): - The hidden states of the image encoder after modality projection. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see - `past_key_values`). - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. - cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): - Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, - this tensor is not affected by padding. It is used to update the cache in the correct position and to infer - the complete sequence length. -""" - - @add_start_docstrings( """SmolVLM model consisting of a SIGLIP vision encoder and Llama3 language decoder""", SMOLVLM_START_DOCSTRING, @@ -746,18 +676,7 @@ class SmolVLMModel(SmolVLMPreTrainedModel): merged_embeds = torch.where(image_mask.unsqueeze(-1), image_embeds, inputs_embeds) return merged_embeds - @add_start_docstrings_to_model_forward( - """ - Inputs fed to the model can have an arbitrary number of images. To account for this, pixel_values fed to - the model have image padding -> (batch_size, max_num_images, 3, max_heights, max_widths) where - max_num_images is the maximum number of images among the batch_size samples in the batch. - Padding images are not needed beyond padding the pixel_values at the entrance of the model. - For efficiency, we only pass through the vision_model's forward the real images by - discarding the padding images i.e. pixel_values of size (image_batch_size, 3, height, width) where - image_batch_size would be 7 when num_images_per_sample=[1, 3, 1, 2] and max_num_images would be 3. - """, - SMOLVLM_INPUTS_DOCSTRING, - ) + @can_return_tuple def forward( self, input_ids: Optional[torch.LongTensor] = None, @@ -773,6 +692,7 @@ class SmolVLMModel(SmolVLMPreTrainedModel): output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], ) -> Union[Tuple, SmolVLMBaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -873,13 +793,11 @@ class SmolVLMModel(SmolVLMPreTrainedModel): use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, + return_dict=True, cache_position=cache_position, + **kwargs, ) - if not return_dict: - return tuple(v for v in [*outputs, image_hidden_states] if v is not None) - return SmolVLMBaseModelOutputWithPast( last_hidden_state=outputs.last_hidden_state, past_key_values=outputs.past_key_values, @@ -927,6 +845,83 @@ class SmolVLMCausalLMOutputWithPast(ModelOutput): image_hidden_states: Optional[Tuple[torch.FloatTensor]] = None +class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... + + +SMOLVLM_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids) + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape + `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)): + The tensors corresponding to the input images. Pixel values can be obtained using + [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details ([]`LlavaProcessor`] uses + [`CLIPImageProcessor`] for processing images). + pixel_attention_mask (`torch.Tensor` of shape `(batch_size, image_size, image_size)`, *optional*): + Mask to avoid performing attention on padding pixel indices. + image_hidden_states (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`): + The hidden states of the image encoder after modality projection. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, + this tensor is not affected by padding. It is used to update the cache in the correct position and to infer + the complete sequence length. +""" + + @add_start_docstrings( """The SmolVLM Model with a language modeling head. It is made up a SigLIP vision encoder, with a language modeling head on top. """, SMOLVLM_START_DOCSTRING, @@ -979,6 +974,7 @@ class SmolVLMForConditionalGeneration(SmolVLMPreTrainedModel, GenerationMixin): def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings + @can_return_tuple @add_start_docstrings_to_model_forward(SMOLVLM_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=SmolVLMCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( @@ -998,6 +994,7 @@ class SmolVLMForConditionalGeneration(SmolVLMPreTrainedModel, GenerationMixin): cache_position: Optional[torch.LongTensor] = None, return_dict: Optional[bool] = None, logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs: Unpack[KwargsForCausalLM], ) -> Union[Tuple, SmolVLMCausalLMOutputWithPast]: r""" Args: @@ -1066,7 +1063,8 @@ class SmolVLMForConditionalGeneration(SmolVLMPreTrainedModel, GenerationMixin): output_attentions=output_attentions, output_hidden_states=output_hidden_states, cache_position=cache_position, - return_dict=return_dict, + return_dict=True, + **kwargs, ) hidden_states = outputs[0] @@ -1076,26 +1074,9 @@ class SmolVLMForConditionalGeneration(SmolVLMPreTrainedModel, GenerationMixin): loss = None if labels is not None: - # Upcast to float if we need to compute the loss to avoid potential precision issues - logits = logits.float() - labels = labels.to(logits.device) - # Shift so that tokens < n predict n - if attention_mask is not None: - # we use the input attention mask to shift the logits and labels, because it is 2D. - # we also crop attn mask in case it is longer, which happens in PrefixTuning with peft - shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to(logits.device) - shift_logits = logits[..., :-1, :][shift_attention_mask != 0].contiguous() - shift_labels = labels[..., 1:][shift_attention_mask != 0].contiguous() - else: - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = CrossEntropyLoss() - loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) - - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output + loss = self.loss_function( + logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs + ) return SmolVLMCausalLMOutputWithPast( loss=loss, diff --git a/src/transformers/models/smolvlm/modular_smolvlm.py b/src/transformers/models/smolvlm/modular_smolvlm.py index 4745fe30dad..66f77752610 100644 --- a/src/transformers/models/smolvlm/modular_smolvlm.py +++ b/src/transformers/models/smolvlm/modular_smolvlm.py @@ -20,7 +20,10 @@ import torch.utils.checkpoint from torch import nn from ...cache_utils import DynamicCache +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...processing_utils import Unpack from ...utils import ( + can_return_tuple, logging, ) from ..idefics3.configuration_idefics3 import Idefics3Config, Idefics3VisionConfig @@ -195,6 +198,7 @@ class SmolVLMModel(Idefics3Model): merged_embeds = torch.where(image_mask.unsqueeze(-1), image_embeds, inputs_embeds) return merged_embeds + @can_return_tuple def forward( self, input_ids: Optional[torch.LongTensor] = None, @@ -210,6 +214,7 @@ class SmolVLMModel(Idefics3Model): output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], ) -> Union[Tuple, SmolVLMBaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -310,13 +315,11 @@ class SmolVLMModel(Idefics3Model): use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, + return_dict=True, cache_position=cache_position, + **kwargs, ) - if not return_dict: - return tuple(v for v in [*outputs, image_hidden_states] if v is not None) - return SmolVLMBaseModelOutputWithPast( last_hidden_state=outputs.last_hidden_state, past_key_values=outputs.past_key_values, diff --git a/src/transformers/models/video_llava/modeling_video_llava.py b/src/transformers/models/video_llava/modeling_video_llava.py index d40dd4db887..fc970dadc0e 100644 --- a/src/transformers/models/video_llava/modeling_video_llava.py +++ b/src/transformers/models/video_llava/modeling_video_llava.py @@ -23,9 +23,12 @@ from torch import nn from ...activations import ACT2FN from ...generation import GenerationMixin +from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import ModelOutput from ...modeling_utils import PreTrainedModel +from ...processing_utils import Unpack from ...utils import ( + LossKwargs, add_start_docstrings, add_start_docstrings_to_model_forward, can_return_tuple, @@ -181,6 +184,7 @@ class VideoLlavaPreTrainedModel(PreTrainedModel): _supports_sdpa = True _supports_quantized_cache = True _supports_static_cache = True + _supports_attention_backend = True def _init_weights(self, module): std = ( @@ -387,6 +391,7 @@ class VideoLlavaModel(VideoLlavaPreTrainedModel): return video_features, num_frames + @can_return_tuple @add_start_docstrings_to_model_forward(VIDEO_LLAVA_INPUTS_DOCSTRING) def forward( self, @@ -404,7 +409,7 @@ class VideoLlavaModel(VideoLlavaPreTrainedModel): output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - **lm_kwargs, + **kwargs: Unpack[FlashAttentionKwargs], ) -> Union[Tuple, VideoLlavaModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -475,10 +480,10 @@ class VideoLlavaModel(VideoLlavaPreTrainedModel): output_hidden_states=output_hidden_states, return_dict=True, cache_position=cache_position, - **lm_kwargs, + **kwargs, ) - output = VideoLlavaModelOutputWithPast( + return VideoLlavaModelOutputWithPast( last_hidden_state=outputs.last_hidden_state, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, @@ -486,7 +491,9 @@ class VideoLlavaModel(VideoLlavaPreTrainedModel): image_hidden_states=image_features if pixel_values_images is not None else None, video_hidden_states=video_features if pixel_values_videos is not None else None, ) - return output if return_dict else output.to_tuple() + + +class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... @add_start_docstrings( @@ -559,7 +566,7 @@ class VideoLlavaForConditionalGeneration(VideoLlavaPreTrainedModel, GenerationMi return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, - **lm_kwargs, + **kwargs: Unpack[KwargsForCausalLM], ) -> Union[Tuple, VideoLlavaCausalLMOutputWithPast]: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): @@ -671,7 +678,7 @@ class VideoLlavaForConditionalGeneration(VideoLlavaPreTrainedModel, GenerationMi output_hidden_states=output_hidden_states, return_dict=True, cache_position=cache_position, - **lm_kwargs, + **kwargs, ) hidden_states = outputs[0] @@ -681,7 +688,9 @@ class VideoLlavaForConditionalGeneration(VideoLlavaPreTrainedModel, GenerationMi loss = None if labels is not None: - loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size) + loss = self.loss_function( + logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs + ) return VideoLlavaCausalLMOutputWithPast( loss=loss, diff --git a/src/transformers/models/vipllava/modeling_vipllava.py b/src/transformers/models/vipllava/modeling_vipllava.py index 49169593320..c1dd0f9941c 100644 --- a/src/transformers/models/vipllava/modeling_vipllava.py +++ b/src/transformers/models/vipllava/modeling_vipllava.py @@ -171,6 +171,7 @@ class VipLlavaPreTrainedModel(PreTrainedModel): _supports_sdpa = True _supports_quantized_cache = True _supports_static_cache = True + _supports_attention_backend = True def _init_weights(self, module): # important: this ported version of VipLlava isn't meant for training from scratch - only diff --git a/tests/models/blip_2/test_modeling_blip_2.py b/tests/models/blip_2/test_modeling_blip_2.py index b7dc7e541c9..89579ae833a 100644 --- a/tests/models/blip_2/test_modeling_blip_2.py +++ b/tests/models/blip_2/test_modeling_blip_2.py @@ -461,6 +461,7 @@ class Blip2ForConditionalGenerationDecoderOnlyModelTester: @require_torch class Blip2ForConditionalGenerationDecoderOnlyTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): all_model_classes = (Blip2ForConditionalGeneration,) if is_torch_available() else () + additional_model_inputs = ["input_ids"] fx_compatible = False test_head_masking = False test_pruning = False @@ -526,15 +527,11 @@ class Blip2ForConditionalGenerationDecoderOnlyTest(ModelTesterMixin, GenerationT model_sdpa = model_class.from_pretrained(tmpdirname) model_sdpa = model_sdpa.eval().to(torch_device) - text_attn = "sdpa" if model.language_model._supports_sdpa else "eager" - vision_attn = "sdpa" if model.vision_model._supports_sdpa else "eager" - qformer_attn = "sdpa" if model.qformer._supports_sdpa else "eager" - # `None` as it is the requested one which will be assigned to each sub-config # Sub-model will dispatch to SDPA if it can (checked below that `SDPA` layers are present) - self.assertTrue(model.language_model.config._attn_implementation == text_attn) - self.assertTrue(model.vision_model.config._attn_implementation == vision_attn) - self.assertTrue(model.qformer.config._attn_implementation == qformer_attn) + self.assertTrue(model.language_model.config._attn_implementation == "sdpa") + self.assertTrue(model.vision_model.config._attn_implementation == "sdpa") + self.assertTrue(model.qformer.config._attn_implementation == "eager") model_eager = model_class.from_pretrained(tmpdirname, attn_implementation="eager") model_eager = model_eager.eval().to(torch_device) @@ -545,20 +542,13 @@ class Blip2ForConditionalGenerationDecoderOnlyTest(ModelTesterMixin, GenerationT for name, submodule in model_eager.named_modules(): class_name = submodule.__class__.__name__ - if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name: + if ( + class_name.endswith("Attention") + and getattr(submodule, "config", None) + and submodule.config._attn_implementation == "sdpa" + ): raise ValueError("The eager model should not have SDPA attention layers") - has_sdpa = False - for name, submodule in model_sdpa.named_modules(): - class_name = submodule.__class__.__name__ - if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name: - has_sdpa = True - break - if not has_sdpa and any( - module_attn == "sdpa" for module_attn in [text_attn, vision_attn, qformer_attn] - ): - raise ValueError("The SDPA model should have SDPA attention layers") - def test_forward_signature(self): config, _ = self.model_tester.prepare_config_and_inputs_for_common() @@ -869,6 +859,7 @@ class Blip2ModelTester: @require_torch class Blip2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): all_model_classes = (Blip2ForConditionalGeneration, Blip2Model) if is_torch_available() else () + additional_model_inputs = ["input_ids", "decoder_input_ids"] # Doesn't run generation tests. TODO: fix generation tests for Blip2ForConditionalGeneration all_generative_model_classes = () pipeline_model_mapping = ( @@ -967,15 +958,11 @@ class Blip2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): model_sdpa = model_class.from_pretrained(tmpdirname) model_sdpa = model_sdpa.eval().to(torch_device) - text_attn = "sdpa" if model.language_model._supports_sdpa else "eager" - vision_attn = "sdpa" if model.vision_model._supports_sdpa else "eager" - qformer_attn = "sdpa" if model.qformer._supports_sdpa else "eager" - # `None` as it is the requested one which will be assigned to each sub-config # Sub-model will dispatch to SDPA if it can (checked below that `SDPA` layers are present) - self.assertTrue(model.language_model.config._attn_implementation == text_attn) - self.assertTrue(model.vision_model.config._attn_implementation == vision_attn) - self.assertTrue(model.qformer.config._attn_implementation == qformer_attn) + self.assertTrue(model.language_model.config._attn_implementation == "eager") + self.assertTrue(model.vision_model.config._attn_implementation == "sdpa") + self.assertTrue(model.qformer.config._attn_implementation == "eager") model_eager = model_class.from_pretrained(tmpdirname, attn_implementation="eager") model_eager = model_eager.eval().to(torch_device) @@ -986,20 +973,13 @@ class Blip2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): for name, submodule in model_eager.named_modules(): class_name = submodule.__class__.__name__ - if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name: + if ( + class_name.endswith("Attention") + and getattr(submodule, "config", None) + and submodule.config._attn_implementation == "sdpa" + ): raise ValueError("The eager model should not have SDPA attention layers") - has_sdpa = False - for name, submodule in model_sdpa.named_modules(): - class_name = submodule.__class__.__name__ - if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name: - has_sdpa = True - break - if not has_sdpa and any( - module_attn == "sdpa" for module_attn in [text_attn, vision_attn, qformer_attn] - ): - raise ValueError("The SDPA model should have SDPA attention layers") - def test_forward_signature(self): config, _ = self.model_tester.prepare_config_and_inputs_for_common() @@ -1485,6 +1465,7 @@ class Blip2TextRetrievalModelTester: @require_torch class Blip2TextRetrievalModelTest(ModelTesterMixin, unittest.TestCase): all_model_classes = (Blip2ForImageTextRetrieval,) if is_torch_available() else () + additional_model_inputs = ["input_ids"] fx_compatible = False test_head_masking = False test_pruning = False diff --git a/tests/models/instructblip/test_modeling_instructblip.py b/tests/models/instructblip/test_modeling_instructblip.py index b154a09c2bc..923a8749c61 100644 --- a/tests/models/instructblip/test_modeling_instructblip.py +++ b/tests/models/instructblip/test_modeling_instructblip.py @@ -475,6 +475,7 @@ class InstructBlipForConditionalGenerationDecoderOnlyTest(ModelTesterMixin, Gene else () ) pipeline_model_mapping = {"image-text-to-text": InstructBlipForConditionalGeneration} + additional_model_inputs = ["qformer_input_ids", "input_ids"] fx_compatible = False test_head_masking = False test_pruning = False @@ -687,15 +688,11 @@ class InstructBlipForConditionalGenerationDecoderOnlyTest(ModelTesterMixin, Gene model_sdpa = model_class.from_pretrained(tmpdirname) model_sdpa = model_sdpa.eval().to(torch_device) - text_attn = "sdpa" if model.language_model._supports_sdpa else "eager" - vision_attn = "sdpa" if model.vision_model._supports_sdpa else "eager" - qformer_attn = "sdpa" if model.qformer._supports_sdpa else "eager" - # `None` as it is the requested one which will be assigned to each sub-config # Sub-model will dispatch to SDPA if it can (checked below that `SDPA` layers are present) - self.assertTrue(model.language_model.config._attn_implementation == text_attn) - self.assertTrue(model.vision_model.config._attn_implementation == vision_attn) - self.assertTrue(model.qformer.config._attn_implementation == qformer_attn) + self.assertTrue(model.language_model.config._attn_implementation == "sdpa") + self.assertTrue(model.vision_model.config._attn_implementation == "sdpa") + self.assertTrue(model.qformer.config._attn_implementation == "eager") model_eager = model_class.from_pretrained(tmpdirname, attn_implementation="eager") model_eager = model_eager.eval().to(torch_device) @@ -706,20 +703,13 @@ class InstructBlipForConditionalGenerationDecoderOnlyTest(ModelTesterMixin, Gene for name, submodule in model_eager.named_modules(): class_name = submodule.__class__.__name__ - if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name: + if ( + class_name.endswith("Attention") + and getattr(submodule, "config", None) + and submodule.config._attn_implementation == "sdpa" + ): raise ValueError("The eager model should not have SDPA attention layers") - has_sdpa = False - for name, submodule in model_sdpa.named_modules(): - class_name = submodule.__class__.__name__ - if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name: - has_sdpa = True - break - if not has_sdpa and any( - module_attn == "sdpa" for module_attn in [text_attn, vision_attn, qformer_attn] - ): - raise ValueError("The SDPA model should have SDPA attention layers") - # We will verify our results on an image of cute cats def prepare_img(): diff --git a/tests/models/instructblipvideo/test_modeling_instructblipvideo.py b/tests/models/instructblipvideo/test_modeling_instructblipvideo.py index 9bd617b4666..a7870a4b29c 100644 --- a/tests/models/instructblipvideo/test_modeling_instructblipvideo.py +++ b/tests/models/instructblipvideo/test_modeling_instructblipvideo.py @@ -492,6 +492,7 @@ class InstructBlipVideoForConditionalGenerationDecoderOnlyTest( all_model_classes = ( (InstructBlipVideoForConditionalGeneration, InstructBlipVideoModel) if is_torch_available() else () ) + additional_model_inputs = ["qformer_input_ids", "input_ids"] fx_compatible = False test_head_masking = False test_pruning = False @@ -702,15 +703,11 @@ class InstructBlipVideoForConditionalGenerationDecoderOnlyTest( model_sdpa = model_class.from_pretrained(tmpdirname) model_sdpa = model_sdpa.eval().to(torch_device) - text_attn = "sdpa" if model.language_model._supports_sdpa else "eager" - vision_attn = "sdpa" if model.vision_model._supports_sdpa else "eager" - qformer_attn = "sdpa" if model.qformer._supports_sdpa else "eager" - # `None` as it is the requested one which will be assigned to each sub-config # Sub-model will dispatch to SDPA if it can (checked below that `SDPA` layers are present) - self.assertTrue(model.language_model.config._attn_implementation == text_attn) - self.assertTrue(model.vision_model.config._attn_implementation == vision_attn) - self.assertTrue(model.qformer.config._attn_implementation == qformer_attn) + self.assertTrue(model.language_model.config._attn_implementation == "sdpa") + self.assertTrue(model.vision_model.config._attn_implementation == "sdpa") + self.assertTrue(model.qformer.config._attn_implementation == "eager") model_eager = model_class.from_pretrained(tmpdirname, attn_implementation="eager") model_eager = model_eager.eval().to(torch_device) @@ -721,20 +718,13 @@ class InstructBlipVideoForConditionalGenerationDecoderOnlyTest( for name, submodule in model_eager.named_modules(): class_name = submodule.__class__.__name__ - if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name: + if ( + class_name.endswith("Attention") + and getattr(submodule, "config", None) + and submodule.config._attn_implementation == "sdpa" + ): raise ValueError("The eager model should not have SDPA attention layers") - has_sdpa = False - for name, submodule in model_sdpa.named_modules(): - class_name = submodule.__class__.__name__ - if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name: - has_sdpa = True - break - if not has_sdpa and any( - module_attn == "sdpa" for module_attn in [text_attn, vision_attn, qformer_attn] - ): - raise ValueError("The SDPA model should have SDPA attention layers") - # We will verify our results on an image of cute cats def prepare_video(): diff --git a/tests/models/kosmos2/test_modeling_kosmos2.py b/tests/models/kosmos2/test_modeling_kosmos2.py index 102db8b57ed..1e61d536d75 100644 --- a/tests/models/kosmos2/test_modeling_kosmos2.py +++ b/tests/models/kosmos2/test_modeling_kosmos2.py @@ -30,6 +30,7 @@ from transformers.testing_utils import ( IS_ROCM_SYSTEM, IS_XPU_SYSTEM, require_torch, + require_torch_sdpa, require_vision, slow, torch_device, @@ -42,6 +43,7 @@ from transformers.utils import ( from ...generation.test_utils import GenerationTesterMixin from ...test_configuration_common import ConfigTester from ...test_modeling_common import ( + TEST_EAGER_MATCHES_SDPA_INFERENCE_PARAMETERIZATION, ModelTesterMixin, _config_zero_init, floats_tensor, @@ -259,6 +261,7 @@ class Kosmos2ModelTester: @require_torch class Kosmos2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase): all_model_classes = (Kosmos2Model, Kosmos2ForConditionalGeneration) if is_torch_available() else () + additional_model_inputs = ["input_ids", "image_embeds_position_mask"] pipeline_model_mapping = ( { "feature-extraction": Kosmos2Model, @@ -462,6 +465,14 @@ class Kosmos2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi def test_generate_from_inputs_embeds(self): pass + @parameterized.expand(TEST_EAGER_MATCHES_SDPA_INFERENCE_PARAMETERIZATION) + @require_torch_sdpa + @unittest.skip("KOSMOS-2 doesn't support padding") + def test_eager_matches_sdpa_inference( + self, name, torch_dtype, padding_side, use_attention_mask, output_attentions, enable_kernels + ): + pass + @pytest.mark.generate def test_left_padding_compatibility(self): # Overwrite because Kosmos-2 need to padd pixel values and pad image-attn-mask diff --git a/tests/models/opt/test_modeling_opt.py b/tests/models/opt/test_modeling_opt.py index 79791e71513..5916f42f5f6 100644 --- a/tests/models/opt/test_modeling_opt.py +++ b/tests/models/opt/test_modeling_opt.py @@ -219,9 +219,10 @@ class OPTModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, else {} ) is_encoder_decoder = False - fx_compatible = True + fx_compatible = False # Broken by attention refactor cc @Cyrilvallez test_pruning = False test_missing_keys = False + test_head_masking = False # new attn API doesn't support head mask # TODO: Fix the failed tests def is_pipeline_test_to_skip( diff --git a/tests/models/pixtral/test_modeling_pixtral.py b/tests/models/pixtral/test_modeling_pixtral.py index de3f5efc2e5..1a7b2ad01d3 100644 --- a/tests/models/pixtral/test_modeling_pixtral.py +++ b/tests/models/pixtral/test_modeling_pixtral.py @@ -109,6 +109,7 @@ class PixtralVisionModelModelTest(ModelTesterMixin, unittest.TestCase): """ all_model_classes = (PixtralVisionModel,) if is_torch_available() else () + additional_model_inputs = ["image_sizes"] test_pruning = False test_head_masking = False test_torchscript = False diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index b8a4fda96d3..ff86e157fdb 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -3765,6 +3765,10 @@ class ModelTesterMixin: key = "decoder_hidden_states" elif "logits" in outputs_eager and "Classification" in model_class.__name__: key = "logits" + elif "language_model_outputs" in outputs_eager and "blip" in model_class.__name__.lower(): + outputs_eager = outputs_eager["language_model_outputs"] + outputs_sdpa = outputs_sdpa["language_model_outputs"] + key = "hidden_states" if "hidden_states" in outputs_eager else "decoder_hidden_states" else: key = "hidden_states" @@ -4002,14 +4006,14 @@ class ModelTesterMixin: model = model_class.from_pretrained(tmpdirname, torch_dtype=torch_dtype) sub_models_supporting_fa2 = [ - module._supports_flash_attn_2 + (module._supports_flash_attn_2 or module._supports_attention_backend) for name, module in model.named_modules() if isinstance(module, PreTrainedModel) and name != "" ] supports_fa2_all_modules = ( all(sub_models_supporting_fa2) if len(sub_models_supporting_fa2) > 0 - else model._supports_flash_attn_2 + else (model._supports_flash_attn_2 or model._supports_attention_backend) ) if not supports_fa2_all_modules: with self.assertRaises(ValueError):