[qwen2-vl] fix FA2 inference (#39121)

* fix FA2

* update is causal flag and remove mask for FA2

* update for FA2 with varlen path

* how the tests were passing with different devices?

* add comment and ref to the PR

* move mask preparation to base pretrained model

* seq len is the first dim, not second

* fix copies to fix GLM4V
This commit is contained in:
Raushan Turganbay 2025-07-01 12:18:37 +02:00 committed by GitHub
parent def9663239
commit 7a25f8dfdb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 363 additions and 199 deletions

View File

@ -508,6 +508,22 @@ def _flash_attention_forward(
query_states, key_states, value_states, target_dtype
)
# We will use `flash_attn_varlen_func` to prevent cross-example attention and also allow padding free approach
# under two cases:
# Case 1. If position_ids is provided and check all examples do not contain only 1 sequence, If tensor in increasing
# then we probably have one sequence, otherwise it is packed. Additionally check we are in pre-fill/training stage.
# Case 2. Some models pass directly pre-computed `cu_seqlens` so we don't need to infer it from position ids. It is safe to
# use `flash_attn_varlen_func` knowing we already have all necessary the kwargs. NOTE: it is user's responsibility
# to take care of flattenning `position_ids` if that's needed by the model. See #39121 for more information
is_fa2_with_position_ids = (
position_ids is not None
and query_states.shape[0] == 1
and (max_length_q is not None or (query_length != 1 and not (torch.diff(position_ids, dim=-1) >= 0).all()))
)
is_fa2_with_varlen_kwargs = all(
kwarg is not None for kwarg in (cu_seq_lens_q, cu_seq_lens_k, max_length_q, max_length_k)
)
# Contains at least one padding token in the sequence
if attention_mask is not None:
batch_size = query_states.shape[0]
@ -531,14 +547,7 @@ def _flash_attention_forward(
)
attn_output = _pad_input(attn_output_unpad, indices_q, batch_size, query_length)
# If position_ids is provided and check all examples do not contain only 1 sequence, If tensor in increasing
# then we probably have one sequence, otherwise it is packed. Additionally check we are in pre-fill/training stage.
# Use `flash_attn_varlen_func` to prevent cross-example attention and also allow padding free approach
elif (
position_ids is not None
and query_states.shape[0] == 1
and (max_length_q is not None or (query_length != 1 and not (torch.diff(position_ids, dim=-1) >= 0).all()))
):
elif is_fa2_with_varlen_kwargs or is_fa2_with_position_ids:
batch_size = query_states.size(0)
if cu_seq_lens_q is None or cu_seq_lens_k is None:

View File

@ -279,14 +279,15 @@ def eager_attention_forward(
class Glm4vVisionAttention(nn.Module):
def __init__(self, config: Glm4vVisionConfig) -> None:
super().__init__()
self.config = config
self.dim = config.hidden_size
self.num_heads = config.num_heads
self.head_dim = config.hidden_size // self.num_heads
self.num_key_value_groups = 1
self.scale = self.head_dim**-0.5
self.attention_dropout = config.attention_dropout
self.head_dim = self.dim // self.num_heads
self.num_key_value_groups = 1 # needed for eager attention
self.qkv = nn.Linear(config.hidden_size, config.hidden_size * 3, bias=config.attention_bias)
self.proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
self.scaling = self.head_dim**-0.5
self.config = config
self.attention_dropout = config.attention_dropout
self.is_causal = False
def forward(
@ -295,23 +296,31 @@ class Glm4vVisionAttention(nn.Module):
cu_seqlens: torch.Tensor,
rotary_pos_emb: Optional[torch.Tensor] = None,
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
**kwargs: Unpack[FlashAttentionKwargs],
attention_mask: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
seq_length = hidden_states.shape[0]
query_states, key_states, value_states = (
self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
)
cos, sin = position_embeddings
if position_embeddings is None:
logger.warning_once(
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
"through `rotary_pos_emb` (2D tensor of RoPE theta values), to using externally computed "
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.54 `rotary_pos_emb` will be "
"removed and `position_embeddings` will be mandatory."
)
emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
cos = emb.cos()
sin = emb.sin()
else:
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb_vision(query_states, key_states, cos, sin)
query_states = query_states.transpose(0, 1).unsqueeze(0)
key_states = key_states.transpose(0, 1).unsqueeze(0)
value_states = value_states.transpose(0, 1).unsqueeze(0)
attention_mask = torch.zeros([1, 1, seq_length, seq_length], device=query_states.device, dtype=torch.bool)
for i in range(1, len(cu_seqlens)):
attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = True
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
@ -322,13 +331,17 @@ class Glm4vVisionAttention(nn.Module):
query_states,
key_states,
value_states,
attention_mask,
attention_mask=attention_mask,
dropout=0.0 if not self.training else self.attention_dropout,
scaling=self.scale,
is_causal=self.is_causal,
scaling=self.scaling,
cu_seq_lens_q=cu_seqlens, # pass cu seq lens for FA2
cu_seq_lens_k=cu_seqlens,
max_length_q=max_seqlen,
max_length_k=max_seqlen,
is_causal=False,
**kwargs,
)
attn_output = attn_output.squeeze(0)
attn_output = attn_output.reshape(seq_length, -1).contiguous()
attn_output = self.proj(attn_output)
return attn_output
@ -348,6 +361,7 @@ class Glm4vVisionBlock(GradientCheckpointingLayer):
cu_seqlens: torch.Tensor,
rotary_pos_emb: Optional[torch.Tensor] = None,
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
hidden_states = hidden_states + self.attn(
@ -355,6 +369,7 @@ class Glm4vVisionBlock(GradientCheckpointingLayer):
cu_seqlens=cu_seqlens,
rotary_pos_emb=rotary_pos_emb,
position_embeddings=position_embeddings,
attention_mask=attention_mask,
**kwargs,
)
hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
@ -452,6 +467,25 @@ class Glm4vVisionModel(Glm4vPreTrainedModel):
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
return rotary_pos_emb, pos_ids
def _prepare_attention_mask(self, inputs_tensor: torch.Tensor, cu_seqlens: torch.Tensor) -> torch.Tensor:
# Flash Attention 2 doesn't need a 4D mask and relies on `cu_seqlens/max_seqlen`
# NOTE: the created attention masl only approximates the ragged FA2 attention by
# allowing bidirectional attention within `cu_seqlens` blocks, and not attending between
# blocks. Though it will not be a 100% match for FA2's `varlen` path
if self.config._attn_implementation == "flash_attention_2":
return None
seq_length = inputs_tensor.shape[0]
attention_mask = torch.full(
[1, 1, seq_length, seq_length],
torch.finfo(inputs_tensor.dtype).min,
device=inputs_tensor.device,
dtype=inputs_tensor.dtype,
)
for i in range(1, len(cu_seqlens)):
attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0
return attention_mask
def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.Tensor:
"""
Args:
@ -481,14 +515,15 @@ class Glm4vVisionModel(Glm4vPreTrainedModel):
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
hidden_states = self.embeddings(hidden_states, seqlens, grid_thw, image_type_ids[:, 0], image_type_ids[:, 1])
attention_mask = self._prepare_attention_mask(hidden_states, cu_seqlens=cu_seqlens)
for blk in self.blocks:
if self.gradient_checkpointing and self.training:
hidden_states = self._gradient_checkpointing_func(
blk.__call__, hidden_states, cu_seqlens, None, position_embeddings
)
else:
hidden_states = blk(hidden_states, cu_seqlens=cu_seqlens, position_embeddings=position_embeddings)
hidden_states = blk(
hidden_states,
cu_seqlens=cu_seqlens,
position_embeddings=position_embeddings,
attention_mask=attention_mask,
)
hidden_states = self.post_layernorm(hidden_states)

View File

@ -50,8 +50,8 @@ from ..qwen2_5_vl.modeling_qwen2_5_vl import (
Qwen2_5_VLPreTrainedModel,
Qwen2_5_VLRotaryEmbedding,
Qwen2_5_VLTextModel,
Qwen2_5_VLVisionAttention,
Qwen2_5_VLVisionBlock,
apply_rotary_pos_emb_vision,
)
from ..qwen2_5_vl.processing_qwen2_5_vl import (
Qwen2_5_VLProcessor,
@ -505,62 +505,12 @@ class Glm4vVisionEmbeddings(nn.Module):
return embeddings
class Glm4vVisionAttention(nn.Module):
class Glm4vVisionAttention(Qwen2_5_VLVisionAttention):
def __init__(self, config: Glm4vVisionConfig) -> None:
super().__init__()
self.config = config
self.num_heads = config.num_heads
self.head_dim = config.hidden_size // self.num_heads
self.num_key_value_groups = 1
self.scale = self.head_dim**-0.5
self.attention_dropout = config.attention_dropout
self.qkv = nn.Linear(config.hidden_size, config.hidden_size * 3, bias=config.attention_bias)
self.proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
self.is_causal = False
def forward(
self,
hidden_states: torch.Tensor,
cu_seqlens: torch.Tensor,
rotary_pos_emb: Optional[torch.Tensor] = None,
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
**kwargs: Unpack[FlashAttentionKwargs],
) -> torch.Tensor:
seq_length = hidden_states.shape[0]
query_states, key_states, value_states = (
self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
)
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb_vision(query_states, key_states, cos, sin)
query_states = query_states.transpose(0, 1).unsqueeze(0)
key_states = key_states.transpose(0, 1).unsqueeze(0)
value_states = value_states.transpose(0, 1).unsqueeze(0)
attention_mask = torch.zeros([1, 1, seq_length, seq_length], device=query_states.device, dtype=torch.bool)
for i in range(1, len(cu_seqlens)):
attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = True
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
attn_output, _ = attention_interface(
self,
query_states,
key_states,
value_states,
attention_mask,
dropout=0.0 if not self.training else self.attention_dropout,
scaling=self.scale,
is_causal=self.is_causal,
**kwargs,
)
attn_output = attn_output.squeeze(0)
attn_output = attn_output.reshape(seq_length, -1).contiguous()
attn_output = self.proj(attn_output)
return attn_output
class Glm4vVisionBlock(Qwen2_5_VLVisionBlock):
@ -653,6 +603,25 @@ class Glm4vVisionModel(Glm4vPreTrainedModel):
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
return rotary_pos_emb, pos_ids
def _prepare_attention_mask(self, inputs_tensor: torch.Tensor, cu_seqlens: torch.Tensor) -> torch.Tensor:
# Flash Attention 2 doesn't need a 4D mask and relies on `cu_seqlens/max_seqlen`
# NOTE: the created attention masl only approximates the ragged FA2 attention by
# allowing bidirectional attention within `cu_seqlens` blocks, and not attending between
# blocks. Though it will not be a 100% match for FA2's `varlen` path
if self.config._attn_implementation == "flash_attention_2":
return None
seq_length = inputs_tensor.shape[0]
attention_mask = torch.full(
[1, 1, seq_length, seq_length],
torch.finfo(inputs_tensor.dtype).min,
device=inputs_tensor.device,
dtype=inputs_tensor.dtype,
)
for i in range(1, len(cu_seqlens)):
attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0
return attention_mask
def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.Tensor:
"""
Args:
@ -682,14 +651,15 @@ class Glm4vVisionModel(Glm4vPreTrainedModel):
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
hidden_states = self.embeddings(hidden_states, seqlens, grid_thw, image_type_ids[:, 0], image_type_ids[:, 1])
attention_mask = self._prepare_attention_mask(hidden_states, cu_seqlens=cu_seqlens)
for blk in self.blocks:
if self.gradient_checkpointing and self.training:
hidden_states = self._gradient_checkpointing_func(
blk.__call__, hidden_states, cu_seqlens, None, position_embeddings
)
else:
hidden_states = blk(hidden_states, cu_seqlens=cu_seqlens, position_embeddings=position_embeddings)
hidden_states = blk(
hidden_states,
cu_seqlens=cu_seqlens,
position_embeddings=position_embeddings,
attention_mask=attention_mask,
)
hidden_states = self.post_layernorm(hidden_states)

View File

@ -607,6 +607,7 @@ class Qwen2_5OmniAudioAttention(nn.Module):
f" and `num_heads`: {self.num_heads})."
)
self.scaling = self.head_dim**-0.5
self.attention_dropout = 0.0
self.is_decoder = False
self.is_causal = False
@ -619,6 +620,7 @@ class Qwen2_5OmniAudioAttention(nn.Module):
self,
hidden_states: torch.Tensor,
cu_seqlens: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
**kwargs,
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""
@ -634,15 +636,6 @@ class Qwen2_5OmniAudioAttention(nn.Module):
value_states = value_states.transpose(0, 1).unsqueeze(0)
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
attention_mask = torch.full(
[1, 1, seq_length, key_states.shape[-2]],
torch.finfo(query_states.dtype).min,
device=query_states.device,
dtype=query_states.dtype,
)
for i in range(1, len(cu_seqlens)):
attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
@ -652,13 +645,13 @@ class Qwen2_5OmniAudioAttention(nn.Module):
query_states,
key_states,
value_states,
attention_mask,
dropout=0.0 if not self.training else self.dropout,
attention_mask=attention_mask,
dropout=0.0 if not self.training else self.attention_dropout,
scaling=self.scaling,
cu_seqlens_q=cu_seqlens, # pass cu seq lens for FA2
cu_seqlens_k=cu_seqlens,
max_seqlen_q=max_seqlen,
max_seqlen_k=max_seqlen,
cu_seq_lens_q=cu_seqlens, # pass cu seq lens for FA2
cu_seq_lens_k=cu_seqlens,
max_length_q=max_seqlen,
max_length_k=max_seqlen,
is_causal=False,
**kwargs,
)
@ -686,6 +679,7 @@ class Qwen2_5OmniAudioEncoderLayer(GradientCheckpointingLayer):
self,
hidden_states: torch.Tensor,
cu_seqlens: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
"""
@ -704,6 +698,7 @@ class Qwen2_5OmniAudioEncoderLayer(GradientCheckpointingLayer):
hidden_states = self.self_attn(
hidden_states=hidden_states,
cu_seqlens=cu_seqlens,
attention_mask=attention_mask,
**kwargs,
)
hidden_states = residual + hidden_states
@ -785,6 +780,25 @@ class Qwen2_5OmniAudioEncoder(Qwen2_5OmniPreTrainedModel):
def set_input_embeddings(self, value: nn.Module):
self.conv1 = value
def _prepare_attention_mask(self, inputs_tensor: torch.Tensor, cu_seqlens: torch.Tensor) -> torch.Tensor:
# Flash Attention 2 doesn't need a 4D mask and relies on `cu_seqlens/max_seqlen`
# NOTE: the created attention masl only approximates the ragged FA2 attention by
# allowing bidirectional attention within `cu_seqlens` blocks, and not attending between
# blocks. Though it will not be a 100% match for FA2's `varlen` path
if self.config._attn_implementation == "flash_attention_2":
return None
seq_length = inputs_tensor.shape[0]
attention_mask = torch.full(
[1, 1, seq_length, seq_length],
torch.finfo(inputs_tensor.dtype).min,
device=inputs_tensor.device,
dtype=inputs_tensor.dtype,
)
for i in range(1, len(cu_seqlens)):
attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0
return attention_mask
@auto_docstring
def forward(
self,
@ -833,9 +847,15 @@ class Qwen2_5OmniAudioEncoder(Qwen2_5OmniPreTrainedModel):
padded_mask_after_cnn.sum(1).cumsum(0),
)
).to(torch.int32)
attention_mask = self._prepare_attention_mask(hidden_states, cu_seqlens)
for encoder_layer in self.layers:
layer_outputs = encoder_layer(hidden_states, cu_seqlens, **kwargs)
layer_outputs = encoder_layer(
hidden_states,
cu_seqlens=cu_seqlens,
attention_mask=attention_mask,
**kwargs,
)
hidden_states = layer_outputs[0]
hidden_states_list = hidden_states.split(aftercnn_lens.tolist(), dim=0)
@ -928,12 +948,15 @@ class Qwen2_5OmniVisionAttention(nn.Module):
self.scaling = self.head_dim**-0.5
self.num_key_value_groups = 1 # needed for eager attention
self.config = config
self.attention_dropout = 0.0
self.is_causal = False
def forward(
self,
hidden_states: torch.Tensor,
cu_seqlens: torch.Tensor,
rotary_pos_emb: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
seq_length = hidden_states.shape[0]
@ -943,18 +966,9 @@ class Qwen2_5OmniVisionAttention(nn.Module):
query_states = apply_rotary_pos_emb_vision(query_states.unsqueeze(0), rotary_pos_emb).squeeze(0)
key_states = apply_rotary_pos_emb_vision(key_states.unsqueeze(0), rotary_pos_emb).squeeze(0)
attention_mask = torch.full(
[1, 1, seq_length, seq_length],
torch.finfo(query_states.dtype).min,
device=query_states.device,
dtype=query_states.dtype,
)
for i in range(1, len(cu_seqlens)):
attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0
query_states = query_states.transpose(0, 1).unsqueeze(0) # unsqueeze batch_dim
key_states = key_states.transpose(0, 1).unsqueeze(0) # unsqueeze batch_dim
value_states = value_states.transpose(0, 1).unsqueeze(0) # unsqueeze batch_dim
query_states = query_states.transpose(0, 1).unsqueeze(0)
key_states = key_states.transpose(0, 1).unsqueeze(0)
value_states = value_states.transpose(0, 1).unsqueeze(0)
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
attention_interface: Callable = eager_attention_forward
@ -966,13 +980,13 @@ class Qwen2_5OmniVisionAttention(nn.Module):
query_states,
key_states,
value_states,
attention_mask,
dropout=0.0,
attention_mask=attention_mask,
dropout=0.0 if not self.training else self.attention_dropout,
scaling=self.scaling,
cu_seqlens_q=cu_seqlens, # pass cu seq lens for FA2
cu_seqlens_k=cu_seqlens,
max_seqlen_q=max_seqlen,
max_seqlen_k=max_seqlen,
cu_seq_lens_q=cu_seqlens, # pass cu seq lens for FA2
cu_seq_lens_k=cu_seqlens,
max_length_q=max_seqlen,
max_length_k=max_seqlen,
is_causal=False,
**kwargs,
)
@ -1009,10 +1023,15 @@ class Qwen2_5OmniVisionBlock(GradientCheckpointingLayer):
hidden_states: torch.Tensor,
cu_seqlens: torch.Tensor,
rotary_pos_emb: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
hidden_states = hidden_states + self.attn(
self.norm1(hidden_states), cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb, **kwargs
self.norm1(hidden_states),
cu_seqlens=cu_seqlens,
rotary_pos_emb=rotary_pos_emb,
attention_mask=attention_mask,
**kwargs,
)
hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
return hidden_states
@ -1171,6 +1190,25 @@ class Qwen2_5OmniVisionEncoder(Qwen2_5OmniPreTrainedModel):
return window_index, cu_window_seqlens
def _prepare_attention_mask(self, inputs_tensor: torch.Tensor, cu_seqlens: torch.Tensor) -> torch.Tensor:
# Flash Attention 2 doesn't need a 4D mask and relies on `cu_seqlens/max_seqlen`
# NOTE: the created attention masl only approximates the ragged FA2 attention by
# allowing bidirectional attention within `cu_seqlens` blocks, and not attending between
# blocks. Though it will not be a 100% match for FA2's `varlen` path
if self.config._attn_implementation == "flash_attention_2":
return None
seq_length = inputs_tensor.shape[0]
attention_mask = torch.full(
[1, 1, seq_length, seq_length],
torch.finfo(inputs_tensor.dtype).min,
device=inputs_tensor.device,
dtype=inputs_tensor.dtype,
)
for i in range(1, len(cu_seqlens)):
attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0
return attention_mask
def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs) -> torch.Tensor:
"""
Args:
@ -1217,10 +1255,13 @@ class Qwen2_5OmniVisionEncoder(Qwen2_5OmniPreTrainedModel):
cu_seqlens_now = cu_seqlens
else:
cu_seqlens_now = cu_window_seqlens
attention_mask = self._prepare_attention_mask(hidden_states, cu_seqlens_now)
hidden_states = blk(
hidden_states,
cu_seqlens=cu_seqlens_now,
rotary_pos_emb=rotary_pos_emb,
attention_mask=attention_mask,
**kwargs,
)
hidden_states = self.merger(hidden_states)

