mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-16 19:18:24 +06:00
[bugfix] fix ATTN_MASK_NPU device mismatch error on multi-device NPU … (#38876)
[bugfix] fix ATTN_MASK_NPU device mismatch error on multi-device NPU setups
This commit is contained in:
parent
9cd7570f34
commit
c55d806355
@ -38,7 +38,14 @@ if SPARSE_MODE not in [TOP_LEFT_ALIGNED_CAUSAL_MASK_MODE, DOWN_RIGHT_ALIGNED_CAU
|
|||||||
"or 3 (down-right aligned causal mask)."
|
"or 3 (down-right aligned causal mask)."
|
||||||
)
|
)
|
||||||
|
|
||||||
ATTN_MASK_NPU = None
|
ATTN_MASK_NPU_CACHE = {}
|
||||||
|
|
||||||
|
|
||||||
|
def get_attn_mask_npu(device):
|
||||||
|
"""Get or create attention mask for the specified device."""
|
||||||
|
if device not in ATTN_MASK_NPU_CACHE:
|
||||||
|
ATTN_MASK_NPU_CACHE[device] = torch.triu(torch.ones([2048, 2048], device=device), diagonal=1).bool()
|
||||||
|
return ATTN_MASK_NPU_CACHE[device]
|
||||||
|
|
||||||
|
|
||||||
def is_npu_fa2_top_left_aligned_causal_mask():
|
def is_npu_fa2_top_left_aligned_causal_mask():
|
||||||
@ -174,9 +181,7 @@ def npu_flash_attn_func(
|
|||||||
head_num = q.shape[2]
|
head_num = q.shape[2]
|
||||||
output = torch_npu.npu_fusion_attention(q, k, v, head_num, "BSND", keep_prob=keep_prob, scale=softmax_scale)[0]
|
output = torch_npu.npu_fusion_attention(q, k, v, head_num, "BSND", keep_prob=keep_prob, scale=softmax_scale)[0]
|
||||||
else:
|
else:
|
||||||
global ATTN_MASK_NPU
|
attn_mask_npu = get_attn_mask_npu(q.device)
|
||||||
if ATTN_MASK_NPU is None:
|
|
||||||
ATTN_MASK_NPU = torch.triu(torch.ones([2048, 2048], device=q.device), diagonal=1).bool()
|
|
||||||
head_num = q.shape[2]
|
head_num = q.shape[2]
|
||||||
output = torch_npu.npu_fusion_attention(
|
output = torch_npu.npu_fusion_attention(
|
||||||
q,
|
q,
|
||||||
@ -186,7 +191,7 @@ def npu_flash_attn_func(
|
|||||||
"BSND",
|
"BSND",
|
||||||
keep_prob=keep_prob,
|
keep_prob=keep_prob,
|
||||||
scale=softmax_scale,
|
scale=softmax_scale,
|
||||||
atten_mask=ATTN_MASK_NPU,
|
atten_mask=attn_mask_npu,
|
||||||
sparse_mode=SPARSE_MODE,
|
sparse_mode=SPARSE_MODE,
|
||||||
)[0]
|
)[0]
|
||||||
|
|
||||||
@ -227,9 +232,7 @@ def npu_flash_attn_varlen_func(
|
|||||||
actual_seq_kvlen=tuple(cu_seqlens_k[1:].cpu().numpy().tolist()),
|
actual_seq_kvlen=tuple(cu_seqlens_k[1:].cpu().numpy().tolist()),
|
||||||
)[0]
|
)[0]
|
||||||
else:
|
else:
|
||||||
global ATTN_MASK_NPU
|
attn_mask_npu = get_attn_mask_npu(q.device)
|
||||||
if ATTN_MASK_NPU is None:
|
|
||||||
ATTN_MASK_NPU = torch.triu(torch.ones([2048, 2048], device=q.device), diagonal=1).bool()
|
|
||||||
head_num = q.shape[1]
|
head_num = q.shape[1]
|
||||||
output = torch_npu.npu_fusion_attention(
|
output = torch_npu.npu_fusion_attention(
|
||||||
q,
|
q,
|
||||||
@ -238,7 +241,7 @@ def npu_flash_attn_varlen_func(
|
|||||||
head_num,
|
head_num,
|
||||||
pse=None,
|
pse=None,
|
||||||
padding_mask=None,
|
padding_mask=None,
|
||||||
atten_mask=ATTN_MASK_NPU,
|
atten_mask=attn_mask_npu,
|
||||||
scale=softmax_scale,
|
scale=softmax_scale,
|
||||||
keep_prob=keep_prob,
|
keep_prob=keep_prob,
|
||||||
input_layout="TND",
|
input_layout="TND",
|
||||||
|
Loading…
Reference in New Issue
Block a user