mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 12:50:06 +06:00
Skip sdpa tests if submodule does not support sdpa (#38907)
This commit is contained in:
parent
5d26a38735
commit
af6120b3eb
@ -3799,8 +3799,20 @@ class ModelTesterMixin:
|
||||
self.skipTest(reason="Idefics currently (transformers==4.39.1) requires an image_attention_mask input")
|
||||
if config.model_type in ["sam"]:
|
||||
self.skipTest(reason="SAM requires an attention_mask input for relative positional embeddings")
|
||||
|
||||
model = model_class(config)
|
||||
|
||||
sub_models_supporting_sdpa = [
|
||||
module._supports_sdpa
|
||||
for name, module in model.named_modules()
|
||||
if isinstance(module, PreTrainedModel) and name != ""
|
||||
]
|
||||
supports_sdpa_all_modules = (
|
||||
all(sub_models_supporting_sdpa) if len(sub_models_supporting_sdpa) > 0 else model._supports_sdpa
|
||||
)
|
||||
if not supports_sdpa_all_modules:
|
||||
self.skipTest(reason="This models' submodels does not support sdpa")
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname)
|
||||
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, attn_implementation="sdpa")
|
||||
@ -3848,8 +3860,20 @@ class ModelTesterMixin:
|
||||
"Cannot compile forward without an existing cache with Hybrid, as `torch._dynamo.mark_static_address` "
|
||||
"is a forbidden call."
|
||||
)
|
||||
|
||||
model = model_class(config)
|
||||
|
||||
sub_models_supporting_sdpa = [
|
||||
module._supports_sdpa
|
||||
for name, module in model.named_modules()
|
||||
if isinstance(module, PreTrainedModel) and name != ""
|
||||
]
|
||||
supports_sdpa_all_modules = (
|
||||
all(sub_models_supporting_sdpa) if len(sub_models_supporting_sdpa) > 0 else model._supports_sdpa
|
||||
)
|
||||
if not supports_sdpa_all_modules:
|
||||
self.skipTest(reason="This models' submodels does not support sdpa")
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname)
|
||||
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, attn_implementation="sdpa")
|
||||
|
Loading…
Reference in New Issue
Block a user