mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-30 17:52:35 +06:00
[tests] make test_sdpa_can_compile_dynamic
device-agnostic (#32519)
* enable * fix
This commit is contained in:
parent
54b7703682
commit
e55b33ceb4
@ -79,6 +79,7 @@ from transformers.testing_utils import (
|
||||
require_read_token,
|
||||
require_safetensors,
|
||||
require_torch,
|
||||
require_torch_accelerator,
|
||||
require_torch_gpu,
|
||||
require_torch_multi_accelerator,
|
||||
require_torch_multi_gpu,
|
||||
@ -4105,17 +4106,17 @@ class ModelTesterMixin:
|
||||
_ = model(**inputs_dict)
|
||||
|
||||
@require_torch_sdpa
|
||||
@require_torch_gpu
|
||||
@require_torch_accelerator
|
||||
@slow
|
||||
def test_sdpa_can_compile_dynamic(self):
|
||||
if not self.has_attentions:
|
||||
self.skipTest(reason="Model architecture does not support attentions")
|
||||
if "cuda" in torch_device:
|
||||
compute_capability = torch.cuda.get_device_capability()
|
||||
major, _ = compute_capability
|
||||
|
||||
compute_capability = torch.cuda.get_device_capability()
|
||||
major, _ = compute_capability
|
||||
|
||||
if not torch.version.cuda or major < 8:
|
||||
self.skipTest(reason="This test requires an NVIDIA GPU with compute capability >= 8.0")
|
||||
if not torch.version.cuda or major < 8:
|
||||
self.skipTest(reason="This test requires an NVIDIA GPU with compute capability >= 8.0")
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
if not model_class._supports_sdpa:
|
||||
|
Loading…
Reference in New Issue
Block a user