Merge pull request #287 from huggingface/gpt2

Gpt2
This commit is contained in:
Thomas Wolf 2019-02-18 11:38:05 +01:00 committed by GitHub
commit 0856a231c0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 1588 additions and 16 deletions

203
README.md
View File

@ -5,8 +5,9 @@
This repository contains op-for-op PyTorch reimplementations, pre-trained models and fine-tuning examples for:
- [Google's BERT model](https://github.com/google-research/bert),
- [OpenAI's GPT model](https://github.com/openai/finetune-transformer-lm), and
- [Google/CMU's Transformer-XL model](https://github.com/kimiyoung/transformer-xl).
- [OpenAI's GPT model](https://github.com/openai/finetune-transformer-lm),
- [Google/CMU's Transformer-XL model](https://github.com/kimiyoung/transformer-xl), and
- [OpenAI's GPT-2 model](https://blog.openai.com/better-language-models/),
These implementations have been tested on several datasets (see the examples) and should match the performances of the associated TensorFlow implementations (e.g. ~91 F1 on SQuAD for BERT, ~88 F1 on RocStories for OpenAI GPT and ~18.3 perplexity on WikiText 103 for the Transformer-XL). You can find more details in the [Examples](#examples) section below.
@ -21,6 +22,10 @@ This PyTorch implementation of OpenAI GPT is an adaptation of the [PyTorch imple
**Google/CMU's Transformer-XL** was released together with the paper [Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context](http://arxiv.org/abs/1901.02860) by Zihang Dai*, Zhilin Yang*, Yiming Yang, Jaime Carbonell, Quoc V. Le, Ruslan Salakhutdinov.
This PyTorch implementation of Transformer-XL is an adaptation of the original [PyTorch implementation](https://github.com/kimiyoung/transformer-xl) which has been slightly modified to match the performances of the TensforFlow implementation and allow to re-use the pretrained weights. A command-line interface is provided to convert TensorFlow checkpoints in PyTorch models.
**OpenAI GPT-2** was released together with the paper [Language Models are Unsupervised Multitask Learners](https://blog.openai.com/better-language-models/) by Alec Radford*, JeffreyWu*, Rewon Child, David Luan, Dario Amodei** and Ilya Sutskever**.
This PyTorch implementation of OpenAI GPT-2 is an adaptation of the [OpenAI's implementation](https://github.com/openai/gpt-2) and is provided with [OpenAI's pre-trained model](https://github.com/openai/gpt-2 and a command-line interface that was used to convert the TensorFlow checkpoint in PyTorch.
## Content
| Section | Description |
@ -98,6 +103,11 @@ This package comprises the following classes that can be imported in Python and
- [`TransfoXLModel`](./pytorch_pretrained_bert/modeling_transfo_xl.py#L974) - Transformer-XL model which outputs the last hidden state and memory cells (**fully pre-trained**),
- [`TransfoXLLMHeadModel`](./pytorch_pretrained_bert/modeling_transfo_xl.py#L1236) - Transformer-XL with the tied adaptive softmax head on top for language modeling which outputs the logits/loss and memory cells (**fully pre-trained**),
- Three **OpenAI GPT-2** PyTorch models (`torch.nn.Module`) with pre-trained weights (in the [`modeling_gpt2.py`](./pytorch_pretrained_bert/modeling_gpt2.py) file):
- [`GPT2Model`](./pytorch_pretrained_bert/modeling_gpt2.py#L537) - raw OpenAI GPT-2 Transformer model (**fully pre-trained**),
- [`GPT2LMHeadModel`](./pytorch_pretrained_bert/modeling_gpt2.py#L691) - OpenAI GPT-2 Transformer with the tied language modeling head on top (**fully pre-trained**),
- [`GPT2DoubleHeadsModel`](./pytorch_pretrained_bert/modeling_gpt2.py#L752) - OpenAI GPT-2 Transformer with the tied language modeling head and a multiple choice classification head on top (OpenAI GPT-2 Transformer is **pre-trained**, the multiple choice classification head **is only initialized and has to be trained**),
- Tokenizers for **BERT** (using word-piece) (in the [`tokenization.py`](./pytorch_pretrained_bert/tokenization.py) file):
- `BasicTokenizer` - basic tokenization (punctuation splitting, lower casing, etc.),
- `WordpieceTokenizer` - WordPiece tokenization,
@ -109,6 +119,9 @@ This package comprises the following classes that can be imported in Python and
- Tokenizer for **Transformer-XL** (word tokens ordered by frequency for adaptive softmax) (in the [`tokenization_transfo_xl.py`](./pytorch_pretrained_bert/tokenization_transfo_xl.py) file):
- `OpenAIGPTTokenizer` - perform word tokenization and can order words by frequency in a corpus for use in an adaptive softmax.
- Tokenizer for **OpenAI GPT-2** (using byte-level Byte-Pair-Encoding) (in the [`tokenization_gpt2.py`](./pytorch_pretrained_bert/tokenization_gpt2.py) file):
- `GPT2Tokenizer` - perform byte-level Byte-Pair-Encoding (BPE) tokenization.
- Optimizer for **BERT** (in the [`optimization.py`](./pytorch_pretrained_bert/optimization.py) file):
- `BertAdam` - Bert version of Adam algorithm with weight decay fix, warmup and linear decay of the learning rate.
@ -135,6 +148,9 @@ The repository further comprises:
- One example on how to use **Transformer-XL** (in the [`examples` folder](./examples)):
- [`run_transfo_xl.py`](./examples/run_transfo_xl.py) - Show how to load and evaluate a pre-trained model of `TransfoXLLMHeadModel` on WikiText 103.
- One example on how to use **OpenAI GPT-2** in the unconditional and interactive mode (in the [`examples` folder](./examples)):
- [`run_gpt2.py`](./examples/run_gpt2.py) - Show how to use OpenAI GPT-2 an instance of `GPT2LMHeadModel` to generate text (same as the original OpenAI GPT-2 examples).
These examples are detailed in the [Examples](#examples) section of this readme.
- Three notebooks that were used to check that the TensorFlow and PyTorch models behave identically (in the [`notebooks` folder](./notebooks)):
@ -367,6 +383,67 @@ predicted_token = tokenizer.convert_ids_to_tokens([predicted_index])[0]
assert predicted_token == 'who'
```
### OpenAI GPT-2
Here is a quick-start example using `GPT2Tokenizer`, `GPT2Model` and `GPT2LMHeadModel` class with OpenAI's pre-trained model. See the [doc section](#doc) below for all the details on these classes.
First let's prepare a tokenized input with `GPT2Tokenizer`
```python
import torch
from pytorch_pretrained_bert import GPT2Tokenizer, GPT2Model, GPT2LMHeadModel
# OPTIONAL: if you want to have more information on what's happening, activate the logger as follows
import logging
logging.basicConfig(level=logging.INFO)
# Load pre-trained model tokenizer (vocabulary)
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
# Encode input
text = "Who was Jim Henson ? Jim Henson was a puppeteer"
indexed_tokens = tokenizer.encode(text)
# Convert inputs to PyTorch tensors
tokens_tensor = torch.tensor([indexed_tokens])
```
Let's see how to use `GPT2Model` to get hidden states
```python
# Load pre-trained model (weights)
model = GPT2Model.from_pretrained('gpt2')
model.eval()
# If you have a GPU, put everything on cuda
tokens_tensor = tokens_tensor.to('cuda')
model.to('cuda')
# Predict hidden states features for each layer
with torch.no_grad():
hidden_states = model(tokens_tensor)
```
And how to use `GPT2LMHeadModel`
```python
# Load pre-trained model (weights)
model = GPT2LMHeadModel.from_pretrained('gpt2')
model.eval()
# If you have a GPU, put everything on cuda
tokens_tensor = tokens_tensor.to('cuda')
model.to('cuda')
# Predict all tokens
with torch.no_grad():
predictions = model(tokens_tensor)
# get the predicted last token
predicted_index = torch.argmax(predictions[0, -1, :]).item()
predicted_token = tokenizer.decode([predicted_index])
```
## Doc
Here is a detailed documentation of the classes in the package and how to use them:
@ -402,11 +479,12 @@ where
- `bert-base-chinese`: Chinese Simplified and Traditional, 12-layer, 768-hidden, 12-heads, 110M parameters
- `openai-gpt`: OpenAI English model, 12-layer, 768-hidden, 12-heads, 110M parameters
- `transfo-xl-wt103`: Transformer-XL English model trained on wikitext-103, 18-layer, 1024-hidden, 16-heads, 257M parameters
- `gpt2`: OpenAI GPT-2 English model, 12-layer, 768-hidden, 12-heads, 117M parameters
- a path or url to a pretrained model archive containing:
- `bert_config.json` or `openai_gpt_config.json` a configuration file for the model, and
- `pytorch_model.bin` a PyTorch dump of a pre-trained instance of `BertForPreTraining`, `OpenAIGPTModel` or `TransfoXLModel` (saved with the usual `torch.save()`)
- `pytorch_model.bin` a PyTorch dump of a pre-trained instance of `BertForPreTraining`, `OpenAIGPTModel`, `TransfoXLModel`, `GPT2LMHeadModel` (saved with the usual `torch.save()`)
If `PRE_TRAINED_MODEL_NAME_OR_PATH` is a shortcut name, the pre-trained weights will be downloaded from AWS S3 (see the links [here](pytorch_pretrained_bert/modeling.py)) and stored in a cache folder to avoid future download (the cache folder can be found at `~/.pytorch_pretrained_bert/`).
- `cache_dir` can be an optional path to a specific directory to download and cache the pre-trained model weights. This option is useful in particular when you are using distributed training: to avoid concurrent access to the same weights you can set for example `cache_dir='./pretrained_model_{}'.format(args.local_rank)` (see the section on distributed training for more information).
@ -428,6 +506,11 @@ model = OpenAIGPTModel.from_pretrained('openai-gpt')
# Transformer-XL
tokenizer = TransfoXLTokenizer.from_pretrained('transfo-xl-wt103')
model = TransfoXLModel.from_pretrained('transfo-xl-wt103')
# OpenAI GPT-2
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
model = GPT2Model.from_pretrained('gpt2')
```
### PyTorch models
@ -649,6 +732,60 @@ all_hidden_states = lower_hidden_states + [hidden_states]
- else: log probabilities of tokens, shape [batch_size, sequence_length, n_tokens]
- `new_mems`: list (num layers) of updated mem states at the entry of each layer each mem state is a torch.FloatTensor of size [self.config.mem_len, batch_size, self.config.d_model]. Note that the first two dimensions are transposed in `mems` with regards to `input_ids`.
#### 14. `GPT2Model`
`GPT2Model` is the OpenAI GPT-2 Transformer model with a layer of summed token and position embeddings followed by a series of 12 identical self-attention blocks.
The inputs and output are **identical to the TensorFlow model inputs and outputs**.
We detail them here. This model takes as *inputs*:
[`modeling_gpt2.py`](./pytorch_pretrained_bert/modeling_gpt2.py)
- `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] (or more generally [d_1, ..., d_n, sequence_length] were d_1 ... d_n are arbitrary dimensions) with the word BPE token indices selected in the range [0, vocab_size[
- `position_ids`: an optional torch.LongTensor with the same shape as input_ids
with the position indices (selected in the range [0, config.n_positions - 1[.
- `token_type_ids`: an optional torch.LongTensor with the same shape as input_ids
You can use it to add a third type of embedding to each input token in the sequence
(the previous two being the word and position embeddings). The input, position and token_type embeddings are summed inside the Transformer before the first self-attention block.
- `past`: an optional list of torch.LongTensor that contains pre-computed hidden-states (key and values in the attention blocks) to speed up sequential decoding (this is the `presents` output of the model, cf. below).
This model *outputs*:
- `hidden_states`: the encoded-hidden-states at the top of the model as a torch.FloatTensor of size [batch_size, sequence_length, hidden_size] (or more generally [d_1, ..., d_n, hidden_size] were d_1 ... d_n are the dimension of input_ids)
- `presents`: a list of pre-computed hidden-states (key and values in each attention blocks) as a torch.FloatTensors. They can be reused to speed up sequential decoding (see the `run_gpt2.py` example).
#### 15. `GPT2LMHeadModel`
`GPT2LMHeadModel` includes the `GPT2Model` Transformer followed by a language modeling head with weights tied to the input embeddings (no additional parameters).
*Inputs* are the same as the inputs of the [`GPT2Model`](#-14.-`GPT2Model`) class plus optional labels:
- `lm_labels`: optional language modeling labels: torch.LongTensor of shape [batch_size, sequence_length] with indices selected in [-1, 0, ..., vocab_size]. All labels set to -1 are ignored (masked), the loss is only computed for the labels set in [0, ..., vocab_size].
*Outputs*:
- if `lm_labels` is not `None`:
Outputs the language modeling loss.
- else: a tupple of
- `lm_logits`: the language modeling logits as a torch.FloatTensor of size [batch_size, sequence_length, total_tokens_embeddings] (or more generally [d_1, ..., d_n, total_tokens_embeddings] were d_1 ... d_n are the dimension of input_ids)
- `presents`: a list of pre-computed hidden-states (key and values in each attention blocks) as a torch.FloatTensors. They can be reused to speed up sequential decoding (see the `run_gpt2.py` example).
#### 16. `GPT2DoubleHeadsModel`
`GPT2DoubleHeadsModel` includes the `GPT2Model` Transformer followed by two heads:
- a language modeling head with weights tied to the input embeddings (no additional parameters) and:
- a multiple choice classifier (linear layer that take as input a hidden state in a sequence to compute a score, see details in paper).
*Inputs* are the same as the inputs of the [`GPT2Model`](#-14.-`GPT2Model`) class plus a classification mask and two optional labels:
- `multiple_choice_token_ids`: a torch.LongTensor of shape [batch_size, num_choices] with the index of the token whose hidden state should be used as input for the multiple choice classifier (usually the [CLS] token for each choice).
- `lm_labels`: optional language modeling labels: torch.LongTensor of shape [batch_size, sequence_length] with indices selected in [-1, 0, ..., vocab_size]. All labels set to -1 are ignored (masked), the loss is only computed for the labels set in [0, ..., vocab_size].
- `multiple_choice_labels`: optional multiple choice labels: torch.LongTensor of shape [batch_size] with indices selected in [0, ..., num_choices].
*Outputs*:
- if `lm_labels` and `multiple_choice_labels` are not `None`:
Outputs a tuple of losses with the language modeling loss and the multiple choice loss.
- else Outputs a tuple with:
- `lm_logits`: the language modeling logits as a torch.FloatTensor of size [batch_size, num_choices, sequence_length, total_tokens_embeddings]
- `multiple_choice_logits`: the multiple choice logits as a torch.FloatTensor of size [batch_size, num_choices]
- `presents`: a list of pre-computed hidden-states (key and values in each attention blocks) as a torch.FloatTensors. They can be reused to speed up sequential decoding (see the `run_gpt2.py` example).
### Tokenizers:
#### `BertTokenizer`
@ -697,6 +834,24 @@ Please refer to the doc strings and code in [`tokenization_openai.py`](./pytorch
Please refer to the doc strings and code in [`tokenization_transfo_xl.py`](./pytorch_pretrained_bert/tokenization_transfo_xl.py) for the details of these additional methods in `TransfoXLTokenizer`.
#### `GPT2Tokenizer`
`GPT2Tokenizer` perform byte-level Byte-Pair-Encoding (BPE) tokenization.
This class has three arguments:
- `vocab_file`: path to a vocabulary file.
- `merges_file`: path to a file containing the BPE merges.
- `errors`: How to handle unicode decoding errors. **Default = `replace`**
and two methods:
- `encode(text)`: convert a `str` in a list of `int` tokens by performing byte-level BPE.
- `decode(tokens)`: convert back a list of `int` tokens in a `str`.
Please refer to [`tokenization_gpt2.py`](./pytorch_pretrained_bert/tokenization_gpt2.py) for more details on the `GPT2Tokenizer`.
### Optimizers:
#### `BertAdam`
@ -896,12 +1051,13 @@ python run_lm_finetuning.py \
--max_seq_length 128 \
```
### OpenAI GPT and Transformer-XL: running the examples
### OpenAI GPT, Transformer-XL and GPT-2: running the examples
We provide two examples of scripts for OpenAI GPT and Transformer-XL based on (and extended from) the respective original implementations:
We provide three examples of scripts for OpenAI GPT, Transformer-XL and OpenAI GPT-2 based on (and extended from) the respective original implementations:
- fine-tuning OpenAI GPT on the ROCStories dataset
- evaluating Transformer-XL on Wikitext 103
- unconditional and conditional generation from a pre-trained OpenAI GPT-2 model
#### Fine-tuning OpenAI GPT on the RocStories dataset
@ -936,6 +1092,22 @@ python run_transfo_xl.py --work_dir ../log
This command runs in about 1 min on a V100 and gives an evaluation perplexity of 18.22 on WikiText-103 (the authors report a perplexity of about 18.3 on this dataset with the TensorFlow code).
#### Unconditional and conditional generation from OpenAI's GPT-2 model
This example code is identical to the original unconditional and conditional generation codes.
Conditional generation:
```shell
python run_gpt2.py
```
Unconditional generation:
```shell
python run_gpt2.py --unconditional
```
The same option as in the original scripts are provided, please refere to the code of the example and the original repository of OpenAI.
## Fine-tuning BERT-large on GPUs
The options we list above allow to fine-tune BERT-large rather easily on GPU(s) instead of the TPU used by the original implementation.
@ -1050,12 +1222,25 @@ pytorch_pretrained_bert convert_openai_checkpoint \
Here is an example of the conversion process for a pre-trained Transformer-XL model (see [here](https://github.com/kimiyoung/transformer-xl/tree/master/tf#obtain-and-evaluate-pretrained-sota-models))
```shell
export BERT_BASE_DIR=/path/to/bert/uncased_L-12_H-768_A-12
export TRANSFO_XL_CHECKPOINT_FOLDER_PATH=/path/to/transfo/xl/checkpoint
pytorch_pretrained_bert convert_openai_checkpoint \
$OPENAI_GPT_CHECKPOINT_FOLDER_PATH \
pytorch_pretrained_bert convert_transfo_xl_checkpoint \
$TRANSFO_XL_CHECKPOINT_FOLDER_PATH \
$PYTORCH_DUMP_OUTPUT \
[OPENAI_GPT_CONFIG]
[TRANSFO_XL_CONFIG]
```
### GPT-2
Here is an example of the conversion process for a pre-trained OpenAI's GPT-2 model.
```shell
export GPT2_DIR=/path/to/gpt2/checkpoint
pytorch_pretrained_bert convert_gpt2_checkpoint \
$GPT2_DIR/model.ckpt \
$PYTORCH_DUMP_OUTPUT \
[GPT2_CONFIG]
```
## TPU

105
examples/run_gpt2.py Normal file
View File

@ -0,0 +1,105 @@
#!/usr/bin/env python3
import argparse
import logging
from tqdm import trange
import torch
import torch.nn.functional as F
import numpy as np
from pytorch_pretrained_bert import GPT2LMHeadModel, GPT2Tokenizer
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
datefmt = '%m/%d/%Y %H:%M:%S',
level = logging.INFO)
logger = logging.getLogger(__name__)
def top_k_logits(logits, k):
if k == 0:
return logits
values, _ = torch.topk(logits, k)
min_values = values[:, -1]
return torch.where(logits < min_values, torch.ones_like(logits, dtype=logits.dtype) * -1e10, logits)
def sample_sequence(model, length, start_token=None, batch_size=None, context=None, temperature=1, top_k=0, device='cuda', sample=True):
if start_token is None:
assert context is not None, 'Specify exactly one of start_token and context!'
context = torch.tensor(context, device=device, dtype=torch.long).unsqueeze(0).repeat(batch_size, 1)
else:
assert context is None, 'Specify exactly one of start_token and context!'
context = torch.full((batch_size, 1), start_token, device=device, dtype=torch.long)
prev = context
output = context
past = None
with torch.no_grad():
for i in trange(length):
logits, past = model(prev, past=past)
logits = logits[:, -1, :] / temperature
logits = top_k_logits(logits, k=top_k)
log_probs = F.softmax(logits, dim=-1)
if sample:
prev = torch.multinomial(log_probs, num_samples=1)
else:
_, prev = torch.topk(log_probs, k=1, dim=-1)
output = torch.cat((output, prev), dim=1)
return output
def run_model():
parser = argparse.ArgumentParser()
parser.add_argument('--model_name_or_path', type=str, default='gpt2', help='pretrained model name or path to local checkpoint')
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--nsamples", type=int, default=1)
parser.add_argument("--batch_size", type=int, default=-1)
parser.add_argument("--length", type=int, default=-1)
parser.add_argument("--temperature", type=int, default=1)
parser.add_argument("--top_k", type=int, default=0)
parser.add_argument('--unconditional', action='store_true', help='If true, unconditional generation.')
args = parser.parse_args()
print(args)
if args.batch_size == -1:
args.batch_size = 1
assert args.nsamples % args.batch_size == 0
np.random.seed(args.seed)
torch.random.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
enc = GPT2Tokenizer.from_pretrained(args.model_name_or_path)
model = GPT2LMHeadModel.from_pretrained(args.model_name_or_path)
model.to(device)
model.eval()
if args.length == -1:
args.length = model.config.n_ctx // 2
elif args.length > model.config.n_ctx:
raise ValueError("Can't get samples longer than window size: %s" % model.config.n_ctx)
while not args.unconditional:
if not args.unconditional:
raw_text = input("Model prompt >>> ")
while not raw_text:
print('Prompt should not be empty!')
raw_text = input("Model prompt >>> ")
context_tokens = enc.encode(raw_text)
generated = 0
for _ in range(args.nsamples // args.batch_size):
out = sample_sequence(
model=model, length=args.length,
context=context_tokens if not args.unconditional else None,
start_token=enc.encoder['<|endoftext|>'] if args.unconditional else None,
batch_size=args.batch_size,
temperature=args.temperature, top_k=args.top_k, device=device
)
out = out[:, len(context_tokens):].tolist()
for i in range(args.batch_size):
generated += 1
text = enc.decode(out[i])
print("=" * 40 + " SAMPLE " + str(generated) + " " + "=" * 40)
print(text)
print("=" * 80)
if __name__ == '__main__':
run_model()

View File

@ -0,0 +1,88 @@
#!/usr/bin/env python3
import argparse
import logging
import torch
import torch.nn.functional as F
import numpy as np
from tqdm import trange
from pytorch_pretrained_bert import GPT2LMHeadModel, GPT2Tokenizer
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
datefmt = '%m/%d/%Y %H:%M:%S',
level = logging.INFO)
logger = logging.getLogger(__name__)
def top_k_logits(logits, k):
if k == 0:
return logits
values, _ = torch.topk(logits, k)
min_values = values[:, -1]
return torch.where(logits < min_values, torch.ones_like(logits, dtype=logits.dtype) * -1e10, logits)
def sample_sequence(model, length, start_token=None, batch_size=None, context=None, temperature=1, top_k=0, device='cuda'):
if start_token is None:
assert context is not None, 'Specify exactly one of start_token and context!'
context = torch.tensor(context, device=device, dtype=torch.long)
else:
assert context is None, 'Specify exactly one of start_token and context!'
context = torch.full((batch_size, 1), start_token, device=device, dtype=torch.long)
prev = context
output = context
past = None
with torch.no_grad():
for i in trange(length):
logits, past = model(prev, past=past)
logits = logits[:, -1, :] / temperature
logits = top_k_logits(logits, k=top_k)
log_probs = F.softmax(logits, dim=-1)
prev = torch.multinomial(log_probs, num_samples=1)
output = torch.cat((output, prev), dim=1)
return output
def sample_model():
parser = argparse.ArgumentParser()
parser.add_argument('--model_name_or_path', type=str, default='gpt2', help='pretrained model name or path to local checkpoint')
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--nsamples", type=int, default=0)
parser.add_argument("--batch_size", type=int, default=1)
parser.add_argument("--length", type=int, default=-1)
parser.add_argument("--temperature", type=int, default=1)
parser.add_argument("--top_k", type=int, default=0)
args = parser.parse_args()
print(args)
np.random.seed(args.seed)
torch.random.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
enc = GPT2Tokenizer.from_pretrained(args.model_name_or_path)
model = GPT2LMHeadModel.from_pretrained(args.model_name_or_path)
model.to(device)
model.eval()
if args.length == -1:
args.length = model.config.n_ctx
elif args.length > model.config.n_ctx:
raise ValueError("Can't get samples longer than window size: %s" % model.config.n_ctx)
generated = 0
while args.nsamples == 0 or generated < args.nsamples:
out = sample_sequence(
model=model, length=args.length,
start_token=enc.encoder['<|endoftext|>'],
batch_size=args.batch_size,
temperature=args.temperature, top_k=args.top_k, device=device
)
out = out.tolist()
for i in range(args.batch_size):
generated += args.batch_size
text = enc.decode(out[i])
print("=" * 40 + " SAMPLE " + str(generated) + " " + "=" * 40)
print(text)
if __name__ == '__main__':
sample_model()

View File

@ -1,7 +1,8 @@
__version__ = "0.5.1"
__version__ = "0.6.0"
from .tokenization import BertTokenizer, BasicTokenizer, WordpieceTokenizer
from .tokenization_openai import OpenAIGPTTokenizer
from .tokenization_transfo_xl import (TransfoXLTokenizer, TransfoXLCorpus)
from .tokenization_gpt2 import GPT2Tokenizer
from .modeling import (BertConfig, BertModel, BertForPreTraining,
BertForMaskedLM, BertForNextSentencePrediction,
@ -13,6 +14,9 @@ from .modeling_openai import (OpenAIGPTConfig, OpenAIGPTModel,
load_tf_weights_in_openai_gpt)
from .modeling_transfo_xl import (TransfoXLConfig, TransfoXLModel, TransfoXLLMHeadModel,
load_tf_weights_in_transfo_xl)
from .modeling_gpt2 import (GPT2Config, GPT2Model,
GPT2LMHeadModel, GPT2DoubleHeadsModel,
load_tf_weights_in_gpt2)
from .optimization import BertAdam
from .optimization_openai import OpenAIAdam

View File

@ -4,13 +4,15 @@ def main():
if (len(sys.argv) != 4 and len(sys.argv) != 5) or sys.argv[1] not in [
"convert_tf_checkpoint_to_pytorch",
"convert_openai_checkpoint",
"convert_transfo_xl_checkpoint"
"convert_transfo_xl_checkpoint",
"convert_gpt2_checkpoint",
]:
print(
"Should be used as one of: \n"
">> `pytorch_pretrained_bert convert_tf_checkpoint_to_pytorch TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT`, \n"
">> `pytorch_pretrained_bert convert_openai_checkpoint OPENAI_GPT_CHECKPOINT_FOLDER_PATH PYTORCH_DUMP_OUTPUT [OPENAI_GPT_CONFIG]` or \n"
">> `pytorch_pretrained_bert convert_transfo_xl_checkpoint TF_CHECKPOINT_OR_DATASET PYTORCH_DUMP_OUTPUT [TF_CONFIG]`")
">> `pytorch_pretrained_bert convert_openai_checkpoint OPENAI_GPT_CHECKPOINT_FOLDER_PATH PYTORCH_DUMP_OUTPUT [OPENAI_GPT_CONFIG]`, \n"
">> `pytorch_pretrained_bert convert_transfo_xl_checkpoint TF_CHECKPOINT_OR_DATASET PYTORCH_DUMP_OUTPUT [TF_CONFIG]` or \n"
">> `pytorch_pretrained_bert convert_gpt2_checkpoint TF_CHECKPOINT PYTORCH_DUMP_OUTPUT [GPT2_CONFIG]`")
else:
if sys.argv[1] == "convert_tf_checkpoint_to_pytorch":
try:
@ -40,7 +42,7 @@ def main():
convert_openai_checkpoint_to_pytorch(OPENAI_GPT_CHECKPOINT_FOLDER_PATH,
OPENAI_GPT_CONFIG,
PYTORCH_DUMP_OUTPUT)
else:
elif sys.argv[1] == "convert_transfo_xl_checkpoint":
try:
from .convert_transfo_xl_checkpoint_to_pytorch import convert_transfo_xl_checkpoint_to_pytorch
except ImportError:
@ -61,5 +63,21 @@ def main():
else:
TF_CONFIG = ""
convert_transfo_xl_checkpoint_to_pytorch(TF_CHECKPOINT, TF_CONFIG, PYTORCH_DUMP_OUTPUT, TF_DATASET_FILE)
else:
try:
from .convert_gpt2_checkpoint_to_pytorch import convert_gpt2_checkpoint_to_pytorch
except ImportError:
print("pytorch_pretrained_bert can only be used from the commandline to convert TensorFlow models in PyTorch, "
"In that case, it requires TensorFlow to be installed. Please see "
"https://www.tensorflow.org/install/ for installation instructions.")
raise
TF_CHECKPOINT = sys.argv[2]
PYTORCH_DUMP_OUTPUT = sys.argv[3]
if len(sys.argv) == 5:
TF_CONFIG = sys.argv[4]
else:
TF_CONFIG = ""
convert_gpt2_checkpoint_to_pytorch(TF_CHECKPOINT, TF_CONFIG, PYTORCH_DUMP_OUTPUT)
if __name__ == '__main__':
main()

View File

@ -0,0 +1,72 @@
# coding=utf-8
# Copyright 2018 The HugginFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert OpenAI GPT checkpoint."""
from __future__ import absolute_import, division, print_function
import argparse
from io import open
import torch
from pytorch_pretrained_bert.modeling_gpt2 import (CONFIG_NAME, WEIGHTS_NAME,
GPT2Config,
GPT2Model,
load_tf_weights_in_gpt2)
def convert_gpt2_checkpoint_to_pytorch(gpt2_checkpoint_path, gpt2_config_file, pytorch_dump_folder_path):
# Construct model
if gpt2_config_file == "":
config = GPT2Config()
else:
config = GPT2Config(gpt2_config_file)
model = GPT2Model(config)
# Load weights from numpy
load_tf_weights_in_gpt2(model, gpt2_checkpoint_path)
# Save pytorch-model
pytorch_weights_dump_path = pytorch_dump_folder_path + '/' + WEIGHTS_NAME
pytorch_config_dump_path = pytorch_dump_folder_path + '/' + CONFIG_NAME
print("Save PyTorch model to {}".format(pytorch_weights_dump_path))
torch.save(model.state_dict(), pytorch_weights_dump_path)
print("Save configuration file to {}".format(pytorch_config_dump_path))
with open(pytorch_config_dump_path, "w", encoding="utf-8") as f:
f.write(config.to_json_string())
if __name__ == "__main__":
parser = argparse.ArgumentParser()
## Required parameters
parser.add_argument("--gpt2_checkpoint_path",
default = None,
type = str,
required = True,
help = "Path the TensorFlow checkpoint path.")
parser.add_argument("--pytorch_dump_folder_path",
default = None,
type = str,
required = True,
help = "Path to the output PyTorch model.")
parser.add_argument("--gpt2_config_file",
default = "",
type = str,
help = "An optional config json file corresponding to the pre-trained OpenAI model. \n"
"This specifies the model architecture.")
args = parser.parse_args()
convert_gpt2_checkpoint_to_pytorch(args.gpt2_checkpoint_path,
args.gpt2_config_file,
args.pytorch_dump_folder_path)

View File

@ -0,0 +1,684 @@
# coding=utf-8
# Copyright 2018 The OpenAI Team Authors and HugginFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch OpenAI GPT-2 model."""
import collections
import copy
import json
import logging
import math
import os
import shutil
import tarfile
import tempfile
import sys
from io import open
import torch
import torch.nn as nn
from torch.nn import CrossEntropyLoss
from torch.nn.parameter import Parameter
from .file_utils import cached_path
from .modeling import BertLayerNorm as LayerNorm
logger = logging.getLogger(__name__)
PRETRAINED_MODEL_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-pytorch_model.bin"}
PRETRAINED_CONFIG_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-config.json"}
CONFIG_NAME = "config.json"
WEIGHTS_NAME = "pytorch_model.bin"
def load_tf_weights_in_gpt2(model, gpt2_checkpoint_path):
""" Load tf checkpoints in a pytorch model
"""
try:
import re
import numpy as np
import tensorflow as tf
except ImportError:
print("Loading a TensorFlow models in PyTorch, requires TensorFlow to be installed. Please see "
"https://www.tensorflow.org/install/ for installation instructions.")
raise
tf_path = os.path.abspath(gpt2_checkpoint_path)
print("Converting TensorFlow checkpoint from {}".format(tf_path))
# Load weights from TF model
init_vars = tf.train.list_variables(tf_path)
names = []
arrays = []
for name, shape in init_vars:
print("Loading TF weight {} with shape {}".format(name, shape))
array = tf.train.load_variable(tf_path, name)
names.append(name)
arrays.append(array.squeeze())
for name, array in zip(names, arrays):
name = name[6:] # skip "model/"
name = name.split('/')
pointer = model
for m_name in name:
if re.fullmatch(r'[A-Za-z]+\d+', m_name):
l = re.split(r'(\d+)', m_name)
else:
l = [m_name]
if l[0] == 'w' or l[0] == 'g':
pointer = getattr(pointer, 'weight')
elif l[0] == 'b':
pointer = getattr(pointer, 'bias')
elif l[0] == 'wpe' or l[0] == 'wte':
pointer = getattr(pointer, l[0])
pointer = getattr(pointer, 'weight')
else:
pointer = getattr(pointer, l[0])
if len(l) >= 2:
num = int(l[1])
pointer = pointer[num]
try:
assert pointer.shape == array.shape
except AssertionError as e:
e.args += (pointer.shape, array.shape)
raise
print("Initialize PyTorch weight {}".format(name))
pointer.data = torch.from_numpy(array)
return model
def gelu(x):
return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
class GPT2Config(object):
"""Configuration class to store the configuration of a `GPT2Model`.
"""
def __init__(
self,
vocab_size_or_config_json_file=50257,
n_positions=1024,
n_ctx=1024,
n_embd=768,
n_layer=12,
n_head=12,
layer_norm_epsilon=1e-5,
initializer_range=0.02,
):
"""Constructs GPT2Config.
Args:
vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `GPT2Model` or a configuration json file.
n_positions: Number of positional embeddings.
n_ctx: Size of the causal mask (usually same as n_positions).
n_embd: Dimensionality of the embeddings and hidden states.
n_layer: Number of hidden layers in the Transformer encoder.
n_head: Number of attention heads for each attention layer in
the Transformer encoder.
layer_norm_epsilon: epsilon to use in the layer norm layers
initializer_range: The sttdev of the truncated_normal_initializer for
initializing all weight matrices.
"""
if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2
and isinstance(vocab_size_or_config_json_file, unicode)):
with open(vocab_size_or_config_json_file, "r", encoding="utf-8") as reader:
json_config = json.loads(reader.read())
for key, value in json_config.items():
self.__dict__[key] = value
elif isinstance(vocab_size_or_config_json_file, int):
self.vocab_size = vocab_size_or_config_json_file
self.n_ctx = n_ctx
self.n_positions = n_positions
self.n_embd = n_embd
self.n_layer = n_layer
self.n_head = n_head
self.layer_norm_epsilon = layer_norm_epsilon
self.initializer_range = initializer_range
else:
raise ValueError(
"First argument must be either a vocabulary size (int)"
"or the path to a pretrained model config file (str)"
)
@classmethod
def from_dict(cls, json_object):
"""Constructs a `GPT2Config` from a Python dictionary of parameters."""
config = GPT2Config(vocab_size_or_config_json_file=-1)
for key, value in json_object.items():
config.__dict__[key] = value
return config
@classmethod
def from_json_file(cls, json_file):
"""Constructs a `GPT2Config` from a json file of parameters."""
with open(json_file, "r", encoding="utf-8") as reader:
text = reader.read()
return cls.from_dict(json.loads(text))
def __repr__(self):
return str(self.to_json_string())
def to_dict(self):
"""Serializes this instance to a Python dictionary."""
output = copy.deepcopy(self.__dict__)
return output
def to_json_string(self):
"""Serializes this instance to a JSON string."""
return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
class Conv1D(nn.Module):
def __init__(self, nf, nx):
super(Conv1D, self).__init__()
self.nf = nf
w = torch.empty(nx, nf)
nn.init.normal_(w, std=0.02)
self.weight = Parameter(w)
self.bias = Parameter(torch.zeros(nf))
def forward(self, x):
size_out = x.size()[:-1] + (self.nf,)
x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight)
x = x.view(*size_out)
return x
class Attention(nn.Module):
def __init__(self, nx, n_ctx, config, scale=False):
super(Attention, self).__init__()
n_state = nx # in Attention: n_state=768 (nx=n_embd)
# [switch nx => n_state from Block to Attention to keep identical to TF implem]
assert n_state % config.n_head == 0
self.register_buffer("bias", torch.tril(torch.ones(n_ctx, n_ctx)).view(1, 1, n_ctx, n_ctx))
self.n_head = config.n_head
self.split_size = n_state
self.scale = scale
self.c_attn = Conv1D(n_state * 3, nx)
self.c_proj = Conv1D(n_state, nx)
def _attn(self, q, k, v):
w = torch.matmul(q, k)
if self.scale:
w = w / math.sqrt(v.size(-1))
nd, ns = w.size(-2), w.size(-1)
b = self.bias[:, :, ns-nd:ns, :ns]
w = w * b - 1e10 * (1 - b)
w = nn.Softmax(dim=-1)(w)
return torch.matmul(w, v)
def merge_heads(self, x):
x = x.permute(0, 2, 1, 3).contiguous()
new_x_shape = x.size()[:-2] + (x.size(-2) * x.size(-1),)
return x.view(*new_x_shape) # in Tensorflow implem: fct merge_states
def split_heads(self, x, k=False):
new_x_shape = x.size()[:-1] + (self.n_head, x.size(-1) // self.n_head)
x = x.view(*new_x_shape) # in Tensorflow implem: fct split_states
if k:
return x.permute(0, 2, 3, 1) # (batch, head, head_features, seq_length)
else:
return x.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features)
def forward(self, x, layer_past=None):
x = self.c_attn(x)
query, key, value = x.split(self.split_size, dim=2)
query = self.split_heads(query)
key = self.split_heads(key, k=True)
value = self.split_heads(value)
if layer_past is not None:
past_key, past_value = layer_past[0].transpose(-2, -1), layer_past[1] # transpose back cf below
key = torch.cat((past_key, key), dim=-1)
value = torch.cat((past_value, value), dim=-2)
present = torch.stack((key.transpose(-2, -1), value)) # transpose to have same shapes for stacking
a = self._attn(query, key, value)
a = self.merge_heads(a)
a = self.c_proj(a)
return a, present
class MLP(nn.Module):
def __init__(self, n_state, config): # in MLP: n_state=3072 (4 * n_embd)
super(MLP, self).__init__()
nx = config.n_embd
self.c_fc = Conv1D(n_state, nx)
self.c_proj = Conv1D(nx, n_state)
self.act = gelu
def forward(self, x):
h = self.act(self.c_fc(x))
h2 = self.c_proj(h)
return h2
class Block(nn.Module):
def __init__(self, n_ctx, config, scale=False):
super(Block, self).__init__()
nx = config.n_embd
self.ln_1 = LayerNorm(nx, eps=config.layer_norm_epsilon)
self.attn = Attention(nx, n_ctx, config, scale)
self.ln_2 = LayerNorm(nx, eps=config.layer_norm_epsilon)
self.mlp = MLP(4 * nx, config)
def forward(self, x, layer_past=None):
a, present = self.attn(self.ln_1(x), layer_past=layer_past)
x = x + a
m = self.mlp(self.ln_2(x))
x = x + m
return x, present
class GPT2LMHead(nn.Module):
""" Language Model Head for the transformer """
def __init__(self, model_embeddings_weights, config):
super(GPT2LMHead, self).__init__()
self.n_embd = config.n_embd
self.set_embeddings_weights(model_embeddings_weights)
def set_embeddings_weights(self, model_embeddings_weights):
embed_shape = model_embeddings_weights.shape
self.decoder = nn.Linear(embed_shape[1], embed_shape[0], bias=False)
self.decoder.weight = model_embeddings_weights # Tied weights
def forward(self, hidden_state):
# Truncated Language modeling logits (we remove the last token)
# h_trunc = h[:, :-1].contiguous().view(-1, self.n_embd)
lm_logits = self.decoder(hidden_state)
return lm_logits
class GPT2MultipleChoiceHead(nn.Module):
""" Classifier Head for the transformer """
def __init__(self, config):
super(GPT2MultipleChoiceHead, self).__init__()
self.n_embd = config.n_embd
self.linear = nn.Linear(config.n_embd, 1)
nn.init.normal_(self.linear.weight, std=0.02)
nn.init.normal_(self.linear.bias, 0)
def forward(self, hidden_states, mc_token_ids):
# Classification logits
# hidden_state (bsz, num_choices, seq_length, hidden_size)
# mc_token_ids (bsz, num_choices)
mc_token_ids = mc_token_ids.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, -1, hidden_states.size(-1))
# (bsz, num_choices, 1, hidden_size)
multiple_choice_h = hidden_states.gather(2, mc_token_ids).squeeze(2)
# (bsz, num_choices, hidden_size)
multiple_choice_logits = self.linear(multiple_choice_h).squeeze(-1)
# (bsz, num_choices)
return multiple_choice_logits
class GPT2PreTrainedModel(nn.Module):
""" An abstract class to handle weights initialization and
a simple interface for dowloading and loading pretrained models.
"""
def __init__(self, config, *inputs, **kwargs):
super(GPT2PreTrainedModel, self).__init__()
if not isinstance(config, GPT2Config):
raise ValueError(
"Parameter config in `{}(config)` should be an instance of class `GPT2Config`. "
"To create a model from a pretrained model use "
"`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
self.__class__.__name__, self.__class__.__name__
)
)
self.config = config
def set_tied(self):
pass
def init_weights(self, module):
""" Initialize the weights.
"""
if isinstance(module, (nn.Linear, nn.Embedding)):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
elif isinstance(module, LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
@classmethod
def from_pretrained(
cls, pretrained_model_name_or_path, state_dict=None, cache_dir=None, from_tf=False, *inputs, **kwargs
):
"""
Instantiate a GPT2PreTrainedModel from a pre-trained model file or a pytorch state dict.
Download and cache the pre-trained model file if needed.
Params:
pretrained_model_name_or_path: either:
- a str with the name of a pre-trained model to load selected in the list of:
. `openai-gpt`
- a path or url to a pretrained model archive containing:
. `gpt2_config.json` a configuration file for the model
. `pytorch_model.bin` a PyTorch dump of a GPT2Model instance
- a path or url to a pretrained model archive containing:
. `bert_config.json` a configuration file for the model
. a TensorFlow checkpoint with trained weights
from_tf: should we load the weights from a locally saved TensorFlow checkpoint
cache_dir: an optional path to a folder in which the pre-trained models will be cached.
state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of pre-trained models
*inputs, **kwargs: additional input for the specific Bert class
(ex: num_labels for BertForSequenceClassification)
"""
if pretrained_model_name_or_path in PRETRAINED_MODEL_ARCHIVE_MAP:
archive_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name_or_path]
config_file = PRETRAINED_CONFIG_ARCHIVE_MAP[pretrained_model_name_or_path]
else:
archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME)
# redirect to the cache, if necessary
try:
resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir)
resolved_config_file = cached_path(config_file, cache_dir=cache_dir)
except EnvironmentError:
logger.error(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find files {} and {} "
"at this path or url.".format(
pretrained_model_name_or_path, ", ".join(PRETRAINED_MODEL_ARCHIVE_MAP.keys()), pretrained_model_name_or_path,
archive_file, config_file
)
)
return None
if resolved_archive_file == archive_file and resolved_config_file == config_file:
logger.info("loading weights file {}".format(archive_file))
logger.info("loading configuration file {}".format(config_file))
else:
logger.info("loading weights file {} from cache at {}".format(
archive_file, resolved_archive_file))
logger.info("loading configuration file {} from cache at {}".format(
config_file, resolved_config_file))
# Load config
config = GPT2Config.from_json_file(resolved_config_file)
logger.info("Model config {}".format(config))
# Instantiate model.
model = cls(config, *inputs, **kwargs)
if state_dict is None and not from_tf:
state_dict = torch.load(resolved_archive_file, map_location='cpu' if not torch.cuda.is_available() else None)
if from_tf:
# Directly load from a TensorFlow checkpoint (stored as NumPy array)
return load_tf_weights_in_gpt2(model, resolved_archive_file)
old_keys = []
new_keys = []
for key in state_dict.keys():
new_key = None
if key.endswith(".g"):
new_key = key[:-2] + ".weight"
elif key.endswith(".b"):
new_key = key[:-2] + ".bias"
elif key.endswith(".w"):
new_key = key[:-2] + ".weight"
if new_key:
old_keys.append(key)
new_keys.append(new_key)
for old_key, new_key in zip(old_keys, new_keys):
state_dict[new_key] = state_dict.pop(old_key)
missing_keys = []
unexpected_keys = []
error_msgs = []
# copy state_dict so _load_from_state_dict can modify it
metadata = getattr(state_dict, "_metadata", None)
state_dict = state_dict.copy()
if metadata is not None:
state_dict._metadata = metadata
def load(module, prefix=""):
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
module._load_from_state_dict(
state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs
)
for name, child in module._modules.items():
if child is not None:
load(child, prefix + name + ".")
start_model = model
if hasattr(model, "transformer") and all(not s.startswith('transformer.') for s in state_dict.keys()):
start_model = model.transformer
load(start_model, prefix="")
if len(missing_keys) > 0:
logger.info(
"Weights of {} not initialized from pretrained model: {}".format(model.__class__.__name__, missing_keys)
)
if len(unexpected_keys) > 0:
logger.info(
"Weights from pretrained model not used in {}: {}".format(model.__class__.__name__, unexpected_keys)
)
if len(error_msgs) > 0:
raise RuntimeError(
"Error(s) in loading state_dict for {}:\n\t{}".format(model.__class__.__name__, "\n\t".join(error_msgs))
)
# Make sure we are still sharing the output and input embeddings after loading weights
model.set_tied()
return model
class GPT2Model(GPT2PreTrainedModel):
"""OpenAI GPT-2 model ("Language Models are Unsupervised Multitask Learners").
Params:
config: a GPT2Config class instance with the configuration to build a new model
Inputs:
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] (or more generally [d_1, ..., d_n, sequence_length]
were d_1 ... d_n are arbitrary dimensions) with the word BPE token indices selected in the range [0, config.vocab_size[
`position_ids`: an optional torch.LongTensor with the same shape as input_ids
with the position indices (selected in the range [0, config.n_positions - 1[.
`token_type_ids`: an optional torch.LongTensor with the same shape as input_ids
You can use it to add a third type of embedding to each input token in the sequence
(the previous two being the word and position embeddings).
The input, position and token_type embeddings are summed inside the Transformer before the first
self-attention block.
Outputs:
`hidden_states`: the encoded-hidden-states at the top of the model
as a torch.FloatTensor of size [batch_size, sequence_length, hidden_size]
(or more generally [d_1, ..., d_n, hidden_size] were d_1 ... d_n are the dimension of input_ids)
Example usage:
```python
# Already been converted into BPE token ids
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
config = modeling_gpt2.GPT2Config()
model = modeling_gpt2.GPT2Model(config)
hidden_states = model(input_ids)
```
"""
def __init__(self, config):
super(GPT2Model, self).__init__(config)
self.wte = nn.Embedding(config.vocab_size, config.n_embd)
self.wpe = nn.Embedding(config.n_positions, config.n_embd)
block = Block(config.n_ctx, config, scale=True)
self.h = nn.ModuleList([copy.deepcopy(block) for _ in range(config.n_layer)])
self.ln_f = LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
self.apply(self.init_weights)
def forward(self, input_ids, position_ids=None, token_type_ids=None, past=None):
if past is None:
past_length = 0
past = [None] * len(self.h)
else:
past_length = past[0][0].size(-2)
if position_ids is None:
position_ids = torch.arange(past_length, input_ids.size(-1) + past_length, dtype=torch.long, device=input_ids.device)
position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_ids.size(-1))
position_ids = position_ids.view(-1, position_ids.size(-1))
inputs_embeds = self.wte(input_ids)
position_embeds = self.wpe(position_ids)
if token_type_ids is not None:
token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1))
token_type_embeds = self.wte(token_type_ids)
else:
token_type_embeds = 0
hidden_states = inputs_embeds + position_embeds + token_type_embeds
presents = []
for block, layer_past in zip(self.h, past):
hidden_states, present = block(hidden_states, layer_past)
presents.append(present)
hidden_states = self.ln_f(hidden_states)
output_shape = input_shape + (hidden_states.size(-1),)
return hidden_states.view(*output_shape), presents
class GPT2LMHeadModel(GPT2PreTrainedModel):
"""OpenAI GPT-2 model with a Language Modeling head ("Language Models are Unsupervised Multitask Learners").
Params:
config: a GPT2Config class instance with the configuration to build a new model
Inputs:
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] (or more generally [d_1, ..., d_n, sequence_length]
were d_1 ... d_n are arbitrary dimensions) with the word BPE token indices selected in the range [0, config.vocab_size[
`position_ids`: an optional torch.LongTensor with the same shape as input_ids
with the position indices (selected in the range [0, config.n_positions - 1[.
`token_type_ids`: an optional torch.LongTensor with the same shape as input_ids
You can use it to add a third type of embedding to each input token in the sequence
(the previous two being the word and position embeddings).
The input, position and token_type embeddings are summed inside the Transformer before the first
self-attention block.
`lm_labels`: optional language modeling labels: torch.LongTensor of shape [batch_size, sequence_length]
with indices selected in [-1, 0, ..., vocab_size]. All labels set to -1 are ignored (masked), the loss
is only computed for the labels set in [0, ..., vocab_size]
Outputs:
if `lm_labels` is not `None`:
Outputs the language modeling loss.
else:
`lm_logits`: the language modeling logits as a torch.FloatTensor of size [batch_size, sequence_length, config.vocab_size]
(or more generally [d_1, ..., d_n, config.vocab_size] were d_1 ... d_n are the dimension of input_ids)
Example usage:
```python
# Already been converted into BPE token ids
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
config = modeling_gpt2.GPT2Config()
model = modeling_gpt2.GPT2LMHeadModel(config)
lm_logits = model(input_ids)
```
"""
def __init__(self, config):
super(GPT2LMHeadModel, self).__init__(config)
self.transformer = GPT2Model(config)
self.lm_head = GPT2LMHead(self.transformer.wte.weight, config)
self.apply(self.init_weights)
def set_tied(self):
""" Make sure we are sharing the embeddings
"""
self.lm_head.set_embeddings_weights(self.transformer.wte.weight)
def forward(self, input_ids, position_ids=None, token_type_ids=None, lm_labels=None, past=None):
hidden_states, presents = self.transformer(input_ids, position_ids, token_type_ids, past)
lm_logits = self.lm_head(hidden_states)
if lm_labels is not None:
loss_fct = CrossEntropyLoss(ignore_index=-1)
loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), lm_labels.view(-1))
return loss
return lm_logits, presents
class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
"""OpenAI GPT-2 model with a Language Modeling and a Multiple Choice head ("Language Models are Unsupervised Multitask Learners").
Params:
config: a GPT2Config class instance with the configuration to build a new model
Inputs:
`input_ids`: a torch.LongTensor of shape [batch_size, num_choices, sequence_length] with the BPE token
indices selected in the range [0, config.vocab_size[
`mc_token_ids`: a torch.LongTensor of shape [batch_size, num_choices] with the index of the token from
which we should take the hidden state to feed the multiple choice classifier (usually last token of the sequence)
`position_ids`: an optional torch.LongTensor with the same shape as input_ids
with the position indices (selected in the range [0, config.n_positions - 1[.
`token_type_ids`: an optional torch.LongTensor with the same shape as input_ids
You can use it to add a third type of embedding to each input token in the sequence
(the previous two being the word and position embeddings).
The input, position and token_type embeddings are summed inside the Transformer before the first
self-attention block.
`lm_labels`: optional language modeling labels: torch.LongTensor of shape [batch_size, num_choices, sequence_length]
with indices selected in [-1, 0, ..., config.vocab_size]. All labels set to -1 are ignored (masked), the loss
is only computed for the labels set in [0, ..., config.vocab_size]
`multiple_choice_labels`: optional multiple choice labels: torch.LongTensor of shape [batch_size]
with indices selected in [0, ..., num_choices].
Outputs:
if `lm_labels` and `multiple_choice_labels` are not `None`:
Outputs a tuple of losses with the language modeling loss and the multiple choice loss.
else: a tuple with
`lm_logits`: the language modeling logits as a torch.FloatTensor of size [batch_size, num_choices, sequence_length, config.vocab_size]
`multiple_choice_logits`: the multiple choice logits as a torch.FloatTensor of size [batch_size, num_choices]
Example usage:
```python
# Already been converted into BPE token ids
input_ids = torch.LongTensor([[[31, 51, 99], [15, 5, 0]]]) # (bsz, number of choice, seq length)
mc_token_ids = torch.LongTensor([[2], [1]]) # (bsz, number of choice)
config = modeling_gpt2.GPT2Config()
model = modeling_gpt2.GPT2LMHeadModel(config)
lm_logits, multiple_choice_logits = model(input_ids, mc_token_ids)
```
"""
def __init__(self, config):
super(GPT2DoubleHeadsModel, self).__init__(config)
self.transformer = GPT2Model(config)
self.lm_head = GPT2LMHead(self.transformer.wte.weight, config)
self.multiple_choice_head = GPT2MultipleChoiceHead(config)
self.apply(self.init_weights)
def set_tied(self):
""" Make sure we are sharing the embeddings
"""
self.lm_head.set_embeddings_weights(self.transformer.wte.weight)
def forward(self, input_ids, mc_token_ids, lm_labels=None, mc_labels=None, token_type_ids=None, position_ids=None, past=None):
hidden_states, presents = self.transformer(input_ids, position_ids, token_type_ids, past)
lm_logits = self.lm_head(hidden_states)
mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids)
losses = []
if lm_labels is not None:
loss_fct = CrossEntropyLoss(ignore_index=-1)
losses.append(loss_fct(lm_logits.view(-1, lm_logits.size(-1)), lm_labels.view(-1)))
if mc_labels is not None:
loss_fct = CrossEntropyLoss()
losses.append(loss_fct(mc_logits.view(-1, mc_logits.size(-1)), mc_labels.view(-1)))
if losses:
return losses
return lm_logits, mc_logits, presents

View File

@ -56,7 +56,7 @@ def load_tf_weights_in_openai_gpt(model, openai_checkpoint_folder_path):
init_params = np.split(np.concatenate(init_params, 0), offsets)[:-1]
init_params = [param.reshape(shape) for param, shape in zip(init_params, shapes)]
# Thsi as used when we had a single embedding matrix for positions and tokens
# This was used when we had a single embedding matrix for positions and tokens
# init_params[0] = np.concatenate([init_params[1], init_params[0]], 0)
# del init_params[1]
init_params = [arr.squeeze() for arr in init_params]

View File

@ -0,0 +1,206 @@
# coding=utf-8
# Copyright 2018 The Open AI Team Authors and The HugginFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tokenization classes for OpenAI GPT."""
from __future__ import (absolute_import, division, print_function,
unicode_literals)
import json
import logging
import os
import regex as re
from io import open
try:
from functools import lru_cache
except ImportError:
# Just a dummy decorator to get the checks to run on python2
# because honestly I don't want to support a byte-level unicode BPE tokenizer on python 2 right now.
def lru_cache():
return lambda func: func
from .file_utils import cached_path
logger = logging.getLogger(__name__)
PRETRAINED_VOCAB_ARCHIVE_MAP = {
'gpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-vocab.json",
}
PRETRAINED_MERGES_ARCHIVE_MAP = {
'gpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-merges.txt",
}
PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = {
'gpt2': 1024,
}
VOCAB_NAME = 'vocab.json'
MERGES_NAME = 'merges.txt'
@lru_cache()
def bytes_to_unicode():
"""
Returns list of utf-8 byte and a corresponding list of unicode strings.
The reversible bpe codes work on unicode strings.
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
This is a signficant percentage of your normal, say, 32K bpe vocab.
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
And avoids mapping to whitespace/control characters the bpe code barfs on.
"""
bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
cs = bs[:]
n = 0
for b in range(2**8):
if b not in bs:
bs.append(b)
cs.append(2**8+n)
n += 1
cs = [chr(n) for n in cs]
return dict(zip(bs, cs))
def get_pairs(word):
"""Return set of symbol pairs in a word.
Word is represented as tuple of symbols (symbols being variable-length strings).
"""
pairs = set()
prev_char = word[0]
for char in word[1:]:
pairs.add((prev_char, char))
prev_char = char
return pairs
class GPT2Tokenizer(object):
"""
GPT-2 BPE tokenizer. Peculiarities:
- Byte-level BPE
"""
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs):
"""
Instantiate a PreTrainedBertModel from a pre-trained model file.
Download and cache the pre-trained model file if needed.
"""
if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP:
vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name_or_path]
merges_file = PRETRAINED_MERGES_ARCHIVE_MAP[pretrained_model_name_or_path]
else:
vocab_file = os.path.join(pretrained_model_name_or_path, VOCAB_NAME)
merges_file = os.path.join(pretrained_model_name_or_path, MERGES_NAME)
# redirect to the cache, if necessary
try:
resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir)
resolved_merges_file = cached_path(merges_file, cache_dir=cache_dir)
except EnvironmentError:
logger.error(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find files {} and {} "
"at this path or url.".format(
pretrained_model_name_or_path,
', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()),
pretrained_model_name_or_path,
vocab_file, merges_file))
return None
if resolved_vocab_file == vocab_file and resolved_merges_file == merges_file:
logger.info("loading vocabulary file {}".format(vocab_file))
logger.info("loading merges file {}".format(merges_file))
else:
logger.info("loading vocabulary file {} from cache at {}".format(
vocab_file, resolved_vocab_file))
logger.info("loading merges file {} from cache at {}".format(
merges_file, resolved_merges_file))
if pretrained_model_name_or_path in PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP:
# if we're using a pretrained model, ensure the tokenizer wont index sequences longer
# than the number of positional embeddings
max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[pretrained_model_name_or_path]
kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len)
# Instantiate tokenizer.
tokenizer = cls(resolved_vocab_file, resolved_merges_file, *inputs, **kwargs)
return tokenizer
def __init__(self, vocab_file, merges_file, errors='replace', max_len=None):
self.max_len = max_len if max_len is not None else int(1e12)
self.encoder = json.load(open(vocab_file))
self.decoder = {v:k for k,v in self.encoder.items()}
self.errors = errors # how to handle errors in decoding
self.byte_encoder = bytes_to_unicode()
self.byte_decoder = {v:k for k, v in self.byte_encoder.items()}
bpe_data = open(merges_file, encoding='utf-8').read().split('\n')[1:-1]
bpe_merges = [tuple(merge.split()) for merge in bpe_data]
self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
self.cache = {}
# Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions
self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")
def __len__(self):
return len(self.encoder)
def bpe(self, token):
if token in self.cache:
return self.cache[token]
word = tuple(token)
pairs = get_pairs(word)
if not pairs:
return token
while True:
bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
if bigram not in self.bpe_ranks:
break
first, second = bigram
new_word = []
i = 0
while i < len(word):
try:
j = word.index(first, i)
new_word.extend(word[i:j])
i = j
except:
new_word.extend(word[i:])
break
if word[i] == first and i < len(word)-1 and word[i+1] == second:
new_word.append(first+second)
i += 2
else:
new_word.append(word[i])
i += 1
new_word = tuple(new_word)
word = new_word
if len(word) == 1:
break
else:
pairs = get_pairs(word)
word = ' '.join(word)
self.cache[token] = word
return word
def encode(self, text):
bpe_tokens = []
for token in re.findall(self.pat, text):
token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
if len(bpe_tokens) > self.max_len:
raise ValueError(
"Token indices sequence length is longer than the specified maximum "
" sequence length for this OpenAI GPT-2 model ({} > {}). Running this"
" sequence through the model will result in indexing errors".format(len(bpe_tokens), self.max_len)
)
return bpe_tokens
def decode(self, tokens):
text = ''.join([self.decoder[token] for token in tokens])
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors)
return text

View File

@ -38,7 +38,7 @@ from setuptools import find_packages, setup
setup(
name="pytorch_pretrained_bert",
version="0.5.1",
version="0.6.0",
author="Thomas Wolf, Victor Sanh, Tim Rault, Google AI Language Team Authors, Open AI team Authors",
author_email="thomas@huggingface.co",
description="PyTorch version of Google AI BERT model with script to load Google pre-trained models",

210
tests/modeling_gpt2_test.py Normal file
View File

@ -0,0 +1,210 @@
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import unittest
import json
import random
import torch
from pytorch_pretrained_bert import (GPT2Config, GPT2Model,
GPT2LMHeadModel, GPT2DoubleHeadsModel)
class GPT2ModelTest(unittest.TestCase):
class GPT2ModelTester(object):
def __init__(self,
parent,
batch_size=13,
seq_length=7,
is_training=True,
use_position_ids=True,
use_token_type_ids=True,
use_labels=True,
vocab_size=99,
n_positions=33,
n_embd=32,
n_layer=5,
n_head=4,
n_choices=3,
type_sequence_label_size=2,
initializer_range=0.02,
num_labels=3,
scope=None):
self.parent = parent
self.batch_size = batch_size
self.seq_length = seq_length
self.is_training = is_training
self.use_position_ids = use_position_ids
self.use_token_type_ids = use_token_type_ids
self.use_labels = use_labels
self.vocab_size = vocab_size
self.n_positions = n_positions
self.n_embd = n_embd
self.n_layer = n_layer
self.n_head = n_head
self.n_choices = n_choices
self.type_sequence_label_size = type_sequence_label_size
self.initializer_range = initializer_range
self.num_labels = num_labels
self.scope = scope
def prepare_config_and_inputs(self):
input_ids = GPT2ModelTest.ids_tensor([self.batch_size, self.n_choices, self.seq_length], self.vocab_size)
position_ids = None
if self.use_position_ids:
position_ids = GPT2ModelTest.ids_tensor([self.batch_size, self.n_choices, self.seq_length], self.n_positions)
token_type_ids = None
if self.use_token_type_ids:
total_voc = self.vocab_size
token_type_ids = GPT2ModelTest.ids_tensor([self.batch_size, self.n_choices, self.seq_length], total_voc)
mc_labels = None
lm_labels = None
mc_token_ids = None
if self.use_labels:
mc_labels = GPT2ModelTest.ids_tensor([self.batch_size], self.type_sequence_label_size)
lm_labels = GPT2ModelTest.ids_tensor([self.batch_size, self.n_choices, self.seq_length], self.num_labels)
mc_token_ids = GPT2ModelTest.ids_tensor([self.batch_size, self.n_choices], self.seq_length)
config = GPT2Config(
vocab_size_or_config_json_file=self.vocab_size,
n_positions=self.n_positions,
n_embd=self.n_embd,
n_layer=self.n_layer,
n_head=self.n_head,
initializer_range=self.initializer_range)
return (config, input_ids, token_type_ids, position_ids,
mc_labels, lm_labels, mc_token_ids)
def create_gpt2_model(self, config, input_ids, token_type_ids, position_ids,
mc_labels, lm_labels, mc_token_ids):
model = GPT2Model(config)
model.eval()
hidden_states, presents = model(input_ids, position_ids, token_type_ids)
outputs = {
"hidden_states": hidden_states,
"presents": presents,
}
return outputs
def check_gpt2_model_output(self, result):
self.parent.assertListEqual(
list(result["hidden_states"].size()),
[self.batch_size, self.n_choices, self.seq_length, self.n_embd])
def create_gpt2_lm_head(self, config, input_ids, token_type_ids, position_ids,
mc_labels, lm_labels, mc_token_ids):
model = GPT2LMHeadModel(config)
model.eval()
loss = model(input_ids, position_ids, token_type_ids, lm_labels)
lm_logits, presents = model(input_ids, position_ids, token_type_ids)
outputs = {
"loss": loss,
"lm_logits": lm_logits,
"presents": presents,
}
return outputs
def check_gpt2_lm_head_output(self, result):
total_voc = self.vocab_size
self.parent.assertListEqual(
list(result["lm_logits"].size()),
[self.batch_size, self.n_choices, self.seq_length, total_voc])
def check_gpt2_lm_head_loss_output(self, result):
self.parent.assertListEqual(
list(result["loss"].size()),
[])
def create_gpt2_double_heads(self, config, input_ids, token_type_ids, position_ids,
mc_labels, lm_labels, mc_token_ids):
model = GPT2DoubleHeadsModel(config)
model.eval()
loss = model(input_ids, mc_token_ids,
lm_labels=lm_labels, mc_labels=mc_labels,
token_type_ids=token_type_ids, position_ids=position_ids)
lm_logits, mc_logits, presents = model(input_ids, mc_token_ids, position_ids=position_ids, token_type_ids=token_type_ids)
outputs = {
"loss": loss,
"lm_logits": lm_logits,
"mc_logits": mc_logits,
"presents": presents,
}
return outputs
def check_gpt2_double_heads_output(self, result):
total_voc = self.vocab_size
self.parent.assertListEqual(
list(result["lm_logits"].size()),
[self.batch_size, self.n_choices, self.seq_length, total_voc])
self.parent.assertListEqual(
list(result["mc_logits"].size()),
[self.batch_size, self.n_choices])
def check_gpt2_double_heads_loss_output(self, result):
self.parent.assertListEqual(
[list(l.size()) for l in result["loss"]],
[[], []])
def test_default(self):
self.run_tester(GPT2ModelTest.GPT2ModelTester(self))
def test_config_to_json_string(self):
config = GPT2Config(vocab_size_or_config_json_file=99, n_embd=37)
obj = json.loads(config.to_json_string())
self.assertEqual(obj["vocab_size"], 99)
self.assertEqual(obj["n_embd"], 37)
def run_tester(self, tester):
config_and_inputs = tester.prepare_config_and_inputs()
output_result = tester.create_gpt2_model(*config_and_inputs)
tester.check_gpt2_model_output(output_result)
output_result = tester.create_gpt2_lm_head(*config_and_inputs)
tester.check_gpt2_lm_head_output(output_result)
tester.check_gpt2_lm_head_loss_output(output_result)
output_result = tester.create_gpt2_double_heads(*config_and_inputs)
tester.check_gpt2_double_heads_output(output_result)
tester.check_gpt2_double_heads_loss_output(output_result)
@classmethod
def ids_tensor(cls, shape, vocab_size, rng=None, name=None):
"""Creates a random int32 tensor of the shape within the vocab size."""
if rng is None:
rng = random.Random()
total_dims = 1
for dim in shape:
total_dims *= dim
values = []
for _ in range(total_dims):
values.append(rng.randint(0, vocab_size - 1))
return torch.tensor(data=values, dtype=torch.long).view(shape).contiguous()
if __name__ == "__main__":
unittest.main()