Device agnostic trainer testing (#27131)

This commit is contained in:
Hz, Ji 2023-10-31 02:16:40 +08:00 committed by GitHub
parent 84724efd10
commit 5bbf671276
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 87 additions and 46 deletions

View File

@ -629,6 +629,20 @@ def require_torch_multi_gpu(test_case):
return unittest.skipUnless(torch.cuda.device_count() > 1, "test requires multiple GPUs")(test_case)
def require_torch_multi_accelerator(test_case):
"""
Decorator marking a test that requires a multi-accelerator (in PyTorch). These tests are skipped on a machine
without multiple accelerators. To run *only* the multi_accelerator tests, assuming all test names contain
multi_accelerator: $ pytest -sv ./tests -k "multi_accelerator"
"""
if not is_torch_available():
return unittest.skip("test requires PyTorch")(test_case)
return unittest.skipUnless(backend_device_count(torch_device) > 1, "test requires multiple accelerators")(
test_case
)
def require_torch_non_multi_gpu(test_case):
"""
Decorator marking a test that requires 0 or 1 GPU setup (in PyTorch).
@ -641,6 +655,16 @@ def require_torch_non_multi_gpu(test_case):
return unittest.skipUnless(torch.cuda.device_count() < 2, "test requires 0 or 1 GPU")(test_case)
def require_torch_non_multi_accelerator(test_case):
"""
Decorator marking a test that requires 0 or 1 accelerator setup (in PyTorch).
"""
if not is_torch_available():
return unittest.skip("test requires PyTorch")(test_case)
return unittest.skipUnless(backend_device_count(torch_device) < 2, "test requires 0 or 1 accelerator")(test_case)
def require_torch_up_to_2_gpus(test_case):
"""
Decorator marking a test that requires 0 or 1 or 2 GPU setup (in PyTorch).
@ -653,6 +677,17 @@ def require_torch_up_to_2_gpus(test_case):
return unittest.skipUnless(torch.cuda.device_count() < 3, "test requires 0 or 1 or 2 GPUs")(test_case)
def require_torch_up_to_2_accelerators(test_case):
"""
Decorator marking a test that requires 0 or 1 or 2 accelerator setup (in PyTorch).
"""
if not is_torch_available():
return unittest.skip("test requires PyTorch")(test_case)
return unittest.skipUnless(backend_device_count(torch_device) < 3, "test requires 0 or 1 or 2 accelerators")
(test_case)
def require_torch_tpu(test_case):
"""
Decorator marking a test that requires a TPU (in PyTorch).
@ -774,7 +809,9 @@ def require_torch_gpu(test_case):
def require_torch_accelerator(test_case):
"""Decorator marking a test that requires an accessible accelerator and PyTorch."""
return unittest.skipUnless(torch_device != "cpu", "test requires accelerator")(test_case)
return unittest.skipUnless(torch_device is not None and torch_device != "cpu", "test requires accelerator")(
test_case
)
def require_torch_fp16(test_case):

View File

@ -26,16 +26,17 @@ from transformers.testing_utils import (
CaptureStderr,
ExtendSysPath,
TestCasePlus,
backend_device_count,
execute_subprocess_async,
get_gpu_count,
get_torch_dist_unique_port,
require_apex,
require_bitsandbytes,
require_torch,
require_torch_gpu,
require_torch_multi_gpu,
require_torch_non_multi_gpu,
require_torch_multi_accelerator,
require_torch_non_multi_accelerator,
slow,
torch_device,
)
from transformers.trainer_callback import TrainerState
from transformers.trainer_utils import set_seed
@ -89,17 +90,17 @@ class TestTrainerExt(TestCasePlus):
assert isinstance(last_step_stats["eval_bleu"], float)
assert not math.isnan(float(last_step_stats["eval_loss"])), "eval_loss must not be `nan`"
@require_torch_non_multi_gpu
@require_torch_non_multi_accelerator
def test_run_seq2seq_no_dist(self):
self.run_seq2seq_quick()
# verify that the trainer can handle non-distributed with n_gpu > 1
@require_torch_multi_gpu
@require_torch_multi_accelerator
def test_run_seq2seq_dp(self):
self.run_seq2seq_quick(distributed=False)
# verify that the trainer can handle distributed with n_gpu > 1
@require_torch_multi_gpu
@require_torch_multi_accelerator
def test_run_seq2seq_ddp(self):
self.run_seq2seq_quick(distributed=True)
@ -120,7 +121,7 @@ class TestTrainerExt(TestCasePlus):
self.run_seq2seq_quick(distributed=True, extra_args_str="--fp16 --fp16_backend=apex")
@parameterized.expand(["base", "low", "high", "mixed"])
@require_torch_multi_gpu
@require_torch_multi_accelerator
def test_trainer_log_level_replica(self, experiment_id):
# as each sub-test is slow-ish split into multiple sub-tests to avoid CI timeout
experiments = {
@ -331,7 +332,7 @@ class TestTrainerExt(TestCasePlus):
if distributed:
if n_gpus_to_use is None:
n_gpus_to_use = get_gpu_count()
n_gpus_to_use = backend_device_count(torch_device)
master_port = get_torch_dist_unique_port()
distributed_args = f"""
-m torch.distributed.run

View File

@ -49,6 +49,7 @@ from transformers.testing_utils import (
USER,
CaptureLogger,
TestCasePlus,
backend_device_count,
execute_subprocess_async,
get_gpu_count,
get_tests_dir,
@ -63,17 +64,19 @@ from transformers.testing_utils import (
require_tensorboard,
require_tokenizers,
require_torch,
require_torch_bf16_cpu,
require_torch_bf16_gpu,
require_torch_accelerator,
require_torch_bf16,
require_torch_gpu,
require_torch_multi_gpu,
require_torch_multi_accelerator,
require_torch_non_multi_accelerator,
require_torch_non_multi_gpu,
require_torch_tensorrt_fx,
require_torch_tf32,
require_torch_up_to_2_gpus,
require_torch_up_to_2_accelerators,
require_torchdynamo,
require_wandb,
slow,
torch_device,
)
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, HPSearchBackend
from transformers.training_args import OptimizerNames
@ -606,7 +609,7 @@ class TrainerIntegrationPrerunTest(TestCasePlus, TrainerIntegrationCommon):
)
def test_training_loss(self):
n_gpus = max(1, get_gpu_count())
n_gpus = max(1, backend_device_count(torch_device))
# With even logs
trainer = get_regression_trainer(logging_steps=64 / (8 * n_gpus))
@ -726,8 +729,8 @@ class TrainerIntegrationPrerunTest(TestCasePlus, TrainerIntegrationCommon):
self.assertFalse(torch.allclose(trainer.model.b, b))
self.assertGreater(trainer.optimizer.state_dict()["param_groups"][0]["lr"], 0)
@require_torch_gpu
@require_torch_bf16_gpu
@require_torch_accelerator
@require_torch_bf16
def test_mixed_bf16(self):
# very basic test
trainer = get_regression_trainer(learning_rate=0.1, bf16=True)
@ -812,25 +815,25 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
train_output = trainer.train()
self.assertEqual(train_output.global_step, 10)
@require_torch_bf16_cpu
@require_torch_bf16
@require_intel_extension_for_pytorch
def test_number_of_steps_in_training_with_ipex(self):
for mix_bf16 in [True, False]:
# Regular training has n_epochs * len(train_dl) steps
trainer = get_regression_trainer(learning_rate=0.1, use_ipex=True, bf16=mix_bf16, no_cuda=True)
trainer = get_regression_trainer(learning_rate=0.1, use_ipex=True, bf16=mix_bf16, use_cpu=True)
train_output = trainer.train()
self.assertEqual(train_output.global_step, self.n_epochs * 64 / trainer.args.train_batch_size)
# Check passing num_train_epochs works (and a float version too):
trainer = get_regression_trainer(
learning_rate=0.1, num_train_epochs=1.5, use_ipex=True, bf16=mix_bf16, no_cuda=True
learning_rate=0.1, num_train_epochs=1.5, use_ipex=True, bf16=mix_bf16, use_cpu=True
)
train_output = trainer.train()
self.assertEqual(train_output.global_step, int(1.5 * 64 / trainer.args.train_batch_size))
# If we pass a max_steps, num_train_epochs is ignored
trainer = get_regression_trainer(
learning_rate=0.1, max_steps=10, use_ipex=True, bf16=mix_bf16, no_cuda=True
learning_rate=0.1, max_steps=10, use_ipex=True, bf16=mix_bf16, use_cpu=True
)
train_output = trainer.train()
self.assertEqual(train_output.global_step, 10)
@ -861,7 +864,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
self.assertFalse(is_any_loss_nan_or_inf(log_history_filter))
def test_train_and_eval_dataloaders(self):
n_gpu = max(1, torch.cuda.device_count())
n_gpu = max(1, backend_device_count(torch_device))
trainer = get_regression_trainer(learning_rate=0.1, per_device_train_batch_size=16)
self.assertEqual(trainer.get_train_dataloader().total_batch_size, 16 * n_gpu)
trainer = get_regression_trainer(learning_rate=0.1, per_device_eval_batch_size=16)
@ -898,7 +901,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
trainer.train()
trainer.evaluate()
@require_torch_multi_gpu
@require_torch_multi_accelerator
def test_data_is_not_parallelized_when_model_is_parallel(self):
model = RegressionModel()
# Make the Trainer believe it's a parallelized model
@ -995,12 +998,12 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
expected_acc = AlmostAccuracy()((pred + 1, y))["accuracy"]
self.assertAlmostEqual(results["eval_accuracy"], expected_acc)
@require_torch_bf16_cpu
@require_torch_bf16
@require_intel_extension_for_pytorch
def test_evaluate_with_ipex(self):
for mix_bf16 in [True, False]:
trainer = get_regression_trainer(
a=1.5, b=2.5, use_ipex=True, compute_metrics=AlmostAccuracy(), bf16=mix_bf16, no_cuda=True
a=1.5, b=2.5, use_ipex=True, compute_metrics=AlmostAccuracy(), bf16=mix_bf16, use_cpu=True
)
results = trainer.evaluate()
@ -1019,7 +1022,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
eval_len=66,
compute_metrics=AlmostAccuracy(),
bf16=mix_bf16,
no_cuda=True,
use_cpu=True,
)
results = trainer.evaluate()
@ -1038,7 +1041,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
compute_metrics=AlmostAccuracy(),
preprocess_logits_for_metrics=lambda logits, labels: logits + 1,
bf16=mix_bf16,
no_cuda=True,
use_cpu=True,
)
results = trainer.evaluate()
@ -1115,24 +1118,24 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
self.assertTrue(np.array_equal(labels[0], trainer.eval_dataset.ys[0]))
self.assertTrue(np.array_equal(labels[1], trainer.eval_dataset.ys[1]))
@require_torch_bf16_cpu
@require_torch_bf16
@require_intel_extension_for_pytorch
def test_predict_with_ipex(self):
for mix_bf16 in [True, False]:
trainer = get_regression_trainer(a=1.5, b=2.5, use_ipex=True, bf16=mix_bf16, no_cuda=True)
trainer = get_regression_trainer(a=1.5, b=2.5, use_ipex=True, bf16=mix_bf16, use_cpu=True)
preds = trainer.predict(trainer.eval_dataset).predictions
x = trainer.eval_dataset.x
self.assertTrue(np.allclose(preds, 1.5 * x + 2.5))
# With a number of elements not a round multiple of the batch size
trainer = get_regression_trainer(a=1.5, b=2.5, eval_len=66, use_ipex=True, bf16=mix_bf16, no_cuda=True)
trainer = get_regression_trainer(a=1.5, b=2.5, eval_len=66, use_ipex=True, bf16=mix_bf16, use_cpu=True)
preds = trainer.predict(trainer.eval_dataset).predictions
x = trainer.eval_dataset.x
self.assertTrue(np.allclose(preds, 1.5 * x + 2.5))
# With more than one output of the model
trainer = get_regression_trainer(
a=1.5, b=2.5, double_output=True, use_ipex=True, bf16=mix_bf16, no_cuda=True
a=1.5, b=2.5, double_output=True, use_ipex=True, bf16=mix_bf16, use_cpu=True
)
preds = trainer.predict(trainer.eval_dataset).predictions
x = trainer.eval_dataset.x
@ -1148,7 +1151,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
label_names=["labels", "labels_2"],
use_ipex=True,
bf16=mix_bf16,
no_cuda=True,
use_cpu=True,
)
outputs = trainer.predict(trainer.eval_dataset)
preds = outputs.predictions
@ -1255,7 +1258,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
tmpdir, 5, int(self.n_epochs * 64 / self.batch_size), False, safe_weights=save_safetensors
)
@require_torch_multi_gpu
@require_torch_multi_accelerator
def test_run_seq2seq_double_train_wrap_once(self):
# test that we don't wrap the model more than once
# since wrapping primarily happens on multi-gpu setup we want multiple gpus to test for
@ -1268,7 +1271,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
model_wrapped_after = trainer.model_wrapped
self.assertIs(model_wrapped_before, model_wrapped_after, "should be not wrapped twice")
@require_torch_up_to_2_gpus
@require_torch_up_to_2_accelerators
def test_can_resume_training(self):
# This test will fail for more than 2 GPUs since the batch size will get bigger and with the number of
# save_steps, the checkpoint will resume training at epoch 2 or more (so the data seen by the model
@ -1424,7 +1427,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
@slow
@require_accelerate
@require_torch_non_multi_gpu
@require_torch_non_multi_accelerator
def test_auto_batch_size_finder(self):
if torch.cuda.is_available():
torch.backends.cudnn.deterministic = True
@ -1471,7 +1474,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
trainer.train(resume_from_checkpoint=False)
@require_torch_up_to_2_gpus
@require_torch_up_to_2_accelerators
def test_resume_training_with_shard_checkpoint(self):
# This test will fail for more than 2 GPUs since the batch size will get bigger and with the number of
# save_steps, the checkpoint will resume training at epoch 2 or more (so the data seen by the model
@ -1497,7 +1500,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
self.check_trainer_state_are_the_same(state, state1)
@require_safetensors
@require_torch_up_to_2_gpus
@require_torch_up_to_2_accelerators
def test_resume_training_with_safe_checkpoint(self):
# This test will fail for more than 2 GPUs since the batch size will get bigger and with the number of
# save_steps, the checkpoint will resume training at epoch 2 or more (so the data seen by the model
@ -1532,7 +1535,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
self.assertEqual(b, b1)
self.check_trainer_state_are_the_same(state, state1)
@require_torch_up_to_2_gpus
@require_torch_up_to_2_accelerators
def test_resume_training_with_gradient_accumulation(self):
# This test will fail for more than 2 GPUs since the batch size will get bigger and with the number of
# save_steps, the checkpoint will resume training at epoch 2 or more (so the data seen by the model
@ -1570,7 +1573,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
self.assertEqual(b, b1)
self.check_trainer_state_are_the_same(state, state1)
@require_torch_up_to_2_gpus
@require_torch_up_to_2_accelerators
def test_resume_training_with_frozen_params(self):
# This test will fail for more than 2 GPUs since the batch size will get bigger and with the number of
# save_steps, the checkpoint will resume training at epoch 2 or more (so the data seen by the model
@ -1715,7 +1718,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
)
eval_dataset = GlueDataset(data_args, tokenizer=tokenizer, mode="dev")
training_args = TrainingArguments(output_dir="./examples", no_cuda=True)
training_args = TrainingArguments(output_dir="./examples", use_cpu=True)
trainer = Trainer(model=model, args=training_args, eval_dataset=eval_dataset)
result = trainer.evaluate()
self.assertLess(result["eval_loss"], 0.2)
@ -1920,12 +1923,12 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
trainer = get_regression_trainer(skip_memory_metrics=True)
self.check_mem_metrics(trainer, self.assertNotIn)
@require_torch_gpu
@require_torch_accelerator
def test_fp16_full_eval(self):
# this is a sensitive test so let's keep debugging printouts in place for quick diagnosis.
# it's using pretty large safety margins, but small enough to detect broken functionality.
debug = 0
n_gpus = get_gpu_count()
n_gpus = backend_device_count(torch_device)
bs = 8
eval_len = 16 * n_gpus
@ -2090,15 +2093,15 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
# aggressively fuses the operations and reduce the memory footprint.
self.assertGreater(orig_peak_mem, peak_mem * 2)
@require_torch_gpu
@require_torch_bf16_gpu
@require_torch_accelerator
@require_torch_bf16
def test_bf16_full_eval(self):
# note: most of the logic is the same as test_fp16_full_eval
# this is a sensitive test so let's keep debugging printouts in place for quick diagnosis.
# it's using pretty large safety margins, but small enough to detect broken functionality.
debug = 0
n_gpus = get_gpu_count()
n_gpus = backend_device_count(torch_device)
bs = 8
eval_len = 16 * n_gpus
@ -2163,7 +2166,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
self.assertListEqual(trainer.optimizer.param_groups[1]["params"], no_wd_params)
@slow
@require_torch_multi_gpu
@require_torch_multi_accelerator
def test_end_to_end_example(self):
# Tests that `translation.py` will run without issues
script_path = os.path.abspath(
@ -2302,7 +2305,7 @@ class TrainerIntegrationWithHubTester(unittest.TestCase):
self.assertIn(f"Training in progress, epoch {i}", commits)
def test_push_to_hub_with_saves_each_n_steps(self):
num_gpus = max(1, get_gpu_count())
num_gpus = max(1, backend_device_count(torch_device))
if num_gpus > 2:
return