Optimize ForCausalLMLoss by removing unnecessary contiguous() call to reduce memory overhead (#35646)

Optimize ForCausalLMLoss by removing unnecessary contiguous() calls to reduce memory overhead
This commit is contained in:
efsotr 2025-01-16 23:47:43 +08:00 committed by GitHub
parent 1302c32a84
commit 8ebe9d7166
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -36,15 +36,15 @@ def ForCausalLMLoss(
logits = logits.float()
labels = labels.to(logits.device)
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
labels = nn.functional.pad(labels, (0, 1), value=ignore_index)
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
shift_logits = shift_logits.view(-1, vocab_size)
logits = logits.view(-1, vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = fixed_cross_entropy(shift_logits, shift_labels, num_items_in_batch, ignore_index, **kwargs)
shift_labels = shift_labels.to(logits.device)
loss = fixed_cross_entropy(logits, shift_labels, num_items_in_batch, ignore_index, **kwargs)
return loss