mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 19:21:31 +06:00
[AttentionMaskConverter] fix sdpa unmask unattended (#28369)
fix tensor device
This commit is contained in:
parent
98dba52ccd
commit
87a6cf41d0
@ -234,8 +234,8 @@ class AttentionMaskConverter:
|
||||
|
||||
# Get the index of the first non-zero value for every sample in the batch.
|
||||
# In the above example, indices = [[2], [0], [1]]]
|
||||
tmp = torch.arange(attention_mask.shape[1], 0, -1)
|
||||
indices = torch.argmax(attention_mask.cpu() * tmp, 1, keepdim=True)
|
||||
tmp = torch.arange(attention_mask.shape[1], 0, -1, device=attention_mask.device)
|
||||
indices = torch.argmax(attention_mask * tmp, 1, keepdim=True)
|
||||
|
||||
# Find the batch indexes that have unattended tokens on the leftmost side (e.g. [0, 0, 1, 1, 1]), for which the first rows of the
|
||||
# expanded mask will be completely unattended.
|
||||
|
Loading…
Reference in New Issue
Block a user