fix consistency CrossEntropyLoss in modeling_bart (#6265)

This commit is contained in:
idoh 2020-08-07 12:44:28 +03:00 committed by GitHub
parent c72f9c90a1
commit 3be2d04884
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1040,7 +1040,7 @@ class BartForConditionalGeneration(PretrainedBartModel):
masked_lm_loss = None
if labels is not None:
loss_fct = nn.CrossEntropyLoss()
loss_fct = CrossEntropyLoss()
# TODO(SS): do we need to ignore pad tokens in labels?
masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1))
@ -1179,7 +1179,8 @@ class BartForSequenceClassification(PretrainedBartModel):
loss = None
if labels is not None:
loss = F.cross_entropy(logits.view(-1, self.config.num_labels), labels.view(-1))
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
if not return_dict:
output = (logits,) + outputs[1:]