mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Remove backend check for torch.compile (#22140)
* Remove backend enforcment for torch.compile * Update error * Update src/transformers/training_args.py Co-authored-by: Stas Bekman <stas00@users.noreply.github.com> * Apply suggestions from code review Co-authored-by: Stas Bekman <stas00@users.noreply.github.com> * Style --------- Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>
This commit is contained in:
parent
618697ef53
commit
3a35937ede
@ -672,7 +672,7 @@ class Trainer:
|
||||
|
||||
# torch.compile
|
||||
if args.torch_compile and not is_torch_compile_available():
|
||||
raise RuntimeError("Using torch.compile requires a nightly install of PyTorch.")
|
||||
raise RuntimeError("Using torch.compile requires PyTorch 2.0 or higher.")
|
||||
|
||||
def add_callback(self, callback):
|
||||
"""
|
||||
|
@ -85,20 +85,6 @@ log_levels = logging.get_log_levels_dict().copy()
|
||||
trainer_log_levels = dict(**log_levels, passive=-1)
|
||||
|
||||
|
||||
TORCH_COMPILE_BACKENDS = [
|
||||
"eager",
|
||||
"aot_eager",
|
||||
"inductor",
|
||||
"nvfuser",
|
||||
"aot_nvfuser",
|
||||
"aot_cudagraphs",
|
||||
"ofi",
|
||||
"fx2trt",
|
||||
"onnxrt",
|
||||
"ipex",
|
||||
]
|
||||
|
||||
|
||||
def default_logdir() -> str:
|
||||
"""
|
||||
Same default as PyTorch
|
||||
@ -571,17 +557,24 @@ class TrainingArguments:
|
||||
Whether or not to compile the model using PyTorch 2.0
|
||||
[`torch.compile`](https://pytorch.org/get-started/pytorch-2.0/) (requires a nighlty install of PyTorch).
|
||||
|
||||
If set, the backend will default to `"inductor"` (can be customized with `torch_compile_backend`) and the
|
||||
mode will default to `"default"` (can be customized with `torch_compile_mode`).
|
||||
This will use the best defaults for the [`torch.compile`
|
||||
API](https://pytorch.org/docs/2.0/generated/torch.compile.html?highlight=torch+compile#torch.compile). You
|
||||
can customize the defaults with the argument `torch_compile_backend` and `torch_compile_mode` but we don't
|
||||
guarantee any of them will work as the support is progressively rolled in in PyTorch.
|
||||
|
||||
This flag and the whole compile API is experimental and subject to change in future releases.
|
||||
torch_compile_backend (`str`, *optional*):
|
||||
The backend to use in `torch.compile`. If set to any value, `torch_compile` will be set to `True`.
|
||||
|
||||
Possible choices are `"eager"`, `"aot_eager"`, `"inductor"`, `"nvfuser"`, `"aot_nvfuser"`,
|
||||
`"aot_cudagraphs"`, `"ofi"`, `"fx2trt"`, `"onnxrt"` and `"ipex"`.
|
||||
Refer to the PyTorch doc for possible values and note that they may change across PyTorch versions.
|
||||
|
||||
This flag is experimental and subject to change in future releases.
|
||||
torch_compile_mode (`str`, *optional*):
|
||||
The mode to use in `torch.compile`. If set to any value, `torch_compile` will be set to `True`.
|
||||
|
||||
Possible choices are `"default"`, `"reduce-overhead"` and `"max-autotune"`.
|
||||
Refer to the PyTorch doc for possible values and note that they may change across PyTorch versions.
|
||||
|
||||
This flag is experimental and subject to change in future releases.
|
||||
"""
|
||||
|
||||
framework = "pt"
|
||||
@ -1061,7 +1054,6 @@ class TrainingArguments:
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "This argument is deprecated, use `--torch_compile_backend` instead.",
|
||||
"choices": TORCH_COMPILE_BACKENDS,
|
||||
},
|
||||
)
|
||||
ray_scope: Optional[str] = field(
|
||||
@ -1090,14 +1082,12 @@ class TrainingArguments:
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Which backend to use with `torch.compile`, passing one will trigger a model compilation.",
|
||||
"choices": TORCH_COMPILE_BACKENDS,
|
||||
},
|
||||
)
|
||||
torch_compile_mode: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Which mode to use with `torch.compile`, passing one will trigger a model compilation.",
|
||||
"choices": ["default", "reduce-overhead", "max-autotune"],
|
||||
},
|
||||
)
|
||||
|
||||
|
@ -478,6 +478,8 @@ def is_torch_compile_available():
|
||||
|
||||
import torch
|
||||
|
||||
# We don't do any version check here to support nighlies marked as 1.14. Ultimately needs to check version against
|
||||
# 2.0 but let's do it later.
|
||||
return hasattr(torch, "compile")
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user