[qwen] refactor attentions for vision/audio (#38930)
Some checks are pending
Self-hosted runner (benchmark) / Benchmark (aws-g5-4xlarge-cache) (push) Waiting to run
Build documentation / build (push) Waiting to run
New model PR merged notification / Notify new model (push) Waiting to run
Slow tests on important models (on Push - A10) / Get all modified files (push) Waiting to run
Slow tests on important models (on Push - A10) / Slow & FA2 tests (push) Blocked by required conditions
Self-hosted runner (push-caller) / Check if setup was changed (push) Waiting to run
Self-hosted runner (push-caller) / build-docker-containers (push) Blocked by required conditions
Self-hosted runner (push-caller) / Trigger Push CI (push) Blocked by required conditions
Secret Leaks / trufflehog (push) Waiting to run
Update Transformers metadata / build_and_package (push) Waiting to run

* refactor attentions in vision/audio

* remove fa2 import

* make config the only args

* pass along kwargs from modality encoders

* style
This commit is contained in:
Raushan Turganbay 2025-06-24 10:53:52 +02:00 committed by GitHub
parent 2e4c045540
commit d3d835d4fc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 414 additions and 851 deletions

View File

@ -34,11 +34,7 @@ from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache
from ...generation import GenerationMixin
from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
from ...modeling_flash_attention_utils import (
FlashAttentionKwargs,
flash_attn_supports_top_left_mask,
is_flash_attn_available,
)
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPast, ModelOutput
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
@ -59,13 +55,6 @@ from .configuration_qwen2_5_omni import (
)
if is_flash_attn_available():
from ...modeling_flash_attention_utils import apply_rotary_emb, flash_attn_varlen_func
else:
flash_attn_varlen_func = None
apply_rotary_emb = None
logger = logging.get_logger(__name__)
@ -559,6 +548,44 @@ class Qwen2_5OmniThinkerCausalLMOutputWithPast(ModelOutput):
rope_deltas: Optional[torch.LongTensor] = None
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
def eager_attention_forward(
module: nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
scaling: float,
dropout: float = 0.0,
**kwargs,
):
key_states = repeat_kv(key, module.num_key_value_groups)
value_states = repeat_kv(value, module.num_key_value_groups)
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
if attention_mask is not None:
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, attn_weights
class Qwen2_5OmniAudioAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
@ -571,6 +598,7 @@ class Qwen2_5OmniAudioAttention(nn.Module):
self.num_heads = config.encoder_attention_heads
self.dropout = config.attention_dropout
self.head_dim = self.embed_dim // self.num_heads
self.num_key_value_groups = 1 # needed for eager attention
self.config = config
if (self.head_dim * self.num_heads) != self.embed_dim:
@ -591,6 +619,7 @@ class Qwen2_5OmniAudioAttention(nn.Module):
self,
hidden_states: torch.Tensor,
cu_seqlens: Optional[torch.Tensor] = None,
**kwargs,
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""
@ -600,13 +629,13 @@ class Qwen2_5OmniAudioAttention(nn.Module):
key_states = self.k_proj(hidden_states).reshape(seq_length, self.num_heads, -1)
value_states = self.v_proj(hidden_states).reshape(seq_length, self.num_heads, -1)
query_states = query_states.transpose(0, 1)
key_states = key_states.transpose(0, 1)
value_states = value_states.transpose(0, 1)
attn_weights = torch.matmul(query_states, key_states.transpose(1, 2)) / math.sqrt(self.head_dim)
query_states = query_states.transpose(0, 1).unsqueeze(0)
key_states = key_states.transpose(0, 1).unsqueeze(0)
value_states = value_states.transpose(0, 1).unsqueeze(0)
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
attention_mask = torch.full(
[1, seq_length, key_states.shape[1]],
[1, 1, seq_length, key_states.shape[-2]],
torch.finfo(query_states.dtype).min,
device=query_states.device,
dtype=query_states.dtype,
@ -614,115 +643,37 @@ class Qwen2_5OmniAudioAttention(nn.Module):
for i in range(1, len(cu_seqlens)):
attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0
attn_weights = attn_weights + attention_mask
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
attn_weights = nn.functional.softmax(attn_weights, dim=-1).to(query_states.dtype)
attn_output = torch.matmul(attn_weights, value_states).transpose(0, 1).reshape(seq_length, self.embed_dim)
# Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
# partitioned across GPUs when using tensor-parallelism.
attn_output = self.out_proj(attn_output)
return attn_output
class Qwen2_5OmniAudioFlashAttention2(Qwen2_5OmniAudioAttention):
"""
Qwen2.5OmniThinker flash attention module. This module inherits from `Qwen2_5OmniAudioAttention` as the weights of the module stays
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
flash attention and deal with padding tokens in case the input contains any of them.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask()
def forward(
self,
hidden_states: torch.Tensor,
cu_seqlens: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
seq_length, all_dim = hidden_states.size()
query_states = self.q_proj(hidden_states)
query_states = query_states.reshape(seq_length, self.num_heads, -1)
key_states = self.k_proj(hidden_states)
key_states = key_states.reshape(seq_length, self.num_heads, -1)
value_states = self.v_proj(hidden_states)
value_states = value_states.reshape(seq_length, self.num_heads, -1)
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
attn_output = flash_attn_varlen_func(
query_states, key_states, value_states, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen, dropout_p=0.0
)
attn_output = attn_output.reshape(seq_length, all_dim)
attn_output = self.out_proj(attn_output)
return attn_output
class Qwen2_5OmniAudioSdpaAttention(Qwen2_5OmniAudioAttention):
def forward(
self,
hidden_states: torch.Tensor,
cu_seqlens: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""
seq_length, _ = hidden_states.size()
query_states = self.q_proj(hidden_states).reshape(seq_length, self.num_heads, -1)
key_states = self.k_proj(hidden_states).reshape(seq_length, self.num_heads, -1)
value_states = self.v_proj(hidden_states).reshape(seq_length, self.num_heads, -1)
attention_mask = torch.zeros(
[1, seq_length, key_states.shape[0]], device=query_states.device, dtype=torch.bool
)
for i in range(1, len(cu_seqlens)):
attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = True
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
# The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1.
query_states = query_states.transpose(0, 1)
key_states = key_states.transpose(0, 1)
value_states = value_states.transpose(0, 1)
# NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask,
# but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577
attn_output = torch.nn.functional.scaled_dot_product_attention(
attn_output, _ = attention_interface(
self,
query_states,
key_states,
value_states,
attn_mask=attention_mask,
dropout_p=self.dropout if self.training else 0.0,
attention_mask,
dropout=0.0 if not self.training else self.dropout,
scaling=self.scaling,
cu_seqlens_q=cu_seqlens, # pass cu seq lens for FA2
cu_seqlens_k=cu_seqlens,
max_seqlen_q=max_seqlen,
max_seqlen_k=max_seqlen,
is_causal=False,
**kwargs,
)
attn_output = attn_output.transpose(0, 1)
# Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
# partitioned across GPUs when using tensor-parallelism.
attn_output = attn_output.reshape(seq_length, self.embed_dim)
attn_output = attn_output.reshape(seq_length, -1).contiguous()
attn_output = self.out_proj(attn_output)
return attn_output
QWEN2_5_OMNI_AUDIO_ATTENTION_CLASSES = {
"eager": Qwen2_5OmniAudioAttention,
"flash_attention_2": Qwen2_5OmniAudioFlashAttention2,
"sdpa": Qwen2_5OmniAudioSdpaAttention,
}
class Qwen2_5OmniAudioEncoderLayer(GradientCheckpointingLayer):
def __init__(self, config: Qwen2_5OmniAudioEncoderConfig):
super().__init__()
self.embed_dim = config.d_model
self.self_attn = QWEN2_5_OMNI_AUDIO_ATTENTION_CLASSES[config._attn_implementation](config)
self.self_attn = Qwen2_5OmniAudioAttention(config)
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.dropout = config.dropout
self.activation_fn = ACT2FN[config.activation_function]
@ -735,6 +686,7 @@ class Qwen2_5OmniAudioEncoderLayer(GradientCheckpointingLayer):
self,
hidden_states: torch.Tensor,
cu_seqlens: torch.Tensor,
**kwargs,
) -> torch.Tensor:
"""
Args:
@ -752,6 +704,7 @@ class Qwen2_5OmniAudioEncoderLayer(GradientCheckpointingLayer):
hidden_states = self.self_attn(
hidden_states=hidden_states,
cu_seqlens=cu_seqlens,
**kwargs,
)
hidden_states = residual + hidden_states
residual = hidden_states
@ -838,6 +791,7 @@ class Qwen2_5OmniAudioEncoder(Qwen2_5OmniPreTrainedModel):
input_features,
feature_lens=None,
aftercnn_lens=None,
**kwargs,
):
r"""
input_features (`torch.LongTensor` of shape `(batch_size, feature_size, sequence_length)`):
@ -881,7 +835,7 @@ class Qwen2_5OmniAudioEncoder(Qwen2_5OmniPreTrainedModel):
).to(torch.int32)
for encoder_layer in self.layers:
layer_outputs = encoder_layer(hidden_states, cu_seqlens)
layer_outputs = encoder_layer(hidden_states, cu_seqlens, **kwargs)
hidden_states = layer_outputs[0]
hidden_states_list = hidden_states.split(aftercnn_lens.tolist(), dim=0)
@ -962,106 +916,68 @@ def apply_rotary_pos_emb_vision(tensor: torch.Tensor, freqs: torch.Tensor) -> to
class Qwen2_5OmniVisionAttention(nn.Module):
def __init__(self, dim: int, num_heads: int = 16) -> None:
def __init__(self, config: Qwen2_5OmniVisionEncoderConfig = None) -> None:
super().__init__()
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.q = nn.Linear(dim, dim, bias=True)
self.k = nn.Linear(dim, dim, bias=True)
self.v = nn.Linear(dim, dim, bias=True)
self.proj = nn.Linear(dim, dim)
self.dim = config.hidden_size
self.num_heads = config.num_heads
self.head_dim = self.dim // self.num_heads
self.q = nn.Linear(self.dim, self.dim, bias=True)
self.k = nn.Linear(self.dim, self.dim, bias=True)
self.v = nn.Linear(self.dim, self.dim, bias=True)
self.proj = nn.Linear(self.dim, self.dim)
self.scaling = math.sqrt(self.head_dim)
self.num_key_value_groups = 1 # needed for eager attention
self.config = config
def forward(
self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor = None
self,
hidden_states: torch.Tensor,
cu_seqlens: torch.Tensor,
rotary_pos_emb: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
seq_length = hidden_states.shape[0]
q = self.q(hidden_states).reshape(seq_length, self.num_heads, -1)
k = self.k(hidden_states).reshape(seq_length, self.num_heads, -1)
v = self.v(hidden_states).reshape(seq_length, self.num_heads, -1)
q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0)
k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0)
query_states = self.q(hidden_states).reshape(seq_length, self.num_heads, -1)
key_states = self.k(hidden_states).reshape(seq_length, self.num_heads, -1)
value_states = self.v(hidden_states).reshape(seq_length, self.num_heads, -1)
query_states = apply_rotary_pos_emb_vision(query_states.unsqueeze(0), rotary_pos_emb).squeeze(0)
key_states = apply_rotary_pos_emb_vision(key_states.unsqueeze(0), rotary_pos_emb).squeeze(0)
attention_mask = torch.full(
[1, seq_length, seq_length], torch.finfo(q.dtype).min, device=q.device, dtype=q.dtype
[1, 1, seq_length, seq_length],
torch.finfo(query_states.dtype).min,
device=query_states.device,
dtype=query_states.dtype,
)
for i in range(1, len(cu_seqlens)):
attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0
q = q.transpose(0, 1)
k = k.transpose(0, 1)
v = v.transpose(0, 1)
attn_weights = torch.matmul(q, k.transpose(1, 2)) / math.sqrt(self.head_dim)
attn_weights = attn_weights + attention_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype)
attn_output = torch.matmul(attn_weights, v)
attn_output = attn_output.transpose(0, 1)
attn_output = attn_output.reshape(seq_length, -1)
attn_output = self.proj(attn_output)
return attn_output
class Qwen2_5OmniVisionFlashAttention2(nn.Module):
def __init__(self, dim: int, num_heads: int = 16) -> None:
super().__init__()
self.num_heads = num_heads
self.q = nn.Linear(dim, dim, bias=True)
self.k = nn.Linear(dim, dim, bias=True)
self.v = nn.Linear(dim, dim, bias=True)
self.proj = nn.Linear(dim, dim)
def _apply_rotary_pos_emb_flashatt(self, tensor: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor:
tensor_ = tensor.float()
cos = freqs.cos().type_as(tensor_)
sin = freqs.sin().type_as(tensor_)
output = apply_rotary_emb(tensor_, cos, sin).type_as(tensor)
return output
def forward(
self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor = None
) -> torch.Tensor:
seq_length = hidden_states.shape[0]
q = self.q(hidden_states).reshape(seq_length, self.num_heads, -1)
k = self.k(hidden_states).reshape(seq_length, self.num_heads, -1)
v = self.v(hidden_states).reshape(seq_length, self.num_heads, -1)
q = self._apply_rotary_pos_emb_flashatt(q.unsqueeze(0), rotary_pos_emb).squeeze(0)
k = self._apply_rotary_pos_emb_flashatt(k.unsqueeze(0), rotary_pos_emb).squeeze(0)
query_states = query_states.transpose(0, 1).unsqueeze(0) # unsqueeze batch_dim
key_states = key_states.transpose(0, 1).unsqueeze(0) # unsqueeze batch_dim
value_states = value_states.transpose(0, 1).unsqueeze(0) # unsqueeze batch_dim
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
attn_output = flash_attn_varlen_func(q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen).reshape(
seq_length, -1
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
attn_output, _ = attention_interface(
self,
query_states,
key_states,
value_states,
attention_mask,
dropout=0.0,
scaling=self.scaling,
cu_seqlens_q=cu_seqlens, # pass cu seq lens for FA2
cu_seqlens_k=cu_seqlens,
max_seqlen_q=max_seqlen,
max_seqlen_k=max_seqlen,
is_causal=False,
**kwargs,
)
attn_output = self.proj(attn_output)
return attn_output
class Qwen2_5OmniVisionSdpaAttention(nn.Module):
def __init__(self, dim: int, num_heads: int = 16) -> None:
super().__init__()
self.num_heads = num_heads
self.q = nn.Linear(dim, dim, bias=True)
self.k = nn.Linear(dim, dim, bias=True)
self.v = nn.Linear(dim, dim, bias=True)
self.proj = nn.Linear(dim, dim)
def forward(
self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor = None
) -> torch.Tensor:
seq_length = hidden_states.shape[0]
q = self.q(hidden_states).reshape(seq_length, self.num_heads, -1)
k = self.k(hidden_states).reshape(seq_length, self.num_heads, -1)
v = self.v(hidden_states).reshape(seq_length, self.num_heads, -1)
q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0)
k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0)
attention_mask = torch.zeros([1, seq_length, seq_length], device=q.device, dtype=torch.bool)
for i in range(1, len(cu_seqlens)):
attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = True
q = q.transpose(0, 1)
k = k.transpose(0, 1)
v = v.transpose(0, 1)
attn_output = F.scaled_dot_product_attention(q, k, v, attention_mask, dropout_p=0.0)
attn_output = attn_output.transpose(0, 1)
attn_output = attn_output.reshape(seq_length, -1)
attn_output = attn_output.reshape(seq_length, -1).contiguous()
attn_output = self.proj(attn_output)
return attn_output
@ -1080,26 +996,23 @@ class Qwen2_5OmniMLP(nn.Module):
return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state))
QWEN2_5_OMNI_VISION_ATTENTION_CLASSES = {
"eager": Qwen2_5OmniVisionAttention,
"flash_attention_2": Qwen2_5OmniVisionFlashAttention2,
"sdpa": Qwen2_5OmniVisionSdpaAttention,
}
class Qwen2_5OmniVisionBlock(GradientCheckpointingLayer):
def __init__(self, config: Qwen2_5OmniVisionEncoderConfig) -> None:
super().__init__()
self.norm1 = Qwen2RMSNorm(config.hidden_size, eps=1e-6)
self.norm2 = Qwen2RMSNorm(config.hidden_size, eps=1e-6)
self.attn = QWEN2_5_OMNI_VISION_ATTENTION_CLASSES[config._attn_implementation](
config.hidden_size, num_heads=config.num_heads
)
self.attn = Qwen2_5OmniVisionAttention(config=config)
self.mlp = Qwen2_5OmniMLP(config, bias=True)
def forward(self, hidden_states, cu_seqlens, rotary_pos_emb) -> torch.Tensor:
def forward(
self,
hidden_states: torch.Tensor,
cu_seqlens: torch.Tensor,
rotary_pos_emb: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
hidden_states = hidden_states + self.attn(
self.norm1(hidden_states), cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb
self.norm1(hidden_states), cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb, **kwargs
)
hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
return hidden_states
@ -1258,7 +1171,7 @@ class Qwen2_5OmniVisionEncoder(Qwen2_5OmniPreTrainedModel):
return window_index, cu_window_seqlens
def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.Tensor:
def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs) -> torch.Tensor:
"""
Args:
hidden_states (`torch.Tensor` of shape `(seq_len, hidden_size)`):
@ -1308,6 +1221,7 @@ class Qwen2_5OmniVisionEncoder(Qwen2_5OmniPreTrainedModel):
hidden_states,
cu_seqlens=cu_seqlens_now,
rotary_pos_emb=rotary_pos_emb,
**kwargs,
)
hidden_states = self.merger(hidden_states)
reverse_indices = torch.argsort(window_index)
@ -1397,44 +1311,6 @@ def apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim
return q_embed, k_embed
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
def eager_attention_forward(
module: nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
scaling: float,
dropout: float = 0.0,
**kwargs,
):
key_states = repeat_kv(key, module.num_key_value_groups)
value_states = repeat_kv(value, module.num_key_value_groups)
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
if attention_mask is not None:
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, attn_weights
class Qwen2_5OmniAttention(nn.Module):
"""
Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer

View File

@ -17,7 +17,7 @@
import math
from dataclasses import dataclass
from typing import Any, Optional, Union
from typing import Any, Callable, Optional, Union
import numpy as np
import torch
@ -36,6 +36,7 @@ from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
Qwen2_5_VLTextModel,
Qwen2_5_VLVisionBlock,
Qwen2RMSNorm,
eager_attention_forward,
)
from transformers.models.qwen2_audio.configuration_qwen2_audio import Qwen2AudioEncoderConfig
from transformers.models.qwen2_audio.modeling_qwen2_audio import Qwen2AudioEncoderLayer
@ -43,7 +44,7 @@ from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLRotaryEmbeddin
from ...configuration_utils import PretrainedConfig, layer_type_validation
from ...generation import GenerationMixin
from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available
from ...modeling_flash_attention_utils import is_flash_attn_available
from ...modeling_outputs import BaseModelOutput, ModelOutput
from ...modeling_rope_utils import rope_config_validation
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
@ -1601,6 +1602,7 @@ class Qwen2_5OmniAudioAttention(nn.Module):
self.num_heads = config.encoder_attention_heads
self.dropout = config.attention_dropout
self.head_dim = self.embed_dim // self.num_heads
self.num_key_value_groups = 1 # needed for eager attention
self.config = config
if (self.head_dim * self.num_heads) != self.embed_dim:
@ -1621,6 +1623,7 @@ class Qwen2_5OmniAudioAttention(nn.Module):
self,
hidden_states: torch.Tensor,
cu_seqlens: Optional[torch.Tensor] = None,
**kwargs,
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""
@ -1630,13 +1633,13 @@ class Qwen2_5OmniAudioAttention(nn.Module):
key_states = self.k_proj(hidden_states).reshape(seq_length, self.num_heads, -1)
value_states = self.v_proj(hidden_states).reshape(seq_length, self.num_heads, -1)
query_states = query_states.transpose(0, 1)
key_states = key_states.transpose(0, 1)
value_states = value_states.transpose(0, 1)
attn_weights = torch.matmul(query_states, key_states.transpose(1, 2)) / math.sqrt(self.head_dim)
query_states = query_states.transpose(0, 1).unsqueeze(0)
key_states = key_states.transpose(0, 1).unsqueeze(0)
value_states = value_states.transpose(0, 1).unsqueeze(0)
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
attention_mask = torch.full(
[1, seq_length, key_states.shape[1]],
[1, 1, seq_length, key_states.shape[-2]],
torch.finfo(query_states.dtype).min,
device=query_states.device,
dtype=query_states.dtype,
@ -1644,125 +1647,49 @@ class Qwen2_5OmniAudioAttention(nn.Module):
for i in range(1, len(cu_seqlens)):
attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0
attn_weights = attn_weights + attention_mask
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
attn_weights = nn.functional.softmax(attn_weights, dim=-1).to(query_states.dtype)
attn_output = torch.matmul(attn_weights, value_states).transpose(0, 1).reshape(seq_length, self.embed_dim)
# Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
# partitioned across GPUs when using tensor-parallelism.
attn_output = self.out_proj(attn_output)
return attn_output
class Qwen2_5OmniAudioFlashAttention2(Qwen2_5OmniAudioAttention):
"""
Qwen2.5OmniThinker flash attention module. This module inherits from `Qwen2_5OmniAudioAttention` as the weights of the module stays
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
flash attention and deal with padding tokens in case the input contains any of them.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask()
def forward(
self,
hidden_states: torch.Tensor,
cu_seqlens: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
seq_length, all_dim = hidden_states.size()
query_states = self.q_proj(hidden_states)
query_states = query_states.reshape(seq_length, self.num_heads, -1)
key_states = self.k_proj(hidden_states)
key_states = key_states.reshape(seq_length, self.num_heads, -1)
value_states = self.v_proj(hidden_states)
value_states = value_states.reshape(seq_length, self.num_heads, -1)
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
attn_output = flash_attn_varlen_func(
query_states, key_states, value_states, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen, dropout_p=0.0
)
attn_output = attn_output.reshape(seq_length, all_dim)
attn_output = self.out_proj(attn_output)
return attn_output
class Qwen2_5OmniAudioSdpaAttention(Qwen2_5OmniAudioAttention):
def forward(
self,
hidden_states: torch.Tensor,
cu_seqlens: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""
seq_length, _ = hidden_states.size()
query_states = self.q_proj(hidden_states).reshape(seq_length, self.num_heads, -1)
key_states = self.k_proj(hidden_states).reshape(seq_length, self.num_heads, -1)
value_states = self.v_proj(hidden_states).reshape(seq_length, self.num_heads, -1)
attention_mask = torch.zeros(
[1, seq_length, key_states.shape[0]], device=query_states.device, dtype=torch.bool
)
for i in range(1, len(cu_seqlens)):
attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = True
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
# The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1.
query_states = query_states.transpose(0, 1)
key_states = key_states.transpose(0, 1)
value_states = value_states.transpose(0, 1)
# NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask,
# but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577
attn_output = torch.nn.functional.scaled_dot_product_attention(
attn_output, _ = attention_interface(
self,
query_states,
key_states,
value_states,
attn_mask=attention_mask,
dropout_p=self.dropout if self.training else 0.0,
attention_mask,
dropout=0.0 if not self.training else self.dropout,
scaling=self.scaling,
cu_seqlens_q=cu_seqlens, # pass cu seq lens for FA2
cu_seqlens_k=cu_seqlens,
max_seqlen_q=max_seqlen,
max_seqlen_k=max_seqlen,
is_causal=False,
**kwargs,
)
attn_output = attn_output.transpose(0, 1)
# Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
# partitioned across GPUs when using tensor-parallelism.
attn_output = attn_output.reshape(seq_length, self.embed_dim)
attn_output = attn_output.reshape(seq_length, -1).contiguous()
attn_output = self.out_proj(attn_output)
return attn_output
QWEN2_5_OMNI_AUDIO_ATTENTION_CLASSES = {
"eager": Qwen2_5OmniAudioAttention,
"flash_attention_2": Qwen2_5OmniAudioFlashAttention2,
"sdpa": Qwen2_5OmniAudioSdpaAttention,
}
class Qwen2_5OmniAudioEncoderLayer(Qwen2AudioEncoderLayer):
def __init__(self, config: Qwen2_5OmniAudioEncoderConfig):
super().__init__(config)
self.self_attn = QWEN2_5_OMNI_AUDIO_ATTENTION_CLASSES[config._attn_implementation](config)
self.self_attn = Qwen2_5OmniAudioAttention(config)
def forward(
self,
hidden_states: torch.Tensor,
cu_seqlens: torch.Tensor,
**kwargs,
) -> torch.Tensor:
residual = hidden_states
hidden_states = self.self_attn_layer_norm(hidden_states)
hidden_states = self.self_attn(
hidden_states=hidden_states,
cu_seqlens=cu_seqlens,
**kwargs,
)
hidden_states = residual + hidden_states
residual = hidden_states
@ -1849,6 +1776,7 @@ class Qwen2_5OmniAudioEncoder(Qwen2_5OmniPreTrainedModel):
input_features,
feature_lens=None,
aftercnn_lens=None,
**kwargs,
):
r"""
input_features (`torch.LongTensor` of shape `(batch_size, feature_size, sequence_length)`):
@ -1892,7 +1820,7 @@ class Qwen2_5OmniAudioEncoder(Qwen2_5OmniPreTrainedModel):
).to(torch.int32)
for encoder_layer in self.layers:
layer_outputs = encoder_layer(hidden_states, cu_seqlens)
layer_outputs = encoder_layer(hidden_states, cu_seqlens, **kwargs)
hidden_states = layer_outputs[0]
hidden_states_list = hidden_states.split(aftercnn_lens.tolist(), dim=0)
@ -1966,127 +1894,86 @@ def apply_rotary_pos_emb_vision(tensor: torch.Tensor, freqs: torch.Tensor) -> to
class Qwen2_5OmniVisionAttention(nn.Module):
def __init__(self, dim: int, num_heads: int = 16) -> None:
def __init__(self, config: Qwen2_5OmniVisionEncoderConfig = None) -> None:
super().__init__()
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.q = nn.Linear(dim, dim, bias=True)
self.k = nn.Linear(dim, dim, bias=True)
self.v = nn.Linear(dim, dim, bias=True)
self.proj = nn.Linear(dim, dim)
self.dim = config.hidden_size
self.num_heads = config.num_heads
self.head_dim = self.dim // self.num_heads
self.q = nn.Linear(self.dim, self.dim, bias=True)
self.k = nn.Linear(self.dim, self.dim, bias=True)
self.v = nn.Linear(self.dim, self.dim, bias=True)
self.proj = nn.Linear(self.dim, self.dim)
self.scaling = math.sqrt(self.head_dim)
self.num_key_value_groups = 1 # needed for eager attention
self.config = config
def forward(
self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor = None
self,
hidden_states: torch.Tensor,
cu_seqlens: torch.Tensor,
rotary_pos_emb: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
seq_length = hidden_states.shape[0]
q = self.q(hidden_states).reshape(seq_length, self.num_heads, -1)
k = self.k(hidden_states).reshape(seq_length, self.num_heads, -1)
v = self.v(hidden_states).reshape(seq_length, self.num_heads, -1)
q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0)
k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0)
query_states = self.q(hidden_states).reshape(seq_length, self.num_heads, -1)
key_states = self.k(hidden_states).reshape(seq_length, self.num_heads, -1)
value_states = self.v(hidden_states).reshape(seq_length, self.num_heads, -1)
query_states = apply_rotary_pos_emb_vision(query_states.unsqueeze(0), rotary_pos_emb).squeeze(0)
key_states = apply_rotary_pos_emb_vision(key_states.unsqueeze(0), rotary_pos_emb).squeeze(0)
attention_mask = torch.full(
[1, seq_length, seq_length], torch.finfo(q.dtype).min, device=q.device, dtype=q.dtype
[1, 1, seq_length, seq_length],
torch.finfo(query_states.dtype).min,
device=query_states.device,
dtype=query_states.dtype,
)
for i in range(1, len(cu_seqlens)):
attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0
q = q.transpose(0, 1)
k = k.transpose(0, 1)
v = v.transpose(0, 1)
attn_weights = torch.matmul(q, k.transpose(1, 2)) / math.sqrt(self.head_dim)
attn_weights = attn_weights + attention_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype)
attn_output = torch.matmul(attn_weights, v)
attn_output = attn_output.transpose(0, 1)
attn_output = attn_output.reshape(seq_length, -1)
attn_output = self.proj(attn_output)
return attn_output
class Qwen2_5OmniVisionFlashAttention2(nn.Module):
def __init__(self, dim: int, num_heads: int = 16) -> None:
super().__init__()
self.num_heads = num_heads
self.q = nn.Linear(dim, dim, bias=True)
self.k = nn.Linear(dim, dim, bias=True)
self.v = nn.Linear(dim, dim, bias=True)
self.proj = nn.Linear(dim, dim)
def _apply_rotary_pos_emb_flashatt(self, tensor: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor:
tensor_ = tensor.float()
cos = freqs.cos().type_as(tensor_)
sin = freqs.sin().type_as(tensor_)
output = apply_rotary_emb(tensor_, cos, sin).type_as(tensor)
return output
def forward(
self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor = None
) -> torch.Tensor:
seq_length = hidden_states.shape[0]
q = self.q(hidden_states).reshape(seq_length, self.num_heads, -1)
k = self.k(hidden_states).reshape(seq_length, self.num_heads, -1)
v = self.v(hidden_states).reshape(seq_length, self.num_heads, -1)
q = self._apply_rotary_pos_emb_flashatt(q.unsqueeze(0), rotary_pos_emb).squeeze(0)
k = self._apply_rotary_pos_emb_flashatt(k.unsqueeze(0), rotary_pos_emb).squeeze(0)
query_states = query_states.transpose(0, 1).unsqueeze(0) # unsqueeze batch_dim
key_states = key_states.transpose(0, 1).unsqueeze(0) # unsqueeze batch_dim
value_states = value_states.transpose(0, 1).unsqueeze(0) # unsqueeze batch_dim
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
attn_output = flash_attn_varlen_func(q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen).reshape(
seq_length, -1
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
attn_output, _ = attention_interface(
self,
query_states,
key_states,
value_states,
attention_mask,
dropout=0.0,
scaling=self.scaling,
cu_seqlens_q=cu_seqlens, # pass cu seq lens for FA2
cu_seqlens_k=cu_seqlens,
max_seqlen_q=max_seqlen,
max_seqlen_k=max_seqlen,
is_causal=False,
**kwargs,
)
attn_output = attn_output.reshape(seq_length, -1).contiguous()
attn_output = self.proj(attn_output)
return attn_output
class Qwen2_5OmniVisionSdpaAttention(nn.Module):
def __init__(self, dim: int, num_heads: int = 16) -> None:
super().__init__()
self.num_heads = num_heads
self.q = nn.Linear(dim, dim, bias=True)
self.k = nn.Linear(dim, dim, bias=True)
self.v = nn.Linear(dim, dim, bias=True)
self.proj = nn.Linear(dim, dim)
def forward(
self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor = None
) -> torch.Tensor:
seq_length = hidden_states.shape[0]
q = self.q(hidden_states).reshape(seq_length, self.num_heads, -1)
k = self.k(hidden_states).reshape(seq_length, self.num_heads, -1)
v = self.v(hidden_states).reshape(seq_length, self.num_heads, -1)
q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0)
k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0)
attention_mask = torch.zeros([1, seq_length, seq_length], device=q.device, dtype=torch.bool)
for i in range(1, len(cu_seqlens)):
attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = True
q = q.transpose(0, 1)
k = k.transpose(0, 1)
v = v.transpose(0, 1)
attn_output = F.scaled_dot_product_attention(q, k, v, attention_mask, dropout_p=0.0)
attn_output = attn_output.transpose(0, 1)
attn_output = attn_output.reshape(seq_length, -1)
attn_output = self.proj(attn_output)
return attn_output
QWEN2_5_OMNI_VISION_ATTENTION_CLASSES = {
"eager": Qwen2_5OmniVisionAttention,
"flash_attention_2": Qwen2_5OmniVisionFlashAttention2,
"sdpa": Qwen2_5OmniVisionSdpaAttention,
}
class Qwen2_5OmniVisionBlock(Qwen2_5_VLVisionBlock):
def __init__(self, config: Qwen2_5OmniVisionEncoderConfig) -> None:
super().__init__(config, config._attn_implementation)
self.attn = QWEN2_5_OMNI_VISION_ATTENTION_CLASSES[config._attn_implementation](
config.hidden_size, num_heads=config.num_heads
)
self.attn = Qwen2_5OmniVisionAttention(config=config)
def forward(self, hidden_states, cu_seqlens, rotary_pos_emb) -> torch.Tensor:
def forward(
self,
hidden_states: torch.Tensor,
cu_seqlens: torch.Tensor,
rotary_pos_emb: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
hidden_states = hidden_states + self.attn(
self.norm1(hidden_states), cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb
self.norm1(hidden_states), cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb, **kwargs
)
hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
return hidden_states
@ -2100,7 +1987,7 @@ class Qwen2_5OmniVisionEncoder(Qwen2_5_VisionTransformerPretrainedModel):
super().__init__(config, *inputs, **kwargs)
self.blocks = nn.ModuleList([Qwen2_5OmniVisionBlock(config) for _ in range(config.depth)])
def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.Tensor:
def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs) -> torch.Tensor:
"""
Args:
hidden_states (`torch.Tensor` of shape `(seq_len, hidden_size)`):
@ -2150,6 +2037,7 @@ class Qwen2_5OmniVisionEncoder(Qwen2_5_VisionTransformerPretrainedModel):
hidden_states,
cu_seqlens=cu_seqlens_now,
rotary_pos_emb=rotary_pos_emb,
**kwargs,
)
hidden_states = self.merger(hidden_states)
reverse_indices = torch.argsort(window_index)

View File

@ -36,7 +36,7 @@ from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache
from ...generation import GenerationMixin
from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
from ...modeling_flash_attention_utils import FlashAttentionKwargs, is_flash_attn_available
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutputWithPast, ModelOutput
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
@ -46,10 +46,6 @@ from ...utils import LossKwargs, auto_docstring, can_return_tuple, is_torchdynam
from .configuration_qwen2_5_vl import Qwen2_5_VLConfig, Qwen2_5_VLTextConfig, Qwen2_5_VLVisionConfig
if is_flash_attn_available():
from ...modeling_flash_attention_utils import apply_rotary_emb, flash_attn_varlen_func
logger = logging.get_logger(__name__)
@ -141,56 +137,6 @@ class Qwen2_5_VLPatchMerger(nn.Module):
return x
def apply_rotary_pos_emb_flashatt(
q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
cos = cos.chunk(2, dim=-1)[0].contiguous()
sin = sin.chunk(2, dim=-1)[0].contiguous()
q_embed = apply_rotary_emb(q.float(), cos.float(), sin.float()).type_as(q)
k_embed = apply_rotary_emb(k.float(), cos.float(), sin.float()).type_as(k)
return q_embed, k_embed
class Qwen2_5_VLVisionFlashAttention2(nn.Module):
def __init__(self, dim: int, num_heads: int = 16) -> None:
super().__init__()
self.num_heads = num_heads
self.qkv = nn.Linear(dim, dim * 3, bias=True)
self.proj = nn.Linear(dim, dim)
def forward(
self,
hidden_states: torch.Tensor,
cu_seqlens: torch.Tensor,
rotary_pos_emb: Optional[torch.Tensor] = None,
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
) -> torch.Tensor:
seq_length = hidden_states.shape[0]
q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
if position_embeddings is None:
logger.warning_once(
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
"through `rotary_pos_emb` (2D tensor of RoPE theta values), to using externally computed "
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.54 `rotary_pos_emb` will be "
"removed and `position_embeddings` will be mandatory."
)
emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
cos = emb.cos()
sin = emb.sin()
else:
cos, sin = position_embeddings
q, k = apply_rotary_pos_emb_flashatt(q.unsqueeze(0), k.unsqueeze(0), cos, sin)
q = q.squeeze(0)
k = k.squeeze(0)
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
attn_output = flash_attn_varlen_func(q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen).reshape(
seq_length, -1
)
attn_output = self.proj(attn_output)
return attn_output
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
@ -212,13 +158,55 @@ def apply_rotary_pos_emb_vision(
return q_embed, k_embed
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
def eager_attention_forward(
module: nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
scaling: float,
dropout: float = 0.0,
**kwargs,
):
key_states = repeat_kv(key, module.num_key_value_groups)
value_states = repeat_kv(value, module.num_key_value_groups)
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
if attention_mask is not None:
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, attn_weights
class Qwen2_5_VLVisionAttention(nn.Module):
def __init__(self, dim: int, num_heads: int = 16) -> None:
def __init__(self, config: Qwen2_5_VLVisionConfig) -> None:
super().__init__()
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.qkv = nn.Linear(dim, dim * 3, bias=True)
self.proj = nn.Linear(dim, dim)
self.dim = config.hidden_size
self.num_heads = config.num_heads
self.head_dim = self.dim // self.num_heads
self.num_key_value_groups = 1 # needed for eager attention
self.qkv = nn.Linear(self.dim, self.dim * 3, bias=True)
self.proj = nn.Linear(self.dim, self.dim)
self.scaling = math.sqrt(self.head_dim)
self.config = config
def forward(
self,
@ -226,9 +214,12 @@ class Qwen2_5_VLVisionAttention(nn.Module):
cu_seqlens: torch.Tensor,
rotary_pos_emb: Optional[torch.Tensor] = None,
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
**kwargs,
) -> torch.Tensor:
seq_length = hidden_states.shape[0]
q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
query_states, key_states, value_states = (
self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
)
if position_embeddings is None:
logger.warning_once(
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
@ -241,87 +232,53 @@ class Qwen2_5_VLVisionAttention(nn.Module):
sin = emb.sin()
else:
cos, sin = position_embeddings
q, k = apply_rotary_pos_emb_vision(q, k, cos, sin)
query_states, key_states = apply_rotary_pos_emb_vision(query_states, key_states, cos, sin)
attention_mask = torch.full(
[1, seq_length, seq_length], torch.finfo(q.dtype).min, device=q.device, dtype=q.dtype
[1, 1, seq_length, seq_length],
torch.finfo(value_states.dtype).min,
device=value_states.device,
dtype=value_states.dtype,
)
for i in range(1, len(cu_seqlens)):
attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0
q = q.transpose(0, 1)
k = k.transpose(0, 1)
v = v.transpose(0, 1)
attn_weights = torch.matmul(q, k.transpose(1, 2)) / math.sqrt(self.head_dim)
attn_weights = attn_weights + attention_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype)
attn_output = torch.matmul(attn_weights, v)
attn_output = attn_output.transpose(0, 1)
attn_output = attn_output.reshape(seq_length, -1)
attn_output = self.proj(attn_output)
return attn_output
query_states = query_states.transpose(0, 1).unsqueeze(0) # unsqueeze batch_dim
key_states = key_states.transpose(0, 1).unsqueeze(0) # unsqueeze batch_dim
value_states = value_states.transpose(0, 1).unsqueeze(0) # unsqueeze batch_dim
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
class Qwen2_5_VLVisionSdpaAttention(nn.Module):
def __init__(self, dim: int, num_heads: int = 16) -> None:
super().__init__()
self.num_heads = num_heads
self.qkv = nn.Linear(dim, dim * 3, bias=True)
self.proj = nn.Linear(dim, dim)
def forward(
self,
hidden_states: torch.Tensor,
cu_seqlens: torch.Tensor,
rotary_pos_emb: Optional[torch.Tensor] = None,
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
) -> torch.Tensor:
seq_length = hidden_states.shape[0]
q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
if position_embeddings is None:
logger.warning_once(
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
"through `rotary_pos_emb` (2D tensor of RoPE theta values), to using externally computed "
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.54 `rotary_pos_emb` will be "
"removed and `position_embeddings` will be mandatory."
)
emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
cos = emb.cos()
sin = emb.sin()
else:
cos, sin = position_embeddings
q, k = apply_rotary_pos_emb_vision(q, k, cos, sin)
attention_mask = torch.zeros([1, seq_length, seq_length], device=q.device, dtype=torch.bool)
for i in range(1, len(cu_seqlens)):
attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = True
q = q.transpose(0, 1)
k = k.transpose(0, 1)
v = v.transpose(0, 1)
attn_output = F.scaled_dot_product_attention(
q.unsqueeze(0), k.unsqueeze(0), v.unsqueeze(0), attention_mask, dropout_p=0.0
attn_output, _ = attention_interface(
self,
query_states,
key_states,
value_states,
attention_mask,
dropout=0.0,
scaling=self.scaling,
cu_seqlens_q=cu_seqlens, # pass cu seq lens for FA2
cu_seqlens_k=cu_seqlens,
max_seqlen_q=max_seqlen,
max_seqlen_k=max_seqlen,
is_causal=False,
**kwargs,
)
attn_output = attn_output.squeeze(0).transpose(0, 1)
attn_output = attn_output.reshape(seq_length, -1)
attn_output = attn_output.reshape(seq_length, -1).contiguous()
attn_output = self.proj(attn_output)
return attn_output
QWEN2_5_VL_VISION_ATTENTION_CLASSES = {
"eager": Qwen2_5_VLVisionAttention,
"flash_attention_2": Qwen2_5_VLVisionFlashAttention2,
"sdpa": Qwen2_5_VLVisionSdpaAttention,
}
class Qwen2_5_VLVisionBlock(GradientCheckpointingLayer):
def __init__(self, config, attn_implementation: str = "sdpa") -> None:
super().__init__()
self.norm1 = Qwen2RMSNorm(config.hidden_size, eps=1e-6)
self.norm2 = Qwen2RMSNorm(config.hidden_size, eps=1e-6)
self.attn = QWEN2_5_VL_VISION_ATTENTION_CLASSES[attn_implementation](
config.hidden_size, num_heads=config.num_heads
)
self.attn = Qwen2_5_VLVisionAttention(config=config)
self.mlp = Qwen2_5_VLMLP(config, bias=True)
def forward(
@ -330,12 +287,14 @@ class Qwen2_5_VLVisionBlock(GradientCheckpointingLayer):
cu_seqlens: torch.Tensor,
rotary_pos_emb: Optional[torch.Tensor] = None,
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
**kwargs,
) -> torch.Tensor:
hidden_states = hidden_states + self.attn(
self.norm1(hidden_states),
cu_seqlens=cu_seqlens,
rotary_pos_emb=rotary_pos_emb,
position_embeddings=position_embeddings,
**kwargs,
)
hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
return hidden_states
@ -390,9 +349,7 @@ class Qwen2_5_VisionTransformerPretrainedModel(Qwen2_5_VLPreTrainedModel):
head_dim = config.hidden_size // config.num_heads
self.rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding(head_dim // 2)
self.blocks = nn.ModuleList(
[Qwen2_5_VLVisionBlock(config, config._attn_implementation) for _ in range(config.depth)]
)
self.blocks = nn.ModuleList([Qwen2_5_VLVisionBlock(config) for _ in range(config.depth)])
self.merger = Qwen2_5_VLPatchMerger(
dim=config.out_hidden_size,
context_dim=config.hidden_size,
@ -470,7 +427,7 @@ class Qwen2_5_VisionTransformerPretrainedModel(Qwen2_5_VLPreTrainedModel):
return window_index, cu_window_seqlens
def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.Tensor:
def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs) -> torch.Tensor:
"""
Args:
hidden_states (`torch.Tensor` of shape `(seq_len, hidden_size)`):
@ -516,7 +473,9 @@ class Qwen2_5_VisionTransformerPretrainedModel(Qwen2_5_VLPreTrainedModel):
cu_seqlens_now = cu_seqlens
else:
cu_seqlens_now = cu_window_seqlens
hidden_states = blk(hidden_states, cu_seqlens=cu_seqlens_now, position_embeddings=position_embeddings)
hidden_states = blk(
hidden_states, cu_seqlens=cu_seqlens_now, position_embeddings=position_embeddings, **kwargs
)
hidden_states = self.merger(hidden_states)
reverse_indices = torch.argsort(window_index)
@ -647,44 +606,6 @@ def apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim
return q_embed, k_embed
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
def eager_attention_forward(
module: nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
scaling: float,
dropout: float = 0.0,
**kwargs,
):
key_states = repeat_kv(key, module.num_key_value_groups)
value_states = repeat_kv(value, module.num_key_value_groups)
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
if attention_mask is not None:
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, attn_weights
class Qwen2_5_VLAttention(nn.Module):
"""
Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer

View File

@ -40,7 +40,6 @@ from transformers.models.qwen2_vl.modeling_qwen2_vl import (
Qwen2VLPreTrainedModel,
VisionAttention,
VisionRotaryEmbedding,
VisionSdpaAttention,
)
from transformers.models.qwen2_vl.processing_qwen2_vl import Qwen2VLImagesKwargs, Qwen2VLProcessor
@ -57,22 +56,12 @@ from ...video_utils import VideoInput
if is_flash_attn_available():
from ...modeling_flash_attention_utils import apply_rotary_emb, flash_attn_varlen_func
pass
logger = logging.get_logger(__name__)
def apply_rotary_pos_emb_flashatt(
q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
cos = cos.chunk(2, dim=-1)[0].contiguous()
sin = sin.chunk(2, dim=-1)[0].contiguous()
q_embed = apply_rotary_emb(q.float(), cos.float(), sin.float()).type_as(q)
k_embed = apply_rotary_emb(k.float(), cos.float(), sin.float()).type_as(k)
return q_embed, k_embed
class Qwen2_5_VLVisionConfig(PretrainedConfig):
model_type = "qwen2_5_vl"
base_config_key = "vision_config"
@ -150,59 +139,10 @@ class Qwen2_5_VLPatchMerger(PatchMerger):
self.ln_q = Qwen2RMSNorm(context_dim, eps=1e-6)
class Qwen2_5_VLVisionFlashAttention2(nn.Module):
def __init__(self, dim: int, num_heads: int = 16) -> None:
super().__init__()
self.num_heads = num_heads
self.qkv = nn.Linear(dim, dim * 3, bias=True)
self.proj = nn.Linear(dim, dim)
def forward(
self,
hidden_states: torch.Tensor,
cu_seqlens: torch.Tensor,
rotary_pos_emb: Optional[torch.Tensor] = None,
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
) -> torch.Tensor:
seq_length = hidden_states.shape[0]
q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
if position_embeddings is None:
logger.warning_once(
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
"through `rotary_pos_emb` (2D tensor of RoPE theta values), to using externally computed "
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.54 `rotary_pos_emb` will be "
"removed and `position_embeddings` will be mandatory."
)
emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
cos = emb.cos()
sin = emb.sin()
else:
cos, sin = position_embeddings
q, k = apply_rotary_pos_emb_flashatt(q.unsqueeze(0), k.unsqueeze(0), cos, sin)
q = q.squeeze(0)
k = k.squeeze(0)
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
attn_output = flash_attn_varlen_func(q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen).reshape(
seq_length, -1
)
attn_output = self.proj(attn_output)
return attn_output
class Qwen2_5_VLVisionAttention(VisionAttention):
pass
class Qwen2_5_VLVisionSdpaAttention(VisionSdpaAttention):
pass
QWEN2_5_VL_VISION_ATTENTION_CLASSES = {
"eager": Qwen2_5_VLVisionAttention,
"flash_attention_2": Qwen2_5_VLVisionFlashAttention2,
"sdpa": Qwen2_5_VLVisionSdpaAttention,
}
def __init__(self, config: Qwen2_5_VLVisionConfig) -> None:
super().__init__()
self.dim = config.hidden_size
class Qwen2_5_VLVisionBlock(GradientCheckpointingLayer):
@ -210,9 +150,7 @@ class Qwen2_5_VLVisionBlock(GradientCheckpointingLayer):
super().__init__()
self.norm1 = Qwen2RMSNorm(config.hidden_size, eps=1e-6)
self.norm2 = Qwen2RMSNorm(config.hidden_size, eps=1e-6)
self.attn = QWEN2_5_VL_VISION_ATTENTION_CLASSES[attn_implementation](
config.hidden_size, num_heads=config.num_heads
)
self.attn = Qwen2_5_VLVisionAttention(config=config)
self.mlp = Qwen2_5_VLMLP(config, bias=True)
def forward(
@ -221,12 +159,14 @@ class Qwen2_5_VLVisionBlock(GradientCheckpointingLayer):
cu_seqlens: torch.Tensor,
rotary_pos_emb: Optional[torch.Tensor] = None,
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
**kwargs,
) -> torch.Tensor:
hidden_states = hidden_states + self.attn(
self.norm1(hidden_states),
cu_seqlens=cu_seqlens,
rotary_pos_emb=rotary_pos_emb,
position_embeddings=position_embeddings,
**kwargs,
)
hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
return hidden_states
@ -269,9 +209,7 @@ class Qwen2_5_VisionTransformerPretrainedModel(Qwen2_5_VLPreTrainedModel):
head_dim = config.hidden_size // config.num_heads
self.rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding(head_dim // 2)
self.blocks = nn.ModuleList(
[Qwen2_5_VLVisionBlock(config, config._attn_implementation) for _ in range(config.depth)]
)
self.blocks = nn.ModuleList([Qwen2_5_VLVisionBlock(config) for _ in range(config.depth)])
self.merger = Qwen2_5_VLPatchMerger(
dim=config.out_hidden_size,
context_dim=config.hidden_size,
@ -349,7 +287,7 @@ class Qwen2_5_VisionTransformerPretrainedModel(Qwen2_5_VLPreTrainedModel):
return window_index, cu_window_seqlens
def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.Tensor:
def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs) -> torch.Tensor:
"""
Args:
hidden_states (`torch.Tensor` of shape `(seq_len, hidden_size)`):
@ -395,7 +333,9 @@ class Qwen2_5_VisionTransformerPretrainedModel(Qwen2_5_VLPreTrainedModel):
cu_seqlens_now = cu_seqlens
else:
cu_seqlens_now = cu_window_seqlens
hidden_states = blk(hidden_states, cu_seqlens=cu_seqlens_now, position_embeddings=position_embeddings)
hidden_states = blk(
hidden_states, cu_seqlens=cu_seqlens_now, position_embeddings=position_embeddings, **kwargs
)
hidden_states = self.merger(hidden_states)
reverse_indices = torch.argsort(window_index)

View File

@ -33,7 +33,7 @@ from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache
from ...generation import GenerationMixin
from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
from ...modeling_flash_attention_utils import FlashAttentionKwargs, is_flash_attn_available
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutputWithPast, ModelOutput
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
@ -49,10 +49,6 @@ from ...utils import (
from .configuration_qwen2_vl import Qwen2VLConfig, Qwen2VLTextConfig, Qwen2VLVisionConfig
if is_flash_attn_available():
from ...modeling_flash_attention_utils import flash_attn_varlen_func
logger = logging.get_logger(__name__)
@ -279,13 +275,56 @@ class VisionMlp(nn.Module):
return self.fc2(self.act(self.fc1(x)))
# Copied from transformers.models.llama.modeling_llama.repeat_kv
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
def eager_attention_forward(
module: nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
scaling: float,
dropout: float = 0.0,
**kwargs,
):
key_states = repeat_kv(key, module.num_key_value_groups)
value_states = repeat_kv(value, module.num_key_value_groups)
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
if attention_mask is not None:
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, attn_weights
class VisionAttention(nn.Module):
def __init__(self, dim: int, num_heads: int = 16) -> None:
def __init__(self, config: Qwen2VLVisionConfig) -> None:
super().__init__()
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.qkv = nn.Linear(dim, dim * 3, bias=True)
self.proj = nn.Linear(dim, dim)
self.dim = config.embed_dim
self.num_heads = config.num_heads
self.head_dim = self.dim // self.num_heads
self.num_key_value_groups = 1 # needed for eager attention
self.qkv = nn.Linear(self.dim, self.dim * 3, bias=True)
self.proj = nn.Linear(self.dim, self.dim)
self.scaling = math.sqrt(self.head_dim)
self.config = config
def forward(
self,
@ -293,9 +332,12 @@ class VisionAttention(nn.Module):
cu_seqlens: torch.Tensor,
rotary_pos_emb: Optional[torch.Tensor] = None,
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
**kwargs,
) -> torch.Tensor:
seq_length = hidden_states.shape[0]
q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
query_states, key_states, value_states = (
self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
)
if position_embeddings is None:
logger.warning_once(
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
@ -308,117 +350,47 @@ class VisionAttention(nn.Module):
sin = emb.sin()
else:
cos, sin = position_embeddings
q, k = apply_rotary_pos_emb_vision(q, k, cos, sin)
query_states, key_states = apply_rotary_pos_emb_vision(query_states, key_states, cos, sin)
attention_mask = torch.full(
[1, seq_length, seq_length], torch.finfo(q.dtype).min, device=q.device, dtype=q.dtype
[1, 1, seq_length, seq_length],
torch.finfo(value_states.dtype).min,
device=value_states.device,
dtype=value_states.dtype,
)
for i in range(1, len(cu_seqlens)):
attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0
q = q.transpose(0, 1)
k = k.transpose(0, 1)
v = v.transpose(0, 1)
attn_weights = torch.matmul(q, k.transpose(1, 2)) / math.sqrt(self.head_dim)
attn_weights = attn_weights + attention_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype)
attn_output = torch.matmul(attn_weights, v)
attn_output = attn_output.transpose(0, 1)
attn_output = attn_output.reshape(seq_length, -1)
attn_output = self.proj(attn_output)
return attn_output
class VisionFlashAttention2(nn.Module):
def __init__(self, dim: int, num_heads: int = 16) -> None:
super().__init__()
self.num_heads = num_heads
self.qkv = nn.Linear(dim, dim * 3, bias=True)
self.proj = nn.Linear(dim, dim)
def forward(
self,
hidden_states: torch.Tensor,
cu_seqlens: torch.Tensor,
rotary_pos_emb: Optional[torch.Tensor] = None,
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
) -> torch.Tensor:
seq_length = hidden_states.shape[0]
q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
if position_embeddings is None:
logger.warning_once(
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
"through `rotary_pos_emb` (2D tensor of RoPE theta values), to using externally computed "
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.54 `rotary_pos_emb` will be "
"removed and `position_embeddings` will be mandatory."
)
emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
cos = emb.cos()
sin = emb.sin()
else:
cos, sin = position_embeddings
q, k = apply_rotary_pos_emb_vision(q, k, cos, sin)
query_states = query_states.transpose(0, 1).unsqueeze(0) # unsqueeze batch_dim
key_states = key_states.transpose(0, 1).unsqueeze(0) # unsqueeze batch_dim
value_states = value_states.transpose(0, 1).unsqueeze(0) # unsqueeze batch_dim
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
attn_output = flash_attn_varlen_func(q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen).reshape(
seq_length, -1
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
attn_output, _ = attention_interface(
self,
query_states,
key_states,
value_states,
attention_mask,
dropout=0.0,
scaling=self.scaling,
cu_seqlens_q=cu_seqlens, # pass cu seq lens for FA2
cu_seqlens_k=cu_seqlens,
max_seqlen_q=max_seqlen,
max_seqlen_k=max_seqlen,
is_causal=False,
**kwargs,
)
attn_output = attn_output.reshape(seq_length, -1).contiguous()
attn_output = self.proj(attn_output)
return attn_output
class VisionSdpaAttention(nn.Module):
def __init__(self, dim: int, num_heads: int = 16) -> None:
super().__init__()
self.num_heads = num_heads
self.qkv = nn.Linear(dim, dim * 3, bias=True)
self.proj = nn.Linear(dim, dim)
def forward(
self,
hidden_states: torch.Tensor,
cu_seqlens: torch.Tensor,
rotary_pos_emb: Optional[torch.Tensor] = None,
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
) -> torch.Tensor:
seq_length = hidden_states.shape[0]
q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
if position_embeddings is None:
logger.warning_once(
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
"through `rotary_pos_emb` (2D tensor of RoPE theta values), to using externally computed "
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.54 `rotary_pos_emb` will be "
"removed and `position_embeddings` will be mandatory."
)
emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
cos = emb.cos()
sin = emb.sin()
else:
cos, sin = position_embeddings
q, k = apply_rotary_pos_emb_vision(q, k, cos, sin)
attention_mask = torch.zeros([1, seq_length, seq_length], device=q.device, dtype=torch.bool)
for i in range(1, len(cu_seqlens)):
attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = True
q = q.transpose(0, 1)
k = k.transpose(0, 1)
v = v.transpose(0, 1)
attn_output = F.scaled_dot_product_attention(
q.unsqueeze(0), k.unsqueeze(0), v.unsqueeze(0), attention_mask, dropout_p=0.0
)
attn_output = attn_output.squeeze(0).transpose(0, 1)
attn_output = attn_output.reshape(seq_length, -1)
attn_output = self.proj(attn_output)
return attn_output
QWEN2_VL_VISION_ATTENTION_CLASSES = {
"eager": VisionAttention,
"flash_attention_2": VisionFlashAttention2,
"sdpa": VisionSdpaAttention,
}
class Qwen2VLVisionBlock(GradientCheckpointingLayer):
def __init__(self, config, attn_implementation: str = "sdpa") -> None:
super().__init__()
@ -426,9 +398,7 @@ class Qwen2VLVisionBlock(GradientCheckpointingLayer):
self.norm2 = LayerNorm(config.embed_dim, eps=1e-6)
mlp_hidden_dim = int(config.embed_dim * config.mlp_ratio)
self.attn = QWEN2_VL_VISION_ATTENTION_CLASSES[attn_implementation](
config.embed_dim, num_heads=config.num_heads
)
self.attn = VisionAttention(config=config)
self.mlp = VisionMlp(dim=config.embed_dim, hidden_dim=mlp_hidden_dim, hidden_act=config.hidden_act)
def forward(
@ -437,12 +407,14 @@ class Qwen2VLVisionBlock(GradientCheckpointingLayer):
cu_seqlens: torch.Tensor,
rotary_pos_emb: Optional[torch.Tensor] = None,
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
**kwargs,
) -> torch.Tensor:
hidden_states = hidden_states + self.attn(
self.norm1(hidden_states),
cu_seqlens=cu_seqlens,
rotary_pos_emb=rotary_pos_emb,
position_embeddings=position_embeddings,
**kwargs,
)
hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
return hidden_states
@ -486,45 +458,6 @@ class Qwen2MLP(nn.Module):
return down_proj
# Copied from transformers.models.llama.modeling_llama.repeat_kv
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
def eager_attention_forward(
module: nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
scaling: float,
dropout: float = 0.0,
**kwargs,
):
key_states = repeat_kv(key, module.num_key_value_groups)
value_states = repeat_kv(value, module.num_key_value_groups)
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
if attention_mask is not None:
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, attn_weights
class Qwen2VLAttention(nn.Module):
"""
Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer
@ -752,9 +685,7 @@ class Qwen2VisionTransformerPretrainedModel(Qwen2VLPreTrainedModel):
head_dim = config.embed_dim // config.num_heads
self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2)
self.blocks = nn.ModuleList(
[Qwen2VLVisionBlock(config, config._attn_implementation) for _ in range(config.depth)]
)
self.blocks = nn.ModuleList([Qwen2VLVisionBlock(config) for _ in range(config.depth)])
self.merger = PatchMerger(
dim=config.hidden_size, context_dim=config.embed_dim, spatial_merge_size=config.spatial_merge_size
)
@ -796,7 +727,12 @@ class Qwen2VisionTransformerPretrainedModel(Qwen2VLPreTrainedModel):
return rotary_pos_emb
@auto_docstring
def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.Tensor:
def forward(
self,
hidden_states: torch.Tensor,
grid_thw: torch.Tensor,
**kwargs,
) -> torch.Tensor:
r"""
grid_thw (`torch.LongTensor` of shape `(num_images, 3)`):
The temporal, height and width dimensions of feature shape for each image. Each row contains [t, h, w] values.
@ -817,7 +753,9 @@ class Qwen2VisionTransformerPretrainedModel(Qwen2VLPreTrainedModel):
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
for blk in self.blocks:
hidden_states = blk(hidden_states, cu_seqlens=cu_seqlens, position_embeddings=position_embeddings)
hidden_states = blk(
hidden_states, cu_seqlens=cu_seqlens, position_embeddings=position_embeddings, **kwargs
)
return self.merger(hidden_states)