ignore_index equal -100 in T5 model

This commit is contained in:
thomwolf 2020-01-08 09:52:10 +01:00
parent 569da80ced
commit 1b59b57b57

View File

@ -905,7 +905,7 @@ class T5WithLMHeadModel(T5PreTrainedModel):
if lm_labels is not None:
shift_logits = lm_logits[..., :-1, :].contiguous()
shift_labels = lm_labels[..., 1:].contiguous()
loss_fct = CrossEntropyLoss(ignore_index=-1)
loss_fct = CrossEntropyLoss(ignore_index=-100)
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
decoder_outputs = (
loss,