From 0d97ba8a9846b4856f8920bf0c955dbce91e8b4f Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Mon, 21 Jun 2021 19:51:36 -0700 Subject: [PATCH] [tests] multiple improvements (#12294) * [tests] multiple improvements * cleanup * style * todo to investigate * fix --- docs/source/testing.rst | 3 + src/transformers/testing_utils.py | 15 ++ tests/test_trainer.py | 259 ++++++++++++++++-------------- 3 files changed, 156 insertions(+), 121 deletions(-) diff --git a/docs/source/testing.rst b/docs/source/testing.rst index 665a1d8f315..68da03bfa9c 100644 --- a/docs/source/testing.rst +++ b/docs/source/testing.rst @@ -431,6 +431,7 @@ decorators are used to set the requirements of tests CPU/GPU/TPU-wise: * ``require_torch_gpu`` - as ``require_torch`` plus requires at least 1 GPU * ``require_torch_multi_gpu`` - as ``require_torch`` plus requires at least 2 GPUs * ``require_torch_non_multi_gpu`` - as ``require_torch`` plus requires 0 or 1 GPUs +* ``require_torch_up_to_2_gpus`` - as ``require_torch`` plus requires 0 or 1 or 2 GPUs * ``require_torch_tpu`` - as ``require_torch`` plus requires at least 1 TPU Let's depict the GPU requirements in the following table: @@ -447,6 +448,8 @@ Let's depict the GPU requirements in the following table: +----------+----------------------------------+ | ``< 2`` | ``@require_torch_non_multi_gpu`` | +----------+----------------------------------+ +| ``< 3`` | ``@require_torch_up_to_2_gpus`` | ++----------+----------------------------------+ For example, here is a test that must be run only when there are 2 or more GPUs available and pytorch is installed: diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index ca607c33016..d315785ed95 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -383,6 +383,21 @@ def require_torch_non_multi_gpu(test_case): return test_case +def require_torch_up_to_2_gpus(test_case): + """ + Decorator marking a test that requires 0 or 1 or 2 GPU setup (in PyTorch). + """ + if not is_torch_available(): + return unittest.skip("test requires PyTorch")(test_case) + + import torch + + if torch.cuda.device_count() > 2: + return unittest.skip("test requires 0 or 1 or 2 GPUs")(test_case) + else: + return test_case + + def require_torch_tpu(test_case): """ Decorator marking a test that requires a TPU (in PyTorch). diff --git a/tests/test_trainer.py b/tests/test_trainer.py index 7e3e02e618f..7107cea56df 100644 --- a/tests/test_trainer.py +++ b/tests/test_trainer.py @@ -15,7 +15,6 @@ import dataclasses import gc -import math import os import random import re @@ -53,6 +52,8 @@ from transformers.testing_utils import ( require_torch, require_torch_gpu, require_torch_multi_gpu, + require_torch_non_multi_gpu, + require_torch_up_to_2_gpus, slow, ) from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR @@ -337,7 +338,14 @@ class TrainerIntegrationCommon: @require_torch @require_sentencepiece @require_tokenizers -class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon): +class TrainerIntegrationPrerunTest(TestCasePlus, TrainerIntegrationCommon): + """ + Only tests that want to tap into the auto-pre-run 2 trainings: + - self.default_trained_model + - self.alternate_trained_model + directly, or via check_trained_model + """ + def setUp(self): super().setUp() args = TrainingArguments(".") @@ -357,6 +365,115 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon): self.assertTrue(torch.allclose(model.a, a)) self.assertTrue(torch.allclose(model.b, b)) + def test_reproducible_training(self): + # Checks that training worked, model trained and seed made a reproducible training. + trainer = get_regression_trainer(learning_rate=0.1) + trainer.train() + self.check_trained_model(trainer.model) + + # Checks that a different seed gets different (reproducible) results. + trainer = get_regression_trainer(learning_rate=0.1, seed=314) + trainer.train() + self.check_trained_model(trainer.model, alternate_seed=True) + + @require_datasets + def test_trainer_with_datasets(self): + import datasets + + np.random.seed(42) + x = np.random.normal(size=(64,)).astype(np.float32) + y = 2.0 * x + 3.0 + np.random.normal(scale=0.1, size=(64,)) + train_dataset = datasets.Dataset.from_dict({"input_x": x, "label": y}) + + # Base training. Should have the same results as test_reproducible_training + model = RegressionModel() + args = TrainingArguments("./regression", learning_rate=0.1) + trainer = Trainer(model, args, train_dataset=train_dataset) + trainer.train() + self.check_trained_model(trainer.model) + + # Can return tensors. + train_dataset.set_format(type="torch", dtype=torch.float32) + model = RegressionModel() + trainer = Trainer(model, args, train_dataset=train_dataset) + trainer.train() + self.check_trained_model(trainer.model) + + # Adding one column not used by the model should have no impact + z = np.random.normal(size=(64,)).astype(np.float32) + train_dataset = datasets.Dataset.from_dict({"input_x": x, "label": y, "extra": z}) + model = RegressionModel() + trainer = Trainer(model, args, train_dataset=train_dataset) + trainer.train() + self.check_trained_model(trainer.model) + + def test_model_init(self): + train_dataset = RegressionDataset() + args = TrainingArguments("./regression", learning_rate=0.1) + trainer = Trainer(args=args, train_dataset=train_dataset, model_init=lambda: RegressionModel()) + trainer.train() + self.check_trained_model(trainer.model) + + # Re-training should restart from scratch, thus lead the same results. + trainer.train() + self.check_trained_model(trainer.model) + + # Re-training should restart from scratch, thus lead the same results and new seed should be used. + trainer.args.seed = 314 + trainer.train() + self.check_trained_model(trainer.model, alternate_seed=True) + + def test_gradient_accumulation(self): + # Training with half the batch size but accumulation steps as 2 should give the same results. + trainer = get_regression_trainer( + gradient_accumulation_steps=2, per_device_train_batch_size=4, learning_rate=0.1 + ) + trainer.train() + self.check_trained_model(trainer.model) + + def test_custom_optimizer(self): + train_dataset = RegressionDataset() + args = TrainingArguments("./regression") + model = RegressionModel() + optimizer = torch.optim.SGD(model.parameters(), lr=1.0) + lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda x: 1.0) + trainer = Trainer(model, args, train_dataset=train_dataset, optimizers=(optimizer, lr_scheduler)) + trainer.train() + + (a, b) = self.default_trained_model + self.assertFalse(torch.allclose(trainer.model.a, a)) + self.assertFalse(torch.allclose(trainer.model.b, b)) + self.assertEqual(trainer.optimizer.state_dict()["param_groups"][0]["lr"], 1.0) + + def test_adafactor_lr_none(self): + # test the special case where lr=None, since Trainer can't not have lr_scheduler + + from transformers.optimization import Adafactor, AdafactorSchedule + + train_dataset = RegressionDataset() + args = TrainingArguments("./regression") + model = RegressionModel() + optimizer = Adafactor(model.parameters(), scale_parameter=True, relative_step=True, warmup_init=True, lr=None) + lr_scheduler = AdafactorSchedule(optimizer) + trainer = Trainer(model, args, train_dataset=train_dataset, optimizers=(optimizer, lr_scheduler)) + trainer.train() + + (a, b) = self.default_trained_model + self.assertFalse(torch.allclose(trainer.model.a, a)) + self.assertFalse(torch.allclose(trainer.model.b, b)) + self.assertGreater(trainer.optimizer.state_dict()["param_groups"][0]["lr"], 0) + + +@require_torch +@require_sentencepiece +@require_tokenizers +class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon): + def setUp(self): + super().setUp() + args = TrainingArguments(".") + self.n_epochs = args.num_train_epochs + self.batch_size = args.train_batch_size + def test_trainer_works_with_dict(self): # Edge case because Apex with mode O2 will change our models to return dicts. This test checks it doesn't break # anything. @@ -394,17 +511,6 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon): if key != "logging_dir": self.assertEqual(dict1[key], dict2[key]) - def test_reproducible_training(self): - # Checks that training worked, model trained and seed made a reproducible training. - trainer = get_regression_trainer(learning_rate=0.1) - trainer.train() - self.check_trained_model(trainer.model) - - # Checks that a different seed gets different (reproducible) results. - trainer = get_regression_trainer(learning_rate=0.1, seed=314) - trainer.train() - self.check_trained_model(trainer.model, alternate_seed=True) - def test_number_of_steps_in_training(self): # Regular training has n_epochs * len(train_dl) steps trainer = get_regression_trainer(learning_rate=0.1) @@ -558,70 +664,6 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon): self.assertTrue(np.array_equal(2 * expected + 1, seen[: expected.shape[0]])) self.assertTrue(np.all(seen[expected.shape[0] :] == -100)) - @require_datasets - def test_trainer_with_datasets(self): - import datasets - - np.random.seed(42) - x = np.random.normal(size=(64,)).astype(np.float32) - y = 2.0 * x + 3.0 + np.random.normal(scale=0.1, size=(64,)) - train_dataset = datasets.Dataset.from_dict({"input_x": x, "label": y}) - - # Base training. Should have the same results as test_reproducible_training - model = RegressionModel() - args = TrainingArguments("./regression", learning_rate=0.1) - trainer = Trainer(model, args, train_dataset=train_dataset) - trainer.train() - self.check_trained_model(trainer.model) - - # Can return tensors. - train_dataset.set_format(type="torch", dtype=torch.float32) - model = RegressionModel() - trainer = Trainer(model, args, train_dataset=train_dataset) - trainer.train() - self.check_trained_model(trainer.model) - - # Adding one column not used by the model should have no impact - z = np.random.normal(size=(64,)).astype(np.float32) - train_dataset = datasets.Dataset.from_dict({"input_x": x, "label": y, "extra": z}) - model = RegressionModel() - trainer = Trainer(model, args, train_dataset=train_dataset) - trainer.train() - self.check_trained_model(trainer.model) - - def test_custom_optimizer(self): - train_dataset = RegressionDataset() - args = TrainingArguments("./regression") - model = RegressionModel() - optimizer = torch.optim.SGD(model.parameters(), lr=1.0) - lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda x: 1.0) - trainer = Trainer(model, args, train_dataset=train_dataset, optimizers=(optimizer, lr_scheduler)) - trainer.train() - - (a, b) = self.default_trained_model - self.assertFalse(torch.allclose(trainer.model.a, a)) - self.assertFalse(torch.allclose(trainer.model.b, b)) - self.assertEqual(trainer.optimizer.state_dict()["param_groups"][0]["lr"], 1.0) - - @require_torch - def test_adafactor_lr_none(self): - # test the special case where lr=None, since Trainer can't not have lr_scheduler - - from transformers.optimization import Adafactor, AdafactorSchedule - - train_dataset = RegressionDataset() - args = TrainingArguments("./regression") - model = RegressionModel() - optimizer = Adafactor(model.parameters(), scale_parameter=True, relative_step=True, warmup_init=True, lr=None) - lr_scheduler = AdafactorSchedule(optimizer) - trainer = Trainer(model, args, train_dataset=train_dataset, optimizers=(optimizer, lr_scheduler)) - trainer.train() - - (a, b) = self.default_trained_model - self.assertFalse(torch.allclose(trainer.model.a, a)) - self.assertFalse(torch.allclose(trainer.model.b, b)) - self.assertGreater(trainer.optimizer.state_dict()["param_groups"][0]["lr"], 0) - def test_log_level(self): # testing only --log_level (--log_level_replica requires multiple nodes) logger = logging.get_logger() @@ -645,22 +687,6 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon): trainer.train() self.assertNotIn(log_info_string, cl.out) - def test_model_init(self): - train_dataset = RegressionDataset() - args = TrainingArguments("./regression", learning_rate=0.1) - trainer = Trainer(args=args, train_dataset=train_dataset, model_init=lambda: RegressionModel()) - trainer.train() - self.check_trained_model(trainer.model) - - # Re-training should restart from scratch, thus lead the same results. - trainer.train() - self.check_trained_model(trainer.model) - - # Re-training should restart from scratch, thus lead the same results and new seed should be used. - trainer.args.seed = 314 - trainer.train() - self.check_trained_model(trainer.model, alternate_seed=True) - def test_save_checkpoints(self): with tempfile.TemporaryDirectory() as tmpdir: trainer = get_regression_trainer(output_dir=tmpdir, save_steps=5) @@ -673,14 +699,6 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon): trainer.train() self.check_saved_checkpoints(tmpdir, 5, int(self.n_epochs * 64 / self.batch_size), False) - def test_gradient_accumulation(self): - # Training with half the batch size but accumulation steps as 2 should give the same results. - trainer = get_regression_trainer( - gradient_accumulation_steps=2, per_device_train_batch_size=4, learning_rate=0.1 - ) - trainer.train() - self.check_trained_model(trainer.model) - @require_torch_multi_gpu def test_run_seq2seq_double_train_wrap_once(self): # test that we don't wrap the model more than once @@ -694,12 +712,11 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon): model_wrapped_after = trainer.model_wrapped self.assertIs(model_wrapped_before, model_wrapped_after, "should be not wrapped twice") + @require_torch_up_to_2_gpus def test_can_resume_training(self): - if torch.cuda.device_count() > 2: - # This test will fail for more than 2 GPUs since the batch size will get bigger and with the number of - # save_steps, the checkpoint will resume training at epoch 2 or more (so the data seen by the model - # won't be the same since the training dataloader is shuffled). - return + # This test will fail for more than 2 GPUs since the batch size will get bigger and with the number of + # save_steps, the checkpoint will resume training at epoch 2 or more (so the data seen by the model + # won't be the same since the training dataloader is shuffled). with tempfile.TemporaryDirectory() as tmpdir: kwargs = dict(output_dir=tmpdir, train_len=128, save_steps=5, learning_rate=0.1) @@ -782,10 +799,10 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon): trainer.train(resume_from_checkpoint=True) self.assertTrue("No valid checkpoint found in output directory" in str(context.exception)) + @require_torch_non_multi_gpu def test_resume_training_with_randomness(self): - if torch.cuda.device_count() >= 2: - # This test will fail flakily for more than 2 GPUs since the result will be slightly more different. - return + # This test will fail flakily for more than 1 GPUs since the result will be slightly more different + # TODO: investigate why it fails for 2 GPUs? if torch.cuda.is_available(): torch.backends.cudnn.deterministic = True @@ -807,15 +824,15 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon): trainer.train(resume_from_checkpoint=os.path.join(tmp_dir, "checkpoint-15")) (a1, b1) = trainer.model.a.item(), trainer.model.b.item() - self.assertTrue(math.isclose(a, a1, rel_tol=1e-8)) - self.assertTrue(math.isclose(b, b1, rel_tol=1e-8)) + self.assertAlmostEqual(a, a1, delta=1e-8) + self.assertAlmostEqual(b, b1, delta=1e-8) + @require_torch_up_to_2_gpus def test_resume_training_with_gradient_accumulation(self): - if torch.cuda.device_count() > 2: - # This test will fail for more than 2 GPUs since the batch size will get bigger and with the number of - # save_steps, the checkpoint will resume training at epoch 2 or more (so the data seen by the model - # won't be the same since the training dataloader is shuffled). - return + # This test will fail for more than 2 GPUs since the batch size will get bigger and with the number of + # save_steps, the checkpoint will resume training at epoch 2 or more (so the data seen by the model + # won't be the same since the training dataloader is shuffled). + with tempfile.TemporaryDirectory() as tmpdir: trainer = get_regression_trainer( output_dir=tmpdir, @@ -848,12 +865,12 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon): self.assertEqual(b, b1) self.check_trainer_state_are_the_same(state, state1) + @require_torch_up_to_2_gpus def test_resume_training_with_frozen_params(self): - if torch.cuda.device_count() > 2: - # This test will fail for more than 2 GPUs since the batch size will get bigger and with the number of - # save_steps, the checkpoint will resume training at epoch 2 or more (so the data seen by the model - # won't be the same since the training dataloader is shuffled). - return + # This test will fail for more than 2 GPUs since the batch size will get bigger and with the number of + # save_steps, the checkpoint will resume training at epoch 2 or more (so the data seen by the model + # won't be the same since the training dataloader is shuffled). + with tempfile.TemporaryDirectory() as tmpdir: trainer = get_regression_trainer( output_dir=tmpdir,