diff --git a/transformers/modeling_bert.py b/transformers/modeling_bert.py index 1c142fcd288..549bc5950b2 100644 --- a/transformers/modeling_bert.py +++ b/transformers/modeling_bert.py @@ -891,7 +891,7 @@ class BertForMaskedLM(BertPreTrainedModel): # 2. If `lm_labels` is provided we are in a causal scenario where we # try to predict the next token for each input in the decoder. if masked_lm_labels is not None: - loss_fct = CrossEntropyLoss() # -1 index = padding token + loss_fct = CrossEntropyLoss() # -100 index = padding token masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1)) outputs = (masked_lm_loss,) + outputs