Fix inference bugs in Qwen2.5 Omni (#37701)

* Init `SinusoidsPositionEmbedding` with float to avoid precision problem

* fix hidden_state for talker

* Update modular_qwen2_5_omni.py

* Move hidden processing out from thinker

* fixup

---------

Co-authored-by: lvyuanjun.lyj <lvyuanjun.lyj@alibaba-inc.com>
This commit is contained in:
BakerBunker 2025-04-24 16:51:44 +08:00 committed by GitHub
parent b7f7aa78a0
commit 3ed56bea0f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 70 additions and 6 deletions

View File

@ -815,7 +815,7 @@ class SinusoidsPositionEmbedding(nn.Module):
if channels % 2 != 0:
raise ValueError("SinusoidsPositionEmbedding needs even channels input")
log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2))
inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2)).float()
scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
self.register_buffer(
"positional_embedding",
@ -4560,12 +4560,44 @@ class Qwen2_5OmniForConditionalGeneration(Qwen2_5OmniPreTrainedModel, Generation
return thinker_result
# 2. Generate speech tokens from talker module
embeds_to_talker = thinker_result.hidden_states[0][0].clone()
if thinker_kwargs.get("input_features", None) is not None:
audio_ids_mask = input_ids == self.config.thinker_config.audio_token_index
audio_mask = audio_ids_mask.unsqueeze(-1).expand_as(embeds_to_talker).to(embeds_to_talker.device)
audio_mask_tensor = torch.zeros(
[audio_ids_mask.sum(), embeds_to_talker.shape[-1]],
dtype=embeds_to_talker.dtype,
device=self.talker.device,
)
embeds_to_talker.masked_scatter_(audio_mask, audio_mask_tensor)
if thinker_kwargs.get("pixel_values", None) is not None:
image_ids_mask = input_ids == self.config.thinker_config.image_token_index
image_mask = image_ids_mask.unsqueeze(-1).expand_as(embeds_to_talker).to(embeds_to_talker.device)
image_mask_tensor = torch.zeros(
[image_ids_mask.sum(), embeds_to_talker.shape[-1]],
dtype=embeds_to_talker.dtype,
device=self.talker.device,
)
embeds_to_talker.masked_scatter_(image_mask, image_mask_tensor)
if thinker_kwargs.get("pixel_values_videos", None) is not None:
video_ids_mask = input_ids == self.config.thinker_config.video_token_index
video_mask = video_ids_mask.unsqueeze(-1).expand_as(embeds_to_talker).to(embeds_to_talker.device)
video_mask_tensor = torch.zeros(
[video_ids_mask.sum(), embeds_to_talker.shape[-1]],
dtype=embeds_to_talker.dtype,
device=self.talker.device,
)
embeds_to_talker.masked_scatter_(video_mask, video_mask_tensor)
processed_thinker_hidden = (
(embeds_to_talker,) + thinker_result.hidden_states[0][1:],
) + thinker_result.hidden_states[1:]
thinker_generate_ids = thinker_result.sequences[:, input_ids.size(1) :].to(self.talker.device)
thinker_token_embeds = [
token_hidden_states[0].to(self.talker.device) for token_hidden_states in thinker_result.hidden_states
token_hidden_states[0].to(self.talker.device) for token_hidden_states in processed_thinker_hidden
]
thinker_hidden_states = [
token_hidden_states[-1].to(self.talker.device) for token_hidden_states in thinker_result.hidden_states
token_hidden_states[-1].to(self.talker.device) for token_hidden_states in processed_thinker_hidden
]
talker_text_bos_token = speaker_params["bos_token"]

View File

@ -1711,7 +1711,7 @@ class SinusoidsPositionEmbedding(nn.Module):
if channels % 2 != 0:
raise ValueError("SinusoidsPositionEmbedding needs even channels input")
log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2))
inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2)).float()
scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
self.register_buffer(
"positional_embedding",
@ -4243,12 +4243,44 @@ class Qwen2_5OmniForConditionalGeneration(Qwen2_5OmniPreTrainedModel, Generation
return thinker_result
# 2. Generate speech tokens from talker module
embeds_to_talker = thinker_result.hidden_states[0][0].clone()
if thinker_kwargs.get("input_features", None) is not None:
audio_ids_mask = input_ids == self.config.thinker_config.audio_token_index
audio_mask = audio_ids_mask.unsqueeze(-1).expand_as(embeds_to_talker).to(embeds_to_talker.device)
audio_mask_tensor = torch.zeros(
[audio_ids_mask.sum(), embeds_to_talker.shape[-1]],
dtype=embeds_to_talker.dtype,
device=self.talker.device,
)
embeds_to_talker.masked_scatter_(audio_mask, audio_mask_tensor)
if thinker_kwargs.get("pixel_values", None) is not None:
image_ids_mask = input_ids == self.config.thinker_config.image_token_index
image_mask = image_ids_mask.unsqueeze(-1).expand_as(embeds_to_talker).to(embeds_to_talker.device)
image_mask_tensor = torch.zeros(
[image_ids_mask.sum(), embeds_to_talker.shape[-1]],
dtype=embeds_to_talker.dtype,
device=self.talker.device,
)
embeds_to_talker.masked_scatter_(image_mask, image_mask_tensor)
if thinker_kwargs.get("pixel_values_videos", None) is not None:
video_ids_mask = input_ids == self.config.thinker_config.video_token_index
video_mask = video_ids_mask.unsqueeze(-1).expand_as(embeds_to_talker).to(embeds_to_talker.device)
video_mask_tensor = torch.zeros(
[video_ids_mask.sum(), embeds_to_talker.shape[-1]],
dtype=embeds_to_talker.dtype,
device=self.talker.device,
)
embeds_to_talker.masked_scatter_(video_mask, video_mask_tensor)
processed_thinker_hidden = (
(embeds_to_talker,) + thinker_result.hidden_states[0][1:],
) + thinker_result.hidden_states[1:]
thinker_generate_ids = thinker_result.sequences[:, input_ids.size(1) :].to(self.talker.device)
thinker_token_embeds = [
token_hidden_states[0].to(self.talker.device) for token_hidden_states in thinker_result.hidden_states
token_hidden_states[0].to(self.talker.device) for token_hidden_states in processed_thinker_hidden
]
thinker_hidden_states = [
token_hidden_states[-1].to(self.talker.device) for token_hidden_states in thinker_result.hidden_states
token_hidden_states[-1].to(self.talker.device) for token_hidden_states in processed_thinker_hidden
]
talker_text_bos_token = speaker_params["bos_token"]