[ForCausalLMLoss] allow users to pass shifted labels (#36607)

* [ForCausalLMLoss] allow users to pass shifted labels

Signed-off-by: Stas Bekman <stas@stason.org>

* style

Signed-off-by: Stas Bekman <stas@stason.org>

---------

Signed-off-by: Stas Bekman <stas@stason.org>
This commit is contained in:
Stas Bekman 2025-03-20 03:25:22 -07:00 committed by GitHub
parent 94555437e2
commit 8f64b177f6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -31,14 +31,22 @@ def fixed_cross_entropy(source, target, num_items_in_batch: int = None, ignore_i
def ForCausalLMLoss(
logits, labels, vocab_size: int, num_items_in_batch: int = None, ignore_index: int = -100, **kwargs
logits,
labels,
vocab_size: int,
num_items_in_batch: int = None,
ignore_index: int = -100,
shift_labels=None,
**kwargs,
):
# Upcast to float if we need to compute the loss to avoid potential precision issues
logits = logits.float()
labels = labels.to(logits.device)
# Shift so that tokens < n predict n
labels = nn.functional.pad(labels, (0, 1), value=ignore_index)
shift_labels = labels[..., 1:].contiguous()
if shift_labels is None:
labels = labels.to(logits.device)
# Shift so that tokens < n predict n
labels = nn.functional.pad(labels, (0, 1), value=ignore_index)
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
logits = logits.view(-1, vocab_size)