[tests] make test_sdpa_equivalence device-agnostic (#32520)

* fix on xpu

* [run_all]
This commit is contained in:
Fanli Lin 2024-08-16 18:34:13 +08:00 committed by GitHub
parent 70d5df6107
commit 8f9fa3b081
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -27,6 +27,7 @@ from transformers.testing_utils import (
require_flash_attn,
require_read_token,
require_torch,
require_torch_accelerator,
require_torch_gpu,
require_torch_sdpa,
slow,
@ -460,7 +461,7 @@ class GemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
self.skipTest(reason="Gemma flash attention does not support right padding")
@require_torch_sdpa
@require_torch_gpu
@require_torch_accelerator
@slow
def test_sdpa_equivalence(self):
for model_class in self.all_model_classes: