mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
adding docs and example for OpenAI GPT
This commit is contained in:
parent
dc5df92fa8
commit
ab90d4cddd
304
examples/run_openai_gpt.py
Normal file
304
examples/run_openai_gpt.py
Normal file
@ -0,0 +1,304 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team.
|
||||
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
" Run OpenAI GPT on RocStories"
|
||||
import argparse
|
||||
import os
|
||||
import random
|
||||
import logging
|
||||
|
||||
from sklearn.metrics import accuracy_score
|
||||
from sklearn.utils import shuffle
|
||||
|
||||
# from analysis import rocstories as rocstories_analysis
|
||||
# from datasets import rocstories
|
||||
# from model_pytorch import DoubleHeadModel, load_openai_pretrained_model
|
||||
# from opt import OpenAIAdam
|
||||
# from text_utils import TextEncoder
|
||||
# from utils import (encode_dataset, iter_data,
|
||||
# ResultLogger, make_path)
|
||||
# from loss import MultipleChoiceLossCompute
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
|
||||
from pytorch_pretrained_bert.tokenization_openai import OpenAIGPTTokenizer
|
||||
from pytorch_pretrained_bert.modeling_openai import OpenAIGPTDoubleHeadsModel
|
||||
from pytorch_pretrained_bert.optimization_openai import OpenAIAdam
|
||||
from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE
|
||||
|
||||
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
||||
datefmt = '%m/%d/%Y %H:%M:%S',
|
||||
level = logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def transform_roc(X1, X2, X3):
|
||||
n_batch = len(X1)
|
||||
xmb = np.zeros((n_batch, 2, n_ctx, 2), dtype=np.int32)
|
||||
mmb = np.zeros((n_batch, 2, n_ctx), dtype=np.float32)
|
||||
start = encoder['_start_']
|
||||
delimiter = encoder['_delimiter_']
|
||||
for i, (x1, x2, x3), in enumerate(zip(X1, X2, X3)):
|
||||
x12 = [start] + x1[:max_len] + [delimiter] + x2[:max_len] + [clf_token]
|
||||
x13 = [start] + x1[:max_len] + [delimiter] + x3[:max_len] + [clf_token]
|
||||
l12 = len(x12)
|
||||
l13 = len(x13)
|
||||
xmb[i, 0, :l12, 0] = x12
|
||||
xmb[i, 1, :l13, 0] = x13
|
||||
mmb[i, 0, :l12] = 1
|
||||
mmb[i, 1, :l13] = 1
|
||||
# Position information that is added to the input embeddings in the TransformerModel
|
||||
xmb[:, :, :, 1] = np.arange(n_vocab + n_special, n_vocab + n_special + n_ctx)
|
||||
return xmb, mmb
|
||||
|
||||
|
||||
def iter_apply(Xs, Ms, Ys):
|
||||
# fns = [lambda x: np.concatenate(x, 0), lambda x: float(np.sum(x))]
|
||||
logits = []
|
||||
cost = 0
|
||||
with torch.no_grad():
|
||||
dh_model.eval()
|
||||
for xmb, mmb, ymb in iter_data(Xs, Ms, Ys, n_batch=n_batch_train, truncate=False, verbose=True):
|
||||
n = len(xmb)
|
||||
XMB = torch.tensor(xmb, dtype=torch.long).to(device)
|
||||
YMB = torch.tensor(ymb, dtype=torch.long).to(device)
|
||||
MMB = torch.tensor(mmb).to(device)
|
||||
_, clf_logits = dh_model(XMB)
|
||||
clf_logits *= n
|
||||
clf_losses = compute_loss_fct(XMB, YMB, MMB, clf_logits, only_return_losses=True)
|
||||
clf_losses *= n
|
||||
logits.append(clf_logits.to("cpu").numpy())
|
||||
cost += clf_losses.sum().item()
|
||||
logits = np.concatenate(logits, 0)
|
||||
return logits, cost
|
||||
|
||||
|
||||
def iter_predict(Xs, Ms):
|
||||
logits = []
|
||||
with torch.no_grad():
|
||||
dh_model.eval()
|
||||
for xmb, mmb in iter_data(Xs, Ms, n_batch=n_batch_train, truncate=False, verbose=True):
|
||||
n = len(xmb)
|
||||
XMB = torch.tensor(xmb, dtype=torch.long).to(device)
|
||||
MMB = torch.tensor(mmb).to(device)
|
||||
_, clf_logits = dh_model(XMB)
|
||||
logits.append(clf_logits.to("cpu").numpy())
|
||||
logits = np.concatenate(logits, 0)
|
||||
return logits
|
||||
|
||||
|
||||
def log(save_dir, desc):
|
||||
global best_score
|
||||
print("Logging")
|
||||
tr_logits, tr_cost = iter_apply(trX[:n_valid], trM[:n_valid], trY[:n_valid])
|
||||
va_logits, va_cost = iter_apply(vaX, vaM, vaY)
|
||||
tr_cost = tr_cost / len(trY[:n_valid])
|
||||
va_cost = va_cost / n_valid
|
||||
tr_acc = accuracy_score(trY[:n_valid], np.argmax(tr_logits, 1)) * 100.
|
||||
va_acc = accuracy_score(vaY, np.argmax(va_logits, 1)) * 100.
|
||||
logger.log(n_epochs=n_epochs, n_updates=n_updates, tr_cost=tr_cost, va_cost=va_cost, tr_acc=tr_acc, va_acc=va_acc)
|
||||
print('%d %d %.3f %.3f %.2f %.2f' % (n_epochs, n_updates, tr_cost, va_cost, tr_acc, va_acc))
|
||||
if submit:
|
||||
score = va_acc
|
||||
if score > best_score:
|
||||
best_score = score
|
||||
path = os.path.join(save_dir, desc, 'best_params')
|
||||
torch.save(dh_model.state_dict(), make_path(path))
|
||||
|
||||
|
||||
def predict(dataset, submission_dir):
|
||||
filename = filenames[dataset]
|
||||
pred_fn = pred_fns[dataset]
|
||||
label_decoder = label_decoders[dataset]
|
||||
predictions = pred_fn(iter_predict(teX, teM))
|
||||
if label_decoder is not None:
|
||||
predictions = [label_decoder[prediction] for prediction in predictions]
|
||||
path = os.path.join(submission_dir, filename)
|
||||
os.makedirs(os.path.dirname(path), exist_ok=True)
|
||||
with open(path, 'w') as f:
|
||||
f.write('{}\t{}\n'.format('index', 'prediction'))
|
||||
for i, prediction in enumerate(predictions):
|
||||
f.write('{}\t{}\n'.format(i, prediction))
|
||||
|
||||
|
||||
def run_epoch():
|
||||
for xmb, mmb, ymb in iter_data(*shuffle(trX, trM, trYt, random_state=np.random),
|
||||
n_batch=n_batch_train, truncate=True, verbose=True):
|
||||
global n_updates
|
||||
dh_model.train()
|
||||
XMB = torch.tensor(xmb, dtype=torch.long).to(device)
|
||||
YMB = torch.tensor(ymb, dtype=torch.long).to(device)
|
||||
MMB = torch.tensor(mmb).to(device)
|
||||
lm_logits, clf_logits = dh_model(XMB)
|
||||
compute_loss_fct(XMB, YMB, MMB, clf_logits, lm_logits)
|
||||
n_updates += 1
|
||||
if n_updates in [1000, 2000, 4000, 8000, 16000, 32000] and n_epochs == 0:
|
||||
log(save_dir, desc)
|
||||
|
||||
|
||||
argmax = lambda x: np.argmax(x, 1)
|
||||
|
||||
pred_fns = {
|
||||
'rocstories': argmax,
|
||||
}
|
||||
|
||||
filenames = {
|
||||
'rocstories': 'ROCStories.tsv',
|
||||
}
|
||||
|
||||
label_decoders = {
|
||||
'rocstories': None,
|
||||
}
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--desc', type=str, help="Description")
|
||||
parser.add_argument('--dataset', type=str)
|
||||
parser.add_argument('--log_dir', type=str, default='log/')
|
||||
parser.add_argument('--save_dir', type=str, default='save/')
|
||||
parser.add_argument('--data_dir', type=str, default='data/')
|
||||
parser.add_argument('--submission_dir', type=str, default='submission/')
|
||||
parser.add_argument('--submit', action='store_true')
|
||||
parser.add_argument('--analysis', action='store_true')
|
||||
parser.add_argument('--seed', type=int, default=42)
|
||||
parser.add_argument('--n_iter', type=int, default=3)
|
||||
parser.add_argument('--n_batch', type=int, default=8)
|
||||
parser.add_argument('--max_grad_norm', type=int, default=1)
|
||||
parser.add_argument('--lr', type=float, default=6.25e-5)
|
||||
parser.add_argument('--lr_warmup', type=float, default=0.002)
|
||||
parser.add_argument('--n_ctx', type=int, default=512)
|
||||
parser.add_argument('--n_embd', type=int, default=768)
|
||||
parser.add_argument('--n_head', type=int, default=12)
|
||||
parser.add_argument('--n_layer', type=int, default=12)
|
||||
parser.add_argument('--embd_pdrop', type=float, default=0.1)
|
||||
parser.add_argument('--attn_pdrop', type=float, default=0.1)
|
||||
parser.add_argument('--resid_pdrop', type=float, default=0.1)
|
||||
parser.add_argument('--clf_pdrop', type=float, default=0.1)
|
||||
parser.add_argument('--l2', type=float, default=0.01)
|
||||
parser.add_argument('--vector_l2', action='store_true')
|
||||
parser.add_argument('--opt', type=str, default='adam')
|
||||
parser.add_argument('--afn', type=str, default='gelu')
|
||||
parser.add_argument('--lr_schedule', type=str, default='warmup_linear')
|
||||
parser.add_argument('--encoder_path', type=str, default='model/encoder_bpe_40000.json')
|
||||
parser.add_argument('--bpe_path', type=str, default='model/vocab_40000.bpe')
|
||||
parser.add_argument('--n_transfer', type=int, default=12)
|
||||
parser.add_argument('--lm_coef', type=float, default=0.5)
|
||||
parser.add_argument('--b1', type=float, default=0.9)
|
||||
parser.add_argument('--b2', type=float, default=0.999)
|
||||
parser.add_argument('--e', type=float, default=1e-8)
|
||||
parser.add_argument('--n_valid', type=int, default=374)
|
||||
|
||||
args = parser.parse_args()
|
||||
print(args)
|
||||
|
||||
random.seed(args.seed)
|
||||
np.random.seed(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
torch.cuda.manual_seed_all(args.seed)
|
||||
|
||||
# Constants
|
||||
submit = args.submit
|
||||
dataset = args.dataset
|
||||
n_ctx = args.n_ctx
|
||||
save_dir = args.save_dir
|
||||
desc = args.desc
|
||||
data_dir = args.data_dir
|
||||
log_dir = args.log_dir
|
||||
submission_dir = args.submission_dir
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
n_gpu = torch.cuda.device_count()
|
||||
print("device", device, "n_gpu", n_gpu)
|
||||
|
||||
logger = ResultLogger(path=os.path.join(log_dir, '{}.jsonl'.format(desc)), **args.__dict__)
|
||||
text_encoder = TextEncoder(args.encoder_path, args.bpe_path)
|
||||
encoder = text_encoder.encoder
|
||||
n_vocab = len(text_encoder.encoder)
|
||||
|
||||
print("Encoding dataset...")
|
||||
((trX1, trX2, trX3, trY),
|
||||
(vaX1, vaX2, vaX3, vaY),
|
||||
(teX1, teX2, teX3)) = encode_dataset(*rocstories(data_dir, n_valid=args.n_valid),
|
||||
encoder=text_encoder)
|
||||
encoder['_start_'] = len(encoder)
|
||||
encoder['_delimiter_'] = len(encoder)
|
||||
encoder['_classify_'] = len(encoder)
|
||||
clf_token = encoder['_classify_']
|
||||
n_special = 3
|
||||
max_len = n_ctx // 2 - 2
|
||||
n_ctx = min(max(
|
||||
[len(x1[:max_len]) + max(len(x2[:max_len]),
|
||||
len(x3[:max_len])) for x1, x2, x3 in zip(trX1, trX2, trX3)]
|
||||
+ [len(x1[:max_len]) + max(len(x2[:max_len]),
|
||||
len(x3[:max_len])) for x1, x2, x3 in zip(vaX1, vaX2, vaX3)]
|
||||
+ [len(x1[:max_len]) + max(len(x2[:max_len]),
|
||||
len(x3[:max_len])) for x1, x2, x3 in zip(teX1, teX2, teX3)]
|
||||
) + 3, n_ctx)
|
||||
vocab = n_vocab + n_special + n_ctx
|
||||
trX, trM = transform_roc(trX1, trX2, trX3)
|
||||
vaX, vaM = transform_roc(vaX1, vaX2, vaX3)
|
||||
if submit:
|
||||
teX, teM = transform_roc(teX1, teX2, teX3)
|
||||
|
||||
n_train = len(trY)
|
||||
n_valid = len(vaY)
|
||||
n_batch_train = args.n_batch * max(n_gpu, 1)
|
||||
n_updates_total = (n_train // n_batch_train) * args.n_iter
|
||||
|
||||
dh_model = DoubleHeadModel(args, clf_token, 'multiple_choice', vocab, n_ctx)
|
||||
|
||||
criterion = nn.CrossEntropyLoss(reduce=False)
|
||||
model_opt = OpenAIAdam(dh_model.parameters(),
|
||||
lr=args.lr,
|
||||
schedule=args.lr_schedule,
|
||||
warmup=args.lr_warmup,
|
||||
t_total=n_updates_total,
|
||||
b1=args.b1,
|
||||
b2=args.b2,
|
||||
e=args.e,
|
||||
l2=args.l2,
|
||||
vector_l2=args.vector_l2,
|
||||
max_grad_norm=args.max_grad_norm)
|
||||
compute_loss_fct = MultipleChoiceLossCompute(criterion,
|
||||
criterion,
|
||||
args.lm_coef,
|
||||
model_opt)
|
||||
load_openai_pretrained_model(dh_model.transformer, n_ctx=n_ctx, n_special=n_special)
|
||||
|
||||
dh_model.to(device)
|
||||
dh_model = nn.DataParallel(dh_model)
|
||||
|
||||
n_updates = 0
|
||||
n_epochs = 0
|
||||
if dataset != 'stsb':
|
||||
trYt = trY
|
||||
if submit:
|
||||
path = os.path.join(save_dir, desc, 'best_params')
|
||||
torch.save(dh_model.state_dict(), make_path(path))
|
||||
best_score = 0
|
||||
for i in range(args.n_iter):
|
||||
print("running epoch", i)
|
||||
run_epoch()
|
||||
n_epochs += 1
|
||||
log(save_dir, desc)
|
||||
if submit:
|
||||
path = os.path.join(save_dir, desc, 'best_params')
|
||||
dh_model.load_state_dict(torch.load(path))
|
||||
predict(dataset, args.submission_dir)
|
||||
if args.analysis:
|
||||
rocstories_analysis(data_dir, os.path.join(args.submission_dir, 'ROCStories.tsv'),
|
||||
os.path.join(log_dir, 'rocstories.jsonl'))
|
@ -659,10 +659,10 @@ class BertForPreTraining(BertPreTrainedModel):
|
||||
selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
|
||||
input sequence length in the current batch. It's the mask that we typically use for attention when
|
||||
a batch has varying length sentences.
|
||||
`masked_lm_labels`: masked language modeling labels: torch.LongTensor of shape [batch_size, sequence_length]
|
||||
`masked_lm_labels`: optional masked language modeling labels: torch.LongTensor of shape [batch_size, sequence_length]
|
||||
with indices selected in [-1, 0, ..., vocab_size]. All labels set to -1 are ignored (masked), the loss
|
||||
is only computed for the labels set in [0, ..., vocab_size]
|
||||
`next_sentence_label`: next sentence classification loss: torch.LongTensor of shape [batch_size]
|
||||
`next_sentence_label`: optional next sentence classification loss: torch.LongTensor of shape [batch_size]
|
||||
with indices selected in [0, 1].
|
||||
0 => next sentence is the continuation, 1 => next sentence is a random sentence.
|
||||
|
||||
|
@ -149,19 +149,19 @@ class Conv1D(nn.Module):
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, nx, n_ctx, cfg, scale=False):
|
||||
def __init__(self, nx, n_ctx, config, scale=False):
|
||||
super(Attention, self).__init__()
|
||||
n_state = nx # in Attention: n_state=768 (nx=n_embd)
|
||||
# [switch nx => n_state from Block to Attention to keep identical to TF implem]
|
||||
assert n_state % cfg.n_head == 0
|
||||
assert n_state % config.n_head == 0
|
||||
self.register_buffer('b', torch.tril(torch.ones(n_ctx, n_ctx)).view(1, 1, n_ctx, n_ctx))
|
||||
self.n_head = cfg.n_head
|
||||
self.n_head = config.n_head
|
||||
self.split_size = n_state
|
||||
self.scale = scale
|
||||
self.c_attn = Conv1D(n_state * 3, 1, nx)
|
||||
self.c_proj = Conv1D(n_state, 1, nx)
|
||||
self.attn_dropout = nn.Dropout(cfg.attn_pdrop)
|
||||
self.resid_dropout = nn.Dropout(cfg.resid_pdrop)
|
||||
self.attn_dropout = nn.Dropout(config.attn_pdrop)
|
||||
self.resid_dropout = nn.Dropout(config.resid_pdrop)
|
||||
|
||||
def _attn(self, q, k, v):
|
||||
w = torch.matmul(q, k)
|
||||
@ -203,13 +203,13 @@ class Attention(nn.Module):
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, n_state, cfg): # in MLP: n_state=3072 (4 * n_embd)
|
||||
def __init__(self, n_state, config): # in MLP: n_state=3072 (4 * n_embd)
|
||||
super(MLP, self).__init__()
|
||||
nx = cfg.n_embd
|
||||
nx = config.n_embd
|
||||
self.c_fc = Conv1D(n_state, 1, nx)
|
||||
self.c_proj = Conv1D(nx, 1, n_state)
|
||||
self.act = ACT_FNS[cfg.afn]
|
||||
self.dropout = nn.Dropout(cfg.resid_pdrop)
|
||||
self.act = ACT_FNS[config.afn]
|
||||
self.dropout = nn.Dropout(config.resid_pdrop)
|
||||
|
||||
def forward(self, x):
|
||||
h = self.act(self.c_fc(x))
|
||||
@ -218,12 +218,12 @@ class MLP(nn.Module):
|
||||
|
||||
|
||||
class Block(nn.Module):
|
||||
def __init__(self, n_ctx, cfg, scale=False):
|
||||
def __init__(self, n_ctx, config, scale=False):
|
||||
super(Block, self).__init__()
|
||||
nx = cfg.n_embd
|
||||
self.attn = Attention(nx, n_ctx, cfg, scale)
|
||||
nx = config.n_embd
|
||||
self.attn = Attention(nx, n_ctx, config, scale)
|
||||
self.ln_1 = LayerNorm(nx)
|
||||
self.mlp = MLP(4 * nx, cfg)
|
||||
self.mlp = MLP(4 * nx, config)
|
||||
self.ln_2 = LayerNorm(nx)
|
||||
|
||||
def forward(self, x):
|
||||
@ -237,9 +237,9 @@ class Block(nn.Module):
|
||||
class OpenAIGPTLMHead(nn.Module):
|
||||
""" Language Model Head for the transformer """
|
||||
|
||||
def __init__(self, model_embeddings_weights, cfg):
|
||||
def __init__(self, model_embeddings_weights, config):
|
||||
super(OpenAIGPTLMHead, self).__init__()
|
||||
self.n_embd = cfg.n_embd
|
||||
self.n_embd = config.n_embd
|
||||
self.set_embeddings_weights(model_embeddings_weights)
|
||||
|
||||
def set_embeddings_weights(self, model_embeddings_weights):
|
||||
@ -257,12 +257,12 @@ class OpenAIGPTLMHead(nn.Module):
|
||||
class OpenAIGPTMultipleChoiceHead(nn.Module):
|
||||
""" Classifier Head for the transformer """
|
||||
|
||||
def __init__(self, cfg):
|
||||
def __init__(self, config):
|
||||
super(OpenAIGPTMultipleChoiceHead, self).__init__()
|
||||
self.n_embd = cfg.n_embd
|
||||
self.n_embd = config.n_embd
|
||||
# self.multiple_choice_token = multiple_choice_token
|
||||
self.dropout = nn.Dropout2d(cfg.resid_pdrop) # To reproduce the noise_shape parameter of TF implementation
|
||||
self.linear = nn.Linear(cfg.n_embd, 1)
|
||||
self.dropout = nn.Dropout2d(config.resid_pdrop) # To reproduce the noise_shape parameter of TF implementation
|
||||
self.linear = nn.Linear(config.n_embd, 1)
|
||||
|
||||
nn.init.normal_(self.linear.weight, std = 0.02)
|
||||
nn.init.normal_(self.linear.bias, 0)
|
||||
@ -428,15 +428,63 @@ class OpenAIGPTPreTrainedModel(nn.Module):
|
||||
|
||||
|
||||
class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
|
||||
""" OpenAI GPT model """
|
||||
"""OpenAI GPT model ("Improving Language Understanding by Generative Pre-Training").
|
||||
|
||||
def __init__(self, cfg):
|
||||
super(OpenAIGPTModel, self).__init__(cfg)
|
||||
total_embeddings_size = cfg.vocab_size + cfg.n_special + cfg.n_ctx
|
||||
self.embed = nn.Embedding(total_embeddings_size, cfg.n_embd)
|
||||
self.drop = nn.Dropout(cfg.embd_pdrop)
|
||||
block = Block(cfg.n_ctx, cfg, scale=True)
|
||||
self.h = nn.ModuleList([copy.deepcopy(block) for _ in range(cfg.n_layer)])
|
||||
The main implementation difference between BERT and the OpenAI is the use, in OpenAI GPT, of a single embedding matrix
|
||||
to store the word, special ([SEP], [CLS]...) and position embeddings.
|
||||
The embeddings are ordered as follow in the word embeddings matrice:
|
||||
[0, ----------------------
|
||||
... -> word embeddings
|
||||
config.vocab_size - 1, ______________________
|
||||
config.vocab_size,
|
||||
... -> special embeddings
|
||||
config.vocab_size + config.n_special - 1, ______________________
|
||||
config.vocab_size + config.n_special,
|
||||
... -> position embeddings
|
||||
total_num_embeddings - 1] ______________________
|
||||
|
||||
where total_num_embeddings can be obtained as config.total_num_embeddings and is:
|
||||
total_num_embeddings = config.vocab_size + config.n_special + config.n_ctx
|
||||
You should use the associate indices to index the embeddings.
|
||||
|
||||
The special embeddings ([SEP], [CLS]...) are not pre-trained and need to be trained during the fine-tuning if you use them.
|
||||
The number of special embeddings can be controled using the `set_num_special_tokens(num_special_tokens)` function.
|
||||
|
||||
Params:
|
||||
config: a OpenAIGPTConfig class instance with the configuration to build a new model
|
||||
|
||||
Inputs:
|
||||
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] (or more generally [d_1, ..., d_n, sequence_length]
|
||||
were d_1 ... d_n are arbitrary dimensions) with the word BPE token indices selected in the range [0, config.vocab_size[
|
||||
`position_ids`: an optional torch.LongTensor with the same shape as input_ids
|
||||
with the position indices (selected in the range [config.vocab_size + config.n_special, config.vocab_size + config.n_special + config.n_ctx - 1[.
|
||||
`token_type_ids`: an optional torch.LongTensor with the same shape as input_ids
|
||||
You can use it to add a third embedding (the previous two being the word and position embeddings)
|
||||
to each token in the sentence.
|
||||
|
||||
Outputs:
|
||||
`hidden_states`: the encoded-hidden-states at the top of the model
|
||||
as a torch.FloatTensor of size [batch_size, sequence_length, hidden_size]
|
||||
(or more generally [d_1, ..., d_n, hidden_size] were d_1 ... d_n are the dimension of input_ids)
|
||||
|
||||
Example usage:
|
||||
```python
|
||||
# Already been converted into BPE token ids
|
||||
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
|
||||
|
||||
config = modeling_openai.OpenAIGPTConfig()
|
||||
|
||||
model = modeling_openai.OpenAIGPTModel(config)
|
||||
hidden_states = model(input_ids)
|
||||
```
|
||||
"""
|
||||
def __init__(self, config):
|
||||
super(OpenAIGPTModel, self).__init__(config)
|
||||
total_embeddings_size = config.vocab_size + config.n_special + config.n_ctx
|
||||
self.embed = nn.Embedding(total_embeddings_size, config.n_embd)
|
||||
self.drop = nn.Dropout(config.embd_pdrop)
|
||||
block = Block(config.n_ctx, config, scale=True)
|
||||
self.h = nn.ModuleList([copy.deepcopy(block) for _ in range(config.n_layer)])
|
||||
|
||||
self.apply(self.init_weights)
|
||||
# nn.init.normal_(self.embed.weight, std=0.02)
|
||||
@ -480,11 +528,67 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
|
||||
return hidden_states.view(*input_shape, hidden_states.size(-1))
|
||||
|
||||
class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
|
||||
""" OpenAI GPT model with language model and classification heads """
|
||||
def __init__(self, cfg):
|
||||
super(OpenAIGPTLMHeadModel, self).__init__(cfg)
|
||||
self.transformer = OpenAIGPTModel(cfg)
|
||||
self.lm_head = OpenAIGPTLMHead(self.transformer.embed.weight, cfg)
|
||||
"""OpenAI GPT model with a Language Modeling head ("Improving Language Understanding by Generative Pre-Training").
|
||||
|
||||
There are two main implementation differences between BERT and the OpenAI GPT:
|
||||
- the use of an LM loss in OpenAI GPT which means the Transformer is trained to predict the NEXT token for each input token
|
||||
vs. predict the SAME token for BERT (i.e. you need to shift your labels to the right)
|
||||
- the use, in OpenAI GPT, of a single embedding matrix to store the word, special ([SEP], [CLS]...) and position embeddings.
|
||||
The embeddings are ordered as follow in the word embeddings matrice:
|
||||
[0, ----------------------
|
||||
... -> word embeddings
|
||||
config.vocab_size - 1, ______________________
|
||||
config.vocab_size,
|
||||
... -> special embeddings
|
||||
config.vocab_size + config.n_special - 1, ______________________
|
||||
config.vocab_size + config.n_special,
|
||||
... -> position embeddings
|
||||
total_num_embeddings - 1] ______________________
|
||||
|
||||
where total_num_embeddings can be obtained as config.total_num_embeddings and is:
|
||||
total_num_embeddings = config.vocab_size + config.n_special + config.n_ctx
|
||||
You should use these indices to index the word, special and position embeddings.
|
||||
|
||||
The special embeddings ([SEP], [CLS]...) are not pre-trained and need to be trained during the fine-tuning if you use them.
|
||||
The number of special embeddings can be controled using the `set_num_special_tokens(num_special_tokens)` function.
|
||||
|
||||
Params:
|
||||
config: a OpenAIGPTConfig class instance with the configuration to build a new model
|
||||
|
||||
Inputs:
|
||||
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] (or more generally [d_1, ..., d_n, sequence_length]
|
||||
were d_1 ... d_n are arbitrary dimensions) with the word BPE token indices selected in the range [0, config.vocab_size[
|
||||
`position_ids`: an optional torch.LongTensor with the same shape as input_ids
|
||||
with the position indices (selected in the range [config.vocab_size + config.n_special, config.vocab_size + config.n_special + config.n_ctx - 1[.
|
||||
`token_type_ids`: an optional torch.LongTensor with the same shape as input_ids
|
||||
You can use it to add a third embedding (the previous two being the word and position embeddings)
|
||||
to each token in the sentence.
|
||||
`lm_labels`: optional language modeling labels: torch.LongTensor of shape [batch_size, sequence_length]
|
||||
with indices selected in [-1, 0, ..., vocab_size]. All labels set to -1 are ignored (masked), the loss
|
||||
is only computed for the labels set in [0, ..., vocab_size]
|
||||
|
||||
Outputs:
|
||||
if `lm_labels` is not `None`:
|
||||
Outputs the language modeling loss.
|
||||
else:
|
||||
`lm_logits`: the language modeling logits as a torch.FloatTensor of size [batch_size, sequence_length, total_num_embeddings]
|
||||
(or more generally [d_1, ..., d_n, total_num_embeddings] were d_1 ... d_n are the dimension of input_ids)
|
||||
|
||||
Example usage:
|
||||
```python
|
||||
# Already been converted into BPE token ids
|
||||
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
|
||||
|
||||
config = modeling_openai.OpenAIGPTConfig()
|
||||
|
||||
model = modeling_openai.OpenAIGPTLMHeadModel(config)
|
||||
lm_logits = model(input_ids)
|
||||
```
|
||||
"""
|
||||
def __init__(self, config):
|
||||
super(OpenAIGPTLMHeadModel, self).__init__(config)
|
||||
self.transformer = OpenAIGPTModel(config)
|
||||
self.lm_head = OpenAIGPTLMHead(self.transformer.embed.weight, config)
|
||||
self.apply(self.init_weights)
|
||||
|
||||
def set_num_special_tokens(self, num_special_tokens):
|
||||
@ -502,12 +606,74 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
|
||||
return lm_logits
|
||||
|
||||
class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
|
||||
""" OpenAI GPT model with language model and classification heads """
|
||||
def __init__(self, cfg):
|
||||
super(OpenAIGPTDoubleHeadsModel, self).__init__(cfg)
|
||||
self.transformer = OpenAIGPTModel(cfg)
|
||||
self.lm_head = OpenAIGPTLMHead(self.transformer.embed.weight, cfg)
|
||||
self.multiple_choice_head = OpenAIGPTMultipleChoiceHead(cfg)
|
||||
"""OpenAI GPT model with a Language Modeling and a Multiple Choice heads ("Improving Language Understanding by Generative Pre-Training").
|
||||
|
||||
There are two main implementation differences between BERT and the OpenAI GPT:
|
||||
- the use of an LM loss in OpenAI GPT which means the Transformer is trained to predict the NEXT token for each input token
|
||||
vs. predict the SAME token for BERT (i.e. you need to shift your labels to the right)
|
||||
- the use, in OpenAI GPT, of a single embedding matrix to store the word, special ([SEP], [CLS]...) and position embeddings.
|
||||
The embeddings are ordered as follow in the word embeddings matrice:
|
||||
[0, ----------------------
|
||||
... -> word embeddings
|
||||
config.vocab_size - 1, ______________________
|
||||
config.vocab_size,
|
||||
... -> special embeddings
|
||||
config.vocab_size + config.n_special - 1, ______________________
|
||||
config.vocab_size + config.n_special,
|
||||
... -> position embeddings
|
||||
total_num_embeddings - 1] ______________________
|
||||
|
||||
where total_num_embeddings can be obtained as config.total_num_embeddings and is:
|
||||
total_num_embeddings = config.vocab_size + config.n_special + config.n_ctx
|
||||
You should use these indices to index the word, special and position embeddings.
|
||||
|
||||
The special embeddings ([SEP], [CLS]...) are not pre-trained and need to be trained during the fine-tuning if you use them.
|
||||
The number of special embeddings can be controled using the `set_num_special_tokens(num_special_tokens)` function.
|
||||
|
||||
Params:
|
||||
config: a OpenAIGPTConfig class instance with the configuration to build a new model
|
||||
|
||||
Inputs:
|
||||
`input_ids`: a torch.LongTensor of shape [batch_size, num_choices, sequence_length]
|
||||
with the word BPE token indices selected in the range [0, config.vocab_size[
|
||||
`multiple_choice_token_mask`: a torch.LongTensor of shape [batch_size, num_choices, sequence_length]
|
||||
with a value of 1 were the last hidden state is (usually the [CLS] token) and 0 otherwise.
|
||||
`position_ids`: an optional torch.LongTensor with the same shape as input_ids
|
||||
with the position indices (selected in the range [config.vocab_size + config.n_special,
|
||||
config.vocab_size + config.n_special + config.n_ctx - 1[.
|
||||
`token_type_ids`: an optional torch.LongTensor with the same shape as input_ids
|
||||
You can use it to add a third embedding (the previous two being the word and position embeddings)
|
||||
to each token in the sentence.
|
||||
`lm_labels`: optional language modeling labels: torch.LongTensor of shape [batch_size, num_choices, sequence_length]
|
||||
with indices selected in [-1, 0, ..., total_num_embeddings]. All labels set to -1 are ignored (masked), the loss
|
||||
is only computed for the labels set in [0, ..., total_num_embeddings]
|
||||
`multiple_choice_labels`: optional multiple choice labels: torch.LongTensor of shape [batch_size]
|
||||
with indices selected in [0, ..., num_choices].
|
||||
|
||||
Outputs:
|
||||
if `lm_labels` and `multiple_choice_labels` are not `None`:
|
||||
Outputs a tuple of losses with the language modeling loss and the multiple choice loss.
|
||||
else: a tuple with
|
||||
`lm_logits`: the language modeling logits as a torch.FloatTensor of size [batch_size, num_choices, sequence_length, total_num_embeddings]
|
||||
`multiple_choice_logits`: the multiple choice logits as a torch.FloatTensor of size [batch_size, num_choices]
|
||||
|
||||
Example usage:
|
||||
```python
|
||||
# Already been converted into BPE token ids
|
||||
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
|
||||
multiple_choice_token_mask = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
|
||||
|
||||
config = modeling_openai.OpenAIGPTConfig()
|
||||
|
||||
model = modeling_openai.OpenAIGPTLMHeadModel(config)
|
||||
lm_logits, multiple_choice_logits = model(input_ids, multiple_choice_token_mask)
|
||||
```
|
||||
"""
|
||||
def __init__(self, config):
|
||||
super(OpenAIGPTDoubleHeadsModel, self).__init__(config)
|
||||
self.transformer = OpenAIGPTModel(config)
|
||||
self.lm_head = OpenAIGPTLMHead(self.transformer.embed.weight, config)
|
||||
self.multiple_choice_head = OpenAIGPTMultipleChoiceHead(config)
|
||||
self.apply(self.init_weights)
|
||||
|
||||
def set_num_special_tokens(self, num_special_tokens):
|
||||
@ -517,9 +683,6 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
|
||||
|
||||
def forward(self, input_ids, multiple_choice_token_mask, position_ids=None, token_type_ids=None,
|
||||
lm_labels=None, multiple_choice_labels=None):
|
||||
""" input_ids should be of shape B x C x S
|
||||
lm_labels can be masked using the -1 value
|
||||
"""
|
||||
hidden_states = self.transformer(input_ids, position_ids, token_type_ids)
|
||||
lm_logits = self.lm_head(hidden_states)
|
||||
multiple_choice_logits = self.multiple_choice_head(hidden_states, multiple_choice_token_mask)
|
||||
|
Loading…
Reference in New Issue
Block a user