mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
[tests] multiple improvements (#12294)
* [tests] multiple improvements * cleanup * style * todo to investigate * fix
This commit is contained in:
parent
dad414d5f9
commit
0d97ba8a98
@ -431,6 +431,7 @@ decorators are used to set the requirements of tests CPU/GPU/TPU-wise:
|
||||
* ``require_torch_gpu`` - as ``require_torch`` plus requires at least 1 GPU
|
||||
* ``require_torch_multi_gpu`` - as ``require_torch`` plus requires at least 2 GPUs
|
||||
* ``require_torch_non_multi_gpu`` - as ``require_torch`` plus requires 0 or 1 GPUs
|
||||
* ``require_torch_up_to_2_gpus`` - as ``require_torch`` plus requires 0 or 1 or 2 GPUs
|
||||
* ``require_torch_tpu`` - as ``require_torch`` plus requires at least 1 TPU
|
||||
|
||||
Let's depict the GPU requirements in the following table:
|
||||
@ -447,6 +448,8 @@ Let's depict the GPU requirements in the following table:
|
||||
+----------+----------------------------------+
|
||||
| ``< 2`` | ``@require_torch_non_multi_gpu`` |
|
||||
+----------+----------------------------------+
|
||||
| ``< 3`` | ``@require_torch_up_to_2_gpus`` |
|
||||
+----------+----------------------------------+
|
||||
|
||||
|
||||
For example, here is a test that must be run only when there are 2 or more GPUs available and pytorch is installed:
|
||||
|
@ -383,6 +383,21 @@ def require_torch_non_multi_gpu(test_case):
|
||||
return test_case
|
||||
|
||||
|
||||
def require_torch_up_to_2_gpus(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires 0 or 1 or 2 GPU setup (in PyTorch).
|
||||
"""
|
||||
if not is_torch_available():
|
||||
return unittest.skip("test requires PyTorch")(test_case)
|
||||
|
||||
import torch
|
||||
|
||||
if torch.cuda.device_count() > 2:
|
||||
return unittest.skip("test requires 0 or 1 or 2 GPUs")(test_case)
|
||||
else:
|
||||
return test_case
|
||||
|
||||
|
||||
def require_torch_tpu(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires a TPU (in PyTorch).
|
||||
|
@ -15,7 +15,6 @@
|
||||
|
||||
import dataclasses
|
||||
import gc
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
@ -53,6 +52,8 @@ from transformers.testing_utils import (
|
||||
require_torch,
|
||||
require_torch_gpu,
|
||||
require_torch_multi_gpu,
|
||||
require_torch_non_multi_gpu,
|
||||
require_torch_up_to_2_gpus,
|
||||
slow,
|
||||
)
|
||||
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
|
||||
@ -337,7 +338,14 @@ class TrainerIntegrationCommon:
|
||||
@require_torch
|
||||
@require_sentencepiece
|
||||
@require_tokenizers
|
||||
class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
class TrainerIntegrationPrerunTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
"""
|
||||
Only tests that want to tap into the auto-pre-run 2 trainings:
|
||||
- self.default_trained_model
|
||||
- self.alternate_trained_model
|
||||
directly, or via check_trained_model
|
||||
"""
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
args = TrainingArguments(".")
|
||||
@ -357,6 +365,115 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
self.assertTrue(torch.allclose(model.a, a))
|
||||
self.assertTrue(torch.allclose(model.b, b))
|
||||
|
||||
def test_reproducible_training(self):
|
||||
# Checks that training worked, model trained and seed made a reproducible training.
|
||||
trainer = get_regression_trainer(learning_rate=0.1)
|
||||
trainer.train()
|
||||
self.check_trained_model(trainer.model)
|
||||
|
||||
# Checks that a different seed gets different (reproducible) results.
|
||||
trainer = get_regression_trainer(learning_rate=0.1, seed=314)
|
||||
trainer.train()
|
||||
self.check_trained_model(trainer.model, alternate_seed=True)
|
||||
|
||||
@require_datasets
|
||||
def test_trainer_with_datasets(self):
|
||||
import datasets
|
||||
|
||||
np.random.seed(42)
|
||||
x = np.random.normal(size=(64,)).astype(np.float32)
|
||||
y = 2.0 * x + 3.0 + np.random.normal(scale=0.1, size=(64,))
|
||||
train_dataset = datasets.Dataset.from_dict({"input_x": x, "label": y})
|
||||
|
||||
# Base training. Should have the same results as test_reproducible_training
|
||||
model = RegressionModel()
|
||||
args = TrainingArguments("./regression", learning_rate=0.1)
|
||||
trainer = Trainer(model, args, train_dataset=train_dataset)
|
||||
trainer.train()
|
||||
self.check_trained_model(trainer.model)
|
||||
|
||||
# Can return tensors.
|
||||
train_dataset.set_format(type="torch", dtype=torch.float32)
|
||||
model = RegressionModel()
|
||||
trainer = Trainer(model, args, train_dataset=train_dataset)
|
||||
trainer.train()
|
||||
self.check_trained_model(trainer.model)
|
||||
|
||||
# Adding one column not used by the model should have no impact
|
||||
z = np.random.normal(size=(64,)).astype(np.float32)
|
||||
train_dataset = datasets.Dataset.from_dict({"input_x": x, "label": y, "extra": z})
|
||||
model = RegressionModel()
|
||||
trainer = Trainer(model, args, train_dataset=train_dataset)
|
||||
trainer.train()
|
||||
self.check_trained_model(trainer.model)
|
||||
|
||||
def test_model_init(self):
|
||||
train_dataset = RegressionDataset()
|
||||
args = TrainingArguments("./regression", learning_rate=0.1)
|
||||
trainer = Trainer(args=args, train_dataset=train_dataset, model_init=lambda: RegressionModel())
|
||||
trainer.train()
|
||||
self.check_trained_model(trainer.model)
|
||||
|
||||
# Re-training should restart from scratch, thus lead the same results.
|
||||
trainer.train()
|
||||
self.check_trained_model(trainer.model)
|
||||
|
||||
# Re-training should restart from scratch, thus lead the same results and new seed should be used.
|
||||
trainer.args.seed = 314
|
||||
trainer.train()
|
||||
self.check_trained_model(trainer.model, alternate_seed=True)
|
||||
|
||||
def test_gradient_accumulation(self):
|
||||
# Training with half the batch size but accumulation steps as 2 should give the same results.
|
||||
trainer = get_regression_trainer(
|
||||
gradient_accumulation_steps=2, per_device_train_batch_size=4, learning_rate=0.1
|
||||
)
|
||||
trainer.train()
|
||||
self.check_trained_model(trainer.model)
|
||||
|
||||
def test_custom_optimizer(self):
|
||||
train_dataset = RegressionDataset()
|
||||
args = TrainingArguments("./regression")
|
||||
model = RegressionModel()
|
||||
optimizer = torch.optim.SGD(model.parameters(), lr=1.0)
|
||||
lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda x: 1.0)
|
||||
trainer = Trainer(model, args, train_dataset=train_dataset, optimizers=(optimizer, lr_scheduler))
|
||||
trainer.train()
|
||||
|
||||
(a, b) = self.default_trained_model
|
||||
self.assertFalse(torch.allclose(trainer.model.a, a))
|
||||
self.assertFalse(torch.allclose(trainer.model.b, b))
|
||||
self.assertEqual(trainer.optimizer.state_dict()["param_groups"][0]["lr"], 1.0)
|
||||
|
||||
def test_adafactor_lr_none(self):
|
||||
# test the special case where lr=None, since Trainer can't not have lr_scheduler
|
||||
|
||||
from transformers.optimization import Adafactor, AdafactorSchedule
|
||||
|
||||
train_dataset = RegressionDataset()
|
||||
args = TrainingArguments("./regression")
|
||||
model = RegressionModel()
|
||||
optimizer = Adafactor(model.parameters(), scale_parameter=True, relative_step=True, warmup_init=True, lr=None)
|
||||
lr_scheduler = AdafactorSchedule(optimizer)
|
||||
trainer = Trainer(model, args, train_dataset=train_dataset, optimizers=(optimizer, lr_scheduler))
|
||||
trainer.train()
|
||||
|
||||
(a, b) = self.default_trained_model
|
||||
self.assertFalse(torch.allclose(trainer.model.a, a))
|
||||
self.assertFalse(torch.allclose(trainer.model.b, b))
|
||||
self.assertGreater(trainer.optimizer.state_dict()["param_groups"][0]["lr"], 0)
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_sentencepiece
|
||||
@require_tokenizers
|
||||
class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
args = TrainingArguments(".")
|
||||
self.n_epochs = args.num_train_epochs
|
||||
self.batch_size = args.train_batch_size
|
||||
|
||||
def test_trainer_works_with_dict(self):
|
||||
# Edge case because Apex with mode O2 will change our models to return dicts. This test checks it doesn't break
|
||||
# anything.
|
||||
@ -394,17 +511,6 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
if key != "logging_dir":
|
||||
self.assertEqual(dict1[key], dict2[key])
|
||||
|
||||
def test_reproducible_training(self):
|
||||
# Checks that training worked, model trained and seed made a reproducible training.
|
||||
trainer = get_regression_trainer(learning_rate=0.1)
|
||||
trainer.train()
|
||||
self.check_trained_model(trainer.model)
|
||||
|
||||
# Checks that a different seed gets different (reproducible) results.
|
||||
trainer = get_regression_trainer(learning_rate=0.1, seed=314)
|
||||
trainer.train()
|
||||
self.check_trained_model(trainer.model, alternate_seed=True)
|
||||
|
||||
def test_number_of_steps_in_training(self):
|
||||
# Regular training has n_epochs * len(train_dl) steps
|
||||
trainer = get_regression_trainer(learning_rate=0.1)
|
||||
@ -558,70 +664,6 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
self.assertTrue(np.array_equal(2 * expected + 1, seen[: expected.shape[0]]))
|
||||
self.assertTrue(np.all(seen[expected.shape[0] :] == -100))
|
||||
|
||||
@require_datasets
|
||||
def test_trainer_with_datasets(self):
|
||||
import datasets
|
||||
|
||||
np.random.seed(42)
|
||||
x = np.random.normal(size=(64,)).astype(np.float32)
|
||||
y = 2.0 * x + 3.0 + np.random.normal(scale=0.1, size=(64,))
|
||||
train_dataset = datasets.Dataset.from_dict({"input_x": x, "label": y})
|
||||
|
||||
# Base training. Should have the same results as test_reproducible_training
|
||||
model = RegressionModel()
|
||||
args = TrainingArguments("./regression", learning_rate=0.1)
|
||||
trainer = Trainer(model, args, train_dataset=train_dataset)
|
||||
trainer.train()
|
||||
self.check_trained_model(trainer.model)
|
||||
|
||||
# Can return tensors.
|
||||
train_dataset.set_format(type="torch", dtype=torch.float32)
|
||||
model = RegressionModel()
|
||||
trainer = Trainer(model, args, train_dataset=train_dataset)
|
||||
trainer.train()
|
||||
self.check_trained_model(trainer.model)
|
||||
|
||||
# Adding one column not used by the model should have no impact
|
||||
z = np.random.normal(size=(64,)).astype(np.float32)
|
||||
train_dataset = datasets.Dataset.from_dict({"input_x": x, "label": y, "extra": z})
|
||||
model = RegressionModel()
|
||||
trainer = Trainer(model, args, train_dataset=train_dataset)
|
||||
trainer.train()
|
||||
self.check_trained_model(trainer.model)
|
||||
|
||||
def test_custom_optimizer(self):
|
||||
train_dataset = RegressionDataset()
|
||||
args = TrainingArguments("./regression")
|
||||
model = RegressionModel()
|
||||
optimizer = torch.optim.SGD(model.parameters(), lr=1.0)
|
||||
lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda x: 1.0)
|
||||
trainer = Trainer(model, args, train_dataset=train_dataset, optimizers=(optimizer, lr_scheduler))
|
||||
trainer.train()
|
||||
|
||||
(a, b) = self.default_trained_model
|
||||
self.assertFalse(torch.allclose(trainer.model.a, a))
|
||||
self.assertFalse(torch.allclose(trainer.model.b, b))
|
||||
self.assertEqual(trainer.optimizer.state_dict()["param_groups"][0]["lr"], 1.0)
|
||||
|
||||
@require_torch
|
||||
def test_adafactor_lr_none(self):
|
||||
# test the special case where lr=None, since Trainer can't not have lr_scheduler
|
||||
|
||||
from transformers.optimization import Adafactor, AdafactorSchedule
|
||||
|
||||
train_dataset = RegressionDataset()
|
||||
args = TrainingArguments("./regression")
|
||||
model = RegressionModel()
|
||||
optimizer = Adafactor(model.parameters(), scale_parameter=True, relative_step=True, warmup_init=True, lr=None)
|
||||
lr_scheduler = AdafactorSchedule(optimizer)
|
||||
trainer = Trainer(model, args, train_dataset=train_dataset, optimizers=(optimizer, lr_scheduler))
|
||||
trainer.train()
|
||||
|
||||
(a, b) = self.default_trained_model
|
||||
self.assertFalse(torch.allclose(trainer.model.a, a))
|
||||
self.assertFalse(torch.allclose(trainer.model.b, b))
|
||||
self.assertGreater(trainer.optimizer.state_dict()["param_groups"][0]["lr"], 0)
|
||||
|
||||
def test_log_level(self):
|
||||
# testing only --log_level (--log_level_replica requires multiple nodes)
|
||||
logger = logging.get_logger()
|
||||
@ -645,22 +687,6 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
trainer.train()
|
||||
self.assertNotIn(log_info_string, cl.out)
|
||||
|
||||
def test_model_init(self):
|
||||
train_dataset = RegressionDataset()
|
||||
args = TrainingArguments("./regression", learning_rate=0.1)
|
||||
trainer = Trainer(args=args, train_dataset=train_dataset, model_init=lambda: RegressionModel())
|
||||
trainer.train()
|
||||
self.check_trained_model(trainer.model)
|
||||
|
||||
# Re-training should restart from scratch, thus lead the same results.
|
||||
trainer.train()
|
||||
self.check_trained_model(trainer.model)
|
||||
|
||||
# Re-training should restart from scratch, thus lead the same results and new seed should be used.
|
||||
trainer.args.seed = 314
|
||||
trainer.train()
|
||||
self.check_trained_model(trainer.model, alternate_seed=True)
|
||||
|
||||
def test_save_checkpoints(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
trainer = get_regression_trainer(output_dir=tmpdir, save_steps=5)
|
||||
@ -673,14 +699,6 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
trainer.train()
|
||||
self.check_saved_checkpoints(tmpdir, 5, int(self.n_epochs * 64 / self.batch_size), False)
|
||||
|
||||
def test_gradient_accumulation(self):
|
||||
# Training with half the batch size but accumulation steps as 2 should give the same results.
|
||||
trainer = get_regression_trainer(
|
||||
gradient_accumulation_steps=2, per_device_train_batch_size=4, learning_rate=0.1
|
||||
)
|
||||
trainer.train()
|
||||
self.check_trained_model(trainer.model)
|
||||
|
||||
@require_torch_multi_gpu
|
||||
def test_run_seq2seq_double_train_wrap_once(self):
|
||||
# test that we don't wrap the model more than once
|
||||
@ -694,12 +712,11 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
model_wrapped_after = trainer.model_wrapped
|
||||
self.assertIs(model_wrapped_before, model_wrapped_after, "should be not wrapped twice")
|
||||
|
||||
@require_torch_up_to_2_gpus
|
||||
def test_can_resume_training(self):
|
||||
if torch.cuda.device_count() > 2:
|
||||
# This test will fail for more than 2 GPUs since the batch size will get bigger and with the number of
|
||||
# save_steps, the checkpoint will resume training at epoch 2 or more (so the data seen by the model
|
||||
# won't be the same since the training dataloader is shuffled).
|
||||
return
|
||||
# This test will fail for more than 2 GPUs since the batch size will get bigger and with the number of
|
||||
# save_steps, the checkpoint will resume training at epoch 2 or more (so the data seen by the model
|
||||
# won't be the same since the training dataloader is shuffled).
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
kwargs = dict(output_dir=tmpdir, train_len=128, save_steps=5, learning_rate=0.1)
|
||||
@ -782,10 +799,10 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
trainer.train(resume_from_checkpoint=True)
|
||||
self.assertTrue("No valid checkpoint found in output directory" in str(context.exception))
|
||||
|
||||
@require_torch_non_multi_gpu
|
||||
def test_resume_training_with_randomness(self):
|
||||
if torch.cuda.device_count() >= 2:
|
||||
# This test will fail flakily for more than 2 GPUs since the result will be slightly more different.
|
||||
return
|
||||
# This test will fail flakily for more than 1 GPUs since the result will be slightly more different
|
||||
# TODO: investigate why it fails for 2 GPUs?
|
||||
|
||||
if torch.cuda.is_available():
|
||||
torch.backends.cudnn.deterministic = True
|
||||
@ -807,15 +824,15 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
trainer.train(resume_from_checkpoint=os.path.join(tmp_dir, "checkpoint-15"))
|
||||
(a1, b1) = trainer.model.a.item(), trainer.model.b.item()
|
||||
|
||||
self.assertTrue(math.isclose(a, a1, rel_tol=1e-8))
|
||||
self.assertTrue(math.isclose(b, b1, rel_tol=1e-8))
|
||||
self.assertAlmostEqual(a, a1, delta=1e-8)
|
||||
self.assertAlmostEqual(b, b1, delta=1e-8)
|
||||
|
||||
@require_torch_up_to_2_gpus
|
||||
def test_resume_training_with_gradient_accumulation(self):
|
||||
if torch.cuda.device_count() > 2:
|
||||
# This test will fail for more than 2 GPUs since the batch size will get bigger and with the number of
|
||||
# save_steps, the checkpoint will resume training at epoch 2 or more (so the data seen by the model
|
||||
# won't be the same since the training dataloader is shuffled).
|
||||
return
|
||||
# This test will fail for more than 2 GPUs since the batch size will get bigger and with the number of
|
||||
# save_steps, the checkpoint will resume training at epoch 2 or more (so the data seen by the model
|
||||
# won't be the same since the training dataloader is shuffled).
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
trainer = get_regression_trainer(
|
||||
output_dir=tmpdir,
|
||||
@ -848,12 +865,12 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
self.assertEqual(b, b1)
|
||||
self.check_trainer_state_are_the_same(state, state1)
|
||||
|
||||
@require_torch_up_to_2_gpus
|
||||
def test_resume_training_with_frozen_params(self):
|
||||
if torch.cuda.device_count() > 2:
|
||||
# This test will fail for more than 2 GPUs since the batch size will get bigger and with the number of
|
||||
# save_steps, the checkpoint will resume training at epoch 2 or more (so the data seen by the model
|
||||
# won't be the same since the training dataloader is shuffled).
|
||||
return
|
||||
# This test will fail for more than 2 GPUs since the batch size will get bigger and with the number of
|
||||
# save_steps, the checkpoint will resume training at epoch 2 or more (so the data seen by the model
|
||||
# won't be the same since the training dataloader is shuffled).
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
trainer = get_regression_trainer(
|
||||
output_dir=tmpdir,
|
||||
|
Loading…
Reference in New Issue
Block a user