mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
ElectraForMultipleChoice (#4954)
* add ElectraForMultipleChoice * add test_for_multiple_choice * add ElectraForMultipleChoice in auto model * add ElectraForMultipleChoice in all_model_classes * add SequenceSummary related parameters * get rid pooler, use SequenceSummary instead * add electra multiple choice test Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
This commit is contained in:
parent
279d8e24f7
commit
ca2d0f98c4
@ -333,6 +333,7 @@ if is_torch_available():
|
||||
ElectraForMaskedLM,
|
||||
ElectraForTokenClassification,
|
||||
ElectraPreTrainedModel,
|
||||
ElectraForMultipleChoice,
|
||||
ElectraForSequenceClassification,
|
||||
ElectraForQuestionAnswering,
|
||||
ElectraModel,
|
||||
|
@ -76,6 +76,27 @@ class ElectraConfig(PretrainedConfig):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
layer_norm_eps (:obj:`float`, optional, defaults to 1e-12):
|
||||
The epsilon used by the layer normalization layers.
|
||||
summary_type (:obj:`string`, optional, defaults to "first"):
|
||||
Argument used when doing sequence summary. Used in for the multiple choice head in
|
||||
:class:`~transformers.ElectraForMultipleChoice`.
|
||||
Is one of the following options:
|
||||
- 'last' => take the last token hidden state (like XLNet)
|
||||
- 'first' => take the first token hidden state (like Bert)
|
||||
- 'mean' => take the mean of all tokens hidden states
|
||||
- 'cls_index' => supply a Tensor of classification token position (GPT/GPT-2)
|
||||
- 'attn' => Not implemented now, use multi-head attention
|
||||
summary_use_proj (:obj:`boolean`, optional, defaults to :obj:`True`):
|
||||
Argument used when doing sequence summary. Used in for the multiple choice head in
|
||||
:class:`~transformers.ElectraForMultipleChoice`.
|
||||
Add a projection after the vector extraction
|
||||
summary_activation (:obj:`string` or :obj:`None`, optional, defaults to :obj:`None`):
|
||||
Argument used when doing sequence summary. Used in for the multiple choice head in
|
||||
:class:`~transformers.ElectraForMultipleChoice`.
|
||||
'gelu' => add a gelu activation to the output, Other => no activation.
|
||||
summary_last_dropout (:obj:`float`, optional, defaults to 0.0):
|
||||
Argument used when doing sequence summary. Used in for the multiple choice head in
|
||||
:class:`~transformers.ElectraForMultipleChoice`.
|
||||
Add a dropout after the projection and activation
|
||||
|
||||
Example::
|
||||
|
||||
@ -107,6 +128,10 @@ class ElectraConfig(PretrainedConfig):
|
||||
type_vocab_size=2,
|
||||
initializer_range=0.02,
|
||||
layer_norm_eps=1e-12,
|
||||
summary_type="first",
|
||||
summary_use_proj=True,
|
||||
summary_activation="gelu",
|
||||
summary_last_dropout=0.1,
|
||||
pad_token_id=0,
|
||||
**kwargs
|
||||
):
|
||||
@ -125,3 +150,8 @@ class ElectraConfig(PretrainedConfig):
|
||||
self.type_vocab_size = type_vocab_size
|
||||
self.initializer_range = initializer_range
|
||||
self.layer_norm_eps = layer_norm_eps
|
||||
|
||||
self.summary_type = summary_type
|
||||
self.summary_use_proj = summary_use_proj
|
||||
self.summary_activation = summary_activation
|
||||
self.summary_last_dropout = summary_last_dropout
|
||||
|
@ -87,6 +87,7 @@ from .modeling_distilbert import (
|
||||
)
|
||||
from .modeling_electra import (
|
||||
ElectraForMaskedLM,
|
||||
ElectraForMultipleChoice,
|
||||
ElectraForPreTraining,
|
||||
ElectraForQuestionAnswering,
|
||||
ElectraForSequenceClassification,
|
||||
@ -315,6 +316,7 @@ MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = OrderedDict(
|
||||
MODEL_FOR_MULTIPLE_CHOICE_MAPPING = OrderedDict(
|
||||
[
|
||||
(CamembertConfig, CamembertForMultipleChoice),
|
||||
(ElectraConfig, ElectraForMultipleChoice),
|
||||
(XLMRobertaConfig, XLMRobertaForMultipleChoice),
|
||||
(LongformerConfig, LongformerForMultipleChoice),
|
||||
(RobertaConfig, RobertaForMultipleChoice),
|
||||
|
@ -10,6 +10,7 @@ from .activations import get_activation
|
||||
from .configuration_electra import ElectraConfig
|
||||
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
|
||||
from .modeling_bert import BertEmbeddings, BertEncoder, BertLayerNorm, BertPreTrainedModel
|
||||
from .modeling_utils import SequenceSummary
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -860,3 +861,115 @@ class ElectraForQuestionAnswering(ElectraPreTrainedModel):
|
||||
outputs = (total_loss,) + outputs
|
||||
|
||||
return outputs # (loss), start_logits, end_logits, (hidden_states), (attentions)
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""ELECTRA Model with a multiple choice classification head on top (a linear layer on top of
|
||||
the pooled output and a softmax) e.g. for RocStories/SWAG tasks. """,
|
||||
ELECTRA_INPUTS_DOCSTRING,
|
||||
)
|
||||
class ElectraForMultipleChoice(ElectraPreTrainedModel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
self.electra = ElectraModel(config)
|
||||
self.summary = SequenceSummary(config)
|
||||
self.classifier = nn.Linear(config.hidden_size, 1)
|
||||
|
||||
self.init_weights()
|
||||
|
||||
@add_start_docstrings_to_callable(ELECTRA_INPUTS_DOCSTRING.format("(batch_size, num_choices, sequence_length)"))
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
labels=None,
|
||||
output_attentions=None,
|
||||
):
|
||||
r"""
|
||||
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
|
||||
Labels for computing the multiple choice classification loss.
|
||||
Indices should be in ``[0, ..., num_choices-1]`` where `num_choices` is the size of the second dimension
|
||||
of the input tensors. (see `input_ids` above)
|
||||
|
||||
Returns:
|
||||
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.ElectraConfig`) and inputs:
|
||||
loss (:obj:`torch.FloatTensor` of shape `(1,)`, `optional`, returned when :obj:`labels` is provided):
|
||||
Classification loss.
|
||||
classification_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_choices)`):
|
||||
`num_choices` is the second dimension of the input tensors. (see `input_ids` above).
|
||||
|
||||
Classification scores (before SoftMax).
|
||||
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
|
||||
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
|
||||
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
||||
|
||||
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or ``config.output_attentions=True``):
|
||||
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
|
||||
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
|
||||
|
||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
||||
heads.
|
||||
|
||||
Examples::
|
||||
|
||||
from transformers import ElectraTokenizer, ElectraForMultipleChoice
|
||||
import torch
|
||||
|
||||
tokenizer = ElectraTokenizer.from_pretrained('google/electra-base-discriminator')
|
||||
model = ElectraForMultipleChoice.from_pretrained('google/electra-base-discriminator')
|
||||
|
||||
prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
|
||||
choice0 = "It is eaten with a fork and a knife."
|
||||
choice1 = "It is eaten while held in the hand."
|
||||
labels = torch.tensor(0) # choice0 is correct (according to Wikipedia ;))
|
||||
|
||||
encoding = tokenizer.batch_encode_plus([[prompt, choice0], [prompt, choice1]], return_tensors='pt', pad_to_max_length=True)
|
||||
outputs = model(**{k: v.unsqueeze(0) for k,v in encoding.items()}, labels=labels) # batch size is 1
|
||||
|
||||
# the linear classifier still needs to be trained
|
||||
loss, logits = outputs[:2]
|
||||
"""
|
||||
num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
|
||||
|
||||
input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
|
||||
attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
|
||||
token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
|
||||
position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
|
||||
inputs_embeds = (
|
||||
inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
|
||||
if inputs_embeds is not None
|
||||
else None
|
||||
)
|
||||
|
||||
discriminator_hidden_states = self.electra(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
sequence_output = discriminator_hidden_states[0]
|
||||
|
||||
pooled_output = self.summary(sequence_output)
|
||||
logits = self.classifier(pooled_output)
|
||||
reshaped_logits = logits.view(-1, num_choices)
|
||||
|
||||
outputs = (reshaped_logits,) + discriminator_hidden_states[
|
||||
1:
|
||||
] # add hidden states and attention if they are here
|
||||
|
||||
if labels is not None:
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(reshaped_logits, labels)
|
||||
outputs = (loss,) + outputs
|
||||
|
||||
return outputs # (loss), reshaped_logits, (hidden_states), (attentions)
|
||||
|
@ -30,6 +30,7 @@ if is_torch_available():
|
||||
ElectraForMaskedLM,
|
||||
ElectraForTokenClassification,
|
||||
ElectraForPreTraining,
|
||||
ElectraForMultipleChoice,
|
||||
ElectraForSequenceClassification,
|
||||
ElectraForQuestionAnswering,
|
||||
)
|
||||
@ -266,6 +267,37 @@ class ElectraModelTester:
|
||||
self.parent.assertListEqual(list(result["end_logits"].size()), [self.batch_size, self.seq_length])
|
||||
self.check_loss_output(result)
|
||||
|
||||
def create_and_check_electra_for_multiple_choice(
|
||||
self,
|
||||
config,
|
||||
input_ids,
|
||||
token_type_ids,
|
||||
input_mask,
|
||||
sequence_labels,
|
||||
token_labels,
|
||||
choice_labels,
|
||||
fake_token_labels,
|
||||
):
|
||||
config.num_choices = self.num_choices
|
||||
model = ElectraForMultipleChoice(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
multiple_choice_inputs_ids = input_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
|
||||
multiple_choice_token_type_ids = token_type_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
|
||||
multiple_choice_input_mask = input_mask.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
|
||||
loss, logits = model(
|
||||
multiple_choice_inputs_ids,
|
||||
attention_mask=multiple_choice_input_mask,
|
||||
token_type_ids=multiple_choice_token_type_ids,
|
||||
labels=choice_labels,
|
||||
)
|
||||
result = {
|
||||
"loss": loss,
|
||||
"logits": logits,
|
||||
}
|
||||
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.num_choices])
|
||||
self.check_loss_output(result)
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
(
|
||||
@ -329,6 +361,10 @@ class ElectraModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_electra_for_question_answering(*config_and_inputs)
|
||||
|
||||
def test_for_multiple_choice(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_electra_for_multiple_choice(*config_and_inputs)
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
for model_name in ELECTRA_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
||||
|
Loading…
Reference in New Issue
Block a user