Fix multi-gpu loss (#35395)

push to device
This commit is contained in:
Arthur 2025-01-09 10:14:31 +01:00 committed by GitHub
parent 2e2f8015c0
commit 8cad65a698
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -34,6 +34,7 @@ def ForCausalLMLoss(
):
# Upcast to float if we need to compute the loss to avoid potential precision issues
logits = logits.float()
labels = labels.to(logits.device)
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
@ -52,6 +53,7 @@ def ForMaskedLMLoss(
):
# Upcast to float if we need to compute the loss to avoid potential precision issues
logits = logits.float()
labels = labels.to(logits.device)
# Flatten the tokens
logits = logits.view(-1, vocab_size)
@ -73,6 +75,7 @@ def ForSequenceClassificationLoss(labels, pooled_logits, config, **kwargs):
else:
config.problem_type = "multi_label_classification"
labels = labels.to(pooled_logits.device)
if config.problem_type == "regression":
loss_fct = MSELoss()
if num_labels == 1:
@ -109,7 +112,7 @@ def ForQuestionAnsweringLoss(start_logits, end_logits, start_positions, end_posi
def ForTokenClassification(logits, labels, config, **kwargs):
# Upcast to float if we need to compute the loss to avoid potential precision issues
logits = logits.view(-1, config.num_labels)
labels = labels.view(-1)
labels = labels.view(-1).to(logits.device)
logits = logits.float()
# Flatten the tokens
return fixed_cross_entropy(logits, labels, **kwargs)