mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
[qwen2-vl] fix FA2 inference (#39121)
* fix FA2 * update is causal flag and remove mask for FA2 * update for FA2 with varlen path * how the tests were passing with different devices? * add comment and ref to the PR * move mask preparation to base pretrained model * seq len is the first dim, not second * fix copies to fix GLM4V
This commit is contained in:
parent
def9663239
commit
7a25f8dfdb
@ -508,6 +508,22 @@ def _flash_attention_forward(
|
||||
query_states, key_states, value_states, target_dtype
|
||||
)
|
||||
|
||||
# We will use `flash_attn_varlen_func` to prevent cross-example attention and also allow padding free approach
|
||||
# under two cases:
|
||||
# Case 1. If position_ids is provided and check all examples do not contain only 1 sequence, If tensor in increasing
|
||||
# then we probably have one sequence, otherwise it is packed. Additionally check we are in pre-fill/training stage.
|
||||
# Case 2. Some models pass directly pre-computed `cu_seqlens` so we don't need to infer it from position ids. It is safe to
|
||||
# use `flash_attn_varlen_func` knowing we already have all necessary the kwargs. NOTE: it is user's responsibility
|
||||
# to take care of flattenning `position_ids` if that's needed by the model. See #39121 for more information
|
||||
is_fa2_with_position_ids = (
|
||||
position_ids is not None
|
||||
and query_states.shape[0] == 1
|
||||
and (max_length_q is not None or (query_length != 1 and not (torch.diff(position_ids, dim=-1) >= 0).all()))
|
||||
)
|
||||
is_fa2_with_varlen_kwargs = all(
|
||||
kwarg is not None for kwarg in (cu_seq_lens_q, cu_seq_lens_k, max_length_q, max_length_k)
|
||||
)
|
||||
|
||||
# Contains at least one padding token in the sequence
|
||||
if attention_mask is not None:
|
||||
batch_size = query_states.shape[0]
|
||||
@ -531,14 +547,7 @@ def _flash_attention_forward(
|
||||
)
|
||||
attn_output = _pad_input(attn_output_unpad, indices_q, batch_size, query_length)
|
||||
|
||||
# If position_ids is provided and check all examples do not contain only 1 sequence, If tensor in increasing
|
||||
# then we probably have one sequence, otherwise it is packed. Additionally check we are in pre-fill/training stage.
|
||||
# Use `flash_attn_varlen_func` to prevent cross-example attention and also allow padding free approach
|
||||
elif (
|
||||
position_ids is not None
|
||||
and query_states.shape[0] == 1
|
||||
and (max_length_q is not None or (query_length != 1 and not (torch.diff(position_ids, dim=-1) >= 0).all()))
|
||||
):
|
||||
elif is_fa2_with_varlen_kwargs or is_fa2_with_position_ids:
|
||||
batch_size = query_states.size(0)
|
||||
|
||||
if cu_seq_lens_q is None or cu_seq_lens_k is None:
|
||||
|
@ -279,14 +279,15 @@ def eager_attention_forward(
|
||||
class Glm4vVisionAttention(nn.Module):
|
||||
def __init__(self, config: Glm4vVisionConfig) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.dim = config.hidden_size
|
||||
self.num_heads = config.num_heads
|
||||
self.head_dim = config.hidden_size // self.num_heads
|
||||
self.num_key_value_groups = 1
|
||||
self.scale = self.head_dim**-0.5
|
||||
self.attention_dropout = config.attention_dropout
|
||||
self.head_dim = self.dim // self.num_heads
|
||||
self.num_key_value_groups = 1 # needed for eager attention
|
||||
self.qkv = nn.Linear(config.hidden_size, config.hidden_size * 3, bias=config.attention_bias)
|
||||
self.proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
|
||||
self.scaling = self.head_dim**-0.5
|
||||
self.config = config
|
||||
self.attention_dropout = config.attention_dropout
|
||||
self.is_causal = False
|
||||
|
||||
def forward(
|
||||
@ -295,23 +296,31 @@ class Glm4vVisionAttention(nn.Module):
|
||||
cu_seqlens: torch.Tensor,
|
||||
rotary_pos_emb: Optional[torch.Tensor] = None,
|
||||
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
seq_length = hidden_states.shape[0]
|
||||
query_states, key_states, value_states = (
|
||||
self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
|
||||
)
|
||||
|
||||
cos, sin = position_embeddings
|
||||
if position_embeddings is None:
|
||||
logger.warning_once(
|
||||
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
|
||||
"through `rotary_pos_emb` (2D tensor of RoPE theta values), to using externally computed "
|
||||
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.54 `rotary_pos_emb` will be "
|
||||
"removed and `position_embeddings` will be mandatory."
|
||||
)
|
||||
emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
|
||||
cos = emb.cos()
|
||||
sin = emb.sin()
|
||||
else:
|
||||
cos, sin = position_embeddings
|
||||
query_states, key_states = apply_rotary_pos_emb_vision(query_states, key_states, cos, sin)
|
||||
|
||||
query_states = query_states.transpose(0, 1).unsqueeze(0)
|
||||
key_states = key_states.transpose(0, 1).unsqueeze(0)
|
||||
value_states = value_states.transpose(0, 1).unsqueeze(0)
|
||||
|
||||
attention_mask = torch.zeros([1, 1, seq_length, seq_length], device=query_states.device, dtype=torch.bool)
|
||||
for i in range(1, len(cu_seqlens)):
|
||||
attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = True
|
||||
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
|
||||
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
if self.config._attn_implementation != "eager":
|
||||
@ -322,13 +331,17 @@ class Glm4vVisionAttention(nn.Module):
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attention_mask,
|
||||
attention_mask=attention_mask,
|
||||
dropout=0.0 if not self.training else self.attention_dropout,
|
||||
scaling=self.scale,
|
||||
is_causal=self.is_causal,
|
||||
scaling=self.scaling,
|
||||
cu_seq_lens_q=cu_seqlens, # pass cu seq lens for FA2
|
||||
cu_seq_lens_k=cu_seqlens,
|
||||
max_length_q=max_seqlen,
|
||||
max_length_k=max_seqlen,
|
||||
is_causal=False,
|
||||
**kwargs,
|
||||
)
|
||||
attn_output = attn_output.squeeze(0)
|
||||
|
||||
attn_output = attn_output.reshape(seq_length, -1).contiguous()
|
||||
attn_output = self.proj(attn_output)
|
||||
return attn_output
|
||||
@ -348,6 +361,7 @@ class Glm4vVisionBlock(GradientCheckpointingLayer):
|
||||
cu_seqlens: torch.Tensor,
|
||||
rotary_pos_emb: Optional[torch.Tensor] = None,
|
||||
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = hidden_states + self.attn(
|
||||
@ -355,6 +369,7 @@ class Glm4vVisionBlock(GradientCheckpointingLayer):
|
||||
cu_seqlens=cu_seqlens,
|
||||
rotary_pos_emb=rotary_pos_emb,
|
||||
position_embeddings=position_embeddings,
|
||||
attention_mask=attention_mask,
|
||||
**kwargs,
|
||||
)
|
||||
hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
|
||||
@ -452,6 +467,25 @@ class Glm4vVisionModel(Glm4vPreTrainedModel):
|
||||
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
|
||||
return rotary_pos_emb, pos_ids
|
||||
|
||||
def _prepare_attention_mask(self, inputs_tensor: torch.Tensor, cu_seqlens: torch.Tensor) -> torch.Tensor:
|
||||
# Flash Attention 2 doesn't need a 4D mask and relies on `cu_seqlens/max_seqlen`
|
||||
# NOTE: the created attention masl only approximates the ragged FA2 attention by
|
||||
# allowing bidirectional attention within `cu_seqlens` blocks, and not attending between
|
||||
# blocks. Though it will not be a 100% match for FA2's `varlen` path
|
||||
if self.config._attn_implementation == "flash_attention_2":
|
||||
return None
|
||||
|
||||
seq_length = inputs_tensor.shape[0]
|
||||
attention_mask = torch.full(
|
||||
[1, 1, seq_length, seq_length],
|
||||
torch.finfo(inputs_tensor.dtype).min,
|
||||
device=inputs_tensor.device,
|
||||
dtype=inputs_tensor.dtype,
|
||||
)
|
||||
for i in range(1, len(cu_seqlens)):
|
||||
attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0
|
||||
return attention_mask
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
@ -481,14 +515,15 @@ class Glm4vVisionModel(Glm4vPreTrainedModel):
|
||||
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
|
||||
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
|
||||
hidden_states = self.embeddings(hidden_states, seqlens, grid_thw, image_type_ids[:, 0], image_type_ids[:, 1])
|
||||
attention_mask = self._prepare_attention_mask(hidden_states, cu_seqlens=cu_seqlens)
|
||||
|
||||
for blk in self.blocks:
|
||||
if self.gradient_checkpointing and self.training:
|
||||
hidden_states = self._gradient_checkpointing_func(
|
||||
blk.__call__, hidden_states, cu_seqlens, None, position_embeddings
|
||||
)
|
||||
else:
|
||||
hidden_states = blk(hidden_states, cu_seqlens=cu_seqlens, position_embeddings=position_embeddings)
|
||||
hidden_states = blk(
|
||||
hidden_states,
|
||||
cu_seqlens=cu_seqlens,
|
||||
position_embeddings=position_embeddings,
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
|
||||
hidden_states = self.post_layernorm(hidden_states)
|
||||
|
||||
|
@ -50,8 +50,8 @@ from ..qwen2_5_vl.modeling_qwen2_5_vl import (
|
||||
Qwen2_5_VLPreTrainedModel,
|
||||
Qwen2_5_VLRotaryEmbedding,
|
||||
Qwen2_5_VLTextModel,
|
||||
Qwen2_5_VLVisionAttention,
|
||||
Qwen2_5_VLVisionBlock,
|
||||
apply_rotary_pos_emb_vision,
|
||||
)
|
||||
from ..qwen2_5_vl.processing_qwen2_5_vl import (
|
||||
Qwen2_5_VLProcessor,
|
||||
@ -505,62 +505,12 @@ class Glm4vVisionEmbeddings(nn.Module):
|
||||
return embeddings
|
||||
|
||||
|
||||
class Glm4vVisionAttention(nn.Module):
|
||||
class Glm4vVisionAttention(Qwen2_5_VLVisionAttention):
|
||||
def __init__(self, config: Glm4vVisionConfig) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.num_heads = config.num_heads
|
||||
self.head_dim = config.hidden_size // self.num_heads
|
||||
self.num_key_value_groups = 1
|
||||
self.scale = self.head_dim**-0.5
|
||||
self.attention_dropout = config.attention_dropout
|
||||
self.qkv = nn.Linear(config.hidden_size, config.hidden_size * 3, bias=config.attention_bias)
|
||||
self.proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
|
||||
self.is_causal = False
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor,
|
||||
rotary_pos_emb: Optional[torch.Tensor] = None,
|
||||
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> torch.Tensor:
|
||||
seq_length = hidden_states.shape[0]
|
||||
query_states, key_states, value_states = (
|
||||
self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
|
||||
)
|
||||
|
||||
cos, sin = position_embeddings
|
||||
query_states, key_states = apply_rotary_pos_emb_vision(query_states, key_states, cos, sin)
|
||||
|
||||
query_states = query_states.transpose(0, 1).unsqueeze(0)
|
||||
key_states = key_states.transpose(0, 1).unsqueeze(0)
|
||||
value_states = value_states.transpose(0, 1).unsqueeze(0)
|
||||
|
||||
attention_mask = torch.zeros([1, 1, seq_length, seq_length], device=query_states.device, dtype=torch.bool)
|
||||
for i in range(1, len(cu_seqlens)):
|
||||
attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = True
|
||||
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
if self.config._attn_implementation != "eager":
|
||||
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
||||
|
||||
attn_output, _ = attention_interface(
|
||||
self,
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attention_mask,
|
||||
dropout=0.0 if not self.training else self.attention_dropout,
|
||||
scaling=self.scale,
|
||||
is_causal=self.is_causal,
|
||||
**kwargs,
|
||||
)
|
||||
attn_output = attn_output.squeeze(0)
|
||||
attn_output = attn_output.reshape(seq_length, -1).contiguous()
|
||||
attn_output = self.proj(attn_output)
|
||||
return attn_output
|
||||
|
||||
|
||||
class Glm4vVisionBlock(Qwen2_5_VLVisionBlock):
|
||||
@ -653,6 +603,25 @@ class Glm4vVisionModel(Glm4vPreTrainedModel):
|
||||
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
|
||||
return rotary_pos_emb, pos_ids
|
||||
|
||||
def _prepare_attention_mask(self, inputs_tensor: torch.Tensor, cu_seqlens: torch.Tensor) -> torch.Tensor:
|
||||
# Flash Attention 2 doesn't need a 4D mask and relies on `cu_seqlens/max_seqlen`
|
||||
# NOTE: the created attention masl only approximates the ragged FA2 attention by
|
||||
# allowing bidirectional attention within `cu_seqlens` blocks, and not attending between
|
||||
# blocks. Though it will not be a 100% match for FA2's `varlen` path
|
||||
if self.config._attn_implementation == "flash_attention_2":
|
||||
return None
|
||||
|
||||
seq_length = inputs_tensor.shape[0]
|
||||
attention_mask = torch.full(
|
||||
[1, 1, seq_length, seq_length],
|
||||
torch.finfo(inputs_tensor.dtype).min,
|
||||
device=inputs_tensor.device,
|
||||
dtype=inputs_tensor.dtype,
|
||||
)
|
||||
for i in range(1, len(cu_seqlens)):
|
||||
attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0
|
||||
return attention_mask
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
@ -682,14 +651,15 @@ class Glm4vVisionModel(Glm4vPreTrainedModel):
|
||||
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
|
||||
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
|
||||
hidden_states = self.embeddings(hidden_states, seqlens, grid_thw, image_type_ids[:, 0], image_type_ids[:, 1])
|
||||
attention_mask = self._prepare_attention_mask(hidden_states, cu_seqlens=cu_seqlens)
|
||||
|
||||
for blk in self.blocks:
|
||||
if self.gradient_checkpointing and self.training:
|
||||
hidden_states = self._gradient_checkpointing_func(
|
||||
blk.__call__, hidden_states, cu_seqlens, None, position_embeddings
|
||||
)
|
||||
else:
|
||||
hidden_states = blk(hidden_states, cu_seqlens=cu_seqlens, position_embeddings=position_embeddings)
|
||||
hidden_states = blk(
|
||||
hidden_states,
|
||||
cu_seqlens=cu_seqlens,
|
||||
position_embeddings=position_embeddings,
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
|
||||
hidden_states = self.post_layernorm(hidden_states)
|
||||
|
||||
|
@ -607,6 +607,7 @@ class Qwen2_5OmniAudioAttention(nn.Module):
|
||||
f" and `num_heads`: {self.num_heads})."
|
||||
)
|
||||
self.scaling = self.head_dim**-0.5
|
||||
self.attention_dropout = 0.0
|
||||
self.is_decoder = False
|
||||
self.is_causal = False
|
||||
|
||||
@ -619,6 +620,7 @@ class Qwen2_5OmniAudioAttention(nn.Module):
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
cu_seqlens: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
|
||||
"""Input shape: Batch x Time x Channel"""
|
||||
@ -634,15 +636,6 @@ class Qwen2_5OmniAudioAttention(nn.Module):
|
||||
value_states = value_states.transpose(0, 1).unsqueeze(0)
|
||||
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
|
||||
|
||||
attention_mask = torch.full(
|
||||
[1, 1, seq_length, key_states.shape[-2]],
|
||||
torch.finfo(query_states.dtype).min,
|
||||
device=query_states.device,
|
||||
dtype=query_states.dtype,
|
||||
)
|
||||
for i in range(1, len(cu_seqlens)):
|
||||
attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0
|
||||
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
if self.config._attn_implementation != "eager":
|
||||
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
||||
@ -652,13 +645,13 @@ class Qwen2_5OmniAudioAttention(nn.Module):
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attention_mask,
|
||||
dropout=0.0 if not self.training else self.dropout,
|
||||
attention_mask=attention_mask,
|
||||
dropout=0.0 if not self.training else self.attention_dropout,
|
||||
scaling=self.scaling,
|
||||
cu_seqlens_q=cu_seqlens, # pass cu seq lens for FA2
|
||||
cu_seqlens_k=cu_seqlens,
|
||||
max_seqlen_q=max_seqlen,
|
||||
max_seqlen_k=max_seqlen,
|
||||
cu_seq_lens_q=cu_seqlens, # pass cu seq lens for FA2
|
||||
cu_seq_lens_k=cu_seqlens,
|
||||
max_length_q=max_seqlen,
|
||||
max_length_k=max_seqlen,
|
||||
is_causal=False,
|
||||
**kwargs,
|
||||
)
|
||||
@ -686,6 +679,7 @@ class Qwen2_5OmniAudioEncoderLayer(GradientCheckpointingLayer):
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
@ -704,6 +698,7 @@ class Qwen2_5OmniAudioEncoderLayer(GradientCheckpointingLayer):
|
||||
hidden_states = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
cu_seqlens=cu_seqlens,
|
||||
attention_mask=attention_mask,
|
||||
**kwargs,
|
||||
)
|
||||
hidden_states = residual + hidden_states
|
||||
@ -785,6 +780,25 @@ class Qwen2_5OmniAudioEncoder(Qwen2_5OmniPreTrainedModel):
|
||||
def set_input_embeddings(self, value: nn.Module):
|
||||
self.conv1 = value
|
||||
|
||||
def _prepare_attention_mask(self, inputs_tensor: torch.Tensor, cu_seqlens: torch.Tensor) -> torch.Tensor:
|
||||
# Flash Attention 2 doesn't need a 4D mask and relies on `cu_seqlens/max_seqlen`
|
||||
# NOTE: the created attention masl only approximates the ragged FA2 attention by
|
||||
# allowing bidirectional attention within `cu_seqlens` blocks, and not attending between
|
||||
# blocks. Though it will not be a 100% match for FA2's `varlen` path
|
||||
if self.config._attn_implementation == "flash_attention_2":
|
||||
return None
|
||||
|
||||
seq_length = inputs_tensor.shape[0]
|
||||
attention_mask = torch.full(
|
||||
[1, 1, seq_length, seq_length],
|
||||
torch.finfo(inputs_tensor.dtype).min,
|
||||
device=inputs_tensor.device,
|
||||
dtype=inputs_tensor.dtype,
|
||||
)
|
||||
for i in range(1, len(cu_seqlens)):
|
||||
attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0
|
||||
return attention_mask
|
||||
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
@ -833,9 +847,15 @@ class Qwen2_5OmniAudioEncoder(Qwen2_5OmniPreTrainedModel):
|
||||
padded_mask_after_cnn.sum(1).cumsum(0),
|
||||
)
|
||||
).to(torch.int32)
|
||||
attention_mask = self._prepare_attention_mask(hidden_states, cu_seqlens)
|
||||
|
||||
for encoder_layer in self.layers:
|
||||
layer_outputs = encoder_layer(hidden_states, cu_seqlens, **kwargs)
|
||||
layer_outputs = encoder_layer(
|
||||
hidden_states,
|
||||
cu_seqlens=cu_seqlens,
|
||||
attention_mask=attention_mask,
|
||||
**kwargs,
|
||||
)
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
hidden_states_list = hidden_states.split(aftercnn_lens.tolist(), dim=0)
|
||||
@ -928,12 +948,15 @@ class Qwen2_5OmniVisionAttention(nn.Module):
|
||||
self.scaling = self.head_dim**-0.5
|
||||
self.num_key_value_groups = 1 # needed for eager attention
|
||||
self.config = config
|
||||
self.attention_dropout = 0.0
|
||||
self.is_causal = False
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor,
|
||||
rotary_pos_emb: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
seq_length = hidden_states.shape[0]
|
||||
@ -943,18 +966,9 @@ class Qwen2_5OmniVisionAttention(nn.Module):
|
||||
query_states = apply_rotary_pos_emb_vision(query_states.unsqueeze(0), rotary_pos_emb).squeeze(0)
|
||||
key_states = apply_rotary_pos_emb_vision(key_states.unsqueeze(0), rotary_pos_emb).squeeze(0)
|
||||
|
||||
attention_mask = torch.full(
|
||||
[1, 1, seq_length, seq_length],
|
||||
torch.finfo(query_states.dtype).min,
|
||||
device=query_states.device,
|
||||
dtype=query_states.dtype,
|
||||
)
|
||||
for i in range(1, len(cu_seqlens)):
|
||||
attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0
|
||||
|
||||
query_states = query_states.transpose(0, 1).unsqueeze(0) # unsqueeze batch_dim
|
||||
key_states = key_states.transpose(0, 1).unsqueeze(0) # unsqueeze batch_dim
|
||||
value_states = value_states.transpose(0, 1).unsqueeze(0) # unsqueeze batch_dim
|
||||
query_states = query_states.transpose(0, 1).unsqueeze(0)
|
||||
key_states = key_states.transpose(0, 1).unsqueeze(0)
|
||||
value_states = value_states.transpose(0, 1).unsqueeze(0)
|
||||
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
|
||||
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
@ -966,13 +980,13 @@ class Qwen2_5OmniVisionAttention(nn.Module):
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attention_mask,
|
||||
dropout=0.0,
|
||||
attention_mask=attention_mask,
|
||||
dropout=0.0 if not self.training else self.attention_dropout,
|
||||
scaling=self.scaling,
|
||||
cu_seqlens_q=cu_seqlens, # pass cu seq lens for FA2
|
||||
cu_seqlens_k=cu_seqlens,
|
||||
max_seqlen_q=max_seqlen,
|
||||
max_seqlen_k=max_seqlen,
|
||||
cu_seq_lens_q=cu_seqlens, # pass cu seq lens for FA2
|
||||
cu_seq_lens_k=cu_seqlens,
|
||||
max_length_q=max_seqlen,
|
||||
max_length_k=max_seqlen,
|
||||
is_causal=False,
|
||||
**kwargs,
|
||||
)
|
||||
@ -1009,10 +1023,15 @@ class Qwen2_5OmniVisionBlock(GradientCheckpointingLayer):
|
||||
hidden_states: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor,
|
||||
rotary_pos_emb: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = hidden_states + self.attn(
|
||||
self.norm1(hidden_states), cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb, **kwargs
|
||||
self.norm1(hidden_states),
|
||||
cu_seqlens=cu_seqlens,
|
||||
rotary_pos_emb=rotary_pos_emb,
|
||||
attention_mask=attention_mask,
|
||||
**kwargs,
|
||||
)
|
||||
hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
|
||||
return hidden_states
|
||||
@ -1171,6 +1190,25 @@ class Qwen2_5OmniVisionEncoder(Qwen2_5OmniPreTrainedModel):
|
||||
|
||||
return window_index, cu_window_seqlens
|
||||
|
||||
def _prepare_attention_mask(self, inputs_tensor: torch.Tensor, cu_seqlens: torch.Tensor) -> torch.Tensor:
|
||||
# Flash Attention 2 doesn't need a 4D mask and relies on `cu_seqlens/max_seqlen`
|
||||
# NOTE: the created attention masl only approximates the ragged FA2 attention by
|
||||
# allowing bidirectional attention within `cu_seqlens` blocks, and not attending between
|
||||
# blocks. Though it will not be a 100% match for FA2's `varlen` path
|
||||
if self.config._attn_implementation == "flash_attention_2":
|
||||
return None
|
||||
|
||||
seq_length = inputs_tensor.shape[0]
|
||||
attention_mask = torch.full(
|
||||
[1, 1, seq_length, seq_length],
|
||||
torch.finfo(inputs_tensor.dtype).min,
|
||||
device=inputs_tensor.device,
|
||||
dtype=inputs_tensor.dtype,
|
||||
)
|
||||
for i in range(1, len(cu_seqlens)):
|
||||
attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0
|
||||
return attention_mask
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
@ -1217,10 +1255,13 @@ class Qwen2_5OmniVisionEncoder(Qwen2_5OmniPreTrainedModel):
|
||||
cu_seqlens_now = cu_seqlens
|
||||
else:
|
||||
cu_seqlens_now = cu_window_seqlens
|
||||
|
||||
attention_mask = self._prepare_attention_mask(hidden_states, cu_seqlens_now)
|
||||
hidden_states = blk(
|
||||
hidden_states,
|
||||
cu_seqlens=cu_seqlens_now,
|
||||
rotary_pos_emb=rotary_pos_emb,
|
||||
attention_mask=attention_mask,
|
||||
**kwargs,
|
||||
)
|
||||
hidden_states = self.merger(hidden_states)
|
||||
|
@ -1611,6 +1611,7 @@ class Qwen2_5OmniAudioAttention(nn.Module):
|
||||
f" and `num_heads`: {self.num_heads})."
|
||||
)
|
||||
self.scaling = self.head_dim**-0.5
|
||||
self.attention_dropout = 0.0
|
||||
self.is_decoder = False
|
||||
self.is_causal = False
|
||||
|
||||
@ -1623,6 +1624,7 @@ class Qwen2_5OmniAudioAttention(nn.Module):
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
cu_seqlens: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
|
||||
"""Input shape: Batch x Time x Channel"""
|
||||
@ -1638,15 +1640,6 @@ class Qwen2_5OmniAudioAttention(nn.Module):
|
||||
value_states = value_states.transpose(0, 1).unsqueeze(0)
|
||||
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
|
||||
|
||||
attention_mask = torch.full(
|
||||
[1, 1, seq_length, key_states.shape[-2]],
|
||||
torch.finfo(query_states.dtype).min,
|
||||
device=query_states.device,
|
||||
dtype=query_states.dtype,
|
||||
)
|
||||
for i in range(1, len(cu_seqlens)):
|
||||
attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0
|
||||
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
if self.config._attn_implementation != "eager":
|
||||
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
||||
@ -1656,13 +1649,13 @@ class Qwen2_5OmniAudioAttention(nn.Module):
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attention_mask,
|
||||
dropout=0.0 if not self.training else self.dropout,
|
||||
attention_mask=attention_mask,
|
||||
dropout=0.0 if not self.training else self.attention_dropout,
|
||||
scaling=self.scaling,
|
||||
cu_seqlens_q=cu_seqlens, # pass cu seq lens for FA2
|
||||
cu_seqlens_k=cu_seqlens,
|
||||
max_seqlen_q=max_seqlen,
|
||||
max_seqlen_k=max_seqlen,
|
||||
cu_seq_lens_q=cu_seqlens, # pass cu seq lens for FA2
|
||||
cu_seq_lens_k=cu_seqlens,
|
||||
max_length_q=max_seqlen,
|
||||
max_length_k=max_seqlen,
|
||||
is_causal=False,
|
||||
**kwargs,
|
||||
)
|
||||
@ -1682,6 +1675,7 @@ class Qwen2_5OmniAudioEncoderLayer(Qwen2AudioEncoderLayer):
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
residual = hidden_states
|
||||
@ -1689,6 +1683,7 @@ class Qwen2_5OmniAudioEncoderLayer(Qwen2AudioEncoderLayer):
|
||||
hidden_states = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
cu_seqlens=cu_seqlens,
|
||||
attention_mask=attention_mask,
|
||||
**kwargs,
|
||||
)
|
||||
hidden_states = residual + hidden_states
|
||||
@ -1770,6 +1765,25 @@ class Qwen2_5OmniAudioEncoder(Qwen2_5OmniPreTrainedModel):
|
||||
def set_input_embeddings(self, value: nn.Module):
|
||||
self.conv1 = value
|
||||
|
||||
def _prepare_attention_mask(self, inputs_tensor: torch.Tensor, cu_seqlens: torch.Tensor) -> torch.Tensor:
|
||||
# Flash Attention 2 doesn't need a 4D mask and relies on `cu_seqlens/max_seqlen`
|
||||
# NOTE: the created attention masl only approximates the ragged FA2 attention by
|
||||
# allowing bidirectional attention within `cu_seqlens` blocks, and not attending between
|
||||
# blocks. Though it will not be a 100% match for FA2's `varlen` path
|
||||
if self.config._attn_implementation == "flash_attention_2":
|
||||
return None
|
||||
|
||||
seq_length = inputs_tensor.shape[0]
|
||||
attention_mask = torch.full(
|
||||
[1, 1, seq_length, seq_length],
|
||||
torch.finfo(inputs_tensor.dtype).min,
|
||||
device=inputs_tensor.device,
|
||||
dtype=inputs_tensor.dtype,
|
||||
)
|
||||
for i in range(1, len(cu_seqlens)):
|
||||
attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0
|
||||
return attention_mask
|
||||
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
@ -1818,9 +1832,15 @@ class Qwen2_5OmniAudioEncoder(Qwen2_5OmniPreTrainedModel):
|
||||
padded_mask_after_cnn.sum(1).cumsum(0),
|
||||
)
|
||||
).to(torch.int32)
|
||||
attention_mask = self._prepare_attention_mask(hidden_states, cu_seqlens)
|
||||
|
||||
for encoder_layer in self.layers:
|
||||
layer_outputs = encoder_layer(hidden_states, cu_seqlens, **kwargs)
|
||||
layer_outputs = encoder_layer(
|
||||
hidden_states,
|
||||
cu_seqlens=cu_seqlens,
|
||||
attention_mask=attention_mask,
|
||||
**kwargs,
|
||||
)
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
hidden_states_list = hidden_states.split(aftercnn_lens.tolist(), dim=0)
|
||||
@ -1906,12 +1926,15 @@ class Qwen2_5OmniVisionAttention(nn.Module):
|
||||
self.scaling = self.head_dim**-0.5
|
||||
self.num_key_value_groups = 1 # needed for eager attention
|
||||
self.config = config
|
||||
self.attention_dropout = 0.0
|
||||
self.is_causal = False
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor,
|
||||
rotary_pos_emb: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
seq_length = hidden_states.shape[0]
|
||||
@ -1921,18 +1944,9 @@ class Qwen2_5OmniVisionAttention(nn.Module):
|
||||
query_states = apply_rotary_pos_emb_vision(query_states.unsqueeze(0), rotary_pos_emb).squeeze(0)
|
||||
key_states = apply_rotary_pos_emb_vision(key_states.unsqueeze(0), rotary_pos_emb).squeeze(0)
|
||||
|
||||
attention_mask = torch.full(
|
||||
[1, 1, seq_length, seq_length],
|
||||
torch.finfo(query_states.dtype).min,
|
||||
device=query_states.device,
|
||||
dtype=query_states.dtype,
|
||||
)
|
||||
for i in range(1, len(cu_seqlens)):
|
||||
attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0
|
||||
|
||||
query_states = query_states.transpose(0, 1).unsqueeze(0) # unsqueeze batch_dim
|
||||
key_states = key_states.transpose(0, 1).unsqueeze(0) # unsqueeze batch_dim
|
||||
value_states = value_states.transpose(0, 1).unsqueeze(0) # unsqueeze batch_dim
|
||||
query_states = query_states.transpose(0, 1).unsqueeze(0)
|
||||
key_states = key_states.transpose(0, 1).unsqueeze(0)
|
||||
value_states = value_states.transpose(0, 1).unsqueeze(0)
|
||||
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
|
||||
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
@ -1944,13 +1958,13 @@ class Qwen2_5OmniVisionAttention(nn.Module):
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attention_mask,
|
||||
dropout=0.0,
|
||||
attention_mask=attention_mask,
|
||||
dropout=0.0 if not self.training else self.attention_dropout,
|
||||
scaling=self.scaling,
|
||||
cu_seqlens_q=cu_seqlens, # pass cu seq lens for FA2
|
||||
cu_seqlens_k=cu_seqlens,
|
||||
max_seqlen_q=max_seqlen,
|
||||
max_seqlen_k=max_seqlen,
|
||||
cu_seq_lens_q=cu_seqlens, # pass cu seq lens for FA2
|
||||
cu_seq_lens_k=cu_seqlens,
|
||||
max_length_q=max_seqlen,
|
||||
max_length_k=max_seqlen,
|
||||
is_causal=False,
|
||||
**kwargs,
|
||||
)
|
||||
@ -1970,10 +1984,15 @@ class Qwen2_5OmniVisionBlock(Qwen2_5_VLVisionBlock):
|
||||
hidden_states: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor,
|
||||
rotary_pos_emb: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = hidden_states + self.attn(
|
||||
self.norm1(hidden_states), cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb, **kwargs
|
||||
self.norm1(hidden_states),
|
||||
cu_seqlens=cu_seqlens,
|
||||
rotary_pos_emb=rotary_pos_emb,
|
||||
attention_mask=attention_mask,
|
||||
**kwargs,
|
||||
)
|
||||
hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
|
||||
return hidden_states
|
||||
@ -1987,6 +2006,25 @@ class Qwen2_5OmniVisionEncoder(Qwen2_5_VisionTransformerPretrainedModel):
|
||||
super().__init__(config, *inputs, **kwargs)
|
||||
self.blocks = nn.ModuleList([Qwen2_5OmniVisionBlock(config) for _ in range(config.depth)])
|
||||
|
||||
def _prepare_attention_mask(self, inputs_tensor: torch.Tensor, cu_seqlens: torch.Tensor) -> torch.Tensor:
|
||||
# Flash Attention 2 doesn't need a 4D mask and relies on `cu_seqlens/max_seqlen`
|
||||
# NOTE: the created attention masl only approximates the ragged FA2 attention by
|
||||
# allowing bidirectional attention within `cu_seqlens` blocks, and not attending between
|
||||
# blocks. Though it will not be a 100% match for FA2's `varlen` path
|
||||
if self.config._attn_implementation == "flash_attention_2":
|
||||
return None
|
||||
|
||||
seq_length = inputs_tensor.shape[0]
|
||||
attention_mask = torch.full(
|
||||
[1, 1, seq_length, seq_length],
|
||||
torch.finfo(inputs_tensor.dtype).min,
|
||||
device=inputs_tensor.device,
|
||||
dtype=inputs_tensor.dtype,
|
||||
)
|
||||
for i in range(1, len(cu_seqlens)):
|
||||
attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0
|
||||
return attention_mask
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
@ -2033,10 +2071,13 @@ class Qwen2_5OmniVisionEncoder(Qwen2_5_VisionTransformerPretrainedModel):
|
||||
cu_seqlens_now = cu_seqlens
|
||||
else:
|
||||
cu_seqlens_now = cu_window_seqlens
|
||||
|
||||
attention_mask = self._prepare_attention_mask(hidden_states, cu_seqlens_now)
|
||||
hidden_states = blk(
|
||||
hidden_states,
|
||||
cu_seqlens=cu_seqlens_now,
|
||||
rotary_pos_emb=rotary_pos_emb,
|
||||
attention_mask=attention_mask,
|
||||
**kwargs,
|
||||
)
|
||||
hidden_states = self.merger(hidden_states)
|
||||
|
@ -206,6 +206,8 @@ class Qwen2_5_VLVisionAttention(nn.Module):
|
||||
self.proj = nn.Linear(self.dim, self.dim)
|
||||
self.scaling = self.head_dim**-0.5
|
||||
self.config = config
|
||||
self.attention_dropout = 0.0
|
||||
self.is_causal = False
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -213,6 +215,7 @@ class Qwen2_5_VLVisionAttention(nn.Module):
|
||||
cu_seqlens: torch.Tensor,
|
||||
rotary_pos_emb: Optional[torch.Tensor] = None,
|
||||
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
seq_length = hidden_states.shape[0]
|
||||
@ -233,18 +236,9 @@ class Qwen2_5_VLVisionAttention(nn.Module):
|
||||
cos, sin = position_embeddings
|
||||
query_states, key_states = apply_rotary_pos_emb_vision(query_states, key_states, cos, sin)
|
||||
|
||||
attention_mask = torch.full(
|
||||
[1, 1, seq_length, seq_length],
|
||||
torch.finfo(value_states.dtype).min,
|
||||
device=value_states.device,
|
||||
dtype=value_states.dtype,
|
||||
)
|
||||
for i in range(1, len(cu_seqlens)):
|
||||
attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0
|
||||
|
||||
query_states = query_states.transpose(0, 1).unsqueeze(0) # unsqueeze batch_dim
|
||||
key_states = key_states.transpose(0, 1).unsqueeze(0) # unsqueeze batch_dim
|
||||
value_states = value_states.transpose(0, 1).unsqueeze(0) # unsqueeze batch_dim
|
||||
query_states = query_states.transpose(0, 1).unsqueeze(0)
|
||||
key_states = key_states.transpose(0, 1).unsqueeze(0)
|
||||
value_states = value_states.transpose(0, 1).unsqueeze(0)
|
||||
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
|
||||
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
@ -256,13 +250,13 @@ class Qwen2_5_VLVisionAttention(nn.Module):
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attention_mask,
|
||||
dropout=0.0,
|
||||
attention_mask=attention_mask,
|
||||
dropout=0.0 if not self.training else self.attention_dropout,
|
||||
scaling=self.scaling,
|
||||
cu_seqlens_q=cu_seqlens, # pass cu seq lens for FA2
|
||||
cu_seqlens_k=cu_seqlens,
|
||||
max_seqlen_q=max_seqlen,
|
||||
max_seqlen_k=max_seqlen,
|
||||
cu_seq_lens_q=cu_seqlens, # pass cu seq lens for FA2
|
||||
cu_seq_lens_k=cu_seqlens,
|
||||
max_length_q=max_seqlen,
|
||||
max_length_k=max_seqlen,
|
||||
is_causal=False,
|
||||
**kwargs,
|
||||
)
|
||||
@ -286,6 +280,7 @@ class Qwen2_5_VLVisionBlock(GradientCheckpointingLayer):
|
||||
cu_seqlens: torch.Tensor,
|
||||
rotary_pos_emb: Optional[torch.Tensor] = None,
|
||||
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = hidden_states + self.attn(
|
||||
@ -293,6 +288,7 @@ class Qwen2_5_VLVisionBlock(GradientCheckpointingLayer):
|
||||
cu_seqlens=cu_seqlens,
|
||||
rotary_pos_emb=rotary_pos_emb,
|
||||
position_embeddings=position_embeddings,
|
||||
attention_mask=attention_mask,
|
||||
**kwargs,
|
||||
)
|
||||
hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
|
||||
@ -426,6 +422,25 @@ class Qwen2_5_VisionTransformerPretrainedModel(Qwen2_5_VLPreTrainedModel):
|
||||
|
||||
return window_index, cu_window_seqlens
|
||||
|
||||
def _prepare_attention_mask(self, inputs_tensor: torch.Tensor, cu_seqlens: torch.Tensor) -> torch.Tensor:
|
||||
# Flash Attention 2 doesn't need a 4D mask and relies on `cu_seqlens/max_seqlen`
|
||||
# NOTE: the created attention masl only approximates the ragged FA2 attention by
|
||||
# allowing bidirectional attention within `cu_seqlens` blocks, and not attending between
|
||||
# blocks. Though it will not be a 100% match for FA2's `varlen` path
|
||||
if self.config._attn_implementation == "flash_attention_2":
|
||||
return None
|
||||
|
||||
seq_length = inputs_tensor.shape[0]
|
||||
attention_mask = torch.full(
|
||||
[1, 1, seq_length, seq_length],
|
||||
torch.finfo(inputs_tensor.dtype).min,
|
||||
device=inputs_tensor.device,
|
||||
dtype=inputs_tensor.dtype,
|
||||
)
|
||||
for i in range(1, len(cu_seqlens)):
|
||||
attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0
|
||||
return attention_mask
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
@ -472,8 +487,14 @@ class Qwen2_5_VisionTransformerPretrainedModel(Qwen2_5_VLPreTrainedModel):
|
||||
cu_seqlens_now = cu_seqlens
|
||||
else:
|
||||
cu_seqlens_now = cu_window_seqlens
|
||||
|
||||
attention_mask = self._prepare_attention_mask(hidden_states, cu_seqlens_now)
|
||||
hidden_states = blk(
|
||||
hidden_states, cu_seqlens=cu_seqlens_now, position_embeddings=position_embeddings, **kwargs
|
||||
hidden_states,
|
||||
cu_seqlens=cu_seqlens_now,
|
||||
position_embeddings=position_embeddings,
|
||||
attention_mask=attention_mask,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = self.merger(hidden_states)
|
||||
|
@ -159,6 +159,7 @@ class Qwen2_5_VLVisionBlock(GradientCheckpointingLayer):
|
||||
cu_seqlens: torch.Tensor,
|
||||
rotary_pos_emb: Optional[torch.Tensor] = None,
|
||||
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = hidden_states + self.attn(
|
||||
@ -166,6 +167,7 @@ class Qwen2_5_VLVisionBlock(GradientCheckpointingLayer):
|
||||
cu_seqlens=cu_seqlens,
|
||||
rotary_pos_emb=rotary_pos_emb,
|
||||
position_embeddings=position_embeddings,
|
||||
attention_mask=attention_mask,
|
||||
**kwargs,
|
||||
)
|
||||
hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
|
||||
@ -287,6 +289,25 @@ class Qwen2_5_VisionTransformerPretrainedModel(Qwen2_5_VLPreTrainedModel):
|
||||
|
||||
return window_index, cu_window_seqlens
|
||||
|
||||
def _prepare_attention_mask(self, inputs_tensor: torch.Tensor, cu_seqlens: torch.Tensor) -> torch.Tensor:
|
||||
# Flash Attention 2 doesn't need a 4D mask and relies on `cu_seqlens/max_seqlen`
|
||||
# NOTE: the created attention masl only approximates the ragged FA2 attention by
|
||||
# allowing bidirectional attention within `cu_seqlens` blocks, and not attending between
|
||||
# blocks. Though it will not be a 100% match for FA2's `varlen` path
|
||||
if self.config._attn_implementation == "flash_attention_2":
|
||||
return None
|
||||
|
||||
seq_length = inputs_tensor.shape[0]
|
||||
attention_mask = torch.full(
|
||||
[1, 1, seq_length, seq_length],
|
||||
torch.finfo(inputs_tensor.dtype).min,
|
||||
device=inputs_tensor.device,
|
||||
dtype=inputs_tensor.dtype,
|
||||
)
|
||||
for i in range(1, len(cu_seqlens)):
|
||||
attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0
|
||||
return attention_mask
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
@ -333,8 +354,14 @@ class Qwen2_5_VisionTransformerPretrainedModel(Qwen2_5_VLPreTrainedModel):
|
||||
cu_seqlens_now = cu_seqlens
|
||||
else:
|
||||
cu_seqlens_now = cu_window_seqlens
|
||||
|
||||
attention_mask = self._prepare_attention_mask(hidden_states, cu_seqlens_now)
|
||||
hidden_states = blk(
|
||||
hidden_states, cu_seqlens=cu_seqlens_now, position_embeddings=position_embeddings, **kwargs
|
||||
hidden_states,
|
||||
cu_seqlens=cu_seqlens_now,
|
||||
position_embeddings=position_embeddings,
|
||||
attention_mask=attention_mask,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = self.merger(hidden_states)
|
||||
|
@ -324,6 +324,8 @@ class VisionAttention(nn.Module):
|
||||
self.proj = nn.Linear(self.dim, self.dim)
|
||||
self.scaling = self.head_dim**-0.5
|
||||
self.config = config
|
||||
self.attention_dropout = 0.0
|
||||
self.is_causal = False
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -331,6 +333,7 @@ class VisionAttention(nn.Module):
|
||||
cu_seqlens: torch.Tensor,
|
||||
rotary_pos_emb: Optional[torch.Tensor] = None,
|
||||
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
seq_length = hidden_states.shape[0]
|
||||
@ -351,18 +354,9 @@ class VisionAttention(nn.Module):
|
||||
cos, sin = position_embeddings
|
||||
query_states, key_states = apply_rotary_pos_emb_vision(query_states, key_states, cos, sin)
|
||||
|
||||
attention_mask = torch.full(
|
||||
[1, 1, seq_length, seq_length],
|
||||
torch.finfo(value_states.dtype).min,
|
||||
device=value_states.device,
|
||||
dtype=value_states.dtype,
|
||||
)
|
||||
for i in range(1, len(cu_seqlens)):
|
||||
attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0
|
||||
|
||||
query_states = query_states.transpose(0, 1).unsqueeze(0) # unsqueeze batch_dim
|
||||
key_states = key_states.transpose(0, 1).unsqueeze(0) # unsqueeze batch_dim
|
||||
value_states = value_states.transpose(0, 1).unsqueeze(0) # unsqueeze batch_dim
|
||||
query_states = query_states.transpose(0, 1).unsqueeze(0)
|
||||
key_states = key_states.transpose(0, 1).unsqueeze(0)
|
||||
value_states = value_states.transpose(0, 1).unsqueeze(0)
|
||||
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
|
||||
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
@ -374,13 +368,13 @@ class VisionAttention(nn.Module):
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attention_mask,
|
||||
dropout=0.0,
|
||||
attention_mask=attention_mask,
|
||||
dropout=0.0 if not self.training else self.attention_dropout,
|
||||
scaling=self.scaling,
|
||||
cu_seqlens_q=cu_seqlens, # pass cu seq lens for FA2
|
||||
cu_seqlens_k=cu_seqlens,
|
||||
max_seqlen_q=max_seqlen,
|
||||
max_seqlen_k=max_seqlen,
|
||||
cu_seq_lens_q=cu_seqlens, # pass cu seq lens for FA2
|
||||
cu_seq_lens_k=cu_seqlens,
|
||||
max_length_q=max_seqlen,
|
||||
max_length_k=max_seqlen,
|
||||
is_causal=False,
|
||||
**kwargs,
|
||||
)
|
||||
@ -406,6 +400,7 @@ class Qwen2VLVisionBlock(GradientCheckpointingLayer):
|
||||
cu_seqlens: torch.Tensor,
|
||||
rotary_pos_emb: Optional[torch.Tensor] = None,
|
||||
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = hidden_states + self.attn(
|
||||
@ -413,6 +408,7 @@ class Qwen2VLVisionBlock(GradientCheckpointingLayer):
|
||||
cu_seqlens=cu_seqlens,
|
||||
rotary_pos_emb=rotary_pos_emb,
|
||||
position_embeddings=position_embeddings,
|
||||
attention_mask=attention_mask,
|
||||
**kwargs,
|
||||
)
|
||||
hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
|
||||
@ -725,6 +721,25 @@ class Qwen2VisionTransformerPretrainedModel(Qwen2VLPreTrainedModel):
|
||||
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
|
||||
return rotary_pos_emb
|
||||
|
||||
def _prepare_attention_mask(self, inputs_tensor: torch.Tensor, cu_seqlens: torch.Tensor) -> torch.Tensor:
|
||||
# Flash Attention 2 doesn't need a 4D mask and relies on `cu_seqlens/max_seqlen`
|
||||
# NOTE: the created attention masl only approximates the ragged FA2 attention by
|
||||
# allowing bidirectional attention within `cu_seqlens` blocks, and not attending between
|
||||
# blocks. Though it will not be a 100% match for FA2's `varlen` path
|
||||
if self.config._attn_implementation == "flash_attention_2":
|
||||
return None
|
||||
|
||||
seq_length = inputs_tensor.shape[0]
|
||||
attention_mask = torch.full(
|
||||
[1, 1, seq_length, seq_length],
|
||||
torch.finfo(inputs_tensor.dtype).min,
|
||||
device=inputs_tensor.device,
|
||||
dtype=inputs_tensor.dtype,
|
||||
)
|
||||
for i in range(1, len(cu_seqlens)):
|
||||
attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0
|
||||
return attention_mask
|
||||
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
@ -750,10 +765,15 @@ class Qwen2VisionTransformerPretrainedModel(Qwen2VLPreTrainedModel):
|
||||
dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32,
|
||||
)
|
||||
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
|
||||
attention_mask = self._prepare_attention_mask(hidden_states, cu_seqlens)
|
||||
|
||||
for blk in self.blocks:
|
||||
hidden_states = blk(
|
||||
hidden_states, cu_seqlens=cu_seqlens, position_embeddings=position_embeddings, **kwargs
|
||||
hidden_states,
|
||||
cu_seqlens=cu_seqlens,
|
||||
position_embeddings=position_embeddings,
|
||||
attention_mask=attention_mask,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return self.merger(hidden_states)
|
||||
|
@ -184,7 +184,7 @@ class Qwen2_5_VLVisionText2TextModelTester:
|
||||
input_ids[:, self.num_image_tokens - 1] = self.vision_start_token_id
|
||||
inputs_dict = {
|
||||
"pixel_values": pixel_values,
|
||||
"image_grid_thw": torch.tensor([[1, 1, 1]] * self.batch_size),
|
||||
"image_grid_thw": torch.tensor([[1, 1, 1]] * self.batch_size, device=torch_device),
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
}
|
||||
|
@ -176,7 +176,7 @@ class Qwen2VLVisionText2TextModelTester:
|
||||
|
||||
inputs_dict = {
|
||||
"pixel_values": pixel_values,
|
||||
"image_grid_thw": torch.tensor([[1, 1, 1]] * self.batch_size),
|
||||
"image_grid_thw": torch.tensor([[1, 1, 1]] * self.batch_size, device=torch_device),
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user