transformers/examples/seq2seq
Stas Bekman 97e688bc22
[Trainer] memory tracker metrics (#10225)
* memory tracker metrics

* go back to eval for somewhat consistency

* handle no-gpu case

* deal with stackable eval calls

* restore callback order

* style

* simplify the API

* add test

* docs

* consistently use eval_ prefix

* improve docs

* Update src/transformers/trainer_utils.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* rename method

* style

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
2021-02-18 09:27:32 -08:00
..
README.md [WIP][examples/seq2seq] move old s2s scripts to legacy (#10136) 2021-02-15 10:48:02 -08:00
requirements.txt [WIP][examples/seq2seq] move old s2s scripts to legacy (#10136) 2021-02-15 10:48:02 -08:00
run_seq2seq.py [Trainer] memory tracker metrics (#10225) 2021-02-18 09:27:32 -08:00

Sequence to Sequence Training and Evaluation

This directory contains examples for finetuning and evaluating transformers on summarization and translation tasks. Please tag @patil-suraj with any issues/unexpected behaviors, or send a PR! For deprecated bertabs instructions, see bertabs/README.md. For the old finetune_trainer.py and related utils, see examples/legacy/seq2seq.

Supported Architectures

  • BartForConditionalGeneration
  • MarianMTModel
  • PegasusForConditionalGeneration
  • MBartForConditionalGeneration
  • FSMTForConditionalGeneration
  • T5ForConditionalGeneration

run_seq2seq.py is a lightweight example of how to download and preprocess a dataset from the 🤗 Datasets library or use your own files (jsonlines or csv), then fine-tune one of the architectures above on it.

For custom datasets in jsonlines format please see: https://huggingface.co/docs/datasets/loading_datasets.html#json-files

Here is an example on a summarization task:

python examples/seq2seq/run_seq2seq.py \
    --model_name_or_path t5-small \
    --do_train \
    --do_eval \
    --task summarization \
    --dataset_name xsum \
    --output_dir ~/tmp/tst-summarization \
    --per_device_train_batch_size=4 \
    --per_device_eval_batch_size=4 \
    --overwrite_output_dir \
    --predict_with_generate

And here is how you would use it on your own files (replace path_to_csv_or_jsonlines_file, text_column_name and summary_column_name by the relevant values):

python examples/seq2seq/run_seq2seq.py \
    --model_name_or_path t5-small \
    --do_train \
    --do_eval \
    --task summarization \
    --train_file path_to_csv_or_jsonlines_file \
    --validation_file path_to_csv_or_jsonlines_file \
    --output_dir ~/tmp/tst-summarization \
    --overwrite_output_dir \
    --per_device_train_batch_size=4 \
    --per_device_eval_batch_size=4 \
    --predict_with_generate \
    --text_column text_column_name \
    --summary_column summary_column_name

The training and validation files should have a column for the inputs texts and a column for the summaries.

Here is an example of a translation fine-tuning:

python examples/seq2seq/run_seq2seq.py \
    --model_name_or_path sshleifer/student_marian_en_ro_6_1 \
    --do_train \
    --do_eval \
    --task translation_en_to_ro \
    --dataset_name wmt16 \
    --dataset_config_name ro-en \
    --source_lang en_XX \
    --target_lang ro_RO\
    --output_dir ~/tmp/tst-translation \
    --per_device_train_batch_size=4 \
    --per_device_eval_batch_size=4 \
    --overwrite_output_dir \
    --predict_with_generate

And here is how you would use it on your own files (replace path_to_jsonlines_file, by the relevant values):

python examples/seq2seq/run_seq2seq.py \
    --model_name_or_path sshleifer/student_marian_en_ro_6_1 \
    --do_train \
    --do_eval \
    --task translation_en_to_ro \
    --dataset_name wmt16 \
    --dataset_config_name ro-en \
    --source_lang en_XX \
    --target_lang ro_RO\
    --train_file path_to_jsonlines_file \
    --validation_file path_to_jsonlines_file \
    --output_dir ~/tmp/tst-translation \
    --per_device_train_batch_size=4 \
    --per_device_eval_batch_size=4 \
    --overwrite_output_dir \
    --predict_with_generate

Here the files are expected to be JSONLINES files, with each input being a dictionary with a key "translation" containing one key per language (here "en" and "ro").