From ae34bd75fdfe4fc9f773f0bfebab6fe163513dba Mon Sep 17 00:00:00 2001 From: Guang Yang <42389959+guangy10@users.noreply.github.com> Date: Tue, 1 Apr 2025 02:47:38 -0700 Subject: [PATCH] Use public export API on torch 2.5 and future (#36781) Co-authored-by: Guang Yang Co-authored-by: Pavel Iakubovskii --- src/transformers/integrations/executorch.py | 31 +++++++++++++-------- 1 file changed, 20 insertions(+), 11 deletions(-) diff --git a/src/transformers/integrations/executorch.py b/src/transformers/integrations/executorch.py index 09fd0c387fd..b0a7f904c96 100644 --- a/src/transformers/integrations/executorch.py +++ b/src/transformers/integrations/executorch.py @@ -19,7 +19,7 @@ from ..utils.import_utils import is_torch_available if is_torch_available(): from transformers import PreTrainedModel, StaticCache - from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_3 + from transformers.pytorch_utils import is_torch_greater_or_equal, is_torch_greater_or_equal_than_2_3 class TorchExportableModuleWithStaticCache(torch.nn.Module): @@ -193,7 +193,6 @@ def convert_and_export_with_cache( Returns: Exported program (`torch.export.ExportedProgram`): The exported program generated via `torch.export`. """ - if not is_torch_greater_or_equal_than_2_3: raise ImportError("torch >= 2.3 is required.") @@ -208,15 +207,25 @@ def convert_and_export_with_cache( example_cache_position if example_cache_position is not None else torch.tensor([0], dtype=torch.long) ) - # Due to issue https://github.com/pytorch/pytorch/issues/128394, we need to switch to use an internal - # export API and pre_dispatch=False. Switch to use the public API once the issue is included in 2.5 release. - exported_program = torch.export._trace._export( - TorchExportableModuleWithStaticCache(model), - args=(example_input_ids,), - kwargs={"cache_position": example_cache_position}, - pre_dispatch=False, - strict=True, - ) + if is_torch_greater_or_equal("2.5.0"): + exported_program = torch.export.export( + TorchExportableModuleWithStaticCache(model), + args=(example_input_ids,), + kwargs={"cache_position": example_cache_position}, + strict=True, + ) + else: + # We have to keep this path for BC. + # + # Due to issue https://github.com/pytorch/pytorch/issues/128394, we need to switch to use an internal + # export API and pre_dispatch=False. Switch to use the public API once the issue is included in 2.5 release. + exported_program = torch.export._trace._export( + TorchExportableModuleWithStaticCache(model), + args=(example_input_ids,), + kwargs={"cache_position": example_cache_position}, + pre_dispatch=False, + strict=True, + ) return exported_program