mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 19:21:31 +06:00
Pytorch GPT
This commit is contained in:
parent
1487b840d3
commit
850795c487
@ -1,6 +1,38 @@
|
||||
OpenAI GPT
|
||||
----------------------------------------------------
|
||||
|
||||
OpenAI GPT model was proposed in `Improving Language Understanding by Generative Pre-Training`_
|
||||
by Alec Radford, Karthik Narasimhan, Tim Salimans and Ilya Sutskever. It's a causal (unidirectional)
|
||||
transformer pre-trained using language modeling on a large corpus will long range dependencies, the Toronto Book Corpus.
|
||||
|
||||
The abstract from the paper is the following:
|
||||
|
||||
*Natural language understanding comprises a wide range of diverse tasks such
|
||||
as textual entailment, question answering, semantic similarity assessment, and
|
||||
document classification. Although large unlabeled text corpora are abundant,
|
||||
labeled data for learning these specific tasks is scarce, making it challenging for
|
||||
discriminatively trained models to perform adequately. We demonstrate that large
|
||||
gains on these tasks can be realized by generative pre-training of a language model
|
||||
on a diverse corpus of unlabeled text, followed by discriminative fine-tuning on each
|
||||
specific task. In contrast to previous approaches, we make use of task-aware input
|
||||
transformations during fine-tuning to achieve effective transfer while requiring
|
||||
minimal changes to the model architecture. We demonstrate the effectiveness of
|
||||
our approach on a wide range of benchmarks for natural language understanding.
|
||||
Our general task-agnostic model outperforms discriminatively trained models that
|
||||
use architectures specifically crafted for each task, significantly improving upon the
|
||||
state of the art in 9 out of the 12 tasks studied.*
|
||||
|
||||
Tips:
|
||||
|
||||
- GPT is a model with absolute position embeddings so it's usually advised to pad the inputs on
|
||||
the right rather than the left.
|
||||
- GPT was trained with a causal language modeling (CLM) objective and is therefore powerful at predicting the next
|
||||
token in a sequence. Leveraging this feature allows GPT-2 to generate syntactically coherent text as
|
||||
it can be observed in the `run_generation.py` example script.
|
||||
|
||||
`Write With Transformer <https://transformer.huggingface.co/doc/gpt>`__ is a webapp created and hosted by
|
||||
Hugging Face showcasing the generative capabilities of several models. GPT is one of them.
|
||||
|
||||
``OpenAIGPTConfig``
|
||||
~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
|
@ -283,7 +283,7 @@ GPT2_START_DOCSTRING = r"""
|
||||
|
||||
GPT2_INPUTS_DOCSTRING = r"""
|
||||
Args:
|
||||
input_id (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`):
|
||||
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`):
|
||||
Indices of input sequence tokens in the vocabulary.
|
||||
|
||||
Indices can be obtained using :class:`transformers.GPT2Tokenizer`.
|
||||
|
@ -26,7 +26,7 @@ import torch.nn as nn
|
||||
from torch.nn import CrossEntropyLoss
|
||||
|
||||
from .configuration_openai import OpenAIGPTConfig
|
||||
from .file_utils import add_start_docstrings
|
||||
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
|
||||
from .modeling_utils import Conv1D, PreTrainedModel, SequenceSummary, prune_conv1d_layer
|
||||
|
||||
|
||||
@ -279,12 +279,7 @@ class OpenAIGPTPreTrainedModel(PreTrainedModel):
|
||||
module.weight.data.fill_(1.0)
|
||||
|
||||
|
||||
OPENAI_GPT_START_DOCSTRING = r""" OpenAI GPT model was proposed in
|
||||
`Improving Language Understanding by Generative Pre-Training`_
|
||||
by Alec Radford, Karthik Narasimhan, Tim Salimans and Ilya Sutskever.
|
||||
It's a causal (unidirectional) transformer pre-trained using language modeling on a large
|
||||
corpus will long range dependencies, the Toronto Book Corpus.
|
||||
|
||||
OPENAI_GPT_START_DOCSTRING = r"""
|
||||
This model is a PyTorch `torch.nn.Module`_ sub-class. Use it as a regular PyTorch Module and
|
||||
refer to the PyTorch documentation for all matter related to general usage and behavior.
|
||||
|
||||
@ -300,31 +295,39 @@ OPENAI_GPT_START_DOCSTRING = r""" OpenAI GPT model was proposed in
|
||||
Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights.
|
||||
"""
|
||||
|
||||
OPENAI_GPT_INPUTS_DOCSTRING = r""" Inputs:
|
||||
**input_ids**: ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
|
||||
Indices of input sequence tokens in the vocabulary.
|
||||
GPT is a model with absolute position embeddings so it's usually advised to pad the inputs on
|
||||
the right rather than the left.
|
||||
Indices can be obtained using :class:`transformers.BPT2Tokenizer`.
|
||||
OPENAI_GPT_INPUTS_DOCSTRING = r"""
|
||||
Args:
|
||||
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`):
|
||||
Indices of input sequence tokens in the vocabulary.
|
||||
|
||||
Indices can be obtained using :class:`transformers.OpenAIGPTTokenizer`.
|
||||
See :func:`transformers.PreTrainedTokenizer.encode` and
|
||||
:func:`transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details.
|
||||
**attention_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length)``:
|
||||
:func:`transformers.PreTrainedTokenizer.encode_plus` for details.
|
||||
|
||||
`What are input IDs? <../glossary.html#input-ids>`__
|
||||
attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
|
||||
Mask to avoid performing attention on padding token indices.
|
||||
Mask values selected in ``[0, 1]``:
|
||||
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
|
||||
**token_type_ids**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
|
||||
A parallel sequence of tokens (can be used to indicate various portions of the inputs).
|
||||
The embeddings from these tokens will be summed with the respective token embeddings.
|
||||
Indices are selected in the vocabulary (unlike BERT which has a specific vocabulary for segment indices)
|
||||
**position_ids**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
|
||||
|
||||
`What are attention masks? <../glossary.html#attention-mask>`__
|
||||
token_type_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
|
||||
Segment token indices to indicate first and second portions of the inputs.
|
||||
Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1``
|
||||
corresponds to a `sentence B` token
|
||||
|
||||
`What are token type IDs? <../glossary.html#token-type-ids>`_
|
||||
position_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
|
||||
Indices of positions of each input sequence tokens in the position embeddings.
|
||||
Selected in the range ``[0, config.max_position_embeddings - 1]``.
|
||||
**head_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(num_heads,)`` or ``(num_layers, num_heads)``:
|
||||
|
||||
`What are position IDs? <../glossary.html#position-ids>`_
|
||||
head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`, defaults to :obj:`None`):
|
||||
Mask to nullify selected heads of the self-attention modules.
|
||||
Mask values selected in ``[0, 1]``:
|
||||
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
|
||||
**inputs_embeds**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, embedding_dim)``:
|
||||
Optionally, instead of passing ``input_ids`` you can choose to directly pass an embedded representation.
|
||||
:obj:`1` indicates the head is **not masked**, :obj:`0` indicates the head is **masked**.
|
||||
input_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`):
|
||||
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
|
||||
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
|
||||
than the model's internal embedding lookup matrix.
|
||||
"""
|
||||
@ -333,30 +336,8 @@ OPENAI_GPT_INPUTS_DOCSTRING = r""" Inputs:
|
||||
@add_start_docstrings(
|
||||
"The bare OpenAI GPT transformer model outputting raw hidden-states without any specific head on top.",
|
||||
OPENAI_GPT_START_DOCSTRING,
|
||||
OPENAI_GPT_INPUTS_DOCSTRING,
|
||||
)
|
||||
class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
|
||||
r"""
|
||||
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
|
||||
**last_hidden_state**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, hidden_size)``
|
||||
Sequence of hidden-states at the last layer of the model.
|
||||
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
|
||||
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
|
||||
of shape ``(batch_size, sequence_length, hidden_size)``:
|
||||
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
|
||||
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
|
||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
|
||||
|
||||
Examples::
|
||||
|
||||
tokenizer = OpenAIGPTTokenizer.from_pretrained('openai-gpt')
|
||||
model = OpenAIGPTModel.from_pretrained('openai-gpt')
|
||||
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0) # Batch size 1
|
||||
outputs = model(input_ids)
|
||||
last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
@ -383,6 +364,7 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
|
||||
for layer, heads in heads_to_prune.items():
|
||||
self.h[layer].attn.prune_heads(heads)
|
||||
|
||||
@add_start_docstrings_to_callable(OPENAI_GPT_INPUTS_DOCSTRING)
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
@ -392,6 +374,32 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
):
|
||||
r"""
|
||||
Return:
|
||||
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (config) and inputs:
|
||||
last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
|
||||
Sequence of hidden-states at the last layer of the model.
|
||||
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
|
||||
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
|
||||
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
||||
|
||||
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``):
|
||||
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
|
||||
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
|
||||
|
||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
||||
heads.
|
||||
|
||||
Examples::
|
||||
|
||||
tokenizer = OpenAIGPTTokenizer.from_pretrained('openai-gpt')
|
||||
model = OpenAIGPTModel.from_pretrained('openai-gpt')
|
||||
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0) # Batch size 1
|
||||
outputs = model(input_ids)
|
||||
last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
|
||||
|
||||
"""
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
@ -481,41 +489,10 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
|
||||
|
||||
@add_start_docstrings(
|
||||
"""OpenAI GPT Model transformer with a language modeling head on top
|
||||
(linear layer with weights tied to the input embeddings). """,
|
||||
(linear layer with weights tied to the input embeddings). """,
|
||||
OPENAI_GPT_START_DOCSTRING,
|
||||
OPENAI_GPT_INPUTS_DOCSTRING,
|
||||
)
|
||||
class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
|
||||
r"""
|
||||
**labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
|
||||
Labels for language modeling.
|
||||
Note that the labels **are shifted** inside the model, i.e. you can set ``labels = input_ids``
|
||||
Indices are selected in ``[-100, 0, ..., config.vocab_size]``
|
||||
All labels set to ``-100`` are ignored (masked), the loss is only
|
||||
computed for labels in ``[0, ..., config.vocab_size]``
|
||||
|
||||
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
|
||||
**loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
|
||||
Language modeling loss.
|
||||
**prediction_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, config.vocab_size)``
|
||||
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
||||
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
|
||||
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
|
||||
of shape ``(batch_size, sequence_length, hidden_size)``:
|
||||
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
|
||||
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
|
||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
|
||||
|
||||
Examples::
|
||||
|
||||
tokenizer = OpenAIGPTTokenizer.from_pretrained('openai-gpt')
|
||||
model = OpenAIGPTLMHeadModel.from_pretrained('openai-gpt')
|
||||
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0) # Batch size 1
|
||||
outputs = model(input_ids, labels=input_ids)
|
||||
loss, logits = outputs[:2]
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
@ -527,6 +504,7 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
|
||||
def get_output_embeddings(self):
|
||||
return self.lm_head
|
||||
|
||||
@add_start_docstrings_to_callable(OPENAI_GPT_INPUTS_DOCSTRING)
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
@ -537,6 +515,45 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
|
||||
inputs_embeds=None,
|
||||
labels=None,
|
||||
):
|
||||
r"""
|
||||
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
|
||||
Labels for language modeling.
|
||||
Note that the labels **are shifted** inside the model, i.e. you can set ``lm_labels = input_ids``
|
||||
Indices are selected in ``[-100, 0, ..., config.vocab_size]``
|
||||
All labels set to ``-100`` are ignored (masked), the loss is only
|
||||
computed for labels in ``[0, ..., config.vocab_size]``
|
||||
|
||||
Return:
|
||||
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:obj:`~transformers.OpenAIGPTConfig`) and inputs:
|
||||
loss (:obj:`torch.FloatTensor` of shape `(1,)`, `optional`, returned when ``labels`` is provided)
|
||||
Language modeling loss.
|
||||
prediction_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
|
||||
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
||||
past (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers` with each tensor of shape :obj:`(2, batch_size, num_heads, sequence_length, embed_size_per_head)`):
|
||||
Contains pre-computed hidden-states (key and values in the attention blocks).
|
||||
Can be used (see `past` input) to speed up sequential decoding. The token ids which have their past given to this model
|
||||
should not be passed as input ids as they have already been computed.
|
||||
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
|
||||
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
|
||||
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
||||
|
||||
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``):
|
||||
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
|
||||
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
|
||||
|
||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
||||
heads.
|
||||
|
||||
Examples::
|
||||
|
||||
tokenizer = OpenAIGPTTokenizer.from_pretrained('openai-gpt')
|
||||
model = OpenAIGPTLMHeadModel.from_pretrained('openai-gpt')
|
||||
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0) # Batch size 1
|
||||
outputs = model(input_ids, labels=input_ids)
|
||||
loss, logits = outputs[:2]
|
||||
|
||||
"""
|
||||
transformer_outputs = self.transformer(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
@ -563,48 +580,80 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
|
||||
|
||||
@add_start_docstrings(
|
||||
"""OpenAI GPT Model transformer with a language modeling and a multiple-choice classification
|
||||
head on top e.g. for RocStories/SWAG tasks. The two heads are two linear layers.
|
||||
The language modeling head has its weights tied to the input embeddings,
|
||||
the classification head takes as input the input of a specified classification token index in the input sequence).
|
||||
head on top e.g. for RocStories/SWAG tasks. The two heads are two linear layers.
|
||||
The language modeling head has its weights tied to the input embeddings,
|
||||
the classification head takes as input the input of a specified classification token index in the input sequence).
|
||||
""",
|
||||
OPENAI_GPT_START_DOCSTRING,
|
||||
OPENAI_GPT_INPUTS_DOCSTRING,
|
||||
)
|
||||
class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
|
||||
r"""
|
||||
**mc_token_ids**: (`optional`, default to index of the last token of the input) ``torch.LongTensor`` of shape ``(batch_size, num_choices)``:
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
config.num_labels = 1
|
||||
self.transformer = OpenAIGPTModel(config)
|
||||
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
||||
self.multiple_choice_head = SequenceSummary(config)
|
||||
|
||||
self.init_weights()
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return self.lm_head
|
||||
|
||||
@add_start_docstrings_to_callable(OPENAI_GPT_INPUTS_DOCSTRING)
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
mc_token_ids=None,
|
||||
lm_labels=None,
|
||||
mc_labels=None,
|
||||
):
|
||||
r"""
|
||||
mc_token_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, num_choices)`, `optional`, default to index of the last token of the input)
|
||||
Index of the classification token in each input sequence.
|
||||
Selected in the range ``[0, input_ids.size(-1) - 1[``.
|
||||
**lm_labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
|
||||
lm_labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`)
|
||||
Labels for language modeling.
|
||||
Note that the labels **are shifted** inside the model, i.e. you can set ``lm_labels = input_ids``
|
||||
Indices are selected in ``[-100, 0, ..., config.vocab_size]``
|
||||
Indices are selected in ``[-1, 0, ..., config.vocab_size]``
|
||||
All labels set to ``-100`` are ignored (masked), the loss is only
|
||||
computed for labels in ``[0, ..., config.vocab_size]``
|
||||
**mc_labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size)``:
|
||||
mc_labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size)`, `optional`, defaults to :obj:`None`)
|
||||
Labels for computing the multiple choice classification loss.
|
||||
Indices should be in ``[0, ..., num_choices]`` where `num_choices` is the size of the second dimension
|
||||
of the input tensors. (see `input_ids` above)
|
||||
|
||||
`multiple_choice_labels`: optional multiple choice labels: ``torch.LongTensor`` of shape [batch_size]
|
||||
with indices selected in [0, ..., num_choices].
|
||||
|
||||
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
|
||||
**lm_loss**: (`optional`, returned when ``lm_labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
|
||||
Return:
|
||||
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:obj:`~transformers.OpenAIGPTConfig`) and inputs:
|
||||
lm_loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when ``lm_labels`` is provided):
|
||||
Language modeling loss.
|
||||
**mc_loss**: (`optional`, returned when ``multiple_choice_labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
|
||||
mc_loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`multiple_choice_labels` is provided):
|
||||
Multiple choice classification loss.
|
||||
**lm_prediction_scores**: ``torch.FloatTensor`` of shape ``(batch_size, num_choices, sequence_length, config.vocab_size)``
|
||||
lm_prediction_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_choices, sequence_length, config.vocab_size)`):
|
||||
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
||||
**mc_prediction_scores**: ``torch.FloatTensor`` of shape ``(batch_size, num_choices)``
|
||||
Prediction scores of the multiplechoice classification head (scores for each choice before SoftMax).
|
||||
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
|
||||
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
|
||||
of shape ``(batch_size, sequence_length, hidden_size)``:
|
||||
mc_prediction_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_choices)`):
|
||||
Prediction scores of the multiple choice classification head (scores for each choice before SoftMax).
|
||||
past (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers` with each tensor of shape :obj:`(2, batch_size, num_heads, sequence_length, embed_size_per_head)`):
|
||||
Contains pre-computed hidden-states (key and values in the attention blocks).
|
||||
Can be used (see `past` input) to speed up sequential decoding. The token ids which have their past given to this model
|
||||
should not be passed as input ids as they have already been computed.
|
||||
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
|
||||
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
|
||||
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
||||
|
||||
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
|
||||
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
|
||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
|
||||
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``):
|
||||
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
|
||||
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
|
||||
|
||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
||||
heads.
|
||||
|
||||
Examples::
|
||||
|
||||
@ -621,32 +670,6 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
|
||||
lm_prediction_scores, mc_prediction_scores = outputs[:2]
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
config.num_labels = 1
|
||||
self.transformer = OpenAIGPTModel(config)
|
||||
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
||||
self.multiple_choice_head = SequenceSummary(config)
|
||||
|
||||
self.init_weights()
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return self.lm_head
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
mc_token_ids=None,
|
||||
lm_labels=None,
|
||||
mc_labels=None,
|
||||
):
|
||||
transformer_outputs = self.transformer(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
|
Loading…
Reference in New Issue
Block a user