mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Merge pull request #1792 from stefan-it/distilbert-for-token-classification
DistilBERT for token classification
This commit is contained in:
commit
74ce8de7d8
@ -554,6 +554,16 @@ On the test dataset the following results could be achieved:
|
||||
10/04/2019 00:42:42 - INFO - __main__ - recall = 0.8624150210424085
|
||||
```
|
||||
|
||||
### Comparing BERT (large, cased), RoBERTa (large, cased) and DistilBERT (base, uncased)
|
||||
|
||||
Here is a small comparison between BERT (large, cased), RoBERTa (large, cased) and DistilBERT (base, uncased) with the same hyperparameters as specified in the [example documentation](https://huggingface.co/transformers/examples.html#named-entity-recognition) (one run):
|
||||
|
||||
| Model | F-Score Dev | F-Score Test
|
||||
| --------------------------------- | ------- | --------
|
||||
| `bert-large-cased` | 95.59 | 91.70
|
||||
| `roberta-large` | 95.96 | 91.87
|
||||
| `distilbert-base-uncased` | 94.34 | 90.32
|
||||
|
||||
## Abstractive summarization
|
||||
|
||||
Based on the script
|
||||
|
@ -36,16 +36,18 @@ from utils_ner import convert_examples_to_features, get_labels, read_examples_fr
|
||||
from transformers import AdamW, get_linear_schedule_with_warmup
|
||||
from transformers import WEIGHTS_NAME, BertConfig, BertForTokenClassification, BertTokenizer
|
||||
from transformers import RobertaConfig, RobertaForTokenClassification, RobertaTokenizer
|
||||
from transformers import DistilBertConfig, DistilBertForTokenClassification, DistilBertTokenizer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
ALL_MODELS = sum(
|
||||
(tuple(conf.pretrained_config_archive_map.keys()) for conf in (BertConfig, RobertaConfig)),
|
||||
(tuple(conf.pretrained_config_archive_map.keys()) for conf in (BertConfig, RobertaConfig, DistilBertConfig)),
|
||||
())
|
||||
|
||||
MODEL_CLASSES = {
|
||||
"bert": (BertConfig, BertForTokenClassification, BertTokenizer),
|
||||
"roberta": (RobertaConfig, RobertaForTokenClassification, RobertaTokenizer)
|
||||
"roberta": (RobertaConfig, RobertaForTokenClassification, RobertaTokenizer),
|
||||
"distilbert": (DistilBertConfig, DistilBertForTokenClassification, DistilBertTokenizer)
|
||||
}
|
||||
|
||||
|
||||
@ -121,9 +123,10 @@ def train(args, train_dataset, model, tokenizer, labels, pad_token_label_id):
|
||||
batch = tuple(t.to(args.device) for t in batch)
|
||||
inputs = {"input_ids": batch[0],
|
||||
"attention_mask": batch[1],
|
||||
"token_type_ids": batch[2] if args.model_type in ["bert", "xlnet"] else None,
|
||||
# XLM and RoBERTa don"t use segment_ids
|
||||
"labels": batch[3]}
|
||||
if args.model_type != "distilbert":
|
||||
inputs["token_type_ids"]: batch[2] if args.model_type in ["bert", "xlnet"] else None # XLM and RoBERTa don"t use segment_ids
|
||||
|
||||
outputs = model(**inputs)
|
||||
loss = outputs[0] # model outputs are always tuple in pytorch-transformers (see doc)
|
||||
|
||||
@ -210,9 +213,9 @@ def evaluate(args, model, tokenizer, labels, pad_token_label_id, mode, prefix=""
|
||||
with torch.no_grad():
|
||||
inputs = {"input_ids": batch[0],
|
||||
"attention_mask": batch[1],
|
||||
"token_type_ids": batch[2] if args.model_type in ["bert", "xlnet"] else None,
|
||||
# XLM and RoBERTa don"t use segment_ids
|
||||
"labels": batch[3]}
|
||||
if args.model_type != "distilbert":
|
||||
inputs["token_type_ids"]: batch[2] if args.model_type in ["bert", "xlnet"] else None # XLM and RoBERTa don"t use segment_ids
|
||||
outputs = model(**inputs)
|
||||
tmp_eval_loss, logits = outputs[:2]
|
||||
|
||||
@ -524,3 +527,4 @@ def main():
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
|
@ -94,6 +94,7 @@ if is_torch_available():
|
||||
ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP)
|
||||
from .modeling_distilbert import (DistilBertForMaskedLM, DistilBertModel,
|
||||
DistilBertForSequenceClassification, DistilBertForQuestionAnswering,
|
||||
DistilBertForTokenClassification,
|
||||
DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP)
|
||||
from .modeling_encoder_decoder import PreTrainedEncoderDecoder, Model2Model
|
||||
|
||||
|
@ -30,6 +30,7 @@ import numpy as np
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import CrossEntropyLoss
|
||||
|
||||
from .modeling_utils import PreTrainedModel, prune_linear_layer
|
||||
from .configuration_distilbert import DistilBertConfig
|
||||
@ -702,3 +703,75 @@ class DistilBertForQuestionAnswering(DistilBertPreTrainedModel):
|
||||
outputs = (total_loss,) + outputs
|
||||
|
||||
return outputs # (loss), start_logits, end_logits, (hidden_states), (attentions)
|
||||
|
||||
|
||||
@add_start_docstrings("""DistilBert Model with a token classification head on top (a linear layer on top of
|
||||
the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks. """,
|
||||
DISTILBERT_START_DOCSTRING,
|
||||
DISTILBERT_INPUTS_DOCSTRING)
|
||||
class DistilBertForTokenClassification(DistilBertPreTrainedModel):
|
||||
r"""
|
||||
**labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
|
||||
Labels for computing the token classification loss.
|
||||
Indices should be in ``[0, ..., config.num_labels - 1]``.
|
||||
|
||||
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
|
||||
**loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
|
||||
Classification loss.
|
||||
**scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, config.num_labels)``
|
||||
Classification scores (before SoftMax).
|
||||
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
|
||||
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
|
||||
of shape ``(batch_size, sequence_length, hidden_size)``:
|
||||
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
|
||||
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
|
||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
|
||||
|
||||
Examples::
|
||||
|
||||
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
|
||||
model = DistilBertForTokenClassification.from_pretrained('distilbert-base-uncased')
|
||||
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
|
||||
labels = torch.tensor([1] * input_ids.size(1)).unsqueeze(0) # Batch size 1
|
||||
outputs = model(input_ids, labels=labels)
|
||||
loss, scores = outputs[:2]
|
||||
|
||||
"""
|
||||
def __init__(self, config):
|
||||
super(DistilBertForTokenClassification, self).__init__(config)
|
||||
self.num_labels = config.num_labels
|
||||
|
||||
self.distilbert = DistilBertModel(config)
|
||||
self.dropout = nn.Dropout(config.dropout)
|
||||
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
||||
|
||||
self.init_weights()
|
||||
|
||||
def forward(self, input_ids=None, attention_mask=None, head_mask=None,
|
||||
inputs_embeds=None, labels=None):
|
||||
|
||||
outputs = self.distilbert(input_ids,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds)
|
||||
|
||||
sequence_output = outputs[0]
|
||||
|
||||
sequence_output = self.dropout(sequence_output)
|
||||
logits = self.classifier(sequence_output)
|
||||
|
||||
outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here
|
||||
if labels is not None:
|
||||
loss_fct = CrossEntropyLoss()
|
||||
# Only keep active parts of the loss
|
||||
if attention_mask is not None:
|
||||
active_loss = attention_mask.view(-1) == 1
|
||||
active_logits = logits.view(-1, self.num_labels)[active_loss]
|
||||
active_labels = labels.view(-1)[active_loss]
|
||||
loss = loss_fct(active_logits, active_labels)
|
||||
else:
|
||||
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||
outputs = (loss,) + outputs
|
||||
|
||||
return outputs # (loss), scores, (hidden_states), (attentions)
|
||||
|
@ -23,6 +23,7 @@ from transformers import is_torch_available
|
||||
|
||||
if is_torch_available():
|
||||
from transformers import (DistilBertConfig, DistilBertModel, DistilBertForMaskedLM,
|
||||
DistilBertForTokenClassification,
|
||||
DistilBertForQuestionAnswering, DistilBertForSequenceClassification)
|
||||
else:
|
||||
pytestmark = pytest.mark.skip("Require Torch")
|
||||
@ -180,6 +181,21 @@ class DistilBertModelTest(CommonTestCases.CommonModelTester):
|
||||
[self.batch_size, self.num_labels])
|
||||
self.check_loss_output(result)
|
||||
|
||||
def create_and_check_distilbert_for_token_classification(self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels):
|
||||
config.num_labels = self.num_labels
|
||||
model = DistilBertForTokenClassification(config=config)
|
||||
model.eval()
|
||||
|
||||
loss, logits = model(input_ids, attention_mask=input_mask, labels=token_labels)
|
||||
result = {
|
||||
"loss": loss,
|
||||
"logits": logits,
|
||||
}
|
||||
self.parent.assertListEqual(
|
||||
list(result["logits"].size()),
|
||||
[self.batch_size, self.seq_length, self.num_labels])
|
||||
self.check_loss_output(result)
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
(config, input_ids, input_mask, sequence_labels, token_labels, choice_labels) = config_and_inputs
|
||||
@ -209,6 +225,10 @@ class DistilBertModelTest(CommonTestCases.CommonModelTester):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_distilbert_for_sequence_classification(*config_and_inputs)
|
||||
|
||||
def test_for_token_classification(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_distilbert_for_token_classification(*config_and_inputs)
|
||||
|
||||
# @pytest.mark.slow
|
||||
# def test_model_from_pretrained(self):
|
||||
# cache_dir = "/tmp/transformers_test/"
|
||||
|
Loading…
Reference in New Issue
Block a user