mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-04 05:10:06 +06:00
[examples/seq2seq/README.md] fix t5 examples (#10734)
* [examples/seq2seq] fix t5 examples This PR: * fixes T5 examples to include `--source_prefix` - it's **not** optional. If you give it a try you will see that you get 10x worse bleu scores w/o it. w/ `27.6849`, w/ `2.374` * added a normal translation example w/o the peculiarities of MBart and T5 * reduces the default max samples to 50 so it's much faster to test quickly summarization seems to be broken for t5 score-wise: https://github.com/huggingface/transformers/issues/10733 @sgugger * specify explicitly the t5 models requiring the special handling * one more * update the t5 summarization example to use cnn_dailymail * move max*samples into the top level README.md * better wording * better wording
This commit is contained in:
parent
094afa515d
commit
9352b5151a
@ -95,6 +95,23 @@ Coming soon!
|
|||||||
| [**`translation`**](https://github.com/huggingface/transformers/tree/master/examples/seq2seq) | WMT | ✅ | - | - | -
|
| [**`translation`**](https://github.com/huggingface/transformers/tree/master/examples/seq2seq) | WMT | ✅ | - | - | -
|
||||||
|
|
||||||
|
|
||||||
|
## Running quick tests
|
||||||
|
|
||||||
|
Most examples are equipped with a mechanism to truncate the number of dataset samples to the desired length. This is useful for debugging purposes, for example to quickly check that all stages of the programs can complete, before running the same setup on the full dataset which may take hours to complete.
|
||||||
|
|
||||||
|
For example here is how to truncate all three splits to just 50 samples each:
|
||||||
|
```
|
||||||
|
examples/token-classification/run_ner.py \
|
||||||
|
--max_train_samples 50 \
|
||||||
|
--max_val_samples 50 \
|
||||||
|
--max_test_samples 50 \
|
||||||
|
[...]
|
||||||
|
```
|
||||||
|
|
||||||
|
Most example scripts should have the first two command line arguments and some have the third one. You can quickly check if a given example supports any of these by passing a `-h` option, e.g.:
|
||||||
|
```
|
||||||
|
examples/token-classification/run_ner.py -h
|
||||||
|
```
|
||||||
|
|
||||||
## Resuming training
|
## Resuming training
|
||||||
|
|
||||||
|
@ -24,10 +24,10 @@ For the old `finetune_trainer.py` and related utils, see [`examples/legacy/seq2s
|
|||||||
### Supported Architectures
|
### Supported Architectures
|
||||||
|
|
||||||
- `BartForConditionalGeneration`
|
- `BartForConditionalGeneration`
|
||||||
|
- `FSMTForConditionalGeneration` (translation only)
|
||||||
|
- `MBartForConditionalGeneration`
|
||||||
- `MarianMTModel`
|
- `MarianMTModel`
|
||||||
- `PegasusForConditionalGeneration`
|
- `PegasusForConditionalGeneration`
|
||||||
- `MBartForConditionalGeneration`
|
|
||||||
- `FSMTForConditionalGeneration` (translation only)
|
|
||||||
- `T5ForConditionalGeneration`
|
- `T5ForConditionalGeneration`
|
||||||
|
|
||||||
`run_summarization.py` and `run_translation.py` are lightweight examples of how to download and preprocess a dataset from the [🤗 Datasets](https://github.com/huggingface/datasets) library or use your own files (jsonlines or csv), then fine-tune one of the architectures above on it.
|
`run_summarization.py` and `run_translation.py` are lightweight examples of how to download and preprocess a dataset from the [🤗 Datasets](https://github.com/huggingface/datasets) library or use your own files (jsonlines or csv), then fine-tune one of the architectures above on it.
|
||||||
@ -43,17 +43,21 @@ python examples/seq2seq/run_summarization.py \
|
|||||||
--model_name_or_path t5-small \
|
--model_name_or_path t5-small \
|
||||||
--do_train \
|
--do_train \
|
||||||
--do_eval \
|
--do_eval \
|
||||||
--dataset_name xsum \
|
--dataset_name cnn_dailymail \
|
||||||
|
--dataset_config "3.0.0" \
|
||||||
|
--source_prefix "summarize: " \
|
||||||
--output_dir /tmp/tst-summarization \
|
--output_dir /tmp/tst-summarization \
|
||||||
--per_device_train_batch_size=4 \
|
--per_device_train_batch_size=4 \
|
||||||
--per_device_eval_batch_size=4 \
|
--per_device_eval_batch_size=4 \
|
||||||
--overwrite_output_dir \
|
--overwrite_output_dir \
|
||||||
--predict_with_generate \
|
--predict_with_generate
|
||||||
--max_train_samples 500 \
|
|
||||||
--max_val_samples 500
|
|
||||||
```
|
```
|
||||||
|
|
||||||
CNN/DailyMail dataset is another commonly used dataset for the task of summarization. To use it replace `--dataset_name xsum` with `--dataset_name cnn_dailymail --dataset_config "3.0.0"`.
|
Only T5 models `t5-small`, `t5-base`, `t5-large`, `t5-3b` and `t5-11b` must use an additional argument: `--source_prefix "summarize: "`.
|
||||||
|
|
||||||
|
We used CNN/DailyMail dataset in this example as `t5-small` was trained on it and one can get good scores even when pre-training with a very small sample.
|
||||||
|
|
||||||
|
Extreme Summarization (XSum) Dataset is another commonly used dataset for the task of summarization. To use it replace `--dataset_name cnn_dailymail --dataset_config "3.0.0"` with `--dataset_name xsum`.
|
||||||
|
|
||||||
And here is how you would use it on your own files, after adjusting the values for the arguments
|
And here is how you would use it on your own files, after adjusting the values for the arguments
|
||||||
`--train_file`, `--validation_file`, `--text_column` and `--summary_column` to match your setup:
|
`--train_file`, `--validation_file`, `--text_column` and `--summary_column` to match your setup:
|
||||||
@ -65,13 +69,12 @@ python examples/seq2seq/run_summarization.py \
|
|||||||
--do_eval \
|
--do_eval \
|
||||||
--train_file path_to_csv_or_jsonlines_file \
|
--train_file path_to_csv_or_jsonlines_file \
|
||||||
--validation_file path_to_csv_or_jsonlines_file \
|
--validation_file path_to_csv_or_jsonlines_file \
|
||||||
|
--source_prefix "summarize: " \
|
||||||
--output_dir /tmp/tst-summarization \
|
--output_dir /tmp/tst-summarization \
|
||||||
--overwrite_output_dir \
|
--overwrite_output_dir \
|
||||||
--per_device_train_batch_size=4 \
|
--per_device_train_batch_size=4 \
|
||||||
--per_device_eval_batch_size=4 \
|
--per_device_eval_batch_size=4 \
|
||||||
--predict_with_generate \
|
--predict_with_generate
|
||||||
--max_train_samples 500 \
|
|
||||||
--max_val_samples 500
|
|
||||||
```
|
```
|
||||||
|
|
||||||
The task of summarization supports custom CSV and JSONLINES formats.
|
The task of summarization supports custom CSV and JSONLINES formats.
|
||||||
@ -135,11 +138,11 @@ And as with the CSV files, you can specify which values to select from the file,
|
|||||||
|
|
||||||
### Translation
|
### Translation
|
||||||
|
|
||||||
Here is an example of a translation fine-tuning with T5:
|
Here is an example of a translation fine-tuning with a MarianMT model:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python examples/seq2seq/run_translation.py \
|
python examples/seq2seq/run_translation.py \
|
||||||
--model_name_or_path t5-small \
|
--model_name_or_path Helsinki-NLP/opus-mt-en-ro \
|
||||||
--do_train \
|
--do_train \
|
||||||
--do_eval \
|
--do_eval \
|
||||||
--source_lang en \
|
--source_lang en \
|
||||||
@ -150,12 +153,35 @@ python examples/seq2seq/run_translation.py \
|
|||||||
--per_device_train_batch_size=4 \
|
--per_device_train_batch_size=4 \
|
||||||
--per_device_eval_batch_size=4 \
|
--per_device_eval_batch_size=4 \
|
||||||
--overwrite_output_dir \
|
--overwrite_output_dir \
|
||||||
--predict_with_generate \
|
--predict_with_generate
|
||||||
--max_train_samples 500 \
|
|
||||||
--max_val_samples 500
|
|
||||||
```
|
```
|
||||||
|
|
||||||
And the same with MBart:
|
MBart and some T5 models require special handling.
|
||||||
|
|
||||||
|
T5 models `t5-small`, `t5-base`, `t5-large`, `t5-3b` and `t5-11b` must use an additional argument: `--source_prefix "translate {source_lang} to {target_lang}"`. For example:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python examples/seq2seq/run_translation.py \
|
||||||
|
--model_name_or_path t5-small \
|
||||||
|
--do_train \
|
||||||
|
--do_eval \
|
||||||
|
--source_lang en \
|
||||||
|
--target_lang ro \
|
||||||
|
--source_prefix "translate English to Romanian: " \
|
||||||
|
--dataset_name wmt16 \
|
||||||
|
--dataset_config_name ro-en \
|
||||||
|
--output_dir /tmp/tst-translation \
|
||||||
|
--per_device_train_batch_size=4 \
|
||||||
|
--per_device_eval_batch_size=4 \
|
||||||
|
--overwrite_output_dir \
|
||||||
|
--predict_with_generate
|
||||||
|
```
|
||||||
|
|
||||||
|
If you get a terrible BLEU score, make sure that you didn't forget to use the `--source_prefix` argument.
|
||||||
|
|
||||||
|
For the aforementioned group of T5 models it's important to remember that if you switch to a different language pair, make sure to adjust the source and target values in all 3 language-specific command line argument: `--source_lang`, `--target_lang` and `--source_prefix`.
|
||||||
|
|
||||||
|
MBart models require a different format for `--source_lang` and `--target_lang` values, e.g. instead of `en` it expects `en_XX`, for `ro` it expects `ro_RO`. The full MBart specification for language codes can be found [here](https://huggingface.co/facebook/mbart-large-cc25). For example:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python examples/seq2seq/run_translation.py \
|
python examples/seq2seq/run_translation.py \
|
||||||
@ -170,18 +196,9 @@ python examples/seq2seq/run_translation.py \
|
|||||||
--per_device_train_batch_size=4 \
|
--per_device_train_batch_size=4 \
|
||||||
--per_device_eval_batch_size=4 \
|
--per_device_eval_batch_size=4 \
|
||||||
--overwrite_output_dir \
|
--overwrite_output_dir \
|
||||||
--predict_with_generate \
|
--predict_with_generate
|
||||||
--max_train_samples 500 \
|
|
||||||
--max_val_samples 500
|
|
||||||
```
|
```
|
||||||
|
|
||||||
Note, that depending on the used model additional language-specific command-line arguments are sometimes required. Specifically:
|
|
||||||
|
|
||||||
* MBart models require different `--{source,target}_lang` values, e.g. in place of `en` it expects `en_XX`, for `ro` it expects `ro_RO`. The full MBart specification for language codes can be looked up [here](https://huggingface.co/facebook/mbart-large-cc25)
|
|
||||||
* T5 models can use a `--source_prefix` argument to override the otherwise automated prefix of the form `translate {source_lang} to {target_lang}` for `run_translation.py` and `summarize: ` for `run_summarization.py`
|
|
||||||
|
|
||||||
Also, if you switch to a different language pair, make sure to adjust the source and target values in all command line arguments.
|
|
||||||
|
|
||||||
And here is how you would use the translation finetuning on your own files, after adjusting the
|
And here is how you would use the translation finetuning on your own files, after adjusting the
|
||||||
values for the arguments `--train_file`, `--validation_file` to match your setup:
|
values for the arguments `--train_file`, `--validation_file` to match your setup:
|
||||||
|
|
||||||
@ -192,6 +209,7 @@ python examples/seq2seq/run_translation.py \
|
|||||||
--do_eval \
|
--do_eval \
|
||||||
--source_lang en \
|
--source_lang en \
|
||||||
--target_lang ro \
|
--target_lang ro \
|
||||||
|
--source_prefix "translate English to Romanian: " \
|
||||||
--dataset_name wmt16 \
|
--dataset_name wmt16 \
|
||||||
--dataset_config_name ro-en \
|
--dataset_config_name ro-en \
|
||||||
--train_file path_to_jsonlines_file \
|
--train_file path_to_jsonlines_file \
|
||||||
@ -200,9 +218,7 @@ python examples/seq2seq/run_translation.py \
|
|||||||
--per_device_train_batch_size=4 \
|
--per_device_train_batch_size=4 \
|
||||||
--per_device_eval_batch_size=4 \
|
--per_device_eval_batch_size=4 \
|
||||||
--overwrite_output_dir \
|
--overwrite_output_dir \
|
||||||
--predict_with_generate \
|
--predict_with_generate
|
||||||
--max_train_samples 500 \
|
|
||||||
--max_val_samples 500
|
|
||||||
```
|
```
|
||||||
|
|
||||||
The task of translation supports only custom JSONLINES files, with each line being a dictionary with a key `"translation"` and its value another dictionary whose keys is the language pair. For example:
|
The task of translation supports only custom JSONLINES files, with each line being a dictionary with a key `"translation"` and its value another dictionary whose keys is the language pair. For example:
|
||||||
@ -213,7 +229,7 @@ The task of translation supports only custom JSONLINES files, with each line bei
|
|||||||
```
|
```
|
||||||
Here the languages are Romanian (`ro`) and English (`en`).
|
Here the languages are Romanian (`ro`) and English (`en`).
|
||||||
|
|
||||||
If you want to use a pre-processed dataset that leads to high bleu scores, but for the `en-de` language pair, you can use `--dataset_name wmt14-en-de-pre-processed`, as following:
|
If you want to use a pre-processed dataset that leads to high BLEU scores, but for the `en-de` language pair, you can use `--dataset_name stas/wmt14-en-de-pre-processed`, as following:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python examples/seq2seq/run_translation.py \
|
python examples/seq2seq/run_translation.py \
|
||||||
@ -222,12 +238,11 @@ python examples/seq2seq/run_translation.py \
|
|||||||
--do_eval \
|
--do_eval \
|
||||||
--source_lang en \
|
--source_lang en \
|
||||||
--target_lang de \
|
--target_lang de \
|
||||||
--dataset_name wmt14-en-de-pre-processed \
|
--source_prefix "translate English to German: " \
|
||||||
|
--dataset_name stas/wmt14-en-de-pre-processed \
|
||||||
--output_dir /tmp/tst-translation \
|
--output_dir /tmp/tst-translation \
|
||||||
--per_device_train_batch_size=4 \
|
--per_device_train_batch_size=4 \
|
||||||
--per_device_eval_batch_size=4 \
|
--per_device_eval_batch_size=4 \
|
||||||
--overwrite_output_dir \
|
--overwrite_output_dir \
|
||||||
--predict_with_generate \
|
--predict_with_generate
|
||||||
--max_train_samples 500 \
|
|
||||||
--max_val_samples 500
|
|
||||||
```
|
```
|
||||||
|
Loading…
Reference in New Issue
Block a user