View File

@ -1611,6 +1611,7 @@ class Qwen2_5OmniAudioAttention(nn.Module):
f" and `num_heads`: {self.num_heads})."
)
self.scaling = self.head_dim**-0.5
self.attention_dropout = 0.0
self.is_decoder = False
self.is_causal = False
@ -1623,6 +1624,7 @@ class Qwen2_5OmniAudioAttention(nn.Module):
self,
hidden_states: torch.Tensor,
cu_seqlens: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
**kwargs,
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""
@ -1638,15 +1640,6 @@ class Qwen2_5OmniAudioAttention(nn.Module):
value_states = value_states.transpose(0, 1).unsqueeze(0)
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
attention_mask = torch.full(
[1, 1, seq_length, key_states.shape[-2]],
torch.finfo(query_states.dtype).min,
device=query_states.device,
dtype=query_states.dtype,
)
for i in range(1, len(cu_seqlens)):
attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
@ -1656,13 +1649,13 @@ class Qwen2_5OmniAudioAttention(nn.Module):
query_states,
key_states,
value_states,
attention_mask,
dropout=0.0 if not self.training else self.dropout,
attention_mask=attention_mask,
dropout=0.0 if not self.training else self.attention_dropout,
scaling=self.scaling,
cu_seqlens_q=cu_seqlens, # pass cu seq lens for FA2
cu_seqlens_k=cu_seqlens,
max_seqlen_q=max_seqlen,
max_seqlen_k=max_seqlen,
cu_seq_lens_q=cu_seqlens, # pass cu seq lens for FA2
cu_seq_lens_k=cu_seqlens,
max_length_q=max_seqlen,
max_length_k=max_seqlen,
is_causal=False,
**kwargs,
)
@ -1682,6 +1675,7 @@ class Qwen2_5OmniAudioEncoderLayer(Qwen2AudioEncoderLayer):
self,
hidden_states: torch.Tensor,
cu_seqlens: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
residual = hidden_states
@ -1689,6 +1683,7 @@ class Qwen2_5OmniAudioEncoderLayer(Qwen2AudioEncoderLayer):
hidden_states = self.self_attn(
hidden_states=hidden_states,
cu_seqlens=cu_seqlens,
attention_mask=attention_mask,
**kwargs,
)
hidden_states = residual + hidden_states
@ -1770,6 +1765,25 @@ class Qwen2_5OmniAudioEncoder(Qwen2_5OmniPreTrainedModel):
def set_input_embeddings(self, value: nn.Module):
self.conv1 = value
def _prepare_attention_mask(self, inputs_tensor: torch.Tensor, cu_seqlens: torch.Tensor) -> torch.Tensor:
# Flash Attention 2 doesn't need a 4D mask and relies on `cu_seqlens/max_seqlen`
# NOTE: the created attention masl only approximates the ragged FA2 attention by
# allowing bidirectional attention within `cu_seqlens` blocks, and not attending between
# blocks. Though it will not be a 100% match for FA2's `varlen` path
if self.config._attn_implementation == "flash_attention_2":
return None
seq_length = inputs_tensor.shape[0]
attention_mask = torch.full(
[1, 1, seq_length, seq_length],
torch.finfo(inputs_tensor.dtype).min,
device=inputs_tensor.device,
dtype=inputs_tensor.dtype,
)
for i in range(1, len(cu_seqlens)):
attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0
return attention_mask
@auto_docstring
def forward(
self,
@ -1818,9 +1832,15 @@ class Qwen2_5OmniAudioEncoder(Qwen2_5OmniPreTrainedModel):
padded_mask_after_cnn.sum(1).cumsum(0),
)
).to(torch.int32)
attention_mask = self._prepare_attention_mask(hidden_states, cu_seqlens)
for encoder_layer in self.layers:
layer_outputs = encoder_layer(hidden_states, cu_seqlens, **kwargs)
layer_outputs = encoder_layer(
hidden_states,
cu_seqlens=cu_seqlens,
attention_mask=attention_mask,
**kwargs,
)
hidden_states = layer_outputs[0]
hidden_states_list = hidden_states.split(aftercnn_lens.tolist(), dim=0)
@ -1906,12 +1926,15 @@ class Qwen2_5OmniVisionAttention(nn.Module):
self.scaling = self.head_dim**-0.5
self.num_key_value_groups = 1 # needed for eager attention
self.config = config
self.attention_dropout = 0.0
self.is_causal = False
def forward(
self,
hidden_states: torch.Tensor,
cu_seqlens: torch.Tensor,
rotary_pos_emb: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
seq_length = hidden_states.shape[0]
@ -1921,18 +1944,9 @@ class Qwen2_5OmniVisionAttention(nn.Module):
query_states = apply_rotary_pos_emb_vision(query_states.unsqueeze(0), rotary_pos_emb).squeeze(0)
key_states = apply_rotary_pos_emb_vision(key_states.unsqueeze(0), rotary_pos_emb).squeeze(0)
attention_mask = torch.full(
[1, 1, seq_length, seq_length],
torch.finfo(query_states.dtype).min,
device=query_states.device,
dtype=query_states.dtype,
)
for i in range(1, len(cu_seqlens)):
attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0
query_states = query_states.transpose(0, 1).unsqueeze(0) # unsqueeze batch_dim
key_states = key_states.transpose(0, 1).unsqueeze(0) # unsqueeze batch_dim
value_states = value_states.transpose(0, 1).unsqueeze(0) # unsqueeze batch_dim
query_states = query_states.transpose(0, 1).unsqueeze(0)
key_states = key_states.transpose(0, 1).unsqueeze(0)
value_states = value_states.transpose(0, 1).unsqueeze(0)
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
attention_interface: Callable = eager_attention_forward
@ -1944,13 +1958,13 @@ class Qwen2_5OmniVisionAttention(nn.Module):
query_states,
key_states,
value_states,
attention_mask,
dropout=0.0,
attention_mask=attention_mask,
dropout=0.0 if not self.training else self.attention_dropout,
scaling=self.scaling,
cu_seqlens_q=cu_seqlens, # pass cu seq lens for FA2
cu_seqlens_k=cu_seqlens,
max_seqlen_q=max_seqlen,
max_seqlen_k=max_seqlen,
cu_seq_lens_q=cu_seqlens, # pass cu seq lens for FA2
cu_seq_lens_k=cu_seqlens,
max_length_q=max_seqlen,
max_length_k=max_seqlen,
is_causal=False,
**kwargs,
)
@ -1970,10 +1984,15 @@ class Qwen2_5OmniVisionBlock(Qwen2_5_VLVisionBlock):
hidden_states: torch.Tensor,
cu_seqlens: torch.Tensor,
rotary_pos_emb: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
hidden_states = hidden_states + self.attn(
self.norm1(hidden_states), cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb, **kwargs
self.norm1(hidden_states),
cu_seqlens=cu_seqlens,
rotary_pos_emb=rotary_pos_emb,
attention_mask=attention_mask,
**kwargs,
)
hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
return hidden_states
@ -1987,6 +2006,25 @@ class Qwen2_5OmniVisionEncoder(Qwen2_5_VisionTransformerPretrainedModel):
super().__init__(config, *inputs, **kwargs)
self.blocks = nn.ModuleList([Qwen2_5OmniVisionBlock(config) for _ in range(config.depth)])
def _prepare_attention_mask(self, inputs_tensor: torch.Tensor, cu_seqlens: torch.Tensor) -> torch.Tensor:
# Flash Attention 2 doesn't need a 4D mask and relies on `cu_seqlens/max_seqlen`
# NOTE: the created attention masl only approximates the ragged FA2 attention by
# allowing bidirectional attention within `cu_seqlens` blocks, and not attending between
# blocks. Though it will not be a 100% match for FA2's `varlen` path
if self.config._attn_implementation == "flash_attention_2":
return None
seq_length = inputs_tensor.shape[0]
attention_mask = torch.full(
[1, 1, seq_length, seq_length],
torch.finfo(inputs_tensor.dtype).min,
device=inputs_tensor.device,
dtype=inputs_tensor.dtype,
)
for i in range(1, len(cu_seqlens)):
attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0
return attention_mask
def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs) -> torch.Tensor:
"""
Args:
@ -2033,10 +2071,13 @@ class Qwen2_5OmniVisionEncoder(Qwen2_5_VisionTransformerPretrainedModel):
cu_seqlens_now = cu_seqlens
else:
cu_seqlens_now = cu_window_seqlens
attention_mask = self._prepare_attention_mask(hidden_states, cu_seqlens_now)
hidden_states = blk(
hidden_states,
cu_seqlens=cu_seqlens_now,
rotary_pos_emb=rotary_pos_emb,
attention_mask=attention_mask,
**kwargs,
)
hidden_states = self.merger(hidden_states)

View File

@ -206,6 +206,8 @@ class Qwen2_5_VLVisionAttention(nn.Module):
self.proj = nn.Linear(self.dim, self.dim)
self.scaling = self.head_dim**-0.5
self.config = config
self.attention_dropout = 0.0
self.is_causal = False
def forward(
self,
@ -213,6 +215,7 @@ class Qwen2_5_VLVisionAttention(nn.Module):
cu_seqlens: torch.Tensor,
rotary_pos_emb: Optional[torch.Tensor] = None,
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
seq_length = hidden_states.shape[0]
@ -233,18 +236,9 @@ class Qwen2_5_VLVisionAttention(nn.Module):
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb_vision(query_states, key_states, cos, sin)
attention_mask = torch.full(
[1, 1, seq_length, seq_length],
torch.finfo(value_states.dtype).min,
device=value_states.device,
dtype=value_states.dtype,
)
for i in range(1, len(cu_seqlens)):
attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0
query_states = query_states.transpose(0, 1).unsqueeze(0) # unsqueeze batch_dim
key_states = key_states.transpose(0, 1).unsqueeze(0) # unsqueeze batch_dim
value_states = value_states.transpose(0, 1).unsqueeze(0) # unsqueeze batch_dim
query_states = query_states.transpose(0, 1).unsqueeze(0)
key_states = key_states.transpose(0, 1).unsqueeze(0)
value_states = value_states.transpose(0, 1).unsqueeze(0)
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
attention_interface: Callable = eager_attention_forward
@ -256,13 +250,13 @@ class Qwen2_5_VLVisionAttention(nn.Module):
query_states,
key_states,
value_states,
attention_mask,
dropout=0.0,
attention_mask=attention_mask,
dropout=0.0 if not self.training else self.attention_dropout,
scaling=self.scaling,
cu_seqlens_q=cu_seqlens, # pass cu seq lens for FA2
cu_seqlens_k=cu_seqlens,
max_seqlen_q=max_seqlen,
max_seqlen_k=max_seqlen,
cu_seq_lens_q=cu_seqlens, # pass cu seq lens for FA2
cu_seq_lens_k=cu_seqlens,
max_length_q=max_seqlen,
max_length_k=max_seqlen,
is_causal=False,
**kwargs,
)
@ -286,6 +280,7 @@ class Qwen2_5_VLVisionBlock(GradientCheckpointingLayer):
cu_seqlens: torch.Tensor,
rotary_pos_emb: Optional[torch.Tensor] = None,
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
hidden_states = hidden_states + self.attn(
@ -293,6 +288,7 @@ class Qwen2_5_VLVisionBlock(GradientCheckpointingLayer):
cu_seqlens=cu_seqlens,
rotary_pos_emb=rotary_pos_emb,
position_embeddings=position_embeddings,
attention_mask=attention_mask,
**kwargs,
)
hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
@ -426,6 +422,25 @@ class Qwen2_5_VisionTransformerPretrainedModel(Qwen2_5_VLPreTrainedModel):
return window_index, cu_window_seqlens
def _prepare_attention_mask(self, inputs_tensor: torch.Tensor, cu_seqlens: torch.Tensor) -> torch.Tensor:
# Flash Attention 2 doesn't need a 4D mask and relies on `cu_seqlens/max_seqlen`
# NOTE: the created attention masl only approximates the ragged FA2 attention by
# allowing bidirectional attention within `cu_seqlens` blocks, and not attending between
# blocks. Though it will not be a 100% match for FA2's `varlen` path
if self.config._attn_implementation == "flash_attention_2":
return None
seq_length = inputs_tensor.shape[0]
attention_mask = torch.full(
[1, 1, seq_length, seq_length],
torch.finfo(inputs_tensor.dtype).min,
device=inputs_tensor.device,
dtype=inputs_tensor.dtype,
)
for i in range(1, len(cu_seqlens)):
attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0
return attention_mask
def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs) -> torch.Tensor:
"""
Args:
@ -472,8 +487,14 @@ class Qwen2_5_VisionTransformerPretrainedModel(Qwen2_5_VLPreTrainedModel):
cu_seqlens_now = cu_seqlens
else:
cu_seqlens_now = cu_window_seqlens
attention_mask = self._prepare_attention_mask(hidden_states, cu_seqlens_now)
hidden_states = blk(
hidden_states, cu_seqlens=cu_seqlens_now, position_embeddings=position_embeddings, **kwargs
hidden_states,
cu_seqlens=cu_seqlens_now,
position_embeddings=position_embeddings,
attention_mask=attention_mask,
**kwargs,
)
hidden_states = self.merger(hidden_states)

