Allow customization of sdpa in executorch.py (#38827)

Earlier PR put executorch specific sdpa and mask function in the export function. This prevent any customization that can be done to sdpa, prior to export. By moving this to __init__, we still keep the original behavior but allow users like optimum-executorch to override sdpa by setting model.config._attn_implementation.
This commit is contained in:
Kimish Patel 2025-06-17 01:38:20 -07:00 committed by GitHub
parent 9c878d2f64
commit 37367c7d9f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -65,6 +65,10 @@ class TorchExportableModuleForDecoderOnlyLM(torch.nn.Module):
"Using `StaticCache` for export as `layer_types` is not specified or `sliding_window` is `null` in the config."
)
self.model = TorchExportableModuleWithStaticCache(model)
# This is the same as sdpa, but mask creation does not use `vmap` which is not exportable
ALL_MASK_ATTENTION_FUNCTIONS.register("sdpa_without_vmap", sdpa_mask_without_vmap)
ALL_ATTENTION_FUNCTIONS.register("sdpa_without_vmap", ALL_ATTENTION_FUNCTIONS["sdpa"])
self.model.model.config._attn_implementation = "sdpa_without_vmap"
def forward(
self,
@ -103,10 +107,6 @@ class TorchExportableModuleForDecoderOnlyLM(torch.nn.Module):
strict(`Optional[bool]`):
Flag to instruct `torch.export` to use `torchdynamo`.
"""
# This is the same as sdpa, but mask creation does not use `vmap` which is not exportable
ALL_MASK_ATTENTION_FUNCTIONS.register("sdpa_without_vmap", sdpa_mask_without_vmap)
ALL_ATTENTION_FUNCTIONS.register("sdpa_without_vmap", ALL_ATTENTION_FUNCTIONS["sdpa"])
self.model.model.config._attn_implementation = "sdpa_without_vmap"
example_input_ids = input_ids if input_ids is not None else torch.tensor([[1]], dtype=torch.long)
example_cache_position = cache_position if cache_position is not None else torch.tensor([0], dtype=torch.long)