diff --git a/src/transformers/integrations/executorch.py b/src/transformers/integrations/executorch.py index a16114fe953..0df283c83b7 100644 --- a/src/transformers/integrations/executorch.py +++ b/src/transformers/integrations/executorch.py @@ -65,6 +65,10 @@ class TorchExportableModuleForDecoderOnlyLM(torch.nn.Module): "Using `StaticCache` for export as `layer_types` is not specified or `sliding_window` is `null` in the config." ) self.model = TorchExportableModuleWithStaticCache(model) + # This is the same as sdpa, but mask creation does not use `vmap` which is not exportable + ALL_MASK_ATTENTION_FUNCTIONS.register("sdpa_without_vmap", sdpa_mask_without_vmap) + ALL_ATTENTION_FUNCTIONS.register("sdpa_without_vmap", ALL_ATTENTION_FUNCTIONS["sdpa"]) + self.model.model.config._attn_implementation = "sdpa_without_vmap" def forward( self, @@ -103,10 +107,6 @@ class TorchExportableModuleForDecoderOnlyLM(torch.nn.Module): strict(`Optional[bool]`): Flag to instruct `torch.export` to use `torchdynamo`. """ - # This is the same as sdpa, but mask creation does not use `vmap` which is not exportable - ALL_MASK_ATTENTION_FUNCTIONS.register("sdpa_without_vmap", sdpa_mask_without_vmap) - ALL_ATTENTION_FUNCTIONS.register("sdpa_without_vmap", ALL_ATTENTION_FUNCTIONS["sdpa"]) - self.model.model.config._attn_implementation = "sdpa_without_vmap" example_input_ids = input_ids if input_ids is not None else torch.tensor([[1]], dtype=torch.long) example_cache_position = cache_position if cache_position is not None else torch.tensor([0], dtype=torch.long)