mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
[qwen2 audio] remove redundant code and update docs (#36282)
This commit is contained in:
parent
f0d5b2ff04
commit
957b05b413
@ -29,7 +29,7 @@ The Qwen2-Audio is the new model series of large audio-language models from the
|
||||
* voice chat: users can freely engage in voice interactions with Qwen2-Audio without text input
|
||||
* audio analysis: users could provide audio and text instructions for analysis during the interaction
|
||||
|
||||
It was proposed in [Qwen2-Audio Technical Report](https://arxiv.org/abs/2407.10759) by Yunfei Chu, Jin Xu, Qian Yang, Haojie Wei, Xipin Wei, Zhifang Guo, Yichong Leng, Yuanjun Lv, Jinzheng He, Junyang Lin, Chang Zhou, Jingren Zhou.
|
||||
It was proposed in [Qwen2-Audio Technical Report](https://arxiv.org/abs/2407.10759) by Yunfei Chu, Jin Xu, Qian Yang, Haojie Wei, Xipin Wei, Zhifang Guo, Yichong Leng, Yuanjun Lv, Jinzheng He, Junyang Lin, Chang Zhou, Jingren Zhou.
|
||||
|
||||
The abstract from the paper is the following:
|
||||
|
||||
@ -100,7 +100,7 @@ for message in conversation:
|
||||
for ele in message["content"]:
|
||||
if ele["type"] == "audio":
|
||||
audios.append(librosa.load(
|
||||
BytesIO(urlopen(ele['audio_url']).read()),
|
||||
BytesIO(urlopen(ele['audio_url']).read()),
|
||||
sr=processor.feature_extractor.sampling_rate)[0]
|
||||
)
|
||||
|
||||
@ -125,7 +125,7 @@ processor = AutoProcessor.from_pretrained("Qwen/Qwen2-Audio-7B-Instruct")
|
||||
model = Qwen2AudioForConditionalGeneration.from_pretrained("Qwen/Qwen2-Audio-7B-Instruct", device_map="auto")
|
||||
|
||||
conversation = [
|
||||
{'role': 'system', 'content': 'You are a helpful assistant.'},
|
||||
{'role': 'system', 'content': 'You are a helpful assistant.'},
|
||||
{"role": "user", "content": [
|
||||
{"type": "audio", "audio_url": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2-Audio/audio/glass-breaking-151256.mp3"},
|
||||
{"type": "text", "text": "What's that sound?"},
|
||||
@ -148,7 +148,7 @@ for message in conversation:
|
||||
if ele["type"] == "audio":
|
||||
audios.append(
|
||||
librosa.load(
|
||||
BytesIO(urlopen(ele['audio_url']).read()),
|
||||
BytesIO(urlopen(ele['audio_url']).read()),
|
||||
sr=processor.feature_extractor.sampling_rate)[0]
|
||||
)
|
||||
|
||||
@ -203,7 +203,7 @@ for conversation in conversations:
|
||||
if ele["type"] == "audio":
|
||||
audios.append(
|
||||
librosa.load(
|
||||
BytesIO(urlopen(ele['audio_url']).read()),
|
||||
BytesIO(urlopen(ele['audio_url']).read()),
|
||||
sr=processor.feature_extractor.sampling_rate)[0]
|
||||
)
|
||||
|
||||
@ -221,7 +221,7 @@ response = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_
|
||||
|
||||
[[autodoc]] Qwen2AudioConfig
|
||||
|
||||
## Qwen2AudioConfig
|
||||
## Qwen2AudioEncoderConfig
|
||||
|
||||
[[autodoc]] Qwen2AudioEncoderConfig
|
||||
|
||||
@ -229,6 +229,11 @@ response = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_
|
||||
|
||||
[[autodoc]] Qwen2AudioProcessor
|
||||
|
||||
## Qwen2AudioEncoder
|
||||
|
||||
[[autodoc]] Qwen2AudioEncoder
|
||||
- forward
|
||||
|
||||
## Qwen2AudioForConditionalGeneration
|
||||
|
||||
[[autodoc]] Qwen2AudioForConditionalGeneration
|
||||
|
@ -16,14 +16,14 @@
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Tuple, Union
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...cache_utils import Cache, EncoderDecoderCache, StaticCache
|
||||
from ...cache_utils import Cache
|
||||
from ...generation import GenerationMixin
|
||||
from ...modeling_outputs import BaseModelOutput, ModelOutput
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
@ -35,6 +35,7 @@ from ...utils import (
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
from ...utils.deprecation import deprecate_kwarg
|
||||
from ..auto import AutoModel, AutoModelForCausalLM
|
||||
from .configuration_qwen2_audio import Qwen2AudioConfig, Qwen2AudioEncoderConfig
|
||||
|
||||
@ -58,12 +59,15 @@ class Qwen2AudioCausalLMOutputWithPast(ModelOutput):
|
||||
Language modeling loss (for next-token prediction).
|
||||
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
|
||||
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
||||
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||||
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
|
||||
`(batch_size, num_heads, sequence_length, embed_size_per_head)`)
|
||||
past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||||
Pre-computed hidden-states that can be used to speed up auto-regressive (sequential) decoding. There are
|
||||
two sets of pre-computed hidden-states: key and values states in the self-attention blocks.
|
||||
The `past_key_values` are returned when `use_cache=True` is passed or when `config.use_cache=True`.
|
||||
It is a [`~cache_utils.Cache`] instance.
|
||||
|
||||
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
|
||||
`past_key_values` input) to speed up sequential decoding.
|
||||
If `past_key_values` are used, the user can optionally input only the last `input_ids` (those
|
||||
that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
|
||||
all `input_ids` of shape `(batch_size, sequence_length)`.
|
||||
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
||||
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
|
||||
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
|
||||
@ -81,16 +85,16 @@ class Qwen2AudioCausalLMOutputWithPast(ModelOutput):
|
||||
|
||||
loss: Optional[torch.FloatTensor] = None
|
||||
logits: torch.FloatTensor = None
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None
|
||||
past_key_values: Optional[Cache] = None
|
||||
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
||||
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||
attention_mask: Optional[torch.FloatTensor] = None
|
||||
|
||||
|
||||
# Copied from transformers.models.whisper.modeling_whisper.WhisperAttention with Whisper->Qwen2Audio
|
||||
class Qwen2AudioAttention(nn.Module):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
||||
# Copied from transformers.models.whisper.modeling_whisper.WhisperAttention.__init__ with Whisper->Qwen2Audio
|
||||
def __init__(
|
||||
self,
|
||||
embed_dim: int,
|
||||
@ -135,11 +139,14 @@ class Qwen2AudioAttention(nn.Module):
|
||||
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
||||
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
||||
|
||||
@deprecate_kwarg("key_value_states", version="4.52")
|
||||
@deprecate_kwarg("past_key_value", version="4.52")
|
||||
@deprecate_kwarg("cache_position", version="4.52")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
key_value_states: Optional[torch.Tensor] = None,
|
||||
past_key_value: Optional[EncoderDecoderCache] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
layer_head_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: bool = False,
|
||||
@ -147,38 +154,12 @@ class Qwen2AudioAttention(nn.Module):
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
"""Input shape: Batch x Time x Channel"""
|
||||
|
||||
# if key_value_states are provided this layer is used as a cross-attention layer
|
||||
# for the decoder
|
||||
is_cross_attention = key_value_states is not None
|
||||
bsz, tgt_len, _ = hidden_states.size()
|
||||
|
||||
# get query proj
|
||||
query_states = self._shape(self.q_proj(hidden_states) * self.scaling, tgt_len, bsz)
|
||||
|
||||
if past_key_value is not None:
|
||||
is_updated = past_key_value.is_updated.get(self.layer_idx)
|
||||
if is_cross_attention:
|
||||
# after the first generated id, we can subsequently re-use all key/value_states from cache
|
||||
past_key_value.is_updated[self.layer_idx] = True
|
||||
past_key_value = past_key_value.cross_attention_cache
|
||||
else:
|
||||
past_key_value = past_key_value.self_attention_cache
|
||||
|
||||
# use key_value_states if cross attention
|
||||
current_states = key_value_states if key_value_states is not None else hidden_states
|
||||
if is_cross_attention and past_key_value and is_updated:
|
||||
# reuse k,v, cross_attentions
|
||||
key_states = past_key_value.key_cache[self.layer_idx]
|
||||
value_states = past_key_value.value_cache[self.layer_idx]
|
||||
else:
|
||||
key_states = self._shape(self.k_proj(current_states), -1, bsz)
|
||||
value_states = self._shape(self.v_proj(current_states), -1, bsz)
|
||||
if past_key_value is not None:
|
||||
# save all key/value_states to cache to be re-used for fast auto-regressive generation
|
||||
cache_position = cache_position if not is_cross_attention else None
|
||||
key_states, value_states = past_key_value.update(
|
||||
key_states, value_states, self.layer_idx, {"cache_position": cache_position}
|
||||
)
|
||||
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
||||
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
||||
|
||||
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3))
|
||||
|
||||
@ -212,10 +193,9 @@ class Qwen2AudioAttention(nn.Module):
|
||||
|
||||
attn_output = self.out_proj(attn_output)
|
||||
|
||||
return attn_output, attn_weights, past_key_value
|
||||
return attn_output, attn_weights, None
|
||||
|
||||
|
||||
# Copied from transformers.models.whisper.modeling_whisper.WhisperFlashAttention2 with Whisper->Qwen2Audio
|
||||
class Qwen2AudioFlashAttention2(Qwen2AudioAttention):
|
||||
"""
|
||||
Qwen2Audio flash attention module. This module inherits from `Qwen2AudioAttention` as the weights of the module stays
|
||||
@ -223,6 +203,7 @@ class Qwen2AudioFlashAttention2(Qwen2AudioAttention):
|
||||
flash attention and deal with padding tokens in case the input contains any of them.
|
||||
"""
|
||||
|
||||
# Copied from transformers.models.whisper.modeling_whisper.WhisperFlashAttention2.__init__ with Whisper->Qwen2Audio
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
@ -231,57 +212,29 @@ class Qwen2AudioFlashAttention2(Qwen2AudioAttention):
|
||||
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
|
||||
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
|
||||
|
||||
@deprecate_kwarg("key_value_states", version="4.52")
|
||||
@deprecate_kwarg("past_key_value", version="4.52")
|
||||
@deprecate_kwarg("cache_position", version="4.52")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
key_value_states: Optional[torch.Tensor] = None,
|
||||
past_key_value: Optional[EncoderDecoderCache] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
layer_head_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: bool = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
if isinstance(past_key_value, StaticCache):
|
||||
raise ValueError(
|
||||
"The `static` cache implementation is not compatible with `attn_implementation='flash_attention_2'`. "
|
||||
"Use `attn_implementation='sdpa'` in the meantime, and open an issue at https://github.com/huggingface/transformers"
|
||||
)
|
||||
# Qwen2AudioFlashAttention2 attention does not support output_attentions
|
||||
if output_attentions:
|
||||
raise ValueError("Qwen2AudioFlashAttention2 attention does not support output_attentions")
|
||||
|
||||
# if key_value_states are provided this layer is used as a cross-attention layer
|
||||
# for the decoder
|
||||
is_cross_attention = key_value_states is not None
|
||||
bsz, tgt_len, _ = hidden_states.size()
|
||||
|
||||
# get query proj
|
||||
query_states = torch.reshape(self.q_proj(hidden_states), (bsz, tgt_len, self.num_heads, self.head_dim))
|
||||
|
||||
if past_key_value is not None:
|
||||
is_updated = past_key_value.is_updated.get(self.layer_idx)
|
||||
if is_cross_attention:
|
||||
# after the first generated id, we can subsequently re-use all key/value_states from cache
|
||||
past_key_value.is_updated[self.layer_idx] = True
|
||||
past_key_value = past_key_value.cross_attention_cache
|
||||
else:
|
||||
past_key_value = past_key_value.self_attention_cache
|
||||
|
||||
# use key_value_states if cross attention
|
||||
current_states = key_value_states if key_value_states is not None else hidden_states
|
||||
if is_cross_attention and past_key_value and is_updated:
|
||||
# reuse k,v, cross_attentions
|
||||
key_states = past_key_value.key_cache[self.layer_idx]
|
||||
value_states = past_key_value.value_cache[self.layer_idx]
|
||||
else:
|
||||
key_states = self._shape(self.k_proj(current_states), -1, bsz)
|
||||
value_states = self._shape(self.v_proj(current_states), -1, bsz)
|
||||
if past_key_value is not None:
|
||||
# save all key/value_states to cache to be re-used for fast auto-regressive generation
|
||||
cache_position = cache_position if not is_cross_attention else None
|
||||
key_states, value_states = past_key_value.update(
|
||||
key_states, value_states, self.layer_idx, {"cache_position": cache_position}
|
||||
)
|
||||
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
||||
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
||||
|
||||
# TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]
|
||||
# We would need to refactor the KV cache to be able to avoid many of these transpose/reshape/view.
|
||||
@ -335,16 +288,18 @@ class Qwen2AudioFlashAttention2(Qwen2AudioAttention):
|
||||
if not output_attentions:
|
||||
attn_weights = None
|
||||
|
||||
return attn_output, attn_weights, past_key_value
|
||||
return attn_output, attn_weights, None
|
||||
|
||||
|
||||
# Copied from transformers.models.whisper.modeling_whisper.WhisperSdpaAttention with Whisper->Qwen2Audio
|
||||
class Qwen2AudioSdpaAttention(Qwen2AudioAttention):
|
||||
@deprecate_kwarg("key_value_states", version="4.52")
|
||||
@deprecate_kwarg("past_key_value", version="4.52")
|
||||
@deprecate_kwarg("cache_position", version="4.52")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
key_value_states: Optional[torch.Tensor] = None,
|
||||
past_key_value: Optional[EncoderDecoderCache] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
layer_head_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: bool = False,
|
||||
@ -359,46 +314,17 @@ class Qwen2AudioSdpaAttention(Qwen2AudioAttention):
|
||||
)
|
||||
return super().forward(
|
||||
hidden_states,
|
||||
key_value_states=key_value_states,
|
||||
past_key_value=past_key_value,
|
||||
attention_mask=attention_mask,
|
||||
layer_head_mask=layer_head_mask,
|
||||
output_attentions=output_attentions,
|
||||
cache_position=cache_position,
|
||||
)
|
||||
|
||||
# if key_value_states are provided this layer is used as a cross-attention layer
|
||||
# for the decoder
|
||||
is_cross_attention = key_value_states is not None
|
||||
bsz, tgt_len, _ = hidden_states.size()
|
||||
|
||||
# get query proj
|
||||
query_states = self._shape(self.q_proj(hidden_states), tgt_len, bsz)
|
||||
|
||||
if past_key_value is not None:
|
||||
is_updated = past_key_value.is_updated.get(self.layer_idx)
|
||||
if is_cross_attention:
|
||||
# after the first generated id, we can subsequently re-use all key/value_states from cache
|
||||
past_key_value.is_updated[self.layer_idx] = True
|
||||
past_key_value = past_key_value.cross_attention_cache
|
||||
else:
|
||||
past_key_value = past_key_value.self_attention_cache
|
||||
|
||||
# use key_value_states if cross attention
|
||||
current_states = key_value_states if key_value_states is not None else hidden_states
|
||||
if is_cross_attention and past_key_value and is_updated:
|
||||
# reuse k,v, cross_attentions
|
||||
key_states = past_key_value.key_cache[self.layer_idx]
|
||||
value_states = past_key_value.value_cache[self.layer_idx]
|
||||
else:
|
||||
key_states = self._shape(self.k_proj(current_states), -1, bsz)
|
||||
value_states = self._shape(self.v_proj(current_states), -1, bsz)
|
||||
if past_key_value is not None:
|
||||
# save all key/value_states to cache to be re-used for fast auto-regressive generation
|
||||
cache_position = cache_position if not is_cross_attention else None
|
||||
key_states, value_states = past_key_value.update(
|
||||
key_states, value_states, self.layer_idx, {"cache_position": cache_position}
|
||||
)
|
||||
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
||||
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
||||
|
||||
causal_mask = attention_mask
|
||||
if attention_mask is not None: # no matter the length, we just slice it
|
||||
@ -434,7 +360,7 @@ class Qwen2AudioSdpaAttention(Qwen2AudioAttention):
|
||||
|
||||
attn_output = self.out_proj(attn_output)
|
||||
|
||||
return attn_output, None, past_key_value
|
||||
return attn_output, None, None
|
||||
|
||||
|
||||
QWEN2AUDIO_ATTENTION_CLASSES = {
|
||||
@ -815,16 +741,15 @@ QWEN2AUDIO_INPUTS_DOCSTRING = r"""
|
||||
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
||||
config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids)
|
||||
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||||
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
|
||||
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
|
||||
`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
|
||||
past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||||
Pre-computed hidden-states that can be used to speed up auto-regressive (sequential) decoding. There are
|
||||
two sets of pre-computed hidden-states: key and values states in the self-attention blocks.
|
||||
The `past_key_values` are returned when `use_cache=True` is passed or when `config.use_cache=True`.
|
||||
It is a [`~cache_utils.Cache`] instance.
|
||||
|
||||
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
|
||||
blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
|
||||
|
||||
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
|
||||
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
|
||||
If `past_key_values` are used, the user can optionally input only the last `input_ids` (those
|
||||
that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
|
||||
all `input_ids` of shape `(batch_size, sequence_length)`.shape `(batch_size, 1)` instead of all
|
||||
`decoder_input_ids` of shape `(batch_size, sequence_length)`.
|
||||
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
||||
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
|
||||
@ -851,7 +776,7 @@ QWEN2AUDIO_INPUTS_DOCSTRING = r"""
|
||||
class Qwen2AudioForConditionalGeneration(Qwen2AudioPreTrainedModel, GenerationMixin):
|
||||
def __init__(self, config: Qwen2AudioConfig):
|
||||
super().__init__(config)
|
||||
self.audio_tower = AutoModel.from_config(config.audio_config)
|
||||
self.audio_tower = AutoModel.from_config(config.audio_config) # Usually a `Qwen2AudioEncoder` instance
|
||||
|
||||
self.multi_modal_projector = Qwen2AudioMultiModalProjector(config)
|
||||
self.vocab_size = config.text_config.vocab_size
|
||||
@ -1103,7 +1028,7 @@ class Qwen2AudioForConditionalGeneration(Qwen2AudioPreTrainedModel, GenerationMi
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
feature_attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
past_key_values: Optional[Cache] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
@ -1258,78 +1183,5 @@ class Qwen2AudioForConditionalGeneration(Qwen2AudioPreTrainedModel, GenerationMi
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
|
||||
def prepare_inputs_for_generation(
|
||||
self,
|
||||
input_ids,
|
||||
past_key_values=None,
|
||||
inputs_embeds=None,
|
||||
input_features=None,
|
||||
attention_mask=None,
|
||||
**kwargs,
|
||||
):
|
||||
# Overwritten -- custom processing (note: might not be needed, but there are no generation tests running atm)
|
||||
|
||||
if past_key_values is not None:
|
||||
if isinstance(past_key_values, Cache):
|
||||
cache_length = past_key_values.get_seq_length()
|
||||
past_length = past_key_values.seen_tokens
|
||||
else:
|
||||
cache_length = past_length = past_key_values[0][0].shape[2]
|
||||
|
||||
# Here, we get the attention_mask, which was previously stored in the state after _merge_input_ids_with_audio_features.
|
||||
if input_features is not None and kwargs.get("attention_mask") is not None:
|
||||
attention_mask = kwargs["attention_mask"]
|
||||
attention_mask = torch.cat(
|
||||
[attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
|
||||
)
|
||||
|
||||
# Keep only the unprocessed tokens:
|
||||
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
|
||||
# some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
|
||||
# input)
|
||||
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
|
||||
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
|
||||
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
|
||||
# input_ids based on the past_length.
|
||||
elif past_length < input_ids.shape[1]:
|
||||
input_ids = input_ids[:, past_length:]
|
||||
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
|
||||
elif self.config.audio_token_index in input_ids:
|
||||
input_ids = input_ids[:, input_ids.shape[1] - 1 :]
|
||||
# If the cache has seen more tokens than it can hold, then the cache has a size limit. Let's discard the
|
||||
# older attention values, as their corresponding values are not part of the input.
|
||||
if cache_length < past_length and attention_mask is not None:
|
||||
attention_mask = attention_mask[:, -(cache_length + input_ids.shape[1]) :]
|
||||
|
||||
position_ids = kwargs.get("position_ids", None)
|
||||
if attention_mask is not None and position_ids is None:
|
||||
# create position_ids on the fly for batch generation
|
||||
position_ids = attention_mask.long().cumsum(-1) - 1
|
||||
position_ids.masked_fill_(attention_mask == 0, 1)
|
||||
if past_key_values:
|
||||
position_ids = position_ids[:, -input_ids.shape[1] :]
|
||||
|
||||
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
||||
if inputs_embeds is not None and past_key_values is None:
|
||||
model_inputs = {"inputs_embeds": inputs_embeds}
|
||||
else:
|
||||
model_inputs = {"input_ids": input_ids}
|
||||
|
||||
feature_attention_mask = kwargs.get("feature_attention_mask", None)
|
||||
model_inputs.update(
|
||||
{
|
||||
"position_ids": position_ids,
|
||||
"past_key_values": past_key_values,
|
||||
"use_cache": kwargs.get("use_cache"),
|
||||
"attention_mask": attention_mask,
|
||||
"input_features": input_features,
|
||||
"feature_attention_mask": feature_attention_mask,
|
||||
}
|
||||
)
|
||||
return model_inputs
|
||||
|
||||
def _reorder_cache(self, *args, **kwargs):
|
||||
return self.language_model._reorder_cache(*args, **kwargs)
|
||||
|
||||
|
||||
__all__ = ["Qwen2AudioForConditionalGeneration", "Qwen2AudioPreTrainedModel", "Qwen2AudioEncoder"]
|
||||
|
Loading…
Reference in New Issue
Block a user