mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 12:50:06 +06:00
[cleanup] delete deprecated kwargs in qwen2_audio 🧹 (#38404)
delete deprecated
This commit is contained in:
parent
b9f8f863d9
commit
f85fd90407
@ -29,7 +29,6 @@ from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask,
|
||||
from ...modeling_outputs import BaseModelOutput, ModelOutput
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...utils import auto_docstring, logging
|
||||
from ...utils.deprecation import deprecate_kwarg
|
||||
from ..auto import AutoModel, AutoModelForCausalLM
|
||||
from .configuration_qwen2_audio import Qwen2AudioConfig, Qwen2AudioEncoderConfig
|
||||
|
||||
@ -130,18 +129,12 @@ class Qwen2AudioAttention(nn.Module):
|
||||
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
||||
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
||||
|
||||
@deprecate_kwarg("key_value_states", version="4.52")
|
||||
@deprecate_kwarg("past_key_value", version="4.52")
|
||||
@deprecate_kwarg("cache_position", version="4.52")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
key_value_states: Optional[torch.Tensor] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
layer_head_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: bool = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
"""Input shape: Batch x Time x Channel"""
|
||||
|
||||
@ -203,18 +196,12 @@ class Qwen2AudioFlashAttention2(Qwen2AudioAttention):
|
||||
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
|
||||
self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask()
|
||||
|
||||
@deprecate_kwarg("key_value_states", version="4.52")
|
||||
@deprecate_kwarg("past_key_value", version="4.52")
|
||||
@deprecate_kwarg("cache_position", version="4.52")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
key_value_states: Optional[torch.Tensor] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
layer_head_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: bool = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
# Qwen2AudioFlashAttention2 attention does not support output_attentions
|
||||
if output_attentions:
|
||||
@ -283,18 +270,12 @@ class Qwen2AudioFlashAttention2(Qwen2AudioAttention):
|
||||
|
||||
|
||||
class Qwen2AudioSdpaAttention(Qwen2AudioAttention):
|
||||
@deprecate_kwarg("key_value_states", version="4.52")
|
||||
@deprecate_kwarg("past_key_value", version="4.52")
|
||||
@deprecate_kwarg("cache_position", version="4.52")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
key_value_states: Optional[torch.Tensor] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
layer_head_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: bool = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
"""Input shape: Batch x Time x Channel"""
|
||||
if output_attentions:
|
||||
|
Loading…
Reference in New Issue
Block a user