mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 18:22:34 +06:00
Optimize ForCausalLMLoss by removing unnecessary contiguous() call to reduce memory overhead (#35646)
Optimize ForCausalLMLoss by removing unnecessary contiguous() calls to reduce memory overhead
This commit is contained in:
parent
1302c32a84
commit
8ebe9d7166
@ -36,15 +36,15 @@ def ForCausalLMLoss(
|
|||||||
logits = logits.float()
|
logits = logits.float()
|
||||||
labels = labels.to(logits.device)
|
labels = labels.to(logits.device)
|
||||||
# Shift so that tokens < n predict n
|
# Shift so that tokens < n predict n
|
||||||
shift_logits = logits[..., :-1, :].contiguous()
|
labels = nn.functional.pad(labels, (0, 1), value=ignore_index)
|
||||||
shift_labels = labels[..., 1:].contiguous()
|
shift_labels = labels[..., 1:].contiguous()
|
||||||
|
|
||||||
# Flatten the tokens
|
# Flatten the tokens
|
||||||
shift_logits = shift_logits.view(-1, vocab_size)
|
logits = logits.view(-1, vocab_size)
|
||||||
shift_labels = shift_labels.view(-1)
|
shift_labels = shift_labels.view(-1)
|
||||||
# Enable model parallelism
|
# Enable model parallelism
|
||||||
shift_labels = shift_labels.to(shift_logits.device)
|
shift_labels = shift_labels.to(logits.device)
|
||||||
loss = fixed_cross_entropy(shift_logits, shift_labels, num_items_in_batch, ignore_index, **kwargs)
|
loss = fixed_cross_entropy(logits, shift_labels, num_items_in_batch, ignore_index, **kwargs)
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user