Remove backend check for torch.compile (#22140)

* Remove backend enforcment for torch.compile

* Update error

* Update src/transformers/training_args.py

Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>

* Apply suggestions from code review

Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>

* Style

---------

Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>
This commit is contained in:
Sylvain Gugger 2023-03-13 16:34:00 -04:00 committed by GitHub
parent 618697ef53
commit 3a35937ede
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 15 additions and 23 deletions

View File

@ -672,7 +672,7 @@ class Trainer:
# torch.compile
if args.torch_compile and not is_torch_compile_available():
raise RuntimeError("Using torch.compile requires a nightly install of PyTorch.")
raise RuntimeError("Using torch.compile requires PyTorch 2.0 or higher.")
def add_callback(self, callback):
"""

View File

@ -85,20 +85,6 @@ log_levels = logging.get_log_levels_dict().copy()
trainer_log_levels = dict(**log_levels, passive=-1)
TORCH_COMPILE_BACKENDS = [
"eager",
"aot_eager",
"inductor",
"nvfuser",
"aot_nvfuser",
"aot_cudagraphs",
"ofi",
"fx2trt",
"onnxrt",
"ipex",
]
def default_logdir() -> str:
"""
Same default as PyTorch
@ -571,17 +557,24 @@ class TrainingArguments:
Whether or not to compile the model using PyTorch 2.0
[`torch.compile`](https://pytorch.org/get-started/pytorch-2.0/) (requires a nighlty install of PyTorch).
If set, the backend will default to `"inductor"` (can be customized with `torch_compile_backend`) and the
mode will default to `"default"` (can be customized with `torch_compile_mode`).
This will use the best defaults for the [`torch.compile`
API](https://pytorch.org/docs/2.0/generated/torch.compile.html?highlight=torch+compile#torch.compile). You
can customize the defaults with the argument `torch_compile_backend` and `torch_compile_mode` but we don't
guarantee any of them will work as the support is progressively rolled in in PyTorch.
This flag and the whole compile API is experimental and subject to change in future releases.
torch_compile_backend (`str`, *optional*):
The backend to use in `torch.compile`. If set to any value, `torch_compile` will be set to `True`.
Possible choices are `"eager"`, `"aot_eager"`, `"inductor"`, `"nvfuser"`, `"aot_nvfuser"`,
`"aot_cudagraphs"`, `"ofi"`, `"fx2trt"`, `"onnxrt"` and `"ipex"`.
Refer to the PyTorch doc for possible values and note that they may change across PyTorch versions.
This flag is experimental and subject to change in future releases.
torch_compile_mode (`str`, *optional*):
The mode to use in `torch.compile`. If set to any value, `torch_compile` will be set to `True`.
Possible choices are `"default"`, `"reduce-overhead"` and `"max-autotune"`.
Refer to the PyTorch doc for possible values and note that they may change across PyTorch versions.
This flag is experimental and subject to change in future releases.
"""
framework = "pt"
@ -1061,7 +1054,6 @@ class TrainingArguments:
default=None,
metadata={
"help": "This argument is deprecated, use `--torch_compile_backend` instead.",
"choices": TORCH_COMPILE_BACKENDS,
},
)
ray_scope: Optional[str] = field(
@ -1090,14 +1082,12 @@ class TrainingArguments:
default=None,
metadata={
"help": "Which backend to use with `torch.compile`, passing one will trigger a model compilation.",
"choices": TORCH_COMPILE_BACKENDS,
},
)
torch_compile_mode: Optional[str] = field(
default=None,
metadata={
"help": "Which mode to use with `torch.compile`, passing one will trigger a model compilation.",
"choices": ["default", "reduce-overhead", "max-autotune"],
},
)

View File

@ -478,6 +478,8 @@ def is_torch_compile_available():
import torch
# We don't do any version check here to support nighlies marked as 1.14. Ultimately needs to check version against
# 2.0 but let's do it later.
return hasattr(torch, "compile")