mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 18:51:14 +06:00
Fix data parallelism in Trainer (#9566)
* Fix data parallelism in Trainer * Update src/transformers/training_args.py Co-authored-by: Lysandre Debut <lysandre@huggingface.co> Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
This commit is contained in:
parent
b2dfcc567b
commit
04dc65e5c6
@ -405,7 +405,7 @@ class TrainingArguments:
|
|||||||
default=0.0, metadata={"help": "The label smoothing epsilon to apply (zero means no label smoothing)."}
|
default=0.0, metadata={"help": "The label smoothing epsilon to apply (zero means no label smoothing)."}
|
||||||
)
|
)
|
||||||
adafactor: bool = field(default=False, metadata={"help": "Whether or not to replace Adam by Adafactor."})
|
adafactor: bool = field(default=False, metadata={"help": "Whether or not to replace Adam by Adafactor."})
|
||||||
_n_gpu: int = field(init=False, repr=False, default=0)
|
_n_gpu: int = field(init=False, repr=False, default=-1)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
if self.disable_tqdm is None:
|
if self.disable_tqdm is None:
|
||||||
@ -483,6 +483,10 @@ class TrainingArguments:
|
|||||||
# GPUs available in the environment, so `CUDA_VISIBLE_DEVICES=1,2` with `cuda:0`
|
# GPUs available in the environment, so `CUDA_VISIBLE_DEVICES=1,2` with `cuda:0`
|
||||||
# will use the first GPU in that env, i.e. GPU#1
|
# will use the first GPU in that env, i.e. GPU#1
|
||||||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||||
|
# Sometimes the line in the postinit has not been run before we end up here, so just checking we're not at
|
||||||
|
# the default value.
|
||||||
|
if self._n_gpu == -1:
|
||||||
|
self._n_gpu = torch.cuda.device_count()
|
||||||
n_gpu = self._n_gpu
|
n_gpu = self._n_gpu
|
||||||
else:
|
else:
|
||||||
# Here, we'll use torch.distributed.
|
# Here, we'll use torch.distributed.
|
||||||
|
@ -29,6 +29,7 @@ from transformers.testing_utils import (
|
|||||||
require_sentencepiece,
|
require_sentencepiece,
|
||||||
require_tokenizers,
|
require_tokenizers,
|
||||||
require_torch,
|
require_torch,
|
||||||
|
require_torch_multi_gpu,
|
||||||
slow,
|
slow,
|
||||||
)
|
)
|
||||||
from transformers.utils.hp_naming import TrialShortNamer
|
from transformers.utils.hp_naming import TrialShortNamer
|
||||||
@ -374,6 +375,22 @@ class TrainerIntegrationTest(unittest.TestCase):
|
|||||||
new_eval_dataset = RegressionDataset(length=128)
|
new_eval_dataset = RegressionDataset(length=128)
|
||||||
self.assertEqual(len(trainer.get_eval_dataloader(new_eval_dataset)), 128 // (32 * n_gpu))
|
self.assertEqual(len(trainer.get_eval_dataloader(new_eval_dataset)), 128 // (32 * n_gpu))
|
||||||
|
|
||||||
|
@require_torch_multi_gpu
|
||||||
|
def test_data_is_not_parallelized_when_model_is_parallel(self):
|
||||||
|
model = RegressionModel()
|
||||||
|
# Make the Trainer believe it's a parallelized model
|
||||||
|
model.is_parallelizable = True
|
||||||
|
model.model_parallel = True
|
||||||
|
trainer = Trainer(model=model, train_dataset=RegressionDataset(), eval_dataset=RegressionDataset())
|
||||||
|
# Check the Trainer was fooled
|
||||||
|
self.assertTrue(trainer.is_model_parallel)
|
||||||
|
|
||||||
|
# The batch size of the training and evaluation dataloaders should be 16, not 16 * n_gpu
|
||||||
|
self.assertEqual(trainer.get_train_dataloader().batch_size, 16)
|
||||||
|
self.assertEqual(len(trainer.get_train_dataloader()), 64 // 16)
|
||||||
|
self.assertEqual(trainer.get_eval_dataloader().batch_size, 16)
|
||||||
|
self.assertEqual(len(trainer.get_eval_dataloader()), 64 // 16)
|
||||||
|
|
||||||
def test_evaluate(self):
|
def test_evaluate(self):
|
||||||
trainer = get_regression_trainer(a=1.5, b=2.5, compute_metrics=AlmostAccuracy())
|
trainer = get_regression_trainer(a=1.5, b=2.5, compute_metrics=AlmostAccuracy())
|
||||||
results = trainer.evaluate()
|
results = trainer.evaluate()
|
||||||
|
Loading…
Reference in New Issue
Block a user