mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
[tests] make test_sdpa_equivalence device-agnostic (#32520)
* fix on xpu * [run_all]
This commit is contained in:
parent
70d5df6107
commit
8f9fa3b081
@ -27,6 +27,7 @@ from transformers.testing_utils import (
|
||||
require_flash_attn,
|
||||
require_read_token,
|
||||
require_torch,
|
||||
require_torch_accelerator,
|
||||
require_torch_gpu,
|
||||
require_torch_sdpa,
|
||||
slow,
|
||||
@ -460,7 +461,7 @@ class GemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
||||
self.skipTest(reason="Gemma flash attention does not support right padding")
|
||||
|
||||
@require_torch_sdpa
|
||||
@require_torch_gpu
|
||||
@require_torch_accelerator
|
||||
@slow
|
||||
def test_sdpa_equivalence(self):
|
||||
for model_class in self.all_model_classes:
|
||||
|
Loading…
Reference in New Issue
Block a user