mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 11:11:05 +06:00
added gpt2 doc
This commit is contained in:
parent
183fedfed5
commit
5bc3d0cc5b
@ -277,10 +277,11 @@ class BertEmbeddings(nn.Module):
|
||||
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
|
||||
def forward(self, input_ids, token_type_ids=None):
|
||||
def forward(self, input_ids, position_ids=None, token_type_ids=None):
|
||||
seq_length = input_ids.size(1)
|
||||
position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
|
||||
position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
|
||||
if position_ids is None:
|
||||
position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
|
||||
position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
|
||||
if token_type_ids is None:
|
||||
token_type_ids = torch.zeros_like(input_ids)
|
||||
|
||||
@ -624,6 +625,9 @@ BERT_INPUTS_DOCSTRING = r"""
|
||||
Indices can be obtained using :class:`pytorch_transformers.BertTokenizer`.
|
||||
See :func:`pytorch_transformers.PreTrainedTokenizer.encode` and
|
||||
:func:`pytorch_transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details.
|
||||
**position_ids**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
|
||||
Indices of positions of each input sequence tokens in the position embeddings.
|
||||
Selected in the range ``[0, config.max_position_embeddings - 1[``.
|
||||
**token_type_ids**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
|
||||
Segment token indices to indicate first and second portions of the inputs.
|
||||
Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1``
|
||||
@ -687,7 +691,7 @@ class BertModel(BertPreTrainedModel):
|
||||
for layer, heads in heads_to_prune.items():
|
||||
self.encoder.layer[layer].attention.prune_heads(heads)
|
||||
|
||||
def forward(self, input_ids, token_type_ids=None, attention_mask=None, head_mask=None):
|
||||
def forward(self, input_ids, position_ids=None, token_type_ids=None, attention_mask=None, head_mask=None):
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones_like(input_ids)
|
||||
if token_type_ids is None:
|
||||
@ -723,7 +727,7 @@ class BertModel(BertPreTrainedModel):
|
||||
else:
|
||||
head_mask = [None] * self.config.num_hidden_layers
|
||||
|
||||
embedding_output = self.embeddings(input_ids, token_type_ids)
|
||||
embedding_output = self.embeddings(input_ids, position_ids, token_type_ids)
|
||||
encoder_outputs = self.encoder(embedding_output,
|
||||
extended_attention_mask,
|
||||
head_mask=head_mask)
|
||||
@ -773,7 +777,7 @@ class BertForPreTraining(BertPreTrainedModel):
|
||||
>>> model = BertForPreTraining(config)
|
||||
>>> input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
|
||||
>>> outputs = model(input_ids)
|
||||
>>> prediction_scores, seq_relationship_scores = outputs[:1]
|
||||
>>> prediction_scores, seq_relationship_scores = outputs[:2]
|
||||
|
||||
"""
|
||||
def __init__(self, config):
|
||||
@ -792,9 +796,9 @@ class BertForPreTraining(BertPreTrainedModel):
|
||||
self._tie_or_clone_weights(self.cls.predictions.decoder,
|
||||
self.bert.embeddings.word_embeddings)
|
||||
|
||||
def forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None,
|
||||
def forward(self, input_ids, position_ids=None, token_type_ids=None, attention_mask=None, masked_lm_labels=None,
|
||||
next_sentence_label=None, head_mask=None):
|
||||
outputs = self.bert(input_ids, token_type_ids, attention_mask, head_mask=head_mask)
|
||||
outputs = self.bert(input_ids, position_ids, token_type_ids, attention_mask, head_mask=head_mask)
|
||||
|
||||
sequence_output, pooled_output = outputs[:2]
|
||||
prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)
|
||||
@ -842,7 +846,7 @@ class BertForMaskedLM(BertPreTrainedModel):
|
||||
>>> model = BertForMaskedLM(config)
|
||||
>>> input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
|
||||
>>> outputs = model(input_ids, masked_lm_labels=input_ids)
|
||||
>>> loss, prediction_scores = outputs[:1]
|
||||
>>> loss, prediction_scores = outputs[:2]
|
||||
|
||||
"""
|
||||
def __init__(self, config):
|
||||
@ -861,8 +865,8 @@ class BertForMaskedLM(BertPreTrainedModel):
|
||||
self._tie_or_clone_weights(self.cls.predictions.decoder,
|
||||
self.bert.embeddings.word_embeddings)
|
||||
|
||||
def forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None, head_mask=None):
|
||||
outputs = self.bert(input_ids, token_type_ids, attention_mask, head_mask=head_mask)
|
||||
def forward(self, input_ids, position_ids=None, token_type_ids=None, attention_mask=None, masked_lm_labels=None, head_mask=None):
|
||||
outputs = self.bert(input_ids, position_ids, token_type_ids, attention_mask, head_mask=head_mask)
|
||||
|
||||
sequence_output = outputs[0]
|
||||
prediction_scores = self.cls(sequence_output)
|
||||
@ -918,8 +922,8 @@ class BertForNextSentencePrediction(BertPreTrainedModel):
|
||||
|
||||
self.apply(self.init_weights)
|
||||
|
||||
def forward(self, input_ids, token_type_ids=None, attention_mask=None, next_sentence_label=None, head_mask=None):
|
||||
outputs = self.bert(input_ids, token_type_ids, attention_mask, head_mask=head_mask)
|
||||
def forward(self, input_ids, position_ids=None, token_type_ids=None, attention_mask=None, next_sentence_label=None, head_mask=None):
|
||||
outputs = self.bert(input_ids, position_ids, token_type_ids, attention_mask, head_mask=head_mask)
|
||||
pooled_output = outputs[1]
|
||||
|
||||
seq_relationship_score = self.cls(pooled_output)
|
||||
@ -966,7 +970,7 @@ class BertForSequenceClassification(BertPreTrainedModel):
|
||||
>>> input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
|
||||
>>> labels = torch.tensor([1]).unsqueeze(0) # Batch size 1
|
||||
>>> outputs = model(input_ids, labels=labels)
|
||||
>>> loss, logits = outputs[:1]
|
||||
>>> loss, logits = outputs[:2]
|
||||
|
||||
"""
|
||||
def __init__(self, config):
|
||||
@ -979,8 +983,8 @@ class BertForSequenceClassification(BertPreTrainedModel):
|
||||
|
||||
self.apply(self.init_weights)
|
||||
|
||||
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None, head_mask=None):
|
||||
outputs = self.bert(input_ids, token_type_ids, attention_mask, head_mask=head_mask)
|
||||
def forward(self, input_ids, position_ids=None, token_type_ids=None, attention_mask=None, labels=None, head_mask=None):
|
||||
outputs = self.bert(input_ids, position_ids, token_type_ids, attention_mask, head_mask=head_mask)
|
||||
pooled_output = outputs[1]
|
||||
|
||||
pooled_output = self.dropout(pooled_output)
|
||||
@ -1071,7 +1075,7 @@ class BertForMultipleChoice(BertPreTrainedModel):
|
||||
>>> input_ids = torch.tensor([tokenizer.encode(s) for s in choices]).unsqueeze(0) # Batch size 1, 2 choices
|
||||
>>> labels = torch.tensor(1).unsqueeze(0) # Batch size 1
|
||||
>>> outputs = model(input_ids, labels=labels)
|
||||
>>> loss, classification_scores = outputs[:1]
|
||||
>>> loss, classification_scores = outputs[:2]
|
||||
|
||||
"""
|
||||
def __init__(self, config):
|
||||
@ -1083,13 +1087,14 @@ class BertForMultipleChoice(BertPreTrainedModel):
|
||||
|
||||
self.apply(self.init_weights)
|
||||
|
||||
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None, head_mask=None):
|
||||
def forward(self, input_ids, position_ids=None, token_type_ids=None, attention_mask=None, labels=None, head_mask=None):
|
||||
num_choices = input_ids.shape[1]
|
||||
|
||||
flat_input_ids = input_ids.view(-1, input_ids.size(-1))
|
||||
flat_position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
|
||||
flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
|
||||
flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
|
||||
outputs = self.bert(flat_input_ids, flat_token_type_ids, flat_attention_mask, head_mask=head_mask)
|
||||
outputs = self.bert(flat_input_ids, flat_position_ids, flat_token_type_ids, flat_attention_mask, head_mask=head_mask)
|
||||
pooled_output = outputs[1]
|
||||
|
||||
pooled_output = self.dropout(pooled_output)
|
||||
@ -1137,7 +1142,7 @@ class BertForTokenClassification(BertPreTrainedModel):
|
||||
>>> input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
|
||||
>>> labels = torch.tensor([1] * input_ids.size(1)).unsqueeze(0) # Batch size 1
|
||||
>>> outputs = model(input_ids, labels=labels)
|
||||
>>> loss, scores = outputs[:1]
|
||||
>>> loss, scores = outputs[:2]
|
||||
|
||||
"""
|
||||
def __init__(self, config):
|
||||
@ -1150,8 +1155,8 @@ class BertForTokenClassification(BertPreTrainedModel):
|
||||
|
||||
self.apply(self.init_weights)
|
||||
|
||||
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None, head_mask=None):
|
||||
outputs = self.bert(input_ids, token_type_ids, attention_mask, head_mask=head_mask)
|
||||
def forward(self, input_ids, position_ids=None, token_type_ids=None, attention_mask=None, labels=None, head_mask=None):
|
||||
outputs = self.bert(input_ids, position_ids, token_type_ids, attention_mask, head_mask=head_mask)
|
||||
sequence_output = outputs[0]
|
||||
|
||||
sequence_output = self.dropout(sequence_output)
|
||||
@ -1177,7 +1182,7 @@ class BertForTokenClassification(BertPreTrainedModel):
|
||||
the hidden-states output to compute `span start logits` and `span end logits`). """,
|
||||
BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING)
|
||||
class BertForQuestionAnswering(BertPreTrainedModel):
|
||||
r"""
|
||||
__doc__ = r"""
|
||||
**start_positions**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
|
||||
Position (index) of the start of the labelled span for computing the token classification loss.
|
||||
Positions are clamped to the length of the sequence (`sequence_length`).
|
||||
@ -1224,9 +1229,9 @@ class BertForQuestionAnswering(BertPreTrainedModel):
|
||||
|
||||
self.apply(self.init_weights)
|
||||
|
||||
def forward(self, input_ids, token_type_ids=None, attention_mask=None, start_positions=None,
|
||||
def forward(self, input_ids, position_ids=None, token_type_ids=None, attention_mask=None, start_positions=None,
|
||||
end_positions=None, head_mask=None):
|
||||
outputs = self.bert(input_ids, token_type_ids, attention_mask, head_mask=head_mask)
|
||||
outputs = self.bert(input_ids, position_ids, token_type_ids, attention_mask, head_mask=head_mask)
|
||||
sequence_output = outputs[0]
|
||||
|
||||
logits = self.qa_outputs(sequence_output)
|
||||
|
@ -365,44 +365,81 @@ class GPT2PreTrainedModel(PreTrainedModel):
|
||||
module.weight.data.fill_(1.0)
|
||||
|
||||
|
||||
GPT2_START_DOCSTRING = r""" OpenAI GPT-2 model was proposed in
|
||||
`Language Models are Unsupervised Multitask Learners`_
|
||||
by Alec Radford*, Jeffrey Wu*, Rewon Child, David Luan, Dario Amodei** and Ilya Sutskever**.
|
||||
It's a causal (unidirectional) transformer pre-trained using language modeling on a very large
|
||||
corpus of ~40 GB of text data.
|
||||
|
||||
This model is a PyTorch `torch.nn.Module`_ sub-class. Use it as a regular PyTorch Module and
|
||||
refer to the PyTorch documentation for all matter related to general usage and behavior.
|
||||
|
||||
.. _`Language Models are Unsupervised Multitask Learners`:
|
||||
https://openai.com/blog/better-language-models/
|
||||
|
||||
.. _`torch.nn.Module`:
|
||||
https://pytorch.org/docs/stable/nn.html#module
|
||||
|
||||
Parameters:
|
||||
config (:class:`~pytorch_transformers.BertConfig`): Model configuration class with all the parameters of the model.
|
||||
"""
|
||||
|
||||
GPT2_INPUTS_DOCTRING = r""" Inputs:
|
||||
**input_ids**: ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
|
||||
Indices of input sequence tokens in the vocabulary.
|
||||
Indices can be obtained using :class:`pytorch_transformers.BPT2Tokenizer`.
|
||||
See :func:`pytorch_transformers.PreTrainedTokenizer.encode` and
|
||||
:func:`pytorch_transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details.
|
||||
**position_ids**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
|
||||
Indices of positions of each input sequence tokens in the position embeddings.
|
||||
Selected in the range ``[0, config.max_position_embeddings - 1[``.
|
||||
**token_type_ids**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
|
||||
A parallel sequence of tokens (can be used to indicate various portions of the inputs).
|
||||
The embeddings from these tokens will be summed with the respective token embeddings.
|
||||
Indices are selected in the vocabulary (unlike BERT which has a specific vocabulary for segment indices).
|
||||
**past**:
|
||||
list of ``torch.FloatTensor`` (one for each layer):
|
||||
that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
|
||||
(see `past` output below). Can be used to speed up sequential decoding.
|
||||
**attention_mask**: (`optional`) ``torch.Tensor`` of shape ``(batch_size, sequence_length)``:
|
||||
Mask to avoid performing attention on padding token indices.
|
||||
Mask indices selected in ``[0, 1]``:
|
||||
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
|
||||
**head_mask**: (`optional`) ``torch.Tensor`` of shape ``(num_heads,)`` or ``(num_layers, num_heads)``:
|
||||
Mask to nullify selected heads of the self-attention modules.
|
||||
Mask indices selected in ``[0, 1]``:
|
||||
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
|
||||
"""
|
||||
|
||||
@add_start_docstrings("The bare GPT2 Model transformer outputing raw hidden-states without any specific head on top.",
|
||||
GPT2_START_DOCSTRING, GPT2_INPUTS_DOCTRING)
|
||||
class GPT2Model(GPT2PreTrainedModel):
|
||||
"""OpenAI GPT-2 model ("Language Models are Unsupervised Multitask Learners").
|
||||
__doc__ = r"""
|
||||
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
|
||||
**last_hidden_state**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, hidden_size)``
|
||||
Sequence of hidden-states at the last layer of the model.
|
||||
**past**:
|
||||
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
|
||||
that contains pre-computed hidden-states (key and values in the attention blocks).
|
||||
Can be used (see `past` input) to speed up sequential decoding.
|
||||
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
|
||||
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
|
||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
|
||||
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
|
||||
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
|
||||
of shape ``(batch_size, sequence_length, hidden_size)``:
|
||||
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||
|
||||
GPT-2 use a single embedding matrix to store the word and special embeddings.
|
||||
Special tokens embeddings are additional tokens that are not pre-trained: [SEP], [CLS]...
|
||||
Special tokens need to be trained during the fine-tuning if you use them.
|
||||
The number of special embeddings can be controlled using the `set_num_special_tokens(num_special_tokens)` function.
|
||||
Examples::
|
||||
|
||||
The embeddings are ordered as follow in the token embeddings matrix:
|
||||
::
|
||||
>>> config = GPT2Config.from_pretrained('gpt2')
|
||||
>>> tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
|
||||
>>> model = GPT2Model(config)
|
||||
>>> input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
|
||||
>>> outputs = model(input_ids)
|
||||
>>> last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
|
||||
|
||||
[0, ----------------------
|
||||
... -> word embeddings
|
||||
config.vocab_size - 1, ______________________
|
||||
config.vocab_size,
|
||||
... -> special embeddings
|
||||
config.vocab_size + n_special - 1] ______________________
|
||||
|
||||
where total_tokens_embeddings is equal to
|
||||
|
||||
::
|
||||
|
||||
total_tokens_embeddings = vocab_size + n_special
|
||||
|
||||
You should use the associated indices to index the embeddings.
|
||||
|
||||
Args:
|
||||
`config`: a GPT2Config class instance with the configuration to build a new model
|
||||
`output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False
|
||||
|
||||
|
||||
|
||||
Example::
|
||||
|
||||
config = modeling_gpt2.GPT2Config()
|
||||
model = modeling_gpt2.GPT2Model(config)
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
super(GPT2Model, self).__init__(config)
|
||||
self.output_hidden_states = config.output_hidden_states
|
||||
@ -428,47 +465,6 @@ class GPT2Model(GPT2PreTrainedModel):
|
||||
self.h[layer].attn.prune_heads(heads)
|
||||
|
||||
def forward(self, input_ids, position_ids=None, token_type_ids=None, past=None, head_mask=None):
|
||||
"""
|
||||
Performs a model forward pass. **Can be called by calling the class directly, once it has been instantiated.**
|
||||
|
||||
Args:
|
||||
`input_ids`: a ``torch.LongTensor`` of shape [batch_size, sequence_length] (or more generally [d_1, ..., d_n, sequence_length]
|
||||
were d_1 ... d_n are arbitrary dimensions) with the word BPE token indices selected in the range [0, config.vocab_size[
|
||||
`position_ids`: an optional ``torch.LongTensor`` with the same shape as input_ids
|
||||
with the position indices (selected in the range [0, config.n_positions - 1[.
|
||||
`token_type_ids`: an optional ``torch.LongTensor`` with the same shape as input_ids
|
||||
You can use it to add a third type of embedding to each input token in the sequence
|
||||
(the previous two being the word and position embeddings).
|
||||
The input, position and token_type embeddings are summed inside the Transformer before the first
|
||||
self-attention block.
|
||||
`past`: an optional list of ``torch.LongTensor`` that contains pre-computed hidden-states
|
||||
(key and values in the attention blocks) to speed up sequential decoding
|
||||
(this is the presents output of the model, cf. below).
|
||||
`head_mask`: an optional ``torch.Tensor`` of shape [num_heads] or [num_layers, num_heads] with indices between 0 and 1.
|
||||
It's a mask to be used to nullify some heads of the transformer. 1.0 => head is fully masked, 0.0 => head is not masked.
|
||||
|
||||
Returns:
|
||||
A tuple consisting of ``hidden_states`` and ``presents``.
|
||||
|
||||
``hidden_states`` are a list of all the encoded-hidden-states in the model (length of the list: number of
|
||||
layers + 1 for the output of the embeddings) as ``torch.FloatTensor`` of size [batch_size, sequence_length,
|
||||
hidden_size] (or more generally [d_1, ..., d_n, hidden_size] were d_1 ... d_n are the dimension of
|
||||
input_ids).
|
||||
|
||||
``presents`` are a list of pre-computed hidden-states (key and values in each attention blocks) as
|
||||
torch.FloatTensors. They can be reused to speed up sequential decoding.
|
||||
|
||||
|
||||
Example::
|
||||
|
||||
# Already been converted into BPE token ids
|
||||
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
|
||||
|
||||
hidden_states, presents = model(input_ids)
|
||||
# or
|
||||
hidden_states, presents = model.forward(input_ids)
|
||||
|
||||
"""
|
||||
if past is None:
|
||||
past_length = 0
|
||||
past = [None] * len(self.h)
|
||||
@ -540,21 +536,44 @@ class GPT2Model(GPT2PreTrainedModel):
|
||||
return outputs # last hidden state, presents, (all hidden_states), (attentions)
|
||||
|
||||
|
||||
@add_start_docstrings("""The GPT2 Model transformer with a language modeling head on top
|
||||
(linear layer with weights tied to the input embeddings). """, GPT2_START_DOCSTRING, GPT2_INPUTS_DOCTRING)
|
||||
class GPT2LMHeadModel(GPT2PreTrainedModel):
|
||||
"""OpenAI GPT-2 model with a Language Modeling head ("Language Models are Unsupervised Multitask Learners").
|
||||
__doc__ = r"""
|
||||
**lm_labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
|
||||
Labels for language modeling.
|
||||
Note that the labels **are shifted** inside the model, i.e. you can set ``lm_labels = input_ids``
|
||||
Indices are selected in ``[-1, 0, ..., config.vocab_size]``
|
||||
All labels set to ``-1`` are ignored (masked), the loss is only
|
||||
computed for labels in ``[0, ..., config.vocab_size]``
|
||||
|
||||
Args:
|
||||
`config`: a GPT2Config class instance with the configuration to build a new model
|
||||
`output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False
|
||||
`keep_multihead_output`: If True, saves output of the multi-head attention module with its gradient.
|
||||
This can be used to compute head importance metrics. Default: False
|
||||
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
|
||||
**loss**: (`optional`, returned when ``lm_labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
|
||||
Language modeling loss.
|
||||
**prediction_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, config.vocab_size)``
|
||||
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
||||
**past**:
|
||||
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
|
||||
that contains pre-computed hidden-states (key and values in the attention blocks).
|
||||
Can be used (see `past` input) to speed up sequential decoding.
|
||||
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
|
||||
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
|
||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
|
||||
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
|
||||
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
|
||||
of shape ``(batch_size, sequence_length, hidden_size)``:
|
||||
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||
|
||||
Example::
|
||||
Examples::
|
||||
|
||||
>>> config = GPT2Config.from_pretrained('gpt2')
|
||||
>>> tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
|
||||
>>> model = GPT2LMHeadModel(config)
|
||||
>>> input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
|
||||
>>> outputs = model(input_ids, lm_labels=input_ids)
|
||||
>>> loss, logits = outputs[:2]
|
||||
|
||||
config = modeling_gpt2.GPT2Config()
|
||||
model = modeling_gpt2.GPT2LMHeadModel(config)
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
super(GPT2LMHeadModel, self).__init__(config)
|
||||
self.transformer = GPT2Model(config)
|
||||
@ -571,49 +590,6 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
|
||||
self.transformer.wte)
|
||||
|
||||
def forward(self, input_ids, position_ids=None, token_type_ids=None, lm_labels=None, past=None, head_mask=None):
|
||||
"""
|
||||
Performs a model forward pass. **Can be called by calling the class directly, once it has been instantiated.**
|
||||
|
||||
Args:
|
||||
`input_ids`: a ``torch.LongTensor`` of shape [batch_size, sequence_length] (or more generally [d_1, ..., d_n, sequence_length]
|
||||
were d_1 ... d_n are arbitrary dimensions) with the word BPE token indices selected in the range [0, config.vocab_size[
|
||||
`position_ids`: an optional ``torch.LongTensor`` with the same shape as input_ids
|
||||
with the position indices (selected in the range [0, config.n_positions - 1[.
|
||||
`token_type_ids`: an optional ``torch.LongTensor`` with the same shape as input_ids
|
||||
You can use it to add a third type of embedding to each input token in the sequence
|
||||
(the previous two being the word and position embeddings).
|
||||
The input, position and token_type embeddings are summed inside the Transformer before the first
|
||||
self-attention block.
|
||||
`lm_labels`: optional language modeling labels: ``torch.LongTensor`` of shape [batch_size, sequence_length]
|
||||
with indices selected in [-1, 0, ..., vocab_size]. All labels set to -1 are ignored (masked), the loss
|
||||
is only computed for the labels set in [0, ..., vocab_size]
|
||||
`past`: an optional list of ``torch.LongTensor`` that contains pre-computed hidden-states
|
||||
(key and values in the attention blocks) to speed up sequential decoding
|
||||
(this is the presents output of the model, cf. below).
|
||||
`head_mask`: an optional ``torch.Tensor`` of shape [num_heads] or [num_layers, num_heads] with indices between 0 and 1.
|
||||
It's a mask to be used to nullify some heads of the transformer. 1.0 => head is fully masked, 0.0 => head is not masked.
|
||||
|
||||
Returns:
|
||||
If ``lm_labels`` is not ``None``, returns the language modeling loss. It ``lm_labels`` is ``None``, returns
|
||||
a tuple of (``lm_logits``, ``presents``).
|
||||
|
||||
``lm_logits`` is the language modeling logits as a ``torch.FloatTensor`` of size [batch_size,
|
||||
sequence_length, config.vocab_size] (or more generally [d_1, ..., d_n, config.vocab_size] were d_1 ...
|
||||
d_n are the dimension of input_ids).
|
||||
|
||||
``presents`` is a list of pre-computed hidden-states (key and values in each attention blocks) as
|
||||
torch.FloatTensors. They can be reused to speed up sequential decoding.
|
||||
|
||||
Example::
|
||||
|
||||
# Already been converted into BPE token ids
|
||||
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
|
||||
|
||||
lm_logits, presents = model(input_ids)
|
||||
# or
|
||||
lm_logits, presents = model.forward(input_ids)
|
||||
|
||||
"""
|
||||
transformer_outputs = self.transformer(input_ids, position_ids, token_type_ids, past, head_mask)
|
||||
hidden_states = transformer_outputs[0]
|
||||
|
||||
@ -633,21 +609,88 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
|
||||
return outputs # (loss), lm_logits, presents, (all hidden_states), (attentions)
|
||||
|
||||
|
||||
@add_start_docstrings("""The GPT2 Model transformer with a language modeling and a multiple-choice classification
|
||||
head on top e.g. for RocStories/SWAG tasks. The two heads are two linear layers.
|
||||
The language modeling head has its weights tied to the input embeddings,
|
||||
the classification head takes as input the input of a specified classification token index in the intput sequence).
|
||||
""", GPT2_START_DOCSTRING)
|
||||
class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
|
||||
"""OpenAI GPT-2 model with a Language Modeling and a Multiple Choice head ("Language Models are Unsupervised Multitask Learners").
|
||||
__doc__ = r""" Inputs:
|
||||
**input_ids**: ``torch.LongTensor`` of shape ``(batch_size, num_choices, sequence_length)``:
|
||||
Indices of input sequence tokens in the vocabulary.
|
||||
The second dimension of the input (`num_choices`) indicates the number of choices to score.
|
||||
Indices can be obtained using :class:`pytorch_transformers.BPT2Tokenizer`.
|
||||
See :func:`pytorch_transformers.PreTrainedTokenizer.encode` and
|
||||
:func:`pytorch_transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details.
|
||||
**mc_token_ids**: ``torch.LongTensor`` of shape ``(batch_size, num_choices)``:
|
||||
Index of the classification token in each input sequence.
|
||||
Selected in the range ``[0, input_ids.size(-1) - 1[``.
|
||||
**position_ids**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, num_choices, sequence_length)``:
|
||||
Indices of positions of each input sequence tokens in the position embeddings.
|
||||
Selected in the range ``[0, config.max_position_embeddings - 1[``.
|
||||
**token_type_ids**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, num_choices, sequence_length)``:
|
||||
A parallel sequence of tokens (can be used to indicate various portions of the inputs).
|
||||
The embeddings from these tokens will be summed with the respective token embeddings.
|
||||
Indices are selected in the vocabulary (unlike BERT which has a specific vocabulary for segment indices).
|
||||
**past**:
|
||||
list of ``torch.FloatTensor`` (one for each layer):
|
||||
that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
|
||||
(see `past` output below). Can be used to speed up sequential decoding.
|
||||
**attention_mask**: (`optional`) ``torch.Tensor`` of shape ``(batch_size, num_choices, sequence_length)``:
|
||||
Mask to avoid performing attention on padding token indices.
|
||||
Mask indices selected in ``[0, 1]``:
|
||||
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
|
||||
**head_mask**: (`optional`) ``torch.Tensor`` of shape ``(num_heads,)`` or ``(num_layers, num_heads)``:
|
||||
Mask to nullify selected heads of the self-attention modules.
|
||||
Mask indices selected in ``[0, 1]``:
|
||||
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
|
||||
**lm_labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
|
||||
Labels for language modeling.
|
||||
Note that the labels **are shifted** inside the model, i.e. you can set ``lm_labels = input_ids``
|
||||
Indices are selected in ``[-1, 0, ..., config.vocab_size]``
|
||||
All labels set to ``-1`` are ignored (masked), the loss is only
|
||||
computed for labels in ``[0, ..., config.vocab_size]``
|
||||
**multiple_choice_labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size)``:
|
||||
Labels for computing the multiple choice classification loss.
|
||||
Indices should be in ``[0, ..., num_choices]`` where `num_choices` is the size of the second dimension
|
||||
of the input tensors. (see `input_ids` above)
|
||||
|
||||
Args:
|
||||
`config`: a GPT2Config class instance with the configuration to build a new model
|
||||
`output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False
|
||||
`keep_multihead_output`: If True, saves output of the multi-head attention module with its gradient.
|
||||
This can be used to compute head importance metrics. Default: False
|
||||
`multiple_choice_labels`: optional multiple choice labels: ``torch.LongTensor`` of shape [batch_size]
|
||||
with indices selected in [0, ..., num_choices].
|
||||
|
||||
Example::
|
||||
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
|
||||
**lm_loss**: (`optional`, returned when ``lm_labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
|
||||
Language modeling loss.
|
||||
**mc_loss**: (`optional`, returned when ``multiple_choice_labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
|
||||
Multiple choice classification loss.
|
||||
**lm_prediction_scores**: ``torch.FloatTensor`` of shape ``(batch_size, num_choices, sequence_length, config.vocab_size)``
|
||||
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
||||
**mc_prediction_scores**: ``torch.FloatTensor`` of shape ``(batch_size, num_choices)``
|
||||
Prediction scores of the multiplechoice classification head (scores for each choice before SoftMax).
|
||||
**past**:
|
||||
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
|
||||
that contains pre-computed hidden-states (key and values in the attention blocks).
|
||||
Can be used (see `past` input) to speed up sequential decoding.
|
||||
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
|
||||
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
|
||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
|
||||
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
|
||||
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
|
||||
of shape ``(batch_size, sequence_length, hidden_size)``:
|
||||
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||
|
||||
Examples::
|
||||
|
||||
>>> config = GPT2Config.from_pretrained('gpt2')
|
||||
>>> tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
|
||||
>>> model = GPT2DoubleHeadsModel(config)
|
||||
>>> choices = ["Hello, my dog is cute [CLS]", "Hello, my cat is cute [CLS]"] # Assume you've added [CLS] to the vocabulary
|
||||
>>> input_ids = torch.tensor([tokenizer.encode(s) for s in choices]).unsqueeze(0) # Batch size 1, 2 choices
|
||||
>>> mc_token_ids = torch.tensor([-1, -1]).unsqueeze(0) # Batch size 1
|
||||
>>> outputs = model(input_ids, mc_token_ids)
|
||||
>>> lm_prediction_scores, mc_prediction_scores = outputs[:2]
|
||||
|
||||
config = modeling_gpt2.GPT2Config()
|
||||
model = modeling_gpt2.GPT2DoubleHeadsModel(config)
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
super(GPT2DoubleHeadsModel, self).__init__(config)
|
||||
self.transformer = GPT2Model(config)
|
||||
@ -665,55 +708,6 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
|
||||
|
||||
def forward(self, input_ids, mc_token_ids=None, lm_labels=None, mc_labels=None, token_type_ids=None,
|
||||
position_ids=None, past=None, head_mask=None):
|
||||
"""
|
||||
Performs a model forward pass. **Can be called by calling the class directly, once it has been instantiated.**
|
||||
|
||||
Args:
|
||||
`input_ids`: a ``torch.LongTensor`` of shape [batch_size, num_choices, sequence_length] with the BPE token
|
||||
indices selected in the range [0, config.vocab_size[
|
||||
`mc_token_ids`: a ``torch.LongTensor`` of shape [batch_size, num_choices] with the index of the token from
|
||||
which we should take the hidden state to feed the multiple choice classifier (usually last token of the sequence)
|
||||
`position_ids`: an optional ``torch.LongTensor`` with the same shape as input_ids
|
||||
with the position indices (selected in the range [0, config.n_positions - 1[.
|
||||
`token_type_ids`: an optional ``torch.LongTensor`` with the same shape as input_ids
|
||||
You can use it to add a third type of embedding to each input token in the sequence
|
||||
(the previous two being the word and position embeddings).
|
||||
The input, position and token_type embeddings are summed inside the Transformer before the first
|
||||
self-attention block.
|
||||
`lm_labels`: optional language modeling labels: ``torch.LongTensor`` of shape [batch_size, num_choices, sequence_length]
|
||||
with indices selected in [-1, 0, ..., config.vocab_size]. All labels set to -1 are ignored (masked), the loss
|
||||
is only computed for the labels set in [0, ..., config.vocab_size]
|
||||
`multiple_choice_labels`: optional multiple choice labels: ``torch.LongTensor`` of shape [batch_size]
|
||||
with indices selected in [0, ..., num_choices].
|
||||
`past`: an optional list of ``torch.LongTensor`` that contains pre-computed hidden-states
|
||||
(key and values in the attention blocks) to speed up sequential decoding
|
||||
(this is the presents output of the model, cf. below).
|
||||
`head_mask`: an optional ``torch.Tensor`` of shape [num_heads] or [num_layers, num_heads] with indices between 0 and 1.
|
||||
It's a mask to be used to nullify some heads of the transformer. 1.0 => head is fully masked, 0.0 => head is not masked.
|
||||
|
||||
Returns:
|
||||
If ``lm_labels`` and ``multiple_choice_labels`` are not ``None``, outputs a
|
||||
``tuple(language_modeling_loss, multiple_choice_loss)``. If they are not ``None``, outputs a
|
||||
``tuple(lm_logits, multiple_choice_logits, presents)``.
|
||||
|
||||
``lm_logits``: the language modeling logits as a ``torch.FloatTensor`` of size [batch_size, num_choices, sequence_length, config.vocab_size]
|
||||
|
||||
``multiple_choice_logits``: the multiple choice logits as a ``torch.FloatTensor`` of size [batch_size, num_choices]
|
||||
|
||||
``presents``: a list of pre-computed hidden-states (key and values in each attention blocks) as
|
||||
torch.FloatTensors. They can be reused to speed up sequential decoding.
|
||||
|
||||
Example::
|
||||
|
||||
# Already been converted into BPE token ids
|
||||
input_ids = torch.LongTensor([[[31, 51, 99], [15, 5, 0]]]) # (bsz, number of choice, seq length)
|
||||
mc_token_ids = torch.LongTensor([[2], [1]]) # (bsz, number of choice)
|
||||
|
||||
lm_logits, multiple_choice_logits, presents = model(input_ids, mc_token_ids)
|
||||
# or
|
||||
lm_logits, multiple_choice_logits, presents = model.forward(input_ids, mc_token_ids)
|
||||
|
||||
"""
|
||||
transformer_outputs = self.transformer(input_ids, position_ids, token_type_ids, past, head_mask)
|
||||
hidden_states = transformer_outputs[0]
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user