diff --git a/src/transformers/integrations/npu_flash_attention.py b/src/transformers/integrations/npu_flash_attention.py index b6a6001729c..dd8a6dc5d07 100644 --- a/src/transformers/integrations/npu_flash_attention.py +++ b/src/transformers/integrations/npu_flash_attention.py @@ -38,7 +38,14 @@ if SPARSE_MODE not in [TOP_LEFT_ALIGNED_CAUSAL_MASK_MODE, DOWN_RIGHT_ALIGNED_CAU "or 3 (down-right aligned causal mask)." ) -ATTN_MASK_NPU = None +ATTN_MASK_NPU_CACHE = {} + + +def get_attn_mask_npu(device): + """Get or create attention mask for the specified device.""" + if device not in ATTN_MASK_NPU_CACHE: + ATTN_MASK_NPU_CACHE[device] = torch.triu(torch.ones([2048, 2048], device=device), diagonal=1).bool() + return ATTN_MASK_NPU_CACHE[device] def is_npu_fa2_top_left_aligned_causal_mask(): @@ -174,9 +181,7 @@ def npu_flash_attn_func( head_num = q.shape[2] output = torch_npu.npu_fusion_attention(q, k, v, head_num, "BSND", keep_prob=keep_prob, scale=softmax_scale)[0] else: - global ATTN_MASK_NPU - if ATTN_MASK_NPU is None: - ATTN_MASK_NPU = torch.triu(torch.ones([2048, 2048], device=q.device), diagonal=1).bool() + attn_mask_npu = get_attn_mask_npu(q.device) head_num = q.shape[2] output = torch_npu.npu_fusion_attention( q, @@ -186,7 +191,7 @@ def npu_flash_attn_func( "BSND", keep_prob=keep_prob, scale=softmax_scale, - atten_mask=ATTN_MASK_NPU, + atten_mask=attn_mask_npu, sparse_mode=SPARSE_MODE, )[0] @@ -227,9 +232,7 @@ def npu_flash_attn_varlen_func( actual_seq_kvlen=tuple(cu_seqlens_k[1:].cpu().numpy().tolist()), )[0] else: - global ATTN_MASK_NPU - if ATTN_MASK_NPU is None: - ATTN_MASK_NPU = torch.triu(torch.ones([2048, 2048], device=q.device), diagonal=1).bool() + attn_mask_npu = get_attn_mask_npu(q.device) head_num = q.shape[1] output = torch_npu.npu_fusion_attention( q, @@ -238,7 +241,7 @@ def npu_flash_attn_varlen_func( head_num, pse=None, padding_mask=None, - atten_mask=ATTN_MASK_NPU, + atten_mask=attn_mask_npu, scale=softmax_scale, keep_prob=keep_prob, input_layout="TND",