This commit is contained in:
JJJYmmm 2025-07-03 00:37:13 +02:00 committed by GitHub
commit e48161a471
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 200 additions and 244 deletions

View File

@ -296,7 +296,6 @@ class Glm4vVisionAttention(nn.Module):
cu_seqlens: torch.Tensor,
rotary_pos_emb: Optional[torch.Tensor] = None,
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
seq_length = hidden_states.shape[0]
@ -320,27 +319,51 @@ class Glm4vVisionAttention(nn.Module):
query_states = query_states.transpose(0, 1).unsqueeze(0)
key_states = key_states.transpose(0, 1).unsqueeze(0)
value_states = value_states.transpose(0, 1).unsqueeze(0)
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
attn_output, _ = attention_interface(
self,
query_states,
key_states,
value_states,
attention_mask=attention_mask,
dropout=0.0 if not self.training else self.attention_dropout,
scaling=self.scaling,
cu_seq_lens_q=cu_seqlens, # pass cu seq lens for FA2
cu_seq_lens_k=cu_seqlens,
max_length_q=max_seqlen,
max_length_k=max_seqlen,
is_causal=False,
**kwargs,
)
if self.config._attn_implementation == "flash_attention_2":
# Flash Attention 2: Use cu_seqlens for variable length attention
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
attn_output, _ = attention_interface(
self,
query_states,
key_states,
value_states,
attention_mask=None,
scaling=self.scaling,
dropout=0.0 if not self.training else self.attention_dropout,
cu_seq_lens_q=cu_seqlens,
cu_seq_lens_k=cu_seqlens,
max_length_q=max_seqlen,
max_length_k=max_seqlen,
is_causal=False,
**kwargs,
)
else:
# Other implementations: Process each chunk separately
lengths = cu_seqlens[1:] - cu_seqlens[:-1]
splits = [
torch.split(tensor, lengths.tolist(), dim=2) for tensor in (query_states, key_states, value_states)
]
attn_outputs = [
attention_interface(
self,
q,
k,
v,
attention_mask=None,
scaling=self.scaling,
dropout=0.0 if not self.training else self.attention_dropout,
is_causal=False,
**kwargs,
)[0]
for q, k, v in zip(*splits)
]
attn_output = torch.cat(attn_outputs, dim=1)
attn_output = attn_output.reshape(seq_length, -1).contiguous()
attn_output = self.proj(attn_output)
@ -361,7 +384,6 @@ class Glm4vVisionBlock(GradientCheckpointingLayer):
cu_seqlens: torch.Tensor,
rotary_pos_emb: Optional[torch.Tensor] = None,
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
hidden_states = hidden_states + self.attn(
@ -369,7 +391,6 @@ class Glm4vVisionBlock(GradientCheckpointingLayer):
cu_seqlens=cu_seqlens,
rotary_pos_emb=rotary_pos_emb,
position_embeddings=position_embeddings,
attention_mask=attention_mask,
**kwargs,
)
hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
@ -467,25 +488,6 @@ class Glm4vVisionModel(Glm4vPreTrainedModel):
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
return rotary_pos_emb, pos_ids
def _prepare_attention_mask(self, inputs_tensor: torch.Tensor, cu_seqlens: torch.Tensor) -> torch.Tensor:
# Flash Attention 2 doesn't need a 4D mask and relies on `cu_seqlens/max_seqlen`
# NOTE: the created attention masl only approximates the ragged FA2 attention by
# allowing bidirectional attention within `cu_seqlens` blocks, and not attending between
# blocks. Though it will not be a 100% match for FA2's `varlen` path
if self.config._attn_implementation == "flash_attention_2":
return None
seq_length = inputs_tensor.shape[0]
attention_mask = torch.full(
[1, 1, seq_length, seq_length],
torch.finfo(inputs_tensor.dtype).min,
device=inputs_tensor.device,
dtype=inputs_tensor.dtype,
)
for i in range(1, len(cu_seqlens)):
attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0
return attention_mask
def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.Tensor:
"""
Args:
@ -515,14 +517,12 @@ class Glm4vVisionModel(Glm4vPreTrainedModel):
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
hidden_states = self.embeddings(hidden_states, seqlens, grid_thw, image_type_ids[:, 0], image_type_ids[:, 1])
attention_mask = self._prepare_attention_mask(hidden_states, cu_seqlens=cu_seqlens)
for blk in self.blocks:
hidden_states = blk(
hidden_states,
cu_seqlens=cu_seqlens,
position_embeddings=position_embeddings,
attention_mask=attention_mask,
)
hidden_states = self.post_layernorm(hidden_states)

View File

@ -603,25 +603,6 @@ class Glm4vVisionModel(Glm4vPreTrainedModel):
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
return rotary_pos_emb, pos_ids
def _prepare_attention_mask(self, inputs_tensor: torch.Tensor, cu_seqlens: torch.Tensor) -> torch.Tensor:
# Flash Attention 2 doesn't need a 4D mask and relies on `cu_seqlens/max_seqlen`
# NOTE: the created attention masl only approximates the ragged FA2 attention by
# allowing bidirectional attention within `cu_seqlens` blocks, and not attending between
# blocks. Though it will not be a 100% match for FA2's `varlen` path
if self.config._attn_implementation == "flash_attention_2":
return None
seq_length = inputs_tensor.shape[0]
attention_mask = torch.full(
[1, 1, seq_length, seq_length],
torch.finfo(inputs_tensor.dtype).min,
device=inputs_tensor.device,
dtype=inputs_tensor.dtype,
)
for i in range(1, len(cu_seqlens)):
attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0
return attention_mask
def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.Tensor:
"""
Args:
@ -651,14 +632,12 @@ class Glm4vVisionModel(Glm4vPreTrainedModel):
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
hidden_states = self.embeddings(hidden_states, seqlens, grid_thw, image_type_ids[:, 0], image_type_ids[:, 1])
attention_mask = self._prepare_attention_mask(hidden_states, cu_seqlens=cu_seqlens)
for blk in self.blocks:
hidden_states = blk(
hidden_states,
cu_seqlens=cu_seqlens,
position_embeddings=position_embeddings,
attention_mask=attention_mask,
)
hidden_states = self.post_layernorm(hidden_states)

