mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
[ForCausalLMLoss] allow users to pass shifted labels (#36607)
* [ForCausalLMLoss] allow users to pass shifted labels Signed-off-by: Stas Bekman <stas@stason.org> * style Signed-off-by: Stas Bekman <stas@stason.org> --------- Signed-off-by: Stas Bekman <stas@stason.org>
This commit is contained in:
parent
94555437e2
commit
8f64b177f6
@ -31,14 +31,22 @@ def fixed_cross_entropy(source, target, num_items_in_batch: int = None, ignore_i
|
||||
|
||||
|
||||
def ForCausalLMLoss(
|
||||
logits, labels, vocab_size: int, num_items_in_batch: int = None, ignore_index: int = -100, **kwargs
|
||||
logits,
|
||||
labels,
|
||||
vocab_size: int,
|
||||
num_items_in_batch: int = None,
|
||||
ignore_index: int = -100,
|
||||
shift_labels=None,
|
||||
**kwargs,
|
||||
):
|
||||
# Upcast to float if we need to compute the loss to avoid potential precision issues
|
||||
logits = logits.float()
|
||||
labels = labels.to(logits.device)
|
||||
# Shift so that tokens < n predict n
|
||||
labels = nn.functional.pad(labels, (0, 1), value=ignore_index)
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
|
||||
if shift_labels is None:
|
||||
labels = labels.to(logits.device)
|
||||
# Shift so that tokens < n predict n
|
||||
labels = nn.functional.pad(labels, (0, 1), value=ignore_index)
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
|
||||
# Flatten the tokens
|
||||
logits = logits.view(-1, vocab_size)
|
||||
|
Loading…
Reference in New Issue
Block a user