Update Loss Functions to Accept Tensor num_items_in_batch (#38029)

* Update Loss Functions to Accept Tensor num_items_in_batch

* Fix device mismatch by moving num_items_in_batch to loss device in fixed_cross_entropy

* fix the ruff check

* delete the unused if stat

* fix the type problem
This commit is contained in:
mohammed benyamna 2025-06-02 10:31:44 +01:00 committed by GitHub
parent 493cf1554b
commit a75b9ffb5c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -35,6 +35,11 @@ def fixed_cross_entropy(
reduction = "sum" if num_items_in_batch is not None else "mean"
loss = nn.functional.cross_entropy(source, target, ignore_index=ignore_index, reduction=reduction)
if reduction == "sum":
if not isinstance(num_items_in_batch, torch.Tensor):
num_items_in_batch = torch.tensor(num_items_in_batch, device=loss.device, dtype=loss.dtype)
elif num_items_in_batch.device != loss.device:
num_items_in_batch = num_items_in_batch.to(loss.device)
loss = loss / num_items_in_batch
return loss