transformers/examples/run_gpt2_generate_unconditional_samples.py
2019-02-18 01:28:18 +01:00

89 lines
3.3 KiB
Python

#!/usr/bin/env python3
import argparse
import logging
import torch
import torch.nn.functional as F
import numpy as np
from tqdm import trange
from pytorch_pretrained_bert import GPT2LMHeadModel, GPT2Tokenizer
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
datefmt = '%m/%d/%Y %H:%M:%S',
level = logging.INFO)
logger = logging.getLogger(__name__)
def top_k_logits(logits, k):
if k == 0:
return logits
values, _ = torch.topk(logits, k)
min_values = values[:, -1]
return torch.where(logits < min_values, torch.ones_like(logits, dtype=logits.dtype) * -1e10, logits)
def sample_sequence(model, length, start_token=None, batch_size=None, context=None, temperature=1, top_k=0, device='cuda'):
if start_token is None:
assert context is not None, 'Specify exactly one of start_token and context!'
context = torch.tensor(context, device=device, dtype=torch.long)
else:
assert context is None, 'Specify exactly one of start_token and context!'
context = torch.full((batch_size, 1), start_token, device=device, dtype=torch.long)
prev = context
output = context
past = None
with torch.no_grad():
for i in trange(length):
logits, past = model(prev, past=past)
logits = logits[:, -1, :] / temperature
logits = top_k_logits(logits, k=top_k)
log_probs = F.softmax(logits, dim=-1)
prev = torch.multinomial(log_probs, num_samples=1)
output = torch.cat((output, prev), dim=1)
return output
def sample_model():
parser = argparse.ArgumentParser()
parser.add_argument('--model_name_or_path', type=str, default='gpt2', help='pretrained model name or path to local checkpoint')
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--nsamples", type=int, default=0)
parser.add_argument("--batch_size", type=int, default=1)
parser.add_argument("--length", type=int, default=-1)
parser.add_argument("--temperature", type=int, default=1)
parser.add_argument("--top_k", type=int, default=0)
args = parser.parse_args()
print(args)
np.random.seed(args.seed)
torch.random.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
enc = GPT2Tokenizer.from_pretrained(args.model_name_or_path)
model = GPT2LMHeadModel.from_pretrained(args.model_name_or_path)
model.to(device)
model.eval()
if args.length == -1:
args.length = model.config.n_ctx
elif args.length > model.config.n_ctx:
raise ValueError("Can't get samples longer than window size: %s" % model.config.n_ctx)
generated = 0
while args.nsamples == 0 or generated < args.nsamples:
out = sample_sequence(
model=model, length=args.length,
start_token=enc.encoder['<|endoftext|>'],
batch_size=args.batch_size,
temperature=args.temperature, top_k=args.top_k, device=device
)
out = out.tolist()
for i in range(args.batch_size):
generated += args.batch_size
text = enc.decode(out[i])
print("=" * 40 + " SAMPLE " + str(generated) + " " + "=" * 40)
print(text)
if __name__ == '__main__':
sample_model()