mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 12:50:06 +06:00
Add Qwen2.5-Omni (#36752)
* Add qwen2.5-omni * Remove einops dependency * Add torchdiffeq dependency * Sort init * Add torchdiffeq to extras['diffeq'] * Fix repo consistency * use cached_file * del odeint * renew pytest * format * Remove torchdiffeq * format * fixed batch infer bug * Change positional_embedding to parameter * Change default speaker * Config revision * Use modular & code clean * code clean * decouple padding with model & code cleaning * sort init * fix * fix * Second code review * fix * fix * rename vars to full name + some comments * update pytest * Code clean & fix * fix * style * more clean up * fixup * smaller vision model in tests * fix processor test * deflake a bit the tests (still flaky though) * de-flake tests finally + add generation mixin * final nits i hope * make sure processor tests are complete * replace with Qwen2_5OmniForConditionalGeneration * fix tests after updating ckpt * fix typos when cleaning, also we can't change ckpt * fixup * images and videos kwargs for processor * thinker and talker loadable from hub ckpt * address comments and update tests after rebase * fixup * skip for now * fixup * fixup * remove torch dependency in processors --------- Co-authored-by: lvyuanjun.lyj <lvyuanjun.lyj@alibaba-inc.con> Co-authored-by: feizi.wx <feizi.wx@alibaba-inc.com> Co-authored-by: raushan <raushan@huggingface.co>
This commit is contained in:
parent
ac1df5fccd
commit
4b8c6d4cf8
@ -993,6 +993,8 @@
|
||||
title: Pix2Struct
|
||||
- local: model_doc/pixtral
|
||||
title: Pixtral
|
||||
- local: model_doc/qwen2_5_omni
|
||||
title: Qwen2.5-Omni
|
||||
- local: model_doc/qwen2_5_vl
|
||||
title: Qwen2.5-VL
|
||||
- local: model_doc/qwen2_audio
|
||||
|
400
docs/source/en/model_doc/qwen2_5_omni.md
Normal file
400
docs/source/en/model_doc/qwen2_5_omni.md
Normal file
@ -0,0 +1,400 @@
|
||||
<!--Copyright 2025 The Qwen Team and The HuggingFace Inc. team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
|
||||
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
||||
rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
|
||||
# Qwen2.5-Omni
|
||||
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
<img alt="FlashAttention" src="https://img.shields.io/badge/%E2%9A%A1%EF%B8%8E%20FlashAttention-eae0c8?style=flat">
|
||||
<img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
</div>
|
||||
|
||||
## Overview
|
||||
|
||||
The [Qwen2.5-Omni](https://qwenlm.github.io/blog/) model is a unified multiple modalities model proposed in [Qwen2.5-Omni Technical Report]() from Qwen team, Alibaba Group.
|
||||
|
||||
The abstract from the technical report is the following:
|
||||
|
||||
*We present Qwen2.5-Omni, an end-to-end multimodal model designed to perceive diverse modalities, including text, images, audio, and video, while simultaneously generating text and natural speech responses in a streaming manner. To enable the streaming of multimodal information inputs, both audio and visual encoders utilize a block-wise processing approach. This strategy effectively decouples the handling of long sequences of multimodal data, assigning the perceptual responsibilities to the multimodal encoder and entrusting the modeling of extended sequences to a large language model. Such a division of labor enhances the fusion of different modalities via the shared attention mechanism. To synchronize the timestamps of video inputs with audio, we organized the audio and video sequentially in an interleaved manner and propose a novel position embedding approach, named TMRoPE (Time-aligned Multimodal RoPE). To concurrently generate text and speech while avoiding interference between the two modalities, we propose Thinker-Talker architecture. In this framework, Thinker functions as a large language model tasked with text generation, while Talker is a dual-track autoregressive model that directly utilizes the hidden representations from the Thinker to produce audio tokens as output. Both the Thinker and Talker models are designed to be trained and inferred in an end-to-end manner. For decoding audio tokens in a streaming manner, we introduce a sliding-window DiT that restricts the receptive field, aiming to reduce the initial package delay. Qwen2.5-Omni outperforms the similarly sized Qwen2-VL and Qwen2-Audio in both image and audio capabilities. Furthermore, Qwen2.5-Omni achieves state-of-the-art performance on multimodal benchmarks like Omni-Bench. Notably, Qwen2.5-Omni is the first open-source model to achieve a level of performance in end-to-end speech instruction following that is comparable to its capabilities with text inputs, as evidenced by benchmarks such as MMLU and GSM8K. As for speech generation, Qwen2.5-Omni’s streaming Talker outperform most existing streaming and non-streaming alternatives in robustness and naturalness.*
|
||||
|
||||
|
||||
|
||||
## Notes
|
||||
|
||||
- Use [`Qwen2_5OmniForConditionalGeneration`] to generate audio and text output. To generate only one output type, use [`Qwen2_5OmniThinkerForConditionalGeneration`] for text-only and [`Qwen2_5OmniTalkersForConditionalGeneration`] for audio-only outputs.
|
||||
- Audio generation with [`Qwen2_5OmniForConditionalGeneration`] supports only single batch size at the moment.
|
||||
- In case out out-of-memory errors hwen working with video input, decrease `processor.max_pixels`. By default the maximum is set to a very arge value and high resolution visuals will not be resized, unless resolution exceeds `processor.max_pixels`.
|
||||
- The processor has its own [`~ProcessorMixin.apply_chat_template`] method to convert chat messages to model inputs.
|
||||
|
||||
|
||||
## Usage example
|
||||
|
||||
`Qwen2.5-Omni` can be found on the [Huggingface Hub](https://huggingface.co/Qwen).
|
||||
|
||||
### Single Media inference
|
||||
|
||||
The model can accept text, images, audio and videos as input. Here's an example code for inference.
|
||||
|
||||
```python
|
||||
import soundfile as sf
|
||||
from transformers import Qwen2_5OmniForConditionalGeneration, Qwen2_5OmniProcessor
|
||||
|
||||
model = Qwen2_5OmniForConditionalGeneration.from_pretrained(
|
||||
"Qwen/Qwen2.5-Omni-7B",
|
||||
torch_dtype="auto",
|
||||
device_map="auto"
|
||||
)
|
||||
processor = Qwen2_5OmniProcessor.from_pretrained("Qwen/Qwen2.5-Omni-7B")
|
||||
|
||||
conversation = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": [
|
||||
{"type": "text", "text": "You are Qwen, a virtual human developed by the Qwen Team, Alibaba Group, capable of perceiving auditory and visual inputs, as well as generating text and speech."}
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "video", "video": "/path/to/video.mp4"},
|
||||
{"type": "text", "text": "What cant you hear and see in this video?"},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
inputs = processor.apply_chat_template(
|
||||
conversations,
|
||||
load_audio_from_video=True,
|
||||
add_generation_prompt=True,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
video_fps=1,
|
||||
|
||||
# kwargs to be passed to `Qwen2-5-OmniProcessor`
|
||||
padding=True,
|
||||
use_audio_in_video=True,
|
||||
).to(model.device)
|
||||
|
||||
# Generation params for audio or text can be different and have to be prefixed with `thinker_` or `talker_`
|
||||
text_ids, audio = model.generate(**inputs, use_audio_in_video=True, thinker_do_sample=False, talker_do_sample=True)
|
||||
text = processor.batch_decode(text_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
|
||||
|
||||
sf.write(
|
||||
"output.wav",
|
||||
audio.reshape(-1).detach().cpu().numpy(),
|
||||
samplerate=24000,
|
||||
)
|
||||
print(text)
|
||||
```
|
||||
|
||||
### Text-only generation
|
||||
|
||||
To generate only text output and save compute by not loading the audio generation model, we can use `Qwen2_5OmniThinkerForConditionalGeneration` model.
|
||||
|
||||
```python
|
||||
from transformers import Qwen2_5OmniThinkerForConditionalGeneration, Qwen2_5OmniProcessor
|
||||
|
||||
model = Qwen2_5OmniThinkerForConditionalGeneration.from_pretrained(
|
||||
"Qwen/Qwen2.5-Omni-7B",
|
||||
torch_dtype="auto",
|
||||
device_map="auto",
|
||||
)
|
||||
processor = Qwen2_5OmniProcessor.from_pretrained("Qwen/Qwen2.5-Omni-7B")
|
||||
|
||||
conversation = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": [
|
||||
{"type": "text", "text": "You are Qwen, a virtual human developed by the Qwen Team, Alibaba Group, capable of perceiving auditory and visual inputs, as well as generating text and speech."}
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "video", "video": "/path/to/video.mp4"},
|
||||
{"type": "text", "text": "What cant you hear and see in this video?"},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
inputs = processor.apply_chat_template(
|
||||
conversations,
|
||||
load_audio_from_video=True,
|
||||
add_generation_prompt=True,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
video_fps=1,
|
||||
|
||||
# kwargs to be passed to `Qwen2-5-OmniProcessor`
|
||||
padding=True,
|
||||
use_audio_in_video=True,
|
||||
).to(model.device)
|
||||
|
||||
|
||||
text_ids = model.generate(**inputs, use_audio_in_video=True)
|
||||
text = processor.batch_decode(text_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
|
||||
|
||||
sf.write(
|
||||
"output.wav",
|
||||
audio.reshape(-1).detach().cpu().numpy(),
|
||||
samplerate=24000,
|
||||
)
|
||||
print(text)
|
||||
```
|
||||
|
||||
### Batch Mixed Media Inference
|
||||
|
||||
The model can batch inputs composed of mixed samples of various types such as text, images, audio and videos as input when using `Qwen2_5OmniThinkerForConditionalGeneration` model. Here is an example.
|
||||
|
||||
```python
|
||||
import soundfile as sf
|
||||
from transformers import Qwen2_5OmniForConditionalGeneration, Qwen2_5OmniProcessor
|
||||
|
||||
model = Qwen2_5OmniForConditionalGeneration.from_pretrained(
|
||||
"Qwen/Qwen2.5-Omni-7B",
|
||||
torch_dtype="auto",
|
||||
device_map="auto"
|
||||
)
|
||||
processor = Qwen2_5OmniProcessor.from_pretrained("Qwen/Qwen2.5-Omni-7B")
|
||||
|
||||
# Conversation with video only
|
||||
conversation1 = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": [
|
||||
{"type": "text", "text": "You are Qwen, a virtual human developed by the Qwen Team, Alibaba Group, capable of perceiving auditory and visual inputs, as well as generating text and speech."}
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "video", "path": "/path/to/video.mp4"},
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
# Conversation with audio only
|
||||
conversation2 = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": [
|
||||
{"type": "text", "text": "You are Qwen, a virtual human developed by the Qwen Team, Alibaba Group, capable of perceiving auditory and visual inputs, as well as generating text and speech."}
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "audio", "path": "/path/to/audio.wav"},
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
# Conversation with pure text
|
||||
conversation3 = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": [
|
||||
{"type": "text", "text": "You are Qwen, a virtual human developed by the Qwen Team, Alibaba Group, capable of perceiving auditory and visual inputs, as well as generating text and speech."}
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{"type": "text", "text": "who are you?"}],
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
# Conversation with mixed media
|
||||
conversation4 = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": [
|
||||
{"type": "text", "text": "You are Qwen, a virtual human developed by the Qwen Team, Alibaba Group, capable of perceiving auditory and visual inputs, as well as generating text and speech."}
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image", "path": "/path/to/image.jpg"},
|
||||
{"type": "video", "path": "/path/to/video.mp4"},
|
||||
{"type": "audio", "path": "/path/to/audio.wav"},
|
||||
{"type": "text", "text": "What are the elements can you see and hear in these medias?"},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
conversations = [conversation1, conversation2, conversation3, conversation4]
|
||||
|
||||
inputs = processor.apply_chat_template(
|
||||
conversations,
|
||||
load_audio_from_video=True,
|
||||
add_generation_prompt=True,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
video_fps=1,
|
||||
|
||||
# kwargs to be passed to `Qwen2-5-OmniProcessor`
|
||||
padding=True,
|
||||
use_audio_in_video=True,
|
||||
).to(model.thinker.device)
|
||||
|
||||
text_ids = model.generate(**inputs, use_audio_in_video=True)
|
||||
text = processor.batch_decode(text_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
|
||||
|
||||
print(text)
|
||||
```
|
||||
|
||||
### Usage Tips
|
||||
|
||||
#### Image Resolution trade-off
|
||||
|
||||
The model supports a wide range of resolution inputs. By default, it uses the native resolution for input, but higher resolutions can enhance performance at the cost of more computation. Users can set the minimum and maximum number of pixels to achieve an optimal configuration for their needs.
|
||||
|
||||
```python
|
||||
min_pixels = 128*28*28
|
||||
max_pixels = 768*28*28
|
||||
processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-Omni-7B", min_pixels=min_pixels, max_pixels=max_pixels)
|
||||
```
|
||||
|
||||
#### Prompt for audio output
|
||||
If users need audio output, the system prompt must be set as "You are Qwen, a virtual human developed by the Qwen Team, Alibaba Group, capable of perceiving auditory and visual inputs, as well as generating text and speech.", otherwise the audio output may not work as expected.
|
||||
```
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are Qwen, a virtual human developed by the Qwen Team, Alibaba Group, capable of perceiving auditory and visual inputs, as well as generating text and speech.",
|
||||
}
|
||||
```
|
||||
|
||||
#### Use audio output or not
|
||||
|
||||
The model supports both text and audio outputs, if users do not need audio outputs, they can set `enable_audio_output` in the `from_pretrained` function. This option will save about `~2GB` of GPU memory but the `return_audio` option for `generate` function will only allow to be set at `False`.
|
||||
```python
|
||||
model = Qwen2_5OmniForConditionalGeneration.from_pretrained(
|
||||
"Qwen/Qwen2.5-Omni-7B",
|
||||
torch_dtype="auto",
|
||||
device_map="auto",
|
||||
enable_audio_output=False,
|
||||
)
|
||||
```
|
||||
|
||||
In order to obtain a flexible experience, we recommend that users set `enable_audio_output` at `True` when initializing the model through `from_pretrained` function, and then decide whether to return audio when `generate` function is called. When `return_audio` is set to `False`, the model will only return text outputs to get text responses faster.
|
||||
|
||||
```python
|
||||
model = Qwen2_5OmniForConditionalGeneration.from_pretrained(
|
||||
"Qwen/Qwen2.5-Omni-7B",
|
||||
torch_dtype="auto",
|
||||
device_map="auto",
|
||||
enable_audio_output=True,
|
||||
)
|
||||
...
|
||||
text_ids = model.generate(**inputs, return_audio=False)
|
||||
```
|
||||
|
||||
#### Change voice type of output audio
|
||||
Qwen2.5-Omni supports the ability to change the voice of the output audio. Users can use the `spk` parameter of `generate` function to specify the voice type. The `"Qwen/Qwen2.5-Omni-7B"` checkpoint support two voice types: `Chelsie` and `Ethan`, while `Chelsie` is a female voice and `Ethan` is a male voice. By defalut, if `spk` is not specified, the default voice type is `Chelsie`.
|
||||
|
||||
```python
|
||||
text_ids, audio = model.generate(**inputs, spk="Chelsie")
|
||||
```
|
||||
|
||||
```python
|
||||
text_ids, audio = model.generate(**inputs, spk="Ethan")
|
||||
```
|
||||
|
||||
#### Flash-Attention 2 to speed up generation
|
||||
|
||||
First, make sure to install the latest version of Flash Attention 2:
|
||||
|
||||
```bash
|
||||
pip install -U flash-attn --no-build-isolation
|
||||
```
|
||||
|
||||
Also, you should have hardware that is compatible with FlashAttention 2. Read more about it in the official documentation of the [flash attention repository](https://github.com/Dao-AILab/flash-attention). FlashAttention-2 can only be used when a model is loaded in `torch.float16` or `torch.bfloat16`.
|
||||
|
||||
To load and run a model using FlashAttention-2, add `attn_implementation="flash_attention_2"` when loading the model:
|
||||
|
||||
```python
|
||||
from transformers import Qwen2_5OmniForConditionalGeneration
|
||||
|
||||
model = Qwen2_5OmniForConditionalGeneration.from_pretrained(
|
||||
"Qwen/Qwen2.5-Omni-7B",
|
||||
device_map="auto",
|
||||
torch_dtype=torch.bfloat16,
|
||||
attn_implementation="flash_attention_2",
|
||||
)
|
||||
```
|
||||
|
||||
|
||||
|
||||
## Qwen2_5OmniConfig
|
||||
|
||||
[[autodoc]] Qwen2_5OmniConfig
|
||||
|
||||
## Qwen2_5OmniProcessor
|
||||
|
||||
[[autodoc]] Qwen2_5OmniProcessor
|
||||
|
||||
## Qwen2_5OmniForConditionalGeneration
|
||||
|
||||
[[autodoc]] Qwen2_5OmniForConditionalGeneration
|
||||
- forward
|
||||
|
||||
## Qwen2_5OmniPreTrainedModelForConditionalGeneration
|
||||
|
||||
[[autodoc]] Qwen2_5OmniPreTrainedModelForConditionalGeneration
|
||||
|
||||
## Qwen2_5OmniThinkerConfig
|
||||
|
||||
[[autodoc]] Qwen2_5OmniThinkerConfig
|
||||
|
||||
## Qwen2_5OmniThinkerForConditionalGeneration
|
||||
|
||||
[[autodoc]] Qwen2_5OmniThinkerForConditionalGeneration
|
||||
|
||||
## Qwen2_5OmniThinkerTextModel
|
||||
|
||||
[[autodoc]] Qwen2_5OmniThinkerTextModel
|
||||
|
||||
## Qwen2_5OmniTalkerConfig
|
||||
|
||||
[[autodoc]] Qwen2_5OmniTalkerConfig
|
||||
|
||||
## Qwen2_5OmniTalkerForConditionalGeneration
|
||||
|
||||
[[autodoc]] Qwen2_5OmniTalkerForConditionalGeneration
|
||||
|
||||
## Qwen2_5OmniTalkerModel
|
||||
|
||||
[[autodoc]] Qwen2_5OmniTalkerModel
|
||||
|
||||
## Qwen2_5OmniToken2WavConfig
|
||||
|
||||
[[autodoc]] Qwen2_5OmniToken2WavConfig
|
||||
|
||||
## Qwen2_5OmniToken2WavModel
|
||||
|
||||
[[autodoc]] Qwen2_5OmniToken2WavModel
|
||||
|
||||
## Qwen2_5OmniToken2WavDiTModel
|
||||
|
||||
[[autodoc]] Qwen2_5OmniToken2WavDiTModel
|
||||
|
||||
## Qwen2_5OmniToken2WavBigVGANModel
|
||||
|
||||
[[autodoc]] Qwen2_5OmniToken2WavBigVGANModel
|
@ -254,6 +254,7 @@ CONFIG_MAPPING_NAMES = OrderedDict(
|
||||
("pvt_v2", "PvtV2Config"),
|
||||
("qdqbert", "QDQBertConfig"),
|
||||
("qwen2", "Qwen2Config"),
|
||||
("qwen2_5_omni", "Qwen2_5OmniConfig"),
|
||||
("qwen2_5_vl", "Qwen2_5_VLConfig"),
|
||||
("qwen2_audio", "Qwen2AudioConfig"),
|
||||
("qwen2_audio_encoder", "Qwen2AudioEncoderConfig"),
|
||||
@ -617,6 +618,7 @@ MODEL_NAMES_MAPPING = OrderedDict(
|
||||
("pvt_v2", "PVTv2"),
|
||||
("qdqbert", "QDQBert"),
|
||||
("qwen2", "Qwen2"),
|
||||
("qwen2_5_omni", "Qwen2_5Omni"),
|
||||
("qwen2_5_vl", "Qwen2_5_VL"),
|
||||
("qwen2_audio", "Qwen2Audio"),
|
||||
("qwen2_audio_encoder", "Qwen2AudioEncoder"),
|
||||
|
@ -1429,6 +1429,7 @@ MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING_NAMES = OrderedDict(
|
||||
("fastspeech2_conformer", "FastSpeech2ConformerWithHifiGan"),
|
||||
("musicgen", "MusicgenForConditionalGeneration"),
|
||||
("musicgen_melody", "MusicgenMelodyForConditionalGeneration"),
|
||||
("qwen2_5_omni", "Qwen2_5OmniForConditionalGeneration"),
|
||||
("seamless_m4t", "SeamlessM4TForTextToSpeech"),
|
||||
("seamless_m4t_v2", "SeamlessM4Tv2ForTextToSpeech"),
|
||||
("vits", "VitsModel"),
|
||||
|
@ -97,6 +97,7 @@ PROCESSOR_MAPPING_NAMES = OrderedDict(
|
||||
("pix2struct", "Pix2StructProcessor"),
|
||||
("pixtral", "PixtralProcessor"),
|
||||
("pop2piano", "Pop2PianoProcessor"),
|
||||
("qwen2_5_omni", "Qwen2_5OmniProcessor"),
|
||||
("qwen2_5_vl", "Qwen2_5_VLProcessor"),
|
||||
("qwen2_audio", "Qwen2AudioProcessor"),
|
||||
("qwen2_vl", "Qwen2VLProcessor"),
|
||||
|
@ -459,6 +459,7 @@ else:
|
||||
"Qwen2TokenizerFast" if is_tokenizers_available() else None,
|
||||
),
|
||||
),
|
||||
("qwen2_5_omni", ("Qwen2Tokenizer", "Qwen2TokenizerFast" if is_tokenizers_available() else None)),
|
||||
("qwen2_5_vl", ("Qwen2Tokenizer", "Qwen2TokenizerFast" if is_tokenizers_available() else None)),
|
||||
("qwen2_audio", ("Qwen2Tokenizer", "Qwen2TokenizerFast" if is_tokenizers_available() else None)),
|
||||
(
|
||||
|
@ -678,7 +678,7 @@ class Glm4Model(Glm4PreTrainedModel):
|
||||
if (
|
||||
self.config._attn_implementation == "sdpa"
|
||||
and attention_mask is not None
|
||||
and attention_mask.device.type in ["cuda", "xpu"]
|
||||
and attention_mask.device.type in ["cuda", "xpu", "npu"]
|
||||
and not output_attentions
|
||||
):
|
||||
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
|
||||
|
@ -1161,12 +1161,12 @@ class MimiTransformerModel(nn.Module):
|
||||
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
|
||||
)
|
||||
diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
|
||||
if config.sliding_window is not None:
|
||||
if config.get_text_config().sliding_window is not None:
|
||||
# if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also
|
||||
# the check is needed to verify is current checkpoint was trained with sliding window or not
|
||||
if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length:
|
||||
sliding_attend_mask = torch.arange(target_length, device=device) <= (
|
||||
cache_position.reshape(-1, 1) - config.sliding_window
|
||||
cache_position.reshape(-1, 1) - config.get_text_config().sliding_window
|
||||
)
|
||||
diagonal_attend_mask.bitwise_or_(sliding_attend_mask)
|
||||
causal_mask *= diagonal_attend_mask
|
||||
|
@ -692,12 +692,12 @@ class MistralModel(MistralPreTrainedModel):
|
||||
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
|
||||
)
|
||||
diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
|
||||
if config.sliding_window is not None:
|
||||
if config.get_text_config().sliding_window is not None:
|
||||
# if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also
|
||||
# the check is needed to verify is current checkpoint was trained with sliding window or not
|
||||
if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length:
|
||||
sliding_attend_mask = torch.arange(target_length, device=device) <= (
|
||||
cache_position.reshape(-1, 1) - config.sliding_window
|
||||
cache_position.reshape(-1, 1) - config.get_text_config().sliding_window
|
||||
)
|
||||
diagonal_attend_mask.bitwise_or_(sliding_attend_mask)
|
||||
causal_mask *= diagonal_attend_mask
|
||||
|
@ -246,12 +246,12 @@ class MistralModel(LlamaModel):
|
||||
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
|
||||
)
|
||||
diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
|
||||
if config.sliding_window is not None:
|
||||
if config.get_text_config().sliding_window is not None:
|
||||
# if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also
|
||||
# the check is needed to verify is current checkpoint was trained with sliding window or not
|
||||
if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length:
|
||||
sliding_attend_mask = torch.arange(target_length, device=device) <= (
|
||||
cache_position.reshape(-1, 1) - config.sliding_window
|
||||
cache_position.reshape(-1, 1) - config.get_text_config().sliding_window
|
||||
)
|
||||
diagonal_attend_mask.bitwise_or_(sliding_attend_mask)
|
||||
causal_mask *= diagonal_attend_mask
|
||||
|
@ -821,12 +821,12 @@ class MixtralModel(MixtralPreTrainedModel):
|
||||
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
|
||||
)
|
||||
diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
|
||||
if config.sliding_window is not None:
|
||||
if config.get_text_config().sliding_window is not None:
|
||||
# if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also
|
||||
# the check is needed to verify is current checkpoint was trained with sliding window or not
|
||||
if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length:
|
||||
sliding_attend_mask = torch.arange(target_length, device=device) <= (
|
||||
cache_position.reshape(-1, 1) - config.sliding_window
|
||||
cache_position.reshape(-1, 1) - config.get_text_config().sliding_window
|
||||
)
|
||||
diagonal_attend_mask.bitwise_or_(sliding_attend_mask)
|
||||
causal_mask *= diagonal_attend_mask
|
||||
|
@ -1391,12 +1391,12 @@ class MoshiDepthDecoder(MoshiPreTrainedModel, GenerationMixin):
|
||||
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
|
||||
)
|
||||
diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
|
||||
if config.sliding_window is not None:
|
||||
if config.get_text_config().sliding_window is not None:
|
||||
# if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also
|
||||
# the check is needed to verify is current checkpoint was trained with sliding window or not
|
||||
if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length:
|
||||
sliding_attend_mask = torch.arange(target_length, device=device) <= (
|
||||
cache_position.reshape(-1, 1) - config.sliding_window
|
||||
cache_position.reshape(-1, 1) - config.get_text_config().sliding_window
|
||||
)
|
||||
diagonal_attend_mask.bitwise_or_(sliding_attend_mask)
|
||||
causal_mask *= diagonal_attend_mask
|
||||
@ -1705,12 +1705,12 @@ class MoshiModel(MoshiPreTrainedModel):
|
||||
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
|
||||
)
|
||||
diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
|
||||
if config.sliding_window is not None:
|
||||
if config.get_text_config().sliding_window is not None:
|
||||
# if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also
|
||||
# the check is needed to verify is current checkpoint was trained with sliding window or not
|
||||
if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length:
|
||||
sliding_attend_mask = torch.arange(target_length, device=device) <= (
|
||||
cache_position.reshape(-1, 1) - config.sliding_window
|
||||
cache_position.reshape(-1, 1) - config.get_text_config().sliding_window
|
||||
)
|
||||
diagonal_attend_mask.bitwise_or_(sliding_attend_mask)
|
||||
causal_mask *= diagonal_attend_mask
|
||||
|
@ -747,12 +747,12 @@ class Phi3Model(Phi3PreTrainedModel):
|
||||
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
|
||||
)
|
||||
diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
|
||||
if config.sliding_window is not None:
|
||||
if config.get_text_config().sliding_window is not None:
|
||||
# if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also
|
||||
# the check is needed to verify is current checkpoint was trained with sliding window or not
|
||||
if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length:
|
||||
sliding_attend_mask = torch.arange(target_length, device=device) <= (
|
||||
cache_position.reshape(-1, 1) - config.sliding_window
|
||||
cache_position.reshape(-1, 1) - config.get_text_config().sliding_window
|
||||
)
|
||||
diagonal_attend_mask.bitwise_or_(sliding_attend_mask)
|
||||
causal_mask *= diagonal_attend_mask
|
||||
|
@ -2041,12 +2041,12 @@ class Phi4MultimodalModel(Phi4MultimodalPreTrainedModel):
|
||||
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
|
||||
)
|
||||
diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
|
||||
if config.sliding_window is not None:
|
||||
if config.get_text_config().sliding_window is not None:
|
||||
# if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also
|
||||
# the check is needed to verify is current checkpoint was trained with sliding window or not
|
||||
if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length:
|
||||
sliding_attend_mask = torch.arange(target_length, device=device) <= (
|
||||
cache_position.reshape(-1, 1) - config.sliding_window
|
||||
cache_position.reshape(-1, 1) - config.get_text_config().sliding_window
|
||||
)
|
||||
diagonal_attend_mask.bitwise_or_(sliding_attend_mask)
|
||||
causal_mask *= diagonal_attend_mask
|
||||
|
@ -1294,12 +1294,12 @@ class PhimoeModel(PhimoePreTrainedModel):
|
||||
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
|
||||
)
|
||||
diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
|
||||
if config.sliding_window is not None:
|
||||
if config.get_text_config().sliding_window is not None:
|
||||
# if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also
|
||||
# the check is needed to verify is current checkpoint was trained with sliding window or not
|
||||
if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length:
|
||||
sliding_attend_mask = torch.arange(target_length, device=device) <= (
|
||||
cache_position.reshape(-1, 1) - config.sliding_window
|
||||
cache_position.reshape(-1, 1) - config.get_text_config().sliding_window
|
||||
)
|
||||
diagonal_attend_mask.bitwise_or_(sliding_attend_mask)
|
||||
causal_mask *= diagonal_attend_mask
|
||||
|
@ -705,12 +705,12 @@ class Qwen2Model(Qwen2PreTrainedModel):
|
||||
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
|
||||
)
|
||||
diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
|
||||
if config.sliding_window is not None:
|
||||
if config.get_text_config().sliding_window is not None:
|
||||
# if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also
|
||||
# the check is needed to verify is current checkpoint was trained with sliding window or not
|
||||
if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length:
|
||||
sliding_attend_mask = torch.arange(target_length, device=device) <= (
|
||||
cache_position.reshape(-1, 1) - config.sliding_window
|
||||
cache_position.reshape(-1, 1) - config.get_text_config().sliding_window
|
||||
)
|
||||
diagonal_attend_mask.bitwise_or_(sliding_attend_mask)
|
||||
causal_mask *= diagonal_attend_mask
|
||||
|
28
src/transformers/models/qwen2_5_omni/__init__.py
Normal file
28
src/transformers/models/qwen2_5_omni/__init__.py
Normal file
@ -0,0 +1,28 @@
|
||||
# Copyright 2025 The Qwen Team and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...utils import _LazyModule
|
||||
from ...utils.import_utils import define_import_structure
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .configuration_qwen2_5_omni import *
|
||||
from .modeling_qwen2_5_omni import *
|
||||
from .processing_qwen2_5_omni import *
|
||||
else:
|
||||
import sys
|
||||
|
||||
_file = globals()["__file__"]
|
||||
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
|
1036
src/transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py
Normal file
1036
src/transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py
Normal file
File diff suppressed because it is too large
Load Diff
4639
src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py
Normal file
4639
src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py
Normal file
File diff suppressed because it is too large
Load Diff
4313
src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py
Normal file
4313
src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py
Normal file
File diff suppressed because it is too large
Load Diff
356
src/transformers/models/qwen2_5_omni/processing_qwen2_5_omni.py
Normal file
356
src/transformers/models/qwen2_5_omni/processing_qwen2_5_omni.py
Normal file
@ -0,0 +1,356 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Processor class for Qwen2.5Omni.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ...feature_extraction_utils import BatchFeature
|
||||
from ...image_utils import ImageInput, VideoInput, make_batched_videos
|
||||
from ...processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin, Unpack, VideosKwargs
|
||||
from ...tokenization_utils_base import AudioInput, PreTokenizedInput, TextInput
|
||||
|
||||
|
||||
class Qwen2_5_OmniVideosKwargs(VideosKwargs):
|
||||
fps: Optional[List[int]] = None
|
||||
use_audio_in_video: Optional[bool] = None
|
||||
seconds_per_chunk: Optional[float] = None
|
||||
position_id_per_seconds: Optional[int] = None
|
||||
min_pixels: Optional[int]
|
||||
max_pixels: Optional[int]
|
||||
patch_size: Optional[int]
|
||||
temporal_patch_size: Optional[int]
|
||||
merge_size: Optional[int]
|
||||
|
||||
|
||||
class Qwen2_5_OmniImagesKwargs(ImagesKwargs):
|
||||
min_pixels: Optional[int]
|
||||
max_pixels: Optional[int]
|
||||
patch_size: Optional[int]
|
||||
temporal_patch_size: Optional[int]
|
||||
merge_size: Optional[int]
|
||||
|
||||
|
||||
class Qwen2_5OmniProcessorKwargs(ProcessingKwargs, total=False):
|
||||
videos_kwargs: Qwen2_5_OmniVideosKwargs
|
||||
images_kwargs: Qwen2_5_OmniImagesKwargs
|
||||
_defaults = {
|
||||
"text_kwargs": {
|
||||
"padding": False,
|
||||
"padding_side": "left",
|
||||
},
|
||||
"videos_kwargs": {
|
||||
"seconds_per_chunk": 2.0,
|
||||
"position_id_per_seconds": 25,
|
||||
"use_audio_in_video": False,
|
||||
},
|
||||
"audio_kwargs": {
|
||||
"sampling_rate": 16000,
|
||||
"padding": "max_length",
|
||||
"return_attention_mask": True,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class Qwen2_5OmniProcessor(ProcessorMixin):
|
||||
r"""
|
||||
Constructs a Qwen2.5Omni processor.
|
||||
[`Qwen2_5OmniProcessor`] offers all the functionalities of [`Qwen2VLImageProcessor`], [`WhisperFeatureExtractor`], and [`Qwen2TokenizerFast`]. See the
|
||||
[`~Qwen2_5OmniProcessor.__call__`] and [`~Qwen2_5OmniProcessor.decode`] for more information.
|
||||
|
||||
Args:
|
||||
image_processor ([`Qwen2VLImageProcessor`], *optional*):
|
||||
The image processor.
|
||||
feature_extractor ([`WhisperFeatureExtractor`], *optional*):
|
||||
The audio feature extractor.
|
||||
tokenizer ([`Qwen2TokenizerFast`], *optional*):
|
||||
The text tokenizer.
|
||||
chat_template (`Optional[str]`, *optional*):
|
||||
The Jinja template to use for formatting the conversation. If not provided, the default chat template is used.
|
||||
"""
|
||||
|
||||
attributes = ["image_processor", "feature_extractor", "tokenizer"]
|
||||
image_processor_class = "Qwen2VLImageProcessor"
|
||||
feature_extractor_class = "WhisperFeatureExtractor"
|
||||
tokenizer_class = ("Qwen2Tokenizer", "Qwen2TokenizerFast")
|
||||
valid_kwargs = ["chat_template"]
|
||||
|
||||
def __init__(self, image_processor=None, feature_extractor=None, tokenizer=None, chat_template=None):
|
||||
super().__init__(image_processor, feature_extractor, tokenizer, chat_template=chat_template)
|
||||
self.image_token = self.tokenizer.image_token
|
||||
self.audio_token = self.tokenizer.audio_token
|
||||
self.video_token = self.tokenizer.video_token
|
||||
self.vision_bos_token = self.tokenizer.vision_bos_token
|
||||
self.vision_eos_token = self.tokenizer.vision_eos_token
|
||||
self.audio_bos_token = self.tokenizer.audio_bos_token
|
||||
self.audio_eos_token = self.tokenizer.audio_eos_token
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
|
||||
images: ImageInput = None,
|
||||
videos: VideoInput = None,
|
||||
audio: AudioInput = None,
|
||||
**kwargs: Unpack[Qwen2_5OmniProcessorKwargs],
|
||||
) -> BatchFeature:
|
||||
"""
|
||||
Main method to prepare for the model one or several sequences(s) and audio(s). This method forwards the `text`
|
||||
and `kwargs` arguments to Qwen2TokenizerFast's [`~Qwen2TokenizerFast.__call__`] if `text` is not `None` to encode
|
||||
the text. To prepare the audio(s), this method forwards the `audio` and `kwargs` arguments to
|
||||
WhisperFeatureExtractor's [`~WhisperFeatureExtractor.__call__`] if `audio` is not `None`. To prepare the vision inputs,
|
||||
this method forwards the `vision_infos` and `kwargs` arguments to Qwen2VLImageProcessor's [`~Qwen2VLImageProcessor.__call__`]
|
||||
if `vision_infos` is not `None`. Please refer to the doctsring
|
||||
of the above two methods for more information.
|
||||
|
||||
Args:
|
||||
text (`str`, `List[str]`, `List[List[str]]`):
|
||||
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
|
||||
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
|
||||
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
|
||||
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
|
||||
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
|
||||
tensor. Both channels-first and channels-last formats are supported.
|
||||
videos (`np.ndarray`, `torch.Tensor`, `List[np.ndarray]`, `List[torch.Tensor]`):
|
||||
The image or batch of videos to be prepared. Each video can be a 4D NumPy array or PyTorch
|
||||
tensor, or a nested list of 3D frames. Both channels-first and channels-last formats are supported.
|
||||
audio (`np.ndarray`, `List[np.ndarray]`):
|
||||
The audio or batch of audio to be prepared. Each audio can be a NumPy array.
|
||||
"""
|
||||
|
||||
if text is None:
|
||||
raise ValueError("You need to specify either a `text` input to process.")
|
||||
|
||||
output_kwargs = self._merge_kwargs(
|
||||
Qwen2_5OmniProcessorKwargs,
|
||||
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
seconds_per_chunk = output_kwargs["videos_kwargs"].pop("seconds_per_chunk")
|
||||
position_id_per_seconds = output_kwargs["videos_kwargs"].pop("position_id_per_seconds")
|
||||
use_audio_in_video = output_kwargs["videos_kwargs"].pop("use_audio_in_video")
|
||||
fps = output_kwargs["videos_kwargs"].pop("fps", None)
|
||||
|
||||
if audio is not None:
|
||||
output_kwargs["audio_kwargs"]["padding"] = "max_length" # Support "max_length" padding only here
|
||||
audio_inputs = self.feature_extractor(audio, **output_kwargs["audio_kwargs"])
|
||||
audio_inputs["feature_attention_mask"] = audio_inputs.pop(
|
||||
"attention_mask"
|
||||
) # rename feature_attention_mask to prevent conflicts later on
|
||||
audio_inputs["input_features"] = audio_inputs.pop(
|
||||
"input_features"
|
||||
) # rename input_features to prevent conflicts later on
|
||||
input_lengths = (audio_inputs["feature_attention_mask"].sum(-1) - 1) // 2 + 1
|
||||
audio_lengths = iter((input_lengths - 2) // 2 + 1)
|
||||
else:
|
||||
audio_inputs = {}
|
||||
audio_lengths = iter([])
|
||||
|
||||
if images is not None:
|
||||
images_inputs = self.image_processor(images=images, videos=None, **output_kwargs["images_kwargs"])
|
||||
image_grid_thw = iter(images_inputs["image_grid_thw"])
|
||||
else:
|
||||
images_inputs = {}
|
||||
image_grid_thw = iter([])
|
||||
|
||||
if videos is not None:
|
||||
videos = make_batched_videos(videos)
|
||||
videos_inputs = self.image_processor(images=None, videos=videos, **output_kwargs["videos_kwargs"])
|
||||
if fps is None:
|
||||
fps = [2.0] * len(videos)
|
||||
videos_inputs["video_second_per_grid"] = [
|
||||
self.image_processor.temporal_patch_size / fps[i] for i in range(len(fps))
|
||||
]
|
||||
video_grid_thw = iter(videos_inputs["video_grid_thw"])
|
||||
video_second_per_grid = iter(videos_inputs["video_second_per_grid"])
|
||||
else:
|
||||
videos_inputs = {}
|
||||
video_grid_thw = iter([])
|
||||
video_second_per_grid = iter([])
|
||||
|
||||
if not isinstance(text, list):
|
||||
text = [text]
|
||||
|
||||
text = self.replace_multimodal_special_tokens(
|
||||
text,
|
||||
audio_lengths,
|
||||
image_grid_thw,
|
||||
video_grid_thw,
|
||||
video_second_per_grid=video_second_per_grid,
|
||||
use_audio_in_video=use_audio_in_video,
|
||||
position_id_per_seconds=position_id_per_seconds,
|
||||
seconds_per_chunk=seconds_per_chunk,
|
||||
)
|
||||
|
||||
texts_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"])
|
||||
|
||||
return BatchFeature(
|
||||
data={**texts_inputs, **images_inputs, **videos_inputs, **audio_inputs},
|
||||
tensor_type=kwargs.get("return_tensors"),
|
||||
)
|
||||
|
||||
def replace_multimodal_special_tokens(
|
||||
self,
|
||||
text,
|
||||
audio_lengths,
|
||||
image_grid_thw,
|
||||
video_grid_thw,
|
||||
video_second_per_grid,
|
||||
use_audio_in_video,
|
||||
position_id_per_seconds,
|
||||
seconds_per_chunk,
|
||||
):
|
||||
# Extend mm token length
|
||||
merge_length = self.image_processor.merge_size**2
|
||||
|
||||
processed_text = []
|
||||
for sample in text:
|
||||
positions = []
|
||||
special_tokens = [re.escape(tok) for tok in [self.audio_token, self.image_token, self.video_token]]
|
||||
pattern = "|".join(special_tokens)
|
||||
positions = sorted([(match.start(), match.group()) for match in re.finditer(pattern, sample)])
|
||||
positions.sort(key=lambda x: x[0])
|
||||
|
||||
for _, special_token in positions:
|
||||
if special_token == self.audio_token:
|
||||
sample = sample.replace(self.audio_token, "<|audio_placeholder|>" * next(audio_lengths), 1)
|
||||
elif special_token == self.image_token:
|
||||
image_seq_length = next(image_grid_thw).prod() // merge_length
|
||||
sample = sample.replace(self.image_token, "<|image_placeholder|>" * image_seq_length, 1)
|
||||
elif special_token == self.video_token:
|
||||
if not use_audio_in_video:
|
||||
video_seq_length = next(video_grid_thw).prod() // merge_length
|
||||
sample = sample.replace(self.video_token, "<|video_placeholder|>" * video_seq_length, 1)
|
||||
else:
|
||||
audio_token_indices = np.arange(next(audio_lengths))
|
||||
curr_video_grid_thw = next(video_grid_thw)
|
||||
height = curr_video_grid_thw[1] // self.image_processor.merge_size
|
||||
width = curr_video_grid_thw[2] // self.image_processor.merge_size
|
||||
video_token_indices = np.arange(curr_video_grid_thw[0]).view(-1, 1, 1)
|
||||
video_token_indices = video_token_indices.expand(-1, height, width).flatten()
|
||||
video_token_indices = (
|
||||
video_token_indices * next(video_second_per_grid) * position_id_per_seconds
|
||||
).long()
|
||||
|
||||
tokens_per_chunk = int(position_id_per_seconds * seconds_per_chunk)
|
||||
video_chunk_indexes = self.get_chunked_index(video_token_indices, tokens_per_chunk)
|
||||
audio_chunk_indexes = self.get_chunked_index(audio_token_indices, tokens_per_chunk)
|
||||
|
||||
placeholder_string = self.vision_bos_token + self.audio_bos_token
|
||||
for j in range(max(len(video_chunk_indexes), len(audio_chunk_indexes))):
|
||||
if j < len(video_chunk_indexes):
|
||||
video_seq_length = video_chunk_indexes[j][1] - video_chunk_indexes[j][0]
|
||||
placeholder_string += "<|video_placeholder|>" * video_seq_length
|
||||
if j < len(audio_chunk_indexes):
|
||||
audio_seq_length = audio_chunk_indexes[j][1] - audio_chunk_indexes[j][0]
|
||||
placeholder_string += "<|audio_placeholder|>" * audio_seq_length
|
||||
placeholder_string += self.audio_eos_token + self.vision_eos_token
|
||||
sample = sample.replace(
|
||||
self.vision_bos_token + self.video_token + self.vision_eos_token,
|
||||
placeholder_string,
|
||||
1,
|
||||
)
|
||||
|
||||
sample = sample.replace("<|audio_placeholder|>", self.audio_token)
|
||||
sample = sample.replace("<|image_placeholder|>", self.image_token)
|
||||
sample = sample.replace("<|video_placeholder|>", self.video_token)
|
||||
processed_text.append(sample)
|
||||
return processed_text
|
||||
|
||||
def get_chunked_index(self, token_indices: np.ndarray, tokens_per_chunk: int) -> list[tuple[int, int]]:
|
||||
"""
|
||||
Splits token index list into chunks based on token value ranges.
|
||||
|
||||
Given a list of token indices, returns a list of (start, end) index tuples representing
|
||||
slices of the list where the token values fall within successive ranges of `t_ntoken_per_chunk`.
|
||||
|
||||
For example, if `t_ntoken_per_chunk` is 1000, the function will create chunks such that:
|
||||
- the first chunk contains token values < 1000,
|
||||
- the second chunk contains values >= 1000 and < 2000, and so on.
|
||||
|
||||
Parameters:
|
||||
token_indices (`List[int]`): A monotonically increasing list of token index values.
|
||||
t_ntoken_per_chunk (`int`): Number of tokens per chunk (used as the chunk size threshold).
|
||||
|
||||
Returns:
|
||||
`List[Tuple[int, int]]`: A list of tuples, each representing the start (inclusive)
|
||||
and end (exclusive) indices of a chunk in `token_indices`.
|
||||
"""
|
||||
|
||||
def _iter():
|
||||
i, start_idx = 0, 0 # skip bos token
|
||||
current_chunk = 1
|
||||
while i < len(token_indices): # skip eos token
|
||||
if token_indices[i] >= current_chunk * tokens_per_chunk:
|
||||
yield (start_idx, i)
|
||||
start_idx = i
|
||||
current_chunk += 1
|
||||
i += 1
|
||||
yield (start_idx, len(token_indices))
|
||||
|
||||
return list(_iter())
|
||||
|
||||
def batch_decode(self, *args, **kwargs):
|
||||
"""
|
||||
This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
|
||||
refer to the docstring of this method for more information.
|
||||
"""
|
||||
return self.tokenizer.batch_decode(*args, **kwargs)
|
||||
|
||||
def decode(self, *args, **kwargs):
|
||||
"""
|
||||
This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
|
||||
the docstring of this method for more information.
|
||||
"""
|
||||
return self.tokenizer.decode(*args, **kwargs)
|
||||
|
||||
def apply_chat_template(self, conversations, chat_template=None, **kwargs):
|
||||
if isinstance(conversations[0], dict):
|
||||
conversations = [conversations]
|
||||
for conversation in conversations:
|
||||
if (
|
||||
conversation[0]["role"] != "system"
|
||||
or conversation[0]["content"][0]["text"]
|
||||
!= "You are Qwen, a virtual human developed by the Qwen Team, Alibaba Group, capable of perceiving auditory and visual inputs, as well as generating text and speech."
|
||||
):
|
||||
logging.warning(
|
||||
"System prompt modified, audio output may not work as expected. "
|
||||
+ "Audio output mode only works when using default system prompt 'You are Qwen, a virtual human developed by the Qwen Team, Alibaba Group, capable of perceiving auditory and visual inputs, as well as generating text and speech.'"
|
||||
)
|
||||
return super().apply_chat_template(conversations, chat_template, **kwargs)
|
||||
|
||||
@property
|
||||
def model_input_names(self):
|
||||
tokenizer_input_names = self.tokenizer.model_input_names
|
||||
feature_extractor_input_names = self.feature_extractor.model_input_names
|
||||
image_processor_input_names = self.image_processor.model_input_names
|
||||
return list(
|
||||
dict.fromkeys(
|
||||
tokenizer_input_names
|
||||
+ feature_extractor_input_names
|
||||
+ image_processor_input_names
|
||||
+ ["feature_attention_mask"]
|
||||
+ ["video_second_per_grid"]
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["Qwen2_5OmniProcessor"]
|
@ -868,7 +868,7 @@ class Qwen2_5_VLFlashAttention2(Qwen2_5_VLAttention):
|
||||
use_top_left_mask=self._flash_attn_uses_top_left_mask,
|
||||
)
|
||||
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
|
||||
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
if not output_attentions:
|
||||
@ -1332,12 +1332,12 @@ class Qwen2_5_VLModel(Qwen2_5_VLPreTrainedModel):
|
||||
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
|
||||
)
|
||||
diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
|
||||
if config.sliding_window is not None:
|
||||
if config.get_text_config().sliding_window is not None:
|
||||
# if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also
|
||||
# the check is needed to verify is current checkpoint was trained with sliding window or not
|
||||
if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length:
|
||||
sliding_attend_mask = torch.arange(target_length, device=device) <= (
|
||||
cache_position.reshape(-1, 1) - config.sliding_window
|
||||
cache_position.reshape(-1, 1) - config.get_text_config().sliding_window
|
||||
)
|
||||
diagonal_attend_mask.bitwise_or_(sliding_attend_mask)
|
||||
causal_mask *= diagonal_attend_mask
|
||||
|
@ -59,7 +59,7 @@ class Qwen2AudioEncoderConfig(PretrainedConfig):
|
||||
The dropout ratio for activations inside the fully connected layer.
|
||||
scale_embedding (`bool`, *optional*, defaults to `False`):
|
||||
Scale embeddings by diving by sqrt(d_model).
|
||||
init_std (`float`, *optional*, defaults to 0.02):
|
||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
max_source_positions (`int`, *optional*, defaults to 1500):
|
||||
The maximum sequence length of log-mel filter-bank features that this model might ever be used with.
|
||||
@ -94,7 +94,7 @@ class Qwen2AudioEncoderConfig(PretrainedConfig):
|
||||
activation_function="gelu",
|
||||
activation_dropout=0.0,
|
||||
scale_embedding=False,
|
||||
init_std=0.02,
|
||||
initializer_range=0.02,
|
||||
max_source_positions=1500,
|
||||
**kwargs,
|
||||
):
|
||||
@ -111,7 +111,7 @@ class Qwen2AudioEncoderConfig(PretrainedConfig):
|
||||
self.activation_dropout = activation_dropout
|
||||
self.encoder_layerdrop = encoder_layerdrop
|
||||
self.num_hidden_layers = encoder_layers
|
||||
self.init_std = init_std
|
||||
self.initializer_range = initializer_range
|
||||
self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
|
||||
self.max_source_positions = max_source_positions
|
||||
|
||||
|
@ -471,7 +471,11 @@ class Qwen2AudioPreTrainedModel(PreTrainedModel):
|
||||
def _init_weights(self, module):
|
||||
# important: this ported version of Qwen2Audio isn't meant for training from scratch - only
|
||||
# inference and fine-tuning - so the proper init weights code has been removed
|
||||
std = self.config.init_std if hasattr(self.config, "init_std") else self.config.audio_config.init_std
|
||||
std = (
|
||||
self.config.initializer_range
|
||||
if hasattr(self.config, "initializer_range")
|
||||
else self.config.audio_config.initializer_range
|
||||
)
|
||||
|
||||
if isinstance(module, (nn.Linear, nn.Conv1d)):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
|
@ -1157,12 +1157,12 @@ class Qwen2MoeModel(Qwen2MoePreTrainedModel):
|
||||
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
|
||||
)
|
||||
diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
|
||||
if config.sliding_window is not None:
|
||||
if config.get_text_config().sliding_window is not None:
|
||||
# if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also
|
||||
# the check is needed to verify is current checkpoint was trained with sliding window or not
|
||||
if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length:
|
||||
sliding_attend_mask = torch.arange(target_length, device=device) <= (
|
||||
cache_position.reshape(-1, 1) - config.sliding_window
|
||||
cache_position.reshape(-1, 1) - config.get_text_config().sliding_window
|
||||
)
|
||||
diagonal_attend_mask.bitwise_or_(sliding_attend_mask)
|
||||
causal_mask *= diagonal_attend_mask
|
||||
|
@ -689,7 +689,7 @@ class Qwen2VLFlashAttention2(Qwen2VLAttention):
|
||||
use_top_left_mask=self._flash_attn_uses_top_left_mask,
|
||||
)
|
||||
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
|
||||
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
if not output_attentions:
|
||||
@ -1287,12 +1287,12 @@ class Qwen2VLModel(Qwen2VLPreTrainedModel):
|
||||
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
|
||||
)
|
||||
diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
|
||||
if config.sliding_window is not None:
|
||||
if config.get_text_config().sliding_window is not None:
|
||||
# if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also
|
||||
# the check is needed to verify is current checkpoint was trained with sliding window or not
|
||||
if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length:
|
||||
sliding_attend_mask = torch.arange(target_length, device=device) <= (
|
||||
cache_position.reshape(-1, 1) - config.sliding_window
|
||||
cache_position.reshape(-1, 1) - config.get_text_config().sliding_window
|
||||
)
|
||||
diagonal_attend_mask.bitwise_or_(sliding_attend_mask)
|
||||
causal_mask *= diagonal_attend_mask
|
||||
|
@ -732,12 +732,12 @@ class Qwen3Model(Qwen3PreTrainedModel):
|
||||
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
|
||||
)
|
||||
diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
|
||||
if config.sliding_window is not None:
|
||||
if config.get_text_config().sliding_window is not None:
|
||||
# if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also
|
||||
# the check is needed to verify is current checkpoint was trained with sliding window or not
|
||||
if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length:
|
||||
sliding_attend_mask = torch.arange(target_length, device=device) <= (
|
||||
cache_position.reshape(-1, 1) - config.sliding_window
|
||||
cache_position.reshape(-1, 1) - config.get_text_config().sliding_window
|
||||
)
|
||||
diagonal_attend_mask.bitwise_or_(sliding_attend_mask)
|
||||
causal_mask *= diagonal_attend_mask
|
||||
|
@ -835,12 +835,12 @@ class Qwen3MoeModel(Qwen3MoePreTrainedModel):
|
||||
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
|
||||
)
|
||||
diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
|
||||
if config.sliding_window is not None:
|
||||
if config.get_text_config().sliding_window is not None:
|
||||
# if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also
|
||||
# the check is needed to verify is current checkpoint was trained with sliding window or not
|
||||
if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length:
|
||||
sliding_attend_mask = torch.arange(target_length, device=device) <= (
|
||||
cache_position.reshape(-1, 1) - config.sliding_window
|
||||
cache_position.reshape(-1, 1) - config.get_text_config().sliding_window
|
||||
)
|
||||
diagonal_attend_mask.bitwise_or_(sliding_attend_mask)
|
||||
causal_mask *= diagonal_attend_mask
|
||||
|
@ -681,12 +681,12 @@ class Starcoder2Model(Starcoder2PreTrainedModel):
|
||||
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
|
||||
)
|
||||
diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
|
||||
if config.sliding_window is not None:
|
||||
if config.get_text_config().sliding_window is not None:
|
||||
# if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also
|
||||
# the check is needed to verify is current checkpoint was trained with sliding window or not
|
||||
if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length:
|
||||
sliding_attend_mask = torch.arange(target_length, device=device) <= (
|
||||
cache_position.reshape(-1, 1) - config.sliding_window
|
||||
cache_position.reshape(-1, 1) - config.get_text_config().sliding_window
|
||||
)
|
||||
diagonal_attend_mask.bitwise_or_(sliding_attend_mask)
|
||||
causal_mask *= diagonal_attend_mask
|
||||
|
@ -129,6 +129,7 @@ VLM_CLASS_NAMES = [
|
||||
"gemma3",
|
||||
"mistral3",
|
||||
"chameleon",
|
||||
"qwen2_5_omni",
|
||||
]
|
||||
|
||||
|
||||
|
0
tests/models/qwen2_5_omni/__init__.py
Normal file
0
tests/models/qwen2_5_omni/__init__.py
Normal file
606
tests/models/qwen2_5_omni/test_modeling_qwen2_5_omni.py
Normal file
606
tests/models/qwen2_5_omni/test_modeling_qwen2_5_omni.py
Normal file
@ -0,0 +1,606 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Testing suite for the PyTorch Qwen2.5-Omni model."""
|
||||
|
||||
import tempfile
|
||||
import unittest
|
||||
from io import BytesIO
|
||||
from urllib.request import urlopen
|
||||
|
||||
import librosa
|
||||
import requests
|
||||
from parameterized import parameterized
|
||||
|
||||
from transformers import (
|
||||
AutoProcessor,
|
||||
Qwen2_5OmniForConditionalGeneration,
|
||||
Qwen2_5OmniThinkerConfig,
|
||||
Qwen2_5OmniThinkerForConditionalGeneration,
|
||||
is_torch_available,
|
||||
is_vision_available,
|
||||
)
|
||||
from transformers.testing_utils import (
|
||||
cleanup,
|
||||
require_flash_attn,
|
||||
require_torch,
|
||||
require_torch_gpu,
|
||||
require_torch_sdpa,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
from ...generation.test_utils import GenerationTesterMixin
|
||||
from ...test_configuration_common import ConfigTester
|
||||
from ...test_modeling_common import (
|
||||
ModelTesterMixin,
|
||||
floats_tensor,
|
||||
ids_tensor,
|
||||
)
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
|
||||
|
||||
class Qwen2_5OmniThinkerForConditionalGenerationTester:
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
batch_size=3,
|
||||
feat_seq_length=30,
|
||||
num_channels=3,
|
||||
image_size=14,
|
||||
seq_length=39,
|
||||
vision_config={
|
||||
"depth": 2,
|
||||
"embed_dim": 32,
|
||||
"hidden_act": "quick_gelu",
|
||||
"hidden_size": 32,
|
||||
"out_hidden_size": 32,
|
||||
"intermediate_size": 24,
|
||||
"mlp_ratio": 4,
|
||||
"num_heads": 4,
|
||||
"patch_size": 14,
|
||||
"spatial_merge_size": 1,
|
||||
"temporal_patch_size": 2,
|
||||
"fullatt_block_indexes": [0],
|
||||
"initializer_range": 0.02,
|
||||
},
|
||||
audio_config={
|
||||
"model_type": "qwen_omni_thinker_audio_encoder",
|
||||
"d_model": 32,
|
||||
"encoder_attention_heads": 4,
|
||||
"encoder_ffn_dim": 32,
|
||||
"encoder_layers": 2,
|
||||
"num_mel_bins": 20,
|
||||
"max_source_positions": 1500,
|
||||
"initializer_range": 0.02,
|
||||
"n_window": 100,
|
||||
"output_dim": 32,
|
||||
},
|
||||
text_config={
|
||||
"rope_scaling": {"mrope_section": [1, 1, 2], "rope_type": "default", "type": "default"},
|
||||
"vocab_size": 99,
|
||||
"hidden_size": 32,
|
||||
"intermediate_size": 37,
|
||||
"num_hidden_layers": 4,
|
||||
"num_attention_heads": 4,
|
||||
"num_key_value_heads": 2,
|
||||
"hidden_act": "silu",
|
||||
"max_position_embeddings": 1024,
|
||||
"rms_norm_eps": 1e-06,
|
||||
"use_cache": True,
|
||||
"tie_word_embeddings": False,
|
||||
"rope_theta": 1000000.0,
|
||||
"use_sliding_window": False,
|
||||
"sliding_window": 50,
|
||||
"max_window_layers": 3,
|
||||
"attention_dropout": 0.0,
|
||||
"pad_token_id": 0,
|
||||
"initializer_range": 0.02,
|
||||
},
|
||||
audio_token_index=1,
|
||||
image_token_index=2,
|
||||
video_token_index=3,
|
||||
position_id_per_seconds=25,
|
||||
seconds_per_chunk=2,
|
||||
audio_start_token_id=4,
|
||||
audio_end_token_id=5,
|
||||
user_token_id=6,
|
||||
vision_start_token_id=7,
|
||||
vision_end_token_id=8,
|
||||
initializer_range=0.02,
|
||||
):
|
||||
self.parent = parent
|
||||
self.audio_config = audio_config
|
||||
self.vision_config = vision_config
|
||||
self.text_config = text_config
|
||||
self.audio_token_index = audio_token_index
|
||||
self.image_token_index = image_token_index
|
||||
self.video_token_index = video_token_index
|
||||
self.position_id_per_seconds = position_id_per_seconds
|
||||
self.seconds_per_chunk = seconds_per_chunk
|
||||
self.audio_start_token_id = audio_start_token_id
|
||||
self.audio_end_token_id = audio_end_token_id
|
||||
self.vision_start_token_id = vision_start_token_id
|
||||
self.vision_end_token_id = vision_end_token_id
|
||||
self.user_token_id = user_token_id
|
||||
self.initializer_range = initializer_range
|
||||
self.batch_size = batch_size
|
||||
self.feat_seq_length = feat_seq_length
|
||||
self.num_channels = num_channels
|
||||
self.image_size = image_size
|
||||
self.seq_length = seq_length
|
||||
self.is_training = False
|
||||
|
||||
# Used from `self.model_tester` by common model tests
|
||||
self.num_hidden_layers = self.text_config["num_hidden_layers"]
|
||||
self.hidden_size = self.text_config["hidden_size"]
|
||||
self.num_attention_heads = self.text_config["num_attention_heads"]
|
||||
self.vocab_size = self.text_config["vocab_size"]
|
||||
|
||||
def get_config(self):
|
||||
return Qwen2_5OmniThinkerConfig(
|
||||
audio_config=self.audio_config,
|
||||
vision_config=self.vision_config,
|
||||
text_config=self.text_config,
|
||||
audio_token_index=self.audio_token_index,
|
||||
image_token_index=self.image_token_index,
|
||||
video_token_index=self.video_token_index,
|
||||
position_id_per_seconds=self.position_id_per_seconds,
|
||||
seconds_per_chunk=self.seconds_per_chunk,
|
||||
audio_start_token_id=self.audio_start_token_id,
|
||||
audio_end_token_id=self.audio_end_token_id,
|
||||
vision_start_token_id=self.vision_start_token_id,
|
||||
vision_end_token_id=self.vision_end_token_id,
|
||||
user_token_id=self.user_token_id,
|
||||
initializer_range=self.initializer_range,
|
||||
)
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
config = self.get_config()
|
||||
patch_size = config.vision_config.patch_size
|
||||
temporal_patch_size = config.vision_config.temporal_patch_size
|
||||
pixel_values = floats_tensor(
|
||||
[
|
||||
self.batch_size * (self.image_size**2) // (patch_size**2),
|
||||
self.num_channels * (patch_size**2) * temporal_patch_size,
|
||||
]
|
||||
)
|
||||
pixel_grid_thw = torch.LongTensor(
|
||||
[[1, self.image_size / patch_size, self.image_size / patch_size]] * self.batch_size
|
||||
).to(pixel_values.device)
|
||||
input_features_values = floats_tensor(
|
||||
[self.batch_size, self.audio_config["num_mel_bins"], self.feat_seq_length]
|
||||
)
|
||||
feature_attention_mask = torch.ones([self.batch_size, self.feat_seq_length], dtype=torch.long).to(torch_device)
|
||||
return config, pixel_values, pixel_grid_thw, input_features_values, feature_attention_mask
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
config, pixel_values, pixel_grid_thw, input_features_values, feature_attention_mask = config_and_inputs
|
||||
input_ids = ids_tensor([self.batch_size, self.seq_length], config.get_text_config().vocab_size - 3) + 3
|
||||
attention_mask = torch.ones(input_ids.shape, dtype=torch.long).to(torch_device)
|
||||
|
||||
# Make sure no other tokens are set to special, to prevetn flakiness
|
||||
tokens_to_replace = torch.tensor(
|
||||
[
|
||||
config.image_token_index,
|
||||
config.audio_token_index,
|
||||
config.audio_start_token_id,
|
||||
config.audio_end_token_id,
|
||||
config.vision_start_token_id,
|
||||
config.vision_end_token_id,
|
||||
],
|
||||
device=input_ids.device,
|
||||
)
|
||||
input_ids[torch.isin(input_ids, tokens_to_replace)] = config.text_config.pad_token_id
|
||||
|
||||
attention_mask[:, :1] = 0
|
||||
|
||||
# Audio token placeholders should be wrapped in start and end token ids
|
||||
audio_feat_length = ((self.feat_seq_length - 1) // 2 + 1 - 2) // 2 + 1
|
||||
input_ids[:, 1] = config.audio_start_token_id
|
||||
input_ids[:, 2 : (2 + audio_feat_length)] = config.audio_token_index
|
||||
input_ids[:, 2 + audio_feat_length] = config.audio_end_token_id
|
||||
|
||||
# Image token placeholders should be wrapped in start and end token ids
|
||||
input_ids[:, -4:-1] = torch.tensor(
|
||||
[config.vision_start_token_id, config.image_token_index, config.vision_end_token_id]
|
||||
)
|
||||
inputs_dict = {
|
||||
"input_features": input_features_values,
|
||||
"feature_attention_mask": feature_attention_mask,
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"image_grid_thw": pixel_grid_thw,
|
||||
"pixel_values": pixel_values,
|
||||
}
|
||||
return config, inputs_dict
|
||||
|
||||
def create_and_check_qwenomnithinker_model_fp16_forward(self, config, input_ids, pixel_values, attention_mask):
|
||||
model = Qwen2_5OmniThinkerForConditionalGeneration(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
with torch.autocast(device_type=torch_device, dtype=torch.float16):
|
||||
logits = model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
pixel_values=pixel_values.to(torch.bfloat16),
|
||||
return_dict=True,
|
||||
)["logits"]
|
||||
self.parent.assertFalse(torch.isnan(logits).any().item())
|
||||
|
||||
|
||||
@require_torch
|
||||
class Qwen2_5OmniThinkerForConditionalGenerationModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||
"""
|
||||
Model tester for `Qwen2_5OmniThinkerForConditionalGeneration`.
|
||||
"""
|
||||
|
||||
all_model_classes = (Qwen2_5OmniThinkerForConditionalGeneration,) if is_torch_available() else ()
|
||||
all_generative_model_classes = (Qwen2_5OmniThinkerForConditionalGeneration,) if is_torch_available() else ()
|
||||
test_pruning = False
|
||||
test_head_masking = False
|
||||
_is_composite = True
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = Qwen2_5OmniThinkerForConditionalGenerationTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=Qwen2_5OmniThinkerConfig, has_text_modality=False)
|
||||
|
||||
@unittest.skip(reason="Cpu not yet supported because in QwenOmniThinker models")
|
||||
def test_disk_offload_bin(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Disk offload bin not yet supported because in QwenOmniThinker models")
|
||||
def test_cpu_offload(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Disk offload safetensors not yet supported because in QwenOmniThinker models")
|
||||
def test_disk_offload_safetensors(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Correct missing keys not yet supported because in QwenOmniThinker models")
|
||||
def test_correct_missing_keys(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Compile not yet supported because in QwenOmniThinker models")
|
||||
def test_sdpa_can_compile_dynamic(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Sdpa dispatch not yet supported because in QwenOmniThinker models")
|
||||
def test_sdpa_can_dispatch_on_flash(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="QwenOmniThinker does not use inputs_embeds")
|
||||
def test_inputs_embeds(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="QwenOmniThinker does not support output_hidden_states test")
|
||||
def test_model_outputs_equivalence(self):
|
||||
pass
|
||||
|
||||
@require_torch_sdpa
|
||||
def test_sdpa_can_dispatch_composite_models(self):
|
||||
# overwrite because Qwen2 is audio+text model (not vision+text)
|
||||
if not self.has_attentions:
|
||||
self.skipTest(reason="Model architecture does not support attentions")
|
||||
|
||||
if not self._is_composite:
|
||||
self.skipTest(f"{self.all_model_classes[0].__name__} does not support SDPA")
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
model = model_class(config)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname)
|
||||
model_sdpa = model_class.from_pretrained(tmpdirname)
|
||||
model_sdpa = model_sdpa.eval().to(torch_device)
|
||||
|
||||
text_attn = "sdpa" if model.model._supports_sdpa else "eager"
|
||||
audio_attn = "sdpa" if model.audio_tower._supports_sdpa else "eager"
|
||||
vision_attn = "sdpa" if model.visual._supports_sdpa else "eager"
|
||||
# `None` as it is the requested one which will be assigned to each sub-config
|
||||
# Sub-model will dispatch to SDPA if it can (checked below that `SDPA` layers are present)
|
||||
self.assertTrue(model_sdpa.config._attn_implementation == "sdpa")
|
||||
self.assertTrue(model.model.config._attn_implementation == text_attn)
|
||||
self.assertTrue(model.audio_tower.config._attn_implementation == audio_attn)
|
||||
self.assertTrue(model.visual.config._attn_implementation == vision_attn)
|
||||
|
||||
model_eager = model_class.from_pretrained(tmpdirname, attn_implementation="eager")
|
||||
model_eager = model_eager.eval().to(torch_device)
|
||||
self.assertTrue(model_eager.config._attn_implementation == "eager")
|
||||
self.assertTrue(model_eager.model.config._attn_implementation == "eager")
|
||||
self.assertTrue(model_eager.audio_tower.config._attn_implementation == "eager")
|
||||
self.assertTrue(model_eager.visual.config._attn_implementation == "eager")
|
||||
|
||||
for name, submodule in model_eager.named_modules():
|
||||
class_name = submodule.__class__.__name__
|
||||
if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
|
||||
raise ValueError("The eager model should not have SDPA attention layers")
|
||||
|
||||
@parameterized.expand([("greedy", 1), ("beam search", 2)])
|
||||
@unittest.skip("Cannot generate from inputs embeds")
|
||||
def test_generate_from_inputs_embeds(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Cannot do contrastive generation, has custom `generate()`")
|
||||
def test_contrastive_generate(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Cannot do contrastive generation, has custom `generate()`")
|
||||
def test_contrastive_generate_dict_outputs_use_cache(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Cannot do contrastive generation, has custom `generate()`")
|
||||
def test_contrastive_generate_low_memory(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Cannot do constraint generation, has custom `generate()`")
|
||||
def test_constrained_beam_search_generate_dict_output(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Cannot do dola generation, has custom `generate()`")
|
||||
def test_dola_decoding_sample(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Cannot generate from inputs embeds")
|
||||
def test_generate_from_inputs_embeds_with_static_cache(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Cannot handle 4D attention mask")
|
||||
def test_generate_compile_model_forward(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Cannot handle 4D attention mask")
|
||||
def test_generate_compilation_all_outputs(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Cannot handle 4D attention mask")
|
||||
def test_generate_with_static_cache(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Cannot handle 4D attention mask")
|
||||
def test_custom_4d_attention_mask(self):
|
||||
pass
|
||||
|
||||
|
||||
@require_torch
|
||||
class Qwen2_5OmniModelIntegrationTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-Omni-7B")
|
||||
self.audio_url = "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2-Audio/audio/glass-breaking-151256.mp3"
|
||||
self.audio_url_additional = (
|
||||
"https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2-Audio/audio/f2641_0_throatclearing.wav"
|
||||
)
|
||||
self.image_url = "https://qianwen-res.oss-accelerate-overseas.aliyuncs.com/Qwen2-VL/demo_small.jpg"
|
||||
self.messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "audio", "audio_url": self.audio_url},
|
||||
{"type": "image", "image_url": self.image_url},
|
||||
{"type": "text", "text": "What's that sound and what kind of dog is this?"},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
self.raw_audio, _ = librosa.load(
|
||||
BytesIO(urlopen(self.audio_url).read()), sr=self.processor.feature_extractor.sampling_rate
|
||||
)
|
||||
self.raw_audio_additional, _ = librosa.load(
|
||||
BytesIO(urlopen(self.audio_url_additional).read()), sr=self.processor.feature_extractor.sampling_rate
|
||||
)
|
||||
self.raw_image = Image.open(requests.get(self.image_url, stream=True).raw)
|
||||
|
||||
def tearDown(self):
|
||||
cleanup(torch_device, gc_collect=True)
|
||||
|
||||
@slow
|
||||
def test_small_model_integration_test(self):
|
||||
model = Qwen2_5OmniForConditionalGeneration.from_pretrained(
|
||||
"Qwen/Qwen2.5-Omni-7B", torch_dtype=torch.float32, device_map="auto"
|
||||
)
|
||||
|
||||
text = self.processor.apply_chat_template(self.messages, tokenize=False, add_generation_prompt=True)
|
||||
inputs = self.processor(
|
||||
text=[text], audio=[self.raw_audio], images=[self.raw_image], return_tensors="pt", padding=True
|
||||
)
|
||||
|
||||
expected_input_ids = torch.tensor(
|
||||
[
|
||||
151644,
|
||||
8948,
|
||||
198,
|
||||
2610,
|
||||
525,
|
||||
264,
|
||||
10950,
|
||||
17847,
|
||||
13,
|
||||
151645,
|
||||
198,
|
||||
151644,
|
||||
872,
|
||||
198,
|
||||
151647,
|
||||
151646,
|
||||
151648,
|
||||
]
|
||||
)
|
||||
assert torch.allclose(expected_input_ids, inputs.input_ids[0][:17], atol=3e-3)
|
||||
|
||||
expected_pixel_slice = torch.tensor(
|
||||
[
|
||||
[0.8792, 0.8792, 0.9084],
|
||||
[1.1858, 1.1858, 1.2296],
|
||||
[1.2004, 1.2004, 1.2150],
|
||||
[1.4340, 1.4340, 1.4194],
|
||||
[1.3902, 1.4048, 1.4194],
|
||||
[1.5216, 1.5362, 1.5362],
|
||||
],
|
||||
dtype=torch.float32,
|
||||
device="cpu",
|
||||
)
|
||||
assert torch.allclose(expected_pixel_slice, inputs.pixel_values[:6, :3], atol=3e-3)
|
||||
|
||||
# verify generation
|
||||
inputs = inputs.to(torch_device)
|
||||
|
||||
output = model.generate(**inputs, thinker_temperature=0, thinker_do_sample=False, return_audio=False)
|
||||
|
||||
EXPECTED_DECODED_TEXT = "system\nYou are a helpful assistant.\nuser\nWhat's that sound and what kind of dog is this?\nassistant\nThe sound is glass shattering, and the dog appears to be a Labrador Retriever."
|
||||
|
||||
self.assertEqual(
|
||||
self.processor.decode(output[0], skip_special_tokens=True),
|
||||
EXPECTED_DECODED_TEXT,
|
||||
)
|
||||
|
||||
@slow
|
||||
def test_small_model_integration_test_batch(self):
|
||||
model = Qwen2_5OmniForConditionalGeneration.from_pretrained(
|
||||
"Qwen/Qwen2.5-Omni-7B", torch_dtype=torch.float32, device_map="auto"
|
||||
)
|
||||
text = self.processor.apply_chat_template(self.messages, tokenize=False, add_generation_prompt=True)
|
||||
inputs = self.processor(
|
||||
text=[text, text],
|
||||
audio=[self.raw_audio, self.raw_audio],
|
||||
images=[self.raw_image, self.raw_image],
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
).to(torch_device)
|
||||
|
||||
output = model.generate(**inputs, thinker_temperature=0, thinker_do_sample=False, return_audio=False)
|
||||
|
||||
EXPECTED_DECODED_TEXT = [
|
||||
"system\nYou are a helpful assistant.\nuser\nWhat's that sound and what kind of dog is this?\nassistant\nThe sound is glass shattering, and the dog appears to be a Labrador Retriever.",
|
||||
"system\nYou are a helpful assistant.\nuser\nWhat's that sound and what kind of dog is this?\nassistant\nThe sound is glass shattering, and the dog appears to be a Labrador Retriever.",
|
||||
]
|
||||
|
||||
self.assertEqual(
|
||||
self.processor.batch_decode(output, skip_special_tokens=True),
|
||||
EXPECTED_DECODED_TEXT,
|
||||
)
|
||||
|
||||
@slow
|
||||
def test_small_model_integration_test_multiturn(self):
|
||||
model = Qwen2_5OmniForConditionalGeneration.from_pretrained(
|
||||
"Qwen/Qwen2.5-Omni-7B", torch_dtype=torch.float32, device_map="auto"
|
||||
)
|
||||
|
||||
messages = [
|
||||
self.messages[0],
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "The sound is glass shattering, and the dog appears to be a Labrador Retriever.",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "audio", "audio_url": self.audio_url_additional},
|
||||
{"type": "text", "text": "How about this one?"},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
||||
inputs = self.processor(
|
||||
text=[text],
|
||||
audio=[self.raw_audio, self.raw_audio_additional],
|
||||
images=[self.raw_image],
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
).to(torch_device)
|
||||
|
||||
output = model.generate(**inputs, thinker_temperature=0, thinker_do_sample=False, return_audio=False)
|
||||
|
||||
EXPECTED_DECODED_TEXT = "system\nYou are a helpful assistant.\nuser\nWhat's that sound and what kind of dog is this?\nassistant\nThe sound is glass shattering, and the dog appears to be a Labrador Retriever.\nuser\nHow about this one?\nassistant\nThe sound is a cough."
|
||||
|
||||
self.assertEqual(
|
||||
self.processor.decode(output[0], skip_special_tokens=True),
|
||||
EXPECTED_DECODED_TEXT,
|
||||
)
|
||||
|
||||
@slow
|
||||
def test_small_model_integration_test_w_audio(self):
|
||||
model = Qwen2_5OmniForConditionalGeneration.from_pretrained(
|
||||
"Qwen/Qwen2.5-Omni-7B", torch_dtype=torch.float32, device_map="auto"
|
||||
)
|
||||
audio_url = "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2-Audio/audio/guess_age_gender.wav"
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are Qwen, a virtual human developed by the Qwen Team, Alibaba Group, capable of perceiving auditory and visual inputs, as well as generating text and speech.",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{"type": "audio", "audio": audio_url}],
|
||||
},
|
||||
]
|
||||
audio, _ = librosa.load(BytesIO(urlopen(audio_url).read()), sr=self.processor.feature_extractor.sampling_rate)
|
||||
|
||||
text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
||||
inputs = self.processor(text=[text], audio=[audio], return_tensors="pt", padding=True).to(torch_device)
|
||||
|
||||
output = model.generate(**inputs, thinker_temperature=0, thinker_do_sample=False)
|
||||
|
||||
EXPECTED_DECODED_TEXT = "system\nYou are Qwen, a virtual human developed by the Qwen Team, Alibaba Group, capable of perceiving auditory and visual inputs, as well as generating text and speech.\nuser\n\nassistant\nWell, I can't really guess your age and gender just from your voice. There are so many factors that can affect how a voice sounds, like the environment you're in, how you're feeling at the moment, and even the microphone you're using. But if you want to share more about your voice, like if it's high - pitched or low - pitched, that might give me a bit of an idea. So, what can you tell me about your voice?"
|
||||
|
||||
self.assertEqual(
|
||||
self.processor.decode(output[0][0], skip_special_tokens=True),
|
||||
EXPECTED_DECODED_TEXT,
|
||||
)
|
||||
self.assertFalse(torch.isnan(output[1]).any().item())
|
||||
|
||||
@slow
|
||||
@require_flash_attn
|
||||
@require_torch_gpu
|
||||
def test_small_model_integration_test_batch_flashatt2(self):
|
||||
model = Qwen2_5OmniForConditionalGeneration.from_pretrained(
|
||||
"Qwen/Qwen2.5-Omni-7B",
|
||||
torch_dtype=torch.bfloat16,
|
||||
attn_implementation="flash_attention_2",
|
||||
device_map="auto",
|
||||
)
|
||||
text = self.processor.apply_chat_template(self.messages, tokenize=False, add_generation_prompt=True)
|
||||
inputs = self.processor(
|
||||
text=[text, text],
|
||||
audio=[self.raw_audio, self.raw_audio],
|
||||
images=[self.raw_image, self.raw_image],
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
).to(torch_device)
|
||||
|
||||
output = model.generate(**inputs, thinker_temperature=0, thinker_do_sample=False, return_audio=False)
|
||||
|
||||
EXPECTED_DECODED_TEXT = [
|
||||
"system\nYou are a helpful assistant.\nuser\nWhat's that sound and what kind of dog is this?\nassistant\nThe sound is glass shattering, and the dog appears to be a Labrador Retriever.",
|
||||
"system\nYou are a helpful assistant.\nuser\nWhat's that sound and what kind of dog is this?\nassistant\nThe sound is glass shattering, and the dog appears to be a Labrador Retriever.",
|
||||
]
|
||||
|
||||
self.assertEqual(
|
||||
self.processor.batch_decode(output, skip_special_tokens=True),
|
||||
EXPECTED_DECODED_TEXT,
|
||||
)
|
||||
self.assertEqual(
|
||||
self.processor.batch_decode(output, skip_special_tokens=True)[0],
|
||||
self.processor.batch_decode(output, skip_special_tokens=True)[1],
|
||||
)
|
616
tests/models/qwen2_5_omni/test_processor_qwen2_5_omni.py
Normal file
616
tests/models/qwen2_5_omni/test_processor_qwen2_5_omni.py
Normal file
@ -0,0 +1,616 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import inspect
|
||||
import shutil
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
from transformers import (
|
||||
AutoProcessor,
|
||||
Qwen2_5OmniProcessor,
|
||||
Qwen2Tokenizer,
|
||||
WhisperFeatureExtractor,
|
||||
)
|
||||
from transformers.testing_utils import require_av, require_librosa, require_torch, require_torchaudio, require_vision
|
||||
from transformers.utils import is_torch_available, is_vision_available
|
||||
|
||||
from ...test_processing_common import ProcessorTesterMixin
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
if is_vision_available():
|
||||
from transformers import Qwen2VLImageProcessor
|
||||
|
||||
|
||||
@require_vision
|
||||
@require_torch
|
||||
@require_torchaudio
|
||||
class Qwen2_5OmniProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
processor_class = Qwen2_5OmniProcessor
|
||||
|
||||
# text + audio kwargs testing
|
||||
@require_torch
|
||||
def test_tokenizer_defaults_preserved_by_kwargs_audio(self):
|
||||
if "feature_extractor" not in self.processor_class.attributes:
|
||||
self.skipTest(f"feature_extractor attribute not present in {self.processor_class}")
|
||||
feature_extractor = self.get_component("feature_extractor")
|
||||
if hasattr(self, "get_tokenizer"):
|
||||
tokenizer = self.get_tokenizer(max_length=800, padding="max_length")
|
||||
elif hasattr(self, "get_component"):
|
||||
tokenizer = self.get_component("tokenizer", max_length=800, padding="max_length")
|
||||
else:
|
||||
self.assertTrue(False, "Processor doesn't have get_tokenizer or get_component defined")
|
||||
if not tokenizer.pad_token:
|
||||
tokenizer.pad_token = "[TEST_PAD]"
|
||||
if "image_processor" not in self.processor_class.attributes:
|
||||
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
|
||||
image_processor = self.get_component("image_processor")
|
||||
processor = self.processor_class(
|
||||
tokenizer=tokenizer, feature_extractor=feature_extractor, image_processor=image_processor
|
||||
)
|
||||
self.skip_processor_without_typed_kwargs(processor)
|
||||
input_str = "lower newer"
|
||||
raw_speech = self.prepare_audio_inputs()
|
||||
inputs = processor(text=input_str, audio=raw_speech, return_tensors="pt")
|
||||
if "input_ids" in inputs:
|
||||
self.assertEqual(len(inputs["input_ids"][0]), 800)
|
||||
elif "labels" in inputs:
|
||||
self.assertEqual(len(inputs["labels"][0]), 800)
|
||||
|
||||
@require_torch
|
||||
@require_vision
|
||||
def test_structured_kwargs_audio_nested(self):
|
||||
if "feature_extractor" not in self.processor_class.attributes:
|
||||
self.skipTest(f"feature_extractor attribute not present in {self.processor_class}")
|
||||
feature_extractor = self.get_component("feature_extractor")
|
||||
if hasattr(self, "get_tokenizer"):
|
||||
tokenizer = self.get_tokenizer()
|
||||
elif hasattr(self, "get_component"):
|
||||
tokenizer = self.get_component("tokenizer")
|
||||
if not tokenizer.pad_token:
|
||||
tokenizer.pad_token = "[TEST_PAD]"
|
||||
if "image_processor" not in self.processor_class.attributes:
|
||||
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
|
||||
image_processor = self.get_component("image_processor")
|
||||
processor = self.processor_class(
|
||||
tokenizer=tokenizer, feature_extractor=feature_extractor, image_processor=image_processor
|
||||
)
|
||||
self.skip_processor_without_typed_kwargs(processor)
|
||||
|
||||
input_str = ["lower newer"]
|
||||
raw_speech = self.prepare_audio_inputs()
|
||||
|
||||
# Define the kwargs for each modality
|
||||
all_kwargs = {
|
||||
"common_kwargs": {"return_tensors": "pt"},
|
||||
"audio_kwargs": {"max_length": 800},
|
||||
}
|
||||
|
||||
inputs = processor(text=input_str, audio=raw_speech, **all_kwargs)
|
||||
if "input_ids" in inputs:
|
||||
self.assertEqual(len(inputs["input_ids"][0]), 2)
|
||||
elif "labels" in inputs:
|
||||
self.assertEqual(len(inputs["labels"][0]), 2)
|
||||
|
||||
@require_torch
|
||||
def test_unstructured_kwargs_audio(self):
|
||||
if "feature_extractor" not in self.processor_class.attributes:
|
||||
self.skipTest(f"feature_extractor attribute not present in {self.processor_class}")
|
||||
feature_extractor = self.get_component("feature_extractor")
|
||||
if hasattr(self, "get_tokenizer"):
|
||||
tokenizer = self.get_tokenizer(max_length=117)
|
||||
elif hasattr(self, "get_component"):
|
||||
tokenizer = self.get_component("tokenizer", max_length=117)
|
||||
if not tokenizer.pad_token:
|
||||
tokenizer.pad_token = "[TEST_PAD]"
|
||||
if "image_processor" not in self.processor_class.attributes:
|
||||
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
|
||||
image_processor = self.get_component("image_processor")
|
||||
processor = self.processor_class(
|
||||
tokenizer=tokenizer, feature_extractor=feature_extractor, image_processor=image_processor
|
||||
)
|
||||
self.skip_processor_without_typed_kwargs(processor)
|
||||
|
||||
input_str = "lower newer"
|
||||
raw_speech = self.prepare_audio_inputs()
|
||||
inputs = processor(
|
||||
text=input_str,
|
||||
audio=raw_speech,
|
||||
return_tensors="pt",
|
||||
padding="max_length",
|
||||
max_length=800,
|
||||
)
|
||||
|
||||
if "input_ids" in inputs:
|
||||
self.assertEqual(len(inputs["input_ids"][0]), 800)
|
||||
elif "labels" in inputs:
|
||||
self.assertEqual(len(inputs["labels"][0]), 800)
|
||||
|
||||
@require_torch
|
||||
def test_doubly_passed_kwargs_audio(self):
|
||||
if "feature_extractor" not in self.processor_class.attributes:
|
||||
self.skipTest(f"feature_extractor attribute not present in {self.processor_class}")
|
||||
feature_extractor = self.get_component("feature_extractor")
|
||||
if hasattr(self, "get_tokenizer"):
|
||||
tokenizer = self.get_tokenizer()
|
||||
elif hasattr(self, "get_component"):
|
||||
tokenizer = self.get_component("tokenizer")
|
||||
if not tokenizer.pad_token:
|
||||
tokenizer.pad_token = "[TEST_PAD]"
|
||||
if "image_processor" not in self.processor_class.attributes:
|
||||
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
|
||||
image_processor = self.get_component("image_processor")
|
||||
self.processor_class(tokenizer=tokenizer, feature_extractor=feature_extractor, image_processor=image_processor)
|
||||
|
||||
@require_torch
|
||||
def test_kwargs_overrides_default_tokenizer_kwargs_audio(self):
|
||||
if "feature_extractor" not in self.processor_class.attributes:
|
||||
self.skipTest(f"feature_extractor attribute not present in {self.processor_class}")
|
||||
feature_extractor = self.get_component("feature_extractor")
|
||||
if hasattr(self, "get_tokenizer"):
|
||||
tokenizer = self.get_tokenizer(max_length=117)
|
||||
elif hasattr(self, "get_component"):
|
||||
tokenizer = self.get_component("tokenizer", max_length=117)
|
||||
if not tokenizer.pad_token:
|
||||
tokenizer.pad_token = "[TEST_PAD]"
|
||||
if "image_processor" not in self.processor_class.attributes:
|
||||
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
|
||||
image_processor = self.get_component("image_processor")
|
||||
self.processor_class(tokenizer=tokenizer, feature_extractor=feature_extractor, image_processor=image_processor)
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.tmpdirname = tempfile.mkdtemp()
|
||||
processor_kwargs = cls.prepare_processor_dict()
|
||||
processor = Qwen2_5OmniProcessor.from_pretrained("Qwen/Qwen2.5-Omni-7B", **processor_kwargs)
|
||||
processor.save_pretrained(cls.tmpdirname)
|
||||
|
||||
def get_tokenizer(self, **kwargs):
|
||||
return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).tokenizer
|
||||
|
||||
def get_image_processor(self, **kwargs):
|
||||
return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).image_processor
|
||||
|
||||
def get_feature_extractor(self, **kwargs):
|
||||
return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).feature_extractor
|
||||
|
||||
@staticmethod
|
||||
def prepare_processor_dict():
|
||||
return {
|
||||
"chat_template": "{% set audio_count = namespace(value=0) %}{% set image_count = namespace(value=0) %}{% set video_count = namespace(value=0) %}{% for message in messages %}{% if loop.first and message['role'] != 'system' %}<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n{% endif %}<|im_start|>{{ message['role'] }}\n{% if message['content'] is string %}{{ message['content'] }}<|im_end|>\n{% else %}{% for content in message['content'] %}{% if content['type'] == 'image' or 'image' in content or 'image_url' in content %}{% set image_count.value = image_count.value + 1 %}{% if add_vision_id %}Picture {{ image_count.value }}: {% endif %}<|vision_bos|><|IMAGE|><|vision_eos|>{% elif content['type'] == 'audio' or 'audio' in content or 'audio_url' in content %}{% set audio_count.value = audio_count.value + 1 %}{% if add_audio_id %}Audio {{ audio_count.value }}: {% endif %}<|audio_bos|><|AUDIO|><|audio_eos|>{% elif content['type'] == 'video' or 'video' in content %}{% set video_count.value = video_count.value + 1 %}{% if add_vision_id %}Video {{ video_count.value }}: {% endif %}<|vision_bos|><|VIDEO|><|vision_eos|>{% elif 'text' in content %}{{ content['text'] }}{% endif %}{% endfor %}<|im_end|>\n{% endif %}{% endfor %}{% if add_generation_prompt %}<|im_start|>assistant\n{% endif %}"
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
shutil.rmtree(cls.tmpdirname, ignore_errors=True)
|
||||
|
||||
def prepare_audio_inputs(self):
|
||||
"""This function prepares a list of numpy audios."""
|
||||
audio_inputs = [np.random.rand(160000) * 2 - 1] * 3 # batch-size=3
|
||||
return audio_inputs
|
||||
|
||||
def test_save_load_pretrained_default(self):
|
||||
image_processor = self.get_image_processor()
|
||||
tokenizer = self.get_tokenizer()
|
||||
feature_extractor = self.get_feature_extractor()
|
||||
|
||||
processor = Qwen2_5OmniProcessor(
|
||||
image_processor=image_processor, feature_extractor=feature_extractor, tokenizer=tokenizer
|
||||
)
|
||||
processor.save_pretrained(self.tmpdirname)
|
||||
processor = Qwen2_5OmniProcessor.from_pretrained(self.tmpdirname, use_fast=False)
|
||||
|
||||
self.assertEqual(processor.tokenizer.get_vocab(), tokenizer.get_vocab())
|
||||
self.assertEqual(processor.image_processor.to_json_string(), image_processor.to_json_string())
|
||||
self.assertEqual(processor.feature_extractor.to_json_string(), feature_extractor.to_json_string())
|
||||
self.assertIsInstance(processor.tokenizer, Qwen2Tokenizer)
|
||||
self.assertIsInstance(processor.image_processor, Qwen2VLImageProcessor)
|
||||
self.assertIsInstance(processor.feature_extractor, WhisperFeatureExtractor)
|
||||
|
||||
def test_image_processor(self):
|
||||
image_processor = self.get_image_processor()
|
||||
tokenizer = self.get_tokenizer()
|
||||
feature_extractor = self.get_feature_extractor()
|
||||
|
||||
processor = Qwen2_5OmniProcessor(
|
||||
image_processor=image_processor, feature_extractor=feature_extractor, tokenizer=tokenizer
|
||||
)
|
||||
|
||||
image_input = self.prepare_image_inputs()
|
||||
|
||||
input_image_proc = image_processor(image_input, return_tensors="np")
|
||||
input_processor = processor(images=image_input, text="dummy", return_tensors="np")
|
||||
|
||||
for key in input_image_proc.keys():
|
||||
self.assertAlmostEqual(input_image_proc[key].sum(), input_processor[key].sum(), delta=1e-2)
|
||||
|
||||
def test_processor(self):
|
||||
image_processor = self.get_image_processor()
|
||||
tokenizer = self.get_tokenizer()
|
||||
feature_extractor = self.get_feature_extractor()
|
||||
|
||||
processor = Qwen2_5OmniProcessor(
|
||||
image_processor=image_processor, feature_extractor=feature_extractor, tokenizer=tokenizer
|
||||
)
|
||||
|
||||
input_str = "lower newer"
|
||||
image_input = self.prepare_image_inputs()
|
||||
audio_input = self.prepare_audio_inputs()
|
||||
inputs = processor(text=input_str, images=image_input, audio=audio_input)
|
||||
keys = list(inputs.keys())
|
||||
self.assertListEqual(
|
||||
keys,
|
||||
[
|
||||
"input_ids",
|
||||
"attention_mask",
|
||||
"pixel_values",
|
||||
"image_grid_thw",
|
||||
"feature_attention_mask",
|
||||
"input_features",
|
||||
],
|
||||
)
|
||||
|
||||
# test if it raises when no input is passed
|
||||
with pytest.raises(ValueError):
|
||||
processor()
|
||||
|
||||
# test if it raises when no text is passed
|
||||
with pytest.raises(ValueError):
|
||||
processor(images=image_input)
|
||||
|
||||
def test_model_input_names(self):
|
||||
image_processor = self.get_image_processor()
|
||||
tokenizer = self.get_tokenizer()
|
||||
feature_extractor = self.get_feature_extractor()
|
||||
|
||||
processor = Qwen2_5OmniProcessor(
|
||||
image_processor=image_processor, feature_extractor=feature_extractor, tokenizer=tokenizer
|
||||
)
|
||||
|
||||
input_str = "lower newer"
|
||||
image_input = self.prepare_image_inputs()
|
||||
video_inputs = self.prepare_video_inputs()
|
||||
audio_input = self.prepare_audio_inputs()
|
||||
|
||||
inputs = processor(text=input_str, images=image_input, videos=video_inputs, audio=audio_input)
|
||||
self.assertListEqual(sorted(inputs.keys()), sorted(processor.model_input_names))
|
||||
|
||||
@require_torch
|
||||
def _test_apply_chat_template(
|
||||
self,
|
||||
modality: str,
|
||||
batch_size: int,
|
||||
return_tensors: str,
|
||||
input_name: str,
|
||||
processor_name: str,
|
||||
input_data: list[str],
|
||||
):
|
||||
processor = self.get_processor()
|
||||
if processor.chat_template is None:
|
||||
self.skipTest("Processor has no chat template")
|
||||
|
||||
if processor_name not in self.processor_class.attributes:
|
||||
self.skipTest(f"{processor_name} attribute not present in {self.processor_class}")
|
||||
|
||||
batch_messages = [
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{"type": "text", "text": "Describe this."}],
|
||||
},
|
||||
]
|
||||
] * batch_size
|
||||
|
||||
# Test that jinja can be applied
|
||||
formatted_prompt = processor.apply_chat_template(batch_messages, add_generation_prompt=True, tokenize=False)
|
||||
self.assertEqual(len(formatted_prompt), batch_size)
|
||||
|
||||
# Test that tokenizing with template and directly with `self.tokenizer` gives same output
|
||||
formatted_prompt_tokenized = processor.apply_chat_template(
|
||||
batch_messages, add_generation_prompt=True, tokenize=True, return_tensors=return_tensors
|
||||
)
|
||||
add_special_tokens = True
|
||||
if processor.tokenizer.bos_token is not None and formatted_prompt[0].startswith(processor.tokenizer.bos_token):
|
||||
add_special_tokens = False
|
||||
tok_output = processor.tokenizer(
|
||||
formatted_prompt, return_tensors=return_tensors, add_special_tokens=add_special_tokens
|
||||
)
|
||||
expected_output = tok_output.input_ids
|
||||
self.assertListEqual(expected_output.tolist(), formatted_prompt_tokenized.tolist())
|
||||
|
||||
# Test that kwargs passed to processor's `__call__` are actually used
|
||||
tokenized_prompt_100 = processor.apply_chat_template(
|
||||
batch_messages,
|
||||
add_generation_prompt=True,
|
||||
tokenize=True,
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
return_tensors=return_tensors,
|
||||
max_length=100,
|
||||
)
|
||||
self.assertEqual(len(tokenized_prompt_100[0]), 100)
|
||||
|
||||
# Test that `return_dict=True` returns text related inputs in the dict
|
||||
out_dict_text = processor.apply_chat_template(
|
||||
batch_messages,
|
||||
add_generation_prompt=True,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors=return_tensors,
|
||||
)
|
||||
self.assertTrue(all(key in out_dict_text for key in ["input_ids", "attention_mask"]))
|
||||
self.assertEqual(len(out_dict_text["input_ids"]), batch_size)
|
||||
self.assertEqual(len(out_dict_text["attention_mask"]), batch_size)
|
||||
|
||||
# Test that with modality URLs and `return_dict=True`, we get modality inputs in the dict
|
||||
for idx, url in enumerate(input_data[:batch_size]):
|
||||
batch_messages[idx][0]["content"] = [batch_messages[idx][0]["content"][0], {"type": modality, "url": url}]
|
||||
|
||||
out_dict = processor.apply_chat_template(
|
||||
batch_messages,
|
||||
add_generation_prompt=True,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors=return_tensors,
|
||||
num_frames=4, # by default no more than 4 frames, otherwise too slow
|
||||
)
|
||||
input_name = getattr(self, input_name)
|
||||
self.assertTrue(input_name in out_dict)
|
||||
self.assertEqual(len(out_dict["input_ids"]), batch_size)
|
||||
self.assertEqual(len(out_dict["attention_mask"]), batch_size)
|
||||
self.assertEqual(len(out_dict[input_name]), batch_size * 1564)
|
||||
|
||||
return_tensor_to_type = {"pt": torch.Tensor, "np": np.ndarray, None: list}
|
||||
for k in out_dict:
|
||||
self.assertIsInstance(out_dict[k], return_tensor_to_type[return_tensors])
|
||||
|
||||
@require_av
|
||||
def test_apply_chat_template_video_frame_sampling(self):
|
||||
processor = self.get_processor()
|
||||
if processor.chat_template is None:
|
||||
self.skipTest("Processor has no chat template")
|
||||
|
||||
signature = inspect.signature(processor.__call__)
|
||||
if "videos" not in {*signature.parameters.keys()} or (
|
||||
signature.parameters.get("videos") is not None
|
||||
and signature.parameters["videos"].annotation == inspect._empty
|
||||
):
|
||||
self.skipTest("Processor doesn't accept videos at input")
|
||||
|
||||
messages = [
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "What is shown in this video?"},
|
||||
],
|
||||
},
|
||||
]
|
||||
]
|
||||
|
||||
formatted_prompt = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
|
||||
self.assertEqual(len(formatted_prompt), 1)
|
||||
|
||||
formatted_prompt_tokenized = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=True)
|
||||
expected_output = processor.tokenizer(formatted_prompt, return_tensors=None).input_ids
|
||||
self.assertListEqual(expected_output, formatted_prompt_tokenized)
|
||||
|
||||
out_dict = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=True, return_dict=True)
|
||||
self.assertListEqual(list(out_dict.keys()), ["input_ids", "attention_mask"])
|
||||
|
||||
# Add video URL for return dict and load with `num_frames` arg
|
||||
messages[0][0]["content"].append(
|
||||
{
|
||||
"type": "video",
|
||||
"url": "https://test-videos.co.uk/vids/bigbuckbunny/mp4/h264/720/Big_Buck_Bunny_720_10s_10MB.mp4",
|
||||
}
|
||||
)
|
||||
num_frames = 3
|
||||
out_dict_with_video = processor.apply_chat_template(
|
||||
messages,
|
||||
add_generation_prompt=True,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
num_frames=num_frames,
|
||||
)
|
||||
self.assertTrue(self.videos_input_name in out_dict_with_video)
|
||||
self.assertEqual(len(out_dict_with_video[self.videos_input_name]), 9568)
|
||||
|
||||
# Load with `video_fps` arg
|
||||
video_fps = 1
|
||||
out_dict_with_video = processor.apply_chat_template(
|
||||
messages,
|
||||
add_generation_prompt=True,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
video_fps=video_fps,
|
||||
)
|
||||
self.assertTrue(self.videos_input_name in out_dict_with_video)
|
||||
self.assertEqual(len(out_dict_with_video[self.videos_input_name]), 23920)
|
||||
|
||||
# Load with `video_fps` and `num_frames` args, should raise an error
|
||||
with self.assertRaises(ValueError):
|
||||
out_dict_with_video = processor.apply_chat_template(
|
||||
messages,
|
||||
add_generation_prompt=True,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
video_fps=video_fps,
|
||||
num_frames=num_frames,
|
||||
)
|
||||
|
||||
# Load without any arg should load the whole video
|
||||
out_dict_with_video = processor.apply_chat_template(
|
||||
messages,
|
||||
add_generation_prompt=True,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
)
|
||||
self.assertTrue(self.videos_input_name in out_dict_with_video)
|
||||
self.assertEqual(len(out_dict_with_video[self.videos_input_name]), 717600)
|
||||
|
||||
# Load video as a list of frames (i.e. images). NOTE: each frame should have same size
|
||||
# because we assume they come from one video
|
||||
messages[0][0]["content"][-1] = {
|
||||
"type": "video",
|
||||
"url": [
|
||||
"https://www.ilankelman.org/stopsigns/australia.jpg",
|
||||
"https://www.ilankelman.org/stopsigns/australia.jpg",
|
||||
],
|
||||
}
|
||||
out_dict_with_video = processor.apply_chat_template(
|
||||
messages,
|
||||
add_generation_prompt=True,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
)
|
||||
self.assertTrue(self.videos_input_name in out_dict_with_video)
|
||||
self.assertEqual(len(out_dict_with_video[self.videos_input_name]), 5704)
|
||||
|
||||
@require_av
|
||||
def test_apply_chat_template_video_special_processing(self):
|
||||
"""
|
||||
Tests that models can use their own preprocessing to preprocess conversations.
|
||||
"""
|
||||
processor = self.get_processor()
|
||||
if processor.chat_template is None:
|
||||
self.skipTest("Processor has no chat template")
|
||||
|
||||
signature = inspect.signature(processor.__call__)
|
||||
if "videos" not in {*signature.parameters.keys()} or (
|
||||
signature.parameters.get("videos") is not None
|
||||
and signature.parameters["videos"].annotation == inspect._empty
|
||||
):
|
||||
self.skipTest("Processor doesn't accept videos at input")
|
||||
|
||||
video_file_path = hf_hub_download(
|
||||
repo_id="raushan-testing-hf/videos-test", filename="sample_demo_1.mp4", repo_type="dataset"
|
||||
)
|
||||
messages = [
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "video", "path": video_file_path},
|
||||
{"type": "text", "text": "What is shown in this video?"},
|
||||
],
|
||||
},
|
||||
]
|
||||
]
|
||||
|
||||
def _process_messages_for_chat_template(
|
||||
conversation,
|
||||
batch_images,
|
||||
batch_videos,
|
||||
batch_video_metadata,
|
||||
**chat_template_kwargs,
|
||||
):
|
||||
# Let us just always return a dummy prompt
|
||||
new_msg = [
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "video"}, # no need to use path, video is loaded already by this moment
|
||||
{"type": "text", "text": "Dummy prompt for preprocess testing"},
|
||||
],
|
||||
},
|
||||
]
|
||||
]
|
||||
return new_msg
|
||||
|
||||
processor._process_messages_for_chat_template = _process_messages_for_chat_template
|
||||
out_dict_with_video = processor.apply_chat_template(
|
||||
messages,
|
||||
add_generation_prompt=True,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
)
|
||||
self.assertTrue(self.videos_input_name in out_dict_with_video)
|
||||
|
||||
# Check with `in` because we don't know how each template formats the prompt with BOS/EOS/etc
|
||||
formatted_text = processor.batch_decode(out_dict_with_video["input_ids"], skip_special_tokens=True)[0]
|
||||
self.assertTrue("Dummy prompt for preprocess testing" in formatted_text)
|
||||
self.assertEqual(len(out_dict_with_video[self.videos_input_name]), 145912)
|
||||
|
||||
@require_librosa
|
||||
@require_av
|
||||
@unittest.skip(
|
||||
"@raushan: librosa can'r decode this audio in CI runner, fix after adding moviepy or another decoder"
|
||||
)
|
||||
def test_chat_template_audio_from_video(self):
|
||||
processor = self.get_processor()
|
||||
if processor.chat_template is None:
|
||||
self.skipTest("Processor has no chat template")
|
||||
|
||||
signature = inspect.signature(processor.__call__)
|
||||
if "videos" not in {*signature.parameters.keys()} or (
|
||||
signature.parameters.get("videos") is not None
|
||||
and signature.parameters["videos"].annotation == inspect._empty
|
||||
):
|
||||
self.skipTest(f"{self.processor_class} does not suport video inputs")
|
||||
|
||||
if "feature_extractor" not in self.processor_class.attributes:
|
||||
self.skipTest(f"feature_extractor attribute not present in {self.processor_class}")
|
||||
|
||||
video_file_path = hf_hub_download(
|
||||
repo_id="raushan-testing-hf/videos-test", filename="sample_demo_1.mp4", repo_type="dataset"
|
||||
)
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "video", "path": video_file_path},
|
||||
{"type": "text", "text": "Which of these animals is making the sound?"},
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [{"type": "text", "text": "It is a cow."}],
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "Tell me all about this animal."},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
formatted_prompt = processor.apply_chat_template([messages], add_generation_prompt=True, tokenize=False)
|
||||
self.assertEqual(len(formatted_prompt), 1) # batch size=1
|
||||
|
||||
out_dict = processor.apply_chat_template(
|
||||
messages,
|
||||
add_generation_prompt=True,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="np",
|
||||
load_audio_from_video=True,
|
||||
)
|
||||
self.assertTrue(self.audio_input_name in out_dict)
|
||||
self.assertTrue(self.videos_input_name in out_dict)
|
||||
|
||||
# should always have input_ids and attention_mask
|
||||
self.assertEqual(len(out_dict["input_ids"]), 1) # batch-size=1
|
||||
self.assertEqual(len(out_dict["attention_mask"]), 1) # batch-size=1
|
||||
self.assertEqual(len(out_dict[self.audio_input_name]), 1) # 1 audio in the conversation
|
||||
self.assertEqual(len(out_dict[self.videos_input_name]), 145912) # 1 video in the conversation
|
@ -719,11 +719,9 @@ class ProcessorTesterMixin:
|
||||
if "feature_extractor" not in self.processor_class.attributes:
|
||||
self.skipTest(f"feature_extractor attribute not present in {self.processor_class}")
|
||||
|
||||
feature_extractor = self.get_component("feature_extractor")
|
||||
tokenizer = self.get_component("tokenizer")
|
||||
processor_components = self.prepare_components()
|
||||
processor_kwargs = self.prepare_processor_dict()
|
||||
|
||||
processor = self.processor_class(tokenizer=tokenizer, feature_extractor=feature_extractor, **processor_kwargs)
|
||||
processor = self.processor_class(**processor_components, **processor_kwargs)
|
||||
self.skip_processor_without_typed_kwargs(processor)
|
||||
|
||||
input_str = self.prepare_text_inputs(batch_size=3, modality="audio")
|
||||
@ -1128,11 +1126,7 @@ class ProcessorTesterMixin:
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "audio",
|
||||
"url": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2-Audio/audio/glass-breaking-151256.mp3",
|
||||
},
|
||||
{"type": "text", "text": "Is it the same sound?"},
|
||||
{"type": "text", "text": "Tell me all about this animal."},
|
||||
],
|
||||
},
|
||||
]
|
||||
@ -1154,5 +1148,5 @@ class ProcessorTesterMixin:
|
||||
# should always have input_ids and attention_mask
|
||||
self.assertEqual(len(out_dict["input_ids"]), 1) # batch-size=1
|
||||
self.assertEqual(len(out_dict["attention_mask"]), 1) # batch-size=1
|
||||
self.assertEqual(len(out_dict[self.audio_input_name]), 2) # 2 audios in the conversation
|
||||
self.assertEqual(len(out_dict[self.audio_input_name]), 1) # 1 audio in the conversation
|
||||
self.assertEqual(len(out_dict[self.videos_input_name]), 1) # 1 video in the conversation
|
||||
|
@ -143,6 +143,13 @@ IGNORE_NON_TESTED = (
|
||||
"ChameleonVQVAE", # VQVAE here is used only for encoding (discretizing) and is tested as part of bigger model
|
||||
"Qwen2VLModel", # Building part of bigger (tested) model. Tested implicitly through Qwen2VLForConditionalGeneration.
|
||||
"Qwen2_5_VLModel", # Building part of bigger (tested) model. Tested implicitly through Qwen2_5_VLForConditionalGeneration.
|
||||
"Qwen2_5OmniForConditionalGeneration", # Not a regular model. Testted in Qwen2_5OmniModelIntergrationTest
|
||||
"Qwen2_5OmniTalkerForConditionalGeneration", # Building part of bigger (tested) model. Tested implicitly through Qwen2_5OmniModelIntergrationTest.
|
||||
"Qwen2_5OmniTalkerModel", # Building part of bigger (tested) model. Tested implicitly through Qwen2_5OmniModelIntergrationTest.
|
||||
"Qwen2_5OmniThinkerTextModel", # Building part of bigger (tested) model. Tested implicitly through Qwen2_5OmniModelIntergrationTest.
|
||||
"Qwen2_5OmniToken2WavModel", # Building part of bigger (tested) model. Tested implicitly through Qwen2_5OmniModelIntergrationTest.
|
||||
"Qwen2_5OmniToken2WavDiTModel", # Building part of bigger (tested) model. Tested implicitly through Qwen2_5OmniModelIntergrationTest.
|
||||
"Qwen2_5OmniToken2WavBigVGANModel", # Building part of bigger (tested) model. Tested implicitly through Qwen2_5OmniModelIntergrationTest.
|
||||
"MllamaTextModel", # Building part of bigger (tested) model. # TODO: add tests
|
||||
"MllamaVisionModel", # Building part of bigger (tested) model. # TODO: add tests
|
||||
"Llama4TextModel", # Building part of bigger (tested) model. # TODO: add tests
|
||||
@ -348,6 +355,13 @@ IGNORE_NON_AUTO_CONFIGURED = PRIVATE_MODELS.copy() + [
|
||||
"MoshiForConditionalGeneration", # no auto class for speech-to-speech
|
||||
"Emu3VQVAE", # no autoclass for VQ-VAE models
|
||||
"Emu3TextModel", # Building part of bigger (tested) model
|
||||
"Qwen2_5OmniTalkerForConditionalGeneration", # Building part of a bigger model
|
||||
"Qwen2_5OmniTalkerModel", # Building part of a bigger model
|
||||
"Qwen2_5OmniThinkerForConditionalGeneration", # Building part of a bigger model
|
||||
"Qwen2_5OmniThinkerTextModel", # Building part of a bigger model
|
||||
"Qwen2_5OmniToken2WavModel", # Building part of a bigger model
|
||||
"Qwen2_5OmniToken2WavBigVGANModel", # Building part of a bigger model
|
||||
"Qwen2_5OmniToken2WavDiTModel", # Building part of a bigger model
|
||||
]
|
||||
|
||||
# DO NOT edit this list!
|
||||
|
Loading…
Reference in New Issue
Block a user