mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-04 13:20:12 +06:00
Fix qwen_2_5 omni
(#38658)
* fix * fix * break style * break style * Apply style fixes * break style * Apply style fixes * fix modular --------- Co-authored-by: ydshieh <ydshieh@users.noreply.github.com> Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
This commit is contained in:
parent
e1812864ab
commit
d4e7aa5526
@ -4029,51 +4029,51 @@ class Qwen2_5OmniForConditionalGeneration(Qwen2_5OmniPreTrainedModel, Generation
|
||||
return thinker_result
|
||||
|
||||
# 2. Generate speech tokens from talker module
|
||||
embeds_to_talker = thinker_result.hidden_states[0][0].clone().to(self.talker.device)
|
||||
embeds_to_talker = thinker_result.hidden_states[0][0].clone().to(input_ids.device)
|
||||
if thinker_kwargs.get("input_features", None) is not None:
|
||||
audio_ids_mask = input_ids == self.config.thinker_config.audio_token_index
|
||||
audio_mask = audio_ids_mask.unsqueeze(-1).expand_as(embeds_to_talker).to(embeds_to_talker.device)
|
||||
audio_mask = audio_ids_mask.unsqueeze(-1).expand_as(embeds_to_talker)
|
||||
audio_mask_tensor = torch.zeros(
|
||||
[audio_ids_mask.sum(), embeds_to_talker.shape[-1]],
|
||||
dtype=embeds_to_talker.dtype,
|
||||
device=self.talker.device,
|
||||
device=input_ids.device,
|
||||
)
|
||||
embeds_to_talker.masked_scatter_(audio_mask, audio_mask_tensor)
|
||||
if thinker_kwargs.get("pixel_values", None) is not None:
|
||||
image_ids_mask = input_ids == self.config.thinker_config.image_token_index
|
||||
image_mask = image_ids_mask.unsqueeze(-1).expand_as(embeds_to_talker).to(embeds_to_talker.device)
|
||||
image_mask = image_ids_mask.unsqueeze(-1).expand_as(embeds_to_talker)
|
||||
image_mask_tensor = torch.zeros(
|
||||
[image_ids_mask.sum(), embeds_to_talker.shape[-1]],
|
||||
dtype=embeds_to_talker.dtype,
|
||||
device=self.talker.device,
|
||||
device=input_ids.device,
|
||||
)
|
||||
embeds_to_talker.masked_scatter_(image_mask, image_mask_tensor)
|
||||
if thinker_kwargs.get("pixel_values_videos", None) is not None:
|
||||
video_ids_mask = input_ids == self.config.thinker_config.video_token_index
|
||||
video_mask = video_ids_mask.unsqueeze(-1).expand_as(embeds_to_talker).to(embeds_to_talker.device)
|
||||
video_mask = video_ids_mask.unsqueeze(-1).expand_as(embeds_to_talker)
|
||||
video_mask_tensor = torch.zeros(
|
||||
[video_ids_mask.sum(), embeds_to_talker.shape[-1]],
|
||||
dtype=embeds_to_talker.dtype,
|
||||
device=self.talker.device,
|
||||
device=input_ids.device,
|
||||
)
|
||||
embeds_to_talker.masked_scatter_(video_mask, video_mask_tensor)
|
||||
|
||||
processed_thinker_hidden = (
|
||||
(embeds_to_talker,) + thinker_result.hidden_states[0][1:],
|
||||
) + thinker_result.hidden_states[1:]
|
||||
thinker_generate_ids = thinker_result.sequences[:, input_ids.size(1) :].to(self.talker.device)
|
||||
thinker_generate_ids = thinker_result.sequences[:, input_ids.size(1) :].to(input_ids.device)
|
||||
thinker_token_embeds = [
|
||||
token_hidden_states[0].to(self.talker.device) for token_hidden_states in processed_thinker_hidden
|
||||
token_hidden_states[0].to(input_ids.device) for token_hidden_states in processed_thinker_hidden
|
||||
]
|
||||
thinker_hidden_states = [
|
||||
token_hidden_states[-1].to(self.talker.device) for token_hidden_states in processed_thinker_hidden
|
||||
token_hidden_states[-1].to(input_ids.device) for token_hidden_states in processed_thinker_hidden
|
||||
]
|
||||
|
||||
talker_text_bos_token = speaker_params["bos_token"]
|
||||
talker_input_text_ids = torch.cat(
|
||||
[
|
||||
input_ids.to(self.talker.device),
|
||||
torch.tensor([[talker_text_bos_token]], dtype=torch.long, device=self.talker.device),
|
||||
input_ids,
|
||||
torch.tensor([[talker_text_bos_token]], dtype=torch.long, device=input_ids.device),
|
||||
thinker_generate_ids[:, :1],
|
||||
],
|
||||
dim=-1,
|
||||
@ -4081,9 +4081,9 @@ class Qwen2_5OmniForConditionalGeneration(Qwen2_5OmniPreTrainedModel, Generation
|
||||
|
||||
talker_input_ids = torch.cat(
|
||||
[
|
||||
torch.full_like(input_ids, fill_value=self.talker.codec_mask_token, device=self.talker.device),
|
||||
torch.tensor([[self.talker.codec_pad_token]], dtype=torch.long, device=self.talker.device),
|
||||
torch.tensor([[self.talker.codec_bos_token]], dtype=torch.long, device=self.talker.device),
|
||||
torch.full_like(input_ids, fill_value=self.talker.codec_mask_token),
|
||||
torch.tensor([[self.talker.codec_pad_token]], dtype=torch.long, device=input_ids.device),
|
||||
torch.tensor([[self.talker.codec_bos_token]], dtype=torch.long, device=input_ids.device),
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
@ -4091,8 +4091,8 @@ class Qwen2_5OmniForConditionalGeneration(Qwen2_5OmniPreTrainedModel, Generation
|
||||
thinker_embed_tokens = self.thinker.get_input_embeddings()
|
||||
thinker_reply_part = torch.cat(thinker_hidden_states[1:], dim=1) + torch.cat(thinker_token_embeds[1:], dim=1)
|
||||
talker_inputs_embeds = thinker_hidden_states[0] + thinker_token_embeds[0]
|
||||
talker_text_bos_token = torch.tensor([[talker_text_bos_token]], dtype=torch.long, device=self.thinker.device)
|
||||
talker_text_bos_embed = thinker_embed_tokens(talker_text_bos_token).to(self.talker.device)
|
||||
talker_text_bos_token = torch.tensor([[talker_text_bos_token]], dtype=torch.long, device=input_ids.device)
|
||||
talker_text_bos_embed = thinker_embed_tokens(talker_text_bos_token).to(input_ids.device)
|
||||
talker_inputs_embeds = torch.cat(
|
||||
[
|
||||
talker_inputs_embeds,
|
||||
@ -4103,12 +4103,12 @@ class Qwen2_5OmniForConditionalGeneration(Qwen2_5OmniPreTrainedModel, Generation
|
||||
)
|
||||
|
||||
eos_embedding = thinker_embed_tokens(
|
||||
torch.tensor([[self.talker.text_eos_token]], dtype=torch.long, device=self.thinker.device)
|
||||
).to(self.talker.device)
|
||||
torch.tensor([[self.talker.text_eos_token]], dtype=torch.long, device=input_ids.device)
|
||||
)
|
||||
|
||||
pad_embedding = thinker_embed_tokens(
|
||||
torch.tensor([[self.talker.text_pad_token]], dtype=torch.long, device=self.thinker.device)
|
||||
).to(self.talker.device)
|
||||
torch.tensor([[self.talker.text_pad_token]], dtype=torch.long, device=input_ids.device)
|
||||
)
|
||||
|
||||
thinker_reply_part = torch.cat(
|
||||
[
|
||||
@ -4123,7 +4123,7 @@ class Qwen2_5OmniForConditionalGeneration(Qwen2_5OmniPreTrainedModel, Generation
|
||||
if "attention_mask" in kwargs:
|
||||
talker_attention_mask = torch.cat(
|
||||
[kwargs["attention_mask"], kwargs["attention_mask"].new_ones((1, 2))], dim=1
|
||||
).to(self.talker.device)
|
||||
).to(input_ids.device)
|
||||
|
||||
talker_result = self.talker.generate(
|
||||
input_ids=talker_input_ids,
|
||||
@ -4132,7 +4132,7 @@ class Qwen2_5OmniForConditionalGeneration(Qwen2_5OmniPreTrainedModel, Generation
|
||||
inputs_embeds=talker_inputs_embeds,
|
||||
attention_mask=talker_attention_mask,
|
||||
suppress_tokens=[self.talker.codec_bos_token],
|
||||
**{k: (v.to(self.talker.device) if torch.is_tensor(v) else v) for k, v in talker_kwargs.items()},
|
||||
**{k: (v.to(input_ids.device) if torch.is_tensor(v) else v) for k, v in talker_kwargs.items()},
|
||||
)
|
||||
talker_generate_codes = talker_result[:, talker_input_ids.shape[1] : -1]
|
||||
|
||||
@ -4141,9 +4141,9 @@ class Qwen2_5OmniForConditionalGeneration(Qwen2_5OmniPreTrainedModel, Generation
|
||||
self.token2wav.float()
|
||||
|
||||
wav = self.token2wav(
|
||||
talker_generate_codes.to(self.token2wav.device),
|
||||
conditioning=speaker_params["cond"].to(self.token2wav.device).float(),
|
||||
reference_mel=speaker_params["ref_mel"].to(self.token2wav.device).float(),
|
||||
talker_generate_codes.to(input_ids.device),
|
||||
conditioning=speaker_params["cond"].to(input_ids.device).float(),
|
||||
reference_mel=speaker_params["ref_mel"].to(input_ids.device).float(),
|
||||
**token2wav_kwargs,
|
||||
)
|
||||
|
||||
|
@ -4295,51 +4295,51 @@ class Qwen2_5OmniForConditionalGeneration(Qwen2_5OmniPreTrainedModel, Generation
|
||||
return thinker_result
|
||||
|
||||
# 2. Generate speech tokens from talker module
|
||||
embeds_to_talker = thinker_result.hidden_states[0][0].clone().to(self.talker.device)
|
||||
embeds_to_talker = thinker_result.hidden_states[0][0].clone().to(input_ids.device)
|
||||
if thinker_kwargs.get("input_features", None) is not None:
|
||||
audio_ids_mask = input_ids == self.config.thinker_config.audio_token_index
|
||||
audio_mask = audio_ids_mask.unsqueeze(-1).expand_as(embeds_to_talker).to(embeds_to_talker.device)
|
||||
audio_mask = audio_ids_mask.unsqueeze(-1).expand_as(embeds_to_talker)
|
||||
audio_mask_tensor = torch.zeros(
|
||||
[audio_ids_mask.sum(), embeds_to_talker.shape[-1]],
|
||||
dtype=embeds_to_talker.dtype,
|
||||
device=self.talker.device,
|
||||
device=input_ids.device,
|
||||
)
|
||||
embeds_to_talker.masked_scatter_(audio_mask, audio_mask_tensor)
|
||||
if thinker_kwargs.get("pixel_values", None) is not None:
|
||||
image_ids_mask = input_ids == self.config.thinker_config.image_token_index
|
||||
image_mask = image_ids_mask.unsqueeze(-1).expand_as(embeds_to_talker).to(embeds_to_talker.device)
|
||||
image_mask = image_ids_mask.unsqueeze(-1).expand_as(embeds_to_talker)
|
||||
image_mask_tensor = torch.zeros(
|
||||
[image_ids_mask.sum(), embeds_to_talker.shape[-1]],
|
||||
dtype=embeds_to_talker.dtype,
|
||||
device=self.talker.device,
|
||||
device=input_ids.device,
|
||||
)
|
||||
embeds_to_talker.masked_scatter_(image_mask, image_mask_tensor)
|
||||
if thinker_kwargs.get("pixel_values_videos", None) is not None:
|
||||
video_ids_mask = input_ids == self.config.thinker_config.video_token_index
|
||||
video_mask = video_ids_mask.unsqueeze(-1).expand_as(embeds_to_talker).to(embeds_to_talker.device)
|
||||
video_mask = video_ids_mask.unsqueeze(-1).expand_as(embeds_to_talker)
|
||||
video_mask_tensor = torch.zeros(
|
||||
[video_ids_mask.sum(), embeds_to_talker.shape[-1]],
|
||||
dtype=embeds_to_talker.dtype,
|
||||
device=self.talker.device,
|
||||
device=input_ids.device,
|
||||
)
|
||||
embeds_to_talker.masked_scatter_(video_mask, video_mask_tensor)
|
||||
|
||||
processed_thinker_hidden = (
|
||||
(embeds_to_talker,) + thinker_result.hidden_states[0][1:],
|
||||
) + thinker_result.hidden_states[1:]
|
||||
thinker_generate_ids = thinker_result.sequences[:, input_ids.size(1) :].to(self.talker.device)
|
||||
thinker_generate_ids = thinker_result.sequences[:, input_ids.size(1) :].to(input_ids.device)
|
||||
thinker_token_embeds = [
|
||||
token_hidden_states[0].to(self.talker.device) for token_hidden_states in processed_thinker_hidden
|
||||
token_hidden_states[0].to(input_ids.device) for token_hidden_states in processed_thinker_hidden
|
||||
]
|
||||
thinker_hidden_states = [
|
||||
token_hidden_states[-1].to(self.talker.device) for token_hidden_states in processed_thinker_hidden
|
||||
token_hidden_states[-1].to(input_ids.device) for token_hidden_states in processed_thinker_hidden
|
||||
]
|
||||
|
||||
talker_text_bos_token = speaker_params["bos_token"]
|
||||
talker_input_text_ids = torch.cat(
|
||||
[
|
||||
input_ids.to(self.talker.device),
|
||||
torch.tensor([[talker_text_bos_token]], dtype=torch.long, device=self.talker.device),
|
||||
input_ids,
|
||||
torch.tensor([[talker_text_bos_token]], dtype=torch.long, device=input_ids.device),
|
||||
thinker_generate_ids[:, :1],
|
||||
],
|
||||
dim=-1,
|
||||
@ -4347,9 +4347,9 @@ class Qwen2_5OmniForConditionalGeneration(Qwen2_5OmniPreTrainedModel, Generation
|
||||
|
||||
talker_input_ids = torch.cat(
|
||||
[
|
||||
torch.full_like(input_ids, fill_value=self.talker.codec_mask_token, device=self.talker.device),
|
||||
torch.tensor([[self.talker.codec_pad_token]], dtype=torch.long, device=self.talker.device),
|
||||
torch.tensor([[self.talker.codec_bos_token]], dtype=torch.long, device=self.talker.device),
|
||||
torch.full_like(input_ids, fill_value=self.talker.codec_mask_token),
|
||||
torch.tensor([[self.talker.codec_pad_token]], dtype=torch.long, device=input_ids.device),
|
||||
torch.tensor([[self.talker.codec_bos_token]], dtype=torch.long, device=input_ids.device),
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
@ -4357,8 +4357,8 @@ class Qwen2_5OmniForConditionalGeneration(Qwen2_5OmniPreTrainedModel, Generation
|
||||
thinker_embed_tokens = self.thinker.get_input_embeddings()
|
||||
thinker_reply_part = torch.cat(thinker_hidden_states[1:], dim=1) + torch.cat(thinker_token_embeds[1:], dim=1)
|
||||
talker_inputs_embeds = thinker_hidden_states[0] + thinker_token_embeds[0]
|
||||
talker_text_bos_token = torch.tensor([[talker_text_bos_token]], dtype=torch.long, device=self.thinker.device)
|
||||
talker_text_bos_embed = thinker_embed_tokens(talker_text_bos_token).to(self.talker.device)
|
||||
talker_text_bos_token = torch.tensor([[talker_text_bos_token]], dtype=torch.long, device=input_ids.device)
|
||||
talker_text_bos_embed = thinker_embed_tokens(talker_text_bos_token).to(input_ids.device)
|
||||
talker_inputs_embeds = torch.cat(
|
||||
[
|
||||
talker_inputs_embeds,
|
||||
@ -4369,12 +4369,12 @@ class Qwen2_5OmniForConditionalGeneration(Qwen2_5OmniPreTrainedModel, Generation
|
||||
)
|
||||
|
||||
eos_embedding = thinker_embed_tokens(
|
||||
torch.tensor([[self.talker.text_eos_token]], dtype=torch.long, device=self.thinker.device)
|
||||
).to(self.talker.device)
|
||||
torch.tensor([[self.talker.text_eos_token]], dtype=torch.long, device=input_ids.device)
|
||||
)
|
||||
|
||||
pad_embedding = thinker_embed_tokens(
|
||||
torch.tensor([[self.talker.text_pad_token]], dtype=torch.long, device=self.thinker.device)
|
||||
).to(self.talker.device)
|
||||
torch.tensor([[self.talker.text_pad_token]], dtype=torch.long, device=input_ids.device)
|
||||
)
|
||||
|
||||
thinker_reply_part = torch.cat(
|
||||
[
|
||||
@ -4389,7 +4389,7 @@ class Qwen2_5OmniForConditionalGeneration(Qwen2_5OmniPreTrainedModel, Generation
|
||||
if "attention_mask" in kwargs:
|
||||
talker_attention_mask = torch.cat(
|
||||
[kwargs["attention_mask"], kwargs["attention_mask"].new_ones((1, 2))], dim=1
|
||||
).to(self.talker.device)
|
||||
).to(input_ids.device)
|
||||
|
||||
talker_result = self.talker.generate(
|
||||
input_ids=talker_input_ids,
|
||||
@ -4398,7 +4398,7 @@ class Qwen2_5OmniForConditionalGeneration(Qwen2_5OmniPreTrainedModel, Generation
|
||||
inputs_embeds=talker_inputs_embeds,
|
||||
attention_mask=talker_attention_mask,
|
||||
suppress_tokens=[self.talker.codec_bos_token],
|
||||
**{k: (v.to(self.talker.device) if torch.is_tensor(v) else v) for k, v in talker_kwargs.items()},
|
||||
**{k: (v.to(input_ids.device) if torch.is_tensor(v) else v) for k, v in talker_kwargs.items()},
|
||||
)
|
||||
talker_generate_codes = talker_result[:, talker_input_ids.shape[1] : -1]
|
||||
|
||||
@ -4407,9 +4407,9 @@ class Qwen2_5OmniForConditionalGeneration(Qwen2_5OmniPreTrainedModel, Generation
|
||||
self.token2wav.float()
|
||||
|
||||
wav = self.token2wav(
|
||||
talker_generate_codes.to(self.token2wav.device),
|
||||
conditioning=speaker_params["cond"].to(self.token2wav.device).float(),
|
||||
reference_mel=speaker_params["ref_mel"].to(self.token2wav.device).float(),
|
||||
talker_generate_codes.to(input_ids.device),
|
||||
conditioning=speaker_params["cond"].to(input_ids.device).float(),
|
||||
reference_mel=speaker_params["ref_mel"].to(input_ids.device).float(),
|
||||
**token2wav_kwargs,
|
||||
)
|
||||
|
||||
|
@ -33,6 +33,7 @@ from transformers import (
|
||||
is_vision_available,
|
||||
)
|
||||
from transformers.testing_utils import (
|
||||
Expectations,
|
||||
cleanup,
|
||||
require_flash_attn,
|
||||
require_torch,
|
||||
@ -555,13 +556,13 @@ class Qwen2_5OmniModelIntegrationTest(unittest.TestCase):
|
||||
@slow
|
||||
def test_small_model_integration_test(self):
|
||||
model = Qwen2_5OmniForConditionalGeneration.from_pretrained(
|
||||
"Qwen/Qwen2.5-Omni-7B", torch_dtype=torch.float32, device_map="auto"
|
||||
"Qwen/Qwen2.5-Omni-7B", torch_dtype=torch.bfloat16, device_map="auto"
|
||||
)
|
||||
|
||||
text = self.processor.apply_chat_template(self.messages, tokenize=False, add_generation_prompt=True)
|
||||
inputs = self.processor(
|
||||
text=[text], audio=[self.raw_audio], images=[self.raw_image], return_tensors="pt", padding=True
|
||||
)
|
||||
text=text, audio=[self.raw_audio], images=[self.raw_image], return_tensors="pt", padding=True
|
||||
).to(torch.bfloat16)
|
||||
|
||||
expected_input_ids = torch.tensor(
|
||||
[
|
||||
@ -581,7 +582,7 @@ class Qwen2_5OmniModelIntegrationTest(unittest.TestCase):
|
||||
198,
|
||||
151647,
|
||||
151646,
|
||||
151648,
|
||||
151646,
|
||||
]
|
||||
)
|
||||
assert torch.allclose(expected_input_ids, inputs.input_ids[0][:17], atol=3e-3)
|
||||
@ -595,7 +596,7 @@ class Qwen2_5OmniModelIntegrationTest(unittest.TestCase):
|
||||
[1.3902, 1.4048, 1.4194],
|
||||
[1.5216, 1.5362, 1.5362],
|
||||
],
|
||||
dtype=torch.float32,
|
||||
dtype=torch.bfloat16,
|
||||
device="cpu",
|
||||
)
|
||||
assert torch.allclose(expected_pixel_slice, inputs.pixel_values[:6, :3], atol=3e-3)
|
||||
@ -603,9 +604,11 @@ class Qwen2_5OmniModelIntegrationTest(unittest.TestCase):
|
||||
# verify generation
|
||||
inputs = inputs.to(torch_device)
|
||||
|
||||
output = model.generate(**inputs, thinker_temperature=0, thinker_do_sample=False, return_audio=False)
|
||||
output = model.generate(
|
||||
**inputs, thinker_temperature=0, thinker_do_sample=False, return_audio=False, thinker_max_new_tokens=20
|
||||
)
|
||||
|
||||
EXPECTED_DECODED_TEXT = "system\nYou are a helpful assistant.\nuser\nWhat's that sound and what kind of dog is this?\nassistant\nThe sound is glass shattering, and the dog appears to be a Labrador Retriever."
|
||||
EXPECTED_DECODED_TEXT = "system\nYou are a helpful assistant.\nuser\nWhat's that sound and what kind of dog is this?\nassistant\nThe sound is glass shattering, and the dog is a Labrador Retriever."
|
||||
|
||||
self.assertEqual(
|
||||
self.processor.decode(output[0], skip_special_tokens=True),
|
||||
@ -615,23 +618,34 @@ class Qwen2_5OmniModelIntegrationTest(unittest.TestCase):
|
||||
@slow
|
||||
def test_small_model_integration_test_batch(self):
|
||||
model = Qwen2_5OmniForConditionalGeneration.from_pretrained(
|
||||
"Qwen/Qwen2.5-Omni-7B", torch_dtype=torch.float32, device_map="auto"
|
||||
"Qwen/Qwen2.5-Omni-7B", torch_dtype=torch.bfloat16, device_map="auto"
|
||||
)
|
||||
text = self.processor.apply_chat_template(self.messages, tokenize=False, add_generation_prompt=True)
|
||||
inputs = self.processor(
|
||||
text=[text, text],
|
||||
text=text * 2,
|
||||
audio=[self.raw_audio, self.raw_audio],
|
||||
images=[self.raw_image, self.raw_image],
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
).to(torch_device)
|
||||
).to(torch_device, dtype=torch.bfloat16)
|
||||
|
||||
output = model.generate(**inputs, thinker_temperature=0, thinker_do_sample=False, return_audio=False)
|
||||
output = model.generate(
|
||||
**inputs, thinker_temperature=0, thinker_do_sample=False, return_audio=False, thinker_max_new_tokens=20
|
||||
)
|
||||
|
||||
EXPECTED_DECODED_TEXT = [
|
||||
"system\nYou are a helpful assistant.\nuser\nWhat's that sound and what kind of dog is this?\nassistant\nThe sound is glass shattering, and the dog appears to be a Labrador Retriever.",
|
||||
"system\nYou are a helpful assistant.\nuser\nWhat's that sound and what kind of dog is this?\nassistant\nThe sound is glass shattering, and the dog appears to be a Labrador Retriever.",
|
||||
]
|
||||
EXPECTED_DECODED_TEXTS = Expectations(
|
||||
{
|
||||
("cuda", 7) : [
|
||||
"system\nYou are a helpful assistant.\nuser\nWhat's that sound and what kind of dog is this?\nassistant\nThe sound is of glass shattering, and the dog in the picture is a Labrador Retriever",
|
||||
"system\nYou are a helpful assistant.\nuser\nWhat's that sound and what kind of dog is this?\nassistant\nThe sound is of glass shattering, and the dog in the picture is a Labrador Retriever",
|
||||
],
|
||||
("cuda", 8): [
|
||||
"system\nYou are a helpful assistant.\nuser\nWhat's that sound and what kind of dog is this?\nassistant\nThe sound is glass shattering, and the dog is a Labrador Retriever.",
|
||||
"system\nYou are a helpful assistant.\nuser\nWhat's that sound and what kind of dog is this?\nassistant\nThe sound is glass shattering, and the dog is a Labrador Retriever.",
|
||||
],
|
||||
}
|
||||
) # fmt: skip
|
||||
EXPECTED_DECODED_TEXT = EXPECTED_DECODED_TEXTS.get_expectation()
|
||||
|
||||
self.assertEqual(
|
||||
self.processor.batch_decode(output, skip_special_tokens=True),
|
||||
@ -641,7 +655,7 @@ class Qwen2_5OmniModelIntegrationTest(unittest.TestCase):
|
||||
@slow
|
||||
def test_small_model_integration_test_multiturn(self):
|
||||
model = Qwen2_5OmniForConditionalGeneration.from_pretrained(
|
||||
"Qwen/Qwen2.5-Omni-7B", torch_dtype=torch.float32, device_map="auto"
|
||||
"Qwen/Qwen2.5-Omni-7B", torch_dtype=torch.bfloat16, device_map="auto"
|
||||
)
|
||||
|
||||
messages = [
|
||||
@ -666,14 +680,16 @@ class Qwen2_5OmniModelIntegrationTest(unittest.TestCase):
|
||||
|
||||
text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
||||
inputs = self.processor(
|
||||
text=[text],
|
||||
text=text,
|
||||
audio=[self.raw_audio, self.raw_audio_additional],
|
||||
images=[self.raw_image],
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
).to(torch_device)
|
||||
).to(torch_device, dtype=torch.bfloat16)
|
||||
|
||||
output = model.generate(**inputs, thinker_temperature=0, thinker_do_sample=False, return_audio=False)
|
||||
output = model.generate(
|
||||
**inputs, thinker_temperature=0, thinker_do_sample=False, return_audio=False, thinker_max_new_tokens=20
|
||||
)
|
||||
|
||||
EXPECTED_DECODED_TEXT = "system\nYou are a helpful assistant.\nuser\nWhat's that sound and what kind of dog is this?\nassistant\nThe sound is glass shattering, and the dog appears to be a Labrador Retriever.\nuser\nHow about this one?\nassistant\nThe sound is a cough."
|
||||
|
||||
@ -685,7 +701,7 @@ class Qwen2_5OmniModelIntegrationTest(unittest.TestCase):
|
||||
@slow
|
||||
def test_small_model_integration_test_w_audio(self):
|
||||
model = Qwen2_5OmniForConditionalGeneration.from_pretrained(
|
||||
"Qwen/Qwen2.5-Omni-7B", torch_dtype=torch.float32, device_map="auto"
|
||||
"Qwen/Qwen2.5-Omni-7B", torch_dtype=torch.bfloat16, device_map="auto"
|
||||
)
|
||||
audio_url = "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2-Audio/audio/guess_age_gender.wav"
|
||||
|
||||
@ -707,11 +723,25 @@ class Qwen2_5OmniModelIntegrationTest(unittest.TestCase):
|
||||
audio, _ = librosa.load(BytesIO(urlopen(audio_url).read()), sr=self.processor.feature_extractor.sampling_rate)
|
||||
|
||||
text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
||||
inputs = self.processor(text=text, audio=[audio], return_tensors="pt", padding=True).to(torch_device)
|
||||
inputs = self.processor(text=text, audio=[audio], return_tensors="pt", padding=True).to(
|
||||
torch_device, dtype=torch.bfloat16
|
||||
)
|
||||
|
||||
output = model.generate(**inputs, thinker_temperature=0, thinker_do_sample=False)
|
||||
output = model.generate(
|
||||
**inputs,
|
||||
thinker_temperature=0,
|
||||
thinker_do_sample=False,
|
||||
thinker_max_new_tokens=20,
|
||||
talker_max_new_tokens=10,
|
||||
)
|
||||
|
||||
EXPECTED_DECODED_TEXT = "system\nYou are Qwen, a virtual human developed by the Qwen Team, Alibaba Group, capable of perceiving auditory and visual inputs, as well as generating text and speech.\nuser\n\nassistant\nWell, I can't really guess your age and gender just from your voice. There are so many factors that can affect how a voice sounds, like the environment you're in, how you're feeling at the moment, and even the microphone you're using. But if you want to share more about your voice, like if it's high - pitched or low - pitched, that might give me a bit of an idea. So, what can you tell me about your voice?"
|
||||
EXPECTED_DECODED_TEXTS = Expectations(
|
||||
{
|
||||
("cuda", 7): "system\nYou are Qwen, a virtual human developed by the Qwen Team, Alibaba Group, capable of perceiving auditory and visual inputs, as well as generating text and speech.\nuser\n\nassistant\nWell, I can try. But it's not always that accurate. I might be able to make",
|
||||
("cuda", 8): "system\nYou are Qwen, a virtual human developed by the Qwen Team, Alibaba Group, capable of perceiving auditory and visual inputs, as well as generating text and speech.\nuser\n\nassistant\nWell, I can't really guess your age and gender just from your voice. There are are a",
|
||||
}
|
||||
) # fmt: skip
|
||||
EXPECTED_DECODED_TEXT = EXPECTED_DECODED_TEXTS.get_expectation()
|
||||
|
||||
self.assertEqual(
|
||||
self.processor.decode(output[0][0], skip_special_tokens=True),
|
||||
|
Loading…
Reference in New Issue
Block a user