[Longformer] Multiple choice for longformer (#4645)

* add multiple choice for longformer

* add models to docs

* adapt docstring

* add test to longformer

* add longformer for mc in init and modeling auto

* fix tests
This commit is contained in:
Patrick von Platen 2020-05-29 13:46:08 +02:00 committed by GitHub
parent 91487cbb8e
commit 9c17256447
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 227 additions and 63 deletions

View File

@ -94,3 +94,17 @@ TFAlbertForSequenceClassification
.. autoclass:: transformers.TFAlbertForSequenceClassification .. autoclass:: transformers.TFAlbertForSequenceClassification
:members: :members:
TFAlbertForMultipleChoice
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.TFAlbertForMultipleChoice
:members:
TFAlbertForQuestionAnswering
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.TFAlbertForQuestionAnswering
:members:

View File

@ -74,3 +74,18 @@ LongformerForQuestionAnswering
.. autoclass:: transformers.LongformerForQuestionAnswering .. autoclass:: transformers.LongformerForQuestionAnswering
:members: :members:
LongformerForMultipleChoice
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.LongformerForMultipleChoice
:members:
LongformerForTokenClassification
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.LongformerForTokenClassification
:members:

View File

@ -74,6 +74,13 @@ RobertaForSequenceClassification
:members: :members:
RobertaForMultipleChoice
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.RobertaForMultipleChoice
:members:
RobertaForTokenClassification RobertaForTokenClassification
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

View File

@ -326,6 +326,7 @@ if is_torch_available():
LongformerModel, LongformerModel,
LongformerForMaskedLM, LongformerForMaskedLM,
LongformerForSequenceClassification, LongformerForSequenceClassification,
LongformerForMultipleChoice,
LongformerForTokenClassification, LongformerForTokenClassification,
LongformerForQuestionAnswering, LongformerForQuestionAnswering,
LONGFORMER_PRETRAINED_MODEL_ARCHIVE_MAP, LONGFORMER_PRETRAINED_MODEL_ARCHIVE_MAP,

View File

@ -104,6 +104,7 @@ from .modeling_gpt2 import GPT2_PRETRAINED_MODEL_ARCHIVE_MAP, GPT2LMHeadModel, G
from .modeling_longformer import ( from .modeling_longformer import (
LONGFORMER_PRETRAINED_MODEL_ARCHIVE_MAP, LONGFORMER_PRETRAINED_MODEL_ARCHIVE_MAP,
LongformerForMaskedLM, LongformerForMaskedLM,
LongformerForMultipleChoice,
LongformerForQuestionAnswering, LongformerForQuestionAnswering,
LongformerForSequenceClassification, LongformerForSequenceClassification,
LongformerForTokenClassification, LongformerForTokenClassification,
@ -297,6 +298,7 @@ MODEL_FOR_MULTIPLE_CHOICE_MAPPING = OrderedDict(
[ [
(CamembertConfig, CamembertForMultipleChoice), (CamembertConfig, CamembertForMultipleChoice),
(XLMRobertaConfig, XLMRobertaForMultipleChoice), (XLMRobertaConfig, XLMRobertaForMultipleChoice),
(LongformerConfig, LongformerForMultipleChoice),
(RobertaConfig, RobertaForMultipleChoice), (RobertaConfig, RobertaForMultipleChoice),
(BertConfig, BertForMultipleChoice), (BertConfig, BertForMultipleChoice),
(XLNetConfig, XLNetForMultipleChoice), (XLNetConfig, XLNetForMultipleChoice),

View File

@ -543,7 +543,7 @@ BERT_START_DOCSTRING = r"""
BERT_INPUTS_DOCSTRING = r""" BERT_INPUTS_DOCSTRING = r"""
Args: Args:
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`): input_ids (:obj:`torch.LongTensor` of shape :obj:`{0}`):
Indices of input sequence tokens in the vocabulary. Indices of input sequence tokens in the vocabulary.
Indices can be obtained using :class:`transformers.BertTokenizer`. Indices can be obtained using :class:`transformers.BertTokenizer`.
@ -551,19 +551,19 @@ BERT_INPUTS_DOCSTRING = r"""
:func:`transformers.PreTrainedTokenizer.encode_plus` for details. :func:`transformers.PreTrainedTokenizer.encode_plus` for details.
`What are input IDs? <../glossary.html#input-ids>`__ `What are input IDs? <../glossary.html#input-ids>`__
attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): attention_mask (:obj:`torch.FloatTensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`):
Mask to avoid performing attention on padding token indices. Mask to avoid performing attention on padding token indices.
Mask values selected in ``[0, 1]``: Mask values selected in ``[0, 1]``:
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens. ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
`What are attention masks? <../glossary.html#attention-mask>`__ `What are attention masks? <../glossary.html#attention-mask>`__
token_type_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): token_type_ids (:obj:`torch.LongTensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`):
Segment token indices to indicate first and second portions of the inputs. Segment token indices to indicate first and second portions of the inputs.
Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1`` Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1``
corresponds to a `sentence B` token corresponds to a `sentence B` token
`What are token type IDs? <../glossary.html#token-type-ids>`_ `What are token type IDs? <../glossary.html#token-type-ids>`_
position_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): position_ids (:obj:`torch.LongTensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`):
Indices of positions of each input sequence tokens in the position embeddings. Indices of positions of each input sequence tokens in the position embeddings.
Selected in the range ``[0, config.max_position_embeddings - 1]``. Selected in the range ``[0, config.max_position_embeddings - 1]``.
@ -632,7 +632,7 @@ class BertModel(BertPreTrainedModel):
for layer, heads in heads_to_prune.items(): for layer, heads in heads_to_prune.items():
self.encoder.layer[layer].attention.prune_heads(heads) self.encoder.layer[layer].attention.prune_heads(heads)
@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING) @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
def forward( def forward(
self, self,
input_ids=None, input_ids=None,
@ -759,7 +759,7 @@ class BertForPreTraining(BertPreTrainedModel):
def get_output_embeddings(self): def get_output_embeddings(self):
return self.cls.predictions.decoder return self.cls.predictions.decoder
@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING) @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
def forward( def forward(
self, self,
input_ids=None, input_ids=None,
@ -859,7 +859,7 @@ class BertForMaskedLM(BertPreTrainedModel):
def get_output_embeddings(self): def get_output_embeddings(self):
return self.cls.predictions.decoder return self.cls.predictions.decoder
@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING) @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
def forward( def forward(
self, self,
input_ids=None, input_ids=None,
@ -992,7 +992,7 @@ class BertForNextSentencePrediction(BertPreTrainedModel):
self.init_weights() self.init_weights()
@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING) @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
def forward( def forward(
self, self,
input_ids=None, input_ids=None,
@ -1081,7 +1081,7 @@ class BertForSequenceClassification(BertPreTrainedModel):
self.init_weights() self.init_weights()
@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING) @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
def forward( def forward(
self, self,
input_ids=None, input_ids=None,
@ -1177,7 +1177,7 @@ class BertForMultipleChoice(BertPreTrainedModel):
self.init_weights() self.init_weights()
@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING) @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, num_choices, sequence_length)"))
def forward( def forward(
self, self,
input_ids=None, input_ids=None,
@ -1278,7 +1278,7 @@ class BertForTokenClassification(BertPreTrainedModel):
self.init_weights() self.init_weights()
@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING) @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
def forward( def forward(
self, self,
input_ids=None, input_ids=None,
@ -1375,7 +1375,7 @@ class BertForQuestionAnswering(BertPreTrainedModel):
self.init_weights() self.init_weights()
@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING) @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
def forward( def forward(
self, self,
input_ids=None, input_ids=None,

View File

@ -411,7 +411,7 @@ LONGFORMER_START_DOCSTRING = r"""
LONGFORMER_INPUTS_DOCSTRING = r""" LONGFORMER_INPUTS_DOCSTRING = r"""
Args: Args:
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`): input_ids (:obj:`torch.LongTensor` of shape :obj:`{0}`):
Indices of input sequence tokens in the vocabulary. Indices of input sequence tokens in the vocabulary.
Indices can be obtained using :class:`transformers.LonmgformerTokenizer`. Indices can be obtained using :class:`transformers.LonmgformerTokenizer`.
@ -419,7 +419,7 @@ LONGFORMER_INPUTS_DOCSTRING = r"""
:func:`transformers.PreTrainedTokenizer.encode_plus` for details. :func:`transformers.PreTrainedTokenizer.encode_plus` for details.
`What are input IDs? <../glossary.html#input-ids>`__ `What are input IDs? <../glossary.html#input-ids>`__
attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): attention_mask (:obj:`torch.FloatTensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`):
Mask to decide the attention given on each token, local attention, global attenion, or no attention (for padding tokens). Mask to decide the attention given on each token, local attention, global attenion, or no attention (for padding tokens).
Tokens with global attention attends to all other tokens, and all other tokens attend to them. This is important for Tokens with global attention attends to all other tokens, and all other tokens attend to them. This is important for
task-specific finetuning because it makes the model more flexible at representing the task. For example, task-specific finetuning because it makes the model more flexible at representing the task. For example,
@ -431,13 +431,13 @@ LONGFORMER_INPUTS_DOCSTRING = r"""
``2`` for global attention (tokens that attend to all other tokens, and all other tokens attend to them). ``2`` for global attention (tokens that attend to all other tokens, and all other tokens attend to them).
`What are attention masks? <../glossary.html#attention-mask>`__ `What are attention masks? <../glossary.html#attention-mask>`__
token_type_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): token_type_ids (:obj:`torch.LongTensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`):
Segment token indices to indicate first and second portions of the inputs. Segment token indices to indicate first and second portions of the inputs.
Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1`` Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1``
corresponds to a `sentence B` token corresponds to a `sentence B` token
`What are token type IDs? <../glossary.html#token-type-ids>`_ `What are token type IDs? <../glossary.html#token-type-ids>`_
position_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): position_ids (:obj:`torch.LongTensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`):
Indices of positions of each input sequence tokens in the position embeddings. Indices of positions of each input sequence tokens in the position embeddings.
Selected in the range ``[0, config.max_position_embeddings - 1]``. Selected in the range ``[0, config.max_position_embeddings - 1]``.
@ -537,7 +537,7 @@ class LongformerModel(RobertaModel):
return padding_len, input_ids, attention_mask, token_type_ids, position_ids, inputs_embeds return padding_len, input_ids, attention_mask, token_type_ids, position_ids, inputs_embeds
@add_start_docstrings_to_callable(LONGFORMER_INPUTS_DOCSTRING) @add_start_docstrings_to_callable(LONGFORMER_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
def forward( def forward(
self, self,
input_ids=None, input_ids=None,
@ -641,7 +641,7 @@ class LongformerForMaskedLM(BertPreTrainedModel):
self.init_weights() self.init_weights()
@add_start_docstrings_to_callable(LONGFORMER_INPUTS_DOCSTRING) @add_start_docstrings_to_callable(LONGFORMER_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
def forward( def forward(
self, self,
input_ids=None, input_ids=None,
@ -729,7 +729,7 @@ class LongformerForSequenceClassification(BertPreTrainedModel):
self.longformer = LongformerModel(config) self.longformer = LongformerModel(config)
self.classifier = LongformerClassificationHead(config) self.classifier = LongformerClassificationHead(config)
@add_start_docstrings_to_callable(LONGFORMER_INPUTS_DOCSTRING) @add_start_docstrings_to_callable(LONGFORMER_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
def forward( def forward(
self, self,
input_ids=None, input_ids=None,
@ -866,7 +866,7 @@ class LongformerForQuestionAnswering(BertPreTrainedModel):
return sep_token_indices.view(batch_size, 3, 2)[:, 0, 1] return sep_token_indices.view(batch_size, 3, 2)[:, 0, 1]
@add_start_docstrings_to_callable(LONGFORMER_INPUTS_DOCSTRING) @add_start_docstrings_to_callable(LONGFORMER_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
def forward( def forward(
self, self,
input_ids, input_ids,
@ -993,7 +993,7 @@ class LongformerForTokenClassification(BertPreTrainedModel):
self.init_weights() self.init_weights()
@add_start_docstrings_to_callable(LONGFORMER_INPUTS_DOCSTRING) @add_start_docstrings_to_callable(LONGFORMER_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
def forward( def forward(
self, self,
input_ids=None, input_ids=None,
@ -1070,3 +1070,100 @@ class LongformerForTokenClassification(BertPreTrainedModel):
outputs = (loss,) + outputs outputs = (loss,) + outputs
return outputs # (loss), scores, (hidden_states), (attentions) return outputs # (loss), scores, (hidden_states), (attentions)
@add_start_docstrings(
"""Longformer Model with a multiple choice classification head on top (a linear layer on top of
the pooled output and a softmax) e.g. for RocStories/SWAG tasks. """,
LONGFORMER_START_DOCSTRING,
)
class LongformerForMultipleChoice(BertPreTrainedModel):
config_class = LongformerConfig
pretrained_model_archive_map = LONGFORMER_PRETRAINED_MODEL_ARCHIVE_MAP
base_model_prefix = "longformer"
def __init__(self, config):
super().__init__(config)
self.longformer = LongformerModel(config)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.classifier = nn.Linear(config.hidden_size, 1)
self.init_weights()
@add_start_docstrings_to_callable(LONGFORMER_INPUTS_DOCSTRING.format("(batch_size, num_choices, sequence_length)"))
def forward(
self,
input_ids=None,
token_type_ids=None,
attention_mask=None,
labels=None,
position_ids=None,
inputs_embeds=None,
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
Labels for computing the multiple choice classification loss.
Indices should be in ``[0, ..., num_choices]`` where `num_choices` is the size of the second dimension
of the input tensors. (see `input_ids` above)
Returns:
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.RobertaConfig`) and inputs:
loss (:obj:`torch.FloatTensor`` of shape ``(1,)`, `optional`, returned when :obj:`labels` is provided):
Classification loss.
classification_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_choices)`):
`num_choices` is the second dimension of the input tensors. (see `input_ids` above).
Classification scores (before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
Examples::
from transformers import LongformerTokenizer, LongformerForTokenClassification
import torch
tokenizer = LongformerTokenizer.from_pretrained('longformer-base-4096')
model = LongformerForMultipleChoice.from_pretrained('longformer-base-4096')
choices = ["Hello, my dog is cute", "Hello, my cat is amazing"]
input_ids = torch.tensor([tokenizer.encode(s, add_special_tokens=True) for s in choices]).unsqueeze(0) # Batch size 1, 2 choices
labels = torch.tensor(1).unsqueeze(0) # Batch size 1
outputs = model(input_ids, labels=labels)
loss, classification_scores = outputs[:2]
"""
num_choices = input_ids.shape[1]
flat_input_ids = input_ids.view(-1, input_ids.size(-1))
flat_position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
outputs = self.longformer(
flat_input_ids,
position_ids=flat_position_ids,
token_type_ids=flat_token_type_ids,
attention_mask=flat_attention_mask,
)
pooled_output = outputs[1]
pooled_output = self.dropout(pooled_output)
logits = self.classifier(pooled_output)
reshaped_logits = logits.view(-1, num_choices)
outputs = (reshaped_logits,) + outputs[2:] # add hidden states and attention if they are here
if labels is not None:
loss_fct = CrossEntropyLoss()
loss = loss_fct(reshaped_logits, labels)
outputs = (loss,) + outputs
return outputs # (loss), reshaped_logits, (hidden_states), (attentions)

View File

@ -95,7 +95,7 @@ ROBERTA_START_DOCSTRING = r"""
ROBERTA_INPUTS_DOCSTRING = r""" ROBERTA_INPUTS_DOCSTRING = r"""
Args: Args:
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`): input_ids (:obj:`torch.LongTensor` of shape :obj:`{0}`):
Indices of input sequence tokens in the vocabulary. Indices of input sequence tokens in the vocabulary.
Indices can be obtained using :class:`transformers.RobertaTokenizer`. Indices can be obtained using :class:`transformers.RobertaTokenizer`.
@ -103,19 +103,19 @@ ROBERTA_INPUTS_DOCSTRING = r"""
:func:`transformers.PreTrainedTokenizer.encode_plus` for details. :func:`transformers.PreTrainedTokenizer.encode_plus` for details.
`What are input IDs? <../glossary.html#input-ids>`__ `What are input IDs? <../glossary.html#input-ids>`__
attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): attention_mask (:obj:`torch.FloatTensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`):
Mask to avoid performing attention on padding token indices. Mask to avoid performing attention on padding token indices.
Mask values selected in ``[0, 1]``: Mask values selected in ``[0, 1]``:
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens. ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
`What are attention masks? <../glossary.html#attention-mask>`__ `What are attention masks? <../glossary.html#attention-mask>`__
token_type_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): token_type_ids (:obj:`torch.LongTensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`):
Segment token indices to indicate first and second portions of the inputs. Segment token indices to indicate first and second portions of the inputs.
Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1`` Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1``
corresponds to a `sentence B` token corresponds to a `sentence B` token
`What are token type IDs? <../glossary.html#token-type-ids>`_ `What are token type IDs? <../glossary.html#token-type-ids>`_
position_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): position_ids (:obj:`torch.LongTensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`):
Indices of positions of each input sequence tokens in the position embeddings. Indices of positions of each input sequence tokens in the position embeddings.
Selected in the range ``[0, config.max_position_embeddings - 1]``. Selected in the range ``[0, config.max_position_embeddings - 1]``.
@ -175,7 +175,7 @@ class RobertaForMaskedLM(BertPreTrainedModel):
def get_output_embeddings(self): def get_output_embeddings(self):
return self.lm_head.decoder return self.lm_head.decoder
@add_start_docstrings_to_callable(ROBERTA_INPUTS_DOCSTRING) @add_start_docstrings_to_callable(ROBERTA_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
def forward( def forward(
self, self,
input_ids=None, input_ids=None,
@ -286,7 +286,7 @@ class RobertaForSequenceClassification(BertPreTrainedModel):
self.roberta = RobertaModel(config) self.roberta = RobertaModel(config)
self.classifier = RobertaClassificationHead(config) self.classifier = RobertaClassificationHead(config)
@add_start_docstrings_to_callable(ROBERTA_INPUTS_DOCSTRING) @add_start_docstrings_to_callable(ROBERTA_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
def forward( def forward(
self, self,
input_ids=None, input_ids=None,
@ -379,7 +379,7 @@ class RobertaForMultipleChoice(BertPreTrainedModel):
self.init_weights() self.init_weights()
@add_start_docstrings_to_callable(ROBERTA_INPUTS_DOCSTRING) @add_start_docstrings_to_callable(ROBERTA_INPUTS_DOCSTRING.format("(batch_size, num_choices, sequence_length)"))
def forward( def forward(
self, self,
input_ids=None, input_ids=None,
@ -479,7 +479,7 @@ class RobertaForTokenClassification(BertPreTrainedModel):
self.init_weights() self.init_weights()
@add_start_docstrings_to_callable(ROBERTA_INPUTS_DOCSTRING) @add_start_docstrings_to_callable(ROBERTA_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
def forward( def forward(
self, self,
input_ids=None, input_ids=None,
@ -598,7 +598,7 @@ class RobertaForQuestionAnswering(BertPreTrainedModel):
self.init_weights() self.init_weights()
@add_start_docstrings_to_callable(ROBERTA_INPUTS_DOCSTRING) @add_start_docstrings_to_callable(ROBERTA_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
def forward( def forward(
self, self,
input_ids, input_ids,

View File

@ -628,7 +628,7 @@ ALBERT_START_DOCSTRING = r"""
ALBERT_INPUTS_DOCSTRING = r""" ALBERT_INPUTS_DOCSTRING = r"""
Args: Args:
input_ids (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`): input_ids (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`{0}`):
Indices of input sequence tokens in the vocabulary. Indices of input sequence tokens in the vocabulary.
Indices can be obtained using :class:`transformers.AlbertTokenizer`. Indices can be obtained using :class:`transformers.AlbertTokenizer`.
@ -636,19 +636,19 @@ ALBERT_INPUTS_DOCSTRING = r"""
:func:`transformers.PreTrainedTokenizer.encode_plus` for details. :func:`transformers.PreTrainedTokenizer.encode_plus` for details.
`What are input IDs? <../glossary.html#input-ids>`__ `What are input IDs? <../glossary.html#input-ids>`__
attention_mask (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional, defaults to :obj:`None`): attention_mask (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`{0}`, `optional, defaults to :obj:`None`):
Mask to avoid performing attention on padding token indices. Mask to avoid performing attention on padding token indices.
Mask values selected in ``[0, 1]``: Mask values selected in ``[0, 1]``:
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens. ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
`What are attention masks? <../glossary.html#attention-mask>`__ `What are attention masks? <../glossary.html#attention-mask>`__
token_type_ids (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): token_type_ids (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`):
Segment token indices to indicate first and second portions of the inputs. Segment token indices to indicate first and second portions of the inputs.
Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1`` Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1``
corresponds to a `sentence B` token corresponds to a `sentence B` token
`What are token type IDs? <../glossary.html#token-type-ids>`_ `What are token type IDs? <../glossary.html#token-type-ids>`_
position_ids (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): position_ids (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`):
Indices of positions of each input sequence tokens in the position embeddings. Indices of positions of each input sequence tokens in the position embeddings.
Selected in the range ``[0, config.max_position_embeddings - 1]``. Selected in the range ``[0, config.max_position_embeddings - 1]``.
@ -676,7 +676,7 @@ class TFAlbertModel(TFAlbertPreTrainedModel):
super().__init__(config, *inputs, **kwargs) super().__init__(config, *inputs, **kwargs)
self.albert = TFAlbertMainLayer(config, name="albert") self.albert = TFAlbertMainLayer(config, name="albert")
@add_start_docstrings_to_callable(ALBERT_INPUTS_DOCSTRING) @add_start_docstrings_to_callable(ALBERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
def call(self, inputs, **kwargs): def call(self, inputs, **kwargs):
r""" r"""
Returns: Returns:
@ -734,7 +734,7 @@ class TFAlbertForPreTraining(TFAlbertPreTrainedModel):
def get_output_embeddings(self): def get_output_embeddings(self):
return self.albert.embeddings return self.albert.embeddings
@add_start_docstrings_to_callable(ALBERT_INPUTS_DOCSTRING) @add_start_docstrings_to_callable(ALBERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
def call(self, inputs, **kwargs): def call(self, inputs, **kwargs):
r""" r"""
Return: Return:
@ -795,7 +795,7 @@ class TFAlbertForMaskedLM(TFAlbertPreTrainedModel):
def get_output_embeddings(self): def get_output_embeddings(self):
return self.albert.embeddings return self.albert.embeddings
@add_start_docstrings_to_callable(ALBERT_INPUTS_DOCSTRING) @add_start_docstrings_to_callable(ALBERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
def call(self, inputs, **kwargs): def call(self, inputs, **kwargs):
r""" r"""
Returns: Returns:
@ -852,7 +852,7 @@ class TFAlbertForSequenceClassification(TFAlbertPreTrainedModel):
config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier" config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
) )
@add_start_docstrings_to_callable(ALBERT_INPUTS_DOCSTRING) @add_start_docstrings_to_callable(ALBERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
def call(self, inputs, **kwargs): def call(self, inputs, **kwargs):
r""" r"""
Returns: Returns:
@ -908,7 +908,7 @@ class TFAlbertForQuestionAnswering(TFAlbertPreTrainedModel):
config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs" config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs"
) )
@add_start_docstrings_to_callable(ALBERT_INPUTS_DOCSTRING) @add_start_docstrings_to_callable(ALBERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
def call(self, inputs, **kwargs): def call(self, inputs, **kwargs):
r""" r"""
Return: Return:
@ -983,7 +983,7 @@ class TFAlbertForMultipleChoice(TFAlbertPreTrainedModel):
""" """
return {"input_ids": tf.constant(MULTIPLE_CHOICE_DUMMY_INPUTS)} return {"input_ids": tf.constant(MULTIPLE_CHOICE_DUMMY_INPUTS)}
@add_start_docstrings_to_callable(ALBERT_INPUTS_DOCSTRING) @add_start_docstrings_to_callable(ALBERT_INPUTS_DOCSTRING.format("(batch_size, num_choices, sequence_length)"))
def call( def call(
self, self,
inputs, inputs,

View File

@ -621,7 +621,7 @@ BERT_START_DOCSTRING = r"""
BERT_INPUTS_DOCSTRING = r""" BERT_INPUTS_DOCSTRING = r"""
Args: Args:
input_ids (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`): input_ids (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`{0}`):
Indices of input sequence tokens in the vocabulary. Indices of input sequence tokens in the vocabulary.
Indices can be obtained using :class:`transformers.BertTokenizer`. Indices can be obtained using :class:`transformers.BertTokenizer`.
@ -629,19 +629,19 @@ BERT_INPUTS_DOCSTRING = r"""
:func:`transformers.PreTrainedTokenizer.encode_plus` for details. :func:`transformers.PreTrainedTokenizer.encode_plus` for details.
`What are input IDs? <../glossary.html#input-ids>`__ `What are input IDs? <../glossary.html#input-ids>`__
attention_mask (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): attention_mask (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`):
Mask to avoid performing attention on padding token indices. Mask to avoid performing attention on padding token indices.
Mask values selected in ``[0, 1]``: Mask values selected in ``[0, 1]``:
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens. ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
`What are attention masks? <../glossary.html#attention-mask>`__ `What are attention masks? <../glossary.html#attention-mask>`__
token_type_ids (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): token_type_ids (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`):
Segment token indices to indicate first and second portions of the inputs. Segment token indices to indicate first and second portions of the inputs.
Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1`` Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1``
corresponds to a `sentence B` token corresponds to a `sentence B` token
`What are token type IDs? <../glossary.html#token-type-ids>`__ `What are token type IDs? <../glossary.html#token-type-ids>`__
position_ids (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): position_ids (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`):
Indices of positions of each input sequence tokens in the position embeddings. Indices of positions of each input sequence tokens in the position embeddings.
Selected in the range ``[0, config.max_position_embeddings - 1]``. Selected in the range ``[0, config.max_position_embeddings - 1]``.
@ -669,7 +669,7 @@ class TFBertModel(TFBertPreTrainedModel):
super().__init__(config, *inputs, **kwargs) super().__init__(config, *inputs, **kwargs)
self.bert = TFBertMainLayer(config, name="bert") self.bert = TFBertMainLayer(config, name="bert")
@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING) @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
def call(self, inputs, **kwargs): def call(self, inputs, **kwargs):
r""" r"""
Returns: Returns:
@ -726,7 +726,7 @@ class TFBertForPreTraining(TFBertPreTrainedModel):
def get_output_embeddings(self): def get_output_embeddings(self):
return self.bert.embeddings return self.bert.embeddings
@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING) @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
def call(self, inputs, **kwargs): def call(self, inputs, **kwargs):
r""" r"""
Return: Return:
@ -782,7 +782,7 @@ class TFBertForMaskedLM(TFBertPreTrainedModel):
def get_output_embeddings(self): def get_output_embeddings(self):
return self.bert.embeddings return self.bert.embeddings
@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING) @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
def call(self, inputs, **kwargs): def call(self, inputs, **kwargs):
r""" r"""
Return: Return:
@ -832,7 +832,7 @@ class TFBertForNextSentencePrediction(TFBertPreTrainedModel):
self.bert = TFBertMainLayer(config, name="bert") self.bert = TFBertMainLayer(config, name="bert")
self.nsp = TFBertNSPHead(config, name="nsp___cls") self.nsp = TFBertNSPHead(config, name="nsp___cls")
@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING) @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
def call(self, inputs, **kwargs): def call(self, inputs, **kwargs):
r""" r"""
Return: Return:
@ -888,7 +888,7 @@ class TFBertForSequenceClassification(TFBertPreTrainedModel):
config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier" config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
) )
@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING) @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
def call(self, inputs, **kwargs): def call(self, inputs, **kwargs):
r""" r"""
Return: Return:
@ -954,7 +954,7 @@ class TFBertForMultipleChoice(TFBertPreTrainedModel):
""" """
return {"input_ids": tf.constant(MULTIPLE_CHOICE_DUMMY_INPUTS)} return {"input_ids": tf.constant(MULTIPLE_CHOICE_DUMMY_INPUTS)}
@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING) @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, num_choices, sequence_length)"))
def call( def call(
self, self,
inputs, inputs,
@ -1065,7 +1065,7 @@ class TFBertForTokenClassification(TFBertPreTrainedModel):
config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier" config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
) )
@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING) @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
def call(self, inputs, **kwargs): def call(self, inputs, **kwargs):
r""" r"""
Return: Return:
@ -1122,7 +1122,7 @@ class TFBertForQuestionAnswering(TFBertPreTrainedModel):
config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs" config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs"
) )
@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING) @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
def call(self, inputs, **kwargs): def call(self, inputs, **kwargs):
r""" r"""
Return: Return:

View File

@ -506,7 +506,7 @@ XLNET_START_DOCSTRING = r"""
XLNET_INPUTS_DOCSTRING = r""" XLNET_INPUTS_DOCSTRING = r"""
Args: Args:
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`): input_ids (:obj:`torch.LongTensor` of shape :obj:`{0}`):
Indices of input sequence tokens in the vocabulary. Indices of input sequence tokens in the vocabulary.
Indices can be obtained using :class:`transformers.BertTokenizer`. Indices can be obtained using :class:`transformers.BertTokenizer`.
@ -514,7 +514,7 @@ XLNET_INPUTS_DOCSTRING = r"""
:func:`transformers.PreTrainedTokenizer.encode_plus` for details. :func:`transformers.PreTrainedTokenizer.encode_plus` for details.
`What are input IDs? <../glossary.html#input-ids>`__ `What are input IDs? <../glossary.html#input-ids>`__
attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): attention_mask (:obj:`torch.FloatTensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`):
Mask to avoid performing attention on padding token indices. Mask to avoid performing attention on padding token indices.
Mask values selected in ``[0, 1]``: Mask values selected in ``[0, 1]``:
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens. ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
@ -535,13 +535,13 @@ XLNET_INPUTS_DOCSTRING = r"""
Mask to indicate the output tokens to use. Mask to indicate the output tokens to use.
If ``target_mapping[k, i, j] = 1``, the i-th predict in batch k is on the j-th token. If ``target_mapping[k, i, j] = 1``, the i-th predict in batch k is on the j-th token.
Only used during pretraining for partial prediction or for sequential decoding (generation). Only used during pretraining for partial prediction or for sequential decoding (generation).
token_type_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): token_type_ids (:obj:`torch.LongTensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`):
Segment token indices to indicate first and second portions of the inputs. Segment token indices to indicate first and second portions of the inputs.
Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1`` Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1``
corresponds to a `sentence B` token. The classifier token should be represented by a ``2``. corresponds to a `sentence B` token. The classifier token should be represented by a ``2``.
`What are token type IDs? <../glossary.html#token-type-ids>`_ `What are token type IDs? <../glossary.html#token-type-ids>`_
input_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): input_mask (:obj:`torch.FloatTensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`):
Mask to avoid performing attention on padding token indices. Mask to avoid performing attention on padding token indices.
Negative of `attention_mask`, i.e. with 0 for real tokens and 1 for padding. Negative of `attention_mask`, i.e. with 0 for real tokens and 1 for padding.
Kept for compatibility with the original code base. Kept for compatibility with the original code base.
@ -688,7 +688,7 @@ class XLNetModel(XLNetPreTrainedModel):
pos_emb = pos_emb.to(self.device) pos_emb = pos_emb.to(self.device)
return pos_emb return pos_emb
@add_start_docstrings_to_callable(XLNET_INPUTS_DOCSTRING) @add_start_docstrings_to_callable(XLNET_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
def forward( def forward(
self, self,
input_ids=None, input_ids=None,
@ -971,7 +971,7 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
return inputs return inputs
@add_start_docstrings_to_callable(XLNET_INPUTS_DOCSTRING) @add_start_docstrings_to_callable(XLNET_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
def forward( def forward(
self, self,
input_ids=None, input_ids=None,
@ -1091,7 +1091,7 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
self.init_weights() self.init_weights()
@add_start_docstrings_to_callable(XLNET_INPUTS_DOCSTRING) @add_start_docstrings_to_callable(XLNET_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
def forward( def forward(
self, self,
input_ids=None, input_ids=None,
@ -1196,7 +1196,7 @@ class XLNetForTokenClassification(XLNetPreTrainedModel):
self.init_weights() self.init_weights()
@add_start_docstrings_to_callable(XLNET_INPUTS_DOCSTRING) @add_start_docstrings_to_callable(XLNET_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
def forward( def forward(
self, self,
input_ids=None, input_ids=None,
@ -1305,7 +1305,7 @@ class XLNetForMultipleChoice(XLNetPreTrainedModel):
self.init_weights() self.init_weights()
@add_start_docstrings_to_callable(XLNET_INPUTS_DOCSTRING) @add_start_docstrings_to_callable(XLNET_INPUTS_DOCSTRING.format("(batch_size, num_choices, sequence_length)"))
def forward( def forward(
self, self,
input_ids=None, input_ids=None,
@ -1418,7 +1418,7 @@ class XLNetForQuestionAnsweringSimple(XLNetPreTrainedModel):
self.init_weights() self.init_weights()
@add_start_docstrings_to_callable(XLNET_INPUTS_DOCSTRING) @add_start_docstrings_to_callable(XLNET_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
def forward( def forward(
self, self,
input_ids=None, input_ids=None,
@ -1544,7 +1544,7 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel):
self.init_weights() self.init_weights()
@add_start_docstrings_to_callable(XLNET_INPUTS_DOCSTRING) @add_start_docstrings_to_callable(XLNET_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
def forward( def forward(
self, self,
input_ids=None, input_ids=None,

View File

@ -32,6 +32,7 @@ if is_torch_available():
LongformerForSequenceClassification, LongformerForSequenceClassification,
LongformerForTokenClassification, LongformerForTokenClassification,
LongformerForQuestionAnswering, LongformerForQuestionAnswering,
LongformerForMultipleChoice,
) )
@ -228,6 +229,29 @@ class LongformerModelTester(object):
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.seq_length, self.num_labels]) self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.seq_length, self.num_labels])
self.check_loss_output(result) self.check_loss_output(result)
def create_and_check_longformer_for_multiple_choice(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
config.num_choices = self.num_choices
model = LongformerForMultipleChoice(config=config)
model.to(torch_device)
model.eval()
multiple_choice_inputs_ids = input_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
multiple_choice_token_type_ids = token_type_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
multiple_choice_input_mask = input_mask.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
loss, logits = model(
multiple_choice_inputs_ids,
attention_mask=multiple_choice_input_mask,
token_type_ids=multiple_choice_token_type_ids,
labels=choice_labels,
)
result = {
"loss": loss,
"logits": logits,
}
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.num_choices])
self.check_loss_output(result)
def prepare_config_and_inputs_for_common(self): def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs() config_and_inputs = self.prepare_config_and_inputs()
( (
@ -298,6 +322,10 @@ class LongformerModelTest(ModelTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_longformer_for_token_classification(*config_and_inputs) self.model_tester.create_and_check_longformer_for_token_classification(*config_and_inputs)
def test_for_multiple_choice(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_longformer_for_multiple_choice(*config_and_inputs)
class LongformerModelIntegrationTest(unittest.TestCase): class LongformerModelIntegrationTest(unittest.TestCase):
@slow @slow