BartForQuestionAnswering (#4908)

This commit is contained in:
Suraj Patil 2020-06-13 01:17:57 +05:30 committed by GitHub
parent 538531cde5
commit e93ccb3290
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 146 additions and 1 deletions

View File

@ -55,6 +55,13 @@ BartForSequenceClassification
:members: forward
BartForQuestionAnswering
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.BartForQuestionAnswering
:members: forward
BartForConditionalGeneration
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

View File

@ -250,6 +250,7 @@ if is_torch_available():
BartForSequenceClassification,
BartModel,
BartForConditionalGeneration,
BartForQuestionAnswering,
BART_PRETRAINED_MODEL_ARCHIVE_LIST,
)
from .modeling_marian import MarianMTModel

View File

@ -52,7 +52,12 @@ from .modeling_albert import (
AlbertForTokenClassification,
AlbertModel,
)
from .modeling_bart import BartForConditionalGeneration, BartForSequenceClassification, BartModel
from .modeling_bart import (
BartForConditionalGeneration,
BartForQuestionAnswering,
BartForSequenceClassification,
BartModel,
)
from .modeling_bert import (
BertForMaskedLM,
BertForMultipleChoice,
@ -274,6 +279,7 @@ MODEL_FOR_QUESTION_ANSWERING_MAPPING = OrderedDict(
[
(DistilBertConfig, DistilBertForQuestionAnswering),
(AlbertConfig, AlbertForQuestionAnswering),
(BartConfig, BartForQuestionAnswering),
(LongformerConfig, LongformerForQuestionAnswering),
(XLMRobertaConfig, XLMRobertaForQuestionAnswering),
(RobertaConfig, RobertaForQuestionAnswering),

View File

@ -23,6 +23,7 @@ import numpy as np
import torch
import torch.nn.functional as F
from torch import Tensor, nn
from torch.nn import CrossEntropyLoss
from .activations import ACT2FN
from .configuration_bart import BartConfig
@ -1123,6 +1124,122 @@ class BartForSequenceClassification(PretrainedBartModel):
return outputs
@add_start_docstrings(
"""BART Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layer on top of
the hidden-states output to compute `span start logits` and `span end logits`). """,
BART_START_DOCSTRING,
)
class BartForQuestionAnswering(PretrainedBartModel):
def __init__(self, config):
super().__init__(config)
config.num_labels = 2
self.num_labels = config.num_labels
self.model = BartModel(config)
self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
self.model._init_weights(self.qa_outputs)
@add_start_docstrings_to_callable(BART_INPUTS_DOCSTRING)
def forward(
self,
input_ids,
attention_mask=None,
encoder_outputs=None,
decoder_input_ids=None,
decoder_attention_mask=None,
start_positions=None,
end_positions=None,
output_attentions=None,
):
r"""
start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
Labels for position (index) of the start of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (`sequence_length`).
Position outside of the sequence are not taken into account for computing the loss.
end_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
Labels for position (index) of the end of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (`sequence_length`).
Position outside of the sequence are not taken into account for computing the loss.
Returns:
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BartConfig`) and inputs:
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`labels` is provided):
Total span extraction loss is the sum of a Cross-Entropy for the start and end positions.
start_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length,)`):
Span-start scores (before SoftMax).
end_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length,)`):
Span-end scores (before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
Examples::
# The checkpoint bart-large is not fine-tuned for question answering. Please see the
# examples/question-answering/run_squad.py example to see how to fine-tune a model to a question answering task.
from transformers import BartTokenizer, BartForQuestionAnswering
import torch
tokenizer = BartTokenizer.from_pretrained('facebook/bart-large')
model = BartForQuestionAnswering.from_pretrained('facebook/bart-large')
question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet"
input_ids = tokenizer.encode(question, text)
start_scores, end_scores = model(torch.tensor([input_ids]))
all_tokens = tokenizer.convert_ids_to_tokens(input_ids)
answer = ' '.join(all_tokens[torch.argmax(start_scores) : torch.argmax(end_scores)+1])
"""
outputs = self.model(
input_ids,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
encoder_outputs=encoder_outputs,
output_attentions=output_attentions,
)
sequence_output = outputs[0]
logits = self.qa_outputs(sequence_output)
start_logits, end_logits = logits.split(1, dim=-1)
start_logits = start_logits.squeeze(-1)
end_logits = end_logits.squeeze(-1)
outputs = (start_logits, end_logits,) + outputs[1:]
if start_positions is not None and end_positions is not None:
# If we are on multi-GPU, split add a dimension
if len(start_positions.size()) > 1:
start_positions = start_positions.squeeze(-1)
if len(end_positions.size()) > 1:
end_positions = end_positions.squeeze(-1)
# sometimes the start/end positions are outside our model inputs, we ignore these terms
ignored_index = start_logits.size(1)
start_positions.clamp_(0, ignored_index)
end_positions.clamp_(0, ignored_index)
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
start_loss = loss_fct(start_logits, start_positions)
end_loss = loss_fct(end_logits, end_positions)
total_loss = (start_loss + end_loss) / 2
outputs = (total_loss,) + outputs
return outputs # (loss), start_logits, end_logits, (hidden_states), (attentions)
class SinusoidalPositionalEmbedding(nn.Embedding):
"""This module produces sinusoidal positional embeddings of any length."""

View File

@ -35,6 +35,7 @@ if is_torch_available():
BartModel,
BartForConditionalGeneration,
BartForSequenceClassification,
BartForQuestionAnswering,
BartConfig,
BartTokenizer,
MBartTokenizer,
@ -375,6 +376,19 @@ class BartHeadTests(unittest.TestCase):
loss = outputs[0]
self.assertIsInstance(loss.item(), float)
def test_question_answering_forward(self):
config, input_ids, batch_size = self._get_config_and_data()
sequence_labels = ids_tensor([batch_size], 2).to(torch_device)
model = BartForQuestionAnswering(config)
model.to(torch_device)
loss, start_logits, end_logits, _ = model(
input_ids=input_ids, start_positions=sequence_labels, end_positions=sequence_labels,
)
self.assertEqual(start_logits.shape, input_ids.shape)
self.assertEqual(end_logits.shape, input_ids.shape)
self.assertIsInstance(loss.item(), float)
@timeout_decorator.timeout(1)
def test_lm_forward(self):
config, input_ids, batch_size = self._get_config_and_data()