diff --git a/src/transformers/integrations/npu_flash_attention.py b/src/transformers/integrations/npu_flash_attention.py index e32af9f4bc9..b6a6001729c 100644 --- a/src/transformers/integrations/npu_flash_attention.py +++ b/src/transformers/integrations/npu_flash_attention.py @@ -23,6 +23,7 @@ if is_torch_npu_available(): import torch_npu from einops import rearrange, repeat + from torch_npu import npu_rotary_mul # FlashAttention2 is supported on Ascend NPU with down-right aligned causal mask by default. @@ -247,3 +248,19 @@ def npu_flash_attn_varlen_func( )[0] return output + + +def npu_apply_rotary_emb(x, cos, sin, **kwargs): + # cos tensor after chunk should be repeated through chunked dimension to original shape on Ascend NPU + if len(cos.shape) == 2 and cos.shape[-1] == x.shape[-1] // 2: + cos = cos.repeat(1, 2) + # cos tensor with [S,D] shape should be unsqueezed to 4-d tensor with shape [1,S,1,D] + cos = cos.unsqueeze(0).unsqueeze(2) + + # sin tensor after chunk should be repeated through chunked dimension to original shape on Ascend NPU + if len(sin.shape) == 2 and sin.shape[-1] == x.shape[-1] // 2: + sin = sin.repeat(1, 2) + # sin tensor with [S,D] shape should be unsqueezed to 4-d tensor with shape [1,S,1,D] + sin = sin.unsqueeze(0).unsqueeze(2) + + return npu_rotary_mul(x, cos, sin) diff --git a/src/transformers/modeling_flash_attention_utils.py b/src/transformers/modeling_flash_attention_utils.py index 648995eed7b..f01db89f5b1 100644 --- a/src/transformers/modeling_flash_attention_utils.py +++ b/src/transformers/modeling_flash_attention_utils.py @@ -40,9 +40,8 @@ if is_flash_attn_2_available(): # patch functions in package `flash-attn` when using flash-attention on Ascend NPU. if is_torch_npu_available(): - from torch_npu import npu_rotary_mul as apply_rotary_emb # noqa - from .integrations.npu_flash_attention import index_first_axis, pad_input, unpad_input + from .integrations.npu_flash_attention import npu_apply_rotary_emb as apply_rotary_emb # noqa from .integrations.npu_flash_attention import npu_flash_attn_func as flash_attn_func from .integrations.npu_flash_attention import npu_flash_attn_varlen_func as flash_attn_varlen_func diff --git a/src/transformers/models/esm/modeling_esm.py b/src/transformers/models/esm/modeling_esm.py index da62ecbea2a..640b8429f00 100755 --- a/src/transformers/models/esm/modeling_esm.py +++ b/src/transformers/models/esm/modeling_esm.py @@ -23,6 +23,7 @@ import torch.utils.checkpoint from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions, @@ -31,11 +32,11 @@ from ...modeling_outputs import ( TokenClassifierOutput, ) from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer -from ...utils import auto_docstring, is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10, logging +from ...utils import auto_docstring, logging from .configuration_esm import EsmConfig -if is_flash_attn_2_available(): +if is_flash_attn_available(): from ...modeling_flash_attention_utils import _flash_attention_forward @@ -413,7 +414,7 @@ class EsmFlashAttention2(EsmSelfAttention): # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask() self.dropout_prob = config.attention_probs_dropout_prob def forward( diff --git a/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py b/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py index 13b641c4681..8164ad0f4ba 100644 --- a/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +++ b/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py @@ -34,19 +34,17 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask -from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_flash_attention_utils import ( + FlashAttentionKwargs, + flash_attn_supports_top_left_mask, + is_flash_attn_available, +) from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPast, ModelOutput from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack -from ...utils import ( - auto_docstring, - check_torch_load_is_safe, - is_flash_attn_2_available, - is_flash_attn_greater_or_equal_2_10, - logging, -) +from ...utils import auto_docstring, check_torch_load_is_safe, logging from ...utils.hub import cached_file from .configuration_qwen2_5_omni import ( Qwen2_5OmniAudioEncoderConfig, @@ -61,9 +59,8 @@ from .configuration_qwen2_5_omni import ( ) -if is_flash_attn_2_available(): - from flash_attn.flash_attn_interface import flash_attn_varlen_func as flash_attn_varlen_func - from flash_attn.layers.rotary import apply_rotary_emb +if is_flash_attn_available(): + from ...modeling_flash_attention_utils import apply_rotary_emb, flash_attn_varlen_func else: flash_attn_varlen_func = None apply_rotary_emb = None @@ -653,7 +650,7 @@ class Qwen2_5OmniAudioFlashAttention2(Qwen2_5OmniAudioAttention): # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask() def forward( self, diff --git a/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py b/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py index 70c0490d95b..6f40803f803 100644 --- a/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +++ b/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py @@ -43,22 +43,20 @@ from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLRotaryEmbeddin from ...configuration_utils import PretrainedConfig, layer_type_validation from ...generation import GenerationMixin +from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available from ...modeling_outputs import BaseModelOutput, ModelOutput from ...modeling_rope_utils import rope_config_validation from ...modeling_utils import ALL_ATTENTION_FUNCTIONS from ...utils import ( auto_docstring, check_torch_load_is_safe, - is_flash_attn_2_available, - is_flash_attn_greater_or_equal_2_10, logging, ) from ...utils.hub import cached_file -if is_flash_attn_2_available(): - from flash_attn.flash_attn_interface import flash_attn_varlen_func as flash_attn_varlen_func - from flash_attn.layers.rotary import apply_rotary_emb +if is_flash_attn_available(): + from ...modeling_flash_attention_utils import apply_rotary_emb, flash_attn_varlen_func else: flash_attn_varlen_func = None apply_rotary_emb = None @@ -1667,7 +1665,7 @@ class Qwen2_5OmniAudioFlashAttention2(Qwen2_5OmniAudioAttention): # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask() def forward( self,