Update trainer._get_eval_sampler() to support group_by_length arg (#33514)

Update 'trainer._get_eval_sampler()' to support 'group_by_length' argument

Trainer didn't support grouping by length for evaluation, which made evaluation slow with 'eval_batch_size'>1.

Updated 'trainer._get_eval_sampler()' method was based off of 'trainer._get_train_sampler()'.
This commit is contained in:
larin92 2024-10-17 15:43:29 +03:00 committed by GitHub
parent 3f06f95ebe
commit 6d2b203339
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -959,6 +959,10 @@ class Trainer:
return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params))
def _get_eval_sampler(self, eval_dataset: Dataset) -> Optional[torch.utils.data.Sampler]:
if self.eval_dataset is None or not has_length(self.eval_dataset):
return None
# Build the sampler.
# Deprecated code
if self.args.use_legacy_prediction_loop:
if is_torch_xla_available():
@ -975,6 +979,23 @@ class Trainer:
else:
return SequentialSampler(eval_dataset)
if self.args.group_by_length:
if is_datasets_available() and isinstance(self.eval_dataset, datasets.Dataset):
lengths = (
self.eval_dataset[self.args.length_column_name]
if self.args.length_column_name in self.eval_dataset.column_names
else None
)
else:
lengths = None
model_input_name = self.tokenizer.model_input_names[0] if self.tokenizer is not None else None
return LengthGroupedSampler(
self.args.eval_batch_size,
dataset=self.eval_dataset,
lengths=lengths,
model_input_name=model_input_name,
)
if self.args.world_size <= 1:
return SequentialSampler(eval_dataset)
else: