[bugfix] [WIP] fix apply_rotary_emb error on Ascend NPU (#38491)

[bugfix] fix apply_rotary_emb error on Ascend NPU
This commit is contained in:
Zhen 2025-06-03 17:31:49 +08:00 committed by GitHub
parent ca0a682796
commit fdf86fb440
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 35 additions and 23 deletions

View File

@ -23,6 +23,7 @@ if is_torch_npu_available():
import torch_npu
from einops import rearrange, repeat
from torch_npu import npu_rotary_mul
# FlashAttention2 is supported on Ascend NPU with down-right aligned causal mask by default.
@ -247,3 +248,19 @@ def npu_flash_attn_varlen_func(
)[0]
return output
def npu_apply_rotary_emb(x, cos, sin, **kwargs):
# cos tensor after chunk should be repeated through chunked dimension to original shape on Ascend NPU
if len(cos.shape) == 2 and cos.shape[-1] == x.shape[-1] // 2:
cos = cos.repeat(1, 2)
# cos tensor with [S,D] shape should be unsqueezed to 4-d tensor with shape [1,S,1,D]
cos = cos.unsqueeze(0).unsqueeze(2)
# sin tensor after chunk should be repeated through chunked dimension to original shape on Ascend NPU
if len(sin.shape) == 2 and sin.shape[-1] == x.shape[-1] // 2:
sin = sin.repeat(1, 2)
# sin tensor with [S,D] shape should be unsqueezed to 4-d tensor with shape [1,S,1,D]
sin = sin.unsqueeze(0).unsqueeze(2)
return npu_rotary_mul(x, cos, sin)

View File

@ -40,9 +40,8 @@ if is_flash_attn_2_available():
# patch functions in package `flash-attn` when using flash-attention on Ascend NPU.
if is_torch_npu_available():
from torch_npu import npu_rotary_mul as apply_rotary_emb # noqa
from .integrations.npu_flash_attention import index_first_axis, pad_input, unpad_input
from .integrations.npu_flash_attention import npu_apply_rotary_emb as apply_rotary_emb # noqa
from .integrations.npu_flash_attention import npu_flash_attn_func as flash_attn_func
from .integrations.npu_flash_attention import npu_flash_attn_varlen_func as flash_attn_varlen_func

View File

@ -23,6 +23,7 @@ import torch.utils.checkpoint
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available
from ...modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions,
BaseModelOutputWithPoolingAndCrossAttentions,
@ -31,11 +32,11 @@ from ...modeling_outputs import (
TokenClassifierOutput,
)
from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import auto_docstring, is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10, logging
from ...utils import auto_docstring, logging
from .configuration_esm import EsmConfig
if is_flash_attn_2_available():
if is_flash_attn_available():
from ...modeling_flash_attention_utils import _flash_attention_forward
@ -413,7 +414,7 @@ class EsmFlashAttention2(EsmSelfAttention):
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask()
self.dropout_prob = config.attention_probs_dropout_prob
def forward(

View File

@ -34,19 +34,17 @@ from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache
from ...generation import GenerationMixin
from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_flash_attention_utils import (
FlashAttentionKwargs,
flash_attn_supports_top_left_mask,
is_flash_attn_available,
)
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPast, ModelOutput
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...utils import (
auto_docstring,
check_torch_load_is_safe,
is_flash_attn_2_available,
is_flash_attn_greater_or_equal_2_10,
logging,
)
from ...utils import auto_docstring, check_torch_load_is_safe, logging
from ...utils.hub import cached_file
from .configuration_qwen2_5_omni import (
Qwen2_5OmniAudioEncoderConfig,
@ -61,9 +59,8 @@ from .configuration_qwen2_5_omni import (
)
if is_flash_attn_2_available():
from flash_attn.flash_attn_interface import flash_attn_varlen_func as flash_attn_varlen_func
from flash_attn.layers.rotary import apply_rotary_emb
if is_flash_attn_available():
from ...modeling_flash_attention_utils import apply_rotary_emb, flash_attn_varlen_func
else:
flash_attn_varlen_func = None
apply_rotary_emb = None
@ -653,7 +650,7 @@ class Qwen2_5OmniAudioFlashAttention2(Qwen2_5OmniAudioAttention):
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask()
def forward(
self,

View File

@ -43,22 +43,20 @@ from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLRotaryEmbeddin
from ...configuration_utils import PretrainedConfig, layer_type_validation
from ...generation import GenerationMixin
from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available
from ...modeling_outputs import BaseModelOutput, ModelOutput
from ...modeling_rope_utils import rope_config_validation
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
from ...utils import (
auto_docstring,
check_torch_load_is_safe,
is_flash_attn_2_available,
is_flash_attn_greater_or_equal_2_10,
logging,
)
from ...utils.hub import cached_file
if is_flash_attn_2_available():
from flash_attn.flash_attn_interface import flash_attn_varlen_func as flash_attn_varlen_func
from flash_attn.layers.rotary import apply_rotary_emb
if is_flash_attn_available():
from ...modeling_flash_attention_utils import apply_rotary_emb, flash_attn_varlen_func
else:
flash_attn_varlen_func = None
apply_rotary_emb = None
@ -1667,7 +1665,7 @@ class Qwen2_5OmniAudioFlashAttention2(Qwen2_5OmniAudioAttention):
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask()
def forward(
self,