Repurpose torchdynamo training args towards torch._dynamo (#20498)

* Repurpose torchdynamo training args towards torch._dynamo

* Add doc
This commit is contained in:
Sylvain Gugger 2022-11-30 11:10:45 -05:00 committed by GitHub
parent 829374e4fc
commit 08b4621899
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 53 additions and 73 deletions

View File

@ -720,16 +720,25 @@ Another use case for training on many GPUs is if the model does not fit on a sin
## Inference with torchdynamo
TorchDynamo is a new tracer that uses Pythons frame evaluation API to automatically create FX traces from existing PyTorch programs. After capturing the FX graph, different backends can be deployed to lower the graph to an optimized engine. One solution is using the [TensorRT](https://developer.nvidia.com/tensorrt) or NVFuser as backend. You can choose one option below for performance boost.
TorchDynamo is a new tracer that uses Pythons frame evaluation API to automatically create FX traces from existing PyTorch programs. After capturing the FX graph, different backends can be deployed to lower the graph to an optimized engine. You can choose one option below for performance boost.
```
TrainingArguments(torchdynamo="eager") #enable eager model GPU. No performance boost
TrainingArguments(torchdynamo="nvfuser") #enable nvfuser
TrainingArguments(torchdynamo="fx2trt") #enable tensorRT fp32
TrainingArguments(torchdynamo="fx2trt-f16") #enable tensorRT fp16
```
TorchDynamo has a growing list of backends, which can be found in [backends.py](https://github.com/pytorch/pytorch/blob/master/torch/_dynamo/optimizations/backends.py)
or `torchdynamo.list_backends()` each of which with its optional dependencies.
This feature involves 3 different libraries. To install them, please follow the instructions below:
- [Torchdynamo installation](https://github.com/pytorch/torchdynamo#requirements-and-setup)
- [Functorch installation](https://github.com/pytorch/functorch#install)
- [Torch-TensorRT(FX) installation](https://github.com/pytorch/TensorRT/blob/master/docsrc/tutorials/getting_started_with_fx_path.rst#installation)
Some of the most commonly used backends are
**Debugging backends**:
* `dynamo.optimize("eager")` - Uses PyTorch to run the extracted GraphModule. This is quite useful in debugging TorchDynamo issues.
* `dynamo.optimize("aot_eager")` - Uses AotAutograd with no compiler, i.e, just using PyTorch eager for the AotAutograd's extracted forward and backward graphs. This is useful for debugging, and unlikely to give speedups.
**Training & inference backends**:
* `dynamo.optimize("inductor")` - Uses TorchInductor backend with AotAutograd and cudagraphs by leveraging codegened Triton kernels [Read more](https://dev-discuss.pytorch.org/t/torchinductor-a-pytorch-native-compiler-with-define-by-run-ir-and-symbolic-shapes/747)
* `dynamo.optimize("nvfuser")` - nvFuser with TorchScript. [Read more](https://dev-discuss.pytorch.org/t/tracing-with-primitives-update-1-nvfuser-and-its-primitives/593)
* `dynamo.optimize("aot_nvfuser")` - nvFuser with AotAutograd. [Read more](https://dev-discuss.pytorch.org/t/tracing-with-primitives-update-1-nvfuser-and-its-primitives/593)
* `dynamo.optimize("aot_cudagraphs")` - cudagraphs with AotAutograd. [Read more](https://github.com/pytorch/torchdynamo/pull/757)
**Inference-only backend**s:
* `dynamo.optimize("ofi")` - Uses Torchscript optimize_for_inference. [Read more](https://pytorch.org/docs/stable/generated/torch.jit.optimize_for_inference.html)
* `dynamo.optimize("fx2trt")` - Uses Nvidia TensorRT for inference optimizations. [Read more](https://github.com/pytorch/TensorRT/blob/master/docsrc/tutorials/getting_started_with_fx_path.rst)
* `dynamo.optimize("onnxrt")` - Uses ONNXRT for inference on CPU/GPU. [Read more](https://onnxruntime.ai/)
* `dynamo.optimize("ipex")` - Uses IPEX for inference on CPU. [Read more](https://github.com/intel/intel-extension-for-pytorch)

View File

@ -144,7 +144,6 @@ from .utils import (
is_ipex_available,
is_sagemaker_dp_enabled,
is_sagemaker_mp_enabled,
is_torch_tensorrt_fx_available,
is_torch_tpu_available,
is_torchdynamo_available,
logging,
@ -637,32 +636,8 @@ class Trainer:
self._memory_tracker.stop_and_update_metrics()
# torchdynamo
if args.torchdynamo:
if not is_torchdynamo_available():
raise RuntimeError("Torchdynamo is not installed.")
import torchdynamo
from torchdynamo.optimizations import backends
def get_ctx():
# Normal
if args.torchdynamo == "eager":
return torchdynamo.optimize("eager")
elif args.torchdynamo == "nvfuser":
return torchdynamo.optimize("aot_nvfuser")
# TensorRT
if args.torchdynamo in ["fx2trt-fp16", "fx2trt"]:
if not is_torch_tensorrt_fx_available():
raise RuntimeError("Torch-TensorRT FX path is not installed.")
if args.torchdynamo == "fx2trt-fp16":
return torchdynamo.optimize(backends.fx2trt_compiler_fp16)
elif args.torchdynamo == "fx2trt":
return torchdynamo.optimize(backends.fx2trt_compiler)
else:
raise RuntimeError(f"Torchdynamo backend {args.torchdynamo} is not supported.")
self.ctx_manager_torchdynamo = get_ctx()
else:
self.ctx_manager_torchdynamo = contextlib.nullcontext()
if args.torchdynamo is not None and not is_torchdynamo_available():
raise RuntimeError("Using torchdynamo requires a nighly install of PyTorch.")
def add_callback(self, callback):
"""
@ -1339,6 +1314,10 @@ class Trainer:
return model
def _wrap_model(self, model, training=True, dataloader=None):
if self.args.torchdynamo is not None:
import torch._dynamo as dynamo
model = dynamo.optimize(self.args.torchdynamo)(model)
if self.args.use_ipex:
dtype = torch.bfloat16 if self.use_cpu_amp else torch.float32
model = self.ipex_optimize_model(model, training, dtype=dtype)
@ -2494,18 +2473,7 @@ class Trainer:
"""
A helper wrapper to group together context managers.
"""
return ContextManagers(
[
self.torchdynamo_smart_context_manager(),
self.autocast_smart_context_manager(),
]
)
def torchdynamo_smart_context_manager(self):
"""
A helper wrapper that creates an appropriate context manager for `torchdynamo`.
"""
return self.ctx_manager_torchdynamo
return self.autocast_smart_context_manager()
def autocast_smart_context_manager(self, cache_enabled: Optional[bool] = True):
"""

View File

@ -73,6 +73,20 @@ log_levels = logging.get_log_levels_dict().copy()
trainer_log_levels = dict(**log_levels, passive=-1)
DYNAMO_BACKENDS = [
"eager",
"aot_eager",
"inductor",
"nvfuser",
"aot_nvfuser",
"aot_cudagraphs",
"ofi",
"fx2trt",
"onnxrt",
"ipex",
]
def default_logdir() -> str:
"""
Same default as PyTorch
@ -485,8 +499,8 @@ class TrainingArguments:
If `True`, [`enable_full_determinism`] is called instead of [`set_seed`] to ensure reproducible results in
distributed training
torchdynamo (`str`, *optional*):
The token that is used to set the backend compiler for TorchDynamo. Possible choices are ["eager",
"nvfuser]. This is an experimental API and subject to change.
If set, the backend compiler for TorchDynamo. Possible choices are `"eager"`, `"aot_eager"`, `"inductor"`,
`"nvfuser"`, `"aot_nvfuser"`, `"aot_cudagraphs"`, `"ofi"`, `"fx2trt"`, `"onnxrt"` and `"ipex"`.
ray_scope (`str`, *optional*, defaults to `"last"`):
The scope to use when doing hyperparameter search with Ray. By default, `"last"` will be used. Ray will
then use the last checkpoint of all trials, compare those, and select the best one. However, other options
@ -969,15 +983,8 @@ class TrainingArguments:
torchdynamo: Optional[str] = field(
default=None,
metadata={
"help": (
"Sets up the backend compiler for TorchDynamo. TorchDynamo is a Python level JIT compiler designed to"
" make unmodified PyTorch programs faster. TorchDynamo dynamically modifies the Python bytecode right"
" before its executed. It rewrites Python bytecode to extract sequences of PyTorch operations"
" and lifts them up into Fx graph. We can then pass these Fx graphs to other backend compilers. There"
" are two options - eager and nvfuser. Eager defaults to pytorch eager and is useful for debugging."
" nvfuser path uses AOT Autograd and nvfuser compiler to optimize the models."
),
"choices": ["eager", "nvfuser", "fx2trt", "fx2trt-fp16"],
"help": "Sets up the backend compiler for TorchDynamo.",
"choices": DYNAMO_BACKENDS,
},
)
ray_scope: Optional[str] = field(

View File

@ -445,7 +445,14 @@ def is_torch_tpu_available(check_device=True):
def is_torchdynamo_available():
return importlib.util.find_spec("torchdynamo") is not None
if not is_torch_available():
return False
try:
import torch._dynamo as dynamo # noqa: F401
return True
except Exception:
return False
def is_torch_tensorrt_fx_available():

View File

@ -1839,20 +1839,9 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
# 4. TorchDynamo fx2trt
trainer = get_regression_trainer(a=a, b=b, eval_len=eval_len, torchdynamo="fx2trt")
metrics = trainer.evaluate()
t1 = metrics["eval_loss"]
t2 = original_eval_loss
self.assertAlmostEqual(metrics["eval_loss"], original_eval_loss)
torchdynamo.reset()
# 5. TorchDynamo fx2trt-fp16
trainer = get_regression_trainer(a=a, b=b, eval_len=eval_len, torchdynamo="fx2trt-fp16")
metrics = trainer.evaluate()
t1 = metrics["eval_loss"]
t2 = original_eval_loss
# fp16 has accuracy accuracy degradation
self.assertLess(np.max(np.abs(t1 - t2)), 1e-3)
torchdynamo.reset()
@require_torch_non_multi_gpu
@require_torchdynamo
def test_torchdynamo_memory(self):