mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-30 17:52:35 +06:00
Run torchdynamo
tests (#19056)
* Enable torchdynamo tests * make style Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
parent
f7ce4f1ff7
commit
16242e1bf0
@ -27,6 +27,24 @@ RUN python3 -m pip uninstall -y deepspeed
|
|||||||
# RUN git clone https://github.com/microsoft/DeepSpeed && cd DeepSpeed && rm -rf build && \
|
# RUN git clone https://github.com/microsoft/DeepSpeed && cd DeepSpeed && rm -rf build && \
|
||||||
# DS_BUILD_CPU_ADAM=1 DS_BUILD_FUSED_ADAM=1 DS_BUILD_AIO=1 DS_BUILD_UTILS=1 python3 -m pip install . --global-option="build_ext" --global-option="-j8" --no-cache -v --disable-pip-version-check 2>&1
|
# DS_BUILD_CPU_ADAM=1 DS_BUILD_FUSED_ADAM=1 DS_BUILD_AIO=1 DS_BUILD_UTILS=1 python3 -m pip install . --global-option="build_ext" --global-option="-j8" --no-cache -v --disable-pip-version-check 2>&1
|
||||||
|
|
||||||
|
# For `torchdynamo` tests
|
||||||
|
# (see https://github.com/huggingface/transformers/pull/17765)
|
||||||
|
RUN git clone https://github.com/pytorch/functorch
|
||||||
|
RUN python3 -m pip install --no-cache-dir ./functorch[aot]
|
||||||
|
RUN cd functorch && python3 setup.py develop
|
||||||
|
|
||||||
|
RUN git clone https://github.com/pytorch/torchdynamo
|
||||||
|
RUN python3 -m pip install -r ./torchdynamo/requirements.txt
|
||||||
|
RUN cd torchdynamo && python3 setup.py develop
|
||||||
|
|
||||||
|
# install TensorRT
|
||||||
|
RUN python3 -m pip install --no-cache-dir -U nvidia-pyindex
|
||||||
|
RUN python3 -m pip install --no-cache-dir -U nvidia-tensorrt==8.2.4.2
|
||||||
|
|
||||||
|
# install torch_tensorrt (fx path)
|
||||||
|
RUN git clone https://github.com/pytorch/TensorRT.git
|
||||||
|
RUN cd TensorRT/py && python3 setup.py install --fx-only
|
||||||
|
|
||||||
# When installing in editable mode, `transformers` is not recognized as a package.
|
# When installing in editable mode, `transformers` is not recognized as a package.
|
||||||
# this line must be added in order for python to be aware of transformers.
|
# this line must be added in order for python to be aware of transformers.
|
||||||
RUN cd transformers && python3 setup.py develop
|
RUN cd transformers && python3 setup.py develop
|
||||||
|
@ -638,14 +638,13 @@ class Trainer:
|
|||||||
raise RuntimeError("Torchdynamo is not installed.")
|
raise RuntimeError("Torchdynamo is not installed.")
|
||||||
import torchdynamo
|
import torchdynamo
|
||||||
from torchdynamo.optimizations import backends
|
from torchdynamo.optimizations import backends
|
||||||
from torchdynamo.optimizations.training import aot_autograd_speedup_strategy
|
|
||||||
|
|
||||||
def get_ctx():
|
def get_ctx():
|
||||||
# Normal
|
# Normal
|
||||||
if args.torchdynamo == "eager":
|
if args.torchdynamo == "eager":
|
||||||
return torchdynamo.optimize("eager")
|
return torchdynamo.optimize("eager")
|
||||||
elif args.torchdynamo == "nvfuser":
|
elif args.torchdynamo == "nvfuser":
|
||||||
return torchdynamo.optimize(aot_autograd_speedup_strategy)
|
return torchdynamo.optimize("aot_nvfuser")
|
||||||
# TensorRT
|
# TensorRT
|
||||||
if args.torchdynamo in ["fx2trt-fp16", "fx2trt"]:
|
if args.torchdynamo in ["fx2trt-fp16", "fx2trt"]:
|
||||||
if not is_torch_tensorrt_fx_available():
|
if not is_torch_tensorrt_fx_available():
|
||||||
|
@ -1799,6 +1799,8 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
|||||||
@require_torchdynamo
|
@require_torchdynamo
|
||||||
@require_torch_tensorrt_fx
|
@require_torch_tensorrt_fx
|
||||||
def test_torchdynamo_full_eval(self):
|
def test_torchdynamo_full_eval(self):
|
||||||
|
import torchdynamo
|
||||||
|
|
||||||
# torchdynamo at the moment doesn't support DP/DDP, therefore require a single gpu
|
# torchdynamo at the moment doesn't support DP/DDP, therefore require a single gpu
|
||||||
n_gpus = get_gpu_count()
|
n_gpus = get_gpu_count()
|
||||||
|
|
||||||
@ -1820,11 +1822,13 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
|||||||
metrics = trainer.evaluate()
|
metrics = trainer.evaluate()
|
||||||
self.assertAlmostEqual(metrics["eval_loss"], original_eval_loss)
|
self.assertAlmostEqual(metrics["eval_loss"], original_eval_loss)
|
||||||
del trainer
|
del trainer
|
||||||
|
torchdynamo.reset()
|
||||||
|
|
||||||
# 3. TorchDynamo nvfuser
|
# 3. TorchDynamo nvfuser
|
||||||
trainer = get_regression_trainer(a=a, b=b, eval_len=eval_len, torchdynamo="nvfuser")
|
trainer = get_regression_trainer(a=a, b=b, eval_len=eval_len, torchdynamo="nvfuser")
|
||||||
metrics = trainer.evaluate()
|
metrics = trainer.evaluate()
|
||||||
self.assertAlmostEqual(metrics["eval_loss"], original_eval_loss)
|
self.assertAlmostEqual(metrics["eval_loss"], original_eval_loss)
|
||||||
|
torchdynamo.reset()
|
||||||
|
|
||||||
# 4. TorchDynamo fx2trt
|
# 4. TorchDynamo fx2trt
|
||||||
trainer = get_regression_trainer(a=a, b=b, eval_len=eval_len, torchdynamo="fx2trt")
|
trainer = get_regression_trainer(a=a, b=b, eval_len=eval_len, torchdynamo="fx2trt")
|
||||||
@ -1832,6 +1836,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
|||||||
t1 = metrics["eval_loss"]
|
t1 = metrics["eval_loss"]
|
||||||
t2 = original_eval_loss
|
t2 = original_eval_loss
|
||||||
self.assertAlmostEqual(metrics["eval_loss"], original_eval_loss)
|
self.assertAlmostEqual(metrics["eval_loss"], original_eval_loss)
|
||||||
|
torchdynamo.reset()
|
||||||
|
|
||||||
# 5. TorchDynamo fx2trt-fp16
|
# 5. TorchDynamo fx2trt-fp16
|
||||||
trainer = get_regression_trainer(a=a, b=b, eval_len=eval_len, torchdynamo="fx2trt-fp16")
|
trainer = get_regression_trainer(a=a, b=b, eval_len=eval_len, torchdynamo="fx2trt-fp16")
|
||||||
@ -1840,11 +1845,14 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
|||||||
t2 = original_eval_loss
|
t2 = original_eval_loss
|
||||||
# fp16 has accuracy accuracy degradation
|
# fp16 has accuracy accuracy degradation
|
||||||
self.assertLess(np.max(np.abs(t1 - t2)), 1e-3)
|
self.assertLess(np.max(np.abs(t1 - t2)), 1e-3)
|
||||||
|
torchdynamo.reset()
|
||||||
|
|
||||||
@require_torch_non_multi_gpu
|
@require_torch_non_multi_gpu
|
||||||
@require_torchdynamo
|
@require_torchdynamo
|
||||||
def test_torchdynamo_memory(self):
|
def test_torchdynamo_memory(self):
|
||||||
# torchdynamo at the moment doesn't support DP/DDP, therefore require a single gpu
|
# torchdynamo at the moment doesn't support DP/DDP, therefore require a single gpu
|
||||||
|
import torchdynamo
|
||||||
|
|
||||||
class CustomTrainer(Trainer):
|
class CustomTrainer(Trainer):
|
||||||
def compute_loss(self, model, inputs, return_outputs=False):
|
def compute_loss(self, model, inputs, return_outputs=False):
|
||||||
x = inputs["x"]
|
x = inputs["x"]
|
||||||
@ -1861,7 +1869,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
|||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
for _ in range(20):
|
for _ in range(20):
|
||||||
x = torch.nn.functional.relu(x)
|
x = torch.cos(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
mod = MyModule()
|
mod = MyModule()
|
||||||
@ -1881,6 +1889,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
|||||||
|
|
||||||
orig_loss = trainer.training_step(mod, {"x": a})
|
orig_loss = trainer.training_step(mod, {"x": a})
|
||||||
orig_peak_mem = torch.cuda.max_memory_allocated()
|
orig_peak_mem = torch.cuda.max_memory_allocated()
|
||||||
|
torchdynamo.reset()
|
||||||
del trainer
|
del trainer
|
||||||
|
|
||||||
# 2. TorchDynamo nvfuser
|
# 2. TorchDynamo nvfuser
|
||||||
@ -1899,6 +1908,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
|||||||
|
|
||||||
loss = trainer.training_step(mod, {"x": a})
|
loss = trainer.training_step(mod, {"x": a})
|
||||||
peak_mem = torch.cuda.max_memory_allocated()
|
peak_mem = torch.cuda.max_memory_allocated()
|
||||||
|
torchdynamo.reset()
|
||||||
del trainer
|
del trainer
|
||||||
|
|
||||||
# Functional check
|
# Functional check
|
||||||
|
Loading…
Reference in New Issue
Block a user