View File

@ -159,6 +159,7 @@ class Qwen2_5_VLVisionBlock(GradientCheckpointingLayer):
cu_seqlens: torch.Tensor,
rotary_pos_emb: Optional[torch.Tensor] = None,
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
hidden_states = hidden_states + self.attn(
@ -166,6 +167,7 @@ class Qwen2_5_VLVisionBlock(GradientCheckpointingLayer):
cu_seqlens=cu_seqlens,
rotary_pos_emb=rotary_pos_emb,
position_embeddings=position_embeddings,
attention_mask=attention_mask,
**kwargs,
)
hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
@ -287,6 +289,25 @@ class Qwen2_5_VisionTransformerPretrainedModel(Qwen2_5_VLPreTrainedModel):
return window_index, cu_window_seqlens
def _prepare_attention_mask(self, inputs_tensor: torch.Tensor, cu_seqlens: torch.Tensor) -> torch.Tensor:
# Flash Attention 2 doesn't need a 4D mask and relies on `cu_seqlens/max_seqlen`
# NOTE: the created attention masl only approximates the ragged FA2 attention by
# allowing bidirectional attention within `cu_seqlens` blocks, and not attending between
# blocks. Though it will not be a 100% match for FA2's `varlen` path
if self.config._attn_implementation == "flash_attention_2":
return None
seq_length = inputs_tensor.shape[0]
attention_mask = torch.full(
[1, 1, seq_length, seq_length],
torch.finfo(inputs_tensor.dtype).min,
device=inputs_tensor.device,
dtype=inputs_tensor.dtype,
)
for i in range(1, len(cu_seqlens)):
attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0
return attention_mask
def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs) -> torch.Tensor:
"""
Args:
@ -333,8 +354,14 @@ class Qwen2_5_VisionTransformerPretrainedModel(Qwen2_5_VLPreTrainedModel):
cu_seqlens_now = cu_seqlens
else:
cu_seqlens_now = cu_window_seqlens
attention_mask = self._prepare_attention_mask(hidden_states, cu_seqlens_now)
hidden_states = blk(
hidden_states, cu_seqlens=cu_seqlens_now, position_embeddings=position_embeddings, **kwargs
hidden_states,
cu_seqlens=cu_seqlens_now,
position_embeddings=position_embeddings,
attention_mask=attention_mask,
**kwargs,
)
hidden_states = self.merger(hidden_states)

View File

@ -324,6 +324,8 @@ class VisionAttention(nn.Module):
self.proj = nn.Linear(self.dim, self.dim)
self.scaling = self.head_dim**-0.5
self.config = config
self.attention_dropout = 0.0
self.is_causal = False
def forward(
self,
@ -331,6 +333,7 @@ class VisionAttention(nn.Module):
cu_seqlens: torch.Tensor,
rotary_pos_emb: Optional[torch.Tensor] = None,
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
seq_length = hidden_states.shape[0]
@ -351,18 +354,9 @@ class VisionAttention(nn.Module):
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb_vision(query_states, key_states, cos, sin)
attention_mask = torch.full(
[1, 1, seq_length, seq_length],
torch.finfo(value_states.dtype).min,
device=value_states.device,
dtype=value_states.dtype,
)
for i in range(1, len(cu_seqlens)):
attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0
query_states = query_states.transpose(0, 1).unsqueeze(0) # unsqueeze batch_dim
key_states = key_states.transpose(0, 1).unsqueeze(0) # unsqueeze batch_dim
value_states = value_states.transpose(0, 1).unsqueeze(0) # unsqueeze batch_dim
query_states = query_states.transpose(0, 1).unsqueeze(0)
key_states = key_states.transpose(0, 1).unsqueeze(0)
value_states = value_states.transpose(0, 1).unsqueeze(0)
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
attention_interface: Callable = eager_attention_forward
@ -374,13 +368,13 @@ class VisionAttention(nn.Module):
query_states,
key_states,
value_states,
attention_mask,
dropout=0.0,
attention_mask=attention_mask,
dropout=0.0 if not self.training else self.attention_dropout,
scaling=self.scaling,
cu_seqlens_q=cu_seqlens, # pass cu seq lens for FA2
cu_seqlens_k=cu_seqlens,
max_seqlen_q=max_seqlen,
max_seqlen_k=max_seqlen,
cu_seq_lens_q=cu_seqlens, # pass cu seq lens for FA2
cu_seq_lens_k=cu_seqlens,
max_length_q=max_seqlen,
max_length_k=max_seqlen,
is_causal=False,
**kwargs,
)
@ -406,6 +400,7 @@ class Qwen2VLVisionBlock(GradientCheckpointingLayer):
cu_seqlens: torch.Tensor,
rotary_pos_emb: Optional[torch.Tensor] = None,
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
hidden_states = hidden_states + self.attn(
@ -413,6 +408,7 @@ class Qwen2VLVisionBlock(GradientCheckpointingLayer):
cu_seqlens=cu_seqlens,
rotary_pos_emb=rotary_pos_emb,
position_embeddings=position_embeddings,
attention_mask=attention_mask,
**kwargs,
)
hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
@ -725,6 +721,25 @@ class Qwen2VisionTransformerPretrainedModel(Qwen2VLPreTrainedModel):
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
return rotary_pos_emb
def _prepare_attention_mask(self, inputs_tensor: torch.Tensor, cu_seqlens: torch.Tensor) -> torch.Tensor:
# Flash Attention 2 doesn't need a 4D mask and relies on `cu_seqlens/max_seqlen`
# NOTE: the created attention masl only approximates the ragged FA2 attention by
# allowing bidirectional attention within `cu_seqlens` blocks, and not attending between
# blocks. Though it will not be a 100% match for FA2's `varlen` path
if self.config._attn_implementation == "flash_attention_2":
return None
seq_length = inputs_tensor.shape[0]
attention_mask = torch.full(
[1, 1, seq_length, seq_length],
torch.finfo(inputs_tensor.dtype).min,
device=inputs_tensor.device,
dtype=inputs_tensor.dtype,
)
for i in range(1, len(cu_seqlens)):
attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0
return attention_mask
@auto_docstring
def forward(
self,
@ -750,10 +765,15 @@ class Qwen2VisionTransformerPretrainedModel(Qwen2VLPreTrainedModel):
dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32,
)
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
attention_mask = self._prepare_attention_mask(hidden_states, cu_seqlens)
for blk in self.blocks:
hidden_states = blk(
hidden_states, cu_seqlens=cu_seqlens, position_embeddings=position_embeddings, **kwargs
hidden_states,
cu_seqlens=cu_seqlens,
position_embeddings=position_embeddings,
attention_mask=attention_mask,
**kwargs,
)
return self.merger(hidden_states)

View File

@ -184,7 +184,7 @@ class Qwen2_5_VLVisionText2TextModelTester:
input_ids[:, self.num_image_tokens - 1] = self.vision_start_token_id
inputs_dict = {
"pixel_values": pixel_values,
"image_grid_thw": torch.tensor([[1, 1, 1]] * self.batch_size),
"image_grid_thw": torch.tensor([[1, 1, 1]] * self.batch_size, device=torch_device),
"input_ids": input_ids,
"attention_mask": attention_mask,
}

View File

@ -176,7 +176,7 @@ class Qwen2VLVisionText2TextModelTester:
inputs_dict = {
"pixel_values": pixel_values,
"image_grid_thw": torch.tensor([[1, 1, 1]] * self.batch_size),
"image_grid_thw": torch.tensor([[1, 1, 1]] * self.batch_size, device=torch_device),
"input_ids": input_ids,
"attention_mask": attention_mask,
}