From 3c289e2104bf16b60607d626c6fc2e3eceb84f45 Mon Sep 17 00:00:00 2001 From: Zhen <295632982@qq.com> Date: Fri, 23 May 2025 23:24:51 +0800 Subject: [PATCH] [performance_optim] reduce frequency of declaring attention_mask in Ascend NPU flash attention (#38278) [performance_optim] reduce frequency of declaring attention_mask in ASCEND NPU flash attention --- .../integrations/npu_flash_attention.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/src/transformers/integrations/npu_flash_attention.py b/src/transformers/integrations/npu_flash_attention.py index bb515540d14..e32af9f4bc9 100644 --- a/src/transformers/integrations/npu_flash_attention.py +++ b/src/transformers/integrations/npu_flash_attention.py @@ -37,6 +37,8 @@ if SPARSE_MODE not in [TOP_LEFT_ALIGNED_CAUSAL_MASK_MODE, DOWN_RIGHT_ALIGNED_CAU "or 3 (down-right aligned causal mask)." ) +ATTN_MASK_NPU = None + def is_npu_fa2_top_left_aligned_causal_mask(): return SPARSE_MODE == TOP_LEFT_ALIGNED_CAUSAL_MASK_MODE if is_torch_npu_available() else False @@ -171,7 +173,9 @@ def npu_flash_attn_func( head_num = q.shape[2] output = torch_npu.npu_fusion_attention(q, k, v, head_num, "BSND", keep_prob=keep_prob, scale=softmax_scale)[0] else: - attn_mask_npu = torch.triu(torch.ones([2048, 2048], device=q.device), diagonal=1).bool() + global ATTN_MASK_NPU + if ATTN_MASK_NPU is None: + ATTN_MASK_NPU = torch.triu(torch.ones([2048, 2048], device=q.device), diagonal=1).bool() head_num = q.shape[2] output = torch_npu.npu_fusion_attention( q, @@ -181,7 +185,7 @@ def npu_flash_attn_func( "BSND", keep_prob=keep_prob, scale=softmax_scale, - atten_mask=attn_mask_npu, + atten_mask=ATTN_MASK_NPU, sparse_mode=SPARSE_MODE, )[0] @@ -222,7 +226,9 @@ def npu_flash_attn_varlen_func( actual_seq_kvlen=tuple(cu_seqlens_k[1:].cpu().numpy().tolist()), )[0] else: - attn_mask_npu = torch.triu(torch.ones([2048, 2048], device=q.device), diagonal=1).bool() + global ATTN_MASK_NPU + if ATTN_MASK_NPU is None: + ATTN_MASK_NPU = torch.triu(torch.ones([2048, 2048], device=q.device), diagonal=1).bool() head_num = q.shape[1] output = torch_npu.npu_fusion_attention( q, @@ -231,7 +237,7 @@ def npu_flash_attn_varlen_func( head_num, pse=None, padding_mask=None, - atten_mask=attn_mask_npu, + atten_mask=ATTN_MASK_NPU, scale=softmax_scale, keep_prob=keep_prob, input_layout="TND",