mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-01 20:00:09 +06:00
[qwen] refactor attentions for vision/audio (#38930)
Some checks are pending
Self-hosted runner (benchmark) / Benchmark (aws-g5-4xlarge-cache) (push) Waiting to run
Build documentation / build (push) Waiting to run
New model PR merged notification / Notify new model (push) Waiting to run
Slow tests on important models (on Push - A10) / Get all modified files (push) Waiting to run
Slow tests on important models (on Push - A10) / Slow & FA2 tests (push) Blocked by required conditions
Self-hosted runner (push-caller) / Check if setup was changed (push) Waiting to run
Self-hosted runner (push-caller) / build-docker-containers (push) Blocked by required conditions
Self-hosted runner (push-caller) / Trigger Push CI (push) Blocked by required conditions
Secret Leaks / trufflehog (push) Waiting to run
Update Transformers metadata / build_and_package (push) Waiting to run
Some checks are pending
Self-hosted runner (benchmark) / Benchmark (aws-g5-4xlarge-cache) (push) Waiting to run
Build documentation / build (push) Waiting to run
New model PR merged notification / Notify new model (push) Waiting to run
Slow tests on important models (on Push - A10) / Get all modified files (push) Waiting to run
Slow tests on important models (on Push - A10) / Slow & FA2 tests (push) Blocked by required conditions
Self-hosted runner (push-caller) / Check if setup was changed (push) Waiting to run
Self-hosted runner (push-caller) / build-docker-containers (push) Blocked by required conditions
Self-hosted runner (push-caller) / Trigger Push CI (push) Blocked by required conditions
Secret Leaks / trufflehog (push) Waiting to run
Update Transformers metadata / build_and_package (push) Waiting to run
* refactor attentions in vision/audio * remove fa2 import * make config the only args * pass along kwargs from modality encoders * style
This commit is contained in:
parent
2e4c045540
commit
d3d835d4fc
@ -34,11 +34,7 @@ from ...activations import ACT2FN
|
||||
from ...cache_utils import Cache, DynamicCache
|
||||
from ...generation import GenerationMixin
|
||||
from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
|
||||
from ...modeling_flash_attention_utils import (
|
||||
FlashAttentionKwargs,
|
||||
flash_attn_supports_top_left_mask,
|
||||
is_flash_attn_available,
|
||||
)
|
||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPast, ModelOutput
|
||||
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
||||
@ -59,13 +55,6 @@ from .configuration_qwen2_5_omni import (
|
||||
)
|
||||
|
||||
|
||||
if is_flash_attn_available():
|
||||
from ...modeling_flash_attention_utils import apply_rotary_emb, flash_attn_varlen_func
|
||||
else:
|
||||
flash_attn_varlen_func = None
|
||||
apply_rotary_emb = None
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@ -559,6 +548,44 @@ class Qwen2_5OmniThinkerCausalLMOutputWithPast(ModelOutput):
|
||||
rope_deltas: Optional[torch.LongTensor] = None
|
||||
|
||||
|
||||
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||
"""
|
||||
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
||||
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
||||
"""
|
||||
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
||||
if n_rep == 1:
|
||||
return hidden_states
|
||||
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
|
||||
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
||||
|
||||
|
||||
def eager_attention_forward(
|
||||
module: nn.Module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor],
|
||||
scaling: float,
|
||||
dropout: float = 0.0,
|
||||
**kwargs,
|
||||
):
|
||||
key_states = repeat_kv(key, module.num_key_value_groups)
|
||||
value_states = repeat_kv(value, module.num_key_value_groups)
|
||||
|
||||
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
|
||||
if attention_mask is not None:
|
||||
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
||||
attn_weights = attn_weights + causal_mask
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
class Qwen2_5OmniAudioAttention(nn.Module):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
||||
@ -571,6 +598,7 @@ class Qwen2_5OmniAudioAttention(nn.Module):
|
||||
self.num_heads = config.encoder_attention_heads
|
||||
self.dropout = config.attention_dropout
|
||||
self.head_dim = self.embed_dim // self.num_heads
|
||||
self.num_key_value_groups = 1 # needed for eager attention
|
||||
self.config = config
|
||||
|
||||
if (self.head_dim * self.num_heads) != self.embed_dim:
|
||||
@ -591,6 +619,7 @@ class Qwen2_5OmniAudioAttention(nn.Module):
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
cu_seqlens: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
|
||||
"""Input shape: Batch x Time x Channel"""
|
||||
|
||||
@ -600,13 +629,13 @@ class Qwen2_5OmniAudioAttention(nn.Module):
|
||||
key_states = self.k_proj(hidden_states).reshape(seq_length, self.num_heads, -1)
|
||||
value_states = self.v_proj(hidden_states).reshape(seq_length, self.num_heads, -1)
|
||||
|
||||
query_states = query_states.transpose(0, 1)
|
||||
key_states = key_states.transpose(0, 1)
|
||||
value_states = value_states.transpose(0, 1)
|
||||
attn_weights = torch.matmul(query_states, key_states.transpose(1, 2)) / math.sqrt(self.head_dim)
|
||||
query_states = query_states.transpose(0, 1).unsqueeze(0)
|
||||
key_states = key_states.transpose(0, 1).unsqueeze(0)
|
||||
value_states = value_states.transpose(0, 1).unsqueeze(0)
|
||||
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
|
||||
|
||||
attention_mask = torch.full(
|
||||
[1, seq_length, key_states.shape[1]],
|
||||
[1, 1, seq_length, key_states.shape[-2]],
|
||||
torch.finfo(query_states.dtype).min,
|
||||
device=query_states.device,
|
||||
dtype=query_states.dtype,
|
||||
@ -614,115 +643,37 @@ class Qwen2_5OmniAudioAttention(nn.Module):
|
||||
for i in range(1, len(cu_seqlens)):
|
||||
attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0
|
||||
|
||||
attn_weights = attn_weights + attention_mask
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
if self.config._attn_implementation != "eager":
|
||||
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1).to(query_states.dtype)
|
||||
|
||||
attn_output = torch.matmul(attn_weights, value_states).transpose(0, 1).reshape(seq_length, self.embed_dim)
|
||||
|
||||
# Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
|
||||
# partitioned across GPUs when using tensor-parallelism.
|
||||
|
||||
attn_output = self.out_proj(attn_output)
|
||||
|
||||
return attn_output
|
||||
|
||||
|
||||
class Qwen2_5OmniAudioFlashAttention2(Qwen2_5OmniAudioAttention):
|
||||
"""
|
||||
Qwen2.5OmniThinker flash attention module. This module inherits from `Qwen2_5OmniAudioAttention` as the weights of the module stays
|
||||
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
|
||||
flash attention and deal with padding tokens in case the input contains any of them.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
|
||||
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
||||
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
|
||||
self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
cu_seqlens: Optional[torch.Tensor] = None,
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
|
||||
seq_length, all_dim = hidden_states.size()
|
||||
query_states = self.q_proj(hidden_states)
|
||||
query_states = query_states.reshape(seq_length, self.num_heads, -1)
|
||||
|
||||
key_states = self.k_proj(hidden_states)
|
||||
key_states = key_states.reshape(seq_length, self.num_heads, -1)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
value_states = value_states.reshape(seq_length, self.num_heads, -1)
|
||||
|
||||
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
|
||||
attn_output = flash_attn_varlen_func(
|
||||
query_states, key_states, value_states, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen, dropout_p=0.0
|
||||
)
|
||||
attn_output = attn_output.reshape(seq_length, all_dim)
|
||||
attn_output = self.out_proj(attn_output)
|
||||
|
||||
return attn_output
|
||||
|
||||
|
||||
class Qwen2_5OmniAudioSdpaAttention(Qwen2_5OmniAudioAttention):
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
cu_seqlens: Optional[torch.Tensor] = None,
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
|
||||
"""Input shape: Batch x Time x Channel"""
|
||||
|
||||
seq_length, _ = hidden_states.size()
|
||||
|
||||
query_states = self.q_proj(hidden_states).reshape(seq_length, self.num_heads, -1)
|
||||
key_states = self.k_proj(hidden_states).reshape(seq_length, self.num_heads, -1)
|
||||
value_states = self.v_proj(hidden_states).reshape(seq_length, self.num_heads, -1)
|
||||
|
||||
attention_mask = torch.zeros(
|
||||
[1, seq_length, key_states.shape[0]], device=query_states.device, dtype=torch.bool
|
||||
)
|
||||
for i in range(1, len(cu_seqlens)):
|
||||
attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = True
|
||||
|
||||
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
|
||||
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
|
||||
# The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1.
|
||||
query_states = query_states.transpose(0, 1)
|
||||
key_states = key_states.transpose(0, 1)
|
||||
value_states = value_states.transpose(0, 1)
|
||||
|
||||
# NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask,
|
||||
# but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577
|
||||
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||
attn_output, _ = attention_interface(
|
||||
self,
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attn_mask=attention_mask,
|
||||
dropout_p=self.dropout if self.training else 0.0,
|
||||
attention_mask,
|
||||
dropout=0.0 if not self.training else self.dropout,
|
||||
scaling=self.scaling,
|
||||
cu_seqlens_q=cu_seqlens, # pass cu seq lens for FA2
|
||||
cu_seqlens_k=cu_seqlens,
|
||||
max_seqlen_q=max_seqlen,
|
||||
max_seqlen_k=max_seqlen,
|
||||
is_causal=False,
|
||||
**kwargs,
|
||||
)
|
||||
attn_output = attn_output.transpose(0, 1)
|
||||
# Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
|
||||
# partitioned across GPUs when using tensor-parallelism.
|
||||
attn_output = attn_output.reshape(seq_length, self.embed_dim)
|
||||
|
||||
attn_output = attn_output.reshape(seq_length, -1).contiguous()
|
||||
attn_output = self.out_proj(attn_output)
|
||||
|
||||
return attn_output
|
||||
|
||||
|
||||
QWEN2_5_OMNI_AUDIO_ATTENTION_CLASSES = {
|
||||
"eager": Qwen2_5OmniAudioAttention,
|
||||
"flash_attention_2": Qwen2_5OmniAudioFlashAttention2,
|
||||
"sdpa": Qwen2_5OmniAudioSdpaAttention,
|
||||
}
|
||||
|
||||
|
||||
class Qwen2_5OmniAudioEncoderLayer(GradientCheckpointingLayer):
|
||||
def __init__(self, config: Qwen2_5OmniAudioEncoderConfig):
|
||||
super().__init__()
|
||||
self.embed_dim = config.d_model
|
||||
self.self_attn = QWEN2_5_OMNI_AUDIO_ATTENTION_CLASSES[config._attn_implementation](config)
|
||||
self.self_attn = Qwen2_5OmniAudioAttention(config)
|
||||
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
|
||||
self.dropout = config.dropout
|
||||
self.activation_fn = ACT2FN[config.activation_function]
|
||||
@ -735,6 +686,7 @@ class Qwen2_5OmniAudioEncoderLayer(GradientCheckpointingLayer):
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
@ -752,6 +704,7 @@ class Qwen2_5OmniAudioEncoderLayer(GradientCheckpointingLayer):
|
||||
hidden_states = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
cu_seqlens=cu_seqlens,
|
||||
**kwargs,
|
||||
)
|
||||
hidden_states = residual + hidden_states
|
||||
residual = hidden_states
|
||||
@ -838,6 +791,7 @@ class Qwen2_5OmniAudioEncoder(Qwen2_5OmniPreTrainedModel):
|
||||
input_features,
|
||||
feature_lens=None,
|
||||
aftercnn_lens=None,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
input_features (`torch.LongTensor` of shape `(batch_size, feature_size, sequence_length)`):
|
||||
@ -881,7 +835,7 @@ class Qwen2_5OmniAudioEncoder(Qwen2_5OmniPreTrainedModel):
|
||||
).to(torch.int32)
|
||||
|
||||
for encoder_layer in self.layers:
|
||||
layer_outputs = encoder_layer(hidden_states, cu_seqlens)
|
||||
layer_outputs = encoder_layer(hidden_states, cu_seqlens, **kwargs)
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
hidden_states_list = hidden_states.split(aftercnn_lens.tolist(), dim=0)
|
||||
@ -962,106 +916,68 @@ def apply_rotary_pos_emb_vision(tensor: torch.Tensor, freqs: torch.Tensor) -> to
|
||||
|
||||
|
||||
class Qwen2_5OmniVisionAttention(nn.Module):
|
||||
def __init__(self, dim: int, num_heads: int = 16) -> None:
|
||||
def __init__(self, config: Qwen2_5OmniVisionEncoderConfig = None) -> None:
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = dim // num_heads
|
||||
self.q = nn.Linear(dim, dim, bias=True)
|
||||
self.k = nn.Linear(dim, dim, bias=True)
|
||||
self.v = nn.Linear(dim, dim, bias=True)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
self.dim = config.hidden_size
|
||||
self.num_heads = config.num_heads
|
||||
self.head_dim = self.dim // self.num_heads
|
||||
self.q = nn.Linear(self.dim, self.dim, bias=True)
|
||||
self.k = nn.Linear(self.dim, self.dim, bias=True)
|
||||
self.v = nn.Linear(self.dim, self.dim, bias=True)
|
||||
self.proj = nn.Linear(self.dim, self.dim)
|
||||
self.scaling = math.sqrt(self.head_dim)
|
||||
self.num_key_value_groups = 1 # needed for eager attention
|
||||
self.config = config
|
||||
|
||||
def forward(
|
||||
self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor = None
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor,
|
||||
rotary_pos_emb: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
seq_length = hidden_states.shape[0]
|
||||
q = self.q(hidden_states).reshape(seq_length, self.num_heads, -1)
|
||||
k = self.k(hidden_states).reshape(seq_length, self.num_heads, -1)
|
||||
v = self.v(hidden_states).reshape(seq_length, self.num_heads, -1)
|
||||
q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0)
|
||||
k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0)
|
||||
query_states = self.q(hidden_states).reshape(seq_length, self.num_heads, -1)
|
||||
key_states = self.k(hidden_states).reshape(seq_length, self.num_heads, -1)
|
||||
value_states = self.v(hidden_states).reshape(seq_length, self.num_heads, -1)
|
||||
query_states = apply_rotary_pos_emb_vision(query_states.unsqueeze(0), rotary_pos_emb).squeeze(0)
|
||||
key_states = apply_rotary_pos_emb_vision(key_states.unsqueeze(0), rotary_pos_emb).squeeze(0)
|
||||
|
||||
attention_mask = torch.full(
|
||||
[1, seq_length, seq_length], torch.finfo(q.dtype).min, device=q.device, dtype=q.dtype
|
||||
[1, 1, seq_length, seq_length],
|
||||
torch.finfo(query_states.dtype).min,
|
||||
device=query_states.device,
|
||||
dtype=query_states.dtype,
|
||||
)
|
||||
for i in range(1, len(cu_seqlens)):
|
||||
attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0
|
||||
|
||||
q = q.transpose(0, 1)
|
||||
k = k.transpose(0, 1)
|
||||
v = v.transpose(0, 1)
|
||||
attn_weights = torch.matmul(q, k.transpose(1, 2)) / math.sqrt(self.head_dim)
|
||||
attn_weights = attn_weights + attention_mask
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype)
|
||||
attn_output = torch.matmul(attn_weights, v)
|
||||
attn_output = attn_output.transpose(0, 1)
|
||||
attn_output = attn_output.reshape(seq_length, -1)
|
||||
attn_output = self.proj(attn_output)
|
||||
return attn_output
|
||||
|
||||
|
||||
class Qwen2_5OmniVisionFlashAttention2(nn.Module):
|
||||
def __init__(self, dim: int, num_heads: int = 16) -> None:
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
self.q = nn.Linear(dim, dim, bias=True)
|
||||
self.k = nn.Linear(dim, dim, bias=True)
|
||||
self.v = nn.Linear(dim, dim, bias=True)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
|
||||
def _apply_rotary_pos_emb_flashatt(self, tensor: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor:
|
||||
tensor_ = tensor.float()
|
||||
cos = freqs.cos().type_as(tensor_)
|
||||
sin = freqs.sin().type_as(tensor_)
|
||||
output = apply_rotary_emb(tensor_, cos, sin).type_as(tensor)
|
||||
return output
|
||||
|
||||
def forward(
|
||||
self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor = None
|
||||
) -> torch.Tensor:
|
||||
seq_length = hidden_states.shape[0]
|
||||
q = self.q(hidden_states).reshape(seq_length, self.num_heads, -1)
|
||||
k = self.k(hidden_states).reshape(seq_length, self.num_heads, -1)
|
||||
v = self.v(hidden_states).reshape(seq_length, self.num_heads, -1)
|
||||
q = self._apply_rotary_pos_emb_flashatt(q.unsqueeze(0), rotary_pos_emb).squeeze(0)
|
||||
k = self._apply_rotary_pos_emb_flashatt(k.unsqueeze(0), rotary_pos_emb).squeeze(0)
|
||||
|
||||
query_states = query_states.transpose(0, 1).unsqueeze(0) # unsqueeze batch_dim
|
||||
key_states = key_states.transpose(0, 1).unsqueeze(0) # unsqueeze batch_dim
|
||||
value_states = value_states.transpose(0, 1).unsqueeze(0) # unsqueeze batch_dim
|
||||
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
|
||||
attn_output = flash_attn_varlen_func(q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen).reshape(
|
||||
seq_length, -1
|
||||
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
if self.config._attn_implementation != "eager":
|
||||
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
||||
|
||||
attn_output, _ = attention_interface(
|
||||
self,
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attention_mask,
|
||||
dropout=0.0,
|
||||
scaling=self.scaling,
|
||||
cu_seqlens_q=cu_seqlens, # pass cu seq lens for FA2
|
||||
cu_seqlens_k=cu_seqlens,
|
||||
max_seqlen_q=max_seqlen,
|
||||
max_seqlen_k=max_seqlen,
|
||||
is_causal=False,
|
||||
**kwargs,
|
||||
)
|
||||
attn_output = self.proj(attn_output)
|
||||
return attn_output
|
||||
|
||||
|
||||
class Qwen2_5OmniVisionSdpaAttention(nn.Module):
|
||||
def __init__(self, dim: int, num_heads: int = 16) -> None:
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
self.q = nn.Linear(dim, dim, bias=True)
|
||||
self.k = nn.Linear(dim, dim, bias=True)
|
||||
self.v = nn.Linear(dim, dim, bias=True)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
|
||||
def forward(
|
||||
self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor = None
|
||||
) -> torch.Tensor:
|
||||
seq_length = hidden_states.shape[0]
|
||||
q = self.q(hidden_states).reshape(seq_length, self.num_heads, -1)
|
||||
k = self.k(hidden_states).reshape(seq_length, self.num_heads, -1)
|
||||
v = self.v(hidden_states).reshape(seq_length, self.num_heads, -1)
|
||||
q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0)
|
||||
k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0)
|
||||
|
||||
attention_mask = torch.zeros([1, seq_length, seq_length], device=q.device, dtype=torch.bool)
|
||||
for i in range(1, len(cu_seqlens)):
|
||||
attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = True
|
||||
q = q.transpose(0, 1)
|
||||
k = k.transpose(0, 1)
|
||||
v = v.transpose(0, 1)
|
||||
attn_output = F.scaled_dot_product_attention(q, k, v, attention_mask, dropout_p=0.0)
|
||||
attn_output = attn_output.transpose(0, 1)
|
||||
attn_output = attn_output.reshape(seq_length, -1)
|
||||
attn_output = attn_output.reshape(seq_length, -1).contiguous()
|
||||
attn_output = self.proj(attn_output)
|
||||
return attn_output
|
||||
|
||||
@ -1080,26 +996,23 @@ class Qwen2_5OmniMLP(nn.Module):
|
||||
return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state))
|
||||
|
||||
|
||||
QWEN2_5_OMNI_VISION_ATTENTION_CLASSES = {
|
||||
"eager": Qwen2_5OmniVisionAttention,
|
||||
"flash_attention_2": Qwen2_5OmniVisionFlashAttention2,
|
||||
"sdpa": Qwen2_5OmniVisionSdpaAttention,
|
||||
}
|
||||
|
||||
|
||||
class Qwen2_5OmniVisionBlock(GradientCheckpointingLayer):
|
||||
def __init__(self, config: Qwen2_5OmniVisionEncoderConfig) -> None:
|
||||
super().__init__()
|
||||
self.norm1 = Qwen2RMSNorm(config.hidden_size, eps=1e-6)
|
||||
self.norm2 = Qwen2RMSNorm(config.hidden_size, eps=1e-6)
|
||||
self.attn = QWEN2_5_OMNI_VISION_ATTENTION_CLASSES[config._attn_implementation](
|
||||
config.hidden_size, num_heads=config.num_heads
|
||||
)
|
||||
self.attn = Qwen2_5OmniVisionAttention(config=config)
|
||||
self.mlp = Qwen2_5OmniMLP(config, bias=True)
|
||||
|
||||
def forward(self, hidden_states, cu_seqlens, rotary_pos_emb) -> torch.Tensor:
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor,
|
||||
rotary_pos_emb: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = hidden_states + self.attn(
|
||||
self.norm1(hidden_states), cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb
|
||||
self.norm1(hidden_states), cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb, **kwargs
|
||||
)
|
||||
hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
|
||||
return hidden_states
|
||||
@ -1258,7 +1171,7 @@ class Qwen2_5OmniVisionEncoder(Qwen2_5OmniPreTrainedModel):
|
||||
|
||||
return window_index, cu_window_seqlens
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.Tensor:
|
||||
def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
hidden_states (`torch.Tensor` of shape `(seq_len, hidden_size)`):
|
||||
@ -1308,6 +1221,7 @@ class Qwen2_5OmniVisionEncoder(Qwen2_5OmniPreTrainedModel):
|
||||
hidden_states,
|
||||
cu_seqlens=cu_seqlens_now,
|
||||
rotary_pos_emb=rotary_pos_emb,
|
||||
**kwargs,
|
||||
)
|
||||
hidden_states = self.merger(hidden_states)
|
||||
reverse_indices = torch.argsort(window_index)
|
||||
@ -1397,44 +1311,6 @@ def apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim
|
||||
return q_embed, k_embed
|
||||
|
||||
|
||||
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||
"""
|
||||
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
||||
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
||||
"""
|
||||
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
||||
if n_rep == 1:
|
||||
return hidden_states
|
||||
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
|
||||
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
||||
|
||||
|
||||
def eager_attention_forward(
|
||||
module: nn.Module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor],
|
||||
scaling: float,
|
||||
dropout: float = 0.0,
|
||||
**kwargs,
|
||||
):
|
||||
key_states = repeat_kv(key, module.num_key_value_groups)
|
||||
value_states = repeat_kv(value, module.num_key_value_groups)
|
||||
|
||||
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
|
||||
if attention_mask is not None:
|
||||
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
||||
attn_weights = attn_weights + causal_mask
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
class Qwen2_5OmniAttention(nn.Module):
|
||||
"""
|
||||
Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer
|
||||
|
@ -17,7 +17,7 @@
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional, Union
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -36,6 +36,7 @@ from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
|
||||
Qwen2_5_VLTextModel,
|
||||
Qwen2_5_VLVisionBlock,
|
||||
Qwen2RMSNorm,
|
||||
eager_attention_forward,
|
||||
)
|
||||
from transformers.models.qwen2_audio.configuration_qwen2_audio import Qwen2AudioEncoderConfig
|
||||
from transformers.models.qwen2_audio.modeling_qwen2_audio import Qwen2AudioEncoderLayer
|
||||
@ -43,7 +44,7 @@ from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLRotaryEmbeddin
|
||||
|
||||
from ...configuration_utils import PretrainedConfig, layer_type_validation
|
||||
from ...generation import GenerationMixin
|
||||
from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available
|
||||
from ...modeling_flash_attention_utils import is_flash_attn_available
|
||||
from ...modeling_outputs import BaseModelOutput, ModelOutput
|
||||
from ...modeling_rope_utils import rope_config_validation
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
|
||||
@ -1601,6 +1602,7 @@ class Qwen2_5OmniAudioAttention(nn.Module):
|
||||
self.num_heads = config.encoder_attention_heads
|
||||
self.dropout = config.attention_dropout
|
||||
self.head_dim = self.embed_dim // self.num_heads
|
||||
self.num_key_value_groups = 1 # needed for eager attention
|
||||
self.config = config
|
||||
|
||||
if (self.head_dim * self.num_heads) != self.embed_dim:
|
||||
@ -1621,6 +1623,7 @@ class Qwen2_5OmniAudioAttention(nn.Module):
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
cu_seqlens: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
|
||||
"""Input shape: Batch x Time x Channel"""
|
||||
|
||||
@ -1630,13 +1633,13 @@ class Qwen2_5OmniAudioAttention(nn.Module):
|
||||
key_states = self.k_proj(hidden_states).reshape(seq_length, self.num_heads, -1)
|
||||
value_states = self.v_proj(hidden_states).reshape(seq_length, self.num_heads, -1)
|
||||
|
||||
query_states = query_states.transpose(0, 1)
|
||||
key_states = key_states.transpose(0, 1)
|
||||
value_states = value_states.transpose(0, 1)
|
||||
attn_weights = torch.matmul(query_states, key_states.transpose(1, 2)) / math.sqrt(self.head_dim)
|
||||
query_states = query_states.transpose(0, 1).unsqueeze(0)
|
||||
key_states = key_states.transpose(0, 1).unsqueeze(0)
|
||||
value_states = value_states.transpose(0, 1).unsqueeze(0)
|
||||
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
|
||||
|
||||
attention_mask = torch.full(
|
||||
[1, seq_length, key_states.shape[1]],
|
||||
[1, 1, seq_length, key_states.shape[-2]],
|
||||
torch.finfo(query_states.dtype).min,
|
||||
device=query_states.device,
|
||||
dtype=query_states.dtype,
|
||||
@ -1644,125 +1647,49 @@ class Qwen2_5OmniAudioAttention(nn.Module):
|
||||
for i in range(1, len(cu_seqlens)):
|
||||
attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0
|
||||
|
||||
attn_weights = attn_weights + attention_mask
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
if self.config._attn_implementation != "eager":
|
||||
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1).to(query_states.dtype)
|
||||
|
||||
attn_output = torch.matmul(attn_weights, value_states).transpose(0, 1).reshape(seq_length, self.embed_dim)
|
||||
|
||||
# Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
|
||||
# partitioned across GPUs when using tensor-parallelism.
|
||||
|
||||
attn_output = self.out_proj(attn_output)
|
||||
|
||||
return attn_output
|
||||
|
||||
|
||||
class Qwen2_5OmniAudioFlashAttention2(Qwen2_5OmniAudioAttention):
|
||||
"""
|
||||
Qwen2.5OmniThinker flash attention module. This module inherits from `Qwen2_5OmniAudioAttention` as the weights of the module stays
|
||||
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
|
||||
flash attention and deal with padding tokens in case the input contains any of them.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
|
||||
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
||||
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
|
||||
self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
cu_seqlens: Optional[torch.Tensor] = None,
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
|
||||
seq_length, all_dim = hidden_states.size()
|
||||
query_states = self.q_proj(hidden_states)
|
||||
query_states = query_states.reshape(seq_length, self.num_heads, -1)
|
||||
|
||||
key_states = self.k_proj(hidden_states)
|
||||
key_states = key_states.reshape(seq_length, self.num_heads, -1)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
value_states = value_states.reshape(seq_length, self.num_heads, -1)
|
||||
|
||||
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
|
||||
attn_output = flash_attn_varlen_func(
|
||||
query_states, key_states, value_states, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen, dropout_p=0.0
|
||||
)
|
||||
attn_output = attn_output.reshape(seq_length, all_dim)
|
||||
attn_output = self.out_proj(attn_output)
|
||||
|
||||
return attn_output
|
||||
|
||||
|
||||
class Qwen2_5OmniAudioSdpaAttention(Qwen2_5OmniAudioAttention):
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
cu_seqlens: Optional[torch.Tensor] = None,
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
|
||||
"""Input shape: Batch x Time x Channel"""
|
||||
|
||||
seq_length, _ = hidden_states.size()
|
||||
|
||||
query_states = self.q_proj(hidden_states).reshape(seq_length, self.num_heads, -1)
|
||||
key_states = self.k_proj(hidden_states).reshape(seq_length, self.num_heads, -1)
|
||||
value_states = self.v_proj(hidden_states).reshape(seq_length, self.num_heads, -1)
|
||||
|
||||
attention_mask = torch.zeros(
|
||||
[1, seq_length, key_states.shape[0]], device=query_states.device, dtype=torch.bool
|
||||
)
|
||||
for i in range(1, len(cu_seqlens)):
|
||||
attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = True
|
||||
|
||||
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
|
||||
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
|
||||
# The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1.
|
||||
query_states = query_states.transpose(0, 1)
|
||||
key_states = key_states.transpose(0, 1)
|
||||
value_states = value_states.transpose(0, 1)
|
||||
|
||||
# NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask,
|
||||
# but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577
|
||||
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||
attn_output, _ = attention_interface(
|
||||
self,
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attn_mask=attention_mask,
|
||||
dropout_p=self.dropout if self.training else 0.0,
|
||||
attention_mask,
|
||||
dropout=0.0 if not self.training else self.dropout,
|
||||
scaling=self.scaling,
|
||||
cu_seqlens_q=cu_seqlens, # pass cu seq lens for FA2
|
||||
cu_seqlens_k=cu_seqlens,
|
||||
max_seqlen_q=max_seqlen,
|
||||
max_seqlen_k=max_seqlen,
|
||||
is_causal=False,
|
||||
**kwargs,
|
||||
)
|
||||
attn_output = attn_output.transpose(0, 1)
|
||||
# Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
|
||||
# partitioned across GPUs when using tensor-parallelism.
|
||||
attn_output = attn_output.reshape(seq_length, self.embed_dim)
|
||||
|
||||
attn_output = attn_output.reshape(seq_length, -1).contiguous()
|
||||
attn_output = self.out_proj(attn_output)
|
||||
|
||||
return attn_output
|
||||
|
||||
|
||||
QWEN2_5_OMNI_AUDIO_ATTENTION_CLASSES = {
|
||||
"eager": Qwen2_5OmniAudioAttention,
|
||||
"flash_attention_2": Qwen2_5OmniAudioFlashAttention2,
|
||||
"sdpa": Qwen2_5OmniAudioSdpaAttention,
|
||||
}
|
||||
|
||||
|
||||
class Qwen2_5OmniAudioEncoderLayer(Qwen2AudioEncoderLayer):
|
||||
def __init__(self, config: Qwen2_5OmniAudioEncoderConfig):
|
||||
super().__init__(config)
|
||||
self.self_attn = QWEN2_5_OMNI_AUDIO_ATTENTION_CLASSES[config._attn_implementation](config)
|
||||
self.self_attn = Qwen2_5OmniAudioAttention(config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
residual = hidden_states
|
||||
hidden_states = self.self_attn_layer_norm(hidden_states)
|
||||
hidden_states = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
cu_seqlens=cu_seqlens,
|
||||
**kwargs,
|
||||
)
|
||||
hidden_states = residual + hidden_states
|
||||
residual = hidden_states
|
||||
@ -1849,6 +1776,7 @@ class Qwen2_5OmniAudioEncoder(Qwen2_5OmniPreTrainedModel):
|
||||
input_features,
|
||||
feature_lens=None,
|
||||
aftercnn_lens=None,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
input_features (`torch.LongTensor` of shape `(batch_size, feature_size, sequence_length)`):
|
||||
@ -1892,7 +1820,7 @@ class Qwen2_5OmniAudioEncoder(Qwen2_5OmniPreTrainedModel):
|
||||
).to(torch.int32)
|
||||
|
||||
for encoder_layer in self.layers:
|
||||
layer_outputs = encoder_layer(hidden_states, cu_seqlens)
|
||||
layer_outputs = encoder_layer(hidden_states, cu_seqlens, **kwargs)
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
hidden_states_list = hidden_states.split(aftercnn_lens.tolist(), dim=0)
|
||||
@ -1966,127 +1894,86 @@ def apply_rotary_pos_emb_vision(tensor: torch.Tensor, freqs: torch.Tensor) -> to
|
||||
|
||||
|
||||
class Qwen2_5OmniVisionAttention(nn.Module):
|
||||
def __init__(self, dim: int, num_heads: int = 16) -> None:
|
||||
def __init__(self, config: Qwen2_5OmniVisionEncoderConfig = None) -> None:
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = dim // num_heads
|
||||
self.q = nn.Linear(dim, dim, bias=True)
|
||||
self.k = nn.Linear(dim, dim, bias=True)
|
||||
self.v = nn.Linear(dim, dim, bias=True)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
self.dim = config.hidden_size
|
||||
self.num_heads = config.num_heads
|
||||
self.head_dim = self.dim // self.num_heads
|
||||
self.q = nn.Linear(self.dim, self.dim, bias=True)
|
||||
self.k = nn.Linear(self.dim, self.dim, bias=True)
|
||||
self.v = nn.Linear(self.dim, self.dim, bias=True)
|
||||
self.proj = nn.Linear(self.dim, self.dim)
|
||||
self.scaling = math.sqrt(self.head_dim)
|
||||
self.num_key_value_groups = 1 # needed for eager attention
|
||||
self.config = config
|
||||
|
||||
def forward(
|
||||
self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor = None
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor,
|
||||
rotary_pos_emb: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
seq_length = hidden_states.shape[0]
|
||||
q = self.q(hidden_states).reshape(seq_length, self.num_heads, -1)
|
||||
k = self.k(hidden_states).reshape(seq_length, self.num_heads, -1)
|
||||
v = self.v(hidden_states).reshape(seq_length, self.num_heads, -1)
|
||||
q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0)
|
||||
k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0)
|
||||
query_states = self.q(hidden_states).reshape(seq_length, self.num_heads, -1)
|
||||
key_states = self.k(hidden_states).reshape(seq_length, self.num_heads, -1)
|
||||
value_states = self.v(hidden_states).reshape(seq_length, self.num_heads, -1)
|
||||
query_states = apply_rotary_pos_emb_vision(query_states.unsqueeze(0), rotary_pos_emb).squeeze(0)
|
||||
key_states = apply_rotary_pos_emb_vision(key_states.unsqueeze(0), rotary_pos_emb).squeeze(0)
|
||||
|
||||
attention_mask = torch.full(
|
||||
[1, seq_length, seq_length], torch.finfo(q.dtype).min, device=q.device, dtype=q.dtype
|
||||
[1, 1, seq_length, seq_length],
|
||||
torch.finfo(query_states.dtype).min,
|
||||
device=query_states.device,
|
||||
dtype=query_states.dtype,
|
||||
)
|
||||
for i in range(1, len(cu_seqlens)):
|
||||
attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0
|
||||
|
||||
q = q.transpose(0, 1)
|
||||
k = k.transpose(0, 1)
|
||||
v = v.transpose(0, 1)
|
||||
attn_weights = torch.matmul(q, k.transpose(1, 2)) / math.sqrt(self.head_dim)
|
||||
attn_weights = attn_weights + attention_mask
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype)
|
||||
attn_output = torch.matmul(attn_weights, v)
|
||||
attn_output = attn_output.transpose(0, 1)
|
||||
attn_output = attn_output.reshape(seq_length, -1)
|
||||
attn_output = self.proj(attn_output)
|
||||
return attn_output
|
||||
|
||||
|
||||
class Qwen2_5OmniVisionFlashAttention2(nn.Module):
|
||||
def __init__(self, dim: int, num_heads: int = 16) -> None:
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
self.q = nn.Linear(dim, dim, bias=True)
|
||||
self.k = nn.Linear(dim, dim, bias=True)
|
||||
self.v = nn.Linear(dim, dim, bias=True)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
|
||||
def _apply_rotary_pos_emb_flashatt(self, tensor: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor:
|
||||
tensor_ = tensor.float()
|
||||
cos = freqs.cos().type_as(tensor_)
|
||||
sin = freqs.sin().type_as(tensor_)
|
||||
output = apply_rotary_emb(tensor_, cos, sin).type_as(tensor)
|
||||
return output
|
||||
|
||||
def forward(
|
||||
self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor = None
|
||||
) -> torch.Tensor:
|
||||
seq_length = hidden_states.shape[0]
|
||||
q = self.q(hidden_states).reshape(seq_length, self.num_heads, -1)
|
||||
k = self.k(hidden_states).reshape(seq_length, self.num_heads, -1)
|
||||
v = self.v(hidden_states).reshape(seq_length, self.num_heads, -1)
|
||||
q = self._apply_rotary_pos_emb_flashatt(q.unsqueeze(0), rotary_pos_emb).squeeze(0)
|
||||
k = self._apply_rotary_pos_emb_flashatt(k.unsqueeze(0), rotary_pos_emb).squeeze(0)
|
||||
|
||||
query_states = query_states.transpose(0, 1).unsqueeze(0) # unsqueeze batch_dim
|
||||
key_states = key_states.transpose(0, 1).unsqueeze(0) # unsqueeze batch_dim
|
||||
value_states = value_states.transpose(0, 1).unsqueeze(0) # unsqueeze batch_dim
|
||||
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
|
||||
attn_output = flash_attn_varlen_func(q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen).reshape(
|
||||
seq_length, -1
|
||||
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
if self.config._attn_implementation != "eager":
|
||||
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
||||
|
||||
attn_output, _ = attention_interface(
|
||||
self,
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attention_mask,
|
||||
dropout=0.0,
|
||||
scaling=self.scaling,
|
||||
cu_seqlens_q=cu_seqlens, # pass cu seq lens for FA2
|
||||
cu_seqlens_k=cu_seqlens,
|
||||
max_seqlen_q=max_seqlen,
|
||||
max_seqlen_k=max_seqlen,
|
||||
is_causal=False,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
attn_output = attn_output.reshape(seq_length, -1).contiguous()
|
||||
attn_output = self.proj(attn_output)
|
||||
return attn_output
|
||||
|
||||
|
||||
class Qwen2_5OmniVisionSdpaAttention(nn.Module):
|
||||
def __init__(self, dim: int, num_heads: int = 16) -> None:
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
self.q = nn.Linear(dim, dim, bias=True)
|
||||
self.k = nn.Linear(dim, dim, bias=True)
|
||||
self.v = nn.Linear(dim, dim, bias=True)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
|
||||
def forward(
|
||||
self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor = None
|
||||
) -> torch.Tensor:
|
||||
seq_length = hidden_states.shape[0]
|
||||
q = self.q(hidden_states).reshape(seq_length, self.num_heads, -1)
|
||||
k = self.k(hidden_states).reshape(seq_length, self.num_heads, -1)
|
||||
v = self.v(hidden_states).reshape(seq_length, self.num_heads, -1)
|
||||
q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0)
|
||||
k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0)
|
||||
|
||||
attention_mask = torch.zeros([1, seq_length, seq_length], device=q.device, dtype=torch.bool)
|
||||
for i in range(1, len(cu_seqlens)):
|
||||
attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = True
|
||||
q = q.transpose(0, 1)
|
||||
k = k.transpose(0, 1)
|
||||
v = v.transpose(0, 1)
|
||||
attn_output = F.scaled_dot_product_attention(q, k, v, attention_mask, dropout_p=0.0)
|
||||
attn_output = attn_output.transpose(0, 1)
|
||||
attn_output = attn_output.reshape(seq_length, -1)
|
||||
attn_output = self.proj(attn_output)
|
||||
return attn_output
|
||||
|
||||
|
||||
QWEN2_5_OMNI_VISION_ATTENTION_CLASSES = {
|
||||
"eager": Qwen2_5OmniVisionAttention,
|
||||
"flash_attention_2": Qwen2_5OmniVisionFlashAttention2,
|
||||
"sdpa": Qwen2_5OmniVisionSdpaAttention,
|
||||
}
|
||||
|
||||
|
||||
class Qwen2_5OmniVisionBlock(Qwen2_5_VLVisionBlock):
|
||||
def __init__(self, config: Qwen2_5OmniVisionEncoderConfig) -> None:
|
||||
super().__init__(config, config._attn_implementation)
|
||||
self.attn = QWEN2_5_OMNI_VISION_ATTENTION_CLASSES[config._attn_implementation](
|
||||
config.hidden_size, num_heads=config.num_heads
|
||||
)
|
||||
self.attn = Qwen2_5OmniVisionAttention(config=config)
|
||||
|
||||
def forward(self, hidden_states, cu_seqlens, rotary_pos_emb) -> torch.Tensor:
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor,
|
||||
rotary_pos_emb: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = hidden_states + self.attn(
|
||||
self.norm1(hidden_states), cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb
|
||||
self.norm1(hidden_states), cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb, **kwargs
|
||||
)
|
||||
hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
|
||||
return hidden_states
|
||||
@ -2100,7 +1987,7 @@ class Qwen2_5OmniVisionEncoder(Qwen2_5_VisionTransformerPretrainedModel):
|
||||
super().__init__(config, *inputs, **kwargs)
|
||||
self.blocks = nn.ModuleList([Qwen2_5OmniVisionBlock(config) for _ in range(config.depth)])
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.Tensor:
|
||||
def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
hidden_states (`torch.Tensor` of shape `(seq_len, hidden_size)`):
|
||||
@ -2150,6 +2037,7 @@ class Qwen2_5OmniVisionEncoder(Qwen2_5_VisionTransformerPretrainedModel):
|
||||
hidden_states,
|
||||
cu_seqlens=cu_seqlens_now,
|
||||
rotary_pos_emb=rotary_pos_emb,
|
||||
**kwargs,
|
||||
)
|
||||
hidden_states = self.merger(hidden_states)
|
||||
reverse_indices = torch.argsort(window_index)
|
||||
|
@ -36,7 +36,7 @@ from ...activations import ACT2FN
|
||||
from ...cache_utils import Cache, DynamicCache
|
||||
from ...generation import GenerationMixin
|
||||
from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
|
||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs, is_flash_attn_available
|
||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import BaseModelOutputWithPast, ModelOutput
|
||||
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
||||
@ -46,10 +46,6 @@ from ...utils import LossKwargs, auto_docstring, can_return_tuple, is_torchdynam
|
||||
from .configuration_qwen2_5_vl import Qwen2_5_VLConfig, Qwen2_5_VLTextConfig, Qwen2_5_VLVisionConfig
|
||||
|
||||
|
||||
if is_flash_attn_available():
|
||||
from ...modeling_flash_attention_utils import apply_rotary_emb, flash_attn_varlen_func
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@ -141,56 +137,6 @@ class Qwen2_5_VLPatchMerger(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
def apply_rotary_pos_emb_flashatt(
|
||||
q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
cos = cos.chunk(2, dim=-1)[0].contiguous()
|
||||
sin = sin.chunk(2, dim=-1)[0].contiguous()
|
||||
q_embed = apply_rotary_emb(q.float(), cos.float(), sin.float()).type_as(q)
|
||||
k_embed = apply_rotary_emb(k.float(), cos.float(), sin.float()).type_as(k)
|
||||
return q_embed, k_embed
|
||||
|
||||
|
||||
class Qwen2_5_VLVisionFlashAttention2(nn.Module):
|
||||
def __init__(self, dim: int, num_heads: int = 16) -> None:
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=True)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor,
|
||||
rotary_pos_emb: Optional[torch.Tensor] = None,
|
||||
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
) -> torch.Tensor:
|
||||
seq_length = hidden_states.shape[0]
|
||||
q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
|
||||
if position_embeddings is None:
|
||||
logger.warning_once(
|
||||
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
|
||||
"through `rotary_pos_emb` (2D tensor of RoPE theta values), to using externally computed "
|
||||
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.54 `rotary_pos_emb` will be "
|
||||
"removed and `position_embeddings` will be mandatory."
|
||||
)
|
||||
emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
|
||||
cos = emb.cos()
|
||||
sin = emb.sin()
|
||||
else:
|
||||
cos, sin = position_embeddings
|
||||
q, k = apply_rotary_pos_emb_flashatt(q.unsqueeze(0), k.unsqueeze(0), cos, sin)
|
||||
q = q.squeeze(0)
|
||||
k = k.squeeze(0)
|
||||
|
||||
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
|
||||
attn_output = flash_attn_varlen_func(q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen).reshape(
|
||||
seq_length, -1
|
||||
)
|
||||
attn_output = self.proj(attn_output)
|
||||
return attn_output
|
||||
|
||||
|
||||
def rotate_half(x):
|
||||
"""Rotates half the hidden dims of the input."""
|
||||
x1 = x[..., : x.shape[-1] // 2]
|
||||
@ -212,13 +158,55 @@ def apply_rotary_pos_emb_vision(
|
||||
return q_embed, k_embed
|
||||
|
||||
|
||||
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||
"""
|
||||
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
||||
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
||||
"""
|
||||
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
||||
if n_rep == 1:
|
||||
return hidden_states
|
||||
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
|
||||
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
||||
|
||||
|
||||
def eager_attention_forward(
|
||||
module: nn.Module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor],
|
||||
scaling: float,
|
||||
dropout: float = 0.0,
|
||||
**kwargs,
|
||||
):
|
||||
key_states = repeat_kv(key, module.num_key_value_groups)
|
||||
value_states = repeat_kv(value, module.num_key_value_groups)
|
||||
|
||||
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
|
||||
if attention_mask is not None:
|
||||
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
||||
attn_weights = attn_weights + causal_mask
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
class Qwen2_5_VLVisionAttention(nn.Module):
|
||||
def __init__(self, dim: int, num_heads: int = 16) -> None:
|
||||
def __init__(self, config: Qwen2_5_VLVisionConfig) -> None:
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = dim // num_heads
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=True)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
self.dim = config.hidden_size
|
||||
self.num_heads = config.num_heads
|
||||
self.head_dim = self.dim // self.num_heads
|
||||
self.num_key_value_groups = 1 # needed for eager attention
|
||||
self.qkv = nn.Linear(self.dim, self.dim * 3, bias=True)
|
||||
self.proj = nn.Linear(self.dim, self.dim)
|
||||
self.scaling = math.sqrt(self.head_dim)
|
||||
self.config = config
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -226,9 +214,12 @@ class Qwen2_5_VLVisionAttention(nn.Module):
|
||||
cu_seqlens: torch.Tensor,
|
||||
rotary_pos_emb: Optional[torch.Tensor] = None,
|
||||
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
seq_length = hidden_states.shape[0]
|
||||
q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
|
||||
query_states, key_states, value_states = (
|
||||
self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
|
||||
)
|
||||
if position_embeddings is None:
|
||||
logger.warning_once(
|
||||
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
|
||||
@ -241,87 +232,53 @@ class Qwen2_5_VLVisionAttention(nn.Module):
|
||||
sin = emb.sin()
|
||||
else:
|
||||
cos, sin = position_embeddings
|
||||
q, k = apply_rotary_pos_emb_vision(q, k, cos, sin)
|
||||
query_states, key_states = apply_rotary_pos_emb_vision(query_states, key_states, cos, sin)
|
||||
|
||||
attention_mask = torch.full(
|
||||
[1, seq_length, seq_length], torch.finfo(q.dtype).min, device=q.device, dtype=q.dtype
|
||||
[1, 1, seq_length, seq_length],
|
||||
torch.finfo(value_states.dtype).min,
|
||||
device=value_states.device,
|
||||
dtype=value_states.dtype,
|
||||
)
|
||||
for i in range(1, len(cu_seqlens)):
|
||||
attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0
|
||||
|
||||
q = q.transpose(0, 1)
|
||||
k = k.transpose(0, 1)
|
||||
v = v.transpose(0, 1)
|
||||
attn_weights = torch.matmul(q, k.transpose(1, 2)) / math.sqrt(self.head_dim)
|
||||
attn_weights = attn_weights + attention_mask
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype)
|
||||
attn_output = torch.matmul(attn_weights, v)
|
||||
attn_output = attn_output.transpose(0, 1)
|
||||
attn_output = attn_output.reshape(seq_length, -1)
|
||||
attn_output = self.proj(attn_output)
|
||||
return attn_output
|
||||
query_states = query_states.transpose(0, 1).unsqueeze(0) # unsqueeze batch_dim
|
||||
key_states = key_states.transpose(0, 1).unsqueeze(0) # unsqueeze batch_dim
|
||||
value_states = value_states.transpose(0, 1).unsqueeze(0) # unsqueeze batch_dim
|
||||
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
|
||||
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
if self.config._attn_implementation != "eager":
|
||||
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
||||
|
||||
class Qwen2_5_VLVisionSdpaAttention(nn.Module):
|
||||
def __init__(self, dim: int, num_heads: int = 16) -> None:
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=True)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor,
|
||||
rotary_pos_emb: Optional[torch.Tensor] = None,
|
||||
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
) -> torch.Tensor:
|
||||
seq_length = hidden_states.shape[0]
|
||||
q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
|
||||
if position_embeddings is None:
|
||||
logger.warning_once(
|
||||
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
|
||||
"through `rotary_pos_emb` (2D tensor of RoPE theta values), to using externally computed "
|
||||
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.54 `rotary_pos_emb` will be "
|
||||
"removed and `position_embeddings` will be mandatory."
|
||||
)
|
||||
emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
|
||||
cos = emb.cos()
|
||||
sin = emb.sin()
|
||||
else:
|
||||
cos, sin = position_embeddings
|
||||
q, k = apply_rotary_pos_emb_vision(q, k, cos, sin)
|
||||
|
||||
attention_mask = torch.zeros([1, seq_length, seq_length], device=q.device, dtype=torch.bool)
|
||||
for i in range(1, len(cu_seqlens)):
|
||||
attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = True
|
||||
q = q.transpose(0, 1)
|
||||
k = k.transpose(0, 1)
|
||||
v = v.transpose(0, 1)
|
||||
attn_output = F.scaled_dot_product_attention(
|
||||
q.unsqueeze(0), k.unsqueeze(0), v.unsqueeze(0), attention_mask, dropout_p=0.0
|
||||
attn_output, _ = attention_interface(
|
||||
self,
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attention_mask,
|
||||
dropout=0.0,
|
||||
scaling=self.scaling,
|
||||
cu_seqlens_q=cu_seqlens, # pass cu seq lens for FA2
|
||||
cu_seqlens_k=cu_seqlens,
|
||||
max_seqlen_q=max_seqlen,
|
||||
max_seqlen_k=max_seqlen,
|
||||
is_causal=False,
|
||||
**kwargs,
|
||||
)
|
||||
attn_output = attn_output.squeeze(0).transpose(0, 1)
|
||||
attn_output = attn_output.reshape(seq_length, -1)
|
||||
|
||||
attn_output = attn_output.reshape(seq_length, -1).contiguous()
|
||||
attn_output = self.proj(attn_output)
|
||||
return attn_output
|
||||
|
||||
|
||||
QWEN2_5_VL_VISION_ATTENTION_CLASSES = {
|
||||
"eager": Qwen2_5_VLVisionAttention,
|
||||
"flash_attention_2": Qwen2_5_VLVisionFlashAttention2,
|
||||
"sdpa": Qwen2_5_VLVisionSdpaAttention,
|
||||
}
|
||||
|
||||
|
||||
class Qwen2_5_VLVisionBlock(GradientCheckpointingLayer):
|
||||
def __init__(self, config, attn_implementation: str = "sdpa") -> None:
|
||||
super().__init__()
|
||||
self.norm1 = Qwen2RMSNorm(config.hidden_size, eps=1e-6)
|
||||
self.norm2 = Qwen2RMSNorm(config.hidden_size, eps=1e-6)
|
||||
self.attn = QWEN2_5_VL_VISION_ATTENTION_CLASSES[attn_implementation](
|
||||
config.hidden_size, num_heads=config.num_heads
|
||||
)
|
||||
self.attn = Qwen2_5_VLVisionAttention(config=config)
|
||||
self.mlp = Qwen2_5_VLMLP(config, bias=True)
|
||||
|
||||
def forward(
|
||||
@ -330,12 +287,14 @@ class Qwen2_5_VLVisionBlock(GradientCheckpointingLayer):
|
||||
cu_seqlens: torch.Tensor,
|
||||
rotary_pos_emb: Optional[torch.Tensor] = None,
|
||||
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = hidden_states + self.attn(
|
||||
self.norm1(hidden_states),
|
||||
cu_seqlens=cu_seqlens,
|
||||
rotary_pos_emb=rotary_pos_emb,
|
||||
position_embeddings=position_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
|
||||
return hidden_states
|
||||
@ -390,9 +349,7 @@ class Qwen2_5_VisionTransformerPretrainedModel(Qwen2_5_VLPreTrainedModel):
|
||||
head_dim = config.hidden_size // config.num_heads
|
||||
self.rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding(head_dim // 2)
|
||||
|
||||
self.blocks = nn.ModuleList(
|
||||
[Qwen2_5_VLVisionBlock(config, config._attn_implementation) for _ in range(config.depth)]
|
||||
)
|
||||
self.blocks = nn.ModuleList([Qwen2_5_VLVisionBlock(config) for _ in range(config.depth)])
|
||||
self.merger = Qwen2_5_VLPatchMerger(
|
||||
dim=config.out_hidden_size,
|
||||
context_dim=config.hidden_size,
|
||||
@ -470,7 +427,7 @@ class Qwen2_5_VisionTransformerPretrainedModel(Qwen2_5_VLPreTrainedModel):
|
||||
|
||||
return window_index, cu_window_seqlens
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.Tensor:
|
||||
def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
hidden_states (`torch.Tensor` of shape `(seq_len, hidden_size)`):
|
||||
@ -516,7 +473,9 @@ class Qwen2_5_VisionTransformerPretrainedModel(Qwen2_5_VLPreTrainedModel):
|
||||
cu_seqlens_now = cu_seqlens
|
||||
else:
|
||||
cu_seqlens_now = cu_window_seqlens
|
||||
hidden_states = blk(hidden_states, cu_seqlens=cu_seqlens_now, position_embeddings=position_embeddings)
|
||||
hidden_states = blk(
|
||||
hidden_states, cu_seqlens=cu_seqlens_now, position_embeddings=position_embeddings, **kwargs
|
||||
)
|
||||
|
||||
hidden_states = self.merger(hidden_states)
|
||||
reverse_indices = torch.argsort(window_index)
|
||||
@ -647,44 +606,6 @@ def apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim
|
||||
return q_embed, k_embed
|
||||
|
||||
|
||||
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||
"""
|
||||
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
||||
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
||||
"""
|
||||
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
||||
if n_rep == 1:
|
||||
return hidden_states
|
||||
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
|
||||
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
||||
|
||||
|
||||
def eager_attention_forward(
|
||||
module: nn.Module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor],
|
||||
scaling: float,
|
||||
dropout: float = 0.0,
|
||||
**kwargs,
|
||||
):
|
||||
key_states = repeat_kv(key, module.num_key_value_groups)
|
||||
value_states = repeat_kv(value, module.num_key_value_groups)
|
||||
|
||||
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
|
||||
if attention_mask is not None:
|
||||
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
||||
attn_weights = attn_weights + causal_mask
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
class Qwen2_5_VLAttention(nn.Module):
|
||||
"""
|
||||
Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer
|
||||
|
@ -40,7 +40,6 @@ from transformers.models.qwen2_vl.modeling_qwen2_vl import (
|
||||
Qwen2VLPreTrainedModel,
|
||||
VisionAttention,
|
||||
VisionRotaryEmbedding,
|
||||
VisionSdpaAttention,
|
||||
)
|
||||
from transformers.models.qwen2_vl.processing_qwen2_vl import Qwen2VLImagesKwargs, Qwen2VLProcessor
|
||||
|
||||
@ -57,22 +56,12 @@ from ...video_utils import VideoInput
|
||||
|
||||
|
||||
if is_flash_attn_available():
|
||||
from ...modeling_flash_attention_utils import apply_rotary_emb, flash_attn_varlen_func
|
||||
pass
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def apply_rotary_pos_emb_flashatt(
|
||||
q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
cos = cos.chunk(2, dim=-1)[0].contiguous()
|
||||
sin = sin.chunk(2, dim=-1)[0].contiguous()
|
||||
q_embed = apply_rotary_emb(q.float(), cos.float(), sin.float()).type_as(q)
|
||||
k_embed = apply_rotary_emb(k.float(), cos.float(), sin.float()).type_as(k)
|
||||
return q_embed, k_embed
|
||||
|
||||
|
||||
class Qwen2_5_VLVisionConfig(PretrainedConfig):
|
||||
model_type = "qwen2_5_vl"
|
||||
base_config_key = "vision_config"
|
||||
@ -150,59 +139,10 @@ class Qwen2_5_VLPatchMerger(PatchMerger):
|
||||
self.ln_q = Qwen2RMSNorm(context_dim, eps=1e-6)
|
||||
|
||||
|
||||
class Qwen2_5_VLVisionFlashAttention2(nn.Module):
|
||||
def __init__(self, dim: int, num_heads: int = 16) -> None:
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=True)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor,
|
||||
rotary_pos_emb: Optional[torch.Tensor] = None,
|
||||
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
) -> torch.Tensor:
|
||||
seq_length = hidden_states.shape[0]
|
||||
q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
|
||||
if position_embeddings is None:
|
||||
logger.warning_once(
|
||||
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
|
||||
"through `rotary_pos_emb` (2D tensor of RoPE theta values), to using externally computed "
|
||||
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.54 `rotary_pos_emb` will be "
|
||||
"removed and `position_embeddings` will be mandatory."
|
||||
)
|
||||
emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
|
||||
cos = emb.cos()
|
||||
sin = emb.sin()
|
||||
else:
|
||||
cos, sin = position_embeddings
|
||||
q, k = apply_rotary_pos_emb_flashatt(q.unsqueeze(0), k.unsqueeze(0), cos, sin)
|
||||
q = q.squeeze(0)
|
||||
k = k.squeeze(0)
|
||||
|
||||
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
|
||||
attn_output = flash_attn_varlen_func(q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen).reshape(
|
||||
seq_length, -1
|
||||
)
|
||||
attn_output = self.proj(attn_output)
|
||||
return attn_output
|
||||
|
||||
|
||||
class Qwen2_5_VLVisionAttention(VisionAttention):
|
||||
pass
|
||||
|
||||
|
||||
class Qwen2_5_VLVisionSdpaAttention(VisionSdpaAttention):
|
||||
pass
|
||||
|
||||
|
||||
QWEN2_5_VL_VISION_ATTENTION_CLASSES = {
|
||||
"eager": Qwen2_5_VLVisionAttention,
|
||||
"flash_attention_2": Qwen2_5_VLVisionFlashAttention2,
|
||||
"sdpa": Qwen2_5_VLVisionSdpaAttention,
|
||||
}
|
||||
def __init__(self, config: Qwen2_5_VLVisionConfig) -> None:
|
||||
super().__init__()
|
||||
self.dim = config.hidden_size
|
||||
|
||||
|
||||
class Qwen2_5_VLVisionBlock(GradientCheckpointingLayer):
|
||||
@ -210,9 +150,7 @@ class Qwen2_5_VLVisionBlock(GradientCheckpointingLayer):
|
||||
super().__init__()
|
||||
self.norm1 = Qwen2RMSNorm(config.hidden_size, eps=1e-6)
|
||||
self.norm2 = Qwen2RMSNorm(config.hidden_size, eps=1e-6)
|
||||
self.attn = QWEN2_5_VL_VISION_ATTENTION_CLASSES[attn_implementation](
|
||||
config.hidden_size, num_heads=config.num_heads
|
||||
)
|
||||
self.attn = Qwen2_5_VLVisionAttention(config=config)
|
||||
self.mlp = Qwen2_5_VLMLP(config, bias=True)
|
||||
|
||||
def forward(
|
||||
@ -221,12 +159,14 @@ class Qwen2_5_VLVisionBlock(GradientCheckpointingLayer):
|
||||
cu_seqlens: torch.Tensor,
|
||||
rotary_pos_emb: Optional[torch.Tensor] = None,
|
||||
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = hidden_states + self.attn(
|
||||
self.norm1(hidden_states),
|
||||
cu_seqlens=cu_seqlens,
|
||||
rotary_pos_emb=rotary_pos_emb,
|
||||
position_embeddings=position_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
|
||||
return hidden_states
|
||||
@ -269,9 +209,7 @@ class Qwen2_5_VisionTransformerPretrainedModel(Qwen2_5_VLPreTrainedModel):
|
||||
head_dim = config.hidden_size // config.num_heads
|
||||
self.rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding(head_dim // 2)
|
||||
|
||||
self.blocks = nn.ModuleList(
|
||||
[Qwen2_5_VLVisionBlock(config, config._attn_implementation) for _ in range(config.depth)]
|
||||
)
|
||||
self.blocks = nn.ModuleList([Qwen2_5_VLVisionBlock(config) for _ in range(config.depth)])
|
||||
self.merger = Qwen2_5_VLPatchMerger(
|
||||
dim=config.out_hidden_size,
|
||||
context_dim=config.hidden_size,
|
||||
@ -349,7 +287,7 @@ class Qwen2_5_VisionTransformerPretrainedModel(Qwen2_5_VLPreTrainedModel):
|
||||
|
||||
return window_index, cu_window_seqlens
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.Tensor:
|
||||
def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
hidden_states (`torch.Tensor` of shape `(seq_len, hidden_size)`):
|
||||
@ -395,7 +333,9 @@ class Qwen2_5_VisionTransformerPretrainedModel(Qwen2_5_VLPreTrainedModel):
|
||||
cu_seqlens_now = cu_seqlens
|
||||
else:
|
||||
cu_seqlens_now = cu_window_seqlens
|
||||
hidden_states = blk(hidden_states, cu_seqlens=cu_seqlens_now, position_embeddings=position_embeddings)
|
||||
hidden_states = blk(
|
||||
hidden_states, cu_seqlens=cu_seqlens_now, position_embeddings=position_embeddings, **kwargs
|
||||
)
|
||||
|
||||
hidden_states = self.merger(hidden_states)
|
||||
reverse_indices = torch.argsort(window_index)
|
||||
|
@ -33,7 +33,7 @@ from ...activations import ACT2FN
|
||||
from ...cache_utils import Cache, DynamicCache
|
||||
from ...generation import GenerationMixin
|
||||
from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
|
||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs, is_flash_attn_available
|
||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import BaseModelOutputWithPast, ModelOutput
|
||||
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
||||
@ -49,10 +49,6 @@ from ...utils import (
|
||||
from .configuration_qwen2_vl import Qwen2VLConfig, Qwen2VLTextConfig, Qwen2VLVisionConfig
|
||||
|
||||
|
||||
if is_flash_attn_available():
|
||||
from ...modeling_flash_attention_utils import flash_attn_varlen_func
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@ -279,13 +275,56 @@ class VisionMlp(nn.Module):
|
||||
return self.fc2(self.act(self.fc1(x)))
|
||||
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.repeat_kv
|
||||
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||
"""
|
||||
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
||||
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
||||
"""
|
||||
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
||||
if n_rep == 1:
|
||||
return hidden_states
|
||||
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
|
||||
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
||||
|
||||
|
||||
def eager_attention_forward(
|
||||
module: nn.Module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor],
|
||||
scaling: float,
|
||||
dropout: float = 0.0,
|
||||
**kwargs,
|
||||
):
|
||||
key_states = repeat_kv(key, module.num_key_value_groups)
|
||||
value_states = repeat_kv(value, module.num_key_value_groups)
|
||||
|
||||
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
|
||||
if attention_mask is not None:
|
||||
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
||||
attn_weights = attn_weights + causal_mask
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
class VisionAttention(nn.Module):
|
||||
def __init__(self, dim: int, num_heads: int = 16) -> None:
|
||||
def __init__(self, config: Qwen2VLVisionConfig) -> None:
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = dim // num_heads
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=True)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
self.dim = config.embed_dim
|
||||
self.num_heads = config.num_heads
|
||||
self.head_dim = self.dim // self.num_heads
|
||||
self.num_key_value_groups = 1 # needed for eager attention
|
||||
self.qkv = nn.Linear(self.dim, self.dim * 3, bias=True)
|
||||
self.proj = nn.Linear(self.dim, self.dim)
|
||||
self.scaling = math.sqrt(self.head_dim)
|
||||
self.config = config
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -293,9 +332,12 @@ class VisionAttention(nn.Module):
|
||||
cu_seqlens: torch.Tensor,
|
||||
rotary_pos_emb: Optional[torch.Tensor] = None,
|
||||
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
seq_length = hidden_states.shape[0]
|
||||
q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
|
||||
query_states, key_states, value_states = (
|
||||
self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
|
||||
)
|
||||
if position_embeddings is None:
|
||||
logger.warning_once(
|
||||
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
|
||||
@ -308,117 +350,47 @@ class VisionAttention(nn.Module):
|
||||
sin = emb.sin()
|
||||
else:
|
||||
cos, sin = position_embeddings
|
||||
q, k = apply_rotary_pos_emb_vision(q, k, cos, sin)
|
||||
query_states, key_states = apply_rotary_pos_emb_vision(query_states, key_states, cos, sin)
|
||||
|
||||
attention_mask = torch.full(
|
||||
[1, seq_length, seq_length], torch.finfo(q.dtype).min, device=q.device, dtype=q.dtype
|
||||
[1, 1, seq_length, seq_length],
|
||||
torch.finfo(value_states.dtype).min,
|
||||
device=value_states.device,
|
||||
dtype=value_states.dtype,
|
||||
)
|
||||
for i in range(1, len(cu_seqlens)):
|
||||
attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0
|
||||
|
||||
q = q.transpose(0, 1)
|
||||
k = k.transpose(0, 1)
|
||||
v = v.transpose(0, 1)
|
||||
attn_weights = torch.matmul(q, k.transpose(1, 2)) / math.sqrt(self.head_dim)
|
||||
attn_weights = attn_weights + attention_mask
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype)
|
||||
attn_output = torch.matmul(attn_weights, v)
|
||||
attn_output = attn_output.transpose(0, 1)
|
||||
attn_output = attn_output.reshape(seq_length, -1)
|
||||
attn_output = self.proj(attn_output)
|
||||
return attn_output
|
||||
|
||||
|
||||
class VisionFlashAttention2(nn.Module):
|
||||
def __init__(self, dim: int, num_heads: int = 16) -> None:
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=True)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor,
|
||||
rotary_pos_emb: Optional[torch.Tensor] = None,
|
||||
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
) -> torch.Tensor:
|
||||
seq_length = hidden_states.shape[0]
|
||||
q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
|
||||
if position_embeddings is None:
|
||||
logger.warning_once(
|
||||
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
|
||||
"through `rotary_pos_emb` (2D tensor of RoPE theta values), to using externally computed "
|
||||
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.54 `rotary_pos_emb` will be "
|
||||
"removed and `position_embeddings` will be mandatory."
|
||||
)
|
||||
emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
|
||||
cos = emb.cos()
|
||||
sin = emb.sin()
|
||||
else:
|
||||
cos, sin = position_embeddings
|
||||
q, k = apply_rotary_pos_emb_vision(q, k, cos, sin)
|
||||
|
||||
query_states = query_states.transpose(0, 1).unsqueeze(0) # unsqueeze batch_dim
|
||||
key_states = key_states.transpose(0, 1).unsqueeze(0) # unsqueeze batch_dim
|
||||
value_states = value_states.transpose(0, 1).unsqueeze(0) # unsqueeze batch_dim
|
||||
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
|
||||
attn_output = flash_attn_varlen_func(q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen).reshape(
|
||||
seq_length, -1
|
||||
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
if self.config._attn_implementation != "eager":
|
||||
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
||||
|
||||
attn_output, _ = attention_interface(
|
||||
self,
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attention_mask,
|
||||
dropout=0.0,
|
||||
scaling=self.scaling,
|
||||
cu_seqlens_q=cu_seqlens, # pass cu seq lens for FA2
|
||||
cu_seqlens_k=cu_seqlens,
|
||||
max_seqlen_q=max_seqlen,
|
||||
max_seqlen_k=max_seqlen,
|
||||
is_causal=False,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
attn_output = attn_output.reshape(seq_length, -1).contiguous()
|
||||
attn_output = self.proj(attn_output)
|
||||
return attn_output
|
||||
|
||||
|
||||
class VisionSdpaAttention(nn.Module):
|
||||
def __init__(self, dim: int, num_heads: int = 16) -> None:
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=True)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor,
|
||||
rotary_pos_emb: Optional[torch.Tensor] = None,
|
||||
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
) -> torch.Tensor:
|
||||
seq_length = hidden_states.shape[0]
|
||||
q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
|
||||
if position_embeddings is None:
|
||||
logger.warning_once(
|
||||
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
|
||||
"through `rotary_pos_emb` (2D tensor of RoPE theta values), to using externally computed "
|
||||
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.54 `rotary_pos_emb` will be "
|
||||
"removed and `position_embeddings` will be mandatory."
|
||||
)
|
||||
emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
|
||||
cos = emb.cos()
|
||||
sin = emb.sin()
|
||||
else:
|
||||
cos, sin = position_embeddings
|
||||
q, k = apply_rotary_pos_emb_vision(q, k, cos, sin)
|
||||
|
||||
attention_mask = torch.zeros([1, seq_length, seq_length], device=q.device, dtype=torch.bool)
|
||||
for i in range(1, len(cu_seqlens)):
|
||||
attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = True
|
||||
q = q.transpose(0, 1)
|
||||
k = k.transpose(0, 1)
|
||||
v = v.transpose(0, 1)
|
||||
attn_output = F.scaled_dot_product_attention(
|
||||
q.unsqueeze(0), k.unsqueeze(0), v.unsqueeze(0), attention_mask, dropout_p=0.0
|
||||
)
|
||||
attn_output = attn_output.squeeze(0).transpose(0, 1)
|
||||
attn_output = attn_output.reshape(seq_length, -1)
|
||||
attn_output = self.proj(attn_output)
|
||||
return attn_output
|
||||
|
||||
|
||||
QWEN2_VL_VISION_ATTENTION_CLASSES = {
|
||||
"eager": VisionAttention,
|
||||
"flash_attention_2": VisionFlashAttention2,
|
||||
"sdpa": VisionSdpaAttention,
|
||||
}
|
||||
|
||||
|
||||
class Qwen2VLVisionBlock(GradientCheckpointingLayer):
|
||||
def __init__(self, config, attn_implementation: str = "sdpa") -> None:
|
||||
super().__init__()
|
||||
@ -426,9 +398,7 @@ class Qwen2VLVisionBlock(GradientCheckpointingLayer):
|
||||
self.norm2 = LayerNorm(config.embed_dim, eps=1e-6)
|
||||
mlp_hidden_dim = int(config.embed_dim * config.mlp_ratio)
|
||||
|
||||
self.attn = QWEN2_VL_VISION_ATTENTION_CLASSES[attn_implementation](
|
||||
config.embed_dim, num_heads=config.num_heads
|
||||
)
|
||||
self.attn = VisionAttention(config=config)
|
||||
self.mlp = VisionMlp(dim=config.embed_dim, hidden_dim=mlp_hidden_dim, hidden_act=config.hidden_act)
|
||||
|
||||
def forward(
|
||||
@ -437,12 +407,14 @@ class Qwen2VLVisionBlock(GradientCheckpointingLayer):
|
||||
cu_seqlens: torch.Tensor,
|
||||
rotary_pos_emb: Optional[torch.Tensor] = None,
|
||||
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = hidden_states + self.attn(
|
||||
self.norm1(hidden_states),
|
||||
cu_seqlens=cu_seqlens,
|
||||
rotary_pos_emb=rotary_pos_emb,
|
||||
position_embeddings=position_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
|
||||
return hidden_states
|
||||
@ -486,45 +458,6 @@ class Qwen2MLP(nn.Module):
|
||||
return down_proj
|
||||
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.repeat_kv
|
||||
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||
"""
|
||||
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
||||
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
||||
"""
|
||||
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
||||
if n_rep == 1:
|
||||
return hidden_states
|
||||
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
|
||||
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
||||
|
||||
|
||||
def eager_attention_forward(
|
||||
module: nn.Module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor],
|
||||
scaling: float,
|
||||
dropout: float = 0.0,
|
||||
**kwargs,
|
||||
):
|
||||
key_states = repeat_kv(key, module.num_key_value_groups)
|
||||
value_states = repeat_kv(value, module.num_key_value_groups)
|
||||
|
||||
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
|
||||
if attention_mask is not None:
|
||||
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
||||
attn_weights = attn_weights + causal_mask
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
class Qwen2VLAttention(nn.Module):
|
||||
"""
|
||||
Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer
|
||||
@ -752,9 +685,7 @@ class Qwen2VisionTransformerPretrainedModel(Qwen2VLPreTrainedModel):
|
||||
head_dim = config.embed_dim // config.num_heads
|
||||
self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2)
|
||||
|
||||
self.blocks = nn.ModuleList(
|
||||
[Qwen2VLVisionBlock(config, config._attn_implementation) for _ in range(config.depth)]
|
||||
)
|
||||
self.blocks = nn.ModuleList([Qwen2VLVisionBlock(config) for _ in range(config.depth)])
|
||||
self.merger = PatchMerger(
|
||||
dim=config.hidden_size, context_dim=config.embed_dim, spatial_merge_size=config.spatial_merge_size
|
||||
)
|
||||
@ -796,7 +727,12 @@ class Qwen2VisionTransformerPretrainedModel(Qwen2VLPreTrainedModel):
|
||||
return rotary_pos_emb
|
||||
|
||||
@auto_docstring
|
||||
def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.Tensor:
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
grid_thw: torch.Tensor,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
r"""
|
||||
grid_thw (`torch.LongTensor` of shape `(num_images, 3)`):
|
||||
The temporal, height and width dimensions of feature shape for each image. Each row contains [t, h, w] values.
|
||||
@ -817,7 +753,9 @@ class Qwen2VisionTransformerPretrainedModel(Qwen2VLPreTrainedModel):
|
||||
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
|
||||
|
||||
for blk in self.blocks:
|
||||
hidden_states = blk(hidden_states, cu_seqlens=cu_seqlens, position_embeddings=position_embeddings)
|
||||
hidden_states = blk(
|
||||
hidden_states, cu_seqlens=cu_seqlens, position_embeddings=position_embeddings, **kwargs
|
||||
)
|
||||
|
||||
return self.merger(hidden_states)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user