[cleanup] delete deprecated kwargs in qwen2_audio 🧹 (#38404)

delete deprecated
This commit is contained in:
Joao Gante 2025-05-27 16:08:53 +01:00 committed by GitHub
parent b9f8f863d9
commit f85fd90407
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -29,7 +29,6 @@ from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask,
from ...modeling_outputs import BaseModelOutput, ModelOutput
from ...modeling_utils import PreTrainedModel
from ...utils import auto_docstring, logging
from ...utils.deprecation import deprecate_kwarg
from ..auto import AutoModel, AutoModelForCausalLM
from .configuration_qwen2_audio import Qwen2AudioConfig, Qwen2AudioEncoderConfig
@ -130,18 +129,12 @@ class Qwen2AudioAttention(nn.Module):
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
@deprecate_kwarg("key_value_states", version="4.52")
@deprecate_kwarg("past_key_value", version="4.52")
@deprecate_kwarg("cache_position", version="4.52")
def forward(
self,
hidden_states: torch.Tensor,
key_value_states: Optional[torch.Tensor] = None,
past_key_value: Optional[Cache] = None,
attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
cache_position: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""
@ -203,18 +196,12 @@ class Qwen2AudioFlashAttention2(Qwen2AudioAttention):
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask()
@deprecate_kwarg("key_value_states", version="4.52")
@deprecate_kwarg("past_key_value", version="4.52")
@deprecate_kwarg("cache_position", version="4.52")
def forward(
self,
hidden_states: torch.Tensor,
key_value_states: Optional[torch.Tensor] = None,
past_key_value: Optional[Cache] = None,
attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
cache_position: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
# Qwen2AudioFlashAttention2 attention does not support output_attentions
if output_attentions:
@ -283,18 +270,12 @@ class Qwen2AudioFlashAttention2(Qwen2AudioAttention):
class Qwen2AudioSdpaAttention(Qwen2AudioAttention):
@deprecate_kwarg("key_value_states", version="4.52")
@deprecate_kwarg("past_key_value", version="4.52")
@deprecate_kwarg("cache_position", version="4.52")
def forward(
self,
hidden_states: torch.Tensor,
key_value_states: Optional[torch.Tensor] = None,
past_key_value: Optional[Cache] = None,
attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
cache_position: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""
if output_attentions: