Add GPTJForQuestionAnswering (#14503)

* Add GPTJForQuestionAnswering

* Reformat for GPTJForQuestionAnswering

* Fix isort error

* make style for GPTJForQA

* Add _keys_to_ignore_on_load_missing

* Change the sequence of qa and classification

Co-authored-by: Suraj Patil <surajp815@gmail.com>
This commit is contained in:
tucan9389 2021-12-07 01:44:10 +09:00 committed by GitHub
parent 1ccc033c56
commit 0f3f045ebd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 141 additions and 2 deletions

View File

@ -121,6 +121,13 @@ GPTJForSequenceClassification
:members: forward
GPTJForQuestionAnswering
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.GPTJForQuestionAnswering
:members: forward
FlaxGPTJModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

View File

@ -951,6 +951,7 @@ if is_torch_available():
[
"GPTJ_PRETRAINED_MODEL_ARCHIVE_LIST",
"GPTJForCausalLM",
"GPTJForQuestionAnswering",
"GPTJForSequenceClassification",
"GPTJModel",
"GPTJPreTrainedModel",
@ -2833,6 +2834,7 @@ if TYPE_CHECKING:
from .models.gptj import (
GPTJ_PRETRAINED_MODEL_ARCHIVE_LIST,
GPTJForCausalLM,
GPTJForQuestionAnswering,
GPTJForSequenceClassification,
GPTJModel,
GPTJPreTrainedModel,

View File

@ -385,6 +385,7 @@ MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
# Model for Question Answering mapping
("qdqbert", "QDQBertForQuestionAnswering"),
("fnet", "FNetForQuestionAnswering"),
("gptj", "GPTJForQuestionAnswering"),
("layoutlmv2", "LayoutLMv2ForQuestionAnswering"),
("rembert", "RemBertForQuestionAnswering"),
("canine", "CanineForQuestionAnswering"),

View File

@ -28,6 +28,7 @@ if is_torch_available():
_import_structure["modeling_gptj"] = [
"GPTJ_PRETRAINED_MODEL_ARCHIVE_LIST",
"GPTJForCausalLM",
"GPTJForQuestionAnswering",
"GPTJForSequenceClassification",
"GPTJModel",
"GPTJPreTrainedModel",
@ -48,6 +49,7 @@ if TYPE_CHECKING:
from .modeling_gptj import (
GPTJ_PRETRAINED_MODEL_ARCHIVE_LIST,
GPTJForCausalLM,
GPTJForQuestionAnswering,
GPTJForSequenceClassification,
GPTJModel,
GPTJPreTrainedModel,

View File

@ -23,7 +23,12 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN
from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
from ...modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
QuestionAnsweringModelOutput,
SequenceClassifierOutputWithPast,
)
from ...modeling_utils import PreTrainedModel
from ...utils import logging
from ...utils.model_parallel_utils import assert_device_map, get_device_map
@ -967,3 +972,108 @@ class GPTJForSequenceClassification(GPTJPreTrainedModel):
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)
@add_start_docstrings(
"""
The GPT-J Model transformer with a span classification head on top for extractive question-answering tasks like
SQuAD (a linear layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
""",
GPTJ_START_DOCSTRING,
)
class GPTJForQuestionAnswering(GPTJPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.masked_bias", r"h\.\d+\.attn\.bias", r"lm_head\.weight"]
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.transformer = GPTJModel(config)
self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
# Model parallel
self.model_parallel = False
self.device_map = None
# Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward(GPTJ_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC,
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=QuestionAnsweringModelOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
start_positions=None,
end_positions=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
r"""
start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
Labels for position (index) of the start of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the
sequence are not taken into account for computing the loss.
end_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
Labels for position (index) of the end of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the
sequence are not taken into account for computing the loss.
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.transformer(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = outputs[0]
logits = self.qa_outputs(sequence_output)
start_logits, end_logits = logits.split(1, dim=-1)
start_logits = start_logits.squeeze(-1).contiguous()
end_logits = end_logits.squeeze(-1).contiguous()
total_loss = None
if start_positions is not None and end_positions is not None:
# If we are on multi-GPU, split add a dimension
if len(start_positions.size()) > 1:
start_positions = start_positions.squeeze(-1)
if len(end_positions.size()) > 1:
end_positions = end_positions.squeeze(-1)
# sometimes the start/end positions are outside our model inputs, we ignore these terms
ignored_index = start_logits.size(1)
start_positions = start_positions.clamp(0, ignored_index)
end_positions = end_positions.clamp(0, ignored_index)
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
start_loss = loss_fct(start_logits, start_positions)
end_loss = loss_fct(end_logits, end_positions)
total_loss = (start_loss + end_loss) / 2
if not return_dict:
output = (start_logits, end_logits) + outputs[2:]
return ((total_loss,) + output) if total_loss is not None else output
return QuestionAnsweringModelOutput(
loss=total_loss,
start_logits=start_logits,
end_logits=end_logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)

View File

@ -2494,6 +2494,18 @@ class GPTJForCausalLM:
requires_backends(self, ["torch"])
class GPTJForQuestionAnswering:
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
def forward(self, *args, **kwargs):
requires_backends(self, ["torch"])
class GPTJForSequenceClassification:
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])

View File

@ -32,6 +32,7 @@ if is_torch_available():
GPTJ_PRETRAINED_MODEL_ARCHIVE_LIST,
AutoTokenizer,
GPTJForCausalLM,
GPTJForQuestionAnswering,
GPTJForSequenceClassification,
GPTJModel,
)
@ -356,7 +357,11 @@ class GPTJModelTester:
@require_torch
class GPTJModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
all_model_classes = (GPTJModel, GPTJForCausalLM, GPTJForSequenceClassification) if is_torch_available() else ()
all_model_classes = (
(GPTJModel, GPTJForCausalLM, GPTJForSequenceClassification, GPTJForQuestionAnswering)
if is_torch_available()
else ()
)
all_generative_model_classes = (GPTJForCausalLM,) if is_torch_available() else ()
fx_ready_model_classes = all_model_classes
test_pruning = False