[tests] multiple improvements (#12294)

* [tests] multiple improvements

* cleanup

* style

* todo to investigate

* fix
This commit is contained in:
Stas Bekman 2021-06-21 19:51:36 -07:00 committed by GitHub
parent dad414d5f9
commit 0d97ba8a98
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 156 additions and 121 deletions

View File

@ -431,6 +431,7 @@ decorators are used to set the requirements of tests CPU/GPU/TPU-wise:
* ``require_torch_gpu`` - as ``require_torch`` plus requires at least 1 GPU
* ``require_torch_multi_gpu`` - as ``require_torch`` plus requires at least 2 GPUs
* ``require_torch_non_multi_gpu`` - as ``require_torch`` plus requires 0 or 1 GPUs
* ``require_torch_up_to_2_gpus`` - as ``require_torch`` plus requires 0 or 1 or 2 GPUs
* ``require_torch_tpu`` - as ``require_torch`` plus requires at least 1 TPU
Let's depict the GPU requirements in the following table:
@ -447,6 +448,8 @@ Let's depict the GPU requirements in the following table:
+----------+----------------------------------+
| ``< 2`` | ``@require_torch_non_multi_gpu`` |
+----------+----------------------------------+
| ``< 3`` | ``@require_torch_up_to_2_gpus`` |
+----------+----------------------------------+
For example, here is a test that must be run only when there are 2 or more GPUs available and pytorch is installed:

View File

@ -383,6 +383,21 @@ def require_torch_non_multi_gpu(test_case):
return test_case
def require_torch_up_to_2_gpus(test_case):
"""
Decorator marking a test that requires 0 or 1 or 2 GPU setup (in PyTorch).
"""
if not is_torch_available():
return unittest.skip("test requires PyTorch")(test_case)
import torch
if torch.cuda.device_count() > 2:
return unittest.skip("test requires 0 or 1 or 2 GPUs")(test_case)
else:
return test_case
def require_torch_tpu(test_case):
"""
Decorator marking a test that requires a TPU (in PyTorch).

View File

@ -15,7 +15,6 @@
import dataclasses
import gc
import math
import os
import random
import re
@ -53,6 +52,8 @@ from transformers.testing_utils import (
require_torch,
require_torch_gpu,
require_torch_multi_gpu,
require_torch_non_multi_gpu,
require_torch_up_to_2_gpus,
slow,
)
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
@ -337,7 +338,14 @@ class TrainerIntegrationCommon:
@require_torch
@require_sentencepiece
@require_tokenizers
class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
class TrainerIntegrationPrerunTest(TestCasePlus, TrainerIntegrationCommon):
"""
Only tests that want to tap into the auto-pre-run 2 trainings:
- self.default_trained_model
- self.alternate_trained_model
directly, or via check_trained_model
"""
def setUp(self):
super().setUp()
args = TrainingArguments(".")
@ -357,6 +365,115 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
self.assertTrue(torch.allclose(model.a, a))
self.assertTrue(torch.allclose(model.b, b))
def test_reproducible_training(self):
# Checks that training worked, model trained and seed made a reproducible training.
trainer = get_regression_trainer(learning_rate=0.1)
trainer.train()
self.check_trained_model(trainer.model)
# Checks that a different seed gets different (reproducible) results.
trainer = get_regression_trainer(learning_rate=0.1, seed=314)
trainer.train()
self.check_trained_model(trainer.model, alternate_seed=True)
@require_datasets
def test_trainer_with_datasets(self):
import datasets
np.random.seed(42)
x = np.random.normal(size=(64,)).astype(np.float32)
y = 2.0 * x + 3.0 + np.random.normal(scale=0.1, size=(64,))
train_dataset = datasets.Dataset.from_dict({"input_x": x, "label": y})
# Base training. Should have the same results as test_reproducible_training
model = RegressionModel()
args = TrainingArguments("./regression", learning_rate=0.1)
trainer = Trainer(model, args, train_dataset=train_dataset)
trainer.train()
self.check_trained_model(trainer.model)
# Can return tensors.
train_dataset.set_format(type="torch", dtype=torch.float32)
model = RegressionModel()
trainer = Trainer(model, args, train_dataset=train_dataset)
trainer.train()
self.check_trained_model(trainer.model)
# Adding one column not used by the model should have no impact
z = np.random.normal(size=(64,)).astype(np.float32)
train_dataset = datasets.Dataset.from_dict({"input_x": x, "label": y, "extra": z})
model = RegressionModel()
trainer = Trainer(model, args, train_dataset=train_dataset)
trainer.train()
self.check_trained_model(trainer.model)
def test_model_init(self):
train_dataset = RegressionDataset()
args = TrainingArguments("./regression", learning_rate=0.1)
trainer = Trainer(args=args, train_dataset=train_dataset, model_init=lambda: RegressionModel())
trainer.train()
self.check_trained_model(trainer.model)
# Re-training should restart from scratch, thus lead the same results.
trainer.train()
self.check_trained_model(trainer.model)
# Re-training should restart from scratch, thus lead the same results and new seed should be used.
trainer.args.seed = 314
trainer.train()
self.check_trained_model(trainer.model, alternate_seed=True)
def test_gradient_accumulation(self):
# Training with half the batch size but accumulation steps as 2 should give the same results.
trainer = get_regression_trainer(
gradient_accumulation_steps=2, per_device_train_batch_size=4, learning_rate=0.1
)
trainer.train()
self.check_trained_model(trainer.model)
def test_custom_optimizer(self):
train_dataset = RegressionDataset()
args = TrainingArguments("./regression")
model = RegressionModel()
optimizer = torch.optim.SGD(model.parameters(), lr=1.0)
lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda x: 1.0)
trainer = Trainer(model, args, train_dataset=train_dataset, optimizers=(optimizer, lr_scheduler))
trainer.train()
(a, b) = self.default_trained_model
self.assertFalse(torch.allclose(trainer.model.a, a))
self.assertFalse(torch.allclose(trainer.model.b, b))
self.assertEqual(trainer.optimizer.state_dict()["param_groups"][0]["lr"], 1.0)
def test_adafactor_lr_none(self):
# test the special case where lr=None, since Trainer can't not have lr_scheduler
from transformers.optimization import Adafactor, AdafactorSchedule
train_dataset = RegressionDataset()
args = TrainingArguments("./regression")
model = RegressionModel()
optimizer = Adafactor(model.parameters(), scale_parameter=True, relative_step=True, warmup_init=True, lr=None)
lr_scheduler = AdafactorSchedule(optimizer)
trainer = Trainer(model, args, train_dataset=train_dataset, optimizers=(optimizer, lr_scheduler))
trainer.train()
(a, b) = self.default_trained_model
self.assertFalse(torch.allclose(trainer.model.a, a))
self.assertFalse(torch.allclose(trainer.model.b, b))
self.assertGreater(trainer.optimizer.state_dict()["param_groups"][0]["lr"], 0)
@require_torch
@require_sentencepiece
@require_tokenizers
class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
def setUp(self):
super().setUp()
args = TrainingArguments(".")
self.n_epochs = args.num_train_epochs
self.batch_size = args.train_batch_size
def test_trainer_works_with_dict(self):
# Edge case because Apex with mode O2 will change our models to return dicts. This test checks it doesn't break
# anything.
@ -394,17 +511,6 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
if key != "logging_dir":
self.assertEqual(dict1[key], dict2[key])
def test_reproducible_training(self):
# Checks that training worked, model trained and seed made a reproducible training.
trainer = get_regression_trainer(learning_rate=0.1)
trainer.train()
self.check_trained_model(trainer.model)
# Checks that a different seed gets different (reproducible) results.
trainer = get_regression_trainer(learning_rate=0.1, seed=314)
trainer.train()
self.check_trained_model(trainer.model, alternate_seed=True)
def test_number_of_steps_in_training(self):
# Regular training has n_epochs * len(train_dl) steps
trainer = get_regression_trainer(learning_rate=0.1)
@ -558,70 +664,6 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
self.assertTrue(np.array_equal(2 * expected + 1, seen[: expected.shape[0]]))
self.assertTrue(np.all(seen[expected.shape[0] :] == -100))
@require_datasets
def test_trainer_with_datasets(self):
import datasets
np.random.seed(42)
x = np.random.normal(size=(64,)).astype(np.float32)
y = 2.0 * x + 3.0 + np.random.normal(scale=0.1, size=(64,))
train_dataset = datasets.Dataset.from_dict({"input_x": x, "label": y})
# Base training. Should have the same results as test_reproducible_training
model = RegressionModel()
args = TrainingArguments("./regression", learning_rate=0.1)
trainer = Trainer(model, args, train_dataset=train_dataset)
trainer.train()
self.check_trained_model(trainer.model)
# Can return tensors.
train_dataset.set_format(type="torch", dtype=torch.float32)
model = RegressionModel()
trainer = Trainer(model, args, train_dataset=train_dataset)
trainer.train()
self.check_trained_model(trainer.model)
# Adding one column not used by the model should have no impact
z = np.random.normal(size=(64,)).astype(np.float32)
train_dataset = datasets.Dataset.from_dict({"input_x": x, "label": y, "extra": z})
model = RegressionModel()
trainer = Trainer(model, args, train_dataset=train_dataset)
trainer.train()
self.check_trained_model(trainer.model)
def test_custom_optimizer(self):
train_dataset = RegressionDataset()
args = TrainingArguments("./regression")
model = RegressionModel()
optimizer = torch.optim.SGD(model.parameters(), lr=1.0)
lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda x: 1.0)
trainer = Trainer(model, args, train_dataset=train_dataset, optimizers=(optimizer, lr_scheduler))
trainer.train()
(a, b) = self.default_trained_model
self.assertFalse(torch.allclose(trainer.model.a, a))
self.assertFalse(torch.allclose(trainer.model.b, b))
self.assertEqual(trainer.optimizer.state_dict()["param_groups"][0]["lr"], 1.0)
@require_torch
def test_adafactor_lr_none(self):
# test the special case where lr=None, since Trainer can't not have lr_scheduler
from transformers.optimization import Adafactor, AdafactorSchedule
train_dataset = RegressionDataset()
args = TrainingArguments("./regression")
model = RegressionModel()
optimizer = Adafactor(model.parameters(), scale_parameter=True, relative_step=True, warmup_init=True, lr=None)
lr_scheduler = AdafactorSchedule(optimizer)
trainer = Trainer(model, args, train_dataset=train_dataset, optimizers=(optimizer, lr_scheduler))
trainer.train()
(a, b) = self.default_trained_model
self.assertFalse(torch.allclose(trainer.model.a, a))
self.assertFalse(torch.allclose(trainer.model.b, b))
self.assertGreater(trainer.optimizer.state_dict()["param_groups"][0]["lr"], 0)
def test_log_level(self):
# testing only --log_level (--log_level_replica requires multiple nodes)
logger = logging.get_logger()
@ -645,22 +687,6 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
trainer.train()
self.assertNotIn(log_info_string, cl.out)
def test_model_init(self):
train_dataset = RegressionDataset()
args = TrainingArguments("./regression", learning_rate=0.1)
trainer = Trainer(args=args, train_dataset=train_dataset, model_init=lambda: RegressionModel())
trainer.train()
self.check_trained_model(trainer.model)
# Re-training should restart from scratch, thus lead the same results.
trainer.train()
self.check_trained_model(trainer.model)
# Re-training should restart from scratch, thus lead the same results and new seed should be used.
trainer.args.seed = 314
trainer.train()
self.check_trained_model(trainer.model, alternate_seed=True)
def test_save_checkpoints(self):
with tempfile.TemporaryDirectory() as tmpdir:
trainer = get_regression_trainer(output_dir=tmpdir, save_steps=5)
@ -673,14 +699,6 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
trainer.train()
self.check_saved_checkpoints(tmpdir, 5, int(self.n_epochs * 64 / self.batch_size), False)
def test_gradient_accumulation(self):
# Training with half the batch size but accumulation steps as 2 should give the same results.
trainer = get_regression_trainer(
gradient_accumulation_steps=2, per_device_train_batch_size=4, learning_rate=0.1
)
trainer.train()
self.check_trained_model(trainer.model)
@require_torch_multi_gpu
def test_run_seq2seq_double_train_wrap_once(self):
# test that we don't wrap the model more than once
@ -694,12 +712,11 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
model_wrapped_after = trainer.model_wrapped
self.assertIs(model_wrapped_before, model_wrapped_after, "should be not wrapped twice")
@require_torch_up_to_2_gpus
def test_can_resume_training(self):
if torch.cuda.device_count() > 2:
# This test will fail for more than 2 GPUs since the batch size will get bigger and with the number of
# save_steps, the checkpoint will resume training at epoch 2 or more (so the data seen by the model
# won't be the same since the training dataloader is shuffled).
return
# This test will fail for more than 2 GPUs since the batch size will get bigger and with the number of
# save_steps, the checkpoint will resume training at epoch 2 or more (so the data seen by the model
# won't be the same since the training dataloader is shuffled).
with tempfile.TemporaryDirectory() as tmpdir:
kwargs = dict(output_dir=tmpdir, train_len=128, save_steps=5, learning_rate=0.1)
@ -782,10 +799,10 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
trainer.train(resume_from_checkpoint=True)
self.assertTrue("No valid checkpoint found in output directory" in str(context.exception))
@require_torch_non_multi_gpu
def test_resume_training_with_randomness(self):
if torch.cuda.device_count() >= 2:
# This test will fail flakily for more than 2 GPUs since the result will be slightly more different.
return
# This test will fail flakily for more than 1 GPUs since the result will be slightly more different
# TODO: investigate why it fails for 2 GPUs?
if torch.cuda.is_available():
torch.backends.cudnn.deterministic = True
@ -807,15 +824,15 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
trainer.train(resume_from_checkpoint=os.path.join(tmp_dir, "checkpoint-15"))
(a1, b1) = trainer.model.a.item(), trainer.model.b.item()
self.assertTrue(math.isclose(a, a1, rel_tol=1e-8))
self.assertTrue(math.isclose(b, b1, rel_tol=1e-8))
self.assertAlmostEqual(a, a1, delta=1e-8)
self.assertAlmostEqual(b, b1, delta=1e-8)
@require_torch_up_to_2_gpus
def test_resume_training_with_gradient_accumulation(self):
if torch.cuda.device_count() > 2:
# This test will fail for more than 2 GPUs since the batch size will get bigger and with the number of
# save_steps, the checkpoint will resume training at epoch 2 or more (so the data seen by the model
# won't be the same since the training dataloader is shuffled).
return
# This test will fail for more than 2 GPUs since the batch size will get bigger and with the number of
# save_steps, the checkpoint will resume training at epoch 2 or more (so the data seen by the model
# won't be the same since the training dataloader is shuffled).
with tempfile.TemporaryDirectory() as tmpdir:
trainer = get_regression_trainer(
output_dir=tmpdir,
@ -848,12 +865,12 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
self.assertEqual(b, b1)
self.check_trainer_state_are_the_same(state, state1)
@require_torch_up_to_2_gpus
def test_resume_training_with_frozen_params(self):
if torch.cuda.device_count() > 2:
# This test will fail for more than 2 GPUs since the batch size will get bigger and with the number of
# save_steps, the checkpoint will resume training at epoch 2 or more (so the data seen by the model
# won't be the same since the training dataloader is shuffled).
return
# This test will fail for more than 2 GPUs since the batch size will get bigger and with the number of
# save_steps, the checkpoint will resume training at epoch 2 or more (so the data seen by the model
# won't be the same since the training dataloader is shuffled).
with tempfile.TemporaryDirectory() as tmpdir:
trainer = get_regression_trainer(
output_dir=tmpdir,