Fix qwen_2_5 omni (#38658)

* fix

* fix

* break style

* break style

* Apply style fixes

* break style

* Apply style fixes

* fix modular

---------

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
This commit is contained in:
Yih-Dar 2025-06-12 14:43:54 +02:00 committed by GitHub
parent e1812864ab
commit d4e7aa5526
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 105 additions and 75 deletions

View File

@ -4029,51 +4029,51 @@ class Qwen2_5OmniForConditionalGeneration(Qwen2_5OmniPreTrainedModel, Generation
return thinker_result
# 2. Generate speech tokens from talker module
embeds_to_talker = thinker_result.hidden_states[0][0].clone().to(self.talker.device)
embeds_to_talker = thinker_result.hidden_states[0][0].clone().to(input_ids.device)
if thinker_kwargs.get("input_features", None) is not None:
audio_ids_mask = input_ids == self.config.thinker_config.audio_token_index
audio_mask = audio_ids_mask.unsqueeze(-1).expand_as(embeds_to_talker).to(embeds_to_talker.device)
audio_mask = audio_ids_mask.unsqueeze(-1).expand_as(embeds_to_talker)
audio_mask_tensor = torch.zeros(
[audio_ids_mask.sum(), embeds_to_talker.shape[-1]],
dtype=embeds_to_talker.dtype,
device=self.talker.device,
device=input_ids.device,
)
embeds_to_talker.masked_scatter_(audio_mask, audio_mask_tensor)
if thinker_kwargs.get("pixel_values", None) is not None:
image_ids_mask = input_ids == self.config.thinker_config.image_token_index
image_mask = image_ids_mask.unsqueeze(-1).expand_as(embeds_to_talker).to(embeds_to_talker.device)
image_mask = image_ids_mask.unsqueeze(-1).expand_as(embeds_to_talker)
image_mask_tensor = torch.zeros(
[image_ids_mask.sum(), embeds_to_talker.shape[-1]],
dtype=embeds_to_talker.dtype,
device=self.talker.device,
device=input_ids.device,
)
embeds_to_talker.masked_scatter_(image_mask, image_mask_tensor)
if thinker_kwargs.get("pixel_values_videos", None) is not None:
video_ids_mask = input_ids == self.config.thinker_config.video_token_index
video_mask = video_ids_mask.unsqueeze(-1).expand_as(embeds_to_talker).to(embeds_to_talker.device)
video_mask = video_ids_mask.unsqueeze(-1).expand_as(embeds_to_talker)
video_mask_tensor = torch.zeros(
[video_ids_mask.sum(), embeds_to_talker.shape[-1]],
dtype=embeds_to_talker.dtype,
device=self.talker.device,
device=input_ids.device,
)
embeds_to_talker.masked_scatter_(video_mask, video_mask_tensor)
processed_thinker_hidden = (
(embeds_to_talker,) + thinker_result.hidden_states[0][1:],
) + thinker_result.hidden_states[1:]
thinker_generate_ids = thinker_result.sequences[:, input_ids.size(1) :].to(self.talker.device)
thinker_generate_ids = thinker_result.sequences[:, input_ids.size(1) :].to(input_ids.device)
thinker_token_embeds = [
token_hidden_states[0].to(self.talker.device) for token_hidden_states in processed_thinker_hidden
token_hidden_states[0].to(input_ids.device) for token_hidden_states in processed_thinker_hidden
]
thinker_hidden_states = [
token_hidden_states[-1].to(self.talker.device) for token_hidden_states in processed_thinker_hidden
token_hidden_states[-1].to(input_ids.device) for token_hidden_states in processed_thinker_hidden
]
talker_text_bos_token = speaker_params["bos_token"]
talker_input_text_ids = torch.cat(
[
input_ids.to(self.talker.device),
torch.tensor([[talker_text_bos_token]], dtype=torch.long, device=self.talker.device),
input_ids,
torch.tensor([[talker_text_bos_token]], dtype=torch.long, device=input_ids.device),
thinker_generate_ids[:, :1],
],
dim=-1,
@ -4081,9 +4081,9 @@ class Qwen2_5OmniForConditionalGeneration(Qwen2_5OmniPreTrainedModel, Generation
talker_input_ids = torch.cat(
[
torch.full_like(input_ids, fill_value=self.talker.codec_mask_token, device=self.talker.device),
torch.tensor([[self.talker.codec_pad_token]], dtype=torch.long, device=self.talker.device),
torch.tensor([[self.talker.codec_bos_token]], dtype=torch.long, device=self.talker.device),
torch.full_like(input_ids, fill_value=self.talker.codec_mask_token),
torch.tensor([[self.talker.codec_pad_token]], dtype=torch.long, device=input_ids.device),
torch.tensor([[self.talker.codec_bos_token]], dtype=torch.long, device=input_ids.device),
],
dim=1,
)
@ -4091,8 +4091,8 @@ class Qwen2_5OmniForConditionalGeneration(Qwen2_5OmniPreTrainedModel, Generation
thinker_embed_tokens = self.thinker.get_input_embeddings()
thinker_reply_part = torch.cat(thinker_hidden_states[1:], dim=1) + torch.cat(thinker_token_embeds[1:], dim=1)
talker_inputs_embeds = thinker_hidden_states[0] + thinker_token_embeds[0]
talker_text_bos_token = torch.tensor([[talker_text_bos_token]], dtype=torch.long, device=self.thinker.device)
talker_text_bos_embed = thinker_embed_tokens(talker_text_bos_token).to(self.talker.device)
talker_text_bos_token = torch.tensor([[talker_text_bos_token]], dtype=torch.long, device=input_ids.device)
talker_text_bos_embed = thinker_embed_tokens(talker_text_bos_token).to(input_ids.device)
talker_inputs_embeds = torch.cat(
[
talker_inputs_embeds,
@ -4103,12 +4103,12 @@ class Qwen2_5OmniForConditionalGeneration(Qwen2_5OmniPreTrainedModel, Generation
)
eos_embedding = thinker_embed_tokens(
torch.tensor([[self.talker.text_eos_token]], dtype=torch.long, device=self.thinker.device)
).to(self.talker.device)
torch.tensor([[self.talker.text_eos_token]], dtype=torch.long, device=input_ids.device)
)
pad_embedding = thinker_embed_tokens(
torch.tensor([[self.talker.text_pad_token]], dtype=torch.long, device=self.thinker.device)
).to(self.talker.device)
torch.tensor([[self.talker.text_pad_token]], dtype=torch.long, device=input_ids.device)
)
thinker_reply_part = torch.cat(
[
@ -4123,7 +4123,7 @@ class Qwen2_5OmniForConditionalGeneration(Qwen2_5OmniPreTrainedModel, Generation
if "attention_mask" in kwargs:
talker_attention_mask = torch.cat(
[kwargs["attention_mask"], kwargs["attention_mask"].new_ones((1, 2))], dim=1
).to(self.talker.device)
).to(input_ids.device)
talker_result = self.talker.generate(
input_ids=talker_input_ids,
@ -4132,7 +4132,7 @@ class Qwen2_5OmniForConditionalGeneration(Qwen2_5OmniPreTrainedModel, Generation
inputs_embeds=talker_inputs_embeds,
attention_mask=talker_attention_mask,
suppress_tokens=[self.talker.codec_bos_token],
**{k: (v.to(self.talker.device) if torch.is_tensor(v) else v) for k, v in talker_kwargs.items()},
**{k: (v.to(input_ids.device) if torch.is_tensor(v) else v) for k, v in talker_kwargs.items()},
)
talker_generate_codes = talker_result[:, talker_input_ids.shape[1] : -1]
@ -4141,9 +4141,9 @@ class Qwen2_5OmniForConditionalGeneration(Qwen2_5OmniPreTrainedModel, Generation
self.token2wav.float()
wav = self.token2wav(
talker_generate_codes.to(self.token2wav.device),
conditioning=speaker_params["cond"].to(self.token2wav.device).float(),
reference_mel=speaker_params["ref_mel"].to(self.token2wav.device).float(),
talker_generate_codes.to(input_ids.device),
conditioning=speaker_params["cond"].to(input_ids.device).float(),
reference_mel=speaker_params["ref_mel"].to(input_ids.device).float(),
**token2wav_kwargs,
)

View File

@ -4295,51 +4295,51 @@ class Qwen2_5OmniForConditionalGeneration(Qwen2_5OmniPreTrainedModel, Generation
return thinker_result
# 2. Generate speech tokens from talker module
embeds_to_talker = thinker_result.hidden_states[0][0].clone().to(self.talker.device)
embeds_to_talker = thinker_result.hidden_states[0][0].clone().to(input_ids.device)
if thinker_kwargs.get("input_features", None) is not None:
audio_ids_mask = input_ids == self.config.thinker_config.audio_token_index
audio_mask = audio_ids_mask.unsqueeze(-1).expand_as(embeds_to_talker).to(embeds_to_talker.device)
audio_mask = audio_ids_mask.unsqueeze(-1).expand_as(embeds_to_talker)
audio_mask_tensor = torch.zeros(
[audio_ids_mask.sum(), embeds_to_talker.shape[-1]],
dtype=embeds_to_talker.dtype,
device=self.talker.device,
device=input_ids.device,
)
embeds_to_talker.masked_scatter_(audio_mask, audio_mask_tensor)
if thinker_kwargs.get("pixel_values", None) is not None:
image_ids_mask = input_ids == self.config.thinker_config.image_token_index
image_mask = image_ids_mask.unsqueeze(-1).expand_as(embeds_to_talker).to(embeds_to_talker.device)
image_mask = image_ids_mask.unsqueeze(-1).expand_as(embeds_to_talker)
image_mask_tensor = torch.zeros(
[image_ids_mask.sum(), embeds_to_talker.shape[-1]],
dtype=embeds_to_talker.dtype,
device=self.talker.device,
device=input_ids.device,
)
embeds_to_talker.masked_scatter_(image_mask, image_mask_tensor)
if thinker_kwargs.get("pixel_values_videos", None) is not None:
video_ids_mask = input_ids == self.config.thinker_config.video_token_index
video_mask = video_ids_mask.unsqueeze(-1).expand_as(embeds_to_talker).to(embeds_to_talker.device)
video_mask = video_ids_mask.unsqueeze(-1).expand_as(embeds_to_talker)
video_mask_tensor = torch.zeros(
[video_ids_mask.sum(), embeds_to_talker.shape[-1]],
dtype=embeds_to_talker.dtype,
device=self.talker.device,
device=input_ids.device,
)
embeds_to_talker.masked_scatter_(video_mask, video_mask_tensor)
processed_thinker_hidden = (
(embeds_to_talker,) + thinker_result.hidden_states[0][1:],
) + thinker_result.hidden_states[1:]
thinker_generate_ids = thinker_result.sequences[:, input_ids.size(1) :].to(self.talker.device)
thinker_generate_ids = thinker_result.sequences[:, input_ids.size(1) :].to(input_ids.device)
thinker_token_embeds = [
token_hidden_states[0].to(self.talker.device) for token_hidden_states in processed_thinker_hidden
token_hidden_states[0].to(input_ids.device) for token_hidden_states in processed_thinker_hidden
]
thinker_hidden_states = [
token_hidden_states[-1].to(self.talker.device) for token_hidden_states in processed_thinker_hidden
token_hidden_states[-1].to(input_ids.device) for token_hidden_states in processed_thinker_hidden
]
talker_text_bos_token = speaker_params["bos_token"]
talker_input_text_ids = torch.cat(
[
input_ids.to(self.talker.device),
torch.tensor([[talker_text_bos_token]], dtype=torch.long, device=self.talker.device),
input_ids,
torch.tensor([[talker_text_bos_token]], dtype=torch.long, device=input_ids.device),
thinker_generate_ids[:, :1],
],
dim=-1,
@ -4347,9 +4347,9 @@ class Qwen2_5OmniForConditionalGeneration(Qwen2_5OmniPreTrainedModel, Generation
talker_input_ids = torch.cat(
[
torch.full_like(input_ids, fill_value=self.talker.codec_mask_token, device=self.talker.device),
torch.tensor([[self.talker.codec_pad_token]], dtype=torch.long, device=self.talker.device),
torch.tensor([[self.talker.codec_bos_token]], dtype=torch.long, device=self.talker.device),
torch.full_like(input_ids, fill_value=self.talker.codec_mask_token),
torch.tensor([[self.talker.codec_pad_token]], dtype=torch.long, device=input_ids.device),
torch.tensor([[self.talker.codec_bos_token]], dtype=torch.long, device=input_ids.device),
],
dim=1,
)
@ -4357,8 +4357,8 @@ class Qwen2_5OmniForConditionalGeneration(Qwen2_5OmniPreTrainedModel, Generation
thinker_embed_tokens = self.thinker.get_input_embeddings()
thinker_reply_part = torch.cat(thinker_hidden_states[1:], dim=1) + torch.cat(thinker_token_embeds[1:], dim=1)
talker_inputs_embeds = thinker_hidden_states[0] + thinker_token_embeds[0]
talker_text_bos_token = torch.tensor([[talker_text_bos_token]], dtype=torch.long, device=self.thinker.device)
talker_text_bos_embed = thinker_embed_tokens(talker_text_bos_token).to(self.talker.device)
talker_text_bos_token = torch.tensor([[talker_text_bos_token]], dtype=torch.long, device=input_ids.device)
talker_text_bos_embed = thinker_embed_tokens(talker_text_bos_token).to(input_ids.device)
talker_inputs_embeds = torch.cat(
[
talker_inputs_embeds,
@ -4369,12 +4369,12 @@ class Qwen2_5OmniForConditionalGeneration(Qwen2_5OmniPreTrainedModel, Generation
)
eos_embedding = thinker_embed_tokens(
torch.tensor([[self.talker.text_eos_token]], dtype=torch.long, device=self.thinker.device)
).to(self.talker.device)
torch.tensor([[self.talker.text_eos_token]], dtype=torch.long, device=input_ids.device)
)
pad_embedding = thinker_embed_tokens(
torch.tensor([[self.talker.text_pad_token]], dtype=torch.long, device=self.thinker.device)
).to(self.talker.device)
torch.tensor([[self.talker.text_pad_token]], dtype=torch.long, device=input_ids.device)
)
thinker_reply_part = torch.cat(
[
@ -4389,7 +4389,7 @@ class Qwen2_5OmniForConditionalGeneration(Qwen2_5OmniPreTrainedModel, Generation
if "attention_mask" in kwargs:
talker_attention_mask = torch.cat(
[kwargs["attention_mask"], kwargs["attention_mask"].new_ones((1, 2))], dim=1
).to(self.talker.device)
).to(input_ids.device)
talker_result = self.talker.generate(
input_ids=talker_input_ids,
@ -4398,7 +4398,7 @@ class Qwen2_5OmniForConditionalGeneration(Qwen2_5OmniPreTrainedModel, Generation
inputs_embeds=talker_inputs_embeds,
attention_mask=talker_attention_mask,
suppress_tokens=[self.talker.codec_bos_token],
**{k: (v.to(self.talker.device) if torch.is_tensor(v) else v) for k, v in talker_kwargs.items()},
**{k: (v.to(input_ids.device) if torch.is_tensor(v) else v) for k, v in talker_kwargs.items()},
)
talker_generate_codes = talker_result[:, talker_input_ids.shape[1] : -1]
@ -4407,9 +4407,9 @@ class Qwen2_5OmniForConditionalGeneration(Qwen2_5OmniPreTrainedModel, Generation
self.token2wav.float()
wav = self.token2wav(
talker_generate_codes.to(self.token2wav.device),
conditioning=speaker_params["cond"].to(self.token2wav.device).float(),
reference_mel=speaker_params["ref_mel"].to(self.token2wav.device).float(),
talker_generate_codes.to(input_ids.device),
conditioning=speaker_params["cond"].to(input_ids.device).float(),
reference_mel=speaker_params["ref_mel"].to(input_ids.device).float(),
**token2wav_kwargs,
)

View File

@ -33,6 +33,7 @@ from transformers import (
is_vision_available,
)
from transformers.testing_utils import (
Expectations,
cleanup,
require_flash_attn,
require_torch,
@ -555,13 +556,13 @@ class Qwen2_5OmniModelIntegrationTest(unittest.TestCase):
@slow
def test_small_model_integration_test(self):
model = Qwen2_5OmniForConditionalGeneration.from_pretrained(
"Qwen/Qwen2.5-Omni-7B", torch_dtype=torch.float32, device_map="auto"
"Qwen/Qwen2.5-Omni-7B", torch_dtype=torch.bfloat16, device_map="auto"
)
text = self.processor.apply_chat_template(self.messages, tokenize=False, add_generation_prompt=True)
inputs = self.processor(
text=[text], audio=[self.raw_audio], images=[self.raw_image], return_tensors="pt", padding=True
)
text=text, audio=[self.raw_audio], images=[self.raw_image], return_tensors="pt", padding=True
).to(torch.bfloat16)
expected_input_ids = torch.tensor(
[
@ -581,7 +582,7 @@ class Qwen2_5OmniModelIntegrationTest(unittest.TestCase):
198,
151647,
151646,
151648,
151646,
]
)
assert torch.allclose(expected_input_ids, inputs.input_ids[0][:17], atol=3e-3)
@ -595,7 +596,7 @@ class Qwen2_5OmniModelIntegrationTest(unittest.TestCase):
[1.3902, 1.4048, 1.4194],
[1.5216, 1.5362, 1.5362],
],
dtype=torch.float32,
dtype=torch.bfloat16,
device="cpu",
)
assert torch.allclose(expected_pixel_slice, inputs.pixel_values[:6, :3], atol=3e-3)
@ -603,9 +604,11 @@ class Qwen2_5OmniModelIntegrationTest(unittest.TestCase):
# verify generation
inputs = inputs.to(torch_device)
output = model.generate(**inputs, thinker_temperature=0, thinker_do_sample=False, return_audio=False)
output = model.generate(
**inputs, thinker_temperature=0, thinker_do_sample=False, return_audio=False, thinker_max_new_tokens=20
)
EXPECTED_DECODED_TEXT = "system\nYou are a helpful assistant.\nuser\nWhat's that sound and what kind of dog is this?\nassistant\nThe sound is glass shattering, and the dog appears to be a Labrador Retriever."
EXPECTED_DECODED_TEXT = "system\nYou are a helpful assistant.\nuser\nWhat's that sound and what kind of dog is this?\nassistant\nThe sound is glass shattering, and the dog is a Labrador Retriever."
self.assertEqual(
self.processor.decode(output[0], skip_special_tokens=True),
@ -615,23 +618,34 @@ class Qwen2_5OmniModelIntegrationTest(unittest.TestCase):
@slow
def test_small_model_integration_test_batch(self):
model = Qwen2_5OmniForConditionalGeneration.from_pretrained(
"Qwen/Qwen2.5-Omni-7B", torch_dtype=torch.float32, device_map="auto"
"Qwen/Qwen2.5-Omni-7B", torch_dtype=torch.bfloat16, device_map="auto"
)
text = self.processor.apply_chat_template(self.messages, tokenize=False, add_generation_prompt=True)
inputs = self.processor(
text=[text, text],
text=text * 2,
audio=[self.raw_audio, self.raw_audio],
images=[self.raw_image, self.raw_image],
return_tensors="pt",
padding=True,
).to(torch_device)
).to(torch_device, dtype=torch.bfloat16)
output = model.generate(**inputs, thinker_temperature=0, thinker_do_sample=False, return_audio=False)
output = model.generate(
**inputs, thinker_temperature=0, thinker_do_sample=False, return_audio=False, thinker_max_new_tokens=20
)
EXPECTED_DECODED_TEXT = [
"system\nYou are a helpful assistant.\nuser\nWhat's that sound and what kind of dog is this?\nassistant\nThe sound is glass shattering, and the dog appears to be a Labrador Retriever.",
"system\nYou are a helpful assistant.\nuser\nWhat's that sound and what kind of dog is this?\nassistant\nThe sound is glass shattering, and the dog appears to be a Labrador Retriever.",
]
EXPECTED_DECODED_TEXTS = Expectations(
{
("cuda", 7) : [
"system\nYou are a helpful assistant.\nuser\nWhat's that sound and what kind of dog is this?\nassistant\nThe sound is of glass shattering, and the dog in the picture is a Labrador Retriever",
"system\nYou are a helpful assistant.\nuser\nWhat's that sound and what kind of dog is this?\nassistant\nThe sound is of glass shattering, and the dog in the picture is a Labrador Retriever",
],
("cuda", 8): [
"system\nYou are a helpful assistant.\nuser\nWhat's that sound and what kind of dog is this?\nassistant\nThe sound is glass shattering, and the dog is a Labrador Retriever.",
"system\nYou are a helpful assistant.\nuser\nWhat's that sound and what kind of dog is this?\nassistant\nThe sound is glass shattering, and the dog is a Labrador Retriever.",
],
}
) # fmt: skip
EXPECTED_DECODED_TEXT = EXPECTED_DECODED_TEXTS.get_expectation()
self.assertEqual(
self.processor.batch_decode(output, skip_special_tokens=True),
@ -641,7 +655,7 @@ class Qwen2_5OmniModelIntegrationTest(unittest.TestCase):
@slow
def test_small_model_integration_test_multiturn(self):
model = Qwen2_5OmniForConditionalGeneration.from_pretrained(
"Qwen/Qwen2.5-Omni-7B", torch_dtype=torch.float32, device_map="auto"
"Qwen/Qwen2.5-Omni-7B", torch_dtype=torch.bfloat16, device_map="auto"
)
messages = [
@ -666,14 +680,16 @@ class Qwen2_5OmniModelIntegrationTest(unittest.TestCase):
text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = self.processor(
text=[text],
text=text,
audio=[self.raw_audio, self.raw_audio_additional],
images=[self.raw_image],
return_tensors="pt",
padding=True,
).to(torch_device)
).to(torch_device, dtype=torch.bfloat16)
output = model.generate(**inputs, thinker_temperature=0, thinker_do_sample=False, return_audio=False)
output = model.generate(
**inputs, thinker_temperature=0, thinker_do_sample=False, return_audio=False, thinker_max_new_tokens=20
)
EXPECTED_DECODED_TEXT = "system\nYou are a helpful assistant.\nuser\nWhat's that sound and what kind of dog is this?\nassistant\nThe sound is glass shattering, and the dog appears to be a Labrador Retriever.\nuser\nHow about this one?\nassistant\nThe sound is a cough."
@ -685,7 +701,7 @@ class Qwen2_5OmniModelIntegrationTest(unittest.TestCase):
@slow
def test_small_model_integration_test_w_audio(self):
model = Qwen2_5OmniForConditionalGeneration.from_pretrained(
"Qwen/Qwen2.5-Omni-7B", torch_dtype=torch.float32, device_map="auto"
"Qwen/Qwen2.5-Omni-7B", torch_dtype=torch.bfloat16, device_map="auto"
)
audio_url = "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2-Audio/audio/guess_age_gender.wav"
@ -707,11 +723,25 @@ class Qwen2_5OmniModelIntegrationTest(unittest.TestCase):
audio, _ = librosa.load(BytesIO(urlopen(audio_url).read()), sr=self.processor.feature_extractor.sampling_rate)
text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = self.processor(text=text, audio=[audio], return_tensors="pt", padding=True).to(torch_device)
inputs = self.processor(text=text, audio=[audio], return_tensors="pt", padding=True).to(
torch_device, dtype=torch.bfloat16
)
output = model.generate(**inputs, thinker_temperature=0, thinker_do_sample=False)
output = model.generate(
**inputs,
thinker_temperature=0,
thinker_do_sample=False,
thinker_max_new_tokens=20,
talker_max_new_tokens=10,
)
EXPECTED_DECODED_TEXT = "system\nYou are Qwen, a virtual human developed by the Qwen Team, Alibaba Group, capable of perceiving auditory and visual inputs, as well as generating text and speech.\nuser\n\nassistant\nWell, I can't really guess your age and gender just from your voice. There are so many factors that can affect how a voice sounds, like the environment you're in, how you're feeling at the moment, and even the microphone you're using. But if you want to share more about your voice, like if it's high - pitched or low - pitched, that might give me a bit of an idea. So, what can you tell me about your voice?"
EXPECTED_DECODED_TEXTS = Expectations(
{
("cuda", 7): "system\nYou are Qwen, a virtual human developed by the Qwen Team, Alibaba Group, capable of perceiving auditory and visual inputs, as well as generating text and speech.\nuser\n\nassistant\nWell, I can try. But it's not always that accurate. I might be able to make",
("cuda", 8): "system\nYou are Qwen, a virtual human developed by the Qwen Team, Alibaba Group, capable of perceiving auditory and visual inputs, as well as generating text and speech.\nuser\n\nassistant\nWell, I can't really guess your age and gender just from your voice. There are are a",
}
) # fmt: skip
EXPECTED_DECODED_TEXT = EXPECTED_DECODED_TEXTS.get_expectation()
self.assertEqual(
self.processor.decode(output[0][0], skip_special_tokens=True),