Upstream (and rename) sortish_sampler

This commit is contained in:
Sylvain Gugger 2021-01-13 15:35:06 -05:00
parent c8b165638e
commit e07d0dcf65
5 changed files with 177 additions and 12 deletions

View File

@ -169,7 +169,7 @@ class TestFinetuneTrainer(TestCasePlus):
--logging_steps 0
--save_steps {str(eval_steps)}
--eval_steps {str(eval_steps)}
--sortish_sampler
--group_by_length
--label_smoothing_factor 0.1
--adafactor
--task translation

View File

@ -70,12 +70,13 @@ from .trainer_callback import (
TrainerState,
)
from .trainer_pt_utils import (
DistributedLengthGroupedSampler,
DistributedTensorGatherer,
LabelSmoother,
SequentialDistributedSampler,
distributed_broadcast_scalars,
distributed_concat,
get_tpu_sampler,
get_length_grouped_indices,
nested_concat,
nested_detach,
nested_numpify,
@ -94,7 +95,7 @@ from .trainer_utils import (
set_seed,
speed_metrics,
)
from .training_args import TrainingArguments
from .training_args import ParallelMode, TrainingArguments
from .utils import logging
@ -448,14 +449,33 @@ class Trainer:
self.train_dataset, collections.abc.Sized
):
return None
elif is_torch_tpu_available():
return get_tpu_sampler(self.train_dataset)
# Gather the number of processes and this process index.
if self.args.parallel_mode == ParallelMode.TPU:
num_processes = xm.xrt_world_size()
process_index = xm.get_ordinal()
elif self.args.parallel_mode == ParallelMode.DISTRIBUTED:
num_processes = torch.distributed.get_world_size()
process_index = torch.distributed.get_rank()
else:
return (
RandomSampler(self.train_dataset)
if self.args.local_rank == -1
else DistributedSampler(self.train_dataset)
)
num_processes = 1
process_index = 0
# Build the sampler.
if self.args.group_by_length:
if num_processes <= 1:
lengths = [len(feature["input_ids"]) for feature in self.train_dataset]
return get_length_grouped_indices(lengths, self.args.train_batch_size)
else:
return DistributedLengthGroupedSampler(
self.train_dataset, self.args.train_batch_size, num_replicas=num_processes, rank=process_index
)
else:
if num_processes <= 1:
return RandomSampler(self.train_dataset)
else:
return DistributedSampler(self.train_dataset, num_replicas=num_processes, rank=process_index)
def get_train_dataloader(self) -> DataLoader:
"""

View File

@ -20,10 +20,11 @@ import math
import warnings
from contextlib import contextmanager
from dataclasses import dataclass
from typing import List, Optional, Union
from typing import Iterator, List, Optional, Union
import numpy as np
import torch
from torch.utils.data.dataset import Dataset
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data.sampler import RandomSampler, Sampler
@ -390,3 +391,110 @@ class LabelSmoother:
# Take the mean over the label dimensions, then divide by the number of active elements (i.e. not-padded):
smoothed_loss = log_probs.mean(dim=-1).sum() / (padding_mask.numel() - padding_mask.long().sum())
return (1 - self.epsilon) * model_loss + self.epsilon * smoothed_loss
def get_length_grouped_indices(lengths, batch_size, mega_batch_mult=None, generator=None):
"""
Return a list of indices so that each slice of :obj:`batch_size` consecutive indices correspond to elements of
similar lengths. To do this, the indices are:
- randomly permuted
- grouped in mega-batches of size :obj:`mega_batch_mult * batch_size`
- sorted by length in each mega-batch
The result is the concatenation of all mega-batches, with the batch of :obj:`batch_size` containing the element of
maximum length placed first, so that an OOM happens sooner rather than later.
"""
# Default for mega_batch_mult: 50 or the number to get 4 megabatches, whichever is smaller.
if mega_batch_mult is None:
mega_batch_mult = min(len(lengths) // (batch_size * 4), 50)
# Just in case, for tiny datasets
if mega_batch_mult == 0:
mega_batch_mult = 1
# We need to use torch for the random part as a distributed sampler will set the random seed for torch.
indices = torch.randperm(len(lengths), generator=generator)
megabatch_size = mega_batch_mult * batch_size
megabatches = [indices[i : i + megabatch_size].tolist() for i in range(0, len(lengths), megabatch_size)]
megabatches = [list(sorted(megabatch, key=lambda i: lengths[i], reverse=True)) for megabatch in megabatches]
# The rest is to get the biggest batch first.
# Since each meagbatch is sorted by descending length, the longest element is the first
megabatch_maximums = [lengths[megabatch[0]] for megabatch in megabatches]
max_idx = torch.argmax(torch.tensor(megabatch_maximums)).item()
# Switch to put the longest element in first position
megabatches[0][0], megabatches[max_idx][0] = megabatches[max_idx][0], megabatches[0][0]
return sum(megabatches, [])
class DistributedLengthGroupedSampler(DistributedSampler):
r"""
Distributed Sampler that samples indices in a way that groups together features of the dataset of roughly the same
length while keeping a bit of randomness.
"""
# Copied and adapted from PyTorch DistributedSampler.
def __init__(
self,
dataset: Dataset,
batch_size: int,
num_replicas: Optional[int] = None,
rank: Optional[int] = None,
seed: int = 0,
drop_last: bool = False,
lengths: Optional[List[int]] = None,
):
if num_replicas is None:
if not torch.distributed.is_available():
raise RuntimeError("Requires distributed package to be available")
num_replicas = torch.distributed.get_world_size()
if rank is None:
if not torch.distributed.is_available():
raise RuntimeError("Requires distributed package to be available")
rank = torch.distributed.get_rank()
self.dataset = dataset
self.batch_size = batch_size
self.num_replicas = num_replicas
self.rank = rank
self.epoch = 0
self.drop_last = drop_last
# If the dataset length is evenly divisible by # of replicas, then there
# is no need to drop any data, since the dataset will be split equally.
if self.drop_last and len(self.dataset) % self.num_replicas != 0:
# Split to nearest available length that is evenly divisible.
# This is to ensure each rank receives the same amount of data when
# using this Sampler.
self.num_samples = math.ceil((len(self.dataset) - self.num_replicas) / self.num_replicas)
else:
self.num_samples = math.ceil(len(self.dataset) / self.num_replicas)
self.total_size = self.num_samples * self.num_replicas
self.seed = seed
if lengths is None:
if not isinstance(dataset[0], dict) or "input_ids" not in dataset[0]:
raise ValueError(
"Can only automatically infer lengths for datasets whose items are dictionaries with an "
"'input_ids' key."
)
lengths = [len(feature["input_ids"]) for feature in dataset]
self.lengths = lengths
def __iter__(self) -> Iterator:
# Deterministically shuffle based on epoch and seed
g = torch.Generator()
g.manual_seed(self.seed + self.epoch)
indices = get_length_grouped_indices(self.lengths, self.batch_size, generator=g)
if not self.drop_last:
# add extra samples to make it evenly divisible
indices += indices[: (self.total_size - len(indices))]
else:
# remove tail of data to make it evenly divisible.
indices = indices[: self.total_size]
assert len(indices) == self.total_size
# subsample
indices = indices[self.rank : self.total_size : self.num_replicas]
assert len(indices) == self.num_samples
return iter(indices)

View File

@ -227,6 +227,9 @@ class TrainingArguments:
adafactor (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to use the :class:`~transformers.Adafactor` optimizer instead of
:class:`~transformers.AdamW`.
group_by_length (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to group together samples of roughly the same legnth in the training dataset (to minimize
padding applied and be more efficient). Only useful if applying dynamic padding.
"""
output_dir: str = field(
@ -405,6 +408,10 @@ class TrainingArguments:
default=0.0, metadata={"help": "The label smoothing epsilon to apply (zero means no label smoothing)."}
)
adafactor: bool = field(default=False, metadata={"help": "Whether or not to replace Adam by Adafactor."})
group_by_length: bool = field(
default=False,
metadata={"help": "Whether or not to group samples of roughly the same length together when batching."},
)
_n_gpu: int = field(init=False, repr=False, default=-1)
def __post_init__(self):

