From 44c985facdf562d6cf3d7cd72f2900e3a0d85d6e Mon Sep 17 00:00:00 2001 From: thomwolf Date: Mon, 15 Jul 2019 11:36:50 +0200 Subject: [PATCH] update doc for XLM and XLNet --- pytorch_transformers/modeling_bert.py | 26 +- pytorch_transformers/modeling_gpt2.py | 20 +- pytorch_transformers/modeling_openai.py | 20 +- pytorch_transformers/modeling_transfo_xl.py | 8 +- pytorch_transformers/modeling_utils.py | 1 - pytorch_transformers/modeling_xlm.py | 471 +++++++++---------- pytorch_transformers/modeling_xlnet.py | 474 +++++++++----------- 7 files changed, 459 insertions(+), 561 deletions(-) diff --git a/pytorch_transformers/modeling_bert.py b/pytorch_transformers/modeling_bert.py index 78dbc699828..a044832282d 100644 --- a/pytorch_transformers/modeling_bert.py +++ b/pytorch_transformers/modeling_bert.py @@ -611,11 +611,11 @@ BERT_INPUTS_DOCSTRING = r""" (see `BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding`_ for more details). **attention_mask**: (`optional`) ``torch.Tensor`` of shape ``(batch_size, sequence_length)``: Mask to avoid performing attention on padding token indices. - Mask indices selected in ``[0, 1]``: + Mask values selected in ``[0, 1]``: ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens. **head_mask**: (`optional`) ``torch.Tensor`` of shape ``(num_heads,)`` or ``(num_layers, num_heads)``: Mask to nullify selected heads of the self-attention modules. - Mask indices selected in ``[0, 1]``: + Mask values selected in ``[0, 1]``: ``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**. """ @@ -714,7 +714,7 @@ class BertModel(BertPreTrainedModel): return outputs # sequence_output, pooled_output, (hidden_states), (attentions) -@add_start_docstrings("""Bert Model transformer BERT model with two heads on top as done during the pre-training: +@add_start_docstrings("""Bert Model with two heads on top as done during the pre-training: a `masked language modeling` head and a `next sentence prediction (classification)` head. """, BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING) class BertForPreTraining(BertPreTrainedModel): @@ -791,7 +791,7 @@ class BertForPreTraining(BertPreTrainedModel): return outputs # (loss), prediction_scores, seq_relationship_score, (hidden_states), (attentions) -@add_start_docstrings("""Bert Model transformer BERT model with a `language modeling` head on top. """, +@add_start_docstrings("""Bert Model with a `language modeling` head on top. """, BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING) class BertForMaskedLM(BertPreTrainedModel): r""" @@ -856,7 +856,7 @@ class BertForMaskedLM(BertPreTrainedModel): return outputs # (masked_lm_loss), prediction_scores, (hidden_states), (attentions) -@add_start_docstrings("""Bert Model transformer BERT model with a `next sentence prediction (classification)` head on top. """, +@add_start_docstrings("""Bert Model with a `next sentence prediction (classification)` head on top. """, BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING) class BertForNextSentencePrediction(BertPreTrainedModel): r""" @@ -913,7 +913,7 @@ class BertForNextSentencePrediction(BertPreTrainedModel): return outputs # (next_sentence_loss), seq_relationship_score, (hidden_states), (attentions) -@add_start_docstrings("""Bert Model transformer BERT model with a sequence classification/regression head on top (a linear layer on top of +@add_start_docstrings("""Bert Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled output) e.g. for GLUE tasks. """, BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING) class BertForSequenceClassification(BertPreTrainedModel): @@ -981,7 +981,7 @@ class BertForSequenceClassification(BertPreTrainedModel): return outputs # (loss), logits, (hidden_states), (attentions) -@add_start_docstrings("""Bert Model transformer BERT model with a multiple choice classification head on top (a linear layer on top of +@add_start_docstrings("""Bert Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a softmax) e.g. for RocStories/SWAG tasks. """, BERT_START_DOCSTRING) class BertForMultipleChoice(BertPreTrainedModel): @@ -1016,11 +1016,11 @@ class BertForMultipleChoice(BertPreTrainedModel): **attention_mask**: (`optional`) ``torch.Tensor`` of shape ``(batch_size, num_choices, sequence_length)``: Mask to avoid performing attention on padding token indices. The second dimension of the input (`num_choices`) indicates the number of choices to score. - Mask indices selected in ``[0, 1]``: + Mask values selected in ``[0, 1]``: ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens. **head_mask**: (`optional`) ``torch.Tensor`` of shape ``(num_heads,)`` or ``(num_layers, num_heads)``: Mask to nullify selected heads of the self-attention modules. - Mask indices selected in ``[0, 1]``: + Mask values selected in ``[0, 1]``: ``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**. **labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``: Labels for computing the multiple choice classification loss. @@ -1087,7 +1087,7 @@ class BertForMultipleChoice(BertPreTrainedModel): return outputs # (loss), reshaped_logits, (hidden_states), (attentions) -@add_start_docstrings("""Bert Model transformer BERT model with a token classification head on top (a linear layer on top of +@add_start_docstrings("""Bert Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks. """, BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING) class BertForTokenClassification(BertPreTrainedModel): @@ -1154,17 +1154,17 @@ class BertForTokenClassification(BertPreTrainedModel): return outputs # (loss), scores, (hidden_states), (attentions) -@add_start_docstrings("""Bert Model transformer BERT model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of +@add_start_docstrings("""Bert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of the hidden-states output to compute `span start logits` and `span end logits`). """, BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING) class BertForQuestionAnswering(BertPreTrainedModel): r""" **start_positions**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``: - Position (index) of the start of the labelled span for computing the token classification loss. + Labels for position (index) of the start of the labelled span for computing the token classification loss. Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence are not taken into account for computing the loss. **end_positions**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``: - Position (index) of the end of the labelled span for computing the token classification loss. + Labels for position (index) of the end of the labelled span for computing the token classification loss. Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence are not taken into account for computing the loss. diff --git a/pytorch_transformers/modeling_gpt2.py b/pytorch_transformers/modeling_gpt2.py index 06386f9aceb..415396496cf 100644 --- a/pytorch_transformers/modeling_gpt2.py +++ b/pytorch_transformers/modeling_gpt2.py @@ -404,11 +404,11 @@ GPT2_INPUTS_DOCSTRING = r""" Inputs: (see `past` output below). Can be used to speed up sequential decoding. **attention_mask**: (`optional`) ``torch.Tensor`` of shape ``(batch_size, sequence_length)``: Mask to avoid performing attention on padding token indices. - Mask indices selected in ``[0, 1]``: + Mask values selected in ``[0, 1]``: ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens. **head_mask**: (`optional`) ``torch.Tensor`` of shape ``(num_heads,)`` or ``(num_layers, num_heads)``: Mask to nullify selected heads of the self-attention modules. - Mask indices selected in ``[0, 1]``: + Mask values selected in ``[0, 1]``: ``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**. """ @@ -541,7 +541,7 @@ class GPT2Model(GPT2PreTrainedModel): (linear layer with weights tied to the input embeddings). """, GPT2_START_DOCSTRING, GPT2_INPUTS_DOCSTRING) class GPT2LMHeadModel(GPT2PreTrainedModel): r""" - **lm_labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``: + **labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``: Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set ``lm_labels = input_ids`` Indices are selected in ``[-1, 0, ..., config.vocab_size]`` @@ -549,7 +549,7 @@ class GPT2LMHeadModel(GPT2PreTrainedModel): computed for labels in ``[0, ..., config.vocab_size]`` Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs: - **loss**: (`optional`, returned when ``lm_labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``: + **loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``: Language modeling loss. **prediction_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, config.vocab_size)`` Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). @@ -571,7 +571,7 @@ class GPT2LMHeadModel(GPT2PreTrainedModel): >>> tokenizer = GPT2Tokenizer.from_pretrained('gpt2') >>> model = GPT2LMHeadModel(config) >>> input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1 - >>> outputs = model(input_ids, lm_labels=input_ids) + >>> outputs = model(input_ids, labels=input_ids) >>> loss, logits = outputs[:2] """ @@ -590,17 +590,17 @@ class GPT2LMHeadModel(GPT2PreTrainedModel): self._tie_or_clone_weights(self.lm_head, self.transformer.wte) - def forward(self, input_ids, position_ids=None, token_type_ids=None, lm_labels=None, past=None, head_mask=None): + def forward(self, input_ids, position_ids=None, token_type_ids=None, labels=None, past=None, head_mask=None): transformer_outputs = self.transformer(input_ids, position_ids, token_type_ids, past, head_mask) hidden_states = transformer_outputs[0] lm_logits = self.lm_head(hidden_states) outputs = (lm_logits,) + transformer_outputs[1:] - if lm_labels is not None: + if labels is not None: # Shift so that tokens < n predict n shift_logits = lm_logits[..., :-1, :].contiguous() - shift_labels = lm_labels[..., 1:].contiguous() + shift_labels = labels[..., 1:].contiguous() # Flatten the tokens loss_fct = CrossEntropyLoss(ignore_index=-1) loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), @@ -639,11 +639,11 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel): (see `past` output below). Can be used to speed up sequential decoding. **attention_mask**: (`optional`) ``torch.Tensor`` of shape ``(batch_size, num_choices, sequence_length)``: Mask to avoid performing attention on padding token indices. - Mask indices selected in ``[0, 1]``: + Mask values selected in ``[0, 1]``: ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens. **head_mask**: (`optional`) ``torch.Tensor`` of shape ``(num_heads,)`` or ``(num_layers, num_heads)``: Mask to nullify selected heads of the self-attention modules. - Mask indices selected in ``[0, 1]``: + Mask values selected in ``[0, 1]``: ``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**. **lm_labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``: Labels for language modeling. diff --git a/pytorch_transformers/modeling_openai.py b/pytorch_transformers/modeling_openai.py index 268252a12c8..d51e4309b83 100644 --- a/pytorch_transformers/modeling_openai.py +++ b/pytorch_transformers/modeling_openai.py @@ -414,11 +414,11 @@ OPENAI_GPT_INPUTS_DOCSTRING = r""" Inputs: Indices are selected in the vocabulary (unlike BERT which has a specific vocabulary for segment indices). **attention_mask**: (`optional`) ``torch.Tensor`` of shape ``(batch_size, sequence_length)``: Mask to avoid performing attention on padding token indices. - Mask indices selected in ``[0, 1]``: + Mask values selected in ``[0, 1]``: ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens. **head_mask**: (`optional`) ``torch.Tensor`` of shape ``(num_heads,)`` or ``(num_layers, num_heads)``: Mask to nullify selected heads of the self-attention modules. - Mask indices selected in ``[0, 1]``: + Mask values selected in ``[0, 1]``: ``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**. """ @@ -536,7 +536,7 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel): (linear layer with weights tied to the input embeddings). """, OPENAI_GPT_START_DOCSTRING, OPENAI_GPT_INPUTS_DOCSTRING) class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel): r""" - **lm_labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``: + **labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``: Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set ``lm_labels = input_ids`` Indices are selected in ``[-1, 0, ..., config.vocab_size]`` @@ -544,7 +544,7 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel): computed for labels in ``[0, ..., config.vocab_size]`` Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs: - **loss**: (`optional`, returned when ``lm_labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``: + **loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``: Language modeling loss. **prediction_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, config.vocab_size)`` Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). @@ -562,7 +562,7 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel): >>> tokenizer = OpenAIGPTTokenizer.from_pretrained('openai-gpt') >>> model = OpenAIGPTLMHeadModel(config) >>> input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1 - >>> outputs = model(input_ids, lm_labels=input_ids) + >>> outputs = model(input_ids, labels=input_ids) >>> loss, logits = outputs[:2] """ @@ -581,16 +581,16 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel): self._tie_or_clone_weights(self.lm_head, self.transformer.tokens_embed) - def forward(self, input_ids, position_ids=None, token_type_ids=None, lm_labels=None, head_mask=None): + def forward(self, input_ids, position_ids=None, token_type_ids=None, labels=None, head_mask=None): transformer_outputs = self.transformer(input_ids, position_ids, token_type_ids, head_mask) hidden_states = transformer_outputs[0] lm_logits = self.lm_head(hidden_states) outputs = (lm_logits,) + transformer_outputs[1:] - if lm_labels is not None: + if labels is not None: # Shift so that tokens < n predict n shift_logits = lm_logits[..., :-1, :].contiguous() - shift_labels = lm_labels[..., 1:].contiguous() + shift_labels = labels[..., 1:].contiguous() # Flatten the tokens loss_fct = CrossEntropyLoss(ignore_index=-1) loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), @@ -625,11 +625,11 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel): Indices are selected in the vocabulary (unlike BERT which has a specific vocabulary for segment indices). **attention_mask**: (`optional`) ``torch.Tensor`` of shape ``(batch_size, num_choices, sequence_length)``: Mask to avoid performing attention on padding token indices. - Mask indices selected in ``[0, 1]``: + Mask values selected in ``[0, 1]``: ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens. **head_mask**: (`optional`) ``torch.Tensor`` of shape ``(num_heads,)`` or ``(num_layers, num_heads)``: Mask to nullify selected heads of the self-attention modules. - Mask indices selected in ``[0, 1]``: + Mask values selected in ``[0, 1]``: ``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**. **lm_labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``: Labels for language modeling. diff --git a/pytorch_transformers/modeling_transfo_xl.py b/pytorch_transformers/modeling_transfo_xl.py index 7eb7a46df36..d9c8cba8dbb 100644 --- a/pytorch_transformers/modeling_transfo_xl.py +++ b/pytorch_transformers/modeling_transfo_xl.py @@ -937,13 +937,13 @@ TRANSFO_XL_INPUTS_DOCSTRING = r""" Indices can be obtained using :class:`pytorch_transformers.TransfoXLTokenizer`. See :func:`pytorch_transformers.PreTrainedTokenizer.encode` and :func:`pytorch_transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details. - **mems**: + **mems**: (`optional`) list of ``torch.FloatTensor`` (one for each layer): that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model (see `mems` output below). Can be used to speed up sequential decoding and attend to longer context. **head_mask**: (`optional`) ``torch.Tensor`` of shape ``(num_heads,)`` or ``(num_layers, num_heads)``: Mask to nullify selected heads of the self-attention modules. - Mask indices selected in ``[0, 1]``: + Mask values selected in ``[0, 1]``: ``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**. """ @@ -954,7 +954,7 @@ class TransfoXLModel(TransfoXLPreTrainedModel): Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs: **last_hidden_state**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, hidden_size)`` Sequence of hidden-states at the last layer of the model. - **mems**: ``torch.Tensor`` of shape ``(num_heads,)`` or ``(num_layers, num_heads)``: + **mems**: list of ``torch.FloatTensor`` (one for each layer): that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model (see `mems` input above). Can be used to speed up sequential decoding and attend to longer context. @@ -1270,7 +1270,7 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel): **prediction_scores**: ``None`` if ``lm_labels`` is provided else ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, config.vocab_size)`` Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). We don't output them when the loss is computed to speedup adaptive softmax decoding. - **mems**: ``torch.Tensor`` of shape ``(num_heads,)`` or ``(num_layers, num_heads)``: + **mems**: list of ``torch.FloatTensor`` (one for each layer): that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model (see `mems` input above). Can be used to speed up sequential decoding and attend to longer context. diff --git a/pytorch_transformers/modeling_utils.py b/pytorch_transformers/modeling_utils.py index 71fa9e37476..4e5fe92001e 100644 --- a/pytorch_transformers/modeling_utils.py +++ b/pytorch_transformers/modeling_utils.py @@ -538,7 +538,6 @@ class PoolerAnswerClass(nn.Module): class SQuADHead(nn.Module): """ A SQuAD head inspired by XLNet. - Compute """ def __init__(self, config): super(SQuADHead, self).__init__() diff --git a/pytorch_transformers/modeling_xlm.py b/pytorch_transformers/modeling_xlm.py index 755e504b7d5..33b5bcf7fe3 100644 --- a/pytorch_transformers/modeling_xlm.py +++ b/pytorch_transformers/modeling_xlm.py @@ -30,7 +30,7 @@ from torch import nn from torch.nn import functional as F from torch.nn import CrossEntropyLoss, MSELoss -from .modeling_utils import (PretrainedConfig, PreTrainedModel, +from .modeling_utils import (PretrainedConfig, PreTrainedModel, add_start_docstrings, prune_linear_layer, SequenceSummary, SQuADHead) logger = logging.getLogger(__name__) @@ -392,28 +392,94 @@ class XLMPreTrainedModel(PreTrainedModel): module.weight.data.fill_(1.0) +XLM_START_DOCSTRING = r""" The XLM model was proposed in + `Cross-lingual Language Model Pretraining`_ + by Guillaume Lample*, Alexis Conneau*. It's a transformer pre-trained using one of the following objectives: + + - a causal language modeling (CLM) objective (next token prediction), + - a masked language modeling (MLM) objective (Bert-like), or + - a Translation Language Modeling (TLM) object (extension of Bert's MLM to multiple language inputs) + + Original code can be found `here`_. + + This model is a PyTorch `torch.nn.Module`_ sub-class. Use it as a regular PyTorch Module and + refer to the PyTorch documentation for all matter related to general usage and behavior. + + .. _`Cross-lingual Language Model Pretraining`: + https://arxiv.org/abs/1901.07291 + + .. _`torch.nn.Module`: + https://pytorch.org/docs/stable/nn.html#module + + .. _`here`: + https://github.com/facebookresearch/XLM + + Parameters: + config (:class:`~pytorch_transformers.XLMConfig`): Model configuration class with all the parameters of the model. +""" + +XLM_INPUTS_DOCSTRING = r""" + Inputs: + **input_ids**: ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``: + Indices of input sequence tokens in the vocabulary. + Indices can be obtained using :class:`pytorch_transformers.XLMTokenizer`. + See :func:`pytorch_transformers.PreTrainedTokenizer.encode` and + :func:`pytorch_transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details. + **position_ids**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``: + Indices of positions of each input sequence tokens in the position embeddings. + Selected in the range ``[0, config.max_position_embeddings - 1[``. + **token_type_ids**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``: + A parallel sequence of tokens (can be used to indicate various portions of the inputs). + The embeddings from these tokens will be summed with the respective token embeddings. + Indices are selected in the vocabulary (unlike BERT which has a specific vocabulary for segment indices). + **langs**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``: + A parallel sequence of tokens to be used to indicate the language of each token in the input. + Indices are selected in the pre-trained language vocabulary, + i.e. in the range ``[0, config.n_langs - 1[``. + **attention_mask**: (`optional`) ``torch.Tensor`` of shape ``(batch_size, sequence_length)``: + Mask to avoid performing attention on padding token indices. + Mask values selected in ``[0, 1]``: + ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens. + **lengths**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``: + Length of each sentence that can be used to avoid performing attention on padding token indices. + You can also use `attention_mask` for the same result (see above), kept here for compatbility. + Indices selected in ``[0, ..., input_ids.size(-1)]``: + **cache**: + dictionary with ``torch.FloatTensor`` that contains pre-computed + hidden-states (key and values in the attention blocks) as computed by the model + (see `cache` output below). Can be used to speed up sequential decoding. + The dictionary object will be modified in-place during the forward pass to add newly computed hidden-states. + **head_mask**: (`optional`) ``torch.Tensor`` of shape ``(num_heads,)`` or ``(num_layers, num_heads)``: + Mask to nullify selected heads of the self-attention modules. + Mask values selected in ``[0, 1]``: + ``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**. +""" + +@add_start_docstrings("The bare XLM Model transformer outputing raw hidden-states without any specific head on top.", + XLM_START_DOCSTRING, XLM_INPUTS_DOCSTRING) class XLMModel(XLMPreTrainedModel): + r""" + Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs: + **last_hidden_state**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, hidden_size)`` + Sequence of hidden-states at the last layer of the model. + **attentions**: (`optional`, returned when ``config.output_attentions=True``) + list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``: + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. + **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``) + list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings) + of shape ``(batch_size, sequence_length, hidden_size)``: + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + + Examples:: + + >>> config = XLMConfig.from_pretrained('xlm-mlm-en-2048') + >>> tokenizer = XLMTokenizer.from_pretrained('xlm-mlm-en-2048') + >>> model = XLMModel(config) + >>> input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1 + >>> outputs = model(input_ids) + >>> last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple + """ - XLM model from: "Cross-lingual Language Model Pretraining" by Guillaume Lample, Alexis Conneau - - Paper: https://arxiv.org/abs/1901.07291 - - Original code: https://github.com/facebookresearch/XLM - - Args: - `config`: a XLMConfig class instance with the configuration to build a new model - `output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False - `keep_multihead_output`: If True, saves output of the multi-head attention module with its gradient. - This can be used to compute head importance metrics. Default: False - - Example:: - - config = modeling.XLMConfig(vocab_size_or_config_json_file=32000, hidden_size=768, - num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) - - model = modeling.XLMModel(config=config) - """ - ATTRIBUTES = ['encoder', 'eos_index', 'pad_index', # 'with_output', 'n_langs', 'n_words', 'dim', 'n_layers', 'n_heads', 'hidden_dim', 'dropout', 'attention_dropout', 'asm', @@ -493,57 +559,8 @@ class XLMModel(XLMPreTrainedModel): for layer, heads in heads_to_prune.items(): self.attentions[layer].prune_heads(heads) - def forward(self, input_ids, lengths=None, positions=None, langs=None, + def forward(self, input_ids, lengths=None, position_ids=None, langs=None, token_type_ids=None, attention_mask=None, cache=None, head_mask=None): # src_enc=None, src_len=None, - """ - Performs a model forward pass. **Can be called by calling the class directly, once it has been instantiated.** - - Parameters: - `input_ids`: a ``torch.LongTensor`` of shape [batch_size, sequence_length] - with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts - `run_bert_extract_features.py`, `run_bert_classifier.py` and `run_bert_squad.py`) - `lengths`: ``torch.LongTensor`` of size ``bs``, containing the length of each sentence - `positions`: ``torch.LongTensor`` of size ``(bs, slen)``, containing word positions - `langs`: ``torch.LongTensor`` of size ``(bs, slen)``, containing language IDs - `token_type_ids`: an optional ``torch.LongTensor`` of shape [batch_size, sequence_length] with the token - types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to - a `sentence B` token (see XLM paper for more details). - `attention_mask`: an optional ``torch.LongTensor`` of shape [batch_size, sequence_length] with indices - selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max - input sequence length in the current batch. It's the mask that we typically use for attention when - a batch has varying length sentences. - `cache`: TODO - `head_mask`: an optional ``torch.Tensor`` of shape [num_heads] or [num_layers, num_heads] with indices between 0 and 1. - It's a mask to be used to nullify some heads of the transformer. 1.0 => head is fully masked, 0.0 => head is not masked. - - - Returns: - A ``tuple(encoded_layers, pooled_output)``, with - - ``encoded_layers``: controlled by ``output_all_encoded_layers`` argument: - - - ``output_all_encoded_layers=True``: outputs a list of the full sequences of encoded-hidden-states at the end \ - of each attention block (i.e. 12 full sequences for XLM-base, 24 for XLM-large), each \ - encoded-hidden-state is a ``torch.FloatTensor`` of size [batch_size, sequence_length, hidden_size], - - - ``output_all_encoded_layers=False``: outputs only the full sequence of hidden-states corresponding \ - to the last attention block of shape [batch_size, sequence_length, hidden_size], - - ``pooled_output``: a ``torch.FloatTensor`` of size [batch_size, hidden_size] which is the output of a - classifier pre-trained on top of the hidden state associated to the first character of the - input (`CLS`) to train on the Next-Sentence task (see XLM's paper). - - Example:: - - # Already been converted into WordPiece token ids - input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) - input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) - token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) - - all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask) - # or - all_encoder_layers, pooled_output = model.forward(input_ids, token_type_ids, input_mask) - """ if lengths is None: lengths = (input_ids != self.pad_index).sum(dim=1).long() # mask = input_ids != self.pad_index @@ -563,18 +580,15 @@ class XLMModel(XLMPreTrainedModel): # if self.is_decoder and src_enc is not None: # src_mask = torch.arange(src_len.max(), dtype=torch.long, device=lengths.device) < src_len[:, None] - # positions - if positions is None: - positions = input_ids.new((slen,)).long() - positions = torch.arange(slen, out=positions).unsqueeze(0) + # position_ids + if position_ids is None: + position_ids = input_ids.new((slen,)).long() + position_ids = torch.arange(slen, out=position_ids).unsqueeze(0) else: - assert positions.size() == (bs, slen) # (slen, bs) - # positions = positions.transpose(0, 1) + assert position_ids.size() == (bs, slen) # (slen, bs) + # position_ids = position_ids.transpose(0, 1) # langs - assert langs is None or token_type_ids is None, "You can only use one among langs and token_type_ids" - if token_type_ids is not None: - langs = token_type_ids if langs is not None: assert langs.size() == (bs, slen) # (slen, bs) # langs = langs.transpose(0, 1) @@ -598,7 +612,7 @@ class XLMModel(XLMPreTrainedModel): if cache is not None: _slen = slen - cache['slen'] input_ids = input_ids[:, -_slen:] - positions = positions[:, -_slen:] + position_ids = position_ids[:, -_slen:] if langs is not None: langs = langs[:, -_slen:] mask = mask[:, -_slen:] @@ -606,9 +620,11 @@ class XLMModel(XLMPreTrainedModel): # embeddings tensor = self.embeddings(input_ids) - tensor = tensor + self.position_embeddings(positions).expand_as(tensor) + tensor = tensor + self.position_embeddings(position_ids).expand_as(tensor) if langs is not None: tensor = tensor + self.lang_embeddings(langs) + if token_type_ids is not None: + tensor = tensor + self.embeddings(token_type_ids) tensor = self.layer_norm_emb(tensor) tensor = F.dropout(tensor, p=self.dropout, training=self.training) tensor *= mask.unsqueeze(-1).to(tensor.dtype) @@ -702,25 +718,40 @@ class XLMPredLayer(nn.Module): return outputs +@add_start_docstrings("""The XLM Model transformer with a language modeling head on top + (linear layer with weights tied to the input embeddings). """, + XLM_START_DOCSTRING, XLM_INPUTS_DOCSTRING) class XLMWithLMHeadModel(XLMPreTrainedModel): - """ XLM model from: "Cross-lingual Language Model Pretraining" by Guillaume Lample, Alexis Conneau + r""" + **labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``: + Labels for language modeling. + Note that the labels **are shifted** inside the model, i.e. you can set ``lm_labels = input_ids`` + Indices are selected in ``[-1, 0, ..., config.vocab_size]`` + All labels set to ``-1`` are ignored (masked), the loss is only + computed for labels in ``[0, ..., config.vocab_size]`` - Paper: https://arxiv.org/abs/1901.07291 + Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs: + **loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``: + Language modeling loss. + **prediction_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, config.vocab_size)`` + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + **attentions**: (`optional`, returned when ``config.output_attentions=True``) + list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``: + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. + **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``) + list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings) + of shape ``(batch_size, sequence_length, hidden_size)``: + Hidden-states of the model at the output of each layer plus the initial embedding outputs. - Original code: https://github.com/facebookresearch/XLM + Examples:: - Args: - `config`: a XLMConfig class instance with the configuration to build a new model - `output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False - `keep_multihead_output`: If True, saves output of the multi-head attention module with its gradient. - This can be used to compute head importance metrics. Default: False + >>> config = XLMConfig.from_pretrained('xlm-mlm-en-2048') + >>> tokenizer = XLMTokenizer.from_pretrained('xlm-mlm-en-2048') + >>> model = XLMWithLMHeadModel(config) + >>> input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1 + >>> outputs = model(input_ids) + >>> last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple - Example:: - - config = modeling.XLMConfig(vocab_size_or_config_json_file=32000, hidden_size=768, - num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) - - model = modeling.XLMModel(config=config) """ def __init__(self, config): super(XLMWithLMHeadModel, self).__init__(config) @@ -735,57 +766,9 @@ class XLMWithLMHeadModel(XLMPreTrainedModel): """ self._tie_or_clone_weights(self.pred_layer.proj, self.transformer.embeddings) - def forward(self, input_ids, lengths=None, positions=None, langs=None, token_type_ids=None, + def forward(self, input_ids, lengths=None, position_ids=None, langs=None, token_type_ids=None, attention_mask=None, cache=None, labels=None, head_mask=None): - """ - Args: - `input_ids`: a ``torch.LongTensor`` of shape [batch_size, sequence_length] - with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts - `run_bert_extract_features.py`, `run_bert_classifier.py` and `run_bert_squad.py`) - `lengths`: TODO - `positions`: TODO - `langs`: TODO - `token_type_ids`: an optional ``torch.LongTensor`` of shape [batch_size, sequence_length] with the token - types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to - a `sentence B` token (see XLM paper for more details). - `attention_mask`: an optional ``torch.LongTensor`` of shape [batch_size, sequence_length] with indices - selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max - input sequence length in the current batch. It's the mask that we typically use for attention when - a batch has varying length sentences. - `cache`: TODO - `labels`: TODO - `head_mask`: an optional ``torch.Tensor`` of shape [num_heads] or [num_layers, num_heads] with indices between 0 and 1. - It's a mask to be used to nullify some heads of the transformer. 1.0 => head is fully masked, 0.0 => head is not masked. - - - Returns: - A ``tuple(encoded_layers, pooled_output)``, with - - ``encoded_layers``: controlled by ``output_all_encoded_layers`` argument: - - If ``output_all_encoded_layers=True``: outputs a list of the full sequences of encoded-hidden-states \ - at the end of each attention block (i.e. 12 full sequences for XLM-base, 24 for XLM-large), each \ - encoded-hidden-state is a ``torch.FloatTensor`` of size [batch_size, sequence_length, hidden_size], - - If ``output_all_encoded_layers=False``: outputs only the full sequence of hidden-states corresponding \ - to the last attention block of shape [batch_size, sequence_length, hidden_size], - - ``pooled_output``: a ``torch.FloatTensor`` of size [batch_size, hidden_size] which is the output of a \ - classifier pre-trained on top of the hidden state associated to the first character of the \ - input (`CLS`) to train on the Next-Sentence task (see XLM's paper). - - Example:: - - # Already been converted into WordPiece token ids - input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) - input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) - token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) - - all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask) - # or - all_encoder_layers, pooled_output = model.forward(input_ids, token_type_ids, input_mask) - """ - transformer_outputs = self.transformer(input_ids, lengths=lengths, positions=positions, token_type_ids=token_type_ids, + transformer_outputs = self.transformer(input_ids, lengths=lengths, position_ids=position_ids, token_type_ids=token_type_ids, langs=langs, attention_mask=attention_mask, cache=cache, head_mask=head_mask) output = transformer_outputs[0] @@ -795,25 +778,40 @@ class XLMWithLMHeadModel(XLMPreTrainedModel): return outputs +@add_start_docstrings("""XLM Model with a sequence classification/regression head on top (a linear layer on top of + the pooled output) e.g. for GLUE tasks. """, + XLM_START_DOCSTRING, XLM_INPUTS_DOCSTRING) class XLMForSequenceClassification(XLMPreTrainedModel): - """XLM model ("XLM: Generalized Autoregressive Pretraining for Language Understanding"). + r""" + **labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``: + Labels for computing the sequence classification/regression loss. + Indices should be in ``[0, ..., config.num_labels]``. + If ``config.num_labels == 1`` a regression loss is computed (Mean-Square loss), + If ``config.num_labels > 1`` a classification loss is computed (Cross-Entropy). - Args: - `config`: a XLMConfig class instance with the configuration to build a new model - `output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False - `keep_multihead_output`: If True, saves output of the multi-head attention module with its gradient. - This can be used to compute head importance metrics. Default: False - `summary_type`: str, "last", "first", "mean", or "attn". The method - to pool the input to get a vector representation. Default: last + Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs: + **loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``: + Classification (or regression if config.num_labels==1) loss. + **logits**: ``torch.FloatTensor`` of shape ``(batch_size, config.num_labels)`` + Classification (or regression if config.num_labels==1) scores (before SoftMax). + **attentions**: (`optional`, returned when ``config.output_attentions=True``) + list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``: + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. + **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``) + list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings) + of shape ``(batch_size, sequence_length, hidden_size)``: + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + Examples:: - - Example:: - - config = modeling.XLMConfig(vocab_size_or_config_json_file=32000, d_model=768, - n_layer=12, num_attention_heads=12, intermediate_size=3072) - - model = modeling.XLMModel(config=config) + >>> config = XLMConfig.from_pretrained('xlm-mlm-en-2048') + >>> tokenizer = XLMTokenizer.from_pretrained('xlm-mlm-en-2048') + >>> + >>> model = XLMForSequenceClassification(config) + >>> input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1 + >>> labels = torch.tensor([1]).unsqueeze(0) # Batch size 1 + >>> outputs = model(input_ids, labels=labels) + >>> loss, logits = outputs[:2] """ def __init__(self, config): @@ -825,42 +823,9 @@ class XLMForSequenceClassification(XLMPreTrainedModel): self.apply(self.init_weights) - def forward(self, input_ids, lengths=None, positions=None, langs=None, token_type_ids=None, + def forward(self, input_ids, lengths=None, position_ids=None, langs=None, token_type_ids=None, attention_mask=None, cache=None, labels=None, head_mask=None): - """ - Args: - input_ids: TODO - lengths: TODO - positions: TODO - langs: TODO - token_type_ids: int32 Tensor in shape [bsz, len], the input segment IDs. - attention_mask: [optional] float32 Tensor, SAME FUNCTION as `input_mask` - but with 1 for real tokens and 0 for padding. - Added for easy compatibility with the XLM model (which uses this negative masking). - You can only uses one among `input_mask` and `attention_mask` - cache: TODO - labels: TODO - head_mask: TODO - - - Returns: - A ``tuple(logits_or_loss, new_mems)``. If ``labels`` is ``None``, return token logits with shape - [batch_size, sequence_length]. If it isn't ``None``, return the ``CrossEntropy`` loss with the targets. - - ``new_mems`` is a list (num layers) of updated mem states at the entry of each layer \ - each mem state is a ``torch.FloatTensor`` of size [self.config.mem_len, batch_size, self.config.d_model] \ - Note that the first two dimensions are transposed in ``mems`` with regards to ``input_ids`` and ``labels`` - - Example:: - - # Already been converted into WordPiece token ids - input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) - input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) - token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) - - all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask) - """ - transformer_outputs = self.transformer(input_ids, lengths=lengths, positions=positions, token_type_ids=token_type_ids, + transformer_outputs = self.transformer(input_ids, lengths=lengths, position_ids=position_ids, token_type_ids=token_type_ids, langs=langs, attention_mask=attention_mask, cache=cache, head_mask=head_mask) output = transformer_outputs[0] @@ -881,26 +846,53 @@ class XLMForSequenceClassification(XLMPreTrainedModel): return outputs +@add_start_docstrings("""XLM Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of + the hidden-states output to compute `span start logits` and `span end logits`). """, + XLM_START_DOCSTRING, XLM_INPUTS_DOCSTRING) class XLMForQuestionAnswering(XLMPreTrainedModel): - """ - XLM model for Question Answering (span extraction). - This module is composed of the XLM model with a linear layer on top of - the sequence output that computes start_logits and end_logits + r""" + **start_positions**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``: + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). + Position outside of the sequence are not taken into account for computing the loss. + **end_positions**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``: + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). + Position outside of the sequence are not taken into account for computing the loss. + **is_impossible**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``: + Labels whether a question has an answer or no answer (SQuAD 2.0) + **cls_index**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``: + Labels for position (index) of the classification token to use as input for computing plausibility of the answer. + **p_mask**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``: + Optional mask of tokens which can't be in answers (e.g. [CLS], [PAD], ...) - Args: - `config`: a XLMConfig class instance with the configuration to build a new model - `output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False - `keep_multihead_output`: If True, saves output of the multi-head attention module with its gradient. - This can be used to compute head importance metrics. Default: False + Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs: + **loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``: + Total span extraction loss is the sum of a Cross-Entropy for the start and end positions. + **start_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length,)`` + Span-start scores (before SoftMax). + **end_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length,)`` + Span-end scores (before SoftMax). + **attentions**: (`optional`, returned when ``config.output_attentions=True``) + list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``: + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. + **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``) + list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings) + of shape ``(batch_size, sequence_length, hidden_size)``: + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + Examples:: + >>> config = XLMConfig.from_pretrained('xlm-mlm-en-2048') + >>> tokenizer = XLMTokenizer.from_pretrained('xlm-mlm-en-2048') + >>> + >>> model = XLMForQuestionAnswering(config) + >>> input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1 + >>> start_positions = torch.tensor([1]) + >>> end_positions = torch.tensor([3]) + >>> outputs = model(input_ids, start_positions=start_positions, end_positions=end_positions) + >>> loss, start_scores, end_scores = outputs[:2] - Example:: - - config = XLMConfig(vocab_size_or_config_json_file=32000, hidden_size=768, - num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) - - model = XLMForQuestionAnswering(config) """ def __init__(self, config): super(XLMForQuestionAnswering, self).__init__(config) @@ -910,63 +902,10 @@ class XLMForQuestionAnswering(XLMPreTrainedModel): self.apply(self.init_weights) - def forward(self, input_ids, lengths=None, positions=None, langs=None, token_type_ids=None, + def forward(self, input_ids, lengths=None, position_ids=None, langs=None, token_type_ids=None, attention_mask=None, cache=None, start_positions=None, end_positions=None, cls_index=None, is_impossible=None, p_mask=None, head_mask=None): - - """ - Performs a model forward pass. **Can be called by calling the class directly, once it has been instantiated.** - - Args: - input_ids: a ``torch.LongTensor`` of shape [batch_size, sequence_length] - with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts - `run_bert_extract_features.py`, `run_bert_classifier.py` and `run_bert_squad.py`) - lengths: TODO - positions: TODO - langs: TODO - token_type_ids: an optional ``torch.LongTensor`` of shape [batch_size, sequence_length] with the token - types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to - a `sentence B` token (see XLM paper for more details). - attention_mask: [optional] float32 Tensor, SAME FUNCTION as `input_mask` - but with 1 for real tokens and 0 for padding. - Added for easy compatibility with the XLM model (which uses this negative masking). - You can only uses one among `input_mask` and `attention_mask` - cache: TODO - start_positions: position of the first token for the labeled span: ``torch.LongTensor`` of shape [batch_size]. - Positions are clamped to the length of the sequence and position outside of the sequence are not taken - into account for computing the loss. - end_positions: position of the last token for the labeled span: ``torch.LongTensor`` of shape [batch_size]. - Positions are clamped to the length of the sequence and position outside of the sequence are not taken - into account for computing the loss. - cls_index: TODO - is_impossible: TODO - p_mask: TODO - head_mask: an optional ``torch.Tensor`` of shape [num_heads] or [num_layers, num_heads] with indices between 0 and 1. - It's a mask to be used to nullify some heads of the transformer. 1.0 => head is fully masked, 0.0 => head is not masked. - - Returns: - Either the ``total_loss`` or a ``tuple(start_logits, end_logits)`` - - if ``start_positions`` and ``end_positions`` are not ``None``, \ - outputs the total_loss which is the sum of the CrossEntropy loss for the start and end token positions. - - if ``start_positions`` or ``end_positions`` is ``None``: - Outputs a ``tuple(start_logits, end_logits)`` which are the logits respectively for the start and end - position tokens of shape [batch_size, sequence_length]. - - Example:: - - # Already been converted into WordPiece token ids - input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) - input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) - token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) - - start_logits, end_logits = model(input_ids, token_type_ids, input_mask) - # or - start_logits, end_logits = model.forward(input_ids, token_type_ids, input_mask) - """ - - transformer_outputs = self.transformer(input_ids, lengths=lengths, positions=positions, token_type_ids=token_type_ids, + transformer_outputs = self.transformer(input_ids, lengths=lengths, position_ids=position_ids, token_type_ids=token_type_ids, langs=langs, attention_mask=attention_mask, cache=cache, head_mask=head_mask) output = transformer_outputs[0] diff --git a/pytorch_transformers/modeling_xlnet.py b/pytorch_transformers/modeling_xlnet.py index d3efd2799a8..a46426d82a7 100644 --- a/pytorch_transformers/modeling_xlnet.py +++ b/pytorch_transformers/modeling_xlnet.py @@ -15,8 +15,6 @@ # limitations under the License. """ PyTorch XLNet model. """ -from __future__ import (absolute_import, division, print_function, - unicode_literals) from __future__ import absolute_import, division, print_function, unicode_literals import json @@ -32,7 +30,8 @@ from torch.nn import functional as F from torch.nn import CrossEntropyLoss, MSELoss from .modeling_utils import (CONFIG_NAME, WEIGHTS_NAME, PretrainedConfig, PreTrainedModel, - SequenceSummary, PoolerAnswerClass, PoolerEndLogits, PoolerStartLogits) + SequenceSummary, PoolerAnswerClass, PoolerEndLogits, PoolerStartLogits, + add_start_docstrings) logger = logging.getLogger(__name__) @@ -619,26 +618,105 @@ class XLNetPreTrainedModel(PreTrainedModel): module.mask_emb.data.normal_(mean=0.0, std=self.config.initializer_range) +XLNET_START_DOCSTRING = r""" The XLNet model was proposed in + `XLNet: Generalized Autoregressive Pretraining for Language Understanding`_ + by Zhilin Yang*, Zihang Dai*, Yiming Yang, Jaime Carbonell, Ruslan Salakhutdinov, Quoc V. Le. + XLnet is an extension of the Transformer-XL model pre-trained using an autoregressive method + to learn bidirectional contexts by maximizing the expected likelihood over all permutations + of the input sequence factorization order. + + The specific attention pattern can be controlled at training and test time using the `perm_mask` input. + + Do to the difficulty of training a fully auto-regressive model over various factorization order, + XLNet is pretrained using only a sub-set of the output tokens as target which are selected + with the `target_mapping` input. + + To use XLNet for sequential decoding (i.e. not in fully bi-directional setting), use the `perm_mask` and + `target_mapping` inputs to control the attention span and outputs (see examples in `examples/run_generation.py`) + + This model is a PyTorch `torch.nn.Module`_ sub-class. Use it as a regular PyTorch Module and + refer to the PyTorch documentation for all matter related to general usage and behavior. + + .. _`XLNet: Generalized Autoregressive Pretraining for Language Understanding`: + http://arxiv.org/abs/1906.08237 + + .. _`torch.nn.Module`: + https://pytorch.org/docs/stable/nn.html#module + + Parameters: + config (:class:`~pytorch_transformers.XLNetConfig`): Model configuration class with all the parameters of the model. +""" + +XLNET_INPUTS_DOCSTRING = r""" + Inputs: + **input_ids**: ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``: + Indices of input sequence tokens in the vocabulary. + Indices can be obtained using :class:`pytorch_transformers.XLNetTokenizer`. + See :func:`pytorch_transformers.PreTrainedTokenizer.encode` and + :func:`pytorch_transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details. + **token_type_ids**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``: + A parallel sequence of tokens (can be used to indicate various portions of the inputs). + The embeddings from these tokens will be summed with the respective token embeddings. + Indices are selected in the vocabulary (unlike BERT which has a specific vocabulary for segment indices). + **attention_mask**: (`optional`) ``torch.Tensor`` of shape ``(batch_size, sequence_length)``: + Mask to avoid performing attention on padding token indices. + Mask values selected in ``[0, 1]``: + ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens. + **input_mask**: (`optional`) ``torch.Tensor`` of shape ``(batch_size, sequence_length)``: + Mask to avoid performing attention on padding token indices. + Negative of `attention_mask`, i.e. with 0 for real tokens and 1 for padding. + Kept for compatibility with the original code base. + You can only uses one of `input_mask` and `attention_mask` + Mask values selected in ``[0, 1]``: + ``1`` for tokens that are MASKED, ``0`` for tokens that are NOT MASKED. + **mems**: (`optional`) + list of ``torch.FloatTensor`` (one for each layer): + that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model + (see `mems` output below). Can be used to speed up sequential decoding and attend to longer context. + **perm_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, sequence_length)``: + Mask to indicate the attention pattern for each input token with values selected in ``[0, 1]``: + If ``perm_mask[k, i, j] = 0``, i attend to j in batch k; + if ``perm_mask[k, i, j] = 1``, i does not attend to j in batch k. + If None, each token attends to all the others (full bidirectional attention). + Only used during pretraining (to define factorization order) or for sequential decoding (generation). + **target_mapping**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, num_predict, sequence_length)``: + Mask to indicate the output tokens to use. + If ``target_mapping[k, i, j] = 1``, the i-th predict in batch k is on the j-th token. + Only used during pretraining for partial prediction or for sequential decoding (generation). + **head_mask**: (`optional`) ``torch.Tensor`` of shape ``(num_heads,)`` or ``(num_layers, num_heads)``: + Mask to nullify selected heads of the self-attention modules. + Mask values selected in ``[0, 1]``: + ``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**. +""" + +@add_start_docstrings("The bare XLNet Model transformer outputing raw hidden-states without any specific head on top.", + XLNET_START_DOCSTRING, XLNET_INPUTS_DOCSTRING) class XLNetModel(XLNetPreTrainedModel): - """XLNet model ("XLNet: Generalized Autoregressive Pretraining for Language Understanding"). + r""" + Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs: + **last_hidden_state**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, hidden_size)`` + Sequence of hidden-states at the last layer of the model. + **mems**: + list of ``torch.FloatTensor`` (one for each layer): + that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model + (see `mems` input above). Can be used to speed up sequential decoding and attend to longer context. + **attentions**: (`optional`, returned when ``config.output_attentions=True``) + list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``: + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. + **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``) + list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings) + of shape ``(batch_size, sequence_length, hidden_size)``: + Hidden-states of the model at the output of each layer plus the initial embedding outputs. - TODO Lysandre filled: this was copied from the XLNetLMHeadModel, check that it's ok. + Examples:: - Args: - `config`: a XLNetConfig class instance with the configuration to build a new model - `output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False - `keep_multihead_output`: If True, saves output of the multi-head attention module with its gradient. - This can be used to compute head importance metrics. Default: False + >>> config = XLNetConfig.from_pretrained('xlnet-large-cased') + >>> tokenizer = XLNetTokenizer.from_pretrained('xlnet-large-cased') + >>> model = XLNetModel(config) + >>> input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1 + >>> outputs = model(input_ids) + >>> last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple - - Example:: - - config = modeling.XLNetConfig(vocab_size_or_config_json_file=32000, d_model=768, - n_layer=12, num_attention_heads=12, intermediate_size=3072) - - model = modeling.XLNetModel(config=config) - - TODO Lysandre filled: Added example usage """ def __init__(self, config): super(XLNetModel, self).__init__(config) @@ -765,50 +843,6 @@ class XLNetModel(XLNetPreTrainedModel): def forward(self, input_ids, token_type_ids=None, input_mask=None, attention_mask=None, mems=None, perm_mask=None, target_mapping=None, head_mask=None): - """ - Performs a model forward pass. **Can be called by calling the class directly, once it has been instantiated.** - - Args: - input_ids: int32 Tensor in shape [bsz, len], the input token IDs. - token_type_ids: int32 Tensor in shape [bsz, len], the input segment IDs. - input_mask: [optional] float32 Tensor in shape [bsz, len], the input mask. - 0 for real tokens and 1 for padding. - attention_mask: [optional] float32 Tensor, SAME FUNCTION as `input_mask` - but with 1 for real tokens and 0 for padding. - Added for easy compatibility with the BERT model (which uses this negative masking). - You can only uses one among `input_mask` and `attention_mask` - mems: [optional] a list of float32 Tensors in shape [mem_len, bsz, d_model], memory - from previous batches. The length of the list equals n_layer. - If None, no memory is used. - perm_mask: [optional] float32 Tensor in shape [bsz, len, len]. - If perm_mask[k, i, j] = 0, i attend to j in batch k; - if perm_mask[k, i, j] = 1, i does not attend to j in batch k. - If None, each position attends to all the others. - target_mapping: [optional] float32 Tensor in shape [bsz, num_predict, len]. - If target_mapping[k, i, j] = 1, the i-th predict in batch k is - on the j-th token. - Only used during pretraining for partial prediction. - Set to None during finetuning. - head_mask: TODO Lysandre didn't fill - - - Returns: - TODO Lysandre didn't fill: Missing returns! - - Example:: - - # Already been converted into WordPiece token ids - input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) - input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) - token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) - - all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask) - # or - all_encoder_layers, pooled_output = model.forward(input_ids, token_type_ids, input_mask) - - TODO Lysandre filled: Filled with the LMHead example, is probably different since it has a different output - - """ # the original code for XLNet uses shapes [len, bsz] with the batch dimension at the end # but we want a unified interface in the library with the batch size on the first dimension # so we move here the first dimension (batch) to the end @@ -952,23 +986,49 @@ class XLNetModel(XLNetPreTrainedModel): return outputs # outputs, new_mems, (hidden_states), (attentions) +@add_start_docstrings("""XLNet Model with a language modeling head on top + (linear layer with weights tied to the input embeddings). """, + XLNET_START_DOCSTRING, XLNET_INPUTS_DOCSTRING) class XLNetLMHeadModel(XLNetPreTrainedModel): - """XLNet model ("XLNet: Generalized Autoregressive Pretraining for Language Understanding"). + r""" + **labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``: + Labels for language modeling. + Note that the labels **are shifted** inside the model, i.e. you can set ``lm_labels = input_ids`` + Indices are selected in ``[-1, 0, ..., config.vocab_size]`` + All labels set to ``-1`` are ignored (masked), the loss is only + computed for labels in ``[0, ..., config.vocab_size]`` - Args: - `config`: a XLNetConfig class instance with the configuration to build a new model - `output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False - `keep_multihead_output`: If True, saves output of the multi-head attention module with its gradient. - This can be used to compute head importance metrics. Default: False + Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs: + **loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``: + Language modeling loss. + **prediction_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, config.vocab_size)`` + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + **mems**: + list of ``torch.FloatTensor`` (one for each layer): + that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model + (see `mems` input above). Can be used to speed up sequential decoding and attend to longer context. + **attentions**: (`optional`, returned when ``config.output_attentions=True``) + list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``: + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. + **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``) + list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings) + of shape ``(batch_size, sequence_length, hidden_size)``: + Hidden-states of the model at the output of each layer plus the initial embedding outputs. - Example:: + Examples:: - config = modeling.XLNetConfig(vocab_size_or_config_json_file=32000, d_model=768, - n_layer=12, num_attention_heads=12, intermediate_size=3072) + >>> config = XLNetConfig.from_pretrained('xlnet-large-cased') + >>> tokenizer = XLNetTokenizer.from_pretrained('xlnet-large-cased') + >>> model = XLNetLMHeadModel(config) + >>> # We show how to setup inputs to predict a next token using a bi-directional context. + >>> input_ids = torch.tensor(tokenizer.encode("Hello, my dog is very ")).unsqueeze(0) # We will predict the masked token + >>> perm_mask = torch.zeros((1, input_ids.shape[1], input_ids.shape[1]), dtype=torch.float) + >>> perm_mask[:, :, -1] = 1.0 # Previous tokens don't see last token + >>> target_mapping = torch.zeros((1, 1, input_ids.shape[1]), dtype=torch.float) # Shape [1, 1, seq_length] => let's predict one token + >>> target_mapping[0, 0, -1] = 1.0 # Our first (and only) prediction will be the last token of the sequence (the masked token) + >>> outputs = model(input_ids, perm_mask=perm_mask, target_mapping=target_mapping) + >>> next_token_logits = outputs[0] # Output has shape [target_mapping.size(0), target_mapping.size(1), config.vocab_size] - model = modeling.XLNetLMHeadModel(config=config) - - TODO Lysandre modified: Changed XLNetModel to XLNetLMHeadModel in the example """ def __init__(self, config): super(XLNetLMHeadModel, self).__init__(config) @@ -989,58 +1049,6 @@ class XLNetLMHeadModel(XLNetPreTrainedModel): def forward(self, input_ids, token_type_ids=None, input_mask=None, attention_mask=None, mems=None, perm_mask=None, target_mapping=None, labels=None, head_mask=None): - """ - all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask) - - Args: - input_ids: int32 Tensor in shape [bsz, len], the input token IDs. - token_type_ids: int32 Tensor in shape [bsz, len], the input segment IDs. - input_mask: [optional] float32 Tensor in shape [bsz, len], the input mask. - 0 for real tokens and 1 for padding. - attention_mask: [optional] float32 Tensor, SAME FUNCTION as `input_mask` - but with 1 for real tokens and 0 for padding. - Added for easy compatibility with the BERT model (which uses this negative masking). - You can only uses one among `input_mask` and `attention_mask` - mems: [optional] a list of float32 Tensors in shape [mem_len, bsz, d_model], memory - from previous batches. The length of the list equals n_layer. - If None, no memory is used. - perm_mask: [optional] float32 Tensor in shape [bsz, len, len]. - If perm_mask[k, i, j] = 0, i attend to j in batch k; - if perm_mask[k, i, j] = 1, i does not attend to j in batch k. - If None, each position attends to all the others. - target_mapping: [optional] float32 Tensor in shape [bsz, num_predict, len]. - If target_mapping[k, i, j] = 1, the i-th predict in batch k is - on the j-th token. - Only used during pretraining for partial prediction. - Set to None during finetuning. - - Returns: - A ``tuple(encoded_layers, pooled_output)``, with - - ``encoded_layers``: controlled by ``output_all_encoded_layers`` argument: - - - ``output_all_encoded_layers=True``: outputs a list of the full sequences of encoded-hidden-states \ - at the end of each attention block (i.e. 12 full sequences for XLNet-base, 24 for XLNet-large), \ - each encoded-hidden-state is a ``torch.FloatTensor`` of size [batch_size, sequence_length, d_model], - - - ``output_all_encoded_layers=False``: outputs only the full sequence of hidden-states corresponding \ - to the last attention block of shape [batch_size, sequence_length, d_model], - - ``pooled_output``: a ``torch.FloatTensor`` of size [batch_size, d_model] which is the output of a \ - classifier pretrained on top of the hidden state associated to the first character of the \ - input (`CLS`) to train on the Next-Sentence task (see XLNet's paper). - - Example:: - - # Already been converted into WordPiece token ids - input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) - input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) - token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) - - all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask) - # or - all_encoder_layers, pooled_output = model.forward(input_ids, token_type_ids, input_mask) - """ transformer_outputs = self.transformer(input_ids, token_type_ids, input_mask, attention_mask, mems, perm_mask, target_mapping, head_mask) @@ -1055,30 +1063,48 @@ class XLNetLMHeadModel(XLNetPreTrainedModel): labels.view(-1)) outputs = (loss,) + outputs - return outputs # return (loss), logits, (mems), (hidden states), (attentions) + return outputs # return (loss), logits, mems, (hidden states), (attentions) +@add_start_docstrings("""XLNet Model with a sequence classification/regression head on top (a linear layer on top of + the pooled output) e.g. for GLUE tasks. """, + XLNET_START_DOCSTRING, XLNET_INPUTS_DOCSTRING) class XLNetForSequenceClassification(XLNetPreTrainedModel): - """XLNet model ("XLNet: Generalized Autoregressive Pretraining for Language Understanding"). + r""" + **labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``: + Labels for computing the sequence classification/regression loss. + Indices should be in ``[0, ..., config.num_labels]``. + If ``config.num_labels == 1`` a regression loss is computed (Mean-Square loss), + If ``config.num_labels > 1`` a classification loss is computed (Cross-Entropy). - Args: - `config`: a XLNetConfig class instance with the configuration to build a new model - `output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False - `keep_multihead_output`: If True, saves output of the multi-head attention module with its gradient. - This can be used to compute head importance metrics. Default: False - `summary_type`: str, "last", "first", "mean", or "attn". The method - to pool the input to get a vector representation. Default: last + Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs: + **loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``: + Classification (or regression if config.num_labels==1) loss. + **logits**: ``torch.FloatTensor`` of shape ``(batch_size, config.num_labels)`` + Classification (or regression if config.num_labels==1) scores (before SoftMax). + **mems**: + list of ``torch.FloatTensor`` (one for each layer): + that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model + (see `mems` input above). Can be used to speed up sequential decoding and attend to longer context. + **attentions**: (`optional`, returned when ``config.output_attentions=True``) + list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``: + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. + **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``) + list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings) + of shape ``(batch_size, sequence_length, hidden_size)``: + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + Examples:: + >>> config = XLNetConfig.from_pretrained('xlnet-large-cased') + >>> tokenizer = XLNetTokenizer.from_pretrained('xlnet-large-cased') + >>> + >>> model = XLNetForSequenceClassification(config) + >>> input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1 + >>> labels = torch.tensor([1]).unsqueeze(0) # Batch size 1 + >>> outputs = model(input_ids, labels=labels) + >>> loss, logits = outputs[:2] - Example:: - - # Already been converted into WordPiece token ids - input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) - input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) - token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) - - all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask) """ def __init__(self, config): super(XLNetForSequenceClassification, self).__init__(config) @@ -1093,57 +1119,6 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel): def forward(self, input_ids, token_type_ids=None, input_mask=None, attention_mask=None, mems=None, perm_mask=None, target_mapping=None, labels=None, head_mask=None): - """ - Performs a model forward pass. **Can be called by calling the class directly, once it has been instantiated.** - - Args: - input_ids: int32 Tensor in shape [bsz, len], the input token IDs. - token_type_ids: int32 Tensor in shape [bsz, len], the input segment IDs. - input_mask: float32 Tensor in shape [bsz, len], the input mask. - 0 for real tokens and 1 for padding. - attention_mask: [optional] float32 Tensor, SAME FUNCTION as `input_mask` - but with 1 for real tokens and 0 for padding. - Added for easy compatibility with the BERT model (which uses this negative masking). - You can only uses one among `input_mask` and `attention_mask` - mems: a list of float32 Tensors in shape [mem_len, bsz, d_model], memory - from previous batches. The length of the list equals n_layer. - If None, no memory is used. - perm_mask: float32 Tensor in shape [bsz, len, len]. - If perm_mask[k, i, j] = 0, i attend to j in batch k; - if perm_mask[k, i, j] = 1, i does not attend to j in batch k. - If None, each position attends to all the others. - target_mapping: float32 Tensor in shape [bsz, num_predict, len]. - If target_mapping[k, i, j] = 1, the i-th predict in batch k is - on the j-th token. - Only used during pre-training for partial prediction. - Set to None during fine-tuning. - labels: TODO Lysandre didn't fill - head_mask: an optional ``torch.Tensor`` of shape [num_heads] or [num_layers, num_heads] with indices between 0 and 1. - It's a mask to be used to nullify some heads of the transformer. 1.0 => head is fully masked, 0.0 => head is not masked. - - - Returns: - A ``tuple(logits_or_loss, mems)`` - - ``logits_or_loss``: if ``labels`` is ``None``, ``logits_or_loss`` corresponds to token logits with shape \ - [batch_size, sequence_length]. If it is not ``None``, it corresponds to the ``CrossEntropy`` loss \ - with the targets. - - ``new_mems``: list (num layers) of updated mem states at the entry of each layer \ - each mem state is a ``torch.FloatTensor`` of size [self.config.mem_len, batch_size, self.config.d_model] \ - Note that the first two dimensions are transposed in ``mems`` with regards to ``input_ids`` and ``labels`` - - Example:: - - # Already been converted into WordPiece token ids - input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) - input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) - token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) - - all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask) - # or - all_encoder_layers, pooled_output = model.forward(input_ids, token_type_ids, input_mask) - """ transformer_outputs = self.transformer(input_ids, token_type_ids, input_mask, attention_mask, mems, perm_mask, target_mapping, head_mask) output = transformer_outputs[0] @@ -1163,28 +1138,60 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel): loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) outputs = (loss,) + outputs - return outputs # return (loss), logits, (mems), (hidden states), (attentions) + return outputs # return (loss), logits, mems, (hidden states), (attentions) +@add_start_docstrings("""XLNet Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of + the hidden-states output to compute `span start logits` and `span end logits`). """, + XLNET_START_DOCSTRING, XLNET_INPUTS_DOCSTRING) class XLNetForQuestionAnswering(XLNetPreTrainedModel): - """ - XLNet model for Question Answering (span extraction). + r""" + **start_positions**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``: + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). + Position outside of the sequence are not taken into account for computing the loss. + **end_positions**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``: + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). + Position outside of the sequence are not taken into account for computing the loss. + **is_impossible**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``: + Labels whether a question has an answer or no answer (SQuAD 2.0) + **cls_index**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``: + Labels for position (index) of the classification token to use as input for computing plausibility of the answer. + **p_mask**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``: + Optional mask of tokens which can't be in answers (e.g. [CLS], [PAD], ...) - This module is composed of the XLNet model with a linear layer on top of - the sequence output that computes ``start_logits`` and ``end_logits`` + Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs: + **loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``: + Total span extraction loss is the sum of a Cross-Entropy for the start and end positions. + **start_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length,)`` + Span-start scores (before SoftMax). + **end_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length,)`` + Span-end scores (before SoftMax). + **mems**: + list of ``torch.FloatTensor`` (one for each layer): + that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model + (see `mems` input above). Can be used to speed up sequential decoding and attend to longer context. + **attentions**: (`optional`, returned when ``config.output_attentions=True``) + list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``: + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. + **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``) + list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings) + of shape ``(batch_size, sequence_length, hidden_size)``: + Hidden-states of the model at the output of each layer plus the initial embedding outputs. - Args: - `config`: a XLNetConfig class instance with the configuration to build a new model - `output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False - `keep_multihead_output`: If True, saves output of the multi-head attention module with its gradient. - This can be used to compute head importance metrics. Default: False + Examples:: - Example:: + >>> config = XLMConfig.from_pretrained('xlm-mlm-en-2048') + >>> tokenizer = XLMTokenizer.from_pretrained('xlm-mlm-en-2048') + >>> + >>> model = XLMForQuestionAnswering(config) + >>> input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1 + >>> start_positions = torch.tensor([1]) + >>> end_positions = torch.tensor([3]) + >>> outputs = model(input_ids, start_positions=start_positions, end_positions=end_positions) + >>> loss, start_scores, end_scores = outputs[:2] - config = XLNetConfig(vocab_size_or_config_json_file=32000, hidden_size=768, - num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) - - model = XLNetForQuestionAnswering(config) """ def __init__(self, config): super(XLNetForQuestionAnswering, self).__init__(config) @@ -1202,53 +1209,6 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel): mems=None, perm_mask=None, target_mapping=None, start_positions=None, end_positions=None, cls_index=None, is_impossible=None, p_mask=None, head_mask=None): - - """ - Performs a model forward pass. **Can be called by calling the class directly, once it has been instantiated.** - - Args: - `input_ids`: a ``torch.LongTensor`` of shape [batch_size, sequence_length] - with the word token indices in the vocabulary(see the tokens pre-processing logic in the scripts - `run_bert_extract_features.py`, `run_bert_classifier.py` and `run_bert_squad.py`) - `token_type_ids`: an optional ``torch.LongTensor`` of shape [batch_size, sequence_length] with the token - types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to - a `sentence B` token (see XLNet paper for more details). - `attention_mask`: [optional] float32 Tensor, SAME FUNCTION as `input_mask` - but with 1 for real tokens and 0 for padding. - Added for easy compatibility with the BERT model (which uses this negative masking). - You can only uses one among ``input_mask`` and ``attention_mask`` - `input_mask`: an optional ``torch.LongTensor`` of shape [batch_size, sequence_length] with indices - selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max - input sequence length in the current batch. It's the mask that we typically use for attention when - a batch has varying length sentences. - `start_positions`: position of the first token for the labeled span: ``torch.LongTensor`` of shape [batch_size]. - Positions are clamped to the length of the sequence and position outside of the sequence are not taken - into account for computing the loss. - `end_positions`: position of the last token for the labeled span: ``torch.LongTensor`` of shape [batch_size]. - Positions are clamped to the length of the sequence and position outside of the sequence are not taken - into account for computing the loss. - `head_mask`: an optional ``torch.Tensor`` of shape [num_heads] or [num_layers, num_heads] with indices between 0 and 1. - It's a mask to be used to nullify some heads of the transformer. 1.0 => head is fully masked, 0.0 => head is not masked. - - Returns: - if ``start_positions`` and ``end_positions`` are not ``None``, outputs the total_loss which is the sum of the \ - ``CrossEntropy`` loss for the start and end token positions. - - if ``start_positions`` or ``end_positions`` is ``None``, outputs a tuple of ``start_logits``, ``end_logits`` - which are the logits respectively for the start and end position tokens of shape \ - [batch_size, sequence_length]. - - Example:: - - # Already been converted into WordPiece token ids - input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) - input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) - token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) - - start_logits, end_logits = model(input_ids, token_type_ids, input_mask) - # or - start_logits, end_logits = model.forward(input_ids, token_type_ids, input_mask) - """ transformer_outputs = self.transformer(input_ids, token_type_ids, input_mask, attention_mask, mems, perm_mask, target_mapping, head_mask) hidden_states = transformer_outputs[0]