mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 03:01:07 +06:00
Fix finite IterableDataset test on multiple GPUs (#14445)
This commit is contained in:
parent
da36c557f7
commit
83ef8bcac2
@ -1069,13 +1069,17 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
|||||||
self.assertIsInstance(loader.sampler, torch.utils.data.dataloader._InfiniteConstantSampler)
|
self.assertIsInstance(loader.sampler, torch.utils.data.dataloader._InfiniteConstantSampler)
|
||||||
|
|
||||||
def test_training_finite_iterable_dataset(self):
|
def test_training_finite_iterable_dataset(self):
|
||||||
|
num_gpus = max(1, get_gpu_count())
|
||||||
|
if num_gpus > 2:
|
||||||
|
return
|
||||||
|
|
||||||
config = RegressionModelConfig()
|
config = RegressionModelConfig()
|
||||||
model = RegressionPreTrainedModel(config)
|
model = RegressionPreTrainedModel(config)
|
||||||
|
|
||||||
batch_size = 1
|
batch_size = 1
|
||||||
num_samples = 10
|
num_samples = 10
|
||||||
|
|
||||||
available_steps = num_samples // batch_size
|
available_steps = num_samples // (batch_size * num_gpus)
|
||||||
|
|
||||||
data = FiniteIterableDataset(length=num_samples)
|
data = FiniteIterableDataset(length=num_samples)
|
||||||
train_args = TrainingArguments(
|
train_args = TrainingArguments(
|
||||||
|
Loading…
Reference in New Issue
Block a user