Smmp batch not divisible by microbatches fix (#10778)

* Added debug prints

* Added config

* Added prints

* Added prints

* Added extra samples to SequentialDistributedSampler

* Added extra samples to SequentialDistributedSampler

Updated SequentialDistributedSampler call

* Added deubg prints

* Removed extra prints

* Making predicitons and labels multiple of batchsize

* updated number of microbatches

* Removed extra prints

* Made start_remainder similar to DistributedSamplerWithLoop

* Minor spacing update

* Added debug prints

Added config

Added prints

Added prints

* Added extra samples to SequentialDistributedSampler

Updated SequentialDistributedSampler call

Added extra samples to SequentialDistributedSampler

Added deubg prints

Removed extra prints

Making predicitons and labels multiple of batchsize

updated number of microbatches

Removed extra prints

Squashing redundant commits

* Made start_remainder similar to DistributedSamplerWithLoop

Minor spacing update

Made start_remainder similar to DistributedSamplerWithLoop

* Test and styling

* Rename test

Co-authored-by: Sylvain Gugger <sylvain.gugger@gmail.com>
This commit is contained in:
Mansi Mane 2021-03-17 16:18:11 -07:00 committed by GitHub
parent 40b049c701
commit 0282e24eef
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 49 additions and 5 deletions

View File

@ -112,7 +112,12 @@ class SageMakerTrainer(Trainer):
def _get_eval_sampler(self, eval_dataset: Dataset) -> Optional[torch.utils.data.sampler.Sampler]:
if self.is_model_parallel_enabled:
return SequentialDistributedSampler(eval_dataset, num_replicas=smp.dp_size(), rank=smp.dp_rank())
return SequentialDistributedSampler(
eval_dataset,
num_replicas=smp.dp_size(),
rank=smp.dp_rank(),
batch_size=self.args.per_device_eval_batch_size,
)
else:
return super()._get_eval_sampler(eval_dataset)

View File

@ -1812,8 +1812,8 @@ class Trainer:
eval_losses_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=batch_size)
if not prediction_loss_only:
preds_gatherer = DistributedTensorGatherer(world_size, num_examples)
labels_gatherer = DistributedTensorGatherer(world_size, num_examples)
preds_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=batch_size)
labels_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=batch_size)
model.eval()

View File

@ -220,7 +220,7 @@ class SequentialDistributedSampler(Sampler):
or `reduce` resulting tensors at the end of the loop.
"""
def __init__(self, dataset, num_replicas=None, rank=None):
def __init__(self, dataset, num_replicas=None, rank=None, batch_size=None):
if num_replicas is None:
if not dist.is_available():
raise RuntimeError("Requires distributed package to be available")
@ -232,8 +232,14 @@ class SequentialDistributedSampler(Sampler):
self.dataset = dataset
self.num_replicas = num_replicas
self.rank = rank
self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas))
num_samples = len(self.dataset)
# Add extra samples to make num_samples a multiple of batch_size if passed
if batch_size is not None:
self.num_samples = int(math.ceil(num_samples / (batch_size * num_replicas))) * batch_size
else:
self.num_samples = int(math.ceil(num_samples / num_replicas))
self.total_size = self.num_samples * self.num_replicas
self.batch_size = batch_size
def __iter__(self):
indices = list(range(len(self.dataset)))

View File

@ -31,6 +31,7 @@ if is_torch_available():
DistributedTensorGatherer,
LabelSmoother,
LengthGroupedSampler,
SequentialDistributedSampler,
get_parameter_names,
)
@ -167,3 +168,35 @@ class TrainerUtilsTest(unittest.TestCase):
self.assertEqual(set(total[:length]), set(dataset))
self.assertEqual(set(total[length:]), set(total[: (len(total) - length)]))
def test_sequential_distributed_sampler(self):
batch_size = 16
for length in [23, 64, 123]:
dataset = list(range(length))
shard1 = SequentialDistributedSampler(dataset, num_replicas=2, rank=0)
shard2 = SequentialDistributedSampler(dataset, num_replicas=2, rank=1)
# Sample
samples1 = list(shard1)
samples2 = list(shard2)
total = samples1 + samples2
self.assertListEqual(total[:length], dataset)
self.assertListEqual(total[length:], dataset[: (len(total) - length)])
# With a batch_size passed
shard1 = SequentialDistributedSampler(dataset, num_replicas=2, rank=0, batch_size=batch_size)
shard2 = SequentialDistributedSampler(dataset, num_replicas=2, rank=1, batch_size=batch_size)
# Sample
samples1 = list(shard1)
samples2 = list(shard2)
self.assertTrue(len(samples1) % batch_size == 0)
self.assertTrue(len(samples2) % batch_size == 0)
total = samples1 + samples2
self.assertListEqual(total[:length], dataset)
self.assertListEqual(total[length:], dataset[: (len(total) - length)])