mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
GPT2ForQuestionAnswering (#23030)
* first draft - gives index error in question_answering.py * maturing * no labels * pipeline should know about QA * fixing checks * formatting * fixed docstring * make sure legacy code executes * comment * like this --------- Co-authored-by: Prof. Peter Schneider-Kamp <jps@ordbogen.com>
This commit is contained in:
parent
bcedd0a471
commit
2b0c924568
@ -111,6 +111,11 @@ A list of official Hugging Face and community (indicated by 🌎) resources to h
|
||||
[[autodoc]] GPT2DoubleHeadsModel
|
||||
- forward
|
||||
|
||||
## GPT2ForQuestionAnswering
|
||||
|
||||
[[autodoc]] GPT2ForQuestionAnswering
|
||||
- forward
|
||||
|
||||
## GPT2ForSequenceClassification
|
||||
|
||||
[[autodoc]] GPT2ForSequenceClassification
|
||||
|
@ -31,7 +31,7 @@ The task illustrated in this tutorial is supported by the following model archit
|
||||
|
||||
<!--This tip is automatically generated by `make fix-copies`, do not fill manually!-->
|
||||
|
||||
[ALBERT](../model_doc/albert), [BART](../model_doc/bart), [BERT](../model_doc/bert), [BigBird](../model_doc/big_bird), [BigBird-Pegasus](../model_doc/bigbird_pegasus), [BLOOM](../model_doc/bloom), [CamemBERT](../model_doc/camembert), [CANINE](../model_doc/canine), [ConvBERT](../model_doc/convbert), [Data2VecText](../model_doc/data2vec-text), [DeBERTa](../model_doc/deberta), [DeBERTa-v2](../model_doc/deberta-v2), [DistilBERT](../model_doc/distilbert), [ELECTRA](../model_doc/electra), [ERNIE](../model_doc/ernie), [ErnieM](../model_doc/ernie_m), [FlauBERT](../model_doc/flaubert), [FNet](../model_doc/fnet), [Funnel Transformer](../model_doc/funnel), [GPT-J](../model_doc/gptj), [I-BERT](../model_doc/ibert), [LayoutLMv2](../model_doc/layoutlmv2), [LayoutLMv3](../model_doc/layoutlmv3), [LED](../model_doc/led), [LiLT](../model_doc/lilt), [Longformer](../model_doc/longformer), [LUKE](../model_doc/luke), [LXMERT](../model_doc/lxmert), [MarkupLM](../model_doc/markuplm), [mBART](../model_doc/mbart), [MEGA](../model_doc/mega), [Megatron-BERT](../model_doc/megatron-bert), [MobileBERT](../model_doc/mobilebert), [MPNet](../model_doc/mpnet), [MVP](../model_doc/mvp), [Nezha](../model_doc/nezha), [Nyströmformer](../model_doc/nystromformer), [OPT](../model_doc/opt), [QDQBert](../model_doc/qdqbert), [Reformer](../model_doc/reformer), [RemBERT](../model_doc/rembert), [RoBERTa](../model_doc/roberta), [RoBERTa-PreLayerNorm](../model_doc/roberta-prelayernorm), [RoCBert](../model_doc/roc_bert), [RoFormer](../model_doc/roformer), [Splinter](../model_doc/splinter), [SqueezeBERT](../model_doc/squeezebert), [XLM](../model_doc/xlm), [XLM-RoBERTa](../model_doc/xlm-roberta), [XLM-RoBERTa-XL](../model_doc/xlm-roberta-xl), [XLNet](../model_doc/xlnet), [X-MOD](../model_doc/xmod), [YOSO](../model_doc/yoso)
|
||||
[ALBERT](../model_doc/albert), [BART](../model_doc/bart), [BERT](../model_doc/bert), [BigBird](../model_doc/big_bird), [BigBird-Pegasus](../model_doc/bigbird_pegasus), [BLOOM](../model_doc/bloom), [CamemBERT](../model_doc/camembert), [CANINE](../model_doc/canine), [ConvBERT](../model_doc/convbert), [Data2VecText](../model_doc/data2vec-text), [DeBERTa](../model_doc/deberta), [DeBERTa-v2](../model_doc/deberta-v2), [DistilBERT](../model_doc/distilbert), [ELECTRA](../model_doc/electra), [ERNIE](../model_doc/ernie), [ErnieM](../model_doc/ernie_m), [FlauBERT](../model_doc/flaubert), [FNet](../model_doc/fnet), [Funnel Transformer](../model_doc/funnel), [OpenAI GPT-2](../model_doc/gpt2), [GPT-J](../model_doc/gptj), [I-BERT](../model_doc/ibert), [LayoutLMv2](../model_doc/layoutlmv2), [LayoutLMv3](../model_doc/layoutlmv3), [LED](../model_doc/led), [LiLT](../model_doc/lilt), [Longformer](../model_doc/longformer), [LUKE](../model_doc/luke), [LXMERT](../model_doc/lxmert), [MarkupLM](../model_doc/markuplm), [mBART](../model_doc/mbart), [MEGA](../model_doc/mega), [Megatron-BERT](../model_doc/megatron-bert), [MobileBERT](../model_doc/mobilebert), [MPNet](../model_doc/mpnet), [MVP](../model_doc/mvp), [Nezha](../model_doc/nezha), [Nyströmformer](../model_doc/nystromformer), [OPT](../model_doc/opt), [QDQBert](../model_doc/qdqbert), [Reformer](../model_doc/reformer), [RemBERT](../model_doc/rembert), [RoBERTa](../model_doc/roberta), [RoBERTa-PreLayerNorm](../model_doc/roberta-prelayernorm), [RoCBert](../model_doc/roc_bert), [RoFormer](../model_doc/roformer), [Splinter](../model_doc/splinter), [SqueezeBERT](../model_doc/squeezebert), [XLM](../model_doc/xlm), [XLM-RoBERTa](../model_doc/xlm-roberta), [XLM-RoBERTa-XL](../model_doc/xlm-roberta-xl), [XLNet](../model_doc/xlnet), [X-MOD](../model_doc/xmod), [YOSO](../model_doc/yoso)
|
||||
|
||||
|
||||
<!--End of the generated tip-->
|
||||
|
@ -1666,6 +1666,7 @@ else:
|
||||
[
|
||||
"GPT2_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
"GPT2DoubleHeadsModel",
|
||||
"GPT2ForQuestionAnswering",
|
||||
"GPT2ForSequenceClassification",
|
||||
"GPT2ForTokenClassification",
|
||||
"GPT2LMHeadModel",
|
||||
@ -5212,6 +5213,7 @@ if TYPE_CHECKING:
|
||||
from .models.gpt2 import (
|
||||
GPT2_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
GPT2DoubleHeadsModel,
|
||||
GPT2ForQuestionAnswering,
|
||||
GPT2ForSequenceClassification,
|
||||
GPT2ForTokenClassification,
|
||||
GPT2LMHeadModel,
|
||||
|
@ -735,6 +735,7 @@ MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
|
||||
("flaubert", "FlaubertForQuestionAnsweringSimple"),
|
||||
("fnet", "FNetForQuestionAnswering"),
|
||||
("funnel", "FunnelForQuestionAnswering"),
|
||||
("gpt2", "GPT2ForQuestionAnswering"),
|
||||
("gptj", "GPTJForQuestionAnswering"),
|
||||
("ibert", "IBertForQuestionAnswering"),
|
||||
("layoutlmv2", "LayoutLMv2ForQuestionAnswering"),
|
||||
|
@ -48,6 +48,7 @@ else:
|
||||
_import_structure["modeling_gpt2"] = [
|
||||
"GPT2_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
"GPT2DoubleHeadsModel",
|
||||
"GPT2ForQuestionAnswering",
|
||||
"GPT2ForSequenceClassification",
|
||||
"GPT2ForTokenClassification",
|
||||
"GPT2LMHeadModel",
|
||||
@ -109,6 +110,7 @@ if TYPE_CHECKING:
|
||||
from .modeling_gpt2 import (
|
||||
GPT2_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
GPT2DoubleHeadsModel,
|
||||
GPT2ForQuestionAnswering,
|
||||
GPT2ForSequenceClassification,
|
||||
GPT2ForTokenClassification,
|
||||
GPT2LMHeadModel,
|
||||
|
@ -31,6 +31,7 @@ from ...activations import ACT2FN
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutputWithPastAndCrossAttentions,
|
||||
CausalLMOutputWithCrossAttentions,
|
||||
QuestionAnsweringModelOutput,
|
||||
SequenceClassifierOutputWithPast,
|
||||
TokenClassifierOutput,
|
||||
)
|
||||
@ -51,6 +52,7 @@ from .configuration_gpt2 import GPT2Config
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
_CHECKPOINT_FOR_DOC = "gpt2"
|
||||
_REAL_CHECKPOINT_FOR_DOC = "gpt2"
|
||||
_CONFIG_FOR_DOC = "GPT2Config"
|
||||
|
||||
GPT2_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||
@ -1586,3 +1588,109 @@ class GPT2ForTokenClassification(GPT2PreTrainedModel):
|
||||
hidden_states=transformer_outputs.hidden_states,
|
||||
attentions=transformer_outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""
|
||||
The GPT-2 Model transformer with a span classification head on top for extractive question-answering tasks like
|
||||
SQuAD (a linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`).
|
||||
""",
|
||||
GPT2_START_DOCSTRING,
|
||||
)
|
||||
class GPT2ForQuestionAnswering(GPT2PreTrainedModel):
|
||||
_keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.masked_bias", r"h\.\d+\.attn\.bias", r"lm_head.weight"]
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.num_labels = config.num_labels
|
||||
self.transformer = GPT2Model(config)
|
||||
self.qa_outputs = nn.Linear(config.hidden_size, 2)
|
||||
|
||||
# Model parallel
|
||||
self.model_parallel = False
|
||||
self.device_map = None
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
@add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
@add_code_sample_docstrings(
|
||||
checkpoint=_CHECKPOINT_FOR_DOC,
|
||||
output_type=QuestionAnsweringModelOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
real_checkpoint=_REAL_CHECKPOINT_FOR_DOC,
|
||||
)
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
token_type_ids: Optional[torch.LongTensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
head_mask: Optional[torch.FloatTensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
start_positions: Optional[torch.LongTensor] = None,
|
||||
end_positions: Optional[torch.LongTensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, QuestionAnsweringModelOutput]:
|
||||
r"""
|
||||
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||
Labels for position (index) of the start of the labelled span for computing the token classification loss.
|
||||
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
|
||||
are not taken into account for computing the loss.
|
||||
end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||
Labels for position (index) of the end of the labelled span for computing the token classification loss.
|
||||
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
|
||||
are not taken into account for computing the loss.
|
||||
"""
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
outputs = self.transformer(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
sequence_output = outputs[0]
|
||||
|
||||
logits = self.qa_outputs(sequence_output)
|
||||
start_logits, end_logits = logits.split(1, dim=-1)
|
||||
start_logits = start_logits.squeeze(-1).contiguous()
|
||||
end_logits = end_logits.squeeze(-1).contiguous()
|
||||
|
||||
total_loss = None
|
||||
if start_positions is not None and end_positions is not None:
|
||||
# If we are on multi-GPU, split add a dimension
|
||||
if len(start_positions.size()) > 1:
|
||||
start_positions = start_positions.squeeze(-1)
|
||||
if len(end_positions.size()) > 1:
|
||||
end_positions = end_positions.squeeze(-1)
|
||||
# sometimes the start/end positions are outside our model inputs, we ignore these terms
|
||||
ignored_index = start_logits.size(1)
|
||||
start_positions = start_positions.clamp(0, ignored_index)
|
||||
end_positions = end_positions.clamp(0, ignored_index)
|
||||
|
||||
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
|
||||
start_loss = loss_fct(start_logits, start_positions)
|
||||
end_loss = loss_fct(end_logits, end_positions)
|
||||
total_loss = (start_loss + end_loss) / 2
|
||||
|
||||
if not return_dict:
|
||||
output = (start_logits, end_logits) + outputs[2:]
|
||||
return ((total_loss,) + output) if total_loss is not None else output
|
||||
|
||||
return QuestionAnsweringModelOutput(
|
||||
loss=total_loss,
|
||||
start_logits=start_logits,
|
||||
end_logits=end_logits,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
@ -3186,6 +3186,13 @@ class GPT2DoubleHeadsModel(metaclass=DummyObject):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class GPT2ForQuestionAnswering(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class GPT2ForSequenceClassification(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
|
@ -33,6 +33,7 @@ if is_torch_available():
|
||||
from transformers import (
|
||||
GPT2_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
GPT2DoubleHeadsModel,
|
||||
GPT2ForQuestionAnswering,
|
||||
GPT2ForSequenceClassification,
|
||||
GPT2ForTokenClassification,
|
||||
GPT2LMHeadModel,
|
||||
@ -377,6 +378,17 @@ class GPT2ModelTester:
|
||||
)
|
||||
self.parent.assertEqual(result.mc_logits.shape, (self.batch_size, self.num_choices))
|
||||
|
||||
def create_and_check_gpt2_for_question_answering(
|
||||
self, config, input_ids, input_mask, head_mask, token_type_ids, mc_token_ids, sequence_labels, *args
|
||||
):
|
||||
config.num_labels = self.num_labels
|
||||
model = GPT2ForQuestionAnswering(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)
|
||||
self.parent.assertEqual(result.start_logits.shape, (self.batch_size, self.seq_length))
|
||||
self.parent.assertEqual(result.end_logits.shape, (self.batch_size, self.seq_length))
|
||||
|
||||
def create_and_check_gpt2_for_sequence_classification(
|
||||
self, config, input_ids, input_mask, head_mask, token_type_ids, mc_token_ids, sequence_labels, *args
|
||||
):
|
||||
@ -432,7 +444,14 @@ class GPT2ModelTester:
|
||||
@require_torch
|
||||
class GPT2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (
|
||||
(GPT2Model, GPT2LMHeadModel, GPT2DoubleHeadsModel, GPT2ForSequenceClassification, GPT2ForTokenClassification)
|
||||
(
|
||||
GPT2Model,
|
||||
GPT2LMHeadModel,
|
||||
GPT2DoubleHeadsModel,
|
||||
GPT2ForQuestionAnswering,
|
||||
GPT2ForSequenceClassification,
|
||||
GPT2ForTokenClassification,
|
||||
)
|
||||
if is_torch_available()
|
||||
else ()
|
||||
)
|
||||
@ -440,6 +459,7 @@ class GPT2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
|
||||
pipeline_model_mapping = (
|
||||
{
|
||||
"feature-extraction": GPT2Model,
|
||||
"question-answering": GPT2ForQuestionAnswering,
|
||||
"text-classification": GPT2ForSequenceClassification,
|
||||
"text-generation": GPT2LMHeadModel,
|
||||
"token-classification": GPT2ForTokenClassification,
|
||||
@ -507,6 +527,10 @@ class GPT2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_double_lm_head_model(*config_and_inputs)
|
||||
|
||||
def test_gpt2_question_answering_model(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_gpt2_for_question_answering(*config_and_inputs)
|
||||
|
||||
def test_gpt2_sequence_classification_model(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_gpt2_for_sequence_classification(*config_and_inputs)
|
||||
|
Loading…
Reference in New Issue
Block a user