mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 12:50:06 +06:00
[Qwen2.5-Omni] Fix dtype of cos,sin when used with flash attention (#38453)
Some checks are pending
Self-hosted runner (benchmark) / Benchmark (aws-g5-4xlarge-cache) (push) Waiting to run
Build documentation / build (push) Waiting to run
New model PR merged notification / Notify new model (push) Waiting to run
Slow tests on important models (on Push - A10) / Get all modified files (push) Waiting to run
Slow tests on important models (on Push - A10) / Slow & FA2 tests (push) Blocked by required conditions
Self-hosted runner (push-caller) / Check if setup was changed (push) Waiting to run
Self-hosted runner (push-caller) / build-docker-containers (push) Blocked by required conditions
Self-hosted runner (push-caller) / Trigger Push CI (push) Blocked by required conditions
Secret Leaks / trufflehog (push) Waiting to run
Update Transformers metadata / build_and_package (push) Waiting to run
Some checks are pending
Self-hosted runner (benchmark) / Benchmark (aws-g5-4xlarge-cache) (push) Waiting to run
Build documentation / build (push) Waiting to run
New model PR merged notification / Notify new model (push) Waiting to run
Slow tests on important models (on Push - A10) / Get all modified files (push) Waiting to run
Slow tests on important models (on Push - A10) / Slow & FA2 tests (push) Blocked by required conditions
Self-hosted runner (push-caller) / Check if setup was changed (push) Waiting to run
Self-hosted runner (push-caller) / build-docker-containers (push) Blocked by required conditions
Self-hosted runner (push-caller) / Trigger Push CI (push) Blocked by required conditions
Secret Leaks / trufflehog (push) Waiting to run
Update Transformers metadata / build_and_package (push) Waiting to run
* Fix dtype of cos,sin when used with flash attention * Fix dtype of cos,sin when used with flash attention
This commit is contained in:
parent
81cff7ad34
commit
42ef218b58
@ -1034,8 +1034,8 @@ class Qwen2_5OmniVisionFlashAttention2(nn.Module):
|
|||||||
|
|
||||||
def _apply_rotary_pos_emb_flashatt(self, tensor: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor:
|
def _apply_rotary_pos_emb_flashatt(self, tensor: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor:
|
||||||
tensor_ = tensor.float()
|
tensor_ = tensor.float()
|
||||||
cos = freqs.cos() # .type_as(tensor_)
|
cos = freqs.cos().type_as(tensor_)
|
||||||
sin = freqs.sin() # .type_as(tensor_)
|
sin = freqs.sin().type_as(tensor_)
|
||||||
output = apply_rotary_emb(tensor_, cos, sin).type_as(tensor)
|
output = apply_rotary_emb(tensor_, cos, sin).type_as(tensor)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
@ -2022,8 +2022,8 @@ class Qwen2_5OmniVisionFlashAttention2(nn.Module):
|
|||||||
|
|
||||||
def _apply_rotary_pos_emb_flashatt(self, tensor: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor:
|
def _apply_rotary_pos_emb_flashatt(self, tensor: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor:
|
||||||
tensor_ = tensor.float()
|
tensor_ = tensor.float()
|
||||||
cos = freqs.cos() # .type_as(tensor_)
|
cos = freqs.cos().type_as(tensor_)
|
||||||
sin = freqs.sin() # .type_as(tensor_)
|
sin = freqs.sin().type_as(tensor_)
|
||||||
output = apply_rotary_emb(tensor_, cos, sin).type_as(tensor)
|
output = apply_rotary_emb(tensor_, cos, sin).type_as(tensor)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user