From a24f830604fc150526d9fd4596a4f3900916abe9 Mon Sep 17 00:00:00 2001 From: wangfei <1140554608@qq.com> Date: Sat, 3 Aug 2019 12:17:06 +0800 Subject: [PATCH] Fix comment typo --- pytorch_transformers/modeling_bert.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_transformers/modeling_bert.py b/pytorch_transformers/modeling_bert.py index b59445513a4..418939f7dad 100644 --- a/pytorch_transformers/modeling_bert.py +++ b/pytorch_transformers/modeling_bert.py @@ -857,7 +857,7 @@ class BertForMaskedLM(BertPreTrainedModel): sequence_output = outputs[0] prediction_scores = self.cls(sequence_output) - outputs = (prediction_scores,) + outputs[2:] # Add hidden states and attention is they are here + outputs = (prediction_scores,) + outputs[2:] # Add hidden states and attention if they are here if masked_lm_labels is not None: loss_fct = CrossEntropyLoss(ignore_index=-1) masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1))