mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
fix qwen2vl vision eager-attention (#33213)
* fix-qwen2vl-vision-eager-attention * code-quality * Update src/transformers/models/qwen2_vl/modeling_qwen2_vl.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * code-quality --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
parent
51d15eb1c1
commit
1bd9d1c899
@ -275,6 +275,7 @@ class VisionAttention(nn.Module):
|
||||
def __init__(self, dim: int, num_heads: int = 16) -> None:
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = dim // num_heads
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=True)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
|
||||
@ -286,9 +287,11 @@ class VisionAttention(nn.Module):
|
||||
q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0)
|
||||
k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0)
|
||||
|
||||
attention_mask = torch.zeros([1, seq_length, seq_length], device=q.device, dtype=torch.bool)
|
||||
attention_mask = torch.full(
|
||||
[1, seq_length, seq_length], torch.finfo(q.dtype).min, device=q.device, dtype=q.dtype
|
||||
)
|
||||
for i in range(1, len(cu_seqlens)):
|
||||
attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = True
|
||||
attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0
|
||||
|
||||
q = q.transpose(0, 1)
|
||||
k = k.transpose(0, 1)
|
||||
|
Loading…
Reference in New Issue
Block a user