Add new dim to num_items_in_batch if necessary (#36967)

* Add new dim to `num_items_in_batch` if necessary

* Unsqueeze only in the DP case

---------

Co-authored-by: Ilyas Moutawwakil <57442720+IlyasMoutawwakil@users.noreply.github.com>
Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
Co-authored-by: Yih-Dar <2521628+ydshieh@users.noreply.github.com>
This commit is contained in:
regisss 2025-04-03 01:57:03 -06:00 committed by GitHub
parent 98601cc818
commit 12048990a9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -5268,6 +5268,10 @@ class Trainer:
if torch.is_tensor(num_items_in_batch):
num_items_in_batch = num_items_in_batch.to(device)
if self.args.n_gpu > 1 and num_items_in_batch.dim() == 0:
# In the DataParallel case, convert the scalar tensor into a 1-dim tensor
num_items_in_batch = num_items_in_batch.unsqueeze(0)
return batch_samples, num_items_in_batch
def set_initial_training_values(