Update test_flash_attn_2_can_dispatch_composite_models (#36050)

* update

* update

* update

---------

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yih-Dar 2025-02-06 12:09:49 +01:00 committed by GitHub
parent 37faa97d9b
commit dce9970884
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -4436,10 +4436,15 @@ class ModelTesterMixin:
model.save_pretrained(tmpdirname)
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch_dtype)
supports_fa2_all_modules = all(
sub_models_supporting_fa2 = [
module._supports_flash_attn_2
for name, module in model.named_modules()
if isinstance(module, PreTrainedModel) and name != ""
]
supports_fa2_all_modules = (
all(sub_models_supporting_fa2)
if len(sub_models_supporting_fa2) > 0
else model._supports_flash_attn_2
)
if not supports_fa2_all_modules:
with self.assertRaises(ValueError):