[performance_optim] define flash attention mask on NPU device directly (#37698)

Co-authored-by: Mohamed Mekkouri <93391238+MekkCyber@users.noreply.github.com>
This commit is contained in:
Zhen 2025-04-24 20:06:47 +08:00 committed by GitHub
parent 14e28bd721
commit 0327d0f7f2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -171,7 +171,7 @@ def npu_flash_attn_func(
head_num = q.shape[2]
output = torch_npu.npu_fusion_attention(q, k, v, head_num, "BSND", keep_prob=keep_prob, scale=softmax_scale)[0]
else:
attn_mask_npu = torch.triu(torch.ones([2048, 2048]), diagonal=1).bool().to(q.device)
attn_mask_npu = torch.triu(torch.ones([2048, 2048], device=q.device), diagonal=1).bool()
head_num = q.shape[2]
output = torch_npu.npu_fusion_attention(
q,
@ -222,7 +222,7 @@ def npu_flash_attn_varlen_func(
actual_seq_kvlen=tuple(cu_seqlens_k[1:].cpu().numpy().tolist()),
)[0]
else:
attn_mask_npu = torch.triu(torch.ones([2048, 2048]), diagonal=1).bool().to(q.device)
attn_mask_npu = torch.triu(torch.ones([2048, 2048], device=q.device), diagonal=1).bool()
head_num = q.shape[1]
output = torch_npu.npu_fusion_attention(
q,