mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-04 13:20:12 +06:00
[cleanup] delete deprecated kwargs in qwen2_audio 🧹 (#38404)
delete deprecated
This commit is contained in:
parent
b9f8f863d9
commit
f85fd90407
@ -29,7 +29,6 @@ from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask,
|
|||||||
from ...modeling_outputs import BaseModelOutput, ModelOutput
|
from ...modeling_outputs import BaseModelOutput, ModelOutput
|
||||||
from ...modeling_utils import PreTrainedModel
|
from ...modeling_utils import PreTrainedModel
|
||||||
from ...utils import auto_docstring, logging
|
from ...utils import auto_docstring, logging
|
||||||
from ...utils.deprecation import deprecate_kwarg
|
|
||||||
from ..auto import AutoModel, AutoModelForCausalLM
|
from ..auto import AutoModel, AutoModelForCausalLM
|
||||||
from .configuration_qwen2_audio import Qwen2AudioConfig, Qwen2AudioEncoderConfig
|
from .configuration_qwen2_audio import Qwen2AudioConfig, Qwen2AudioEncoderConfig
|
||||||
|
|
||||||
@ -130,18 +129,12 @@ class Qwen2AudioAttention(nn.Module):
|
|||||||
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
||||||
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
||||||
|
|
||||||
@deprecate_kwarg("key_value_states", version="4.52")
|
|
||||||
@deprecate_kwarg("past_key_value", version="4.52")
|
|
||||||
@deprecate_kwarg("cache_position", version="4.52")
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
key_value_states: Optional[torch.Tensor] = None,
|
|
||||||
past_key_value: Optional[Cache] = None,
|
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
layer_head_mask: Optional[torch.Tensor] = None,
|
layer_head_mask: Optional[torch.Tensor] = None,
|
||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
"""Input shape: Batch x Time x Channel"""
|
"""Input shape: Batch x Time x Channel"""
|
||||||
|
|
||||||
@ -203,18 +196,12 @@ class Qwen2AudioFlashAttention2(Qwen2AudioAttention):
|
|||||||
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
|
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
|
||||||
self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask()
|
self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask()
|
||||||
|
|
||||||
@deprecate_kwarg("key_value_states", version="4.52")
|
|
||||||
@deprecate_kwarg("past_key_value", version="4.52")
|
|
||||||
@deprecate_kwarg("cache_position", version="4.52")
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
key_value_states: Optional[torch.Tensor] = None,
|
|
||||||
past_key_value: Optional[Cache] = None,
|
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
layer_head_mask: Optional[torch.Tensor] = None,
|
layer_head_mask: Optional[torch.Tensor] = None,
|
||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
# Qwen2AudioFlashAttention2 attention does not support output_attentions
|
# Qwen2AudioFlashAttention2 attention does not support output_attentions
|
||||||
if output_attentions:
|
if output_attentions:
|
||||||
@ -283,18 +270,12 @@ class Qwen2AudioFlashAttention2(Qwen2AudioAttention):
|
|||||||
|
|
||||||
|
|
||||||
class Qwen2AudioSdpaAttention(Qwen2AudioAttention):
|
class Qwen2AudioSdpaAttention(Qwen2AudioAttention):
|
||||||
@deprecate_kwarg("key_value_states", version="4.52")
|
|
||||||
@deprecate_kwarg("past_key_value", version="4.52")
|
|
||||||
@deprecate_kwarg("cache_position", version="4.52")
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
key_value_states: Optional[torch.Tensor] = None,
|
|
||||||
past_key_value: Optional[Cache] = None,
|
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
layer_head_mask: Optional[torch.Tensor] = None,
|
layer_head_mask: Optional[torch.Tensor] = None,
|
||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
"""Input shape: Batch x Time x Channel"""
|
"""Input shape: Batch x Time x Channel"""
|
||||||
if output_attentions:
|
if output_attentions:
|
||||||
|
Loading…
Reference in New Issue
Block a user