mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Optimize ForCausalLMLoss by removing unnecessary contiguous() call to reduce memory overhead (#35646)
Optimize ForCausalLMLoss by removing unnecessary contiguous() calls to reduce memory overhead
This commit is contained in:
parent
1302c32a84
commit
8ebe9d7166
@ -36,15 +36,15 @@ def ForCausalLMLoss(
|
||||
logits = logits.float()
|
||||
labels = labels.to(logits.device)
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
labels = nn.functional.pad(labels, (0, 1), value=ignore_index)
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
|
||||
# Flatten the tokens
|
||||
shift_logits = shift_logits.view(-1, vocab_size)
|
||||
logits = logits.view(-1, vocab_size)
|
||||
shift_labels = shift_labels.view(-1)
|
||||
# Enable model parallelism
|
||||
shift_labels = shift_labels.to(shift_logits.device)
|
||||
loss = fixed_cross_entropy(shift_logits, shift_labels, num_items_in_batch, ignore_index, **kwargs)
|
||||
shift_labels = shift_labels.to(logits.device)
|
||||
loss = fixed_cross_entropy(logits, shift_labels, num_items_in_batch, ignore_index, **kwargs)
|
||||
return loss
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user