mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 12:50:06 +06:00
[performance_optim] reduce frequency of declaring attention_mask in Ascend NPU flash attention (#38278)
[performance_optim] reduce frequency of declaring attention_mask in ASCEND NPU flash attention
This commit is contained in:
parent
f5d45d89c4
commit
3c289e2104
@ -37,6 +37,8 @@ if SPARSE_MODE not in [TOP_LEFT_ALIGNED_CAUSAL_MASK_MODE, DOWN_RIGHT_ALIGNED_CAU
|
||||
"or 3 (down-right aligned causal mask)."
|
||||
)
|
||||
|
||||
ATTN_MASK_NPU = None
|
||||
|
||||
|
||||
def is_npu_fa2_top_left_aligned_causal_mask():
|
||||
return SPARSE_MODE == TOP_LEFT_ALIGNED_CAUSAL_MASK_MODE if is_torch_npu_available() else False
|
||||
@ -171,7 +173,9 @@ def npu_flash_attn_func(
|
||||
head_num = q.shape[2]
|
||||
output = torch_npu.npu_fusion_attention(q, k, v, head_num, "BSND", keep_prob=keep_prob, scale=softmax_scale)[0]
|
||||
else:
|
||||
attn_mask_npu = torch.triu(torch.ones([2048, 2048], device=q.device), diagonal=1).bool()
|
||||
global ATTN_MASK_NPU
|
||||
if ATTN_MASK_NPU is None:
|
||||
ATTN_MASK_NPU = torch.triu(torch.ones([2048, 2048], device=q.device), diagonal=1).bool()
|
||||
head_num = q.shape[2]
|
||||
output = torch_npu.npu_fusion_attention(
|
||||
q,
|
||||
@ -181,7 +185,7 @@ def npu_flash_attn_func(
|
||||
"BSND",
|
||||
keep_prob=keep_prob,
|
||||
scale=softmax_scale,
|
||||
atten_mask=attn_mask_npu,
|
||||
atten_mask=ATTN_MASK_NPU,
|
||||
sparse_mode=SPARSE_MODE,
|
||||
)[0]
|
||||
|
||||
@ -222,7 +226,9 @@ def npu_flash_attn_varlen_func(
|
||||
actual_seq_kvlen=tuple(cu_seqlens_k[1:].cpu().numpy().tolist()),
|
||||
)[0]
|
||||
else:
|
||||
attn_mask_npu = torch.triu(torch.ones([2048, 2048], device=q.device), diagonal=1).bool()
|
||||
global ATTN_MASK_NPU
|
||||
if ATTN_MASK_NPU is None:
|
||||
ATTN_MASK_NPU = torch.triu(torch.ones([2048, 2048], device=q.device), diagonal=1).bool()
|
||||
head_num = q.shape[1]
|
||||
output = torch_npu.npu_fusion_attention(
|
||||
q,
|
||||
@ -231,7 +237,7 @@ def npu_flash_attn_varlen_func(
|
||||
head_num,
|
||||
pse=None,
|
||||
padding_mask=None,
|
||||
atten_mask=attn_mask_npu,
|
||||
atten_mask=ATTN_MASK_NPU,
|
||||
scale=softmax_scale,
|
||||
keep_prob=keep_prob,
|
||||
input_layout="TND",
|
||||
|
Loading…
Reference in New Issue
Block a user