mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 11:11:05 +06:00
Device agnostic trainer testing (#27131)
This commit is contained in:
parent
84724efd10
commit
5bbf671276
@ -629,6 +629,20 @@ def require_torch_multi_gpu(test_case):
|
||||
return unittest.skipUnless(torch.cuda.device_count() > 1, "test requires multiple GPUs")(test_case)
|
||||
|
||||
|
||||
def require_torch_multi_accelerator(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires a multi-accelerator (in PyTorch). These tests are skipped on a machine
|
||||
without multiple accelerators. To run *only* the multi_accelerator tests, assuming all test names contain
|
||||
multi_accelerator: $ pytest -sv ./tests -k "multi_accelerator"
|
||||
"""
|
||||
if not is_torch_available():
|
||||
return unittest.skip("test requires PyTorch")(test_case)
|
||||
|
||||
return unittest.skipUnless(backend_device_count(torch_device) > 1, "test requires multiple accelerators")(
|
||||
test_case
|
||||
)
|
||||
|
||||
|
||||
def require_torch_non_multi_gpu(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires 0 or 1 GPU setup (in PyTorch).
|
||||
@ -641,6 +655,16 @@ def require_torch_non_multi_gpu(test_case):
|
||||
return unittest.skipUnless(torch.cuda.device_count() < 2, "test requires 0 or 1 GPU")(test_case)
|
||||
|
||||
|
||||
def require_torch_non_multi_accelerator(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires 0 or 1 accelerator setup (in PyTorch).
|
||||
"""
|
||||
if not is_torch_available():
|
||||
return unittest.skip("test requires PyTorch")(test_case)
|
||||
|
||||
return unittest.skipUnless(backend_device_count(torch_device) < 2, "test requires 0 or 1 accelerator")(test_case)
|
||||
|
||||
|
||||
def require_torch_up_to_2_gpus(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires 0 or 1 or 2 GPU setup (in PyTorch).
|
||||
@ -653,6 +677,17 @@ def require_torch_up_to_2_gpus(test_case):
|
||||
return unittest.skipUnless(torch.cuda.device_count() < 3, "test requires 0 or 1 or 2 GPUs")(test_case)
|
||||
|
||||
|
||||
def require_torch_up_to_2_accelerators(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires 0 or 1 or 2 accelerator setup (in PyTorch).
|
||||
"""
|
||||
if not is_torch_available():
|
||||
return unittest.skip("test requires PyTorch")(test_case)
|
||||
|
||||
return unittest.skipUnless(backend_device_count(torch_device) < 3, "test requires 0 or 1 or 2 accelerators")
|
||||
(test_case)
|
||||
|
||||
|
||||
def require_torch_tpu(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires a TPU (in PyTorch).
|
||||
@ -774,7 +809,9 @@ def require_torch_gpu(test_case):
|
||||
|
||||
def require_torch_accelerator(test_case):
|
||||
"""Decorator marking a test that requires an accessible accelerator and PyTorch."""
|
||||
return unittest.skipUnless(torch_device != "cpu", "test requires accelerator")(test_case)
|
||||
return unittest.skipUnless(torch_device is not None and torch_device != "cpu", "test requires accelerator")(
|
||||
test_case
|
||||
)
|
||||
|
||||
|
||||
def require_torch_fp16(test_case):
|
||||
|
@ -26,16 +26,17 @@ from transformers.testing_utils import (
|
||||
CaptureStderr,
|
||||
ExtendSysPath,
|
||||
TestCasePlus,
|
||||
backend_device_count,
|
||||
execute_subprocess_async,
|
||||
get_gpu_count,
|
||||
get_torch_dist_unique_port,
|
||||
require_apex,
|
||||
require_bitsandbytes,
|
||||
require_torch,
|
||||
require_torch_gpu,
|
||||
require_torch_multi_gpu,
|
||||
require_torch_non_multi_gpu,
|
||||
require_torch_multi_accelerator,
|
||||
require_torch_non_multi_accelerator,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
from transformers.trainer_callback import TrainerState
|
||||
from transformers.trainer_utils import set_seed
|
||||
@ -89,17 +90,17 @@ class TestTrainerExt(TestCasePlus):
|
||||
assert isinstance(last_step_stats["eval_bleu"], float)
|
||||
assert not math.isnan(float(last_step_stats["eval_loss"])), "eval_loss must not be `nan`"
|
||||
|
||||
@require_torch_non_multi_gpu
|
||||
@require_torch_non_multi_accelerator
|
||||
def test_run_seq2seq_no_dist(self):
|
||||
self.run_seq2seq_quick()
|
||||
|
||||
# verify that the trainer can handle non-distributed with n_gpu > 1
|
||||
@require_torch_multi_gpu
|
||||
@require_torch_multi_accelerator
|
||||
def test_run_seq2seq_dp(self):
|
||||
self.run_seq2seq_quick(distributed=False)
|
||||
|
||||
# verify that the trainer can handle distributed with n_gpu > 1
|
||||
@require_torch_multi_gpu
|
||||
@require_torch_multi_accelerator
|
||||
def test_run_seq2seq_ddp(self):
|
||||
self.run_seq2seq_quick(distributed=True)
|
||||
|
||||
@ -120,7 +121,7 @@ class TestTrainerExt(TestCasePlus):
|
||||
self.run_seq2seq_quick(distributed=True, extra_args_str="--fp16 --fp16_backend=apex")
|
||||
|
||||
@parameterized.expand(["base", "low", "high", "mixed"])
|
||||
@require_torch_multi_gpu
|
||||
@require_torch_multi_accelerator
|
||||
def test_trainer_log_level_replica(self, experiment_id):
|
||||
# as each sub-test is slow-ish split into multiple sub-tests to avoid CI timeout
|
||||
experiments = {
|
||||
@ -331,7 +332,7 @@ class TestTrainerExt(TestCasePlus):
|
||||
|
||||
if distributed:
|
||||
if n_gpus_to_use is None:
|
||||
n_gpus_to_use = get_gpu_count()
|
||||
n_gpus_to_use = backend_device_count(torch_device)
|
||||
master_port = get_torch_dist_unique_port()
|
||||
distributed_args = f"""
|
||||
-m torch.distributed.run
|
||||
|
@ -49,6 +49,7 @@ from transformers.testing_utils import (
|
||||
USER,
|
||||
CaptureLogger,
|
||||
TestCasePlus,
|
||||
backend_device_count,
|
||||
execute_subprocess_async,
|
||||
get_gpu_count,
|
||||
get_tests_dir,
|
||||
@ -63,17 +64,19 @@ from transformers.testing_utils import (
|
||||
require_tensorboard,
|
||||
require_tokenizers,
|
||||
require_torch,
|
||||
require_torch_bf16_cpu,
|
||||
require_torch_bf16_gpu,
|
||||
require_torch_accelerator,
|
||||
require_torch_bf16,
|
||||
require_torch_gpu,
|
||||
require_torch_multi_gpu,
|
||||
require_torch_multi_accelerator,
|
||||
require_torch_non_multi_accelerator,
|
||||
require_torch_non_multi_gpu,
|
||||
require_torch_tensorrt_fx,
|
||||
require_torch_tf32,
|
||||
require_torch_up_to_2_gpus,
|
||||
require_torch_up_to_2_accelerators,
|
||||
require_torchdynamo,
|
||||
require_wandb,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, HPSearchBackend
|
||||
from transformers.training_args import OptimizerNames
|
||||
@ -606,7 +609,7 @@ class TrainerIntegrationPrerunTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
)
|
||||
|
||||
def test_training_loss(self):
|
||||
n_gpus = max(1, get_gpu_count())
|
||||
n_gpus = max(1, backend_device_count(torch_device))
|
||||
|
||||
# With even logs
|
||||
trainer = get_regression_trainer(logging_steps=64 / (8 * n_gpus))
|
||||
@ -726,8 +729,8 @@ class TrainerIntegrationPrerunTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
self.assertFalse(torch.allclose(trainer.model.b, b))
|
||||
self.assertGreater(trainer.optimizer.state_dict()["param_groups"][0]["lr"], 0)
|
||||
|
||||
@require_torch_gpu
|
||||
@require_torch_bf16_gpu
|
||||
@require_torch_accelerator
|
||||
@require_torch_bf16
|
||||
def test_mixed_bf16(self):
|
||||
# very basic test
|
||||
trainer = get_regression_trainer(learning_rate=0.1, bf16=True)
|
||||
@ -812,25 +815,25 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
train_output = trainer.train()
|
||||
self.assertEqual(train_output.global_step, 10)
|
||||
|
||||
@require_torch_bf16_cpu
|
||||
@require_torch_bf16
|
||||
@require_intel_extension_for_pytorch
|
||||
def test_number_of_steps_in_training_with_ipex(self):
|
||||
for mix_bf16 in [True, False]:
|
||||
# Regular training has n_epochs * len(train_dl) steps
|
||||
trainer = get_regression_trainer(learning_rate=0.1, use_ipex=True, bf16=mix_bf16, no_cuda=True)
|
||||
trainer = get_regression_trainer(learning_rate=0.1, use_ipex=True, bf16=mix_bf16, use_cpu=True)
|
||||
train_output = trainer.train()
|
||||
self.assertEqual(train_output.global_step, self.n_epochs * 64 / trainer.args.train_batch_size)
|
||||
|
||||
# Check passing num_train_epochs works (and a float version too):
|
||||
trainer = get_regression_trainer(
|
||||
learning_rate=0.1, num_train_epochs=1.5, use_ipex=True, bf16=mix_bf16, no_cuda=True
|
||||
learning_rate=0.1, num_train_epochs=1.5, use_ipex=True, bf16=mix_bf16, use_cpu=True
|
||||
)
|
||||
train_output = trainer.train()
|
||||
self.assertEqual(train_output.global_step, int(1.5 * 64 / trainer.args.train_batch_size))
|
||||
|
||||
# If we pass a max_steps, num_train_epochs is ignored
|
||||
trainer = get_regression_trainer(
|
||||
learning_rate=0.1, max_steps=10, use_ipex=True, bf16=mix_bf16, no_cuda=True
|
||||
learning_rate=0.1, max_steps=10, use_ipex=True, bf16=mix_bf16, use_cpu=True
|
||||
)
|
||||
train_output = trainer.train()
|
||||
self.assertEqual(train_output.global_step, 10)
|
||||
@ -861,7 +864,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
self.assertFalse(is_any_loss_nan_or_inf(log_history_filter))
|
||||
|
||||
def test_train_and_eval_dataloaders(self):
|
||||
n_gpu = max(1, torch.cuda.device_count())
|
||||
n_gpu = max(1, backend_device_count(torch_device))
|
||||
trainer = get_regression_trainer(learning_rate=0.1, per_device_train_batch_size=16)
|
||||
self.assertEqual(trainer.get_train_dataloader().total_batch_size, 16 * n_gpu)
|
||||
trainer = get_regression_trainer(learning_rate=0.1, per_device_eval_batch_size=16)
|
||||
@ -898,7 +901,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
trainer.train()
|
||||
trainer.evaluate()
|
||||
|
||||
@require_torch_multi_gpu
|
||||
@require_torch_multi_accelerator
|
||||
def test_data_is_not_parallelized_when_model_is_parallel(self):
|
||||
model = RegressionModel()
|
||||
# Make the Trainer believe it's a parallelized model
|
||||
@ -995,12 +998,12 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
expected_acc = AlmostAccuracy()((pred + 1, y))["accuracy"]
|
||||
self.assertAlmostEqual(results["eval_accuracy"], expected_acc)
|
||||
|
||||
@require_torch_bf16_cpu
|
||||
@require_torch_bf16
|
||||
@require_intel_extension_for_pytorch
|
||||
def test_evaluate_with_ipex(self):
|
||||
for mix_bf16 in [True, False]:
|
||||
trainer = get_regression_trainer(
|
||||
a=1.5, b=2.5, use_ipex=True, compute_metrics=AlmostAccuracy(), bf16=mix_bf16, no_cuda=True
|
||||
a=1.5, b=2.5, use_ipex=True, compute_metrics=AlmostAccuracy(), bf16=mix_bf16, use_cpu=True
|
||||
)
|
||||
results = trainer.evaluate()
|
||||
|
||||
@ -1019,7 +1022,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
eval_len=66,
|
||||
compute_metrics=AlmostAccuracy(),
|
||||
bf16=mix_bf16,
|
||||
no_cuda=True,
|
||||
use_cpu=True,
|
||||
)
|
||||
results = trainer.evaluate()
|
||||
|
||||
@ -1038,7 +1041,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
compute_metrics=AlmostAccuracy(),
|
||||
preprocess_logits_for_metrics=lambda logits, labels: logits + 1,
|
||||
bf16=mix_bf16,
|
||||
no_cuda=True,
|
||||
use_cpu=True,
|
||||
)
|
||||
results = trainer.evaluate()
|
||||
|
||||
@ -1115,24 +1118,24 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
self.assertTrue(np.array_equal(labels[0], trainer.eval_dataset.ys[0]))
|
||||
self.assertTrue(np.array_equal(labels[1], trainer.eval_dataset.ys[1]))
|
||||
|
||||
@require_torch_bf16_cpu
|
||||
@require_torch_bf16
|
||||
@require_intel_extension_for_pytorch
|
||||
def test_predict_with_ipex(self):
|
||||
for mix_bf16 in [True, False]:
|
||||
trainer = get_regression_trainer(a=1.5, b=2.5, use_ipex=True, bf16=mix_bf16, no_cuda=True)
|
||||
trainer = get_regression_trainer(a=1.5, b=2.5, use_ipex=True, bf16=mix_bf16, use_cpu=True)
|
||||
preds = trainer.predict(trainer.eval_dataset).predictions
|
||||
x = trainer.eval_dataset.x
|
||||
self.assertTrue(np.allclose(preds, 1.5 * x + 2.5))
|
||||
|
||||
# With a number of elements not a round multiple of the batch size
|
||||
trainer = get_regression_trainer(a=1.5, b=2.5, eval_len=66, use_ipex=True, bf16=mix_bf16, no_cuda=True)
|
||||
trainer = get_regression_trainer(a=1.5, b=2.5, eval_len=66, use_ipex=True, bf16=mix_bf16, use_cpu=True)
|
||||
preds = trainer.predict(trainer.eval_dataset).predictions
|
||||
x = trainer.eval_dataset.x
|
||||
self.assertTrue(np.allclose(preds, 1.5 * x + 2.5))
|
||||
|
||||
# With more than one output of the model
|
||||
trainer = get_regression_trainer(
|
||||
a=1.5, b=2.5, double_output=True, use_ipex=True, bf16=mix_bf16, no_cuda=True
|
||||
a=1.5, b=2.5, double_output=True, use_ipex=True, bf16=mix_bf16, use_cpu=True
|
||||
)
|
||||
preds = trainer.predict(trainer.eval_dataset).predictions
|
||||
x = trainer.eval_dataset.x
|
||||
@ -1148,7 +1151,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
label_names=["labels", "labels_2"],
|
||||
use_ipex=True,
|
||||
bf16=mix_bf16,
|
||||
no_cuda=True,
|
||||
use_cpu=True,
|
||||
)
|
||||
outputs = trainer.predict(trainer.eval_dataset)
|
||||
preds = outputs.predictions
|
||||
@ -1255,7 +1258,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
tmpdir, 5, int(self.n_epochs * 64 / self.batch_size), False, safe_weights=save_safetensors
|
||||
)
|
||||
|
||||
@require_torch_multi_gpu
|
||||
@require_torch_multi_accelerator
|
||||
def test_run_seq2seq_double_train_wrap_once(self):
|
||||
# test that we don't wrap the model more than once
|
||||
# since wrapping primarily happens on multi-gpu setup we want multiple gpus to test for
|
||||
@ -1268,7 +1271,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
model_wrapped_after = trainer.model_wrapped
|
||||
self.assertIs(model_wrapped_before, model_wrapped_after, "should be not wrapped twice")
|
||||
|
||||
@require_torch_up_to_2_gpus
|
||||
@require_torch_up_to_2_accelerators
|
||||
def test_can_resume_training(self):
|
||||
# This test will fail for more than 2 GPUs since the batch size will get bigger and with the number of
|
||||
# save_steps, the checkpoint will resume training at epoch 2 or more (so the data seen by the model
|
||||
@ -1424,7 +1427,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
|
||||
@slow
|
||||
@require_accelerate
|
||||
@require_torch_non_multi_gpu
|
||||
@require_torch_non_multi_accelerator
|
||||
def test_auto_batch_size_finder(self):
|
||||
if torch.cuda.is_available():
|
||||
torch.backends.cudnn.deterministic = True
|
||||
@ -1471,7 +1474,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
|
||||
trainer.train(resume_from_checkpoint=False)
|
||||
|
||||
@require_torch_up_to_2_gpus
|
||||
@require_torch_up_to_2_accelerators
|
||||
def test_resume_training_with_shard_checkpoint(self):
|
||||
# This test will fail for more than 2 GPUs since the batch size will get bigger and with the number of
|
||||
# save_steps, the checkpoint will resume training at epoch 2 or more (so the data seen by the model
|
||||
@ -1497,7 +1500,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
self.check_trainer_state_are_the_same(state, state1)
|
||||
|
||||
@require_safetensors
|
||||
@require_torch_up_to_2_gpus
|
||||
@require_torch_up_to_2_accelerators
|
||||
def test_resume_training_with_safe_checkpoint(self):
|
||||
# This test will fail for more than 2 GPUs since the batch size will get bigger and with the number of
|
||||
# save_steps, the checkpoint will resume training at epoch 2 or more (so the data seen by the model
|
||||
@ -1532,7 +1535,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
self.assertEqual(b, b1)
|
||||
self.check_trainer_state_are_the_same(state, state1)
|
||||
|
||||
@require_torch_up_to_2_gpus
|
||||
@require_torch_up_to_2_accelerators
|
||||
def test_resume_training_with_gradient_accumulation(self):
|
||||
# This test will fail for more than 2 GPUs since the batch size will get bigger and with the number of
|
||||
# save_steps, the checkpoint will resume training at epoch 2 or more (so the data seen by the model
|
||||
@ -1570,7 +1573,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
self.assertEqual(b, b1)
|
||||
self.check_trainer_state_are_the_same(state, state1)
|
||||
|
||||
@require_torch_up_to_2_gpus
|
||||
@require_torch_up_to_2_accelerators
|
||||
def test_resume_training_with_frozen_params(self):
|
||||
# This test will fail for more than 2 GPUs since the batch size will get bigger and with the number of
|
||||
# save_steps, the checkpoint will resume training at epoch 2 or more (so the data seen by the model
|
||||
@ -1715,7 +1718,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
)
|
||||
eval_dataset = GlueDataset(data_args, tokenizer=tokenizer, mode="dev")
|
||||
|
||||
training_args = TrainingArguments(output_dir="./examples", no_cuda=True)
|
||||
training_args = TrainingArguments(output_dir="./examples", use_cpu=True)
|
||||
trainer = Trainer(model=model, args=training_args, eval_dataset=eval_dataset)
|
||||
result = trainer.evaluate()
|
||||
self.assertLess(result["eval_loss"], 0.2)
|
||||
@ -1920,12 +1923,12 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
trainer = get_regression_trainer(skip_memory_metrics=True)
|
||||
self.check_mem_metrics(trainer, self.assertNotIn)
|
||||
|
||||
@require_torch_gpu
|
||||
@require_torch_accelerator
|
||||
def test_fp16_full_eval(self):
|
||||
# this is a sensitive test so let's keep debugging printouts in place for quick diagnosis.
|
||||
# it's using pretty large safety margins, but small enough to detect broken functionality.
|
||||
debug = 0
|
||||
n_gpus = get_gpu_count()
|
||||
n_gpus = backend_device_count(torch_device)
|
||||
|
||||
bs = 8
|
||||
eval_len = 16 * n_gpus
|
||||
@ -2090,15 +2093,15 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
# aggressively fuses the operations and reduce the memory footprint.
|
||||
self.assertGreater(orig_peak_mem, peak_mem * 2)
|
||||
|
||||
@require_torch_gpu
|
||||
@require_torch_bf16_gpu
|
||||
@require_torch_accelerator
|
||||
@require_torch_bf16
|
||||
def test_bf16_full_eval(self):
|
||||
# note: most of the logic is the same as test_fp16_full_eval
|
||||
|
||||
# this is a sensitive test so let's keep debugging printouts in place for quick diagnosis.
|
||||
# it's using pretty large safety margins, but small enough to detect broken functionality.
|
||||
debug = 0
|
||||
n_gpus = get_gpu_count()
|
||||
n_gpus = backend_device_count(torch_device)
|
||||
|
||||
bs = 8
|
||||
eval_len = 16 * n_gpus
|
||||
@ -2163,7 +2166,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
self.assertListEqual(trainer.optimizer.param_groups[1]["params"], no_wd_params)
|
||||
|
||||
@slow
|
||||
@require_torch_multi_gpu
|
||||
@require_torch_multi_accelerator
|
||||
def test_end_to_end_example(self):
|
||||
# Tests that `translation.py` will run without issues
|
||||
script_path = os.path.abspath(
|
||||
@ -2302,7 +2305,7 @@ class TrainerIntegrationWithHubTester(unittest.TestCase):
|
||||
self.assertIn(f"Training in progress, epoch {i}", commits)
|
||||
|
||||
def test_push_to_hub_with_saves_each_n_steps(self):
|
||||
num_gpus = max(1, get_gpu_count())
|
||||
num_gpus = max(1, backend_device_count(torch_device))
|
||||
if num_gpus > 2:
|
||||
return
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user