mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Fix batch size handling in prediction_loop for DataLoaderShard (#34343)
* Fix batch size handling in prediction_loop for DataLoaderShard Updated the prediction_loop method in the Trainer class to correctly handle batch size when using DataLoaderShard. This ensures that the batch size is retrieved from total_batch_size for distributed training scenarios, preventing TypeError related to NoneType during evaluation. * Update src/transformers/trainer.py Co-authored-by: Zach Mueller <muellerzr@gmail.com> * Applied the fix to remove unused imports --------- Co-authored-by: Zach Mueller <muellerzr@gmail.com>
This commit is contained in:
parent
9360f1827d
commit
92bcdff2ef
@ -4714,7 +4714,17 @@ class Trainer:
|
||||
elif args.bf16_full_eval:
|
||||
model = model.to(dtype=torch.bfloat16, device=args.device)
|
||||
|
||||
batch_size = dataloader.batch_size
|
||||
batch_size = (
|
||||
dataloader.total_batch_size
|
||||
if getattr(dataloader, "_is_accelerate_prepared", False)
|
||||
else dataloader.batch_size
|
||||
)
|
||||
|
||||
if batch_size is None:
|
||||
raise ValueError(
|
||||
"Batch size cannot be None. Ensure the dataloader has a valid batch_size or total_batch_size."
|
||||
)
|
||||
|
||||
num_examples = self.num_examples(dataloader)
|
||||
logger.info(f"\n***** Running {description} *****")
|
||||
logger.info(f" Num examples = {num_examples}")
|
||||
|
Loading…
Reference in New Issue
Block a user