Fixes to run_seq2seq and instructions (#9734)

* Fixes to run_seq2seq and instructions

* Add more defaults for summarization
This commit is contained in:
Sylvain Gugger 2021-01-22 10:03:57 -05:00 committed by GitHub
parent d7c31abf38
commit 411c582109
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 113 additions and 4 deletions

View File

@ -22,14 +22,98 @@ For deprecated `bertabs` instructions, see [`bertabs/README.md`](https://github.
### Supported Architectures
- `BartForConditionalGeneration` (and anything that inherits from it)
- `BartForConditionalGeneration`
- `MarianMTModel`
- `PegasusForConditionalGeneration`
- `MBartForConditionalGeneration`
- `FSMTForConditionalGeneration`
- `T5ForConditionalGeneration`
## Datasets
This directory is in a bit of messy state and is undergoing some cleaning, please bare with us in the meantime :-) Here are the instructions to use the new and old scripts for fine-tuning sequence-to-sequence models.
## New script
The new script for fine-tuning a model on a summarization or translation task is `run_seq2seq.py`. It is a lightweight example of how to download and preprocess a dataset from the [🤗 Datasets](https://github.com/huggingface/datasets) library or use your own files (json or csv), then fine-tune one of the architectures above on it.
Here is an example on a summarization task:
```bash
python examples/seq2seq/run_seq2seq.py \
--model_name_or_path t5-small \
--do_train \
--do_eval \
--task summarization \
--dataset_name xsum \
--output_dir ~/tmp/tst-summarization \
--per_device_train_batch_size=4 \
--per_device_eval_batch_size=4 \
--overwrite_output_dir \
--predict_with_generate
```
And here is how you would use it on your own files (replace `path_to_csv_or_json_file`, `text_column_name` and `summary_column_name` by the relevant values):
```bash
python examples/seq2seq/run_seq2seq.py \
-model_name_or_path t5-small \
--do_train \
--do_eval \
--task summarization \
--train_file path_to_csv_or_json_file \
--validation_file path_to_csv_or_json_file \
--output_dir ~/tmp/tst-summarization \
--overwrite_output_dir \
--per_device_train_batch_size=4 \
--per_device_eval_batch_size=4 \
--predict_with_generate \
--text_column text_column_name \
--summary_column summary_column_name
```
The training and validation files should have a column for the inputs texts and a column for the summaries.
Here is an example of a translation fine-tuning:
```bash
python examples/seq2seq/run_seq2seq.py \
--model_name_or_path sshleifer/student_marian_en_ro_6_1 \
--do_train \
--do_eval \
--task translation_en_to_ro \
--dataset_name wmt16 \
--dataset_config_name ro-en \
--source_lang en-XX \
--target_lang ro-RO\
--output_dir ~/tmp/tst-translation \
--per_device_train_batch_size=4 \
--per_device_eval_batch_size=4 \
--overwrite_output_dir \
--predict_with_generate
```
And here is how you would use it on your own files (replace `path_to_json_file`, by the relevant values):
```bash
python examples/seq2seq/run_seq2seq.py \
--model_name_or_path sshleifer/student_marian_en_ro_6_1 \
--do_train \
--do_eval \
--task translation_en_to_ro \
--dataset_name wmt16 \
--dataset_config_name ro-en \
--source_lang en-XX \
--target_lang ro-RO\
--train_file path_to_json_file \
--validation_file path_to_json_file \
--output_dir ~/tmp/tst-translation \
--per_device_train_batch_size=4 \
--per_device_eval_batch_size=4 \
--overwrite_output_dir \
--predict_with_generate
```
Here the files are expected to be JSON files, with each input being a dictionary with a key `"translation"` containing one key per language (here `"en"` and `"ro"`).
## Old script
The new script is very new and hasn't been widely tested yet. It also misses a few functionality offered by the old
script, which is why we are leaving the old script here for now.
### Downlowd the Datasets
#### XSUM

View File

@ -136,10 +136,10 @@ class DataTrainingArguments:
},
)
val_max_target_length: Optional[int] = field(
default=142,
default=None,
metadata={
"help": "The maximum total sequence length for validation target text after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded. "
"than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`."
"This argument is also used to override the ``max_length`` param of ``model.generate``, which is used "
"during ``evaluate`` and ``predict``."
},
@ -175,6 +175,9 @@ class DataTrainingArguments:
"help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not."
},
)
source_prefix: Optional[str] = field(
default=None, metadata={"help": "A prefix to add before every source text (useful for T5 models)."}
)
def __post_init__(self):
if self.dataset_name is None and self.train_file is None and self.validation_file is None:
@ -190,11 +193,22 @@ class DataTrainingArguments:
raise ValueError(
"`task` should be summarization, summarization_{dataset}, translation or translation_{xx}_to_{yy}."
)
if self.val_max_target_length is None:
self.val_max_target_length = self.max_target_length
summarization_name_mapping = {
"amazon_reviews_multi": ("review_body", "review_title"),
"big_patent": ("description", "abstract"),
"cnn_dailymail": ("article", "highlights"),
"orange_sum": ("text", "summary"),
"pn_summary": ("article", "summary"),
"psc": ("extract_text", "summary_text"),
"samsum": ("dialogue", "summary"),
"thaisum": ("body", "summary"),
"xglue": ("news_body", "news_title"),
"xsum": ("document", "summary"),
"wiki_summary": ("article", "highlights"),
}
@ -302,6 +316,16 @@ def main():
if model.config.decoder_start_token_id is None:
raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined")
# Get the default prefix if None is passed.
if data_args.source_prefix is None:
task_specific_params = model.config.task_specific_params
if task_specific_params is not None:
prefix = task_specific_params.get("prefix", "")
else:
prefix = ""
else:
prefix = data_args.source_prefix
# Preprocessing the datasets.
# We need to tokenize inputs and targets.
if training_args.do_train:
@ -362,6 +386,7 @@ def main():
else:
inputs = examples[text_column]
targets = examples[summary_column]
inputs = [prefix + inp for inp in inputs]
model_inputs = tokenizer(inputs, max_length=data_args.max_source_length, padding=padding, truncation=True)
# Setup the tokenizer for targets