mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 03:01:07 +06:00
Enable torchdynamo with torch_tensorrt(fx path) (#17765)
* enable fx2trt * Update perf_train_gpu_one.mdx * Update perf_train_gpu_one.mdx * add lib check * update * format * update * fix import check * fix isort * improve doc * refactor ctx manager * fix isort * black format * isort fix * fix format * update args * update black * cleanups * Update perf_train_gpu_one.mdx * code refactor * code refactor to init * remove redundancy * isort * replace self.args with args Co-authored-by: Stas Bekman <stas@stason.org>
This commit is contained in:
parent
37aeb5787a
commit
7ea6ccc2b3
@ -718,3 +718,15 @@ For some applications, such as pretraining large language models, applying all t
|
||||
|
||||
Another use case for training on many GPUs is if the model does not fit on a single GPU with all the mentioned tricks. There are still more methods we can apply although life starts to get a bit more complicated. This usually involves some form of pipeline or tensor parallelism where the model itself is distributed across several GPUs. One can also make use of DeepSpeed which implements some of these parallelism strategies along with some more optimization to reduce the memory footprint such as partitioning the optimizer states. You can read more about this in the ["Multi-GPU training" section](perf_train_gpu_many).
|
||||
|
||||
## Inference with torchdynamo
|
||||
TorchDynamo is a new tracer that uses Python’s frame evaluation API to automatically create FX traces from existing PyTorch programs. After capturing the FX graph, different backends can be deployed to lower the graph to an optimized engine. One solution is using the [TensorRT](https://developer.nvidia.com/tensorrt) or NVFuser as backend. You can choose one option below for performance boost.
|
||||
```
|
||||
TrainingArguments(torchdynamo="eager") #enable eager model GPU. No performance boost
|
||||
TrainingArguments(torchdynamo="nvfuser") #enable nvfuser
|
||||
TrainingArguments(torchdynamo="fx2trt") #enable tensorRT fp32
|
||||
TrainingArguments(torchdynamo="fx2trt-f16") #enable tensorRT fp16
|
||||
```
|
||||
This feature involves 3 different libraries. To install them, please follow the instructions below:
|
||||
- [Torchdynamo installation](https://github.com/pytorch/torchdynamo#requirements-and-setup)
|
||||
- [Functorch installation](https://github.com/pytorch/functorch#install)
|
||||
- [Torch-TensorRT(FX) installation](https://github.com/pytorch/TensorRT/blob/master/docsrc/tutorials/getting_started_with_fx_path.rst#installation)
|
||||
|
@ -71,6 +71,7 @@ from .utils import (
|
||||
is_torch_available,
|
||||
is_torch_bf16_cpu_available,
|
||||
is_torch_bf16_gpu_available,
|
||||
is_torch_tensorrt_fx_available,
|
||||
is_torch_tf32_available,
|
||||
is_torch_tpu_available,
|
||||
is_torchaudio_available,
|
||||
@ -499,6 +500,11 @@ def require_torchdynamo(test_case):
|
||||
return unittest.skipUnless(is_torchdynamo_available(), "test requires TorchDynamo")(test_case)
|
||||
|
||||
|
||||
def require_torch_tensorrt_fx(test_case):
|
||||
"""Decorator marking a test that requires Torch-TensorRT FX"""
|
||||
return unittest.skipUnless(is_torch_tensorrt_fx_available(), "test requires Torch-TensorRT FX")(test_case)
|
||||
|
||||
|
||||
def require_torch_gpu(test_case):
|
||||
"""Decorator marking a test that requires CUDA and PyTorch."""
|
||||
return unittest.skipUnless(torch_device == "cuda", "test requires CUDA")(test_case)
|
||||
|
@ -141,6 +141,7 @@ from .utils import (
|
||||
is_ipex_available,
|
||||
is_sagemaker_dp_enabled,
|
||||
is_sagemaker_mp_enabled,
|
||||
is_torch_tensorrt_fx_available,
|
||||
is_torch_tpu_available,
|
||||
is_torchdynamo_available,
|
||||
logging,
|
||||
@ -598,6 +599,35 @@ class Trainer:
|
||||
# very last
|
||||
self._memory_tracker.stop_and_update_metrics()
|
||||
|
||||
# torchdynamo
|
||||
if args.torchdynamo:
|
||||
if not is_torchdynamo_available():
|
||||
raise RuntimeError("Torchdynamo is not installed.")
|
||||
import torchdynamo
|
||||
from torchdynamo.optimizations import backends
|
||||
from torchdynamo.optimizations.training import aot_autograd_speedup_strategy
|
||||
|
||||
def get_ctx():
|
||||
# Normal
|
||||
if args.torchdynamo == "eager":
|
||||
return torchdynamo.optimize("eager")
|
||||
elif args.torchdynamo == "nvfuser":
|
||||
return torchdynamo.optimize(aot_autograd_speedup_strategy)
|
||||
# TensorRT
|
||||
if args.torchdynamo in ["fx2trt-fp16", "fx2trt"]:
|
||||
if not is_torch_tensorrt_fx_available():
|
||||
raise RuntimeError("Torch-TensorRT FX path is not installed.")
|
||||
if args.torchdynamo == "fx2trt-fp16":
|
||||
return torchdynamo.optimize(backends.fx2trt_compiler_fp16)
|
||||
elif args.torchdynamo == "fx2trt":
|
||||
return torchdynamo.optimize(backends.fx2trt_compiler)
|
||||
else:
|
||||
raise RuntimeError(f"Torchdynamo backend {args.torchdynamo} is not supported.")
|
||||
|
||||
self.ctx_manager_torchdynamo = get_ctx()
|
||||
else:
|
||||
self.ctx_manager_torchdynamo = contextlib.nullcontext()
|
||||
|
||||
def add_callback(self, callback):
|
||||
"""
|
||||
Add a callback to the current list of [`~transformer.TrainerCallback`].
|
||||
@ -2291,16 +2321,7 @@ class Trainer:
|
||||
"""
|
||||
A helper wrapper that creates an appropriate context manager for `torchdynamo`.
|
||||
"""
|
||||
ctx_manager = contextlib.nullcontext()
|
||||
if is_torchdynamo_available():
|
||||
import torchdynamo
|
||||
from torchdynamo.optimizations.training import aot_autograd_speedup_strategy
|
||||
|
||||
if self.args.torchdynamo == "eager":
|
||||
ctx_manager = torchdynamo.optimize("eager")
|
||||
elif self.args.torchdynamo == "nvfuser":
|
||||
ctx_manager = torchdynamo.optimize(aot_autograd_speedup_strategy)
|
||||
return ctx_manager
|
||||
return self.ctx_manager_torchdynamo
|
||||
|
||||
def autocast_smart_context_manager(self):
|
||||
"""
|
||||
|
@ -935,7 +935,7 @@ class TrainingArguments:
|
||||
" are two options - eager and nvfuser. Eager defaults to pytorch eager and is useful for debugging."
|
||||
" nvfuser path uses AOT Autograd and nvfuser compiler to optimize the models."
|
||||
),
|
||||
"choices": ["eager", "nvfuser"],
|
||||
"choices": ["eager", "nvfuser", "fx2trt", "fx2trt-fp16"],
|
||||
},
|
||||
)
|
||||
ray_scope: Optional[str] = field(
|
||||
|
@ -132,6 +132,7 @@ from .import_utils import (
|
||||
is_torch_fx_available,
|
||||
is_torch_fx_proxy,
|
||||
is_torch_onnx_dict_inputs_support_available,
|
||||
is_torch_tensorrt_fx_available,
|
||||
is_torch_tf32_available,
|
||||
is_torch_tpu_available,
|
||||
is_torchaudio_available,
|
||||
|
@ -418,6 +418,12 @@ def is_torchdynamo_available():
|
||||
return importlib.util.find_spec("torchdynamo") is not None
|
||||
|
||||
|
||||
def is_torch_tensorrt_fx_available():
|
||||
if importlib.util.find_spec("torch_tensorrt") is None:
|
||||
return False
|
||||
return importlib.util.find_spec("torch_tensorrt.fx") is not None
|
||||
|
||||
|
||||
def is_datasets_available():
|
||||
return _datasets_available
|
||||
|
||||
|
@ -62,6 +62,7 @@ from transformers.testing_utils import (
|
||||
require_torch_gpu,
|
||||
require_torch_multi_gpu,
|
||||
require_torch_non_multi_gpu,
|
||||
require_torch_tensorrt_fx,
|
||||
require_torch_tf32,
|
||||
require_torch_up_to_2_gpus,
|
||||
require_torchdynamo,
|
||||
@ -1796,6 +1797,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
|
||||
@require_torch_non_multi_gpu
|
||||
@require_torchdynamo
|
||||
@require_torch_tensorrt_fx
|
||||
def test_torchdynamo_full_eval(self):
|
||||
# torchdynamo at the moment doesn't support DP/DDP, therefore require a single gpu
|
||||
n_gpus = get_gpu_count()
|
||||
@ -1824,6 +1826,21 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
metrics = trainer.evaluate()
|
||||
self.assertAlmostEqual(metrics["eval_loss"], original_eval_loss)
|
||||
|
||||
# 4. TorchDynamo fx2trt
|
||||
trainer = get_regression_trainer(a=a, b=b, eval_len=eval_len, torchdynamo="fx2trt")
|
||||
metrics = trainer.evaluate()
|
||||
t1 = metrics["eval_loss"]
|
||||
t2 = original_eval_loss
|
||||
self.assertAlmostEqual(metrics["eval_loss"], original_eval_loss)
|
||||
|
||||
# 5. TorchDynamo fx2trt-fp16
|
||||
trainer = get_regression_trainer(a=a, b=b, eval_len=eval_len, torchdynamo="fx2trt-fp16")
|
||||
metrics = trainer.evaluate()
|
||||
t1 = metrics["eval_loss"]
|
||||
t2 = original_eval_loss
|
||||
# fp16 has accuracy accuracy degradation
|
||||
self.assertLess(np.max(np.abs(t1 - t2)), 1e-3)
|
||||
|
||||
@require_torch_non_multi_gpu
|
||||
@require_torchdynamo
|
||||
def test_torchdynamo_memory(self):
|
||||
@ -1849,7 +1866,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
|
||||
mod = MyModule()
|
||||
|
||||
# 1. Default - without TorchDynamo
|
||||
# 1. without TorchDynamo (eager baseline)
|
||||
a = torch.ones(1024, 1024, device="cuda", requires_grad=True)
|
||||
a.grad = None
|
||||
trainer = CustomTrainer(model=mod)
|
||||
@ -1857,16 +1874,15 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
for _ in range(10):
|
||||
orig_loss = trainer.training_step(mod, {"x": a})
|
||||
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
orig_loss = trainer.training_step(mod, {"x": a})
|
||||
orig_peak_mem = torch.cuda.max_memory_allocated()
|
||||
del trainer
|
||||
|
||||
# Reset the peak for another measurement
|
||||
# resets
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
|
||||
orig_loss = trainer.training_step(mod, {"x": a})
|
||||
orig_peak_mem = torch.cuda.max_memory_allocated()
|
||||
del trainer
|
||||
|
||||
# 2. TorchDynamo nvfuser
|
||||
a = torch.ones(1024, 1024, device="cuda", requires_grad=True)
|
||||
a.grad = None
|
||||
@ -1876,7 +1892,11 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
for _ in range(10):
|
||||
loss = trainer.training_step(mod, {"x": a})
|
||||
|
||||
# resets
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
|
||||
loss = trainer.training_step(mod, {"x": a})
|
||||
peak_mem = torch.cuda.max_memory_allocated()
|
||||
del trainer
|
||||
|
Loading…
Reference in New Issue
Block a user