added gpt2 doc

This commit is contained in:
thomwolf 2019-07-15 09:40:05 +02:00
parent 183fedfed5
commit 5bc3d0cc5b
2 changed files with 210 additions and 211 deletions

View File

@ -277,10 +277,11 @@ class BertEmbeddings(nn.Module):
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, input_ids, token_type_ids=None):
def forward(self, input_ids, position_ids=None, token_type_ids=None):
seq_length = input_ids.size(1)
position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
if position_ids is None:
position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
if token_type_ids is None:
token_type_ids = torch.zeros_like(input_ids)
@ -624,6 +625,9 @@ BERT_INPUTS_DOCSTRING = r"""
Indices can be obtained using :class:`pytorch_transformers.BertTokenizer`.
See :func:`pytorch_transformers.PreTrainedTokenizer.encode` and
:func:`pytorch_transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details.
**position_ids**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
Indices of positions of each input sequence tokens in the position embeddings.
Selected in the range ``[0, config.max_position_embeddings - 1[``.
**token_type_ids**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
Segment token indices to indicate first and second portions of the inputs.
Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1``
@ -687,7 +691,7 @@ class BertModel(BertPreTrainedModel):
for layer, heads in heads_to_prune.items():
self.encoder.layer[layer].attention.prune_heads(heads)
def forward(self, input_ids, token_type_ids=None, attention_mask=None, head_mask=None):
def forward(self, input_ids, position_ids=None, token_type_ids=None, attention_mask=None, head_mask=None):
if attention_mask is None:
attention_mask = torch.ones_like(input_ids)
if token_type_ids is None:
@ -723,7 +727,7 @@ class BertModel(BertPreTrainedModel):
else:
head_mask = [None] * self.config.num_hidden_layers
embedding_output = self.embeddings(input_ids, token_type_ids)
embedding_output = self.embeddings(input_ids, position_ids, token_type_ids)
encoder_outputs = self.encoder(embedding_output,
extended_attention_mask,
head_mask=head_mask)
@ -773,7 +777,7 @@ class BertForPreTraining(BertPreTrainedModel):
>>> model = BertForPreTraining(config)
>>> input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
>>> outputs = model(input_ids)
>>> prediction_scores, seq_relationship_scores = outputs[:1]
>>> prediction_scores, seq_relationship_scores = outputs[:2]
"""
def __init__(self, config):
@ -792,9 +796,9 @@ class BertForPreTraining(BertPreTrainedModel):
self._tie_or_clone_weights(self.cls.predictions.decoder,
self.bert.embeddings.word_embeddings)
def forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None,
def forward(self, input_ids, position_ids=None, token_type_ids=None, attention_mask=None, masked_lm_labels=None,
next_sentence_label=None, head_mask=None):
outputs = self.bert(input_ids, token_type_ids, attention_mask, head_mask=head_mask)
outputs = self.bert(input_ids, position_ids, token_type_ids, attention_mask, head_mask=head_mask)
sequence_output, pooled_output = outputs[:2]
prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)
@ -842,7 +846,7 @@ class BertForMaskedLM(BertPreTrainedModel):
>>> model = BertForMaskedLM(config)
>>> input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
>>> outputs = model(input_ids, masked_lm_labels=input_ids)
>>> loss, prediction_scores = outputs[:1]
>>> loss, prediction_scores = outputs[:2]
"""
def __init__(self, config):
@ -861,8 +865,8 @@ class BertForMaskedLM(BertPreTrainedModel):
self._tie_or_clone_weights(self.cls.predictions.decoder,
self.bert.embeddings.word_embeddings)
def forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None, head_mask=None):
outputs = self.bert(input_ids, token_type_ids, attention_mask, head_mask=head_mask)
def forward(self, input_ids, position_ids=None, token_type_ids=None, attention_mask=None, masked_lm_labels=None, head_mask=None):
outputs = self.bert(input_ids, position_ids, token_type_ids, attention_mask, head_mask=head_mask)
sequence_output = outputs[0]
prediction_scores = self.cls(sequence_output)
@ -918,8 +922,8 @@ class BertForNextSentencePrediction(BertPreTrainedModel):
self.apply(self.init_weights)
def forward(self, input_ids, token_type_ids=None, attention_mask=None, next_sentence_label=None, head_mask=None):
outputs = self.bert(input_ids, token_type_ids, attention_mask, head_mask=head_mask)
def forward(self, input_ids, position_ids=None, token_type_ids=None, attention_mask=None, next_sentence_label=None, head_mask=None):
outputs = self.bert(input_ids, position_ids, token_type_ids, attention_mask, head_mask=head_mask)
pooled_output = outputs[1]
seq_relationship_score = self.cls(pooled_output)
@ -966,7 +970,7 @@ class BertForSequenceClassification(BertPreTrainedModel):
>>> input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
>>> labels = torch.tensor([1]).unsqueeze(0) # Batch size 1
>>> outputs = model(input_ids, labels=labels)
>>> loss, logits = outputs[:1]
>>> loss, logits = outputs[:2]
"""
def __init__(self, config):
@ -979,8 +983,8 @@ class BertForSequenceClassification(BertPreTrainedModel):
self.apply(self.init_weights)
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None, head_mask=None):
outputs = self.bert(input_ids, token_type_ids, attention_mask, head_mask=head_mask)
def forward(self, input_ids, position_ids=None, token_type_ids=None, attention_mask=None, labels=None, head_mask=None):
outputs = self.bert(input_ids, position_ids, token_type_ids, attention_mask, head_mask=head_mask)
pooled_output = outputs[1]
pooled_output = self.dropout(pooled_output)
@ -1071,7 +1075,7 @@ class BertForMultipleChoice(BertPreTrainedModel):
>>> input_ids = torch.tensor([tokenizer.encode(s) for s in choices]).unsqueeze(0) # Batch size 1, 2 choices
>>> labels = torch.tensor(1).unsqueeze(0) # Batch size 1
>>> outputs = model(input_ids, labels=labels)
>>> loss, classification_scores = outputs[:1]
>>> loss, classification_scores = outputs[:2]
"""
def __init__(self, config):
@ -1083,13 +1087,14 @@ class BertForMultipleChoice(BertPreTrainedModel):
self.apply(self.init_weights)
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None, head_mask=None):
def forward(self, input_ids, position_ids=None, token_type_ids=None, attention_mask=None, labels=None, head_mask=None):
num_choices = input_ids.shape[1]
flat_input_ids = input_ids.view(-1, input_ids.size(-1))
flat_position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
outputs = self.bert(flat_input_ids, flat_token_type_ids, flat_attention_mask, head_mask=head_mask)
outputs = self.bert(flat_input_ids, flat_position_ids, flat_token_type_ids, flat_attention_mask, head_mask=head_mask)
pooled_output = outputs[1]
pooled_output = self.dropout(pooled_output)
@ -1137,7 +1142,7 @@ class BertForTokenClassification(BertPreTrainedModel):
>>> input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
>>> labels = torch.tensor([1] * input_ids.size(1)).unsqueeze(0) # Batch size 1
>>> outputs = model(input_ids, labels=labels)
>>> loss, scores = outputs[:1]
>>> loss, scores = outputs[:2]
"""
def __init__(self, config):
@ -1150,8 +1155,8 @@ class BertForTokenClassification(BertPreTrainedModel):
self.apply(self.init_weights)
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None, head_mask=None):
outputs = self.bert(input_ids, token_type_ids, attention_mask, head_mask=head_mask)
def forward(self, input_ids, position_ids=None, token_type_ids=None, attention_mask=None, labels=None, head_mask=None):
outputs = self.bert(input_ids, position_ids, token_type_ids, attention_mask, head_mask=head_mask)
sequence_output = outputs[0]
sequence_output = self.dropout(sequence_output)
@ -1177,7 +1182,7 @@ class BertForTokenClassification(BertPreTrainedModel):
the hidden-states output to compute `span start logits` and `span end logits`). """,
BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING)
class BertForQuestionAnswering(BertPreTrainedModel):
r"""
__doc__ = r"""
**start_positions**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
Position (index) of the start of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (`sequence_length`).
@ -1224,9 +1229,9 @@ class BertForQuestionAnswering(BertPreTrainedModel):
self.apply(self.init_weights)
def forward(self, input_ids, token_type_ids=None, attention_mask=None, start_positions=None,
def forward(self, input_ids, position_ids=None, token_type_ids=None, attention_mask=None, start_positions=None,
end_positions=None, head_mask=None):
outputs = self.bert(input_ids, token_type_ids, attention_mask, head_mask=head_mask)
outputs = self.bert(input_ids, position_ids, token_type_ids, attention_mask, head_mask=head_mask)
sequence_output = outputs[0]
logits = self.qa_outputs(sequence_output)

View File

@ -365,44 +365,81 @@ class GPT2PreTrainedModel(PreTrainedModel):
module.weight.data.fill_(1.0)
GPT2_START_DOCSTRING = r""" OpenAI GPT-2 model was proposed in
`Language Models are Unsupervised Multitask Learners`_
by Alec Radford*, Jeffrey Wu*, Rewon Child, David Luan, Dario Amodei** and Ilya Sutskever**.
It's a causal (unidirectional) transformer pre-trained using language modeling on a very large
corpus of ~40 GB of text data.
This model is a PyTorch `torch.nn.Module`_ sub-class. Use it as a regular PyTorch Module and
refer to the PyTorch documentation for all matter related to general usage and behavior.
.. _`Language Models are Unsupervised Multitask Learners`:
https://openai.com/blog/better-language-models/
.. _`torch.nn.Module`:
https://pytorch.org/docs/stable/nn.html#module
Parameters:
config (:class:`~pytorch_transformers.BertConfig`): Model configuration class with all the parameters of the model.
"""
GPT2_INPUTS_DOCTRING = r""" Inputs:
**input_ids**: ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
Indices of input sequence tokens in the vocabulary.
Indices can be obtained using :class:`pytorch_transformers.BPT2Tokenizer`.
See :func:`pytorch_transformers.PreTrainedTokenizer.encode` and
:func:`pytorch_transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details.
**position_ids**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
Indices of positions of each input sequence tokens in the position embeddings.
Selected in the range ``[0, config.max_position_embeddings - 1[``.
**token_type_ids**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
A parallel sequence of tokens (can be used to indicate various portions of the inputs).
The embeddings from these tokens will be summed with the respective token embeddings.
Indices are selected in the vocabulary (unlike BERT which has a specific vocabulary for segment indices).
**past**:
list of ``torch.FloatTensor`` (one for each layer):
that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
(see `past` output below). Can be used to speed up sequential decoding.
**attention_mask**: (`optional`) ``torch.Tensor`` of shape ``(batch_size, sequence_length)``:
Mask to avoid performing attention on padding token indices.
Mask indices selected in ``[0, 1]``:
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
**head_mask**: (`optional`) ``torch.Tensor`` of shape ``(num_heads,)`` or ``(num_layers, num_heads)``:
Mask to nullify selected heads of the self-attention modules.
Mask indices selected in ``[0, 1]``:
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
"""
@add_start_docstrings("The bare GPT2 Model transformer outputing raw hidden-states without any specific head on top.",
GPT2_START_DOCSTRING, GPT2_INPUTS_DOCTRING)
class GPT2Model(GPT2PreTrainedModel):
"""OpenAI GPT-2 model ("Language Models are Unsupervised Multitask Learners").
__doc__ = r"""
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
**last_hidden_state**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, hidden_size)``
Sequence of hidden-states at the last layer of the model.
**past**:
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
that contains pre-computed hidden-states (key and values in the attention blocks).
Can be used (see `past` input) to speed up sequential decoding.
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
of shape ``(batch_size, sequence_length, hidden_size)``:
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
GPT-2 use a single embedding matrix to store the word and special embeddings.
Special tokens embeddings are additional tokens that are not pre-trained: [SEP], [CLS]...
Special tokens need to be trained during the fine-tuning if you use them.
The number of special embeddings can be controlled using the `set_num_special_tokens(num_special_tokens)` function.
Examples::
The embeddings are ordered as follow in the token embeddings matrix:
::
>>> config = GPT2Config.from_pretrained('gpt2')
>>> tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
>>> model = GPT2Model(config)
>>> input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
>>> outputs = model(input_ids)
>>> last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
[0, ----------------------
... -> word embeddings
config.vocab_size - 1, ______________________
config.vocab_size,
... -> special embeddings
config.vocab_size + n_special - 1] ______________________
where total_tokens_embeddings is equal to
::
total_tokens_embeddings = vocab_size + n_special
You should use the associated indices to index the embeddings.
Args:
`config`: a GPT2Config class instance with the configuration to build a new model
`output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False
Example::
config = modeling_gpt2.GPT2Config()
model = modeling_gpt2.GPT2Model(config)
"""
def __init__(self, config):
super(GPT2Model, self).__init__(config)
self.output_hidden_states = config.output_hidden_states
@ -428,47 +465,6 @@ class GPT2Model(GPT2PreTrainedModel):
self.h[layer].attn.prune_heads(heads)
def forward(self, input_ids, position_ids=None, token_type_ids=None, past=None, head_mask=None):
"""
Performs a model forward pass. **Can be called by calling the class directly, once it has been instantiated.**
Args:
`input_ids`: a ``torch.LongTensor`` of shape [batch_size, sequence_length] (or more generally [d_1, ..., d_n, sequence_length]
were d_1 ... d_n are arbitrary dimensions) with the word BPE token indices selected in the range [0, config.vocab_size[
`position_ids`: an optional ``torch.LongTensor`` with the same shape as input_ids
with the position indices (selected in the range [0, config.n_positions - 1[.
`token_type_ids`: an optional ``torch.LongTensor`` with the same shape as input_ids
You can use it to add a third type of embedding to each input token in the sequence
(the previous two being the word and position embeddings).
The input, position and token_type embeddings are summed inside the Transformer before the first
self-attention block.
`past`: an optional list of ``torch.LongTensor`` that contains pre-computed hidden-states
(key and values in the attention blocks) to speed up sequential decoding
(this is the presents output of the model, cf. below).
`head_mask`: an optional ``torch.Tensor`` of shape [num_heads] or [num_layers, num_heads] with indices between 0 and 1.
It's a mask to be used to nullify some heads of the transformer. 1.0 => head is fully masked, 0.0 => head is not masked.
Returns:
A tuple consisting of ``hidden_states`` and ``presents``.
``hidden_states`` are a list of all the encoded-hidden-states in the model (length of the list: number of
layers + 1 for the output of the embeddings) as ``torch.FloatTensor`` of size [batch_size, sequence_length,
hidden_size] (or more generally [d_1, ..., d_n, hidden_size] were d_1 ... d_n are the dimension of
input_ids).
``presents`` are a list of pre-computed hidden-states (key and values in each attention blocks) as
torch.FloatTensors. They can be reused to speed up sequential decoding.
Example::
# Already been converted into BPE token ids
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
hidden_states, presents = model(input_ids)
# or
hidden_states, presents = model.forward(input_ids)
"""
if past is None:
past_length = 0
past = [None] * len(self.h)
@ -540,21 +536,44 @@ class GPT2Model(GPT2PreTrainedModel):
return outputs # last hidden state, presents, (all hidden_states), (attentions)
@add_start_docstrings("""The GPT2 Model transformer with a language modeling head on top
(linear layer with weights tied to the input embeddings). """, GPT2_START_DOCSTRING, GPT2_INPUTS_DOCTRING)
class GPT2LMHeadModel(GPT2PreTrainedModel):
"""OpenAI GPT-2 model with a Language Modeling head ("Language Models are Unsupervised Multitask Learners").
__doc__ = r"""
**lm_labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
Labels for language modeling.
Note that the labels **are shifted** inside the model, i.e. you can set ``lm_labels = input_ids``
Indices are selected in ``[-1, 0, ..., config.vocab_size]``
All labels set to ``-1`` are ignored (masked), the loss is only
computed for labels in ``[0, ..., config.vocab_size]``
Args:
`config`: a GPT2Config class instance with the configuration to build a new model
`output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False
`keep_multihead_output`: If True, saves output of the multi-head attention module with its gradient.
This can be used to compute head importance metrics. Default: False
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
**loss**: (`optional`, returned when ``lm_labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
Language modeling loss.
**prediction_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, config.vocab_size)``
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
**past**:
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
that contains pre-computed hidden-states (key and values in the attention blocks).
Can be used (see `past` input) to speed up sequential decoding.
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
of shape ``(batch_size, sequence_length, hidden_size)``:
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
Example::
Examples::
>>> config = GPT2Config.from_pretrained('gpt2')
>>> tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
>>> model = GPT2LMHeadModel(config)
>>> input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
>>> outputs = model(input_ids, lm_labels=input_ids)
>>> loss, logits = outputs[:2]
config = modeling_gpt2.GPT2Config()
model = modeling_gpt2.GPT2LMHeadModel(config)
"""
def __init__(self, config):
super(GPT2LMHeadModel, self).__init__(config)
self.transformer = GPT2Model(config)
@ -571,49 +590,6 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
self.transformer.wte)
def forward(self, input_ids, position_ids=None, token_type_ids=None, lm_labels=None, past=None, head_mask=None):
"""
Performs a model forward pass. **Can be called by calling the class directly, once it has been instantiated.**
Args:
`input_ids`: a ``torch.LongTensor`` of shape [batch_size, sequence_length] (or more generally [d_1, ..., d_n, sequence_length]
were d_1 ... d_n are arbitrary dimensions) with the word BPE token indices selected in the range [0, config.vocab_size[
`position_ids`: an optional ``torch.LongTensor`` with the same shape as input_ids
with the position indices (selected in the range [0, config.n_positions - 1[.
`token_type_ids`: an optional ``torch.LongTensor`` with the same shape as input_ids
You can use it to add a third type of embedding to each input token in the sequence
(the previous two being the word and position embeddings).
The input, position and token_type embeddings are summed inside the Transformer before the first
self-attention block.
`lm_labels`: optional language modeling labels: ``torch.LongTensor`` of shape [batch_size, sequence_length]
with indices selected in [-1, 0, ..., vocab_size]. All labels set to -1 are ignored (masked), the loss
is only computed for the labels set in [0, ..., vocab_size]
`past`: an optional list of ``torch.LongTensor`` that contains pre-computed hidden-states
(key and values in the attention blocks) to speed up sequential decoding
(this is the presents output of the model, cf. below).
`head_mask`: an optional ``torch.Tensor`` of shape [num_heads] or [num_layers, num_heads] with indices between 0 and 1.
It's a mask to be used to nullify some heads of the transformer. 1.0 => head is fully masked, 0.0 => head is not masked.
Returns:
If ``lm_labels`` is not ``None``, returns the language modeling loss. It ``lm_labels`` is ``None``, returns
a tuple of (``lm_logits``, ``presents``).
``lm_logits`` is the language modeling logits as a ``torch.FloatTensor`` of size [batch_size,
sequence_length, config.vocab_size] (or more generally [d_1, ..., d_n, config.vocab_size] were d_1 ...
d_n are the dimension of input_ids).
``presents`` is a list of pre-computed hidden-states (key and values in each attention blocks) as
torch.FloatTensors. They can be reused to speed up sequential decoding.
Example::
# Already been converted into BPE token ids
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
lm_logits, presents = model(input_ids)
# or
lm_logits, presents = model.forward(input_ids)
"""
transformer_outputs = self.transformer(input_ids, position_ids, token_type_ids, past, head_mask)
hidden_states = transformer_outputs[0]
@ -633,21 +609,88 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
return outputs # (loss), lm_logits, presents, (all hidden_states), (attentions)
@add_start_docstrings("""The GPT2 Model transformer with a language modeling and a multiple-choice classification
head on top e.g. for RocStories/SWAG tasks. The two heads are two linear layers.
The language modeling head has its weights tied to the input embeddings,
the classification head takes as input the input of a specified classification token index in the intput sequence).
""", GPT2_START_DOCSTRING)
class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
"""OpenAI GPT-2 model with a Language Modeling and a Multiple Choice head ("Language Models are Unsupervised Multitask Learners").
__doc__ = r""" Inputs:
**input_ids**: ``torch.LongTensor`` of shape ``(batch_size, num_choices, sequence_length)``:
Indices of input sequence tokens in the vocabulary.
The second dimension of the input (`num_choices`) indicates the number of choices to score.
Indices can be obtained using :class:`pytorch_transformers.BPT2Tokenizer`.
See :func:`pytorch_transformers.PreTrainedTokenizer.encode` and
:func:`pytorch_transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details.
**mc_token_ids**: ``torch.LongTensor`` of shape ``(batch_size, num_choices)``:
Index of the classification token in each input sequence.
Selected in the range ``[0, input_ids.size(-1) - 1[``.
**position_ids**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, num_choices, sequence_length)``:
Indices of positions of each input sequence tokens in the position embeddings.
Selected in the range ``[0, config.max_position_embeddings - 1[``.
**token_type_ids**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, num_choices, sequence_length)``:
A parallel sequence of tokens (can be used to indicate various portions of the inputs).
The embeddings from these tokens will be summed with the respective token embeddings.
Indices are selected in the vocabulary (unlike BERT which has a specific vocabulary for segment indices).
**past**:
list of ``torch.FloatTensor`` (one for each layer):
that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
(see `past` output below). Can be used to speed up sequential decoding.
**attention_mask**: (`optional`) ``torch.Tensor`` of shape ``(batch_size, num_choices, sequence_length)``:
Mask to avoid performing attention on padding token indices.
Mask indices selected in ``[0, 1]``:
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
**head_mask**: (`optional`) ``torch.Tensor`` of shape ``(num_heads,)`` or ``(num_layers, num_heads)``:
Mask to nullify selected heads of the self-attention modules.
Mask indices selected in ``[0, 1]``:
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
**lm_labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
Labels for language modeling.
Note that the labels **are shifted** inside the model, i.e. you can set ``lm_labels = input_ids``
Indices are selected in ``[-1, 0, ..., config.vocab_size]``
All labels set to ``-1`` are ignored (masked), the loss is only
computed for labels in ``[0, ..., config.vocab_size]``
**multiple_choice_labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size)``:
Labels for computing the multiple choice classification loss.
Indices should be in ``[0, ..., num_choices]`` where `num_choices` is the size of the second dimension
of the input tensors. (see `input_ids` above)
Args:
`config`: a GPT2Config class instance with the configuration to build a new model
`output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False
`keep_multihead_output`: If True, saves output of the multi-head attention module with its gradient.
This can be used to compute head importance metrics. Default: False
`multiple_choice_labels`: optional multiple choice labels: ``torch.LongTensor`` of shape [batch_size]
with indices selected in [0, ..., num_choices].
Example::
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
**lm_loss**: (`optional`, returned when ``lm_labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
Language modeling loss.
**mc_loss**: (`optional`, returned when ``multiple_choice_labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
Multiple choice classification loss.
**lm_prediction_scores**: ``torch.FloatTensor`` of shape ``(batch_size, num_choices, sequence_length, config.vocab_size)``
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
**mc_prediction_scores**: ``torch.FloatTensor`` of shape ``(batch_size, num_choices)``
Prediction scores of the multiplechoice classification head (scores for each choice before SoftMax).
**past**:
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
that contains pre-computed hidden-states (key and values in the attention blocks).
Can be used (see `past` input) to speed up sequential decoding.
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
of shape ``(batch_size, sequence_length, hidden_size)``:
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
Examples::
>>> config = GPT2Config.from_pretrained('gpt2')
>>> tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
>>> model = GPT2DoubleHeadsModel(config)
>>> choices = ["Hello, my dog is cute [CLS]", "Hello, my cat is cute [CLS]"] # Assume you've added [CLS] to the vocabulary
>>> input_ids = torch.tensor([tokenizer.encode(s) for s in choices]).unsqueeze(0) # Batch size 1, 2 choices
>>> mc_token_ids = torch.tensor([-1, -1]).unsqueeze(0) # Batch size 1
>>> outputs = model(input_ids, mc_token_ids)
>>> lm_prediction_scores, mc_prediction_scores = outputs[:2]
config = modeling_gpt2.GPT2Config()
model = modeling_gpt2.GPT2DoubleHeadsModel(config)
"""
def __init__(self, config):
super(GPT2DoubleHeadsModel, self).__init__(config)
self.transformer = GPT2Model(config)
@ -665,55 +708,6 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
def forward(self, input_ids, mc_token_ids=None, lm_labels=None, mc_labels=None, token_type_ids=None,
position_ids=None, past=None, head_mask=None):
"""
Performs a model forward pass. **Can be called by calling the class directly, once it has been instantiated.**
Args:
`input_ids`: a ``torch.LongTensor`` of shape [batch_size, num_choices, sequence_length] with the BPE token
indices selected in the range [0, config.vocab_size[
`mc_token_ids`: a ``torch.LongTensor`` of shape [batch_size, num_choices] with the index of the token from
which we should take the hidden state to feed the multiple choice classifier (usually last token of the sequence)
`position_ids`: an optional ``torch.LongTensor`` with the same shape as input_ids
with the position indices (selected in the range [0, config.n_positions - 1[.
`token_type_ids`: an optional ``torch.LongTensor`` with the same shape as input_ids
You can use it to add a third type of embedding to each input token in the sequence
(the previous two being the word and position embeddings).
The input, position and token_type embeddings are summed inside the Transformer before the first
self-attention block.
`lm_labels`: optional language modeling labels: ``torch.LongTensor`` of shape [batch_size, num_choices, sequence_length]
with indices selected in [-1, 0, ..., config.vocab_size]. All labels set to -1 are ignored (masked), the loss
is only computed for the labels set in [0, ..., config.vocab_size]
`multiple_choice_labels`: optional multiple choice labels: ``torch.LongTensor`` of shape [batch_size]
with indices selected in [0, ..., num_choices].
`past`: an optional list of ``torch.LongTensor`` that contains pre-computed hidden-states
(key and values in the attention blocks) to speed up sequential decoding
(this is the presents output of the model, cf. below).
`head_mask`: an optional ``torch.Tensor`` of shape [num_heads] or [num_layers, num_heads] with indices between 0 and 1.
It's a mask to be used to nullify some heads of the transformer. 1.0 => head is fully masked, 0.0 => head is not masked.
Returns:
If ``lm_labels`` and ``multiple_choice_labels`` are not ``None``, outputs a
``tuple(language_modeling_loss, multiple_choice_loss)``. If they are not ``None``, outputs a
``tuple(lm_logits, multiple_choice_logits, presents)``.
``lm_logits``: the language modeling logits as a ``torch.FloatTensor`` of size [batch_size, num_choices, sequence_length, config.vocab_size]
``multiple_choice_logits``: the multiple choice logits as a ``torch.FloatTensor`` of size [batch_size, num_choices]
``presents``: a list of pre-computed hidden-states (key and values in each attention blocks) as
torch.FloatTensors. They can be reused to speed up sequential decoding.
Example::
# Already been converted into BPE token ids
input_ids = torch.LongTensor([[[31, 51, 99], [15, 5, 0]]]) # (bsz, number of choice, seq length)
mc_token_ids = torch.LongTensor([[2], [1]]) # (bsz, number of choice)
lm_logits, multiple_choice_logits, presents = model(input_ids, mc_token_ids)
# or
lm_logits, multiple_choice_logits, presents = model.forward(input_ids, mc_token_ids)
"""
transformer_outputs = self.transformer(input_ids, position_ids, token_type_ids, past, head_mask)
hidden_states = transformer_outputs[0]