Skip sdpa tests if submodule does not support sdpa (#38907)

This commit is contained in:
ivarflakstad 2025-06-19 15:11:01 +02:00 committed by GitHub
parent 5d26a38735
commit af6120b3eb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -3799,8 +3799,20 @@ class ModelTesterMixin:
self.skipTest(reason="Idefics currently (transformers==4.39.1) requires an image_attention_mask input")
if config.model_type in ["sam"]:
self.skipTest(reason="SAM requires an attention_mask input for relative positional embeddings")
model = model_class(config)
sub_models_supporting_sdpa = [
module._supports_sdpa
for name, module in model.named_modules()
if isinstance(module, PreTrainedModel) and name != ""
]
supports_sdpa_all_modules = (
all(sub_models_supporting_sdpa) if len(sub_models_supporting_sdpa) > 0 else model._supports_sdpa
)
if not supports_sdpa_all_modules:
self.skipTest(reason="This models' submodels does not support sdpa")
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, attn_implementation="sdpa")
@ -3848,8 +3860,20 @@ class ModelTesterMixin:
"Cannot compile forward without an existing cache with Hybrid, as `torch._dynamo.mark_static_address` "
"is a forbidden call."
)
model = model_class(config)
sub_models_supporting_sdpa = [
module._supports_sdpa
for name, module in model.named_modules()
if isinstance(module, PreTrainedModel) and name != ""
]
supports_sdpa_all_modules = (
all(sub_models_supporting_sdpa) if len(sub_models_supporting_sdpa) > 0 else model._supports_sdpa
)
if not supports_sdpa_all_modules:
self.skipTest(reason="This models' submodels does not support sdpa")
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, attn_implementation="sdpa")