diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 7fedc4e7544..6a47324a333 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -4436,10 +4436,15 @@ class ModelTesterMixin: model.save_pretrained(tmpdirname) model = model_class.from_pretrained(tmpdirname, torch_dtype=torch_dtype) - supports_fa2_all_modules = all( + sub_models_supporting_fa2 = [ module._supports_flash_attn_2 for name, module in model.named_modules() if isinstance(module, PreTrainedModel) and name != "" + ] + supports_fa2_all_modules = ( + all(sub_models_supporting_fa2) + if len(sub_models_supporting_fa2) > 0 + else model._supports_flash_attn_2 ) if not supports_fa2_all_modules: with self.assertRaises(ValueError):