added doc for openai GPT

This commit is contained in:
thomwolf 2019-07-15 09:58:01 +02:00
parent 62b8eb43c1
commit 4cb489457f
2 changed files with 158 additions and 237 deletions

View File

@ -154,6 +154,7 @@ class BertConfig(PretrainedConfig):
:class:`~pytorch_transformers.BertConfig` is the configuration class to store the configuration of a
`BertModel`.
Arguments:
vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `BertModel`.
hidden_size: Size of the encoder layers and the pooler layer.
@ -193,31 +194,6 @@ class BertConfig(PretrainedConfig):
initializer_range=0.02,
layer_norm_eps=1e-12,
**kwargs):
"""Constructs BertConfig.
Arguments:
vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `BertModel`.
hidden_size: Size of the encoder layers and the pooler layer.
num_hidden_layers: Number of hidden layers in the Transformer encoder.
num_attention_heads: Number of attention heads for each attention layer in
the Transformer encoder.
intermediate_size: The size of the "intermediate" (i.e., feed-forward)
layer in the Transformer encoder.
hidden_act: The non-linear activation function (function or string) in the
encoder and pooler. If string, "gelu", "relu" and "swish" are supported.
hidden_dropout_prob: The dropout probabilitiy for all fully connected
layers in the embeddings, encoder, and pooler.
attention_probs_dropout_prob: The dropout ratio for the attention
probabilities.
max_position_embeddings: The maximum sequence length that this model might
ever be used with. Typically set this to something large just in case
(e.g., 512 or 1024 or 2048).
type_vocab_size: The vocabulary size of the `token_type_ids` passed into
`BertModel`.
initializer_range: The sttdev of the truncated_normal_initializer for
initializing all weight matrices.
layer_norm_eps: The epsilon used by LayerNorm.
"""
super(BertConfig, self).__init__(**kwargs)
if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2
and isinstance(vocab_size_or_config_json_file, unicode)):

View File

