Run torchdynamo tests (#19056)

* Enable torchdynamo tests

* make style

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yih-Dar 2022-09-15 20:10:16 +02:00 committed by GitHub
parent f7ce4f1ff7
commit 16242e1bf0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 30 additions and 3 deletions

View File

@ -27,6 +27,24 @@ RUN python3 -m pip uninstall -y deepspeed
# RUN git clone https://github.com/microsoft/DeepSpeed && cd DeepSpeed && rm -rf build && \
# DS_BUILD_CPU_ADAM=1 DS_BUILD_FUSED_ADAM=1 DS_BUILD_AIO=1 DS_BUILD_UTILS=1 python3 -m pip install . --global-option="build_ext" --global-option="-j8" --no-cache -v --disable-pip-version-check 2>&1
# For `torchdynamo` tests
# (see https://github.com/huggingface/transformers/pull/17765)
RUN git clone https://github.com/pytorch/functorch
RUN python3 -m pip install --no-cache-dir ./functorch[aot]
RUN cd functorch && python3 setup.py develop
RUN git clone https://github.com/pytorch/torchdynamo
RUN python3 -m pip install -r ./torchdynamo/requirements.txt
RUN cd torchdynamo && python3 setup.py develop
# install TensorRT
RUN python3 -m pip install --no-cache-dir -U nvidia-pyindex
RUN python3 -m pip install --no-cache-dir -U nvidia-tensorrt==8.2.4.2
# install torch_tensorrt (fx path)
RUN git clone https://github.com/pytorch/TensorRT.git
RUN cd TensorRT/py && python3 setup.py install --fx-only
# When installing in editable mode, `transformers` is not recognized as a package.
# this line must be added in order for python to be aware of transformers.
RUN cd transformers && python3 setup.py develop

View File

@ -638,14 +638,13 @@ class Trainer:
raise RuntimeError("Torchdynamo is not installed.")
import torchdynamo
from torchdynamo.optimizations import backends
from torchdynamo.optimizations.training import aot_autograd_speedup_strategy
def get_ctx():
# Normal
if args.torchdynamo == "eager":
return torchdynamo.optimize("eager")
elif args.torchdynamo == "nvfuser":
return torchdynamo.optimize(aot_autograd_speedup_strategy)
return torchdynamo.optimize("aot_nvfuser")
# TensorRT
if args.torchdynamo in ["fx2trt-fp16", "fx2trt"]:
if not is_torch_tensorrt_fx_available():

View File

@ -1799,6 +1799,8 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
@require_torchdynamo
@require_torch_tensorrt_fx
def test_torchdynamo_full_eval(self):
import torchdynamo
# torchdynamo at the moment doesn't support DP/DDP, therefore require a single gpu
n_gpus = get_gpu_count()
@ -1820,11 +1822,13 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
metrics = trainer.evaluate()
self.assertAlmostEqual(metrics["eval_loss"], original_eval_loss)
del trainer
torchdynamo.reset()
# 3. TorchDynamo nvfuser
trainer = get_regression_trainer(a=a, b=b, eval_len=eval_len, torchdynamo="nvfuser")
metrics = trainer.evaluate()
self.assertAlmostEqual(metrics["eval_loss"], original_eval_loss)
torchdynamo.reset()
# 4. TorchDynamo fx2trt
trainer = get_regression_trainer(a=a, b=b, eval_len=eval_len, torchdynamo="fx2trt")
@ -1832,6 +1836,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
t1 = metrics["eval_loss"]
t2 = original_eval_loss
self.assertAlmostEqual(metrics["eval_loss"], original_eval_loss)
torchdynamo.reset()
# 5. TorchDynamo fx2trt-fp16
trainer = get_regression_trainer(a=a, b=b, eval_len=eval_len, torchdynamo="fx2trt-fp16")
@ -1840,11 +1845,14 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
t2 = original_eval_loss
# fp16 has accuracy accuracy degradation
self.assertLess(np.max(np.abs(t1 - t2)), 1e-3)
torchdynamo.reset()
@require_torch_non_multi_gpu
@require_torchdynamo
def test_torchdynamo_memory(self):
# torchdynamo at the moment doesn't support DP/DDP, therefore require a single gpu
import torchdynamo
class CustomTrainer(Trainer):
def compute_loss(self, model, inputs, return_outputs=False):
x = inputs["x"]
@ -1861,7 +1869,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
def forward(self, x):
for _ in range(20):
x = torch.nn.functional.relu(x)
x = torch.cos(x)
return x
mod = MyModule()
@ -1881,6 +1889,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
orig_loss = trainer.training_step(mod, {"x": a})
orig_peak_mem = torch.cuda.max_memory_allocated()
torchdynamo.reset()
del trainer
# 2. TorchDynamo nvfuser
@ -1899,6 +1908,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
loss = trainer.training_step(mod, {"x": a})
peak_mem = torch.cuda.max_memory_allocated()
torchdynamo.reset()
del trainer
# Functional check