ElectraForMultipleChoice (#4954)

* add ElectraForMultipleChoice

* add  test_for_multiple_choice

* add ElectraForMultipleChoice in auto model

* add ElectraForMultipleChoice in all_model_classes

* add SequenceSummary related parameters

* get rid pooler, use SequenceSummary instead

* add electra multiple choice test

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
This commit is contained in:
Suraj Patil 2020-06-19 00:29:35 +05:30 committed by GitHub
parent 279d8e24f7
commit ca2d0f98c4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 182 additions and 0 deletions

View File

@ -333,6 +333,7 @@ if is_torch_available():
ElectraForMaskedLM,
ElectraForTokenClassification,
ElectraPreTrainedModel,
ElectraForMultipleChoice,
ElectraForSequenceClassification,
ElectraForQuestionAnswering,
ElectraModel,

View File

@ -76,6 +76,27 @@ class ElectraConfig(PretrainedConfig):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
layer_norm_eps (:obj:`float`, optional, defaults to 1e-12):
The epsilon used by the layer normalization layers.
summary_type (:obj:`string`, optional, defaults to "first"):
Argument used when doing sequence summary. Used in for the multiple choice head in
:class:`~transformers.ElectraForMultipleChoice`.
Is one of the following options:
- 'last' => take the last token hidden state (like XLNet)
- 'first' => take the first token hidden state (like Bert)
- 'mean' => take the mean of all tokens hidden states
- 'cls_index' => supply a Tensor of classification token position (GPT/GPT-2)
- 'attn' => Not implemented now, use multi-head attention
summary_use_proj (:obj:`boolean`, optional, defaults to :obj:`True`):
Argument used when doing sequence summary. Used in for the multiple choice head in
:class:`~transformers.ElectraForMultipleChoice`.
Add a projection after the vector extraction
summary_activation (:obj:`string` or :obj:`None`, optional, defaults to :obj:`None`):
Argument used when doing sequence summary. Used in for the multiple choice head in
:class:`~transformers.ElectraForMultipleChoice`.
'gelu' => add a gelu activation to the output, Other => no activation.
summary_last_dropout (:obj:`float`, optional, defaults to 0.0):
Argument used when doing sequence summary. Used in for the multiple choice head in
:class:`~transformers.ElectraForMultipleChoice`.
Add a dropout after the projection and activation
Example::
@ -107,6 +128,10 @@ class ElectraConfig(PretrainedConfig):
type_vocab_size=2,
initializer_range=0.02,
layer_norm_eps=1e-12,
summary_type="first",
summary_use_proj=True,
summary_activation="gelu",
summary_last_dropout=0.1,
pad_token_id=0,
**kwargs
):
@ -125,3 +150,8 @@ class ElectraConfig(PretrainedConfig):
self.type_vocab_size = type_vocab_size
self.initializer_range = initializer_range
self.layer_norm_eps = layer_norm_eps
self.summary_type = summary_type
self.summary_use_proj = summary_use_proj
self.summary_activation = summary_activation
self.summary_last_dropout = summary_last_dropout

View File

