fix qwen2vl vision eager-attention (#33213)

* fix-qwen2vl-vision-eager-attention

* code-quality

* Update src/transformers/models/qwen2_vl/modeling_qwen2_vl.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* code-quality

---------

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
Shijie 2024-09-06 19:42:17 +08:00 committed by GitHub
parent 51d15eb1c1
commit 1bd9d1c899
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -275,6 +275,7 @@ class VisionAttention(nn.Module):
def __init__(self, dim: int, num_heads: int = 16) -> None:
super().__init__()
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.qkv = nn.Linear(dim, dim * 3, bias=True)
self.proj = nn.Linear(dim, dim)
@ -286,9 +287,11 @@ class VisionAttention(nn.Module):
q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0)
k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0)
attention_mask = torch.zeros([1, seq_length, seq_length], device=q.device, dtype=torch.bool)
attention_mask = torch.full(
[1, seq_length, seq_length], torch.finfo(q.dtype).min, device=q.device, dtype=q.dtype
)
for i in range(1, len(cu_seqlens)):
attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = True
attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0
q = q.transpose(0, 1)
k = k.transpose(0, 1)