mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 12:50:06 +06:00
Merge 95cfafe348
into ebfbcd42da
This commit is contained in:
commit
e48161a471
@ -296,7 +296,6 @@ class Glm4vVisionAttention(nn.Module):
|
||||
cu_seqlens: torch.Tensor,
|
||||
rotary_pos_emb: Optional[torch.Tensor] = None,
|
||||
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
seq_length = hidden_states.shape[0]
|
||||
@ -320,27 +319,51 @@ class Glm4vVisionAttention(nn.Module):
|
||||
query_states = query_states.transpose(0, 1).unsqueeze(0)
|
||||
key_states = key_states.transpose(0, 1).unsqueeze(0)
|
||||
value_states = value_states.transpose(0, 1).unsqueeze(0)
|
||||
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
|
||||
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
if self.config._attn_implementation != "eager":
|
||||
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
||||
|
||||
attn_output, _ = attention_interface(
|
||||
self,
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attention_mask=attention_mask,
|
||||
dropout=0.0 if not self.training else self.attention_dropout,
|
||||
scaling=self.scaling,
|
||||
cu_seq_lens_q=cu_seqlens, # pass cu seq lens for FA2
|
||||
cu_seq_lens_k=cu_seqlens,
|
||||
max_length_q=max_seqlen,
|
||||
max_length_k=max_seqlen,
|
||||
is_causal=False,
|
||||
**kwargs,
|
||||
)
|
||||
if self.config._attn_implementation == "flash_attention_2":
|
||||
# Flash Attention 2: Use cu_seqlens for variable length attention
|
||||
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
|
||||
attn_output, _ = attention_interface(
|
||||
self,
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attention_mask=None,
|
||||
scaling=self.scaling,
|
||||
dropout=0.0 if not self.training else self.attention_dropout,
|
||||
cu_seq_lens_q=cu_seqlens,
|
||||
cu_seq_lens_k=cu_seqlens,
|
||||
max_length_q=max_seqlen,
|
||||
max_length_k=max_seqlen,
|
||||
is_causal=False,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
# Other implementations: Process each chunk separately
|
||||
lengths = cu_seqlens[1:] - cu_seqlens[:-1]
|
||||
splits = [
|
||||
torch.split(tensor, lengths.tolist(), dim=2) for tensor in (query_states, key_states, value_states)
|
||||
]
|
||||
|
||||
attn_outputs = [
|
||||
attention_interface(
|
||||
self,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
attention_mask=None,
|
||||
scaling=self.scaling,
|
||||
dropout=0.0 if not self.training else self.attention_dropout,
|
||||
is_causal=False,
|
||||
**kwargs,
|
||||
)[0]
|
||||
for q, k, v in zip(*splits)
|
||||
]
|
||||
attn_output = torch.cat(attn_outputs, dim=1)
|
||||
|
||||
attn_output = attn_output.reshape(seq_length, -1).contiguous()
|
||||
attn_output = self.proj(attn_output)
|
||||
@ -361,7 +384,6 @@ class Glm4vVisionBlock(GradientCheckpointingLayer):
|
||||
cu_seqlens: torch.Tensor,
|
||||
rotary_pos_emb: Optional[torch.Tensor] = None,
|
||||
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = hidden_states + self.attn(
|
||||
@ -369,7 +391,6 @@ class Glm4vVisionBlock(GradientCheckpointingLayer):
|
||||
cu_seqlens=cu_seqlens,
|
||||
rotary_pos_emb=rotary_pos_emb,
|
||||
position_embeddings=position_embeddings,
|
||||
attention_mask=attention_mask,
|
||||
**kwargs,
|
||||
)
|
||||
hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
|
||||
@ -467,25 +488,6 @@ class Glm4vVisionModel(Glm4vPreTrainedModel):
|
||||
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
|
||||
return rotary_pos_emb, pos_ids
|
||||
|
||||
def _prepare_attention_mask(self, inputs_tensor: torch.Tensor, cu_seqlens: torch.Tensor) -> torch.Tensor:
|
||||
# Flash Attention 2 doesn't need a 4D mask and relies on `cu_seqlens/max_seqlen`
|
||||
# NOTE: the created attention masl only approximates the ragged FA2 attention by
|
||||
# allowing bidirectional attention within `cu_seqlens` blocks, and not attending between
|
||||
# blocks. Though it will not be a 100% match for FA2's `varlen` path
|
||||
if self.config._attn_implementation == "flash_attention_2":
|
||||
return None
|
||||
|
||||
seq_length = inputs_tensor.shape[0]
|
||||
attention_mask = torch.full(
|
||||
[1, 1, seq_length, seq_length],
|
||||
torch.finfo(inputs_tensor.dtype).min,
|
||||
device=inputs_tensor.device,
|
||||
dtype=inputs_tensor.dtype,
|
||||
)
|
||||
for i in range(1, len(cu_seqlens)):
|
||||
attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0
|
||||
return attention_mask
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
@ -515,14 +517,12 @@ class Glm4vVisionModel(Glm4vPreTrainedModel):
|
||||
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
|
||||
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
|
||||
hidden_states = self.embeddings(hidden_states, seqlens, grid_thw, image_type_ids[:, 0], image_type_ids[:, 1])
|
||||
attention_mask = self._prepare_attention_mask(hidden_states, cu_seqlens=cu_seqlens)
|
||||
|
||||
for blk in self.blocks:
|
||||
hidden_states = blk(
|
||||
hidden_states,
|
||||
cu_seqlens=cu_seqlens,
|
||||
position_embeddings=position_embeddings,
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
|
||||
hidden_states = self.post_layernorm(hidden_states)
|
||||
|
@ -603,25 +603,6 @@ class Glm4vVisionModel(Glm4vPreTrainedModel):
|
||||
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
|
||||
return rotary_pos_emb, pos_ids
|
||||
|
||||
def _prepare_attention_mask(self, inputs_tensor: torch.Tensor, cu_seqlens: torch.Tensor) -> torch.Tensor:
|
||||
# Flash Attention 2 doesn't need a 4D mask and relies on `cu_seqlens/max_seqlen`
|
||||
# NOTE: the created attention masl only approximates the ragged FA2 attention by
|
||||
# allowing bidirectional attention within `cu_seqlens` blocks, and not attending between
|
||||
# blocks. Though it will not be a 100% match for FA2's `varlen` path
|
||||
if self.config._attn_implementation == "flash_attention_2":
|
||||
return None
|
||||
|
||||
seq_length = inputs_tensor.shape[0]
|
||||
attention_mask = torch.full(
|
||||
[1, 1, seq_length, seq_length],
|
||||
torch.finfo(inputs_tensor.dtype).min,
|
||||
device=inputs_tensor.device,
|
||||
dtype=inputs_tensor.dtype,
|
||||
)
|
||||
for i in range(1, len(cu_seqlens)):
|
||||
attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0
|
||||
return attention_mask
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
@ -651,14 +632,12 @@ class Glm4vVisionModel(Glm4vPreTrainedModel):
|
||||
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
|
||||
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
|
||||
hidden_states = self.embeddings(hidden_states, seqlens, grid_thw, image_type_ids[:, 0], image_type_ids[:, 1])
|
||||
attention_mask = self._prepare_attention_mask(hidden_states, cu_seqlens=cu_seqlens)
|
||||
|
||||
for blk in self.blocks:
|
||||
hidden_states = blk(
|
||||
hidden_states,
|
||||
cu_seqlens=cu_seqlens,
|
||||
position_embeddings=position_embeddings,
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
|
||||
hidden_states = self.post_layernorm(hidden_states)
|
||||
|
@ -956,7 +956,6 @@ class Qwen2_5OmniVisionAttention(nn.Module):
|
||||
hidden_states: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor,
|
||||
rotary_pos_emb: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
seq_length = hidden_states.shape[0]
|
||||
@ -969,27 +968,51 @@ class Qwen2_5OmniVisionAttention(nn.Module):
|
||||
query_states = query_states.transpose(0, 1).unsqueeze(0)
|
||||
key_states = key_states.transpose(0, 1).unsqueeze(0)
|
||||
value_states = value_states.transpose(0, 1).unsqueeze(0)
|
||||
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
|
||||
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
if self.config._attn_implementation != "eager":
|
||||
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
||||
|
||||
attn_output, _ = attention_interface(
|
||||
self,
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attention_mask=attention_mask,
|
||||
dropout=0.0 if not self.training else self.attention_dropout,
|
||||
scaling=self.scaling,
|
||||
cu_seq_lens_q=cu_seqlens, # pass cu seq lens for FA2
|
||||
cu_seq_lens_k=cu_seqlens,
|
||||
max_length_q=max_seqlen,
|
||||
max_length_k=max_seqlen,
|
||||
is_causal=False,
|
||||
**kwargs,
|
||||
)
|
||||
if self.config._attn_implementation == "flash_attention_2":
|
||||
# Flash Attention 2: Use cu_seqlens for variable length attention
|
||||
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
|
||||
attn_output, _ = attention_interface(
|
||||
self,
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attention_mask=None,
|
||||
scaling=self.scaling,
|
||||
dropout=0.0 if not self.training else self.attention_dropout,
|
||||
cu_seq_lens_q=cu_seqlens,
|
||||
cu_seq_lens_k=cu_seqlens,
|
||||
max_length_q=max_seqlen,
|
||||
max_length_k=max_seqlen,
|
||||
is_causal=False,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
# Other implementations: Process each chunk separately
|
||||
lengths = cu_seqlens[1:] - cu_seqlens[:-1]
|
||||
splits = [
|
||||
torch.split(tensor, lengths.tolist(), dim=2) for tensor in (query_states, key_states, value_states)
|
||||
]
|
||||
|
||||
attn_outputs = [
|
||||
attention_interface(
|
||||
self,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
attention_mask=None,
|
||||
scaling=self.scaling,
|
||||
dropout=0.0 if not self.training else self.attention_dropout,
|
||||
is_causal=False,
|
||||
**kwargs,
|
||||
)[0]
|
||||
for q, k, v in zip(*splits)
|
||||
]
|
||||
attn_output = torch.cat(attn_outputs, dim=1)
|
||||
|
||||
attn_output = attn_output.reshape(seq_length, -1).contiguous()
|
||||
attn_output = self.proj(attn_output)
|
||||
@ -1023,14 +1046,12 @@ class Qwen2_5OmniVisionBlock(GradientCheckpointingLayer):
|
||||
hidden_states: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor,
|
||||
rotary_pos_emb: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = hidden_states + self.attn(
|
||||
self.norm1(hidden_states),
|
||||
cu_seqlens=cu_seqlens,
|
||||
rotary_pos_emb=rotary_pos_emb,
|
||||
attention_mask=attention_mask,
|
||||
**kwargs,
|
||||
)
|
||||
hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
|
||||
@ -1190,25 +1211,6 @@ class Qwen2_5OmniVisionEncoder(Qwen2_5OmniPreTrainedModel):
|
||||
|
||||
return window_index, cu_window_seqlens
|
||||
|
||||
def _prepare_attention_mask(self, inputs_tensor: torch.Tensor, cu_seqlens: torch.Tensor) -> torch.Tensor:
|
||||
# Flash Attention 2 doesn't need a 4D mask and relies on `cu_seqlens/max_seqlen`
|
||||
# NOTE: the created attention masl only approximates the ragged FA2 attention by
|
||||
# allowing bidirectional attention within `cu_seqlens` blocks, and not attending between
|
||||
# blocks. Though it will not be a 100% match for FA2's `varlen` path
|
||||
if self.config._attn_implementation == "flash_attention_2":
|
||||
return None
|
||||
|
||||
seq_length = inputs_tensor.shape[0]
|
||||
attention_mask = torch.full(
|
||||
[1, 1, seq_length, seq_length],
|
||||
torch.finfo(inputs_tensor.dtype).min,
|
||||
device=inputs_tensor.device,
|
||||
dtype=inputs_tensor.dtype,
|
||||
)
|
||||
for i in range(1, len(cu_seqlens)):
|
||||
attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0
|
||||
return attention_mask
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
@ -1256,12 +1258,10 @@ class Qwen2_5OmniVisionEncoder(Qwen2_5OmniPreTrainedModel):
|
||||
else:
|
||||
cu_seqlens_now = cu_window_seqlens
|
||||
|
||||
attention_mask = self._prepare_attention_mask(hidden_states, cu_seqlens_now)
|
||||
hidden_states = blk(
|
||||
hidden_states,
|
||||
cu_seqlens=cu_seqlens_now,
|
||||
rotary_pos_emb=rotary_pos_emb,
|
||||
attention_mask=attention_mask,
|
||||
**kwargs,
|
||||
)
|
||||
hidden_states = self.merger(hidden_states)
|
||||
|
@ -1934,7 +1934,6 @@ class Qwen2_5OmniVisionAttention(nn.Module):
|
||||
hidden_states: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor,
|
||||
rotary_pos_emb: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
seq_length = hidden_states.shape[0]
|
||||
@ -1947,27 +1946,51 @@ class Qwen2_5OmniVisionAttention(nn.Module):
|
||||
query_states = query_states.transpose(0, 1).unsqueeze(0)
|
||||
key_states = key_states.transpose(0, 1).unsqueeze(0)
|
||||
value_states = value_states.transpose(0, 1).unsqueeze(0)
|
||||
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
|
||||
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
if self.config._attn_implementation != "eager":
|
||||
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
||||
|
||||
attn_output, _ = attention_interface(
|
||||
self,
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attention_mask=attention_mask,
|
||||
dropout=0.0 if not self.training else self.attention_dropout,
|
||||
scaling=self.scaling,
|
||||
cu_seq_lens_q=cu_seqlens, # pass cu seq lens for FA2
|
||||
cu_seq_lens_k=cu_seqlens,
|
||||
max_length_q=max_seqlen,
|
||||
max_length_k=max_seqlen,
|
||||
is_causal=False,
|
||||
**kwargs,
|
||||
)
|
||||
if self.config._attn_implementation == "flash_attention_2":
|
||||
# Flash Attention 2: Use cu_seqlens for variable length attention
|
||||
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
|
||||
attn_output, _ = attention_interface(
|
||||
self,
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attention_mask=None,
|
||||
scaling=self.scaling,
|
||||
dropout=0.0 if not self.training else self.attention_dropout,
|
||||
cu_seq_lens_q=cu_seqlens,
|
||||
cu_seq_lens_k=cu_seqlens,
|
||||
max_length_q=max_seqlen,
|
||||
max_length_k=max_seqlen,
|
||||
is_causal=False,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
# Other implementations: Process each chunk separately
|
||||
lengths = cu_seqlens[1:] - cu_seqlens[:-1]
|
||||
splits = [
|
||||
torch.split(tensor, lengths.tolist(), dim=2) for tensor in (query_states, key_states, value_states)
|
||||
]
|
||||
|
||||
attn_outputs = [
|
||||
attention_interface(
|
||||
self,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
attention_mask=None,
|
||||
scaling=self.scaling,
|
||||
dropout=0.0 if not self.training else self.attention_dropout,
|
||||
is_causal=False,
|
||||
**kwargs,
|
||||
)[0]
|
||||
for q, k, v in zip(*splits)
|
||||
]
|
||||
attn_output = torch.cat(attn_outputs, dim=1)
|
||||
|
||||
attn_output = attn_output.reshape(seq_length, -1).contiguous()
|
||||
attn_output = self.proj(attn_output)
|
||||
@ -1984,14 +2007,12 @@ class Qwen2_5OmniVisionBlock(Qwen2_5_VLVisionBlock):
|
||||
hidden_states: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor,
|
||||
rotary_pos_emb: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = hidden_states + self.attn(
|
||||
self.norm1(hidden_states),
|
||||
cu_seqlens=cu_seqlens,
|
||||
rotary_pos_emb=rotary_pos_emb,
|
||||
attention_mask=attention_mask,
|
||||
**kwargs,
|
||||
)
|
||||
hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
|
||||
@ -2006,25 +2027,6 @@ class Qwen2_5OmniVisionEncoder(Qwen2_5_VisionTransformerPretrainedModel):
|
||||
super().__init__(config, *inputs, **kwargs)
|
||||
self.blocks = nn.ModuleList([Qwen2_5OmniVisionBlock(config) for _ in range(config.depth)])
|
||||
|
||||
def _prepare_attention_mask(self, inputs_tensor: torch.Tensor, cu_seqlens: torch.Tensor) -> torch.Tensor:
|
||||
# Flash Attention 2 doesn't need a 4D mask and relies on `cu_seqlens/max_seqlen`
|
||||
# NOTE: the created attention masl only approximates the ragged FA2 attention by
|
||||
# allowing bidirectional attention within `cu_seqlens` blocks, and not attending between
|
||||
# blocks. Though it will not be a 100% match for FA2's `varlen` path
|
||||
if self.config._attn_implementation == "flash_attention_2":
|
||||
return None
|
||||
|
||||
seq_length = inputs_tensor.shape[0]
|
||||
attention_mask = torch.full(
|
||||
[1, 1, seq_length, seq_length],
|
||||
torch.finfo(inputs_tensor.dtype).min,
|
||||
device=inputs_tensor.device,
|
||||
dtype=inputs_tensor.dtype,
|
||||
)
|
||||
for i in range(1, len(cu_seqlens)):
|
||||
attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0
|
||||
return attention_mask
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
@ -2072,12 +2074,10 @@ class Qwen2_5OmniVisionEncoder(Qwen2_5_VisionTransformerPretrainedModel):
|
||||
else:
|
||||
cu_seqlens_now = cu_window_seqlens
|
||||
|
||||
attention_mask = self._prepare_attention_mask(hidden_states, cu_seqlens_now)
|
||||
hidden_states = blk(
|
||||
hidden_states,
|
||||
cu_seqlens=cu_seqlens_now,
|
||||
rotary_pos_emb=rotary_pos_emb,
|
||||
attention_mask=attention_mask,
|
||||
**kwargs,
|
||||
)
|
||||
hidden_states = self.merger(hidden_states)
|
||||
|
@ -215,7 +215,6 @@ class Qwen2_5_VLVisionAttention(nn.Module):
|
||||
cu_seqlens: torch.Tensor,
|
||||
rotary_pos_emb: Optional[torch.Tensor] = None,
|
||||
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
seq_length = hidden_states.shape[0]
|
||||
@ -239,27 +238,51 @@ class Qwen2_5_VLVisionAttention(nn.Module):
|
||||
query_states = query_states.transpose(0, 1).unsqueeze(0)
|
||||
key_states = key_states.transpose(0, 1).unsqueeze(0)
|
||||
value_states = value_states.transpose(0, 1).unsqueeze(0)
|
||||
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
|
||||
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
if self.config._attn_implementation != "eager":
|
||||
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
||||
|
||||
attn_output, _ = attention_interface(
|
||||
self,
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attention_mask=attention_mask,
|
||||
dropout=0.0 if not self.training else self.attention_dropout,
|
||||
scaling=self.scaling,
|
||||
cu_seq_lens_q=cu_seqlens, # pass cu seq lens for FA2
|
||||
cu_seq_lens_k=cu_seqlens,
|
||||
max_length_q=max_seqlen,
|
||||
max_length_k=max_seqlen,
|
||||
is_causal=False,
|
||||
**kwargs,
|
||||
)
|
||||
if self.config._attn_implementation == "flash_attention_2":
|
||||
# Flash Attention 2: Use cu_seqlens for variable length attention
|
||||
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
|
||||
attn_output, _ = attention_interface(
|
||||
self,
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attention_mask=None,
|
||||
scaling=self.scaling,
|
||||
dropout=0.0 if not self.training else self.attention_dropout,
|
||||
cu_seq_lens_q=cu_seqlens,
|
||||
cu_seq_lens_k=cu_seqlens,
|
||||
max_length_q=max_seqlen,
|
||||
max_length_k=max_seqlen,
|
||||
is_causal=False,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
# Other implementations: Process each chunk separately
|
||||
lengths = cu_seqlens[1:] - cu_seqlens[:-1]
|
||||
splits = [
|
||||
torch.split(tensor, lengths.tolist(), dim=2) for tensor in (query_states, key_states, value_states)
|
||||
]
|
||||
|
||||
attn_outputs = [
|
||||
attention_interface(
|
||||
self,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
attention_mask=None,
|
||||
scaling=self.scaling,
|
||||
dropout=0.0 if not self.training else self.attention_dropout,
|
||||
is_causal=False,
|
||||
**kwargs,
|
||||
)[0]
|
||||
for q, k, v in zip(*splits)
|
||||
]
|
||||
attn_output = torch.cat(attn_outputs, dim=1)
|
||||
|
||||
attn_output = attn_output.reshape(seq_length, -1).contiguous()
|
||||
attn_output = self.proj(attn_output)
|
||||
@ -280,7 +303,6 @@ class Qwen2_5_VLVisionBlock(GradientCheckpointingLayer):
|
||||
cu_seqlens: torch.Tensor,
|
||||
rotary_pos_emb: Optional[torch.Tensor] = None,
|
||||
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = hidden_states + self.attn(
|
||||
@ -288,7 +310,6 @@ class Qwen2_5_VLVisionBlock(GradientCheckpointingLayer):
|
||||
cu_seqlens=cu_seqlens,
|
||||
rotary_pos_emb=rotary_pos_emb,
|
||||
position_embeddings=position_embeddings,
|
||||
attention_mask=attention_mask,
|
||||
**kwargs,
|
||||
)
|
||||
hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
|
||||
@ -422,25 +443,6 @@ class Qwen2_5_VisionTransformerPretrainedModel(Qwen2_5_VLPreTrainedModel):
|
||||
|
||||
return window_index, cu_window_seqlens
|
||||
|
||||
def _prepare_attention_mask(self, inputs_tensor: torch.Tensor, cu_seqlens: torch.Tensor) -> torch.Tensor:
|
||||
# Flash Attention 2 doesn't need a 4D mask and relies on `cu_seqlens/max_seqlen`
|
||||
# NOTE: the created attention masl only approximates the ragged FA2 attention by
|
||||
# allowing bidirectional attention within `cu_seqlens` blocks, and not attending between
|
||||
# blocks. Though it will not be a 100% match for FA2's `varlen` path
|
||||
if self.config._attn_implementation == "flash_attention_2":
|
||||
return None
|
||||
|
||||
seq_length = inputs_tensor.shape[0]
|
||||
attention_mask = torch.full(
|
||||
[1, 1, seq_length, seq_length],
|
||||
torch.finfo(inputs_tensor.dtype).min,
|
||||
device=inputs_tensor.device,
|
||||
dtype=inputs_tensor.dtype,
|
||||
)
|
||||
for i in range(1, len(cu_seqlens)):
|
||||
attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0
|
||||
return attention_mask
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
@ -488,12 +490,10 @@ class Qwen2_5_VisionTransformerPretrainedModel(Qwen2_5_VLPreTrainedModel):
|
||||
else:
|
||||
cu_seqlens_now = cu_window_seqlens
|
||||
|
||||
attention_mask = self._prepare_attention_mask(hidden_states, cu_seqlens_now)
|
||||
hidden_states = blk(
|
||||
hidden_states,
|
||||
cu_seqlens=cu_seqlens_now,
|
||||
position_embeddings=position_embeddings,
|
||||
attention_mask=attention_mask,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
@ -159,7 +159,6 @@ class Qwen2_5_VLVisionBlock(GradientCheckpointingLayer):
|
||||
cu_seqlens: torch.Tensor,
|
||||
rotary_pos_emb: Optional[torch.Tensor] = None,
|
||||
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = hidden_states + self.attn(
|
||||
@ -167,7 +166,6 @@ class Qwen2_5_VLVisionBlock(GradientCheckpointingLayer):
|
||||
cu_seqlens=cu_seqlens,
|
||||
rotary_pos_emb=rotary_pos_emb,
|
||||
position_embeddings=position_embeddings,
|
||||
attention_mask=attention_mask,
|
||||
**kwargs,
|
||||
)
|
||||
hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
|
||||
@ -289,25 +287,6 @@ class Qwen2_5_VisionTransformerPretrainedModel(Qwen2_5_VLPreTrainedModel):
|
||||
|
||||
return window_index, cu_window_seqlens
|
||||
|
||||
def _prepare_attention_mask(self, inputs_tensor: torch.Tensor, cu_seqlens: torch.Tensor) -> torch.Tensor:
|
||||
# Flash Attention 2 doesn't need a 4D mask and relies on `cu_seqlens/max_seqlen`
|
||||
# NOTE: the created attention masl only approximates the ragged FA2 attention by
|
||||
# allowing bidirectional attention within `cu_seqlens` blocks, and not attending between
|
||||
# blocks. Though it will not be a 100% match for FA2's `varlen` path
|
||||
if self.config._attn_implementation == "flash_attention_2":
|
||||
return None
|
||||
|
||||
seq_length = inputs_tensor.shape[0]
|
||||
attention_mask = torch.full(
|
||||
[1, 1, seq_length, seq_length],
|
||||
torch.finfo(inputs_tensor.dtype).min,
|
||||
device=inputs_tensor.device,
|
||||
dtype=inputs_tensor.dtype,
|
||||
)
|
||||
for i in range(1, len(cu_seqlens)):
|
||||
attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0
|
||||
return attention_mask
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
@ -355,12 +334,10 @@ class Qwen2_5_VisionTransformerPretrainedModel(Qwen2_5_VLPreTrainedModel):
|
||||
else:
|
||||
cu_seqlens_now = cu_window_seqlens
|
||||
|
||||
attention_mask = self._prepare_attention_mask(hidden_states, cu_seqlens_now)
|
||||
hidden_states = blk(
|
||||
hidden_states,
|
||||
cu_seqlens=cu_seqlens_now,
|
||||
position_embeddings=position_embeddings,
|
||||
attention_mask=attention_mask,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
@ -333,7 +333,6 @@ class VisionAttention(nn.Module):
|
||||
cu_seqlens: torch.Tensor,
|
||||
rotary_pos_emb: Optional[torch.Tensor] = None,
|
||||
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
seq_length = hidden_states.shape[0]
|
||||
@ -357,27 +356,51 @@ class VisionAttention(nn.Module):
|
||||
query_states = query_states.transpose(0, 1).unsqueeze(0)
|
||||
key_states = key_states.transpose(0, 1).unsqueeze(0)
|
||||
value_states = value_states.transpose(0, 1).unsqueeze(0)
|
||||
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
|
||||
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
if self.config._attn_implementation != "eager":
|
||||
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
||||
|
||||
attn_output, _ = attention_interface(
|
||||
self,
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attention_mask=attention_mask,
|
||||
dropout=0.0 if not self.training else self.attention_dropout,
|
||||
scaling=self.scaling,
|
||||
cu_seq_lens_q=cu_seqlens, # pass cu seq lens for FA2
|
||||
cu_seq_lens_k=cu_seqlens,
|
||||
max_length_q=max_seqlen,
|
||||
max_length_k=max_seqlen,
|
||||
is_causal=False,
|
||||
**kwargs,
|
||||
)
|
||||
if self.config._attn_implementation == "flash_attention_2":
|
||||
# Flash Attention 2: Use cu_seqlens for variable length attention
|
||||
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
|
||||
attn_output, _ = attention_interface(
|
||||
self,
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attention_mask=None,
|
||||
scaling=self.scaling,
|
||||
dropout=0.0 if not self.training else self.attention_dropout,
|
||||
cu_seq_lens_q=cu_seqlens,
|
||||
cu_seq_lens_k=cu_seqlens,
|
||||
max_length_q=max_seqlen,
|
||||
max_length_k=max_seqlen,
|
||||
is_causal=False,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
# Other implementations: Process each chunk separately
|
||||
lengths = cu_seqlens[1:] - cu_seqlens[:-1]
|
||||
splits = [
|
||||
torch.split(tensor, lengths.tolist(), dim=2) for tensor in (query_states, key_states, value_states)
|
||||
]
|
||||
|
||||
attn_outputs = [
|
||||
attention_interface(
|
||||
self,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
attention_mask=None,
|
||||
scaling=self.scaling,
|
||||
dropout=0.0 if not self.training else self.attention_dropout,
|
||||
is_causal=False,
|
||||
**kwargs,
|
||||
)[0]
|
||||
for q, k, v in zip(*splits)
|
||||
]
|
||||
attn_output = torch.cat(attn_outputs, dim=1)
|
||||
|
||||
attn_output = attn_output.reshape(seq_length, -1).contiguous()
|
||||
attn_output = self.proj(attn_output)
|
||||
@ -400,7 +423,6 @@ class Qwen2VLVisionBlock(GradientCheckpointingLayer):
|
||||
cu_seqlens: torch.Tensor,
|
||||
rotary_pos_emb: Optional[torch.Tensor] = None,
|
||||
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = hidden_states + self.attn(
|
||||
@ -408,7 +430,6 @@ class Qwen2VLVisionBlock(GradientCheckpointingLayer):
|
||||
cu_seqlens=cu_seqlens,
|
||||
rotary_pos_emb=rotary_pos_emb,
|
||||
position_embeddings=position_embeddings,
|
||||
attention_mask=attention_mask,
|
||||
**kwargs,
|
||||
)
|
||||
hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
|
||||
@ -721,25 +742,6 @@ class Qwen2VisionTransformerPretrainedModel(Qwen2VLPreTrainedModel):
|
||||
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
|
||||
return rotary_pos_emb
|
||||
|
||||
def _prepare_attention_mask(self, inputs_tensor: torch.Tensor, cu_seqlens: torch.Tensor) -> torch.Tensor:
|
||||
# Flash Attention 2 doesn't need a 4D mask and relies on `cu_seqlens/max_seqlen`
|
||||
# NOTE: the created attention masl only approximates the ragged FA2 attention by
|
||||
# allowing bidirectional attention within `cu_seqlens` blocks, and not attending between
|
||||
# blocks. Though it will not be a 100% match for FA2's `varlen` path
|
||||
if self.config._attn_implementation == "flash_attention_2":
|
||||
return None
|
||||
|
||||
seq_length = inputs_tensor.shape[0]
|
||||
attention_mask = torch.full(
|
||||
[1, 1, seq_length, seq_length],
|
||||
torch.finfo(inputs_tensor.dtype).min,
|
||||
device=inputs_tensor.device,
|
||||
dtype=inputs_tensor.dtype,
|
||||
)
|
||||
for i in range(1, len(cu_seqlens)):
|
||||
attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0
|
||||
return attention_mask
|
||||
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
@ -765,14 +767,12 @@ class Qwen2VisionTransformerPretrainedModel(Qwen2VLPreTrainedModel):
|
||||
dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32,
|
||||
)
|
||||
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
|
||||
attention_mask = self._prepare_attention_mask(hidden_states, cu_seqlens)
|
||||
|
||||
for blk in self.blocks:
|
||||
hidden_states = blk(
|
||||
hidden_states,
|
||||
cu_seqlens=cu_seqlens,
|
||||
position_embeddings=position_embeddings,
|
||||
attention_mask=attention_mask,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user