mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 10:12:23 +06:00
Merge remote-tracking branch 'upstream/master' into optim
# Conflicts: # pytorch_pretrained_bert/optimization.py - updated docs for optimization
This commit is contained in:
commit
725a56329d
106
README.md
106
README.md
@ -85,28 +85,28 @@ python -m pytest -sv tests/
|
||||
This package comprises the following classes that can be imported in Python and are detailed in the [Doc](#doc) section of this readme:
|
||||
|
||||
- Eight **Bert** PyTorch models (`torch.nn.Module`) with pre-trained weights (in the [`modeling.py`](./pytorch_pretrained_bert/modeling.py) file):
|
||||
- [`BertModel`](./pytorch_pretrained_bert/modeling.py#L556) - raw BERT Transformer model (**fully pre-trained**),
|
||||
- [`BertForMaskedLM`](./pytorch_pretrained_bert/modeling.py#L710) - BERT Transformer with the pre-trained masked language modeling head on top (**fully pre-trained**),
|
||||
- [`BertForNextSentencePrediction`](./pytorch_pretrained_bert/modeling.py#L771) - BERT Transformer with the pre-trained next sentence prediction classifier on top (**fully pre-trained**),
|
||||
- [`BertForPreTraining`](./pytorch_pretrained_bert/modeling.py#L639) - BERT Transformer with masked language modeling head and next sentence prediction classifier on top (**fully pre-trained**),
|
||||
- [`BertForSequenceClassification`](./pytorch_pretrained_bert/modeling.py#L833) - BERT Transformer with a sequence classification head on top (BERT Transformer is **pre-trained**, the sequence classification head **is only initialized and has to be trained**),
|
||||
- [`BertForMultipleChoice`](./pytorch_pretrained_bert/modeling.py#L899) - BERT Transformer with a multiple choice head on top (used for task like Swag) (BERT Transformer is **pre-trained**, the multiple choice classification head **is only initialized and has to be trained**),
|
||||
- [`BertForTokenClassification`](./pytorch_pretrained_bert/modeling.py#L969) - BERT Transformer with a token classification head on top (BERT Transformer is **pre-trained**, the token classification head **is only initialized and has to be trained**),
|
||||
- [`BertForQuestionAnswering`](./pytorch_pretrained_bert/modeling.py#L1034) - BERT Transformer with a token classification head on top (BERT Transformer is **pre-trained**, the token classification head **is only initialized and has to be trained**).
|
||||
- [`BertModel`](./pytorch_pretrained_bert/modeling.py#L639) - raw BERT Transformer model (**fully pre-trained**),
|
||||
- [`BertForMaskedLM`](./pytorch_pretrained_bert/modeling.py#L793) - BERT Transformer with the pre-trained masked language modeling head on top (**fully pre-trained**),
|
||||
- [`BertForNextSentencePrediction`](./pytorch_pretrained_bert/modeling.py#L854) - BERT Transformer with the pre-trained next sentence prediction classifier on top (**fully pre-trained**),
|
||||
- [`BertForPreTraining`](./pytorch_pretrained_bert/modeling.py#L722) - BERT Transformer with masked language modeling head and next sentence prediction classifier on top (**fully pre-trained**),
|
||||
- [`BertForSequenceClassification`](./pytorch_pretrained_bert/modeling.py#L916) - BERT Transformer with a sequence classification head on top (BERT Transformer is **pre-trained**, the sequence classification head **is only initialized and has to be trained**),
|
||||
- [`BertForMultipleChoice`](./pytorch_pretrained_bert/modeling.py#L982) - BERT Transformer with a multiple choice head on top (used for task like Swag) (BERT Transformer is **pre-trained**, the multiple choice classification head **is only initialized and has to be trained**),
|
||||
- [`BertForTokenClassification`](./pytorch_pretrained_bert/modeling.py#L1051) - BERT Transformer with a token classification head on top (BERT Transformer is **pre-trained**, the token classification head **is only initialized and has to be trained**),
|
||||
- [`BertForQuestionAnswering`](./pytorch_pretrained_bert/modeling.py#L1124) - BERT Transformer with a token classification head on top (BERT Transformer is **pre-trained**, the token classification head **is only initialized and has to be trained**).
|
||||
|
||||
- Three **OpenAI GPT** PyTorch models (`torch.nn.Module`) with pre-trained weights (in the [`modeling_openai.py`](./pytorch_pretrained_bert/modeling_openai.py) file):
|
||||
- [`OpenAIGPTModel`](./pytorch_pretrained_bert/modeling_openai.py#L537) - raw OpenAI GPT Transformer model (**fully pre-trained**),
|
||||
- [`OpenAIGPTLMHeadModel`](./pytorch_pretrained_bert/modeling_openai.py#L691) - OpenAI GPT Transformer with the tied language modeling head on top (**fully pre-trained**),
|
||||
- [`OpenAIGPTDoubleHeadsModel`](./pytorch_pretrained_bert/modeling_openai.py#L752) - OpenAI GPT Transformer with the tied language modeling head and a multiple choice classification head on top (OpenAI GPT Transformer is **pre-trained**, the multiple choice classification head **is only initialized and has to be trained**),
|
||||
- [`OpenAIGPTModel`](./pytorch_pretrained_bert/modeling_openai.py#L536) - raw OpenAI GPT Transformer model (**fully pre-trained**),
|
||||
- [`OpenAIGPTLMHeadModel`](./pytorch_pretrained_bert/modeling_openai.py#L643) - OpenAI GPT Transformer with the tied language modeling head on top (**fully pre-trained**),
|
||||
- [`OpenAIGPTDoubleHeadsModel`](./pytorch_pretrained_bert/modeling_openai.py#L722) - OpenAI GPT Transformer with the tied language modeling head and a multiple choice classification head on top (OpenAI GPT Transformer is **pre-trained**, the multiple choice classification head **is only initialized and has to be trained**),
|
||||
|
||||
- Two **Transformer-XL** PyTorch models (`torch.nn.Module`) with pre-trained weights (in the [`modeling_transfo_xl.py`](./pytorch_pretrained_bert/modeling_transfo_xl.py) file):
|
||||
- [`TransfoXLModel`](./pytorch_pretrained_bert/modeling_transfo_xl.py#L974) - Transformer-XL model which outputs the last hidden state and memory cells (**fully pre-trained**),
|
||||
- [`TransfoXLLMHeadModel`](./pytorch_pretrained_bert/modeling_transfo_xl.py#L1236) - Transformer-XL with the tied adaptive softmax head on top for language modeling which outputs the logits/loss and memory cells (**fully pre-trained**),
|
||||
- [`TransfoXLModel`](./pytorch_pretrained_bert/modeling_transfo_xl.py#L983) - Transformer-XL model which outputs the last hidden state and memory cells (**fully pre-trained**),
|
||||
- [`TransfoXLLMHeadModel`](./pytorch_pretrained_bert/modeling_transfo_xl.py#L1260) - Transformer-XL with the tied adaptive softmax head on top for language modeling which outputs the logits/loss and memory cells (**fully pre-trained**),
|
||||
|
||||
- Three **OpenAI GPT-2** PyTorch models (`torch.nn.Module`) with pre-trained weights (in the [`modeling_gpt2.py`](./pytorch_pretrained_bert/modeling_gpt2.py) file):
|
||||
- [`GPT2Model`](./pytorch_pretrained_bert/modeling_gpt2.py#L537) - raw OpenAI GPT-2 Transformer model (**fully pre-trained**),
|
||||
- [`GPT2LMHeadModel`](./pytorch_pretrained_bert/modeling_gpt2.py#L691) - OpenAI GPT-2 Transformer with the tied language modeling head on top (**fully pre-trained**),
|
||||
- [`GPT2DoubleHeadsModel`](./pytorch_pretrained_bert/modeling_gpt2.py#L752) - OpenAI GPT-2 Transformer with the tied language modeling head and a multiple choice classification head on top (OpenAI GPT-2 Transformer is **pre-trained**, the multiple choice classification head **is only initialized and has to be trained**),
|
||||
- [`GPT2Model`](./pytorch_pretrained_bert/modeling_gpt2.py#L479) - raw OpenAI GPT-2 Transformer model (**fully pre-trained**),
|
||||
- [`GPT2LMHeadModel`](./pytorch_pretrained_bert/modeling_gpt2.py#L559) - OpenAI GPT-2 Transformer with the tied language modeling head on top (**fully pre-trained**),
|
||||
- [`GPT2DoubleHeadsModel`](./pytorch_pretrained_bert/modeling_gpt2.py#L624) - OpenAI GPT-2 Transformer with the tied language modeling head and a multiple choice classification head on top (OpenAI GPT-2 Transformer is **pre-trained**, the multiple choice classification head **is only initialized and has to be trained**),
|
||||
|
||||
- Tokenizers for **BERT** (using word-piece) (in the [`tokenization.py`](./pytorch_pretrained_bert/tokenization.py) file):
|
||||
- `BasicTokenizer` - basic tokenization (punctuation splitting, lower casing, etc.),
|
||||
@ -140,7 +140,7 @@ The repository further comprises:
|
||||
- [`run_classifier.py`](./examples/run_classifier.py) - Show how to fine-tune an instance of `BertForSequenceClassification` on GLUE's MRPC task,
|
||||
- [`run_squad.py`](./examples/run_squad.py) - Show how to fine-tune an instance of `BertForQuestionAnswering` on SQuAD v1.0 and SQuAD v2.0 tasks.
|
||||
- [`run_swag.py`](./examples/run_swag.py) - Show how to fine-tune an instance of `BertForMultipleChoice` on Swag task.
|
||||
- [`run_lm_finetuning.py`](./examples/run_lm_finetuning.py) - Show how to fine-tune an instance of `BertForPretraining` on a target text corpus.
|
||||
- [`simple_lm_finetuning.py`](./examples/lm_finetuning/simple_lm_finetuning.py) - Show how to fine-tune an instance of `BertForPretraining` on a target text corpus.
|
||||
|
||||
- One example on how to use **OpenAI GPT** (in the [`examples` folder](./examples)):
|
||||
- [`run_openai_gpt.py`](./examples/run_openai_gpt.py) - Show how to fine-tune an instance of `OpenGPTDoubleHeadsModel` on the RocStories task.
|
||||
@ -461,12 +461,12 @@ Here is a detailed documentation of the classes in the package and how to use th
|
||||
|
||||
| Sub-section | Description |
|
||||
|-|-|
|
||||
| [Loading Google AI's/OpenAI's pre-trained weigths](#loading-google-ai-or-openai-pre-trained-weigths-or-pytorch-dump) | How to load Google AI/OpenAI's pre-trained weight or a PyTorch saved instance |
|
||||
| [Loading Google AI's/OpenAI's pre-trained weights](#loading-google-ai-or-openai-pre-trained-weights-or-pytorch-dump) | How to load Google AI/OpenAI's pre-trained weight or a PyTorch saved instance |
|
||||
| [PyTorch models](#PyTorch-models) | API of the BERT, GPT, GPT-2 and Transformer-XL PyTorch model classes |
|
||||
| [Tokenizers](#tokenizers) | API of the tokenizers class for BERT, GPT, GPT-2 and Transformer-XL|
|
||||
| [Optimizers](#optimizerss) | API of the optimizers |
|
||||
|
||||
### Loading Google AI or OpenAI pre-trained weigths or PyTorch dump
|
||||
### Loading Google AI or OpenAI pre-trained weights or PyTorch dump
|
||||
|
||||
To load one of Google AI's, OpenAI's pre-trained models or a PyTorch saved model (an instance of `BertForPreTraining` saved with `torch.save()`), the PyTorch model classes and the tokenizer can be instantiated as
|
||||
|
||||
@ -927,11 +927,60 @@ Where `$THIS_MACHINE_INDEX` is an sequential index assigned to each of your mach
|
||||
|
||||
We showcase several fine-tuning examples based on (and extended from) [the original implementation](https://github.com/google-research/bert/):
|
||||
|
||||
- a *sequence-level classifier* on the MRPC classification corpus,
|
||||
- a *sequence-level classifier* on nine different GLUE tasks,
|
||||
- a *token-level classifier* on the question answering dataset SQuAD, and
|
||||
- a *sequence-level multiple-choice classifier* on the SWAG classification corpus.
|
||||
- a *BERT language model* on another target corpus
|
||||
|
||||
#### GLUE results on dev set
|
||||
|
||||
We get the following results on the dev set of GLUE benchmark with an uncased BERT base
|
||||
model. All experiments were run on a P100 GPU with a batch size of 32.
|
||||
|
||||
| Task | Metric | Result |
|
||||
|-|-|-|
|
||||
| CoLA | Matthew's corr. | 57.29 |
|
||||
| SST-2 | accuracy | 93.00 |
|
||||
| MRPC | F1/accuracy | 88.85/83.82 |
|
||||
| STS-B | Pearson/Spearman corr. | 89.70/89.37 |
|
||||
| QQP | accuracy/F1 | 90.72/87.41 |
|
||||
| MNLI | matched acc./mismatched acc.| 83.95/84.39 |
|
||||
| QNLI | accuracy | 89.04 |
|
||||
| RTE | accuracy | 61.01 |
|
||||
| WNLI | accuracy | 53.52 |
|
||||
|
||||
Some of these results are significantly different from the ones reported on the test set
|
||||
of GLUE benchmark on the website. For QQP and WNLI, please refer to [FAQ #12](https://gluebenchmark.com/faq) on the webite.
|
||||
|
||||
Before running anyone of these GLUE tasks you should download the
|
||||
[GLUE data](https://gluebenchmark.com/tasks) by running
|
||||
[this script](https://gist.github.com/W4ngatang/60c2bdb54d156a41194446737ce03e2e)
|
||||
and unpack it to some directory `$GLUE_DIR`.
|
||||
|
||||
```shell
|
||||
export GLUE_DIR=/path/to/glue
|
||||
export TASK_NAME=MRPC
|
||||
|
||||
python run_classifier.py \
|
||||
--task_name $TASK_NAME \
|
||||
--do_train \
|
||||
--do_eval \
|
||||
--do_lower_case \
|
||||
--data_dir $GLUE_DIR/$TASK_NAME \
|
||||
--bert_model bert-base-uncased \
|
||||
--max_seq_length 128 \
|
||||
--train_batch_size 32 \
|
||||
--learning_rate 2e-5 \
|
||||
--num_train_epochs 3.0 \
|
||||
--output_dir /tmp/$TASK_NAME/
|
||||
```
|
||||
|
||||
where task name can be one of CoLA, SST-2, MRPC, STS-B, QQP, MNLI, QNLI, RTE, WNLI.
|
||||
|
||||
The dev set results will be present within the text file 'eval_results.txt' in the specified output_dir. In case of MNLI, since there are two separate dev sets, matched and mismatched, there will be a separate output folder called '/tmp/MNLI-MM/' in addition to '/tmp/MNLI/'.
|
||||
|
||||
The code has not been tested with half-precision training with apex on any GLUE task apart from MRPC, MNLI, CoLA, SST-2. The following section provides details on how to run half-precision training with MRPC. With that being said, there shouldn't be any issues in running half-precision training with the remaining GLUE tasks as well, since the data processor for each task inherits from the base class DataProcessor.
|
||||
|
||||
#### MRPC
|
||||
|
||||
This example code fine-tunes BERT on the Microsoft Research Paraphrase
|
||||
@ -1051,18 +1100,7 @@ You can download an [exemplary training corpus](https://ext-bert-sample.obs.eu-d
|
||||
Training one epoch on this corpus takes about 1:20h on 4 x NVIDIA Tesla P100 with `train_batch_size=200` and `max_seq_length=128`:
|
||||
|
||||
|
||||
```shell
|
||||
python run_lm_finetuning.py \
|
||||
--bert_model bert-base-uncased \
|
||||
--do_lower_case \
|
||||
--do_train \
|
||||
--train_file ../samples/sample_text.txt \
|
||||
--output_dir models \
|
||||
--num_train_epochs 5.0 \
|
||||
--learning_rate 3e-5 \
|
||||
--train_batch_size 32 \
|
||||
--max_seq_length 128 \
|
||||
```
|
||||
Thank to the work of @Rocketknight1 and @tholor there are now **several scripts** that can be used to fine-tune BERT using the pretraining objective (combination of masked-language modeling and next sentence prediction loss). These scripts are detailed in the [`README`](./examples/lm_finetuning/README.md) of the [`examples/lm_finetuning/`](./examples/lm_finetuning/) folder.
|
||||
|
||||
### OpenAI GPT, Transformer-XL and GPT-2: running the examples
|
||||
|
||||
@ -1196,9 +1234,9 @@ A command-line interface is provided to convert a TensorFlow checkpoint in a PyT
|
||||
|
||||
### BERT
|
||||
|
||||
You can convert any TensorFlow checkpoint for BERT (in particular [the pre-trained models released by Google](https://github.com/google-research/bert#pre-trained-models)) in a PyTorch save file by using the [`./pytorch_pretrained_bert/convert_tf_checkpoint_to_pytorch.py`](convert_tf_checkpoint_to_pytorch.py) script.
|
||||
You can convert any TensorFlow checkpoint for BERT (in particular [the pre-trained models released by Google](https://github.com/google-research/bert#pre-trained-models)) in a PyTorch save file by using the [`convert_tf_checkpoint_to_pytorch.py`](./pytorch_pretrained_bert/convert_tf_checkpoint_to_pytorch.py ) script.
|
||||
|
||||
This CLI takes as input a TensorFlow checkpoint (three files starting with `bert_model.ckpt`) and the associated configuration file (`bert_config.json`), and creates a PyTorch model for this configuration, loads the weights from the TensorFlow checkpoint in the PyTorch model and saves the resulting model in a standard PyTorch save file that can be imported using `torch.load()` (see examples in [`extract_features.py`](./examples/extract_features.py), [`run_classifier.py`](./examples/run_classifier.py) and [`run_squad.py`]((./examples/run_squad.py))).
|
||||
This CLI takes as input a TensorFlow checkpoint (three files starting with `bert_model.ckpt`) and the associated configuration file (`bert_config.json`), and creates a PyTorch model for this configuration, loads the weights from the TensorFlow checkpoint in the PyTorch model and saves the resulting model in a standard PyTorch save file that can be imported using `torch.load()` (see examples in [`extract_features.py`](./examples/extract_features.py), [`run_classifier.py`](./examples/run_classifier.py) and [`run_squad.py`](./examples/run_squad.py)).
|
||||
|
||||
You only need to run this conversion script **once** to get a PyTorch model. You can then disregard the TensorFlow checkpoint (the three files starting with `bert_model.ckpt`) but be sure to keep the configuration file (`bert_config.json`) and the vocabulary file (`vocab.txt`) as these are needed for the PyTorch model too.
|
||||
|
||||
|
@ -57,7 +57,7 @@ class InputFeatures(object):
|
||||
|
||||
|
||||
def convert_examples_to_features(examples, seq_length, tokenizer):
|
||||
"""Loads a data file into a list of `InputBatch`s."""
|
||||
"""Loads a data file into a list of `InputFeature`s."""
|
||||
|
||||
features = []
|
||||
for (ex_index, example) in enumerate(examples):
|
||||
|
63
examples/lm_finetuning/README.md
Normal file
63
examples/lm_finetuning/README.md
Normal file
@ -0,0 +1,63 @@
|
||||
# BERT Model Finetuning using Masked Language Modeling objective
|
||||
|
||||
## Introduction
|
||||
|
||||
The three example scripts in this folder can be used to **fine-tune** a pre-trained BERT model using the pretraining objective (combination of masked language modeling and next sentence prediction loss). In general, pretrained models like BERT are first trained with a pretraining objective (masked language modeling and next sentence prediction for BERT) on a large and general natural language corpus. A classifier head is then added on top of the pre-trained architecture and the model is quickly fine-tuned on a target task, while still (hopefully) retaining its general language understanding. This greatly reduces overfitting and yields state-of-the-art results, especially when training data for the target task are limited.
|
||||
|
||||
The [ULMFiT paper](https://arxiv.org/abs/1801.06146) took a slightly different approach, however, and added an intermediate step in which the model is fine-tuned on text **from the same domain as the target task and using the pretraining objective** before the final stage in which the classifier head is added and the model is trained on the target task itself. This paper reported significantly improved results from this step, and found that they could get high-quality classifications even with only tiny numbers (<1000) of labelled training examples, as long as they had a lot of unlabelled data from the target domain.
|
||||
|
||||
The BERT model has more capacity than the LSTM models used in the ULMFiT work, but the [BERT paper](https://arxiv.org/abs/1810.04805) did not test finetuning using the pretraining objective and at the present stage there aren't many examples of this approach being used for Transformer-based language models. As such, it's hard to predict what effect this step will have on final model performance, but it's reasonable to conjecture that this approach can improve the final classification performance, especially when a large unlabelled corpus from the target domain is available, labelled data is limited, or the target domain is very unusual and different from 'normal' English text. If you are aware of any literature on this subject, please feel free to add it in here, or open an issue and tag me (@Rocketknight1) and I'll include it.
|
||||
|
||||
## Input format
|
||||
|
||||
The scripts in this folder expect a single file as input, consisting of untokenized text, with one **sentence** per line, and one blank line between documents. The reason for the sentence splitting is that part of BERT's training involves a _next sentence_ objective in which the model must predict whether two sequences of text are contiguous text from the same document or not, and to avoid making the task _too easy_, the split point between the sequences is always at the end of a sentence. The linebreaks in the file are therefore necessary to mark the points where the text can be split.
|
||||
|
||||
## Usage
|
||||
|
||||
There are two ways to fine-tune a language model using these scripts. The first _quick_ approach is to use [`simple_lm_finetuning.py`](./simple_lm_finetuning.py). This script does everything in a single script, but generates training instances that consist of just two sentences. This is quite different from the BERT paper, where (confusingly) the NextSentence task concatenated sentences together from each document to form two long multi-sentences, which the paper just referred to as _sentences_. The difference between this simple approach and the original paper approach can have a significant effect for long sequences since two sentences will be much shorter than the max sequence length. In this case, most of each training example will just consist of blank padding characters, which wastes a lot of computation and results in a model that isn't really training on long sequences.
|
||||
|
||||
As such, the preferred approach (assuming you have documents containing multiple contiguous sentences from your target domain) is to use [`pregenerate_training_data.py`](./pregenerate_training_data.py) to pre-process your data into training examples following the methodology used for LM training in the original BERT paper and repository. Since there is a significant random component to training data generation for BERT, this script includes an option to generate multiple _epochs_ of pre-processed data, to avoid training on the same random splits each epoch. Generating an epoch of data for each training epoch should result a better final model, and so we recommend doing so.
|
||||
|
||||
You can then train on the pregenerated data using [`finetune_on_pregenerated.py`](./finetune_on_pregenerated.py), and pointing it to the folder created by [`pregenerate_training_data.py`](./pregenerate_training_data.py). Note that you should use the same `bert_model` and case options for both! Also note that `max_seq_len` does not need to be specified for the [`finetune_on_pregenerated.py`](./finetune_on_pregenerated.py) script, as it is inferred from the training examples.
|
||||
|
||||
There are various options that can be tweaked, but they are mostly set to the values from the BERT paper/repository and default values should make sense. The most relevant ones are:
|
||||
|
||||
- `--max_seq_len`: Controls the length of training examples (in wordpiece tokens) seen by the model. Defaults to 128 but can be set as high as 512. Higher values may yield stronger language models at the cost of slower and more memory-intensive training.
|
||||
- `--fp16`: Enables fast half-precision training on recent GPUs.
|
||||
|
||||
In addition, if memory usage is an issue, especially when training on a single GPU, reducing `--train_batch_size` from the default 32 to a lower number (4-16) can be helpful, or leaving `--train_batch_size` at the default and increasing `--gradient_accumulation_steps` to 2-8. Changing `--gradient_accumulation_steps` may be preferable as alterations to the batch size may require corresponding changes in the learning rate to compensate. There is also a `--reduce_memory` option for both the `pregenerate_training_data.py` and `finetune_on_pregenerated.py` scripts that spills data to disc in shelf objects or numpy memmaps rather than retaining it in memory, which significantly reduces memory usage with little performance impact.
|
||||
|
||||
## Examples
|
||||
|
||||
### Simple fine-tuning
|
||||
|
||||
```
|
||||
python3 simple_lm_finetuning.py
|
||||
--train_corpus my_corpus.txt
|
||||
--bert_model bert-base-uncased
|
||||
--do_lower_case
|
||||
--output_dir finetuned_lm/
|
||||
```
|
||||
|
||||
### Pregenerating training data
|
||||
|
||||
```
|
||||
python3 pregenerate_training_data.py
|
||||
--train_corpus my_corpus.txt
|
||||
--bert_model bert-base-uncased
|
||||
--do_lower_case
|
||||
--output_dir training/
|
||||
--epochs_to_generate 3
|
||||
--max_seq_len 256
|
||||
```
|
||||
|
||||
### Training on pregenerated data
|
||||
|
||||
```
|
||||
python3 finetune_on_pregenerated.py
|
||||
--pregenerated_data training/
|
||||
--bert_model bert-base-uncased
|
||||
--do_lower_case
|
||||
--output_dir finetuned_lm/
|
||||
--epochs 3
|
||||
```
|
334
examples/lm_finetuning/finetune_on_pregenerated.py
Normal file
334
examples/lm_finetuning/finetune_on_pregenerated.py
Normal file
@ -0,0 +1,334 @@
|
||||
from argparse import ArgumentParser
|
||||
from pathlib import Path
|
||||
import torch
|
||||
import logging
|
||||
import json
|
||||
import random
|
||||
import numpy as np
|
||||
from collections import namedtuple
|
||||
from tempfile import TemporaryDirectory
|
||||
|
||||
from torch.utils.data import DataLoader, Dataset, RandomSampler
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
from tqdm import tqdm
|
||||
|
||||
from pytorch_pretrained_bert.modeling import BertForPreTraining
|
||||
from pytorch_pretrained_bert.tokenization import BertTokenizer
|
||||
from pytorch_pretrained_bert.optimization import BertAdam, warmup_linear
|
||||
|
||||
InputFeatures = namedtuple("InputFeatures", "input_ids input_mask segment_ids lm_label_ids is_next")
|
||||
|
||||
log_format = '%(asctime)-10s: %(message)s'
|
||||
logging.basicConfig(level=logging.INFO, format=log_format)
|
||||
|
||||
|
||||
def convert_example_to_features(example, tokenizer, max_seq_length):
|
||||
tokens = example["tokens"]
|
||||
segment_ids = example["segment_ids"]
|
||||
is_random_next = example["is_random_next"]
|
||||
masked_lm_positions = example["masked_lm_positions"]
|
||||
masked_lm_labels = example["masked_lm_labels"]
|
||||
|
||||
assert len(tokens) == len(segment_ids) <= max_seq_length # The preprocessed data should be already truncated
|
||||
input_ids = tokenizer.convert_tokens_to_ids(tokens)
|
||||
masked_label_ids = tokenizer.convert_tokens_to_ids(masked_lm_labels)
|
||||
|
||||
input_array = np.zeros(max_seq_length, dtype=np.int)
|
||||
input_array[:len(input_ids)] = input_ids
|
||||
|
||||
mask_array = np.zeros(max_seq_length, dtype=np.bool)
|
||||
mask_array[:len(input_ids)] = 1
|
||||
|
||||
segment_array = np.zeros(max_seq_length, dtype=np.bool)
|
||||
segment_array[:len(segment_ids)] = segment_ids
|
||||
|
||||
lm_label_array = np.full(max_seq_length, dtype=np.int, fill_value=-1)
|
||||
lm_label_array[masked_lm_positions] = masked_label_ids
|
||||
|
||||
features = InputFeatures(input_ids=input_array,
|
||||
input_mask=mask_array,
|
||||
segment_ids=segment_array,
|
||||
lm_label_ids=lm_label_array,
|
||||
is_next=is_random_next)
|
||||
return features
|
||||
|
||||
|
||||
class PregeneratedDataset(Dataset):
|
||||
def __init__(self, training_path, epoch, tokenizer, num_data_epochs, reduce_memory=False):
|
||||
self.vocab = tokenizer.vocab
|
||||
self.tokenizer = tokenizer
|
||||
self.epoch = epoch
|
||||
self.data_epoch = epoch % num_data_epochs
|
||||
data_file = training_path / f"epoch_{self.data_epoch}.json"
|
||||
metrics_file = training_path / f"epoch_{self.data_epoch}_metrics.json"
|
||||
assert data_file.is_file() and metrics_file.is_file()
|
||||
metrics = json.loads(metrics_file.read_text())
|
||||
num_samples = metrics['num_training_examples']
|
||||
seq_len = metrics['max_seq_len']
|
||||
self.temp_dir = None
|
||||
self.working_dir = None
|
||||
if reduce_memory:
|
||||
self.temp_dir = TemporaryDirectory()
|
||||
self.working_dir = Path(self.temp_dir.name)
|
||||
input_ids = np.memmap(filename=self.working_dir/'input_ids.memmap',
|
||||
mode='w+', dtype=np.int32, shape=(num_samples, seq_len))
|
||||
input_masks = np.memmap(filename=self.working_dir/'input_masks.memmap',
|
||||
shape=(num_samples, seq_len), mode='w+', dtype=np.bool)
|
||||
segment_ids = np.memmap(filename=self.working_dir/'input_masks.memmap',
|
||||
shape=(num_samples, seq_len), mode='w+', dtype=np.bool)
|
||||
lm_label_ids = np.memmap(filename=self.working_dir/'lm_label_ids.memmap',
|
||||
shape=(num_samples, seq_len), mode='w+', dtype=np.int32)
|
||||
lm_label_ids[:] = -1
|
||||
is_nexts = np.memmap(filename=self.working_dir/'is_nexts.memmap',
|
||||
shape=(num_samples,), mode='w+', dtype=np.bool)
|
||||
else:
|
||||
input_ids = np.zeros(shape=(num_samples, seq_len), dtype=np.int32)
|
||||
input_masks = np.zeros(shape=(num_samples, seq_len), dtype=np.bool)
|
||||
segment_ids = np.zeros(shape=(num_samples, seq_len), dtype=np.bool)
|
||||
lm_label_ids = np.full(shape=(num_samples, seq_len), dtype=np.int32, fill_value=-1)
|
||||
is_nexts = np.zeros(shape=(num_samples,), dtype=np.bool)
|
||||
logging.info(f"Loading training examples for epoch {epoch}")
|
||||
with data_file.open() as f:
|
||||
for i, line in enumerate(tqdm(f, total=num_samples, desc="Training examples")):
|
||||
line = line.strip()
|
||||
example = json.loads(line)
|
||||
features = convert_example_to_features(example, tokenizer, seq_len)
|
||||
input_ids[i] = features.input_ids
|
||||
segment_ids[i] = features.segment_ids
|
||||
input_masks[i] = features.input_mask
|
||||
lm_label_ids[i] = features.lm_label_ids
|
||||
is_nexts[i] = features.is_next
|
||||
assert i == num_samples - 1 # Assert that the sample count metric was true
|
||||
logging.info("Loading complete!")
|
||||
self.num_samples = num_samples
|
||||
self.seq_len = seq_len
|
||||
self.input_ids = input_ids
|
||||
self.input_masks = input_masks
|
||||
self.segment_ids = segment_ids
|
||||
self.lm_label_ids = lm_label_ids
|
||||
self.is_nexts = is_nexts
|
||||
|
||||
def __len__(self):
|
||||
return self.num_samples
|
||||
|
||||
def __getitem__(self, item):
|
||||
return (torch.tensor(self.input_ids[item].astype(np.int64)),
|
||||
torch.tensor(self.input_masks[item].astype(np.int64)),
|
||||
torch.tensor(self.segment_ids[item].astype(np.int64)),
|
||||
torch.tensor(self.lm_label_ids[item].astype(np.int64)),
|
||||
torch.tensor(self.is_nexts[item].astype(np.int64)))
|
||||
|
||||
|
||||
def main():
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument('--pregenerated_data', type=Path, required=True)
|
||||
parser.add_argument('--output_dir', type=Path, required=True)
|
||||
parser.add_argument("--bert_model", type=str, required=True,
|
||||
choices=["bert-base-uncased", "bert-large-uncased", "bert-base-cased",
|
||||
"bert-base-multilingual", "bert-base-chinese"])
|
||||
parser.add_argument("--do_lower_case", action="store_true")
|
||||
parser.add_argument("--reduce_memory", action="store_true",
|
||||
help="Store training data as on-disc memmaps to massively reduce memory usage")
|
||||
|
||||
parser.add_argument("--epochs", type=int, default=3, help="Number of epochs to train for")
|
||||
parser.add_argument("--local_rank",
|
||||
type=int,
|
||||
default=-1,
|
||||
help="local_rank for distributed training on gpus")
|
||||
parser.add_argument("--no_cuda",
|
||||
action='store_true',
|
||||
help="Whether not to use CUDA when available")
|
||||
parser.add_argument('--gradient_accumulation_steps',
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of updates steps to accumulate before performing a backward/update pass.")
|
||||
parser.add_argument("--train_batch_size",
|
||||
default=32,
|
||||
type=int,
|
||||
help="Total batch size for training.")
|
||||
parser.add_argument('--fp16',
|
||||
action='store_true',
|
||||
help="Whether to use 16-bit float precision instead of 32-bit")
|
||||
parser.add_argument('--loss_scale',
|
||||
type=float, default=0,
|
||||
help="Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
|
||||
"0 (default value): dynamic loss scaling.\n"
|
||||
"Positive power of 2: static loss scaling value.\n")
|
||||
parser.add_argument("--warmup_proportion",
|
||||
default=0.1,
|
||||
type=float,
|
||||
help="Proportion of training to perform linear learning rate warmup for. "
|
||||
"E.g., 0.1 = 10%% of training.")
|
||||
parser.add_argument("--learning_rate",
|
||||
default=3e-5,
|
||||
type=float,
|
||||
help="The initial learning rate for Adam.")
|
||||
parser.add_argument('--seed',
|
||||
type=int,
|
||||
default=42,
|
||||
help="random seed for initialization")
|
||||
args = parser.parse_args()
|
||||
|
||||
assert args.pregenerated_data.is_dir(), \
|
||||
"--pregenerated_data should point to the folder of files made by pregenerate_training_data.py!"
|
||||
|
||||
samples_per_epoch = []
|
||||
for i in range(args.epochs):
|
||||
epoch_file = args.pregenerated_data / f"epoch_{i}.json"
|
||||
metrics_file = args.pregenerated_data / f"epoch_{i}_metrics.json"
|
||||
if epoch_file.is_file() and metrics_file.is_file():
|
||||
metrics = json.loads(metrics_file.read_text())
|
||||
samples_per_epoch.append(metrics['num_training_examples'])
|
||||
else:
|
||||
if i == 0:
|
||||
exit("No training data was found!")
|
||||
print(f"Warning! There are fewer epochs of pregenerated data ({i}) than training epochs ({args.epochs}).")
|
||||
print("This script will loop over the available data, but training diversity may be negatively impacted.")
|
||||
num_data_epochs = i
|
||||
break
|
||||
else:
|
||||
num_data_epochs = args.epochs
|
||||
|
||||
if args.local_rank == -1 or args.no_cuda:
|
||||
device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
|
||||
n_gpu = torch.cuda.device_count()
|
||||
else:
|
||||
torch.cuda.set_device(args.local_rank)
|
||||
device = torch.device("cuda", args.local_rank)
|
||||
n_gpu = 1
|
||||
# Initializes the distributed backend which will take care of sychronizing nodes/GPUs
|
||||
torch.distributed.init_process_group(backend='nccl')
|
||||
logging.info("device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".format(
|
||||
device, n_gpu, bool(args.local_rank != -1), args.fp16))
|
||||
|
||||
if args.gradient_accumulation_steps < 1:
|
||||
raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format(
|
||||
args.gradient_accumulation_steps))
|
||||
|
||||
args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps
|
||||
|
||||
random.seed(args.seed)
|
||||
np.random.seed(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
if n_gpu > 0:
|
||||
torch.cuda.manual_seed_all(args.seed)
|
||||
|
||||
if args.output_dir.is_dir() and list(args.output_dir.iterdir()):
|
||||
logging.warning(f"Output directory ({args.output_dir}) already exists and is not empty!")
|
||||
args.output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)
|
||||
|
||||
total_train_examples = 0
|
||||
for i in range(args.epochs):
|
||||
# The modulo takes into account the fact that we may loop over limited epochs of data
|
||||
total_train_examples += samples_per_epoch[i % len(samples_per_epoch)]
|
||||
|
||||
num_train_optimization_steps = int(
|
||||
total_train_examples / args.train_batch_size / args.gradient_accumulation_steps)
|
||||
if args.local_rank != -1:
|
||||
num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size()
|
||||
|
||||
# Prepare model
|
||||
model = BertForPreTraining.from_pretrained(args.bert_model)
|
||||
if args.fp16:
|
||||
model.half()
|
||||
model.to(device)
|
||||
if args.local_rank != -1:
|
||||
try:
|
||||
from apex.parallel import DistributedDataParallel as DDP
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")
|
||||
model = DDP(model)
|
||||
elif n_gpu > 1:
|
||||
model = torch.nn.DataParallel(model)
|
||||
|
||||
# Prepare optimizer
|
||||
param_optimizer = list(model.named_parameters())
|
||||
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
|
||||
optimizer_grouped_parameters = [
|
||||
{'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
|
||||
'weight_decay': 0.01},
|
||||
{'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
|
||||
]
|
||||
|
||||
if args.fp16:
|
||||
try:
|
||||
from apex.optimizers import FP16_Optimizer
|
||||
from apex.optimizers import FusedAdam
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")
|
||||
|
||||
optimizer = FusedAdam(optimizer_grouped_parameters,
|
||||
lr=args.learning_rate,
|
||||
bias_correction=False,
|
||||
max_grad_norm=1.0)
|
||||
if args.loss_scale == 0:
|
||||
optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True)
|
||||
else:
|
||||
optimizer = FP16_Optimizer(optimizer, static_loss_scale=args.loss_scale)
|
||||
|
||||
else:
|
||||
optimizer = BertAdam(optimizer_grouped_parameters,
|
||||
lr=args.learning_rate,
|
||||
warmup=args.warmup_proportion,
|
||||
t_total=num_train_optimization_steps)
|
||||
|
||||
global_step = 0
|
||||
logging.info("***** Running training *****")
|
||||
logging.info(f" Num examples = {total_train_examples}")
|
||||
logging.info(" Batch size = %d", args.train_batch_size)
|
||||
logging.info(" Num steps = %d", num_train_optimization_steps)
|
||||
model.train()
|
||||
for epoch in range(args.epochs):
|
||||
epoch_dataset = PregeneratedDataset(epoch=epoch, training_path=args.pregenerated_data, tokenizer=tokenizer,
|
||||
num_data_epochs=num_data_epochs)
|
||||
if args.local_rank == -1:
|
||||
train_sampler = RandomSampler(epoch_dataset)
|
||||
else:
|
||||
train_sampler = DistributedSampler(epoch_dataset)
|
||||
train_dataloader = DataLoader(epoch_dataset, sampler=train_sampler, batch_size=args.train_batch_size)
|
||||
tr_loss = 0
|
||||
nb_tr_examples, nb_tr_steps = 0, 0
|
||||
with tqdm(total=len(train_dataloader), desc=f"Epoch {epoch}") as pbar:
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
batch = tuple(t.to(device) for t in batch)
|
||||
input_ids, input_mask, segment_ids, lm_label_ids, is_next = batch
|
||||
loss = model(input_ids, segment_ids, input_mask, lm_label_ids, is_next)
|
||||
if n_gpu > 1:
|
||||
loss = loss.mean() # mean() to average on multi-gpu.
|
||||
if args.gradient_accumulation_steps > 1:
|
||||
loss = loss / args.gradient_accumulation_steps
|
||||
if args.fp16:
|
||||
optimizer.backward(loss)
|
||||
else:
|
||||
loss.backward()
|
||||
tr_loss += loss.item()
|
||||
nb_tr_examples += input_ids.size(0)
|
||||
nb_tr_steps += 1
|
||||
pbar.update(1)
|
||||
mean_loss = tr_loss * args.gradient_accumulation_steps / nb_tr_steps
|
||||
pbar.set_postfix_str(f"Loss: {mean_loss:.5f}")
|
||||
if (step + 1) % args.gradient_accumulation_steps == 0:
|
||||
if args.fp16:
|
||||
# modify learning rate with special warm up BERT uses
|
||||
# if args.fp16 is False, BertAdam is used that handles this automatically
|
||||
lr_this_step = args.learning_rate * warmup_linear(global_step/num_train_optimization_steps,
|
||||
args.warmup_proportion)
|
||||
for param_group in optimizer.param_groups:
|
||||
param_group['lr'] = lr_this_step
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
global_step += 1
|
||||
|
||||
# Save a trained model
|
||||
logging.info("** ** * Saving fine-tuned model ** ** * ")
|
||||
model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self
|
||||
output_model_file = args.output_dir / "pytorch_model.bin"
|
||||
torch.save(model_to_save.state_dict(), str(output_model_file))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
292
examples/lm_finetuning/pregenerate_training_data.py
Normal file
292
examples/lm_finetuning/pregenerate_training_data.py
Normal file
@ -0,0 +1,292 @@
|
||||
from argparse import ArgumentParser
|
||||
from pathlib import Path
|
||||
from tqdm import tqdm, trange
|
||||
from tempfile import TemporaryDirectory
|
||||
import shelve
|
||||
|
||||
from random import random, randint, shuffle, choice, sample
|
||||
from pytorch_pretrained_bert.tokenization import BertTokenizer
|
||||
import numpy as np
|
||||
import json
|
||||
|
||||
|
||||
class DocumentDatabase:
|
||||
def __init__(self, reduce_memory=False):
|
||||
if reduce_memory:
|
||||
self.temp_dir = TemporaryDirectory()
|
||||
self.working_dir = Path(self.temp_dir.name)
|
||||
self.document_shelf_filepath = self.working_dir / 'shelf.db'
|
||||
self.document_shelf = shelve.open(str(self.document_shelf_filepath),
|
||||
flag='n', protocol=-1)
|
||||
self.documents = None
|
||||
else:
|
||||
self.documents = []
|
||||
self.document_shelf = None
|
||||
self.document_shelf_filepath = None
|
||||
self.temp_dir = None
|
||||
self.doc_lengths = []
|
||||
self.doc_cumsum = None
|
||||
self.cumsum_max = None
|
||||
self.reduce_memory = reduce_memory
|
||||
|
||||
def add_document(self, document):
|
||||
if self.reduce_memory:
|
||||
current_idx = len(self.doc_lengths)
|
||||
self.document_shelf[str(current_idx)] = document
|
||||
else:
|
||||
self.documents.append(document)
|
||||
self.doc_lengths.append(len(document))
|
||||
|
||||
def _precalculate_doc_weights(self):
|
||||
self.doc_cumsum = np.cumsum(self.doc_lengths)
|
||||
self.cumsum_max = self.doc_cumsum[-1]
|
||||
|
||||
def sample_doc(self, current_idx, sentence_weighted=True):
|
||||
# Uses the current iteration counter to ensure we don't sample the same doc twice
|
||||
if sentence_weighted:
|
||||
# With sentence weighting, we sample docs proportionally to their sentence length
|
||||
if self.doc_cumsum is None or len(self.doc_cumsum) != len(self.doc_lengths):
|
||||
self._precalculate_doc_weights()
|
||||
rand_start = self.doc_cumsum[current_idx]
|
||||
rand_end = rand_start + self.cumsum_max - self.doc_lengths[current_idx]
|
||||
sentence_index = randint(rand_start, rand_end-1) % self.cumsum_max
|
||||
sampled_doc_index = np.searchsorted(self.doc_cumsum, sentence_index, side='right')
|
||||
else:
|
||||
# If we don't use sentence weighting, then every doc has an equal chance to be chosen
|
||||
sampled_doc_index = current_idx + randint(1, len(self.doc_lengths)-1)
|
||||
assert sampled_doc_index != current_idx
|
||||
if self.reduce_memory:
|
||||
return self.document_shelf[str(sampled_doc_index)]
|
||||
else:
|
||||
return self.documents[sampled_doc_index]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.doc_lengths)
|
||||
|
||||
def __getitem__(self, item):
|
||||
if self.reduce_memory:
|
||||
return self.document_shelf[str(item)]
|
||||
else:
|
||||
return self.documents[item]
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, traceback):
|
||||
if self.document_shelf is not None:
|
||||
self.document_shelf.close()
|
||||
if self.temp_dir is not None:
|
||||
self.temp_dir.cleanup()
|
||||
|
||||
|
||||
def truncate_seq_pair(tokens_a, tokens_b, max_num_tokens):
|
||||
"""Truncates a pair of sequences to a maximum sequence length. Lifted from Google's BERT repo."""
|
||||
while True:
|
||||
total_length = len(tokens_a) + len(tokens_b)
|
||||
if total_length <= max_num_tokens:
|
||||
break
|
||||
|
||||
trunc_tokens = tokens_a if len(tokens_a) > len(tokens_b) else tokens_b
|
||||
assert len(trunc_tokens) >= 1
|
||||
|
||||
# We want to sometimes truncate from the front and sometimes from the
|
||||
# back to add more randomness and avoid biases.
|
||||
if random() < 0.5:
|
||||
del trunc_tokens[0]
|
||||
else:
|
||||
trunc_tokens.pop()
|
||||
|
||||
|
||||
def create_masked_lm_predictions(tokens, masked_lm_prob, max_predictions_per_seq, vocab_list):
|
||||
"""Creates the predictions for the masked LM objective. This is mostly copied from the Google BERT repo, but
|
||||
with several refactors to clean it up and remove a lot of unnecessary variables."""
|
||||
cand_indices = []
|
||||
for (i, token) in enumerate(tokens):
|
||||
if token == "[CLS]" or token == "[SEP]":
|
||||
continue
|
||||
cand_indices.append(i)
|
||||
|
||||
num_to_mask = min(max_predictions_per_seq,
|
||||
max(1, int(round(len(tokens) * masked_lm_prob))))
|
||||
shuffle(cand_indices)
|
||||
mask_indices = sorted(sample(cand_indices, num_to_mask))
|
||||
masked_token_labels = []
|
||||
for index in mask_indices:
|
||||
# 80% of the time, replace with [MASK]
|
||||
if random() < 0.8:
|
||||
masked_token = "[MASK]"
|
||||
else:
|
||||
# 10% of the time, keep original
|
||||
if random() < 0.5:
|
||||
masked_token = tokens[index]
|
||||
# 10% of the time, replace with random word
|
||||
else:
|
||||
masked_token = choice(vocab_list)
|
||||
masked_token_labels.append(tokens[index])
|
||||
# Once we've saved the true label for that token, we can overwrite it with the masked version
|
||||
tokens[index] = masked_token
|
||||
|
||||
return tokens, mask_indices, masked_token_labels
|
||||
|
||||
|
||||
def create_instances_from_document(
|
||||
doc_database, doc_idx, max_seq_length, short_seq_prob,
|
||||
masked_lm_prob, max_predictions_per_seq, vocab_list):
|
||||
"""This code is mostly a duplicate of the equivalent function from Google BERT's repo.
|
||||
However, we make some changes and improvements. Sampling is improved and no longer requires a loop in this function.
|
||||
Also, documents are sampled proportionally to the number of sentences they contain, which means each sentence
|
||||
(rather than each document) has an equal chance of being sampled as a false example for the NextSentence task."""
|
||||
document = doc_database[doc_idx]
|
||||
# Account for [CLS], [SEP], [SEP]
|
||||
max_num_tokens = max_seq_length - 3
|
||||
|
||||
# We *usually* want to fill up the entire sequence since we are padding
|
||||
# to `max_seq_length` anyways, so short sequences are generally wasted
|
||||
# computation. However, we *sometimes*
|
||||
# (i.e., short_seq_prob == 0.1 == 10% of the time) want to use shorter
|
||||
# sequences to minimize the mismatch between pre-training and fine-tuning.
|
||||
# The `target_seq_length` is just a rough target however, whereas
|
||||
# `max_seq_length` is a hard limit.
|
||||
target_seq_length = max_num_tokens
|
||||
if random() < short_seq_prob:
|
||||
target_seq_length = randint(2, max_num_tokens)
|
||||
|
||||
# We DON'T just concatenate all of the tokens from a document into a long
|
||||
# sequence and choose an arbitrary split point because this would make the
|
||||
# next sentence prediction task too easy. Instead, we split the input into
|
||||
# segments "A" and "B" based on the actual "sentences" provided by the user
|
||||
# input.
|
||||
instances = []
|
||||
current_chunk = []
|
||||
current_length = 0
|
||||
i = 0
|
||||
while i < len(document):
|
||||
segment = document[i]
|
||||
current_chunk.append(segment)
|
||||
current_length += len(segment)
|
||||
if i == len(document) - 1 or current_length >= target_seq_length:
|
||||
if current_chunk:
|
||||
# `a_end` is how many segments from `current_chunk` go into the `A`
|
||||
# (first) sentence.
|
||||
a_end = 1
|
||||
if len(current_chunk) >= 2:
|
||||
a_end = randint(1, len(current_chunk) - 1)
|
||||
|
||||
tokens_a = []
|
||||
for j in range(a_end):
|
||||
tokens_a.extend(current_chunk[j])
|
||||
|
||||
tokens_b = []
|
||||
|
||||
# Random next
|
||||
if len(current_chunk) == 1 or random() < 0.5:
|
||||
is_random_next = True
|
||||
target_b_length = target_seq_length - len(tokens_a)
|
||||
|
||||
# Sample a random document, with longer docs being sampled more frequently
|
||||
random_document = doc_database.sample_doc(current_idx=doc_idx, sentence_weighted=True)
|
||||
|
||||
random_start = randint(0, len(random_document) - 1)
|
||||
for j in range(random_start, len(random_document)):
|
||||
tokens_b.extend(random_document[j])
|
||||
if len(tokens_b) >= target_b_length:
|
||||
break
|
||||
# We didn't actually use these segments so we "put them back" so
|
||||
# they don't go to waste.
|
||||
num_unused_segments = len(current_chunk) - a_end
|
||||
i -= num_unused_segments
|
||||
# Actual next
|
||||
else:
|
||||
is_random_next = False
|
||||
for j in range(a_end, len(current_chunk)):
|
||||
tokens_b.extend(current_chunk[j])
|
||||
truncate_seq_pair(tokens_a, tokens_b, max_num_tokens)
|
||||
|
||||
assert len(tokens_a) >= 1
|
||||
assert len(tokens_b) >= 1
|
||||
|
||||
tokens = ["[CLS]"] + tokens_a + ["[SEP]"] + tokens_b + ["[SEP]"]
|
||||
# The segment IDs are 0 for the [CLS] token, the A tokens and the first [SEP]
|
||||
# They are 1 for the B tokens and the final [SEP]
|
||||
segment_ids = [0 for _ in range(len(tokens_a) + 2)] + [1 for _ in range(len(tokens_b) + 1)]
|
||||
|
||||
tokens, masked_lm_positions, masked_lm_labels = create_masked_lm_predictions(
|
||||
tokens, masked_lm_prob, max_predictions_per_seq, vocab_list)
|
||||
|
||||
instance = {
|
||||
"tokens": tokens,
|
||||
"segment_ids": segment_ids,
|
||||
"is_random_next": is_random_next,
|
||||
"masked_lm_positions": masked_lm_positions,
|
||||
"masked_lm_labels": masked_lm_labels}
|
||||
instances.append(instance)
|
||||
current_chunk = []
|
||||
current_length = 0
|
||||
i += 1
|
||||
|
||||
return instances
|
||||
|
||||
|
||||
def main():
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument('--train_corpus', type=Path, required=True)
|
||||
parser.add_argument("--output_dir", type=Path, required=True)
|
||||
parser.add_argument("--bert_model", type=str, required=True,
|
||||
choices=["bert-base-uncased", "bert-large-uncased", "bert-base-cased",
|
||||
"bert-base-multilingual", "bert-base-chinese"])
|
||||
parser.add_argument("--do_lower_case", action="store_true")
|
||||
|
||||
parser.add_argument("--reduce_memory", action="store_true",
|
||||
help="Reduce memory usage for large datasets by keeping data on disc rather than in memory")
|
||||
|
||||
parser.add_argument("--epochs_to_generate", type=int, default=3,
|
||||
help="Number of epochs of data to pregenerate")
|
||||
parser.add_argument("--max_seq_len", type=int, default=128)
|
||||
parser.add_argument("--short_seq_prob", type=float, default=0.1,
|
||||
help="Probability of making a short sentence as a training example")
|
||||
parser.add_argument("--masked_lm_prob", type=float, default=0.15,
|
||||
help="Probability of masking each token for the LM task")
|
||||
parser.add_argument("--max_predictions_per_seq", type=int, default=20,
|
||||
help="Maximum number of tokens to mask in each sequence")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)
|
||||
vocab_list = list(tokenizer.vocab.keys())
|
||||
with DocumentDatabase(reduce_memory=args.reduce_memory) as docs:
|
||||
with args.train_corpus.open() as f:
|
||||
doc = []
|
||||
for line in tqdm(f, desc="Loading Dataset", unit=" lines"):
|
||||
line = line.strip()
|
||||
if line == "":
|
||||
docs.add_document(doc)
|
||||
doc = []
|
||||
else:
|
||||
tokens = tokenizer.tokenize(line)
|
||||
doc.append(tokens)
|
||||
|
||||
args.output_dir.mkdir(exist_ok=True)
|
||||
for epoch in trange(args.epochs_to_generate, desc="Epoch"):
|
||||
epoch_filename = args.output_dir / f"epoch_{epoch}.json"
|
||||
num_instances = 0
|
||||
with epoch_filename.open('w') as epoch_file:
|
||||
for doc_idx in trange(len(docs), desc="Document"):
|
||||
doc_instances = create_instances_from_document(
|
||||
docs, doc_idx, max_seq_length=args.max_seq_len, short_seq_prob=args.short_seq_prob,
|
||||
masked_lm_prob=args.masked_lm_prob, max_predictions_per_seq=args.max_predictions_per_seq,
|
||||
vocab_list=vocab_list)
|
||||
doc_instances = [json.dumps(instance) for instance in doc_instances]
|
||||
for instance in doc_instances:
|
||||
epoch_file.write(instance + '\n')
|
||||
num_instances += 1
|
||||
metrics_file = args.output_dir / f"epoch_{epoch}_metrics.json"
|
||||
with metrics_file.open('w') as metrics_file:
|
||||
metrics = {
|
||||
"num_training_examples": num_instances,
|
||||
"max_seq_len": args.max_seq_len
|
||||
}
|
||||
metrics_file.write(json.dumps(metrics))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
@ -33,9 +33,6 @@ from pytorch_pretrained_bert.modeling import BertForPreTraining
|
||||
from pytorch_pretrained_bert.tokenization import BertTokenizer
|
||||
from pytorch_pretrained_bert.optimization import BertAdam, warmup_linear
|
||||
|
||||
from torch.utils.data import Dataset
|
||||
import random
|
||||
|
||||
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
||||
datefmt='%m/%d/%Y %H:%M:%S',
|
||||
level=logging.INFO)
|
||||
@ -404,7 +401,7 @@ def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
## Required parameters
|
||||
parser.add_argument("--train_file",
|
||||
parser.add_argument("--train_corpus",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
@ -514,8 +511,8 @@ def main():
|
||||
#train_examples = None
|
||||
num_train_optimization_steps = None
|
||||
if args.do_train:
|
||||
print("Loading Train Dataset", args.train_file)
|
||||
train_dataset = BERTDataset(args.train_file, tokenizer, seq_len=args.max_seq_length,
|
||||
print("Loading Train Dataset", args.train_corpus)
|
||||
train_dataset = BERTDataset(args.train_corpus, tokenizer, seq_len=args.max_seq_length,
|
||||
corpus_lines=None, on_memory=args.on_memory)
|
||||
num_train_optimization_steps = int(
|
||||
len(train_dataset) / args.train_batch_size / args.gradient_accumulation_steps) * args.num_train_epochs
|
@ -31,6 +31,10 @@ from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler,
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
from tqdm import tqdm, trange
|
||||
|
||||
from torch.nn import CrossEntropyLoss, MSELoss
|
||||
from scipy.stats import pearsonr, spearmanr
|
||||
from sklearn.metrics import matthews_corrcoef, f1_score
|
||||
|
||||
from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE
|
||||
from pytorch_pretrained_bert.modeling import BertForSequenceClassification, BertConfig, WEIGHTS_NAME, CONFIG_NAME
|
||||
from pytorch_pretrained_bert.tokenization import BertTokenizer
|
||||
@ -167,6 +171,16 @@ class MnliProcessor(DataProcessor):
|
||||
return examples
|
||||
|
||||
|
||||
class MnliMismatchedProcessor(MnliProcessor):
|
||||
"""Processor for the MultiNLI Mismatched data set (GLUE version)."""
|
||||
|
||||
def get_dev_examples(self, data_dir):
|
||||
"""See base class."""
|
||||
return self._create_examples(
|
||||
self._read_tsv(os.path.join(data_dir, "dev_mismatched.tsv")),
|
||||
"dev_matched")
|
||||
|
||||
|
||||
class ColaProcessor(DataProcessor):
|
||||
"""Processor for the CoLA data set (GLUE version)."""
|
||||
|
||||
@ -227,13 +241,181 @@ class Sst2Processor(DataProcessor):
|
||||
return examples
|
||||
|
||||
|
||||
def convert_examples_to_features(examples, label_list, max_seq_length, tokenizer):
|
||||
class StsbProcessor(DataProcessor):
|
||||
"""Processor for the STS-B data set (GLUE version)."""
|
||||
|
||||
def get_train_examples(self, data_dir):
|
||||
"""See base class."""
|
||||
return self._create_examples(
|
||||
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
|
||||
|
||||
def get_dev_examples(self, data_dir):
|
||||
"""See base class."""
|
||||
return self._create_examples(
|
||||
self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
|
||||
|
||||
def get_labels(self):
|
||||
"""See base class."""
|
||||
return [None]
|
||||
|
||||
def _create_examples(self, lines, set_type):
|
||||
"""Creates examples for the training and dev sets."""
|
||||
examples = []
|
||||
for (i, line) in enumerate(lines):
|
||||
if i == 0:
|
||||
continue
|
||||
guid = "%s-%s" % (set_type, line[0])
|
||||
text_a = line[7]
|
||||
text_b = line[8]
|
||||
label = line[-1]
|
||||
examples.append(
|
||||
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
|
||||
return examples
|
||||
|
||||
|
||||
class QqpProcessor(DataProcessor):
|
||||
"""Processor for the STS-B data set (GLUE version)."""
|
||||
|
||||
def get_train_examples(self, data_dir):
|
||||
"""See base class."""
|
||||
return self._create_examples(
|
||||
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
|
||||
|
||||
def get_dev_examples(self, data_dir):
|
||||
"""See base class."""
|
||||
return self._create_examples(
|
||||
self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
|
||||
|
||||
def get_labels(self):
|
||||
"""See base class."""
|
||||
return ["0", "1"]
|
||||
|
||||
def _create_examples(self, lines, set_type):
|
||||
"""Creates examples for the training and dev sets."""
|
||||
examples = []
|
||||
for (i, line) in enumerate(lines):
|
||||
if i == 0:
|
||||
continue
|
||||
guid = "%s-%s" % (set_type, line[0])
|
||||
try:
|
||||
text_a = line[3]
|
||||
text_b = line[4]
|
||||
label = line[5]
|
||||
except IndexError:
|
||||
continue
|
||||
examples.append(
|
||||
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
|
||||
return examples
|
||||
|
||||
|
||||
class QnliProcessor(DataProcessor):
|
||||
"""Processor for the STS-B data set (GLUE version)."""
|
||||
|
||||
def get_train_examples(self, data_dir):
|
||||
"""See base class."""
|
||||
return self._create_examples(
|
||||
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
|
||||
|
||||
def get_dev_examples(self, data_dir):
|
||||
"""See base class."""
|
||||
return self._create_examples(
|
||||
self._read_tsv(os.path.join(data_dir, "dev.tsv")),
|
||||
"dev_matched")
|
||||
|
||||
def get_labels(self):
|
||||
"""See base class."""
|
||||
return ["entailment", "not_entailment"]
|
||||
|
||||
def _create_examples(self, lines, set_type):
|
||||
"""Creates examples for the training and dev sets."""
|
||||
examples = []
|
||||
for (i, line) in enumerate(lines):
|
||||
if i == 0:
|
||||
continue
|
||||
guid = "%s-%s" % (set_type, line[0])
|
||||
text_a = line[1]
|
||||
text_b = line[2]
|
||||
label = line[-1]
|
||||
examples.append(
|
||||
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
|
||||
return examples
|
||||
|
||||
|
||||
class RteProcessor(DataProcessor):
|
||||
"""Processor for the RTE data set (GLUE version)."""
|
||||
|
||||
def get_train_examples(self, data_dir):
|
||||
"""See base class."""
|
||||
return self._create_examples(
|
||||
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
|
||||
|
||||
def get_dev_examples(self, data_dir):
|
||||
"""See base class."""
|
||||
return self._create_examples(
|
||||
self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
|
||||
|
||||
def get_labels(self):
|
||||
"""See base class."""
|
||||
return ["entailment", "not_entailment"]
|
||||
|
||||
def _create_examples(self, lines, set_type):
|
||||
"""Creates examples for the training and dev sets."""
|
||||
examples = []
|
||||
for (i, line) in enumerate(lines):
|
||||
if i == 0:
|
||||
continue
|
||||
guid = "%s-%s" % (set_type, line[0])
|
||||
text_a = line[1]
|
||||
text_b = line[2]
|
||||
label = line[-1]
|
||||
examples.append(
|
||||
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
|
||||
return examples
|
||||
|
||||
|
||||
class WnliProcessor(DataProcessor):
|
||||
"""Processor for the WNLI data set (GLUE version)."""
|
||||
|
||||
def get_train_examples(self, data_dir):
|
||||
"""See base class."""
|
||||
return self._create_examples(
|
||||
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
|
||||
|
||||
def get_dev_examples(self, data_dir):
|
||||
"""See base class."""
|
||||
return self._create_examples(
|
||||
self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
|
||||
|
||||
def get_labels(self):
|
||||
"""See base class."""
|
||||
return ["0", "1"]
|
||||
|
||||
def _create_examples(self, lines, set_type):
|
||||
"""Creates examples for the training and dev sets."""
|
||||
examples = []
|
||||
for (i, line) in enumerate(lines):
|
||||
if i == 0:
|
||||
continue
|
||||
guid = "%s-%s" % (set_type, line[0])
|
||||
text_a = line[1]
|
||||
text_b = line[2]
|
||||
label = line[-1]
|
||||
examples.append(
|
||||
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
|
||||
return examples
|
||||
|
||||
|
||||
def convert_examples_to_features(examples, label_list, max_seq_length,
|
||||
tokenizer, output_mode):
|
||||
"""Loads a data file into a list of `InputBatch`s."""
|
||||
|
||||
label_map = {label : i for i, label in enumerate(label_list)}
|
||||
|
||||
features = []
|
||||
for (ex_index, example) in enumerate(examples):
|
||||
if ex_index % 10000 == 0:
|
||||
logger.info("Writing example %d of %d" % (ex_index, len(examples)))
|
||||
|
||||
tokens_a = tokenizer.tokenize(example.text_a)
|
||||
|
||||
tokens_b = None
|
||||
@ -260,7 +442,7 @@ def convert_examples_to_features(examples, label_list, max_seq_length, tokenizer
|
||||
# sequence or the second sequence. The embedding vectors for `type=0` and
|
||||
# `type=1` were learned during pre-training and are added to the wordpiece
|
||||
# embedding vector (and position vector). This is not *strictly* necessary
|
||||
# since the [SEP] token unambigiously separates the sequences, but it makes
|
||||
# since the [SEP] token unambiguously separates the sequences, but it makes
|
||||
# it easier for the model to learn the concept of sequences.
|
||||
#
|
||||
# For classification tasks, the first vector (corresponding to [CLS]) is
|
||||
@ -289,7 +471,13 @@ def convert_examples_to_features(examples, label_list, max_seq_length, tokenizer
|
||||
assert len(input_mask) == max_seq_length
|
||||
assert len(segment_ids) == max_seq_length
|
||||
|
||||
label_id = label_map[example.label]
|
||||
if output_mode == "classification":
|
||||
label_id = label_map[example.label]
|
||||
elif output_mode == "regression":
|
||||
label_id = float(example.label)
|
||||
else:
|
||||
raise KeyError(output_mode)
|
||||
|
||||
if ex_index < 5:
|
||||
logger.info("*** Example ***")
|
||||
logger.info("guid: %s" % (example.guid))
|
||||
@ -325,9 +513,56 @@ def _truncate_seq_pair(tokens_a, tokens_b, max_length):
|
||||
else:
|
||||
tokens_b.pop()
|
||||
|
||||
def accuracy(out, labels):
|
||||
outputs = np.argmax(out, axis=1)
|
||||
return np.sum(outputs == labels)
|
||||
|
||||
def simple_accuracy(preds, labels):
|
||||
return (preds == labels).mean()
|
||||
|
||||
|
||||
def acc_and_f1(preds, labels):
|
||||
acc = simple_accuracy(preds, labels)
|
||||
f1 = f1_score(y_true=labels, y_pred=preds)
|
||||
return {
|
||||
"acc": acc,
|
||||
"f1": f1,
|
||||
"acc_and_f1": (acc + f1) / 2,
|
||||
}
|
||||
|
||||
|
||||
def pearson_and_spearman(preds, labels):
|
||||
pearson_corr = pearsonr(preds, labels)[0]
|
||||
spearman_corr = spearmanr(preds, labels)[0]
|
||||
return {
|
||||
"pearson": pearson_corr,
|
||||
"spearmanr": spearman_corr,
|
||||
"corr": (pearson_corr + spearman_corr) / 2,
|
||||
}
|
||||
|
||||
|
||||
def compute_metrics(task_name, preds, labels):
|
||||
assert len(preds) == len(labels)
|
||||
if task_name == "cola":
|
||||
return {"mcc": matthews_corrcoef(labels, preds)}
|
||||
elif task_name == "sst-2":
|
||||
return {"acc": simple_accuracy(preds, labels)}
|
||||
elif task_name == "mrpc":
|
||||
return acc_and_f1(preds, labels)
|
||||
elif task_name == "sts-b":
|
||||
return pearson_and_spearman(preds, labels)
|
||||
elif task_name == "qqp":
|
||||
return acc_and_f1(preds, labels)
|
||||
elif task_name == "mnli":
|
||||
return {"acc": simple_accuracy(preds, labels)}
|
||||
elif task_name == "mnli-mm":
|
||||
return {"acc": simple_accuracy(preds, labels)}
|
||||
elif task_name == "qnli":
|
||||
return {"acc": simple_accuracy(preds, labels)}
|
||||
elif task_name == "rte":
|
||||
return {"acc": simple_accuracy(preds, labels)}
|
||||
elif task_name == "wnli":
|
||||
return {"acc": simple_accuracy(preds, labels)}
|
||||
else:
|
||||
raise KeyError(task_name)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
@ -431,15 +666,26 @@ def main():
|
||||
processors = {
|
||||
"cola": ColaProcessor,
|
||||
"mnli": MnliProcessor,
|
||||
"mnli-mm": MnliMismatchedProcessor,
|
||||
"mrpc": MrpcProcessor,
|
||||
"sst-2": Sst2Processor,
|
||||
"sts-b": StsbProcessor,
|
||||
"qqp": QqpProcessor,
|
||||
"qnli": QnliProcessor,
|
||||
"rte": RteProcessor,
|
||||
"wnli": WnliProcessor,
|
||||
}
|
||||
|
||||
num_labels_task = {
|
||||
"cola": 2,
|
||||
"sst-2": 2,
|
||||
"mnli": 3,
|
||||
"mrpc": 2,
|
||||
output_modes = {
|
||||
"cola": "classification",
|
||||
"mnli": "classification",
|
||||
"mrpc": "classification",
|
||||
"sst-2": "classification",
|
||||
"sts-b": "regression",
|
||||
"qqp": "classification",
|
||||
"qnli": "classification",
|
||||
"rte": "classification",
|
||||
"wnli": "classification",
|
||||
}
|
||||
|
||||
if args.local_rank == -1 or args.no_cuda:
|
||||
@ -480,8 +726,10 @@ def main():
|
||||
raise ValueError("Task not found: %s" % (task_name))
|
||||
|
||||
processor = processors[task_name]()
|
||||
num_labels = num_labels_task[task_name]
|
||||
output_mode = output_modes[task_name]
|
||||
|
||||
label_list = processor.get_labels()
|
||||
num_labels = len(label_list)
|
||||
|
||||
tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)
|
||||
|
||||
@ -498,7 +746,7 @@ def main():
|
||||
cache_dir = args.cache_dir if args.cache_dir else os.path.join(str(PYTORCH_PRETRAINED_BERT_CACHE), 'distributed_{}'.format(args.local_rank))
|
||||
model = BertForSequenceClassification.from_pretrained(args.bert_model,
|
||||
cache_dir=cache_dir,
|
||||
num_labels = num_labels)
|
||||
num_labels=num_labels)
|
||||
if args.fp16:
|
||||
model.half()
|
||||
model.to(device)
|
||||
@ -546,7 +794,7 @@ def main():
|
||||
tr_loss = 0
|
||||
if args.do_train:
|
||||
train_features = convert_examples_to_features(
|
||||
train_examples, label_list, args.max_seq_length, tokenizer)
|
||||
train_examples, label_list, args.max_seq_length, tokenizer, output_mode)
|
||||
logger.info("***** Running training *****")
|
||||
logger.info(" Num examples = %d", len(train_examples))
|
||||
logger.info(" Batch size = %d", args.train_batch_size)
|
||||
@ -554,7 +802,12 @@ def main():
|
||||
all_input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.long)
|
||||
all_input_mask = torch.tensor([f.input_mask for f in train_features], dtype=torch.long)
|
||||
all_segment_ids = torch.tensor([f.segment_ids for f in train_features], dtype=torch.long)
|
||||
all_label_ids = torch.tensor([f.label_id for f in train_features], dtype=torch.long)
|
||||
|
||||
if output_mode == "classification":
|
||||
all_label_ids = torch.tensor([f.label_id for f in train_features], dtype=torch.long)
|
||||
elif output_mode == "regression":
|
||||
all_label_ids = torch.tensor([f.label_id for f in train_features], dtype=torch.float)
|
||||
|
||||
train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)
|
||||
if args.local_rank == -1:
|
||||
train_sampler = RandomSampler(train_data)
|
||||
@ -569,7 +822,17 @@ def main():
|
||||
for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration")):
|
||||
batch = tuple(t.to(device) for t in batch)
|
||||
input_ids, input_mask, segment_ids, label_ids = batch
|
||||
loss = model(input_ids, segment_ids, input_mask, label_ids)
|
||||
|
||||
# define a new function to compute loss values for both output_modes
|
||||
logits = model(input_ids, segment_ids, input_mask, labels=None)
|
||||
|
||||
if output_mode == "classification":
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(logits.view(-1, num_labels), label_ids.view(-1))
|
||||
elif output_mode == "regression":
|
||||
loss_fct = MSELoss()
|
||||
loss = loss_fct(logits.view(-1), label_ids.view(-1))
|
||||
|
||||
if n_gpu > 1:
|
||||
loss = loss.mean() # mean() to average on multi-gpu.
|
||||
if args.gradient_accumulation_steps > 1:
|
||||
@ -594,7 +857,6 @@ def main():
|
||||
optimizer.zero_grad()
|
||||
global_step += 1
|
||||
|
||||
if args.do_train:
|
||||
# Save a trained model and the associated configuration
|
||||
model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self
|
||||
output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME)
|
||||
@ -614,22 +876,28 @@ def main():
|
||||
if args.do_eval and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
|
||||
eval_examples = processor.get_dev_examples(args.data_dir)
|
||||
eval_features = convert_examples_to_features(
|
||||
eval_examples, label_list, args.max_seq_length, tokenizer)
|
||||
eval_examples, label_list, args.max_seq_length, tokenizer, output_mode)
|
||||
logger.info("***** Running evaluation *****")
|
||||
logger.info(" Num examples = %d", len(eval_examples))
|
||||
logger.info(" Batch size = %d", args.eval_batch_size)
|
||||
all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long)
|
||||
all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long)
|
||||
all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long)
|
||||
all_label_ids = torch.tensor([f.label_id for f in eval_features], dtype=torch.long)
|
||||
|
||||
if output_mode == "classification":
|
||||
all_label_ids = torch.tensor([f.label_id for f in eval_features], dtype=torch.long)
|
||||
elif output_mode == "regression":
|
||||
all_label_ids = torch.tensor([f.label_id for f in eval_features], dtype=torch.float)
|
||||
|
||||
eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)
|
||||
# Run prediction for full data
|
||||
eval_sampler = SequentialSampler(eval_data)
|
||||
eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size)
|
||||
|
||||
model.eval()
|
||||
eval_loss, eval_accuracy = 0, 0
|
||||
nb_eval_steps, nb_eval_examples = 0, 0
|
||||
eval_loss = 0
|
||||
nb_eval_steps = 0
|
||||
preds = []
|
||||
|
||||
for input_ids, input_mask, segment_ids, label_ids in tqdm(eval_dataloader, desc="Evaluating"):
|
||||
input_ids = input_ids.to(device)
|
||||
@ -638,26 +906,36 @@ def main():
|
||||
label_ids = label_ids.to(device)
|
||||
|
||||
with torch.no_grad():
|
||||
tmp_eval_loss = model(input_ids, segment_ids, input_mask, label_ids)
|
||||
logits = model(input_ids, segment_ids, input_mask)
|
||||
|
||||
logits = logits.detach().cpu().numpy()
|
||||
label_ids = label_ids.to('cpu').numpy()
|
||||
tmp_eval_accuracy = accuracy(logits, label_ids)
|
||||
logits = model(input_ids, segment_ids, input_mask, labels=None)
|
||||
|
||||
# create eval loss and other metric required by the task
|
||||
if output_mode == "classification":
|
||||
loss_fct = CrossEntropyLoss()
|
||||
tmp_eval_loss = loss_fct(logits.view(-1, num_labels), label_ids.view(-1))
|
||||
elif output_mode == "regression":
|
||||
loss_fct = MSELoss()
|
||||
tmp_eval_loss = loss_fct(logits.view(-1), label_ids.view(-1))
|
||||
|
||||
eval_loss += tmp_eval_loss.mean().item()
|
||||
eval_accuracy += tmp_eval_accuracy
|
||||
|
||||
nb_eval_examples += input_ids.size(0)
|
||||
nb_eval_steps += 1
|
||||
if len(preds) == 0:
|
||||
preds.append(logits.detach().cpu().numpy())
|
||||
else:
|
||||
preds[0] = np.append(
|
||||
preds[0], logits.detach().cpu().numpy(), axis=0)
|
||||
|
||||
eval_loss = eval_loss / nb_eval_steps
|
||||
eval_accuracy = eval_accuracy / nb_eval_examples
|
||||
preds = preds[0]
|
||||
if output_mode == "classification":
|
||||
preds = np.argmax(preds, axis=1)
|
||||
elif output_mode == "regression":
|
||||
preds = np.squeeze(preds)
|
||||
result = compute_metrics(task_name, preds, all_label_ids.numpy())
|
||||
loss = tr_loss/nb_tr_steps if args.do_train else None
|
||||
result = {'eval_loss': eval_loss,
|
||||
'eval_accuracy': eval_accuracy,
|
||||
'global_step': global_step,
|
||||
'loss': loss}
|
||||
|
||||
result['eval_loss'] = eval_loss
|
||||
result['global_step'] = global_step
|
||||
result['loss'] = loss
|
||||
|
||||
output_eval_file = os.path.join(args.output_dir, "eval_results.txt")
|
||||
with open(output_eval_file, "w") as writer:
|
||||
@ -666,5 +944,73 @@ def main():
|
||||
logger.info(" %s = %s", key, str(result[key]))
|
||||
writer.write("%s = %s\n" % (key, str(result[key])))
|
||||
|
||||
# hack for MNLI-MM
|
||||
if task_name == "mnli":
|
||||
task_name = "mnli-mm"
|
||||
processor = processors[task_name]()
|
||||
|
||||
if os.path.exists(args.output_dir + '-MM') and os.listdir(args.output_dir + '-MM') and args.do_train:
|
||||
raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir))
|
||||
if not os.path.exists(args.output_dir + '-MM'):
|
||||
os.makedirs(args.output_dir + '-MM')
|
||||
|
||||
eval_examples = processor.get_dev_examples(args.data_dir)
|
||||
eval_features = convert_examples_to_features(
|
||||
eval_examples, label_list, args.max_seq_length, tokenizer, output_mode)
|
||||
logger.info("***** Running evaluation *****")
|
||||
logger.info(" Num examples = %d", len(eval_examples))
|
||||
logger.info(" Batch size = %d", args.eval_batch_size)
|
||||
all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long)
|
||||
all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long)
|
||||
all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long)
|
||||
all_label_ids = torch.tensor([f.label_id for f in eval_features], dtype=torch.long)
|
||||
|
||||
eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)
|
||||
# Run prediction for full data
|
||||
eval_sampler = SequentialSampler(eval_data)
|
||||
eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size)
|
||||
|
||||
model.eval()
|
||||
eval_loss = 0
|
||||
nb_eval_steps = 0
|
||||
preds = []
|
||||
|
||||
for input_ids, input_mask, segment_ids, label_ids in tqdm(eval_dataloader, desc="Evaluating"):
|
||||
input_ids = input_ids.to(device)
|
||||
input_mask = input_mask.to(device)
|
||||
segment_ids = segment_ids.to(device)
|
||||
label_ids = label_ids.to(device)
|
||||
|
||||
with torch.no_grad():
|
||||
logits = model(input_ids, segment_ids, input_mask, labels=None)
|
||||
|
||||
loss_fct = CrossEntropyLoss()
|
||||
tmp_eval_loss = loss_fct(logits.view(-1, num_labels), label_ids.view(-1))
|
||||
|
||||
eval_loss += tmp_eval_loss.mean().item()
|
||||
nb_eval_steps += 1
|
||||
if len(preds) == 0:
|
||||
preds.append(logits.detach().cpu().numpy())
|
||||
else:
|
||||
preds[0] = np.append(
|
||||
preds[0], logits.detach().cpu().numpy(), axis=0)
|
||||
|
||||
eval_loss = eval_loss / nb_eval_steps
|
||||
preds = preds[0]
|
||||
preds = np.argmax(preds, axis=1)
|
||||
result = compute_metrics(task_name, preds, all_label_ids.numpy())
|
||||
loss = tr_loss/nb_tr_steps if args.do_train else None
|
||||
|
||||
result['eval_loss'] = eval_loss
|
||||
result['global_step'] = global_step
|
||||
result['loss'] = loss
|
||||
|
||||
output_eval_file = os.path.join(args.output_dir + '-MM', "eval_results.txt")
|
||||
with open(output_eval_file, "w") as writer:
|
||||
logger.info("***** Eval results *****")
|
||||
for key in sorted(result.keys()):
|
||||
logger.info(" %s = %s", key, str(result[key]))
|
||||
writer.write("%s = %s\n" % (key, str(result[key])))
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
@ -109,3 +109,4 @@ def run_model():
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_model()
|
||||
|
||||
|
@ -85,9 +85,9 @@ class SquadExample(object):
|
||||
s += ", doc_tokens: [%s]" % (" ".join(self.doc_tokens))
|
||||
if self.start_position:
|
||||
s += ", start_position: %d" % (self.start_position)
|
||||
if self.start_position:
|
||||
if self.end_position:
|
||||
s += ", end_position: %d" % (self.end_position)
|
||||
if self.start_position:
|
||||
if self.is_impossible:
|
||||
s += ", is_impossible: %r" % (self.is_impossible)
|
||||
return s
|
||||
|
||||
@ -471,7 +471,7 @@ def write_predictions(all_examples, all_features, all_results, n_best_size,
|
||||
prelim_predictions = []
|
||||
# keep track of the minimum score of null start+end of position 0
|
||||
score_null = 1000000 # large and positive
|
||||
min_null_feature_index = 0 # the paragraph slice with min mull score
|
||||
min_null_feature_index = 0 # the paragraph slice with min null score
|
||||
null_start_logit = 0 # the start logit at the slice with min null score
|
||||
null_end_logit = 0 # the end logit at the slice with min null score
|
||||
for (feature_index, feature) in enumerate(features):
|
||||
@ -620,7 +620,7 @@ def write_predictions(all_examples, all_features, all_results, n_best_size,
|
||||
all_predictions[example.qas_id] = ""
|
||||
else:
|
||||
all_predictions[example.qas_id] = best_non_null_entry.text
|
||||
all_nbest_json[example.qas_id] = nbest_json
|
||||
all_nbest_json[example.qas_id] = nbest_json
|
||||
|
||||
with open(output_prediction_file, "w") as writer:
|
||||
writer.write(json.dumps(all_predictions, indent=4) + "\n")
|
||||
@ -657,8 +657,8 @@ def get_final_text(pred_text, orig_text, do_lower_case, verbose_logging=False):
|
||||
#
|
||||
# What we really want to return is "Steve Smith".
|
||||
#
|
||||
# Therefore, we have to apply a semi-complicated alignment heruistic between
|
||||
# `pred_text` and `orig_text` to get a character-to-charcter alignment. This
|
||||
# Therefore, we have to apply a semi-complicated alignment heuristic between
|
||||
# `pred_text` and `orig_text` to get a character-to-character alignment. This
|
||||
# can fail in certain cases in which case we just return `orig_text`.
|
||||
|
||||
def _strip_spaces(text):
|
||||
|
@ -15,6 +15,8 @@
|
||||
# limitations under the License.
|
||||
"""BERT finetuning runner."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
|
||||
import argparse
|
||||
import csv
|
||||
import logging
|
||||
@ -31,7 +33,7 @@ from torch.utils.data.distributed import DistributedSampler
|
||||
from tqdm import tqdm, trange
|
||||
|
||||
from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE
|
||||
from pytorch_pretrained_bert.modeling import BertForMultipleChoice
|
||||
from pytorch_pretrained_bert.modeling import (BertForMultipleChoice, BertConfig, WEIGHTS_NAME, CONFIG_NAME)
|
||||
from pytorch_pretrained_bert.optimization import BertAdam, warmup_linear
|
||||
from pytorch_pretrained_bert.tokenization import BertTokenizer
|
||||
|
||||
@ -507,7 +509,7 @@ def main():
|
||||
model.eval()
|
||||
eval_loss, eval_accuracy = 0, 0
|
||||
nb_eval_steps, nb_eval_examples = 0, 0
|
||||
for input_ids, input_mask, segment_ids, label_ids in eval_dataloader:
|
||||
for input_ids, input_mask, segment_ids, label_ids in tqdm(eval_dataloader, desc="Evaluating"):
|
||||
input_ids = input_ids.to(device)
|
||||
input_mask = input_mask.to(device)
|
||||
segment_ids = segment_ids.to(device)
|
||||
|
@ -18,7 +18,7 @@ from .modeling_gpt2 import (GPT2Config, GPT2Model,
|
||||
GPT2LMHeadModel, GPT2DoubleHeadsModel,
|
||||
load_tf_weights_in_gpt2)
|
||||
|
||||
from .optimization import *
|
||||
from .optimization import BertAdam
|
||||
from .optimization_openai import OpenAIAdam
|
||||
|
||||
from .file_utils import PYTORCH_PRETRAINED_BERT_CACHE, cached_path
|
||||
|
@ -76,7 +76,7 @@ def load_tf_weights_in_bert(model, tf_checkpoint_path):
|
||||
name = name.split('/')
|
||||
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
|
||||
# which are not required for using pretrained model
|
||||
if any(n in ["adam_v", "adam_m"] for n in name):
|
||||
if any(n in ["adam_v", "adam_m", "global_step"] for n in name):
|
||||
print("Skipping {}".format("/".join(name)))
|
||||
continue
|
||||
pointer = model
|
||||
@ -91,8 +91,14 @@ def load_tf_weights_in_bert(model, tf_checkpoint_path):
|
||||
pointer = getattr(pointer, 'bias')
|
||||
elif l[0] == 'output_weights':
|
||||
pointer = getattr(pointer, 'weight')
|
||||
elif l[0] == 'squad':
|
||||
pointer = getattr(pointer, 'classifier')
|
||||
else:
|
||||
pointer = getattr(pointer, l[0])
|
||||
try:
|
||||
pointer = getattr(pointer, l[0])
|
||||
except AttributeError:
|
||||
print("Skipping {}".format("/".join(name)))
|
||||
continue
|
||||
if len(l) >= 2:
|
||||
num = int(l[1])
|
||||
pointer = pointer[num]
|
||||
@ -217,7 +223,7 @@ class BertConfig(object):
|
||||
try:
|
||||
from apex.normalization.fused_layer_norm import FusedLayerNorm as BertLayerNorm
|
||||
except ImportError:
|
||||
logger.info("Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex.")
|
||||
logger.info("Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex .")
|
||||
class BertLayerNorm(nn.Module):
|
||||
def __init__(self, hidden_size, eps=1e-12):
|
||||
"""Construct a layernorm module in the TF style (epsilon inside the square root).
|
||||
@ -238,7 +244,7 @@ class BertEmbeddings(nn.Module):
|
||||
"""
|
||||
def __init__(self, config):
|
||||
super(BertEmbeddings, self).__init__()
|
||||
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
|
||||
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=0)
|
||||
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
|
||||
self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
|
||||
|
||||
|
@ -15,6 +15,8 @@
|
||||
# limitations under the License.
|
||||
"""PyTorch OpenAI GPT-2 model."""
|
||||
|
||||
from __future__ import absolute_import, division, print_function, unicode_literals
|
||||
|
||||
import collections
|
||||
import copy
|
||||
import json
|
||||
@ -615,8 +617,14 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
|
||||
hidden_states, presents = self.transformer(input_ids, position_ids, token_type_ids, past)
|
||||
lm_logits = self.lm_head(hidden_states)
|
||||
if lm_labels is not None:
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = lm_logits[:, :-1].contiguous()
|
||||
shift_labels = lm_labels[:, 1:].contiguous()
|
||||
|
||||
# Flatten the tokens
|
||||
loss_fct = CrossEntropyLoss(ignore_index=-1)
|
||||
loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), lm_labels.view(-1))
|
||||
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)),
|
||||
shift_labels.view(-1))
|
||||
return loss
|
||||
return lm_logits, presents
|
||||
|
||||
@ -688,8 +696,11 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
|
||||
mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids)
|
||||
losses = []
|
||||
if lm_labels is not None:
|
||||
shift_logits = lm_logits[:, :-1].contiguous()
|
||||
shift_labels = lm_labels[:, 1:].contiguous()
|
||||
loss_fct = CrossEntropyLoss(ignore_index=-1)
|
||||
losses.append(loss_fct(lm_logits.view(-1, lm_logits.size(-1)), lm_labels.view(-1)))
|
||||
losses.append(loss_fct(shift_logits.view(-1,
|
||||
shift_logits.size(-1)), shift_labels.view(-1)))
|
||||
if mc_labels is not None:
|
||||
loss_fct = CrossEntropyLoss()
|
||||
losses.append(loss_fct(mc_logits.view(-1, mc_logits.size(-1)), mc_labels.view(-1)))
|
||||
|
@ -15,6 +15,8 @@
|
||||
# limitations under the License.
|
||||
"""PyTorch OpenAI GPT model."""
|
||||
|
||||
from __future__ import absolute_import, division, print_function, unicode_literals
|
||||
|
||||
import collections
|
||||
import copy
|
||||
import json
|
||||
@ -374,6 +376,7 @@ class OpenAIGPTMultipleChoiceHead(nn.Module):
|
||||
# (bsz, num_choices, 1, hidden_size)
|
||||
multiple_choice_h = hidden_states.gather(2, mc_token_ids).squeeze(2)
|
||||
# (bsz, num_choices, hidden_size)
|
||||
multiple_choice_h = self.dropout(multiple_choice_h.transpose(1, 2)).transpose(1, 2)
|
||||
multiple_choice_logits = self.linear(multiple_choice_h).squeeze(-1)
|
||||
# (bsz, num_choices)
|
||||
return multiple_choice_logits
|
||||
@ -713,8 +716,14 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
|
||||
hidden_states = self.transformer(input_ids, position_ids, token_type_ids)
|
||||
lm_logits = self.lm_head(hidden_states)
|
||||
if lm_labels is not None:
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = lm_logits[:, :-1].contiguous()
|
||||
shift_labels = lm_labels[:, 1:].contiguous()
|
||||
|
||||
# Flatten the tokens
|
||||
loss_fct = CrossEntropyLoss(ignore_index=-1)
|
||||
loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), lm_labels.view(-1))
|
||||
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)),
|
||||
shift_labels.view(-1))
|
||||
return loss
|
||||
return lm_logits
|
||||
|
||||
@ -800,8 +809,11 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
|
||||
mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids)
|
||||
losses = []
|
||||
if lm_labels is not None:
|
||||
shift_logits = lm_logits[:, :-1].contiguous()
|
||||
shift_labels = lm_labels[:, 1:].contiguous()
|
||||
loss_fct = CrossEntropyLoss(ignore_index=-1)
|
||||
losses.append(loss_fct(lm_logits.view(-1, lm_logits.size(-1)), lm_labels.view(-1)))
|
||||
losses.append(loss_fct(shift_logits.view(-1,
|
||||
shift_logits.size(-1)), shift_labels.view(-1)))
|
||||
if mc_labels is not None:
|
||||
loss_fct = CrossEntropyLoss()
|
||||
losses.append(loss_fct(mc_logits.view(-1, mc_logits.size(-1)), mc_labels.view(-1)))
|
||||
|
@ -18,6 +18,8 @@
|
||||
In particular https://github.com/kimiyoung/transformer-xl/blob/master/pytorch/mem_transformer.py
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import, division, print_function, unicode_literals
|
||||
|
||||
import os
|
||||
import copy
|
||||
import json
|
||||
|
@ -26,7 +26,8 @@ logger = logging.getLogger(__name__)
|
||||
def warmup_cosine(x, warmup=0.002):
|
||||
if x < warmup:
|
||||
return x/warmup
|
||||
return 0.5 * (1.0 + torch.cos(math.pi * x))
|
||||
x_ = (x - warmup) / (1 - warmup) # progress after warmup
|
||||
return 0.5 * (1. + math.cos(math.pi * x_))
|
||||
|
||||
def warmup_constant(x, warmup=0.002):
|
||||
""" Linearly increases learning rate over `warmup`*`t_total` (as provided to OpenAIAdam) training steps.
|
||||
|
@ -105,13 +105,13 @@ class BertTokenizer(object):
|
||||
self.max_len = max_len if max_len is not None else int(1e12)
|
||||
|
||||
def tokenize(self, text):
|
||||
split_tokens = []
|
||||
if self.do_basic_tokenize:
|
||||
split_tokens = []
|
||||
for token in self.basic_tokenizer.tokenize(text):
|
||||
for sub_token in self.wordpiece_tokenizer.tokenize(token):
|
||||
split_tokens.append(sub_token)
|
||||
for token in self.basic_tokenizer.tokenize(text):
|
||||
for sub_token in self.wordpiece_tokenizer.tokenize(token):
|
||||
split_tokens.append(sub_token)
|
||||
else:
|
||||
split_tokens = self.wordpiece_tokenizer.tokenize(text)
|
||||
split_tokens = self.wordpiece_tokenizer.tokenize(text)
|
||||
return split_tokens
|
||||
|
||||
def convert_tokens_to_ids(self, tokens):
|
||||
@ -142,6 +142,16 @@ class BertTokenizer(object):
|
||||
"""
|
||||
if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP:
|
||||
vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name_or_path]
|
||||
if '-cased' in pretrained_model_name_or_path and kwargs.get('do_lower_case', True):
|
||||
logger.warning("The pre-trained model you are loading is a cased model but you have not set "
|
||||
"`do_lower_case` to False. We are setting `do_lower_case=False` for you but "
|
||||
"you may want to check this behavior.")
|
||||
kwargs['do_lower_case'] = False
|
||||
elif '-cased' not in pretrained_model_name_or_path and not kwargs.get('do_lower_case', True):
|
||||
logger.warning("The pre-trained model you are loading is an uncased model but you have set "
|
||||
"`do_lower_case` to False. We are setting `do_lower_case=True` for you "
|
||||
"but you may want to check this behavior.")
|
||||
kwargs['do_lower_case'] = True
|
||||
else:
|
||||
vocab_file = pretrained_model_name_or_path
|
||||
if os.path.isdir(vocab_file):
|
||||
|
Loading…
Reference in New Issue
Block a user