mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
update doc for XLM and XLNet
This commit is contained in:
parent
0201d86015
commit
44c985facd
@ -611,11 +611,11 @@ BERT_INPUTS_DOCSTRING = r"""
|
||||
(see `BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding`_ for more details).
|
||||
**attention_mask**: (`optional`) ``torch.Tensor`` of shape ``(batch_size, sequence_length)``:
|
||||
Mask to avoid performing attention on padding token indices.
|
||||
Mask indices selected in ``[0, 1]``:
|
||||
Mask values selected in ``[0, 1]``:
|
||||
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
|
||||
**head_mask**: (`optional`) ``torch.Tensor`` of shape ``(num_heads,)`` or ``(num_layers, num_heads)``:
|
||||
Mask to nullify selected heads of the self-attention modules.
|
||||
Mask indices selected in ``[0, 1]``:
|
||||
Mask values selected in ``[0, 1]``:
|
||||
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
|
||||
"""
|
||||
|
||||
@ -714,7 +714,7 @@ class BertModel(BertPreTrainedModel):
|
||||
return outputs # sequence_output, pooled_output, (hidden_states), (attentions)
|
||||
|
||||
|
||||
@add_start_docstrings("""Bert Model transformer BERT model with two heads on top as done during the pre-training:
|
||||
@add_start_docstrings("""Bert Model with two heads on top as done during the pre-training:
|
||||
a `masked language modeling` head and a `next sentence prediction (classification)` head. """,
|
||||
BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING)
|
||||
class BertForPreTraining(BertPreTrainedModel):
|
||||
@ -791,7 +791,7 @@ class BertForPreTraining(BertPreTrainedModel):
|
||||
return outputs # (loss), prediction_scores, seq_relationship_score, (hidden_states), (attentions)
|
||||
|
||||
|
||||
@add_start_docstrings("""Bert Model transformer BERT model with a `language modeling` head on top. """,
|
||||
@add_start_docstrings("""Bert Model with a `language modeling` head on top. """,
|
||||
BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING)
|
||||
class BertForMaskedLM(BertPreTrainedModel):
|
||||
r"""
|
||||
@ -856,7 +856,7 @@ class BertForMaskedLM(BertPreTrainedModel):
|
||||
return outputs # (masked_lm_loss), prediction_scores, (hidden_states), (attentions)
|
||||
|
||||
|
||||
@add_start_docstrings("""Bert Model transformer BERT model with a `next sentence prediction (classification)` head on top. """,
|
||||
@add_start_docstrings("""Bert Model with a `next sentence prediction (classification)` head on top. """,
|
||||
BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING)
|
||||
class BertForNextSentencePrediction(BertPreTrainedModel):
|
||||
r"""
|
||||
@ -913,7 +913,7 @@ class BertForNextSentencePrediction(BertPreTrainedModel):
|
||||
return outputs # (next_sentence_loss), seq_relationship_score, (hidden_states), (attentions)
|
||||
|
||||
|
||||
@add_start_docstrings("""Bert Model transformer BERT model with a sequence classification/regression head on top (a linear layer on top of
|
||||
@add_start_docstrings("""Bert Model transformer with a sequence classification/regression head on top (a linear layer on top of
|
||||
the pooled output) e.g. for GLUE tasks. """,
|
||||
BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING)
|
||||
class BertForSequenceClassification(BertPreTrainedModel):
|
||||
@ -981,7 +981,7 @@ class BertForSequenceClassification(BertPreTrainedModel):
|
||||
return outputs # (loss), logits, (hidden_states), (attentions)
|
||||
|
||||
|
||||
@add_start_docstrings("""Bert Model transformer BERT model with a multiple choice classification head on top (a linear layer on top of
|
||||
@add_start_docstrings("""Bert Model with a multiple choice classification head on top (a linear layer on top of
|
||||
the pooled output and a softmax) e.g. for RocStories/SWAG tasks. """,
|
||||
BERT_START_DOCSTRING)
|
||||
class BertForMultipleChoice(BertPreTrainedModel):
|
||||
@ -1016,11 +1016,11 @@ class BertForMultipleChoice(BertPreTrainedModel):
|
||||
**attention_mask**: (`optional`) ``torch.Tensor`` of shape ``(batch_size, num_choices, sequence_length)``:
|
||||
Mask to avoid performing attention on padding token indices.
|
||||
The second dimension of the input (`num_choices`) indicates the number of choices to score.
|
||||
Mask indices selected in ``[0, 1]``:
|
||||
Mask values selected in ``[0, 1]``:
|
||||
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
|
||||
**head_mask**: (`optional`) ``torch.Tensor`` of shape ``(num_heads,)`` or ``(num_layers, num_heads)``:
|
||||
Mask to nullify selected heads of the self-attention modules.
|
||||
Mask indices selected in ``[0, 1]``:
|
||||
Mask values selected in ``[0, 1]``:
|
||||
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
|
||||
**labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
|
||||
Labels for computing the multiple choice classification loss.
|
||||
@ -1087,7 +1087,7 @@ class BertForMultipleChoice(BertPreTrainedModel):
|
||||
return outputs # (loss), reshaped_logits, (hidden_states), (attentions)
|
||||
|
||||
|
||||
@add_start_docstrings("""Bert Model transformer BERT model with a token classification head on top (a linear layer on top of
|
||||
@add_start_docstrings("""Bert Model with a token classification head on top (a linear layer on top of
|
||||
the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks. """,
|
||||
BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING)
|
||||
class BertForTokenClassification(BertPreTrainedModel):
|
||||
@ -1154,17 +1154,17 @@ class BertForTokenClassification(BertPreTrainedModel):
|
||||
return outputs # (loss), scores, (hidden_states), (attentions)
|
||||
|
||||
|
||||
@add_start_docstrings("""Bert Model transformer BERT model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of
|
||||
@add_start_docstrings("""Bert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of
|
||||
the hidden-states output to compute `span start logits` and `span end logits`). """,
|
||||
BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING)
|
||||
class BertForQuestionAnswering(BertPreTrainedModel):
|
||||
r"""
|
||||
**start_positions**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
|
||||
Position (index) of the start of the labelled span for computing the token classification loss.
|
||||
Labels for position (index) of the start of the labelled span for computing the token classification loss.
|
||||
Positions are clamped to the length of the sequence (`sequence_length`).
|
||||
Position outside of the sequence are not taken into account for computing the loss.
|
||||
**end_positions**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
|
||||
Position (index) of the end of the labelled span for computing the token classification loss.
|
||||
Labels for position (index) of the end of the labelled span for computing the token classification loss.
|
||||
Positions are clamped to the length of the sequence (`sequence_length`).
|
||||
Position outside of the sequence are not taken into account for computing the loss.
|
||||
|
||||
|
@ -404,11 +404,11 @@ GPT2_INPUTS_DOCSTRING = r""" Inputs:
|
||||
(see `past` output below). Can be used to speed up sequential decoding.
|
||||
**attention_mask**: (`optional`) ``torch.Tensor`` of shape ``(batch_size, sequence_length)``:
|
||||
Mask to avoid performing attention on padding token indices.
|
||||
Mask indices selected in ``[0, 1]``:
|
||||
Mask values selected in ``[0, 1]``:
|
||||
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
|
||||
**head_mask**: (`optional`) ``torch.Tensor`` of shape ``(num_heads,)`` or ``(num_layers, num_heads)``:
|
||||
Mask to nullify selected heads of the self-attention modules.
|
||||
Mask indices selected in ``[0, 1]``:
|
||||
Mask values selected in ``[0, 1]``:
|
||||
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
|
||||
"""
|
||||
|
||||
@ -541,7 +541,7 @@ class GPT2Model(GPT2PreTrainedModel):
|
||||
(linear layer with weights tied to the input embeddings). """, GPT2_START_DOCSTRING, GPT2_INPUTS_DOCSTRING)
|
||||
class GPT2LMHeadModel(GPT2PreTrainedModel):
|
||||
r"""
|
||||
**lm_labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
|
||||
**labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
|
||||
Labels for language modeling.
|
||||
Note that the labels **are shifted** inside the model, i.e. you can set ``lm_labels = input_ids``
|
||||
Indices are selected in ``[-1, 0, ..., config.vocab_size]``
|
||||
@ -549,7 +549,7 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
|
||||
computed for labels in ``[0, ..., config.vocab_size]``
|
||||
|
||||
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
|
||||
**loss**: (`optional`, returned when ``lm_labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
|
||||
**loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
|
||||
Language modeling loss.
|
||||
**prediction_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, config.vocab_size)``
|
||||
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
||||
@ -571,7 +571,7 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
|
||||
>>> tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
|
||||
>>> model = GPT2LMHeadModel(config)
|
||||
>>> input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
|
||||
>>> outputs = model(input_ids, lm_labels=input_ids)
|
||||
>>> outputs = model(input_ids, labels=input_ids)
|
||||
>>> loss, logits = outputs[:2]
|
||||
|
||||
"""
|
||||
@ -590,17 +590,17 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
|
||||
self._tie_or_clone_weights(self.lm_head,
|
||||
self.transformer.wte)
|
||||
|
||||
def forward(self, input_ids, position_ids=None, token_type_ids=None, lm_labels=None, past=None, head_mask=None):
|
||||
def forward(self, input_ids, position_ids=None, token_type_ids=None, labels=None, past=None, head_mask=None):
|
||||
transformer_outputs = self.transformer(input_ids, position_ids, token_type_ids, past, head_mask)
|
||||
hidden_states = transformer_outputs[0]
|
||||
|
||||
lm_logits = self.lm_head(hidden_states)
|
||||
|
||||
outputs = (lm_logits,) + transformer_outputs[1:]
|
||||
if lm_labels is not None:
|
||||
if labels is not None:
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = lm_logits[..., :-1, :].contiguous()
|
||||
shift_labels = lm_labels[..., 1:].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
# Flatten the tokens
|
||||
loss_fct = CrossEntropyLoss(ignore_index=-1)
|
||||
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)),
|
||||
@ -639,11 +639,11 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
|
||||
(see `past` output below). Can be used to speed up sequential decoding.
|
||||
**attention_mask**: (`optional`) ``torch.Tensor`` of shape ``(batch_size, num_choices, sequence_length)``:
|
||||
Mask to avoid performing attention on padding token indices.
|
||||
Mask indices selected in ``[0, 1]``:
|
||||
Mask values selected in ``[0, 1]``:
|
||||
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
|
||||
**head_mask**: (`optional`) ``torch.Tensor`` of shape ``(num_heads,)`` or ``(num_layers, num_heads)``:
|
||||
Mask to nullify selected heads of the self-attention modules.
|
||||
Mask indices selected in ``[0, 1]``:
|
||||
Mask values selected in ``[0, 1]``:
|
||||
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
|
||||
**lm_labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
|
||||
Labels for language modeling.
|
||||
|
@ -414,11 +414,11 @@ OPENAI_GPT_INPUTS_DOCSTRING = r""" Inputs:
|
||||
Indices are selected in the vocabulary (unlike BERT which has a specific vocabulary for segment indices).
|
||||
**attention_mask**: (`optional`) ``torch.Tensor`` of shape ``(batch_size, sequence_length)``:
|
||||
Mask to avoid performing attention on padding token indices.
|
||||
Mask indices selected in ``[0, 1]``:
|
||||
Mask values selected in ``[0, 1]``:
|
||||
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
|
||||
**head_mask**: (`optional`) ``torch.Tensor`` of shape ``(num_heads,)`` or ``(num_layers, num_heads)``:
|
||||
Mask to nullify selected heads of the self-attention modules.
|
||||
Mask indices selected in ``[0, 1]``:
|
||||
Mask values selected in ``[0, 1]``:
|
||||
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
|
||||
"""
|
||||
|
||||
@ -536,7 +536,7 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
|
||||
(linear layer with weights tied to the input embeddings). """, OPENAI_GPT_START_DOCSTRING, OPENAI_GPT_INPUTS_DOCSTRING)
|
||||
class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
|
||||
r"""
|
||||
**lm_labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
|
||||
**labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
|
||||
Labels for language modeling.
|
||||
Note that the labels **are shifted** inside the model, i.e. you can set ``lm_labels = input_ids``
|
||||
Indices are selected in ``[-1, 0, ..., config.vocab_size]``
|
||||
@ -544,7 +544,7 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
|
||||
computed for labels in ``[0, ..., config.vocab_size]``
|
||||
|
||||
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
|
||||
**loss**: (`optional`, returned when ``lm_labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
|
||||
**loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
|
||||
Language modeling loss.
|
||||
**prediction_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, config.vocab_size)``
|
||||
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
||||
@ -562,7 +562,7 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
|
||||
>>> tokenizer = OpenAIGPTTokenizer.from_pretrained('openai-gpt')
|
||||
>>> model = OpenAIGPTLMHeadModel(config)
|
||||
>>> input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
|
||||
>>> outputs = model(input_ids, lm_labels=input_ids)
|
||||
>>> outputs = model(input_ids, labels=input_ids)
|
||||
>>> loss, logits = outputs[:2]
|
||||
|
||||
"""
|
||||
@ -581,16 +581,16 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
|
||||
self._tie_or_clone_weights(self.lm_head,
|
||||
self.transformer.tokens_embed)
|
||||
|
||||
def forward(self, input_ids, position_ids=None, token_type_ids=None, lm_labels=None, head_mask=None):
|
||||
def forward(self, input_ids, position_ids=None, token_type_ids=None, labels=None, head_mask=None):
|
||||
transformer_outputs = self.transformer(input_ids, position_ids, token_type_ids, head_mask)
|
||||
hidden_states = transformer_outputs[0]
|
||||
lm_logits = self.lm_head(hidden_states)
|
||||
|
||||
outputs = (lm_logits,) + transformer_outputs[1:]
|
||||
if lm_labels is not None:
|
||||
if labels is not None:
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = lm_logits[..., :-1, :].contiguous()
|
||||
shift_labels = lm_labels[..., 1:].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
# Flatten the tokens
|
||||
loss_fct = CrossEntropyLoss(ignore_index=-1)
|
||||
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)),
|
||||
@ -625,11 +625,11 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
|
||||
Indices are selected in the vocabulary (unlike BERT which has a specific vocabulary for segment indices).
|
||||
**attention_mask**: (`optional`) ``torch.Tensor`` of shape ``(batch_size, num_choices, sequence_length)``:
|
||||
Mask to avoid performing attention on padding token indices.
|
||||
Mask indices selected in ``[0, 1]``:
|
||||
Mask values selected in ``[0, 1]``:
|
||||
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
|
||||
**head_mask**: (`optional`) ``torch.Tensor`` of shape ``(num_heads,)`` or ``(num_layers, num_heads)``:
|
||||
Mask to nullify selected heads of the self-attention modules.
|
||||
Mask indices selected in ``[0, 1]``:
|
||||
Mask values selected in ``[0, 1]``:
|
||||
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
|
||||
**lm_labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
|
||||
Labels for language modeling.
|
||||
|
@ -937,13 +937,13 @@ TRANSFO_XL_INPUTS_DOCSTRING = r"""
|
||||
Indices can be obtained using :class:`pytorch_transformers.TransfoXLTokenizer`.
|
||||
See :func:`pytorch_transformers.PreTrainedTokenizer.encode` and
|
||||
:func:`pytorch_transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details.
|
||||
**mems**:
|
||||
**mems**: (`optional`)
|
||||
list of ``torch.FloatTensor`` (one for each layer):
|
||||
that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
|
||||
(see `mems` output below). Can be used to speed up sequential decoding and attend to longer context.
|
||||
**head_mask**: (`optional`) ``torch.Tensor`` of shape ``(num_heads,)`` or ``(num_layers, num_heads)``:
|
||||
Mask to nullify selected heads of the self-attention modules.
|
||||
Mask indices selected in ``[0, 1]``:
|
||||
Mask values selected in ``[0, 1]``:
|
||||
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
|
||||
"""
|
||||
|
||||
@ -954,7 +954,7 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
|
||||
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
|
||||
**last_hidden_state**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, hidden_size)``
|
||||
Sequence of hidden-states at the last layer of the model.
|
||||
**mems**: ``torch.Tensor`` of shape ``(num_heads,)`` or ``(num_layers, num_heads)``:
|
||||
**mems**:
|
||||
list of ``torch.FloatTensor`` (one for each layer):
|
||||
that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
|
||||
(see `mems` input above). Can be used to speed up sequential decoding and attend to longer context.
|
||||
@ -1270,7 +1270,7 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel):
|
||||
**prediction_scores**: ``None`` if ``lm_labels`` is provided else ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, config.vocab_size)``
|
||||
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
||||
We don't output them when the loss is computed to speedup adaptive softmax decoding.
|
||||
**mems**: ``torch.Tensor`` of shape ``(num_heads,)`` or ``(num_layers, num_heads)``:
|
||||
**mems**:
|
||||
list of ``torch.FloatTensor`` (one for each layer):
|
||||
that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
|
||||
(see `mems` input above). Can be used to speed up sequential decoding and attend to longer context.
|
||||
|
@ -538,7 +538,6 @@ class PoolerAnswerClass(nn.Module):
|
||||
|
||||
class SQuADHead(nn.Module):
|
||||
""" A SQuAD head inspired by XLNet.
|
||||
Compute
|
||||
"""
|
||||
def __init__(self, config):
|
||||
super(SQuADHead, self).__init__()
|
||||
|
@ -30,7 +30,7 @@ from torch import nn
|
||||
from torch.nn import functional as F
|
||||
from torch.nn import CrossEntropyLoss, MSELoss
|
||||
|
||||
from .modeling_utils import (PretrainedConfig, PreTrainedModel,
|
||||
from .modeling_utils import (PretrainedConfig, PreTrainedModel, add_start_docstrings,
|
||||
prune_linear_layer, SequenceSummary, SQuADHead)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -392,28 +392,94 @@ class XLMPreTrainedModel(PreTrainedModel):
|
||||
module.weight.data.fill_(1.0)
|
||||
|
||||
|
||||
XLM_START_DOCSTRING = r""" The XLM model was proposed in
|
||||
`Cross-lingual Language Model Pretraining`_
|
||||
by Guillaume Lample*, Alexis Conneau*. It's a transformer pre-trained using one of the following objectives:
|
||||
|
||||
- a causal language modeling (CLM) objective (next token prediction),
|
||||
- a masked language modeling (MLM) objective (Bert-like), or
|
||||
- a Translation Language Modeling (TLM) object (extension of Bert's MLM to multiple language inputs)
|
||||
|
||||
Original code can be found `here`_.
|
||||
|
||||
This model is a PyTorch `torch.nn.Module`_ sub-class. Use it as a regular PyTorch Module and
|
||||
refer to the PyTorch documentation for all matter related to general usage and behavior.
|
||||
|
||||
.. _`Cross-lingual Language Model Pretraining`:
|
||||
https://arxiv.org/abs/1901.07291
|
||||
|
||||
.. _`torch.nn.Module`:
|
||||
https://pytorch.org/docs/stable/nn.html#module
|
||||
|
||||
.. _`here`:
|
||||
https://github.com/facebookresearch/XLM
|
||||
|
||||
Parameters:
|
||||
config (:class:`~pytorch_transformers.XLMConfig`): Model configuration class with all the parameters of the model.
|
||||
"""
|
||||
|
||||
XLM_INPUTS_DOCSTRING = r"""
|
||||
Inputs:
|
||||
**input_ids**: ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
|
||||
Indices of input sequence tokens in the vocabulary.
|
||||
Indices can be obtained using :class:`pytorch_transformers.XLMTokenizer`.
|
||||
See :func:`pytorch_transformers.PreTrainedTokenizer.encode` and
|
||||
:func:`pytorch_transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details.
|
||||
**position_ids**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
|
||||
Indices of positions of each input sequence tokens in the position embeddings.
|
||||
Selected in the range ``[0, config.max_position_embeddings - 1[``.
|
||||
**token_type_ids**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
|
||||
A parallel sequence of tokens (can be used to indicate various portions of the inputs).
|
||||
The embeddings from these tokens will be summed with the respective token embeddings.
|
||||
Indices are selected in the vocabulary (unlike BERT which has a specific vocabulary for segment indices).
|
||||
**langs**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
|
||||
A parallel sequence of tokens to be used to indicate the language of each token in the input.
|
||||
Indices are selected in the pre-trained language vocabulary,
|
||||
i.e. in the range ``[0, config.n_langs - 1[``.
|
||||
**attention_mask**: (`optional`) ``torch.Tensor`` of shape ``(batch_size, sequence_length)``:
|
||||
Mask to avoid performing attention on padding token indices.
|
||||
Mask values selected in ``[0, 1]``:
|
||||
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
|
||||
**lengths**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
|
||||
Length of each sentence that can be used to avoid performing attention on padding token indices.
|
||||
You can also use `attention_mask` for the same result (see above), kept here for compatbility.
|
||||
Indices selected in ``[0, ..., input_ids.size(-1)]``:
|
||||
**cache**:
|
||||
dictionary with ``torch.FloatTensor`` that contains pre-computed
|
||||
hidden-states (key and values in the attention blocks) as computed by the model
|
||||
(see `cache` output below). Can be used to speed up sequential decoding.
|
||||
The dictionary object will be modified in-place during the forward pass to add newly computed hidden-states.
|
||||
**head_mask**: (`optional`) ``torch.Tensor`` of shape ``(num_heads,)`` or ``(num_layers, num_heads)``:
|
||||
Mask to nullify selected heads of the self-attention modules.
|
||||
Mask values selected in ``[0, 1]``:
|
||||
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
|
||||
"""
|
||||
|
||||
@add_start_docstrings("The bare XLM Model transformer outputing raw hidden-states without any specific head on top.",
|
||||
XLM_START_DOCSTRING, XLM_INPUTS_DOCSTRING)
|
||||
class XLMModel(XLMPreTrainedModel):
|
||||
r"""
|
||||
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
|
||||
**last_hidden_state**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, hidden_size)``
|
||||
Sequence of hidden-states at the last layer of the model.
|
||||
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
|
||||
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
|
||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
|
||||
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
|
||||
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
|
||||
of shape ``(batch_size, sequence_length, hidden_size)``:
|
||||
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||
|
||||
Examples::
|
||||
|
||||
>>> config = XLMConfig.from_pretrained('xlm-mlm-en-2048')
|
||||
>>> tokenizer = XLMTokenizer.from_pretrained('xlm-mlm-en-2048')
|
||||
>>> model = XLMModel(config)
|
||||
>>> input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
|
||||
>>> outputs = model(input_ids)
|
||||
>>> last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
|
||||
|
||||
"""
|
||||
XLM model from: "Cross-lingual Language Model Pretraining" by Guillaume Lample, Alexis Conneau
|
||||
|
||||
Paper: https://arxiv.org/abs/1901.07291
|
||||
|
||||
Original code: https://github.com/facebookresearch/XLM
|
||||
|
||||
Args:
|
||||
`config`: a XLMConfig class instance with the configuration to build a new model
|
||||
`output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False
|
||||
`keep_multihead_output`: If True, saves output of the multi-head attention module with its gradient.
|
||||
This can be used to compute head importance metrics. Default: False
|
||||
|
||||
Example::
|
||||
|
||||
config = modeling.XLMConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
|
||||
num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
|
||||
|
||||
model = modeling.XLMModel(config=config)
|
||||
"""
|
||||
|
||||
ATTRIBUTES = ['encoder', 'eos_index', 'pad_index', # 'with_output',
|
||||
'n_langs', 'n_words', 'dim', 'n_layers', 'n_heads',
|
||||
'hidden_dim', 'dropout', 'attention_dropout', 'asm',
|
||||
@ -493,57 +559,8 @@ class XLMModel(XLMPreTrainedModel):
|
||||
for layer, heads in heads_to_prune.items():
|
||||
self.attentions[layer].prune_heads(heads)
|
||||
|
||||
def forward(self, input_ids, lengths=None, positions=None, langs=None,
|
||||
def forward(self, input_ids, lengths=None, position_ids=None, langs=None,
|
||||
token_type_ids=None, attention_mask=None, cache=None, head_mask=None): # src_enc=None, src_len=None,
|
||||
"""
|
||||
Performs a model forward pass. **Can be called by calling the class directly, once it has been instantiated.**
|
||||
|
||||
Parameters:
|
||||
`input_ids`: a ``torch.LongTensor`` of shape [batch_size, sequence_length]
|
||||
with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
|
||||
`run_bert_extract_features.py`, `run_bert_classifier.py` and `run_bert_squad.py`)
|
||||
`lengths`: ``torch.LongTensor`` of size ``bs``, containing the length of each sentence
|
||||
`positions`: ``torch.LongTensor`` of size ``(bs, slen)``, containing word positions
|
||||
`langs`: ``torch.LongTensor`` of size ``(bs, slen)``, containing language IDs
|
||||
`token_type_ids`: an optional ``torch.LongTensor`` of shape [batch_size, sequence_length] with the token
|
||||
types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
|
||||
a `sentence B` token (see XLM paper for more details).
|
||||
`attention_mask`: an optional ``torch.LongTensor`` of shape [batch_size, sequence_length] with indices
|
||||
selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
|
||||
input sequence length in the current batch. It's the mask that we typically use for attention when
|
||||
a batch has varying length sentences.
|
||||
`cache`: TODO
|
||||
`head_mask`: an optional ``torch.Tensor`` of shape [num_heads] or [num_layers, num_heads] with indices between 0 and 1.
|
||||
It's a mask to be used to nullify some heads of the transformer. 1.0 => head is fully masked, 0.0 => head is not masked.
|
||||
|
||||
|
||||
Returns:
|
||||
A ``tuple(encoded_layers, pooled_output)``, with
|
||||
|
||||
``encoded_layers``: controlled by ``output_all_encoded_layers`` argument:
|
||||
|
||||
- ``output_all_encoded_layers=True``: outputs a list of the full sequences of encoded-hidden-states at the end \
|
||||
of each attention block (i.e. 12 full sequences for XLM-base, 24 for XLM-large), each \
|
||||
encoded-hidden-state is a ``torch.FloatTensor`` of size [batch_size, sequence_length, hidden_size],
|
||||
|
||||
- ``output_all_encoded_layers=False``: outputs only the full sequence of hidden-states corresponding \
|
||||
to the last attention block of shape [batch_size, sequence_length, hidden_size],
|
||||
|
||||
``pooled_output``: a ``torch.FloatTensor`` of size [batch_size, hidden_size] which is the output of a
|
||||
classifier pre-trained on top of the hidden state associated to the first character of the
|
||||
input (`CLS`) to train on the Next-Sentence task (see XLM's paper).
|
||||
|
||||
Example::
|
||||
|
||||
# Already been converted into WordPiece token ids
|
||||
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
|
||||
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
|
||||
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
|
||||
|
||||
all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
|
||||
# or
|
||||
all_encoder_layers, pooled_output = model.forward(input_ids, token_type_ids, input_mask)
|
||||
"""
|
||||
if lengths is None:
|
||||
lengths = (input_ids != self.pad_index).sum(dim=1).long()
|
||||
# mask = input_ids != self.pad_index
|
||||
@ -563,18 +580,15 @@ class XLMModel(XLMPreTrainedModel):
|
||||
# if self.is_decoder and src_enc is not None:
|
||||
# src_mask = torch.arange(src_len.max(), dtype=torch.long, device=lengths.device) < src_len[:, None]
|
||||
|
||||
# positions
|
||||
if positions is None:
|
||||
positions = input_ids.new((slen,)).long()
|
||||
positions = torch.arange(slen, out=positions).unsqueeze(0)
|
||||
# position_ids
|
||||
if position_ids is None:
|
||||
position_ids = input_ids.new((slen,)).long()
|
||||
position_ids = torch.arange(slen, out=position_ids).unsqueeze(0)
|
||||
else:
|
||||
assert positions.size() == (bs, slen) # (slen, bs)
|
||||
# positions = positions.transpose(0, 1)
|
||||
assert position_ids.size() == (bs, slen) # (slen, bs)
|
||||
# position_ids = position_ids.transpose(0, 1)
|
||||
|
||||
# langs
|
||||
assert langs is None or token_type_ids is None, "You can only use one among langs and token_type_ids"
|
||||
if token_type_ids is not None:
|
||||
langs = token_type_ids
|
||||
if langs is not None:
|
||||
assert langs.size() == (bs, slen) # (slen, bs)
|
||||
# langs = langs.transpose(0, 1)
|
||||
@ -598,7 +612,7 @@ class XLMModel(XLMPreTrainedModel):
|
||||
if cache is not None:
|
||||
_slen = slen - cache['slen']
|
||||
input_ids = input_ids[:, -_slen:]
|
||||
positions = positions[:, -_slen:]
|
||||
position_ids = position_ids[:, -_slen:]
|
||||
if langs is not None:
|
||||
langs = langs[:, -_slen:]
|
||||
mask = mask[:, -_slen:]
|
||||
@ -606,9 +620,11 @@ class XLMModel(XLMPreTrainedModel):
|
||||
|
||||
# embeddings
|
||||
tensor = self.embeddings(input_ids)
|
||||
tensor = tensor + self.position_embeddings(positions).expand_as(tensor)
|
||||
tensor = tensor + self.position_embeddings(position_ids).expand_as(tensor)
|
||||
if langs is not None:
|
||||
tensor = tensor + self.lang_embeddings(langs)
|
||||
if token_type_ids is not None:
|
||||
tensor = tensor + self.embeddings(token_type_ids)
|
||||
tensor = self.layer_norm_emb(tensor)
|
||||
tensor = F.dropout(tensor, p=self.dropout, training=self.training)
|
||||
tensor *= mask.unsqueeze(-1).to(tensor.dtype)
|
||||
@ -702,25 +718,40 @@ class XLMPredLayer(nn.Module):
|
||||
return outputs
|
||||
|
||||
|
||||
@add_start_docstrings("""The XLM Model transformer with a language modeling head on top
|
||||
(linear layer with weights tied to the input embeddings). """,
|
||||
XLM_START_DOCSTRING, XLM_INPUTS_DOCSTRING)
|
||||
class XLMWithLMHeadModel(XLMPreTrainedModel):
|
||||
""" XLM model from: "Cross-lingual Language Model Pretraining" by Guillaume Lample, Alexis Conneau
|
||||
r"""
|
||||
**labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
|
||||
Labels for language modeling.
|
||||
Note that the labels **are shifted** inside the model, i.e. you can set ``lm_labels = input_ids``
|
||||
Indices are selected in ``[-1, 0, ..., config.vocab_size]``
|
||||
All labels set to ``-1`` are ignored (masked), the loss is only
|
||||
computed for labels in ``[0, ..., config.vocab_size]``
|
||||
|
||||
Paper: https://arxiv.org/abs/1901.07291
|
||||
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
|
||||
**loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
|
||||
Language modeling loss.
|
||||
**prediction_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, config.vocab_size)``
|
||||
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
||||
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
|
||||
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
|
||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
|
||||
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
|
||||
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
|
||||
of shape ``(batch_size, sequence_length, hidden_size)``:
|
||||
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||
|
||||
Original code: https://github.com/facebookresearch/XLM
|
||||
Examples::
|
||||
|
||||
Args:
|
||||
`config`: a XLMConfig class instance with the configuration to build a new model
|
||||
`output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False
|
||||
`keep_multihead_output`: If True, saves output of the multi-head attention module with its gradient.
|
||||
This can be used to compute head importance metrics. Default: False
|
||||
>>> config = XLMConfig.from_pretrained('xlm-mlm-en-2048')
|
||||
>>> tokenizer = XLMTokenizer.from_pretrained('xlm-mlm-en-2048')
|
||||
>>> model = XLMWithLMHeadModel(config)
|
||||
>>> input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
|
||||
>>> outputs = model(input_ids)
|
||||
>>> last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
|
||||
|
||||
Example::
|
||||
|
||||
config = modeling.XLMConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
|
||||
num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
|
||||
|
||||
model = modeling.XLMModel(config=config)
|
||||
"""
|
||||
def __init__(self, config):
|
||||
super(XLMWithLMHeadModel, self).__init__(config)
|
||||
@ -735,57 +766,9 @@ class XLMWithLMHeadModel(XLMPreTrainedModel):
|
||||
"""
|
||||
self._tie_or_clone_weights(self.pred_layer.proj, self.transformer.embeddings)
|
||||
|
||||
def forward(self, input_ids, lengths=None, positions=None, langs=None, token_type_ids=None,
|
||||
def forward(self, input_ids, lengths=None, position_ids=None, langs=None, token_type_ids=None,
|
||||
attention_mask=None, cache=None, labels=None, head_mask=None):
|
||||
"""
|
||||
Args:
|
||||
`input_ids`: a ``torch.LongTensor`` of shape [batch_size, sequence_length]
|
||||
with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
|
||||
`run_bert_extract_features.py`, `run_bert_classifier.py` and `run_bert_squad.py`)
|
||||
`lengths`: TODO
|
||||
`positions`: TODO
|
||||
`langs`: TODO
|
||||
`token_type_ids`: an optional ``torch.LongTensor`` of shape [batch_size, sequence_length] with the token
|
||||
types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
|
||||
a `sentence B` token (see XLM paper for more details).
|
||||
`attention_mask`: an optional ``torch.LongTensor`` of shape [batch_size, sequence_length] with indices
|
||||
selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
|
||||
input sequence length in the current batch. It's the mask that we typically use for attention when
|
||||
a batch has varying length sentences.
|
||||
`cache`: TODO
|
||||
`labels`: TODO
|
||||
`head_mask`: an optional ``torch.Tensor`` of shape [num_heads] or [num_layers, num_heads] with indices between 0 and 1.
|
||||
It's a mask to be used to nullify some heads of the transformer. 1.0 => head is fully masked, 0.0 => head is not masked.
|
||||
|
||||
|
||||
Returns:
|
||||
A ``tuple(encoded_layers, pooled_output)``, with
|
||||
|
||||
``encoded_layers``: controlled by ``output_all_encoded_layers`` argument:
|
||||
|
||||
If ``output_all_encoded_layers=True``: outputs a list of the full sequences of encoded-hidden-states \
|
||||
at the end of each attention block (i.e. 12 full sequences for XLM-base, 24 for XLM-large), each \
|
||||
encoded-hidden-state is a ``torch.FloatTensor`` of size [batch_size, sequence_length, hidden_size],
|
||||
|
||||
If ``output_all_encoded_layers=False``: outputs only the full sequence of hidden-states corresponding \
|
||||
to the last attention block of shape [batch_size, sequence_length, hidden_size],
|
||||
|
||||
``pooled_output``: a ``torch.FloatTensor`` of size [batch_size, hidden_size] which is the output of a \
|
||||
classifier pre-trained on top of the hidden state associated to the first character of the \
|
||||
input (`CLS`) to train on the Next-Sentence task (see XLM's paper).
|
||||
|
||||
Example::
|
||||
|
||||
# Already been converted into WordPiece token ids
|
||||
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
|
||||
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
|
||||
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
|
||||
|
||||
all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
|
||||
# or
|
||||
all_encoder_layers, pooled_output = model.forward(input_ids, token_type_ids, input_mask)
|
||||
"""
|
||||
transformer_outputs = self.transformer(input_ids, lengths=lengths, positions=positions, token_type_ids=token_type_ids,
|
||||
transformer_outputs = self.transformer(input_ids, lengths=lengths, position_ids=position_ids, token_type_ids=token_type_ids,
|
||||
langs=langs, attention_mask=attention_mask, cache=cache, head_mask=head_mask)
|
||||
|
||||
output = transformer_outputs[0]
|
||||
@ -795,25 +778,40 @@ class XLMWithLMHeadModel(XLMPreTrainedModel):
|
||||
return outputs
|
||||
|
||||
|
||||
@add_start_docstrings("""XLM Model with a sequence classification/regression head on top (a linear layer on top of
|
||||
the pooled output) e.g. for GLUE tasks. """,
|
||||
XLM_START_DOCSTRING, XLM_INPUTS_DOCSTRING)
|
||||
class XLMForSequenceClassification(XLMPreTrainedModel):
|
||||
"""XLM model ("XLM: Generalized Autoregressive Pretraining for Language Understanding").
|
||||
r"""
|
||||
**labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
|
||||
Labels for computing the sequence classification/regression loss.
|
||||
Indices should be in ``[0, ..., config.num_labels]``.
|
||||
If ``config.num_labels == 1`` a regression loss is computed (Mean-Square loss),
|
||||
If ``config.num_labels > 1`` a classification loss is computed (Cross-Entropy).
|
||||
|
||||
Args:
|
||||
`config`: a XLMConfig class instance with the configuration to build a new model
|
||||
`output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False
|
||||
`keep_multihead_output`: If True, saves output of the multi-head attention module with its gradient.
|
||||
This can be used to compute head importance metrics. Default: False
|
||||
`summary_type`: str, "last", "first", "mean", or "attn". The method
|
||||
to pool the input to get a vector representation. Default: last
|
||||
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
|
||||
**loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
|
||||
Classification (or regression if config.num_labels==1) loss.
|
||||
**logits**: ``torch.FloatTensor`` of shape ``(batch_size, config.num_labels)``
|
||||
Classification (or regression if config.num_labels==1) scores (before SoftMax).
|
||||
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
|
||||
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
|
||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
|
||||
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
|
||||
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
|
||||
of shape ``(batch_size, sequence_length, hidden_size)``:
|
||||
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||
|
||||
Examples::
|
||||
|
||||
|
||||
Example::
|
||||
|
||||
config = modeling.XLMConfig(vocab_size_or_config_json_file=32000, d_model=768,
|
||||
n_layer=12, num_attention_heads=12, intermediate_size=3072)
|
||||
|
||||
model = modeling.XLMModel(config=config)
|
||||
>>> config = XLMConfig.from_pretrained('xlm-mlm-en-2048')
|
||||
>>> tokenizer = XLMTokenizer.from_pretrained('xlm-mlm-en-2048')
|
||||
>>>
|
||||
>>> model = XLMForSequenceClassification(config)
|
||||
>>> input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
|
||||
>>> labels = torch.tensor([1]).unsqueeze(0) # Batch size 1
|
||||
>>> outputs = model(input_ids, labels=labels)
|
||||
>>> loss, logits = outputs[:2]
|
||||
|
||||
"""
|
||||
def __init__(self, config):
|
||||
@ -825,42 +823,9 @@ class XLMForSequenceClassification(XLMPreTrainedModel):
|
||||
|
||||
self.apply(self.init_weights)
|
||||
|
||||
def forward(self, input_ids, lengths=None, positions=None, langs=None, token_type_ids=None,
|
||||
def forward(self, input_ids, lengths=None, position_ids=None, langs=None, token_type_ids=None,
|
||||
attention_mask=None, cache=None, labels=None, head_mask=None):
|
||||
"""
|
||||
Args:
|
||||
input_ids: TODO
|
||||
lengths: TODO
|
||||
positions: TODO
|
||||
langs: TODO
|
||||
token_type_ids: int32 Tensor in shape [bsz, len], the input segment IDs.
|
||||
attention_mask: [optional] float32 Tensor, SAME FUNCTION as `input_mask`
|
||||
but with 1 for real tokens and 0 for padding.
|
||||
Added for easy compatibility with the XLM model (which uses this negative masking).
|
||||
You can only uses one among `input_mask` and `attention_mask`
|
||||
cache: TODO
|
||||
labels: TODO
|
||||
head_mask: TODO
|
||||
|
||||
|
||||
Returns:
|
||||
A ``tuple(logits_or_loss, new_mems)``. If ``labels`` is ``None``, return token logits with shape
|
||||
[batch_size, sequence_length]. If it isn't ``None``, return the ``CrossEntropy`` loss with the targets.
|
||||
|
||||
``new_mems`` is a list (num layers) of updated mem states at the entry of each layer \
|
||||
each mem state is a ``torch.FloatTensor`` of size [self.config.mem_len, batch_size, self.config.d_model] \
|
||||
Note that the first two dimensions are transposed in ``mems`` with regards to ``input_ids`` and ``labels``
|
||||
|
||||
Example::
|
||||
|
||||
# Already been converted into WordPiece token ids
|
||||
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
|
||||
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
|
||||
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
|
||||
|
||||
all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
|
||||
"""
|
||||
transformer_outputs = self.transformer(input_ids, lengths=lengths, positions=positions, token_type_ids=token_type_ids,
|
||||
transformer_outputs = self.transformer(input_ids, lengths=lengths, position_ids=position_ids, token_type_ids=token_type_ids,
|
||||
langs=langs, attention_mask=attention_mask, cache=cache, head_mask=head_mask)
|
||||
|
||||
output = transformer_outputs[0]
|
||||
@ -881,26 +846,53 @@ class XLMForSequenceClassification(XLMPreTrainedModel):
|
||||
return outputs
|
||||
|
||||
|
||||
@add_start_docstrings("""XLM Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of
|
||||
the hidden-states output to compute `span start logits` and `span end logits`). """,
|
||||
XLM_START_DOCSTRING, XLM_INPUTS_DOCSTRING)
|
||||
class XLMForQuestionAnswering(XLMPreTrainedModel):
|
||||
"""
|
||||
XLM model for Question Answering (span extraction).
|
||||
This module is composed of the XLM model with a linear layer on top of
|
||||
the sequence output that computes start_logits and end_logits
|
||||
r"""
|
||||
**start_positions**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
|
||||
Labels for position (index) of the start of the labelled span for computing the token classification loss.
|
||||
Positions are clamped to the length of the sequence (`sequence_length`).
|
||||
Position outside of the sequence are not taken into account for computing the loss.
|
||||
**end_positions**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
|
||||
Labels for position (index) of the end of the labelled span for computing the token classification loss.
|
||||
Positions are clamped to the length of the sequence (`sequence_length`).
|
||||
Position outside of the sequence are not taken into account for computing the loss.
|
||||
**is_impossible**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
|
||||
Labels whether a question has an answer or no answer (SQuAD 2.0)
|
||||
**cls_index**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
|
||||
Labels for position (index) of the classification token to use as input for computing plausibility of the answer.
|
||||
**p_mask**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
|
||||
Optional mask of tokens which can't be in answers (e.g. [CLS], [PAD], ...)
|
||||
|
||||
Args:
|
||||
`config`: a XLMConfig class instance with the configuration to build a new model
|
||||
`output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False
|
||||
`keep_multihead_output`: If True, saves output of the multi-head attention module with its gradient.
|
||||
This can be used to compute head importance metrics. Default: False
|
||||
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
|
||||
**loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
|
||||
Total span extraction loss is the sum of a Cross-Entropy for the start and end positions.
|
||||
**start_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length,)``
|
||||
Span-start scores (before SoftMax).
|
||||
**end_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length,)``
|
||||
Span-end scores (before SoftMax).
|
||||
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
|
||||
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
|
||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
|
||||
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
|
||||
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
|
||||
of shape ``(batch_size, sequence_length, hidden_size)``:
|
||||
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||
|
||||
Examples::
|
||||
|
||||
>>> config = XLMConfig.from_pretrained('xlm-mlm-en-2048')
|
||||
>>> tokenizer = XLMTokenizer.from_pretrained('xlm-mlm-en-2048')
|
||||
>>>
|
||||
>>> model = XLMForQuestionAnswering(config)
|
||||
>>> input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
|
||||
>>> start_positions = torch.tensor([1])
|
||||
>>> end_positions = torch.tensor([3])
|
||||
>>> outputs = model(input_ids, start_positions=start_positions, end_positions=end_positions)
|
||||
>>> loss, start_scores, end_scores = outputs[:2]
|
||||
|
||||
Example::
|
||||
|
||||
config = XLMConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
|
||||
num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
|
||||
|
||||
model = XLMForQuestionAnswering(config)
|
||||
"""
|
||||
def __init__(self, config):
|
||||
super(XLMForQuestionAnswering, self).__init__(config)
|
||||
@ -910,63 +902,10 @@ class XLMForQuestionAnswering(XLMPreTrainedModel):
|
||||
|
||||
self.apply(self.init_weights)
|
||||
|
||||
def forward(self, input_ids, lengths=None, positions=None, langs=None, token_type_ids=None,
|
||||
def forward(self, input_ids, lengths=None, position_ids=None, langs=None, token_type_ids=None,
|
||||
attention_mask=None, cache=None, start_positions=None, end_positions=None,
|
||||
cls_index=None, is_impossible=None, p_mask=None, head_mask=None):
|
||||
|
||||
"""
|
||||
Performs a model forward pass. **Can be called by calling the class directly, once it has been instantiated.**
|
||||
|
||||
Args:
|
||||
input_ids: a ``torch.LongTensor`` of shape [batch_size, sequence_length]
|
||||
with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
|
||||
`run_bert_extract_features.py`, `run_bert_classifier.py` and `run_bert_squad.py`)
|
||||
lengths: TODO
|
||||
positions: TODO
|
||||
langs: TODO
|
||||
token_type_ids: an optional ``torch.LongTensor`` of shape [batch_size, sequence_length] with the token
|
||||
types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
|
||||
a `sentence B` token (see XLM paper for more details).
|
||||
attention_mask: [optional] float32 Tensor, SAME FUNCTION as `input_mask`
|
||||
but with 1 for real tokens and 0 for padding.
|
||||
Added for easy compatibility with the XLM model (which uses this negative masking).
|
||||
You can only uses one among `input_mask` and `attention_mask`
|
||||
cache: TODO
|
||||
start_positions: position of the first token for the labeled span: ``torch.LongTensor`` of shape [batch_size].
|
||||
Positions are clamped to the length of the sequence and position outside of the sequence are not taken
|
||||
into account for computing the loss.
|
||||
end_positions: position of the last token for the labeled span: ``torch.LongTensor`` of shape [batch_size].
|
||||
Positions are clamped to the length of the sequence and position outside of the sequence are not taken
|
||||
into account for computing the loss.
|
||||
cls_index: TODO
|
||||
is_impossible: TODO
|
||||
p_mask: TODO
|
||||
head_mask: an optional ``torch.Tensor`` of shape [num_heads] or [num_layers, num_heads] with indices between 0 and 1.
|
||||
It's a mask to be used to nullify some heads of the transformer. 1.0 => head is fully masked, 0.0 => head is not masked.
|
||||
|
||||
Returns:
|
||||
Either the ``total_loss`` or a ``tuple(start_logits, end_logits)``
|
||||
|
||||
if ``start_positions`` and ``end_positions`` are not ``None``, \
|
||||
outputs the total_loss which is the sum of the CrossEntropy loss for the start and end token positions.
|
||||
|
||||
if ``start_positions`` or ``end_positions`` is ``None``:
|
||||
Outputs a ``tuple(start_logits, end_logits)`` which are the logits respectively for the start and end
|
||||
position tokens of shape [batch_size, sequence_length].
|
||||
|
||||
Example::
|
||||
|
||||
# Already been converted into WordPiece token ids
|
||||
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
|
||||
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
|
||||
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
|
||||
|
||||
start_logits, end_logits = model(input_ids, token_type_ids, input_mask)
|
||||
# or
|
||||
start_logits, end_logits = model.forward(input_ids, token_type_ids, input_mask)
|
||||
"""
|
||||
|
||||
transformer_outputs = self.transformer(input_ids, lengths=lengths, positions=positions, token_type_ids=token_type_ids,
|
||||
transformer_outputs = self.transformer(input_ids, lengths=lengths, position_ids=position_ids, token_type_ids=token_type_ids,
|
||||
langs=langs, attention_mask=attention_mask, cache=cache, head_mask=head_mask)
|
||||
|
||||
output = transformer_outputs[0]
|
||||
|
@ -15,8 +15,6 @@
|
||||
# limitations under the License.
|
||||
""" PyTorch XLNet model.
|
||||
"""
|
||||
from __future__ import (absolute_import, division, print_function,
|
||||
unicode_literals)
|
||||
from __future__ import absolute_import, division, print_function, unicode_literals
|
||||
|
||||
import json
|
||||
@ -32,7 +30,8 @@ from torch.nn import functional as F
|
||||
from torch.nn import CrossEntropyLoss, MSELoss
|
||||
|
||||
from .modeling_utils import (CONFIG_NAME, WEIGHTS_NAME, PretrainedConfig, PreTrainedModel,
|
||||
SequenceSummary, PoolerAnswerClass, PoolerEndLogits, PoolerStartLogits)
|
||||
SequenceSummary, PoolerAnswerClass, PoolerEndLogits, PoolerStartLogits,
|
||||
add_start_docstrings)
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -619,26 +618,105 @@ class XLNetPreTrainedModel(PreTrainedModel):
|
||||
module.mask_emb.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
|
||||
|
||||
XLNET_START_DOCSTRING = r""" The XLNet model was proposed in
|
||||
`XLNet: Generalized Autoregressive Pretraining for Language Understanding`_
|
||||
by Zhilin Yang*, Zihang Dai*, Yiming Yang, Jaime Carbonell, Ruslan Salakhutdinov, Quoc V. Le.
|
||||
XLnet is an extension of the Transformer-XL model pre-trained using an autoregressive method
|
||||
to learn bidirectional contexts by maximizing the expected likelihood over all permutations
|
||||
of the input sequence factorization order.
|
||||
|
||||
The specific attention pattern can be controlled at training and test time using the `perm_mask` input.
|
||||
|
||||
Do to the difficulty of training a fully auto-regressive model over various factorization order,
|
||||
XLNet is pretrained using only a sub-set of the output tokens as target which are selected
|
||||
with the `target_mapping` input.
|
||||
|
||||
To use XLNet for sequential decoding (i.e. not in fully bi-directional setting), use the `perm_mask` and
|
||||
`target_mapping` inputs to control the attention span and outputs (see examples in `examples/run_generation.py`)
|
||||
|
||||
This model is a PyTorch `torch.nn.Module`_ sub-class. Use it as a regular PyTorch Module and
|
||||
refer to the PyTorch documentation for all matter related to general usage and behavior.
|
||||
|
||||
.. _`XLNet: Generalized Autoregressive Pretraining for Language Understanding`:
|
||||
http://arxiv.org/abs/1906.08237
|
||||
|
||||
.. _`torch.nn.Module`:
|
||||
https://pytorch.org/docs/stable/nn.html#module
|
||||
|
||||
Parameters:
|
||||
config (:class:`~pytorch_transformers.XLNetConfig`): Model configuration class with all the parameters of the model.
|
||||
"""
|
||||
|
||||
XLNET_INPUTS_DOCSTRING = r"""
|
||||
Inputs:
|
||||
**input_ids**: ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
|
||||
Indices of input sequence tokens in the vocabulary.
|
||||
Indices can be obtained using :class:`pytorch_transformers.XLNetTokenizer`.
|
||||
See :func:`pytorch_transformers.PreTrainedTokenizer.encode` and
|
||||
:func:`pytorch_transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details.
|
||||
**token_type_ids**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
|
||||
A parallel sequence of tokens (can be used to indicate various portions of the inputs).
|
||||
The embeddings from these tokens will be summed with the respective token embeddings.
|
||||
Indices are selected in the vocabulary (unlike BERT which has a specific vocabulary for segment indices).
|
||||
**attention_mask**: (`optional`) ``torch.Tensor`` of shape ``(batch_size, sequence_length)``:
|
||||
Mask to avoid performing attention on padding token indices.
|
||||
Mask values selected in ``[0, 1]``:
|
||||
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
|
||||
**input_mask**: (`optional`) ``torch.Tensor`` of shape ``(batch_size, sequence_length)``:
|
||||
Mask to avoid performing attention on padding token indices.
|
||||
Negative of `attention_mask`, i.e. with 0 for real tokens and 1 for padding.
|
||||
Kept for compatibility with the original code base.
|
||||
You can only uses one of `input_mask` and `attention_mask`
|
||||
Mask values selected in ``[0, 1]``:
|
||||
``1`` for tokens that are MASKED, ``0`` for tokens that are NOT MASKED.
|
||||
**mems**: (`optional`)
|
||||
list of ``torch.FloatTensor`` (one for each layer):
|
||||
that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
|
||||
(see `mems` output below). Can be used to speed up sequential decoding and attend to longer context.
|
||||
**perm_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, sequence_length)``:
|
||||
Mask to indicate the attention pattern for each input token with values selected in ``[0, 1]``:
|
||||
If ``perm_mask[k, i, j] = 0``, i attend to j in batch k;
|
||||
if ``perm_mask[k, i, j] = 1``, i does not attend to j in batch k.
|
||||
If None, each token attends to all the others (full bidirectional attention).
|
||||
Only used during pretraining (to define factorization order) or for sequential decoding (generation).
|
||||
**target_mapping**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, num_predict, sequence_length)``:
|
||||
Mask to indicate the output tokens to use.
|
||||
If ``target_mapping[k, i, j] = 1``, the i-th predict in batch k is on the j-th token.
|
||||
Only used during pretraining for partial prediction or for sequential decoding (generation).
|
||||
**head_mask**: (`optional`) ``torch.Tensor`` of shape ``(num_heads,)`` or ``(num_layers, num_heads)``:
|
||||
Mask to nullify selected heads of the self-attention modules.
|
||||
Mask values selected in ``[0, 1]``:
|
||||
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
|
||||
"""
|
||||
|
||||
@add_start_docstrings("The bare XLNet Model transformer outputing raw hidden-states without any specific head on top.",
|
||||
XLNET_START_DOCSTRING, XLNET_INPUTS_DOCSTRING)
|
||||
class XLNetModel(XLNetPreTrainedModel):
|
||||
"""XLNet model ("XLNet: Generalized Autoregressive Pretraining for Language Understanding").
|
||||
r"""
|
||||
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
|
||||
**last_hidden_state**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, hidden_size)``
|
||||
Sequence of hidden-states at the last layer of the model.
|
||||
**mems**:
|
||||
list of ``torch.FloatTensor`` (one for each layer):
|
||||
that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
|
||||
(see `mems` input above). Can be used to speed up sequential decoding and attend to longer context.
|
||||
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
|
||||
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
|
||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
|
||||
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
|
||||
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
|
||||
of shape ``(batch_size, sequence_length, hidden_size)``:
|
||||
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||
|
||||
TODO Lysandre filled: this was copied from the XLNetLMHeadModel, check that it's ok.
|
||||
Examples::
|
||||
|
||||
Args:
|
||||
`config`: a XLNetConfig class instance with the configuration to build a new model
|
||||
`output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False
|
||||
`keep_multihead_output`: If True, saves output of the multi-head attention module with its gradient.
|
||||
This can be used to compute head importance metrics. Default: False
|
||||
>>> config = XLNetConfig.from_pretrained('xlnet-large-cased')
|
||||
>>> tokenizer = XLNetTokenizer.from_pretrained('xlnet-large-cased')
|
||||
>>> model = XLNetModel(config)
|
||||
>>> input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
|
||||
>>> outputs = model(input_ids)
|
||||
>>> last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
|
||||
|
||||
|
||||
Example::
|
||||
|
||||
config = modeling.XLNetConfig(vocab_size_or_config_json_file=32000, d_model=768,
|
||||
n_layer=12, num_attention_heads=12, intermediate_size=3072)
|
||||
|
||||
model = modeling.XLNetModel(config=config)
|
||||
|
||||
TODO Lysandre filled: Added example usage
|
||||
"""
|
||||
def __init__(self, config):
|
||||
super(XLNetModel, self).__init__(config)
|
||||
@ -765,50 +843,6 @@ class XLNetModel(XLNetPreTrainedModel):
|
||||
|
||||
def forward(self, input_ids, token_type_ids=None, input_mask=None, attention_mask=None,
|
||||
mems=None, perm_mask=None, target_mapping=None, head_mask=None):
|
||||
"""
|
||||
Performs a model forward pass. **Can be called by calling the class directly, once it has been instantiated.**
|
||||
|
||||
Args:
|
||||
input_ids: int32 Tensor in shape [bsz, len], the input token IDs.
|
||||
token_type_ids: int32 Tensor in shape [bsz, len], the input segment IDs.
|
||||
input_mask: [optional] float32 Tensor in shape [bsz, len], the input mask.
|
||||
0 for real tokens and 1 for padding.
|
||||
attention_mask: [optional] float32 Tensor, SAME FUNCTION as `input_mask`
|
||||
but with 1 for real tokens and 0 for padding.
|
||||
Added for easy compatibility with the BERT model (which uses this negative masking).
|
||||
You can only uses one among `input_mask` and `attention_mask`
|
||||
mems: [optional] a list of float32 Tensors in shape [mem_len, bsz, d_model], memory
|
||||
from previous batches. The length of the list equals n_layer.
|
||||
If None, no memory is used.
|
||||
perm_mask: [optional] float32 Tensor in shape [bsz, len, len].
|
||||
If perm_mask[k, i, j] = 0, i attend to j in batch k;
|
||||
if perm_mask[k, i, j] = 1, i does not attend to j in batch k.
|
||||
If None, each position attends to all the others.
|
||||
target_mapping: [optional] float32 Tensor in shape [bsz, num_predict, len].
|
||||
If target_mapping[k, i, j] = 1, the i-th predict in batch k is
|
||||
on the j-th token.
|
||||
Only used during pretraining for partial prediction.
|
||||
Set to None during finetuning.
|
||||
head_mask: TODO Lysandre didn't fill
|
||||
|
||||
|
||||
Returns:
|
||||
TODO Lysandre didn't fill: Missing returns!
|
||||
|
||||
Example::
|
||||
|
||||
# Already been converted into WordPiece token ids
|
||||
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
|
||||
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
|
||||
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
|
||||
|
||||
all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
|
||||
# or
|
||||
all_encoder_layers, pooled_output = model.forward(input_ids, token_type_ids, input_mask)
|
||||
|
||||
TODO Lysandre filled: Filled with the LMHead example, is probably different since it has a different output
|
||||
|
||||
"""
|
||||
# the original code for XLNet uses shapes [len, bsz] with the batch dimension at the end
|
||||
# but we want a unified interface in the library with the batch size on the first dimension
|
||||
# so we move here the first dimension (batch) to the end
|
||||
@ -952,23 +986,49 @@ class XLNetModel(XLNetPreTrainedModel):
|
||||
return outputs # outputs, new_mems, (hidden_states), (attentions)
|
||||
|
||||
|
||||
@add_start_docstrings("""XLNet Model with a language modeling head on top
|
||||
(linear layer with weights tied to the input embeddings). """,
|
||||
XLNET_START_DOCSTRING, XLNET_INPUTS_DOCSTRING)
|
||||
class XLNetLMHeadModel(XLNetPreTrainedModel):
|
||||
"""XLNet model ("XLNet: Generalized Autoregressive Pretraining for Language Understanding").
|
||||
r"""
|
||||
**labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
|
||||
Labels for language modeling.
|
||||
Note that the labels **are shifted** inside the model, i.e. you can set ``lm_labels = input_ids``
|
||||
Indices are selected in ``[-1, 0, ..., config.vocab_size]``
|
||||
All labels set to ``-1`` are ignored (masked), the loss is only
|
||||
computed for labels in ``[0, ..., config.vocab_size]``
|
||||
|
||||
Args:
|
||||
`config`: a XLNetConfig class instance with the configuration to build a new model
|
||||
`output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False
|
||||
`keep_multihead_output`: If True, saves output of the multi-head attention module with its gradient.
|
||||
This can be used to compute head importance metrics. Default: False
|
||||
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
|
||||
**loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
|
||||
Language modeling loss.
|
||||
**prediction_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, config.vocab_size)``
|
||||
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
||||
**mems**:
|
||||
list of ``torch.FloatTensor`` (one for each layer):
|
||||
that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
|
||||
(see `mems` input above). Can be used to speed up sequential decoding and attend to longer context.
|
||||
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
|
||||
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
|
||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
|
||||
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
|
||||
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
|
||||
of shape ``(batch_size, sequence_length, hidden_size)``:
|
||||
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||
|
||||
Example::
|
||||
Examples::
|
||||
|
||||
config = modeling.XLNetConfig(vocab_size_or_config_json_file=32000, d_model=768,
|
||||
n_layer=12, num_attention_heads=12, intermediate_size=3072)
|
||||
>>> config = XLNetConfig.from_pretrained('xlnet-large-cased')
|
||||
>>> tokenizer = XLNetTokenizer.from_pretrained('xlnet-large-cased')
|
||||
>>> model = XLNetLMHeadModel(config)
|
||||
>>> # We show how to setup inputs to predict a next token using a bi-directional context.
|
||||
>>> input_ids = torch.tensor(tokenizer.encode("Hello, my dog is very <mask>")).unsqueeze(0) # We will predict the masked token
|
||||
>>> perm_mask = torch.zeros((1, input_ids.shape[1], input_ids.shape[1]), dtype=torch.float)
|
||||
>>> perm_mask[:, :, -1] = 1.0 # Previous tokens don't see last token
|
||||
>>> target_mapping = torch.zeros((1, 1, input_ids.shape[1]), dtype=torch.float) # Shape [1, 1, seq_length] => let's predict one token
|
||||
>>> target_mapping[0, 0, -1] = 1.0 # Our first (and only) prediction will be the last token of the sequence (the masked token)
|
||||
>>> outputs = model(input_ids, perm_mask=perm_mask, target_mapping=target_mapping)
|
||||
>>> next_token_logits = outputs[0] # Output has shape [target_mapping.size(0), target_mapping.size(1), config.vocab_size]
|
||||
|
||||
model = modeling.XLNetLMHeadModel(config=config)
|
||||
|
||||
TODO Lysandre modified: Changed XLNetModel to XLNetLMHeadModel in the example
|
||||
"""
|
||||
def __init__(self, config):
|
||||
super(XLNetLMHeadModel, self).__init__(config)
|
||||
@ -989,58 +1049,6 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
|
||||
def forward(self, input_ids, token_type_ids=None, input_mask=None, attention_mask=None,
|
||||
mems=None, perm_mask=None, target_mapping=None,
|
||||
labels=None, head_mask=None):
|
||||
"""
|
||||
all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
|
||||
|
||||
Args:
|
||||
input_ids: int32 Tensor in shape [bsz, len], the input token IDs.
|
||||
token_type_ids: int32 Tensor in shape [bsz, len], the input segment IDs.
|
||||
input_mask: [optional] float32 Tensor in shape [bsz, len], the input mask.
|
||||
0 for real tokens and 1 for padding.
|
||||
attention_mask: [optional] float32 Tensor, SAME FUNCTION as `input_mask`
|
||||
but with 1 for real tokens and 0 for padding.
|
||||
Added for easy compatibility with the BERT model (which uses this negative masking).
|
||||
You can only uses one among `input_mask` and `attention_mask`
|
||||
mems: [optional] a list of float32 Tensors in shape [mem_len, bsz, d_model], memory
|
||||
from previous batches. The length of the list equals n_layer.
|
||||
If None, no memory is used.
|
||||
perm_mask: [optional] float32 Tensor in shape [bsz, len, len].
|
||||
If perm_mask[k, i, j] = 0, i attend to j in batch k;
|
||||
if perm_mask[k, i, j] = 1, i does not attend to j in batch k.
|
||||
If None, each position attends to all the others.
|
||||
target_mapping: [optional] float32 Tensor in shape [bsz, num_predict, len].
|
||||
If target_mapping[k, i, j] = 1, the i-th predict in batch k is
|
||||
on the j-th token.
|
||||
Only used during pretraining for partial prediction.
|
||||
Set to None during finetuning.
|
||||
|
||||
Returns:
|
||||
A ``tuple(encoded_layers, pooled_output)``, with
|
||||
|
||||
``encoded_layers``: controlled by ``output_all_encoded_layers`` argument:
|
||||
|
||||
- ``output_all_encoded_layers=True``: outputs a list of the full sequences of encoded-hidden-states \
|
||||
at the end of each attention block (i.e. 12 full sequences for XLNet-base, 24 for XLNet-large), \
|
||||
each encoded-hidden-state is a ``torch.FloatTensor`` of size [batch_size, sequence_length, d_model],
|
||||
|
||||
- ``output_all_encoded_layers=False``: outputs only the full sequence of hidden-states corresponding \
|
||||
to the last attention block of shape [batch_size, sequence_length, d_model],
|
||||
|
||||
``pooled_output``: a ``torch.FloatTensor`` of size [batch_size, d_model] which is the output of a \
|
||||
classifier pretrained on top of the hidden state associated to the first character of the \
|
||||
input (`CLS`) to train on the Next-Sentence task (see XLNet's paper).
|
||||
|
||||
Example::
|
||||
|
||||
# Already been converted into WordPiece token ids
|
||||
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
|
||||
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
|
||||
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
|
||||
|
||||
all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
|
||||
# or
|
||||
all_encoder_layers, pooled_output = model.forward(input_ids, token_type_ids, input_mask)
|
||||
"""
|
||||
transformer_outputs = self.transformer(input_ids, token_type_ids, input_mask, attention_mask,
|
||||
mems, perm_mask, target_mapping, head_mask)
|
||||
|
||||
@ -1055,30 +1063,48 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
|
||||
labels.view(-1))
|
||||
outputs = (loss,) + outputs
|
||||
|
||||
return outputs # return (loss), logits, (mems), (hidden states), (attentions)
|
||||
return outputs # return (loss), logits, mems, (hidden states), (attentions)
|
||||
|
||||
|
||||
@add_start_docstrings("""XLNet Model with a sequence classification/regression head on top (a linear layer on top of
|
||||
the pooled output) e.g. for GLUE tasks. """,
|
||||
XLNET_START_DOCSTRING, XLNET_INPUTS_DOCSTRING)
|
||||
class XLNetForSequenceClassification(XLNetPreTrainedModel):
|
||||
"""XLNet model ("XLNet: Generalized Autoregressive Pretraining for Language Understanding").
|
||||
r"""
|
||||
**labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
|
||||
Labels for computing the sequence classification/regression loss.
|
||||
Indices should be in ``[0, ..., config.num_labels]``.
|
||||
If ``config.num_labels == 1`` a regression loss is computed (Mean-Square loss),
|
||||
If ``config.num_labels > 1`` a classification loss is computed (Cross-Entropy).
|
||||
|
||||
Args:
|
||||
`config`: a XLNetConfig class instance with the configuration to build a new model
|
||||
`output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False
|
||||
`keep_multihead_output`: If True, saves output of the multi-head attention module with its gradient.
|
||||
This can be used to compute head importance metrics. Default: False
|
||||
`summary_type`: str, "last", "first", "mean", or "attn". The method
|
||||
to pool the input to get a vector representation. Default: last
|
||||
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
|
||||
**loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
|
||||
Classification (or regression if config.num_labels==1) loss.
|
||||
**logits**: ``torch.FloatTensor`` of shape ``(batch_size, config.num_labels)``
|
||||
Classification (or regression if config.num_labels==1) scores (before SoftMax).
|
||||
**mems**:
|
||||
list of ``torch.FloatTensor`` (one for each layer):
|
||||
that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
|
||||
(see `mems` input above). Can be used to speed up sequential decoding and attend to longer context.
|
||||
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
|
||||
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
|
||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
|
||||
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
|
||||
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
|
||||
of shape ``(batch_size, sequence_length, hidden_size)``:
|
||||
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||
|
||||
Examples::
|
||||
|
||||
>>> config = XLNetConfig.from_pretrained('xlnet-large-cased')
|
||||
>>> tokenizer = XLNetTokenizer.from_pretrained('xlnet-large-cased')
|
||||
>>>
|
||||
>>> model = XLNetForSequenceClassification(config)
|
||||
>>> input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
|
||||
>>> labels = torch.tensor([1]).unsqueeze(0) # Batch size 1
|
||||
>>> outputs = model(input_ids, labels=labels)
|
||||
>>> loss, logits = outputs[:2]
|
||||
|
||||
Example::
|
||||
|
||||
# Already been converted into WordPiece token ids
|
||||
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
|
||||
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
|
||||
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
|
||||
|
||||
all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
|
||||
"""
|
||||
def __init__(self, config):
|
||||
super(XLNetForSequenceClassification, self).__init__(config)
|
||||
@ -1093,57 +1119,6 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
|
||||
def forward(self, input_ids, token_type_ids=None, input_mask=None, attention_mask=None,
|
||||
mems=None, perm_mask=None, target_mapping=None,
|
||||
labels=None, head_mask=None):
|
||||
"""
|
||||
Performs a model forward pass. **Can be called by calling the class directly, once it has been instantiated.**
|
||||
|
||||
Args:
|
||||
input_ids: int32 Tensor in shape [bsz, len], the input token IDs.
|
||||
token_type_ids: int32 Tensor in shape [bsz, len], the input segment IDs.
|
||||
input_mask: float32 Tensor in shape [bsz, len], the input mask.
|
||||
0 for real tokens and 1 for padding.
|
||||
attention_mask: [optional] float32 Tensor, SAME FUNCTION as `input_mask`
|
||||
but with 1 for real tokens and 0 for padding.
|
||||
Added for easy compatibility with the BERT model (which uses this negative masking).
|
||||
You can only uses one among `input_mask` and `attention_mask`
|
||||
mems: a list of float32 Tensors in shape [mem_len, bsz, d_model], memory
|
||||
from previous batches. The length of the list equals n_layer.
|
||||
If None, no memory is used.
|
||||
perm_mask: float32 Tensor in shape [bsz, len, len].
|
||||
If perm_mask[k, i, j] = 0, i attend to j in batch k;
|
||||
if perm_mask[k, i, j] = 1, i does not attend to j in batch k.
|
||||
If None, each position attends to all the others.
|
||||
target_mapping: float32 Tensor in shape [bsz, num_predict, len].
|
||||
If target_mapping[k, i, j] = 1, the i-th predict in batch k is
|
||||
on the j-th token.
|
||||
Only used during pre-training for partial prediction.
|
||||
Set to None during fine-tuning.
|
||||
labels: TODO Lysandre didn't fill
|
||||
head_mask: an optional ``torch.Tensor`` of shape [num_heads] or [num_layers, num_heads] with indices between 0 and 1.
|
||||
It's a mask to be used to nullify some heads of the transformer. 1.0 => head is fully masked, 0.0 => head is not masked.
|
||||
|
||||
|
||||
Returns:
|
||||
A ``tuple(logits_or_loss, mems)``
|
||||
|
||||
``logits_or_loss``: if ``labels`` is ``None``, ``logits_or_loss`` corresponds to token logits with shape \
|
||||
[batch_size, sequence_length]. If it is not ``None``, it corresponds to the ``CrossEntropy`` loss \
|
||||
with the targets.
|
||||
|
||||
``new_mems``: list (num layers) of updated mem states at the entry of each layer \
|
||||
each mem state is a ``torch.FloatTensor`` of size [self.config.mem_len, batch_size, self.config.d_model] \
|
||||
Note that the first two dimensions are transposed in ``mems`` with regards to ``input_ids`` and ``labels``
|
||||
|
||||
Example::
|
||||
|
||||
# Already been converted into WordPiece token ids
|
||||
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
|
||||
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
|
||||
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
|
||||
|
||||
all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
|
||||
# or
|
||||
all_encoder_layers, pooled_output = model.forward(input_ids, token_type_ids, input_mask)
|
||||
"""
|
||||
transformer_outputs = self.transformer(input_ids, token_type_ids, input_mask, attention_mask,
|
||||
mems, perm_mask, target_mapping, head_mask)
|
||||
output = transformer_outputs[0]
|
||||
@ -1163,28 +1138,60 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
|
||||
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||
outputs = (loss,) + outputs
|
||||
|
||||
return outputs # return (loss), logits, (mems), (hidden states), (attentions)
|
||||
return outputs # return (loss), logits, mems, (hidden states), (attentions)
|
||||
|
||||
|
||||
@add_start_docstrings("""XLNet Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of
|
||||
the hidden-states output to compute `span start logits` and `span end logits`). """,
|
||||
XLNET_START_DOCSTRING, XLNET_INPUTS_DOCSTRING)
|
||||
class XLNetForQuestionAnswering(XLNetPreTrainedModel):
|
||||
"""
|
||||
XLNet model for Question Answering (span extraction).
|
||||
r"""
|
||||
**start_positions**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
|
||||
Labels for position (index) of the start of the labelled span for computing the token classification loss.
|
||||
Positions are clamped to the length of the sequence (`sequence_length`).
|
||||
Position outside of the sequence are not taken into account for computing the loss.
|
||||
**end_positions**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
|
||||
Labels for position (index) of the end of the labelled span for computing the token classification loss.
|
||||
Positions are clamped to the length of the sequence (`sequence_length`).
|
||||
Position outside of the sequence are not taken into account for computing the loss.
|
||||
**is_impossible**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
|
||||
Labels whether a question has an answer or no answer (SQuAD 2.0)
|
||||
**cls_index**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
|
||||
Labels for position (index) of the classification token to use as input for computing plausibility of the answer.
|
||||
**p_mask**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
|
||||
Optional mask of tokens which can't be in answers (e.g. [CLS], [PAD], ...)
|
||||
|
||||
This module is composed of the XLNet model with a linear layer on top of
|
||||
the sequence output that computes ``start_logits`` and ``end_logits``
|
||||
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
|
||||
**loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
|
||||
Total span extraction loss is the sum of a Cross-Entropy for the start and end positions.
|
||||
**start_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length,)``
|
||||
Span-start scores (before SoftMax).
|
||||
**end_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length,)``
|
||||
Span-end scores (before SoftMax).
|
||||
**mems**:
|
||||
list of ``torch.FloatTensor`` (one for each layer):
|
||||
that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
|
||||
(see `mems` input above). Can be used to speed up sequential decoding and attend to longer context.
|
||||
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
|
||||
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
|
||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
|
||||
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
|
||||
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
|
||||
of shape ``(batch_size, sequence_length, hidden_size)``:
|
||||
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||
|
||||
Args:
|
||||
`config`: a XLNetConfig class instance with the configuration to build a new model
|
||||
`output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False
|
||||
`keep_multihead_output`: If True, saves output of the multi-head attention module with its gradient.
|
||||
This can be used to compute head importance metrics. Default: False
|
||||
Examples::
|
||||
|
||||
Example::
|
||||
>>> config = XLMConfig.from_pretrained('xlm-mlm-en-2048')
|
||||
>>> tokenizer = XLMTokenizer.from_pretrained('xlm-mlm-en-2048')
|
||||
>>>
|
||||
>>> model = XLMForQuestionAnswering(config)
|
||||
>>> input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
|
||||
>>> start_positions = torch.tensor([1])
|
||||
>>> end_positions = torch.tensor([3])
|
||||
>>> outputs = model(input_ids, start_positions=start_positions, end_positions=end_positions)
|
||||
>>> loss, start_scores, end_scores = outputs[:2]
|
||||
|
||||
config = XLNetConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
|
||||
num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
|
||||
|
||||
model = XLNetForQuestionAnswering(config)
|
||||
"""
|
||||
def __init__(self, config):
|
||||
super(XLNetForQuestionAnswering, self).__init__(config)
|
||||
@ -1202,53 +1209,6 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel):
|
||||
mems=None, perm_mask=None, target_mapping=None,
|
||||
start_positions=None, end_positions=None, cls_index=None, is_impossible=None, p_mask=None,
|
||||
head_mask=None):
|
||||
|
||||
"""
|
||||
Performs a model forward pass. **Can be called by calling the class directly, once it has been instantiated.**
|
||||
|
||||
Args:
|
||||
`input_ids`: a ``torch.LongTensor`` of shape [batch_size, sequence_length]
|
||||
with the word token indices in the vocabulary(see the tokens pre-processing logic in the scripts
|
||||
`run_bert_extract_features.py`, `run_bert_classifier.py` and `run_bert_squad.py`)
|
||||
`token_type_ids`: an optional ``torch.LongTensor`` of shape [batch_size, sequence_length] with the token
|
||||
types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
|
||||
a `sentence B` token (see XLNet paper for more details).
|
||||
`attention_mask`: [optional] float32 Tensor, SAME FUNCTION as `input_mask`
|
||||
but with 1 for real tokens and 0 for padding.
|
||||
Added for easy compatibility with the BERT model (which uses this negative masking).
|
||||
You can only uses one among ``input_mask`` and ``attention_mask``
|
||||
`input_mask`: an optional ``torch.LongTensor`` of shape [batch_size, sequence_length] with indices
|
||||
selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
|
||||
input sequence length in the current batch. It's the mask that we typically use for attention when
|
||||
a batch has varying length sentences.
|
||||
`start_positions`: position of the first token for the labeled span: ``torch.LongTensor`` of shape [batch_size].
|
||||
Positions are clamped to the length of the sequence and position outside of the sequence are not taken
|
||||
into account for computing the loss.
|
||||
`end_positions`: position of the last token for the labeled span: ``torch.LongTensor`` of shape [batch_size].
|
||||
Positions are clamped to the length of the sequence and position outside of the sequence are not taken
|
||||
into account for computing the loss.
|
||||
`head_mask`: an optional ``torch.Tensor`` of shape [num_heads] or [num_layers, num_heads] with indices between 0 and 1.
|
||||
It's a mask to be used to nullify some heads of the transformer. 1.0 => head is fully masked, 0.0 => head is not masked.
|
||||
|
||||
Returns:
|
||||
if ``start_positions`` and ``end_positions`` are not ``None``, outputs the total_loss which is the sum of the \
|
||||
``CrossEntropy`` loss for the start and end token positions.
|
||||
|
||||
if ``start_positions`` or ``end_positions`` is ``None``, outputs a tuple of ``start_logits``, ``end_logits``
|
||||
which are the logits respectively for the start and end position tokens of shape \
|
||||
[batch_size, sequence_length].
|
||||
|
||||
Example::
|
||||
|
||||
# Already been converted into WordPiece token ids
|
||||
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
|
||||
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
|
||||
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
|
||||
|
||||
start_logits, end_logits = model(input_ids, token_type_ids, input_mask)
|
||||
# or
|
||||
start_logits, end_logits = model.forward(input_ids, token_type_ids, input_mask)
|
||||
"""
|
||||
transformer_outputs = self.transformer(input_ids, token_type_ids, input_mask, attention_mask,
|
||||
mems, perm_mask, target_mapping, head_mask)
|
||||
hidden_states = transformer_outputs[0]
|
||||
|
Loading…
Reference in New Issue
Block a user