mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 04:40:06 +06:00
Update Loss Functions to Accept Tensor num_items_in_batch (#38029)
* Update Loss Functions to Accept Tensor num_items_in_batch * Fix device mismatch by moving num_items_in_batch to loss device in fixed_cross_entropy * fix the ruff check * delete the unused if stat * fix the type problem
This commit is contained in:
parent
493cf1554b
commit
a75b9ffb5c
@ -35,6 +35,11 @@ def fixed_cross_entropy(
|
||||
reduction = "sum" if num_items_in_batch is not None else "mean"
|
||||
loss = nn.functional.cross_entropy(source, target, ignore_index=ignore_index, reduction=reduction)
|
||||
if reduction == "sum":
|
||||
if not isinstance(num_items_in_batch, torch.Tensor):
|
||||
num_items_in_batch = torch.tensor(num_items_in_batch, device=loss.device, dtype=loss.dtype)
|
||||
elif num_items_in_batch.device != loss.device:
|
||||
num_items_in_batch = num_items_in_batch.to(loss.device)
|
||||
|
||||
loss = loss / num_items_in_batch
|
||||
return loss
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user