View File

@ -25,7 +25,12 @@ if is_torch_available():
import torch
from transformers.modeling_outputs import SequenceClassifierOutput
from transformers.trainer_pt_utils import DistributedTensorGatherer, LabelSmoother
from transformers.trainer_pt_utils import (
DistributedLengthGroupedSampler,
DistributedTensorGatherer,
LabelSmoother,
get_length_grouped_indices,
)
@require_torch
@ -87,3 +92,28 @@ class TrainerUtilsTest(unittest.TestCase):
log_probs[2, 3] = 0.0
expected_loss = (1 - epsilon) * loss + epsilon * log_probs.sum() / (num_labels * 17)
self.assertTrue(torch.allclose(label_smoothed_loss, expected_loss))
def test_group_by_length(self):
# Get some inputs of random lengths
lengths = torch.randint(0, 25, (100,)).tolist()
# Put one bigger than the others to check it ends up in first position
lengths[32] = 50
indices = get_length_grouped_indices(lengths, 4)
# The biggest element should be first
self.assertEqual(lengths[indices[0]], 50)
# The indices should be a permutation of range(100)
self.assertEqual(list(sorted(indices)), list(range(100)))
def test_distributed_length_grouped(self):
# Get some inputs of random lengths
lengths = torch.randint(0, 25, (100,)).tolist()
# Put one bigger than the others to check it ends up in first position
lengths[32] = 50
indices_process_0 = list(DistributedLengthGroupedSampler(lengths, 4, 2, 0, lengths=lengths))
indices_process_1 = list(DistributedLengthGroupedSampler(lengths, 4, 2, 1, lengths=lengths))
# The biggest element should be first
self.assertEqual(lengths[indices_process_0[0]], 50)
# The indices should be a permutation of range(100)
self.assertEqual(list(sorted(indices_process_0 + indices_process_1)), list(range(100)))