mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
[performance_optim] define flash attention mask on NPU device directly (#37698)
Co-authored-by: Mohamed Mekkouri <93391238+MekkCyber@users.noreply.github.com>
This commit is contained in:
parent
14e28bd721
commit
0327d0f7f2
@ -171,7 +171,7 @@ def npu_flash_attn_func(
|
||||
head_num = q.shape[2]
|
||||
output = torch_npu.npu_fusion_attention(q, k, v, head_num, "BSND", keep_prob=keep_prob, scale=softmax_scale)[0]
|
||||
else:
|
||||
attn_mask_npu = torch.triu(torch.ones([2048, 2048]), diagonal=1).bool().to(q.device)
|
||||
attn_mask_npu = torch.triu(torch.ones([2048, 2048], device=q.device), diagonal=1).bool()
|
||||
head_num = q.shape[2]
|
||||
output = torch_npu.npu_fusion_attention(
|
||||
q,
|
||||
@ -222,7 +222,7 @@ def npu_flash_attn_varlen_func(
|
||||
actual_seq_kvlen=tuple(cu_seqlens_k[1:].cpu().numpy().tolist()),
|
||||
)[0]
|
||||
else:
|
||||
attn_mask_npu = torch.triu(torch.ones([2048, 2048]), diagonal=1).bool().to(q.device)
|
||||
attn_mask_npu = torch.triu(torch.ones([2048, 2048], device=q.device), diagonal=1).bool()
|
||||
head_num = q.shape[1]
|
||||
output = torch_npu.npu_fusion_attention(
|
||||
q,
|
||||
|
Loading…
Reference in New Issue
Block a user