mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 18:51:14 +06:00
[tests] make test_sdpa_can_compile_dynamic
device-agnostic (#32519)
* enable * fix
This commit is contained in:
parent
54b7703682
commit
e55b33ceb4
@ -79,6 +79,7 @@ from transformers.testing_utils import (
|
|||||||
require_read_token,
|
require_read_token,
|
||||||
require_safetensors,
|
require_safetensors,
|
||||||
require_torch,
|
require_torch,
|
||||||
|
require_torch_accelerator,
|
||||||
require_torch_gpu,
|
require_torch_gpu,
|
||||||
require_torch_multi_accelerator,
|
require_torch_multi_accelerator,
|
||||||
require_torch_multi_gpu,
|
require_torch_multi_gpu,
|
||||||
@ -4105,17 +4106,17 @@ class ModelTesterMixin:
|
|||||||
_ = model(**inputs_dict)
|
_ = model(**inputs_dict)
|
||||||
|
|
||||||
@require_torch_sdpa
|
@require_torch_sdpa
|
||||||
@require_torch_gpu
|
@require_torch_accelerator
|
||||||
@slow
|
@slow
|
||||||
def test_sdpa_can_compile_dynamic(self):
|
def test_sdpa_can_compile_dynamic(self):
|
||||||
if not self.has_attentions:
|
if not self.has_attentions:
|
||||||
self.skipTest(reason="Model architecture does not support attentions")
|
self.skipTest(reason="Model architecture does not support attentions")
|
||||||
|
if "cuda" in torch_device:
|
||||||
|
compute_capability = torch.cuda.get_device_capability()
|
||||||
|
major, _ = compute_capability
|
||||||
|
|
||||||
compute_capability = torch.cuda.get_device_capability()
|
if not torch.version.cuda or major < 8:
|
||||||
major, _ = compute_capability
|
self.skipTest(reason="This test requires an NVIDIA GPU with compute capability >= 8.0")
|
||||||
|
|
||||||
if not torch.version.cuda or major < 8:
|
|
||||||
self.skipTest(reason="This test requires an NVIDIA GPU with compute capability >= 8.0")
|
|
||||||
|
|
||||||
for model_class in self.all_model_classes:
|
for model_class in self.all_model_classes:
|
||||||
if not model_class._supports_sdpa:
|
if not model_class._supports_sdpa:
|
||||||
|
Loading…
Reference in New Issue
Block a user