From a75b9ffb5c653efbc8279d563f2275c2fd592575 Mon Sep 17 00:00:00 2001 From: mohammed benyamna Date: Mon, 2 Jun 2025 10:31:44 +0100 Subject: [PATCH] Update Loss Functions to Accept Tensor num_items_in_batch (#38029) * Update Loss Functions to Accept Tensor num_items_in_batch * Fix device mismatch by moving num_items_in_batch to loss device in fixed_cross_entropy * fix the ruff check * delete the unused if stat * fix the type problem --- src/transformers/loss/loss_utils.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/transformers/loss/loss_utils.py b/src/transformers/loss/loss_utils.py index aad42d3fd52..cf6d9078a94 100644 --- a/src/transformers/loss/loss_utils.py +++ b/src/transformers/loss/loss_utils.py @@ -35,6 +35,11 @@ def fixed_cross_entropy( reduction = "sum" if num_items_in_batch is not None else "mean" loss = nn.functional.cross_entropy(source, target, ignore_index=ignore_index, reduction=reduction) if reduction == "sum": + if not isinstance(num_items_in_batch, torch.Tensor): + num_items_in_batch = torch.tensor(num_items_in_batch, device=loss.device, dtype=loss.dtype) + elif num_items_in_batch.device != loss.device: + num_items_in_batch = num_items_in_batch.to(loss.device) + loss = loss / num_items_in_batch return loss