LongformerForSequenceClassification (#4580)

* LongformerForSequenceClassification

* better naming x=>hidden_states, fix typo in doc

* Update src/transformers/modeling_longformer.py

* Update src/transformers/modeling_longformer.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
Suraj Patil 2020-05-28 02:00:00 +05:30 committed by GitHub
parent 4402879ee4
commit ec4cdfdd05
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 141 additions and 1 deletions

View File

@ -338,6 +338,7 @@ if is_torch_available():
from .modeling_longformer import (
LongformerModel,
LongformerForMaskedLM,
LongformerForSequenceClassification,
LongformerForQuestionAnswering,
LONGFORMER_PRETRAINED_MODEL_ARCHIVE_MAP,
)

View File

@ -105,6 +105,7 @@ from .modeling_longformer import (
LONGFORMER_PRETRAINED_MODEL_ARCHIVE_MAP,
LongformerForMaskedLM,
LongformerForQuestionAnswering,
LongformerForSequenceClassification,
LongformerModel,
)
from .modeling_marian import MarianMTModel
@ -252,6 +253,7 @@ MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = OrderedDict(
(CamembertConfig, CamembertForSequenceClassification),
(XLMRobertaConfig, XLMRobertaForSequenceClassification),
(BartConfig, BartForSequenceClassification),
(LongformerConfig, LongformerForSequenceClassification),
(RobertaConfig, RobertaForSequenceClassification),
(BertConfig, BertForSequenceClassification),
(XLNetConfig, XLNetForSequenceClassification),

View File

@ -19,7 +19,7 @@ import math
import torch
import torch.nn as nn
from torch.nn import CrossEntropyLoss
from torch.nn import CrossEntropyLoss, MSELoss
from torch.nn import functional as F
from .configuration_longformer import LongformerConfig
@ -710,6 +710,121 @@ class LongformerForMaskedLM(BertPreTrainedModel):
return outputs # (masked_lm_loss), prediction_scores, (hidden_states), (attentions)
@add_start_docstrings(
"""Longformer Model transformer with a sequence classification/regression head on top (a linear layer
on top of the pooled output) e.g. for GLUE tasks. """,
LONGFORMER_START_DOCSTRING,
)
class LongformerForSequenceClassification(BertPreTrainedModel):
config_class = LongformerConfig
pretrained_model_archive_map = LONGFORMER_PRETRAINED_MODEL_ARCHIVE_MAP
base_model_prefix = "longformer"
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.longformer = LongformerModel(config)
self.classifier = LongformerClassificationHead(config)
@add_start_docstrings_to_callable(LONGFORMER_INPUTS_DOCSTRING)
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
inputs_embeds=None,
labels=None,
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
Labels for computing the sequence classification/regression loss.
Indices should be in :obj:`[0, ..., config.num_labels - 1]`.
If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
Returns:
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.LongformerConfig`) and inputs:
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`label` is provided):
Classification (or regression if config.num_labels==1) loss.
logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.num_labels)`):
Classification (or regression if config.num_labels==1) scores (before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
Examples::
from transformers import LongformerTokenizer, LongformerForSequenceClassification
import torch
tokenizer = LongformerTokenizer.from_pretrained('longformer-base-4096')
model = LongformerForSequenceClassification.from_pretrained('longformer-base-4096')
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0) # Batch size 1
labels = torch.tensor([1]).unsqueeze(0) # Batch size 1
outputs = model(input_ids, labels=labels)
loss, logits = outputs[:2]
"""
if attention_mask is None:
attention_mask = torch.ones_like(input_ids)
# global attention on cls token
attention_mask[:, 0] = 2
outputs = self.longformer(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
inputs_embeds=inputs_embeds,
)
sequence_output = outputs[0]
logits = self.classifier(sequence_output)
outputs = (logits,) + outputs[2:]
if labels is not None:
if self.num_labels == 1:
# We are doing regression
loss_fct = MSELoss()
loss = loss_fct(logits.view(-1), labels.view(-1))
else:
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
outputs = (loss,) + outputs
return outputs # (loss), logits, (hidden_states), (attentions)
class LongformerClassificationHead(nn.Module):
"""Head for sentence-level classification tasks."""
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.out_proj = nn.Linear(config.hidden_size, config.num_labels)
def forward(self, hidden_states, **kwargs):
hidden_states = hidden_states[:, 0, :] # take <s> token (equiv. to [CLS])
hidden_states = self.dropout(hidden_states)
hidden_states = self.dense(hidden_states)
hidden_states = torch.tanh(hidden_states)
hidden_states = self.dropout(hidden_states)
output = self.out_proj(hidden_states)
return output
@add_start_docstrings(
"""Longformer Model with a span classification head on top for extractive question-answering tasks like SQuAD / TriviaQA (a linear layers on top of
the hidden-states output to compute `span start logits` and `span end logits`). """,

View File

@ -29,6 +29,7 @@ if is_torch_available():
LongformerConfig,
LongformerModel,
LongformerForMaskedLM,
LongformerForSequenceClassification,
LongformerForQuestionAnswering,
)
@ -194,6 +195,23 @@ class LongformerModelTester(object):
self.parent.assertListEqual(list(result["end_logits"].size()), [self.batch_size, self.seq_length])
self.check_loss_output(result)
def create_and_check_longformer_for_sequence_classification(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
config.num_labels = self.num_labels
model = LongformerForSequenceClassification(config)
model.to(torch_device)
model.eval()
loss, logits = model(
input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=sequence_labels
)
result = {
"loss": loss,
"logits": logits,
}
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.num_labels])
self.check_loss_output(result)
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
(
@ -256,6 +274,10 @@ class LongformerModelTest(ModelTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_question_answering()
self.model_tester.create_and_check_longformer_for_question_answering(*config_and_inputs)
def test_for_sequence_classification(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_longformer_for_sequence_classification(*config_and_inputs)
class LongformerModelIntegrationTest(unittest.TestCase):
@slow