diff --git a/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py b/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py index be1c55051be..13b641c4681 100644 --- a/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +++ b/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py @@ -1034,8 +1034,8 @@ class Qwen2_5OmniVisionFlashAttention2(nn.Module): def _apply_rotary_pos_emb_flashatt(self, tensor: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor: tensor_ = tensor.float() - cos = freqs.cos() # .type_as(tensor_) - sin = freqs.sin() # .type_as(tensor_) + cos = freqs.cos().type_as(tensor_) + sin = freqs.sin().type_as(tensor_) output = apply_rotary_emb(tensor_, cos, sin).type_as(tensor) return output diff --git a/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py b/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py index 3ced88af768..70c0490d95b 100644 --- a/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +++ b/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py @@ -2022,8 +2022,8 @@ class Qwen2_5OmniVisionFlashAttention2(nn.Module): def _apply_rotary_pos_emb_flashatt(self, tensor: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor: tensor_ = tensor.float() - cos = freqs.cos() # .type_as(tensor_) - sin = freqs.sin() # .type_as(tensor_) + cos = freqs.cos().type_as(tensor_) + sin = freqs.sin().type_as(tensor_) output = apply_rotary_emb(tensor_, cos, sin).type_as(tensor) return output