Fix batch size handling in prediction_loop for DataLoaderShard (#34343)

* Fix batch size handling in prediction_loop for DataLoaderShard

Updated the prediction_loop method in the Trainer class to correctly handle batch size when using DataLoaderShard. This ensures that the batch size is retrieved from total_batch_size for distributed training scenarios, preventing TypeError related to NoneType during evaluation.

* Update src/transformers/trainer.py

Co-authored-by: Zach Mueller <muellerzr@gmail.com>

* Applied the fix to remove unused imports

---------

Co-authored-by: Zach Mueller <muellerzr@gmail.com>
This commit is contained in:
Nischay 2024-10-28 17:53:52 +05:30 committed by GitHub
parent 9360f1827d
commit 92bcdff2ef
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -4714,7 +4714,17 @@ class Trainer:
elif args.bf16_full_eval:
model = model.to(dtype=torch.bfloat16, device=args.device)
batch_size = dataloader.batch_size
batch_size = (
dataloader.total_batch_size
if getattr(dataloader, "_is_accelerate_prepared", False)
else dataloader.batch_size
)
if batch_size is None:
raise ValueError(
"Batch size cannot be None. Ensure the dataloader has a valid batch_size or total_batch_size."
)
num_examples = self.num_examples(dataloader)
logger.info(f"\n***** Running {description} *****")
logger.info(f" Num examples = {num_examples}")