mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 18:22:34 +06:00
Update test_flash_attn_2_can_dispatch_composite_models
(#36050)
* update * update * update --------- Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
parent
37faa97d9b
commit
dce9970884
@ -4436,10 +4436,15 @@ class ModelTesterMixin:
|
|||||||
model.save_pretrained(tmpdirname)
|
model.save_pretrained(tmpdirname)
|
||||||
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch_dtype)
|
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch_dtype)
|
||||||
|
|
||||||
supports_fa2_all_modules = all(
|
sub_models_supporting_fa2 = [
|
||||||
module._supports_flash_attn_2
|
module._supports_flash_attn_2
|
||||||
for name, module in model.named_modules()
|
for name, module in model.named_modules()
|
||||||
if isinstance(module, PreTrainedModel) and name != ""
|
if isinstance(module, PreTrainedModel) and name != ""
|
||||||
|
]
|
||||||
|
supports_fa2_all_modules = (
|
||||||
|
all(sub_models_supporting_fa2)
|
||||||
|
if len(sub_models_supporting_fa2) > 0
|
||||||
|
else model._supports_flash_attn_2
|
||||||
)
|
)
|
||||||
if not supports_fa2_all_modules:
|
if not supports_fa2_all_modules:
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
|
Loading…
Reference in New Issue
Block a user