[qwen2 audio] remove redundant code and update docs (#36282)

This commit is contained in:
Joao Gante 2025-03-20 10:54:51 +00:00 committed by GitHub
parent f0d5b2ff04
commit 957b05b413
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 56 additions and 199 deletions

View File

@ -29,7 +29,7 @@ The Qwen2-Audio is the new model series of large audio-language models from the
* voice chat: users can freely engage in voice interactions with Qwen2-Audio without text input
* audio analysis: users could provide audio and text instructions for analysis during the interaction
It was proposed in [Qwen2-Audio Technical Report](https://arxiv.org/abs/2407.10759) by Yunfei Chu, Jin Xu, Qian Yang, Haojie Wei, Xipin Wei, Zhifang Guo, Yichong Leng, Yuanjun Lv, Jinzheng He, Junyang Lin, Chang Zhou, Jingren Zhou.
It was proposed in [Qwen2-Audio Technical Report](https://arxiv.org/abs/2407.10759) by Yunfei Chu, Jin Xu, Qian Yang, Haojie Wei, Xipin Wei, Zhifang Guo, Yichong Leng, Yuanjun Lv, Jinzheng He, Junyang Lin, Chang Zhou, Jingren Zhou.
The abstract from the paper is the following:
@ -100,7 +100,7 @@ for message in conversation:
for ele in message["content"]:
if ele["type"] == "audio":
audios.append(librosa.load(
BytesIO(urlopen(ele['audio_url']).read()),
BytesIO(urlopen(ele['audio_url']).read()),
sr=processor.feature_extractor.sampling_rate)[0]
)
@ -125,7 +125,7 @@ processor = AutoProcessor.from_pretrained("Qwen/Qwen2-Audio-7B-Instruct")
model = Qwen2AudioForConditionalGeneration.from_pretrained("Qwen/Qwen2-Audio-7B-Instruct", device_map="auto")
conversation = [
{'role': 'system', 'content': 'You are a helpful assistant.'},
{'role': 'system', 'content': 'You are a helpful assistant.'},
{"role": "user", "content": [
{"type": "audio", "audio_url": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2-Audio/audio/glass-breaking-151256.mp3"},
{"type": "text", "text": "What's that sound?"},
@ -148,7 +148,7 @@ for message in conversation:
if ele["type"] == "audio":
audios.append(
librosa.load(
BytesIO(urlopen(ele['audio_url']).read()),
BytesIO(urlopen(ele['audio_url']).read()),
sr=processor.feature_extractor.sampling_rate)[0]
)
@ -203,7 +203,7 @@ for conversation in conversations:
if ele["type"] == "audio":
audios.append(
librosa.load(
BytesIO(urlopen(ele['audio_url']).read()),
BytesIO(urlopen(ele['audio_url']).read()),
sr=processor.feature_extractor.sampling_rate)[0]
)
@ -221,7 +221,7 @@ response = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_
[[autodoc]] Qwen2AudioConfig
## Qwen2AudioConfig
## Qwen2AudioEncoderConfig
[[autodoc]] Qwen2AudioEncoderConfig
@ -229,6 +229,11 @@ response = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_
[[autodoc]] Qwen2AudioProcessor
## Qwen2AudioEncoder
[[autodoc]] Qwen2AudioEncoder
- forward
## Qwen2AudioForConditionalGeneration
[[autodoc]] Qwen2AudioForConditionalGeneration

View File

@ -16,14 +16,14 @@
import math
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
from typing import Optional, Tuple, Union
import torch
import torch.utils.checkpoint
from torch import nn
from ...activations import ACT2FN
from ...cache_utils import Cache, EncoderDecoderCache, StaticCache
from ...cache_utils import Cache
from ...generation import GenerationMixin
from ...modeling_outputs import BaseModelOutput, ModelOutput
from ...modeling_utils import PreTrainedModel
@ -35,6 +35,7 @@ from ...utils import (
logging,
replace_return_docstrings,
)
from ...utils.deprecation import deprecate_kwarg
from ..auto import AutoModel, AutoModelForCausalLM
from .configuration_qwen2_audio import Qwen2AudioConfig, Qwen2AudioEncoderConfig
@ -58,12 +59,15 @@ class Qwen2AudioCausalLMOutputWithPast(ModelOutput):
Language modeling loss (for next-token prediction).
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`)
past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Pre-computed hidden-states that can be used to speed up auto-regressive (sequential) decoding. There are
two sets of pre-computed hidden-states: key and values states in the self-attention blocks.
The `past_key_values` are returned when `use_cache=True` is passed or when `config.use_cache=True`.
It is a [`~cache_utils.Cache`] instance.
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
`past_key_values` input) to speed up sequential decoding.
If `past_key_values` are used, the user can optionally input only the last `input_ids` (those
that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
all `input_ids` of shape `(batch_size, sequence_length)`.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
@ -81,16 +85,16 @@ class Qwen2AudioCausalLMOutputWithPast(ModelOutput):
loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None
past_key_values: Optional[List[torch.FloatTensor]] = None
past_key_values: Optional[Cache] = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
attention_mask: Optional[torch.FloatTensor] = None
# Copied from transformers.models.whisper.modeling_whisper.WhisperAttention with Whisper->Qwen2Audio
class Qwen2AudioAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
# Copied from transformers.models.whisper.modeling_whisper.WhisperAttention.__init__ with Whisper->Qwen2Audio
def __init__(
self,
embed_dim: int,
@ -135,11 +139,14 @@ class Qwen2AudioAttention(nn.Module):
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
@deprecate_kwarg("key_value_states", version="4.52")
@deprecate_kwarg("past_key_value", version="4.52")
@deprecate_kwarg("cache_position", version="4.52")
def forward(
self,
hidden_states: torch.Tensor,
key_value_states: Optional[torch.Tensor] = None,
past_key_value: Optional[EncoderDecoderCache] = None,
past_key_value: Optional[Cache] = None,
attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
@ -147,38 +154,12 @@ class Qwen2AudioAttention(nn.Module):
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""
# if key_value_states are provided this layer is used as a cross-attention layer
# for the decoder
is_cross_attention = key_value_states is not None
bsz, tgt_len, _ = hidden_states.size()
# get query proj
query_states = self._shape(self.q_proj(hidden_states) * self.scaling, tgt_len, bsz)
if past_key_value is not None:
is_updated = past_key_value.is_updated.get(self.layer_idx)
if is_cross_attention:
# after the first generated id, we can subsequently re-use all key/value_states from cache
past_key_value.is_updated[self.layer_idx] = True
past_key_value = past_key_value.cross_attention_cache
else:
past_key_value = past_key_value.self_attention_cache
# use key_value_states if cross attention
current_states = key_value_states if key_value_states is not None else hidden_states
if is_cross_attention and past_key_value and is_updated:
# reuse k,v, cross_attentions
key_states = past_key_value.key_cache[self.layer_idx]
value_states = past_key_value.value_cache[self.layer_idx]
else:
key_states = self._shape(self.k_proj(current_states), -1, bsz)
value_states = self._shape(self.v_proj(current_states), -1, bsz)
if past_key_value is not None:
# save all key/value_states to cache to be re-used for fast auto-regressive generation
cache_position = cache_position if not is_cross_attention else None
key_states, value_states = past_key_value.update(
key_states, value_states, self.layer_idx, {"cache_position": cache_position}
)
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3))
@ -212,10 +193,9 @@ class Qwen2AudioAttention(nn.Module):
attn_output = self.out_proj(attn_output)
return attn_output, attn_weights, past_key_value
return attn_output, attn_weights, None
# Copied from transformers.models.whisper.modeling_whisper.WhisperFlashAttention2 with Whisper->Qwen2Audio
class Qwen2AudioFlashAttention2(Qwen2AudioAttention):
"""
Qwen2Audio flash attention module. This module inherits from `Qwen2AudioAttention` as the weights of the module stays
@ -223,6 +203,7 @@ class Qwen2AudioFlashAttention2(Qwen2AudioAttention):
flash attention and deal with padding tokens in case the input contains any of them.
"""
# Copied from transformers.models.whisper.modeling_whisper.WhisperFlashAttention2.__init__ with Whisper->Qwen2Audio
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@ -231,57 +212,29 @@ class Qwen2AudioFlashAttention2(Qwen2AudioAttention):
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
@deprecate_kwarg("key_value_states", version="4.52")
@deprecate_kwarg("past_key_value", version="4.52")
@deprecate_kwarg("cache_position", version="4.52")
def forward(
self,
hidden_states: torch.Tensor,
key_value_states: Optional[torch.Tensor] = None,
past_key_value: Optional[EncoderDecoderCache] = None,
past_key_value: Optional[Cache] = None,
attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
cache_position: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if isinstance(past_key_value, StaticCache):
raise ValueError(
"The `static` cache implementation is not compatible with `attn_implementation='flash_attention_2'`. "
"Use `attn_implementation='sdpa'` in the meantime, and open an issue at https://github.com/huggingface/transformers"
)
# Qwen2AudioFlashAttention2 attention does not support output_attentions
if output_attentions:
raise ValueError("Qwen2AudioFlashAttention2 attention does not support output_attentions")
# if key_value_states are provided this layer is used as a cross-attention layer
# for the decoder
is_cross_attention = key_value_states is not None
bsz, tgt_len, _ = hidden_states.size()
# get query proj
query_states = torch.reshape(self.q_proj(hidden_states), (bsz, tgt_len, self.num_heads, self.head_dim))
if past_key_value is not None:
is_updated = past_key_value.is_updated.get(self.layer_idx)
if is_cross_attention:
# after the first generated id, we can subsequently re-use all key/value_states from cache
past_key_value.is_updated[self.layer_idx] = True
past_key_value = past_key_value.cross_attention_cache
else:
past_key_value = past_key_value.self_attention_cache
# use key_value_states if cross attention
current_states = key_value_states if key_value_states is not None else hidden_states
if is_cross_attention and past_key_value and is_updated:
# reuse k,v, cross_attentions
key_states = past_key_value.key_cache[self.layer_idx]
value_states = past_key_value.value_cache[self.layer_idx]
else:
key_states = self._shape(self.k_proj(current_states), -1, bsz)
value_states = self._shape(self.v_proj(current_states), -1, bsz)
if past_key_value is not None:
# save all key/value_states to cache to be re-used for fast auto-regressive generation
cache_position = cache_position if not is_cross_attention else None
key_states, value_states = past_key_value.update(
key_states, value_states, self.layer_idx, {"cache_position": cache_position}
)
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
# TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]
# We would need to refactor the KV cache to be able to avoid many of these transpose/reshape/view.
@ -335,16 +288,18 @@ class Qwen2AudioFlashAttention2(Qwen2AudioAttention):
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
return attn_output, attn_weights, None
# Copied from transformers.models.whisper.modeling_whisper.WhisperSdpaAttention with Whisper->Qwen2Audio
class Qwen2AudioSdpaAttention(Qwen2AudioAttention):
@deprecate_kwarg("key_value_states", version="4.52")
@deprecate_kwarg("past_key_value", version="4.52")
@deprecate_kwarg("cache_position", version="4.52")
def forward(
self,
hidden_states: torch.Tensor,
key_value_states: Optional[torch.Tensor] = None,
past_key_value: Optional[EncoderDecoderCache] = None,
past_key_value: Optional[Cache] = None,
attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
@ -359,46 +314,17 @@ class Qwen2AudioSdpaAttention(Qwen2AudioAttention):
)
return super().forward(
hidden_states,
key_value_states=key_value_states,
past_key_value=past_key_value,
attention_mask=attention_mask,
layer_head_mask=layer_head_mask,
output_attentions=output_attentions,
cache_position=cache_position,
)
# if key_value_states are provided this layer is used as a cross-attention layer
# for the decoder
is_cross_attention = key_value_states is not None
bsz, tgt_len, _ = hidden_states.size()
# get query proj
query_states = self._shape(self.q_proj(hidden_states), tgt_len, bsz)
if past_key_value is not None:
is_updated = past_key_value.is_updated.get(self.layer_idx)
if is_cross_attention:
# after the first generated id, we can subsequently re-use all key/value_states from cache
past_key_value.is_updated[self.layer_idx] = True
past_key_value = past_key_value.cross_attention_cache
else:
past_key_value = past_key_value.self_attention_cache
# use key_value_states if cross attention
current_states = key_value_states if key_value_states is not None else hidden_states
if is_cross_attention and past_key_value and is_updated:
# reuse k,v, cross_attentions
key_states = past_key_value.key_cache[self.layer_idx]
value_states = past_key_value.value_cache[self.layer_idx]
else:
key_states = self._shape(self.k_proj(current_states), -1, bsz)
value_states = self._shape(self.v_proj(current_states), -1, bsz)
if past_key_value is not None:
# save all key/value_states to cache to be re-used for fast auto-regressive generation
cache_position = cache_position if not is_cross_attention else None
key_states, value_states = past_key_value.update(
key_states, value_states, self.layer_idx, {"cache_position": cache_position}
)
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
causal_mask = attention_mask
if attention_mask is not None: # no matter the length, we just slice it
@ -434,7 +360,7 @@ class Qwen2AudioSdpaAttention(Qwen2AudioAttention):
attn_output = self.out_proj(attn_output)
return attn_output, None, past_key_value
return attn_output, None, None
QWEN2AUDIO_ATTENTION_CLASSES = {
@ -815,16 +741,15 @@ QWEN2AUDIO_INPUTS_DOCSTRING = r"""
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids)
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Pre-computed hidden-states that can be used to speed up auto-regressive (sequential) decoding. There are
two sets of pre-computed hidden-states: key and values states in the self-attention blocks.
The `past_key_values` are returned when `use_cache=True` is passed or when `config.use_cache=True`.
It is a [`~cache_utils.Cache`] instance.
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
If `past_key_values` are used, the user can optionally input only the last `input_ids` (those
that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
all `input_ids` of shape `(batch_size, sequence_length)`.shape `(batch_size, 1)` instead of all
`decoder_input_ids` of shape `(batch_size, sequence_length)`.
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
@ -851,7 +776,7 @@ QWEN2AUDIO_INPUTS_DOCSTRING = r"""
class Qwen2AudioForConditionalGeneration(Qwen2AudioPreTrainedModel, GenerationMixin):
def __init__(self, config: Qwen2AudioConfig):
super().__init__(config)
self.audio_tower = AutoModel.from_config(config.audio_config)
self.audio_tower = AutoModel.from_config(config.audio_config) # Usually a `Qwen2AudioEncoder` instance
self.multi_modal_projector = Qwen2AudioMultiModalProjector(config)
self.vocab_size = config.text_config.vocab_size
@ -1103,7 +1028,7 @@ class Qwen2AudioForConditionalGeneration(Qwen2AudioPreTrainedModel, GenerationMi
attention_mask: Optional[torch.Tensor] = None,
feature_attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
@ -1258,78 +1183,5 @@ class Qwen2AudioForConditionalGeneration(Qwen2AudioPreTrainedModel, GenerationMi
attention_mask=attention_mask,
)
def prepare_inputs_for_generation(
self,
input_ids,
past_key_values=None,
inputs_embeds=None,
input_features=None,
attention_mask=None,
**kwargs,
):
# Overwritten -- custom processing (note: might not be needed, but there are no generation tests running atm)
if past_key_values is not None:
if isinstance(past_key_values, Cache):
cache_length = past_key_values.get_seq_length()
past_length = past_key_values.seen_tokens
else:
cache_length = past_length = past_key_values[0][0].shape[2]
# Here, we get the attention_mask, which was previously stored in the state after _merge_input_ids_with_audio_features.
if input_features is not None and kwargs.get("attention_mask") is not None:
attention_mask = kwargs["attention_mask"]
attention_mask = torch.cat(
[attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
)
# Keep only the unprocessed tokens:
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
# some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
# input)
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
# input_ids based on the past_length.
elif past_length < input_ids.shape[1]:
input_ids = input_ids[:, past_length:]
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
elif self.config.audio_token_index in input_ids:
input_ids = input_ids[:, input_ids.shape[1] - 1 :]
# If the cache has seen more tokens than it can hold, then the cache has a size limit. Let's discard the
# older attention values, as their corresponding values are not part of the input.
if cache_length < past_length and attention_mask is not None:
attention_mask = attention_mask[:, -(cache_length + input_ids.shape[1]) :]
position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
position_ids = position_ids[:, -input_ids.shape[1] :]
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}
feature_attention_mask = kwargs.get("feature_attention_mask", None)
model_inputs.update(
{
"position_ids": position_ids,
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"attention_mask": attention_mask,
"input_features": input_features,
"feature_attention_mask": feature_attention_mask,
}
)
return model_inputs
def _reorder_cache(self, *args, **kwargs):
return self.language_model._reorder_cache(*args, **kwargs)
__all__ = ["Qwen2AudioForConditionalGeneration", "Qwen2AudioPreTrainedModel", "Qwen2AudioEncoder"]