mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 11:11:05 +06:00
commit
0856a231c0
203
README.md
203
README.md
@ -5,8 +5,9 @@
|
||||
This repository contains op-for-op PyTorch reimplementations, pre-trained models and fine-tuning examples for:
|
||||
|
||||
- [Google's BERT model](https://github.com/google-research/bert),
|
||||
- [OpenAI's GPT model](https://github.com/openai/finetune-transformer-lm), and
|
||||
- [Google/CMU's Transformer-XL model](https://github.com/kimiyoung/transformer-xl).
|
||||
- [OpenAI's GPT model](https://github.com/openai/finetune-transformer-lm),
|
||||
- [Google/CMU's Transformer-XL model](https://github.com/kimiyoung/transformer-xl), and
|
||||
- [OpenAI's GPT-2 model](https://blog.openai.com/better-language-models/),
|
||||
|
||||
These implementations have been tested on several datasets (see the examples) and should match the performances of the associated TensorFlow implementations (e.g. ~91 F1 on SQuAD for BERT, ~88 F1 on RocStories for OpenAI GPT and ~18.3 perplexity on WikiText 103 for the Transformer-XL). You can find more details in the [Examples](#examples) section below.
|
||||
|
||||
@ -21,6 +22,10 @@ This PyTorch implementation of OpenAI GPT is an adaptation of the [PyTorch imple
|
||||
**Google/CMU's Transformer-XL** was released together with the paper [Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context](http://arxiv.org/abs/1901.02860) by Zihang Dai*, Zhilin Yang*, Yiming Yang, Jaime Carbonell, Quoc V. Le, Ruslan Salakhutdinov.
|
||||
This PyTorch implementation of Transformer-XL is an adaptation of the original [PyTorch implementation](https://github.com/kimiyoung/transformer-xl) which has been slightly modified to match the performances of the TensforFlow implementation and allow to re-use the pretrained weights. A command-line interface is provided to convert TensorFlow checkpoints in PyTorch models.
|
||||
|
||||
**OpenAI GPT-2** was released together with the paper [Language Models are Unsupervised Multitask Learners](https://blog.openai.com/better-language-models/) by Alec Radford*, JeffreyWu*, Rewon Child, David Luan, Dario Amodei** and Ilya Sutskever**.
|
||||
This PyTorch implementation of OpenAI GPT-2 is an adaptation of the [OpenAI's implementation](https://github.com/openai/gpt-2) and is provided with [OpenAI's pre-trained model](https://github.com/openai/gpt-2 and a command-line interface that was used to convert the TensorFlow checkpoint in PyTorch.
|
||||
|
||||
|
||||
## Content
|
||||
|
||||
| Section | Description |
|
||||
@ -98,6 +103,11 @@ This package comprises the following classes that can be imported in Python and
|
||||
- [`TransfoXLModel`](./pytorch_pretrained_bert/modeling_transfo_xl.py#L974) - Transformer-XL model which outputs the last hidden state and memory cells (**fully pre-trained**),
|
||||
- [`TransfoXLLMHeadModel`](./pytorch_pretrained_bert/modeling_transfo_xl.py#L1236) - Transformer-XL with the tied adaptive softmax head on top for language modeling which outputs the logits/loss and memory cells (**fully pre-trained**),
|
||||
|
||||
- Three **OpenAI GPT-2** PyTorch models (`torch.nn.Module`) with pre-trained weights (in the [`modeling_gpt2.py`](./pytorch_pretrained_bert/modeling_gpt2.py) file):
|
||||
- [`GPT2Model`](./pytorch_pretrained_bert/modeling_gpt2.py#L537) - raw OpenAI GPT-2 Transformer model (**fully pre-trained**),
|
||||
- [`GPT2LMHeadModel`](./pytorch_pretrained_bert/modeling_gpt2.py#L691) - OpenAI GPT-2 Transformer with the tied language modeling head on top (**fully pre-trained**),
|
||||
- [`GPT2DoubleHeadsModel`](./pytorch_pretrained_bert/modeling_gpt2.py#L752) - OpenAI GPT-2 Transformer with the tied language modeling head and a multiple choice classification head on top (OpenAI GPT-2 Transformer is **pre-trained**, the multiple choice classification head **is only initialized and has to be trained**),
|
||||
|
||||
- Tokenizers for **BERT** (using word-piece) (in the [`tokenization.py`](./pytorch_pretrained_bert/tokenization.py) file):
|
||||
- `BasicTokenizer` - basic tokenization (punctuation splitting, lower casing, etc.),
|
||||
- `WordpieceTokenizer` - WordPiece tokenization,
|
||||
@ -109,6 +119,9 @@ This package comprises the following classes that can be imported in Python and
|
||||
- Tokenizer for **Transformer-XL** (word tokens ordered by frequency for adaptive softmax) (in the [`tokenization_transfo_xl.py`](./pytorch_pretrained_bert/tokenization_transfo_xl.py) file):
|
||||
- `OpenAIGPTTokenizer` - perform word tokenization and can order words by frequency in a corpus for use in an adaptive softmax.
|
||||
|
||||
- Tokenizer for **OpenAI GPT-2** (using byte-level Byte-Pair-Encoding) (in the [`tokenization_gpt2.py`](./pytorch_pretrained_bert/tokenization_gpt2.py) file):
|
||||
- `GPT2Tokenizer` - perform byte-level Byte-Pair-Encoding (BPE) tokenization.
|
||||
|
||||
- Optimizer for **BERT** (in the [`optimization.py`](./pytorch_pretrained_bert/optimization.py) file):
|
||||
- `BertAdam` - Bert version of Adam algorithm with weight decay fix, warmup and linear decay of the learning rate.
|
||||
|
||||
@ -135,6 +148,9 @@ The repository further comprises:
|
||||
- One example on how to use **Transformer-XL** (in the [`examples` folder](./examples)):
|
||||
- [`run_transfo_xl.py`](./examples/run_transfo_xl.py) - Show how to load and evaluate a pre-trained model of `TransfoXLLMHeadModel` on WikiText 103.
|
||||
|
||||
- One example on how to use **OpenAI GPT-2** in the unconditional and interactive mode (in the [`examples` folder](./examples)):
|
||||
- [`run_gpt2.py`](./examples/run_gpt2.py) - Show how to use OpenAI GPT-2 an instance of `GPT2LMHeadModel` to generate text (same as the original OpenAI GPT-2 examples).
|
||||
|
||||
These examples are detailed in the [Examples](#examples) section of this readme.
|
||||
|
||||
- Three notebooks that were used to check that the TensorFlow and PyTorch models behave identically (in the [`notebooks` folder](./notebooks)):
|
||||
@ -367,6 +383,67 @@ predicted_token = tokenizer.convert_ids_to_tokens([predicted_index])[0]
|
||||
assert predicted_token == 'who'
|
||||
```
|
||||
|
||||
### OpenAI GPT-2
|
||||
|
||||
Here is a quick-start example using `GPT2Tokenizer`, `GPT2Model` and `GPT2LMHeadModel` class with OpenAI's pre-trained model. See the [doc section](#doc) below for all the details on these classes.
|
||||
|
||||
First let's prepare a tokenized input with `GPT2Tokenizer`
|
||||
|
||||
```python
|
||||
import torch
|
||||
from pytorch_pretrained_bert import GPT2Tokenizer, GPT2Model, GPT2LMHeadModel
|
||||
|
||||
# OPTIONAL: if you want to have more information on what's happening, activate the logger as follows
|
||||
import logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
# Load pre-trained model tokenizer (vocabulary)
|
||||
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
|
||||
|
||||
# Encode input
|
||||
text = "Who was Jim Henson ? Jim Henson was a puppeteer"
|
||||
indexed_tokens = tokenizer.encode(text)
|
||||
|
||||
# Convert inputs to PyTorch tensors
|
||||
tokens_tensor = torch.tensor([indexed_tokens])
|
||||
```
|
||||
|
||||
Let's see how to use `GPT2Model` to get hidden states
|
||||
|
||||
```python
|
||||
# Load pre-trained model (weights)
|
||||
model = GPT2Model.from_pretrained('gpt2')
|
||||
model.eval()
|
||||
|
||||
# If you have a GPU, put everything on cuda
|
||||
tokens_tensor = tokens_tensor.to('cuda')
|
||||
model.to('cuda')
|
||||
|
||||
# Predict hidden states features for each layer
|
||||
with torch.no_grad():
|
||||
hidden_states = model(tokens_tensor)
|
||||
```
|
||||
|
||||
And how to use `GPT2LMHeadModel`
|
||||
|
||||
```python
|
||||
# Load pre-trained model (weights)
|
||||
model = GPT2LMHeadModel.from_pretrained('gpt2')
|
||||
model.eval()
|
||||
|
||||
# If you have a GPU, put everything on cuda
|
||||
tokens_tensor = tokens_tensor.to('cuda')
|
||||
model.to('cuda')
|
||||
|
||||
# Predict all tokens
|
||||
with torch.no_grad():
|
||||
predictions = model(tokens_tensor)
|
||||
|
||||
# get the predicted last token
|
||||
predicted_index = torch.argmax(predictions[0, -1, :]).item()
|
||||
predicted_token = tokenizer.decode([predicted_index])
|
||||
```
|
||||
|
||||
## Doc
|
||||
|
||||
Here is a detailed documentation of the classes in the package and how to use them:
|
||||
@ -402,11 +479,12 @@ where
|
||||
- `bert-base-chinese`: Chinese Simplified and Traditional, 12-layer, 768-hidden, 12-heads, 110M parameters
|
||||
- `openai-gpt`: OpenAI English model, 12-layer, 768-hidden, 12-heads, 110M parameters
|
||||
- `transfo-xl-wt103`: Transformer-XL English model trained on wikitext-103, 18-layer, 1024-hidden, 16-heads, 257M parameters
|
||||
- `gpt2`: OpenAI GPT-2 English model, 12-layer, 768-hidden, 12-heads, 117M parameters
|
||||
|
||||
- a path or url to a pretrained model archive containing:
|
||||
|
||||
- `bert_config.json` or `openai_gpt_config.json` a configuration file for the model, and
|
||||
- `pytorch_model.bin` a PyTorch dump of a pre-trained instance of `BertForPreTraining`, `OpenAIGPTModel` or `TransfoXLModel` (saved with the usual `torch.save()`)
|
||||
- `pytorch_model.bin` a PyTorch dump of a pre-trained instance of `BertForPreTraining`, `OpenAIGPTModel`, `TransfoXLModel`, `GPT2LMHeadModel` (saved with the usual `torch.save()`)
|
||||
|
||||
If `PRE_TRAINED_MODEL_NAME_OR_PATH` is a shortcut name, the pre-trained weights will be downloaded from AWS S3 (see the links [here](pytorch_pretrained_bert/modeling.py)) and stored in a cache folder to avoid future download (the cache folder can be found at `~/.pytorch_pretrained_bert/`).
|
||||
- `cache_dir` can be an optional path to a specific directory to download and cache the pre-trained model weights. This option is useful in particular when you are using distributed training: to avoid concurrent access to the same weights you can set for example `cache_dir='./pretrained_model_{}'.format(args.local_rank)` (see the section on distributed training for more information).
|
||||
@ -428,6 +506,11 @@ model = OpenAIGPTModel.from_pretrained('openai-gpt')
|
||||
# Transformer-XL
|
||||
tokenizer = TransfoXLTokenizer.from_pretrained('transfo-xl-wt103')
|
||||
model = TransfoXLModel.from_pretrained('transfo-xl-wt103')
|
||||
|
||||
# OpenAI GPT-2
|
||||
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
|
||||
model = GPT2Model.from_pretrained('gpt2')
|
||||
|
||||
```
|
||||
|
||||
### PyTorch models
|
||||
@ -649,6 +732,60 @@ all_hidden_states = lower_hidden_states + [hidden_states]
|
||||
- else: log probabilities of tokens, shape [batch_size, sequence_length, n_tokens]
|
||||
- `new_mems`: list (num layers) of updated mem states at the entry of each layer each mem state is a torch.FloatTensor of size [self.config.mem_len, batch_size, self.config.d_model]. Note that the first two dimensions are transposed in `mems` with regards to `input_ids`.
|
||||
|
||||
#### 14. `GPT2Model`
|
||||
|
||||
`GPT2Model` is the OpenAI GPT-2 Transformer model with a layer of summed token and position embeddings followed by a series of 12 identical self-attention blocks.
|
||||
|
||||
The inputs and output are **identical to the TensorFlow model inputs and outputs**.
|
||||
|
||||
We detail them here. This model takes as *inputs*:
|
||||
[`modeling_gpt2.py`](./pytorch_pretrained_bert/modeling_gpt2.py)
|
||||
- `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] (or more generally [d_1, ..., d_n, sequence_length] were d_1 ... d_n are arbitrary dimensions) with the word BPE token indices selected in the range [0, vocab_size[
|
||||
- `position_ids`: an optional torch.LongTensor with the same shape as input_ids
|
||||
with the position indices (selected in the range [0, config.n_positions - 1[.
|
||||
- `token_type_ids`: an optional torch.LongTensor with the same shape as input_ids
|
||||
You can use it to add a third type of embedding to each input token in the sequence
|
||||
(the previous two being the word and position embeddings). The input, position and token_type embeddings are summed inside the Transformer before the first self-attention block.
|
||||
- `past`: an optional list of torch.LongTensor that contains pre-computed hidden-states (key and values in the attention blocks) to speed up sequential decoding (this is the `presents` output of the model, cf. below).
|
||||
|
||||
This model *outputs*:
|
||||
- `hidden_states`: the encoded-hidden-states at the top of the model as a torch.FloatTensor of size [batch_size, sequence_length, hidden_size] (or more generally [d_1, ..., d_n, hidden_size] were d_1 ... d_n are the dimension of input_ids)
|
||||
- `presents`: a list of pre-computed hidden-states (key and values in each attention blocks) as a torch.FloatTensors. They can be reused to speed up sequential decoding (see the `run_gpt2.py` example).
|
||||
|
||||
#### 15. `GPT2LMHeadModel`
|
||||
|
||||
`GPT2LMHeadModel` includes the `GPT2Model` Transformer followed by a language modeling head with weights tied to the input embeddings (no additional parameters).
|
||||
|
||||
*Inputs* are the same as the inputs of the [`GPT2Model`](#-14.-`GPT2Model`) class plus optional labels:
|
||||
- `lm_labels`: optional language modeling labels: torch.LongTensor of shape [batch_size, sequence_length] with indices selected in [-1, 0, ..., vocab_size]. All labels set to -1 are ignored (masked), the loss is only computed for the labels set in [0, ..., vocab_size].
|
||||
|
||||
*Outputs*:
|
||||
- if `lm_labels` is not `None`:
|
||||
Outputs the language modeling loss.
|
||||
- else: a tupple of
|
||||
- `lm_logits`: the language modeling logits as a torch.FloatTensor of size [batch_size, sequence_length, total_tokens_embeddings] (or more generally [d_1, ..., d_n, total_tokens_embeddings] were d_1 ... d_n are the dimension of input_ids)
|
||||
- `presents`: a list of pre-computed hidden-states (key and values in each attention blocks) as a torch.FloatTensors. They can be reused to speed up sequential decoding (see the `run_gpt2.py` example).
|
||||
|
||||
#### 16. `GPT2DoubleHeadsModel`
|
||||
|
||||
`GPT2DoubleHeadsModel` includes the `GPT2Model` Transformer followed by two heads:
|
||||
- a language modeling head with weights tied to the input embeddings (no additional parameters) and:
|
||||
- a multiple choice classifier (linear layer that take as input a hidden state in a sequence to compute a score, see details in paper).
|
||||
|
||||
*Inputs* are the same as the inputs of the [`GPT2Model`](#-14.-`GPT2Model`) class plus a classification mask and two optional labels:
|
||||
- `multiple_choice_token_ids`: a torch.LongTensor of shape [batch_size, num_choices] with the index of the token whose hidden state should be used as input for the multiple choice classifier (usually the [CLS] token for each choice).
|
||||
- `lm_labels`: optional language modeling labels: torch.LongTensor of shape [batch_size, sequence_length] with indices selected in [-1, 0, ..., vocab_size]. All labels set to -1 are ignored (masked), the loss is only computed for the labels set in [0, ..., vocab_size].
|
||||
- `multiple_choice_labels`: optional multiple choice labels: torch.LongTensor of shape [batch_size] with indices selected in [0, ..., num_choices].
|
||||
|
||||
*Outputs*:
|
||||
- if `lm_labels` and `multiple_choice_labels` are not `None`:
|
||||
Outputs a tuple of losses with the language modeling loss and the multiple choice loss.
|
||||
- else Outputs a tuple with:
|
||||
- `lm_logits`: the language modeling logits as a torch.FloatTensor of size [batch_size, num_choices, sequence_length, total_tokens_embeddings]
|
||||
- `multiple_choice_logits`: the multiple choice logits as a torch.FloatTensor of size [batch_size, num_choices]
|
||||
- `presents`: a list of pre-computed hidden-states (key and values in each attention blocks) as a torch.FloatTensors. They can be reused to speed up sequential decoding (see the `run_gpt2.py` example).
|
||||
|
||||
|
||||
### Tokenizers:
|
||||
|
||||
#### `BertTokenizer`
|
||||
@ -697,6 +834,24 @@ Please refer to the doc strings and code in [`tokenization_openai.py`](./pytorch
|
||||
|
||||
Please refer to the doc strings and code in [`tokenization_transfo_xl.py`](./pytorch_pretrained_bert/tokenization_transfo_xl.py) for the details of these additional methods in `TransfoXLTokenizer`.
|
||||
|
||||
#### `GPT2Tokenizer`
|
||||
|
||||
`GPT2Tokenizer` perform byte-level Byte-Pair-Encoding (BPE) tokenization.
|
||||
|
||||
This class has three arguments:
|
||||
|
||||
- `vocab_file`: path to a vocabulary file.
|
||||
- `merges_file`: path to a file containing the BPE merges.
|
||||
- `errors`: How to handle unicode decoding errors. **Default = `replace`**
|
||||
|
||||
and two methods:
|
||||
|
||||
- `encode(text)`: convert a `str` in a list of `int` tokens by performing byte-level BPE.
|
||||
- `decode(tokens)`: convert back a list of `int` tokens in a `str`.
|
||||
|
||||
Please refer to [`tokenization_gpt2.py`](./pytorch_pretrained_bert/tokenization_gpt2.py) for more details on the `GPT2Tokenizer`.
|
||||
|
||||
|
||||
### Optimizers:
|
||||
|
||||
#### `BertAdam`
|
||||
@ -896,12 +1051,13 @@ python run_lm_finetuning.py \
|
||||
--max_seq_length 128 \
|
||||
```
|
||||
|
||||
### OpenAI GPT and Transformer-XL: running the examples
|
||||
### OpenAI GPT, Transformer-XL and GPT-2: running the examples
|
||||
|
||||
We provide two examples of scripts for OpenAI GPT and Transformer-XL based on (and extended from) the respective original implementations:
|
||||
We provide three examples of scripts for OpenAI GPT, Transformer-XL and OpenAI GPT-2 based on (and extended from) the respective original implementations:
|
||||
|
||||
- fine-tuning OpenAI GPT on the ROCStories dataset
|
||||
- evaluating Transformer-XL on Wikitext 103
|
||||
- unconditional and conditional generation from a pre-trained OpenAI GPT-2 model
|
||||
|
||||
#### Fine-tuning OpenAI GPT on the RocStories dataset
|
||||
|
||||
@ -936,6 +1092,22 @@ python run_transfo_xl.py --work_dir ../log
|
||||
|
||||
This command runs in about 1 min on a V100 and gives an evaluation perplexity of 18.22 on WikiText-103 (the authors report a perplexity of about 18.3 on this dataset with the TensorFlow code).
|
||||
|
||||
#### Unconditional and conditional generation from OpenAI's GPT-2 model
|
||||
|
||||
This example code is identical to the original unconditional and conditional generation codes.
|
||||
|
||||
Conditional generation:
|
||||
```shell
|
||||
python run_gpt2.py
|
||||
```
|
||||
|
||||
Unconditional generation:
|
||||
```shell
|
||||
python run_gpt2.py --unconditional
|
||||
```
|
||||
|
||||
The same option as in the original scripts are provided, please refere to the code of the example and the original repository of OpenAI.
|
||||
|
||||
## Fine-tuning BERT-large on GPUs
|
||||
|
||||
The options we list above allow to fine-tune BERT-large rather easily on GPU(s) instead of the TPU used by the original implementation.
|
||||
@ -1050,12 +1222,25 @@ pytorch_pretrained_bert convert_openai_checkpoint \
|
||||
Here is an example of the conversion process for a pre-trained Transformer-XL model (see [here](https://github.com/kimiyoung/transformer-xl/tree/master/tf#obtain-and-evaluate-pretrained-sota-models))
|
||||
|
||||
```shell
|
||||
export BERT_BASE_DIR=/path/to/bert/uncased_L-12_H-768_A-12
|
||||
export TRANSFO_XL_CHECKPOINT_FOLDER_PATH=/path/to/transfo/xl/checkpoint
|
||||
|
||||
pytorch_pretrained_bert convert_openai_checkpoint \
|
||||
$OPENAI_GPT_CHECKPOINT_FOLDER_PATH \
|
||||
pytorch_pretrained_bert convert_transfo_xl_checkpoint \
|
||||
$TRANSFO_XL_CHECKPOINT_FOLDER_PATH \
|
||||
$PYTORCH_DUMP_OUTPUT \
|
||||
[OPENAI_GPT_CONFIG]
|
||||
[TRANSFO_XL_CONFIG]
|
||||
```
|
||||
|
||||
### GPT-2
|
||||
|
||||
Here is an example of the conversion process for a pre-trained OpenAI's GPT-2 model.
|
||||
|
||||
```shell
|
||||
export GPT2_DIR=/path/to/gpt2/checkpoint
|
||||
|
||||
pytorch_pretrained_bert convert_gpt2_checkpoint \
|
||||
$GPT2_DIR/model.ckpt \
|
||||
$PYTORCH_DUMP_OUTPUT \
|
||||
[GPT2_CONFIG]
|
||||
```
|
||||
|
||||
## TPU
|
||||
|
105
examples/run_gpt2.py
Normal file
105
examples/run_gpt2.py
Normal file
@ -0,0 +1,105 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from tqdm import trange
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
|
||||
from pytorch_pretrained_bert import GPT2LMHeadModel, GPT2Tokenizer
|
||||
|
||||
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
||||
datefmt = '%m/%d/%Y %H:%M:%S',
|
||||
level = logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def top_k_logits(logits, k):
|
||||
if k == 0:
|
||||
return logits
|
||||
values, _ = torch.topk(logits, k)
|
||||
min_values = values[:, -1]
|
||||
return torch.where(logits < min_values, torch.ones_like(logits, dtype=logits.dtype) * -1e10, logits)
|
||||
|
||||
def sample_sequence(model, length, start_token=None, batch_size=None, context=None, temperature=1, top_k=0, device='cuda', sample=True):
|
||||
if start_token is None:
|
||||
assert context is not None, 'Specify exactly one of start_token and context!'
|
||||
context = torch.tensor(context, device=device, dtype=torch.long).unsqueeze(0).repeat(batch_size, 1)
|
||||
else:
|
||||
assert context is None, 'Specify exactly one of start_token and context!'
|
||||
context = torch.full((batch_size, 1), start_token, device=device, dtype=torch.long)
|
||||
prev = context
|
||||
output = context
|
||||
past = None
|
||||
with torch.no_grad():
|
||||
for i in trange(length):
|
||||
logits, past = model(prev, past=past)
|
||||
logits = logits[:, -1, :] / temperature
|
||||
logits = top_k_logits(logits, k=top_k)
|
||||
log_probs = F.softmax(logits, dim=-1)
|
||||
if sample:
|
||||
prev = torch.multinomial(log_probs, num_samples=1)
|
||||
else:
|
||||
_, prev = torch.topk(log_probs, k=1, dim=-1)
|
||||
output = torch.cat((output, prev), dim=1)
|
||||
return output
|
||||
|
||||
def run_model():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--model_name_or_path', type=str, default='gpt2', help='pretrained model name or path to local checkpoint')
|
||||
parser.add_argument("--seed", type=int, default=0)
|
||||
parser.add_argument("--nsamples", type=int, default=1)
|
||||
parser.add_argument("--batch_size", type=int, default=-1)
|
||||
parser.add_argument("--length", type=int, default=-1)
|
||||
parser.add_argument("--temperature", type=int, default=1)
|
||||
parser.add_argument("--top_k", type=int, default=0)
|
||||
parser.add_argument('--unconditional', action='store_true', help='If true, unconditional generation.')
|
||||
args = parser.parse_args()
|
||||
print(args)
|
||||
|
||||
if args.batch_size == -1:
|
||||
args.batch_size = 1
|
||||
assert args.nsamples % args.batch_size == 0
|
||||
|
||||
np.random.seed(args.seed)
|
||||
torch.random.manual_seed(args.seed)
|
||||
torch.cuda.manual_seed(args.seed)
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
enc = GPT2Tokenizer.from_pretrained(args.model_name_or_path)
|
||||
model = GPT2LMHeadModel.from_pretrained(args.model_name_or_path)
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
||||
if args.length == -1:
|
||||
args.length = model.config.n_ctx // 2
|
||||
elif args.length > model.config.n_ctx:
|
||||
raise ValueError("Can't get samples longer than window size: %s" % model.config.n_ctx)
|
||||
|
||||
while not args.unconditional:
|
||||
if not args.unconditional:
|
||||
raw_text = input("Model prompt >>> ")
|
||||
while not raw_text:
|
||||
print('Prompt should not be empty!')
|
||||
raw_text = input("Model prompt >>> ")
|
||||
context_tokens = enc.encode(raw_text)
|
||||
generated = 0
|
||||
for _ in range(args.nsamples // args.batch_size):
|
||||
out = sample_sequence(
|
||||
model=model, length=args.length,
|
||||
context=context_tokens if not args.unconditional else None,
|
||||
start_token=enc.encoder['<|endoftext|>'] if args.unconditional else None,
|
||||
batch_size=args.batch_size,
|
||||
temperature=args.temperature, top_k=args.top_k, device=device
|
||||
)
|
||||
out = out[:, len(context_tokens):].tolist()
|
||||
for i in range(args.batch_size):
|
||||
generated += 1
|
||||
text = enc.decode(out[i])
|
||||
print("=" * 40 + " SAMPLE " + str(generated) + " " + "=" * 40)
|
||||
print(text)
|
||||
print("=" * 80)
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_model()
|
88
examples/run_gpt2_generate_unconditional_samples.py
Normal file
88
examples/run_gpt2_generate_unconditional_samples.py
Normal file
@ -0,0 +1,88 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
from tqdm import trange
|
||||
|
||||
from pytorch_pretrained_bert import GPT2LMHeadModel, GPT2Tokenizer
|
||||
|
||||
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
||||
datefmt = '%m/%d/%Y %H:%M:%S',
|
||||
level = logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def top_k_logits(logits, k):
|
||||
if k == 0:
|
||||
return logits
|
||||
values, _ = torch.topk(logits, k)
|
||||
min_values = values[:, -1]
|
||||
return torch.where(logits < min_values, torch.ones_like(logits, dtype=logits.dtype) * -1e10, logits)
|
||||
|
||||
def sample_sequence(model, length, start_token=None, batch_size=None, context=None, temperature=1, top_k=0, device='cuda'):
|
||||
if start_token is None:
|
||||
assert context is not None, 'Specify exactly one of start_token and context!'
|
||||
context = torch.tensor(context, device=device, dtype=torch.long)
|
||||
else:
|
||||
assert context is None, 'Specify exactly one of start_token and context!'
|
||||
context = torch.full((batch_size, 1), start_token, device=device, dtype=torch.long)
|
||||
prev = context
|
||||
output = context
|
||||
past = None
|
||||
with torch.no_grad():
|
||||
for i in trange(length):
|
||||
logits, past = model(prev, past=past)
|
||||
logits = logits[:, -1, :] / temperature
|
||||
logits = top_k_logits(logits, k=top_k)
|
||||
log_probs = F.softmax(logits, dim=-1)
|
||||
prev = torch.multinomial(log_probs, num_samples=1)
|
||||
output = torch.cat((output, prev), dim=1)
|
||||
return output
|
||||
|
||||
def sample_model():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--model_name_or_path', type=str, default='gpt2', help='pretrained model name or path to local checkpoint')
|
||||
parser.add_argument("--seed", type=int, default=0)
|
||||
parser.add_argument("--nsamples", type=int, default=0)
|
||||
parser.add_argument("--batch_size", type=int, default=1)
|
||||
parser.add_argument("--length", type=int, default=-1)
|
||||
parser.add_argument("--temperature", type=int, default=1)
|
||||
parser.add_argument("--top_k", type=int, default=0)
|
||||
args = parser.parse_args()
|
||||
print(args)
|
||||
|
||||
np.random.seed(args.seed)
|
||||
torch.random.manual_seed(args.seed)
|
||||
torch.cuda.manual_seed(args.seed)
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
enc = GPT2Tokenizer.from_pretrained(args.model_name_or_path)
|
||||
model = GPT2LMHeadModel.from_pretrained(args.model_name_or_path)
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
||||
if args.length == -1:
|
||||
args.length = model.config.n_ctx
|
||||
elif args.length > model.config.n_ctx:
|
||||
raise ValueError("Can't get samples longer than window size: %s" % model.config.n_ctx)
|
||||
|
||||
generated = 0
|
||||
while args.nsamples == 0 or generated < args.nsamples:
|
||||
out = sample_sequence(
|
||||
model=model, length=args.length,
|
||||
start_token=enc.encoder['<|endoftext|>'],
|
||||
batch_size=args.batch_size,
|
||||
temperature=args.temperature, top_k=args.top_k, device=device
|
||||
)
|
||||
out = out.tolist()
|
||||
for i in range(args.batch_size):
|
||||
generated += args.batch_size
|
||||
text = enc.decode(out[i])
|
||||
print("=" * 40 + " SAMPLE " + str(generated) + " " + "=" * 40)
|
||||
print(text)
|
||||
|
||||
if __name__ == '__main__':
|
||||
sample_model()
|
@ -1,7 +1,8 @@
|
||||
__version__ = "0.5.1"
|
||||
__version__ = "0.6.0"
|
||||
from .tokenization import BertTokenizer, BasicTokenizer, WordpieceTokenizer
|
||||
from .tokenization_openai import OpenAIGPTTokenizer
|
||||
from .tokenization_transfo_xl import (TransfoXLTokenizer, TransfoXLCorpus)
|
||||
from .tokenization_gpt2 import GPT2Tokenizer
|
||||
|
||||
from .modeling import (BertConfig, BertModel, BertForPreTraining,
|
||||
BertForMaskedLM, BertForNextSentencePrediction,
|
||||
@ -13,6 +14,9 @@ from .modeling_openai import (OpenAIGPTConfig, OpenAIGPTModel,
|
||||
load_tf_weights_in_openai_gpt)
|
||||
from .modeling_transfo_xl import (TransfoXLConfig, TransfoXLModel, TransfoXLLMHeadModel,
|
||||
load_tf_weights_in_transfo_xl)
|
||||
from .modeling_gpt2 import (GPT2Config, GPT2Model,
|
||||
GPT2LMHeadModel, GPT2DoubleHeadsModel,
|
||||
load_tf_weights_in_gpt2)
|
||||
|
||||
from .optimization import BertAdam
|
||||
from .optimization_openai import OpenAIAdam
|
||||
|
@ -4,13 +4,15 @@ def main():
|
||||
if (len(sys.argv) != 4 and len(sys.argv) != 5) or sys.argv[1] not in [
|
||||
"convert_tf_checkpoint_to_pytorch",
|
||||
"convert_openai_checkpoint",
|
||||
"convert_transfo_xl_checkpoint"
|
||||
"convert_transfo_xl_checkpoint",
|
||||
"convert_gpt2_checkpoint",
|
||||
]:
|
||||
print(
|
||||
"Should be used as one of: \n"
|
||||
">> `pytorch_pretrained_bert convert_tf_checkpoint_to_pytorch TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT`, \n"
|
||||
">> `pytorch_pretrained_bert convert_openai_checkpoint OPENAI_GPT_CHECKPOINT_FOLDER_PATH PYTORCH_DUMP_OUTPUT [OPENAI_GPT_CONFIG]` or \n"
|
||||
">> `pytorch_pretrained_bert convert_transfo_xl_checkpoint TF_CHECKPOINT_OR_DATASET PYTORCH_DUMP_OUTPUT [TF_CONFIG]`")
|
||||
">> `pytorch_pretrained_bert convert_openai_checkpoint OPENAI_GPT_CHECKPOINT_FOLDER_PATH PYTORCH_DUMP_OUTPUT [OPENAI_GPT_CONFIG]`, \n"
|
||||
">> `pytorch_pretrained_bert convert_transfo_xl_checkpoint TF_CHECKPOINT_OR_DATASET PYTORCH_DUMP_OUTPUT [TF_CONFIG]` or \n"
|
||||
">> `pytorch_pretrained_bert convert_gpt2_checkpoint TF_CHECKPOINT PYTORCH_DUMP_OUTPUT [GPT2_CONFIG]`")
|
||||
else:
|
||||
if sys.argv[1] == "convert_tf_checkpoint_to_pytorch":
|
||||
try:
|
||||
@ -40,7 +42,7 @@ def main():
|
||||
convert_openai_checkpoint_to_pytorch(OPENAI_GPT_CHECKPOINT_FOLDER_PATH,
|
||||
OPENAI_GPT_CONFIG,
|
||||
PYTORCH_DUMP_OUTPUT)
|
||||
else:
|
||||
elif sys.argv[1] == "convert_transfo_xl_checkpoint":
|
||||
try:
|
||||
from .convert_transfo_xl_checkpoint_to_pytorch import convert_transfo_xl_checkpoint_to_pytorch
|
||||
except ImportError:
|
||||
@ -61,5 +63,21 @@ def main():
|
||||
else:
|
||||
TF_CONFIG = ""
|
||||
convert_transfo_xl_checkpoint_to_pytorch(TF_CHECKPOINT, TF_CONFIG, PYTORCH_DUMP_OUTPUT, TF_DATASET_FILE)
|
||||
else:
|
||||
try:
|
||||
from .convert_gpt2_checkpoint_to_pytorch import convert_gpt2_checkpoint_to_pytorch
|
||||
except ImportError:
|
||||
print("pytorch_pretrained_bert can only be used from the commandline to convert TensorFlow models in PyTorch, "
|
||||
"In that case, it requires TensorFlow to be installed. Please see "
|
||||
"https://www.tensorflow.org/install/ for installation instructions.")
|
||||
raise
|
||||
|
||||
TF_CHECKPOINT = sys.argv[2]
|
||||
PYTORCH_DUMP_OUTPUT = sys.argv[3]
|
||||
if len(sys.argv) == 5:
|
||||
TF_CONFIG = sys.argv[4]
|
||||
else:
|
||||
TF_CONFIG = ""
|
||||
convert_gpt2_checkpoint_to_pytorch(TF_CHECKPOINT, TF_CONFIG, PYTORCH_DUMP_OUTPUT)
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
72
pytorch_pretrained_bert/convert_gpt2_checkpoint_to_pytorch.py
Executable file
72
pytorch_pretrained_bert/convert_gpt2_checkpoint_to_pytorch.py
Executable file
@ -0,0 +1,72 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2018 The HugginFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Convert OpenAI GPT checkpoint."""
|
||||
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import argparse
|
||||
from io import open
|
||||
|
||||
import torch
|
||||
|
||||
from pytorch_pretrained_bert.modeling_gpt2 import (CONFIG_NAME, WEIGHTS_NAME,
|
||||
GPT2Config,
|
||||
GPT2Model,
|
||||
load_tf_weights_in_gpt2)
|
||||
|
||||
|
||||
def convert_gpt2_checkpoint_to_pytorch(gpt2_checkpoint_path, gpt2_config_file, pytorch_dump_folder_path):
|
||||
# Construct model
|
||||
if gpt2_config_file == "":
|
||||
config = GPT2Config()
|
||||
else:
|
||||
config = GPT2Config(gpt2_config_file)
|
||||
model = GPT2Model(config)
|
||||
|
||||
# Load weights from numpy
|
||||
load_tf_weights_in_gpt2(model, gpt2_checkpoint_path)
|
||||
|
||||
# Save pytorch-model
|
||||
pytorch_weights_dump_path = pytorch_dump_folder_path + '/' + WEIGHTS_NAME
|
||||
pytorch_config_dump_path = pytorch_dump_folder_path + '/' + CONFIG_NAME
|
||||
print("Save PyTorch model to {}".format(pytorch_weights_dump_path))
|
||||
torch.save(model.state_dict(), pytorch_weights_dump_path)
|
||||
print("Save configuration file to {}".format(pytorch_config_dump_path))
|
||||
with open(pytorch_config_dump_path, "w", encoding="utf-8") as f:
|
||||
f.write(config.to_json_string())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
## Required parameters
|
||||
parser.add_argument("--gpt2_checkpoint_path",
|
||||
default = None,
|
||||
type = str,
|
||||
required = True,
|
||||
help = "Path the TensorFlow checkpoint path.")
|
||||
parser.add_argument("--pytorch_dump_folder_path",
|
||||
default = None,
|
||||
type = str,
|
||||
required = True,
|
||||
help = "Path to the output PyTorch model.")
|
||||
parser.add_argument("--gpt2_config_file",
|
||||
default = "",
|
||||
type = str,
|
||||
help = "An optional config json file corresponding to the pre-trained OpenAI model. \n"
|
||||
"This specifies the model architecture.")
|
||||
args = parser.parse_args()
|
||||
convert_gpt2_checkpoint_to_pytorch(args.gpt2_checkpoint_path,
|
||||
args.gpt2_config_file,
|
||||
args.pytorch_dump_folder_path)
|
684
pytorch_pretrained_bert/modeling_gpt2.py
Normal file
684
pytorch_pretrained_bert/modeling_gpt2.py
Normal file
@ -0,0 +1,684 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2018 The OpenAI Team Authors and HugginFace Inc. team.
|
||||
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""PyTorch OpenAI GPT-2 model."""
|
||||
|
||||
import collections
|
||||
import copy
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import shutil
|
||||
import tarfile
|
||||
import tempfile
|
||||
import sys
|
||||
from io import open
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import CrossEntropyLoss
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from .file_utils import cached_path
|
||||
from .modeling import BertLayerNorm as LayerNorm
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
PRETRAINED_MODEL_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-pytorch_model.bin"}
|
||||
PRETRAINED_CONFIG_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-config.json"}
|
||||
|
||||
CONFIG_NAME = "config.json"
|
||||
WEIGHTS_NAME = "pytorch_model.bin"
|
||||
|
||||
def load_tf_weights_in_gpt2(model, gpt2_checkpoint_path):
|
||||
""" Load tf checkpoints in a pytorch model
|
||||
"""
|
||||
try:
|
||||
import re
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
except ImportError:
|
||||
print("Loading a TensorFlow models in PyTorch, requires TensorFlow to be installed. Please see "
|
||||
"https://www.tensorflow.org/install/ for installation instructions.")
|
||||
raise
|
||||
tf_path = os.path.abspath(gpt2_checkpoint_path)
|
||||
print("Converting TensorFlow checkpoint from {}".format(tf_path))
|
||||
# Load weights from TF model
|
||||
init_vars = tf.train.list_variables(tf_path)
|
||||
names = []
|
||||
arrays = []
|
||||
for name, shape in init_vars:
|
||||
print("Loading TF weight {} with shape {}".format(name, shape))
|
||||
array = tf.train.load_variable(tf_path, name)
|
||||
names.append(name)
|
||||
arrays.append(array.squeeze())
|
||||
|
||||
for name, array in zip(names, arrays):
|
||||
name = name[6:] # skip "model/"
|
||||
name = name.split('/')
|
||||
pointer = model
|
||||
for m_name in name:
|
||||
if re.fullmatch(r'[A-Za-z]+\d+', m_name):
|
||||
l = re.split(r'(\d+)', m_name)
|
||||
else:
|
||||
l = [m_name]
|
||||
if l[0] == 'w' or l[0] == 'g':
|
||||
pointer = getattr(pointer, 'weight')
|
||||
elif l[0] == 'b':
|
||||
pointer = getattr(pointer, 'bias')
|
||||
elif l[0] == 'wpe' or l[0] == 'wte':
|
||||
pointer = getattr(pointer, l[0])
|
||||
pointer = getattr(pointer, 'weight')
|
||||
else:
|
||||
pointer = getattr(pointer, l[0])
|
||||
if len(l) >= 2:
|
||||
num = int(l[1])
|
||||
pointer = pointer[num]
|
||||
try:
|
||||
assert pointer.shape == array.shape
|
||||
except AssertionError as e:
|
||||
e.args += (pointer.shape, array.shape)
|
||||
raise
|
||||
print("Initialize PyTorch weight {}".format(name))
|
||||
pointer.data = torch.from_numpy(array)
|
||||
return model
|
||||
|
||||
|
||||
def gelu(x):
|
||||
return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
|
||||
|
||||
|
||||
class GPT2Config(object):
|
||||
"""Configuration class to store the configuration of a `GPT2Model`.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size_or_config_json_file=50257,
|
||||
n_positions=1024,
|
||||
n_ctx=1024,
|
||||
n_embd=768,
|
||||
n_layer=12,
|
||||
n_head=12,
|
||||
layer_norm_epsilon=1e-5,
|
||||
initializer_range=0.02,
|
||||
):
|
||||
"""Constructs GPT2Config.
|
||||
|
||||
Args:
|
||||
vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `GPT2Model` or a configuration json file.
|
||||
n_positions: Number of positional embeddings.
|
||||
n_ctx: Size of the causal mask (usually same as n_positions).
|
||||
n_embd: Dimensionality of the embeddings and hidden states.
|
||||
n_layer: Number of hidden layers in the Transformer encoder.
|
||||
n_head: Number of attention heads for each attention layer in
|
||||
the Transformer encoder.
|
||||
layer_norm_epsilon: epsilon to use in the layer norm layers
|
||||
initializer_range: The sttdev of the truncated_normal_initializer for
|
||||
initializing all weight matrices.
|
||||
"""
|
||||
if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2
|
||||
and isinstance(vocab_size_or_config_json_file, unicode)):
|
||||
with open(vocab_size_or_config_json_file, "r", encoding="utf-8") as reader:
|
||||
json_config = json.loads(reader.read())
|
||||
for key, value in json_config.items():
|
||||
self.__dict__[key] = value
|
||||
elif isinstance(vocab_size_or_config_json_file, int):
|
||||
self.vocab_size = vocab_size_or_config_json_file
|
||||
self.n_ctx = n_ctx
|
||||
self.n_positions = n_positions
|
||||
self.n_embd = n_embd
|
||||
self.n_layer = n_layer
|
||||
self.n_head = n_head
|
||||
self.layer_norm_epsilon = layer_norm_epsilon
|
||||
self.initializer_range = initializer_range
|
||||
else:
|
||||
raise ValueError(
|
||||
"First argument must be either a vocabulary size (int)"
|
||||
"or the path to a pretrained model config file (str)"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, json_object):
|
||||
"""Constructs a `GPT2Config` from a Python dictionary of parameters."""
|
||||
config = GPT2Config(vocab_size_or_config_json_file=-1)
|
||||
for key, value in json_object.items():
|
||||
config.__dict__[key] = value
|
||||
return config
|
||||
|
||||
@classmethod
|
||||
def from_json_file(cls, json_file):
|
||||
"""Constructs a `GPT2Config` from a json file of parameters."""
|
||||
with open(json_file, "r", encoding="utf-8") as reader:
|
||||
text = reader.read()
|
||||
return cls.from_dict(json.loads(text))
|
||||
|
||||
def __repr__(self):
|
||||
return str(self.to_json_string())
|
||||
|
||||
def to_dict(self):
|
||||
"""Serializes this instance to a Python dictionary."""
|
||||
output = copy.deepcopy(self.__dict__)
|
||||
return output
|
||||
|
||||
def to_json_string(self):
|
||||
"""Serializes this instance to a JSON string."""
|
||||
return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
|
||||
|
||||
|
||||
class Conv1D(nn.Module):
|
||||
def __init__(self, nf, nx):
|
||||
super(Conv1D, self).__init__()
|
||||
self.nf = nf
|
||||
w = torch.empty(nx, nf)
|
||||
nn.init.normal_(w, std=0.02)
|
||||
self.weight = Parameter(w)
|
||||
self.bias = Parameter(torch.zeros(nf))
|
||||
|
||||
def forward(self, x):
|
||||
size_out = x.size()[:-1] + (self.nf,)
|
||||
x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight)
|
||||
x = x.view(*size_out)
|
||||
return x
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, nx, n_ctx, config, scale=False):
|
||||
super(Attention, self).__init__()
|
||||
n_state = nx # in Attention: n_state=768 (nx=n_embd)
|
||||
# [switch nx => n_state from Block to Attention to keep identical to TF implem]
|
||||
assert n_state % config.n_head == 0
|
||||
self.register_buffer("bias", torch.tril(torch.ones(n_ctx, n_ctx)).view(1, 1, n_ctx, n_ctx))
|
||||
self.n_head = config.n_head
|
||||
self.split_size = n_state
|
||||
self.scale = scale
|
||||
self.c_attn = Conv1D(n_state * 3, nx)
|
||||
self.c_proj = Conv1D(n_state, nx)
|
||||
|
||||
def _attn(self, q, k, v):
|
||||
w = torch.matmul(q, k)
|
||||
if self.scale:
|
||||
w = w / math.sqrt(v.size(-1))
|
||||
nd, ns = w.size(-2), w.size(-1)
|
||||
b = self.bias[:, :, ns-nd:ns, :ns]
|
||||
w = w * b - 1e10 * (1 - b)
|
||||
|
||||
w = nn.Softmax(dim=-1)(w)
|
||||
return torch.matmul(w, v)
|
||||
|
||||
def merge_heads(self, x):
|
||||
x = x.permute(0, 2, 1, 3).contiguous()
|
||||
new_x_shape = x.size()[:-2] + (x.size(-2) * x.size(-1),)
|
||||
return x.view(*new_x_shape) # in Tensorflow implem: fct merge_states
|
||||
|
||||
def split_heads(self, x, k=False):
|
||||
new_x_shape = x.size()[:-1] + (self.n_head, x.size(-1) // self.n_head)
|
||||
x = x.view(*new_x_shape) # in Tensorflow implem: fct split_states
|
||||
if k:
|
||||
return x.permute(0, 2, 3, 1) # (batch, head, head_features, seq_length)
|
||||
else:
|
||||
return x.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features)
|
||||
|
||||
def forward(self, x, layer_past=None):
|
||||
x = self.c_attn(x)
|
||||
query, key, value = x.split(self.split_size, dim=2)
|
||||
query = self.split_heads(query)
|
||||
key = self.split_heads(key, k=True)
|
||||
value = self.split_heads(value)
|
||||
if layer_past is not None:
|
||||
past_key, past_value = layer_past[0].transpose(-2, -1), layer_past[1] # transpose back cf below
|
||||
key = torch.cat((past_key, key), dim=-1)
|
||||
value = torch.cat((past_value, value), dim=-2)
|
||||
present = torch.stack((key.transpose(-2, -1), value)) # transpose to have same shapes for stacking
|
||||
a = self._attn(query, key, value)
|
||||
a = self.merge_heads(a)
|
||||
a = self.c_proj(a)
|
||||
return a, present
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, n_state, config): # in MLP: n_state=3072 (4 * n_embd)
|
||||
super(MLP, self).__init__()
|
||||
nx = config.n_embd
|
||||
self.c_fc = Conv1D(n_state, nx)
|
||||
self.c_proj = Conv1D(nx, n_state)
|
||||
self.act = gelu
|
||||
|
||||
def forward(self, x):
|
||||
h = self.act(self.c_fc(x))
|
||||
h2 = self.c_proj(h)
|
||||
return h2
|
||||
|
||||
|
||||
class Block(nn.Module):
|
||||
def __init__(self, n_ctx, config, scale=False):
|
||||
super(Block, self).__init__()
|
||||
nx = config.n_embd
|
||||
self.ln_1 = LayerNorm(nx, eps=config.layer_norm_epsilon)
|
||||
self.attn = Attention(nx, n_ctx, config, scale)
|
||||
self.ln_2 = LayerNorm(nx, eps=config.layer_norm_epsilon)
|
||||
self.mlp = MLP(4 * nx, config)
|
||||
|
||||
def forward(self, x, layer_past=None):
|
||||
a, present = self.attn(self.ln_1(x), layer_past=layer_past)
|
||||
x = x + a
|
||||
m = self.mlp(self.ln_2(x))
|
||||
x = x + m
|
||||
return x, present
|
||||
|
||||
|
||||
class GPT2LMHead(nn.Module):
|
||||
""" Language Model Head for the transformer """
|
||||
|
||||
def __init__(self, model_embeddings_weights, config):
|
||||
super(GPT2LMHead, self).__init__()
|
||||
self.n_embd = config.n_embd
|
||||
self.set_embeddings_weights(model_embeddings_weights)
|
||||
|
||||
def set_embeddings_weights(self, model_embeddings_weights):
|
||||
embed_shape = model_embeddings_weights.shape
|
||||
self.decoder = nn.Linear(embed_shape[1], embed_shape[0], bias=False)
|
||||
self.decoder.weight = model_embeddings_weights # Tied weights
|
||||
|
||||
def forward(self, hidden_state):
|
||||
# Truncated Language modeling logits (we remove the last token)
|
||||
# h_trunc = h[:, :-1].contiguous().view(-1, self.n_embd)
|
||||
lm_logits = self.decoder(hidden_state)
|
||||
return lm_logits
|
||||
|
||||
|
||||
class GPT2MultipleChoiceHead(nn.Module):
|
||||
""" Classifier Head for the transformer """
|
||||
|
||||
def __init__(self, config):
|
||||
super(GPT2MultipleChoiceHead, self).__init__()
|
||||
self.n_embd = config.n_embd
|
||||
self.linear = nn.Linear(config.n_embd, 1)
|
||||
|
||||
nn.init.normal_(self.linear.weight, std=0.02)
|
||||
nn.init.normal_(self.linear.bias, 0)
|
||||
|
||||
def forward(self, hidden_states, mc_token_ids):
|
||||
# Classification logits
|
||||
# hidden_state (bsz, num_choices, seq_length, hidden_size)
|
||||
# mc_token_ids (bsz, num_choices)
|
||||
mc_token_ids = mc_token_ids.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, -1, hidden_states.size(-1))
|
||||
# (bsz, num_choices, 1, hidden_size)
|
||||
multiple_choice_h = hidden_states.gather(2, mc_token_ids).squeeze(2)
|
||||
# (bsz, num_choices, hidden_size)
|
||||
multiple_choice_logits = self.linear(multiple_choice_h).squeeze(-1)
|
||||
# (bsz, num_choices)
|
||||
return multiple_choice_logits
|
||||
|
||||
|
||||
class GPT2PreTrainedModel(nn.Module):
|
||||
""" An abstract class to handle weights initialization and
|
||||
a simple interface for dowloading and loading pretrained models.
|
||||
"""
|
||||
|
||||
def __init__(self, config, *inputs, **kwargs):
|
||||
super(GPT2PreTrainedModel, self).__init__()
|
||||
if not isinstance(config, GPT2Config):
|
||||
raise ValueError(
|
||||
"Parameter config in `{}(config)` should be an instance of class `GPT2Config`. "
|
||||
"To create a model from a pretrained model use "
|
||||
"`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
|
||||
self.__class__.__name__, self.__class__.__name__
|
||||
)
|
||||
)
|
||||
self.config = config
|
||||
|
||||
def set_tied(self):
|
||||
pass
|
||||
|
||||
def init_weights(self, module):
|
||||
""" Initialize the weights.
|
||||
"""
|
||||
if isinstance(module, (nn.Linear, nn.Embedding)):
|
||||
# Slightly different from the TF version which uses truncated_normal for initialization
|
||||
# cf https://github.com/pytorch/pytorch/pull/5617
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
elif isinstance(module, LayerNorm):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
if isinstance(module, nn.Linear) and module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls, pretrained_model_name_or_path, state_dict=None, cache_dir=None, from_tf=False, *inputs, **kwargs
|
||||
):
|
||||
"""
|
||||
Instantiate a GPT2PreTrainedModel from a pre-trained model file or a pytorch state dict.
|
||||
Download and cache the pre-trained model file if needed.
|
||||
|
||||
Params:
|
||||
pretrained_model_name_or_path: either:
|
||||
- a str with the name of a pre-trained model to load selected in the list of:
|
||||
. `openai-gpt`
|
||||
- a path or url to a pretrained model archive containing:
|
||||
. `gpt2_config.json` a configuration file for the model
|
||||
. `pytorch_model.bin` a PyTorch dump of a GPT2Model instance
|
||||
- a path or url to a pretrained model archive containing:
|
||||
. `bert_config.json` a configuration file for the model
|
||||
. a TensorFlow checkpoint with trained weights
|
||||
from_tf: should we load the weights from a locally saved TensorFlow checkpoint
|
||||
cache_dir: an optional path to a folder in which the pre-trained models will be cached.
|
||||
state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of pre-trained models
|
||||
*inputs, **kwargs: additional input for the specific Bert class
|
||||
(ex: num_labels for BertForSequenceClassification)
|
||||
"""
|
||||
if pretrained_model_name_or_path in PRETRAINED_MODEL_ARCHIVE_MAP:
|
||||
archive_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name_or_path]
|
||||
config_file = PRETRAINED_CONFIG_ARCHIVE_MAP[pretrained_model_name_or_path]
|
||||
else:
|
||||
archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
|
||||
config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME)
|
||||
# redirect to the cache, if necessary
|
||||
try:
|
||||
resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir)
|
||||
resolved_config_file = cached_path(config_file, cache_dir=cache_dir)
|
||||
except EnvironmentError:
|
||||
logger.error(
|
||||
"Model name '{}' was not found in model name list ({}). "
|
||||
"We assumed '{}' was a path or url but couldn't find files {} and {} "
|
||||
"at this path or url.".format(
|
||||
pretrained_model_name_or_path, ", ".join(PRETRAINED_MODEL_ARCHIVE_MAP.keys()), pretrained_model_name_or_path,
|
||||
archive_file, config_file
|
||||
)
|
||||
)
|
||||
return None
|
||||
if resolved_archive_file == archive_file and resolved_config_file == config_file:
|
||||
logger.info("loading weights file {}".format(archive_file))
|
||||
logger.info("loading configuration file {}".format(config_file))
|
||||
else:
|
||||
logger.info("loading weights file {} from cache at {}".format(
|
||||
archive_file, resolved_archive_file))
|
||||
logger.info("loading configuration file {} from cache at {}".format(
|
||||
config_file, resolved_config_file))
|
||||
# Load config
|
||||
config = GPT2Config.from_json_file(resolved_config_file)
|
||||
logger.info("Model config {}".format(config))
|
||||
# Instantiate model.
|
||||
model = cls(config, *inputs, **kwargs)
|
||||
if state_dict is None and not from_tf:
|
||||
state_dict = torch.load(resolved_archive_file, map_location='cpu' if not torch.cuda.is_available() else None)
|
||||
if from_tf:
|
||||
# Directly load from a TensorFlow checkpoint (stored as NumPy array)
|
||||
return load_tf_weights_in_gpt2(model, resolved_archive_file)
|
||||
|
||||
old_keys = []
|
||||
new_keys = []
|
||||
for key in state_dict.keys():
|
||||
new_key = None
|
||||
if key.endswith(".g"):
|
||||
new_key = key[:-2] + ".weight"
|
||||
elif key.endswith(".b"):
|
||||
new_key = key[:-2] + ".bias"
|
||||
elif key.endswith(".w"):
|
||||
new_key = key[:-2] + ".weight"
|
||||
if new_key:
|
||||
old_keys.append(key)
|
||||
new_keys.append(new_key)
|
||||
for old_key, new_key in zip(old_keys, new_keys):
|
||||
state_dict[new_key] = state_dict.pop(old_key)
|
||||
|
||||
missing_keys = []
|
||||
unexpected_keys = []
|
||||
error_msgs = []
|
||||
# copy state_dict so _load_from_state_dict can modify it
|
||||
metadata = getattr(state_dict, "_metadata", None)
|
||||
state_dict = state_dict.copy()
|
||||
if metadata is not None:
|
||||
state_dict._metadata = metadata
|
||||
|
||||
def load(module, prefix=""):
|
||||
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
|
||||
module._load_from_state_dict(
|
||||
state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs
|
||||
)
|
||||
for name, child in module._modules.items():
|
||||
if child is not None:
|
||||
load(child, prefix + name + ".")
|
||||
|
||||
start_model = model
|
||||
if hasattr(model, "transformer") and all(not s.startswith('transformer.') for s in state_dict.keys()):
|
||||
start_model = model.transformer
|
||||
load(start_model, prefix="")
|
||||
|
||||
if len(missing_keys) > 0:
|
||||
logger.info(
|
||||
"Weights of {} not initialized from pretrained model: {}".format(model.__class__.__name__, missing_keys)
|
||||
)
|
||||
if len(unexpected_keys) > 0:
|
||||
logger.info(
|
||||
"Weights from pretrained model not used in {}: {}".format(model.__class__.__name__, unexpected_keys)
|
||||
)
|
||||
if len(error_msgs) > 0:
|
||||
raise RuntimeError(
|
||||
"Error(s) in loading state_dict for {}:\n\t{}".format(model.__class__.__name__, "\n\t".join(error_msgs))
|
||||
)
|
||||
|
||||
# Make sure we are still sharing the output and input embeddings after loading weights
|
||||
model.set_tied()
|
||||
return model
|
||||
|
||||
|
||||
class GPT2Model(GPT2PreTrainedModel):
|
||||
"""OpenAI GPT-2 model ("Language Models are Unsupervised Multitask Learners").
|
||||
|
||||
Params:
|
||||
config: a GPT2Config class instance with the configuration to build a new model
|
||||
|
||||
Inputs:
|
||||
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] (or more generally [d_1, ..., d_n, sequence_length]
|
||||
were d_1 ... d_n are arbitrary dimensions) with the word BPE token indices selected in the range [0, config.vocab_size[
|
||||
`position_ids`: an optional torch.LongTensor with the same shape as input_ids
|
||||
with the position indices (selected in the range [0, config.n_positions - 1[.
|
||||
`token_type_ids`: an optional torch.LongTensor with the same shape as input_ids
|
||||
You can use it to add a third type of embedding to each input token in the sequence
|
||||
(the previous two being the word and position embeddings).
|
||||
The input, position and token_type embeddings are summed inside the Transformer before the first
|
||||
self-attention block.
|
||||
|
||||
Outputs:
|
||||
`hidden_states`: the encoded-hidden-states at the top of the model
|
||||
as a torch.FloatTensor of size [batch_size, sequence_length, hidden_size]
|
||||
(or more generally [d_1, ..., d_n, hidden_size] were d_1 ... d_n are the dimension of input_ids)
|
||||
|
||||
Example usage:
|
||||
```python
|
||||
# Already been converted into BPE token ids
|
||||
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
|
||||
|
||||
config = modeling_gpt2.GPT2Config()
|
||||
|
||||
model = modeling_gpt2.GPT2Model(config)
|
||||
hidden_states = model(input_ids)
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
super(GPT2Model, self).__init__(config)
|
||||
self.wte = nn.Embedding(config.vocab_size, config.n_embd)
|
||||
self.wpe = nn.Embedding(config.n_positions, config.n_embd)
|
||||
block = Block(config.n_ctx, config, scale=True)
|
||||
self.h = nn.ModuleList([copy.deepcopy(block) for _ in range(config.n_layer)])
|
||||
self.ln_f = LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
|
||||
|
||||
self.apply(self.init_weights)
|
||||
|
||||
def forward(self, input_ids, position_ids=None, token_type_ids=None, past=None):
|
||||
if past is None:
|
||||
past_length = 0
|
||||
past = [None] * len(self.h)
|
||||
else:
|
||||
past_length = past[0][0].size(-2)
|
||||
if position_ids is None:
|
||||
position_ids = torch.arange(past_length, input_ids.size(-1) + past_length, dtype=torch.long, device=input_ids.device)
|
||||
position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
|
||||
|
||||
input_shape = input_ids.size()
|
||||
input_ids = input_ids.view(-1, input_ids.size(-1))
|
||||
position_ids = position_ids.view(-1, position_ids.size(-1))
|
||||
|
||||
inputs_embeds = self.wte(input_ids)
|
||||
position_embeds = self.wpe(position_ids)
|
||||
if token_type_ids is not None:
|
||||
token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1))
|
||||
token_type_embeds = self.wte(token_type_ids)
|
||||
else:
|
||||
token_type_embeds = 0
|
||||
hidden_states = inputs_embeds + position_embeds + token_type_embeds
|
||||
presents = []
|
||||
for block, layer_past in zip(self.h, past):
|
||||
hidden_states, present = block(hidden_states, layer_past)
|
||||
presents.append(present)
|
||||
hidden_states = self.ln_f(hidden_states)
|
||||
output_shape = input_shape + (hidden_states.size(-1),)
|
||||
return hidden_states.view(*output_shape), presents
|
||||
|
||||
|
||||
class GPT2LMHeadModel(GPT2PreTrainedModel):
|
||||
"""OpenAI GPT-2 model with a Language Modeling head ("Language Models are Unsupervised Multitask Learners").
|
||||
|
||||
Params:
|
||||
config: a GPT2Config class instance with the configuration to build a new model
|
||||
|
||||
Inputs:
|
||||
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] (or more generally [d_1, ..., d_n, sequence_length]
|
||||
were d_1 ... d_n are arbitrary dimensions) with the word BPE token indices selected in the range [0, config.vocab_size[
|
||||
`position_ids`: an optional torch.LongTensor with the same shape as input_ids
|
||||
with the position indices (selected in the range [0, config.n_positions - 1[.
|
||||
`token_type_ids`: an optional torch.LongTensor with the same shape as input_ids
|
||||
You can use it to add a third type of embedding to each input token in the sequence
|
||||
(the previous two being the word and position embeddings).
|
||||
The input, position and token_type embeddings are summed inside the Transformer before the first
|
||||
self-attention block.
|
||||
`lm_labels`: optional language modeling labels: torch.LongTensor of shape [batch_size, sequence_length]
|
||||
with indices selected in [-1, 0, ..., vocab_size]. All labels set to -1 are ignored (masked), the loss
|
||||
is only computed for the labels set in [0, ..., vocab_size]
|
||||
|
||||
Outputs:
|
||||
if `lm_labels` is not `None`:
|
||||
Outputs the language modeling loss.
|
||||
else:
|
||||
`lm_logits`: the language modeling logits as a torch.FloatTensor of size [batch_size, sequence_length, config.vocab_size]
|
||||
(or more generally [d_1, ..., d_n, config.vocab_size] were d_1 ... d_n are the dimension of input_ids)
|
||||
|
||||
Example usage:
|
||||
```python
|
||||
# Already been converted into BPE token ids
|
||||
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
|
||||
|
||||
config = modeling_gpt2.GPT2Config()
|
||||
|
||||
model = modeling_gpt2.GPT2LMHeadModel(config)
|
||||
lm_logits = model(input_ids)
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
super(GPT2LMHeadModel, self).__init__(config)
|
||||
self.transformer = GPT2Model(config)
|
||||
self.lm_head = GPT2LMHead(self.transformer.wte.weight, config)
|
||||
self.apply(self.init_weights)
|
||||
|
||||
def set_tied(self):
|
||||
""" Make sure we are sharing the embeddings
|
||||
"""
|
||||
self.lm_head.set_embeddings_weights(self.transformer.wte.weight)
|
||||
|
||||
def forward(self, input_ids, position_ids=None, token_type_ids=None, lm_labels=None, past=None):
|
||||
hidden_states, presents = self.transformer(input_ids, position_ids, token_type_ids, past)
|
||||
lm_logits = self.lm_head(hidden_states)
|
||||
if lm_labels is not None:
|
||||
loss_fct = CrossEntropyLoss(ignore_index=-1)
|
||||
loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), lm_labels.view(-1))
|
||||
return loss
|
||||
return lm_logits, presents
|
||||
|
||||
|
||||
class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
|
||||
"""OpenAI GPT-2 model with a Language Modeling and a Multiple Choice head ("Language Models are Unsupervised Multitask Learners").
|
||||
|
||||
Params:
|
||||
config: a GPT2Config class instance with the configuration to build a new model
|
||||
|
||||
Inputs:
|
||||
`input_ids`: a torch.LongTensor of shape [batch_size, num_choices, sequence_length] with the BPE token
|
||||
indices selected in the range [0, config.vocab_size[
|
||||
`mc_token_ids`: a torch.LongTensor of shape [batch_size, num_choices] with the index of the token from
|
||||
which we should take the hidden state to feed the multiple choice classifier (usually last token of the sequence)
|
||||
`position_ids`: an optional torch.LongTensor with the same shape as input_ids
|
||||
with the position indices (selected in the range [0, config.n_positions - 1[.
|
||||
`token_type_ids`: an optional torch.LongTensor with the same shape as input_ids
|
||||
You can use it to add a third type of embedding to each input token in the sequence
|
||||
(the previous two being the word and position embeddings).
|
||||
The input, position and token_type embeddings are summed inside the Transformer before the first
|
||||
self-attention block.
|
||||
`lm_labels`: optional language modeling labels: torch.LongTensor of shape [batch_size, num_choices, sequence_length]
|
||||
with indices selected in [-1, 0, ..., config.vocab_size]. All labels set to -1 are ignored (masked), the loss
|
||||
is only computed for the labels set in [0, ..., config.vocab_size]
|
||||
`multiple_choice_labels`: optional multiple choice labels: torch.LongTensor of shape [batch_size]
|
||||
with indices selected in [0, ..., num_choices].
|
||||
|
||||
Outputs:
|
||||
if `lm_labels` and `multiple_choice_labels` are not `None`:
|
||||
Outputs a tuple of losses with the language modeling loss and the multiple choice loss.
|
||||
else: a tuple with
|
||||
`lm_logits`: the language modeling logits as a torch.FloatTensor of size [batch_size, num_choices, sequence_length, config.vocab_size]
|
||||
`multiple_choice_logits`: the multiple choice logits as a torch.FloatTensor of size [batch_size, num_choices]
|
||||
|
||||
Example usage:
|
||||
```python
|
||||
# Already been converted into BPE token ids
|
||||
input_ids = torch.LongTensor([[[31, 51, 99], [15, 5, 0]]]) # (bsz, number of choice, seq length)
|
||||
mc_token_ids = torch.LongTensor([[2], [1]]) # (bsz, number of choice)
|
||||
|
||||
config = modeling_gpt2.GPT2Config()
|
||||
|
||||
model = modeling_gpt2.GPT2LMHeadModel(config)
|
||||
lm_logits, multiple_choice_logits = model(input_ids, mc_token_ids)
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
super(GPT2DoubleHeadsModel, self).__init__(config)
|
||||
self.transformer = GPT2Model(config)
|
||||
self.lm_head = GPT2LMHead(self.transformer.wte.weight, config)
|
||||
self.multiple_choice_head = GPT2MultipleChoiceHead(config)
|
||||
self.apply(self.init_weights)
|
||||
|
||||
def set_tied(self):
|
||||
""" Make sure we are sharing the embeddings
|
||||
"""
|
||||
self.lm_head.set_embeddings_weights(self.transformer.wte.weight)
|
||||
|
||||
def forward(self, input_ids, mc_token_ids, lm_labels=None, mc_labels=None, token_type_ids=None, position_ids=None, past=None):
|
||||
hidden_states, presents = self.transformer(input_ids, position_ids, token_type_ids, past)
|
||||
lm_logits = self.lm_head(hidden_states)
|
||||
mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids)
|
||||
losses = []
|
||||
if lm_labels is not None:
|
||||
loss_fct = CrossEntropyLoss(ignore_index=-1)
|
||||
losses.append(loss_fct(lm_logits.view(-1, lm_logits.size(-1)), lm_labels.view(-1)))
|
||||
if mc_labels is not None:
|
||||
loss_fct = CrossEntropyLoss()
|
||||
losses.append(loss_fct(mc_logits.view(-1, mc_logits.size(-1)), mc_labels.view(-1)))
|
||||
if losses:
|
||||
return losses
|
||||
return lm_logits, mc_logits, presents
|
@ -56,7 +56,7 @@ def load_tf_weights_in_openai_gpt(model, openai_checkpoint_folder_path):
|
||||
init_params = np.split(np.concatenate(init_params, 0), offsets)[:-1]
|
||||
init_params = [param.reshape(shape) for param, shape in zip(init_params, shapes)]
|
||||
|
||||
# Thsi as used when we had a single embedding matrix for positions and tokens
|
||||
# This was used when we had a single embedding matrix for positions and tokens
|
||||
# init_params[0] = np.concatenate([init_params[1], init_params[0]], 0)
|
||||
# del init_params[1]
|
||||
init_params = [arr.squeeze() for arr in init_params]
|
||||
|
206
pytorch_pretrained_bert/tokenization_gpt2.py
Normal file
206
pytorch_pretrained_bert/tokenization_gpt2.py
Normal file
@ -0,0 +1,206 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2018 The Open AI Team Authors and The HugginFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tokenization classes for OpenAI GPT."""
|
||||
from __future__ import (absolute_import, division, print_function,
|
||||
unicode_literals)
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import regex as re
|
||||
from io import open
|
||||
|
||||
try:
|
||||
from functools import lru_cache
|
||||
except ImportError:
|
||||
# Just a dummy decorator to get the checks to run on python2
|
||||
# because honestly I don't want to support a byte-level unicode BPE tokenizer on python 2 right now.
|
||||
def lru_cache():
|
||||
return lambda func: func
|
||||
|
||||
from .file_utils import cached_path
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
PRETRAINED_VOCAB_ARCHIVE_MAP = {
|
||||
'gpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-vocab.json",
|
||||
}
|
||||
PRETRAINED_MERGES_ARCHIVE_MAP = {
|
||||
'gpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-merges.txt",
|
||||
}
|
||||
PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = {
|
||||
'gpt2': 1024,
|
||||
}
|
||||
VOCAB_NAME = 'vocab.json'
|
||||
MERGES_NAME = 'merges.txt'
|
||||
|
||||
@lru_cache()
|
||||
def bytes_to_unicode():
|
||||
"""
|
||||
Returns list of utf-8 byte and a corresponding list of unicode strings.
|
||||
The reversible bpe codes work on unicode strings.
|
||||
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
|
||||
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
|
||||
This is a signficant percentage of your normal, say, 32K bpe vocab.
|
||||
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
|
||||
And avoids mapping to whitespace/control characters the bpe code barfs on.
|
||||
"""
|
||||
bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
|
||||
cs = bs[:]
|
||||
n = 0
|
||||
for b in range(2**8):
|
||||
if b not in bs:
|
||||
bs.append(b)
|
||||
cs.append(2**8+n)
|
||||
n += 1
|
||||
cs = [chr(n) for n in cs]
|
||||
return dict(zip(bs, cs))
|
||||
|
||||
def get_pairs(word):
|
||||
"""Return set of symbol pairs in a word.
|
||||
|
||||
Word is represented as tuple of symbols (symbols being variable-length strings).
|
||||
"""
|
||||
pairs = set()
|
||||
prev_char = word[0]
|
||||
for char in word[1:]:
|
||||
pairs.add((prev_char, char))
|
||||
prev_char = char
|
||||
return pairs
|
||||
|
||||
class GPT2Tokenizer(object):
|
||||
"""
|
||||
GPT-2 BPE tokenizer. Peculiarities:
|
||||
- Byte-level BPE
|
||||
"""
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs):
|
||||
"""
|
||||
Instantiate a PreTrainedBertModel from a pre-trained model file.
|
||||
Download and cache the pre-trained model file if needed.
|
||||
"""
|
||||
if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP:
|
||||
vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name_or_path]
|
||||
merges_file = PRETRAINED_MERGES_ARCHIVE_MAP[pretrained_model_name_or_path]
|
||||
else:
|
||||
vocab_file = os.path.join(pretrained_model_name_or_path, VOCAB_NAME)
|
||||
merges_file = os.path.join(pretrained_model_name_or_path, MERGES_NAME)
|
||||
# redirect to the cache, if necessary
|
||||
try:
|
||||
resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir)
|
||||
resolved_merges_file = cached_path(merges_file, cache_dir=cache_dir)
|
||||
except EnvironmentError:
|
||||
logger.error(
|
||||
"Model name '{}' was not found in model name list ({}). "
|
||||
"We assumed '{}' was a path or url but couldn't find files {} and {} "
|
||||
"at this path or url.".format(
|
||||
pretrained_model_name_or_path,
|
||||
', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()),
|
||||
pretrained_model_name_or_path,
|
||||
vocab_file, merges_file))
|
||||
return None
|
||||
if resolved_vocab_file == vocab_file and resolved_merges_file == merges_file:
|
||||
logger.info("loading vocabulary file {}".format(vocab_file))
|
||||
logger.info("loading merges file {}".format(merges_file))
|
||||
else:
|
||||
logger.info("loading vocabulary file {} from cache at {}".format(
|
||||
vocab_file, resolved_vocab_file))
|
||||
logger.info("loading merges file {} from cache at {}".format(
|
||||
merges_file, resolved_merges_file))
|
||||
if pretrained_model_name_or_path in PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP:
|
||||
# if we're using a pretrained model, ensure the tokenizer wont index sequences longer
|
||||
# than the number of positional embeddings
|
||||
max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[pretrained_model_name_or_path]
|
||||
kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len)
|
||||
# Instantiate tokenizer.
|
||||
tokenizer = cls(resolved_vocab_file, resolved_merges_file, *inputs, **kwargs)
|
||||
return tokenizer
|
||||
|
||||
def __init__(self, vocab_file, merges_file, errors='replace', max_len=None):
|
||||
self.max_len = max_len if max_len is not None else int(1e12)
|
||||
self.encoder = json.load(open(vocab_file))
|
||||
self.decoder = {v:k for k,v in self.encoder.items()}
|
||||
self.errors = errors # how to handle errors in decoding
|
||||
self.byte_encoder = bytes_to_unicode()
|
||||
self.byte_decoder = {v:k for k, v in self.byte_encoder.items()}
|
||||
bpe_data = open(merges_file, encoding='utf-8').read().split('\n')[1:-1]
|
||||
bpe_merges = [tuple(merge.split()) for merge in bpe_data]
|
||||
self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
|
||||
self.cache = {}
|
||||
|
||||
# Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions
|
||||
self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")
|
||||
|
||||
def __len__(self):
|
||||
return len(self.encoder)
|
||||
|
||||
def bpe(self, token):
|
||||
if token in self.cache:
|
||||
return self.cache[token]
|
||||
word = tuple(token)
|
||||
pairs = get_pairs(word)
|
||||
|
||||
if not pairs:
|
||||
return token
|
||||
|
||||
while True:
|
||||
bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
|
||||
if bigram not in self.bpe_ranks:
|
||||
break
|
||||
first, second = bigram
|
||||
new_word = []
|
||||
i = 0
|
||||
while i < len(word):
|
||||
try:
|
||||
j = word.index(first, i)
|
||||
new_word.extend(word[i:j])
|
||||
i = j
|
||||
except:
|
||||
new_word.extend(word[i:])
|
||||
break
|
||||
|
||||
if word[i] == first and i < len(word)-1 and word[i+1] == second:
|
||||
new_word.append(first+second)
|
||||
i += 2
|
||||
else:
|
||||
new_word.append(word[i])
|
||||
i += 1
|
||||
new_word = tuple(new_word)
|
||||
word = new_word
|
||||
if len(word) == 1:
|
||||
break
|
||||
else:
|
||||
pairs = get_pairs(word)
|
||||
word = ' '.join(word)
|
||||
self.cache[token] = word
|
||||
return word
|
||||
|
||||
def encode(self, text):
|
||||
bpe_tokens = []
|
||||
for token in re.findall(self.pat, text):
|
||||
token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
|
||||
bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
|
||||
if len(bpe_tokens) > self.max_len:
|
||||
raise ValueError(
|
||||
"Token indices sequence length is longer than the specified maximum "
|
||||
" sequence length for this OpenAI GPT-2 model ({} > {}). Running this"
|
||||
" sequence through the model will result in indexing errors".format(len(bpe_tokens), self.max_len)
|
||||
)
|
||||
return bpe_tokens
|
||||
|
||||
def decode(self, tokens):
|
||||
text = ''.join([self.decoder[token] for token in tokens])
|
||||
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors)
|
||||
return text
|
2
setup.py
2
setup.py
@ -38,7 +38,7 @@ from setuptools import find_packages, setup
|
||||
|
||||
setup(
|
||||
name="pytorch_pretrained_bert",
|
||||
version="0.5.1",
|
||||
version="0.6.0",
|
||||
author="Thomas Wolf, Victor Sanh, Tim Rault, Google AI Language Team Authors, Open AI team Authors",
|
||||
author_email="thomas@huggingface.co",
|
||||
description="PyTorch version of Google AI BERT model with script to load Google pre-trained models",
|
||||
|
210
tests/modeling_gpt2_test.py
Normal file
210
tests/modeling_gpt2_test.py
Normal file
@ -0,0 +1,210 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2018 The Google AI Language Team Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import unittest
|
||||
import json
|
||||
import random
|
||||
|
||||
import torch
|
||||
|
||||
from pytorch_pretrained_bert import (GPT2Config, GPT2Model,
|
||||
GPT2LMHeadModel, GPT2DoubleHeadsModel)
|
||||
|
||||
|
||||
class GPT2ModelTest(unittest.TestCase):
|
||||
class GPT2ModelTester(object):
|
||||
|
||||
def __init__(self,
|
||||
parent,
|
||||
batch_size=13,
|
||||
seq_length=7,
|
||||
is_training=True,
|
||||
use_position_ids=True,
|
||||
use_token_type_ids=True,
|
||||
use_labels=True,
|
||||
vocab_size=99,
|
||||
n_positions=33,
|
||||
n_embd=32,
|
||||
n_layer=5,
|
||||
n_head=4,
|
||||
n_choices=3,
|
||||
type_sequence_label_size=2,
|
||||
initializer_range=0.02,
|
||||
num_labels=3,
|
||||
scope=None):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
self.seq_length = seq_length
|
||||
self.is_training = is_training
|
||||
self.use_position_ids = use_position_ids
|
||||
self.use_token_type_ids = use_token_type_ids
|
||||
self.use_labels = use_labels
|
||||
self.vocab_size = vocab_size
|
||||
self.n_positions = n_positions
|
||||
self.n_embd = n_embd
|
||||
self.n_layer = n_layer
|
||||
self.n_head = n_head
|
||||
self.n_choices = n_choices
|
||||
self.type_sequence_label_size = type_sequence_label_size
|
||||
self.initializer_range = initializer_range
|
||||
self.num_labels = num_labels
|
||||
self.scope = scope
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
input_ids = GPT2ModelTest.ids_tensor([self.batch_size, self.n_choices, self.seq_length], self.vocab_size)
|
||||
|
||||
position_ids = None
|
||||
if self.use_position_ids:
|
||||
position_ids = GPT2ModelTest.ids_tensor([self.batch_size, self.n_choices, self.seq_length], self.n_positions)
|
||||
|
||||
token_type_ids = None
|
||||
if self.use_token_type_ids:
|
||||
total_voc = self.vocab_size
|
||||
token_type_ids = GPT2ModelTest.ids_tensor([self.batch_size, self.n_choices, self.seq_length], total_voc)
|
||||
|
||||
mc_labels = None
|
||||
lm_labels = None
|
||||
mc_token_ids = None
|
||||
if self.use_labels:
|
||||
mc_labels = GPT2ModelTest.ids_tensor([self.batch_size], self.type_sequence_label_size)
|
||||
lm_labels = GPT2ModelTest.ids_tensor([self.batch_size, self.n_choices, self.seq_length], self.num_labels)
|
||||
mc_token_ids = GPT2ModelTest.ids_tensor([self.batch_size, self.n_choices], self.seq_length)
|
||||
|
||||
config = GPT2Config(
|
||||
vocab_size_or_config_json_file=self.vocab_size,
|
||||
n_positions=self.n_positions,
|
||||
n_embd=self.n_embd,
|
||||
n_layer=self.n_layer,
|
||||
n_head=self.n_head,
|
||||
initializer_range=self.initializer_range)
|
||||
|
||||
return (config, input_ids, token_type_ids, position_ids,
|
||||
mc_labels, lm_labels, mc_token_ids)
|
||||
|
||||
def create_gpt2_model(self, config, input_ids, token_type_ids, position_ids,
|
||||
mc_labels, lm_labels, mc_token_ids):
|
||||
model = GPT2Model(config)
|
||||
model.eval()
|
||||
hidden_states, presents = model(input_ids, position_ids, token_type_ids)
|
||||
outputs = {
|
||||
"hidden_states": hidden_states,
|
||||
"presents": presents,
|
||||
}
|
||||
return outputs
|
||||
|
||||
def check_gpt2_model_output(self, result):
|
||||
self.parent.assertListEqual(
|
||||
list(result["hidden_states"].size()),
|
||||
[self.batch_size, self.n_choices, self.seq_length, self.n_embd])
|
||||
|
||||
|
||||
def create_gpt2_lm_head(self, config, input_ids, token_type_ids, position_ids,
|
||||
mc_labels, lm_labels, mc_token_ids):
|
||||
model = GPT2LMHeadModel(config)
|
||||
model.eval()
|
||||
loss = model(input_ids, position_ids, token_type_ids, lm_labels)
|
||||
lm_logits, presents = model(input_ids, position_ids, token_type_ids)
|
||||
outputs = {
|
||||
"loss": loss,
|
||||
"lm_logits": lm_logits,
|
||||
"presents": presents,
|
||||
}
|
||||
return outputs
|
||||
|
||||
def check_gpt2_lm_head_output(self, result):
|
||||
total_voc = self.vocab_size
|
||||
self.parent.assertListEqual(
|
||||
list(result["lm_logits"].size()),
|
||||
[self.batch_size, self.n_choices, self.seq_length, total_voc])
|
||||
|
||||
def check_gpt2_lm_head_loss_output(self, result):
|
||||
self.parent.assertListEqual(
|
||||
list(result["loss"].size()),
|
||||
[])
|
||||
|
||||
def create_gpt2_double_heads(self, config, input_ids, token_type_ids, position_ids,
|
||||
mc_labels, lm_labels, mc_token_ids):
|
||||
model = GPT2DoubleHeadsModel(config)
|
||||
model.eval()
|
||||
loss = model(input_ids, mc_token_ids,
|
||||
lm_labels=lm_labels, mc_labels=mc_labels,
|
||||
token_type_ids=token_type_ids, position_ids=position_ids)
|
||||
lm_logits, mc_logits, presents = model(input_ids, mc_token_ids, position_ids=position_ids, token_type_ids=token_type_ids)
|
||||
outputs = {
|
||||
"loss": loss,
|
||||
"lm_logits": lm_logits,
|
||||
"mc_logits": mc_logits,
|
||||
"presents": presents,
|
||||
}
|
||||
return outputs
|
||||
|
||||
def check_gpt2_double_heads_output(self, result):
|
||||
total_voc = self.vocab_size
|
||||
self.parent.assertListEqual(
|
||||
list(result["lm_logits"].size()),
|
||||
[self.batch_size, self.n_choices, self.seq_length, total_voc])
|
||||
self.parent.assertListEqual(
|
||||
list(result["mc_logits"].size()),
|
||||
[self.batch_size, self.n_choices])
|
||||
|
||||
def check_gpt2_double_heads_loss_output(self, result):
|
||||
self.parent.assertListEqual(
|
||||
[list(l.size()) for l in result["loss"]],
|
||||
[[], []])
|
||||
|
||||
def test_default(self):
|
||||
self.run_tester(GPT2ModelTest.GPT2ModelTester(self))
|
||||
|
||||
def test_config_to_json_string(self):
|
||||
config = GPT2Config(vocab_size_or_config_json_file=99, n_embd=37)
|
||||
obj = json.loads(config.to_json_string())
|
||||
self.assertEqual(obj["vocab_size"], 99)
|
||||
self.assertEqual(obj["n_embd"], 37)
|
||||
|
||||
def run_tester(self, tester):
|
||||
config_and_inputs = tester.prepare_config_and_inputs()
|
||||
output_result = tester.create_gpt2_model(*config_and_inputs)
|
||||
tester.check_gpt2_model_output(output_result)
|
||||
|
||||
output_result = tester.create_gpt2_lm_head(*config_and_inputs)
|
||||
tester.check_gpt2_lm_head_output(output_result)
|
||||
tester.check_gpt2_lm_head_loss_output(output_result)
|
||||
|
||||
output_result = tester.create_gpt2_double_heads(*config_and_inputs)
|
||||
tester.check_gpt2_double_heads_output(output_result)
|
||||
tester.check_gpt2_double_heads_loss_output(output_result)
|
||||
|
||||
@classmethod
|
||||
def ids_tensor(cls, shape, vocab_size, rng=None, name=None):
|
||||
"""Creates a random int32 tensor of the shape within the vocab size."""
|
||||
if rng is None:
|
||||
rng = random.Random()
|
||||
|
||||
total_dims = 1
|
||||
for dim in shape:
|
||||
total_dims *= dim
|
||||
|
||||
values = []
|
||||
for _ in range(total_dims):
|
||||
values.append(rng.randint(0, vocab_size - 1))
|
||||
|
||||
return torch.tensor(data=values, dtype=torch.long).view(shape).contiguous()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Loading…
Reference in New Issue
Block a user