Use public export API on torch 2.5 and future (#36781)

Co-authored-by: Guang Yang <guangyang@fb.com>
Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com>
This commit is contained in:
Guang Yang 2025-04-01 02:47:38 -07:00 committed by GitHub
parent 8f6b27eb5c
commit ae34bd75fd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -19,7 +19,7 @@ from ..utils.import_utils import is_torch_available
if is_torch_available():
from transformers import PreTrainedModel, StaticCache
from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_3
from transformers.pytorch_utils import is_torch_greater_or_equal, is_torch_greater_or_equal_than_2_3
class TorchExportableModuleWithStaticCache(torch.nn.Module):
@ -193,7 +193,6 @@ def convert_and_export_with_cache(
Returns:
Exported program (`torch.export.ExportedProgram`): The exported program generated via `torch.export`.
"""
if not is_torch_greater_or_equal_than_2_3:
raise ImportError("torch >= 2.3 is required.")
@ -208,6 +207,16 @@ def convert_and_export_with_cache(
example_cache_position if example_cache_position is not None else torch.tensor([0], dtype=torch.long)
)
if is_torch_greater_or_equal("2.5.0"):
exported_program = torch.export.export(
TorchExportableModuleWithStaticCache(model),
args=(example_input_ids,),
kwargs={"cache_position": example_cache_position},
strict=True,
)
else:
# We have to keep this path for BC.
#
# Due to issue https://github.com/pytorch/pytorch/issues/128394, we need to switch to use an internal
# export API and pre_dispatch=False. Switch to use the public API once the issue is included in 2.5 release.
exported_program = torch.export._trace._export(