mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
LongformerForSequenceClassification (#4580)
* LongformerForSequenceClassification * better naming x=>hidden_states, fix typo in doc * Update src/transformers/modeling_longformer.py * Update src/transformers/modeling_longformer.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
parent
4402879ee4
commit
ec4cdfdd05
@ -338,6 +338,7 @@ if is_torch_available():
|
||||
from .modeling_longformer import (
|
||||
LongformerModel,
|
||||
LongformerForMaskedLM,
|
||||
LongformerForSequenceClassification,
|
||||
LongformerForQuestionAnswering,
|
||||
LONGFORMER_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
)
|
||||
|
@ -105,6 +105,7 @@ from .modeling_longformer import (
|
||||
LONGFORMER_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
LongformerForMaskedLM,
|
||||
LongformerForQuestionAnswering,
|
||||
LongformerForSequenceClassification,
|
||||
LongformerModel,
|
||||
)
|
||||
from .modeling_marian import MarianMTModel
|
||||
@ -252,6 +253,7 @@ MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = OrderedDict(
|
||||
(CamembertConfig, CamembertForSequenceClassification),
|
||||
(XLMRobertaConfig, XLMRobertaForSequenceClassification),
|
||||
(BartConfig, BartForSequenceClassification),
|
||||
(LongformerConfig, LongformerForSequenceClassification),
|
||||
(RobertaConfig, RobertaForSequenceClassification),
|
||||
(BertConfig, BertForSequenceClassification),
|
||||
(XLNetConfig, XLNetForSequenceClassification),
|
||||
|
@ -19,7 +19,7 @@ import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import CrossEntropyLoss
|
||||
from torch.nn import CrossEntropyLoss, MSELoss
|
||||
from torch.nn import functional as F
|
||||
|
||||
from .configuration_longformer import LongformerConfig
|
||||
@ -710,6 +710,121 @@ class LongformerForMaskedLM(BertPreTrainedModel):
|
||||
return outputs # (masked_lm_loss), prediction_scores, (hidden_states), (attentions)
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""Longformer Model transformer with a sequence classification/regression head on top (a linear layer
|
||||
on top of the pooled output) e.g. for GLUE tasks. """,
|
||||
LONGFORMER_START_DOCSTRING,
|
||||
)
|
||||
class LongformerForSequenceClassification(BertPreTrainedModel):
|
||||
config_class = LongformerConfig
|
||||
pretrained_model_archive_map = LONGFORMER_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
base_model_prefix = "longformer"
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.num_labels = config.num_labels
|
||||
|
||||
self.longformer = LongformerModel(config)
|
||||
self.classifier = LongformerClassificationHead(config)
|
||||
|
||||
@add_start_docstrings_to_callable(LONGFORMER_INPUTS_DOCSTRING)
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
inputs_embeds=None,
|
||||
labels=None,
|
||||
):
|
||||
r"""
|
||||
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
|
||||
Labels for computing the sequence classification/regression loss.
|
||||
Indices should be in :obj:`[0, ..., config.num_labels - 1]`.
|
||||
If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
|
||||
If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
||||
|
||||
Returns:
|
||||
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.LongformerConfig`) and inputs:
|
||||
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`label` is provided):
|
||||
Classification (or regression if config.num_labels==1) loss.
|
||||
logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.num_labels)`):
|
||||
Classification (or regression if config.num_labels==1) scores (before SoftMax).
|
||||
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
|
||||
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
|
||||
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
||||
|
||||
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``):
|
||||
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
|
||||
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
|
||||
|
||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
||||
heads.
|
||||
|
||||
Examples::
|
||||
|
||||
from transformers import LongformerTokenizer, LongformerForSequenceClassification
|
||||
import torch
|
||||
|
||||
tokenizer = LongformerTokenizer.from_pretrained('longformer-base-4096')
|
||||
model = LongformerForSequenceClassification.from_pretrained('longformer-base-4096')
|
||||
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0) # Batch size 1
|
||||
labels = torch.tensor([1]).unsqueeze(0) # Batch size 1
|
||||
outputs = model(input_ids, labels=labels)
|
||||
loss, logits = outputs[:2]
|
||||
|
||||
"""
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones_like(input_ids)
|
||||
|
||||
# global attention on cls token
|
||||
attention_mask[:, 0] = 2
|
||||
|
||||
outputs = self.longformer(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
inputs_embeds=inputs_embeds,
|
||||
)
|
||||
sequence_output = outputs[0]
|
||||
logits = self.classifier(sequence_output)
|
||||
|
||||
outputs = (logits,) + outputs[2:]
|
||||
if labels is not None:
|
||||
if self.num_labels == 1:
|
||||
# We are doing regression
|
||||
loss_fct = MSELoss()
|
||||
loss = loss_fct(logits.view(-1), labels.view(-1))
|
||||
else:
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||
outputs = (loss,) + outputs
|
||||
|
||||
return outputs # (loss), logits, (hidden_states), (attentions)
|
||||
|
||||
|
||||
class LongformerClassificationHead(nn.Module):
|
||||
"""Head for sentence-level classification tasks."""
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
self.out_proj = nn.Linear(config.hidden_size, config.num_labels)
|
||||
|
||||
def forward(self, hidden_states, **kwargs):
|
||||
hidden_states = hidden_states[:, 0, :] # take <s> token (equiv. to [CLS])
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
hidden_states = self.dense(hidden_states)
|
||||
hidden_states = torch.tanh(hidden_states)
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
output = self.out_proj(hidden_states)
|
||||
return output
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""Longformer Model with a span classification head on top for extractive question-answering tasks like SQuAD / TriviaQA (a linear layers on top of
|
||||
the hidden-states output to compute `span start logits` and `span end logits`). """,
|
||||
|
@ -29,6 +29,7 @@ if is_torch_available():
|
||||
LongformerConfig,
|
||||
LongformerModel,
|
||||
LongformerForMaskedLM,
|
||||
LongformerForSequenceClassification,
|
||||
LongformerForQuestionAnswering,
|
||||
)
|
||||
|
||||
@ -194,6 +195,23 @@ class LongformerModelTester(object):
|
||||
self.parent.assertListEqual(list(result["end_logits"].size()), [self.batch_size, self.seq_length])
|
||||
self.check_loss_output(result)
|
||||
|
||||
def create_and_check_longformer_for_sequence_classification(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
):
|
||||
config.num_labels = self.num_labels
|
||||
model = LongformerForSequenceClassification(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
loss, logits = model(
|
||||
input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=sequence_labels
|
||||
)
|
||||
result = {
|
||||
"loss": loss,
|
||||
"logits": logits,
|
||||
}
|
||||
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.num_labels])
|
||||
self.check_loss_output(result)
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
(
|
||||
@ -256,6 +274,10 @@ class LongformerModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_question_answering()
|
||||
self.model_tester.create_and_check_longformer_for_question_answering(*config_and_inputs)
|
||||
|
||||
def test_for_sequence_classification(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_longformer_for_sequence_classification(*config_and_inputs)
|
||||
|
||||
|
||||
class LongformerModelIntegrationTest(unittest.TestCase):
|
||||
@slow
|
||||
|
Loading…
Reference in New Issue
Block a user