transformers/examples/seq2seq
2021-02-16 09:39:37 -05:00
..
README.md [WIP][examples/seq2seq] move old s2s scripts to legacy (#10136) 2021-02-15 10:48:02 -08:00
requirements.txt [WIP][examples/seq2seq] move old s2s scripts to legacy (#10136) 2021-02-15 10:48:02 -08:00
run_seq2seq.py set tgt_lang of MBart Tokenizer for summarization (#10205) 2021-02-16 09:39:37 -05:00

Sequence to Sequence Training and Evaluation

This directory contains examples for finetuning and evaluating transformers on summarization and translation tasks. Please tag @patil-suraj with any issues/unexpected behaviors, or send a PR! For deprecated bertabs instructions, see bertabs/README.md. For the old finetune_trainer.py and related utils, see examples/legacy/seq2seq.

Supported Architectures

  • BartForConditionalGeneration
  • MarianMTModel
  • PegasusForConditionalGeneration
  • MBartForConditionalGeneration
  • FSMTForConditionalGeneration
  • T5ForConditionalGeneration

run_seq2seq.py is a lightweight example of how to download and preprocess a dataset from the 🤗 Datasets library or use your own files (jsonlines or csv), then fine-tune one of the architectures above on it.

For custom datasets in jsonlines format please see: https://huggingface.co/docs/datasets/loading_datasets.html#json-files

Here is an example on a summarization task:

python examples/seq2seq/run_seq2seq.py \
    --model_name_or_path t5-small \
    --do_train \
    --do_eval \
    --task summarization \
    --dataset_name xsum \
    --output_dir ~/tmp/tst-summarization \
    --per_device_train_batch_size=4 \
    --per_device_eval_batch_size=4 \
    --overwrite_output_dir \
    --predict_with_generate

And here is how you would use it on your own files (replace path_to_csv_or_jsonlines_file, text_column_name and summary_column_name by the relevant values):

python examples/seq2seq/run_seq2seq.py \
    --model_name_or_path t5-small \
    --do_train \
    --do_eval \
    --task summarization \
    --train_file path_to_csv_or_jsonlines_file \
    --validation_file path_to_csv_or_jsonlines_file \
    --output_dir ~/tmp/tst-summarization \
    --overwrite_output_dir \
    --per_device_train_batch_size=4 \
    --per_device_eval_batch_size=4 \
    --predict_with_generate \
    --text_column text_column_name \
    --summary_column summary_column_name

The training and validation files should have a column for the inputs texts and a column for the summaries.

Here is an example of a translation fine-tuning:

python examples/seq2seq/run_seq2seq.py \
    --model_name_or_path sshleifer/student_marian_en_ro_6_1 \
    --do_train \
    --do_eval \
    --task translation_en_to_ro \
    --dataset_name wmt16 \
    --dataset_config_name ro-en \
    --source_lang en_XX \
    --target_lang ro_RO\
    --output_dir ~/tmp/tst-translation \
    --per_device_train_batch_size=4 \
    --per_device_eval_batch_size=4 \
    --overwrite_output_dir \
    --predict_with_generate

And here is how you would use it on your own files (replace path_to_jsonlines_file, by the relevant values):

python examples/seq2seq/run_seq2seq.py \
    --model_name_or_path sshleifer/student_marian_en_ro_6_1 \
    --do_train \
    --do_eval \
    --task translation_en_to_ro \
    --dataset_name wmt16 \
    --dataset_config_name ro-en \
    --source_lang en_XX \
    --target_lang ro_RO\
    --train_file path_to_jsonlines_file \
    --validation_file path_to_jsonlines_file \
    --output_dir ~/tmp/tst-translation \
    --per_device_train_batch_size=4 \
    --per_device_eval_batch_size=4 \
    --overwrite_output_dir \
    --predict_with_generate

Here the files are expected to be JSONLINES files, with each input being a dictionary with a key "translation" containing one key per language (here "en" and "ro").