update new token classification model

This commit is contained in:
thomwolf 2018-11-30 22:56:02 +01:00
parent 258eb50086
commit 511bce58bd

View File

@ -932,8 +932,8 @@ class BertForTokenClassification(PreTrainedBertModel):
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None):
sequence_output, _ = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False)
pooled_output = self.dropout(sequence_output)
logits = self.classifier(pooled_output)
sequence_output = self.dropout(sequence_output)
logits = self.classifier(sequence_output)
if labels is not None:
loss_fct = CrossEntropyLoss()