Use public export API on torch 2.5 and future (#36781)

Co-authored-by: Guang Yang <guangyang@fb.com>
Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com>
This commit is contained in:
Guang Yang 2025-04-01 02:47:38 -07:00 committed by GitHub
parent 8f6b27eb5c
commit ae34bd75fd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -19,7 +19,7 @@ from ..utils.import_utils import is_torch_available
if is_torch_available(): if is_torch_available():
from transformers import PreTrainedModel, StaticCache from transformers import PreTrainedModel, StaticCache
from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_3 from transformers.pytorch_utils import is_torch_greater_or_equal, is_torch_greater_or_equal_than_2_3
class TorchExportableModuleWithStaticCache(torch.nn.Module): class TorchExportableModuleWithStaticCache(torch.nn.Module):
@ -193,7 +193,6 @@ def convert_and_export_with_cache(
Returns: Returns:
Exported program (`torch.export.ExportedProgram`): The exported program generated via `torch.export`. Exported program (`torch.export.ExportedProgram`): The exported program generated via `torch.export`.
""" """
if not is_torch_greater_or_equal_than_2_3: if not is_torch_greater_or_equal_than_2_3:
raise ImportError("torch >= 2.3 is required.") raise ImportError("torch >= 2.3 is required.")
@ -208,15 +207,25 @@ def convert_and_export_with_cache(
example_cache_position if example_cache_position is not None else torch.tensor([0], dtype=torch.long) example_cache_position if example_cache_position is not None else torch.tensor([0], dtype=torch.long)
) )
# Due to issue https://github.com/pytorch/pytorch/issues/128394, we need to switch to use an internal if is_torch_greater_or_equal("2.5.0"):
# export API and pre_dispatch=False. Switch to use the public API once the issue is included in 2.5 release. exported_program = torch.export.export(
exported_program = torch.export._trace._export( TorchExportableModuleWithStaticCache(model),
TorchExportableModuleWithStaticCache(model), args=(example_input_ids,),
args=(example_input_ids,), kwargs={"cache_position": example_cache_position},
kwargs={"cache_position": example_cache_position}, strict=True,
pre_dispatch=False, )
strict=True, else:
) # We have to keep this path for BC.
#
# Due to issue https://github.com/pytorch/pytorch/issues/128394, we need to switch to use an internal
# export API and pre_dispatch=False. Switch to use the public API once the issue is included in 2.5 release.
exported_program = torch.export._trace._export(
TorchExportableModuleWithStaticCache(model),
args=(example_input_ids,),
kwargs={"cache_position": example_cache_position},
pre_dispatch=False,
strict=True,
)
return exported_program return exported_program