mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
parent
2e2f8015c0
commit
8cad65a698
@ -34,6 +34,7 @@ def ForCausalLMLoss(
|
||||
):
|
||||
# Upcast to float if we need to compute the loss to avoid potential precision issues
|
||||
logits = logits.float()
|
||||
labels = labels.to(logits.device)
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
@ -52,6 +53,7 @@ def ForMaskedLMLoss(
|
||||
):
|
||||
# Upcast to float if we need to compute the loss to avoid potential precision issues
|
||||
logits = logits.float()
|
||||
labels = labels.to(logits.device)
|
||||
|
||||
# Flatten the tokens
|
||||
logits = logits.view(-1, vocab_size)
|
||||
@ -73,6 +75,7 @@ def ForSequenceClassificationLoss(labels, pooled_logits, config, **kwargs):
|
||||
else:
|
||||
config.problem_type = "multi_label_classification"
|
||||
|
||||
labels = labels.to(pooled_logits.device)
|
||||
if config.problem_type == "regression":
|
||||
loss_fct = MSELoss()
|
||||
if num_labels == 1:
|
||||
@ -109,7 +112,7 @@ def ForQuestionAnsweringLoss(start_logits, end_logits, start_positions, end_posi
|
||||
def ForTokenClassification(logits, labels, config, **kwargs):
|
||||
# Upcast to float if we need to compute the loss to avoid potential precision issues
|
||||
logits = logits.view(-1, config.num_labels)
|
||||
labels = labels.view(-1)
|
||||
labels = labels.view(-1).to(logits.device)
|
||||
logits = logits.float()
|
||||
# Flatten the tokens
|
||||
return fixed_cross_entropy(logits, labels, **kwargs)
|
||||
|
Loading…
Reference in New Issue
Block a user