mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 12:50:06 +06:00
Run torchdynamo
tests (#19056)
* Enable torchdynamo tests * make style Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
parent
f7ce4f1ff7
commit
16242e1bf0
@ -27,6 +27,24 @@ RUN python3 -m pip uninstall -y deepspeed
|
||||
# RUN git clone https://github.com/microsoft/DeepSpeed && cd DeepSpeed && rm -rf build && \
|
||||
# DS_BUILD_CPU_ADAM=1 DS_BUILD_FUSED_ADAM=1 DS_BUILD_AIO=1 DS_BUILD_UTILS=1 python3 -m pip install . --global-option="build_ext" --global-option="-j8" --no-cache -v --disable-pip-version-check 2>&1
|
||||
|
||||
# For `torchdynamo` tests
|
||||
# (see https://github.com/huggingface/transformers/pull/17765)
|
||||
RUN git clone https://github.com/pytorch/functorch
|
||||
RUN python3 -m pip install --no-cache-dir ./functorch[aot]
|
||||
RUN cd functorch && python3 setup.py develop
|
||||
|
||||
RUN git clone https://github.com/pytorch/torchdynamo
|
||||
RUN python3 -m pip install -r ./torchdynamo/requirements.txt
|
||||
RUN cd torchdynamo && python3 setup.py develop
|
||||
|
||||
# install TensorRT
|
||||
RUN python3 -m pip install --no-cache-dir -U nvidia-pyindex
|
||||
RUN python3 -m pip install --no-cache-dir -U nvidia-tensorrt==8.2.4.2
|
||||
|
||||
# install torch_tensorrt (fx path)
|
||||
RUN git clone https://github.com/pytorch/TensorRT.git
|
||||
RUN cd TensorRT/py && python3 setup.py install --fx-only
|
||||
|
||||
# When installing in editable mode, `transformers` is not recognized as a package.
|
||||
# this line must be added in order for python to be aware of transformers.
|
||||
RUN cd transformers && python3 setup.py develop
|
||||
|
@ -638,14 +638,13 @@ class Trainer:
|
||||
raise RuntimeError("Torchdynamo is not installed.")
|
||||
import torchdynamo
|
||||
from torchdynamo.optimizations import backends
|
||||
from torchdynamo.optimizations.training import aot_autograd_speedup_strategy
|
||||
|
||||
def get_ctx():
|
||||
# Normal
|
||||
if args.torchdynamo == "eager":
|
||||
return torchdynamo.optimize("eager")
|
||||
elif args.torchdynamo == "nvfuser":
|
||||
return torchdynamo.optimize(aot_autograd_speedup_strategy)
|
||||
return torchdynamo.optimize("aot_nvfuser")
|
||||
# TensorRT
|
||||
if args.torchdynamo in ["fx2trt-fp16", "fx2trt"]:
|
||||
if not is_torch_tensorrt_fx_available():
|
||||
|
@ -1799,6 +1799,8 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
@require_torchdynamo
|
||||
@require_torch_tensorrt_fx
|
||||
def test_torchdynamo_full_eval(self):
|
||||
import torchdynamo
|
||||
|
||||
# torchdynamo at the moment doesn't support DP/DDP, therefore require a single gpu
|
||||
n_gpus = get_gpu_count()
|
||||
|
||||
@ -1820,11 +1822,13 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
metrics = trainer.evaluate()
|
||||
self.assertAlmostEqual(metrics["eval_loss"], original_eval_loss)
|
||||
del trainer
|
||||
torchdynamo.reset()
|
||||
|
||||
# 3. TorchDynamo nvfuser
|
||||
trainer = get_regression_trainer(a=a, b=b, eval_len=eval_len, torchdynamo="nvfuser")
|
||||
metrics = trainer.evaluate()
|
||||
self.assertAlmostEqual(metrics["eval_loss"], original_eval_loss)
|
||||
torchdynamo.reset()
|
||||
|
||||
# 4. TorchDynamo fx2trt
|
||||
trainer = get_regression_trainer(a=a, b=b, eval_len=eval_len, torchdynamo="fx2trt")
|
||||
@ -1832,6 +1836,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
t1 = metrics["eval_loss"]
|
||||
t2 = original_eval_loss
|
||||
self.assertAlmostEqual(metrics["eval_loss"], original_eval_loss)
|
||||
torchdynamo.reset()
|
||||
|
||||
# 5. TorchDynamo fx2trt-fp16
|
||||
trainer = get_regression_trainer(a=a, b=b, eval_len=eval_len, torchdynamo="fx2trt-fp16")
|
||||
@ -1840,11 +1845,14 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
t2 = original_eval_loss
|
||||
# fp16 has accuracy accuracy degradation
|
||||
self.assertLess(np.max(np.abs(t1 - t2)), 1e-3)
|
||||
torchdynamo.reset()
|
||||
|
||||
@require_torch_non_multi_gpu
|
||||
@require_torchdynamo
|
||||
def test_torchdynamo_memory(self):
|
||||
# torchdynamo at the moment doesn't support DP/DDP, therefore require a single gpu
|
||||
import torchdynamo
|
||||
|
||||
class CustomTrainer(Trainer):
|
||||
def compute_loss(self, model, inputs, return_outputs=False):
|
||||
x = inputs["x"]
|
||||
@ -1861,7 +1869,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
|
||||
def forward(self, x):
|
||||
for _ in range(20):
|
||||
x = torch.nn.functional.relu(x)
|
||||
x = torch.cos(x)
|
||||
return x
|
||||
|
||||
mod = MyModule()
|
||||
@ -1881,6 +1889,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
|
||||
orig_loss = trainer.training_step(mod, {"x": a})
|
||||
orig_peak_mem = torch.cuda.max_memory_allocated()
|
||||
torchdynamo.reset()
|
||||
del trainer
|
||||
|
||||
# 2. TorchDynamo nvfuser
|
||||
@ -1899,6 +1908,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
|
||||
loss = trainer.training_step(mod, {"x": a})
|
||||
peak_mem = torch.cuda.max_memory_allocated()
|
||||
torchdynamo.reset()
|
||||
del trainer
|
||||
|
||||
# Functional check
|
||||
|
Loading…
Reference in New Issue
Block a user