adding docs and example for OpenAI GPT

This commit is contained in:
thomwolf 2019-01-09 00:12:43 +01:00
parent dc5df92fa8
commit ab90d4cddd
3 changed files with 510 additions and 43 deletions

304
examples/run_openai_gpt.py Normal file
View File

@ -0,0 +1,304 @@
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
" Run OpenAI GPT on RocStories"
import argparse
import os
import random
import logging
from sklearn.metrics import accuracy_score
from sklearn.utils import shuffle
# from analysis import rocstories as rocstories_analysis
# from datasets import rocstories
# from model_pytorch import DoubleHeadModel, load_openai_pretrained_model
# from opt import OpenAIAdam
# from text_utils import TextEncoder
# from utils import (encode_dataset, iter_data,
# ResultLogger, make_path)
# from loss import MultipleChoiceLossCompute
import numpy as np
import torch
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
from torch.utils.data.distributed import DistributedSampler
from pytorch_pretrained_bert.tokenization_openai import OpenAIGPTTokenizer
from pytorch_pretrained_bert.modeling_openai import OpenAIGPTDoubleHeadsModel
from pytorch_pretrained_bert.optimization_openai import OpenAIAdam
from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
datefmt = '%m/%d/%Y %H:%M:%S',
level = logging.INFO)
logger = logging.getLogger(__name__)
def transform_roc(X1, X2, X3):
n_batch = len(X1)
xmb = np.zeros((n_batch, 2, n_ctx, 2), dtype=np.int32)
mmb = np.zeros((n_batch, 2, n_ctx), dtype=np.float32)
start = encoder['_start_']
delimiter = encoder['_delimiter_']
for i, (x1, x2, x3), in enumerate(zip(X1, X2, X3)):
x12 = [start] + x1[:max_len] + [delimiter] + x2[:max_len] + [clf_token]
x13 = [start] + x1[:max_len] + [delimiter] + x3[:max_len] + [clf_token]
l12 = len(x12)
l13 = len(x13)
xmb[i, 0, :l12, 0] = x12
xmb[i, 1, :l13, 0] = x13
mmb[i, 0, :l12] = 1
mmb[i, 1, :l13] = 1
# Position information that is added to the input embeddings in the TransformerModel
xmb[:, :, :, 1] = np.arange(n_vocab + n_special, n_vocab + n_special + n_ctx)
return xmb, mmb
def iter_apply(Xs, Ms, Ys):
# fns = [lambda x: np.concatenate(x, 0), lambda x: float(np.sum(x))]
logits = []
cost = 0
with torch.no_grad():
dh_model.eval()
for xmb, mmb, ymb in iter_data(Xs, Ms, Ys, n_batch=n_batch_train, truncate=False, verbose=True):
n = len(xmb)
XMB = torch.tensor(xmb, dtype=torch.long).to(device)
YMB = torch.tensor(ymb, dtype=torch.long).to(device)
MMB = torch.tensor(mmb).to(device)
_, clf_logits = dh_model(XMB)
clf_logits *= n
clf_losses = compute_loss_fct(XMB, YMB, MMB, clf_logits, only_return_losses=True)
clf_losses *= n
logits.append(clf_logits.to("cpu").numpy())
cost += clf_losses.sum().item()
logits = np.concatenate(logits, 0)
return logits, cost
def iter_predict(Xs, Ms):
logits = []
with torch.no_grad():
dh_model.eval()
for xmb, mmb in iter_data(Xs, Ms, n_batch=n_batch_train, truncate=False, verbose=True):
n = len(xmb)
XMB = torch.tensor(xmb, dtype=torch.long).to(device)
MMB = torch.tensor(mmb).to(device)
_, clf_logits = dh_model(XMB)
logits.append(clf_logits.to("cpu").numpy())
logits = np.concatenate(logits, 0)
return logits
def log(save_dir, desc):
global best_score
print("Logging")
tr_logits, tr_cost = iter_apply(trX[:n_valid], trM[:n_valid], trY[:n_valid])
va_logits, va_cost = iter_apply(vaX, vaM, vaY)
tr_cost = tr_cost / len(trY[:n_valid])
va_cost = va_cost / n_valid
tr_acc = accuracy_score(trY[:n_valid], np.argmax(tr_logits, 1)) * 100.
va_acc = accuracy_score(vaY, np.argmax(va_logits, 1)) * 100.
logger.log(n_epochs=n_epochs, n_updates=n_updates, tr_cost=tr_cost, va_cost=va_cost, tr_acc=tr_acc, va_acc=va_acc)
print('%d %d %.3f %.3f %.2f %.2f' % (n_epochs, n_updates, tr_cost, va_cost, tr_acc, va_acc))
if submit:
score = va_acc
if score > best_score:
best_score = score
path = os.path.join(save_dir, desc, 'best_params')
torch.save(dh_model.state_dict(), make_path(path))
def predict(dataset, submission_dir):
filename = filenames[dataset]
pred_fn = pred_fns[dataset]
label_decoder = label_decoders[dataset]
predictions = pred_fn(iter_predict(teX, teM))
if label_decoder is not None:
predictions = [label_decoder[prediction] for prediction in predictions]
path = os.path.join(submission_dir, filename)
os.makedirs(os.path.dirname(path), exist_ok=True)
with open(path, 'w') as f:
f.write('{}\t{}\n'.format('index', 'prediction'))
for i, prediction in enumerate(predictions):
f.write('{}\t{}\n'.format(i, prediction))
def run_epoch():
for xmb, mmb, ymb in iter_data(*shuffle(trX, trM, trYt, random_state=np.random),
n_batch=n_batch_train, truncate=True, verbose=True):
global n_updates
dh_model.train()
XMB = torch.tensor(xmb, dtype=torch.long).to(device)
YMB = torch.tensor(ymb, dtype=torch.long).to(device)
MMB = torch.tensor(mmb).to(device)
lm_logits, clf_logits = dh_model(XMB)
compute_loss_fct(XMB, YMB, MMB, clf_logits, lm_logits)
n_updates += 1
if n_updates in [1000, 2000, 4000, 8000, 16000, 32000] and n_epochs == 0:
log(save_dir, desc)
argmax = lambda x: np.argmax(x, 1)
pred_fns = {
'rocstories': argmax,
}
filenames = {
'rocstories': 'ROCStories.tsv',
}
label_decoders = {
'rocstories': None,
}
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--desc', type=str, help="Description")
parser.add_argument('--dataset', type=str)
parser.add_argument('--log_dir', type=str, default='log/')
parser.add_argument('--save_dir', type=str, default='save/')
parser.add_argument('--data_dir', type=str, default='data/')
parser.add_argument('--submission_dir', type=str, default='submission/')
parser.add_argument('--submit', action='store_true')
parser.add_argument('--analysis', action='store_true')
parser.add_argument('--seed', type=int, default=42)
parser.add_argument('--n_iter', type=int, default=3)
parser.add_argument('--n_batch', type=int, default=8)
parser.add_argument('--max_grad_norm', type=int, default=1)
parser.add_argument('--lr', type=float, default=6.25e-5)
parser.add_argument('--lr_warmup', type=float, default=0.002)
parser.add_argument('--n_ctx', type=int, default=512)
parser.add_argument('--n_embd', type=int, default=768)
parser.add_argument('--n_head', type=int, default=12)
parser.add_argument('--n_layer', type=int, default=12)
parser.add_argument('--embd_pdrop', type=float, default=0.1)
parser.add_argument('--attn_pdrop', type=float, default=0.1)
parser.add_argument('--resid_pdrop', type=float, default=0.1)
parser.add_argument('--clf_pdrop', type=float, default=0.1)
parser.add_argument('--l2', type=float, default=0.01)
parser.add_argument('--vector_l2', action='store_true')
parser.add_argument('--opt', type=str, default='adam')
parser.add_argument('--afn', type=str, default='gelu')
parser.add_argument('--lr_schedule', type=str, default='warmup_linear')
parser.add_argument('--encoder_path', type=str, default='model/encoder_bpe_40000.json')
parser.add_argument('--bpe_path', type=str, default='model/vocab_40000.bpe')
parser.add_argument('--n_transfer', type=int, default=12)
parser.add_argument('--lm_coef', type=float, default=0.5)
parser.add_argument('--b1', type=float, default=0.9)
parser.add_argument('--b2', type=float, default=0.999)
parser.add_argument('--e', type=float, default=1e-8)
parser.add_argument('--n_valid', type=int, default=374)
args = parser.parse_args()
print(args)
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)
# Constants
submit = args.submit
dataset = args.dataset
n_ctx = args.n_ctx
save_dir = args.save_dir
desc = args.desc
data_dir = args.data_dir
log_dir = args.log_dir
submission_dir = args.submission_dir
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
n_gpu = torch.cuda.device_count()
print("device", device, "n_gpu", n_gpu)
logger = ResultLogger(path=os.path.join(log_dir, '{}.jsonl'.format(desc)), **args.__dict__)
text_encoder = TextEncoder(args.encoder_path, args.bpe_path)
encoder = text_encoder.encoder
n_vocab = len(text_encoder.encoder)
print("Encoding dataset...")
((trX1, trX2, trX3, trY),
(vaX1, vaX2, vaX3, vaY),
(teX1, teX2, teX3)) = encode_dataset(*rocstories(data_dir, n_valid=args.n_valid),
encoder=text_encoder)
encoder['_start_'] = len(encoder)
encoder['_delimiter_'] = len(encoder)
encoder['_classify_'] = len(encoder)
clf_token = encoder['_classify_']
n_special = 3
max_len = n_ctx // 2 - 2
n_ctx = min(max(
[len(x1[:max_len]) + max(len(x2[:max_len]),
len(x3[:max_len])) for x1, x2, x3 in zip(trX1, trX2, trX3)]
+ [len(x1[:max_len]) + max(len(x2[:max_len]),
len(x3[:max_len])) for x1, x2, x3 in zip(vaX1, vaX2, vaX3)]
+ [len(x1[:max_len]) + max(len(x2[:max_len]),
len(x3[:max_len])) for x1, x2, x3 in zip(teX1, teX2, teX3)]
) + 3, n_ctx)
vocab = n_vocab + n_special + n_ctx
trX, trM = transform_roc(trX1, trX2, trX3)
vaX, vaM = transform_roc(vaX1, vaX2, vaX3)
if submit:
teX, teM = transform_roc(teX1, teX2, teX3)
n_train = len(trY)
n_valid = len(vaY)
n_batch_train = args.n_batch * max(n_gpu, 1)
n_updates_total = (n_train // n_batch_train) * args.n_iter
dh_model = DoubleHeadModel(args, clf_token, 'multiple_choice', vocab, n_ctx)
criterion = nn.CrossEntropyLoss(reduce=False)
model_opt = OpenAIAdam(dh_model.parameters(),
lr=args.lr,
schedule=args.lr_schedule,
warmup=args.lr_warmup,
t_total=n_updates_total,
b1=args.b1,
b2=args.b2,
e=args.e,
l2=args.l2,
vector_l2=args.vector_l2,
max_grad_norm=args.max_grad_norm)
compute_loss_fct = MultipleChoiceLossCompute(criterion,
criterion,
args.lm_coef,
model_opt)
load_openai_pretrained_model(dh_model.transformer, n_ctx=n_ctx, n_special=n_special)
dh_model.to(device)
dh_model = nn.DataParallel(dh_model)
n_updates = 0
n_epochs = 0
if dataset != 'stsb':
trYt = trY
if submit:
path = os.path.join(save_dir, desc, 'best_params')
torch.save(dh_model.state_dict(), make_path(path))
best_score = 0
for i in range(args.n_iter):
print("running epoch", i)
run_epoch()
n_epochs += 1
log(save_dir, desc)
if submit:
path = os.path.join(save_dir, desc, 'best_params')
dh_model.load_state_dict(torch.load(path))
predict(dataset, args.submission_dir)
if args.analysis:
rocstories_analysis(data_dir, os.path.join(args.submission_dir, 'ROCStories.tsv'),
os.path.join(log_dir, 'rocstories.jsonl'))

View File

@ -659,10 +659,10 @@ class BertForPreTraining(BertPreTrainedModel):
selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
input sequence length in the current batch. It's the mask that we typically use for attention when
a batch has varying length sentences.
`masked_lm_labels`: masked language modeling labels: torch.LongTensor of shape [batch_size, sequence_length]
`masked_lm_labels`: optional masked language modeling labels: torch.LongTensor of shape [batch_size, sequence_length]
with indices selected in [-1, 0, ..., vocab_size]. All labels set to -1 are ignored (masked), the loss
is only computed for the labels set in [0, ..., vocab_size]
`next_sentence_label`: next sentence classification loss: torch.LongTensor of shape [batch_size]
`next_sentence_label`: optional next sentence classification loss: torch.LongTensor of shape [batch_size]
with indices selected in [0, 1].
0 => next sentence is the continuation, 1 => next sentence is a random sentence.

View File

@ -149,19 +149,19 @@ class Conv1D(nn.Module):
class Attention(nn.Module):
def __init__(self, nx, n_ctx, cfg, scale=False):
def __init__(self, nx, n_ctx, config, scale=False):
super(Attention, self).__init__()
n_state = nx # in Attention: n_state=768 (nx=n_embd)
# [switch nx => n_state from Block to Attention to keep identical to TF implem]
assert n_state % cfg.n_head == 0
assert n_state % config.n_head == 0
self.register_buffer('b', torch.tril(torch.ones(n_ctx, n_ctx)).view(1, 1, n_ctx, n_ctx))
self.n_head = cfg.n_head
self.n_head = config.n_head
self.split_size = n_state
self.scale = scale
self.c_attn = Conv1D(n_state * 3, 1, nx)
self.c_proj = Conv1D(n_state, 1, nx)
self.attn_dropout = nn.Dropout(cfg.attn_pdrop)
self.resid_dropout = nn.Dropout(cfg.resid_pdrop)
self.attn_dropout = nn.Dropout(config.attn_pdrop)
self.resid_dropout = nn.Dropout(config.resid_pdrop)
def _attn(self, q, k, v):
w = torch.matmul(q, k)
@ -203,13 +203,13 @@ class Attention(nn.Module):
class MLP(nn.Module):
def __init__(self, n_state, cfg): # in MLP: n_state=3072 (4 * n_embd)
def __init__(self, n_state, config): # in MLP: n_state=3072 (4 * n_embd)
super(MLP, self).__init__()
nx = cfg.n_embd
nx = config.n_embd
self.c_fc = Conv1D(n_state, 1, nx)
self.c_proj = Conv1D(nx, 1, n_state)
self.act = ACT_FNS[cfg.afn]
self.dropout = nn.Dropout(cfg.resid_pdrop)
self.act = ACT_FNS[config.afn]
self.dropout = nn.Dropout(config.resid_pdrop)
def forward(self, x):
h = self.act(self.c_fc(x))
@ -218,12 +218,12 @@ class MLP(nn.Module):
class Block(nn.Module):
def __init__(self, n_ctx, cfg, scale=False):
def __init__(self, n_ctx, config, scale=False):
super(Block, self).__init__()
nx = cfg.n_embd
self.attn = Attention(nx, n_ctx, cfg, scale)
nx = config.n_embd
self.attn = Attention(nx, n_ctx, config, scale)
self.ln_1 = LayerNorm(nx)
self.mlp = MLP(4 * nx, cfg)
self.mlp = MLP(4 * nx, config)
self.ln_2 = LayerNorm(nx)
def forward(self, x):
@ -237,9 +237,9 @@ class Block(nn.Module):
class OpenAIGPTLMHead(nn.Module):
""" Language Model Head for the transformer """
def __init__(self, model_embeddings_weights, cfg):
def __init__(self, model_embeddings_weights, config):
super(OpenAIGPTLMHead, self).__init__()
self.n_embd = cfg.n_embd
self.n_embd = config.n_embd
self.set_embeddings_weights(model_embeddings_weights)
def set_embeddings_weights(self, model_embeddings_weights):
@ -257,12 +257,12 @@ class OpenAIGPTLMHead(nn.Module):
class OpenAIGPTMultipleChoiceHead(nn.Module):
""" Classifier Head for the transformer """
def __init__(self, cfg):
def __init__(self, config):
super(OpenAIGPTMultipleChoiceHead, self).__init__()
self.n_embd = cfg.n_embd
self.n_embd = config.n_embd
# self.multiple_choice_token = multiple_choice_token
self.dropout = nn.Dropout2d(cfg.resid_pdrop) # To reproduce the noise_shape parameter of TF implementation
self.linear = nn.Linear(cfg.n_embd, 1)
self.dropout = nn.Dropout2d(config.resid_pdrop) # To reproduce the noise_shape parameter of TF implementation
self.linear = nn.Linear(config.n_embd, 1)
nn.init.normal_(self.linear.weight, std = 0.02)
nn.init.normal_(self.linear.bias, 0)
@ -428,15 +428,63 @@ class OpenAIGPTPreTrainedModel(nn.Module):
class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
""" OpenAI GPT model """
"""OpenAI GPT model ("Improving Language Understanding by Generative Pre-Training").
def __init__(self, cfg):
super(OpenAIGPTModel, self).__init__(cfg)
total_embeddings_size = cfg.vocab_size + cfg.n_special + cfg.n_ctx
self.embed = nn.Embedding(total_embeddings_size, cfg.n_embd)
self.drop = nn.Dropout(cfg.embd_pdrop)
block = Block(cfg.n_ctx, cfg, scale=True)
self.h = nn.ModuleList([copy.deepcopy(block) for _ in range(cfg.n_layer)])
The main implementation difference between BERT and the OpenAI is the use, in OpenAI GPT, of a single embedding matrix
to store the word, special ([SEP], [CLS]...) and position embeddings.
The embeddings are ordered as follow in the word embeddings matrice:
[0, ----------------------
... -> word embeddings
config.vocab_size - 1, ______________________
config.vocab_size,
... -> special embeddings
config.vocab_size + config.n_special - 1, ______________________
config.vocab_size + config.n_special,
... -> position embeddings
total_num_embeddings - 1] ______________________
where total_num_embeddings can be obtained as config.total_num_embeddings and is:
total_num_embeddings = config.vocab_size + config.n_special + config.n_ctx
You should use the associate indices to index the embeddings.
The special embeddings ([SEP], [CLS]...) are not pre-trained and need to be trained during the fine-tuning if you use them.
The number of special embeddings can be controled using the `set_num_special_tokens(num_special_tokens)` function.
Params:
config: a OpenAIGPTConfig class instance with the configuration to build a new model
Inputs:
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] (or more generally [d_1, ..., d_n, sequence_length]
were d_1 ... d_n are arbitrary dimensions) with the word BPE token indices selected in the range [0, config.vocab_size[
`position_ids`: an optional torch.LongTensor with the same shape as input_ids
with the position indices (selected in the range [config.vocab_size + config.n_special, config.vocab_size + config.n_special + config.n_ctx - 1[.
`token_type_ids`: an optional torch.LongTensor with the same shape as input_ids
You can use it to add a third embedding (the previous two being the word and position embeddings)
to each token in the sentence.
Outputs:
`hidden_states`: the encoded-hidden-states at the top of the model
as a torch.FloatTensor of size [batch_size, sequence_length, hidden_size]
(or more generally [d_1, ..., d_n, hidden_size] were d_1 ... d_n are the dimension of input_ids)
Example usage:
```python
# Already been converted into BPE token ids
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
config = modeling_openai.OpenAIGPTConfig()
model = modeling_openai.OpenAIGPTModel(config)
hidden_states = model(input_ids)
```
"""
def __init__(self, config):
super(OpenAIGPTModel, self).__init__(config)
total_embeddings_size = config.vocab_size + config.n_special + config.n_ctx
self.embed = nn.Embedding(total_embeddings_size, config.n_embd)
self.drop = nn.Dropout(config.embd_pdrop)
block = Block(config.n_ctx, config, scale=True)
self.h = nn.ModuleList([copy.deepcopy(block) for _ in range(config.n_layer)])
self.apply(self.init_weights)
# nn.init.normal_(self.embed.weight, std=0.02)
@ -480,11 +528,67 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
return hidden_states.view(*input_shape, hidden_states.size(-1))
class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
""" OpenAI GPT model with language model and classification heads """
def __init__(self, cfg):
super(OpenAIGPTLMHeadModel, self).__init__(cfg)
self.transformer = OpenAIGPTModel(cfg)
self.lm_head = OpenAIGPTLMHead(self.transformer.embed.weight, cfg)
"""OpenAI GPT model with a Language Modeling head ("Improving Language Understanding by Generative Pre-Training").
There are two main implementation differences between BERT and the OpenAI GPT:
- the use of an LM loss in OpenAI GPT which means the Transformer is trained to predict the NEXT token for each input token
vs. predict the SAME token for BERT (i.e. you need to shift your labels to the right)
- the use, in OpenAI GPT, of a single embedding matrix to store the word, special ([SEP], [CLS]...) and position embeddings.
The embeddings are ordered as follow in the word embeddings matrice:
[0, ----------------------
... -> word embeddings
config.vocab_size - 1, ______________________
config.vocab_size,
... -> special embeddings
config.vocab_size + config.n_special - 1, ______________________
config.vocab_size + config.n_special,
... -> position embeddings
total_num_embeddings - 1] ______________________
where total_num_embeddings can be obtained as config.total_num_embeddings and is:
total_num_embeddings = config.vocab_size + config.n_special + config.n_ctx
You should use these indices to index the word, special and position embeddings.
The special embeddings ([SEP], [CLS]...) are not pre-trained and need to be trained during the fine-tuning if you use them.
The number of special embeddings can be controled using the `set_num_special_tokens(num_special_tokens)` function.
Params:
config: a OpenAIGPTConfig class instance with the configuration to build a new model
Inputs:
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] (or more generally [d_1, ..., d_n, sequence_length]
were d_1 ... d_n are arbitrary dimensions) with the word BPE token indices selected in the range [0, config.vocab_size[
`position_ids`: an optional torch.LongTensor with the same shape as input_ids
with the position indices (selected in the range [config.vocab_size + config.n_special, config.vocab_size + config.n_special + config.n_ctx - 1[.
`token_type_ids`: an optional torch.LongTensor with the same shape as input_ids
You can use it to add a third embedding (the previous two being the word and position embeddings)
to each token in the sentence.
`lm_labels`: optional language modeling labels: torch.LongTensor of shape [batch_size, sequence_length]
with indices selected in [-1, 0, ..., vocab_size]. All labels set to -1 are ignored (masked), the loss
is only computed for the labels set in [0, ..., vocab_size]
Outputs:
if `lm_labels` is not `None`:
Outputs the language modeling loss.
else:
`lm_logits`: the language modeling logits as a torch.FloatTensor of size [batch_size, sequence_length, total_num_embeddings]
(or more generally [d_1, ..., d_n, total_num_embeddings] were d_1 ... d_n are the dimension of input_ids)
Example usage:
```python
# Already been converted into BPE token ids
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
config = modeling_openai.OpenAIGPTConfig()
model = modeling_openai.OpenAIGPTLMHeadModel(config)
lm_logits = model(input_ids)
```
"""
def __init__(self, config):
super(OpenAIGPTLMHeadModel, self).__init__(config)
self.transformer = OpenAIGPTModel(config)
self.lm_head = OpenAIGPTLMHead(self.transformer.embed.weight, config)
self.apply(self.init_weights)
def set_num_special_tokens(self, num_special_tokens):
@ -502,12 +606,74 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
return lm_logits
class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
""" OpenAI GPT model with language model and classification heads """
def __init__(self, cfg):
super(OpenAIGPTDoubleHeadsModel, self).__init__(cfg)
self.transformer = OpenAIGPTModel(cfg)
self.lm_head = OpenAIGPTLMHead(self.transformer.embed.weight, cfg)
self.multiple_choice_head = OpenAIGPTMultipleChoiceHead(cfg)
"""OpenAI GPT model with a Language Modeling and a Multiple Choice heads ("Improving Language Understanding by Generative Pre-Training").
There are two main implementation differences between BERT and the OpenAI GPT:
- the use of an LM loss in OpenAI GPT which means the Transformer is trained to predict the NEXT token for each input token
vs. predict the SAME token for BERT (i.e. you need to shift your labels to the right)
- the use, in OpenAI GPT, of a single embedding matrix to store the word, special ([SEP], [CLS]...) and position embeddings.
The embeddings are ordered as follow in the word embeddings matrice:
[0, ----------------------
... -> word embeddings
config.vocab_size - 1, ______________________
config.vocab_size,
... -> special embeddings
config.vocab_size + config.n_special - 1, ______________________
config.vocab_size + config.n_special,
... -> position embeddings
total_num_embeddings - 1] ______________________
where total_num_embeddings can be obtained as config.total_num_embeddings and is:
total_num_embeddings = config.vocab_size + config.n_special + config.n_ctx
You should use these indices to index the word, special and position embeddings.
The special embeddings ([SEP], [CLS]...) are not pre-trained and need to be trained during the fine-tuning if you use them.
The number of special embeddings can be controled using the `set_num_special_tokens(num_special_tokens)` function.
Params:
config: a OpenAIGPTConfig class instance with the configuration to build a new model
Inputs:
`input_ids`: a torch.LongTensor of shape [batch_size, num_choices, sequence_length]
with the word BPE token indices selected in the range [0, config.vocab_size[
`multiple_choice_token_mask`: a torch.LongTensor of shape [batch_size, num_choices, sequence_length]
with a value of 1 were the last hidden state is (usually the [CLS] token) and 0 otherwise.
`position_ids`: an optional torch.LongTensor with the same shape as input_ids
with the position indices (selected in the range [config.vocab_size + config.n_special,
config.vocab_size + config.n_special + config.n_ctx - 1[.
`token_type_ids`: an optional torch.LongTensor with the same shape as input_ids
You can use it to add a third embedding (the previous two being the word and position embeddings)
to each token in the sentence.
`lm_labels`: optional language modeling labels: torch.LongTensor of shape [batch_size, num_choices, sequence_length]
with indices selected in [-1, 0, ..., total_num_embeddings]. All labels set to -1 are ignored (masked), the loss
is only computed for the labels set in [0, ..., total_num_embeddings]
`multiple_choice_labels`: optional multiple choice labels: torch.LongTensor of shape [batch_size]
with indices selected in [0, ..., num_choices].
Outputs:
if `lm_labels` and `multiple_choice_labels` are not `None`:
Outputs a tuple of losses with the language modeling loss and the multiple choice loss.
else: a tuple with
`lm_logits`: the language modeling logits as a torch.FloatTensor of size [batch_size, num_choices, sequence_length, total_num_embeddings]
`multiple_choice_logits`: the multiple choice logits as a torch.FloatTensor of size [batch_size, num_choices]
Example usage:
```python
# Already been converted into BPE token ids
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
multiple_choice_token_mask = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
config = modeling_openai.OpenAIGPTConfig()
model = modeling_openai.OpenAIGPTLMHeadModel(config)
lm_logits, multiple_choice_logits = model(input_ids, multiple_choice_token_mask)
```
"""
def __init__(self, config):
super(OpenAIGPTDoubleHeadsModel, self).__init__(config)
self.transformer = OpenAIGPTModel(config)
self.lm_head = OpenAIGPTLMHead(self.transformer.embed.weight, config)
self.multiple_choice_head = OpenAIGPTMultipleChoiceHead(config)
self.apply(self.init_weights)
def set_num_special_tokens(self, num_special_tokens):
@ -517,9 +683,6 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
def forward(self, input_ids, multiple_choice_token_mask, position_ids=None, token_type_ids=None,
lm_labels=None, multiple_choice_labels=None):
""" input_ids should be of shape B x C x S
lm_labels can be masked using the -1 value
"""
hidden_states = self.transformer(input_ids, position_ids, token_type_ids)
lm_logits = self.lm_head(hidden_states)
multiple_choice_logits = self.multiple_choice_head(hidden_states, multiple_choice_token_mask)