mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
[T5
, MT5
, UMT5
] Add [T5, MT5, UMT5]ForSequenceClassification (#24726)
* Initial addition of t5forsequenceclassification * Adding imports and adding tests * Formatting * Running make fix-copies * Adding mt5forseq * Formatting * run make fix-copies * Adding to docs * Add model_parallel * Fix bug * Fix * Remove TODO * Fixing tests for T5ForSequenceClassification * Undo changes to dependency_versions_table.py * Change classification head to work with T5Config directly * Change seq length to let tests pass * PR comments for formatting * Formatting * Initial addition of UMT5ForSequenceClassification * Adding to inits and formatting * run make fix-copies * Add doc for UMT5ForSeqClass * Update UMT5 config * Fix docs * Skip torch fx test for SequenceClassification * Formatting * Add skip to UMT5 tests as well * Fix umt5 tests * Running make fix-copies * PR comments * Fix for change to sentence_representation * Rename seq_len to hidden_size since that's what it is * Use base_model to follow format of the rest of the library * Update docs * Extract the decoder_input_ids changes and make one liner * Make one-liner
This commit is contained in:
parent
21150cb0f3
commit
8f36ab3e22
@ -95,6 +95,10 @@ See [`T5TokenizerFast`] for all details.
|
||||
|
||||
[[autodoc]] MT5EncoderModel
|
||||
|
||||
## MT5ForSequenceClassification
|
||||
|
||||
[[autodoc]] MT5ForSequenceClassification
|
||||
|
||||
## MT5ForQuestionAnswering
|
||||
|
||||
[[autodoc]] MT5ForQuestionAnswering
|
||||
|
@ -401,6 +401,11 @@ A list of official Hugging Face and community (indicated by 🌎) resources to h
|
||||
[[autodoc]] T5EncoderModel
|
||||
- forward
|
||||
|
||||
## T5ForSequenceClassification
|
||||
|
||||
[[autodoc]] T5ForSequenceClassification
|
||||
- forward
|
||||
|
||||
## T5ForQuestionAnswering
|
||||
|
||||
[[autodoc]] T5ForQuestionAnswering
|
||||
|
@ -92,6 +92,11 @@ The conversion script is also different because the model was saved in t5x's lat
|
||||
[[autodoc]] UMT5EncoderModel
|
||||
- forward
|
||||
|
||||
## UMT5ForSequenceClassification
|
||||
|
||||
[[autodoc]] UMT5ForSequenceClassification
|
||||
- forward
|
||||
|
||||
## UMT5ForQuestionAnswering
|
||||
|
||||
[[autodoc]] UMT5ForQuestionAnswering
|
||||
|
@ -33,7 +33,7 @@ The task illustrated in this tutorial is supported by the following model archit
|
||||
<!--This tip is automatically generated by `make fix-copies`, do not fill manually!-->
|
||||
|
||||
|
||||
[ALBERT](../model_doc/albert), [BART](../model_doc/bart), [BERT](../model_doc/bert), [BigBird](../model_doc/big_bird), [BigBird-Pegasus](../model_doc/bigbird_pegasus), [BioGpt](../model_doc/biogpt), [BLOOM](../model_doc/bloom), [CamemBERT](../model_doc/camembert), [CANINE](../model_doc/canine), [ConvBERT](../model_doc/convbert), [CTRL](../model_doc/ctrl), [Data2VecText](../model_doc/data2vec-text), [DeBERTa](../model_doc/deberta), [DeBERTa-v2](../model_doc/deberta-v2), [DistilBERT](../model_doc/distilbert), [ELECTRA](../model_doc/electra), [ERNIE](../model_doc/ernie), [ErnieM](../model_doc/ernie_m), [ESM](../model_doc/esm), [Falcon](../model_doc/falcon), [FlauBERT](../model_doc/flaubert), [FNet](../model_doc/fnet), [Funnel Transformer](../model_doc/funnel), [GPT-Sw3](../model_doc/gpt-sw3), [OpenAI GPT-2](../model_doc/gpt2), [GPTBigCode](../model_doc/gpt_bigcode), [GPT Neo](../model_doc/gpt_neo), [GPT NeoX](../model_doc/gpt_neox), [GPT-J](../model_doc/gptj), [I-BERT](../model_doc/ibert), [LayoutLM](../model_doc/layoutlm), [LayoutLMv2](../model_doc/layoutlmv2), [LayoutLMv3](../model_doc/layoutlmv3), [LED](../model_doc/led), [LiLT](../model_doc/lilt), [LLaMA](../model_doc/llama), [Longformer](../model_doc/longformer), [LUKE](../model_doc/luke), [MarkupLM](../model_doc/markuplm), [mBART](../model_doc/mbart), [MEGA](../model_doc/mega), [Megatron-BERT](../model_doc/megatron-bert), [MobileBERT](../model_doc/mobilebert), [MPNet](../model_doc/mpnet), [MPT](../model_doc/mpt), [MRA](../model_doc/mra), [MVP](../model_doc/mvp), [Nezha](../model_doc/nezha), [Nyströmformer](../model_doc/nystromformer), [OpenLlama](../model_doc/open-llama), [OpenAI GPT](../model_doc/openai-gpt), [OPT](../model_doc/opt), [Perceiver](../model_doc/perceiver), [PLBart](../model_doc/plbart), [QDQBert](../model_doc/qdqbert), [Reformer](../model_doc/reformer), [RemBERT](../model_doc/rembert), [RoBERTa](../model_doc/roberta), [RoBERTa-PreLayerNorm](../model_doc/roberta-prelayernorm), [RoCBert](../model_doc/roc_bert), [RoFormer](../model_doc/roformer), [SqueezeBERT](../model_doc/squeezebert), [TAPAS](../model_doc/tapas), [Transformer-XL](../model_doc/transfo-xl), [XLM](../model_doc/xlm), [XLM-RoBERTa](../model_doc/xlm-roberta), [XLM-RoBERTa-XL](../model_doc/xlm-roberta-xl), [XLNet](../model_doc/xlnet), [X-MOD](../model_doc/xmod), [YOSO](../model_doc/yoso)
|
||||
[ALBERT](../model_doc/albert), [BART](../model_doc/bart), [BERT](../model_doc/bert), [BigBird](../model_doc/big_bird), [BigBird-Pegasus](../model_doc/bigbird_pegasus), [BioGpt](../model_doc/biogpt), [BLOOM](../model_doc/bloom), [CamemBERT](../model_doc/camembert), [CANINE](../model_doc/canine), [ConvBERT](../model_doc/convbert), [CTRL](../model_doc/ctrl), [Data2VecText](../model_doc/data2vec-text), [DeBERTa](../model_doc/deberta), [DeBERTa-v2](../model_doc/deberta-v2), [DistilBERT](../model_doc/distilbert), [ELECTRA](../model_doc/electra), [ERNIE](../model_doc/ernie), [ErnieM](../model_doc/ernie_m), [ESM](../model_doc/esm), [Falcon](../model_doc/falcon), [FlauBERT](../model_doc/flaubert), [FNet](../model_doc/fnet), [Funnel Transformer](../model_doc/funnel), [GPT-Sw3](../model_doc/gpt-sw3), [OpenAI GPT-2](../model_doc/gpt2), [GPTBigCode](../model_doc/gpt_bigcode), [GPT Neo](../model_doc/gpt_neo), [GPT NeoX](../model_doc/gpt_neox), [GPT-J](../model_doc/gptj), [I-BERT](../model_doc/ibert), [LayoutLM](../model_doc/layoutlm), [LayoutLMv2](../model_doc/layoutlmv2), [LayoutLMv3](../model_doc/layoutlmv3), [LED](../model_doc/led), [LiLT](../model_doc/lilt), [LLaMA](../model_doc/llama), [Longformer](../model_doc/longformer), [LUKE](../model_doc/luke), [MarkupLM](../model_doc/markuplm), [mBART](../model_doc/mbart), [MEGA](../model_doc/mega), [Megatron-BERT](../model_doc/megatron-bert), [MobileBERT](../model_doc/mobilebert), [MPNet](../model_doc/mpnet), [MPT](../model_doc/mpt), [MRA](../model_doc/mra), [MT5](../model_doc/mt5), [MVP](../model_doc/mvp), [Nezha](../model_doc/nezha), [Nyströmformer](../model_doc/nystromformer), [OpenLlama](../model_doc/open-llama), [OpenAI GPT](../model_doc/openai-gpt), [OPT](../model_doc/opt), [Perceiver](../model_doc/perceiver), [PLBart](../model_doc/plbart), [QDQBert](../model_doc/qdqbert), [Reformer](../model_doc/reformer), [RemBERT](../model_doc/rembert), [RoBERTa](../model_doc/roberta), [RoBERTa-PreLayerNorm](../model_doc/roberta-prelayernorm), [RoCBert](../model_doc/roc_bert), [RoFormer](../model_doc/roformer), [SqueezeBERT](../model_doc/squeezebert), [T5](../model_doc/t5), [TAPAS](../model_doc/tapas), [Transformer-XL](../model_doc/transfo-xl), [UMT5](../model_doc/umt5), [XLM](../model_doc/xlm), [XLM-RoBERTa](../model_doc/xlm-roberta), [XLM-RoBERTa-XL](../model_doc/xlm-roberta-xl), [XLNet](../model_doc/xlnet), [X-MOD](../model_doc/xmod), [YOSO](../model_doc/yoso)
|
||||
|
||||
|
||||
|
||||
|
@ -2240,7 +2240,14 @@ else:
|
||||
]
|
||||
)
|
||||
_import_structure["models.mt5"].extend(
|
||||
["MT5EncoderModel", "MT5ForConditionalGeneration", "MT5ForQuestionAnswering", "MT5Model", "MT5PreTrainedModel"]
|
||||
[
|
||||
"MT5EncoderModel",
|
||||
"MT5ForConditionalGeneration",
|
||||
"MT5ForQuestionAnswering",
|
||||
"MT5ForSequenceClassification",
|
||||
"MT5Model",
|
||||
"MT5PreTrainedModel",
|
||||
]
|
||||
)
|
||||
_import_structure["models.musicgen"].extend(
|
||||
[
|
||||
@ -2694,6 +2701,7 @@ else:
|
||||
"T5EncoderModel",
|
||||
"T5ForConditionalGeneration",
|
||||
"T5ForQuestionAnswering",
|
||||
"T5ForSequenceClassification",
|
||||
"T5Model",
|
||||
"T5PreTrainedModel",
|
||||
"load_tf_weights_in_t5",
|
||||
@ -2763,6 +2771,7 @@ else:
|
||||
"UMT5EncoderModel",
|
||||
"UMT5ForConditionalGeneration",
|
||||
"UMT5ForQuestionAnswering",
|
||||
"UMT5ForSequenceClassification",
|
||||
"UMT5Model",
|
||||
"UMT5PreTrainedModel",
|
||||
]
|
||||
@ -5930,6 +5939,7 @@ if TYPE_CHECKING:
|
||||
MT5EncoderModel,
|
||||
MT5ForConditionalGeneration,
|
||||
MT5ForQuestionAnswering,
|
||||
MT5ForSequenceClassification,
|
||||
MT5Model,
|
||||
MT5PreTrainedModel,
|
||||
)
|
||||
@ -6303,6 +6313,7 @@ if TYPE_CHECKING:
|
||||
T5EncoderModel,
|
||||
T5ForConditionalGeneration,
|
||||
T5ForQuestionAnswering,
|
||||
T5ForSequenceClassification,
|
||||
T5Model,
|
||||
T5PreTrainedModel,
|
||||
load_tf_weights_in_t5,
|
||||
@ -6356,6 +6367,7 @@ if TYPE_CHECKING:
|
||||
UMT5EncoderModel,
|
||||
UMT5ForConditionalGeneration,
|
||||
UMT5ForQuestionAnswering,
|
||||
UMT5ForSequenceClassification,
|
||||
UMT5Model,
|
||||
UMT5PreTrainedModel,
|
||||
)
|
||||
|
@ -724,6 +724,7 @@ MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
|
||||
("mpnet", "MPNetForSequenceClassification"),
|
||||
("mpt", "MptForSequenceClassification"),
|
||||
("mra", "MraForSequenceClassification"),
|
||||
("mt5", "MT5ForSequenceClassification"),
|
||||
("mvp", "MvpForSequenceClassification"),
|
||||
("nezha", "NezhaForSequenceClassification"),
|
||||
("nystromformer", "NystromformerForSequenceClassification"),
|
||||
@ -740,8 +741,10 @@ MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
|
||||
("roc_bert", "RoCBertForSequenceClassification"),
|
||||
("roformer", "RoFormerForSequenceClassification"),
|
||||
("squeezebert", "SqueezeBertForSequenceClassification"),
|
||||
("t5", "T5ForSequenceClassification"),
|
||||
("tapas", "TapasForSequenceClassification"),
|
||||
("transfo-xl", "TransfoXLForSequenceClassification"),
|
||||
("umt5", "UMT5ForSequenceClassification"),
|
||||
("xlm", "XLMForSequenceClassification"),
|
||||
("xlm-roberta", "XLMRobertaForSequenceClassification"),
|
||||
("xlm-roberta-xl", "XLMRobertaXLForSequenceClassification"),
|
||||
|
@ -51,6 +51,7 @@ else:
|
||||
"MT5EncoderModel",
|
||||
"MT5ForConditionalGeneration",
|
||||
"MT5ForQuestionAnswering",
|
||||
"MT5ForSequenceClassification",
|
||||
"MT5Model",
|
||||
"MT5PreTrainedModel",
|
||||
"MT5Stack",
|
||||
@ -86,6 +87,7 @@ if TYPE_CHECKING:
|
||||
MT5EncoderModel,
|
||||
MT5ForConditionalGeneration,
|
||||
MT5ForQuestionAnswering,
|
||||
MT5ForSequenceClassification,
|
||||
MT5Model,
|
||||
MT5PreTrainedModel,
|
||||
MT5Stack,
|
||||
|
@ -56,6 +56,8 @@ class MT5Config(PretrainedConfig):
|
||||
The maximum distance of the longer sequences for the bucket separation.
|
||||
dropout_rate (`float`, *optional*, defaults to 0.1):
|
||||
The ratio for all dropout layers.
|
||||
classifier_dropout (`float`, *optional*, defaults to 0.0):
|
||||
The dropout ratio for classifier.
|
||||
layer_norm_eps (`float`, *optional*, defaults to 1e-6):
|
||||
The epsilon used by the layer normalization layers.
|
||||
initializer_factor (`float`, *optional*, defaults to 1):
|
||||
@ -91,6 +93,7 @@ class MT5Config(PretrainedConfig):
|
||||
pad_token_id=0,
|
||||
eos_token_id=1,
|
||||
decoder_start_token_id=0,
|
||||
classifier_dropout=0.0,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(
|
||||
@ -114,6 +117,7 @@ class MT5Config(PretrainedConfig):
|
||||
self.relative_attention_num_buckets = relative_attention_num_buckets
|
||||
self.relative_attention_max_distance = relative_attention_max_distance
|
||||
self.dropout_rate = dropout_rate
|
||||
self.classifier_dropout = classifier_dropout
|
||||
self.layer_norm_epsilon = layer_norm_epsilon
|
||||
self.initializer_factor = initializer_factor
|
||||
self.feed_forward_proj = feed_forward_proj
|
||||
|
@ -18,11 +18,11 @@ import copy
|
||||
import math
|
||||
import os
|
||||
import warnings
|
||||
from typing import Optional, Tuple, Union
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import CrossEntropyLoss
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
from ...activations import ACT2FN
|
||||
@ -32,6 +32,7 @@ from ...modeling_outputs import (
|
||||
Seq2SeqLMOutput,
|
||||
Seq2SeqModelOutput,
|
||||
Seq2SeqQuestionAnsweringModelOutput,
|
||||
Seq2SeqSequenceClassifierOutput,
|
||||
)
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
|
||||
@ -742,6 +743,25 @@ def load_tf_weights_in_mt5(model, config, tf_checkpoint_path):
|
||||
return model
|
||||
|
||||
|
||||
# Copied from transformers.models.t5.modeling_t5.T5ClassificationHead with T5->MT5
|
||||
class MT5ClassificationHead(nn.Module):
|
||||
"""Head for sentence-level classification tasks."""
|
||||
|
||||
def __init__(self, config: MT5Config):
|
||||
super().__init__()
|
||||
self.dense = nn.Linear(config.d_model, config.d_model)
|
||||
self.dropout = nn.Dropout(p=config.classifier_dropout)
|
||||
self.out_proj = nn.Linear(config.d_model, config.num_labels)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
hidden_states = self.dense(hidden_states)
|
||||
hidden_states = torch.tanh(hidden_states)
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
hidden_states = self.out_proj(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
# Copied from transformers.models.t5.modeling_t5.T5PreTrainedModel with T5->MT5, t5->mt5
|
||||
class MT5PreTrainedModel(PreTrainedModel):
|
||||
"""
|
||||
@ -773,7 +793,10 @@ class MT5PreTrainedModel(PreTrainedModel):
|
||||
factor = self.config.initializer_factor # Used for testing weights initialization
|
||||
if isinstance(module, MT5LayerNorm):
|
||||
module.weight.data.fill_(factor * 1.0)
|
||||
elif isinstance(module, (MT5Model, MT5ForConditionalGeneration, MT5EncoderModel, MT5ForQuestionAnswering)):
|
||||
elif isinstance(
|
||||
module,
|
||||
(MT5Model, MT5ForConditionalGeneration, MT5EncoderModel, MT5ForQuestionAnswering),
|
||||
):
|
||||
# Mesh TensorFlow embeddings initialization
|
||||
# See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L1624
|
||||
module.shared.weight.data.normal_(mean=0.0, std=factor * 1.0)
|
||||
@ -782,6 +805,13 @@ class MT5PreTrainedModel(PreTrainedModel):
|
||||
if hasattr(module, "qa_outputs"):
|
||||
module.qa_outputs.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
|
||||
module.qa_outputs.bias.data.zero_()
|
||||
elif isinstance(module, MT5ClassificationHead):
|
||||
module.dense.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
|
||||
if hasattr(module.dense, "bias") and module.dense.bias is not None:
|
||||
module.dense.bias.data.zero_()
|
||||
module.out_proj.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
|
||||
if hasattr(module.out_proj, "bias") and module.out_proj.bias is not None:
|
||||
module.out_proj.bias.data.zero_()
|
||||
elif isinstance(module, MT5DenseActDense):
|
||||
# Mesh TensorFlow FF initialization
|
||||
# See https://github.com/tensorflow/mesh/blob/master/mesh_tensorflow/transformer/transformer_layers.py#L56
|
||||
@ -1996,6 +2026,141 @@ class MT5EncoderModel(MT5PreTrainedModel):
|
||||
return encoder_outputs
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""
|
||||
MT5 model with a sequence classification/head on top (a linear layer on top of the pooled output) e.g. for GLUE
|
||||
tasks.
|
||||
""",
|
||||
MT5_START_DOCSTRING,
|
||||
)
|
||||
class MT5ForSequenceClassification(MT5PreTrainedModel):
|
||||
_keys_to_ignore_on_load_unexpected = ["decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight"]
|
||||
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
|
||||
|
||||
# Copied from transformers.models.t5.modeling_t5.T5ForSequenceClassification.__init__ with T5->MT5
|
||||
def __init__(self, config: MT5Config):
|
||||
super().__init__(config)
|
||||
self.transformer = MT5Model(config)
|
||||
self.classification_head = MT5ClassificationHead(config)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
self.model_parallel = False
|
||||
|
||||
@add_start_docstrings_to_model_forward(MT5_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(output_type=Seq2SeqSequenceClassifierOutput, config_class=_CONFIG_FOR_DOC)
|
||||
# Copied from transformers.models.t5.modeling_t5.T5ForSequenceClassification.forward
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
decoder_input_ids: Optional[torch.LongTensor] = None,
|
||||
decoder_attention_mask: Optional[torch.LongTensor] = None,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
decoder_head_mask: Optional[torch.Tensor] = None,
|
||||
cross_attn_head_mask: Optional[torch.Tensor] = None,
|
||||
encoder_outputs: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, Seq2SeqSequenceClassifierOutput]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
||||
config.num_labels - 1]`. If `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
||||
Returns:
|
||||
"""
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
if labels is not None:
|
||||
use_cache = False
|
||||
|
||||
if input_ids is None and inputs_embeds is not None:
|
||||
raise NotImplementedError(
|
||||
f"Passing input embeddings is currently not supported for {self.__class__.__name__}"
|
||||
)
|
||||
|
||||
# Copied from models.bart.modeling_bart.BartModel.forward different to other models, T5 automatically creates
|
||||
# decoder_input_ids from input_ids if no decoder_input_ids are provided
|
||||
if decoder_input_ids is None and decoder_inputs_embeds is None:
|
||||
if input_ids is None:
|
||||
raise ValueError(
|
||||
"If no `decoder_input_ids` or `decoder_inputs_embeds` are "
|
||||
"passed, `input_ids` cannot be `None`. Please pass either "
|
||||
"`input_ids` or `decoder_input_ids` or `decoder_inputs_embeds`."
|
||||
)
|
||||
decoder_input_ids = self._shift_right(input_ids)
|
||||
|
||||
outputs = self.transformer(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
head_mask=head_mask,
|
||||
decoder_head_mask=decoder_head_mask,
|
||||
cross_attn_head_mask=cross_attn_head_mask,
|
||||
encoder_outputs=encoder_outputs,
|
||||
inputs_embeds=inputs_embeds,
|
||||
decoder_inputs_embeds=decoder_inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
sequence_output = outputs[0]
|
||||
|
||||
eos_mask = input_ids.eq(self.config.eos_token_id).to(sequence_output.device)
|
||||
|
||||
if len(torch.unique_consecutive(eos_mask.sum(1))) > 1:
|
||||
raise ValueError("All examples must have the same number of <eos> tokens.")
|
||||
batch_size, _, hidden_size = sequence_output.shape
|
||||
sentence_representation = sequence_output[eos_mask, :].view(batch_size, -1, hidden_size)[:, -1, :]
|
||||
logits = self.classification_head(sentence_representation)
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
labels = labels.to(logits.device)
|
||||
if self.config.problem_type is None:
|
||||
if self.config.num_labels == 1:
|
||||
self.config.problem_type = "regression"
|
||||
elif self.config.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
||||
self.config.problem_type = "single_label_classification"
|
||||
else:
|
||||
self.config.problem_type = "multi_label_classification"
|
||||
|
||||
if self.config.problem_type == "regression":
|
||||
loss_fct = MSELoss()
|
||||
if self.config.num_labels == 1:
|
||||
loss = loss_fct(logits.squeeze(), labels.squeeze())
|
||||
else:
|
||||
loss = loss_fct(logits, labels)
|
||||
elif self.config.problem_type == "single_label_classification":
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
|
||||
elif self.config.problem_type == "multi_label_classification":
|
||||
loss_fct = BCEWithLogitsLoss()
|
||||
loss = loss_fct(logits, labels)
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return Seq2SeqSequenceClassifierOutput(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
decoder_hidden_states=outputs.decoder_hidden_states,
|
||||
decoder_attentions=outputs.decoder_attentions,
|
||||
cross_attentions=outputs.cross_attentions,
|
||||
encoder_last_hidden_state=outputs.encoder_last_hidden_state,
|
||||
encoder_hidden_states=outputs.encoder_hidden_states,
|
||||
encoder_attentions=outputs.encoder_attentions,
|
||||
)
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""
|
||||
MT5 Model with a span classification head on top for extractive question-answering tasks like SQuAD (linear layers
|
||||
|
@ -57,6 +57,7 @@ else:
|
||||
"T5PreTrainedModel",
|
||||
"load_tf_weights_in_t5",
|
||||
"T5ForQuestionAnswering",
|
||||
"T5ForSequenceClassification",
|
||||
]
|
||||
|
||||
try:
|
||||
@ -117,6 +118,7 @@ if TYPE_CHECKING:
|
||||
T5EncoderModel,
|
||||
T5ForConditionalGeneration,
|
||||
T5ForQuestionAnswering,
|
||||
T5ForSequenceClassification,
|
||||
T5Model,
|
||||
T5PreTrainedModel,
|
||||
load_tf_weights_in_t5,
|
||||
|
@ -64,6 +64,8 @@ class T5Config(PretrainedConfig):
|
||||
The maximum distance of the longer sequences for the bucket separation.
|
||||
dropout_rate (`float`, *optional*, defaults to 0.1):
|
||||
The ratio for all dropout layers.
|
||||
classifier_dropout (`float`, *optional*, defaults to 0.0):
|
||||
The dropout ratio for classifier.
|
||||
layer_norm_eps (`float`, *optional*, defaults to 1e-6):
|
||||
The epsilon used by the layer normalization layers.
|
||||
initializer_factor (`float`, *optional*, defaults to 1):
|
||||
@ -98,6 +100,7 @@ class T5Config(PretrainedConfig):
|
||||
use_cache=True,
|
||||
pad_token_id=0,
|
||||
eos_token_id=1,
|
||||
classifier_dropout=0.0,
|
||||
**kwargs,
|
||||
):
|
||||
self.vocab_size = vocab_size
|
||||
@ -112,6 +115,7 @@ class T5Config(PretrainedConfig):
|
||||
self.relative_attention_num_buckets = relative_attention_num_buckets
|
||||
self.relative_attention_max_distance = relative_attention_max_distance
|
||||
self.dropout_rate = dropout_rate
|
||||
self.classifier_dropout = classifier_dropout
|
||||
self.layer_norm_epsilon = layer_norm_epsilon
|
||||
self.initializer_factor = initializer_factor
|
||||
self.feed_forward_proj = feed_forward_proj
|
||||
|
@ -19,11 +19,11 @@ import copy
|
||||
import math
|
||||
import os
|
||||
import warnings
|
||||
from typing import Optional, Tuple, Union
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import CrossEntropyLoss
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
from ...activations import ACT2FN
|
||||
@ -33,6 +33,7 @@ from ...modeling_outputs import (
|
||||
Seq2SeqLMOutput,
|
||||
Seq2SeqModelOutput,
|
||||
Seq2SeqQuestionAnsweringModelOutput,
|
||||
Seq2SeqSequenceClassifierOutput,
|
||||
)
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...pytorch_utils import ALL_LAYERNORM_LAYERS, find_pruneable_heads_and_indices, prune_linear_layer
|
||||
@ -772,6 +773,24 @@ class T5Block(nn.Module):
|
||||
return outputs # hidden-states, present_key_value_states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights)
|
||||
|
||||
|
||||
class T5ClassificationHead(nn.Module):
|
||||
"""Head for sentence-level classification tasks."""
|
||||
|
||||
def __init__(self, config: T5Config):
|
||||
super().__init__()
|
||||
self.dense = nn.Linear(config.d_model, config.d_model)
|
||||
self.dropout = nn.Dropout(p=config.classifier_dropout)
|
||||
self.out_proj = nn.Linear(config.d_model, config.num_labels)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
hidden_states = self.dense(hidden_states)
|
||||
hidden_states = torch.tanh(hidden_states)
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
hidden_states = self.out_proj(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class T5PreTrainedModel(PreTrainedModel):
|
||||
"""
|
||||
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
||||
@ -802,7 +821,10 @@ class T5PreTrainedModel(PreTrainedModel):
|
||||
factor = self.config.initializer_factor # Used for testing weights initialization
|
||||
if isinstance(module, T5LayerNorm):
|
||||
module.weight.data.fill_(factor * 1.0)
|
||||
elif isinstance(module, (T5Model, T5ForConditionalGeneration, T5EncoderModel, T5ForQuestionAnswering)):
|
||||
elif isinstance(
|
||||
module,
|
||||
(T5Model, T5ForConditionalGeneration, T5EncoderModel, T5ForQuestionAnswering),
|
||||
):
|
||||
# Mesh TensorFlow embeddings initialization
|
||||
# See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L1624
|
||||
module.shared.weight.data.normal_(mean=0.0, std=factor * 1.0)
|
||||
@ -811,6 +833,13 @@ class T5PreTrainedModel(PreTrainedModel):
|
||||
if hasattr(module, "qa_outputs"):
|
||||
module.qa_outputs.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
|
||||
module.qa_outputs.bias.data.zero_()
|
||||
elif isinstance(module, T5ClassificationHead):
|
||||
module.dense.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
|
||||
if hasattr(module.dense, "bias") and module.dense.bias is not None:
|
||||
module.dense.bias.data.zero_()
|
||||
module.out_proj.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
|
||||
if hasattr(module.out_proj, "bias") and module.out_proj.bias is not None:
|
||||
module.out_proj.bias.data.zero_()
|
||||
elif isinstance(module, T5DenseActDense):
|
||||
# Mesh TensorFlow FF initialization
|
||||
# See https://github.com/tensorflow/mesh/blob/master/mesh_tensorflow/transformer/transformer_layers.py#L56
|
||||
@ -1945,6 +1974,139 @@ class T5EncoderModel(T5PreTrainedModel):
|
||||
return encoder_outputs
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""
|
||||
T5 model with a sequence classification/head on top (a linear layer on top of the pooled output) e.g. for GLUE
|
||||
tasks.
|
||||
""",
|
||||
T5_START_DOCSTRING,
|
||||
)
|
||||
class T5ForSequenceClassification(T5PreTrainedModel):
|
||||
_keys_to_ignore_on_load_unexpected = ["decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight"]
|
||||
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
|
||||
|
||||
def __init__(self, config: T5Config):
|
||||
super().__init__(config)
|
||||
self.transformer = T5Model(config)
|
||||
self.classification_head = T5ClassificationHead(config)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
self.model_parallel = False
|
||||
|
||||
@add_start_docstrings_to_model_forward(T5_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(output_type=Seq2SeqSequenceClassifierOutput, config_class=_CONFIG_FOR_DOC)
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
decoder_input_ids: Optional[torch.LongTensor] = None,
|
||||
decoder_attention_mask: Optional[torch.LongTensor] = None,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
decoder_head_mask: Optional[torch.Tensor] = None,
|
||||
cross_attn_head_mask: Optional[torch.Tensor] = None,
|
||||
encoder_outputs: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, Seq2SeqSequenceClassifierOutput]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
||||
config.num_labels - 1]`. If `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
||||
Returns:
|
||||
"""
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
if labels is not None:
|
||||
use_cache = False
|
||||
|
||||
if input_ids is None and inputs_embeds is not None:
|
||||
raise NotImplementedError(
|
||||
f"Passing input embeddings is currently not supported for {self.__class__.__name__}"
|
||||
)
|
||||
|
||||
# Copied from models.bart.modeling_bart.BartModel.forward different to other models, T5 automatically creates
|
||||
# decoder_input_ids from input_ids if no decoder_input_ids are provided
|
||||
if decoder_input_ids is None and decoder_inputs_embeds is None:
|
||||
if input_ids is None:
|
||||
raise ValueError(
|
||||
"If no `decoder_input_ids` or `decoder_inputs_embeds` are "
|
||||
"passed, `input_ids` cannot be `None`. Please pass either "
|
||||
"`input_ids` or `decoder_input_ids` or `decoder_inputs_embeds`."
|
||||
)
|
||||
decoder_input_ids = self._shift_right(input_ids)
|
||||
|
||||
outputs = self.transformer(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
head_mask=head_mask,
|
||||
decoder_head_mask=decoder_head_mask,
|
||||
cross_attn_head_mask=cross_attn_head_mask,
|
||||
encoder_outputs=encoder_outputs,
|
||||
inputs_embeds=inputs_embeds,
|
||||
decoder_inputs_embeds=decoder_inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
sequence_output = outputs[0]
|
||||
|
||||
eos_mask = input_ids.eq(self.config.eos_token_id).to(sequence_output.device)
|
||||
|
||||
if len(torch.unique_consecutive(eos_mask.sum(1))) > 1:
|
||||
raise ValueError("All examples must have the same number of <eos> tokens.")
|
||||
batch_size, _, hidden_size = sequence_output.shape
|
||||
sentence_representation = sequence_output[eos_mask, :].view(batch_size, -1, hidden_size)[:, -1, :]
|
||||
logits = self.classification_head(sentence_representation)
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
labels = labels.to(logits.device)
|
||||
if self.config.problem_type is None:
|
||||
if self.config.num_labels == 1:
|
||||
self.config.problem_type = "regression"
|
||||
elif self.config.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
||||
self.config.problem_type = "single_label_classification"
|
||||
else:
|
||||
self.config.problem_type = "multi_label_classification"
|
||||
|
||||
if self.config.problem_type == "regression":
|
||||
loss_fct = MSELoss()
|
||||
if self.config.num_labels == 1:
|
||||
loss = loss_fct(logits.squeeze(), labels.squeeze())
|
||||
else:
|
||||
loss = loss_fct(logits, labels)
|
||||
elif self.config.problem_type == "single_label_classification":
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
|
||||
elif self.config.problem_type == "multi_label_classification":
|
||||
loss_fct = BCEWithLogitsLoss()
|
||||
loss = loss_fct(logits, labels)
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return Seq2SeqSequenceClassifierOutput(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
decoder_hidden_states=outputs.decoder_hidden_states,
|
||||
decoder_attentions=outputs.decoder_attentions,
|
||||
cross_attentions=outputs.cross_attentions,
|
||||
encoder_last_hidden_state=outputs.encoder_last_hidden_state,
|
||||
encoder_hidden_states=outputs.encoder_hidden_states,
|
||||
encoder_attentions=outputs.encoder_attentions,
|
||||
)
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""
|
||||
T5 Model with a span classification head on top for extractive question-answering tasks like SQuAD (linear layers
|
||||
|
@ -30,6 +30,7 @@ else:
|
||||
"UMT5EncoderModel",
|
||||
"UMT5ForConditionalGeneration",
|
||||
"UMT5ForQuestionAnswering",
|
||||
"UMT5ForSequenceClassification",
|
||||
"UMT5Model",
|
||||
"UMT5PreTrainedModel",
|
||||
]
|
||||
@ -47,6 +48,7 @@ if TYPE_CHECKING:
|
||||
UMT5EncoderModel,
|
||||
UMT5ForConditionalGeneration,
|
||||
UMT5ForQuestionAnswering,
|
||||
UMT5ForSequenceClassification,
|
||||
UMT5Model,
|
||||
UMT5PreTrainedModel,
|
||||
)
|
||||
|
@ -61,6 +61,8 @@ class UMT5Config(PretrainedConfig):
|
||||
The maximum distance of the longer sequences for the bucket separation.
|
||||
dropout_rate (`float`, *optional*, defaults to 0.1):
|
||||
The ratio for all dropout layers.
|
||||
classifier_dropout (`float`, *optional*, defaults to 0.0):
|
||||
The dropout ratio for classifier.
|
||||
layer_norm_eps (`float`, *optional*, defaults to 1e-6):
|
||||
The epsilon used by the layer normalization layers.
|
||||
initializer_factor (`float`, *optional*, defaults to 1):
|
||||
@ -96,6 +98,7 @@ class UMT5Config(PretrainedConfig):
|
||||
pad_token_id=0,
|
||||
eos_token_id=1,
|
||||
decoder_start_token_id=0,
|
||||
classifier_dropout=0.0,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(
|
||||
@ -119,6 +122,7 @@ class UMT5Config(PretrainedConfig):
|
||||
self.relative_attention_num_buckets = relative_attention_num_buckets
|
||||
self.relative_attention_max_distance = relative_attention_max_distance
|
||||
self.dropout_rate = dropout_rate
|
||||
self.classifier_dropout = classifier_dropout
|
||||
self.layer_norm_epsilon = layer_norm_epsilon
|
||||
self.initializer_factor = initializer_factor
|
||||
self.feed_forward_proj = feed_forward_proj
|
||||
|
@ -16,11 +16,11 @@
|
||||
|
||||
import copy
|
||||
import math
|
||||
from typing import Optional, Tuple, Union
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import CrossEntropyLoss
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
from ...activations import ACT2FN
|
||||
@ -30,6 +30,7 @@ from ...modeling_outputs import (
|
||||
Seq2SeqLMOutput,
|
||||
Seq2SeqModelOutput,
|
||||
Seq2SeqQuestionAnsweringModelOutput,
|
||||
Seq2SeqSequenceClassifierOutput,
|
||||
)
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...utils import (
|
||||
@ -451,6 +452,25 @@ class UMT5Block(nn.Module):
|
||||
return outputs
|
||||
|
||||
|
||||
# Copied from transformers.models.t5.modeling_t5.T5ClassificationHead with T5->UMT5
|
||||
class UMT5ClassificationHead(nn.Module):
|
||||
"""Head for sentence-level classification tasks."""
|
||||
|
||||
def __init__(self, config: UMT5Config):
|
||||
super().__init__()
|
||||
self.dense = nn.Linear(config.d_model, config.d_model)
|
||||
self.dropout = nn.Dropout(p=config.classifier_dropout)
|
||||
self.out_proj = nn.Linear(config.d_model, config.num_labels)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
hidden_states = self.dense(hidden_states)
|
||||
hidden_states = torch.tanh(hidden_states)
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
hidden_states = self.out_proj(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class UMT5PreTrainedModel(PreTrainedModel):
|
||||
"""
|
||||
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
||||
@ -479,7 +499,15 @@ class UMT5PreTrainedModel(PreTrainedModel):
|
||||
factor = self.config.initializer_factor # Used for testing weights initialization
|
||||
if isinstance(module, UMT5LayerNorm):
|
||||
module.weight.data.fill_(factor * 1.0)
|
||||
elif isinstance(module, (UMT5Model, UMT5ForConditionalGeneration, UMT5EncoderModel, UMT5ForQuestionAnswering)):
|
||||
elif isinstance(
|
||||
module,
|
||||
(
|
||||
UMT5Model,
|
||||
UMT5ForConditionalGeneration,
|
||||
UMT5EncoderModel,
|
||||
UMT5ForQuestionAnswering,
|
||||
),
|
||||
):
|
||||
# Mesh TensorFlow embeddings initialization
|
||||
# See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L1624
|
||||
module.shared.weight.data.normal_(mean=0.0, std=factor * 1.0)
|
||||
@ -488,6 +516,13 @@ class UMT5PreTrainedModel(PreTrainedModel):
|
||||
if hasattr(module, "qa_outputs"):
|
||||
module.qa_outputs.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
|
||||
module.qa_outputs.bias.data.zero_()
|
||||
elif isinstance(module, UMT5ClassificationHead):
|
||||
module.dense.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
|
||||
if hasattr(module.dense, "bias") and module.dense.bias is not None:
|
||||
module.dense.bias.data.zero_()
|
||||
module.out_proj.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
|
||||
if hasattr(module.out_proj, "bias") and module.out_proj.bias is not None:
|
||||
module.out_proj.bias.data.zero_()
|
||||
elif isinstance(module, UMT5DenseActDense):
|
||||
# Mesh TensorFlow FF initialization
|
||||
# See https://github.com/tensorflow/mesh/blob/master/mesh_tensorflow/transformer/transformer_layers.py#L56
|
||||
@ -1401,6 +1436,140 @@ class UMT5EncoderModel(UMT5PreTrainedModel):
|
||||
return encoder_outputs
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""
|
||||
UMT5 model with a sequence classification/head on top (a linear layer on top of the pooled output) e.g. for GLUE
|
||||
tasks.
|
||||
""",
|
||||
UMT5_START_DOCSTRING,
|
||||
)
|
||||
class UMT5ForSequenceClassification(UMT5PreTrainedModel):
|
||||
_keys_to_ignore_on_load_unexpected = ["decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight"]
|
||||
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
|
||||
|
||||
# Copied from transformers.models.t5.modeling_t5.T5ForSequenceClassification.__init__ with T5->UMT5
|
||||
def __init__(self, config: UMT5Config):
|
||||
super().__init__(config)
|
||||
self.transformer = UMT5Model(config)
|
||||
self.classification_head = UMT5ClassificationHead(config)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
self.model_parallel = False
|
||||
|
||||
@add_start_docstrings_to_model_forward(UMT5_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(output_type=Seq2SeqSequenceClassifierOutput, config_class=_CONFIG_FOR_DOC)
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
decoder_input_ids: Optional[torch.LongTensor] = None,
|
||||
decoder_attention_mask: Optional[torch.LongTensor] = None,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
decoder_head_mask: Optional[torch.Tensor] = None,
|
||||
cross_attn_head_mask: Optional[torch.Tensor] = None,
|
||||
encoder_outputs: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, Seq2SeqSequenceClassifierOutput]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
||||
config.num_labels - 1]`. If `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
||||
Returns:
|
||||
"""
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
if labels is not None:
|
||||
use_cache = False
|
||||
|
||||
if input_ids is None and inputs_embeds is not None:
|
||||
raise NotImplementedError(
|
||||
f"Passing input embeddings is currently not supported for {self.__class__.__name__}"
|
||||
)
|
||||
|
||||
# Copied from models.bart.modeling_bart.BartModel.forward different to other models, T5 automatically creates
|
||||
# decoder_input_ids from input_ids if no decoder_input_ids are provided
|
||||
if decoder_input_ids is None and decoder_inputs_embeds is None:
|
||||
if input_ids is None:
|
||||
raise ValueError(
|
||||
"If no `decoder_input_ids` or `decoder_inputs_embeds` are "
|
||||
"passed, `input_ids` cannot be `None`. Please pass either "
|
||||
"`input_ids` or `decoder_input_ids` or `decoder_inputs_embeds`."
|
||||
)
|
||||
decoder_input_ids = self._shift_right(input_ids)
|
||||
|
||||
outputs = self.transformer(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
head_mask=head_mask,
|
||||
decoder_head_mask=decoder_head_mask,
|
||||
cross_attn_head_mask=cross_attn_head_mask,
|
||||
encoder_outputs=encoder_outputs,
|
||||
inputs_embeds=inputs_embeds,
|
||||
decoder_inputs_embeds=decoder_inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
sequence_output = outputs[0]
|
||||
|
||||
eos_mask = input_ids.eq(self.config.eos_token_id).to(sequence_output.device)
|
||||
|
||||
if len(torch.unique_consecutive(eos_mask.sum(1))) > 1:
|
||||
raise ValueError("All examples must have the same number of <eos> tokens.")
|
||||
batch_size, _, hidden_size = sequence_output.shape
|
||||
sentence_representation = sequence_output[eos_mask, :].view(batch_size, -1, hidden_size)[:, -1, :]
|
||||
logits = self.classification_head(sentence_representation)
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
labels = labels.to(logits.device)
|
||||
if self.config.problem_type is None:
|
||||
if self.config.num_labels == 1:
|
||||
self.config.problem_type = "regression"
|
||||
elif self.config.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
||||
self.config.problem_type = "single_label_classification"
|
||||
else:
|
||||
self.config.problem_type = "multi_label_classification"
|
||||
|
||||
if self.config.problem_type == "regression":
|
||||
loss_fct = MSELoss()
|
||||
if self.config.num_labels == 1:
|
||||
loss = loss_fct(logits.squeeze(), labels.squeeze())
|
||||
else:
|
||||
loss = loss_fct(logits, labels)
|
||||
elif self.config.problem_type == "single_label_classification":
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
|
||||
elif self.config.problem_type == "multi_label_classification":
|
||||
loss_fct = BCEWithLogitsLoss()
|
||||
loss = loss_fct(logits, labels)
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return Seq2SeqSequenceClassifierOutput(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
decoder_hidden_states=outputs.decoder_hidden_states,
|
||||
decoder_attentions=outputs.decoder_attentions,
|
||||
cross_attentions=outputs.cross_attentions,
|
||||
encoder_last_hidden_state=outputs.encoder_last_hidden_state,
|
||||
encoder_hidden_states=outputs.encoder_hidden_states,
|
||||
encoder_attentions=outputs.encoder_attentions,
|
||||
)
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""
|
||||
UMT5 Model with a span classification head on top for extractive question-answering tasks like SQuAD (linear layers
|
||||
|
@ -5219,6 +5219,13 @@ class MT5ForQuestionAnswering(metaclass=DummyObject):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class MT5ForSequenceClassification(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class MT5Model(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
@ -7070,6 +7077,13 @@ class T5ForQuestionAnswering(metaclass=DummyObject):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class T5ForSequenceClassification(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class T5Model(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
@ -7320,6 +7334,13 @@ class UMT5ForQuestionAnswering(metaclass=DummyObject):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class UMT5ForSequenceClassification(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class UMT5Model(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
|
@ -15,10 +15,13 @@
|
||||
|
||||
|
||||
import copy
|
||||
import os
|
||||
import pickle
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
from transformers import T5Config, is_torch_available
|
||||
from transformers.models.auto.modeling_auto import MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES
|
||||
from transformers.testing_utils import (
|
||||
require_accelerate,
|
||||
require_sentencepiece,
|
||||
@ -27,14 +30,18 @@ from transformers.testing_utils import (
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
from transformers.utils import cached_property
|
||||
from transformers.utils import cached_property, is_torch_fx_available
|
||||
|
||||
from ...generation.test_utils import GenerationTesterMixin
|
||||
from ...test_configuration_common import ConfigTester
|
||||
from ...test_modeling_common import ModelTesterMixin, ids_tensor
|
||||
from ...test_modeling_common import ModelTesterMixin, _config_zero_init, ids_tensor
|
||||
from ...test_pipeline_mixin import PipelineTesterMixin
|
||||
|
||||
|
||||
if is_torch_fx_available():
|
||||
from transformers.utils.fx import symbolic_trace
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
@ -44,6 +51,7 @@ if is_torch_available():
|
||||
T5EncoderModel,
|
||||
T5ForConditionalGeneration,
|
||||
T5ForQuestionAnswering,
|
||||
T5ForSequenceClassification,
|
||||
T5Model,
|
||||
T5Tokenizer,
|
||||
)
|
||||
@ -57,7 +65,7 @@ class T5ModelTester:
|
||||
vocab_size=99,
|
||||
batch_size=13,
|
||||
encoder_seq_length=7,
|
||||
decoder_seq_length=9,
|
||||
decoder_seq_length=7,
|
||||
# For common tests
|
||||
is_training=True,
|
||||
use_attention_mask=True,
|
||||
@ -102,7 +110,8 @@ class T5ModelTester:
|
||||
return T5Config.from_pretrained("t5-base")
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
input_ids = ids_tensor([self.batch_size, self.encoder_seq_length], self.vocab_size)
|
||||
input_ids = ids_tensor([self.batch_size, self.encoder_seq_length], self.vocab_size).clamp(2)
|
||||
input_ids[:, -1] = self.eos_token_id # Eos Token
|
||||
decoder_input_ids = ids_tensor([self.batch_size, self.decoder_seq_length], self.vocab_size)
|
||||
|
||||
attention_mask = None
|
||||
@ -251,6 +260,26 @@ class T5ModelTester:
|
||||
self.parent.assertEqual(outputs["logits"].size(), (self.batch_size, self.decoder_seq_length, self.vocab_size))
|
||||
self.parent.assertEqual(outputs["loss"].size(), ())
|
||||
|
||||
def create_and_check_with_sequence_classification_head(
|
||||
self,
|
||||
config,
|
||||
input_ids,
|
||||
decoder_input_ids,
|
||||
attention_mask,
|
||||
decoder_attention_mask,
|
||||
lm_labels,
|
||||
):
|
||||
labels = torch.tensor([1] * self.batch_size, dtype=torch.long, device=torch_device)
|
||||
model = T5ForSequenceClassification(config=config).to(torch_device).eval()
|
||||
outputs = model(
|
||||
input_ids=input_ids,
|
||||
decoder_input_ids=input_ids,
|
||||
labels=labels,
|
||||
)
|
||||
# self.parent.assertEqual(len(outputs), 4)
|
||||
self.parent.assertEqual(outputs["logits"].size(), (self.batch_size, config.num_labels))
|
||||
self.parent.assertEqual(outputs["loss"].size(), ())
|
||||
|
||||
def create_and_check_decoder_model_past(
|
||||
self,
|
||||
config,
|
||||
@ -521,7 +550,11 @@ class T5ModelTester:
|
||||
|
||||
@require_torch
|
||||
class T5ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (T5Model, T5ForConditionalGeneration, T5ForQuestionAnswering) if is_torch_available() else ()
|
||||
all_model_classes = (
|
||||
(T5Model, T5ForConditionalGeneration, T5ForSequenceClassification, T5ForQuestionAnswering)
|
||||
if is_torch_available()
|
||||
else ()
|
||||
)
|
||||
all_generative_model_classes = (T5ForConditionalGeneration,) if is_torch_available() else ()
|
||||
pipeline_model_mapping = (
|
||||
{
|
||||
@ -531,6 +564,8 @@ class T5ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
|
||||
"text2text-generation": T5ForConditionalGeneration,
|
||||
"translation": T5ForConditionalGeneration,
|
||||
"question-answering": T5ForQuestionAnswering,
|
||||
"text-classification": T5ForSequenceClassification,
|
||||
"zero-shot": T5ForSequenceClassification,
|
||||
}
|
||||
if is_torch_available()
|
||||
else {}
|
||||
@ -548,6 +583,126 @@ class T5ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
|
||||
self.model_tester = T5ModelTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=T5Config, d_model=37)
|
||||
|
||||
def _create_and_check_torch_fx_tracing(self, config, inputs_dict, output_loss=False):
|
||||
if not is_torch_fx_available() or not self.fx_compatible:
|
||||
return
|
||||
|
||||
configs_no_init = _config_zero_init(config) # To be sure we have no Nan
|
||||
configs_no_init.return_dict = False
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
if model_class.__name__ == "T5ForSequenceClassification":
|
||||
continue
|
||||
model = model_class(config=configs_no_init)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=output_loss)
|
||||
|
||||
try:
|
||||
if model.config.is_encoder_decoder:
|
||||
model.config.use_cache = False # FSTM still requires this hack -> FSTM should probably be refactored similar to BART afterward
|
||||
labels = inputs.get("labels", None)
|
||||
input_names = [
|
||||
"attention_mask",
|
||||
"decoder_attention_mask",
|
||||
"decoder_input_ids",
|
||||
"input_features",
|
||||
"input_ids",
|
||||
"input_values",
|
||||
]
|
||||
if labels is not None:
|
||||
input_names.append("labels")
|
||||
|
||||
filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names}
|
||||
input_names = list(filtered_inputs.keys())
|
||||
|
||||
model_output = model(**filtered_inputs)
|
||||
|
||||
traced_model = symbolic_trace(model, input_names)
|
||||
traced_output = traced_model(**filtered_inputs)
|
||||
else:
|
||||
input_names = [
|
||||
"attention_mask",
|
||||
"bbox",
|
||||
"input_features",
|
||||
"input_ids",
|
||||
"input_values",
|
||||
"pixel_values",
|
||||
"token_type_ids",
|
||||
"visual_feats",
|
||||
"visual_pos",
|
||||
]
|
||||
|
||||
labels = inputs.get("labels", None)
|
||||
start_positions = inputs.get("start_positions", None)
|
||||
end_positions = inputs.get("end_positions", None)
|
||||
if labels is not None:
|
||||
input_names.append("labels")
|
||||
if start_positions is not None:
|
||||
input_names.append("start_positions")
|
||||
if end_positions is not None:
|
||||
input_names.append("end_positions")
|
||||
|
||||
filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names}
|
||||
input_names = list(filtered_inputs.keys())
|
||||
|
||||
if model.__class__.__name__ in set(MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES.values()) and (
|
||||
not hasattr(model.config, "problem_type") or model.config.problem_type is None
|
||||
):
|
||||
model.config.problem_type = "single_label_classification"
|
||||
|
||||
traced_model = symbolic_trace(model, input_names)
|
||||
traced_output = traced_model(**filtered_inputs)
|
||||
model_output = model(**filtered_inputs)
|
||||
|
||||
except Exception as e:
|
||||
self.fail(f"Couldn't trace module: {e}")
|
||||
|
||||
def flatten_output(output):
|
||||
flatten = []
|
||||
for x in output:
|
||||
if isinstance(x, (tuple, list)):
|
||||
flatten += flatten_output(x)
|
||||
elif not isinstance(x, torch.Tensor):
|
||||
continue
|
||||
else:
|
||||
flatten.append(x)
|
||||
return flatten
|
||||
|
||||
model_output = flatten_output(model_output)
|
||||
traced_output = flatten_output(traced_output)
|
||||
num_outputs = len(model_output)
|
||||
|
||||
for i in range(num_outputs):
|
||||
self.assertTrue(
|
||||
torch.allclose(model_output[i], traced_output[i]),
|
||||
f"traced {i}th output doesn't match model {i}th output for {model_class}",
|
||||
)
|
||||
|
||||
# Test that the model can be serialized and restored properly
|
||||
with tempfile.TemporaryDirectory() as tmp_dir_name:
|
||||
pkl_file_name = os.path.join(tmp_dir_name, "model.pkl")
|
||||
try:
|
||||
with open(pkl_file_name, "wb") as f:
|
||||
pickle.dump(traced_model, f)
|
||||
with open(pkl_file_name, "rb") as f:
|
||||
loaded = pickle.load(f)
|
||||
except Exception as e:
|
||||
self.fail(f"Couldn't serialize / deserialize the traced model: {e}")
|
||||
|
||||
loaded_output = loaded(**filtered_inputs)
|
||||
loaded_output = flatten_output(loaded_output)
|
||||
|
||||
for i in range(num_outputs):
|
||||
self.assertTrue(
|
||||
torch.allclose(model_output[i], loaded_output[i]),
|
||||
f"serialized model {i}th output doesn't match model {i}th output for {model_class}",
|
||||
)
|
||||
|
||||
# Avoid memory leak. Without this, each call increase RAM usage by ~20MB.
|
||||
# (Even with this call, there are still memory leak by ~0.04MB)
|
||||
self.clear_torch_jit_class_registry()
|
||||
|
||||
def test_config(self):
|
||||
self.config_tester.run_common_tests()
|
||||
|
||||
@ -567,6 +722,36 @@ class T5ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
|
||||
config.feed_forward_proj = "gated-gelu"
|
||||
self.model_tester.create_and_check_model(config, *config_and_inputs[1:])
|
||||
|
||||
# T5ForSequenceClassification does not support inputs_embeds
|
||||
def test_inputs_embeds(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in (T5Model, T5ForConditionalGeneration, T5ForQuestionAnswering):
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
inputs = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class))
|
||||
|
||||
if not self.is_encoder_decoder:
|
||||
input_ids = inputs["input_ids"]
|
||||
del inputs["input_ids"]
|
||||
else:
|
||||
encoder_input_ids = inputs["input_ids"]
|
||||
decoder_input_ids = inputs.get("decoder_input_ids", encoder_input_ids)
|
||||
del inputs["input_ids"]
|
||||
inputs.pop("decoder_input_ids", None)
|
||||
|
||||
wte = model.get_input_embeddings()
|
||||
if not self.is_encoder_decoder:
|
||||
inputs["inputs_embeds"] = wte(input_ids)
|
||||
else:
|
||||
inputs["inputs_embeds"] = wte(encoder_input_ids)
|
||||
inputs["decoder_inputs_embeds"] = wte(decoder_input_ids)
|
||||
|
||||
with torch.no_grad():
|
||||
model(**inputs)[0]
|
||||
|
||||
def test_config_and_model_silu_gated(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
config = config_and_inputs[0]
|
||||
@ -577,6 +762,10 @@ class T5ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_with_lm_head(*config_and_inputs)
|
||||
|
||||
def test_with_sequence_classification_head(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_with_sequence_classification_head(*config_and_inputs)
|
||||
|
||||
def test_decoder_model_past(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_decoder_model_past(*config_and_inputs)
|
||||
|
@ -12,10 +12,14 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import copy
|
||||
import os
|
||||
import pickle
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
from transformers import T5Config, is_torch_available
|
||||
from transformers.models.auto.modeling_auto import MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES
|
||||
from transformers.testing_utils import (
|
||||
require_sentencepiece,
|
||||
require_tokenizers,
|
||||
@ -23,16 +27,27 @@ from transformers.testing_utils import (
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
from transformers.utils import is_torch_fx_available
|
||||
|
||||
from ...generation.test_utils import GenerationTesterMixin
|
||||
from ...test_modeling_common import ModelTesterMixin, ids_tensor
|
||||
from ...test_modeling_common import ModelTesterMixin, _config_zero_init, ids_tensor
|
||||
from ...test_pipeline_mixin import PipelineTesterMixin
|
||||
|
||||
|
||||
if is_torch_fx_available():
|
||||
from transformers.utils.fx import symbolic_trace
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers import AutoTokenizer, UMT5ForConditionalGeneration, UMT5ForQuestionAnswering, UMT5Model
|
||||
from transformers import (
|
||||
AutoTokenizer,
|
||||
UMT5ForConditionalGeneration,
|
||||
UMT5ForQuestionAnswering,
|
||||
UMT5ForSequenceClassification,
|
||||
UMT5Model,
|
||||
)
|
||||
|
||||
|
||||
# Copied from test.models.t5.test_modeling_t5.T5ModelTester with T5->UMT5
|
||||
@ -43,7 +58,7 @@ class UMT5ModelTester:
|
||||
vocab_size=99,
|
||||
batch_size=13,
|
||||
encoder_seq_length=7,
|
||||
decoder_seq_length=9,
|
||||
decoder_seq_length=7,
|
||||
# For common tests
|
||||
is_training=True,
|
||||
use_attention_mask=True,
|
||||
@ -131,7 +146,8 @@ class UMT5ModelTester:
|
||||
# but when using past, there is no way of knowing if the past input ids had
|
||||
# pad tokens in them, which results in incorrect seq_lenth and which in turn results in
|
||||
# position_ids being off by num_pad_tokens in past input
|
||||
input_ids = input_ids.clamp(self.pad_token_id + 1)
|
||||
input_ids = input_ids.clamp(self.pad_token_id + 2)
|
||||
input_ids[:, -1] = self.eos_token_id # Eos Token
|
||||
decoder_input_ids = decoder_input_ids.clamp(self.pad_token_id + 1)
|
||||
|
||||
config = self.get_config()
|
||||
@ -255,11 +271,25 @@ class UMT5ModelTester:
|
||||
output = model(**input_dict)["last_hidden_state"]
|
||||
self.parent.assertFalse(torch.isnan(output).any().item())
|
||||
|
||||
def create_and_check_with_sequence_classification_head(
|
||||
self,
|
||||
config,
|
||||
input_dict,
|
||||
):
|
||||
labels = torch.tensor([1] * self.batch_size, dtype=torch.long, device=torch_device)
|
||||
model = UMT5ForSequenceClassification(config=config).to(torch_device).eval()
|
||||
outputs = model(**input_dict, labels=labels)
|
||||
# self.parent.assertEqual(len(outputs), 4)
|
||||
self.parent.assertEqual(outputs["logits"].size(), (self.batch_size, config.num_labels))
|
||||
self.parent.assertEqual(outputs["loss"].size(), ())
|
||||
|
||||
|
||||
@require_torch
|
||||
class UMT5ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (
|
||||
(UMT5Model, UMT5ForConditionalGeneration, UMT5ForQuestionAnswering) if is_torch_available() else ()
|
||||
(UMT5Model, UMT5ForConditionalGeneration, UMT5ForSequenceClassification, UMT5ForQuestionAnswering)
|
||||
if is_torch_available()
|
||||
else ()
|
||||
)
|
||||
all_generative_model_classes = (UMT5ForConditionalGeneration,) if is_torch_available() else ()
|
||||
pipeline_model_mapping = (
|
||||
@ -270,6 +300,8 @@ class UMT5ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
|
||||
"text2text-generation": UMT5ForConditionalGeneration,
|
||||
"translation": UMT5ForConditionalGeneration,
|
||||
"question-answering": UMT5ForQuestionAnswering,
|
||||
"text-classification": UMT5ForSequenceClassification,
|
||||
"zero-shot": UMT5ForSequenceClassification,
|
||||
}
|
||||
if is_torch_available()
|
||||
else {}
|
||||
@ -285,6 +317,160 @@ class UMT5ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
|
||||
def setUp(self):
|
||||
self.model_tester = UMT5ModelTester(self)
|
||||
|
||||
def _create_and_check_torch_fx_tracing(self, config, inputs_dict, output_loss=False):
|
||||
if not is_torch_fx_available() or not self.fx_compatible:
|
||||
return
|
||||
|
||||
configs_no_init = _config_zero_init(config) # To be sure we have no Nan
|
||||
configs_no_init.return_dict = False
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
if model_class.__name__ == "UMT5ForSequenceClassification":
|
||||
continue
|
||||
model = model_class(config=configs_no_init)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=output_loss)
|
||||
|
||||
try:
|
||||
if model.config.is_encoder_decoder:
|
||||
model.config.use_cache = False # FSTM still requires this hack -> FSTM should probably be refactored similar to BART afterward
|
||||
labels = inputs.get("labels", None)
|
||||
input_names = [
|
||||
"attention_mask",
|
||||
"decoder_attention_mask",
|
||||
"decoder_input_ids",
|
||||
"input_features",
|
||||
"input_ids",
|
||||
"input_values",
|
||||
]
|
||||
if labels is not None:
|
||||
input_names.append("labels")
|
||||
|
||||
filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names}
|
||||
input_names = list(filtered_inputs.keys())
|
||||
|
||||
model_output = model(**filtered_inputs)
|
||||
|
||||
traced_model = symbolic_trace(model, input_names)
|
||||
traced_output = traced_model(**filtered_inputs)
|
||||
else:
|
||||
input_names = [
|
||||
"attention_mask",
|
||||
"bbox",
|
||||
"input_features",
|
||||
"input_ids",
|
||||
"input_values",
|
||||
"pixel_values",
|
||||
"token_type_ids",
|
||||
"visual_feats",
|
||||
"visual_pos",
|
||||
]
|
||||
|
||||
labels = inputs.get("labels", None)
|
||||
start_positions = inputs.get("start_positions", None)
|
||||
end_positions = inputs.get("end_positions", None)
|
||||
if labels is not None:
|
||||
input_names.append("labels")
|
||||
if start_positions is not None:
|
||||
input_names.append("start_positions")
|
||||
if end_positions is not None:
|
||||
input_names.append("end_positions")
|
||||
|
||||
filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names}
|
||||
input_names = list(filtered_inputs.keys())
|
||||
|
||||
if model.__class__.__name__ in set(MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES.values()) and (
|
||||
not hasattr(model.config, "problem_type") or model.config.problem_type is None
|
||||
):
|
||||
model.config.problem_type = "single_label_classification"
|
||||
|
||||
traced_model = symbolic_trace(model, input_names)
|
||||
traced_output = traced_model(**filtered_inputs)
|
||||
model_output = model(**filtered_inputs)
|
||||
|
||||
except Exception as e:
|
||||
self.fail(f"Couldn't trace module: {e}")
|
||||
|
||||
def flatten_output(output):
|
||||
flatten = []
|
||||
for x in output:
|
||||
if isinstance(x, (tuple, list)):
|
||||
flatten += flatten_output(x)
|
||||
elif not isinstance(x, torch.Tensor):
|
||||
continue
|
||||
else:
|
||||
flatten.append(x)
|
||||
return flatten
|
||||
|
||||
model_output = flatten_output(model_output)
|
||||
traced_output = flatten_output(traced_output)
|
||||
num_outputs = len(model_output)
|
||||
|
||||
for i in range(num_outputs):
|
||||
self.assertTrue(
|
||||
torch.allclose(model_output[i], traced_output[i]),
|
||||
f"traced {i}th output doesn't match model {i}th output for {model_class}",
|
||||
)
|
||||
|
||||
# Test that the model can be serialized and restored properly
|
||||
with tempfile.TemporaryDirectory() as tmp_dir_name:
|
||||
pkl_file_name = os.path.join(tmp_dir_name, "model.pkl")
|
||||
try:
|
||||
with open(pkl_file_name, "wb") as f:
|
||||
pickle.dump(traced_model, f)
|
||||
with open(pkl_file_name, "rb") as f:
|
||||
loaded = pickle.load(f)
|
||||
except Exception as e:
|
||||
self.fail(f"Couldn't serialize / deserialize the traced model: {e}")
|
||||
|
||||
loaded_output = loaded(**filtered_inputs)
|
||||
loaded_output = flatten_output(loaded_output)
|
||||
|
||||
for i in range(num_outputs):
|
||||
self.assertTrue(
|
||||
torch.allclose(model_output[i], loaded_output[i]),
|
||||
f"serialized model {i}th output doesn't match model {i}th output for {model_class}",
|
||||
)
|
||||
|
||||
# Avoid memory leak. Without this, each call increase RAM usage by ~20MB.
|
||||
# (Even with this call, there are still memory leak by ~0.04MB)
|
||||
self.clear_torch_jit_class_registry()
|
||||
|
||||
# UMT5ForSequenceClassification does not support inputs_embeds
|
||||
def test_inputs_embeds(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in (UMT5Model, UMT5ForConditionalGeneration, UMT5ForQuestionAnswering):
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
inputs = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class))
|
||||
|
||||
if not self.is_encoder_decoder:
|
||||
input_ids = inputs["input_ids"]
|
||||
del inputs["input_ids"]
|
||||
else:
|
||||
encoder_input_ids = inputs["input_ids"]
|
||||
decoder_input_ids = inputs.get("decoder_input_ids", encoder_input_ids)
|
||||
del inputs["input_ids"]
|
||||
inputs.pop("decoder_input_ids", None)
|
||||
|
||||
wte = model.get_input_embeddings()
|
||||
if not self.is_encoder_decoder:
|
||||
inputs["inputs_embeds"] = wte(input_ids)
|
||||
else:
|
||||
inputs["inputs_embeds"] = wte(encoder_input_ids)
|
||||
inputs["decoder_inputs_embeds"] = wte(decoder_input_ids)
|
||||
|
||||
with torch.no_grad():
|
||||
model(**inputs)[0]
|
||||
|
||||
def test_with_sequence_classification_head(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_with_sequence_classification_head(*config_and_inputs)
|
||||
|
||||
@unittest.skip("Test has a segmentation fault on torch 1.8.0")
|
||||
def test_export_to_onnx(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
|
Loading…
Reference in New Issue
Block a user