mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-23 06:20:22 +06:00
Use public export API on torch 2.5 and future (#36781)
Co-authored-by: Guang Yang <guangyang@fb.com> Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com>
This commit is contained in:
parent
8f6b27eb5c
commit
ae34bd75fd
@ -19,7 +19,7 @@ from ..utils.import_utils import is_torch_available
|
||||
|
||||
if is_torch_available():
|
||||
from transformers import PreTrainedModel, StaticCache
|
||||
from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_3
|
||||
from transformers.pytorch_utils import is_torch_greater_or_equal, is_torch_greater_or_equal_than_2_3
|
||||
|
||||
|
||||
class TorchExportableModuleWithStaticCache(torch.nn.Module):
|
||||
@ -193,7 +193,6 @@ def convert_and_export_with_cache(
|
||||
Returns:
|
||||
Exported program (`torch.export.ExportedProgram`): The exported program generated via `torch.export`.
|
||||
"""
|
||||
|
||||
if not is_torch_greater_or_equal_than_2_3:
|
||||
raise ImportError("torch >= 2.3 is required.")
|
||||
|
||||
@ -208,15 +207,25 @@ def convert_and_export_with_cache(
|
||||
example_cache_position if example_cache_position is not None else torch.tensor([0], dtype=torch.long)
|
||||
)
|
||||
|
||||
# Due to issue https://github.com/pytorch/pytorch/issues/128394, we need to switch to use an internal
|
||||
# export API and pre_dispatch=False. Switch to use the public API once the issue is included in 2.5 release.
|
||||
exported_program = torch.export._trace._export(
|
||||
TorchExportableModuleWithStaticCache(model),
|
||||
args=(example_input_ids,),
|
||||
kwargs={"cache_position": example_cache_position},
|
||||
pre_dispatch=False,
|
||||
strict=True,
|
||||
)
|
||||
if is_torch_greater_or_equal("2.5.0"):
|
||||
exported_program = torch.export.export(
|
||||
TorchExportableModuleWithStaticCache(model),
|
||||
args=(example_input_ids,),
|
||||
kwargs={"cache_position": example_cache_position},
|
||||
strict=True,
|
||||
)
|
||||
else:
|
||||
# We have to keep this path for BC.
|
||||
#
|
||||
# Due to issue https://github.com/pytorch/pytorch/issues/128394, we need to switch to use an internal
|
||||
# export API and pre_dispatch=False. Switch to use the public API once the issue is included in 2.5 release.
|
||||
exported_program = torch.export._trace._export(
|
||||
TorchExportableModuleWithStaticCache(model),
|
||||
args=(example_input_ids,),
|
||||
kwargs={"cache_position": example_cache_position},
|
||||
pre_dispatch=False,
|
||||
strict=True,
|
||||
)
|
||||
return exported_program
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user