mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 10:12:23 +06:00
Add new dim to num_items_in_batch
if necessary (#36967)
* Add new dim to `num_items_in_batch` if necessary * Unsqueeze only in the DP case --------- Co-authored-by: Ilyas Moutawwakil <57442720+IlyasMoutawwakil@users.noreply.github.com> Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> Co-authored-by: Yih-Dar <2521628+ydshieh@users.noreply.github.com>
This commit is contained in:
parent
98601cc818
commit
12048990a9
@ -5268,6 +5268,10 @@ class Trainer:
|
||||
if torch.is_tensor(num_items_in_batch):
|
||||
num_items_in_batch = num_items_in_batch.to(device)
|
||||
|
||||
if self.args.n_gpu > 1 and num_items_in_batch.dim() == 0:
|
||||
# In the DataParallel case, convert the scalar tensor into a 1-dim tensor
|
||||
num_items_in_batch = num_items_in_batch.unsqueeze(0)
|
||||
|
||||
return batch_samples, num_items_in_batch
|
||||
|
||||
def set_initial_training_values(
|
||||
|
Loading…
Reference in New Issue
Block a user