mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 18:22:34 +06:00
Upstream (and rename) sortish_sampler
This commit is contained in:
parent
c8b165638e
commit
e07d0dcf65
@ -169,7 +169,7 @@ class TestFinetuneTrainer(TestCasePlus):
|
|||||||
--logging_steps 0
|
--logging_steps 0
|
||||||
--save_steps {str(eval_steps)}
|
--save_steps {str(eval_steps)}
|
||||||
--eval_steps {str(eval_steps)}
|
--eval_steps {str(eval_steps)}
|
||||||
--sortish_sampler
|
--group_by_length
|
||||||
--label_smoothing_factor 0.1
|
--label_smoothing_factor 0.1
|
||||||
--adafactor
|
--adafactor
|
||||||
--task translation
|
--task translation
|
||||||
|
@ -70,12 +70,13 @@ from .trainer_callback import (
|
|||||||
TrainerState,
|
TrainerState,
|
||||||
)
|
)
|
||||||
from .trainer_pt_utils import (
|
from .trainer_pt_utils import (
|
||||||
|
DistributedLengthGroupedSampler,
|
||||||
DistributedTensorGatherer,
|
DistributedTensorGatherer,
|
||||||
LabelSmoother,
|
LabelSmoother,
|
||||||
SequentialDistributedSampler,
|
SequentialDistributedSampler,
|
||||||
distributed_broadcast_scalars,
|
distributed_broadcast_scalars,
|
||||||
distributed_concat,
|
distributed_concat,
|
||||||
get_tpu_sampler,
|
get_length_grouped_indices,
|
||||||
nested_concat,
|
nested_concat,
|
||||||
nested_detach,
|
nested_detach,
|
||||||
nested_numpify,
|
nested_numpify,
|
||||||
@ -94,7 +95,7 @@ from .trainer_utils import (
|
|||||||
set_seed,
|
set_seed,
|
||||||
speed_metrics,
|
speed_metrics,
|
||||||
)
|
)
|
||||||
from .training_args import TrainingArguments
|
from .training_args import ParallelMode, TrainingArguments
|
||||||
from .utils import logging
|
from .utils import logging
|
||||||
|
|
||||||
|
|
||||||
@ -448,14 +449,33 @@ class Trainer:
|
|||||||
self.train_dataset, collections.abc.Sized
|
self.train_dataset, collections.abc.Sized
|
||||||
):
|
):
|
||||||
return None
|
return None
|
||||||
elif is_torch_tpu_available():
|
|
||||||
return get_tpu_sampler(self.train_dataset)
|
# Gather the number of processes and this process index.
|
||||||
|
if self.args.parallel_mode == ParallelMode.TPU:
|
||||||
|
num_processes = xm.xrt_world_size()
|
||||||
|
process_index = xm.get_ordinal()
|
||||||
|
elif self.args.parallel_mode == ParallelMode.DISTRIBUTED:
|
||||||
|
num_processes = torch.distributed.get_world_size()
|
||||||
|
process_index = torch.distributed.get_rank()
|
||||||
else:
|
else:
|
||||||
return (
|
num_processes = 1
|
||||||
RandomSampler(self.train_dataset)
|
process_index = 0
|
||||||
if self.args.local_rank == -1
|
|
||||||
else DistributedSampler(self.train_dataset)
|
# Build the sampler.
|
||||||
)
|
if self.args.group_by_length:
|
||||||
|
if num_processes <= 1:
|
||||||
|
lengths = [len(feature["input_ids"]) for feature in self.train_dataset]
|
||||||
|
return get_length_grouped_indices(lengths, self.args.train_batch_size)
|
||||||
|
else:
|
||||||
|
return DistributedLengthGroupedSampler(
|
||||||
|
self.train_dataset, self.args.train_batch_size, num_replicas=num_processes, rank=process_index
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
if num_processes <= 1:
|
||||||
|
return RandomSampler(self.train_dataset)
|
||||||
|
else:
|
||||||
|
return DistributedSampler(self.train_dataset, num_replicas=num_processes, rank=process_index)
|
||||||
|
|
||||||
def get_train_dataloader(self) -> DataLoader:
|
def get_train_dataloader(self) -> DataLoader:
|
||||||
"""
|
"""
|
||||||
|
@ -20,10 +20,11 @@ import math
|
|||||||
import warnings
|
import warnings
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import List, Optional, Union
|
from typing import Iterator, List, Optional, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
from torch.utils.data.dataset import Dataset
|
||||||
from torch.utils.data.distributed import DistributedSampler
|
from torch.utils.data.distributed import DistributedSampler
|
||||||
from torch.utils.data.sampler import RandomSampler, Sampler
|
from torch.utils.data.sampler import RandomSampler, Sampler
|
||||||
|
|
||||||
@ -390,3 +391,110 @@ class LabelSmoother:
|
|||||||
# Take the mean over the label dimensions, then divide by the number of active elements (i.e. not-padded):
|
# Take the mean over the label dimensions, then divide by the number of active elements (i.e. not-padded):
|
||||||
smoothed_loss = log_probs.mean(dim=-1).sum() / (padding_mask.numel() - padding_mask.long().sum())
|
smoothed_loss = log_probs.mean(dim=-1).sum() / (padding_mask.numel() - padding_mask.long().sum())
|
||||||
return (1 - self.epsilon) * model_loss + self.epsilon * smoothed_loss
|
return (1 - self.epsilon) * model_loss + self.epsilon * smoothed_loss
|
||||||
|
|
||||||
|
|
||||||
|
def get_length_grouped_indices(lengths, batch_size, mega_batch_mult=None, generator=None):
|
||||||
|
"""
|
||||||
|
Return a list of indices so that each slice of :obj:`batch_size` consecutive indices correspond to elements of
|
||||||
|
similar lengths. To do this, the indices are:
|
||||||
|
|
||||||
|
- randomly permuted
|
||||||
|
- grouped in mega-batches of size :obj:`mega_batch_mult * batch_size`
|
||||||
|
- sorted by length in each mega-batch
|
||||||
|
|
||||||
|
The result is the concatenation of all mega-batches, with the batch of :obj:`batch_size` containing the element of
|
||||||
|
maximum length placed first, so that an OOM happens sooner rather than later.
|
||||||
|
"""
|
||||||
|
# Default for mega_batch_mult: 50 or the number to get 4 megabatches, whichever is smaller.
|
||||||
|
if mega_batch_mult is None:
|
||||||
|
mega_batch_mult = min(len(lengths) // (batch_size * 4), 50)
|
||||||
|
# Just in case, for tiny datasets
|
||||||
|
if mega_batch_mult == 0:
|
||||||
|
mega_batch_mult = 1
|
||||||
|
|
||||||
|
# We need to use torch for the random part as a distributed sampler will set the random seed for torch.
|
||||||
|
indices = torch.randperm(len(lengths), generator=generator)
|
||||||
|
megabatch_size = mega_batch_mult * batch_size
|
||||||
|
megabatches = [indices[i : i + megabatch_size].tolist() for i in range(0, len(lengths), megabatch_size)]
|
||||||
|
megabatches = [list(sorted(megabatch, key=lambda i: lengths[i], reverse=True)) for megabatch in megabatches]
|
||||||
|
|
||||||
|
# The rest is to get the biggest batch first.
|
||||||
|
# Since each meagbatch is sorted by descending length, the longest element is the first
|
||||||
|
megabatch_maximums = [lengths[megabatch[0]] for megabatch in megabatches]
|
||||||
|
max_idx = torch.argmax(torch.tensor(megabatch_maximums)).item()
|
||||||
|
# Switch to put the longest element in first position
|
||||||
|
megabatches[0][0], megabatches[max_idx][0] = megabatches[max_idx][0], megabatches[0][0]
|
||||||
|
|
||||||
|
return sum(megabatches, [])
|
||||||
|
|
||||||
|
|
||||||
|
class DistributedLengthGroupedSampler(DistributedSampler):
|
||||||
|
r"""
|
||||||
|
Distributed Sampler that samples indices in a way that groups together features of the dataset of roughly the same
|
||||||
|
length while keeping a bit of randomness.
|
||||||
|
"""
|
||||||
|
# Copied and adapted from PyTorch DistributedSampler.
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dataset: Dataset,
|
||||||
|
batch_size: int,
|
||||||
|
num_replicas: Optional[int] = None,
|
||||||
|
rank: Optional[int] = None,
|
||||||
|
seed: int = 0,
|
||||||
|
drop_last: bool = False,
|
||||||
|
lengths: Optional[List[int]] = None,
|
||||||
|
):
|
||||||
|
if num_replicas is None:
|
||||||
|
if not torch.distributed.is_available():
|
||||||
|
raise RuntimeError("Requires distributed package to be available")
|
||||||
|
num_replicas = torch.distributed.get_world_size()
|
||||||
|
if rank is None:
|
||||||
|
if not torch.distributed.is_available():
|
||||||
|
raise RuntimeError("Requires distributed package to be available")
|
||||||
|
rank = torch.distributed.get_rank()
|
||||||
|
self.dataset = dataset
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.num_replicas = num_replicas
|
||||||
|
self.rank = rank
|
||||||
|
self.epoch = 0
|
||||||
|
self.drop_last = drop_last
|
||||||
|
# If the dataset length is evenly divisible by # of replicas, then there
|
||||||
|
# is no need to drop any data, since the dataset will be split equally.
|
||||||
|
if self.drop_last and len(self.dataset) % self.num_replicas != 0:
|
||||||
|
# Split to nearest available length that is evenly divisible.
|
||||||
|
# This is to ensure each rank receives the same amount of data when
|
||||||
|
# using this Sampler.
|
||||||
|
self.num_samples = math.ceil((len(self.dataset) - self.num_replicas) / self.num_replicas)
|
||||||
|
else:
|
||||||
|
self.num_samples = math.ceil(len(self.dataset) / self.num_replicas)
|
||||||
|
self.total_size = self.num_samples * self.num_replicas
|
||||||
|
self.seed = seed
|
||||||
|
|
||||||
|
if lengths is None:
|
||||||
|
if not isinstance(dataset[0], dict) or "input_ids" not in dataset[0]:
|
||||||
|
raise ValueError(
|
||||||
|
"Can only automatically infer lengths for datasets whose items are dictionaries with an "
|
||||||
|
"'input_ids' key."
|
||||||
|
)
|
||||||
|
lengths = [len(feature["input_ids"]) for feature in dataset]
|
||||||
|
self.lengths = lengths
|
||||||
|
|
||||||
|
def __iter__(self) -> Iterator:
|
||||||
|
# Deterministically shuffle based on epoch and seed
|
||||||
|
g = torch.Generator()
|
||||||
|
g.manual_seed(self.seed + self.epoch)
|
||||||
|
indices = get_length_grouped_indices(self.lengths, self.batch_size, generator=g)
|
||||||
|
|
||||||
|
if not self.drop_last:
|
||||||
|
# add extra samples to make it evenly divisible
|
||||||
|
indices += indices[: (self.total_size - len(indices))]
|
||||||
|
else:
|
||||||
|
# remove tail of data to make it evenly divisible.
|
||||||
|
indices = indices[: self.total_size]
|
||||||
|
assert len(indices) == self.total_size
|
||||||
|
|
||||||
|
# subsample
|
||||||
|
indices = indices[self.rank : self.total_size : self.num_replicas]
|
||||||
|
assert len(indices) == self.num_samples
|
||||||
|
|
||||||
|
return iter(indices)
|
||||||
|
@ -227,6 +227,9 @@ class TrainingArguments:
|
|||||||
adafactor (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
adafactor (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||||
Whether or not to use the :class:`~transformers.Adafactor` optimizer instead of
|
Whether or not to use the :class:`~transformers.Adafactor` optimizer instead of
|
||||||
:class:`~transformers.AdamW`.
|
:class:`~transformers.AdamW`.
|
||||||
|
group_by_length (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||||
|
Whether or not to group together samples of roughly the same legnth in the training dataset (to minimize
|
||||||
|
padding applied and be more efficient). Only useful if applying dynamic padding.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
output_dir: str = field(
|
output_dir: str = field(
|
||||||
@ -405,6 +408,10 @@ class TrainingArguments:
|
|||||||
default=0.0, metadata={"help": "The label smoothing epsilon to apply (zero means no label smoothing)."}
|
default=0.0, metadata={"help": "The label smoothing epsilon to apply (zero means no label smoothing)."}
|
||||||
)
|
)
|
||||||
adafactor: bool = field(default=False, metadata={"help": "Whether or not to replace Adam by Adafactor."})
|
adafactor: bool = field(default=False, metadata={"help": "Whether or not to replace Adam by Adafactor."})
|
||||||
|
group_by_length: bool = field(
|
||||||
|
default=False,
|
||||||
|
metadata={"help": "Whether or not to group samples of roughly the same length together when batching."},
|
||||||
|
)
|
||||||
_n_gpu: int = field(init=False, repr=False, default=-1)
|
_n_gpu: int = field(init=False, repr=False, default=-1)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
|
@ -25,7 +25,12 @@ if is_torch_available():
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from transformers.modeling_outputs import SequenceClassifierOutput
|
from transformers.modeling_outputs import SequenceClassifierOutput
|
||||||
from transformers.trainer_pt_utils import DistributedTensorGatherer, LabelSmoother
|
from transformers.trainer_pt_utils import (
|
||||||
|
DistributedLengthGroupedSampler,
|
||||||
|
DistributedTensorGatherer,
|
||||||
|
LabelSmoother,
|
||||||
|
get_length_grouped_indices,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
@ -87,3 +92,28 @@ class TrainerUtilsTest(unittest.TestCase):
|
|||||||
log_probs[2, 3] = 0.0
|
log_probs[2, 3] = 0.0
|
||||||
expected_loss = (1 - epsilon) * loss + epsilon * log_probs.sum() / (num_labels * 17)
|
expected_loss = (1 - epsilon) * loss + epsilon * log_probs.sum() / (num_labels * 17)
|
||||||
self.assertTrue(torch.allclose(label_smoothed_loss, expected_loss))
|
self.assertTrue(torch.allclose(label_smoothed_loss, expected_loss))
|
||||||
|
|
||||||
|
def test_group_by_length(self):
|
||||||
|
# Get some inputs of random lengths
|
||||||
|
lengths = torch.randint(0, 25, (100,)).tolist()
|
||||||
|
# Put one bigger than the others to check it ends up in first position
|
||||||
|
lengths[32] = 50
|
||||||
|
|
||||||
|
indices = get_length_grouped_indices(lengths, 4)
|
||||||
|
# The biggest element should be first
|
||||||
|
self.assertEqual(lengths[indices[0]], 50)
|
||||||
|
# The indices should be a permutation of range(100)
|
||||||
|
self.assertEqual(list(sorted(indices)), list(range(100)))
|
||||||
|
|
||||||
|
def test_distributed_length_grouped(self):
|
||||||
|
# Get some inputs of random lengths
|
||||||
|
lengths = torch.randint(0, 25, (100,)).tolist()
|
||||||
|
# Put one bigger than the others to check it ends up in first position
|
||||||
|
lengths[32] = 50
|
||||||
|
|
||||||
|
indices_process_0 = list(DistributedLengthGroupedSampler(lengths, 4, 2, 0, lengths=lengths))
|
||||||
|
indices_process_1 = list(DistributedLengthGroupedSampler(lengths, 4, 2, 1, lengths=lengths))
|
||||||
|
# The biggest element should be first
|
||||||
|
self.assertEqual(lengths[indices_process_0[0]], 50)
|
||||||
|
# The indices should be a permutation of range(100)
|
||||||
|
self.assertEqual(list(sorted(indices_process_0 + indices_process_1)), list(range(100)))
|
||||||
|
Loading…
Reference in New Issue
Block a user