mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 11:11:05 +06:00
BartForQuestionAnswering (#4908)
This commit is contained in:
parent
538531cde5
commit
e93ccb3290
@ -55,6 +55,13 @@ BartForSequenceClassification
|
||||
:members: forward
|
||||
|
||||
|
||||
BartForQuestionAnswering
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.BartForQuestionAnswering
|
||||
:members: forward
|
||||
|
||||
|
||||
BartForConditionalGeneration
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
|
@ -250,6 +250,7 @@ if is_torch_available():
|
||||
BartForSequenceClassification,
|
||||
BartModel,
|
||||
BartForConditionalGeneration,
|
||||
BartForQuestionAnswering,
|
||||
BART_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
)
|
||||
from .modeling_marian import MarianMTModel
|
||||
|
@ -52,7 +52,12 @@ from .modeling_albert import (
|
||||
AlbertForTokenClassification,
|
||||
AlbertModel,
|
||||
)
|
||||
from .modeling_bart import BartForConditionalGeneration, BartForSequenceClassification, BartModel
|
||||
from .modeling_bart import (
|
||||
BartForConditionalGeneration,
|
||||
BartForQuestionAnswering,
|
||||
BartForSequenceClassification,
|
||||
BartModel,
|
||||
)
|
||||
from .modeling_bert import (
|
||||
BertForMaskedLM,
|
||||
BertForMultipleChoice,
|
||||
@ -274,6 +279,7 @@ MODEL_FOR_QUESTION_ANSWERING_MAPPING = OrderedDict(
|
||||
[
|
||||
(DistilBertConfig, DistilBertForQuestionAnswering),
|
||||
(AlbertConfig, AlbertForQuestionAnswering),
|
||||
(BartConfig, BartForQuestionAnswering),
|
||||
(LongformerConfig, LongformerForQuestionAnswering),
|
||||
(XLMRobertaConfig, XLMRobertaForQuestionAnswering),
|
||||
(RobertaConfig, RobertaForQuestionAnswering),
|
||||
|
@ -23,6 +23,7 @@ import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor, nn
|
||||
from torch.nn import CrossEntropyLoss
|
||||
|
||||
from .activations import ACT2FN
|
||||
from .configuration_bart import BartConfig
|
||||
@ -1123,6 +1124,122 @@ class BartForSequenceClassification(PretrainedBartModel):
|
||||
return outputs
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""BART Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layer on top of
|
||||
the hidden-states output to compute `span start logits` and `span end logits`). """,
|
||||
BART_START_DOCSTRING,
|
||||
)
|
||||
class BartForQuestionAnswering(PretrainedBartModel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
config.num_labels = 2
|
||||
self.num_labels = config.num_labels
|
||||
|
||||
self.model = BartModel(config)
|
||||
self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
|
||||
|
||||
self.model._init_weights(self.qa_outputs)
|
||||
|
||||
@add_start_docstrings_to_callable(BART_INPUTS_DOCSTRING)
|
||||
def forward(
|
||||
self,
|
||||
input_ids,
|
||||
attention_mask=None,
|
||||
encoder_outputs=None,
|
||||
decoder_input_ids=None,
|
||||
decoder_attention_mask=None,
|
||||
start_positions=None,
|
||||
end_positions=None,
|
||||
output_attentions=None,
|
||||
):
|
||||
r"""
|
||||
start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
|
||||
Labels for position (index) of the start of the labelled span for computing the token classification loss.
|
||||
Positions are clamped to the length of the sequence (`sequence_length`).
|
||||
Position outside of the sequence are not taken into account for computing the loss.
|
||||
end_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
|
||||
Labels for position (index) of the end of the labelled span for computing the token classification loss.
|
||||
Positions are clamped to the length of the sequence (`sequence_length`).
|
||||
Position outside of the sequence are not taken into account for computing the loss.
|
||||
|
||||
Returns:
|
||||
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BartConfig`) and inputs:
|
||||
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`labels` is provided):
|
||||
Total span extraction loss is the sum of a Cross-Entropy for the start and end positions.
|
||||
start_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length,)`):
|
||||
Span-start scores (before SoftMax).
|
||||
end_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length,)`):
|
||||
Span-end scores (before SoftMax).
|
||||
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
|
||||
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
|
||||
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
||||
|
||||
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True``):
|
||||
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
|
||||
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
|
||||
|
||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
||||
heads.
|
||||
|
||||
Examples::
|
||||
|
||||
# The checkpoint bart-large is not fine-tuned for question answering. Please see the
|
||||
# examples/question-answering/run_squad.py example to see how to fine-tune a model to a question answering task.
|
||||
|
||||
from transformers import BartTokenizer, BartForQuestionAnswering
|
||||
import torch
|
||||
|
||||
tokenizer = BartTokenizer.from_pretrained('facebook/bart-large')
|
||||
model = BartForQuestionAnswering.from_pretrained('facebook/bart-large')
|
||||
|
||||
question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet"
|
||||
input_ids = tokenizer.encode(question, text)
|
||||
start_scores, end_scores = model(torch.tensor([input_ids]))
|
||||
|
||||
all_tokens = tokenizer.convert_ids_to_tokens(input_ids)
|
||||
answer = ' '.join(all_tokens[torch.argmax(start_scores) : torch.argmax(end_scores)+1])
|
||||
|
||||
"""
|
||||
|
||||
outputs = self.model(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
encoder_outputs=encoder_outputs,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
sequence_output = outputs[0]
|
||||
|
||||
logits = self.qa_outputs(sequence_output)
|
||||
start_logits, end_logits = logits.split(1, dim=-1)
|
||||
start_logits = start_logits.squeeze(-1)
|
||||
end_logits = end_logits.squeeze(-1)
|
||||
|
||||
outputs = (start_logits, end_logits,) + outputs[1:]
|
||||
if start_positions is not None and end_positions is not None:
|
||||
# If we are on multi-GPU, split add a dimension
|
||||
if len(start_positions.size()) > 1:
|
||||
start_positions = start_positions.squeeze(-1)
|
||||
if len(end_positions.size()) > 1:
|
||||
end_positions = end_positions.squeeze(-1)
|
||||
# sometimes the start/end positions are outside our model inputs, we ignore these terms
|
||||
ignored_index = start_logits.size(1)
|
||||
start_positions.clamp_(0, ignored_index)
|
||||
end_positions.clamp_(0, ignored_index)
|
||||
|
||||
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
|
||||
start_loss = loss_fct(start_logits, start_positions)
|
||||
end_loss = loss_fct(end_logits, end_positions)
|
||||
total_loss = (start_loss + end_loss) / 2
|
||||
outputs = (total_loss,) + outputs
|
||||
|
||||
return outputs # (loss), start_logits, end_logits, (hidden_states), (attentions)
|
||||
|
||||
|
||||
class SinusoidalPositionalEmbedding(nn.Embedding):
|
||||
"""This module produces sinusoidal positional embeddings of any length."""
|
||||
|
||||
|
@ -35,6 +35,7 @@ if is_torch_available():
|
||||
BartModel,
|
||||
BartForConditionalGeneration,
|
||||
BartForSequenceClassification,
|
||||
BartForQuestionAnswering,
|
||||
BartConfig,
|
||||
BartTokenizer,
|
||||
MBartTokenizer,
|
||||
@ -375,6 +376,19 @@ class BartHeadTests(unittest.TestCase):
|
||||
loss = outputs[0]
|
||||
self.assertIsInstance(loss.item(), float)
|
||||
|
||||
def test_question_answering_forward(self):
|
||||
config, input_ids, batch_size = self._get_config_and_data()
|
||||
sequence_labels = ids_tensor([batch_size], 2).to(torch_device)
|
||||
model = BartForQuestionAnswering(config)
|
||||
model.to(torch_device)
|
||||
loss, start_logits, end_logits, _ = model(
|
||||
input_ids=input_ids, start_positions=sequence_labels, end_positions=sequence_labels,
|
||||
)
|
||||
|
||||
self.assertEqual(start_logits.shape, input_ids.shape)
|
||||
self.assertEqual(end_logits.shape, input_ids.shape)
|
||||
self.assertIsInstance(loss.item(), float)
|
||||
|
||||
@timeout_decorator.timeout(1)
|
||||
def test_lm_forward(self):
|
||||
config, input_ids, batch_size = self._get_config_and_data()
|
||||
|
Loading…
Reference in New Issue
Block a user