mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
Update trainer._get_eval_sampler()
to support group_by_length
arg (#33514)
Update 'trainer._get_eval_sampler()' to support 'group_by_length' argument Trainer didn't support grouping by length for evaluation, which made evaluation slow with 'eval_batch_size'>1. Updated 'trainer._get_eval_sampler()' method was based off of 'trainer._get_train_sampler()'.
This commit is contained in:
parent
3f06f95ebe
commit
6d2b203339
@ -959,6 +959,10 @@ class Trainer:
|
||||
return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params))
|
||||
|
||||
def _get_eval_sampler(self, eval_dataset: Dataset) -> Optional[torch.utils.data.Sampler]:
|
||||
if self.eval_dataset is None or not has_length(self.eval_dataset):
|
||||
return None
|
||||
# Build the sampler.
|
||||
|
||||
# Deprecated code
|
||||
if self.args.use_legacy_prediction_loop:
|
||||
if is_torch_xla_available():
|
||||
@ -975,6 +979,23 @@ class Trainer:
|
||||
else:
|
||||
return SequentialSampler(eval_dataset)
|
||||
|
||||
if self.args.group_by_length:
|
||||
if is_datasets_available() and isinstance(self.eval_dataset, datasets.Dataset):
|
||||
lengths = (
|
||||
self.eval_dataset[self.args.length_column_name]
|
||||
if self.args.length_column_name in self.eval_dataset.column_names
|
||||
else None
|
||||
)
|
||||
else:
|
||||
lengths = None
|
||||
model_input_name = self.tokenizer.model_input_names[0] if self.tokenizer is not None else None
|
||||
return LengthGroupedSampler(
|
||||
self.args.eval_batch_size,
|
||||
dataset=self.eval_dataset,
|
||||
lengths=lengths,
|
||||
model_input_name=model_input_name,
|
||||
)
|
||||
|
||||
if self.args.world_size <= 1:
|
||||
return SequentialSampler(eval_dataset)
|
||||
else:
|
||||
|
Loading…
Reference in New Issue
Block a user