From d3d835d4fc145e5062d2153ac23ccd4b3e2c2cbd Mon Sep 17 00:00:00 2001 From: Raushan Turganbay Date: Tue, 24 Jun 2025 10:53:52 +0200 Subject: [PATCH] [qwen] refactor attentions for vision/audio (#38930) * refactor attentions in vision/audio * remove fa2 import * make config the only args * pass along kwargs from modality encoders * style --- .../qwen2_5_omni/modeling_qwen2_5_omni.py | 378 ++++++------------ .../qwen2_5_omni/modular_qwen2_5_omni.py | 296 +++++--------- .../models/qwen2_5_vl/modeling_qwen2_5_vl.py | 255 ++++-------- .../models/qwen2_5_vl/modular_qwen2_5_vl.py | 84 +--- .../models/qwen2_vl/modeling_qwen2_vl.py | 252 +++++------- 5 files changed, 414 insertions(+), 851 deletions(-) diff --git a/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py b/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py index 880f209cc21..3ccebbd3423 100644 --- a/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +++ b/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py @@ -34,11 +34,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask -from ...modeling_flash_attention_utils import ( - FlashAttentionKwargs, - flash_attn_supports_top_left_mask, - is_flash_attn_available, -) +from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPast, ModelOutput from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update @@ -59,13 +55,6 @@ from .configuration_qwen2_5_omni import ( ) -if is_flash_attn_available(): - from ...modeling_flash_attention_utils import apply_rotary_emb, flash_attn_varlen_func -else: - flash_attn_varlen_func = None - apply_rotary_emb = None - - logger = logging.get_logger(__name__) @@ -559,6 +548,44 @@ class Qwen2_5OmniThinkerCausalLMOutputWithPast(ModelOutput): rope_deltas: Optional[torch.LongTensor] = None +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + class Qwen2_5OmniAudioAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -571,6 +598,7 @@ class Qwen2_5OmniAudioAttention(nn.Module): self.num_heads = config.encoder_attention_heads self.dropout = config.attention_dropout self.head_dim = self.embed_dim // self.num_heads + self.num_key_value_groups = 1 # needed for eager attention self.config = config if (self.head_dim * self.num_heads) != self.embed_dim: @@ -591,6 +619,7 @@ class Qwen2_5OmniAudioAttention(nn.Module): self, hidden_states: torch.Tensor, cu_seqlens: Optional[torch.Tensor] = None, + **kwargs, ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" @@ -600,13 +629,13 @@ class Qwen2_5OmniAudioAttention(nn.Module): key_states = self.k_proj(hidden_states).reshape(seq_length, self.num_heads, -1) value_states = self.v_proj(hidden_states).reshape(seq_length, self.num_heads, -1) - query_states = query_states.transpose(0, 1) - key_states = key_states.transpose(0, 1) - value_states = value_states.transpose(0, 1) - attn_weights = torch.matmul(query_states, key_states.transpose(1, 2)) / math.sqrt(self.head_dim) + query_states = query_states.transpose(0, 1).unsqueeze(0) + key_states = key_states.transpose(0, 1).unsqueeze(0) + value_states = value_states.transpose(0, 1).unsqueeze(0) + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() attention_mask = torch.full( - [1, seq_length, key_states.shape[1]], + [1, 1, seq_length, key_states.shape[-2]], torch.finfo(query_states.dtype).min, device=query_states.device, dtype=query_states.dtype, @@ -614,115 +643,37 @@ class Qwen2_5OmniAudioAttention(nn.Module): for i in range(1, len(cu_seqlens)): attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0 - attn_weights = attn_weights + attention_mask + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - attn_weights = nn.functional.softmax(attn_weights, dim=-1).to(query_states.dtype) - - attn_output = torch.matmul(attn_weights, value_states).transpose(0, 1).reshape(seq_length, self.embed_dim) - - # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be - # partitioned across GPUs when using tensor-parallelism. - - attn_output = self.out_proj(attn_output) - - return attn_output - - -class Qwen2_5OmniAudioFlashAttention2(Qwen2_5OmniAudioAttention): - """ - Qwen2.5OmniThinker flash attention module. This module inherits from `Qwen2_5OmniAudioAttention` as the weights of the module stays - untouched. The only required change would be on the forward pass where it needs to correctly call the public API of - flash attention and deal with padding tokens in case the input contains any of them. - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. - # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. - # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask() - - def forward( - self, - hidden_states: torch.Tensor, - cu_seqlens: Optional[torch.Tensor] = None, - ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: - seq_length, all_dim = hidden_states.size() - query_states = self.q_proj(hidden_states) - query_states = query_states.reshape(seq_length, self.num_heads, -1) - - key_states = self.k_proj(hidden_states) - key_states = key_states.reshape(seq_length, self.num_heads, -1) - value_states = self.v_proj(hidden_states) - value_states = value_states.reshape(seq_length, self.num_heads, -1) - - max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() - attn_output = flash_attn_varlen_func( - query_states, key_states, value_states, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen, dropout_p=0.0 - ) - attn_output = attn_output.reshape(seq_length, all_dim) - attn_output = self.out_proj(attn_output) - - return attn_output - - -class Qwen2_5OmniAudioSdpaAttention(Qwen2_5OmniAudioAttention): - def forward( - self, - hidden_states: torch.Tensor, - cu_seqlens: Optional[torch.Tensor] = None, - ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: - """Input shape: Batch x Time x Channel""" - - seq_length, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states).reshape(seq_length, self.num_heads, -1) - key_states = self.k_proj(hidden_states).reshape(seq_length, self.num_heads, -1) - value_states = self.v_proj(hidden_states).reshape(seq_length, self.num_heads, -1) - - attention_mask = torch.zeros( - [1, seq_length, key_states.shape[0]], device=query_states.device, dtype=torch.bool - ) - for i in range(1, len(cu_seqlens)): - attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = True - - # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment - # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. - # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1. - query_states = query_states.transpose(0, 1) - key_states = key_states.transpose(0, 1) - value_states = value_states.transpose(0, 1) - - # NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask, - # but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577 - attn_output = torch.nn.functional.scaled_dot_product_attention( + attn_output, _ = attention_interface( + self, query_states, key_states, value_states, - attn_mask=attention_mask, - dropout_p=self.dropout if self.training else 0.0, + attention_mask, + dropout=0.0 if not self.training else self.dropout, + scaling=self.scaling, + cu_seqlens_q=cu_seqlens, # pass cu seq lens for FA2 + cu_seqlens_k=cu_seqlens, + max_seqlen_q=max_seqlen, + max_seqlen_k=max_seqlen, + is_causal=False, + **kwargs, ) - attn_output = attn_output.transpose(0, 1) - # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be - # partitioned across GPUs when using tensor-parallelism. - attn_output = attn_output.reshape(seq_length, self.embed_dim) + + attn_output = attn_output.reshape(seq_length, -1).contiguous() attn_output = self.out_proj(attn_output) + return attn_output -QWEN2_5_OMNI_AUDIO_ATTENTION_CLASSES = { - "eager": Qwen2_5OmniAudioAttention, - "flash_attention_2": Qwen2_5OmniAudioFlashAttention2, - "sdpa": Qwen2_5OmniAudioSdpaAttention, -} - - class Qwen2_5OmniAudioEncoderLayer(GradientCheckpointingLayer): def __init__(self, config: Qwen2_5OmniAudioEncoderConfig): super().__init__() self.embed_dim = config.d_model - self.self_attn = QWEN2_5_OMNI_AUDIO_ATTENTION_CLASSES[config._attn_implementation](config) + self.self_attn = Qwen2_5OmniAudioAttention(config) self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.dropout = config.dropout self.activation_fn = ACT2FN[config.activation_function] @@ -735,6 +686,7 @@ class Qwen2_5OmniAudioEncoderLayer(GradientCheckpointingLayer): self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, + **kwargs, ) -> torch.Tensor: """ Args: @@ -752,6 +704,7 @@ class Qwen2_5OmniAudioEncoderLayer(GradientCheckpointingLayer): hidden_states = self.self_attn( hidden_states=hidden_states, cu_seqlens=cu_seqlens, + **kwargs, ) hidden_states = residual + hidden_states residual = hidden_states @@ -838,6 +791,7 @@ class Qwen2_5OmniAudioEncoder(Qwen2_5OmniPreTrainedModel): input_features, feature_lens=None, aftercnn_lens=None, + **kwargs, ): r""" input_features (`torch.LongTensor` of shape `(batch_size, feature_size, sequence_length)`): @@ -881,7 +835,7 @@ class Qwen2_5OmniAudioEncoder(Qwen2_5OmniPreTrainedModel): ).to(torch.int32) for encoder_layer in self.layers: - layer_outputs = encoder_layer(hidden_states, cu_seqlens) + layer_outputs = encoder_layer(hidden_states, cu_seqlens, **kwargs) hidden_states = layer_outputs[0] hidden_states_list = hidden_states.split(aftercnn_lens.tolist(), dim=0) @@ -962,106 +916,68 @@ def apply_rotary_pos_emb_vision(tensor: torch.Tensor, freqs: torch.Tensor) -> to class Qwen2_5OmniVisionAttention(nn.Module): - def __init__(self, dim: int, num_heads: int = 16) -> None: + def __init__(self, config: Qwen2_5OmniVisionEncoderConfig = None) -> None: super().__init__() - self.num_heads = num_heads - self.head_dim = dim // num_heads - self.q = nn.Linear(dim, dim, bias=True) - self.k = nn.Linear(dim, dim, bias=True) - self.v = nn.Linear(dim, dim, bias=True) - self.proj = nn.Linear(dim, dim) + self.dim = config.hidden_size + self.num_heads = config.num_heads + self.head_dim = self.dim // self.num_heads + self.q = nn.Linear(self.dim, self.dim, bias=True) + self.k = nn.Linear(self.dim, self.dim, bias=True) + self.v = nn.Linear(self.dim, self.dim, bias=True) + self.proj = nn.Linear(self.dim, self.dim) + self.scaling = math.sqrt(self.head_dim) + self.num_key_value_groups = 1 # needed for eager attention + self.config = config def forward( - self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor = None + self, + hidden_states: torch.Tensor, + cu_seqlens: torch.Tensor, + rotary_pos_emb: Optional[torch.Tensor] = None, + **kwargs, ) -> torch.Tensor: seq_length = hidden_states.shape[0] - q = self.q(hidden_states).reshape(seq_length, self.num_heads, -1) - k = self.k(hidden_states).reshape(seq_length, self.num_heads, -1) - v = self.v(hidden_states).reshape(seq_length, self.num_heads, -1) - q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0) - k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0) + query_states = self.q(hidden_states).reshape(seq_length, self.num_heads, -1) + key_states = self.k(hidden_states).reshape(seq_length, self.num_heads, -1) + value_states = self.v(hidden_states).reshape(seq_length, self.num_heads, -1) + query_states = apply_rotary_pos_emb_vision(query_states.unsqueeze(0), rotary_pos_emb).squeeze(0) + key_states = apply_rotary_pos_emb_vision(key_states.unsqueeze(0), rotary_pos_emb).squeeze(0) attention_mask = torch.full( - [1, seq_length, seq_length], torch.finfo(q.dtype).min, device=q.device, dtype=q.dtype + [1, 1, seq_length, seq_length], + torch.finfo(query_states.dtype).min, + device=query_states.device, + dtype=query_states.dtype, ) for i in range(1, len(cu_seqlens)): attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0 - q = q.transpose(0, 1) - k = k.transpose(0, 1) - v = v.transpose(0, 1) - attn_weights = torch.matmul(q, k.transpose(1, 2)) / math.sqrt(self.head_dim) - attn_weights = attn_weights + attention_mask - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype) - attn_output = torch.matmul(attn_weights, v) - attn_output = attn_output.transpose(0, 1) - attn_output = attn_output.reshape(seq_length, -1) - attn_output = self.proj(attn_output) - return attn_output - - -class Qwen2_5OmniVisionFlashAttention2(nn.Module): - def __init__(self, dim: int, num_heads: int = 16) -> None: - super().__init__() - self.num_heads = num_heads - self.q = nn.Linear(dim, dim, bias=True) - self.k = nn.Linear(dim, dim, bias=True) - self.v = nn.Linear(dim, dim, bias=True) - self.proj = nn.Linear(dim, dim) - - def _apply_rotary_pos_emb_flashatt(self, tensor: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor: - tensor_ = tensor.float() - cos = freqs.cos().type_as(tensor_) - sin = freqs.sin().type_as(tensor_) - output = apply_rotary_emb(tensor_, cos, sin).type_as(tensor) - return output - - def forward( - self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor = None - ) -> torch.Tensor: - seq_length = hidden_states.shape[0] - q = self.q(hidden_states).reshape(seq_length, self.num_heads, -1) - k = self.k(hidden_states).reshape(seq_length, self.num_heads, -1) - v = self.v(hidden_states).reshape(seq_length, self.num_heads, -1) - q = self._apply_rotary_pos_emb_flashatt(q.unsqueeze(0), rotary_pos_emb).squeeze(0) - k = self._apply_rotary_pos_emb_flashatt(k.unsqueeze(0), rotary_pos_emb).squeeze(0) - + query_states = query_states.transpose(0, 1).unsqueeze(0) # unsqueeze batch_dim + key_states = key_states.transpose(0, 1).unsqueeze(0) # unsqueeze batch_dim + value_states = value_states.transpose(0, 1).unsqueeze(0) # unsqueeze batch_dim max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() - attn_output = flash_attn_varlen_func(q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen).reshape( - seq_length, -1 + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, _ = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0, + scaling=self.scaling, + cu_seqlens_q=cu_seqlens, # pass cu seq lens for FA2 + cu_seqlens_k=cu_seqlens, + max_seqlen_q=max_seqlen, + max_seqlen_k=max_seqlen, + is_causal=False, + **kwargs, ) - attn_output = self.proj(attn_output) - return attn_output - -class Qwen2_5OmniVisionSdpaAttention(nn.Module): - def __init__(self, dim: int, num_heads: int = 16) -> None: - super().__init__() - self.num_heads = num_heads - self.q = nn.Linear(dim, dim, bias=True) - self.k = nn.Linear(dim, dim, bias=True) - self.v = nn.Linear(dim, dim, bias=True) - self.proj = nn.Linear(dim, dim) - - def forward( - self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor = None - ) -> torch.Tensor: - seq_length = hidden_states.shape[0] - q = self.q(hidden_states).reshape(seq_length, self.num_heads, -1) - k = self.k(hidden_states).reshape(seq_length, self.num_heads, -1) - v = self.v(hidden_states).reshape(seq_length, self.num_heads, -1) - q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0) - k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0) - - attention_mask = torch.zeros([1, seq_length, seq_length], device=q.device, dtype=torch.bool) - for i in range(1, len(cu_seqlens)): - attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = True - q = q.transpose(0, 1) - k = k.transpose(0, 1) - v = v.transpose(0, 1) - attn_output = F.scaled_dot_product_attention(q, k, v, attention_mask, dropout_p=0.0) - attn_output = attn_output.transpose(0, 1) - attn_output = attn_output.reshape(seq_length, -1) + attn_output = attn_output.reshape(seq_length, -1).contiguous() attn_output = self.proj(attn_output) return attn_output @@ -1080,26 +996,23 @@ class Qwen2_5OmniMLP(nn.Module): return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state)) -QWEN2_5_OMNI_VISION_ATTENTION_CLASSES = { - "eager": Qwen2_5OmniVisionAttention, - "flash_attention_2": Qwen2_5OmniVisionFlashAttention2, - "sdpa": Qwen2_5OmniVisionSdpaAttention, -} - - class Qwen2_5OmniVisionBlock(GradientCheckpointingLayer): def __init__(self, config: Qwen2_5OmniVisionEncoderConfig) -> None: super().__init__() self.norm1 = Qwen2RMSNorm(config.hidden_size, eps=1e-6) self.norm2 = Qwen2RMSNorm(config.hidden_size, eps=1e-6) - self.attn = QWEN2_5_OMNI_VISION_ATTENTION_CLASSES[config._attn_implementation]( - config.hidden_size, num_heads=config.num_heads - ) + self.attn = Qwen2_5OmniVisionAttention(config=config) self.mlp = Qwen2_5OmniMLP(config, bias=True) - def forward(self, hidden_states, cu_seqlens, rotary_pos_emb) -> torch.Tensor: + def forward( + self, + hidden_states: torch.Tensor, + cu_seqlens: torch.Tensor, + rotary_pos_emb: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: hidden_states = hidden_states + self.attn( - self.norm1(hidden_states), cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb + self.norm1(hidden_states), cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb, **kwargs ) hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) return hidden_states @@ -1258,7 +1171,7 @@ class Qwen2_5OmniVisionEncoder(Qwen2_5OmniPreTrainedModel): return window_index, cu_window_seqlens - def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.Tensor: + def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs) -> torch.Tensor: """ Args: hidden_states (`torch.Tensor` of shape `(seq_len, hidden_size)`): @@ -1308,6 +1221,7 @@ class Qwen2_5OmniVisionEncoder(Qwen2_5OmniPreTrainedModel): hidden_states, cu_seqlens=cu_seqlens_now, rotary_pos_emb=rotary_pos_emb, + **kwargs, ) hidden_states = self.merger(hidden_states) reverse_indices = torch.argsort(window_index) @@ -1397,44 +1311,6 @@ def apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim return q_embed, k_embed -def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: - """ - This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, - num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) - """ - batch, num_key_value_heads, slen, head_dim = hidden_states.shape - if n_rep == 1: - return hidden_states - hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) - return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) - - -def eager_attention_forward( - module: nn.Module, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - attention_mask: Optional[torch.Tensor], - scaling: float, - dropout: float = 0.0, - **kwargs, -): - key_states = repeat_kv(key, module.num_key_value_groups) - value_states = repeat_kv(value, module.num_key_value_groups) - - attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling - if attention_mask is not None: - causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] - attn_weights = attn_weights + causal_mask - - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) - attn_output = torch.matmul(attn_weights, value_states) - attn_output = attn_output.transpose(1, 2).contiguous() - - return attn_output, attn_weights - - class Qwen2_5OmniAttention(nn.Module): """ Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer diff --git a/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py b/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py index a6e330845cc..ac134bd4837 100644 --- a/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +++ b/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py @@ -17,7 +17,7 @@ import math from dataclasses import dataclass -from typing import Any, Optional, Union +from typing import Any, Callable, Optional, Union import numpy as np import torch @@ -36,6 +36,7 @@ from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import ( Qwen2_5_VLTextModel, Qwen2_5_VLVisionBlock, Qwen2RMSNorm, + eager_attention_forward, ) from transformers.models.qwen2_audio.configuration_qwen2_audio import Qwen2AudioEncoderConfig from transformers.models.qwen2_audio.modeling_qwen2_audio import Qwen2AudioEncoderLayer @@ -43,7 +44,7 @@ from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLRotaryEmbeddin from ...configuration_utils import PretrainedConfig, layer_type_validation from ...generation import GenerationMixin -from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available +from ...modeling_flash_attention_utils import is_flash_attn_available from ...modeling_outputs import BaseModelOutput, ModelOutput from ...modeling_rope_utils import rope_config_validation from ...modeling_utils import ALL_ATTENTION_FUNCTIONS @@ -1601,6 +1602,7 @@ class Qwen2_5OmniAudioAttention(nn.Module): self.num_heads = config.encoder_attention_heads self.dropout = config.attention_dropout self.head_dim = self.embed_dim // self.num_heads + self.num_key_value_groups = 1 # needed for eager attention self.config = config if (self.head_dim * self.num_heads) != self.embed_dim: @@ -1621,6 +1623,7 @@ class Qwen2_5OmniAudioAttention(nn.Module): self, hidden_states: torch.Tensor, cu_seqlens: Optional[torch.Tensor] = None, + **kwargs, ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" @@ -1630,13 +1633,13 @@ class Qwen2_5OmniAudioAttention(nn.Module): key_states = self.k_proj(hidden_states).reshape(seq_length, self.num_heads, -1) value_states = self.v_proj(hidden_states).reshape(seq_length, self.num_heads, -1) - query_states = query_states.transpose(0, 1) - key_states = key_states.transpose(0, 1) - value_states = value_states.transpose(0, 1) - attn_weights = torch.matmul(query_states, key_states.transpose(1, 2)) / math.sqrt(self.head_dim) + query_states = query_states.transpose(0, 1).unsqueeze(0) + key_states = key_states.transpose(0, 1).unsqueeze(0) + value_states = value_states.transpose(0, 1).unsqueeze(0) + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() attention_mask = torch.full( - [1, seq_length, key_states.shape[1]], + [1, 1, seq_length, key_states.shape[-2]], torch.finfo(query_states.dtype).min, device=query_states.device, dtype=query_states.dtype, @@ -1644,125 +1647,49 @@ class Qwen2_5OmniAudioAttention(nn.Module): for i in range(1, len(cu_seqlens)): attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0 - attn_weights = attn_weights + attention_mask + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - attn_weights = nn.functional.softmax(attn_weights, dim=-1).to(query_states.dtype) - - attn_output = torch.matmul(attn_weights, value_states).transpose(0, 1).reshape(seq_length, self.embed_dim) - - # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be - # partitioned across GPUs when using tensor-parallelism. - - attn_output = self.out_proj(attn_output) - - return attn_output - - -class Qwen2_5OmniAudioFlashAttention2(Qwen2_5OmniAudioAttention): - """ - Qwen2.5OmniThinker flash attention module. This module inherits from `Qwen2_5OmniAudioAttention` as the weights of the module stays - untouched. The only required change would be on the forward pass where it needs to correctly call the public API of - flash attention and deal with padding tokens in case the input contains any of them. - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. - # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. - # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask() - - def forward( - self, - hidden_states: torch.Tensor, - cu_seqlens: Optional[torch.Tensor] = None, - ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: - seq_length, all_dim = hidden_states.size() - query_states = self.q_proj(hidden_states) - query_states = query_states.reshape(seq_length, self.num_heads, -1) - - key_states = self.k_proj(hidden_states) - key_states = key_states.reshape(seq_length, self.num_heads, -1) - value_states = self.v_proj(hidden_states) - value_states = value_states.reshape(seq_length, self.num_heads, -1) - - max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() - attn_output = flash_attn_varlen_func( - query_states, key_states, value_states, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen, dropout_p=0.0 - ) - attn_output = attn_output.reshape(seq_length, all_dim) - attn_output = self.out_proj(attn_output) - - return attn_output - - -class Qwen2_5OmniAudioSdpaAttention(Qwen2_5OmniAudioAttention): - def forward( - self, - hidden_states: torch.Tensor, - cu_seqlens: Optional[torch.Tensor] = None, - ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: - """Input shape: Batch x Time x Channel""" - - seq_length, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states).reshape(seq_length, self.num_heads, -1) - key_states = self.k_proj(hidden_states).reshape(seq_length, self.num_heads, -1) - value_states = self.v_proj(hidden_states).reshape(seq_length, self.num_heads, -1) - - attention_mask = torch.zeros( - [1, seq_length, key_states.shape[0]], device=query_states.device, dtype=torch.bool - ) - for i in range(1, len(cu_seqlens)): - attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = True - - # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment - # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. - # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1. - query_states = query_states.transpose(0, 1) - key_states = key_states.transpose(0, 1) - value_states = value_states.transpose(0, 1) - - # NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask, - # but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577 - attn_output = torch.nn.functional.scaled_dot_product_attention( + attn_output, _ = attention_interface( + self, query_states, key_states, value_states, - attn_mask=attention_mask, - dropout_p=self.dropout if self.training else 0.0, + attention_mask, + dropout=0.0 if not self.training else self.dropout, + scaling=self.scaling, + cu_seqlens_q=cu_seqlens, # pass cu seq lens for FA2 + cu_seqlens_k=cu_seqlens, + max_seqlen_q=max_seqlen, + max_seqlen_k=max_seqlen, + is_causal=False, + **kwargs, ) - attn_output = attn_output.transpose(0, 1) - # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be - # partitioned across GPUs when using tensor-parallelism. - attn_output = attn_output.reshape(seq_length, self.embed_dim) + + attn_output = attn_output.reshape(seq_length, -1).contiguous() attn_output = self.out_proj(attn_output) + return attn_output -QWEN2_5_OMNI_AUDIO_ATTENTION_CLASSES = { - "eager": Qwen2_5OmniAudioAttention, - "flash_attention_2": Qwen2_5OmniAudioFlashAttention2, - "sdpa": Qwen2_5OmniAudioSdpaAttention, -} - - class Qwen2_5OmniAudioEncoderLayer(Qwen2AudioEncoderLayer): def __init__(self, config: Qwen2_5OmniAudioEncoderConfig): super().__init__(config) - self.self_attn = QWEN2_5_OMNI_AUDIO_ATTENTION_CLASSES[config._attn_implementation](config) + self.self_attn = Qwen2_5OmniAudioAttention(config) def forward( self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, + **kwargs, ) -> torch.Tensor: residual = hidden_states hidden_states = self.self_attn_layer_norm(hidden_states) hidden_states = self.self_attn( hidden_states=hidden_states, cu_seqlens=cu_seqlens, + **kwargs, ) hidden_states = residual + hidden_states residual = hidden_states @@ -1849,6 +1776,7 @@ class Qwen2_5OmniAudioEncoder(Qwen2_5OmniPreTrainedModel): input_features, feature_lens=None, aftercnn_lens=None, + **kwargs, ): r""" input_features (`torch.LongTensor` of shape `(batch_size, feature_size, sequence_length)`): @@ -1892,7 +1820,7 @@ class Qwen2_5OmniAudioEncoder(Qwen2_5OmniPreTrainedModel): ).to(torch.int32) for encoder_layer in self.layers: - layer_outputs = encoder_layer(hidden_states, cu_seqlens) + layer_outputs = encoder_layer(hidden_states, cu_seqlens, **kwargs) hidden_states = layer_outputs[0] hidden_states_list = hidden_states.split(aftercnn_lens.tolist(), dim=0) @@ -1966,127 +1894,86 @@ def apply_rotary_pos_emb_vision(tensor: torch.Tensor, freqs: torch.Tensor) -> to class Qwen2_5OmniVisionAttention(nn.Module): - def __init__(self, dim: int, num_heads: int = 16) -> None: + def __init__(self, config: Qwen2_5OmniVisionEncoderConfig = None) -> None: super().__init__() - self.num_heads = num_heads - self.head_dim = dim // num_heads - self.q = nn.Linear(dim, dim, bias=True) - self.k = nn.Linear(dim, dim, bias=True) - self.v = nn.Linear(dim, dim, bias=True) - self.proj = nn.Linear(dim, dim) + self.dim = config.hidden_size + self.num_heads = config.num_heads + self.head_dim = self.dim // self.num_heads + self.q = nn.Linear(self.dim, self.dim, bias=True) + self.k = nn.Linear(self.dim, self.dim, bias=True) + self.v = nn.Linear(self.dim, self.dim, bias=True) + self.proj = nn.Linear(self.dim, self.dim) + self.scaling = math.sqrt(self.head_dim) + self.num_key_value_groups = 1 # needed for eager attention + self.config = config def forward( - self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor = None + self, + hidden_states: torch.Tensor, + cu_seqlens: torch.Tensor, + rotary_pos_emb: Optional[torch.Tensor] = None, + **kwargs, ) -> torch.Tensor: seq_length = hidden_states.shape[0] - q = self.q(hidden_states).reshape(seq_length, self.num_heads, -1) - k = self.k(hidden_states).reshape(seq_length, self.num_heads, -1) - v = self.v(hidden_states).reshape(seq_length, self.num_heads, -1) - q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0) - k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0) + query_states = self.q(hidden_states).reshape(seq_length, self.num_heads, -1) + key_states = self.k(hidden_states).reshape(seq_length, self.num_heads, -1) + value_states = self.v(hidden_states).reshape(seq_length, self.num_heads, -1) + query_states = apply_rotary_pos_emb_vision(query_states.unsqueeze(0), rotary_pos_emb).squeeze(0) + key_states = apply_rotary_pos_emb_vision(key_states.unsqueeze(0), rotary_pos_emb).squeeze(0) attention_mask = torch.full( - [1, seq_length, seq_length], torch.finfo(q.dtype).min, device=q.device, dtype=q.dtype + [1, 1, seq_length, seq_length], + torch.finfo(query_states.dtype).min, + device=query_states.device, + dtype=query_states.dtype, ) for i in range(1, len(cu_seqlens)): attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0 - q = q.transpose(0, 1) - k = k.transpose(0, 1) - v = v.transpose(0, 1) - attn_weights = torch.matmul(q, k.transpose(1, 2)) / math.sqrt(self.head_dim) - attn_weights = attn_weights + attention_mask - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype) - attn_output = torch.matmul(attn_weights, v) - attn_output = attn_output.transpose(0, 1) - attn_output = attn_output.reshape(seq_length, -1) - attn_output = self.proj(attn_output) - return attn_output - - -class Qwen2_5OmniVisionFlashAttention2(nn.Module): - def __init__(self, dim: int, num_heads: int = 16) -> None: - super().__init__() - self.num_heads = num_heads - self.q = nn.Linear(dim, dim, bias=True) - self.k = nn.Linear(dim, dim, bias=True) - self.v = nn.Linear(dim, dim, bias=True) - self.proj = nn.Linear(dim, dim) - - def _apply_rotary_pos_emb_flashatt(self, tensor: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor: - tensor_ = tensor.float() - cos = freqs.cos().type_as(tensor_) - sin = freqs.sin().type_as(tensor_) - output = apply_rotary_emb(tensor_, cos, sin).type_as(tensor) - return output - - def forward( - self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor = None - ) -> torch.Tensor: - seq_length = hidden_states.shape[0] - q = self.q(hidden_states).reshape(seq_length, self.num_heads, -1) - k = self.k(hidden_states).reshape(seq_length, self.num_heads, -1) - v = self.v(hidden_states).reshape(seq_length, self.num_heads, -1) - q = self._apply_rotary_pos_emb_flashatt(q.unsqueeze(0), rotary_pos_emb).squeeze(0) - k = self._apply_rotary_pos_emb_flashatt(k.unsqueeze(0), rotary_pos_emb).squeeze(0) - + query_states = query_states.transpose(0, 1).unsqueeze(0) # unsqueeze batch_dim + key_states = key_states.transpose(0, 1).unsqueeze(0) # unsqueeze batch_dim + value_states = value_states.transpose(0, 1).unsqueeze(0) # unsqueeze batch_dim max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() - attn_output = flash_attn_varlen_func(q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen).reshape( - seq_length, -1 + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, _ = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0, + scaling=self.scaling, + cu_seqlens_q=cu_seqlens, # pass cu seq lens for FA2 + cu_seqlens_k=cu_seqlens, + max_seqlen_q=max_seqlen, + max_seqlen_k=max_seqlen, + is_causal=False, + **kwargs, ) + + attn_output = attn_output.reshape(seq_length, -1).contiguous() attn_output = self.proj(attn_output) return attn_output -class Qwen2_5OmniVisionSdpaAttention(nn.Module): - def __init__(self, dim: int, num_heads: int = 16) -> None: - super().__init__() - self.num_heads = num_heads - self.q = nn.Linear(dim, dim, bias=True) - self.k = nn.Linear(dim, dim, bias=True) - self.v = nn.Linear(dim, dim, bias=True) - self.proj = nn.Linear(dim, dim) - - def forward( - self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor = None - ) -> torch.Tensor: - seq_length = hidden_states.shape[0] - q = self.q(hidden_states).reshape(seq_length, self.num_heads, -1) - k = self.k(hidden_states).reshape(seq_length, self.num_heads, -1) - v = self.v(hidden_states).reshape(seq_length, self.num_heads, -1) - q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0) - k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0) - - attention_mask = torch.zeros([1, seq_length, seq_length], device=q.device, dtype=torch.bool) - for i in range(1, len(cu_seqlens)): - attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = True - q = q.transpose(0, 1) - k = k.transpose(0, 1) - v = v.transpose(0, 1) - attn_output = F.scaled_dot_product_attention(q, k, v, attention_mask, dropout_p=0.0) - attn_output = attn_output.transpose(0, 1) - attn_output = attn_output.reshape(seq_length, -1) - attn_output = self.proj(attn_output) - return attn_output - - -QWEN2_5_OMNI_VISION_ATTENTION_CLASSES = { - "eager": Qwen2_5OmniVisionAttention, - "flash_attention_2": Qwen2_5OmniVisionFlashAttention2, - "sdpa": Qwen2_5OmniVisionSdpaAttention, -} - - class Qwen2_5OmniVisionBlock(Qwen2_5_VLVisionBlock): def __init__(self, config: Qwen2_5OmniVisionEncoderConfig) -> None: super().__init__(config, config._attn_implementation) - self.attn = QWEN2_5_OMNI_VISION_ATTENTION_CLASSES[config._attn_implementation]( - config.hidden_size, num_heads=config.num_heads - ) + self.attn = Qwen2_5OmniVisionAttention(config=config) - def forward(self, hidden_states, cu_seqlens, rotary_pos_emb) -> torch.Tensor: + def forward( + self, + hidden_states: torch.Tensor, + cu_seqlens: torch.Tensor, + rotary_pos_emb: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: hidden_states = hidden_states + self.attn( - self.norm1(hidden_states), cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb + self.norm1(hidden_states), cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb, **kwargs ) hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) return hidden_states @@ -2100,7 +1987,7 @@ class Qwen2_5OmniVisionEncoder(Qwen2_5_VisionTransformerPretrainedModel): super().__init__(config, *inputs, **kwargs) self.blocks = nn.ModuleList([Qwen2_5OmniVisionBlock(config) for _ in range(config.depth)]) - def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.Tensor: + def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs) -> torch.Tensor: """ Args: hidden_states (`torch.Tensor` of shape `(seq_len, hidden_size)`): @@ -2150,6 +2037,7 @@ class Qwen2_5OmniVisionEncoder(Qwen2_5_VisionTransformerPretrainedModel): hidden_states, cu_seqlens=cu_seqlens_now, rotary_pos_emb=rotary_pos_emb, + **kwargs, ) hidden_states = self.merger(hidden_states) reverse_indices = torch.argsort(window_index) diff --git a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py index 0bcbf1cb506..0122aa37e02 100644 --- a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py @@ -36,7 +36,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask -from ...modeling_flash_attention_utils import FlashAttentionKwargs, is_flash_attn_available +from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutputWithPast, ModelOutput from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update @@ -46,10 +46,6 @@ from ...utils import LossKwargs, auto_docstring, can_return_tuple, is_torchdynam from .configuration_qwen2_5_vl import Qwen2_5_VLConfig, Qwen2_5_VLTextConfig, Qwen2_5_VLVisionConfig -if is_flash_attn_available(): - from ...modeling_flash_attention_utils import apply_rotary_emb, flash_attn_varlen_func - - logger = logging.get_logger(__name__) @@ -141,56 +137,6 @@ class Qwen2_5_VLPatchMerger(nn.Module): return x -def apply_rotary_pos_emb_flashatt( - q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor -) -> tuple[torch.Tensor, torch.Tensor]: - cos = cos.chunk(2, dim=-1)[0].contiguous() - sin = sin.chunk(2, dim=-1)[0].contiguous() - q_embed = apply_rotary_emb(q.float(), cos.float(), sin.float()).type_as(q) - k_embed = apply_rotary_emb(k.float(), cos.float(), sin.float()).type_as(k) - return q_embed, k_embed - - -class Qwen2_5_VLVisionFlashAttention2(nn.Module): - def __init__(self, dim: int, num_heads: int = 16) -> None: - super().__init__() - self.num_heads = num_heads - self.qkv = nn.Linear(dim, dim * 3, bias=True) - self.proj = nn.Linear(dim, dim) - - def forward( - self, - hidden_states: torch.Tensor, - cu_seqlens: torch.Tensor, - rotary_pos_emb: Optional[torch.Tensor] = None, - position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, - ) -> torch.Tensor: - seq_length = hidden_states.shape[0] - q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) - if position_embeddings is None: - logger.warning_once( - "The attention layers in this model are transitioning from computing the RoPE embeddings internally " - "through `rotary_pos_emb` (2D tensor of RoPE theta values), to using externally computed " - "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.54 `rotary_pos_emb` will be " - "removed and `position_embeddings` will be mandatory." - ) - emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) - cos = emb.cos() - sin = emb.sin() - else: - cos, sin = position_embeddings - q, k = apply_rotary_pos_emb_flashatt(q.unsqueeze(0), k.unsqueeze(0), cos, sin) - q = q.squeeze(0) - k = k.squeeze(0) - - max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() - attn_output = flash_attn_varlen_func(q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen).reshape( - seq_length, -1 - ) - attn_output = self.proj(attn_output) - return attn_output - - def rotate_half(x): """Rotates half the hidden dims of the input.""" x1 = x[..., : x.shape[-1] // 2] @@ -212,13 +158,55 @@ def apply_rotary_pos_emb_vision( return q_embed, k_embed +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + class Qwen2_5_VLVisionAttention(nn.Module): - def __init__(self, dim: int, num_heads: int = 16) -> None: + def __init__(self, config: Qwen2_5_VLVisionConfig) -> None: super().__init__() - self.num_heads = num_heads - self.head_dim = dim // num_heads - self.qkv = nn.Linear(dim, dim * 3, bias=True) - self.proj = nn.Linear(dim, dim) + self.dim = config.hidden_size + self.num_heads = config.num_heads + self.head_dim = self.dim // self.num_heads + self.num_key_value_groups = 1 # needed for eager attention + self.qkv = nn.Linear(self.dim, self.dim * 3, bias=True) + self.proj = nn.Linear(self.dim, self.dim) + self.scaling = math.sqrt(self.head_dim) + self.config = config def forward( self, @@ -226,9 +214,12 @@ class Qwen2_5_VLVisionAttention(nn.Module): cu_seqlens: torch.Tensor, rotary_pos_emb: Optional[torch.Tensor] = None, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, + **kwargs, ) -> torch.Tensor: seq_length = hidden_states.shape[0] - q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) + query_states, key_states, value_states = ( + self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) + ) if position_embeddings is None: logger.warning_once( "The attention layers in this model are transitioning from computing the RoPE embeddings internally " @@ -241,87 +232,53 @@ class Qwen2_5_VLVisionAttention(nn.Module): sin = emb.sin() else: cos, sin = position_embeddings - q, k = apply_rotary_pos_emb_vision(q, k, cos, sin) + query_states, key_states = apply_rotary_pos_emb_vision(query_states, key_states, cos, sin) attention_mask = torch.full( - [1, seq_length, seq_length], torch.finfo(q.dtype).min, device=q.device, dtype=q.dtype + [1, 1, seq_length, seq_length], + torch.finfo(value_states.dtype).min, + device=value_states.device, + dtype=value_states.dtype, ) for i in range(1, len(cu_seqlens)): attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0 - q = q.transpose(0, 1) - k = k.transpose(0, 1) - v = v.transpose(0, 1) - attn_weights = torch.matmul(q, k.transpose(1, 2)) / math.sqrt(self.head_dim) - attn_weights = attn_weights + attention_mask - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype) - attn_output = torch.matmul(attn_weights, v) - attn_output = attn_output.transpose(0, 1) - attn_output = attn_output.reshape(seq_length, -1) - attn_output = self.proj(attn_output) - return attn_output + query_states = query_states.transpose(0, 1).unsqueeze(0) # unsqueeze batch_dim + key_states = key_states.transpose(0, 1).unsqueeze(0) # unsqueeze batch_dim + value_states = value_states.transpose(0, 1).unsqueeze(0) # unsqueeze batch_dim + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] -class Qwen2_5_VLVisionSdpaAttention(nn.Module): - def __init__(self, dim: int, num_heads: int = 16) -> None: - super().__init__() - self.num_heads = num_heads - self.qkv = nn.Linear(dim, dim * 3, bias=True) - self.proj = nn.Linear(dim, dim) - - def forward( - self, - hidden_states: torch.Tensor, - cu_seqlens: torch.Tensor, - rotary_pos_emb: Optional[torch.Tensor] = None, - position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, - ) -> torch.Tensor: - seq_length = hidden_states.shape[0] - q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) - if position_embeddings is None: - logger.warning_once( - "The attention layers in this model are transitioning from computing the RoPE embeddings internally " - "through `rotary_pos_emb` (2D tensor of RoPE theta values), to using externally computed " - "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.54 `rotary_pos_emb` will be " - "removed and `position_embeddings` will be mandatory." - ) - emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) - cos = emb.cos() - sin = emb.sin() - else: - cos, sin = position_embeddings - q, k = apply_rotary_pos_emb_vision(q, k, cos, sin) - - attention_mask = torch.zeros([1, seq_length, seq_length], device=q.device, dtype=torch.bool) - for i in range(1, len(cu_seqlens)): - attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = True - q = q.transpose(0, 1) - k = k.transpose(0, 1) - v = v.transpose(0, 1) - attn_output = F.scaled_dot_product_attention( - q.unsqueeze(0), k.unsqueeze(0), v.unsqueeze(0), attention_mask, dropout_p=0.0 + attn_output, _ = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0, + scaling=self.scaling, + cu_seqlens_q=cu_seqlens, # pass cu seq lens for FA2 + cu_seqlens_k=cu_seqlens, + max_seqlen_q=max_seqlen, + max_seqlen_k=max_seqlen, + is_causal=False, + **kwargs, ) - attn_output = attn_output.squeeze(0).transpose(0, 1) - attn_output = attn_output.reshape(seq_length, -1) + + attn_output = attn_output.reshape(seq_length, -1).contiguous() attn_output = self.proj(attn_output) return attn_output -QWEN2_5_VL_VISION_ATTENTION_CLASSES = { - "eager": Qwen2_5_VLVisionAttention, - "flash_attention_2": Qwen2_5_VLVisionFlashAttention2, - "sdpa": Qwen2_5_VLVisionSdpaAttention, -} - - class Qwen2_5_VLVisionBlock(GradientCheckpointingLayer): def __init__(self, config, attn_implementation: str = "sdpa") -> None: super().__init__() self.norm1 = Qwen2RMSNorm(config.hidden_size, eps=1e-6) self.norm2 = Qwen2RMSNorm(config.hidden_size, eps=1e-6) - self.attn = QWEN2_5_VL_VISION_ATTENTION_CLASSES[attn_implementation]( - config.hidden_size, num_heads=config.num_heads - ) + self.attn = Qwen2_5_VLVisionAttention(config=config) self.mlp = Qwen2_5_VLMLP(config, bias=True) def forward( @@ -330,12 +287,14 @@ class Qwen2_5_VLVisionBlock(GradientCheckpointingLayer): cu_seqlens: torch.Tensor, rotary_pos_emb: Optional[torch.Tensor] = None, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, + **kwargs, ) -> torch.Tensor: hidden_states = hidden_states + self.attn( self.norm1(hidden_states), cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb, position_embeddings=position_embeddings, + **kwargs, ) hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) return hidden_states @@ -390,9 +349,7 @@ class Qwen2_5_VisionTransformerPretrainedModel(Qwen2_5_VLPreTrainedModel): head_dim = config.hidden_size // config.num_heads self.rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding(head_dim // 2) - self.blocks = nn.ModuleList( - [Qwen2_5_VLVisionBlock(config, config._attn_implementation) for _ in range(config.depth)] - ) + self.blocks = nn.ModuleList([Qwen2_5_VLVisionBlock(config) for _ in range(config.depth)]) self.merger = Qwen2_5_VLPatchMerger( dim=config.out_hidden_size, context_dim=config.hidden_size, @@ -470,7 +427,7 @@ class Qwen2_5_VisionTransformerPretrainedModel(Qwen2_5_VLPreTrainedModel): return window_index, cu_window_seqlens - def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.Tensor: + def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs) -> torch.Tensor: """ Args: hidden_states (`torch.Tensor` of shape `(seq_len, hidden_size)`): @@ -516,7 +473,9 @@ class Qwen2_5_VisionTransformerPretrainedModel(Qwen2_5_VLPreTrainedModel): cu_seqlens_now = cu_seqlens else: cu_seqlens_now = cu_window_seqlens - hidden_states = blk(hidden_states, cu_seqlens=cu_seqlens_now, position_embeddings=position_embeddings) + hidden_states = blk( + hidden_states, cu_seqlens=cu_seqlens_now, position_embeddings=position_embeddings, **kwargs + ) hidden_states = self.merger(hidden_states) reverse_indices = torch.argsort(window_index) @@ -647,44 +606,6 @@ def apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim return q_embed, k_embed -def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: - """ - This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, - num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) - """ - batch, num_key_value_heads, slen, head_dim = hidden_states.shape - if n_rep == 1: - return hidden_states - hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) - return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) - - -def eager_attention_forward( - module: nn.Module, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - attention_mask: Optional[torch.Tensor], - scaling: float, - dropout: float = 0.0, - **kwargs, -): - key_states = repeat_kv(key, module.num_key_value_groups) - value_states = repeat_kv(value, module.num_key_value_groups) - - attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling - if attention_mask is not None: - causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] - attn_weights = attn_weights + causal_mask - - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) - attn_output = torch.matmul(attn_weights, value_states) - attn_output = attn_output.transpose(1, 2).contiguous() - - return attn_output, attn_weights - - class Qwen2_5_VLAttention(nn.Module): """ Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer diff --git a/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py index 81f3dddf4ac..84a7a69ac81 100644 --- a/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py @@ -40,7 +40,6 @@ from transformers.models.qwen2_vl.modeling_qwen2_vl import ( Qwen2VLPreTrainedModel, VisionAttention, VisionRotaryEmbedding, - VisionSdpaAttention, ) from transformers.models.qwen2_vl.processing_qwen2_vl import Qwen2VLImagesKwargs, Qwen2VLProcessor @@ -57,22 +56,12 @@ from ...video_utils import VideoInput if is_flash_attn_available(): - from ...modeling_flash_attention_utils import apply_rotary_emb, flash_attn_varlen_func + pass logger = logging.get_logger(__name__) -def apply_rotary_pos_emb_flashatt( - q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor -) -> tuple[torch.Tensor, torch.Tensor]: - cos = cos.chunk(2, dim=-1)[0].contiguous() - sin = sin.chunk(2, dim=-1)[0].contiguous() - q_embed = apply_rotary_emb(q.float(), cos.float(), sin.float()).type_as(q) - k_embed = apply_rotary_emb(k.float(), cos.float(), sin.float()).type_as(k) - return q_embed, k_embed - - class Qwen2_5_VLVisionConfig(PretrainedConfig): model_type = "qwen2_5_vl" base_config_key = "vision_config" @@ -150,59 +139,10 @@ class Qwen2_5_VLPatchMerger(PatchMerger): self.ln_q = Qwen2RMSNorm(context_dim, eps=1e-6) -class Qwen2_5_VLVisionFlashAttention2(nn.Module): - def __init__(self, dim: int, num_heads: int = 16) -> None: - super().__init__() - self.num_heads = num_heads - self.qkv = nn.Linear(dim, dim * 3, bias=True) - self.proj = nn.Linear(dim, dim) - - def forward( - self, - hidden_states: torch.Tensor, - cu_seqlens: torch.Tensor, - rotary_pos_emb: Optional[torch.Tensor] = None, - position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, - ) -> torch.Tensor: - seq_length = hidden_states.shape[0] - q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) - if position_embeddings is None: - logger.warning_once( - "The attention layers in this model are transitioning from computing the RoPE embeddings internally " - "through `rotary_pos_emb` (2D tensor of RoPE theta values), to using externally computed " - "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.54 `rotary_pos_emb` will be " - "removed and `position_embeddings` will be mandatory." - ) - emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) - cos = emb.cos() - sin = emb.sin() - else: - cos, sin = position_embeddings - q, k = apply_rotary_pos_emb_flashatt(q.unsqueeze(0), k.unsqueeze(0), cos, sin) - q = q.squeeze(0) - k = k.squeeze(0) - - max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() - attn_output = flash_attn_varlen_func(q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen).reshape( - seq_length, -1 - ) - attn_output = self.proj(attn_output) - return attn_output - - class Qwen2_5_VLVisionAttention(VisionAttention): - pass - - -class Qwen2_5_VLVisionSdpaAttention(VisionSdpaAttention): - pass - - -QWEN2_5_VL_VISION_ATTENTION_CLASSES = { - "eager": Qwen2_5_VLVisionAttention, - "flash_attention_2": Qwen2_5_VLVisionFlashAttention2, - "sdpa": Qwen2_5_VLVisionSdpaAttention, -} + def __init__(self, config: Qwen2_5_VLVisionConfig) -> None: + super().__init__() + self.dim = config.hidden_size class Qwen2_5_VLVisionBlock(GradientCheckpointingLayer): @@ -210,9 +150,7 @@ class Qwen2_5_VLVisionBlock(GradientCheckpointingLayer): super().__init__() self.norm1 = Qwen2RMSNorm(config.hidden_size, eps=1e-6) self.norm2 = Qwen2RMSNorm(config.hidden_size, eps=1e-6) - self.attn = QWEN2_5_VL_VISION_ATTENTION_CLASSES[attn_implementation]( - config.hidden_size, num_heads=config.num_heads - ) + self.attn = Qwen2_5_VLVisionAttention(config=config) self.mlp = Qwen2_5_VLMLP(config, bias=True) def forward( @@ -221,12 +159,14 @@ class Qwen2_5_VLVisionBlock(GradientCheckpointingLayer): cu_seqlens: torch.Tensor, rotary_pos_emb: Optional[torch.Tensor] = None, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, + **kwargs, ) -> torch.Tensor: hidden_states = hidden_states + self.attn( self.norm1(hidden_states), cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb, position_embeddings=position_embeddings, + **kwargs, ) hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) return hidden_states @@ -269,9 +209,7 @@ class Qwen2_5_VisionTransformerPretrainedModel(Qwen2_5_VLPreTrainedModel): head_dim = config.hidden_size // config.num_heads self.rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding(head_dim // 2) - self.blocks = nn.ModuleList( - [Qwen2_5_VLVisionBlock(config, config._attn_implementation) for _ in range(config.depth)] - ) + self.blocks = nn.ModuleList([Qwen2_5_VLVisionBlock(config) for _ in range(config.depth)]) self.merger = Qwen2_5_VLPatchMerger( dim=config.out_hidden_size, context_dim=config.hidden_size, @@ -349,7 +287,7 @@ class Qwen2_5_VisionTransformerPretrainedModel(Qwen2_5_VLPreTrainedModel): return window_index, cu_window_seqlens - def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.Tensor: + def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs) -> torch.Tensor: """ Args: hidden_states (`torch.Tensor` of shape `(seq_len, hidden_size)`): @@ -395,7 +333,9 @@ class Qwen2_5_VisionTransformerPretrainedModel(Qwen2_5_VLPreTrainedModel): cu_seqlens_now = cu_seqlens else: cu_seqlens_now = cu_window_seqlens - hidden_states = blk(hidden_states, cu_seqlens=cu_seqlens_now, position_embeddings=position_embeddings) + hidden_states = blk( + hidden_states, cu_seqlens=cu_seqlens_now, position_embeddings=position_embeddings, **kwargs + ) hidden_states = self.merger(hidden_states) reverse_indices = torch.argsort(window_index) diff --git a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py index f6d43308145..3b3c460c0c6 100644 --- a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py +++ b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py @@ -33,7 +33,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask -from ...modeling_flash_attention_utils import FlashAttentionKwargs, is_flash_attn_available +from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutputWithPast, ModelOutput from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update @@ -49,10 +49,6 @@ from ...utils import ( from .configuration_qwen2_vl import Qwen2VLConfig, Qwen2VLTextConfig, Qwen2VLVisionConfig -if is_flash_attn_available(): - from ...modeling_flash_attention_utils import flash_attn_varlen_func - - logger = logging.get_logger(__name__) @@ -279,13 +275,56 @@ class VisionMlp(nn.Module): return self.fc2(self.act(self.fc1(x))) +# Copied from transformers.models.llama.modeling_llama.repeat_kv +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + class VisionAttention(nn.Module): - def __init__(self, dim: int, num_heads: int = 16) -> None: + def __init__(self, config: Qwen2VLVisionConfig) -> None: super().__init__() - self.num_heads = num_heads - self.head_dim = dim // num_heads - self.qkv = nn.Linear(dim, dim * 3, bias=True) - self.proj = nn.Linear(dim, dim) + self.dim = config.embed_dim + self.num_heads = config.num_heads + self.head_dim = self.dim // self.num_heads + self.num_key_value_groups = 1 # needed for eager attention + self.qkv = nn.Linear(self.dim, self.dim * 3, bias=True) + self.proj = nn.Linear(self.dim, self.dim) + self.scaling = math.sqrt(self.head_dim) + self.config = config def forward( self, @@ -293,9 +332,12 @@ class VisionAttention(nn.Module): cu_seqlens: torch.Tensor, rotary_pos_emb: Optional[torch.Tensor] = None, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, + **kwargs, ) -> torch.Tensor: seq_length = hidden_states.shape[0] - q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) + query_states, key_states, value_states = ( + self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) + ) if position_embeddings is None: logger.warning_once( "The attention layers in this model are transitioning from computing the RoPE embeddings internally " @@ -308,117 +350,47 @@ class VisionAttention(nn.Module): sin = emb.sin() else: cos, sin = position_embeddings - q, k = apply_rotary_pos_emb_vision(q, k, cos, sin) + query_states, key_states = apply_rotary_pos_emb_vision(query_states, key_states, cos, sin) attention_mask = torch.full( - [1, seq_length, seq_length], torch.finfo(q.dtype).min, device=q.device, dtype=q.dtype + [1, 1, seq_length, seq_length], + torch.finfo(value_states.dtype).min, + device=value_states.device, + dtype=value_states.dtype, ) for i in range(1, len(cu_seqlens)): attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0 - q = q.transpose(0, 1) - k = k.transpose(0, 1) - v = v.transpose(0, 1) - attn_weights = torch.matmul(q, k.transpose(1, 2)) / math.sqrt(self.head_dim) - attn_weights = attn_weights + attention_mask - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype) - attn_output = torch.matmul(attn_weights, v) - attn_output = attn_output.transpose(0, 1) - attn_output = attn_output.reshape(seq_length, -1) - attn_output = self.proj(attn_output) - return attn_output - - -class VisionFlashAttention2(nn.Module): - def __init__(self, dim: int, num_heads: int = 16) -> None: - super().__init__() - self.num_heads = num_heads - self.qkv = nn.Linear(dim, dim * 3, bias=True) - self.proj = nn.Linear(dim, dim) - - def forward( - self, - hidden_states: torch.Tensor, - cu_seqlens: torch.Tensor, - rotary_pos_emb: Optional[torch.Tensor] = None, - position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, - ) -> torch.Tensor: - seq_length = hidden_states.shape[0] - q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) - if position_embeddings is None: - logger.warning_once( - "The attention layers in this model are transitioning from computing the RoPE embeddings internally " - "through `rotary_pos_emb` (2D tensor of RoPE theta values), to using externally computed " - "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.54 `rotary_pos_emb` will be " - "removed and `position_embeddings` will be mandatory." - ) - emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) - cos = emb.cos() - sin = emb.sin() - else: - cos, sin = position_embeddings - q, k = apply_rotary_pos_emb_vision(q, k, cos, sin) - + query_states = query_states.transpose(0, 1).unsqueeze(0) # unsqueeze batch_dim + key_states = key_states.transpose(0, 1).unsqueeze(0) # unsqueeze batch_dim + value_states = value_states.transpose(0, 1).unsqueeze(0) # unsqueeze batch_dim max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() - attn_output = flash_attn_varlen_func(q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen).reshape( - seq_length, -1 + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, _ = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0, + scaling=self.scaling, + cu_seqlens_q=cu_seqlens, # pass cu seq lens for FA2 + cu_seqlens_k=cu_seqlens, + max_seqlen_q=max_seqlen, + max_seqlen_k=max_seqlen, + is_causal=False, + **kwargs, ) + + attn_output = attn_output.reshape(seq_length, -1).contiguous() attn_output = self.proj(attn_output) return attn_output -class VisionSdpaAttention(nn.Module): - def __init__(self, dim: int, num_heads: int = 16) -> None: - super().__init__() - self.num_heads = num_heads - self.qkv = nn.Linear(dim, dim * 3, bias=True) - self.proj = nn.Linear(dim, dim) - - def forward( - self, - hidden_states: torch.Tensor, - cu_seqlens: torch.Tensor, - rotary_pos_emb: Optional[torch.Tensor] = None, - position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, - ) -> torch.Tensor: - seq_length = hidden_states.shape[0] - q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) - if position_embeddings is None: - logger.warning_once( - "The attention layers in this model are transitioning from computing the RoPE embeddings internally " - "through `rotary_pos_emb` (2D tensor of RoPE theta values), to using externally computed " - "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.54 `rotary_pos_emb` will be " - "removed and `position_embeddings` will be mandatory." - ) - emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) - cos = emb.cos() - sin = emb.sin() - else: - cos, sin = position_embeddings - q, k = apply_rotary_pos_emb_vision(q, k, cos, sin) - - attention_mask = torch.zeros([1, seq_length, seq_length], device=q.device, dtype=torch.bool) - for i in range(1, len(cu_seqlens)): - attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = True - q = q.transpose(0, 1) - k = k.transpose(0, 1) - v = v.transpose(0, 1) - attn_output = F.scaled_dot_product_attention( - q.unsqueeze(0), k.unsqueeze(0), v.unsqueeze(0), attention_mask, dropout_p=0.0 - ) - attn_output = attn_output.squeeze(0).transpose(0, 1) - attn_output = attn_output.reshape(seq_length, -1) - attn_output = self.proj(attn_output) - return attn_output - - -QWEN2_VL_VISION_ATTENTION_CLASSES = { - "eager": VisionAttention, - "flash_attention_2": VisionFlashAttention2, - "sdpa": VisionSdpaAttention, -} - - class Qwen2VLVisionBlock(GradientCheckpointingLayer): def __init__(self, config, attn_implementation: str = "sdpa") -> None: super().__init__() @@ -426,9 +398,7 @@ class Qwen2VLVisionBlock(GradientCheckpointingLayer): self.norm2 = LayerNorm(config.embed_dim, eps=1e-6) mlp_hidden_dim = int(config.embed_dim * config.mlp_ratio) - self.attn = QWEN2_VL_VISION_ATTENTION_CLASSES[attn_implementation]( - config.embed_dim, num_heads=config.num_heads - ) + self.attn = VisionAttention(config=config) self.mlp = VisionMlp(dim=config.embed_dim, hidden_dim=mlp_hidden_dim, hidden_act=config.hidden_act) def forward( @@ -437,12 +407,14 @@ class Qwen2VLVisionBlock(GradientCheckpointingLayer): cu_seqlens: torch.Tensor, rotary_pos_emb: Optional[torch.Tensor] = None, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, + **kwargs, ) -> torch.Tensor: hidden_states = hidden_states + self.attn( self.norm1(hidden_states), cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb, position_embeddings=position_embeddings, + **kwargs, ) hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) return hidden_states @@ -486,45 +458,6 @@ class Qwen2MLP(nn.Module): return down_proj -# Copied from transformers.models.llama.modeling_llama.repeat_kv -def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: - """ - This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, - num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) - """ - batch, num_key_value_heads, slen, head_dim = hidden_states.shape - if n_rep == 1: - return hidden_states - hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) - return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) - - -def eager_attention_forward( - module: nn.Module, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - attention_mask: Optional[torch.Tensor], - scaling: float, - dropout: float = 0.0, - **kwargs, -): - key_states = repeat_kv(key, module.num_key_value_groups) - value_states = repeat_kv(value, module.num_key_value_groups) - - attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling - if attention_mask is not None: - causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] - attn_weights = attn_weights + causal_mask - - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) - attn_output = torch.matmul(attn_weights, value_states) - attn_output = attn_output.transpose(1, 2).contiguous() - - return attn_output, attn_weights - - class Qwen2VLAttention(nn.Module): """ Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer @@ -752,9 +685,7 @@ class Qwen2VisionTransformerPretrainedModel(Qwen2VLPreTrainedModel): head_dim = config.embed_dim // config.num_heads self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2) - self.blocks = nn.ModuleList( - [Qwen2VLVisionBlock(config, config._attn_implementation) for _ in range(config.depth)] - ) + self.blocks = nn.ModuleList([Qwen2VLVisionBlock(config) for _ in range(config.depth)]) self.merger = PatchMerger( dim=config.hidden_size, context_dim=config.embed_dim, spatial_merge_size=config.spatial_merge_size ) @@ -796,7 +727,12 @@ class Qwen2VisionTransformerPretrainedModel(Qwen2VLPreTrainedModel): return rotary_pos_emb @auto_docstring - def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.Tensor: + def forward( + self, + hidden_states: torch.Tensor, + grid_thw: torch.Tensor, + **kwargs, + ) -> torch.Tensor: r""" grid_thw (`torch.LongTensor` of shape `(num_images, 3)`): The temporal, height and width dimensions of feature shape for each image. Each row contains [t, h, w] values. @@ -817,7 +753,9 @@ class Qwen2VisionTransformerPretrainedModel(Qwen2VLPreTrainedModel): cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) for blk in self.blocks: - hidden_states = blk(hidden_states, cu_seqlens=cu_seqlens, position_embeddings=position_embeddings) + hidden_states = blk( + hidden_states, cu_seqlens=cu_seqlens, position_embeddings=position_embeddings, **kwargs + ) return self.merger(hidden_states)