mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 02:31:11 +06:00
Support for torch-lightning in NER examples (#2890)
* initial pytorch lightning commit * tested multigpu * Fix learning rate schedule * black formatting * fix flake8 * isort * isort * . Co-authored-by: Check your git settings! <chris@chris-laptop>
This commit is contained in:
parent
ab1238393c
commit
b662f0e625
@ -3,7 +3,7 @@
|
||||
In this section a few examples are put together. All of these examples work for several models, making use of the very
|
||||
similar API between the different models.
|
||||
|
||||
**Important**
|
||||
**Important**
|
||||
To run the latest versions of the examples, you have to install from source and install some specific requirements for the examples.
|
||||
Execute the following steps in a new virtual environment:
|
||||
|
||||
@ -15,8 +15,8 @@ pip install -r ./examples/requirements.txt
|
||||
```
|
||||
|
||||
| Section | Description |
|
||||
|----------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
||||
| [TensorFlow 2.0 models on GLUE](#TensorFlow-2.0-Bert-models-on-GLUE) | Examples running BERT TensorFlow 2.0 model on the GLUE tasks.
|
||||
|----------------------------|------------------------------------------------------------------------------------------------------------------------------------------
|
||||
| [TensorFlow 2.0 models on GLUE](#TensorFlow-2.0-Bert-models-on-GLUE) | Examples running BERT TensorFlow 2.0 model on the GLUE tasks. |
|
||||
| [Language Model training](#language-model-training) | Fine-tuning (or training from scratch) the library models for language modeling on a text dataset. Causal language modeling for GPT/GPT-2, masked language modeling for BERT/RoBERTa. |
|
||||
| [Language Generation](#language-generation) | Conditional text generation using the auto-regressive models of the library: GPT, GPT-2, Transformer-XL and XLNet. |
|
||||
| [GLUE](#glue) | Examples running BERT/XLM/XLNet/RoBERTa on the 9 GLUE tasks. Examples feature distributed training as well as half-precision. |
|
||||
@ -88,7 +88,7 @@ a score of ~20 perplexity once fine-tuned on the dataset.
|
||||
|
||||
The following example fine-tunes RoBERTa on WikiText-2. Here too, we're using the raw WikiText-2. The loss is different
|
||||
as BERT/RoBERTa have a bidirectional mechanism; we're therefore using the same loss that was used during their
|
||||
pre-training: masked language modeling.
|
||||
pre-training: masked language modeling.
|
||||
|
||||
In accordance to the RoBERTa paper, we use dynamic masking rather than static masking. The model may, therefore, converge
|
||||
slightly slower (over-fitting takes more epochs).
|
||||
@ -130,8 +130,8 @@ python run_generation.py \
|
||||
|
||||
Based on the script [`run_glue.py`](https://github.com/huggingface/transformers/blob/master/examples/run_glue.py).
|
||||
|
||||
Fine-tuning the library models for sequence classification on the GLUE benchmark: [General Language Understanding
|
||||
Evaluation](https://gluebenchmark.com/). This script can fine-tune the following models: BERT, XLM, XLNet and RoBERTa.
|
||||
Fine-tuning the library models for sequence classification on the GLUE benchmark: [General Language Understanding
|
||||
Evaluation](https://gluebenchmark.com/). This script can fine-tune the following models: BERT, XLM, XLNet and RoBERTa.
|
||||
|
||||
GLUE is made up of a total of 9 different tasks. We get the following results on the dev set of the benchmark with an
|
||||
uncased BERT base model (the checkpoint `bert-base-uncased`). All experiments ran single V100 GPUs with a total train
|
||||
@ -179,20 +179,20 @@ python run_glue.py \
|
||||
|
||||
where task name can be one of CoLA, SST-2, MRPC, STS-B, QQP, MNLI, QNLI, RTE, WNLI.
|
||||
|
||||
The dev set results will be present within the text file `eval_results.txt` in the specified output_dir.
|
||||
In case of MNLI, since there are two separate dev sets (matched and mismatched), there will be a separate
|
||||
The dev set results will be present within the text file `eval_results.txt` in the specified output_dir.
|
||||
In case of MNLI, since there are two separate dev sets (matched and mismatched), there will be a separate
|
||||
output folder called `/tmp/MNLI-MM/` in addition to `/tmp/MNLI/`.
|
||||
|
||||
The code has not been tested with half-precision training with apex on any GLUE task apart from MRPC, MNLI,
|
||||
CoLA, SST-2. The following section provides details on how to run half-precision training with MRPC. With that being
|
||||
said, there shouldn’t be any issues in running half-precision training with the remaining GLUE tasks as well,
|
||||
The code has not been tested with half-precision training with apex on any GLUE task apart from MRPC, MNLI,
|
||||
CoLA, SST-2. The following section provides details on how to run half-precision training with MRPC. With that being
|
||||
said, there shouldn’t be any issues in running half-precision training with the remaining GLUE tasks as well,
|
||||
since the data processor for each task inherits from the base class DataProcessor.
|
||||
|
||||
### MRPC
|
||||
|
||||
#### Fine-tuning example
|
||||
|
||||
The following examples fine-tune BERT on the Microsoft Research Paraphrase Corpus (MRPC) corpus and runs in less
|
||||
The following examples fine-tune BERT on the Microsoft Research Paraphrase Corpus (MRPC) corpus and runs in less
|
||||
than 10 minutes on a single K-80 and in 27 seconds (!) on single tesla V100 16GB with apex installed.
|
||||
|
||||
Before running any one of these GLUE tasks you should download the
|
||||
@ -219,12 +219,12 @@ python run_glue.py \
|
||||
```
|
||||
|
||||
Our test ran on a few seeds with [the original implementation hyper-
|
||||
parameters](https://github.com/google-research/bert#sentence-and-sentence-pair-classification-tasks) gave evaluation
|
||||
parameters](https://github.com/google-research/bert#sentence-and-sentence-pair-classification-tasks) gave evaluation
|
||||
results between 84% and 88%.
|
||||
|
||||
#### Using Apex and mixed-precision
|
||||
|
||||
Using Apex and 16 bit precision, the fine-tuning on MRPC only takes 27 seconds. First install
|
||||
Using Apex and 16 bit precision, the fine-tuning on MRPC only takes 27 seconds. First install
|
||||
[apex](https://github.com/NVIDIA/apex), then run the following example:
|
||||
|
||||
```bash
|
||||
@ -360,8 +360,8 @@ Based on the script [`run_squad.py`](https://github.com/huggingface/transformers
|
||||
|
||||
#### Fine-tuning BERT on SQuAD1.0
|
||||
|
||||
This example code fine-tunes BERT on the SQuAD1.0 dataset. It runs in 24 min (with BERT-base) or 68 min (with BERT-large)
|
||||
on a single tesla V100 16GB. The data for SQuAD can be downloaded with the following links and should be saved in a
|
||||
This example code fine-tunes BERT on the SQuAD1.0 dataset. It runs in 24 min (with BERT-base) or 68 min (with BERT-large)
|
||||
on a single tesla V100 16GB. The data for SQuAD can be downloaded with the following links and should be saved in a
|
||||
$SQUAD_DIR directory.
|
||||
|
||||
* [train-v1.1.json](https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v1.1.json)
|
||||
@ -516,185 +516,6 @@ Larger batch size may improve the performance while costing more memory.
|
||||
|
||||
|
||||
|
||||
## Named Entity Recognition
|
||||
|
||||
Based on the scripts [`run_ner.py`](https://github.com/huggingface/transformers/blob/master/examples/run_ner.py) for Pytorch and
|
||||
[`run_tf_ner.py`](https://github.com/huggingface/transformers/blob/master/examples/run_tf_ner.py) for Tensorflow 2.
|
||||
This example fine-tune Bert Multilingual on GermEval 2014 (German NER).
|
||||
Details and results for the fine-tuning provided by @stefan-it.
|
||||
|
||||
### Data (Download and pre-processing steps)
|
||||
|
||||
Data can be obtained from the [GermEval 2014](https://sites.google.com/site/germeval2014ner/data) shared task page.
|
||||
|
||||
Here are the commands for downloading and pre-processing train, dev and test datasets. The original data format has four (tab-separated) columns, in a pre-processing step only the two relevant columns (token and outer span NER annotation) are extracted:
|
||||
|
||||
```bash
|
||||
curl -L 'https://sites.google.com/site/germeval2014ner/data/NER-de-train.tsv?attredirects=0&d=1' \
|
||||
| grep -v "^#" | cut -f 2,3 | tr '\t' ' ' > train.txt.tmp
|
||||
curl -L 'https://sites.google.com/site/germeval2014ner/data/NER-de-dev.tsv?attredirects=0&d=1' \
|
||||
| grep -v "^#" | cut -f 2,3 | tr '\t' ' ' > dev.txt.tmp
|
||||
curl -L 'https://sites.google.com/site/germeval2014ner/data/NER-de-test.tsv?attredirects=0&d=1' \
|
||||
| grep -v "^#" | cut -f 2,3 | tr '\t' ' ' > test.txt.tmp
|
||||
```
|
||||
|
||||
The GermEval 2014 dataset contains some strange "control character" tokens like `'\x96', '\u200e', '\x95', '\xad' or '\x80'`. One problem with these tokens is, that `BertTokenizer` returns an empty token for them, resulting in misaligned `InputExample`s. I wrote a script that a) filters these tokens and b) splits longer sentences into smaller ones (once the max. subtoken length is reached).
|
||||
|
||||
```bash
|
||||
wget "https://raw.githubusercontent.com/stefan-it/fine-tuned-berts-seq/master/scripts/preprocess.py"
|
||||
```
|
||||
Let's define some variables that we need for further pre-processing steps and training the model:
|
||||
|
||||
```bash
|
||||
export MAX_LENGTH=128
|
||||
export BERT_MODEL=bert-base-multilingual-cased
|
||||
```
|
||||
|
||||
Run the pre-processing script on training, dev and test datasets:
|
||||
|
||||
```bash
|
||||
python3 preprocess.py train.txt.tmp $BERT_MODEL $MAX_LENGTH > train.txt
|
||||
python3 preprocess.py dev.txt.tmp $BERT_MODEL $MAX_LENGTH > dev.txt
|
||||
python3 preprocess.py test.txt.tmp $BERT_MODEL $MAX_LENGTH > test.txt
|
||||
```
|
||||
|
||||
The GermEval 2014 dataset has much more labels than CoNLL-2002/2003 datasets, so an own set of labels must be used:
|
||||
|
||||
```bash
|
||||
cat train.txt dev.txt test.txt | cut -d " " -f 2 | grep -v "^$"| sort | uniq > labels.txt
|
||||
```
|
||||
|
||||
### Prepare the run
|
||||
|
||||
Additional environment variables must be set:
|
||||
|
||||
```bash
|
||||
export OUTPUT_DIR=germeval-model
|
||||
export BATCH_SIZE=32
|
||||
export NUM_EPOCHS=3
|
||||
export SAVE_STEPS=750
|
||||
export SEED=1
|
||||
```
|
||||
|
||||
### Run the Pytorch version
|
||||
|
||||
To start training, just run:
|
||||
|
||||
```bash
|
||||
python3 run_ner.py --data_dir ./ \
|
||||
--model_type bert \
|
||||
--labels ./labels.txt \
|
||||
--model_name_or_path $BERT_MODEL \
|
||||
--output_dir $OUTPUT_DIR \
|
||||
--max_seq_length $MAX_LENGTH \
|
||||
--num_train_epochs $NUM_EPOCHS \
|
||||
--per_gpu_train_batch_size $BATCH_SIZE \
|
||||
--save_steps $SAVE_STEPS \
|
||||
--seed $SEED \
|
||||
--do_train \
|
||||
--do_eval \
|
||||
--do_predict
|
||||
```
|
||||
|
||||
If your GPU supports half-precision training, just add the `--fp16` flag. After training, the model will be both evaluated on development and test datasets.
|
||||
|
||||
#### Evaluation
|
||||
|
||||
Evaluation on development dataset outputs the following for our example:
|
||||
|
||||
```bash
|
||||
10/04/2019 00:42:06 - INFO - __main__ - ***** Eval results *****
|
||||
10/04/2019 00:42:06 - INFO - __main__ - f1 = 0.8623348017621146
|
||||
10/04/2019 00:42:06 - INFO - __main__ - loss = 0.07183869666975543
|
||||
10/04/2019 00:42:06 - INFO - __main__ - precision = 0.8467916366258111
|
||||
10/04/2019 00:42:06 - INFO - __main__ - recall = 0.8784592370979806
|
||||
```
|
||||
|
||||
On the test dataset the following results could be achieved:
|
||||
|
||||
```bash
|
||||
10/04/2019 00:42:42 - INFO - __main__ - ***** Eval results *****
|
||||
10/04/2019 00:42:42 - INFO - __main__ - f1 = 0.8614389652384803
|
||||
10/04/2019 00:42:42 - INFO - __main__ - loss = 0.07064602487454782
|
||||
10/04/2019 00:42:42 - INFO - __main__ - precision = 0.8604651162790697
|
||||
10/04/2019 00:42:42 - INFO - __main__ - recall = 0.8624150210424085
|
||||
```
|
||||
|
||||
#### Comparing BERT (large, cased), RoBERTa (large, cased) and DistilBERT (base, uncased)
|
||||
|
||||
Here is a small comparison between BERT (large, cased), RoBERTa (large, cased) and DistilBERT (base, uncased) with the same hyperparameters as specified in the [example documentation](https://huggingface.co/transformers/examples.html#named-entity-recognition) (one run):
|
||||
|
||||
| Model | F-Score Dev | F-Score Test
|
||||
| --------------------------------- | ------- | --------
|
||||
| `bert-large-cased` | 95.59 | 91.70
|
||||
| `roberta-large` | 95.96 | 91.87
|
||||
| `distilbert-base-uncased` | 94.34 | 90.32
|
||||
|
||||
### Run the Tensorflow 2 version
|
||||
|
||||
To start training, just run:
|
||||
|
||||
```bash
|
||||
python3 run_tf_ner.py --data_dir ./ \
|
||||
--model_type bert \
|
||||
--labels ./labels.txt \
|
||||
--model_name_or_path $BERT_MODEL \
|
||||
--output_dir $OUTPUT_DIR \
|
||||
--max_seq_length $MAX_LENGTH \
|
||||
--num_train_epochs $NUM_EPOCHS \
|
||||
--per_device_train_batch_size $BATCH_SIZE \
|
||||
--save_steps $SAVE_STEPS \
|
||||
--seed $SEED \
|
||||
--do_train \
|
||||
--do_eval \
|
||||
--do_predict
|
||||
```
|
||||
|
||||
Such as the Pytorch version, if your GPU supports half-precision training, just add the `--fp16` flag. After training, the model will be both evaluated on development and test datasets.
|
||||
|
||||
#### Evaluation
|
||||
|
||||
Evaluation on development dataset outputs the following for our example:
|
||||
```bash
|
||||
precision recall f1-score support
|
||||
|
||||
LOCderiv 0.7619 0.6154 0.6809 52
|
||||
PERpart 0.8724 0.8997 0.8858 4057
|
||||
OTHpart 0.9360 0.9466 0.9413 711
|
||||
ORGpart 0.7015 0.6989 0.7002 269
|
||||
LOCpart 0.7668 0.8488 0.8057 496
|
||||
LOC 0.8745 0.9191 0.8963 235
|
||||
ORGderiv 0.7723 0.8571 0.8125 91
|
||||
OTHderiv 0.4800 0.6667 0.5581 18
|
||||
OTH 0.5789 0.6875 0.6286 16
|
||||
PERderiv 0.5385 0.3889 0.4516 18
|
||||
PER 0.5000 0.5000 0.5000 2
|
||||
ORG 0.0000 0.0000 0.0000 3
|
||||
|
||||
micro avg 0.8574 0.8862 0.8715 5968
|
||||
macro avg 0.8575 0.8862 0.8713 5968
|
||||
```
|
||||
|
||||
On the test dataset the following results could be achieved:
|
||||
```bash
|
||||
precision recall f1-score support
|
||||
|
||||
PERpart 0.8847 0.8944 0.8896 9397
|
||||
OTHpart 0.9376 0.9353 0.9365 1639
|
||||
ORGpart 0.7307 0.7044 0.7173 697
|
||||
LOC 0.9133 0.9394 0.9262 561
|
||||
LOCpart 0.8058 0.8157 0.8107 1150
|
||||
ORG 0.0000 0.0000 0.0000 8
|
||||
OTHderiv 0.5882 0.4762 0.5263 42
|
||||
PERderiv 0.6571 0.5227 0.5823 44
|
||||
OTH 0.4906 0.6667 0.5652 39
|
||||
ORGderiv 0.7016 0.7791 0.7383 172
|
||||
LOCderiv 0.8256 0.6514 0.7282 109
|
||||
PER 0.0000 0.0000 0.0000 11
|
||||
|
||||
micro avg 0.8722 0.8774 0.8748 13869
|
||||
macro avg 0.8712 0.8774 0.8740 13869
|
||||
```
|
||||
|
||||
## XNLI
|
||||
|
||||
@ -705,7 +526,7 @@ Based on the script [`run_xnli.py`](https://github.com/huggingface/transformers/
|
||||
#### Fine-tuning on XNLI
|
||||
|
||||
This example code fine-tunes mBERT (multi-lingual BERT) on the XNLI dataset. It runs in 106 mins
|
||||
on a single tesla V100 16GB. The data for XNLI can be downloaded with the following links and should be both saved (and un-zipped) in a
|
||||
on a single tesla V100 16GB. The data for XNLI can be downloaded with the following links and should be both saved (and un-zipped) in a
|
||||
`$XNLI_DIR` directory.
|
||||
|
||||
* [XNLI 1.0](https://www.nyu.edu/projects/bowman/xnli/XNLI-1.0.zip)
|
||||
|
179
examples/ner/README.md
Normal file
179
examples/ner/README.md
Normal file
@ -0,0 +1,179 @@
|
||||
## Named Entity Recognition
|
||||
|
||||
Based on the scripts [`run_ner.py`](https://github.com/huggingface/transformers/blob/master/examples/ner/run_ner.py) for Pytorch and
|
||||
[`run_tf_ner.py`](https://github.com/huggingface/transformers/blob/master/examples/ner/run_tf_ner.py) for Tensorflow 2.
|
||||
This example fine-tune Bert Multilingual on GermEval 2014 (German NER).
|
||||
Details and results for the fine-tuning provided by @stefan-it.
|
||||
|
||||
### Data (Download and pre-processing steps)
|
||||
|
||||
Data can be obtained from the [GermEval 2014](https://sites.google.com/site/germeval2014ner/data) shared task page.
|
||||
|
||||
Here are the commands for downloading and pre-processing train, dev and test datasets. The original data format has four (tab-separated) columns, in a pre-processing step only the two relevant columns (token and outer span NER annotation) are extracted:
|
||||
|
||||
```bash
|
||||
curl -L 'https://sites.google.com/site/germeval2014ner/data/NER-de-train.tsv?attredirects=0&d=1' \
|
||||
| grep -v "^#" | cut -f 2,3 | tr '\t' ' ' > train.txt.tmp
|
||||
curl -L 'https://sites.google.com/site/germeval2014ner/data/NER-de-dev.tsv?attredirects=0&d=1' \
|
||||
| grep -v "^#" | cut -f 2,3 | tr '\t' ' ' > dev.txt.tmp
|
||||
curl -L 'https://sites.google.com/site/germeval2014ner/data/NER-de-test.tsv?attredirects=0&d=1' \
|
||||
| grep -v "^#" | cut -f 2,3 | tr '\t' ' ' > test.txt.tmp
|
||||
```
|
||||
|
||||
The GermEval 2014 dataset contains some strange "control character" tokens like `'\x96', '\u200e', '\x95', '\xad' or '\x80'`. One problem with these tokens is, that `BertTokenizer` returns an empty token for them, resulting in misaligned `InputExample`s. I wrote a script that a) filters these tokens and b) splits longer sentences into smaller ones (once the max. subtoken length is reached).
|
||||
|
||||
```bash
|
||||
wget "https://raw.githubusercontent.com/stefan-it/fine-tuned-berts-seq/master/scripts/preprocess.py"
|
||||
```
|
||||
Let's define some variables that we need for further pre-processing steps and training the model:
|
||||
|
||||
```bash
|
||||
export MAX_LENGTH=128
|
||||
export BERT_MODEL=bert-base-multilingual-cased
|
||||
```
|
||||
|
||||
Run the pre-processing script on training, dev and test datasets:
|
||||
|
||||
```bash
|
||||
python3 preprocess.py train.txt.tmp $BERT_MODEL $MAX_LENGTH > train.txt
|
||||
python3 preprocess.py dev.txt.tmp $BERT_MODEL $MAX_LENGTH > dev.txt
|
||||
python3 preprocess.py test.txt.tmp $BERT_MODEL $MAX_LENGTH > test.txt
|
||||
```
|
||||
|
||||
The GermEval 2014 dataset has much more labels than CoNLL-2002/2003 datasets, so an own set of labels must be used:
|
||||
|
||||
```bash
|
||||
cat train.txt dev.txt test.txt | cut -d " " -f 2 | grep -v "^$"| sort | uniq > labels.txt
|
||||
```
|
||||
|
||||
### Prepare the run
|
||||
|
||||
Additional environment variables must be set:
|
||||
|
||||
```bash
|
||||
export OUTPUT_DIR=germeval-model
|
||||
export BATCH_SIZE=32
|
||||
export NUM_EPOCHS=3
|
||||
export SAVE_STEPS=750
|
||||
export SEED=1
|
||||
```
|
||||
|
||||
### Run the Pytorch version
|
||||
|
||||
To start training, just run:
|
||||
|
||||
```bash
|
||||
python3 run_ner.py --data_dir ./ \
|
||||
--model_type bert \
|
||||
--labels ./labels.txt \
|
||||
--model_name_or_path $BERT_MODEL \
|
||||
--output_dir $OUTPUT_DIR \
|
||||
--max_seq_length $MAX_LENGTH \
|
||||
--num_train_epochs $NUM_EPOCHS \
|
||||
--per_gpu_train_batch_size $BATCH_SIZE \
|
||||
--save_steps $SAVE_STEPS \
|
||||
--seed $SEED \
|
||||
--do_train \
|
||||
--do_eval \
|
||||
--do_predict
|
||||
```
|
||||
|
||||
If your GPU supports half-precision training, just add the `--fp16` flag. After training, the model will be both evaluated on development and test datasets.
|
||||
|
||||
#### Evaluation
|
||||
|
||||
Evaluation on development dataset outputs the following for our example:
|
||||
|
||||
```bash
|
||||
10/04/2019 00:42:06 - INFO - __main__ - ***** Eval results *****
|
||||
10/04/2019 00:42:06 - INFO - __main__ - f1 = 0.8623348017621146
|
||||
10/04/2019 00:42:06 - INFO - __main__ - loss = 0.07183869666975543
|
||||
10/04/2019 00:42:06 - INFO - __main__ - precision = 0.8467916366258111
|
||||
10/04/2019 00:42:06 - INFO - __main__ - recall = 0.8784592370979806
|
||||
```
|
||||
|
||||
On the test dataset the following results could be achieved:
|
||||
|
||||
```bash
|
||||
10/04/2019 00:42:42 - INFO - __main__ - ***** Eval results *****
|
||||
10/04/2019 00:42:42 - INFO - __main__ - f1 = 0.8614389652384803
|
||||
10/04/2019 00:42:42 - INFO - __main__ - loss = 0.07064602487454782
|
||||
10/04/2019 00:42:42 - INFO - __main__ - precision = 0.8604651162790697
|
||||
10/04/2019 00:42:42 - INFO - __main__ - recall = 0.8624150210424085
|
||||
```
|
||||
|
||||
#### Comparing BERT (large, cased), RoBERTa (large, cased) and DistilBERT (base, uncased)
|
||||
|
||||
Here is a small comparison between BERT (large, cased), RoBERTa (large, cased) and DistilBERT (base, uncased) with the same hyperparameters as specified in the [example documentation](https://huggingface.co/transformers/examples.html#named-entity-recognition) (one run):
|
||||
|
||||
| Model | F-Score Dev | F-Score Test
|
||||
| --------------------------------- | ------- | --------
|
||||
| `bert-large-cased` | 95.59 | 91.70
|
||||
| `roberta-large` | 95.96 | 91.87
|
||||
| `distilbert-base-uncased` | 94.34 | 90.32
|
||||
|
||||
### Run the Tensorflow 2 version
|
||||
|
||||
To start training, just run:
|
||||
|
||||
```bash
|
||||
python3 run_tf_ner.py --data_dir ./ \
|
||||
--model_type bert \
|
||||
--labels ./labels.txt \
|
||||
--model_name_or_path $BERT_MODEL \
|
||||
--output_dir $OUTPUT_DIR \
|
||||
--max_seq_length $MAX_LENGTH \
|
||||
--num_train_epochs $NUM_EPOCHS \
|
||||
--per_device_train_batch_size $BATCH_SIZE \
|
||||
--save_steps $SAVE_STEPS \
|
||||
--seed $SEED \
|
||||
--do_train \
|
||||
--do_eval \
|
||||
--do_predict
|
||||
```
|
||||
|
||||
Such as the Pytorch version, if your GPU supports half-precision training, just add the `--fp16` flag. After training, the model will be both evaluated on development and test datasets.
|
||||
|
||||
#### Evaluation
|
||||
|
||||
Evaluation on development dataset outputs the following for our example:
|
||||
```bash
|
||||
precision recall f1-score support
|
||||
|
||||
LOCderiv 0.7619 0.6154 0.6809 52
|
||||
PERpart 0.8724 0.8997 0.8858 4057
|
||||
OTHpart 0.9360 0.9466 0.9413 711
|
||||
ORGpart 0.7015 0.6989 0.7002 269
|
||||
LOCpart 0.7668 0.8488 0.8057 496
|
||||
LOC 0.8745 0.9191 0.8963 235
|
||||
ORGderiv 0.7723 0.8571 0.8125 91
|
||||
OTHderiv 0.4800 0.6667 0.5581 18
|
||||
OTH 0.5789 0.6875 0.6286 16
|
||||
PERderiv 0.5385 0.3889 0.4516 18
|
||||
PER 0.5000 0.5000 0.5000 2
|
||||
ORG 0.0000 0.0000 0.0000 3
|
||||
|
||||
micro avg 0.8574 0.8862 0.8715 5968
|
||||
macro avg 0.8575 0.8862 0.8713 5968
|
||||
```
|
||||
|
||||
On the test dataset the following results could be achieved:
|
||||
```bash
|
||||
precision recall f1-score support
|
||||
|
||||
PERpart 0.8847 0.8944 0.8896 9397
|
||||
OTHpart 0.9376 0.9353 0.9365 1639
|
||||
ORGpart 0.7307 0.7044 0.7173 697
|
||||
LOC 0.9133 0.9394 0.9262 561
|
||||
LOCpart 0.8058 0.8157 0.8107 1150
|
||||
ORG 0.0000 0.0000 0.0000 8
|
||||
OTHderiv 0.5882 0.4762 0.5263 42
|
||||
PERderiv 0.6571 0.5227 0.5823 44
|
||||
OTH 0.4906 0.6667 0.5652 39
|
||||
ORGderiv 0.7016 0.7791 0.7383 172
|
||||
LOCderiv 0.8256 0.6514 0.7282 109
|
||||
PER 0.0000 0.0000 0.0000 11
|
||||
|
||||
micro avg 0.8722 0.8774 0.8748 13869
|
||||
macro avg 0.8712 0.8774 0.8740 13869
|
||||
```
|
32
examples/ner/run.sh
Normal file
32
examples/ner/run.sh
Normal file
@ -0,0 +1,32 @@
|
||||
curl -L 'https://sites.google.com/site/germeval2014ner/data/NER-de-train.tsv?attredirects=0&d=1' \
|
||||
| grep -v "^#" | cut -f 2,3 | tr '\t' ' ' > train.txt.tmp
|
||||
curl -L 'https://sites.google.com/site/germeval2014ner/data/NER-de-dev.tsv?attredirects=0&d=1' \
|
||||
| grep -v "^#" | cut -f 2,3 | tr '\t' ' ' > dev.txt.tmp
|
||||
curl -L 'https://sites.google.com/site/germeval2014ner/data/NER-de-test.tsv?attredirects=0&d=1' \
|
||||
| grep -v "^#" | cut -f 2,3 | tr '\t' ' ' > test.txt.tmp
|
||||
wget "https://raw.githubusercontent.com/stefan-it/fine-tuned-berts-seq/master/scripts/preprocess.py"
|
||||
export MAX_LENGTH=128
|
||||
export BERT_MODEL=bert-base-multilingual-cased
|
||||
python3 preprocess.py train.txt.tmp $BERT_MODEL $MAX_LENGTH > train.txt
|
||||
python3 preprocess.py dev.txt.tmp $BERT_MODEL $MAX_LENGTH > dev.txt
|
||||
python3 preprocess.py test.txt.tmp $BERT_MODEL $MAX_LENGTH > test.txt
|
||||
cat train.txt dev.txt test.txt | cut -d " " -f 2 | grep -v "^$"| sort | uniq > labels.txt
|
||||
export OUTPUT_DIR=germeval-model
|
||||
export BATCH_SIZE=32
|
||||
export NUM_EPOCHS=3
|
||||
export SAVE_STEPS=750
|
||||
export SEED=1
|
||||
|
||||
python3 run_ner.py --data_dir ./ \
|
||||
--model_type bert \
|
||||
--labels ./labels.txt \
|
||||
--model_name_or_path $BERT_MODEL \
|
||||
--output_dir $OUTPUT_DIR \
|
||||
--max_seq_length $MAX_LENGTH \
|
||||
--num_train_epochs $NUM_EPOCHS \
|
||||
--per_gpu_train_batch_size $BATCH_SIZE \
|
||||
--save_steps $SAVE_STEPS \
|
||||
--seed $SEED \
|
||||
--do_train \
|
||||
--do_eval \
|
||||
--do_predict
|
21
examples/ner/run_pl.sh
Normal file
21
examples/ner/run_pl.sh
Normal file
@ -0,0 +1,21 @@
|
||||
# Require pytorch-lightning=0.6
|
||||
export MAX_LENGTH=128
|
||||
export BERT_MODEL=bert-base-multilingual-cased
|
||||
export OUTPUT_DIR=germeval-model
|
||||
export BATCH_SIZE=32
|
||||
export NUM_EPOCHS=3
|
||||
export SAVE_STEPS=750
|
||||
export SEED=1
|
||||
|
||||
python3 run_pl_ner.py --data_dir ./ \
|
||||
--model_type bert \
|
||||
--labels ./labels.txt \
|
||||
--model_name_or_path $BERT_MODEL \
|
||||
--output_dir $OUTPUT_DIR \
|
||||
--max_seq_length $MAX_LENGTH \
|
||||
--num_train_epochs $NUM_EPOCHS \
|
||||
--train_batch_size 32 \
|
||||
--save_steps $SAVE_STEPS \
|
||||
--seed $SEED \
|
||||
--do_train \
|
||||
--do_predict
|
238
examples/ner/run_pl_ner.py
Normal file
238
examples/ner/run_pl_ner.py
Normal file
@ -0,0 +1,238 @@
|
||||
import argparse
|
||||
import glob
|
||||
import logging
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from seqeval.metrics import f1_score, precision_score, recall_score
|
||||
from torch.nn import CrossEntropyLoss
|
||||
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
|
||||
from transformer_base import BaseTransformer, add_generic_args, generic_train
|
||||
from utils_ner import convert_examples_to_features, get_labels, read_examples_from_file
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class NERTransformer(BaseTransformer):
|
||||
"""
|
||||
A training module for NER. See BaseTransformer for the core options.
|
||||
"""
|
||||
|
||||
def __init__(self, hparams):
|
||||
self.labels = get_labels(hparams.labels)
|
||||
num_labels = len(self.labels)
|
||||
super(NERTransformer, self).__init__(hparams, num_labels)
|
||||
|
||||
def forward(self, **inputs):
|
||||
return self.model(**inputs)
|
||||
|
||||
def training_step(self, batch, batch_num):
|
||||
"Compute loss"
|
||||
inputs = {"input_ids": batch[0], "attention_mask": batch[1], "labels": batch[3]}
|
||||
if self.hparams.model_type != "distilbert":
|
||||
inputs["token_type_ids"] = (
|
||||
batch[2] if self.hparams.model_type in ["bert", "xlnet"] else None
|
||||
) # XLM and RoBERTa don"t use segment_ids
|
||||
|
||||
outputs = self.forward(**inputs)
|
||||
loss = outputs[0]
|
||||
|
||||
tensorboard_logs = {"loss": loss, "rate": self.lr_scheduler.get_last_lr()[-1]}
|
||||
return {"loss": loss, "log": tensorboard_logs}
|
||||
|
||||
def load_dataset(self, mode, batch_size):
|
||||
labels = get_labels(self.hparams.labels)
|
||||
self.pad_token_label_id = CrossEntropyLoss().ignore_index
|
||||
dataset = self.load_and_cache_examples(labels, self.pad_token_label_id, mode)
|
||||
if mode == "train":
|
||||
if self.hparams.n_gpu > 1:
|
||||
sampler = DistributedSampler(dataset)
|
||||
else:
|
||||
sampler = RandomSampler(dataset)
|
||||
else:
|
||||
sampler = SequentialSampler(dataset)
|
||||
dataloader = DataLoader(dataset, sampler=sampler, batch_size=batch_size)
|
||||
return dataloader
|
||||
|
||||
def validation_step(self, batch, batch_nb):
|
||||
inputs = {"input_ids": batch[0], "attention_mask": batch[1], "labels": batch[3]}
|
||||
if self.hparams.model_type != "distilbert":
|
||||
inputs["token_type_ids"] = (
|
||||
batch[2] if self.hparams.model_type in ["bert", "xlnet"] else None
|
||||
) # XLM and RoBERTa don"t use segment_ids
|
||||
outputs = self.forward(**inputs)
|
||||
tmp_eval_loss, logits = outputs[:2]
|
||||
preds = logits.detach().cpu().numpy()
|
||||
out_label_ids = inputs["labels"].detach().cpu().numpy()
|
||||
|
||||
return {"val_loss": tmp_eval_loss, "pred": preds, "target": out_label_ids}
|
||||
|
||||
def _eval_end(self, outputs):
|
||||
"Task specific validation"
|
||||
val_loss_mean = torch.stack([x["val_loss"] for x in outputs]).mean()
|
||||
preds = np.concatenate([x["pred"] for x in outputs], axis=0)
|
||||
preds = np.argmax(preds, axis=2)
|
||||
out_label_ids = np.concatenate([x["target"] for x in outputs], axis=0)
|
||||
|
||||
label_map = {i: label for i, label in enumerate(self.labels)}
|
||||
out_label_list = [[] for _ in range(out_label_ids.shape[0])]
|
||||
preds_list = [[] for _ in range(out_label_ids.shape[0])]
|
||||
|
||||
for i in range(out_label_ids.shape[0]):
|
||||
for j in range(out_label_ids.shape[1]):
|
||||
if out_label_ids[i, j] != self.pad_token_label_id:
|
||||
out_label_list[i].append(label_map[out_label_ids[i][j]])
|
||||
preds_list[i].append(label_map[preds[i][j]])
|
||||
|
||||
results = {
|
||||
"val_loss": val_loss_mean,
|
||||
"precision": precision_score(out_label_list, preds_list),
|
||||
"recall": recall_score(out_label_list, preds_list),
|
||||
"f1": f1_score(out_label_list, preds_list),
|
||||
}
|
||||
|
||||
if self.is_logger():
|
||||
logger.info(self.proc_rank)
|
||||
logger.info("***** Eval results *****")
|
||||
for key in sorted(results.keys()):
|
||||
logger.info(" %s = %s", key, str(results[key]))
|
||||
|
||||
tensorboard_logs = results
|
||||
ret = {k: v for k, v in results.items()}
|
||||
ret["log"] = tensorboard_logs
|
||||
return ret, preds_list, out_label_list
|
||||
|
||||
def validation_end(self, outputs):
|
||||
ret, preds, targets = self._eval_end(outputs)
|
||||
return ret
|
||||
|
||||
def test_end(self, outputs):
|
||||
ret, predictions, targets = self._eval_end(outputs)
|
||||
|
||||
if self.is_logger():
|
||||
# Write output to a file:
|
||||
# Save results
|
||||
output_test_results_file = os.path.join(self.hparams.output_dir, "test_results.txt")
|
||||
with open(output_test_results_file, "w") as writer:
|
||||
for key in sorted(ret.keys()):
|
||||
if key != "log":
|
||||
writer.write("{} = {}\n".format(key, str(ret[key])))
|
||||
# Save predictions
|
||||
output_test_predictions_file = os.path.join(self.hparams.output_dir, "test_predictions.txt")
|
||||
with open(output_test_predictions_file, "w") as writer:
|
||||
with open(os.path.join(self.hparams.data_dir, "test.txt"), "r") as f:
|
||||
example_id = 0
|
||||
for line in f:
|
||||
if line.startswith("-DOCSTART-") or line == "" or line == "\n":
|
||||
writer.write(line)
|
||||
if not predictions[example_id]:
|
||||
example_id += 1
|
||||
elif predictions[example_id]:
|
||||
output_line = line.split()[0] + " " + predictions[example_id].pop(0) + "\n"
|
||||
writer.write(output_line)
|
||||
else:
|
||||
logger.warning(
|
||||
"Maximum sequence length exceeded: No prediction for '%s'.", line.split()[0]
|
||||
)
|
||||
return ret
|
||||
|
||||
def load_and_cache_examples(self, labels, pad_token_label_id, mode):
|
||||
args = self.hparams
|
||||
tokenizer = self.tokenizer
|
||||
if self.proc_rank not in [-1, 0] and mode == "train":
|
||||
torch.distributed.barrier() # Make sure only the first process in distributed training process the dataset, and the others will use the cache
|
||||
|
||||
# Load data features from cache or dataset file
|
||||
cached_features_file = os.path.join(
|
||||
args.data_dir,
|
||||
"cached_{}_{}_{}".format(
|
||||
mode, list(filter(None, args.model_name_or_path.split("/"))).pop(), str(args.max_seq_length)
|
||||
),
|
||||
)
|
||||
if os.path.exists(cached_features_file) and not args.overwrite_cache:
|
||||
logger.info("Loading features from cached file %s", cached_features_file)
|
||||
features = torch.load(cached_features_file)
|
||||
else:
|
||||
logger.info("Creating features from dataset file at %s", args.data_dir)
|
||||
examples = read_examples_from_file(args.data_dir, mode)
|
||||
features = convert_examples_to_features(
|
||||
examples,
|
||||
labels,
|
||||
args.max_seq_length,
|
||||
tokenizer,
|
||||
cls_token_at_end=bool(args.model_type in ["xlnet"]),
|
||||
cls_token=tokenizer.cls_token,
|
||||
cls_token_segment_id=2 if args.model_type in ["xlnet"] else 0,
|
||||
sep_token=tokenizer.sep_token,
|
||||
sep_token_extra=bool(args.model_type in ["roberta"]),
|
||||
pad_on_left=bool(args.model_type in ["xlnet"]),
|
||||
pad_token=tokenizer.convert_tokens_to_ids([tokenizer.pad_token])[0],
|
||||
pad_token_segment_id=4 if args.model_type in ["xlnet"] else 0,
|
||||
pad_token_label_id=pad_token_label_id,
|
||||
)
|
||||
if self.proc_rank in [-1, 0]:
|
||||
logger.info("Saving features into cached file %s", cached_features_file)
|
||||
torch.save(features, cached_features_file)
|
||||
|
||||
if self.proc_rank == 0 and mode == "train":
|
||||
torch.distributed.barrier() # Make sure only the first process in distributed training process the dataset, and the others will use the cache
|
||||
|
||||
# Convert to Tensors and build dataset
|
||||
all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
|
||||
all_input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long)
|
||||
all_segment_ids = torch.tensor([f.segment_ids for f in features], dtype=torch.long)
|
||||
all_label_ids = torch.tensor([f.label_ids for f in features], dtype=torch.long)
|
||||
|
||||
dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)
|
||||
return dataset
|
||||
|
||||
@staticmethod
|
||||
def add_model_specific_args(parser, root_dir):
|
||||
# Add NER specific options
|
||||
BaseTransformer.add_model_specific_args(parser, root_dir)
|
||||
parser.add_argument(
|
||||
"--max_seq_length",
|
||||
default=128,
|
||||
type=int,
|
||||
help="The maximum total input sequence length after tokenization. Sequences longer "
|
||||
"than this will be truncated, sequences shorter will be padded.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--labels",
|
||||
default="",
|
||||
type=str,
|
||||
help="Path to a file containing all labels. If not specified, CoNLL-2003 labels are used.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--data_dir",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="The input data dir. Should contain the training files for the CoNLL-2003 NER task.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--overwrite_cache", action="store_true", help="Overwrite the cached training and evaluation sets"
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
add_generic_args(parser, os.getcwd())
|
||||
parser = NERTransformer.add_model_specific_args(parser, os.getcwd())
|
||||
args = parser.parse_args()
|
||||
model = NERTransformer(args)
|
||||
trainer = generic_train(model, args)
|
||||
|
||||
if args.do_predict:
|
||||
checkpoints = list(sorted(glob.glob(args.output_dir + "/checkpoint_*.ckpt", recursive=True)))
|
||||
NERTransformer.load_from_checkpoint(checkpoints[-1])
|
||||
trainer.test(model)
|
264
examples/ner/transformer_base.py
Normal file
264
examples/ner/transformer_base.py
Normal file
@ -0,0 +1,264 @@
|
||||
import os
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
|
||||
from transformers import (
|
||||
AdamW,
|
||||
BertConfig,
|
||||
BertForTokenClassification,
|
||||
BertTokenizer,
|
||||
CamembertConfig,
|
||||
CamembertForTokenClassification,
|
||||
CamembertTokenizer,
|
||||
DistilBertConfig,
|
||||
DistilBertForTokenClassification,
|
||||
DistilBertTokenizer,
|
||||
RobertaConfig,
|
||||
RobertaForTokenClassification,
|
||||
RobertaTokenizer,
|
||||
XLMRobertaConfig,
|
||||
XLMRobertaForTokenClassification,
|
||||
XLMRobertaTokenizer,
|
||||
get_linear_schedule_with_warmup,
|
||||
)
|
||||
|
||||
|
||||
ALL_MODELS = sum(
|
||||
(
|
||||
tuple(conf.pretrained_config_archive_map.keys())
|
||||
for conf in (BertConfig, RobertaConfig, DistilBertConfig, CamembertConfig, XLMRobertaConfig)
|
||||
),
|
||||
(),
|
||||
)
|
||||
|
||||
MODEL_CLASSES = {
|
||||
"bert": (BertConfig, BertForTokenClassification, BertTokenizer),
|
||||
"roberta": (RobertaConfig, RobertaForTokenClassification, RobertaTokenizer),
|
||||
"distilbert": (DistilBertConfig, DistilBertForTokenClassification, DistilBertTokenizer),
|
||||
"camembert": (CamembertConfig, CamembertForTokenClassification, CamembertTokenizer),
|
||||
"xlmroberta": (XLMRobertaConfig, XLMRobertaForTokenClassification, XLMRobertaTokenizer),
|
||||
}
|
||||
|
||||
|
||||
def set_seed(args):
|
||||
random.seed(args.seed)
|
||||
np.random.seed(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
if args.n_gpu > 0:
|
||||
torch.cuda.manual_seed_all(args.seed)
|
||||
|
||||
|
||||
class BaseTransformer(pl.LightningModule):
|
||||
def __init__(self, hparams, num_labels=None):
|
||||
"Initialize a model."
|
||||
|
||||
super(BaseTransformer, self).__init__()
|
||||
self.hparams = hparams
|
||||
self.hparams.model_type = self.hparams.model_type.lower()
|
||||
|
||||
config_class, model_class, tokenizer_class = MODEL_CLASSES[self.hparams.model_type]
|
||||
config = config_class.from_pretrained(
|
||||
self.hparams.config_name if self.hparams.config_name else self.hparams.model_name_or_path,
|
||||
num_labels=num_labels,
|
||||
cache_dir=self.hparams.cache_dir if self.hparams.cache_dir else None,
|
||||
)
|
||||
tokenizer = tokenizer_class.from_pretrained(
|
||||
self.hparams.tokenizer_name if self.hparams.tokenizer_name else self.hparams.model_name_or_path,
|
||||
do_lower_case=self.hparams.do_lower_case,
|
||||
cache_dir=self.hparams.cache_dir if self.hparams.cache_dir else None,
|
||||
)
|
||||
model = model_class.from_pretrained(
|
||||
self.hparams.model_name_or_path,
|
||||
from_tf=bool(".ckpt" in self.hparams.model_name_or_path),
|
||||
config=config,
|
||||
cache_dir=self.hparams.cache_dir if self.hparams.cache_dir else None,
|
||||
)
|
||||
self.config, self.tokenizer, self.model = config, tokenizer, model
|
||||
self.proc_rank = -1
|
||||
|
||||
def is_logger(self):
|
||||
return self.proc_rank <= 0
|
||||
|
||||
def configure_optimizers(self):
|
||||
"Prepare optimizer and schedule (linear warmup and decay)"
|
||||
model = self.model
|
||||
|
||||
t_total = (
|
||||
len(self.train_dataloader())
|
||||
// self.hparams.gradient_accumulation_steps
|
||||
* float(self.hparams.num_train_epochs)
|
||||
)
|
||||
no_decay = ["bias", "LayerNorm.weight"]
|
||||
optimizer_grouped_parameters = [
|
||||
{
|
||||
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
|
||||
"weight_decay": self.hparams.weight_decay,
|
||||
},
|
||||
{
|
||||
"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
|
||||
"weight_decay": 0.0,
|
||||
},
|
||||
]
|
||||
optimizer = AdamW(optimizer_grouped_parameters, lr=self.hparams.learning_rate, eps=self.hparams.adam_epsilon)
|
||||
scheduler = get_linear_schedule_with_warmup(
|
||||
optimizer, num_warmup_steps=self.hparams.warmup_steps, num_training_steps=t_total
|
||||
)
|
||||
self.lr_scheduler = scheduler
|
||||
return [optimizer]
|
||||
|
||||
def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, second_order_closure=None):
|
||||
|
||||
# Step each time.
|
||||
optimizer.step()
|
||||
self.lr_scheduler.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
def get_tqdm_dict(self):
|
||||
tqdm_dict = {"loss": "{:.3f}".format(self.trainer.avg_loss), "lr": self.lr_scheduler.get_last_lr()[-1]}
|
||||
|
||||
return tqdm_dict
|
||||
|
||||
def test_step(self, batch, batch_nb):
|
||||
return self.validation_step(batch, batch_nb)
|
||||
|
||||
def test_end(self, outputs):
|
||||
return self.validation_end(outputs)
|
||||
|
||||
@pl.data_loader
|
||||
def train_dataloader(self):
|
||||
return self.load_dataset("train", self.hparams.train_batch_size)
|
||||
|
||||
@pl.data_loader
|
||||
def val_dataloader(self):
|
||||
return self.load_dataset("dev", self.hparams.eval_batch_size)
|
||||
|
||||
@pl.data_loader
|
||||
def test_dataloader(self):
|
||||
return self.load_dataset("test", self.hparams.eval_batch_size)
|
||||
|
||||
def init_ddp_connection(self, proc_rank, world_size):
|
||||
self.proc_rank = proc_rank
|
||||
super(BaseTransformer, self).init_ddp_connection(proc_rank, world_size)
|
||||
|
||||
@staticmethod
|
||||
def add_model_specific_args(parser, root_dir):
|
||||
parser.add_argument(
|
||||
"--model_type",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_name_or_path",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--config_name", default="", type=str, help="Pretrained config name or path if not the same as model_name"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tokenizer_name",
|
||||
default="",
|
||||
type=str,
|
||||
help="Pretrained tokenizer name or path if not the same as model_name",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cache_dir",
|
||||
default="",
|
||||
type=str,
|
||||
help="Where do you want to store the pre-trained models downloaded from s3",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--do_lower_case", action="store_true", help="Set this flag if you are using an uncased model."
|
||||
)
|
||||
parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.")
|
||||
parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.")
|
||||
parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.")
|
||||
parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.")
|
||||
parser.add_argument(
|
||||
"--num_train_epochs", default=3, type=int, help="Total number of training epochs to perform."
|
||||
)
|
||||
|
||||
parser.add_argument("--train_batch_size", default=32, type=int)
|
||||
parser.add_argument("--eval_batch_size", default=32, type=int)
|
||||
|
||||
|
||||
def add_generic_args(parser, root_dir):
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="The output directory where the model predictions and checkpoints will be written.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--fp16",
|
||||
action="store_true",
|
||||
help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--fp16_opt_level",
|
||||
type=str,
|
||||
default="O1",
|
||||
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
|
||||
"See details at https://nvidia.github.io/apex/amp.html",
|
||||
)
|
||||
|
||||
parser.add_argument("--n_gpu", type=int, default=1)
|
||||
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
|
||||
parser.add_argument("--do_train", action="store_true", help="Whether to run training.")
|
||||
parser.add_argument("--do_predict", action="store_true", help="Whether to run predictions on the test set.")
|
||||
parser.add_argument(
|
||||
"--gradient_accumulation_steps",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of updates steps to accumulate before performing a backward/update pass.",
|
||||
)
|
||||
|
||||
parser.add_argument("--server_ip", type=str, default="", help="For distant debugging.")
|
||||
parser.add_argument("--server_port", type=str, default="", help="For distant debugging.")
|
||||
parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
|
||||
|
||||
|
||||
def generic_train(model, args):
|
||||
# init model
|
||||
set_seed(args)
|
||||
|
||||
# Setup distant debugging if needed
|
||||
if args.server_ip and args.server_port:
|
||||
# Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
|
||||
import ptvsd
|
||||
|
||||
print("Waiting for debugger attach")
|
||||
ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
|
||||
ptvsd.wait_for_attach()
|
||||
|
||||
if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train:
|
||||
raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir))
|
||||
|
||||
checkpoint_callback = pl.callbacks.ModelCheckpoint(
|
||||
filepath=args.output_dir, prefix="checkpoint", monitor="val_loss", mode="min", save_top_k=5
|
||||
)
|
||||
|
||||
trainer = pl.Trainer(
|
||||
accumulate_grad_batches=args.gradient_accumulation_steps,
|
||||
gpus=args.n_gpu,
|
||||
max_epochs=args.num_train_epochs,
|
||||
use_amp=args.fp16,
|
||||
amp_level=args.fp16_opt_level,
|
||||
distributed_backend="ddp",
|
||||
gradient_clip_val=args.max_grad_norm,
|
||||
checkpoint_callback=checkpoint_callback,
|
||||
)
|
||||
if args.do_train:
|
||||
trainer.fit(model)
|
||||
|
||||
return trainer
|
Loading…
Reference in New Issue
Block a user