mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-15 18:48:24 +06:00
719 lines
34 KiB
Python
719 lines
34 KiB
Python
# coding=utf-8
|
|
# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
|
|
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""PyTorch OpenAI GPT model."""
|
|
|
|
from __future__ import absolute_import, division, print_function, unicode_literals
|
|
|
|
import collections
|
|
import json
|
|
import logging
|
|
import math
|
|
import os
|
|
import sys
|
|
from io import open
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from torch.nn import CrossEntropyLoss
|
|
from torch.nn.parameter import Parameter
|
|
|
|
from .modeling_utils import (Conv1D, CONFIG_NAME, WEIGHTS_NAME, PretrainedConfig,
|
|
PreTrainedModel, prune_conv1d_layer, SequenceSummary,
|
|
add_start_docstrings)
|
|
from .modeling_bert import BertLayerNorm as LayerNorm
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP = {"openai-gpt": "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-pytorch_model.bin"}
|
|
OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP = {"openai-gpt": "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-config.json"}
|
|
|
|
|
|
def load_tf_weights_in_openai_gpt(model, config, openai_checkpoint_folder_path):
|
|
""" Load tf pre-trained weights in a pytorch model (from NumPy arrays here)
|
|
"""
|
|
import re
|
|
import numpy as np
|
|
|
|
if '.ckpt' in openai_checkpoint_folder_path:
|
|
openai_checkpoint_folder_path = os.path.dirname(openai_checkpoint_folder_path)
|
|
|
|
logger.info("Loading weights from {}".format(openai_checkpoint_folder_path))
|
|
|
|
names = json.load(open(openai_checkpoint_folder_path + '/parameters_names.json', "r", encoding='utf-8'))
|
|
shapes = json.load(open(openai_checkpoint_folder_path + '/params_shapes.json', "r", encoding='utf-8'))
|
|
offsets = np.cumsum([np.prod(shape) for shape in shapes])
|
|
init_params = [np.load(openai_checkpoint_folder_path + '/params_{}.npy'.format(n)) for n in range(10)]
|
|
init_params = np.split(np.concatenate(init_params, 0), offsets)[:-1]
|
|
init_params = [param.reshape(shape) for param, shape in zip(init_params, shapes)]
|
|
|
|
# This was used when we had a single embedding matrix for positions and tokens
|
|
# init_params[0] = np.concatenate([init_params[1], init_params[0]], 0)
|
|
# del init_params[1]
|
|
init_params = [arr.squeeze() for arr in init_params]
|
|
|
|
try:
|
|
assert model.tokens_embed.weight.shape == init_params[1].shape
|
|
assert model.positions_embed.weight.shape == init_params[0].shape
|
|
except AssertionError as e:
|
|
e.args += (model.tokens_embed.weight.shape, init_params[1].shape)
|
|
e.args += (model.positions_embed.weight.shape, init_params[0].shape)
|
|
raise
|
|
|
|
model.tokens_embed.weight.data = torch.from_numpy(init_params[1])
|
|
model.positions_embed.weight.data = torch.from_numpy(init_params[0])
|
|
names.pop(0)
|
|
# Pop position and token embedding arrays
|
|
init_params.pop(0)
|
|
init_params.pop(0)
|
|
|
|
for name, array in zip(names, init_params): # names[1:n_transfer], init_params[1:n_transfer]):
|
|
name = name[6:] # skip "model/"
|
|
assert name[-2:] == ":0"
|
|
name = name[:-2]
|
|
name = name.split('/')
|
|
pointer = model
|
|
for m_name in name:
|
|
if re.fullmatch(r'[A-Za-z]+\d+', m_name):
|
|
l = re.split(r'(\d+)', m_name)
|
|
else:
|
|
l = [m_name]
|
|
if l[0] == 'g':
|
|
pointer = getattr(pointer, 'weight')
|
|
elif l[0] == 'b':
|
|
pointer = getattr(pointer, 'bias')
|
|
elif l[0] == 'w':
|
|
pointer = getattr(pointer, 'weight')
|
|
else:
|
|
pointer = getattr(pointer, l[0])
|
|
if len(l) >= 2:
|
|
num = int(l[1])
|
|
pointer = pointer[num]
|
|
try:
|
|
assert pointer.shape == array.shape
|
|
except AssertionError as e:
|
|
e.args += (pointer.shape, array.shape)
|
|
raise
|
|
try:
|
|
assert pointer.shape == array.shape
|
|
except AssertionError as e:
|
|
e.args += (pointer.shape, array.shape)
|
|
raise
|
|
logger.info("Initialize PyTorch weight {}".format(name))
|
|
pointer.data = torch.from_numpy(array)
|
|
return model
|
|
|
|
|
|
def gelu(x):
|
|
return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
|
|
|
|
|
|
def swish(x):
|
|
return x * torch.sigmoid(x)
|
|
|
|
|
|
ACT_FNS = {"relu": nn.ReLU, "swish": swish, "gelu": gelu}
|
|
|
|
|
|
class OpenAIGPTConfig(PretrainedConfig):
|
|
"""
|
|
Configuration class to store the configuration of a `OpenAIGPTModel`.
|
|
|
|
Args:
|
|
vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `OpenAIGPTModel` or a configuration json file.
|
|
n_special: The number of special tokens to learn during fine-tuning ('[SEP]', '[CLF]', ...)
|
|
n_positions: Number of positional embeddings.
|
|
n_ctx: Size of the causal mask (usually same as n_positions).
|
|
n_embd: Dimensionality of the embeddings and hidden states.
|
|
n_layer: Number of hidden layers in the Transformer encoder.
|
|
n_head: Number of attention heads for each attention layer in
|
|
the Transformer encoder.
|
|
afn: The non-linear activation function (function or string) in the
|
|
encoder and pooler. If string, "gelu", "relu" and "swish" are supported.
|
|
resid_pdrop: The dropout probabilitiy for all fully connected
|
|
layers in the embeddings, encoder, and pooler.
|
|
attn_pdrop: The dropout ratio for the attention
|
|
probabilities.
|
|
embd_pdrop: The dropout ratio for the embeddings.
|
|
layer_norm_epsilon: epsilon to use in the layer norm layers
|
|
initializer_range: The sttdev of the truncated_normal_initializer for
|
|
initializing all weight matrices.
|
|
predict_special_tokens: should we predict special tokens (when the model has a LM head)
|
|
"""
|
|
pretrained_config_archive_map = OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP
|
|
|
|
def __init__(
|
|
self,
|
|
vocab_size_or_config_json_file=40478,
|
|
n_positions=512,
|
|
n_ctx=512,
|
|
n_embd=768,
|
|
n_layer=12,
|
|
n_head=12,
|
|
afn="gelu",
|
|
resid_pdrop=0.1,
|
|
embd_pdrop=0.1,
|
|
attn_pdrop=0.1,
|
|
layer_norm_epsilon=1e-5,
|
|
initializer_range=0.02,
|
|
predict_special_tokens=True,
|
|
|
|
num_labels=1,
|
|
summary_type='token_ids',
|
|
summary_use_proj=True,
|
|
summary_activation=None,
|
|
summary_proj_to_labels=True,
|
|
summary_first_dropout=0.1,
|
|
**kwargs
|
|
):
|
|
"""Constructs OpenAIGPTConfig.
|
|
"""
|
|
super(OpenAIGPTConfig, self).__init__(**kwargs)
|
|
|
|
if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2
|
|
and isinstance(vocab_size_or_config_json_file, unicode)):
|
|
with open(vocab_size_or_config_json_file, "r", encoding="utf-8") as reader:
|
|
json_config = json.loads(reader.read())
|
|
for key, value in json_config.items():
|
|
self.__dict__[key] = value
|
|
elif isinstance(vocab_size_or_config_json_file, int):
|
|
self.vocab_size = vocab_size_or_config_json_file
|
|
self.n_ctx = n_ctx
|
|
self.n_positions = n_positions
|
|
self.n_embd = n_embd
|
|
self.n_layer = n_layer
|
|
self.n_head = n_head
|
|
self.afn = afn
|
|
self.resid_pdrop = resid_pdrop
|
|
self.embd_pdrop = embd_pdrop
|
|
self.attn_pdrop = attn_pdrop
|
|
self.layer_norm_epsilon = layer_norm_epsilon
|
|
self.initializer_range = initializer_range
|
|
self.predict_special_tokens = predict_special_tokens
|
|
|
|
self.num_labels = num_labels
|
|
self.summary_type = summary_type
|
|
self.summary_use_proj = summary_use_proj
|
|
self.summary_activation = summary_activation
|
|
self.summary_first_dropout = summary_first_dropout
|
|
self.summary_proj_to_labels = summary_proj_to_labels
|
|
else:
|
|
raise ValueError(
|
|
"First argument must be either a vocabulary size (int)"
|
|
"or the path to a pretrained model config file (str)"
|
|
)
|
|
|
|
@property
|
|
def max_position_embeddings(self):
|
|
return self.n_positions
|
|
|
|
@property
|
|
def hidden_size(self):
|
|
return self.n_embd
|
|
|
|
@property
|
|
def num_attention_heads(self):
|
|
return self.n_head
|
|
|
|
@property
|
|
def num_hidden_layers(self):
|
|
return self.n_layer
|
|
|
|
|
|
class Attention(nn.Module):
|
|
def __init__(self, nx, n_ctx, config, scale=False):
|
|
super(Attention, self).__init__()
|
|
n_state = nx # in Attention: n_state=768 (nx=n_embd)
|
|
# [switch nx => n_state from Block to Attention to keep identical to TF implem]
|
|
assert n_state % config.n_head == 0
|
|
self.register_buffer("bias", torch.tril(torch.ones(n_ctx, n_ctx)).view(1, 1, n_ctx, n_ctx))
|
|
self.n_head = config.n_head
|
|
self.split_size = n_state
|
|
self.scale = scale
|
|
|
|
self.output_attentions = config.output_attentions
|
|
|
|
self.c_attn = Conv1D(n_state * 3, nx)
|
|
self.c_proj = Conv1D(n_state, nx)
|
|
self.attn_dropout = nn.Dropout(config.attn_pdrop)
|
|
self.resid_dropout = nn.Dropout(config.resid_pdrop)
|
|
|
|
def prune_heads(self, heads):
|
|
if len(heads) == 0:
|
|
return
|
|
mask = torch.ones(self.n_head, self.split_size // self.n_head)
|
|
for head in heads:
|
|
mask[head] = 0
|
|
mask = mask.view(-1).contiguous().eq(1)
|
|
index = torch.arange(len(mask))[mask].long()
|
|
index_attn = torch.cat([index, index + self.split_size, index + (2*self.split_size)])
|
|
# Prune conv1d layers
|
|
self.c_attn = prune_conv1d_layer(self.c_attn, index_attn, dim=1)
|
|
self.c_proj = prune_conv1d_layer(self.c_proj, index, dim=0)
|
|
# Update hyper params
|
|
self.split_size = (self.split_size // self.n_head) * (self.n_head - len(heads))
|
|
self.n_head = self.n_head - len(heads)
|
|
|
|
def _attn(self, q, k, v, head_mask=None):
|
|
w = torch.matmul(q, k)
|
|
if self.scale:
|
|
w = w / math.sqrt(v.size(-1))
|
|
# w = w * self.bias + -1e9 * (1 - self.bias) # TF implem method: mask_attn_weights
|
|
# XD: self.b may be larger than w, so we need to crop it
|
|
b = self.bias[:, :, : w.size(-2), : w.size(-1)]
|
|
w = w * b + -1e9 * (1 - b)
|
|
|
|
w = nn.Softmax(dim=-1)(w)
|
|
w = self.attn_dropout(w)
|
|
|
|
# Mask heads if we want to
|
|
if head_mask is not None:
|
|
w = w * head_mask
|
|
|
|
outputs = [torch.matmul(w, v)]
|
|
if self.output_attentions:
|
|
outputs.append(w)
|
|
return outputs
|
|
|
|
def merge_heads(self, x):
|
|
x = x.permute(0, 2, 1, 3).contiguous()
|
|
new_x_shape = x.size()[:-2] + (x.size(-2) * x.size(-1),)
|
|
return x.view(*new_x_shape) # in Tensorflow implem: fct merge_states
|
|
|
|
def split_heads(self, x, k=False):
|
|
new_x_shape = x.size()[:-1] + (self.n_head, x.size(-1) // self.n_head)
|
|
x = x.view(*new_x_shape) # in Tensorflow implem: fct split_states
|
|
if k:
|
|
return x.permute(0, 2, 3, 1)
|
|
else:
|
|
return x.permute(0, 2, 1, 3)
|
|
|
|
def forward(self, x, head_mask=None):
|
|
x = self.c_attn(x)
|
|
query, key, value = x.split(self.split_size, dim=2)
|
|
query = self.split_heads(query)
|
|
key = self.split_heads(key, k=True)
|
|
value = self.split_heads(value)
|
|
|
|
attn_outputs = self._attn(query, key, value, head_mask)
|
|
a = attn_outputs[0]
|
|
|
|
a = self.merge_heads(a)
|
|
a = self.c_proj(a)
|
|
a = self.resid_dropout(a)
|
|
|
|
outputs = [a] + attn_outputs[1:]
|
|
return outputs # a, (attentions)
|
|
|
|
|
|
class MLP(nn.Module):
|
|
def __init__(self, n_state, config): # in MLP: n_state=3072 (4 * n_embd)
|
|
super(MLP, self).__init__()
|
|
nx = config.n_embd
|
|
self.c_fc = Conv1D(n_state, nx)
|
|
self.c_proj = Conv1D(nx, n_state)
|
|
self.act = ACT_FNS[config.afn]
|
|
self.dropout = nn.Dropout(config.resid_pdrop)
|
|
|
|
def forward(self, x):
|
|
h = self.act(self.c_fc(x))
|
|
h2 = self.c_proj(h)
|
|
return self.dropout(h2)
|
|
|
|
|
|
class Block(nn.Module):
|
|
def __init__(self, n_ctx, config, scale=False):
|
|
super(Block, self).__init__()
|
|
nx = config.n_embd
|
|
self.attn = Attention(nx, n_ctx, config, scale)
|
|
self.ln_1 = LayerNorm(nx, eps=config.layer_norm_epsilon)
|
|
self.mlp = MLP(4 * nx, config)
|
|
self.ln_2 = LayerNorm(nx, eps=config.layer_norm_epsilon)
|
|
|
|
def forward(self, x, head_mask=None):
|
|
attn_outputs = self.attn(x, head_mask=head_mask)
|
|
a = attn_outputs[0]
|
|
|
|
n = self.ln_1(x + a)
|
|
m = self.mlp(n)
|
|
h = self.ln_2(n + m)
|
|
|
|
outputs = [h] + attn_outputs[1:]
|
|
return outputs
|
|
|
|
|
|
class OpenAIGPTPreTrainedModel(PreTrainedModel):
|
|
""" An abstract class to handle weights initialization and
|
|
a simple interface for dowloading and loading pretrained models.
|
|
"""
|
|
config_class = OpenAIGPTConfig
|
|
pretrained_model_archive_map = OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP
|
|
load_tf_weights = load_tf_weights_in_openai_gpt
|
|
base_model_prefix = "transformer"
|
|
|
|
def __init__(self, *inputs, **kwargs):
|
|
super(OpenAIGPTPreTrainedModel, self).__init__(*inputs, **kwargs)
|
|
|
|
def init_weights(self, module):
|
|
""" Initialize the weights.
|
|
"""
|
|
if isinstance(module, (nn.Linear, nn.Embedding, Conv1D)):
|
|
# Slightly different from the TF version which uses truncated_normal for initialization
|
|
# cf https://github.com/pytorch/pytorch/pull/5617
|
|
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
|
if isinstance(module, (nn.Linear, Conv1D)) and module.bias is not None:
|
|
module.bias.data.zero_()
|
|
elif isinstance(module, LayerNorm):
|
|
module.bias.data.zero_()
|
|
module.weight.data.fill_(1.0)
|
|
|
|
|
|
OPENAI_GPT_START_DOCSTRING = r""" OpenAI GPT model was proposed in
|
|
`Improving Language Understanding by Generative Pre-Training`_
|
|
by Alec Radford, Karthik Narasimhan, Tim Salimans and Ilya Sutskever.
|
|
It's a causal (unidirectional) transformer pre-trained using language modeling on a large
|
|
corpus will long range dependencies, the Toronto Book Corpus.
|
|
|
|
This model is a PyTorch `torch.nn.Module`_ sub-class. Use it as a regular PyTorch Module and
|
|
refer to the PyTorch documentation for all matter related to general usage and behavior.
|
|
|
|
.. _`Improving Language Understanding by Generative Pre-Training`:
|
|
https://openai.com/blog/language-unsupervised/
|
|
|
|
.. _`torch.nn.Module`:
|
|
https://pytorch.org/docs/stable/nn.html#module
|
|
|
|
Parameters:
|
|
config (:class:`~pytorch_transformers.OpenAIGPTConfig`): Model configuration class with all the parameters of the model.
|
|
"""
|
|
|
|
OPENAI_GPT_INPUTS_DOCSTRING = r""" Inputs:
|
|
**input_ids**: ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
|
|
Indices of input sequence tokens in the vocabulary.
|
|
Indices can be obtained using :class:`pytorch_transformers.BPT2Tokenizer`.
|
|
See :func:`pytorch_transformers.PreTrainedTokenizer.encode` and
|
|
:func:`pytorch_transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details.
|
|
**position_ids**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
|
|
Indices of positions of each input sequence tokens in the position embeddings.
|
|
Selected in the range ``[0, config.max_position_embeddings - 1[``.
|
|
**token_type_ids**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
|
|
A parallel sequence of tokens (can be used to indicate various portions of the inputs).
|
|
The embeddings from these tokens will be summed with the respective token embeddings.
|
|
Indices are selected in the vocabulary (unlike BERT which has a specific vocabulary for segment indices).
|
|
**attention_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length)``:
|
|
Mask to avoid performing attention on padding token indices.
|
|
Mask values selected in ``[0, 1]``:
|
|
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
|
|
**head_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(num_heads,)`` or ``(num_layers, num_heads)``:
|
|
Mask to nullify selected heads of the self-attention modules.
|
|
Mask values selected in ``[0, 1]``:
|
|
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
|
|
"""
|
|
|
|
@add_start_docstrings("The bare OpenAI GPT transformer model outputing raw hidden-states without any specific head on top.",
|
|
OPENAI_GPT_START_DOCSTRING, OPENAI_GPT_INPUTS_DOCSTRING)
|
|
class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
|
|
r"""
|
|
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
|
|
**last_hidden_state**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, hidden_size)``
|
|
Sequence of hidden-states at the last layer of the model.
|
|
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
|
|
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
|
|
of shape ``(batch_size, sequence_length, hidden_size)``:
|
|
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
|
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
|
|
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
|
|
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
|
|
|
|
Examples::
|
|
|
|
>>> config = OpenAIGPTConfig.from_pretrained('openai-gpt')
|
|
>>> tokenizer = OpenAIGPTTokenizer.from_pretrained('openai-gpt')
|
|
>>> model = OpenAIGPTModel(config)
|
|
>>> input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
|
|
>>> outputs = model(input_ids)
|
|
>>> last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
|
|
|
|
"""
|
|
def __init__(self, config):
|
|
super(OpenAIGPTModel, self).__init__(config)
|
|
self.output_attentions = config.output_attentions
|
|
self.output_hidden_states = config.output_hidden_states
|
|
|
|
self.tokens_embed = nn.Embedding(config.vocab_size, config.n_embd)
|
|
self.positions_embed = nn.Embedding(config.n_positions, config.n_embd)
|
|
self.drop = nn.Dropout(config.embd_pdrop)
|
|
self.h = nn.ModuleList([Block(config.n_ctx, config, scale=True) for _ in range(config.n_layer)])
|
|
|
|
self.apply(self.init_weights)
|
|
|
|
def _resize_token_embeddings(self, new_num_tokens):
|
|
self.tokens_embed = self._get_resized_embeddings(self.tokens_embed, new_num_tokens)
|
|
return self.tokens_embed
|
|
|
|
def _prune_heads(self, heads_to_prune):
|
|
""" Prunes heads of the model.
|
|
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
|
|
"""
|
|
for layer, heads in heads_to_prune.items():
|
|
self.h[layer].attn.prune_heads(heads)
|
|
|
|
def forward(self, input_ids, position_ids=None, token_type_ids=None, head_mask=None):
|
|
if position_ids is None:
|
|
# This was used when we had a single embedding matrice from position and token embeddings
|
|
# start = self.config.vocab_size + self.config.n_special
|
|
# end = start + input_ids.size(-1)
|
|
# position_ids = torch.arange(start, end, dtype=torch.long, device=input_ids.device)
|
|
position_ids = torch.arange(input_ids.size(-1), dtype=torch.long, device=input_ids.device)
|
|
position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
|
|
|
|
# Prepare head mask if needed
|
|
# 1.0 in head_mask indicate we keep the head
|
|
# attention_probs has shape bsz x n_heads x N x N
|
|
# head_mask has shape n_layer x batch x n_heads x N x N
|
|
if head_mask is not None:
|
|
if head_mask.dim() == 1:
|
|
head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
|
|
head_mask = head_mask.expand(self.config.n_layer, -1, -1, -1, -1)
|
|
elif head_mask.dim() == 2:
|
|
head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) # We can specify head_mask for each layer
|
|
head_mask = head_mask.to(dtype=next(self.parameters()).dtype) # switch to fload if need + fp16 compatibility
|
|
else:
|
|
head_mask = [None] * self.config.n_layer
|
|
|
|
input_shape = input_ids.size()
|
|
input_ids = input_ids.view(-1, input_ids.size(-1))
|
|
position_ids = position_ids.view(-1, position_ids.size(-1))
|
|
|
|
inputs_embeds = self.tokens_embed(input_ids)
|
|
position_embeds = self.positions_embed(position_ids)
|
|
if token_type_ids is not None:
|
|
token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1))
|
|
token_type_embeds = self.tokens_embed(token_type_ids)
|
|
else:
|
|
token_type_embeds = 0
|
|
hidden_states = inputs_embeds + position_embeds + token_type_embeds
|
|
hidden_states = self.drop(hidden_states)
|
|
|
|
output_shape = input_shape + (hidden_states.size(-1),)
|
|
|
|
all_attentions = ()
|
|
all_hidden_states = ()
|
|
for i, block in enumerate(self.h):
|
|
if self.output_hidden_states:
|
|
all_hidden_states = all_hidden_states + (hidden_states.view(*output_shape),)
|
|
|
|
outputs = block(hidden_states, head_mask[i])
|
|
hidden_states = outputs[0]
|
|
if self.output_attentions:
|
|
all_attentions = all_attentions + (outputs[1],)
|
|
|
|
# Add last layer
|
|
if self.output_hidden_states:
|
|
all_hidden_states = all_hidden_states + (hidden_states.view(*output_shape),)
|
|
|
|
outputs = (hidden_states.view(*output_shape),)
|
|
if self.output_hidden_states:
|
|
outputs = outputs + (all_hidden_states,)
|
|
if self.output_attentions:
|
|
outputs = outputs + (all_attentions,)
|
|
return outputs # last hidden state, (all hidden states), (all attentions)
|
|
|
|
|
|
@add_start_docstrings("""OpenAI GPT Model transformer with a language modeling head on top
|
|
(linear layer with weights tied to the input embeddings). """, OPENAI_GPT_START_DOCSTRING, OPENAI_GPT_INPUTS_DOCSTRING)
|
|
class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
|
|
r"""
|
|
**labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
|
|
Labels for language modeling.
|
|
Note that the labels **are shifted** inside the model, i.e. you can set ``labels = input_ids``
|
|
Indices are selected in ``[-1, 0, ..., config.vocab_size]``
|
|
All labels set to ``-1`` are ignored (masked), the loss is only
|
|
computed for labels in ``[0, ..., config.vocab_size]``
|
|
|
|
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
|
|
**loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
|
|
Language modeling loss.
|
|
**prediction_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, config.vocab_size)``
|
|
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
|
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
|
|
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
|
|
of shape ``(batch_size, sequence_length, hidden_size)``:
|
|
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
|
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
|
|
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
|
|
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
|
|
|
|
Examples::
|
|
|
|
>>> config = OpenAIGPTConfig.from_pretrained('openai-gpt')
|
|
>>> tokenizer = OpenAIGPTTokenizer.from_pretrained('openai-gpt')
|
|
>>> model = OpenAIGPTLMHeadModel(config)
|
|
>>> input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
|
|
>>> outputs = model(input_ids, labels=input_ids)
|
|
>>> loss, logits = outputs[:2]
|
|
|
|
"""
|
|
def __init__(self, config):
|
|
super(OpenAIGPTLMHeadModel, self).__init__(config)
|
|
self.transformer = OpenAIGPTModel(config)
|
|
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
|
|
|
self.apply(self.init_weights)
|
|
self.tie_weights()
|
|
|
|
def tie_weights(self):
|
|
""" Make sure we are sharing the input and output embeddings.
|
|
Export to TorchScript can't handle parameter sharing so we are cloning them instead.
|
|
"""
|
|
self._tie_or_clone_weights(self.lm_head,
|
|
self.transformer.tokens_embed)
|
|
|
|
def forward(self, input_ids, position_ids=None, token_type_ids=None, labels=None, head_mask=None):
|
|
transformer_outputs = self.transformer(input_ids, position_ids=position_ids, token_type_ids=token_type_ids,
|
|
head_mask=head_mask)
|
|
hidden_states = transformer_outputs[0]
|
|
lm_logits = self.lm_head(hidden_states)
|
|
|
|
outputs = (lm_logits,) + transformer_outputs[1:]
|
|
if labels is not None:
|
|
# Shift so that tokens < n predict n
|
|
shift_logits = lm_logits[..., :-1, :].contiguous()
|
|
shift_labels = labels[..., 1:].contiguous()
|
|
# Flatten the tokens
|
|
loss_fct = CrossEntropyLoss(ignore_index=-1)
|
|
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)),
|
|
shift_labels.view(-1))
|
|
outputs = (loss,) + outputs
|
|
|
|
return outputs # (loss), lm_logits, (all hidden states), (all attentions)
|
|
|
|
|
|
@add_start_docstrings("""OpenAI GPT Model transformer with a language modeling and a multiple-choice classification
|
|
head on top e.g. for RocStories/SWAG tasks. The two heads are two linear layers.
|
|
The language modeling head has its weights tied to the input embeddings,
|
|
the classification head takes as input the input of a specified classification token index in the intput sequence).
|
|
""", OPENAI_GPT_START_DOCSTRING)
|
|
class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
|
|
r""" Inputs:
|
|
**input_ids**: ``torch.LongTensor`` of shape ``(batch_size, num_choices, sequence_length)``:
|
|
Indices of input sequence tokens in the vocabulary.
|
|
The second dimension of the input (`num_choices`) indicates the number of choices to score.
|
|
Indices can be obtained using :class:`pytorch_transformers.BPT2Tokenizer`.
|
|
See :func:`pytorch_transformers.PreTrainedTokenizer.encode` and
|
|
:func:`pytorch_transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details.
|
|
**mc_token_ids**: ``torch.LongTensor`` of shape ``(batch_size, num_choices)``:
|
|
Index of the classification token in each input sequence.
|
|
Selected in the range ``[0, input_ids.size(-1) - 1[``.
|
|
**position_ids**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, num_choices, sequence_length)``:
|
|
Indices of positions of each input sequence tokens in the position embeddings.
|
|
Selected in the range ``[0, config.max_position_embeddings - 1[``.
|
|
**token_type_ids**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, num_choices, sequence_length)``:
|
|
A parallel sequence of tokens (can be used to indicate various portions of the inputs).
|
|
The embeddings from these tokens will be summed with the respective token embeddings.
|
|
Indices are selected in the vocabulary (unlike BERT which has a specific vocabulary for segment indices).
|
|
**attention_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, num_choices, sequence_length)``:
|
|
Mask to avoid performing attention on padding token indices.
|
|
Mask values selected in ``[0, 1]``:
|
|
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
|
|
**head_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(num_heads,)`` or ``(num_layers, num_heads)``:
|
|
Mask to nullify selected heads of the self-attention modules.
|
|
Mask values selected in ``[0, 1]``:
|
|
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
|
|
**lm_labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
|
|
Labels for language modeling.
|
|
Note that the labels **are shifted** inside the model, i.e. you can set ``lm_labels = input_ids``
|
|
Indices are selected in ``[-1, 0, ..., config.vocab_size]``
|
|
All labels set to ``-1`` are ignored (masked), the loss is only
|
|
computed for labels in ``[0, ..., config.vocab_size]``
|
|
**multiple_choice_labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size)``:
|
|
Labels for computing the multiple choice classification loss.
|
|
Indices should be in ``[0, ..., num_choices]`` where `num_choices` is the size of the second dimension
|
|
of the input tensors. (see `input_ids` above)
|
|
|
|
`multiple_choice_labels`: optional multiple choice labels: ``torch.LongTensor`` of shape [batch_size]
|
|
with indices selected in [0, ..., num_choices].
|
|
|
|
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
|
|
**lm_loss**: (`optional`, returned when ``lm_labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
|
|
Language modeling loss.
|
|
**mc_loss**: (`optional`, returned when ``multiple_choice_labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
|
|
Multiple choice classification loss.
|
|
**lm_prediction_scores**: ``torch.FloatTensor`` of shape ``(batch_size, num_choices, sequence_length, config.vocab_size)``
|
|
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
|
**mc_prediction_scores**: ``torch.FloatTensor`` of shape ``(batch_size, num_choices)``
|
|
Prediction scores of the multiplechoice classification head (scores for each choice before SoftMax).
|
|
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
|
|
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
|
|
of shape ``(batch_size, sequence_length, hidden_size)``:
|
|
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
|
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
|
|
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
|
|
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
|
|
|
|
Examples::
|
|
|
|
>>> config = OpenAIGPTConfig.from_pretrained('openai-gpt')
|
|
>>> tokenizer = OpenAIGPTTokenizer.from_pretrained('openai-gpt')
|
|
>>> model = OpenAIGPTDoubleHeadsModel(config)
|
|
>>> choices = ["Hello, my dog is cute [CLS]", "Hello, my cat is cute [CLS]"] # Assume you've added [CLS] to the vocabulary
|
|
>>> input_ids = torch.tensor([tokenizer.encode(s) for s in choices]).unsqueeze(0) # Batch size 1, 2 choices
|
|
>>> mc_token_ids = torch.tensor([-1, -1]).unsqueeze(0) # Batch size 1
|
|
>>> outputs = model(input_ids, mc_token_ids)
|
|
>>> lm_prediction_scores, mc_prediction_scores = outputs[:2]
|
|
|
|
"""
|
|
def __init__(self, config):
|
|
super(OpenAIGPTDoubleHeadsModel, self).__init__(config)
|
|
|
|
self.transformer = OpenAIGPTModel(config)
|
|
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
|
self.multiple_choice_head = SequenceSummary(config)
|
|
|
|
self.apply(self.init_weights)
|
|
self.tie_weights()
|
|
|
|
def tie_weights(self):
|
|
""" Make sure we are sharing the input and output embeddings.
|
|
Export to TorchScript can't handle parameter sharing so we are cloning them instead.
|
|
"""
|
|
self._tie_or_clone_weights(self.lm_head,
|
|
self.transformer.tokens_embed)
|
|
|
|
def forward(self, input_ids, mc_token_ids=None, lm_labels=None, mc_labels=None, token_type_ids=None,
|
|
position_ids=None, head_mask=None):
|
|
transformer_outputs = self.transformer(input_ids, position_ids=position_ids, token_type_ids=token_type_ids,
|
|
head_mask=head_mask)
|
|
hidden_states = transformer_outputs[0]
|
|
|
|
lm_logits = self.lm_head(hidden_states)
|
|
mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids).squeeze(-1)
|
|
|
|
outputs = (lm_logits, mc_logits) + transformer_outputs[1:]
|
|
if mc_labels is not None:
|
|
loss_fct = CrossEntropyLoss()
|
|
loss = loss_fct(mc_logits.view(-1, mc_logits.size(-1)),
|
|
mc_labels.view(-1))
|
|
outputs = (loss,) + outputs
|
|
if lm_labels is not None:
|
|
shift_logits = lm_logits[..., :-1, :].contiguous()
|
|
shift_labels = lm_labels[..., 1:].contiguous()
|
|
loss_fct = CrossEntropyLoss(ignore_index=-1)
|
|
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)),
|
|
shift_labels.view(-1))
|
|
outputs = (loss,) + outputs
|
|
|
|
return outputs # (lm loss), (mc loss), lm logits, mc logits, (all hidden_states), (attentions)
|