View File

@ -956,7 +956,6 @@ class Qwen2_5OmniVisionAttention(nn.Module):
hidden_states: torch.Tensor,
cu_seqlens: torch.Tensor,
rotary_pos_emb: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
seq_length = hidden_states.shape[0]
@ -969,27 +968,51 @@ class Qwen2_5OmniVisionAttention(nn.Module):
query_states = query_states.transpose(0, 1).unsqueeze(0)
key_states = key_states.transpose(0, 1).unsqueeze(0)
value_states = value_states.transpose(0, 1).unsqueeze(0)
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
attn_output, _ = attention_interface(
self,
query_states,
key_states,
value_states,
attention_mask=attention_mask,
dropout=0.0 if not self.training else self.attention_dropout,
scaling=self.scaling,
cu_seq_lens_q=cu_seqlens, # pass cu seq lens for FA2
cu_seq_lens_k=cu_seqlens,
max_length_q=max_seqlen,
max_length_k=max_seqlen,
is_causal=False,
**kwargs,
)
if self.config._attn_implementation == "flash_attention_2":
# Flash Attention 2: Use cu_seqlens for variable length attention
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
attn_output, _ = attention_interface(
self,
query_states,
key_states,
value_states,
attention_mask=None,
scaling=self.scaling,
dropout=0.0 if not self.training else self.attention_dropout,
cu_seq_lens_q=cu_seqlens,
cu_seq_lens_k=cu_seqlens,
max_length_q=max_seqlen,
max_length_k=max_seqlen,
is_causal=False,
**kwargs,
)
else:
# Other implementations: Process each chunk separately
lengths = cu_seqlens[1:] - cu_seqlens[:-1]
splits = [
torch.split(tensor, lengths.tolist(), dim=2) for tensor in (query_states, key_states, value_states)
]
attn_outputs = [
attention_interface(
self,
q,
k,
v,
attention_mask=None,
scaling=self.scaling,
dropout=0.0 if not self.training else self.attention_dropout,
is_causal=False,
**kwargs,
)[0]
for q, k, v in zip(*splits)
]
attn_output = torch.cat(attn_outputs, dim=1)
attn_output = attn_output.reshape(seq_length, -1).contiguous()
attn_output = self.proj(attn_output)
@ -1023,14 +1046,12 @@ class Qwen2_5OmniVisionBlock(GradientCheckpointingLayer):
hidden_states: torch.Tensor,
cu_seqlens: torch.Tensor,
rotary_pos_emb: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
hidden_states = hidden_states + self.attn(
self.norm1(hidden_states),
cu_seqlens=cu_seqlens,
rotary_pos_emb=rotary_pos_emb,
attention_mask=attention_mask,
**kwargs,
)
hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
@ -1190,25 +1211,6 @@ class Qwen2_5OmniVisionEncoder(Qwen2_5OmniPreTrainedModel):
return window_index, cu_window_seqlens
def _prepare_attention_mask(self, inputs_tensor: torch.Tensor, cu_seqlens: torch.Tensor) -> torch.Tensor:
# Flash Attention 2 doesn't need a 4D mask and relies on `cu_seqlens/max_seqlen`
# NOTE: the created attention masl only approximates the ragged FA2 attention by
# allowing bidirectional attention within `cu_seqlens` blocks, and not attending between
# blocks. Though it will not be a 100% match for FA2's `varlen` path
if self.config._attn_implementation == "flash_attention_2":
return None
seq_length = inputs_tensor.shape[0]
attention_mask = torch.full(
[1, 1, seq_length, seq_length],
torch.finfo(inputs_tensor.dtype).min,
device=inputs_tensor.device,
dtype=inputs_tensor.dtype,
)
for i in range(1, len(cu_seqlens)):
attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0
return attention_mask
def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs) -> torch.Tensor:
"""
Args:
@ -1256,12 +1258,10 @@ class Qwen2_5OmniVisionEncoder(Qwen2_5OmniPreTrainedModel):
else:
cu_seqlens_now = cu_window_seqlens
attention_mask = self._prepare_attention_mask(hidden_states, cu_seqlens_now)
hidden_states = blk(
hidden_states,
cu_seqlens=cu_seqlens_now,
rotary_pos_emb=rotary_pos_emb,
attention_mask=attention_mask,
**kwargs,
)
hidden_states = self.merger(hidden_states)

View File

@ -1934,7 +1934,6 @@ class Qwen2_5OmniVisionAttention(nn.Module):
hidden_states: torch.Tensor,
cu_seqlens: torch.Tensor,
rotary_pos_emb: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
seq_length = hidden_states.shape[0]
@ -1947,27 +1946,51 @@ class Qwen2_5OmniVisionAttention(nn.Module):
query_states = query_states.transpose(0, 1).unsqueeze(0)
key_states = key_states.transpose(0, 1).unsqueeze(0)
value_states = value_states.transpose(0, 1).unsqueeze(0)
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
attn_output, _ = attention_interface(
self,
query_states,
key_states,
value_states,
attention_mask=attention_mask,
dropout=0.0 if not self.training else self.attention_dropout,
scaling=self.scaling,
cu_seq_lens_q=cu_seqlens, # pass cu seq lens for FA2
cu_seq_lens_k=cu_seqlens,
max_length_q=max_seqlen,
max_length_k=max_seqlen,
is_causal=False,
**kwargs,
)
if self.config._attn_implementation == "flash_attention_2":
# Flash Attention 2: Use cu_seqlens for variable length attention
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
attn_output, _ = attention_interface(
self,
query_states,
key_states,
value_states,
attention_mask=None,
scaling=self.scaling,
dropout=0.0 if not self.training else self.attention_dropout,
cu_seq_lens_q=cu_seqlens,
cu_seq_lens_k=cu_seqlens,
max_length_q=max_seqlen,
max_length_k=max_seqlen,
is_causal=False,
**kwargs,
)
else:
# Other implementations: Process each chunk separately
lengths = cu_seqlens[1:] - cu_seqlens[:-1]
splits = [
torch.split(tensor, lengths.tolist(), dim=2) for tensor in (query_states, key_states, value_states)
]
attn_outputs = [
attention_interface(
self,
q,
k,
v,
attention_mask=None,
scaling=self.scaling,
dropout=0.0 if not self.training else self.attention_dropout,
is_causal=False,
**kwargs,
)[0]
for q, k, v in zip(*splits)
]
attn_output = torch.cat(attn_outputs, dim=1)
attn_output = attn_output.reshape(seq_length, -1).contiguous()
attn_output = self.proj(attn_output)
@ -1984,14 +2007,12 @@ class Qwen2_5OmniVisionBlock(Qwen2_5_VLVisionBlock):
hidden_states: torch.Tensor,
cu_seqlens: torch.Tensor,
rotary_pos_emb: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
hidden_states = hidden_states + self.attn(
self.norm1(hidden_states),
cu_seqlens=cu_seqlens,
rotary_pos_emb=rotary_pos_emb,
attention_mask=attention_mask,
**kwargs,
)
hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
@ -2006,25 +2027,6 @@ class Qwen2_5OmniVisionEncoder(Qwen2_5_VisionTransformerPretrainedModel):
super().__init__(config, *inputs, **kwargs)
self.blocks = nn.ModuleList([Qwen2_5OmniVisionBlock(config) for _ in range(config.depth)])
def _prepare_attention_mask(self, inputs_tensor: torch.Tensor, cu_seqlens: torch.Tensor) -> torch.Tensor:
# Flash Attention 2 doesn't need a 4D mask and relies on `cu_seqlens/max_seqlen`
# NOTE: the created attention masl only approximates the ragged FA2 attention by
# allowing bidirectional attention within `cu_seqlens` blocks, and not attending between
# blocks. Though it will not be a 100% match for FA2's `varlen` path
if self.config._attn_implementation == "flash_attention_2":
return None
seq_length = inputs_tensor.shape[0]
attention_mask = torch.full(
[1, 1, seq_length, seq_length],
torch.finfo(inputs_tensor.dtype).min,
device=inputs_tensor.device,
dtype=inputs_tensor.dtype,
)
for i in range(1, len(cu_seqlens)):
attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0
return attention_mask
def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs) -> torch.Tensor:
"""
Args:
@ -2072,12 +2074,10 @@ class Qwen2_5OmniVisionEncoder(Qwen2_5_VisionTransformerPretrainedModel):
else:
cu_seqlens_now = cu_window_seqlens
attention_mask = self._prepare_attention_mask(hidden_states, cu_seqlens_now)
hidden_states = blk(
hidden_states,
cu_seqlens=cu_seqlens_now,
rotary_pos_emb=rotary_pos_emb,
attention_mask=attention_mask,
**kwargs,
)
hidden_states = self.merger(hidden_states)

