From 511bce58bd93d2cbc6eef21773b99b7a35b0d814 Mon Sep 17 00:00:00 2001 From: thomwolf Date: Fri, 30 Nov 2018 22:56:02 +0100 Subject: [PATCH] update new token classification model --- pytorch_pretrained_bert/modeling.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_pretrained_bert/modeling.py b/pytorch_pretrained_bert/modeling.py index 13666c86dff..3af5854072f 100644 --- a/pytorch_pretrained_bert/modeling.py +++ b/pytorch_pretrained_bert/modeling.py @@ -932,8 +932,8 @@ class BertForTokenClassification(PreTrainedBertModel): def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None): sequence_output, _ = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False) - pooled_output = self.dropout(sequence_output) - logits = self.classifier(pooled_output) + sequence_output = self.dropout(sequence_output) + logits = self.classifier(sequence_output) if labels is not None: loss_fct = CrossEntropyLoss()