[s2s] distillBART docs for paper replication (#8150)

This commit is contained in:
Sam Shleifer 2020-10-29 12:01:15 -04:00 committed by GitHub
parent acf56408d8
commit 49e4fece5c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 199 additions and 56 deletions

View File

@ -1,8 +1,8 @@
## Sequence to Sequence
## Sequence to Sequence Training and Evaluation
This directory contains examples for finetuning and evaluating transformers on summarization and translation tasks.
Please tag @sshleifer with any issues/unexpected behaviors, or send a PR!
For `bertabs` instructions, see [`bertabs/README.md`](bertabs/README.md).
For deprecated `bertabs` instructions, see [`bertabs/README.md`](bertabs/README.md).
### Supported Architectures
@ -66,7 +66,7 @@ Multiple eval datasets are available for download from:
https://github.com/stas00/porting/tree/master/datasets/pegasus
#### Private Data
#### Your Data
If you are using your own data, it must be formatted as one directory with 6 files:
```
@ -232,7 +232,191 @@ Following command fine-tunes `sshleifer/student_marian_en_ro_6_3` on TPU V3-8 an
./builtin_trainer/train_distil_marian_enro_tpu.sh
```
### Evaluation Commands
# DistilBART
<!---It should be called distilling bart and pegasus, but I don't want to break the link in the paper.-->
This section describes all code and artifacts from our [Paper](http://arxiv.org/abs/2010.13002)
![DBART](https://huggingface.co/front/thumbnails/distilbart_large.png)
+ For the CNN/DailyMail dataset, (relatively longer, more extractive summaries), we found a simple technique that works, which we call "Shrink and Fine-tune", or SFT.
you just copy alternating layers from `facebook/bart-large-cnn` and fine-tune more on the cnn/dm data. `sshleifer/distill-pegasus-cnn-16-4`, `sshleifer/distilbart-cnn-12-6` and all other checkpoints under `sshleifer` that start with `distilbart-cnn` were trained this way.
+ For the XSUM dataset, training on pseudo-labels worked best for Pegasus (`sshleifer/distill-pegasus-16-4`), while training with KD worked best for `distilbart-xsum-12-6`
+ For `sshleifer/dbart-xsum-12-3`
+ We ran 100s experiments, and didn't want to document 100s of commands. If you want a command to replicate a figure from the paper that is not documented below, feel free to ask on the [forums](https://discuss.huggingface.co/t/seq2seq-distillation-methodology-questions/1270) and tag `@sshleifer`.
+ You can see the performance tradeoffs of model sizes [here](https://docs.google.com/spreadsheets/d/1EkhDMwVO02m8jCD1cG3RoFPLicpcL1GQHTQjfvDYgIM/edit#gid=0).
and more granular timing results [here](https://docs.google.com/spreadsheets/d/1EkhDMwVO02m8jCD1cG3RoFPLicpcL1GQHTQjfvDYgIM/edit#gid=1753259047&range=B2:I23).
### Evaluation
use [run_distributed_eval](./run_distributed_eval.py), with the following convenient alias
```bash
deval () {
proc=$1
m=$2
dd=$3
sd=$4
shift
shift
shift
shift
python -m torch.distributed.launch --nproc_per_node=$proc run_distributed_eval.py \
--model_name $m --save_dir $sd --data_dir $dd $@
}
```
On a 1 GPU system, here are four commands (that assume `xsum`, `cnn_dm` are downloaded, cmd-F for those links in this file).
`distilBART`:
```bash
deval 1 sshleifer/distilbart-xsum-12-3 xsum dbart_12_3_xsum_eval --fp16 # --help for more choices.
deval 1 sshleifer/distilbart-cnn_dm-12-6 cnn_dm dbart_12_6_cnn_eval --fp16
```
`distill-pegasus`:
```bash
deval 1 sshleifer/distill-pegasus-cnn-16-4 cnn_dm dpx_cnn_eval
deval 1 sshleifer/distill-pegasus-xsum-16-4 xsum dpx_xsum_eval
```
### Distillation
+ For all of the following commands, you can get roughly equivalent result and faster run times by passing `--num_beams=4`. That's not what we did for the paper.
+ Besides the KD section, you can also run commands with the built-in transformers trainer. See, for example, [builtin_trainer/train_distilbart_cnn.sh](./builtin_trainer/train_distilbart_cnn.sh).
+ Large performance deviations (> 5X slower or more than 0.5 Rouge-2 worse), should be reported.
+ Multi-gpu (controlled with `--gpus` should work, but might require more epochs).
#### Recommended Workflow
+ Get your dataset in the right format. (see 6 files above).
+ Find a teacher model [Pegasus](https://huggingface.co/models?search=pegasus) (slower, better ROUGE) or `facebook/bart-large-xsum`/`facebook/bart-large-cnn` (faster, slightly lower.).
Choose the checkpoint where the corresponding dataset is most similar (or identical to) your dataset.
+ Follow the sections in order below. You can stop after SFT if you are satisfied, or move on to pseudo-labeling if you want more performance.
+ student size: If you want a close to free 50% speedup, cut the decoder in half. If you want a larger speedup, cut it in 4.
+ If your SFT run starts at a validation ROUGE-2 that is more than 10 pts below the teacher's validation ROUGE-2, you have a bug. Switching to a more expensive technique will not help. Try setting a breakpoint and looking at generation and truncation defaults/hyper-parameters, and share your experience on the forums!
#### Initialization
We use [make_student.py](./make_student.py) to copy alternating layers from the teacher, and save the resulting model to disk
```bash
python make_student.py facebook/bart-large-xsum --save_path dbart_xsum_12_3 -e 12 -d 3
```
or for `pegasus-xsum`
```bash
python make_student.py google/pegasus-xsum --save_path dpx_xsum_16_4 --e 16 --d 4
```
we now have an initialized student saved to `dbart_xsum_12_3`, which we will use for the following commands.
+ Extension: To replicate more complicated initialize experiments in section 6.1, or try your own. Use the `create_student_by_copying_alternating_layers` function.
#### Pegasus
+ The following commands are written for BART and will require, at minimum, the following modifications
+ reduce batch size, and increase gradient accumulation steps so that the product `gpus * batch size * gradient_accumulation_steps = 256`. We used `--learning-rate` = 1e-4 * gradient accumulation steps.
+ don't use fp16
+ `--tokenizer_name google/pegasus-large`
### SFT (No Teacher Distillation)
You don't need `distillation.py`, you can just run:
```bash
python finetune.py \
--data_dir xsum \
--freeze_encoder --freeze_embeds \
--learning_rate=3e-4 \
--do_train \
--do_predict \
--fp16 --fp16_opt_level=O1 \
--val_check_interval 0.1 --n_val 1000 --eval_beams 2 --length_penalty=0.5 \
--max_target_length=60 --val_max_target_length=60 --test_max_target_length=100 \
--model_name_or_path dbart_xsum_12_3 \
--train_batch_size=64 --eval_batch_size=64 \
--sortish_sampler \
--num_train_epochs=6 \
--warmup_steps 500 \
--output_dir distilbart_xsum_sft_12_3 --gpus 1
```
+ Note: The command that produced `sshleifer/distilbart-cnn-12-6` is at [train_distilbart_cnn.sh](./[train_distilbart_cnn.sh)
```bash
./train_distilbart_cnn.sh
```
<!--- runtime: 6H on NVIDIA RTX 24GB GPU -->
+ Tip: You can get the same simple distillation logic by using `distillation.py --no_teacher ` followed by identical arguments as the ones in `train_distilbart_cnn.sh`.
If you are using `wandb` and comparing the two distillation methods, using this entry point will make your logs consistent,
because you will have the same hyper-parameters logged in every run.
### Pseudo-Labeling
+ You don't need `distillation.py`.
+ Instructions to generate pseudo-labels and use pre-computed pseudo-labels can be found [here](./precomputed_pseudo_labels.md).
Simply run `finetune.py` with one of those pseudo-label datasets as `--data_dir` (`DATA`, below).
```bash
python finetune.py \
--teacher facebook/bart-large-xsum --data_dir DATA \
--freeze_encoder --freeze_embeds \
--learning_rate=3e-4 \
--do_train \
--do_predict \
--fp16 --fp16_opt_level=O1 \
--val_check_interval 0.1 --n_val 1000 --eval_beams 2 --length_penalty=0.5 \
--max_target_length=60 --val_max_target_length=60 --test_max_target_length=100 \
--model_name_or_path dbart_xsum_12_3 \
--train_batch_size=32 --eval_batch_size=32 \
--sortish_sampler \
--num_train_epochs=5 \
--warmup_steps 500 \
--output_dir dbart_xsum_12_3_PL --gpus 1 --logger_name wandb
```
To combine datasets, as in Section 6.2, try something like:
```bash
curl -S https://cdn-datasets.huggingface.co/pseudo/xsum/bart_xsum_pl.tgz | tar -xvz -C .
curl -S https://cdn-datasets.huggingface.co/pseudo/xsum/pegasus_xsum.tgz | tar -xvz -C .
curl -S https://cdn-datasets.huggingface.co/summarization/xsum.tar.gz | tar -xvz -C .
mkdir all_pl
cat bart_xsum_pl/train.source pegasus_xsum/train.source xsum/train.source > all_pl/train.source
cat bart_xsum_pl/train.target pegasus_xsum/train.target xsum/train.target > all_pl/train.target
cp xsum/val* all_pl
cp xsum/test* all_pl
```
then use `all_pl` as DATA in the command above.
#### Direct Knowledge Distillation (KD)
+ In this method, we use try to enforce that the student and teacher produce similar encoder_outputs, logits, and hidden_states using `BartSummarizationDistiller`.
+ This method was used for `sshleifer/distilbart-xsum-12-6`, `6-6`, and `9-6` checkpoints were produced.
+ You must use [`distillation.py`](./distillation.py). Note that this command initializes the student for you.
The command that produced `sshleifer/distilbart-xsum-12-6` is at [./train_distilbart_xsum.sh](train_distilbart_xsum.sh)
```bash
./train_distilbart_xsum.sh --logger_name wandb --gpus 1
```
+ Expected ROUGE-2 between 21.3 and 21.6, run time ~13H.
+ direct KD + Pegasus is VERY slow and works best with `--supervise_forward --normalize_hidden`.
<!--- runtime: 13H on V-100 16GB GPU. -->
### Citation
```bibtex
@misc{shleifer2020pretrained,
title={Pre-trained Summarization Distillation},
author={Sam Shleifer and Alexander M. Rush},
year={2020},
eprint={2010.13002},
archivePrefix={arXiv},
primaryClass={cs.CL}
}
@article{Wolf2019HuggingFacesTS,
title={HuggingFace's Transformers: State-of-the-art Natural Language Processing},
author={Thomas Wolf and Lysandre Debut and Victor Sanh and Julien Chaumond and Clement Delangue and Anthony Moi and Pierric Cistac and Tim Rault and Rémi Louf and Morgan Funtowicz and Joe Davison and Sam Shleifer and Patrick von Platen and Clara Ma and Yacine Jernite and Julien Plu and Canwen Xu and Teven Le Scao and Sylvain Gugger and Mariama Drame and Quentin Lhoest and Alexander M. Rush},
journal={ArXiv},
year={2019},
volume={abs/1910.03771}
}
```
This is the end of the distillation section, the rest of this doc pertains to general seq2seq commands.
## Evaluation Commands
To create summaries for each article in dataset, we use `run_eval.py`, here are a few commands that run eval for different tasks and models.
If 'translation' is in your task name, the computed metric will be BLEU. Otherwise, ROUGE will be used.
@ -272,9 +456,8 @@ export DATA_DIR=cnn_dm
--score_path cnn_rouge.json \
--task summarization \
--n_obs 100 \
--device cuda \
--max_source_length 1024 \
--max_target_length 56 \
th 56 \
--fp16 \
--bs 32
```
@ -356,47 +539,6 @@ stas/wmt19-en-ru data/en-ru/val.source data/en-ru/test_translations.txt --refere
If you pass `--info "some experiment-specific info"` it will get printed before the results table - this is useful for scripting and multiple runs, so one can tell the different sets of results from each other.
### DistilBART
![DBART](https://huggingface.co/front/thumbnails/distilbart_large.png)
For the CNN/DailyMail dataset, (relatively longer, more extractive summaries), we found a simple technique that works:
you just copy alternating layers from `bart-large-cnn` and finetune more on the same data.
For the XSUM dataset, that didnt work as well so we used that same initialization strategy followed by a combination of Distillberts ce_loss and the hidden states MSE loss used in the tinybert paper.
You can see the performance tradeoffs of model sizes [here](https://docs.google.com/spreadsheets/d/1EkhDMwVO02m8jCD1cG3RoFPLicpcL1GQHTQjfvDYgIM/edit#gid=0).
and more granular timing results [here](https://docs.google.com/spreadsheets/d/1EkhDMwVO02m8jCD1cG3RoFPLicpcL1GQHTQjfvDYgIM/edit#gid=1753259047&range=B2:I23).
#### No Teacher Distillation
To run the simpler distilbart-cnn style distillation all you need is data, a GPU, and a properly initialized student.
You don't even need `distillation.py`.
Some [un-finetuned students](https://huggingface.co/models?search=sshleifer%2Fstudent) are available for replication purposes.
They are initialized by copying layers from the associated `bart-large-{cnn|xsum}` teacher using `--init_strategy alternate`. (You can read about that in `initialization_utils.py`)
The command that produced `sshleifer/distilbart-cnn-12-6` is
```bash
./train_distilbart_cnn.sh
```
runtime: 6H on NVIDIA RTX 24GB GPU
*Note*: You can get the same simple distillation logic by using `./run_distiller.sh --no_teacher` followed by identical arguments as the ones in `train_distilbart_cnn.sh`.
If you are using `wandb` and comparing the two distillation methods, using this entry point will make your logs consistent,
because you will have the same hyperparameters logged in every run.
#### With a teacher (Intermediate Supervision)
*Note* only BART variants are supported
In this method, we use try to enforce that the student and teacher produce similar encoder_outputs, logits, and hidden_states using `BartSummarizationDistiller`.
This is how `sshleifer/distilbart-xsum*` checkpoints were produced.
The command that produced `sshleifer/distilbart-xsum-12-6` is:
```bash
./train_distilbart_xsum.sh --logger_name wandb --gpus 1
```
runtime: 13H on V-100 16GB GPU.
### Contributing
- follow the standard contributing guidelines and code of conduct.
- add tests to `test_seq2seq_examples.py`
@ -417,7 +559,7 @@ python convert_pl_checkpoint_to_hf.py PATH_TO_CKPT randomly_initialized_hf_mode
Then either `run_eval` or `run_distributed_eval` with `save_dir/best_tfmr` (see previous sections)
## Experimental Features
# Experimental Features
These features are harder to use and not always useful.
### Dynamic Batch Size for MT
@ -444,3 +586,5 @@ uses 12,723 batches of length 48 and takes slightly more time 9.5 minutes.
The feature is still experimental, because:
+ we can make it much more robust if we have memory mapped/preprocessed datasets.
+ The speedup over sortish sampler is not that large at the moment.

View File

@ -1,19 +1,18 @@
export WANDB_PROJECT=distilbart-cnn
export WANDB_PROJECT=distilbart-trainer
export BS=32
export GAS=1
export m=sshleifer/student_cnn_12_6
export tok=facebook/bart-large
export MAX_TGT_LEN=142
python finetune_trainer.py \
--model_name_or_path $m --tokenizer_name $tok \
--data_dir $CNN_DIR \
--data_dir cnn_dm \
--output_dir distilbart-cnn-12-6 --overwrite_output_dir \
--learning_rate=3e-5 \
--warmup_steps 500 --sortish_sampler \
--fp16 \
--n_val 500 \
--gradient_accumulation_steps=$GAS \
--gradient_accumulation_steps=1 \
--per_device_train_batch_size=$BS --per_device_eval_batch_size=$BS \
--freeze_encoder --freeze_embeds \
--num_train_epochs=2 \
@ -21,5 +20,5 @@ python finetune_trainer.py \
--logging_first_step \
--max_target_length 56 --val_max_target_length $MAX_TGT_LEN --test_max_target_length $MAX_TGT_LEN \
--do_train --do_eval --do_predict --evaluate_during_training \
--predict_with_generate \
--predict_with_generate --sortish_sampler \
"$@"

View File

@ -33,7 +33,7 @@ python -m torch.distributed.launch --nproc_per_node=8 run_distributed_eval.py \
--type_path train
```
+ These command takes a while to run. For example, pegasus_cnn_cnn_pls.tgz took 8 hours on 8 GPUs.
+ These commands takes a while to run. For example, `pegasus_cnn_cnn_pls.tgz` took 8 hours on 8 GPUs.
+ Pegasus does not work in fp16 :(, Bart, mBART and Marian do.
+ Even if you have 1 GPU, `run_distributed_eval.py` is 10-20% faster than `run_eval.py` because it uses `SortishSampler` to minimize padding computation.

View File

@ -13,7 +13,7 @@ python finetune.py \
--val_check_interval 0.25 \
--n_val 500 \
--num_train_epochs 2 \
--freeze_encoder --freeze_embeds --data_dir $CNN_DIR \
--freeze_encoder --freeze_embeds --data_dir cnn_dm \
--max_target_length 142 --val_max_target_length=142 \
--train_batch_size=$BS --eval_batch_size=$BS --gradient_accumulation_steps=$GAS \
--model_name_or_path sshleifer/student_cnn_12_6 \