diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 8a69b2e0a3a..4c7cef05c35 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -3799,8 +3799,20 @@ class ModelTesterMixin: self.skipTest(reason="Idefics currently (transformers==4.39.1) requires an image_attention_mask input") if config.model_type in ["sam"]: self.skipTest(reason="SAM requires an attention_mask input for relative positional embeddings") + model = model_class(config) + sub_models_supporting_sdpa = [ + module._supports_sdpa + for name, module in model.named_modules() + if isinstance(module, PreTrainedModel) and name != "" + ] + supports_sdpa_all_modules = ( + all(sub_models_supporting_sdpa) if len(sub_models_supporting_sdpa) > 0 else model._supports_sdpa + ) + if not supports_sdpa_all_modules: + self.skipTest(reason="This models' submodels does not support sdpa") + with tempfile.TemporaryDirectory() as tmpdirname: model.save_pretrained(tmpdirname) model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, attn_implementation="sdpa") @@ -3848,8 +3860,20 @@ class ModelTesterMixin: "Cannot compile forward without an existing cache with Hybrid, as `torch._dynamo.mark_static_address` " "is a forbidden call." ) + model = model_class(config) + sub_models_supporting_sdpa = [ + module._supports_sdpa + for name, module in model.named_modules() + if isinstance(module, PreTrainedModel) and name != "" + ] + supports_sdpa_all_modules = ( + all(sub_models_supporting_sdpa) if len(sub_models_supporting_sdpa) > 0 else model._supports_sdpa + ) + if not supports_sdpa_all_modules: + self.skipTest(reason="This models' submodels does not support sdpa") + with tempfile.TemporaryDirectory() as tmpdirname: model.save_pretrained(tmpdirname) model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, attn_implementation="sdpa")