@ -379,47 +379,73 @@ class OpenAIGPTPreTrainedModel(PreTrainedModel):
module.weight.data.fill_(1.0)
OPENAI_GPT_START_DOCSTRING = r""" OpenAI GPT model was proposed in
`Improving Language Understanding by Generative Pre-Training`_
by Alec Radford, Karthik Narasimhan, Tim Salimans and Ilya Sutskever.
It's a causal (unidirectional) transformer pre-trained using language modeling on a large
corpus will long range dependencies, the Toronto Book Corpus.
This model is a PyTorch `torch.nn.Module`_ sub-class. Use it as a regular PyTorch Module and
refer to the PyTorch documentation for all matter related to general usage and behavior.
.. _`Improving Language Understanding by Generative Pre-Training`:
https://openai.com/blog/language-unsupervised/
.. _`torch.nn.Module`:
https://pytorch.org/docs/stable/nn.html#module
Parameters:
config (:class:`~pytorch_transformers.BertConfig`): Model configuration class with all the parameters of the model.
"""
OPENAI_GPT_INPUTS_DOCTRING = r""" Inputs:
**input_ids**: ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
Indices of input sequence tokens in the vocabulary.
Indices can be obtained using :class:`pytorch_transformers.BPT2Tokenizer`.
See :func:`pytorch_transformers.PreTrainedTokenizer.encode` and
:func:`pytorch_transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details.
**position_ids**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
Indices of positions of each input sequence tokens in the position embeddings.
Selected in the range ``[0, config.max_position_embeddings - 1[``.
**token_type_ids**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
A parallel sequence of tokens (can be used to indicate various portions of the inputs).
The embeddings from these tokens will be summed with the respective token embeddings.
Indices are selected in the vocabulary (unlike BERT which has a specific vocabulary for segment indices).
**attention_mask**: (`optional`) ``torch.Tensor`` of shape ``(batch_size, sequence_length)``:
Mask to avoid performing attention on padding token indices.
Mask indices selected in ``[0, 1]``:
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
**head_mask**: (`optional`) ``torch.Tensor`` of shape ``(num_heads,)`` or ``(num_layers, num_heads)``:
Mask to nullify selected heads of the self-attention modules.
Mask indices selected in ``[0, 1]``:
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
"""
@add_start_docstrings("The bare OpenAI GPT transformer model outputing raw hidden-states without any specific head on top.",
OPENAI_GPT_START_DOCSTRING, GPT2_INPUTS_DOCTRING)
class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
"""OpenAI GPT model ("Improving Language Understanding by Generative Pre-Training").
r"""
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
**last_hidden_state**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, hidden_size)``
Sequence of hidden-states at the last layer of the model.
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
of shape ``(batch_size, sequence_length, hidden_size)``:
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
OpenAI GPT uses a single embedding matrix to store the word and special embeddings.
Special tokens embeddings are additional tokens that are not pre-trained, such as: [SEP], [CLS]...
Examples::
Special tokens need to be trained during the fine-tuning if you use them.
The number of special embeddings can be controlled using the ``set_num_special_tokens(num_special_tokens)`` function.
>>> config = OpenAIGPTConfig.from_pretrained('openai-gpt')
>>> tokenizer = OpenAIGPTTokenizer.from_pretrained('openai-gpt')
>>> model = OpenAIGPTModel(config)
>>> input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
>>> outputs = model(input_ids)
>>> last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
The embeddings are ordered as follow in the token embeddings matrix:
::
[0, ----------------------
... -> word embeddings
config.vocab_size - 1, ______________________
config.vocab_size,
... -> special embeddings
config.vocab_size + n_special - 1] ______________________
where ``total_tokens_embeddings`` is:
::
total_tokens_embeddings = config.vocab_size + n_special
You should use the associated indices to index the embeddings.
Args:
`config`: a OpenAIGPTConfig class instance with the configuration to build a new model
`output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False
`keep_multihead_output`: If True, saves output of the multi-head attention module with its gradient.
This can be used to compute head importance metrics. Default: False
Example::
config = modeling_openai.OpenAIGPTConfig()
model = modeling_openai.OpenAIGPTModel(config)
"""
def __init__(self, config):
super(OpenAIGPTModel, self).__init__(config)
self.output_attentions = config.output_attentions
@ -444,37 +470,6 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
self.h[layer].attn.prune_heads(heads)
def forward(self, input_ids, position_ids=None, token_type_ids=None, head_mask=None):
"""
Performs a model forward pass. **Can be called by calling the class directly, once it has been instantiated.**
Args:
`input_ids`: a ``torch.LongTensor`` of shape [batch_size, sequence_length] (or more generally [d_1, ..., d_n, sequence_length]
were d_1 ... d_n are arbitrary dimensions) with the word BPE token indices selected in the range [0, total_tokens_embeddings[
`position_ids`: an optional ``torch.LongTensor`` with the same shape as input_ids
with the position indices (selected in the range [0, config.n_positions - 1[.
`token_type_ids`: an optional ``torch.LongTensor`` with the same shape as input_ids
You can use it to add a third type of embedding to each input token in the sequence
(the previous two being the word and position embeddings).
The input, position and token_type embeddings are summed inside the Transformer before the first
self-attention block.
`head_mask`: an optional ``torch.Tensor`` of shape [num_heads] or [num_layers, num_heads] with indices between 0 and 1.
It's a mask to be used to nullify some heads of the transformer. 1.0 => head is fully masked, 0.0 => head is not masked.
Returns:
``hidden_states``, a list of all the encoded-hidden-states in the model (length of the list is number
of layers + 1 for the output of the embeddings)
as ``torch.FloatTensor`` of size [batch_size, sequence_length, hidden_size]
(or more generally [d_1, ..., d_n, hidden_size] were d_1 ... d_n are the dimension of input_ids)
Example::
# Already been converted into BPE token ids
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
hidden_states = model(input_ids)
# or
hidden_states = model.forward(input_ids)
"""
if position_ids is None:
# This was used when we had a single embedding matrice from position and token embeddings
# start = self.config.vocab_size + self.config.n_special
@ -536,46 +531,40 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
return outputs # last hidden state, (all hidden states), (all attentions)
@add_start_docstrings("""OpenAI GPT Model transformer with a language modeling head on top
(linear layer with weights tied to the input embeddings). """, OPENAI_GPT_START_DOCSTRING, OPENAI_GPT_INPUTS_DOCTRING)
class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
"""OpenAI GPT model with a Language Modeling head ("Improving Language Understanding by Generative Pre-Training").
r"""
**lm_labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
Labels for language modeling.
Note that the labels **are shifted** inside the model, i.e. you can set ``lm_labels = input_ids``
Indices are selected in ``[-1, 0, ..., config.vocab_size]``
All labels set to ``-1`` are ignored (masked), the loss is only
computed for labels in ``[0, ..., config.vocab_size]``
OpenAI GPT use a single embedding matrix to store the word and special embeddings.
Special tokens embeddings are additional tokens that are not pre-trained: [SEP], [CLS]...
Special tokens need to be trained during the fine-tuning if you use them. The number of special embeddings
can be controlled using the ``set_num_special_tokens(num_special_tokens)`` function.
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
**loss**: (`optional`, returned when ``lm_labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
Language modeling loss.
**prediction_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, config.vocab_size)``
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
of shape ``(batch_size, sequence_length, hidden_size)``:
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
The embeddings are ordered as follow in the token embeddings matrix:
Examples::
::
>>> config = OpenAIGPTConfig.from_pretrained('openai-gpt')
>>> tokenizer = OpenAIGPTTokenizer.from_pretrained('openai-gpt')
>>> model = OpenAIGPTLMHeadModel(config)
>>> input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
>>> outputs = model(input_ids, lm_labels=input_ids)
>>> loss, logits = outputs[:2]
[0, ----------------------
... -> word embeddings
config.vocab_size - 1, ______________________
config.vocab_size,
... -> special embeddings
config.vocab_size + config.n_special - 1] ______________________
where ``total_tokens_embeddings`` can be obtained as ``config.total_tokens_embeddings`` and is:
::
total_tokens_embeddings = config.vocab_size + config.n_special
You should use the associated indices to index the embeddings.
Args:
`config`: a OpenAIGPTConfig class instance with the configuration to build a new model
`output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False
`keep_multihead_output`: If True, saves output of the multi-head attention module with its gradient.
This can be used to compute head importance metrics. Default: False
Example::
config = modeling_openai.OpenAIGPTConfig()
model = modeling_openai.OpenAIGPTLMHeadModel(config)
"""
def __init__(self, config):
super(OpenAIGPTLMHeadModel, self).__init__(config)
self.transformer = OpenAIGPTModel(config)
@ -592,40 +581,6 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
self.transformer.tokens_embed)
def forward(self, input_ids, position_ids=None, token_type_ids=None, lm_labels=None, head_mask=None):
"""
Performs a model forward pass. **Can be called by calling the class directly, once it has been instantiated.**
Args:
`input_ids`: a ``torch.LongTensor`` of shape [batch_size, sequence_length] (or more generally [d_1, ..., d_n, sequence_length]
were d_1 ... d_n are arbitrary dimensions) with the word BPE token indices selected in the range [0, total_tokens_embeddings[
`position_ids`: an optional ``torch.LongTensor`` with the same shape as input_ids
with the position indices (selected in the range [0, config.n_positions - 1[.
`token_type_ids`: an optional ``torch.LongTensor`` with the same shape as input_ids
You can use it to add a third type of embedding to each input token in the sequence
(the previous two being the word and position embeddings).
The input, position and token_type embeddings are summed inside the Transformer before the first
self-attention block.
`lm_labels`: optional language modeling labels: ``torch.LongTensor`` of shape [batch_size, sequence_length]
with indices selected in [-1, 0, ..., vocab_size]. All labels set to -1 are ignored (masked), the loss
is only computed for the labels set in [0, ..., vocab_size]
`head_mask`: an optional ``torch.Tensor`` of shape [num_heads] or [num_layers, num_heads] with indices between 0 and 1.
It's a mask to be used to nullify some heads of the transformer. 1.0 => head is fully masked, 0.0 => head is not masked.
Returns:
if ``lm_labels`` is not ``None``, outputs the language modeling loss. Otherwise, outputs ``lm_logits``,
the language modeling logits as a ``torch.FloatTensor`` of size [batch_size, sequence_length,
total_tokens_embeddings] (or more generally [d_1, ..., d_n, total_tokens_embeddings] where d_1 ... d_n are
the dimension of input_ids)
Example::
# Already been converted into BPE token ids
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
lm_logits = model(input_ids)
# or
lm_logits = model.forward(input_ids)
"""
transformer_outputs = self.transformer(input_ids, position_ids, token_type_ids, head_mask)
hidden_states = transformer_outputs[0]
lm_logits = self.lm_head(hidden_states)
@ -644,46 +599,80 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
return outputs # (loss), lm_logits, (all hidden states), (all attentions)
@add_start_docstrings("""OpenAI GPT Model transformer with a language modeling and a multiple-choice classification
head on top e.g. for RocStories/SWAG tasks. The two heads are two linear layers.
The language modeling head has its weights tied to the input embeddings,
the classification head takes as input the input of a specified classification token index in the intput sequence).
""", OPENAI_GPT_START_DOCSTRING)
class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
"""OpenAI GPT model with a Language Modeling and a Multiple Choice head ("Improving Language Understanding by Generative Pre-Training").
r""" Inputs:
**input_ids**: ``torch.LongTensor`` of shape ``(batch_size, num_choices, sequence_length)``:
Indices of input sequence tokens in the vocabulary.
The second dimension of the input (`num_choices`) indicates the number of choices to score.
Indices can be obtained using :class:`pytorch_transformers.BPT2Tokenizer`.
See :func:`pytorch_transformers.PreTrainedTokenizer.encode` and
:func:`pytorch_transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details.
**mc_token_ids**: ``torch.LongTensor`` of shape ``(batch_size, num_choices)``:
Index of the classification token in each input sequence.
Selected in the range ``[0, input_ids.size(-1) - 1[``.
**position_ids**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, num_choices, sequence_length)``:
Indices of positions of each input sequence tokens in the position embeddings.
Selected in the range ``[0, config.max_position_embeddings - 1[``.
**token_type_ids**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, num_choices, sequence_length)``:
A parallel sequence of tokens (can be used to indicate various portions of the inputs).
The embeddings from these tokens will be summed with the respective token embeddings.
Indices are selected in the vocabulary (unlike BERT which has a specific vocabulary for segment indices).
**attention_mask**: (`optional`) ``torch.Tensor`` of shape ``(batch_size, num_choices, sequence_length)``:
Mask to avoid performing attention on padding token indices.
Mask indices selected in ``[0, 1]``:
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
**head_mask**: (`optional`) ``torch.Tensor`` of shape ``(num_heads,)`` or ``(num_layers, num_heads)``:
Mask to nullify selected heads of the self-attention modules.
Mask indices selected in ``[0, 1]``:
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
**lm_labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
Labels for language modeling.
Note that the labels **are shifted** inside the model, i.e. you can set ``lm_labels = input_ids``
Indices are selected in ``[-1, 0, ..., config.vocab_size]``
All labels set to ``-1`` are ignored (masked), the loss is only
computed for labels in ``[0, ..., config.vocab_size]``
**multiple_choice_labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size)``:
Labels for computing the multiple choice classification loss.
Indices should be in ``[0, ..., num_choices]`` where `num_choices` is the size of the second dimension
of the input tensors. (see `input_ids` above)
OpenAI GPT use a single embedding matrix to store the word and special embeddings.
Special tokens embeddings are additional tokens that are not pre-trained: [SEP], [CLS]...
Special tokens need to be trained during the fine-tuning if you use them.
The number of special embeddings can be controlled using the ``set_num_special_tokens(num_special_tokens)``
function.
`multiple_choice_labels`: optional multiple choice labels: ``torch.LongTensor`` of shape [batch_size]
with indices selected in [0, ..., num_choices].
The embeddings are ordered as follow in the token embeddings matrix:
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
**lm_loss**: (`optional`, returned when ``lm_labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
Language modeling loss.
**mc_loss**: (`optional`, returned when ``multiple_choice_labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
Multiple choice classification loss.
**lm_prediction_scores**: ``torch.FloatTensor`` of shape ``(batch_size, num_choices, sequence_length, config.vocab_size)``
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
**mc_prediction_scores**: ``torch.FloatTensor`` of shape ``(batch_size, num_choices)``
Prediction scores of the multiplechoice classification head (scores for each choice before SoftMax).
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
of shape ``(batch_size, sequence_length, hidden_size)``:
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
::
Examples::
[0, ----------------------
... -> word embeddings
config.vocab_size - 1, ______________________
config.vocab_size,
... -> special embeddings
config.vocab_size + n_special - 1] ______________________
>>> config = OpenAIGPTConfig.from_pretrained('openai-gpt')
>>> tokenizer = OpenAIGPTTokenizer.from_pretrained('openai-gpt')
>>> model = OpenAIGPTDoubleHeadsModel(config)
>>> choices = ["Hello, my dog is cute [CLS]", "Hello, my cat is cute [CLS]"] # Assume you've added [CLS] to the vocabulary
>>> input_ids = torch.tensor([tokenizer.encode(s) for s in choices]).unsqueeze(0) # Batch size 1, 2 choices
>>> mc_token_ids = torch.tensor([-1, -1]).unsqueeze(0) # Batch size 1
>>> outputs = model(input_ids, mc_token_ids)
>>> lm_prediction_scores, mc_prediction_scores = outputs[:2]
where ``total_tokens_embeddings`` is:
::
total_tokens_embeddings = config.vocab_size + .n_special
You should use the associate indices to index the embeddings.
Args:
`config`: a OpenAIGPTConfig class instance with the configuration to build a new model
`output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False
`keep_multihead_output`: If True, saves output of the multi-head attention module with its gradient.
This can be used to compute head importance metrics. Default: False
Example::
config = modeling_openai.OpenAIGPTConfig()
model = modeling_openai.OpenAIGPTDoubleHeadsModel(config)
"""
def __init__(self, config):
super(OpenAIGPTDoubleHeadsModel, self).__init__(config)
@ -703,50 +692,6 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
def forward(self, input_ids, mc_token_ids=None, lm_labels=None, mc_labels=None, token_type_ids=None,
position_ids=None, head_mask=None):
"""
Performs a model forward pass. **Can be called by calling the class directly, once it has been instantiated.**
Args:
`input_ids`: a ``torch.LongTensor`` of shape [batch_size, num_choices, sequence_length] with the BPE token
indices selected in the range [0, total_tokens_embeddings[
`mc_token_ids`: a ``torch.LongTensor`` of shape [batch_size, num_choices] with the index of the token from
which we should take the hidden state to feed the multiple choice classifier (usually last token of the sequence)
`position_ids`: an optional ``torch.LongTensor`` with the same shape as input_ids
with the position indices (selected in the range [0, config.n_positions - 1[.
`token_type_ids`: an optional ``torch.LongTensor`` with the same shape as input_ids
You can use it to add a third type of embedding to each input token in the sequence
(the previous two being the word and position embeddings).
The input, position and token_type embeddings are summed inside the Transformer before the first
self-attention block.
`lm_labels`: optional language modeling labels: ``torch.LongTensor`` of shape [batch_size, num_choices, sequence_length]
with indices selected in [-1, 0, ..., total_tokens_embeddings]. All labels set to -1 are ignored (masked), the loss
is only computed for the labels set in [0, ..., total_tokens_embeddings]
`multiple_choice_labels`: optional multiple choice labels: ``torch.LongTensor`` of shape [batch_size]
with indices selected in [0, ..., num_choices].
`head_mask`: an optional ``torch.Tensor`` of shape [num_heads] or [num_layers, num_heads] with indices between 0 and 1.
It's a mask to be used to nullify some heads of the transformer. 1.0 => head is fully masked, 0.0 => head is not masked.
Returns:
if ``lm_labels`` and ``multiple_choice_labels`` are not ``None``, outputs a tuple of losses with the
language modeling loss and the multiple choice loss. Otherwise, returns a
``tuple(lm_logits, multiple_choice_logits)``.
``lm_logits`` are the language modeling logits as a ``torch.FloatTensor`` of size
[batch_size, num_choices, sequence_length, total_tokens_embeddings]
``multiple_choice_logits``: the multiple choice logits as a ``torch.FloatTensor`` of
size [batch_size, num_choices]
Example::
# Already been converted into BPE token ids
input_ids = torch.LongTensor([[[31, 51, 99], [15, 5, 0]]]) # (bsz, number of choice, seq length)
mc_token_ids = torch.LongTensor([[2], [1]]) # (bsz, number of choice)
lm_logits, multiple_choice_logits = model(input_ids, mc_token_ids)
# or
lm_logits, multiple_choice_logits = model.forward(input_ids, mc_token_ids)
"""
transformer_outputs = self.transformer(input_ids, position_ids, token_type_ids, head_mask)
hidden_states = transformer_outputs[0]