mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 11:11:05 +06:00
Fix inference bugs in Qwen2.5 Omni (#37701)
* Init `SinusoidsPositionEmbedding` with float to avoid precision problem * fix hidden_state for talker * Update modular_qwen2_5_omni.py * Move hidden processing out from thinker * fixup --------- Co-authored-by: lvyuanjun.lyj <lvyuanjun.lyj@alibaba-inc.com>
This commit is contained in:
parent
b7f7aa78a0
commit
3ed56bea0f
@ -815,7 +815,7 @@ class SinusoidsPositionEmbedding(nn.Module):
|
||||
if channels % 2 != 0:
|
||||
raise ValueError("SinusoidsPositionEmbedding needs even channels input")
|
||||
log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
|
||||
inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2))
|
||||
inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2)).float()
|
||||
scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
|
||||
self.register_buffer(
|
||||
"positional_embedding",
|
||||
@ -4560,12 +4560,44 @@ class Qwen2_5OmniForConditionalGeneration(Qwen2_5OmniPreTrainedModel, Generation
|
||||
return thinker_result
|
||||
|
||||
# 2. Generate speech tokens from talker module
|
||||
embeds_to_talker = thinker_result.hidden_states[0][0].clone()
|
||||
if thinker_kwargs.get("input_features", None) is not None:
|
||||
audio_ids_mask = input_ids == self.config.thinker_config.audio_token_index
|
||||
audio_mask = audio_ids_mask.unsqueeze(-1).expand_as(embeds_to_talker).to(embeds_to_talker.device)
|
||||
audio_mask_tensor = torch.zeros(
|
||||
[audio_ids_mask.sum(), embeds_to_talker.shape[-1]],
|
||||
dtype=embeds_to_talker.dtype,
|
||||
device=self.talker.device,
|
||||
)
|
||||
embeds_to_talker.masked_scatter_(audio_mask, audio_mask_tensor)
|
||||
if thinker_kwargs.get("pixel_values", None) is not None:
|
||||
image_ids_mask = input_ids == self.config.thinker_config.image_token_index
|
||||
image_mask = image_ids_mask.unsqueeze(-1).expand_as(embeds_to_talker).to(embeds_to_talker.device)
|
||||
image_mask_tensor = torch.zeros(
|
||||
[image_ids_mask.sum(), embeds_to_talker.shape[-1]],
|
||||
dtype=embeds_to_talker.dtype,
|
||||
device=self.talker.device,
|
||||
)
|
||||
embeds_to_talker.masked_scatter_(image_mask, image_mask_tensor)
|
||||
if thinker_kwargs.get("pixel_values_videos", None) is not None:
|
||||
video_ids_mask = input_ids == self.config.thinker_config.video_token_index
|
||||
video_mask = video_ids_mask.unsqueeze(-1).expand_as(embeds_to_talker).to(embeds_to_talker.device)
|
||||
video_mask_tensor = torch.zeros(
|
||||
[video_ids_mask.sum(), embeds_to_talker.shape[-1]],
|
||||
dtype=embeds_to_talker.dtype,
|
||||
device=self.talker.device,
|
||||
)
|
||||
embeds_to_talker.masked_scatter_(video_mask, video_mask_tensor)
|
||||
|
||||
processed_thinker_hidden = (
|
||||
(embeds_to_talker,) + thinker_result.hidden_states[0][1:],
|
||||
) + thinker_result.hidden_states[1:]
|
||||
thinker_generate_ids = thinker_result.sequences[:, input_ids.size(1) :].to(self.talker.device)
|
||||
thinker_token_embeds = [
|
||||
token_hidden_states[0].to(self.talker.device) for token_hidden_states in thinker_result.hidden_states
|
||||
token_hidden_states[0].to(self.talker.device) for token_hidden_states in processed_thinker_hidden
|
||||
]
|
||||
thinker_hidden_states = [
|
||||
token_hidden_states[-1].to(self.talker.device) for token_hidden_states in thinker_result.hidden_states
|
||||
token_hidden_states[-1].to(self.talker.device) for token_hidden_states in processed_thinker_hidden
|
||||
]
|
||||
|
||||
talker_text_bos_token = speaker_params["bos_token"]
|
||||
|
@ -1711,7 +1711,7 @@ class SinusoidsPositionEmbedding(nn.Module):
|
||||
if channels % 2 != 0:
|
||||
raise ValueError("SinusoidsPositionEmbedding needs even channels input")
|
||||
log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
|
||||
inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2))
|
||||
inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2)).float()
|
||||
scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
|
||||
self.register_buffer(
|
||||
"positional_embedding",
|
||||
@ -4243,12 +4243,44 @@ class Qwen2_5OmniForConditionalGeneration(Qwen2_5OmniPreTrainedModel, Generation
|
||||
return thinker_result
|
||||
|
||||
# 2. Generate speech tokens from talker module
|
||||
embeds_to_talker = thinker_result.hidden_states[0][0].clone()
|
||||
if thinker_kwargs.get("input_features", None) is not None:
|
||||
audio_ids_mask = input_ids == self.config.thinker_config.audio_token_index
|
||||
audio_mask = audio_ids_mask.unsqueeze(-1).expand_as(embeds_to_talker).to(embeds_to_talker.device)
|
||||
audio_mask_tensor = torch.zeros(
|
||||
[audio_ids_mask.sum(), embeds_to_talker.shape[-1]],
|
||||
dtype=embeds_to_talker.dtype,
|
||||
device=self.talker.device,
|
||||
)
|
||||
embeds_to_talker.masked_scatter_(audio_mask, audio_mask_tensor)
|
||||
if thinker_kwargs.get("pixel_values", None) is not None:
|
||||
image_ids_mask = input_ids == self.config.thinker_config.image_token_index
|
||||
image_mask = image_ids_mask.unsqueeze(-1).expand_as(embeds_to_talker).to(embeds_to_talker.device)
|
||||
image_mask_tensor = torch.zeros(
|
||||
[image_ids_mask.sum(), embeds_to_talker.shape[-1]],
|
||||
dtype=embeds_to_talker.dtype,
|
||||
device=self.talker.device,
|
||||
)
|
||||
embeds_to_talker.masked_scatter_(image_mask, image_mask_tensor)
|
||||
if thinker_kwargs.get("pixel_values_videos", None) is not None:
|
||||
video_ids_mask = input_ids == self.config.thinker_config.video_token_index
|
||||
video_mask = video_ids_mask.unsqueeze(-1).expand_as(embeds_to_talker).to(embeds_to_talker.device)
|
||||
video_mask_tensor = torch.zeros(
|
||||
[video_ids_mask.sum(), embeds_to_talker.shape[-1]],
|
||||
dtype=embeds_to_talker.dtype,
|
||||
device=self.talker.device,
|
||||
)
|
||||
embeds_to_talker.masked_scatter_(video_mask, video_mask_tensor)
|
||||
|
||||
processed_thinker_hidden = (
|
||||
(embeds_to_talker,) + thinker_result.hidden_states[0][1:],
|
||||
) + thinker_result.hidden_states[1:]
|
||||
thinker_generate_ids = thinker_result.sequences[:, input_ids.size(1) :].to(self.talker.device)
|
||||
thinker_token_embeds = [
|
||||
token_hidden_states[0].to(self.talker.device) for token_hidden_states in thinker_result.hidden_states
|
||||
token_hidden_states[0].to(self.talker.device) for token_hidden_states in processed_thinker_hidden
|
||||
]
|
||||
thinker_hidden_states = [
|
||||
token_hidden_states[-1].to(self.talker.device) for token_hidden_states in thinker_result.hidden_states
|
||||
token_hidden_states[-1].to(self.talker.device) for token_hidden_states in processed_thinker_hidden
|
||||
]
|
||||
|
||||
talker_text_bos_token = speaker_params["bos_token"]
|
||||
|
Loading…
Reference in New Issue
Block a user