mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
update new token classification model
This commit is contained in:
parent
258eb50086
commit
511bce58bd
@ -932,8 +932,8 @@ class BertForTokenClassification(PreTrainedBertModel):
|
||||
|
||||
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None):
|
||||
sequence_output, _ = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False)
|
||||
pooled_output = self.dropout(sequence_output)
|
||||
logits = self.classifier(pooled_output)
|
||||
sequence_output = self.dropout(sequence_output)
|
||||
logits = self.classifier(sequence_output)
|
||||
|
||||
if labels is not None:
|
||||
loss_fct = CrossEntropyLoss()
|
||||
|
Loading…
Reference in New Issue
Block a user