mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
[T5
] Add T5ForQuestionAnswering and MT5ForQuestionAnswering (#24481)
* Adding T5ForQuestionAnswering * Changed weight initialization that results in better initial loss when fine-tuning * Update to class variables * Running make fixup * Running make fix-copies * Remove model_parallel * Adding MT5ForQuestionAnswering * Adding docs * Fix wrong doc * Update src/transformers/models/mt5/modeling_mt5.py Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> * Update src/transformers/models/t5/modeling_t5.py Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> * File formatting * Undoing change --------- Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
This commit is contained in:
parent
bcf02ec701
commit
06910f5a76
@ -95,6 +95,10 @@ See [`T5TokenizerFast`] for all details.
|
||||
|
||||
[[autodoc]] MT5EncoderModel
|
||||
|
||||
## MT5ForQuestionAnswering
|
||||
|
||||
[[autodoc]] MT5ForQuestionAnswering
|
||||
|
||||
## TFMT5Model
|
||||
|
||||
[[autodoc]] TFMT5Model
|
||||
|
@ -399,6 +399,11 @@ A list of official Hugging Face and community (indicated by 🌎) resources to h
|
||||
[[autodoc]] T5EncoderModel
|
||||
- forward
|
||||
|
||||
## T5ForQuestionAnswering
|
||||
|
||||
[[autodoc]] T5ForQuestionAnswering
|
||||
- forward
|
||||
|
||||
## TFT5Model
|
||||
|
||||
[[autodoc]] TFT5Model
|
||||
|
@ -35,7 +35,7 @@ The task illustrated in this tutorial is supported by the following model archit
|
||||
|
||||
<!--This tip is automatically generated by `make fix-copies`, do not fill manually!-->
|
||||
|
||||
[ALBERT](../model_doc/albert), [BART](../model_doc/bart), [BERT](../model_doc/bert), [BigBird](../model_doc/big_bird), [BigBird-Pegasus](../model_doc/bigbird_pegasus), [BLOOM](../model_doc/bloom), [CamemBERT](../model_doc/camembert), [CANINE](../model_doc/canine), [ConvBERT](../model_doc/convbert), [Data2VecText](../model_doc/data2vec-text), [DeBERTa](../model_doc/deberta), [DeBERTa-v2](../model_doc/deberta-v2), [DistilBERT](../model_doc/distilbert), [ELECTRA](../model_doc/electra), [ERNIE](../model_doc/ernie), [ErnieM](../model_doc/ernie_m), [FlauBERT](../model_doc/flaubert), [FNet](../model_doc/fnet), [Funnel Transformer](../model_doc/funnel), [OpenAI GPT-2](../model_doc/gpt2), [GPT Neo](../model_doc/gpt_neo), [GPT NeoX](../model_doc/gpt_neox), [GPT-J](../model_doc/gptj), [I-BERT](../model_doc/ibert), [LayoutLMv2](../model_doc/layoutlmv2), [LayoutLMv3](../model_doc/layoutlmv3), [LED](../model_doc/led), [LiLT](../model_doc/lilt), [Longformer](../model_doc/longformer), [LUKE](../model_doc/luke), [LXMERT](../model_doc/lxmert), [MarkupLM](../model_doc/markuplm), [mBART](../model_doc/mbart), [MEGA](../model_doc/mega), [Megatron-BERT](../model_doc/megatron-bert), [MobileBERT](../model_doc/mobilebert), [MPNet](../model_doc/mpnet), [MVP](../model_doc/mvp), [Nezha](../model_doc/nezha), [Nyströmformer](../model_doc/nystromformer), [OPT](../model_doc/opt), [QDQBert](../model_doc/qdqbert), [Reformer](../model_doc/reformer), [RemBERT](../model_doc/rembert), [RoBERTa](../model_doc/roberta), [RoBERTa-PreLayerNorm](../model_doc/roberta-prelayernorm), [RoCBert](../model_doc/roc_bert), [RoFormer](../model_doc/roformer), [Splinter](../model_doc/splinter), [SqueezeBERT](../model_doc/squeezebert), [XLM](../model_doc/xlm), [XLM-RoBERTa](../model_doc/xlm-roberta), [XLM-RoBERTa-XL](../model_doc/xlm-roberta-xl), [XLNet](../model_doc/xlnet), [X-MOD](../model_doc/xmod), [YOSO](../model_doc/yoso)
|
||||
[ALBERT](../model_doc/albert), [BART](../model_doc/bart), [BERT](../model_doc/bert), [BigBird](../model_doc/big_bird), [BigBird-Pegasus](../model_doc/bigbird_pegasus), [BLOOM](../model_doc/bloom), [CamemBERT](../model_doc/camembert), [CANINE](../model_doc/canine), [ConvBERT](../model_doc/convbert), [Data2VecText](../model_doc/data2vec-text), [DeBERTa](../model_doc/deberta), [DeBERTa-v2](../model_doc/deberta-v2), [DistilBERT](../model_doc/distilbert), [ELECTRA](../model_doc/electra), [ERNIE](../model_doc/ernie), [ErnieM](../model_doc/ernie_m), [FlauBERT](../model_doc/flaubert), [FNet](../model_doc/fnet), [Funnel Transformer](../model_doc/funnel), [OpenAI GPT-2](../model_doc/gpt2), [GPT Neo](../model_doc/gpt_neo), [GPT NeoX](../model_doc/gpt_neox), [GPT-J](../model_doc/gptj), [I-BERT](../model_doc/ibert), [LayoutLMv2](../model_doc/layoutlmv2), [LayoutLMv3](../model_doc/layoutlmv3), [LED](../model_doc/led), [LiLT](../model_doc/lilt), [Longformer](../model_doc/longformer), [LUKE](../model_doc/luke), [LXMERT](../model_doc/lxmert), [MarkupLM](../model_doc/markuplm), [mBART](../model_doc/mbart), [MEGA](../model_doc/mega), [Megatron-BERT](../model_doc/megatron-bert), [MobileBERT](../model_doc/mobilebert), [MPNet](../model_doc/mpnet), [MT5](../model_doc/mt5), [MVP](../model_doc/mvp), [Nezha](../model_doc/nezha), [Nyströmformer](../model_doc/nystromformer), [OPT](../model_doc/opt), [QDQBert](../model_doc/qdqbert), [Reformer](../model_doc/reformer), [RemBERT](../model_doc/rembert), [RoBERTa](../model_doc/roberta), [RoBERTa-PreLayerNorm](../model_doc/roberta-prelayernorm), [RoCBert](../model_doc/roc_bert), [RoFormer](../model_doc/roformer), [Splinter](../model_doc/splinter), [SqueezeBERT](../model_doc/squeezebert), [T5](../model_doc/t5), [XLM](../model_doc/xlm), [XLM-RoBERTa](../model_doc/xlm-roberta), [XLM-RoBERTa-XL](../model_doc/xlm-roberta-xl), [XLNet](../model_doc/xlnet), [X-MOD](../model_doc/xmod), [YOSO](../model_doc/yoso)
|
||||
|
||||
|
||||
<!--End of the generated tip-->
|
||||
|
@ -2132,7 +2132,7 @@ else:
|
||||
]
|
||||
)
|
||||
_import_structure["models.mt5"].extend(
|
||||
["MT5EncoderModel", "MT5ForConditionalGeneration", "MT5Model", "MT5PreTrainedModel"]
|
||||
["MT5EncoderModel", "MT5ForConditionalGeneration", "MT5ForQuestionAnswering", "MT5Model", "MT5PreTrainedModel"]
|
||||
)
|
||||
_import_structure["models.mvp"].extend(
|
||||
[
|
||||
@ -2573,6 +2573,7 @@ else:
|
||||
"T5_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
"T5EncoderModel",
|
||||
"T5ForConditionalGeneration",
|
||||
"T5ForQuestionAnswering",
|
||||
"T5Model",
|
||||
"T5PreTrainedModel",
|
||||
"load_tf_weights_in_t5",
|
||||
@ -5701,7 +5702,13 @@ if TYPE_CHECKING:
|
||||
MPNetModel,
|
||||
MPNetPreTrainedModel,
|
||||
)
|
||||
from .models.mt5 import MT5EncoderModel, MT5ForConditionalGeneration, MT5Model, MT5PreTrainedModel
|
||||
from .models.mt5 import (
|
||||
MT5EncoderModel,
|
||||
MT5ForConditionalGeneration,
|
||||
MT5ForQuestionAnswering,
|
||||
MT5Model,
|
||||
MT5PreTrainedModel,
|
||||
)
|
||||
from .models.mvp import (
|
||||
MVP_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
MvpForCausalLM,
|
||||
@ -6064,6 +6071,7 @@ if TYPE_CHECKING:
|
||||
T5_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
T5EncoderModel,
|
||||
T5ForConditionalGeneration,
|
||||
T5ForQuestionAnswering,
|
||||
T5Model,
|
||||
T5PreTrainedModel,
|
||||
load_tf_weights_in_t5,
|
||||
|
@ -768,6 +768,7 @@ MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
|
||||
("megatron-bert", "MegatronBertForQuestionAnswering"),
|
||||
("mobilebert", "MobileBertForQuestionAnswering"),
|
||||
("mpnet", "MPNetForQuestionAnswering"),
|
||||
("mt5", "MT5ForQuestionAnswering"),
|
||||
("mvp", "MvpForQuestionAnswering"),
|
||||
("nezha", "NezhaForQuestionAnswering"),
|
||||
("nystromformer", "NystromformerForQuestionAnswering"),
|
||||
@ -781,6 +782,7 @@ MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
|
||||
("roformer", "RoFormerForQuestionAnswering"),
|
||||
("splinter", "SplinterForQuestionAnswering"),
|
||||
("squeezebert", "SqueezeBertForQuestionAnswering"),
|
||||
("t5", "T5ForQuestionAnswering"),
|
||||
("xlm", "XLMForQuestionAnsweringSimple"),
|
||||
("xlm-roberta", "XLMRobertaForQuestionAnswering"),
|
||||
("xlm-roberta-xl", "XLMRobertaXLForQuestionAnswering"),
|
||||
|
@ -50,6 +50,7 @@ else:
|
||||
_import_structure["modeling_mt5"] = [
|
||||
"MT5EncoderModel",
|
||||
"MT5ForConditionalGeneration",
|
||||
"MT5ForQuestionAnswering",
|
||||
"MT5Model",
|
||||
"MT5PreTrainedModel",
|
||||
"MT5Stack",
|
||||
@ -81,7 +82,14 @@ if TYPE_CHECKING:
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
from .modeling_mt5 import MT5EncoderModel, MT5ForConditionalGeneration, MT5Model, MT5PreTrainedModel, MT5Stack
|
||||
from .modeling_mt5 import (
|
||||
MT5EncoderModel,
|
||||
MT5ForConditionalGeneration,
|
||||
MT5ForQuestionAnswering,
|
||||
MT5Model,
|
||||
MT5PreTrainedModel,
|
||||
MT5Stack,
|
||||
)
|
||||
|
||||
try:
|
||||
if not is_tf_available():
|
||||
|
@ -31,6 +31,7 @@ from ...modeling_outputs import (
|
||||
BaseModelOutputWithPastAndCrossAttentions,
|
||||
Seq2SeqLMOutput,
|
||||
Seq2SeqModelOutput,
|
||||
Seq2SeqQuestionAnsweringModelOutput,
|
||||
)
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
|
||||
@ -772,12 +773,15 @@ class MT5PreTrainedModel(PreTrainedModel):
|
||||
factor = self.config.initializer_factor # Used for testing weights initialization
|
||||
if isinstance(module, MT5LayerNorm):
|
||||
module.weight.data.fill_(factor * 1.0)
|
||||
elif isinstance(module, (MT5Model, MT5ForConditionalGeneration, MT5EncoderModel)):
|
||||
elif isinstance(module, (MT5Model, MT5ForConditionalGeneration, MT5EncoderModel, MT5ForQuestionAnswering)):
|
||||
# Mesh TensorFlow embeddings initialization
|
||||
# See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L1624
|
||||
module.shared.weight.data.normal_(mean=0.0, std=factor * 1.0)
|
||||
if hasattr(module, "lm_head") and not self.config.tie_word_embeddings:
|
||||
module.lm_head.weight.data.normal_(mean=0.0, std=factor * 1.0)
|
||||
if hasattr(module, "qa_outputs"):
|
||||
module.qa_outputs.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
|
||||
module.qa_outputs.bias.data.zero_()
|
||||
elif isinstance(module, MT5DenseActDense):
|
||||
# Mesh TensorFlow FF initialization
|
||||
# See https://github.com/tensorflow/mesh/blob/master/mesh_tensorflow/transformer/transformer_layers.py#L56
|
||||
@ -2015,3 +2019,201 @@ class MT5EncoderModel(MT5PreTrainedModel):
|
||||
)
|
||||
|
||||
return encoder_outputs
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""
|
||||
MT5 Model with a span classification head on top for extractive question-answering tasks like SQuAD (linear layers
|
||||
on top of the hidden-states output to compute `span start logits` and `span end logits`).
|
||||
""",
|
||||
MT5_START_DOCSTRING,
|
||||
)
|
||||
class MT5ForQuestionAnswering(MT5PreTrainedModel):
|
||||
_keys_to_ignore_on_load_missing = [
|
||||
r"encoder.embed_tokens.weight",
|
||||
r"decoder.embed_tokens.weight",
|
||||
r"lm_head.weight",
|
||||
]
|
||||
_keys_to_ignore_on_load_unexpected = [
|
||||
r"decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight",
|
||||
]
|
||||
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
|
||||
|
||||
# Copied from transformers.models.t5.modeling_t5.T5ForQuestionAnswering.__init__ with T5->MT5
|
||||
def __init__(self, config: MT5Config):
|
||||
super().__init__(config)
|
||||
self.model_dim = config.d_model
|
||||
|
||||
self.shared = nn.Embedding(config.vocab_size, config.d_model)
|
||||
|
||||
encoder_config = copy.deepcopy(config)
|
||||
encoder_config.is_decoder = False
|
||||
encoder_config.use_cache = False
|
||||
encoder_config.is_encoder_decoder = False
|
||||
self.encoder = MT5Stack(encoder_config, self.shared)
|
||||
|
||||
decoder_config = copy.deepcopy(config)
|
||||
decoder_config.is_decoder = True
|
||||
decoder_config.is_encoder_decoder = False
|
||||
decoder_config.num_layers = config.num_decoder_layers
|
||||
self.decoder = MT5Stack(decoder_config, self.shared)
|
||||
|
||||
self.num_labels = config.num_labels
|
||||
self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
# Copied from transformers.models.t5.modeling_t5.T5ForQuestionAnswering.get_input_embeddings
|
||||
def get_input_embeddings(self):
|
||||
return self.shared
|
||||
|
||||
# Copied from transformers.models.t5.modeling_t5.T5ForQuestionAnswering.set_input_embeddings
|
||||
def set_input_embeddings(self, new_embeddings):
|
||||
self.shared = new_embeddings
|
||||
self.encoder.set_input_embeddings(new_embeddings)
|
||||
self.decoder.set_input_embeddings(new_embeddings)
|
||||
|
||||
# Copied from transformers.models.t5.modeling_t5.T5ForQuestionAnswering.get_encoder
|
||||
def get_encoder(self):
|
||||
return self.encoder
|
||||
|
||||
# Copied from transformers.models.t5.modeling_t5.T5ForQuestionAnswering.get_decoder
|
||||
def get_decoder(self):
|
||||
return self.decoder
|
||||
|
||||
@add_start_docstrings_to_model_forward(MT5_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(output_type=Seq2SeqQuestionAnsweringModelOutput, config_class=_CONFIG_FOR_DOC)
|
||||
# Copied from transformers.models.t5.modeling_t5.T5ForQuestionAnswering.forward
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
decoder_input_ids: Optional[torch.LongTensor] = None,
|
||||
decoder_attention_mask: Optional[torch.BoolTensor] = None,
|
||||
head_mask: Optional[torch.FloatTensor] = None,
|
||||
decoder_head_mask: Optional[torch.FloatTensor] = None,
|
||||
cross_attn_head_mask: Optional[torch.Tensor] = None,
|
||||
encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
||||
start_positions: Optional[torch.LongTensor] = None,
|
||||
end_positions: Optional[torch.LongTensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple[torch.FloatTensor], Seq2SeqQuestionAnsweringModelOutput]:
|
||||
r"""
|
||||
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||
Labels for position (index) of the start of the labelled span for computing the token classification loss.
|
||||
Positions are clamped to the length of the sequence (*sequence_length*). Position outside of the sequence
|
||||
are not taken into account for computing the loss.
|
||||
end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||
Labels for position (index) of the end of the labelled span for computing the token classification loss.
|
||||
Positions are clamped to the length of the sequence (*sequence_length*). Position outside of the sequence
|
||||
are not taken into account for computing the loss.
|
||||
Returns:
|
||||
"""
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
if start_positions is not None and end_positions is not None:
|
||||
use_cache = False
|
||||
|
||||
# Copied from models.bart.modeling_bart.BartModel.forward
|
||||
# different to other models, T5 automatically creates decoder_input_ids from
|
||||
# input_ids if no decoder_input_ids are provided
|
||||
if decoder_input_ids is None and decoder_inputs_embeds is None:
|
||||
if input_ids is None:
|
||||
raise ValueError(
|
||||
"If no `decoder_input_ids` or `decoder_inputs_embeds` are "
|
||||
"passed, `input_ids` cannot be `None`. Please pass either "
|
||||
"`input_ids` or `decoder_input_ids` or `decoder_inputs_embeds`."
|
||||
)
|
||||
decoder_input_ids = self._shift_right(input_ids)
|
||||
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
|
||||
if head_mask is not None and decoder_head_mask is None:
|
||||
if self.config.num_layers == self.config.num_decoder_layers:
|
||||
warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning)
|
||||
decoder_head_mask = head_mask
|
||||
|
||||
# Encode if needed (training, first prediction pass)
|
||||
if encoder_outputs is None:
|
||||
encoder_outputs = self.encoder(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
head_mask=head_mask,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
|
||||
encoder_outputs = BaseModelOutput(
|
||||
last_hidden_state=encoder_outputs[0],
|
||||
hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
|
||||
attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
|
||||
)
|
||||
|
||||
hidden_states = encoder_outputs[0]
|
||||
|
||||
# Decode
|
||||
decoder_outputs = self.decoder(
|
||||
input_ids=decoder_input_ids,
|
||||
attention_mask=decoder_attention_mask,
|
||||
inputs_embeds=decoder_inputs_embeds,
|
||||
past_key_values=None,
|
||||
encoder_hidden_states=hidden_states,
|
||||
encoder_attention_mask=attention_mask,
|
||||
head_mask=decoder_head_mask,
|
||||
cross_attn_head_mask=cross_attn_head_mask,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
sequence_output = decoder_outputs[0]
|
||||
|
||||
logits = self.qa_outputs(sequence_output)
|
||||
start_logits, end_logits = logits.split(1, dim=-1)
|
||||
start_logits = start_logits.squeeze(-1).contiguous()
|
||||
end_logits = end_logits.squeeze(-1).contiguous()
|
||||
|
||||
total_loss = None
|
||||
if start_positions is not None and end_positions is not None:
|
||||
# If we are on multi-GPU, split add a dimension
|
||||
if len(start_positions.size()) > 1:
|
||||
start_positions = start_positions.squeeze(-1).to(start_logits.device)
|
||||
if len(end_positions.size()) > 1:
|
||||
end_positions = end_positions.squeeze(-1).to(end_logits.device)
|
||||
# sometimes the start/end positions are outside our model inputs, we ignore these terms
|
||||
ignored_index = start_logits.size(1)
|
||||
start_positions = start_positions.clamp(0, ignored_index)
|
||||
end_positions = end_positions.clamp(0, ignored_index)
|
||||
|
||||
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
|
||||
start_loss = loss_fct(start_logits, start_positions)
|
||||
end_loss = loss_fct(end_logits, end_positions)
|
||||
total_loss = (start_loss + end_loss) / 2
|
||||
|
||||
if not return_dict:
|
||||
output = (start_logits, end_logits) + decoder_outputs[1:] + encoder_outputs
|
||||
return ((total_loss,) + output) if total_loss is not None else output
|
||||
|
||||
return Seq2SeqQuestionAnsweringModelOutput(
|
||||
loss=total_loss,
|
||||
start_logits=start_logits,
|
||||
end_logits=end_logits,
|
||||
past_key_values=decoder_outputs.past_key_values,
|
||||
decoder_hidden_states=decoder_outputs.hidden_states,
|
||||
decoder_attentions=decoder_outputs.attentions,
|
||||
cross_attentions=decoder_outputs.cross_attentions,
|
||||
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
|
||||
encoder_hidden_states=encoder_outputs.hidden_states,
|
||||
encoder_attentions=encoder_outputs.attentions,
|
||||
)
|
||||
|
@ -56,6 +56,7 @@ else:
|
||||
"T5Model",
|
||||
"T5PreTrainedModel",
|
||||
"load_tf_weights_in_t5",
|
||||
"T5ForQuestionAnswering",
|
||||
]
|
||||
|
||||
try:
|
||||
@ -115,6 +116,7 @@ if TYPE_CHECKING:
|
||||
T5_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
T5EncoderModel,
|
||||
T5ForConditionalGeneration,
|
||||
T5ForQuestionAnswering,
|
||||
T5Model,
|
||||
T5PreTrainedModel,
|
||||
load_tf_weights_in_t5,
|
||||
|
@ -32,6 +32,7 @@ from ...modeling_outputs import (
|
||||
BaseModelOutputWithPastAndCrossAttentions,
|
||||
Seq2SeqLMOutput,
|
||||
Seq2SeqModelOutput,
|
||||
Seq2SeqQuestionAnsweringModelOutput,
|
||||
)
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...pytorch_utils import ALL_LAYERNORM_LAYERS, find_pruneable_heads_and_indices, prune_linear_layer
|
||||
@ -801,12 +802,15 @@ class T5PreTrainedModel(PreTrainedModel):
|
||||
factor = self.config.initializer_factor # Used for testing weights initialization
|
||||
if isinstance(module, T5LayerNorm):
|
||||
module.weight.data.fill_(factor * 1.0)
|
||||
elif isinstance(module, (T5Model, T5ForConditionalGeneration, T5EncoderModel)):
|
||||
elif isinstance(module, (T5Model, T5ForConditionalGeneration, T5EncoderModel, T5ForQuestionAnswering)):
|
||||
# Mesh TensorFlow embeddings initialization
|
||||
# See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L1624
|
||||
module.shared.weight.data.normal_(mean=0.0, std=factor * 1.0)
|
||||
if hasattr(module, "lm_head") and not self.config.tie_word_embeddings:
|
||||
module.lm_head.weight.data.normal_(mean=0.0, std=factor * 1.0)
|
||||
if hasattr(module, "qa_outputs"):
|
||||
module.qa_outputs.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
|
||||
module.qa_outputs.bias.data.zero_()
|
||||
elif isinstance(module, T5DenseActDense):
|
||||
# Mesh TensorFlow FF initialization
|
||||
# See https://github.com/tensorflow/mesh/blob/master/mesh_tensorflow/transformer/transformer_layers.py#L56
|
||||
@ -1949,3 +1953,195 @@ class T5EncoderModel(T5PreTrainedModel):
|
||||
)
|
||||
|
||||
return encoder_outputs
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""
|
||||
T5 Model with a span classification head on top for extractive question-answering tasks like SQuAD (linear layers
|
||||
on top of the hidden-states output to compute `span start logits` and `span end logits`).
|
||||
""",
|
||||
T5_START_DOCSTRING,
|
||||
)
|
||||
class T5ForQuestionAnswering(T5PreTrainedModel):
|
||||
_keys_to_ignore_on_load_missing = [
|
||||
r"encoder.embed_tokens.weight",
|
||||
r"decoder.embed_tokens.weight",
|
||||
r"lm_head.weight",
|
||||
]
|
||||
_keys_to_ignore_on_load_unexpected = [
|
||||
r"decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight",
|
||||
]
|
||||
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
|
||||
|
||||
def __init__(self, config: T5Config):
|
||||
super().__init__(config)
|
||||
self.model_dim = config.d_model
|
||||
|
||||
self.shared = nn.Embedding(config.vocab_size, config.d_model)
|
||||
|
||||
encoder_config = copy.deepcopy(config)
|
||||
encoder_config.is_decoder = False
|
||||
encoder_config.use_cache = False
|
||||
encoder_config.is_encoder_decoder = False
|
||||
self.encoder = T5Stack(encoder_config, self.shared)
|
||||
|
||||
decoder_config = copy.deepcopy(config)
|
||||
decoder_config.is_decoder = True
|
||||
decoder_config.is_encoder_decoder = False
|
||||
decoder_config.num_layers = config.num_decoder_layers
|
||||
self.decoder = T5Stack(decoder_config, self.shared)
|
||||
|
||||
self.num_labels = config.num_labels
|
||||
self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.shared
|
||||
|
||||
def set_input_embeddings(self, new_embeddings):
|
||||
self.shared = new_embeddings
|
||||
self.encoder.set_input_embeddings(new_embeddings)
|
||||
self.decoder.set_input_embeddings(new_embeddings)
|
||||
|
||||
def get_encoder(self):
|
||||
return self.encoder
|
||||
|
||||
def get_decoder(self):
|
||||
return self.decoder
|
||||
|
||||
@add_start_docstrings_to_model_forward(T5_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(output_type=Seq2SeqQuestionAnsweringModelOutput, config_class=_CONFIG_FOR_DOC)
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
decoder_input_ids: Optional[torch.LongTensor] = None,
|
||||
decoder_attention_mask: Optional[torch.BoolTensor] = None,
|
||||
head_mask: Optional[torch.FloatTensor] = None,
|
||||
decoder_head_mask: Optional[torch.FloatTensor] = None,
|
||||
cross_attn_head_mask: Optional[torch.Tensor] = None,
|
||||
encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
||||
start_positions: Optional[torch.LongTensor] = None,
|
||||
end_positions: Optional[torch.LongTensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple[torch.FloatTensor], Seq2SeqQuestionAnsweringModelOutput]:
|
||||
r"""
|
||||
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||
Labels for position (index) of the start of the labelled span for computing the token classification loss.
|
||||
Positions are clamped to the length of the sequence (*sequence_length*). Position outside of the sequence
|
||||
are not taken into account for computing the loss.
|
||||
end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||
Labels for position (index) of the end of the labelled span for computing the token classification loss.
|
||||
Positions are clamped to the length of the sequence (*sequence_length*). Position outside of the sequence
|
||||
are not taken into account for computing the loss.
|
||||
Returns:
|
||||
"""
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
if start_positions is not None and end_positions is not None:
|
||||
use_cache = False
|
||||
|
||||
# Copied from models.bart.modeling_bart.BartModel.forward
|
||||
# different to other models, T5 automatically creates decoder_input_ids from
|
||||
# input_ids if no decoder_input_ids are provided
|
||||
if decoder_input_ids is None and decoder_inputs_embeds is None:
|
||||
if input_ids is None:
|
||||
raise ValueError(
|
||||
"If no `decoder_input_ids` or `decoder_inputs_embeds` are "
|
||||
"passed, `input_ids` cannot be `None`. Please pass either "
|
||||
"`input_ids` or `decoder_input_ids` or `decoder_inputs_embeds`."
|
||||
)
|
||||
decoder_input_ids = self._shift_right(input_ids)
|
||||
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
|
||||
if head_mask is not None and decoder_head_mask is None:
|
||||
if self.config.num_layers == self.config.num_decoder_layers:
|
||||
warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning)
|
||||
decoder_head_mask = head_mask
|
||||
|
||||
# Encode if needed (training, first prediction pass)
|
||||
if encoder_outputs is None:
|
||||
encoder_outputs = self.encoder(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
head_mask=head_mask,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
|
||||
encoder_outputs = BaseModelOutput(
|
||||
last_hidden_state=encoder_outputs[0],
|
||||
hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
|
||||
attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
|
||||
)
|
||||
|
||||
hidden_states = encoder_outputs[0]
|
||||
|
||||
# Decode
|
||||
decoder_outputs = self.decoder(
|
||||
input_ids=decoder_input_ids,
|
||||
attention_mask=decoder_attention_mask,
|
||||
inputs_embeds=decoder_inputs_embeds,
|
||||
past_key_values=None,
|
||||
encoder_hidden_states=hidden_states,
|
||||
encoder_attention_mask=attention_mask,
|
||||
head_mask=decoder_head_mask,
|
||||
cross_attn_head_mask=cross_attn_head_mask,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
sequence_output = decoder_outputs[0]
|
||||
|
||||
logits = self.qa_outputs(sequence_output)
|
||||
start_logits, end_logits = logits.split(1, dim=-1)
|
||||
start_logits = start_logits.squeeze(-1).contiguous()
|
||||
end_logits = end_logits.squeeze(-1).contiguous()
|
||||
|
||||
total_loss = None
|
||||
if start_positions is not None and end_positions is not None:
|
||||
# If we are on multi-GPU, split add a dimension
|
||||
if len(start_positions.size()) > 1:
|
||||
start_positions = start_positions.squeeze(-1).to(start_logits.device)
|
||||
if len(end_positions.size()) > 1:
|
||||
end_positions = end_positions.squeeze(-1).to(end_logits.device)
|
||||
# sometimes the start/end positions are outside our model inputs, we ignore these terms
|
||||
ignored_index = start_logits.size(1)
|
||||
start_positions = start_positions.clamp(0, ignored_index)
|
||||
end_positions = end_positions.clamp(0, ignored_index)
|
||||
|
||||
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
|
||||
start_loss = loss_fct(start_logits, start_positions)
|
||||
end_loss = loss_fct(end_logits, end_positions)
|
||||
total_loss = (start_loss + end_loss) / 2
|
||||
|
||||
if not return_dict:
|
||||
output = (start_logits, end_logits) + decoder_outputs[1:] + encoder_outputs
|
||||
return ((total_loss,) + output) if total_loss is not None else output
|
||||
|
||||
return Seq2SeqQuestionAnsweringModelOutput(
|
||||
loss=total_loss,
|
||||
start_logits=start_logits,
|
||||
end_logits=end_logits,
|
||||
past_key_values=decoder_outputs.past_key_values,
|
||||
decoder_hidden_states=decoder_outputs.hidden_states,
|
||||
decoder_attentions=decoder_outputs.attentions,
|
||||
cross_attentions=decoder_outputs.cross_attentions,
|
||||
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
|
||||
encoder_hidden_states=encoder_outputs.hidden_states,
|
||||
encoder_attentions=encoder_outputs.attentions,
|
||||
)
|
||||
|
@ -4915,6 +4915,13 @@ class MT5ForConditionalGeneration(metaclass=DummyObject):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class MT5ForQuestionAnswering(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class MT5Model(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
@ -6742,6 +6749,13 @@ class T5ForConditionalGeneration(metaclass=DummyObject):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class T5ForQuestionAnswering(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class T5Model(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
|
@ -43,6 +43,7 @@ if is_torch_available():
|
||||
ByT5Tokenizer,
|
||||
T5EncoderModel,
|
||||
T5ForConditionalGeneration,
|
||||
T5ForQuestionAnswering,
|
||||
T5Model,
|
||||
T5Tokenizer,
|
||||
)
|
||||
@ -520,7 +521,7 @@ class T5ModelTester:
|
||||
|
||||
@require_torch
|
||||
class T5ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (T5Model, T5ForConditionalGeneration) if is_torch_available() else ()
|
||||
all_model_classes = (T5Model, T5ForConditionalGeneration, T5ForQuestionAnswering) if is_torch_available() else ()
|
||||
all_generative_model_classes = (T5ForConditionalGeneration,) if is_torch_available() else ()
|
||||
pipeline_model_mapping = (
|
||||
{
|
||||
@ -529,6 +530,7 @@ class T5ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
|
||||
"summarization": T5ForConditionalGeneration,
|
||||
"text2text-generation": T5ForConditionalGeneration,
|
||||
"translation": T5ForConditionalGeneration,
|
||||
"question-answering": T5ForQuestionAnswering,
|
||||
}
|
||||
if is_torch_available()
|
||||
else {}
|
||||
|
Loading…
Reference in New Issue
Block a user