mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
Fixes to run_seq2seq and instructions (#9734)
* Fixes to run_seq2seq and instructions * Add more defaults for summarization
This commit is contained in:
parent
d7c31abf38
commit
411c582109
@ -22,14 +22,98 @@ For deprecated `bertabs` instructions, see [`bertabs/README.md`](https://github.
|
||||
|
||||
### Supported Architectures
|
||||
|
||||
- `BartForConditionalGeneration` (and anything that inherits from it)
|
||||
- `BartForConditionalGeneration`
|
||||
- `MarianMTModel`
|
||||
- `PegasusForConditionalGeneration`
|
||||
- `MBartForConditionalGeneration`
|
||||
- `FSMTForConditionalGeneration`
|
||||
- `T5ForConditionalGeneration`
|
||||
|
||||
## Datasets
|
||||
This directory is in a bit of messy state and is undergoing some cleaning, please bare with us in the meantime :-) Here are the instructions to use the new and old scripts for fine-tuning sequence-to-sequence models.
|
||||
|
||||
## New script
|
||||
|
||||
The new script for fine-tuning a model on a summarization or translation task is `run_seq2seq.py`. It is a lightweight example of how to download and preprocess a dataset from the [🤗 Datasets](https://github.com/huggingface/datasets) library or use your own files (json or csv), then fine-tune one of the architectures above on it.
|
||||
|
||||
Here is an example on a summarization task:
|
||||
```bash
|
||||
python examples/seq2seq/run_seq2seq.py \
|
||||
--model_name_or_path t5-small \
|
||||
--do_train \
|
||||
--do_eval \
|
||||
--task summarization \
|
||||
--dataset_name xsum \
|
||||
--output_dir ~/tmp/tst-summarization \
|
||||
--per_device_train_batch_size=4 \
|
||||
--per_device_eval_batch_size=4 \
|
||||
--overwrite_output_dir \
|
||||
--predict_with_generate
|
||||
```
|
||||
|
||||
And here is how you would use it on your own files (replace `path_to_csv_or_json_file`, `text_column_name` and `summary_column_name` by the relevant values):
|
||||
```bash
|
||||
python examples/seq2seq/run_seq2seq.py \
|
||||
-model_name_or_path t5-small \
|
||||
--do_train \
|
||||
--do_eval \
|
||||
--task summarization \
|
||||
--train_file path_to_csv_or_json_file \
|
||||
--validation_file path_to_csv_or_json_file \
|
||||
--output_dir ~/tmp/tst-summarization \
|
||||
--overwrite_output_dir \
|
||||
--per_device_train_batch_size=4 \
|
||||
--per_device_eval_batch_size=4 \
|
||||
--predict_with_generate \
|
||||
--text_column text_column_name \
|
||||
--summary_column summary_column_name
|
||||
```
|
||||
The training and validation files should have a column for the inputs texts and a column for the summaries.
|
||||
|
||||
Here is an example of a translation fine-tuning:
|
||||
```bash
|
||||
python examples/seq2seq/run_seq2seq.py \
|
||||
--model_name_or_path sshleifer/student_marian_en_ro_6_1 \
|
||||
--do_train \
|
||||
--do_eval \
|
||||
--task translation_en_to_ro \
|
||||
--dataset_name wmt16 \
|
||||
--dataset_config_name ro-en \
|
||||
--source_lang en-XX \
|
||||
--target_lang ro-RO\
|
||||
--output_dir ~/tmp/tst-translation \
|
||||
--per_device_train_batch_size=4 \
|
||||
--per_device_eval_batch_size=4 \
|
||||
--overwrite_output_dir \
|
||||
--predict_with_generate
|
||||
```
|
||||
|
||||
And here is how you would use it on your own files (replace `path_to_json_file`, by the relevant values):
|
||||
```bash
|
||||
python examples/seq2seq/run_seq2seq.py \
|
||||
--model_name_or_path sshleifer/student_marian_en_ro_6_1 \
|
||||
--do_train \
|
||||
--do_eval \
|
||||
--task translation_en_to_ro \
|
||||
--dataset_name wmt16 \
|
||||
--dataset_config_name ro-en \
|
||||
--source_lang en-XX \
|
||||
--target_lang ro-RO\
|
||||
--train_file path_to_json_file \
|
||||
--validation_file path_to_json_file \
|
||||
--output_dir ~/tmp/tst-translation \
|
||||
--per_device_train_batch_size=4 \
|
||||
--per_device_eval_batch_size=4 \
|
||||
--overwrite_output_dir \
|
||||
--predict_with_generate
|
||||
```
|
||||
Here the files are expected to be JSON files, with each input being a dictionary with a key `"translation"` containing one key per language (here `"en"` and `"ro"`).
|
||||
|
||||
## Old script
|
||||
|
||||
The new script is very new and hasn't been widely tested yet. It also misses a few functionality offered by the old
|
||||
script, which is why we are leaving the old script here for now.
|
||||
|
||||
### Downlowd the Datasets
|
||||
|
||||
#### XSUM
|
||||
|
||||
|
@ -136,10 +136,10 @@ class DataTrainingArguments:
|
||||
},
|
||||
)
|
||||
val_max_target_length: Optional[int] = field(
|
||||
default=142,
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "The maximum total sequence length for validation target text after tokenization. Sequences longer "
|
||||
"than this will be truncated, sequences shorter will be padded. "
|
||||
"than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`."
|
||||
"This argument is also used to override the ``max_length`` param of ``model.generate``, which is used "
|
||||
"during ``evaluate`` and ``predict``."
|
||||
},
|
||||
@ -175,6 +175,9 @@ class DataTrainingArguments:
|
||||
"help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not."
|
||||
},
|
||||
)
|
||||
source_prefix: Optional[str] = field(
|
||||
default=None, metadata={"help": "A prefix to add before every source text (useful for T5 models)."}
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
if self.dataset_name is None and self.train_file is None and self.validation_file is None:
|
||||
@ -190,11 +193,22 @@ class DataTrainingArguments:
|
||||
raise ValueError(
|
||||
"`task` should be summarization, summarization_{dataset}, translation or translation_{xx}_to_{yy}."
|
||||
)
|
||||
if self.val_max_target_length is None:
|
||||
self.val_max_target_length = self.max_target_length
|
||||
|
||||
|
||||
summarization_name_mapping = {
|
||||
"amazon_reviews_multi": ("review_body", "review_title"),
|
||||
"big_patent": ("description", "abstract"),
|
||||
"cnn_dailymail": ("article", "highlights"),
|
||||
"orange_sum": ("text", "summary"),
|
||||
"pn_summary": ("article", "summary"),
|
||||
"psc": ("extract_text", "summary_text"),
|
||||
"samsum": ("dialogue", "summary"),
|
||||
"thaisum": ("body", "summary"),
|
||||
"xglue": ("news_body", "news_title"),
|
||||
"xsum": ("document", "summary"),
|
||||
"wiki_summary": ("article", "highlights"),
|
||||
}
|
||||
|
||||
|
||||
@ -302,6 +316,16 @@ def main():
|
||||
if model.config.decoder_start_token_id is None:
|
||||
raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined")
|
||||
|
||||
# Get the default prefix if None is passed.
|
||||
if data_args.source_prefix is None:
|
||||
task_specific_params = model.config.task_specific_params
|
||||
if task_specific_params is not None:
|
||||
prefix = task_specific_params.get("prefix", "")
|
||||
else:
|
||||
prefix = ""
|
||||
else:
|
||||
prefix = data_args.source_prefix
|
||||
|
||||
# Preprocessing the datasets.
|
||||
# We need to tokenize inputs and targets.
|
||||
if training_args.do_train:
|
||||
@ -362,6 +386,7 @@ def main():
|
||||
else:
|
||||
inputs = examples[text_column]
|
||||
targets = examples[summary_column]
|
||||
inputs = [prefix + inp for inp in inputs]
|
||||
model_inputs = tokenizer(inputs, max_length=data_args.max_source_length, padding=padding, truncation=True)
|
||||
|
||||
# Setup the tokenizer for targets
|
||||
|
Loading…
Reference in New Issue
Block a user