From 9c17256447b91cf8483c856cb15e95ed30ace538 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Fri, 29 May 2020 13:46:08 +0200 Subject: [PATCH] [Longformer] Multiple choice for longformer (#4645) * add multiple choice for longformer * add models to docs * adapt docstring * add test to longformer * add longformer for mc in init and modeling auto * fix tests --- docs/source/model_doc/albert.rst | 14 +++ docs/source/model_doc/longformer.rst | 15 ++++ docs/source/model_doc/roberta.rst | 7 ++ src/transformers/__init__.py | 1 + src/transformers/modeling_auto.py | 2 + src/transformers/modeling_bert.py | 24 ++--- src/transformers/modeling_longformer.py | 115 ++++++++++++++++++++++-- src/transformers/modeling_roberta.py | 18 ++-- src/transformers/modeling_tf_albert.py | 20 ++--- src/transformers/modeling_tf_bert.py | 24 ++--- src/transformers/modeling_xlnet.py | 22 ++--- tests/test_modeling_longformer.py | 28 ++++++ 12 files changed, 227 insertions(+), 63 deletions(-) diff --git a/docs/source/model_doc/albert.rst b/docs/source/model_doc/albert.rst index 8b78a336b54..057187e3d06 100644 --- a/docs/source/model_doc/albert.rst +++ b/docs/source/model_doc/albert.rst @@ -94,3 +94,17 @@ TFAlbertForSequenceClassification .. autoclass:: transformers.TFAlbertForSequenceClassification :members: + + +TFAlbertForMultipleChoice +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.TFAlbertForMultipleChoice + :members: + + +TFAlbertForQuestionAnswering +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.TFAlbertForQuestionAnswering + :members: diff --git a/docs/source/model_doc/longformer.rst b/docs/source/model_doc/longformer.rst index 7e8e816410e..07d0898ccf3 100644 --- a/docs/source/model_doc/longformer.rst +++ b/docs/source/model_doc/longformer.rst @@ -74,3 +74,18 @@ LongformerForQuestionAnswering .. autoclass:: transformers.LongformerForQuestionAnswering :members: + + +LongformerForMultipleChoice +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.LongformerForMultipleChoice + :members: + + +LongformerForTokenClassification +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.LongformerForTokenClassification + :members: + diff --git a/docs/source/model_doc/roberta.rst b/docs/source/model_doc/roberta.rst index 07e511228a8..31b39998160 100644 --- a/docs/source/model_doc/roberta.rst +++ b/docs/source/model_doc/roberta.rst @@ -74,6 +74,13 @@ RobertaForSequenceClassification :members: +RobertaForMultipleChoice +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.RobertaForMultipleChoice + :members: + + RobertaForTokenClassification ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 33907741f96..6c392478bd4 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -326,6 +326,7 @@ if is_torch_available(): LongformerModel, LongformerForMaskedLM, LongformerForSequenceClassification, + LongformerForMultipleChoice, LongformerForTokenClassification, LongformerForQuestionAnswering, LONGFORMER_PRETRAINED_MODEL_ARCHIVE_MAP, diff --git a/src/transformers/modeling_auto.py b/src/transformers/modeling_auto.py index cc8604f560b..11a8281963f 100644 --- a/src/transformers/modeling_auto.py +++ b/src/transformers/modeling_auto.py @@ -104,6 +104,7 @@ from .modeling_gpt2 import GPT2_PRETRAINED_MODEL_ARCHIVE_MAP, GPT2LMHeadModel, G from .modeling_longformer import ( LONGFORMER_PRETRAINED_MODEL_ARCHIVE_MAP, LongformerForMaskedLM, + LongformerForMultipleChoice, LongformerForQuestionAnswering, LongformerForSequenceClassification, LongformerForTokenClassification, @@ -297,6 +298,7 @@ MODEL_FOR_MULTIPLE_CHOICE_MAPPING = OrderedDict( [ (CamembertConfig, CamembertForMultipleChoice), (XLMRobertaConfig, XLMRobertaForMultipleChoice), + (LongformerConfig, LongformerForMultipleChoice), (RobertaConfig, RobertaForMultipleChoice), (BertConfig, BertForMultipleChoice), (XLNetConfig, XLNetForMultipleChoice), diff --git a/src/transformers/modeling_bert.py b/src/transformers/modeling_bert.py index 1e31b5c402a..29aea1b0a68 100644 --- a/src/transformers/modeling_bert.py +++ b/src/transformers/modeling_bert.py @@ -543,7 +543,7 @@ BERT_START_DOCSTRING = r""" BERT_INPUTS_DOCSTRING = r""" Args: - input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`): + input_ids (:obj:`torch.LongTensor` of shape :obj:`{0}`): Indices of input sequence tokens in the vocabulary. Indices can be obtained using :class:`transformers.BertTokenizer`. @@ -551,19 +551,19 @@ BERT_INPUTS_DOCSTRING = r""" :func:`transformers.PreTrainedTokenizer.encode_plus` for details. `What are input IDs? <../glossary.html#input-ids>`__ - attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): + attention_mask (:obj:`torch.FloatTensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`): Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens. `What are attention masks? <../glossary.html#attention-mask>`__ - token_type_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): + token_type_ids (:obj:`torch.LongTensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`): Segment token indices to indicate first and second portions of the inputs. Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1`` corresponds to a `sentence B` token `What are token type IDs? <../glossary.html#token-type-ids>`_ - position_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): + position_ids (:obj:`torch.LongTensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`): Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0, config.max_position_embeddings - 1]``. @@ -632,7 +632,7 @@ class BertModel(BertPreTrainedModel): for layer, heads in heads_to_prune.items(): self.encoder.layer[layer].attention.prune_heads(heads) - @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING) + @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)")) def forward( self, input_ids=None, @@ -759,7 +759,7 @@ class BertForPreTraining(BertPreTrainedModel): def get_output_embeddings(self): return self.cls.predictions.decoder - @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING) + @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)")) def forward( self, input_ids=None, @@ -859,7 +859,7 @@ class BertForMaskedLM(BertPreTrainedModel): def get_output_embeddings(self): return self.cls.predictions.decoder - @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING) + @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)")) def forward( self, input_ids=None, @@ -992,7 +992,7 @@ class BertForNextSentencePrediction(BertPreTrainedModel): self.init_weights() - @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING) + @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)")) def forward( self, input_ids=None, @@ -1081,7 +1081,7 @@ class BertForSequenceClassification(BertPreTrainedModel): self.init_weights() - @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING) + @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)")) def forward( self, input_ids=None, @@ -1177,7 +1177,7 @@ class BertForMultipleChoice(BertPreTrainedModel): self.init_weights() - @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING) + @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, num_choices, sequence_length)")) def forward( self, input_ids=None, @@ -1278,7 +1278,7 @@ class BertForTokenClassification(BertPreTrainedModel): self.init_weights() - @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING) + @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)")) def forward( self, input_ids=None, @@ -1375,7 +1375,7 @@ class BertForQuestionAnswering(BertPreTrainedModel): self.init_weights() - @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING) + @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)")) def forward( self, input_ids=None, diff --git a/src/transformers/modeling_longformer.py b/src/transformers/modeling_longformer.py index 70e5fbf903b..5baf056a10f 100644 --- a/src/transformers/modeling_longformer.py +++ b/src/transformers/modeling_longformer.py @@ -411,7 +411,7 @@ LONGFORMER_START_DOCSTRING = r""" LONGFORMER_INPUTS_DOCSTRING = r""" Args: - input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`): + input_ids (:obj:`torch.LongTensor` of shape :obj:`{0}`): Indices of input sequence tokens in the vocabulary. Indices can be obtained using :class:`transformers.LonmgformerTokenizer`. @@ -419,7 +419,7 @@ LONGFORMER_INPUTS_DOCSTRING = r""" :func:`transformers.PreTrainedTokenizer.encode_plus` for details. `What are input IDs? <../glossary.html#input-ids>`__ - attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): + attention_mask (:obj:`torch.FloatTensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`): Mask to decide the attention given on each token, local attention, global attenion, or no attention (for padding tokens). Tokens with global attention attends to all other tokens, and all other tokens attend to them. This is important for task-specific finetuning because it makes the model more flexible at representing the task. For example, @@ -431,13 +431,13 @@ LONGFORMER_INPUTS_DOCSTRING = r""" ``2`` for global attention (tokens that attend to all other tokens, and all other tokens attend to them). `What are attention masks? <../glossary.html#attention-mask>`__ - token_type_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): + token_type_ids (:obj:`torch.LongTensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`): Segment token indices to indicate first and second portions of the inputs. Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1`` corresponds to a `sentence B` token `What are token type IDs? <../glossary.html#token-type-ids>`_ - position_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): + position_ids (:obj:`torch.LongTensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`): Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0, config.max_position_embeddings - 1]``. @@ -537,7 +537,7 @@ class LongformerModel(RobertaModel): return padding_len, input_ids, attention_mask, token_type_ids, position_ids, inputs_embeds - @add_start_docstrings_to_callable(LONGFORMER_INPUTS_DOCSTRING) + @add_start_docstrings_to_callable(LONGFORMER_INPUTS_DOCSTRING.format("(batch_size, sequence_length)")) def forward( self, input_ids=None, @@ -641,7 +641,7 @@ class LongformerForMaskedLM(BertPreTrainedModel): self.init_weights() - @add_start_docstrings_to_callable(LONGFORMER_INPUTS_DOCSTRING) + @add_start_docstrings_to_callable(LONGFORMER_INPUTS_DOCSTRING.format("(batch_size, sequence_length)")) def forward( self, input_ids=None, @@ -729,7 +729,7 @@ class LongformerForSequenceClassification(BertPreTrainedModel): self.longformer = LongformerModel(config) self.classifier = LongformerClassificationHead(config) - @add_start_docstrings_to_callable(LONGFORMER_INPUTS_DOCSTRING) + @add_start_docstrings_to_callable(LONGFORMER_INPUTS_DOCSTRING.format("(batch_size, sequence_length)")) def forward( self, input_ids=None, @@ -866,7 +866,7 @@ class LongformerForQuestionAnswering(BertPreTrainedModel): return sep_token_indices.view(batch_size, 3, 2)[:, 0, 1] - @add_start_docstrings_to_callable(LONGFORMER_INPUTS_DOCSTRING) + @add_start_docstrings_to_callable(LONGFORMER_INPUTS_DOCSTRING.format("(batch_size, sequence_length)")) def forward( self, input_ids, @@ -993,7 +993,7 @@ class LongformerForTokenClassification(BertPreTrainedModel): self.init_weights() - @add_start_docstrings_to_callable(LONGFORMER_INPUTS_DOCSTRING) + @add_start_docstrings_to_callable(LONGFORMER_INPUTS_DOCSTRING.format("(batch_size, sequence_length)")) def forward( self, input_ids=None, @@ -1070,3 +1070,100 @@ class LongformerForTokenClassification(BertPreTrainedModel): outputs = (loss,) + outputs return outputs # (loss), scores, (hidden_states), (attentions) + + +@add_start_docstrings( + """Longformer Model with a multiple choice classification head on top (a linear layer on top of + the pooled output and a softmax) e.g. for RocStories/SWAG tasks. """, + LONGFORMER_START_DOCSTRING, +) +class LongformerForMultipleChoice(BertPreTrainedModel): + config_class = LongformerConfig + pretrained_model_archive_map = LONGFORMER_PRETRAINED_MODEL_ARCHIVE_MAP + base_model_prefix = "longformer" + + def __init__(self, config): + super().__init__(config) + + self.longformer = LongformerModel(config) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.classifier = nn.Linear(config.hidden_size, 1) + + self.init_weights() + + @add_start_docstrings_to_callable(LONGFORMER_INPUTS_DOCSTRING.format("(batch_size, num_choices, sequence_length)")) + def forward( + self, + input_ids=None, + token_type_ids=None, + attention_mask=None, + labels=None, + position_ids=None, + inputs_embeds=None, + ): + r""" + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`): + Labels for computing the multiple choice classification loss. + Indices should be in ``[0, ..., num_choices]`` where `num_choices` is the size of the second dimension + of the input tensors. (see `input_ids` above) + + Returns: + :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.RobertaConfig`) and inputs: + loss (:obj:`torch.FloatTensor`` of shape ``(1,)`, `optional`, returned when :obj:`labels` is provided): + Classification loss. + classification_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_choices)`): + `num_choices` is the second dimension of the input tensors. (see `input_ids` above). + + Classification scores (before SoftMax). + hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``): + Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) + of shape :obj:`(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``): + Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape + :obj:`(batch_size, num_heads, sequence_length, sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + + Examples:: + + from transformers import LongformerTokenizer, LongformerForTokenClassification + import torch + + tokenizer = LongformerTokenizer.from_pretrained('longformer-base-4096') + model = LongformerForMultipleChoice.from_pretrained('longformer-base-4096') + choices = ["Hello, my dog is cute", "Hello, my cat is amazing"] + input_ids = torch.tensor([tokenizer.encode(s, add_special_tokens=True) for s in choices]).unsqueeze(0) # Batch size 1, 2 choices + labels = torch.tensor(1).unsqueeze(0) # Batch size 1 + outputs = model(input_ids, labels=labels) + loss, classification_scores = outputs[:2] + + """ + num_choices = input_ids.shape[1] + + flat_input_ids = input_ids.view(-1, input_ids.size(-1)) + flat_position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None + flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None + flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None + outputs = self.longformer( + flat_input_ids, + position_ids=flat_position_ids, + token_type_ids=flat_token_type_ids, + attention_mask=flat_attention_mask, + ) + pooled_output = outputs[1] + + pooled_output = self.dropout(pooled_output) + logits = self.classifier(pooled_output) + reshaped_logits = logits.view(-1, num_choices) + + outputs = (reshaped_logits,) + outputs[2:] # add hidden states and attention if they are here + + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(reshaped_logits, labels) + outputs = (loss,) + outputs + + return outputs # (loss), reshaped_logits, (hidden_states), (attentions) diff --git a/src/transformers/modeling_roberta.py b/src/transformers/modeling_roberta.py index 9e1460c8300..2d085e3a8ae 100644 --- a/src/transformers/modeling_roberta.py +++ b/src/transformers/modeling_roberta.py @@ -95,7 +95,7 @@ ROBERTA_START_DOCSTRING = r""" ROBERTA_INPUTS_DOCSTRING = r""" Args: - input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`): + input_ids (:obj:`torch.LongTensor` of shape :obj:`{0}`): Indices of input sequence tokens in the vocabulary. Indices can be obtained using :class:`transformers.RobertaTokenizer`. @@ -103,19 +103,19 @@ ROBERTA_INPUTS_DOCSTRING = r""" :func:`transformers.PreTrainedTokenizer.encode_plus` for details. `What are input IDs? <../glossary.html#input-ids>`__ - attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): + attention_mask (:obj:`torch.FloatTensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`): Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens. `What are attention masks? <../glossary.html#attention-mask>`__ - token_type_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): + token_type_ids (:obj:`torch.LongTensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`): Segment token indices to indicate first and second portions of the inputs. Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1`` corresponds to a `sentence B` token `What are token type IDs? <../glossary.html#token-type-ids>`_ - position_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): + position_ids (:obj:`torch.LongTensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`): Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0, config.max_position_embeddings - 1]``. @@ -175,7 +175,7 @@ class RobertaForMaskedLM(BertPreTrainedModel): def get_output_embeddings(self): return self.lm_head.decoder - @add_start_docstrings_to_callable(ROBERTA_INPUTS_DOCSTRING) + @add_start_docstrings_to_callable(ROBERTA_INPUTS_DOCSTRING.format("(batch_size, sequence_length)")) def forward( self, input_ids=None, @@ -286,7 +286,7 @@ class RobertaForSequenceClassification(BertPreTrainedModel): self.roberta = RobertaModel(config) self.classifier = RobertaClassificationHead(config) - @add_start_docstrings_to_callable(ROBERTA_INPUTS_DOCSTRING) + @add_start_docstrings_to_callable(ROBERTA_INPUTS_DOCSTRING.format("(batch_size, sequence_length)")) def forward( self, input_ids=None, @@ -379,7 +379,7 @@ class RobertaForMultipleChoice(BertPreTrainedModel): self.init_weights() - @add_start_docstrings_to_callable(ROBERTA_INPUTS_DOCSTRING) + @add_start_docstrings_to_callable(ROBERTA_INPUTS_DOCSTRING.format("(batch_size, num_choices, sequence_length)")) def forward( self, input_ids=None, @@ -479,7 +479,7 @@ class RobertaForTokenClassification(BertPreTrainedModel): self.init_weights() - @add_start_docstrings_to_callable(ROBERTA_INPUTS_DOCSTRING) + @add_start_docstrings_to_callable(ROBERTA_INPUTS_DOCSTRING.format("(batch_size, sequence_length)")) def forward( self, input_ids=None, @@ -598,7 +598,7 @@ class RobertaForQuestionAnswering(BertPreTrainedModel): self.init_weights() - @add_start_docstrings_to_callable(ROBERTA_INPUTS_DOCSTRING) + @add_start_docstrings_to_callable(ROBERTA_INPUTS_DOCSTRING.format("(batch_size, sequence_length)")) def forward( self, input_ids, diff --git a/src/transformers/modeling_tf_albert.py b/src/transformers/modeling_tf_albert.py index 186b0ae3288..da7c3d458f7 100644 --- a/src/transformers/modeling_tf_albert.py +++ b/src/transformers/modeling_tf_albert.py @@ -628,7 +628,7 @@ ALBERT_START_DOCSTRING = r""" ALBERT_INPUTS_DOCSTRING = r""" Args: - input_ids (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`): + input_ids (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`{0}`): Indices of input sequence tokens in the vocabulary. Indices can be obtained using :class:`transformers.AlbertTokenizer`. @@ -636,19 +636,19 @@ ALBERT_INPUTS_DOCSTRING = r""" :func:`transformers.PreTrainedTokenizer.encode_plus` for details. `What are input IDs? <../glossary.html#input-ids>`__ - attention_mask (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional, defaults to :obj:`None`): + attention_mask (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`{0}`, `optional, defaults to :obj:`None`): Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens. `What are attention masks? <../glossary.html#attention-mask>`__ - token_type_ids (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): + token_type_ids (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`): Segment token indices to indicate first and second portions of the inputs. Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1`` corresponds to a `sentence B` token `What are token type IDs? <../glossary.html#token-type-ids>`_ - position_ids (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): + position_ids (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`): Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0, config.max_position_embeddings - 1]``. @@ -676,7 +676,7 @@ class TFAlbertModel(TFAlbertPreTrainedModel): super().__init__(config, *inputs, **kwargs) self.albert = TFAlbertMainLayer(config, name="albert") - @add_start_docstrings_to_callable(ALBERT_INPUTS_DOCSTRING) + @add_start_docstrings_to_callable(ALBERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)")) def call(self, inputs, **kwargs): r""" Returns: @@ -734,7 +734,7 @@ class TFAlbertForPreTraining(TFAlbertPreTrainedModel): def get_output_embeddings(self): return self.albert.embeddings - @add_start_docstrings_to_callable(ALBERT_INPUTS_DOCSTRING) + @add_start_docstrings_to_callable(ALBERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)")) def call(self, inputs, **kwargs): r""" Return: @@ -795,7 +795,7 @@ class TFAlbertForMaskedLM(TFAlbertPreTrainedModel): def get_output_embeddings(self): return self.albert.embeddings - @add_start_docstrings_to_callable(ALBERT_INPUTS_DOCSTRING) + @add_start_docstrings_to_callable(ALBERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)")) def call(self, inputs, **kwargs): r""" Returns: @@ -852,7 +852,7 @@ class TFAlbertForSequenceClassification(TFAlbertPreTrainedModel): config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier" ) - @add_start_docstrings_to_callable(ALBERT_INPUTS_DOCSTRING) + @add_start_docstrings_to_callable(ALBERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)")) def call(self, inputs, **kwargs): r""" Returns: @@ -908,7 +908,7 @@ class TFAlbertForQuestionAnswering(TFAlbertPreTrainedModel): config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs" ) - @add_start_docstrings_to_callable(ALBERT_INPUTS_DOCSTRING) + @add_start_docstrings_to_callable(ALBERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)")) def call(self, inputs, **kwargs): r""" Return: @@ -983,7 +983,7 @@ class TFAlbertForMultipleChoice(TFAlbertPreTrainedModel): """ return {"input_ids": tf.constant(MULTIPLE_CHOICE_DUMMY_INPUTS)} - @add_start_docstrings_to_callable(ALBERT_INPUTS_DOCSTRING) + @add_start_docstrings_to_callable(ALBERT_INPUTS_DOCSTRING.format("(batch_size, num_choices, sequence_length)")) def call( self, inputs, diff --git a/src/transformers/modeling_tf_bert.py b/src/transformers/modeling_tf_bert.py index b2dd660f995..48ad5656c79 100644 --- a/src/transformers/modeling_tf_bert.py +++ b/src/transformers/modeling_tf_bert.py @@ -621,7 +621,7 @@ BERT_START_DOCSTRING = r""" BERT_INPUTS_DOCSTRING = r""" Args: - input_ids (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`): + input_ids (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`{0}`): Indices of input sequence tokens in the vocabulary. Indices can be obtained using :class:`transformers.BertTokenizer`. @@ -629,19 +629,19 @@ BERT_INPUTS_DOCSTRING = r""" :func:`transformers.PreTrainedTokenizer.encode_plus` for details. `What are input IDs? <../glossary.html#input-ids>`__ - attention_mask (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): + attention_mask (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`): Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens. `What are attention masks? <../glossary.html#attention-mask>`__ - token_type_ids (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): + token_type_ids (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`): Segment token indices to indicate first and second portions of the inputs. Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1`` corresponds to a `sentence B` token `What are token type IDs? <../glossary.html#token-type-ids>`__ - position_ids (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): + position_ids (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`): Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0, config.max_position_embeddings - 1]``. @@ -669,7 +669,7 @@ class TFBertModel(TFBertPreTrainedModel): super().__init__(config, *inputs, **kwargs) self.bert = TFBertMainLayer(config, name="bert") - @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING) + @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)")) def call(self, inputs, **kwargs): r""" Returns: @@ -726,7 +726,7 @@ class TFBertForPreTraining(TFBertPreTrainedModel): def get_output_embeddings(self): return self.bert.embeddings - @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING) + @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)")) def call(self, inputs, **kwargs): r""" Return: @@ -782,7 +782,7 @@ class TFBertForMaskedLM(TFBertPreTrainedModel): def get_output_embeddings(self): return self.bert.embeddings - @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING) + @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)")) def call(self, inputs, **kwargs): r""" Return: @@ -832,7 +832,7 @@ class TFBertForNextSentencePrediction(TFBertPreTrainedModel): self.bert = TFBertMainLayer(config, name="bert") self.nsp = TFBertNSPHead(config, name="nsp___cls") - @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING) + @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)")) def call(self, inputs, **kwargs): r""" Return: @@ -888,7 +888,7 @@ class TFBertForSequenceClassification(TFBertPreTrainedModel): config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier" ) - @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING) + @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)")) def call(self, inputs, **kwargs): r""" Return: @@ -954,7 +954,7 @@ class TFBertForMultipleChoice(TFBertPreTrainedModel): """ return {"input_ids": tf.constant(MULTIPLE_CHOICE_DUMMY_INPUTS)} - @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING) + @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, num_choices, sequence_length)")) def call( self, inputs, @@ -1065,7 +1065,7 @@ class TFBertForTokenClassification(TFBertPreTrainedModel): config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier" ) - @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING) + @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)")) def call(self, inputs, **kwargs): r""" Return: @@ -1122,7 +1122,7 @@ class TFBertForQuestionAnswering(TFBertPreTrainedModel): config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs" ) - @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING) + @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)")) def call(self, inputs, **kwargs): r""" Return: diff --git a/src/transformers/modeling_xlnet.py b/src/transformers/modeling_xlnet.py index 5aeb69fca0e..9dfbae4f6f9 100644 --- a/src/transformers/modeling_xlnet.py +++ b/src/transformers/modeling_xlnet.py @@ -506,7 +506,7 @@ XLNET_START_DOCSTRING = r""" XLNET_INPUTS_DOCSTRING = r""" Args: - input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`): + input_ids (:obj:`torch.LongTensor` of shape :obj:`{0}`): Indices of input sequence tokens in the vocabulary. Indices can be obtained using :class:`transformers.BertTokenizer`. @@ -514,7 +514,7 @@ XLNET_INPUTS_DOCSTRING = r""" :func:`transformers.PreTrainedTokenizer.encode_plus` for details. `What are input IDs? <../glossary.html#input-ids>`__ - attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): + attention_mask (:obj:`torch.FloatTensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`): Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens. @@ -535,13 +535,13 @@ XLNET_INPUTS_DOCSTRING = r""" Mask to indicate the output tokens to use. If ``target_mapping[k, i, j] = 1``, the i-th predict in batch k is on the j-th token. Only used during pretraining for partial prediction or for sequential decoding (generation). - token_type_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): + token_type_ids (:obj:`torch.LongTensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`): Segment token indices to indicate first and second portions of the inputs. Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1`` corresponds to a `sentence B` token. The classifier token should be represented by a ``2``. `What are token type IDs? <../glossary.html#token-type-ids>`_ - input_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): + input_mask (:obj:`torch.FloatTensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`): Mask to avoid performing attention on padding token indices. Negative of `attention_mask`, i.e. with 0 for real tokens and 1 for padding. Kept for compatibility with the original code base. @@ -688,7 +688,7 @@ class XLNetModel(XLNetPreTrainedModel): pos_emb = pos_emb.to(self.device) return pos_emb - @add_start_docstrings_to_callable(XLNET_INPUTS_DOCSTRING) + @add_start_docstrings_to_callable(XLNET_INPUTS_DOCSTRING.format("(batch_size, sequence_length)")) def forward( self, input_ids=None, @@ -971,7 +971,7 @@ class XLNetLMHeadModel(XLNetPreTrainedModel): return inputs - @add_start_docstrings_to_callable(XLNET_INPUTS_DOCSTRING) + @add_start_docstrings_to_callable(XLNET_INPUTS_DOCSTRING.format("(batch_size, sequence_length)")) def forward( self, input_ids=None, @@ -1091,7 +1091,7 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel): self.init_weights() - @add_start_docstrings_to_callable(XLNET_INPUTS_DOCSTRING) + @add_start_docstrings_to_callable(XLNET_INPUTS_DOCSTRING.format("(batch_size, sequence_length)")) def forward( self, input_ids=None, @@ -1196,7 +1196,7 @@ class XLNetForTokenClassification(XLNetPreTrainedModel): self.init_weights() - @add_start_docstrings_to_callable(XLNET_INPUTS_DOCSTRING) + @add_start_docstrings_to_callable(XLNET_INPUTS_DOCSTRING.format("(batch_size, sequence_length)")) def forward( self, input_ids=None, @@ -1305,7 +1305,7 @@ class XLNetForMultipleChoice(XLNetPreTrainedModel): self.init_weights() - @add_start_docstrings_to_callable(XLNET_INPUTS_DOCSTRING) + @add_start_docstrings_to_callable(XLNET_INPUTS_DOCSTRING.format("(batch_size, num_choices, sequence_length)")) def forward( self, input_ids=None, @@ -1418,7 +1418,7 @@ class XLNetForQuestionAnsweringSimple(XLNetPreTrainedModel): self.init_weights() - @add_start_docstrings_to_callable(XLNET_INPUTS_DOCSTRING) + @add_start_docstrings_to_callable(XLNET_INPUTS_DOCSTRING.format("(batch_size, sequence_length)")) def forward( self, input_ids=None, @@ -1544,7 +1544,7 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel): self.init_weights() - @add_start_docstrings_to_callable(XLNET_INPUTS_DOCSTRING) + @add_start_docstrings_to_callable(XLNET_INPUTS_DOCSTRING.format("(batch_size, sequence_length)")) def forward( self, input_ids=None, diff --git a/tests/test_modeling_longformer.py b/tests/test_modeling_longformer.py index dbf5e39ab36..00a67d716f8 100644 --- a/tests/test_modeling_longformer.py +++ b/tests/test_modeling_longformer.py @@ -32,6 +32,7 @@ if is_torch_available(): LongformerForSequenceClassification, LongformerForTokenClassification, LongformerForQuestionAnswering, + LongformerForMultipleChoice, ) @@ -228,6 +229,29 @@ class LongformerModelTester(object): self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.seq_length, self.num_labels]) self.check_loss_output(result) + def create_and_check_longformer_for_multiple_choice( + self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels + ): + config.num_choices = self.num_choices + model = LongformerForMultipleChoice(config=config) + model.to(torch_device) + model.eval() + multiple_choice_inputs_ids = input_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous() + multiple_choice_token_type_ids = token_type_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous() + multiple_choice_input_mask = input_mask.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous() + loss, logits = model( + multiple_choice_inputs_ids, + attention_mask=multiple_choice_input_mask, + token_type_ids=multiple_choice_token_type_ids, + labels=choice_labels, + ) + result = { + "loss": loss, + "logits": logits, + } + self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.num_choices]) + self.check_loss_output(result) + def prepare_config_and_inputs_for_common(self): config_and_inputs = self.prepare_config_and_inputs() ( @@ -298,6 +322,10 @@ class LongformerModelTest(ModelTesterMixin, unittest.TestCase): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_longformer_for_token_classification(*config_and_inputs) + def test_for_multiple_choice(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_longformer_for_multiple_choice(*config_and_inputs) + class LongformerModelIntegrationTest(unittest.TestCase): @slow