diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 3d6bcf8cc3c..4e9c760f676 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -405,7 +405,7 @@ class TrainingArguments: default=0.0, metadata={"help": "The label smoothing epsilon to apply (zero means no label smoothing)."} ) adafactor: bool = field(default=False, metadata={"help": "Whether or not to replace Adam by Adafactor."}) - _n_gpu: int = field(init=False, repr=False, default=0) + _n_gpu: int = field(init=False, repr=False, default=-1) def __post_init__(self): if self.disable_tqdm is None: @@ -483,6 +483,10 @@ class TrainingArguments: # GPUs available in the environment, so `CUDA_VISIBLE_DEVICES=1,2` with `cuda:0` # will use the first GPU in that env, i.e. GPU#1 device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + # Sometimes the line in the postinit has not been run before we end up here, so just checking we're not at + # the default value. + if self._n_gpu == -1: + self._n_gpu = torch.cuda.device_count() n_gpu = self._n_gpu else: # Here, we'll use torch.distributed. diff --git a/tests/test_trainer.py b/tests/test_trainer.py index 0443a1429ec..cfb01ece0ca 100644 --- a/tests/test_trainer.py +++ b/tests/test_trainer.py @@ -29,6 +29,7 @@ from transformers.testing_utils import ( require_sentencepiece, require_tokenizers, require_torch, + require_torch_multi_gpu, slow, ) from transformers.utils.hp_naming import TrialShortNamer @@ -374,6 +375,22 @@ class TrainerIntegrationTest(unittest.TestCase): new_eval_dataset = RegressionDataset(length=128) self.assertEqual(len(trainer.get_eval_dataloader(new_eval_dataset)), 128 // (32 * n_gpu)) + @require_torch_multi_gpu + def test_data_is_not_parallelized_when_model_is_parallel(self): + model = RegressionModel() + # Make the Trainer believe it's a parallelized model + model.is_parallelizable = True + model.model_parallel = True + trainer = Trainer(model=model, train_dataset=RegressionDataset(), eval_dataset=RegressionDataset()) + # Check the Trainer was fooled + self.assertTrue(trainer.is_model_parallel) + + # The batch size of the training and evaluation dataloaders should be 16, not 16 * n_gpu + self.assertEqual(trainer.get_train_dataloader().batch_size, 16) + self.assertEqual(len(trainer.get_train_dataloader()), 64 // 16) + self.assertEqual(trainer.get_eval_dataloader().batch_size, 16) + self.assertEqual(len(trainer.get_eval_dataloader()), 64 // 16) + def test_evaluate(self): trainer = get_regression_trainer(a=1.5, b=2.5, compute_metrics=AlmostAccuracy()) results = trainer.evaluate()