View File

@ -215,7 +215,6 @@ class Qwen2_5_VLVisionAttention(nn.Module):
cu_seqlens: torch.Tensor,
rotary_pos_emb: Optional[torch.Tensor] = None,
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
seq_length = hidden_states.shape[0]
@ -239,27 +238,51 @@ class Qwen2_5_VLVisionAttention(nn.Module):
query_states = query_states.transpose(0, 1).unsqueeze(0)
key_states = key_states.transpose(0, 1).unsqueeze(0)
value_states = value_states.transpose(0, 1).unsqueeze(0)
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
attn_output, _ = attention_interface(
self,
query_states,
key_states,
value_states,
attention_mask=attention_mask,
dropout=0.0 if not self.training else self.attention_dropout,
scaling=self.scaling,
cu_seq_lens_q=cu_seqlens, # pass cu seq lens for FA2
cu_seq_lens_k=cu_seqlens,
max_length_q=max_seqlen,
max_length_k=max_seqlen,
is_causal=False,
**kwargs,
)
if self.config._attn_implementation == "flash_attention_2":
# Flash Attention 2: Use cu_seqlens for variable length attention
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
attn_output, _ = attention_interface(
self,
query_states,
key_states,
value_states,
attention_mask=None,
scaling=self.scaling,
dropout=0.0 if not self.training else self.attention_dropout,
cu_seq_lens_q=cu_seqlens,
cu_seq_lens_k=cu_seqlens,
max_length_q=max_seqlen,
max_length_k=max_seqlen,
is_causal=False,
**kwargs,
)
else:
# Other implementations: Process each chunk separately
lengths = cu_seqlens[1:] - cu_seqlens[:-1]
splits = [
torch.split(tensor, lengths.tolist(), dim=2) for tensor in (query_states, key_states, value_states)
]
attn_outputs = [
attention_interface(
self,
q,
k,
v,
attention_mask=None,
scaling=self.scaling,
dropout=0.0 if not self.training else self.attention_dropout,
is_causal=False,
**kwargs,
)[0]
for q, k, v in zip(*splits)
]
attn_output = torch.cat(attn_outputs, dim=1)
attn_output = attn_output.reshape(seq_length, -1).contiguous()
attn_output = self.proj(attn_output)
@ -280,7 +303,6 @@ class Qwen2_5_VLVisionBlock(GradientCheckpointingLayer):
cu_seqlens: torch.Tensor,
rotary_pos_emb: Optional[torch.Tensor] = None,
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
hidden_states = hidden_states + self.attn(
@ -288,7 +310,6 @@ class Qwen2_5_VLVisionBlock(GradientCheckpointingLayer):
cu_seqlens=cu_seqlens,
rotary_pos_emb=rotary_pos_emb,
position_embeddings=position_embeddings,
attention_mask=attention_mask,
**kwargs,
)
hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
@ -422,25 +443,6 @@ class Qwen2_5_VisionTransformerPretrainedModel(Qwen2_5_VLPreTrainedModel):
return window_index, cu_window_seqlens
def _prepare_attention_mask(self, inputs_tensor: torch.Tensor, cu_seqlens: torch.Tensor) -> torch.Tensor:
# Flash Attention 2 doesn't need a 4D mask and relies on `cu_seqlens/max_seqlen`
# NOTE: the created attention masl only approximates the ragged FA2 attention by
# allowing bidirectional attention within `cu_seqlens` blocks, and not attending between
# blocks. Though it will not be a 100% match for FA2's `varlen` path
if self.config._attn_implementation == "flash_attention_2":
return None
seq_length = inputs_tensor.shape[0]
attention_mask = torch.full(
[1, 1, seq_length, seq_length],
torch.finfo(inputs_tensor.dtype).min,
device=inputs_tensor.device,
dtype=inputs_tensor.dtype,
)
for i in range(1, len(cu_seqlens)):
attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0
return attention_mask
def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs) -> torch.Tensor:
"""
Args:
@ -488,12 +490,10 @@ class Qwen2_5_VisionTransformerPretrainedModel(Qwen2_5_VLPreTrainedModel):
else:
cu_seqlens_now = cu_window_seqlens
attention_mask = self._prepare_attention_mask(hidden_states, cu_seqlens_now)
hidden_states = blk(
hidden_states,
cu_seqlens=cu_seqlens_now,
position_embeddings=position_embeddings,
attention_mask=attention_mask,
**kwargs,
)

View File

@ -159,7 +159,6 @@ class Qwen2_5_VLVisionBlock(GradientCheckpointingLayer):
cu_seqlens: torch.Tensor,
rotary_pos_emb: Optional[torch.Tensor] = None,
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
hidden_states = hidden_states + self.attn(
@ -167,7 +166,6 @@ class Qwen2_5_VLVisionBlock(GradientCheckpointingLayer):
cu_seqlens=cu_seqlens,
rotary_pos_emb=rotary_pos_emb,
position_embeddings=position_embeddings,
attention_mask=attention_mask,
**kwargs,
)
hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
@ -289,25 +287,6 @@ class Qwen2_5_VisionTransformerPretrainedModel(Qwen2_5_VLPreTrainedModel):
return window_index, cu_window_seqlens
def _prepare_attention_mask(self, inputs_tensor: torch.Tensor, cu_seqlens: torch.Tensor) -> torch.Tensor:
# Flash Attention 2 doesn't need a 4D mask and relies on `cu_seqlens/max_seqlen`
# NOTE: the created attention masl only approximates the ragged FA2 attention by
# allowing bidirectional attention within `cu_seqlens` blocks, and not attending between
# blocks. Though it will not be a 100% match for FA2's `varlen` path
if self.config._attn_implementation == "flash_attention_2":
return None
seq_length = inputs_tensor.shape[0]
attention_mask = torch.full(
[1, 1, seq_length, seq_length],
torch.finfo(inputs_tensor.dtype).min,
device=inputs_tensor.device,
dtype=inputs_tensor.dtype,
)
for i in range(1, len(cu_seqlens)):
attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0
return attention_mask
def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs) -> torch.Tensor:
"""
Args:
@ -355,12 +334,10 @@ class Qwen2_5_VisionTransformerPretrainedModel(Qwen2_5_VLPreTrainedModel):
else:
cu_seqlens_now = cu_window_seqlens
attention_mask = self._prepare_attention_mask(hidden_states, cu_seqlens_now)
hidden_states = blk(
hidden_states,
cu_seqlens=cu_seqlens_now,
position_embeddings=position_embeddings,
attention_mask=attention_mask,
**kwargs,
)

