mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-23 14:29:01 +06:00
Use public export API on torch 2.5 and future (#36781)
Co-authored-by: Guang Yang <guangyang@fb.com> Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com>
This commit is contained in:
parent
8f6b27eb5c
commit
ae34bd75fd
@ -19,7 +19,7 @@ from ..utils.import_utils import is_torch_available
|
|||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
from transformers import PreTrainedModel, StaticCache
|
from transformers import PreTrainedModel, StaticCache
|
||||||
from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_3
|
from transformers.pytorch_utils import is_torch_greater_or_equal, is_torch_greater_or_equal_than_2_3
|
||||||
|
|
||||||
|
|
||||||
class TorchExportableModuleWithStaticCache(torch.nn.Module):
|
class TorchExportableModuleWithStaticCache(torch.nn.Module):
|
||||||
@ -193,7 +193,6 @@ def convert_and_export_with_cache(
|
|||||||
Returns:
|
Returns:
|
||||||
Exported program (`torch.export.ExportedProgram`): The exported program generated via `torch.export`.
|
Exported program (`torch.export.ExportedProgram`): The exported program generated via `torch.export`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if not is_torch_greater_or_equal_than_2_3:
|
if not is_torch_greater_or_equal_than_2_3:
|
||||||
raise ImportError("torch >= 2.3 is required.")
|
raise ImportError("torch >= 2.3 is required.")
|
||||||
|
|
||||||
@ -208,15 +207,25 @@ def convert_and_export_with_cache(
|
|||||||
example_cache_position if example_cache_position is not None else torch.tensor([0], dtype=torch.long)
|
example_cache_position if example_cache_position is not None else torch.tensor([0], dtype=torch.long)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Due to issue https://github.com/pytorch/pytorch/issues/128394, we need to switch to use an internal
|
if is_torch_greater_or_equal("2.5.0"):
|
||||||
# export API and pre_dispatch=False. Switch to use the public API once the issue is included in 2.5 release.
|
exported_program = torch.export.export(
|
||||||
exported_program = torch.export._trace._export(
|
TorchExportableModuleWithStaticCache(model),
|
||||||
TorchExportableModuleWithStaticCache(model),
|
args=(example_input_ids,),
|
||||||
args=(example_input_ids,),
|
kwargs={"cache_position": example_cache_position},
|
||||||
kwargs={"cache_position": example_cache_position},
|
strict=True,
|
||||||
pre_dispatch=False,
|
)
|
||||||
strict=True,
|
else:
|
||||||
)
|
# We have to keep this path for BC.
|
||||||
|
#
|
||||||
|
# Due to issue https://github.com/pytorch/pytorch/issues/128394, we need to switch to use an internal
|
||||||
|
# export API and pre_dispatch=False. Switch to use the public API once the issue is included in 2.5 release.
|
||||||
|
exported_program = torch.export._trace._export(
|
||||||
|
TorchExportableModuleWithStaticCache(model),
|
||||||
|
args=(example_input_ids,),
|
||||||
|
kwargs={"cache_position": example_cache_position},
|
||||||
|
pre_dispatch=False,
|
||||||
|
strict=True,
|
||||||
|
)
|
||||||
return exported_program
|
return exported_program
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user