mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-23 22:38:58 +06:00
[Longformer] Multiple choice for longformer (#4645)
* add multiple choice for longformer * add models to docs * adapt docstring * add test to longformer * add longformer for mc in init and modeling auto * fix tests
This commit is contained in:
parent
91487cbb8e
commit
9c17256447
@ -94,3 +94,17 @@ TFAlbertForSequenceClassification
|
||||
|
||||
.. autoclass:: transformers.TFAlbertForSequenceClassification
|
||||
:members:
|
||||
|
||||
|
||||
TFAlbertForMultipleChoice
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.TFAlbertForMultipleChoice
|
||||
:members:
|
||||
|
||||
|
||||
TFAlbertForQuestionAnswering
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.TFAlbertForQuestionAnswering
|
||||
:members:
|
||||
|
@ -74,3 +74,18 @@ LongformerForQuestionAnswering
|
||||
|
||||
.. autoclass:: transformers.LongformerForQuestionAnswering
|
||||
:members:
|
||||
|
||||
|
||||
LongformerForMultipleChoice
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.LongformerForMultipleChoice
|
||||
:members:
|
||||
|
||||
|
||||
LongformerForTokenClassification
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.LongformerForTokenClassification
|
||||
:members:
|
||||
|
||||
|
@ -74,6 +74,13 @@ RobertaForSequenceClassification
|
||||
:members:
|
||||
|
||||
|
||||
RobertaForMultipleChoice
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.RobertaForMultipleChoice
|
||||
:members:
|
||||
|
||||
|
||||
RobertaForTokenClassification
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
|
@ -326,6 +326,7 @@ if is_torch_available():
|
||||
LongformerModel,
|
||||
LongformerForMaskedLM,
|
||||
LongformerForSequenceClassification,
|
||||
LongformerForMultipleChoice,
|
||||
LongformerForTokenClassification,
|
||||
LongformerForQuestionAnswering,
|
||||
LONGFORMER_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
|
@ -104,6 +104,7 @@ from .modeling_gpt2 import GPT2_PRETRAINED_MODEL_ARCHIVE_MAP, GPT2LMHeadModel, G
|
||||
from .modeling_longformer import (
|
||||
LONGFORMER_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
LongformerForMaskedLM,
|
||||
LongformerForMultipleChoice,
|
||||
LongformerForQuestionAnswering,
|
||||
LongformerForSequenceClassification,
|
||||
LongformerForTokenClassification,
|
||||
@ -297,6 +298,7 @@ MODEL_FOR_MULTIPLE_CHOICE_MAPPING = OrderedDict(
|
||||
[
|
||||
(CamembertConfig, CamembertForMultipleChoice),
|
||||
(XLMRobertaConfig, XLMRobertaForMultipleChoice),
|
||||
(LongformerConfig, LongformerForMultipleChoice),
|
||||
(RobertaConfig, RobertaForMultipleChoice),
|
||||
(BertConfig, BertForMultipleChoice),
|
||||
(XLNetConfig, XLNetForMultipleChoice),
|
||||
|
@ -543,7 +543,7 @@ BERT_START_DOCSTRING = r"""
|
||||
|
||||
BERT_INPUTS_DOCSTRING = r"""
|
||||
Args:
|
||||
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`):
|
||||
input_ids (:obj:`torch.LongTensor` of shape :obj:`{0}`):
|
||||
Indices of input sequence tokens in the vocabulary.
|
||||
|
||||
Indices can be obtained using :class:`transformers.BertTokenizer`.
|
||||
@ -551,19 +551,19 @@ BERT_INPUTS_DOCSTRING = r"""
|
||||
:func:`transformers.PreTrainedTokenizer.encode_plus` for details.
|
||||
|
||||
`What are input IDs? <../glossary.html#input-ids>`__
|
||||
attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
|
||||
attention_mask (:obj:`torch.FloatTensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`):
|
||||
Mask to avoid performing attention on padding token indices.
|
||||
Mask values selected in ``[0, 1]``:
|
||||
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
|
||||
|
||||
`What are attention masks? <../glossary.html#attention-mask>`__
|
||||
token_type_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
|
||||
token_type_ids (:obj:`torch.LongTensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`):
|
||||
Segment token indices to indicate first and second portions of the inputs.
|
||||
Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1``
|
||||
corresponds to a `sentence B` token
|
||||
|
||||
`What are token type IDs? <../glossary.html#token-type-ids>`_
|
||||
position_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
|
||||
position_ids (:obj:`torch.LongTensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`):
|
||||
Indices of positions of each input sequence tokens in the position embeddings.
|
||||
Selected in the range ``[0, config.max_position_embeddings - 1]``.
|
||||
|
||||
@ -632,7 +632,7 @@ class BertModel(BertPreTrainedModel):
|
||||
for layer, heads in heads_to_prune.items():
|
||||
self.encoder.layer[layer].attention.prune_heads(heads)
|
||||
|
||||
@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING)
|
||||
@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
@ -759,7 +759,7 @@ class BertForPreTraining(BertPreTrainedModel):
|
||||
def get_output_embeddings(self):
|
||||
return self.cls.predictions.decoder
|
||||
|
||||
@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING)
|
||||
@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
@ -859,7 +859,7 @@ class BertForMaskedLM(BertPreTrainedModel):
|
||||
def get_output_embeddings(self):
|
||||
return self.cls.predictions.decoder
|
||||
|
||||
@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING)
|
||||
@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
@ -992,7 +992,7 @@ class BertForNextSentencePrediction(BertPreTrainedModel):
|
||||
|
||||
self.init_weights()
|
||||
|
||||
@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING)
|
||||
@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
@ -1081,7 +1081,7 @@ class BertForSequenceClassification(BertPreTrainedModel):
|
||||
|
||||
self.init_weights()
|
||||
|
||||
@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING)
|
||||
@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
@ -1177,7 +1177,7 @@ class BertForMultipleChoice(BertPreTrainedModel):
|
||||
|
||||
self.init_weights()
|
||||
|
||||
@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING)
|
||||
@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, num_choices, sequence_length)"))
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
@ -1278,7 +1278,7 @@ class BertForTokenClassification(BertPreTrainedModel):
|
||||
|
||||
self.init_weights()
|
||||
|
||||
@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING)
|
||||
@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
@ -1375,7 +1375,7 @@ class BertForQuestionAnswering(BertPreTrainedModel):
|
||||
|
||||
self.init_weights()
|
||||
|
||||
@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING)
|
||||
@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
|
@ -411,7 +411,7 @@ LONGFORMER_START_DOCSTRING = r"""
|
||||
|
||||
LONGFORMER_INPUTS_DOCSTRING = r"""
|
||||
Args:
|
||||
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`):
|
||||
input_ids (:obj:`torch.LongTensor` of shape :obj:`{0}`):
|
||||
Indices of input sequence tokens in the vocabulary.
|
||||
|
||||
Indices can be obtained using :class:`transformers.LonmgformerTokenizer`.
|
||||
@ -419,7 +419,7 @@ LONGFORMER_INPUTS_DOCSTRING = r"""
|
||||
:func:`transformers.PreTrainedTokenizer.encode_plus` for details.
|
||||
|
||||
`What are input IDs? <../glossary.html#input-ids>`__
|
||||
attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
|
||||
attention_mask (:obj:`torch.FloatTensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`):
|
||||
Mask to decide the attention given on each token, local attention, global attenion, or no attention (for padding tokens).
|
||||
Tokens with global attention attends to all other tokens, and all other tokens attend to them. This is important for
|
||||
task-specific finetuning because it makes the model more flexible at representing the task. For example,
|
||||
@ -431,13 +431,13 @@ LONGFORMER_INPUTS_DOCSTRING = r"""
|
||||
``2`` for global attention (tokens that attend to all other tokens, and all other tokens attend to them).
|
||||
|
||||
`What are attention masks? <../glossary.html#attention-mask>`__
|
||||
token_type_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
|
||||
token_type_ids (:obj:`torch.LongTensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`):
|
||||
Segment token indices to indicate first and second portions of the inputs.
|
||||
Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1``
|
||||
corresponds to a `sentence B` token
|
||||
|
||||
`What are token type IDs? <../glossary.html#token-type-ids>`_
|
||||
position_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
|
||||
position_ids (:obj:`torch.LongTensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`):
|
||||
Indices of positions of each input sequence tokens in the position embeddings.
|
||||
Selected in the range ``[0, config.max_position_embeddings - 1]``.
|
||||
|
||||
@ -537,7 +537,7 @@ class LongformerModel(RobertaModel):
|
||||
|
||||
return padding_len, input_ids, attention_mask, token_type_ids, position_ids, inputs_embeds
|
||||
|
||||
@add_start_docstrings_to_callable(LONGFORMER_INPUTS_DOCSTRING)
|
||||
@add_start_docstrings_to_callable(LONGFORMER_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
@ -641,7 +641,7 @@ class LongformerForMaskedLM(BertPreTrainedModel):
|
||||
|
||||
self.init_weights()
|
||||
|
||||
@add_start_docstrings_to_callable(LONGFORMER_INPUTS_DOCSTRING)
|
||||
@add_start_docstrings_to_callable(LONGFORMER_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
@ -729,7 +729,7 @@ class LongformerForSequenceClassification(BertPreTrainedModel):
|
||||
self.longformer = LongformerModel(config)
|
||||
self.classifier = LongformerClassificationHead(config)
|
||||
|
||||
@add_start_docstrings_to_callable(LONGFORMER_INPUTS_DOCSTRING)
|
||||
@add_start_docstrings_to_callable(LONGFORMER_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
@ -866,7 +866,7 @@ class LongformerForQuestionAnswering(BertPreTrainedModel):
|
||||
|
||||
return sep_token_indices.view(batch_size, 3, 2)[:, 0, 1]
|
||||
|
||||
@add_start_docstrings_to_callable(LONGFORMER_INPUTS_DOCSTRING)
|
||||
@add_start_docstrings_to_callable(LONGFORMER_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
|
||||
def forward(
|
||||
self,
|
||||
input_ids,
|
||||
@ -993,7 +993,7 @@ class LongformerForTokenClassification(BertPreTrainedModel):
|
||||
|
||||
self.init_weights()
|
||||
|
||||
@add_start_docstrings_to_callable(LONGFORMER_INPUTS_DOCSTRING)
|
||||
@add_start_docstrings_to_callable(LONGFORMER_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
@ -1070,3 +1070,100 @@ class LongformerForTokenClassification(BertPreTrainedModel):
|
||||
outputs = (loss,) + outputs
|
||||
|
||||
return outputs # (loss), scores, (hidden_states), (attentions)
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""Longformer Model with a multiple choice classification head on top (a linear layer on top of
|
||||
the pooled output and a softmax) e.g. for RocStories/SWAG tasks. """,
|
||||
LONGFORMER_START_DOCSTRING,
|
||||
)
|
||||
class LongformerForMultipleChoice(BertPreTrainedModel):
|
||||
config_class = LongformerConfig
|
||||
pretrained_model_archive_map = LONGFORMER_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
base_model_prefix = "longformer"
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
self.longformer = LongformerModel(config)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
self.classifier = nn.Linear(config.hidden_size, 1)
|
||||
|
||||
self.init_weights()
|
||||
|
||||
@add_start_docstrings_to_callable(LONGFORMER_INPUTS_DOCSTRING.format("(batch_size, num_choices, sequence_length)"))
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
token_type_ids=None,
|
||||
attention_mask=None,
|
||||
labels=None,
|
||||
position_ids=None,
|
||||
inputs_embeds=None,
|
||||
):
|
||||
r"""
|
||||
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
|
||||
Labels for computing the multiple choice classification loss.
|
||||
Indices should be in ``[0, ..., num_choices]`` where `num_choices` is the size of the second dimension
|
||||
of the input tensors. (see `input_ids` above)
|
||||
|
||||
Returns:
|
||||
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.RobertaConfig`) and inputs:
|
||||
loss (:obj:`torch.FloatTensor`` of shape ``(1,)`, `optional`, returned when :obj:`labels` is provided):
|
||||
Classification loss.
|
||||
classification_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_choices)`):
|
||||
`num_choices` is the second dimension of the input tensors. (see `input_ids` above).
|
||||
|
||||
Classification scores (before SoftMax).
|
||||
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
|
||||
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
|
||||
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
||||
|
||||
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``):
|
||||
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
|
||||
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
|
||||
|
||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
||||
heads.
|
||||
|
||||
Examples::
|
||||
|
||||
from transformers import LongformerTokenizer, LongformerForTokenClassification
|
||||
import torch
|
||||
|
||||
tokenizer = LongformerTokenizer.from_pretrained('longformer-base-4096')
|
||||
model = LongformerForMultipleChoice.from_pretrained('longformer-base-4096')
|
||||
choices = ["Hello, my dog is cute", "Hello, my cat is amazing"]
|
||||
input_ids = torch.tensor([tokenizer.encode(s, add_special_tokens=True) for s in choices]).unsqueeze(0) # Batch size 1, 2 choices
|
||||
labels = torch.tensor(1).unsqueeze(0) # Batch size 1
|
||||
outputs = model(input_ids, labels=labels)
|
||||
loss, classification_scores = outputs[:2]
|
||||
|
||||
"""
|
||||
num_choices = input_ids.shape[1]
|
||||
|
||||
flat_input_ids = input_ids.view(-1, input_ids.size(-1))
|
||||
flat_position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
|
||||
flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
|
||||
flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
|
||||
outputs = self.longformer(
|
||||
flat_input_ids,
|
||||
position_ids=flat_position_ids,
|
||||
token_type_ids=flat_token_type_ids,
|
||||
attention_mask=flat_attention_mask,
|
||||
)
|
||||
pooled_output = outputs[1]
|
||||
|
||||
pooled_output = self.dropout(pooled_output)
|
||||
logits = self.classifier(pooled_output)
|
||||
reshaped_logits = logits.view(-1, num_choices)
|
||||
|
||||
outputs = (reshaped_logits,) + outputs[2:] # add hidden states and attention if they are here
|
||||
|
||||
if labels is not None:
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(reshaped_logits, labels)
|
||||
outputs = (loss,) + outputs
|
||||
|
||||
return outputs # (loss), reshaped_logits, (hidden_states), (attentions)
|
||||
|
@ -95,7 +95,7 @@ ROBERTA_START_DOCSTRING = r"""
|
||||
|
||||
ROBERTA_INPUTS_DOCSTRING = r"""
|
||||
Args:
|
||||
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`):
|
||||
input_ids (:obj:`torch.LongTensor` of shape :obj:`{0}`):
|
||||
Indices of input sequence tokens in the vocabulary.
|
||||
|
||||
Indices can be obtained using :class:`transformers.RobertaTokenizer`.
|
||||
@ -103,19 +103,19 @@ ROBERTA_INPUTS_DOCSTRING = r"""
|
||||
:func:`transformers.PreTrainedTokenizer.encode_plus` for details.
|
||||
|
||||
`What are input IDs? <../glossary.html#input-ids>`__
|
||||
attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
|
||||
attention_mask (:obj:`torch.FloatTensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`):
|
||||
Mask to avoid performing attention on padding token indices.
|
||||
Mask values selected in ``[0, 1]``:
|
||||
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
|
||||
|
||||
`What are attention masks? <../glossary.html#attention-mask>`__
|
||||
token_type_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
|
||||
token_type_ids (:obj:`torch.LongTensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`):
|
||||
Segment token indices to indicate first and second portions of the inputs.
|
||||
Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1``
|
||||
corresponds to a `sentence B` token
|
||||
|
||||
`What are token type IDs? <../glossary.html#token-type-ids>`_
|
||||
position_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
|
||||
position_ids (:obj:`torch.LongTensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`):
|
||||
Indices of positions of each input sequence tokens in the position embeddings.
|
||||
Selected in the range ``[0, config.max_position_embeddings - 1]``.
|
||||
|
||||
@ -175,7 +175,7 @@ class RobertaForMaskedLM(BertPreTrainedModel):
|
||||
def get_output_embeddings(self):
|
||||
return self.lm_head.decoder
|
||||
|
||||
@add_start_docstrings_to_callable(ROBERTA_INPUTS_DOCSTRING)
|
||||
@add_start_docstrings_to_callable(ROBERTA_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
@ -286,7 +286,7 @@ class RobertaForSequenceClassification(BertPreTrainedModel):
|
||||
self.roberta = RobertaModel(config)
|
||||
self.classifier = RobertaClassificationHead(config)
|
||||
|
||||
@add_start_docstrings_to_callable(ROBERTA_INPUTS_DOCSTRING)
|
||||
@add_start_docstrings_to_callable(ROBERTA_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
@ -379,7 +379,7 @@ class RobertaForMultipleChoice(BertPreTrainedModel):
|
||||
|
||||
self.init_weights()
|
||||
|
||||
@add_start_docstrings_to_callable(ROBERTA_INPUTS_DOCSTRING)
|
||||
@add_start_docstrings_to_callable(ROBERTA_INPUTS_DOCSTRING.format("(batch_size, num_choices, sequence_length)"))
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
@ -479,7 +479,7 @@ class RobertaForTokenClassification(BertPreTrainedModel):
|
||||
|
||||
self.init_weights()
|
||||
|
||||
@add_start_docstrings_to_callable(ROBERTA_INPUTS_DOCSTRING)
|
||||
@add_start_docstrings_to_callable(ROBERTA_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
@ -598,7 +598,7 @@ class RobertaForQuestionAnswering(BertPreTrainedModel):
|
||||
|
||||
self.init_weights()
|
||||
|
||||
@add_start_docstrings_to_callable(ROBERTA_INPUTS_DOCSTRING)
|
||||
@add_start_docstrings_to_callable(ROBERTA_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
|
||||
def forward(
|
||||
self,
|
||||
input_ids,
|
||||
|
@ -628,7 +628,7 @@ ALBERT_START_DOCSTRING = r"""
|
||||
|
||||
ALBERT_INPUTS_DOCSTRING = r"""
|
||||
Args:
|
||||
input_ids (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`):
|
||||
input_ids (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`{0}`):
|
||||
Indices of input sequence tokens in the vocabulary.
|
||||
|
||||
Indices can be obtained using :class:`transformers.AlbertTokenizer`.
|
||||
@ -636,19 +636,19 @@ ALBERT_INPUTS_DOCSTRING = r"""
|
||||
:func:`transformers.PreTrainedTokenizer.encode_plus` for details.
|
||||
|
||||
`What are input IDs? <../glossary.html#input-ids>`__
|
||||
attention_mask (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional, defaults to :obj:`None`):
|
||||
attention_mask (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`{0}`, `optional, defaults to :obj:`None`):
|
||||
Mask to avoid performing attention on padding token indices.
|
||||
Mask values selected in ``[0, 1]``:
|
||||
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
|
||||
|
||||
`What are attention masks? <../glossary.html#attention-mask>`__
|
||||
token_type_ids (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
|
||||
token_type_ids (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`):
|
||||
Segment token indices to indicate first and second portions of the inputs.
|
||||
Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1``
|
||||
corresponds to a `sentence B` token
|
||||
|
||||
`What are token type IDs? <../glossary.html#token-type-ids>`_
|
||||
position_ids (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
|
||||
position_ids (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`):
|
||||
Indices of positions of each input sequence tokens in the position embeddings.
|
||||
Selected in the range ``[0, config.max_position_embeddings - 1]``.
|
||||
|
||||
@ -676,7 +676,7 @@ class TFAlbertModel(TFAlbertPreTrainedModel):
|
||||
super().__init__(config, *inputs, **kwargs)
|
||||
self.albert = TFAlbertMainLayer(config, name="albert")
|
||||
|
||||
@add_start_docstrings_to_callable(ALBERT_INPUTS_DOCSTRING)
|
||||
@add_start_docstrings_to_callable(ALBERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
|
||||
def call(self, inputs, **kwargs):
|
||||
r"""
|
||||
Returns:
|
||||
@ -734,7 +734,7 @@ class TFAlbertForPreTraining(TFAlbertPreTrainedModel):
|
||||
def get_output_embeddings(self):
|
||||
return self.albert.embeddings
|
||||
|
||||
@add_start_docstrings_to_callable(ALBERT_INPUTS_DOCSTRING)
|
||||
@add_start_docstrings_to_callable(ALBERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
|
||||
def call(self, inputs, **kwargs):
|
||||
r"""
|
||||
Return:
|
||||
@ -795,7 +795,7 @@ class TFAlbertForMaskedLM(TFAlbertPreTrainedModel):
|
||||
def get_output_embeddings(self):
|
||||
return self.albert.embeddings
|
||||
|
||||
@add_start_docstrings_to_callable(ALBERT_INPUTS_DOCSTRING)
|
||||
@add_start_docstrings_to_callable(ALBERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
|
||||
def call(self, inputs, **kwargs):
|
||||
r"""
|
||||
Returns:
|
||||
@ -852,7 +852,7 @@ class TFAlbertForSequenceClassification(TFAlbertPreTrainedModel):
|
||||
config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
|
||||
)
|
||||
|
||||
@add_start_docstrings_to_callable(ALBERT_INPUTS_DOCSTRING)
|
||||
@add_start_docstrings_to_callable(ALBERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
|
||||
def call(self, inputs, **kwargs):
|
||||
r"""
|
||||
Returns:
|
||||
@ -908,7 +908,7 @@ class TFAlbertForQuestionAnswering(TFAlbertPreTrainedModel):
|
||||
config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs"
|
||||
)
|
||||
|
||||
@add_start_docstrings_to_callable(ALBERT_INPUTS_DOCSTRING)
|
||||
@add_start_docstrings_to_callable(ALBERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
|
||||
def call(self, inputs, **kwargs):
|
||||
r"""
|
||||
Return:
|
||||
@ -983,7 +983,7 @@ class TFAlbertForMultipleChoice(TFAlbertPreTrainedModel):
|
||||
"""
|
||||
return {"input_ids": tf.constant(MULTIPLE_CHOICE_DUMMY_INPUTS)}
|
||||
|
||||
@add_start_docstrings_to_callable(ALBERT_INPUTS_DOCSTRING)
|
||||
@add_start_docstrings_to_callable(ALBERT_INPUTS_DOCSTRING.format("(batch_size, num_choices, sequence_length)"))
|
||||
def call(
|
||||
self,
|
||||
inputs,
|
||||
|
@ -621,7 +621,7 @@ BERT_START_DOCSTRING = r"""
|
||||
|
||||
BERT_INPUTS_DOCSTRING = r"""
|
||||
Args:
|
||||
input_ids (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`):
|
||||
input_ids (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`{0}`):
|
||||
Indices of input sequence tokens in the vocabulary.
|
||||
|
||||
Indices can be obtained using :class:`transformers.BertTokenizer`.
|
||||
@ -629,19 +629,19 @@ BERT_INPUTS_DOCSTRING = r"""
|
||||
:func:`transformers.PreTrainedTokenizer.encode_plus` for details.
|
||||
|
||||
`What are input IDs? <../glossary.html#input-ids>`__
|
||||
attention_mask (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
|
||||
attention_mask (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`):
|
||||
Mask to avoid performing attention on padding token indices.
|
||||
Mask values selected in ``[0, 1]``:
|
||||
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
|
||||
|
||||
`What are attention masks? <../glossary.html#attention-mask>`__
|
||||
token_type_ids (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
|
||||
token_type_ids (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`):
|
||||
Segment token indices to indicate first and second portions of the inputs.
|
||||
Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1``
|
||||
corresponds to a `sentence B` token
|
||||
|
||||
`What are token type IDs? <../glossary.html#token-type-ids>`__
|
||||
position_ids (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
|
||||
position_ids (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`):
|
||||
Indices of positions of each input sequence tokens in the position embeddings.
|
||||
Selected in the range ``[0, config.max_position_embeddings - 1]``.
|
||||
|
||||
@ -669,7 +669,7 @@ class TFBertModel(TFBertPreTrainedModel):
|
||||
super().__init__(config, *inputs, **kwargs)
|
||||
self.bert = TFBertMainLayer(config, name="bert")
|
||||
|
||||
@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING)
|
||||
@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
|
||||
def call(self, inputs, **kwargs):
|
||||
r"""
|
||||
Returns:
|
||||
@ -726,7 +726,7 @@ class TFBertForPreTraining(TFBertPreTrainedModel):
|
||||
def get_output_embeddings(self):
|
||||
return self.bert.embeddings
|
||||
|
||||
@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING)
|
||||
@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
|
||||
def call(self, inputs, **kwargs):
|
||||
r"""
|
||||
Return:
|
||||
@ -782,7 +782,7 @@ class TFBertForMaskedLM(TFBertPreTrainedModel):
|
||||
def get_output_embeddings(self):
|
||||
return self.bert.embeddings
|
||||
|
||||
@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING)
|
||||
@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
|
||||
def call(self, inputs, **kwargs):
|
||||
r"""
|
||||
Return:
|
||||
@ -832,7 +832,7 @@ class TFBertForNextSentencePrediction(TFBertPreTrainedModel):
|
||||
self.bert = TFBertMainLayer(config, name="bert")
|
||||
self.nsp = TFBertNSPHead(config, name="nsp___cls")
|
||||
|
||||
@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING)
|
||||
@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
|
||||
def call(self, inputs, **kwargs):
|
||||
r"""
|
||||
Return:
|
||||
@ -888,7 +888,7 @@ class TFBertForSequenceClassification(TFBertPreTrainedModel):
|
||||
config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
|
||||
)
|
||||
|
||||
@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING)
|
||||
@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
|
||||
def call(self, inputs, **kwargs):
|
||||
r"""
|
||||
Return:
|
||||
@ -954,7 +954,7 @@ class TFBertForMultipleChoice(TFBertPreTrainedModel):
|
||||
"""
|
||||
return {"input_ids": tf.constant(MULTIPLE_CHOICE_DUMMY_INPUTS)}
|
||||
|
||||
@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING)
|
||||
@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, num_choices, sequence_length)"))
|
||||
def call(
|
||||
self,
|
||||
inputs,
|
||||
@ -1065,7 +1065,7 @@ class TFBertForTokenClassification(TFBertPreTrainedModel):
|
||||
config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
|
||||
)
|
||||
|
||||
@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING)
|
||||
@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
|
||||
def call(self, inputs, **kwargs):
|
||||
r"""
|
||||
Return:
|
||||
@ -1122,7 +1122,7 @@ class TFBertForQuestionAnswering(TFBertPreTrainedModel):
|
||||
config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs"
|
||||
)
|
||||
|
||||
@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING)
|
||||
@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
|
||||
def call(self, inputs, **kwargs):
|
||||
r"""
|
||||
Return:
|
||||
|
@ -506,7 +506,7 @@ XLNET_START_DOCSTRING = r"""
|
||||
|
||||
XLNET_INPUTS_DOCSTRING = r"""
|
||||
Args:
|
||||
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`):
|
||||
input_ids (:obj:`torch.LongTensor` of shape :obj:`{0}`):
|
||||
Indices of input sequence tokens in the vocabulary.
|
||||
|
||||
Indices can be obtained using :class:`transformers.BertTokenizer`.
|
||||
@ -514,7 +514,7 @@ XLNET_INPUTS_DOCSTRING = r"""
|
||||
:func:`transformers.PreTrainedTokenizer.encode_plus` for details.
|
||||
|
||||
`What are input IDs? <../glossary.html#input-ids>`__
|
||||
attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
|
||||
attention_mask (:obj:`torch.FloatTensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`):
|
||||
Mask to avoid performing attention on padding token indices.
|
||||
Mask values selected in ``[0, 1]``:
|
||||
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
|
||||
@ -535,13 +535,13 @@ XLNET_INPUTS_DOCSTRING = r"""
|
||||
Mask to indicate the output tokens to use.
|
||||
If ``target_mapping[k, i, j] = 1``, the i-th predict in batch k is on the j-th token.
|
||||
Only used during pretraining for partial prediction or for sequential decoding (generation).
|
||||
token_type_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
|
||||
token_type_ids (:obj:`torch.LongTensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`):
|
||||
Segment token indices to indicate first and second portions of the inputs.
|
||||
Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1``
|
||||
corresponds to a `sentence B` token. The classifier token should be represented by a ``2``.
|
||||
|
||||
`What are token type IDs? <../glossary.html#token-type-ids>`_
|
||||
input_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
|
||||
input_mask (:obj:`torch.FloatTensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`):
|
||||
Mask to avoid performing attention on padding token indices.
|
||||
Negative of `attention_mask`, i.e. with 0 for real tokens and 1 for padding.
|
||||
Kept for compatibility with the original code base.
|
||||
@ -688,7 +688,7 @@ class XLNetModel(XLNetPreTrainedModel):
|
||||
pos_emb = pos_emb.to(self.device)
|
||||
return pos_emb
|
||||
|
||||
@add_start_docstrings_to_callable(XLNET_INPUTS_DOCSTRING)
|
||||
@add_start_docstrings_to_callable(XLNET_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
@ -971,7 +971,7 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
|
||||
|
||||
return inputs
|
||||
|
||||
@add_start_docstrings_to_callable(XLNET_INPUTS_DOCSTRING)
|
||||
@add_start_docstrings_to_callable(XLNET_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
@ -1091,7 +1091,7 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
|
||||
|
||||
self.init_weights()
|
||||
|
||||
@add_start_docstrings_to_callable(XLNET_INPUTS_DOCSTRING)
|
||||
@add_start_docstrings_to_callable(XLNET_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
@ -1196,7 +1196,7 @@ class XLNetForTokenClassification(XLNetPreTrainedModel):
|
||||
|
||||
self.init_weights()
|
||||
|
||||
@add_start_docstrings_to_callable(XLNET_INPUTS_DOCSTRING)
|
||||
@add_start_docstrings_to_callable(XLNET_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
@ -1305,7 +1305,7 @@ class XLNetForMultipleChoice(XLNetPreTrainedModel):
|
||||
|
||||
self.init_weights()
|
||||
|
||||
@add_start_docstrings_to_callable(XLNET_INPUTS_DOCSTRING)
|
||||
@add_start_docstrings_to_callable(XLNET_INPUTS_DOCSTRING.format("(batch_size, num_choices, sequence_length)"))
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
@ -1418,7 +1418,7 @@ class XLNetForQuestionAnsweringSimple(XLNetPreTrainedModel):
|
||||
|
||||
self.init_weights()
|
||||
|
||||
@add_start_docstrings_to_callable(XLNET_INPUTS_DOCSTRING)
|
||||
@add_start_docstrings_to_callable(XLNET_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
@ -1544,7 +1544,7 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel):
|
||||
|
||||
self.init_weights()
|
||||
|
||||
@add_start_docstrings_to_callable(XLNET_INPUTS_DOCSTRING)
|
||||
@add_start_docstrings_to_callable(XLNET_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
|
@ -32,6 +32,7 @@ if is_torch_available():
|
||||
LongformerForSequenceClassification,
|
||||
LongformerForTokenClassification,
|
||||
LongformerForQuestionAnswering,
|
||||
LongformerForMultipleChoice,
|
||||
)
|
||||
|
||||
|
||||
@ -228,6 +229,29 @@ class LongformerModelTester(object):
|
||||
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.seq_length, self.num_labels])
|
||||
self.check_loss_output(result)
|
||||
|
||||
def create_and_check_longformer_for_multiple_choice(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
):
|
||||
config.num_choices = self.num_choices
|
||||
model = LongformerForMultipleChoice(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
multiple_choice_inputs_ids = input_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
|
||||
multiple_choice_token_type_ids = token_type_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
|
||||
multiple_choice_input_mask = input_mask.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
|
||||
loss, logits = model(
|
||||
multiple_choice_inputs_ids,
|
||||
attention_mask=multiple_choice_input_mask,
|
||||
token_type_ids=multiple_choice_token_type_ids,
|
||||
labels=choice_labels,
|
||||
)
|
||||
result = {
|
||||
"loss": loss,
|
||||
"logits": logits,
|
||||
}
|
||||
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.num_choices])
|
||||
self.check_loss_output(result)
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
(
|
||||
@ -298,6 +322,10 @@ class LongformerModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_longformer_for_token_classification(*config_and_inputs)
|
||||
|
||||
def test_for_multiple_choice(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_longformer_for_multiple_choice(*config_and_inputs)
|
||||
|
||||
|
||||
class LongformerModelIntegrationTest(unittest.TestCase):
|
||||
@slow
|
||||
|
Loading…
Reference in New Issue
Block a user