View File

@ -333,7 +333,6 @@ class VisionAttention(nn.Module):
cu_seqlens: torch.Tensor,
rotary_pos_emb: Optional[torch.Tensor] = None,
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
seq_length = hidden_states.shape[0]
@ -357,27 +356,51 @@ class VisionAttention(nn.Module):
query_states = query_states.transpose(0, 1).unsqueeze(0)
key_states = key_states.transpose(0, 1).unsqueeze(0)
value_states = value_states.transpose(0, 1).unsqueeze(0)
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
attn_output, _ = attention_interface(
self,
query_states,
key_states,
value_states,
attention_mask=attention_mask,
dropout=0.0 if not self.training else self.attention_dropout,
scaling=self.scaling,
cu_seq_lens_q=cu_seqlens, # pass cu seq lens for FA2
cu_seq_lens_k=cu_seqlens,
max_length_q=max_seqlen,
max_length_k=max_seqlen,
is_causal=False,
**kwargs,
)
if self.config._attn_implementation == "flash_attention_2":
# Flash Attention 2: Use cu_seqlens for variable length attention
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
attn_output, _ = attention_interface(
self,
query_states,
key_states,
value_states,
attention_mask=None,
scaling=self.scaling,
dropout=0.0 if not self.training else self.attention_dropout,
cu_seq_lens_q=cu_seqlens,
cu_seq_lens_k=cu_seqlens,
max_length_q=max_seqlen,
max_length_k=max_seqlen,
is_causal=False,
**kwargs,
)
else:
# Other implementations: Process each chunk separately
lengths = cu_seqlens[1:] - cu_seqlens[:-1]
splits = [
torch.split(tensor, lengths.tolist(), dim=2) for tensor in (query_states, key_states, value_states)
]
attn_outputs = [
attention_interface(
self,
q,
k,
v,
attention_mask=None,
scaling=self.scaling,
dropout=0.0 if not self.training else self.attention_dropout,
is_causal=False,
**kwargs,
)[0]
for q, k, v in zip(*splits)
]
attn_output = torch.cat(attn_outputs, dim=1)
attn_output = attn_output.reshape(seq_length, -1).contiguous()
attn_output = self.proj(attn_output)
@ -400,7 +423,6 @@ class Qwen2VLVisionBlock(GradientCheckpointingLayer):
cu_seqlens: torch.Tensor,
rotary_pos_emb: Optional[torch.Tensor] = None,
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
hidden_states = hidden_states + self.attn(
@ -408,7 +430,6 @@ class Qwen2VLVisionBlock(GradientCheckpointingLayer):
cu_seqlens=cu_seqlens,
rotary_pos_emb=rotary_pos_emb,
position_embeddings=position_embeddings,
attention_mask=attention_mask,
**kwargs,
)
hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
@ -721,25 +742,6 @@ class Qwen2VisionTransformerPretrainedModel(Qwen2VLPreTrainedModel):
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
return rotary_pos_emb
def _prepare_attention_mask(self, inputs_tensor: torch.Tensor, cu_seqlens: torch.Tensor) -> torch.Tensor:
# Flash Attention 2 doesn't need a 4D mask and relies on `cu_seqlens/max_seqlen`
# NOTE: the created attention masl only approximates the ragged FA2 attention by
# allowing bidirectional attention within `cu_seqlens` blocks, and not attending between
# blocks. Though it will not be a 100% match for FA2's `varlen` path
if self.config._attn_implementation == "flash_attention_2":
return None
seq_length = inputs_tensor.shape[0]
attention_mask = torch.full(
[1, 1, seq_length, seq_length],
torch.finfo(inputs_tensor.dtype).min,
device=inputs_tensor.device,
dtype=inputs_tensor.dtype,
)
for i in range(1, len(cu_seqlens)):
attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0
return attention_mask
@auto_docstring
def forward(
self,
@ -765,14 +767,12 @@ class Qwen2VisionTransformerPretrainedModel(Qwen2VLPreTrainedModel):
dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32,
)
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
attention_mask = self._prepare_attention_mask(hidden_states, cu_seqlens)
for blk in self.blocks:
hidden_states = blk(
hidden_states,
cu_seqlens=cu_seqlens,
position_embeddings=position_embeddings,
attention_mask=attention_mask,
**kwargs,
)