@ -87,6 +87,7 @@ from .modeling_distilbert import (
)
from .modeling_electra import (
ElectraForMaskedLM,
ElectraForMultipleChoice,
ElectraForPreTraining,
ElectraForQuestionAnswering,
ElectraForSequenceClassification,
@ -315,6 +316,7 @@ MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = OrderedDict(
MODEL_FOR_MULTIPLE_CHOICE_MAPPING = OrderedDict(
[
(CamembertConfig, CamembertForMultipleChoice),
(ElectraConfig, ElectraForMultipleChoice),
(XLMRobertaConfig, XLMRobertaForMultipleChoice),
(LongformerConfig, LongformerForMultipleChoice),
(RobertaConfig, RobertaForMultipleChoice),

View File

@ -10,6 +10,7 @@ from .activations import get_activation
from .configuration_electra import ElectraConfig
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
from .modeling_bert import BertEmbeddings, BertEncoder, BertLayerNorm, BertPreTrainedModel
from .modeling_utils import SequenceSummary
logger = logging.getLogger(__name__)
@ -860,3 +861,115 @@ class ElectraForQuestionAnswering(ElectraPreTrainedModel):
outputs = (total_loss,) + outputs
return outputs # (loss), start_logits, end_logits, (hidden_states), (attentions)
@add_start_docstrings(
"""ELECTRA Model with a multiple choice classification head on top (a linear layer on top of
the pooled output and a softmax) e.g. for RocStories/SWAG tasks. """,
ELECTRA_INPUTS_DOCSTRING,
)
class ElectraForMultipleChoice(ElectraPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.electra = ElectraModel(config)
self.summary = SequenceSummary(config)
self.classifier = nn.Linear(config.hidden_size, 1)
self.init_weights()
@add_start_docstrings_to_callable(ELECTRA_INPUTS_DOCSTRING.format("(batch_size, num_choices, sequence_length)"))
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
labels=None,
output_attentions=None,
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
Labels for computing the multiple choice classification loss.
Indices should be in ``[0, ..., num_choices-1]`` where `num_choices` is the size of the second dimension
of the input tensors. (see `input_ids` above)
Returns:
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.ElectraConfig`) and inputs:
loss (:obj:`torch.FloatTensor` of shape `(1,)`, `optional`, returned when :obj:`labels` is provided):
Classification loss.
classification_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_choices)`):
`num_choices` is the second dimension of the input tensors. (see `input_ids` above).
Classification scores (before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
Examples::
from transformers import ElectraTokenizer, ElectraForMultipleChoice
import torch
tokenizer = ElectraTokenizer.from_pretrained('google/electra-base-discriminator')
model = ElectraForMultipleChoice.from_pretrained('google/electra-base-discriminator')
prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
choice0 = "It is eaten with a fork and a knife."
choice1 = "It is eaten while held in the hand."
labels = torch.tensor(0) # choice0 is correct (according to Wikipedia ;))
encoding = tokenizer.batch_encode_plus([[prompt, choice0], [prompt, choice1]], return_tensors='pt', pad_to_max_length=True)
outputs = model(**{k: v.unsqueeze(0) for k,v in encoding.items()}, labels=labels) # batch size is 1
# the linear classifier still needs to be trained
loss, logits = outputs[:2]
"""
num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
inputs_embeds = (
inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
if inputs_embeds is not None
else None
)
discriminator_hidden_states = self.electra(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
)
sequence_output = discriminator_hidden_states[0]
pooled_output = self.summary(sequence_output)
logits = self.classifier(pooled_output)
reshaped_logits = logits.view(-1, num_choices)
outputs = (reshaped_logits,) + discriminator_hidden_states[
1:
] # add hidden states and attention if they are here
if labels is not None:
loss_fct = CrossEntropyLoss()
loss = loss_fct(reshaped_logits, labels)
outputs = (loss,) + outputs
return outputs # (loss), reshaped_logits, (hidden_states), (attentions)

View File

@ -30,6 +30,7 @@ if is_torch_available():
ElectraForMaskedLM,
ElectraForTokenClassification,
ElectraForPreTraining,
ElectraForMultipleChoice,
ElectraForSequenceClassification,
ElectraForQuestionAnswering,
)
@ -266,6 +267,37 @@ class ElectraModelTester:
self.parent.assertListEqual(list(result["end_logits"].size()), [self.batch_size, self.seq_length])
self.check_loss_output(result)
def create_and_check_electra_for_multiple_choice(
self,
config,
input_ids,
token_type_ids,
input_mask,
sequence_labels,
token_labels,
choice_labels,
fake_token_labels,
):
config.num_choices = self.num_choices
model = ElectraForMultipleChoice(config=config)
model.to(torch_device)
model.eval()
multiple_choice_inputs_ids = input_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
multiple_choice_token_type_ids = token_type_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
multiple_choice_input_mask = input_mask.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
loss, logits = model(
multiple_choice_inputs_ids,
attention_mask=multiple_choice_input_mask,
token_type_ids=multiple_choice_token_type_ids,
labels=choice_labels,
)
result = {
"loss": loss,
"logits": logits,
}
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.num_choices])
self.check_loss_output(result)
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
(
@ -329,6 +361,10 @@ class ElectraModelTest(ModelTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_electra_for_question_answering(*config_and_inputs)
def test_for_multiple_choice(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_electra_for_multiple_choice(*config_and_inputs)
@slow
def test_model_from_pretrained(self):
for model_name in ELECTRA_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: