mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 12:50:06 +06:00
Skip sdpa tests if submodule does not support sdpa (#38907)
This commit is contained in:
parent
5d26a38735
commit
af6120b3eb
@ -3799,8 +3799,20 @@ class ModelTesterMixin:
|
|||||||
self.skipTest(reason="Idefics currently (transformers==4.39.1) requires an image_attention_mask input")
|
self.skipTest(reason="Idefics currently (transformers==4.39.1) requires an image_attention_mask input")
|
||||||
if config.model_type in ["sam"]:
|
if config.model_type in ["sam"]:
|
||||||
self.skipTest(reason="SAM requires an attention_mask input for relative positional embeddings")
|
self.skipTest(reason="SAM requires an attention_mask input for relative positional embeddings")
|
||||||
|
|
||||||
model = model_class(config)
|
model = model_class(config)
|
||||||
|
|
||||||
|
sub_models_supporting_sdpa = [
|
||||||
|
module._supports_sdpa
|
||||||
|
for name, module in model.named_modules()
|
||||||
|
if isinstance(module, PreTrainedModel) and name != ""
|
||||||
|
]
|
||||||
|
supports_sdpa_all_modules = (
|
||||||
|
all(sub_models_supporting_sdpa) if len(sub_models_supporting_sdpa) > 0 else model._supports_sdpa
|
||||||
|
)
|
||||||
|
if not supports_sdpa_all_modules:
|
||||||
|
self.skipTest(reason="This models' submodels does not support sdpa")
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
model.save_pretrained(tmpdirname)
|
model.save_pretrained(tmpdirname)
|
||||||
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, attn_implementation="sdpa")
|
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, attn_implementation="sdpa")
|
||||||
@ -3848,8 +3860,20 @@ class ModelTesterMixin:
|
|||||||
"Cannot compile forward without an existing cache with Hybrid, as `torch._dynamo.mark_static_address` "
|
"Cannot compile forward without an existing cache with Hybrid, as `torch._dynamo.mark_static_address` "
|
||||||
"is a forbidden call."
|
"is a forbidden call."
|
||||||
)
|
)
|
||||||
|
|
||||||
model = model_class(config)
|
model = model_class(config)
|
||||||
|
|
||||||
|
sub_models_supporting_sdpa = [
|
||||||
|
module._supports_sdpa
|
||||||
|
for name, module in model.named_modules()
|
||||||
|
if isinstance(module, PreTrainedModel) and name != ""
|
||||||
|
]
|
||||||
|
supports_sdpa_all_modules = (
|
||||||
|
all(sub_models_supporting_sdpa) if len(sub_models_supporting_sdpa) > 0 else model._supports_sdpa
|
||||||
|
)
|
||||||
|
if not supports_sdpa_all_modules:
|
||||||
|
self.skipTest(reason="This models' submodels does not support sdpa")
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
model.save_pretrained(tmpdirname)
|
model.save_pretrained(tmpdirname)
|
||||||
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, attn_implementation="sdpa")
|
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, attn_implementation="sdpa")
|
||||||
|
Loading…
Reference in New Issue
Block a user