Fix finite IterableDataset test on multiple GPUs (#14445)

This commit is contained in:
Sylvain Gugger 2021-11-18 10:25:06 -05:00 committed by GitHub
parent da36c557f7
commit 83ef8bcac2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1069,13 +1069,17 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
self.assertIsInstance(loader.sampler, torch.utils.data.dataloader._InfiniteConstantSampler)
def test_training_finite_iterable_dataset(self):
num_gpus = max(1, get_gpu_count())
if num_gpus > 2:
return
config = RegressionModelConfig()
model = RegressionPreTrainedModel(config)
batch_size = 1
num_samples = 10
available_steps = num_samples // batch_size
available_steps = num_samples // (batch_size * num_gpus)
data = FiniteIterableDataset(length=num_samples)
train_args = TrainingArguments(