mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 04:40:06 +06:00
[bugfix] [WIP] fix apply_rotary_emb error on Ascend NPU (#38491)
[bugfix] fix apply_rotary_emb error on Ascend NPU
This commit is contained in:
parent
ca0a682796
commit
fdf86fb440
@ -23,6 +23,7 @@ if is_torch_npu_available():
|
||||
|
||||
import torch_npu
|
||||
from einops import rearrange, repeat
|
||||
from torch_npu import npu_rotary_mul
|
||||
|
||||
|
||||
# FlashAttention2 is supported on Ascend NPU with down-right aligned causal mask by default.
|
||||
@ -247,3 +248,19 @@ def npu_flash_attn_varlen_func(
|
||||
)[0]
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def npu_apply_rotary_emb(x, cos, sin, **kwargs):
|
||||
# cos tensor after chunk should be repeated through chunked dimension to original shape on Ascend NPU
|
||||
if len(cos.shape) == 2 and cos.shape[-1] == x.shape[-1] // 2:
|
||||
cos = cos.repeat(1, 2)
|
||||
# cos tensor with [S,D] shape should be unsqueezed to 4-d tensor with shape [1,S,1,D]
|
||||
cos = cos.unsqueeze(0).unsqueeze(2)
|
||||
|
||||
# sin tensor after chunk should be repeated through chunked dimension to original shape on Ascend NPU
|
||||
if len(sin.shape) == 2 and sin.shape[-1] == x.shape[-1] // 2:
|
||||
sin = sin.repeat(1, 2)
|
||||
# sin tensor with [S,D] shape should be unsqueezed to 4-d tensor with shape [1,S,1,D]
|
||||
sin = sin.unsqueeze(0).unsqueeze(2)
|
||||
|
||||
return npu_rotary_mul(x, cos, sin)
|
||||
|
@ -40,9 +40,8 @@ if is_flash_attn_2_available():
|
||||
|
||||
# patch functions in package `flash-attn` when using flash-attention on Ascend NPU.
|
||||
if is_torch_npu_available():
|
||||
from torch_npu import npu_rotary_mul as apply_rotary_emb # noqa
|
||||
|
||||
from .integrations.npu_flash_attention import index_first_axis, pad_input, unpad_input
|
||||
from .integrations.npu_flash_attention import npu_apply_rotary_emb as apply_rotary_emb # noqa
|
||||
from .integrations.npu_flash_attention import npu_flash_attn_func as flash_attn_func
|
||||
from .integrations.npu_flash_attention import npu_flash_attn_varlen_func as flash_attn_varlen_func
|
||||
|
||||
|
@ -23,6 +23,7 @@ import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
|
||||
from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutputWithPastAndCrossAttentions,
|
||||
BaseModelOutputWithPoolingAndCrossAttentions,
|
||||
@ -31,11 +32,11 @@ from ...modeling_outputs import (
|
||||
TokenClassifierOutput,
|
||||
)
|
||||
from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
|
||||
from ...utils import auto_docstring, is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10, logging
|
||||
from ...utils import auto_docstring, logging
|
||||
from .configuration_esm import EsmConfig
|
||||
|
||||
|
||||
if is_flash_attn_2_available():
|
||||
if is_flash_attn_available():
|
||||
from ...modeling_flash_attention_utils import _flash_attention_forward
|
||||
|
||||
|
||||
@ -413,7 +414,7 @@ class EsmFlashAttention2(EsmSelfAttention):
|
||||
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
|
||||
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
||||
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
|
||||
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
|
||||
self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask()
|
||||
self.dropout_prob = config.attention_probs_dropout_prob
|
||||
|
||||
def forward(
|
||||
|
@ -34,19 +34,17 @@ from ...activations import ACT2FN
|
||||
from ...cache_utils import Cache, DynamicCache
|
||||
from ...generation import GenerationMixin
|
||||
from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
|
||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from ...modeling_flash_attention_utils import (
|
||||
FlashAttentionKwargs,
|
||||
flash_attn_supports_top_left_mask,
|
||||
is_flash_attn_available,
|
||||
)
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPast, ModelOutput
|
||||
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import (
|
||||
auto_docstring,
|
||||
check_torch_load_is_safe,
|
||||
is_flash_attn_2_available,
|
||||
is_flash_attn_greater_or_equal_2_10,
|
||||
logging,
|
||||
)
|
||||
from ...utils import auto_docstring, check_torch_load_is_safe, logging
|
||||
from ...utils.hub import cached_file
|
||||
from .configuration_qwen2_5_omni import (
|
||||
Qwen2_5OmniAudioEncoderConfig,
|
||||
@ -61,9 +59,8 @@ from .configuration_qwen2_5_omni import (
|
||||
)
|
||||
|
||||
|
||||
if is_flash_attn_2_available():
|
||||
from flash_attn.flash_attn_interface import flash_attn_varlen_func as flash_attn_varlen_func
|
||||
from flash_attn.layers.rotary import apply_rotary_emb
|
||||
if is_flash_attn_available():
|
||||
from ...modeling_flash_attention_utils import apply_rotary_emb, flash_attn_varlen_func
|
||||
else:
|
||||
flash_attn_varlen_func = None
|
||||
apply_rotary_emb = None
|
||||
@ -653,7 +650,7 @@ class Qwen2_5OmniAudioFlashAttention2(Qwen2_5OmniAudioAttention):
|
||||
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
|
||||
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
||||
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
|
||||
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
|
||||
self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
@ -43,22 +43,20 @@ from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLRotaryEmbeddin
|
||||
|
||||
from ...configuration_utils import PretrainedConfig, layer_type_validation
|
||||
from ...generation import GenerationMixin
|
||||
from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available
|
||||
from ...modeling_outputs import BaseModelOutput, ModelOutput
|
||||
from ...modeling_rope_utils import rope_config_validation
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
|
||||
from ...utils import (
|
||||
auto_docstring,
|
||||
check_torch_load_is_safe,
|
||||
is_flash_attn_2_available,
|
||||
is_flash_attn_greater_or_equal_2_10,
|
||||
logging,
|
||||
)
|
||||
from ...utils.hub import cached_file
|
||||
|
||||
|
||||
if is_flash_attn_2_available():
|
||||
from flash_attn.flash_attn_interface import flash_attn_varlen_func as flash_attn_varlen_func
|
||||
from flash_attn.layers.rotary import apply_rotary_emb
|
||||
if is_flash_attn_available():
|
||||
from ...modeling_flash_attention_utils import apply_rotary_emb, flash_attn_varlen_func
|
||||
else:
|
||||
flash_attn_varlen_func = None
|
||||
apply_rotary_emb = None
|
||||
@ -1667,7 +1665,7 @@ class Qwen2_5OmniAudioFlashAttention2(Qwen2_5OmniAudioAttention):
|
||||
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
|
||||
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
||||
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
|
||||
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
|
||||
self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
Loading…
Reference in New Issue
Block a user