From b65f07d8c0d535050fbdfd64b73baef3837751b5 Mon Sep 17 00:00:00 2001 From: thomwolf Date: Mon, 18 Feb 2019 00:55:33 +0100 Subject: [PATCH] adding examples --- ...run_gpt2_generate_unconditional_samples.py | 81 ++++++++++++++++ ...un_gpt2_interactive_conditional_samples.py | 94 +++++++++++++++++++ 2 files changed, 175 insertions(+) create mode 100644 examples/run_gpt2_generate_unconditional_samples.py create mode 100644 examples/run_gpt2_interactive_conditional_samples.py diff --git a/examples/run_gpt2_generate_unconditional_samples.py b/examples/run_gpt2_generate_unconditional_samples.py new file mode 100644 index 00000000000..7300bb2f5e8 --- /dev/null +++ b/examples/run_gpt2_generate_unconditional_samples.py @@ -0,0 +1,81 @@ +#!/usr/bin/env python3 + +import argparse +import logging + +import torch +import numpy as np + +from pytorch_pretrained_bert import GPT2LMHeadModel, GPT2Tokenizer + +logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', + datefmt = '%m/%d/%Y %H:%M:%S', + level = logging.INFO) +logger = logging.getLogger(__name__) + +def top_k_logits(logits, k): + if k == 0: + return logits + values, _ = torch.topk(logits, k) + min_values = values[:, -1] + return torch.where(logits < min_values, torch.ones_like(logits, dtype=logits.dtype) * -1e10, logits) + +def sample_sequence(model, length, start_token=None, batch_size=None, context=None, temperature=1, top_k=0, device='cuda'): + if start_token is None: + assert context is not None, 'Specify exactly one of start_token and context!' + context = torch.tensor(context, device=device) + else: + assert context is None, 'Specify exactly one of start_token and context!' + context = torch.full((batch_size, 1), start_token, device=device) + prev = context + output = context + with torch.no_grad(): + for i in range(length): + logits, past = model(prev, past=past) + logits = logits[:, -1, :] / temperature + logits = top_k_logits(logits, k=top_k) + prev = torch.multinomial(logits, 1) + output = torch.cat((output, prev), dim=1) + return output + +def sample_model(): + parser = argparse.ArgumentParser() + parser.add_argument('--model_name_or_path', type=str, default='gpt2', help='pretrained model name or path to local checkpoint') + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--nsamples", type=int, default=0) + parser.add_argument("--batch_size", type=int, default=1) + parser.add_argument("--length", type=int, default=-1) + parser.add_argument("--temperature", type=int, default=1) + parser.add_argument("--top_k", type=int, default=0) + args = parser.parse_args() + print(args) + + np.random.seed(args.seed) + torch.random.manual_seed(args.seed) + torch.cuda.manual_seed(args.seed) + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + enc = GPT2Tokenizer.from_pretrained(args.model_name_or_path) + model = GPT2LMHeadModel.from_pretrained(args.model_name_or_path) + + if args.length == -1: + args.length = model.config.n_ctx + elif args.length > model.config.n_ctx: + raise ValueError("Can't get samples longer than window size: %s" % model.config.n_ctx) + + generated = 0 + while args.nsamples == 0 or generated < args.nsamples: + out = sample_sequence( + model=model, length=args.length, + start_token=enc.encoder['<|endoftext|>'], + batch_size=args.batch_size, + temperature=args.temperature, top_k=args.top_k, device=device + ) + for i in range(args.batch_size): + generated += args.batch_size + text = enc.decode(out[i]) + print("=" * 40 + " SAMPLE " + str(generated) + " " + "=" * 40) + print(text) + +if __name__ == '__main__': + sample_model() diff --git a/examples/run_gpt2_interactive_conditional_samples.py b/examples/run_gpt2_interactive_conditional_samples.py new file mode 100644 index 00000000000..e631864a273 --- /dev/null +++ b/examples/run_gpt2_interactive_conditional_samples.py @@ -0,0 +1,94 @@ +#!/usr/bin/env python3 + +import argparse +import logging + +import torch +import numpy as np + +from pytorch_pretrained_bert import GPT2LMHeadModel, GPT2Tokenizer + +logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', + datefmt = '%m/%d/%Y %H:%M:%S', + level = logging.INFO) +logger = logging.getLogger(__name__) + +def top_k_logits(logits, k): + if k == 0: + return logits + values, _ = torch.topk(logits, k) + min_values = values[:, -1] + return torch.where(logits < min_values, torch.ones_like(logits, dtype=logits.dtype) * -1e10, logits) + +def sample_sequence(model, length, start_token=None, batch_size=None, context=None, temperature=1, top_k=0, device='cuda'): + if start_token is None: + assert context is not None, 'Specify exactly one of start_token and context!' + context = torch.tensor(context, device=device) + else: + assert context is None, 'Specify exactly one of start_token and context!' + context = torch.full((batch_size, 1), start_token, device=device) + prev = context + output = context + with torch.no_grad(): + for i in range(length): + logits, past = model(prev, past=past) + logits = logits[:, -1, :] / temperature + logits = top_k_logits(logits, k=top_k) + prev = torch.multinomial(logits, 1) + output = torch.cat((output, prev), dim=1) + return output + +def interact_model(): + parser = argparse.ArgumentParser() + parser.add_argument('--model_name_or_path', type=str, default='gpt2', help='pretrained model name or path to local checkpoint') + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--nsamples", type=int, default=1) + parser.add_argument("--batch_size", type=int, default=-1) + parser.add_argument("--length", type=int, default=-1) + parser.add_argument("--temperature", type=int, default=1) + parser.add_argument("--top_k", type=int, default=0) + args = parser.parse_args() + print(args) + + if args.batch_size is None: + args.batch_size = 1 + assert args.nsamples % args.batch_size == 0 + + np.random.seed(args.seed) + torch.random.manual_seed(args.seed) + torch.cuda.manual_seed(args.seed) + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + enc = GPT2Tokenizer.from_pretrained(args.model_name_or_path) + model = GPT2LMHeadModel.from_pretrained(args.model_name_or_path) + + if args.length == -1: + args.length = model.config.n_ctx // 2 + elif args.length > model.config.n_ctx: + raise ValueError("Can't get samples longer than window size: %s" % model.config.n_ctx) + + while True: + raw_text = input("Model prompt >>> ") + while not raw_text: + print('Prompt should not be empty!') + raw_text = input("Model prompt >>> ") + context_tokens = enc.encode(raw_text) + generated = 0 + for _ in range(args.nsamples // args.batch_size): + out = sample_sequence( + model=model, length=args.length, + context=context_tokens, + batch_size=args.batch_size, + temperature=args.temperature, top_k=args.top_k, device=device + ) + out = out[:, len(context_tokens):] + for i in range(args.batch_size): + generated += 1 + text = enc.decode(out[i]) + print("=" * 40 + " SAMPLE " + str(generated) + " " + "=" * 40) + print(text) + print("=" * 80) + +if __name__ == '__main__': + interact_model() +