mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-04 05:10:06 +06:00
updating GLUE utils for compatibility with XLNet
This commit is contained in:
parent
24ed0b9346
commit
62d78aa37e
126
README.md
126
README.md
@ -137,9 +137,9 @@ This package comprises the following classes that can be imported in Python and
|
||||
The repository further comprises:
|
||||
|
||||
- Five examples on how to use **BERT** (in the [`examples` folder](./examples)):
|
||||
- [`extract_features.py`](./examples/extract_features.py) - Show how to extract hidden states from an instance of `BertModel`,
|
||||
- [`run_classifier.py`](./examples/run_classifier.py) - Show how to fine-tune an instance of `BertForSequenceClassification` on GLUE's MRPC task,
|
||||
- [`run_squad.py`](./examples/run_squad.py) - Show how to fine-tune an instance of `BertForQuestionAnswering` on SQuAD v1.0 and SQuAD v2.0 tasks.
|
||||
- [`run_bert_extract_features.py`](./examples/run_bert_extract_features.py) - Show how to extract hidden states from an instance of `BertModel`,
|
||||
- [`run_bert_classifier.py`](./examples/run_bert_classifier.py) - Show how to fine-tune an instance of `BertForSequenceClassification` on GLUE's MRPC task,
|
||||
- [`run_bert_squad.py`](./examples/run_bert_squad.py) - Show how to fine-tune an instance of `BertForQuestionAnswering` on SQuAD v1.0 and SQuAD v2.0 tasks.
|
||||
- [`run_swag.py`](./examples/run_swag.py) - Show how to fine-tune an instance of `BertForMultipleChoice` on Swag task.
|
||||
- [`simple_lm_finetuning.py`](./examples/lm_finetuning/simple_lm_finetuning.py) - Show how to fine-tune an instance of `BertForPretraining` on a target text corpus.
|
||||
|
||||
@ -541,7 +541,7 @@ where
|
||||
- `bert-base-german-cased`: Trained on German data only, 12-layer, 768-hidden, 12-heads, 110M parameters [Performance Evaluation](https://deepset.ai/german-bert)
|
||||
- `bert-large-uncased-whole-word-masking`: 24-layer, 1024-hidden, 16-heads, 340M parameters - Trained with Whole Word Masking (mask all of the the tokens corresponding to a word at once)
|
||||
- `bert-large-cased-whole-word-masking`: 24-layer, 1024-hidden, 16-heads, 340M parameters - Trained with Whole Word Masking (mask all of the the tokens corresponding to a word at once)
|
||||
- `bert-large-uncased-whole-word-masking-finetuned-squad`: The `bert-large-uncased-whole-word-masking` model finetuned on SQuAD (using the `run_squad.py` examples). Results: *exact_match: 86.91579943235573, f1: 93.1532499015869*
|
||||
- `bert-large-uncased-whole-word-masking-finetuned-squad`: The `bert-large-uncased-whole-word-masking` model finetuned on SQuAD (using the `run_bert_squad.py` examples). Results: *exact_match: 86.91579943235573, f1: 93.1532499015869*
|
||||
- `openai-gpt`: OpenAI GPT English model, 12-layer, 768-hidden, 12-heads, 110M parameters
|
||||
- `gpt2`: OpenAI GPT-2 English model, 12-layer, 768-hidden, 12-heads, 117M parameters
|
||||
- `gpt2-medium`: OpenAI GPT-2 English model, 24-layer, 1024-hidden, 16-heads, 345M parameters
|
||||
@ -720,7 +720,7 @@ The inputs and output are **identical to the TensorFlow model inputs and outputs
|
||||
|
||||
We detail them here. This model takes as *inputs*:
|
||||
[`modeling.py`](./pytorch_pretrained_bert/modeling.py)
|
||||
- `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] with the word token indices in the vocabulary (see the tokens preprocessing logic in the scripts [`extract_features.py`](./examples/extract_features.py), [`run_classifier.py`](./examples/run_classifier.py) and [`run_squad.py`](./examples/run_squad.py)), and
|
||||
- `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] with the word token indices in the vocabulary (see the tokens preprocessing logic in the scripts [`run_bert_extract_features.py`](./examples/run_bert_extract_features.py), [`run_bert_classifier.py`](./examples/run_bert_classifier.py) and [`run_bert_squad.py`](./examples/run_bert_squad.py)), and
|
||||
- `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to a `sentence B` token (see BERT paper for more details).
|
||||
- `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices selected in [0, 1]. It's a mask to be used if some input sequence lengths are smaller than the max input sequence length of the current batch. It's the mask that we typically use for attention when a batch has varying length sentences.
|
||||
- `output_all_encoded_layers`: boolean which controls the content of the `encoded_layers` output as described below. Default: `True`.
|
||||
@ -735,7 +735,7 @@ This model *outputs* a tuple composed of:
|
||||
|
||||
- `pooled_output`: a torch.FloatTensor of size [batch_size, hidden_size] which is the output of a classifier pretrained on top of the hidden state associated to the first character of the input (`CLF`) to train on the Next-Sentence task (see BERT's paper).
|
||||
|
||||
An example on how to use this class is given in the [`extract_features.py`](./examples/extract_features.py) script which can be used to extract the hidden states of the model for a given input.
|
||||
An example on how to use this class is given in the [`run_bert_extract_features.py`](./examples/run_bert_extract_features.py) script which can be used to extract the hidden states of the model for a given input.
|
||||
|
||||
#### 2. `BertForPreTraining`
|
||||
|
||||
@ -792,7 +792,7 @@ An example on how to use this class is given in the [`run_lm_finetuning.py`](./e
|
||||
|
||||
The sequence-level classifier is a linear layer that takes as input the last hidden state of the first character in the input sequence (see Figures 3a and 3b in the BERT paper).
|
||||
|
||||
An example on how to use this class is given in the [`run_classifier.py`](./examples/run_classifier.py) script which can be used to fine-tune a single sequence (or pair of sequence) classifier using BERT, for example for the MRPC task.
|
||||
An example on how to use this class is given in the [`run_bert_classifier.py`](./examples/run_bert_classifier.py) script which can be used to fine-tune a single sequence (or pair of sequence) classifier using BERT, for example for the MRPC task.
|
||||
|
||||
#### 6. `BertForMultipleChoice`
|
||||
|
||||
@ -816,7 +816,7 @@ The token-level classifier is a linear layer that takes as input the last hidden
|
||||
|
||||
The token-level classifier takes as input the full sequence of the last hidden state and compute several (e.g. two) scores for each tokens that can for example respectively be the score that a given token is a `start_span` and a `end_span` token (see Figures 3c and 3d in the BERT paper).
|
||||
|
||||
An example on how to use this class is given in the [`run_squad.py`](./examples/run_squad.py) script which can be used to fine-tune a token classifier using BERT, for example for the SQuAD task.
|
||||
An example on how to use this class is given in the [`run_bert_squad.py`](./examples/run_bert_squad.py) script which can be used to fine-tune a token classifier using BERT, for example for the SQuAD task.
|
||||
|
||||
#### 9. `OpenAIGPTModel`
|
||||
|
||||
@ -1138,7 +1138,7 @@ An overview of the implemented schedules:
|
||||
| Sub-section | Description |
|
||||
|-|-|
|
||||
| [Training large models: introduction, tools and examples](#Training-large-models-introduction,-tools-and-examples) | How to use gradient-accumulation, multi-gpu training, distributed training, optimize on CPU and 16-bits training to train Bert models |
|
||||
| [Fine-tuning with BERT: running the examples](#Fine-tuning-with-BERT-running-the-examples) | Running the examples in [`./examples`](./examples/): `extract_classif.py`, `run_classifier.py`, `run_squad.py` and `run_lm_finetuning.py` |
|
||||
| [Fine-tuning with BERT: running the examples](#Fine-tuning-with-BERT-running-the-examples) | Running the examples in [`./examples`](./examples/): `extract_classif.py`, `run_bert_classifier.py`, `run_bert_squad.py` and `run_lm_finetuning.py` |
|
||||
| [Fine-tuning with OpenAI GPT, Transformer-XL and GPT-2](#openai-gpt-transformer-xl-and-gpt-2-running-the-examples) | Running the examples in [`./examples`](./examples/): `run_openai_gpt.py`, `run_transfo_xl.py` and `run_gpt2.py` |
|
||||
| [Fine-tuning BERT-large on GPUs](#Fine-tuning-BERT-large-on-GPUs) | How to fine tune `BERT large`|
|
||||
|
||||
@ -1146,7 +1146,7 @@ An overview of the implemented schedules:
|
||||
|
||||
BERT-base and BERT-large are respectively 110M and 340M parameters models and it can be difficult to fine-tune them on a single GPU with the recommended batch size for good performance (in most case a batch size of 32).
|
||||
|
||||
To help with fine-tuning these models, we have included several techniques that you can activate in the fine-tuning scripts [`run_classifier.py`](./examples/run_classifier.py) and [`run_squad.py`](./examples/run_squad.py): gradient-accumulation, multi-gpu training, distributed training and 16-bits training . For more details on how to use these techniques you can read [the tips on training large batches in PyTorch](https://medium.com/huggingface/training-larger-batches-practical-tips-on-1-gpu-multi-gpu-distributed-setups-ec88c3e51255) that I published earlier this month.
|
||||
To help with fine-tuning these models, we have included several techniques that you can activate in the fine-tuning scripts [`run_bert_classifier.py`](./examples/run_bert_classifier.py) and [`run_bert_squad.py`](./examples/run_bert_squad.py): gradient-accumulation, multi-gpu training, distributed training and 16-bits training . For more details on how to use these techniques you can read [the tips on training large batches in PyTorch](https://medium.com/huggingface/training-larger-batches-practical-tips-on-1-gpu-multi-gpu-distributed-setups-ec88c3e51255) that I published earlier this month.
|
||||
|
||||
Here is how to use these techniques in our scripts:
|
||||
|
||||
@ -1159,7 +1159,7 @@ To use 16-bits training and distributed training, you need to install NVIDIA's a
|
||||
|
||||
Note: To use *Distributed Training*, you will need to run one training script on each of your machines. This can be done for example by running the following command on each server (see [the above mentioned blog post]((https://medium.com/huggingface/training-larger-batches-practical-tips-on-1-gpu-multi-gpu-distributed-setups-ec88c3e51255)) for more details):
|
||||
```bash
|
||||
python -m torch.distributed.launch --nproc_per_node=4 --nnodes=2 --node_rank=$THIS_MACHINE_INDEX --master_addr="192.168.1.1" --master_port=1234 run_classifier.py (--arg1 --arg2 --arg3 and all other arguments of the run_classifier script)
|
||||
python -m torch.distributed.launch --nproc_per_node=4 --nnodes=2 --node_rank=$THIS_MACHINE_INDEX --master_addr="192.168.1.1" --master_port=1234 run_bert_classifier.py (--arg1 --arg2 --arg3 and all other arguments of the run_classifier script)
|
||||
```
|
||||
Where `$THIS_MACHINE_INDEX` is an sequential index assigned to each of your machine (0, 1, 2...) and the machine with rank 0 has an IP address `192.168.1.1` and an open port `1234`.
|
||||
|
||||
@ -1201,7 +1201,7 @@ and unpack it to some directory `$GLUE_DIR`.
|
||||
export GLUE_DIR=/path/to/glue
|
||||
export TASK_NAME=MRPC
|
||||
|
||||
python run_classifier.py \
|
||||
python run_bert_classifier.py \
|
||||
--task_name $TASK_NAME \
|
||||
--do_train \
|
||||
--do_eval \
|
||||
@ -1234,7 +1234,7 @@ and unpack it to some directory `$GLUE_DIR`.
|
||||
```shell
|
||||
export GLUE_DIR=/path/to/glue
|
||||
|
||||
python run_classifier.py \
|
||||
python run_bert_classifier.py \
|
||||
--task_name MRPC \
|
||||
--do_train \
|
||||
--do_eval \
|
||||
@ -1256,7 +1256,7 @@ Then run
|
||||
```shell
|
||||
export GLUE_DIR=/path/to/glue
|
||||
|
||||
python run_classifier.py \
|
||||
python run_bert_classifier.py \
|
||||
--task_name MRPC \
|
||||
--do_train \
|
||||
--do_eval \
|
||||
@ -1275,7 +1275,7 @@ python run_classifier.py \
|
||||
Here is an example using distributed training on 8 V100 GPUs and Bert Whole Word Masking model to reach a F1 > 92 on MRPC:
|
||||
|
||||
```bash
|
||||
python -m torch.distributed.launch --nproc_per_node 8 run_classifier.py --bert_model bert-large-uncased-whole-word-masking --task_name MRPC --do_train --do_eval --do_lower_case --data_dir $GLUE_DIR/MRPC/ --max_seq_length 128 --train_batch_size 8 --learning_rate 2e-5 --num_train_epochs 3.0 --output_dir /tmp/mrpc_output/
|
||||
python -m torch.distributed.launch --nproc_per_node 8 run_bert_classifier.py --bert_model bert-large-uncased-whole-word-masking --task_name MRPC --do_train --do_eval --do_lower_case --data_dir $GLUE_DIR/MRPC/ --max_seq_length 128 --train_batch_size 8 --learning_rate 2e-5 --num_train_epochs 3.0 --output_dir /tmp/mrpc_output/
|
||||
```
|
||||
|
||||
Training with these hyper-parameters gave us the following results:
|
||||
@ -1291,7 +1291,7 @@ Training with these hyper-parameters gave us the following results:
|
||||
Here is an example on MNLI:
|
||||
|
||||
```bash
|
||||
python -m torch.distributed.launch --nproc_per_node 8 run_classifier.py --bert_model bert-large-uncased-whole-word-masking --task_name mnli --do_train --do_eval --do_lower_case --data_dir /datadrive/bert_data/glue_data//MNLI/ --max_seq_length 128 --train_batch_size 8 --learning_rate 2e-5 --num_train_epochs 3.0 --output_dir ../models/wwm-uncased-finetuned-mnli/ --overwrite_output_dir
|
||||
python -m torch.distributed.launch --nproc_per_node 8 run_bert_classifier.py --bert_model bert-large-uncased-whole-word-masking --task_name mnli --do_train --do_eval --do_lower_case --data_dir /datadrive/bert_data/glue_data//MNLI/ --max_seq_length 128 --train_batch_size 8 --learning_rate 2e-5 --num_train_epochs 3.0 --output_dir ../models/wwm-uncased-finetuned-mnli/ --overwrite_output_dir
|
||||
```
|
||||
|
||||
```bash
|
||||
@ -1324,7 +1324,7 @@ The data for SQuAD can be downloaded with the following links and should be save
|
||||
```shell
|
||||
export SQUAD_DIR=/path/to/SQUAD
|
||||
|
||||
python run_squad.py \
|
||||
python run_bert_squad.py \
|
||||
--bert_model bert-base-uncased \
|
||||
--do_train \
|
||||
--do_predict \
|
||||
@ -1351,7 +1351,7 @@ Here is an example using distributed training on 8 V100 GPUs and Bert Whole Word
|
||||
|
||||
```bash
|
||||
python -m torch.distributed.launch --nproc_per_node=8 \
|
||||
run_squad.py \
|
||||
run_bert_squad.py \
|
||||
--bert_model bert-large-uncased-whole-word-masking \
|
||||
--do_train \
|
||||
--do_predict \
|
||||
@ -1378,7 +1378,7 @@ This is the model provided as `bert-large-uncased-whole-word-masking-finetuned-s
|
||||
And here is the model provided as `bert-large-cased-whole-word-masking-finetuned-squad`:
|
||||
|
||||
```bash
|
||||
python -m torch.distributed.launch --nproc_per_node=8 run_squad.py --bert_model bert-large-cased-whole-word-masking --do_train --do_predict --do_lower_case --train_file $SQUAD_DIR/train-v1.1.json --predict_file $SQUAD_DIR/dev-v1.1.json --learning_rate 3e-5 --num_train_epochs 2 --max_seq_length 384 --doc_stride 128 --output_dir ../models/wwm_cased_finetuned_squad/ --train_batch_size 24 --gradient_accumulation_steps 12
|
||||
python -m torch.distributed.launch --nproc_per_node=8 run_bert_squad.py --bert_model bert-large-cased-whole-word-masking --do_train --do_predict --do_lower_case --train_file $SQUAD_DIR/train-v1.1.json --predict_file $SQUAD_DIR/dev-v1.1.json --learning_rate 3e-5 --num_train_epochs 2 --max_seq_length 384 --doc_stride 128 --output_dir ../models/wwm_cased_finetuned_squad/ --train_batch_size 24 --gradient_accumulation_steps 12
|
||||
```
|
||||
|
||||
Training with these hyper-parameters gave us the following results:
|
||||
@ -1499,7 +1499,7 @@ Here is the full list of hyper-parameters for this run:
|
||||
```bash
|
||||
export SQUAD_DIR=/path/to/SQUAD
|
||||
|
||||
python ./run_squad.py \
|
||||
python ./run_bert_squad.py \
|
||||
--bert_model bert-large-uncased \
|
||||
--do_train \
|
||||
--do_predict \
|
||||
@ -1521,7 +1521,7 @@ Here is an example of hyper-parameters for a FP16 run we tried:
|
||||
```bash
|
||||
export SQUAD_DIR=/path/to/SQUAD
|
||||
|
||||
python ./run_squad.py \
|
||||
python ./run_bert_squad.py \
|
||||
--bert_model bert-large-uncased \
|
||||
--do_train \
|
||||
--do_predict \
|
||||
@ -1547,7 +1547,7 @@ Here is an example with the recent `bert-large-uncased-whole-word-masking`:
|
||||
|
||||
```bash
|
||||
python -m torch.distributed.launch --nproc_per_node=8 \
|
||||
run_squad.py \
|
||||
run_bert_squad.py \
|
||||
--bert_model bert-large-uncased-whole-word-masking \
|
||||
--do_train \
|
||||
--do_predict \
|
||||
@ -1563,6 +1563,86 @@ python -m torch.distributed.launch --nproc_per_node=8 \
|
||||
--gradient_accumulation_steps 2
|
||||
```
|
||||
|
||||
## Fine-tuning XLNet
|
||||
|
||||
#### STS-B
|
||||
|
||||
This example code fine-tunes XLNet on the STS-B corpus.
|
||||
|
||||
Before running this example you should download the
|
||||
[GLUE data](https://gluebenchmark.com/tasks) by running
|
||||
[this script](https://gist.github.com/W4ngatang/60c2bdb54d156a41194446737ce03e2e)
|
||||
and unpack it to some directory `$GLUE_DIR`.
|
||||
|
||||
```shell
|
||||
export GLUE_DIR=/path/to/glue
|
||||
|
||||
python run_xlnet_classifier.py \
|
||||
--task_name STS-B \
|
||||
--do_train \
|
||||
--do_eval \
|
||||
--do_lower_case \
|
||||
--data_dir $GLUE_DIR/STS-B/ \
|
||||
--max_seq_length 128 \
|
||||
--train_batch_size 8 \
|
||||
--gradient_accumulation_steps 1 \
|
||||
--learning_rate 5e-5 \
|
||||
--num_train_epochs 3.0 \
|
||||
--output_dir /tmp/mrpc_output/
|
||||
```
|
||||
|
||||
Our test ran on a few seeds with [the original implementation hyper-parameters](https://github.com/zihangdai/xlnet#1-sts-b-sentence-pair-relevance-regression-with-gpus) gave evaluation results between 84% and 88%.
|
||||
|
||||
**Distributed training**
|
||||
Here is an example using distributed training on 8 V100 GPUs to reach XXXX:
|
||||
|
||||
```bash
|
||||
python -m torch.distributed.launch --nproc_per_node 8 \
|
||||
run_xlnet_classifier.py \
|
||||
--task_name STS-B \
|
||||
--do_train \
|
||||
--do_eval \
|
||||
--data_dir $GLUE_DIR/STS-B/ \
|
||||
--max_seq_length 128 \
|
||||
--train_batch_size 8 \
|
||||
--gradient_accumulation_steps 1 \
|
||||
--learning_rate 5e-5 \
|
||||
--num_train_epochs 3.0 \
|
||||
--output_dir /tmp/mrpc_output/
|
||||
```
|
||||
|
||||
Training with these hyper-parameters gave us the following results:
|
||||
```bash
|
||||
acc = 0.8823529411764706
|
||||
acc_and_f1 = 0.901702786377709
|
||||
eval_loss = 0.3418912578906332
|
||||
f1 = 0.9210526315789473
|
||||
global_step = 174
|
||||
loss = 0.07231863956341798
|
||||
```
|
||||
|
||||
Here is an example on MNLI:
|
||||
|
||||
```bash
|
||||
python -m torch.distributed.launch --nproc_per_node 8 run_bert_classifier.py --bert_model bert-large-uncased-whole-word-masking --task_name mnli --do_train --do_eval --data_dir /datadrive/bert_data/glue_data//MNLI/ --max_seq_length 128 --train_batch_size 8 --learning_rate 2e-5 --num_train_epochs 3.0 --output_dir ../models/wwm-uncased-finetuned-mnli/ --overwrite_output_dir
|
||||
```
|
||||
|
||||
```bash
|
||||
***** Eval results *****
|
||||
acc = 0.8679706601466992
|
||||
eval_loss = 0.4911287787382479
|
||||
global_step = 18408
|
||||
loss = 0.04755385363816904
|
||||
|
||||
***** Eval results *****
|
||||
acc = 0.8747965825874695
|
||||
eval_loss = 0.45516540421714036
|
||||
global_step = 18408
|
||||
loss = 0.04755385363816904
|
||||
```
|
||||
|
||||
This is the example of the `bert-large-uncased-whole-word-masking-finetuned-mnli` model
|
||||
|
||||
## BERTology
|
||||
|
||||
There is a growing field of study concerned with investigating the inner working of large-scale transformers like BERT (that some call "BERTology"). Some good examples of this field are:
|
||||
@ -1599,7 +1679,7 @@ A command-line interface is provided to convert a TensorFlow checkpoint in a PyT
|
||||
|
||||
You can convert any TensorFlow checkpoint for BERT (in particular [the pre-trained models released by Google](https://github.com/google-research/bert#pre-trained-models)) in a PyTorch save file by using the [`convert_tf_checkpoint_to_pytorch.py`](./pytorch_pretrained_bert/convert_tf_checkpoint_to_pytorch.py ) script.
|
||||
|
||||
This CLI takes as input a TensorFlow checkpoint (three files starting with `bert_model.ckpt`) and the associated configuration file (`bert_config.json`), and creates a PyTorch model for this configuration, loads the weights from the TensorFlow checkpoint in the PyTorch model and saves the resulting model in a standard PyTorch save file that can be imported using `torch.load()` (see examples in [`extract_features.py`](./examples/extract_features.py), [`run_classifier.py`](./examples/run_classifier.py) and [`run_squad.py`](./examples/run_squad.py)).
|
||||
This CLI takes as input a TensorFlow checkpoint (three files starting with `bert_model.ckpt`) and the associated configuration file (`bert_config.json`), and creates a PyTorch model for this configuration, loads the weights from the TensorFlow checkpoint in the PyTorch model and saves the resulting model in a standard PyTorch save file that can be imported using `torch.load()` (see examples in [`run_bert_extract_features.py`](./examples/run_bert_extract_features.py), [`run_bert_classifier.py`](./examples/run_bert_classifier.py) and [`run_bert_squad.py`](./examples/run_bert_squad.py)).
|
||||
|
||||
You only need to run this conversion script **once** to get a PyTorch model. You can then disregard the TensorFlow checkpoint (the three files starting with `bert_model.ckpt`) but be sure to keep the configuration file (`bert_config.json`) and the vocabulary file (`vocab.txt`) as these are needed for the PyTorch model too.
|
||||
|
||||
|
@ -203,7 +203,9 @@ def main():
|
||||
train_features = pickle.load(reader)
|
||||
except:
|
||||
train_features = convert_examples_to_features(
|
||||
train_examples, label_list, args.max_seq_length, tokenizer, output_mode)
|
||||
train_examples, label_list, args.max_seq_length, tokenizer, output_mode,
|
||||
cls_token_at_end=True, cls_token=tokenizer.CLS_TOKEN,
|
||||
sep_token=tokenizer.SEP_TOKEN, cls_token_segment_id=2)
|
||||
if args.local_rank == -1 or torch.distributed.get_rank() == 0:
|
||||
logger.info(" Saving train features into cached file %s", cached_train_features_file)
|
||||
with open(cached_train_features_file, "wb") as writer:
|
||||
@ -347,7 +349,9 @@ def main():
|
||||
eval_features = pickle.load(reader)
|
||||
except:
|
||||
eval_features = convert_examples_to_features(
|
||||
eval_examples, label_list, args.max_seq_length, tokenizer, output_mode)
|
||||
eval_examples, label_list, args.max_seq_length, tokenizer, output_mode,
|
||||
cls_token_at_end=True, cls_token=tokenizer.CLS_TOKEN,
|
||||
sep_token=tokenizer.SEP_TOKEN, cls_token_segment_id=2)
|
||||
if args.local_rank == -1 or torch.distributed.get_rank() == 0:
|
||||
logger.info(" Saving eval features into cached file %s", cached_eval_features_file)
|
||||
with open(cached_eval_features_file, "wb") as writer:
|
||||
|
@ -388,8 +388,15 @@ class WnliProcessor(DataProcessor):
|
||||
|
||||
|
||||
def convert_examples_to_features(examples, label_list, max_seq_length,
|
||||
tokenizer, output_mode):
|
||||
"""Loads a data file into a list of `InputBatch`s."""
|
||||
tokenizer, output_mode,
|
||||
cls_token_at_end=False, cls_token='[CLS]',
|
||||
sep_token='[SEP]', cls_token_segment_id=0):
|
||||
""" Loads a data file into a list of `InputBatch`s
|
||||
`cls_token_at_end` define the location of the CLS token:
|
||||
- False (BERT pattern): [CLS] + A + [SEP] + B + [SEP]
|
||||
- True (XLNet/GPT pattern): A + [SEP] + B + [SEP] + [CLS]
|
||||
`cls_token_segment_id` define the segment id associated to the CLS token (0 for BERT, 2 for XLNet)
|
||||
"""
|
||||
|
||||
label_map = {label : i for i, label in enumerate(label_list)}
|
||||
|
||||
@ -430,13 +437,20 @@ def convert_examples_to_features(examples, label_list, max_seq_length,
|
||||
# For classification tasks, the first vector (corresponding to [CLS]) is
|
||||
# used as as the "sentence vector". Note that this only makes sense because
|
||||
# the entire model is fine-tuned.
|
||||
tokens = ["[CLS]"] + tokens_a + ["[SEP]"]
|
||||
tokens = tokens_a + [sep_token]
|
||||
segment_ids = [0] * len(tokens)
|
||||
|
||||
if tokens_b:
|
||||
tokens += tokens_b + ["[SEP]"]
|
||||
tokens += tokens_b + [sep_token]
|
||||
segment_ids += [1] * (len(tokens_b) + 1)
|
||||
|
||||
if cls_token_at_end:
|
||||
tokens = tokens + [cls_token]
|
||||
segment_ids = segment_ids + [cls_token_segment_id]
|
||||
else:
|
||||
tokens = [cls_token] + tokens
|
||||
segment_ids = [cls_token_segment_id] + segment_ids
|
||||
|
||||
input_ids = tokenizer.convert_tokens_to_ids(tokens)
|
||||
|
||||
# The mask has 1 for real tokens and 0 for padding tokens. Only real
|
||||
|
@ -86,7 +86,7 @@
|
||||
"spec.loader.exec_module(module)\n",
|
||||
"sys.modules['modeling_tensorflow'] = module\n",
|
||||
"\n",
|
||||
"spec = importlib.util.spec_from_file_location('*', original_tf_inplem_dir + '/run_squad.py')\n",
|
||||
"spec = importlib.util.spec_from_file_location('*', original_tf_inplem_dir + '/run_bert_squad.py')\n",
|
||||
"module = importlib.util.module_from_spec(spec)\n",
|
||||
"spec.loader.exec_module(module)\n",
|
||||
"sys.modules['run_squad_tensorflow'] = module\n",
|
||||
|
@ -778,7 +778,7 @@ class BertModel(BertPreTrainedModel):
|
||||
Inputs:
|
||||
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
|
||||
with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
|
||||
`extract_features.py`, `run_classifier.py` and `run_squad.py`)
|
||||
`run_bert_extract_features.py`, `run_bert_classifier.py` and `run_bert_squad.py`)
|
||||
`token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
|
||||
types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
|
||||
a `sentence B` token (see BERT paper for more details).
|
||||
@ -905,7 +905,7 @@ class BertForPreTraining(BertPreTrainedModel):
|
||||
Inputs:
|
||||
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
|
||||
with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
|
||||
`extract_features.py`, `run_classifier.py` and `run_squad.py`)
|
||||
`run_bert_extract_features.py`, `run_bert_classifier.py` and `run_bert_squad.py`)
|
||||
`token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
|
||||
types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
|
||||
a `sentence B` token (see BERT paper for more details).
|
||||
@ -986,7 +986,7 @@ class BertForMaskedLM(BertPreTrainedModel):
|
||||
Inputs:
|
||||
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
|
||||
with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
|
||||
`extract_features.py`, `run_classifier.py` and `run_squad.py`)
|
||||
`run_bert_extract_features.py`, `run_bert_classifier.py` and `run_bert_squad.py`)
|
||||
`token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
|
||||
types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
|
||||
a `sentence B` token (see BERT paper for more details).
|
||||
@ -1064,7 +1064,7 @@ class BertForNextSentencePrediction(BertPreTrainedModel):
|
||||
Inputs:
|
||||
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
|
||||
with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
|
||||
`extract_features.py`, `run_classifier.py` and `run_squad.py`)
|
||||
`run_bert_extract_features.py`, `run_bert_classifier.py` and `run_bert_squad.py`)
|
||||
`token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
|
||||
types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
|
||||
a `sentence B` token (see BERT paper for more details).
|
||||
@ -1141,7 +1141,7 @@ class BertForSequenceClassification(BertPreTrainedModel):
|
||||
Inputs:
|
||||
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
|
||||
with the word token indices in the vocabulary. Items in the batch should begin with the special "CLS" token. (see the tokens preprocessing logic in the scripts
|
||||
`extract_features.py`, `run_classifier.py` and `run_squad.py`)
|
||||
`run_bert_extract_features.py`, `run_bert_classifier.py` and `run_bert_squad.py`)
|
||||
`token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
|
||||
types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
|
||||
a `sentence B` token (see BERT paper for more details).
|
||||
@ -1219,7 +1219,7 @@ class BertForMultipleChoice(BertPreTrainedModel):
|
||||
Inputs:
|
||||
`input_ids`: a torch.LongTensor of shape [batch_size, num_choices, sequence_length]
|
||||
with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
|
||||
`extract_features.py`, `run_classifier.py` and `run_squad.py`)
|
||||
`run_bert_extract_features.py`, `run_bert_classifier.py` and `run_bert_squad.py`)
|
||||
`token_type_ids`: an optional torch.LongTensor of shape [batch_size, num_choices, sequence_length]
|
||||
with the token types indices selected in [0, 1]. Type 0 corresponds to a `sentence A`
|
||||
and type 1 corresponds to a `sentence B` token (see BERT paper for more details).
|
||||
@ -1300,7 +1300,7 @@ class BertForTokenClassification(BertPreTrainedModel):
|
||||
Inputs:
|
||||
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
|
||||
with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
|
||||
`extract_features.py`, `run_classifier.py` and `run_squad.py`)
|
||||
`run_bert_extract_features.py`, `run_bert_classifier.py` and `run_bert_squad.py`)
|
||||
`token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
|
||||
types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
|
||||
a `sentence B` token (see BERT paper for more details).
|
||||
@ -1384,7 +1384,7 @@ class BertForQuestionAnswering(BertPreTrainedModel):
|
||||
Inputs:
|
||||
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
|
||||
with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
|
||||
`extract_features.py`, `run_classifier.py` and `run_squad.py`)
|
||||
`run_bert_extract_features.py`, `run_bert_classifier.py` and `run_bert_squad.py`)
|
||||
`token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
|
||||
types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
|
||||
a `sentence B` token (see BERT paper for more details).
|
||||
|
@ -46,7 +46,7 @@ XLNET_CONFIG_NAME = 'xlnet_config.json'
|
||||
TF_WEIGHTS_NAME = 'model.ckpt'
|
||||
|
||||
|
||||
def build_tf_xlnet_to_pytorch_map(model, config):
|
||||
def build_tf_xlnet_to_pytorch_map(model, config, tf_weights=None):
|
||||
""" A map of modules from TF to PyTorch.
|
||||
I use a map to keep the PyTorch model as
|
||||
identical to the original PyTorch model as possible.
|
||||
@ -55,8 +55,15 @@ def build_tf_xlnet_to_pytorch_map(model, config):
|
||||
tf_to_pt_map = {}
|
||||
|
||||
if hasattr(model, 'transformer'):
|
||||
# We are loading pre-trained weights in a XLNetLMHeadModel => we will load also the output bias
|
||||
if hasattr(model, 'lm_loss'):
|
||||
# We will load also the output bias
|
||||
tf_to_pt_map['model/lm_loss/bias'] = model.lm_loss.bias
|
||||
elif hasattr(model, 'sequence_summary') and 'model/sequnece_summary/summary/kernel' in tf_weights:
|
||||
# We will load also the sequence summary
|
||||
tf_to_pt_map['model/sequnece_summary/summary/kernel'] = model.sequence_summary.summary.weight
|
||||
tf_to_pt_map['model/sequnece_summary/summary/bias'] = model.sequence_summary.summary.bias
|
||||
elif hasattr(model, 'proj_loss') and any('model/regression' in name for name in tf_weights.keys()):
|
||||
raise NotImplementedError
|
||||
# Now load the rest of the transformer
|
||||
model = model.transformer
|
||||
|
||||
@ -116,9 +123,6 @@ def load_tf_weights_in_xlnet(model, config, tf_path):
|
||||
print("Loading a TensorFlow models in PyTorch, requires TensorFlow to be installed. Please see "
|
||||
"https://www.tensorflow.org/install/ for installation instructions.")
|
||||
raise
|
||||
# Build TF to PyTorch weights loading map
|
||||
tf_to_pt_map = build_tf_xlnet_to_pytorch_map(model, config)
|
||||
|
||||
# Load weights from TF model
|
||||
init_vars = tf.train.list_variables(tf_path)
|
||||
tf_weights = {}
|
||||
@ -127,9 +131,14 @@ def load_tf_weights_in_xlnet(model, config, tf_path):
|
||||
array = tf.train.load_variable(tf_path, name)
|
||||
tf_weights[name] = array
|
||||
|
||||
# Build TF to PyTorch weights loading map
|
||||
tf_to_pt_map = build_tf_xlnet_to_pytorch_map(model, config, tf_weights)
|
||||
|
||||
for name, pointer in tf_to_pt_map.items():
|
||||
print("Importing {}".format(name))
|
||||
assert name in tf_weights
|
||||
if name not in tf_weights:
|
||||
print("{} not in tf pre-trained weights, skipping".format(name))
|
||||
continue
|
||||
array = tf_weights[name]
|
||||
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
|
||||
# which are not required for using pretrained model
|
||||
@ -643,6 +652,11 @@ class XLNetPreTrainedModel(nn.Module):
|
||||
elif isinstance(module, XLNetLayerNorm):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
elif isinstance(module, XLNetRelativeAttention):
|
||||
for param in [module.q, module.k, module.v, module.o, module.r,
|
||||
module.r_r_bias, module.r_s_bias, module.r_w_bias,
|
||||
module.seg_embed]:
|
||||
param.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if isinstance(module, nn.Linear) and module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
|
||||
@ -904,15 +918,19 @@ class XLNetModel(XLNetPreTrainedModel):
|
||||
pos_emb = pos_emb.to(next(self.parameters()))
|
||||
return pos_emb
|
||||
|
||||
def forward(self, inp_k, token_type_ids=None, attention_mask=None,
|
||||
def forward(self, inp_k, token_type_ids=None, input_mask=None, attention_mask=None,
|
||||
mems=None, perm_mask=None, target_mapping=None, inp_q=None,
|
||||
output_all_encoded_layers=True, head_mask=None):
|
||||
"""
|
||||
Args:
|
||||
inp_k: int32 Tensor in shape [bsz, len], the input token IDs.
|
||||
token_type_ids: int32 Tensor in shape [bsz, len], the input segment IDs.
|
||||
attention_mask: [optional] float32 Tensor in shape [bsz, len], the input mask.
|
||||
input_mask: [optional] float32 Tensor in shape [bsz, len], the input mask.
|
||||
0 for real tokens and 1 for padding.
|
||||
attention_mask: [optional] float32 Tensor, SAME FUNCTION as `input_mask`
|
||||
but with 1 for real tokens and 0 for padding.
|
||||
Added for easy compatibility with the BERT model (which uses this negative masking).
|
||||
You can only uses one among `input_mask` and `attention_mask`
|
||||
mems: [optional] a list of float32 Tensors in shape [mem_len, bsz, d_model], memory
|
||||
from previous batches. The length of the list equals n_layer.
|
||||
If None, no memory is used.
|
||||
@ -946,6 +964,7 @@ class XLNetModel(XLNetPreTrainedModel):
|
||||
# so we move here the first dimension (batch) to the end
|
||||
inp_k = inp_k.transpose(0, 1).contiguous()
|
||||
token_type_ids = token_type_ids.transpose(0, 1).contiguous() if token_type_ids is not None else None
|
||||
input_mask = input_mask.transpose(0, 1).contiguous() if input_mask is not None else None
|
||||
attention_mask = attention_mask.transpose(0, 1).contiguous() if attention_mask is not None else None
|
||||
perm_mask = perm_mask.permute(1, 2, 0).contiguous() if perm_mask is not None else None
|
||||
target_mapping = target_mapping.permute(1, 2, 0).contiguous() if target_mapping is not None else None
|
||||
@ -969,11 +988,15 @@ class XLNetModel(XLNetPreTrainedModel):
|
||||
raise ValueError('Unsupported attention type: {}'.format(self.attn_type))
|
||||
|
||||
# data mask: input mask & perm mask
|
||||
if attention_mask is not None and perm_mask is not None:
|
||||
data_mask = attention_mask[None] + perm_mask
|
||||
elif attention_mask is not None and perm_mask is None:
|
||||
data_mask = attention_mask[None]
|
||||
elif attention_mask is None and perm_mask is not None:
|
||||
assert input_mask is None or attention_mask is None, "You can only use one of input_mask (uses 1 for padding) "
|
||||
"or attention_mask (uses 0 for padding, added for compatbility with BERT). Please choose one."
|
||||
if input_mask is None and attention_mask is not None:
|
||||
input_mask = 1.0 - attention_mask
|
||||
if input_mask is not None and perm_mask is not None:
|
||||
data_mask = input_mask[None] + perm_mask
|
||||
elif input_mask is not None and perm_mask is None:
|
||||
data_mask = input_mask[None]
|
||||
elif input_mask is None and perm_mask is not None:
|
||||
data_mask = perm_mask
|
||||
else:
|
||||
data_mask = None
|
||||
@ -1077,8 +1100,12 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
|
||||
Inputs:
|
||||
inp_k: int32 Tensor in shape [bsz, len], the input token IDs.
|
||||
token_type_ids: int32 Tensor in shape [bsz, len], the input segment IDs.
|
||||
attention_mask: [optional] float32 Tensor in shape [bsz, len], the input mask.
|
||||
input_mask: [optional] float32 Tensor in shape [bsz, len], the input mask.
|
||||
0 for real tokens and 1 for padding.
|
||||
attention_mask: [optional] float32 Tensor, SAME FUNCTION as `input_mask`
|
||||
but with 1 for real tokens and 0 for padding.
|
||||
Added for easy compatibility with the BERT model (which uses this negative masking).
|
||||
You can only uses one among `input_mask` and `attention_mask`
|
||||
mems: [optional] a list of float32 Tensors in shape [mem_len, bsz, d_model], memory
|
||||
from previous batches. The length of the list equals n_layer.
|
||||
If None, no memory is used.
|
||||
@ -1112,14 +1139,14 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
|
||||
```python
|
||||
# Already been converted into WordPiece token ids
|
||||
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
|
||||
attention_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
|
||||
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
|
||||
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
|
||||
|
||||
config = modeling.XLNetConfig(vocab_size_or_config_json_file=32000, d_model=768,
|
||||
n_layer=12, num_attention_heads=12, intermediate_size=3072)
|
||||
|
||||
model = modeling.XLNetModel(config=config)
|
||||
all_encoder_layers, pooled_output = model(input_ids, token_type_ids, attention_mask)
|
||||
all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
|
||||
```
|
||||
"""
|
||||
def __init__(self, config, output_attentions=False, keep_multihead_output=False):
|
||||
@ -1142,15 +1169,19 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
|
||||
"""
|
||||
self.lm_loss.weight = self.transformer.word_embedding.weight
|
||||
|
||||
def forward(self, inp_k, token_type_ids=None, attention_mask=None,
|
||||
def forward(self, inp_k, token_type_ids=None, input_mask=None, attention_mask=None,
|
||||
mems=None, perm_mask=None, target_mapping=None, inp_q=None,
|
||||
target=None, output_all_encoded_layers=True, head_mask=None):
|
||||
"""
|
||||
Args:
|
||||
inp_k: int32 Tensor in shape [bsz, len], the input token IDs.
|
||||
token_type_ids: int32 Tensor in shape [bsz, len], the input segment IDs.
|
||||
attention_mask: float32 Tensor in shape [bsz, len], the input mask.
|
||||
input_mask: float32 Tensor in shape [bsz, len], the input mask.
|
||||
0 for real tokens and 1 for padding.
|
||||
attention_mask: [optional] float32 Tensor, SAME FUNCTION as `input_mask`
|
||||
but with 1 for real tokens and 0 for padding.
|
||||
Added for easy compatibility with the BERT model (which uses this negative masking).
|
||||
You can only uses one among `input_mask` and `attention_mask`
|
||||
mems: a list of float32 Tensors in shape [mem_len, bsz, d_model], memory
|
||||
from previous batches. The length of the list equals n_layer.
|
||||
If None, no memory is used.
|
||||
@ -1171,7 +1202,7 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
|
||||
summary_type: str, "last", "first", "mean", or "attn". The method
|
||||
to pool the input to get a vector representation.
|
||||
"""
|
||||
output, hidden_states, new_mems = self.transformer(inp_k, token_type_ids, attention_mask,
|
||||
output, hidden_states, new_mems = self.transformer(inp_k, token_type_ids, input_mask, attention_mask,
|
||||
mems, perm_mask, target_mapping, inp_q,
|
||||
output_all_encoded_layers, head_mask)
|
||||
|
||||
@ -1242,8 +1273,12 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
|
||||
Inputs:
|
||||
inp_k: int32 Tensor in shape [bsz, len], the input token IDs.
|
||||
token_type_ids: int32 Tensor in shape [bsz, len], the input segment IDs.
|
||||
attention_mask: float32 Tensor in shape [bsz, len], the input mask.
|
||||
input_mask: float32 Tensor in shape [bsz, len], the input mask.
|
||||
0 for real tokens and 1 for padding.
|
||||
attention_mask: [optional] float32 Tensor, SAME FUNCTION as `input_mask`
|
||||
but with 1 for real tokens and 0 for padding.
|
||||
Added for easy compatibility with the BERT model (which uses this negative masking).
|
||||
You can only uses one among `input_mask` and `attention_mask`
|
||||
mems: a list of float32 Tensors in shape [mem_len, bsz, d_model], memory
|
||||
from previous batches. The length of the list equals n_layer.
|
||||
If None, no memory is used.
|
||||
@ -1278,14 +1313,14 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
|
||||
```python
|
||||
# Already been converted into WordPiece token ids
|
||||
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
|
||||
attention_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
|
||||
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
|
||||
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
|
||||
|
||||
config = modeling.XLNetConfig(vocab_size_or_config_json_file=32000, d_model=768,
|
||||
n_layer=12, num_attention_heads=12, intermediate_size=3072)
|
||||
|
||||
model = modeling.XLNetModel(config=config)
|
||||
all_encoder_layers, pooled_output = model(input_ids, token_type_ids, attention_mask)
|
||||
all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
|
||||
```
|
||||
"""
|
||||
def __init__(self, config, summary_type="last", use_proj=True, num_labels=2,
|
||||
@ -1306,15 +1341,19 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
|
||||
self.loss_proj = nn.Linear(config.d_model, num_labels if not is_regression else 1)
|
||||
self.apply(self.init_weights)
|
||||
|
||||
def forward(self, inp_k, token_type_ids=None, attention_mask=None,
|
||||
def forward(self, inp_k, token_type_ids=None, input_mask=None, attention_mask=None,
|
||||
mems=None, perm_mask=None, target_mapping=None, inp_q=None,
|
||||
target=None, output_all_encoded_layers=True, head_mask=None):
|
||||
"""
|
||||
Args:
|
||||
inp_k: int32 Tensor in shape [bsz, len], the input token IDs.
|
||||
token_type_ids: int32 Tensor in shape [bsz, len], the input segment IDs.
|
||||
attention_mask: float32 Tensor in shape [bsz, len], the input mask.
|
||||
input_mask: float32 Tensor in shape [bsz, len], the input mask.
|
||||
0 for real tokens and 1 for padding.
|
||||
attention_mask: [optional] float32 Tensor, SAME FUNCTION as `input_mask`
|
||||
but with 1 for real tokens and 0 for padding.
|
||||
Added for easy compatibility with the BERT model (which uses this negative masking).
|
||||
You can only uses one among `input_mask` and `attention_mask`
|
||||
mems: a list of float32 Tensors in shape [mem_len, bsz, d_model], memory
|
||||
from previous batches. The length of the list equals n_layer.
|
||||
If None, no memory is used.
|
||||
@ -1332,7 +1371,7 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
|
||||
Only used during pretraining for two-stream attention.
|
||||
Set to None during finetuning.
|
||||
"""
|
||||
output, _, new_mems = self.transformer(inp_k, token_type_ids, attention_mask,
|
||||
output, _, new_mems = self.transformer(inp_k, token_type_ids, input_mask, attention_mask,
|
||||
mems, perm_mask, target_mapping, inp_q,
|
||||
output_all_encoded_layers, head_mask)
|
||||
|
||||
@ -1372,11 +1411,15 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel):
|
||||
Inputs:
|
||||
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
|
||||
with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
|
||||
`extract_features.py`, `run_classifier.py` and `run_squad.py`)
|
||||
`run_bert_extract_features.py`, `run_bert_classifier.py` and `run_bert_squad.py`)
|
||||
`token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
|
||||
types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
|
||||
a `sentence B` token (see XLNet paper for more details).
|
||||
`attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices
|
||||
`attention_mask`: [optional] float32 Tensor, SAME FUNCTION as `input_mask`
|
||||
but with 1 for real tokens and 0 for padding.
|
||||
Added for easy compatibility with the BERT model (which uses this negative masking).
|
||||
You can only uses one among `input_mask` and `attention_mask`
|
||||
`input_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices
|
||||
selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
|
||||
input sequence length in the current batch. It's the mask that we typically use for attention when
|
||||
a batch has varying length sentences.
|
||||
@ -1400,14 +1443,14 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel):
|
||||
```python
|
||||
# Already been converted into WordPiece token ids
|
||||
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
|
||||
attention_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
|
||||
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
|
||||
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
|
||||
|
||||
config = XLNetConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
|
||||
num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
|
||||
|
||||
model = XLNetForQuestionAnswering(config)
|
||||
start_logits, end_logits = model(input_ids, token_type_ids, attention_mask)
|
||||
start_logits, end_logits = model(input_ids, token_type_ids, input_mask)
|
||||
```
|
||||
"""
|
||||
def __init__(self, config, output_attentions=False, keep_multihead_output=False):
|
||||
@ -1418,11 +1461,11 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel):
|
||||
self.qa_outputs = nn.Linear(config.hidden_size, 2)
|
||||
self.apply(self.init_weights)
|
||||
|
||||
def forward(self, inp_k, token_type_ids=None, attention_mask=None,
|
||||
def forward(self, inp_k, token_type_ids=None, input_mask=None, attention_mask=None,
|
||||
mems=None, perm_mask=None, target_mapping=None, inp_q=None,
|
||||
start_positions=None, end_positions=None,
|
||||
output_all_encoded_layers=True, head_mask=None):
|
||||
output, _, new_mems = self.transformer(inp_k, token_type_ids, attention_mask,
|
||||
output, _, new_mems = self.transformer(inp_k, token_type_ids, input_mask, attention_mask,
|
||||
mems, perm_mask, target_mapping, inp_q,
|
||||
output_all_encoded_layers, head_mask)
|
||||
|
||||
|
@ -58,7 +58,6 @@ PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = {
|
||||
}
|
||||
VOCAB_NAME = 'vocab.txt'
|
||||
|
||||
|
||||
def load_vocab(vocab_file):
|
||||
"""Loads a vocabulary file into a dictionary."""
|
||||
vocab = collections.OrderedDict()
|
||||
@ -116,6 +115,46 @@ class BertTokenizer(object):
|
||||
self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab)
|
||||
self.max_len = max_len if max_len is not None else int(1e12)
|
||||
|
||||
@property
|
||||
def UNK_TOKEN(self):
|
||||
return "[UNK]"
|
||||
|
||||
@property
|
||||
def SEP_TOKEN(self):
|
||||
return "[SEP]"
|
||||
|
||||
@property
|
||||
def PAD_TOKEN(self):
|
||||
return "[PAD]"
|
||||
|
||||
@property
|
||||
def CLS_TOKEN(self):
|
||||
return "[CLS]"
|
||||
|
||||
@property
|
||||
def MASK_TOKEN(self):
|
||||
return "[MASK]"
|
||||
|
||||
@property
|
||||
def UNK_ID(self):
|
||||
return self.vocab["[UNK]"]
|
||||
|
||||
@property
|
||||
def SEP_ID(self):
|
||||
return self.vocab["[SEP]"]
|
||||
|
||||
@property
|
||||
def PAD_ID(self):
|
||||
return self.vocab["[PAD]"]
|
||||
|
||||
@property
|
||||
def CLS_ID(self):
|
||||
return self.vocab["[CLS]"]
|
||||
|
||||
@property
|
||||
def MASK_ID(self):
|
||||
return self.vocab["[MASK]"]
|
||||
|
||||
def tokenize(self, text):
|
||||
split_tokens = []
|
||||
if self.do_basic_tokenize:
|
||||
|
@ -38,26 +38,6 @@ SPECIAL_TOKENS_NAME = 'special_tokens.txt'
|
||||
|
||||
SPIECE_UNDERLINE = u'▁'
|
||||
|
||||
# Tokens
|
||||
special_symbols = {
|
||||
"<unk>" : 0,
|
||||
"<s>" : 1,
|
||||
"</s>" : 2,
|
||||
"<cls>" : 3,
|
||||
"<sep>" : 4,
|
||||
"<pad>" : 5,
|
||||
"<mask>" : 6,
|
||||
"<eod>" : 7,
|
||||
"<eop>" : 8,
|
||||
}
|
||||
|
||||
VOCAB_SIZE = 32000
|
||||
UNK_ID = special_symbols["<unk>"]
|
||||
CLS_ID = special_symbols["<cls>"]
|
||||
SEP_ID = special_symbols["<sep>"]
|
||||
MASK_ID = special_symbols["<mask>"]
|
||||
EOD_ID = special_symbols["<eod>"]
|
||||
|
||||
# Segments (not really needed)
|
||||
SEG_ID_A = 0
|
||||
SEG_ID_B = 1
|
||||
@ -70,6 +50,18 @@ class XLNetTokenizer(object):
|
||||
SentencePiece based tokenizer. Peculiarities:
|
||||
- requires SentencePiece: https://github.com/google/sentencepiece
|
||||
"""
|
||||
# Tokens
|
||||
special_symbols = {
|
||||
"<unk>" : 0,
|
||||
"<s>" : 1,
|
||||
"</s>" : 2,
|
||||
"<cls>" : 3,
|
||||
"<sep>" : 4,
|
||||
"<pad>" : 5,
|
||||
"<mask>" : 6,
|
||||
"<eod>" : 7,
|
||||
"<eop>" : 8,
|
||||
}
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs):
|
||||
"""
|
||||
@ -147,6 +139,46 @@ class XLNetTokenizer(object):
|
||||
self.special_tokens_decoder = {}
|
||||
self.set_special_tokens(special_tokens)
|
||||
|
||||
@property
|
||||
def UNK_TOKEN(self):
|
||||
return "<unk>"
|
||||
|
||||
@property
|
||||
def SEP_TOKEN(self):
|
||||
return "<sep>"
|
||||
|
||||
@property
|
||||
def PAD_TOKEN(self):
|
||||
return "<pad>"
|
||||
|
||||
@property
|
||||
def CLS_TOKEN(self):
|
||||
return "<cls>"
|
||||
|
||||
@property
|
||||
def MASK_TOKEN(self):
|
||||
return "<mask>"
|
||||
|
||||
@property
|
||||
def UNK_ID(self):
|
||||
return self.special_symbols["<unk>"]
|
||||
|
||||
@property
|
||||
def SEP_ID(self):
|
||||
return self.special_symbols["<sep>"]
|
||||
|
||||
@property
|
||||
def PAD_ID(self):
|
||||
return self.special_symbols["<pad>"]
|
||||
|
||||
@property
|
||||
def CLS_ID(self):
|
||||
return self.special_symbols["<cls>"]
|
||||
|
||||
@property
|
||||
def MASK_ID(self):
|
||||
return self.special_symbols["<mask>"]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.encoder) + len(self.special_tokens)
|
||||
|
||||
|
@ -86,7 +86,7 @@ class XLNetModelTest(unittest.TestCase):
|
||||
inp_q = target_mapping[:, 0, :].clone() # predict last token
|
||||
|
||||
# inp_k: int32 Tensor in shape [bsz, len], the input token IDs.
|
||||
# seg_id: int32 Tensor in shape [bsz, len], the input segment IDs.
|
||||
# token_type_ids: int32 Tensor in shape [bsz, len], the input segment IDs.
|
||||
# input_mask: float32 Tensor in shape [bsz, len], the input mask.
|
||||
# 0 for real tokens and 1 for padding.
|
||||
# mems: a list of float32 Tensors in shape [bsz, mem_len, d_model], memory
|
||||
@ -138,11 +138,11 @@ class XLNetModelTest(unittest.TestCase):
|
||||
model = XLNetLMHeadModel(config)
|
||||
model.eval()
|
||||
|
||||
loss_1, mems_1a = model(input_ids_1, seg_id=segment_ids, target=lm_labels)
|
||||
all_logits_1, mems_1b = model(input_ids_1, seg_id=segment_ids)
|
||||
loss_1, mems_1a = model(input_ids_1, token_type_ids=segment_ids, target=lm_labels)
|
||||
all_logits_1, mems_1b = model(input_ids_1, token_type_ids=segment_ids)
|
||||
|
||||
loss_2, mems_2a = model(input_ids_2, seg_id=segment_ids, target=lm_labels, mems=mems_1a)
|
||||
all_logits_2, mems_2b = model(input_ids_2, seg_id=segment_ids, mems=mems_1b)
|
||||
loss_2, mems_2a = model(input_ids_2, token_type_ids=segment_ids, target=lm_labels, mems=mems_1a)
|
||||
all_logits_2, mems_2b = model(input_ids_2, token_type_ids=segment_ids, mems=mems_1b)
|
||||
|
||||
logits, _ = model(input_ids_q,
|
||||
perm_mask=perm_mask,
|
||||
|
Loading…
Reference in New Issue
Block a user