
* add readme * update readme and add requirements * Update examples/flax/summarization/README.md Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
2.8 KiB
Summarization (Seq2Seq model) training examples
The following example showcases how to finetune a sequence-to-sequence model for summarization using the JAX/Flax backend.
JAX/Flax allows you to trace pure functions and compile them into efficient, fused accelerator code on both GPU and TPU. Models written in JAX/Flax are immutable and updated in a purely functional way which enables simple and efficient model parallelism.
run_summarization_flax.py
is a lightweight example of how to download and preprocess a dataset from the 🤗 Datasets library or use your own files (jsonlines or csv), then fine-tune one of the architectures above on it.
For custom datasets in jsonlines
format please see: https://huggingface.co/docs/datasets/loading_datasets.html#json-files and you also will find examples of these below.
Let's start by creating a model repository to save the trained model and logs.
Here we call the model "bart-base-xsum"
, but you can change the model name as you like.
You can do this either directly on huggingface.co (assuming that you are logged in) or via the command line:
huggingface-cli repo create bart-base-xsum
Next we clone the model repository to add the tokenizer and model files.
git clone https://huggingface.co/<your-username>/bart-base-xsum
To ensure that all tensorboard traces will be uploaded correctly, we need to track them. You can run the following command inside your model repo to do so.
cd bart-base-xsum
git lfs track "*tfevents*"
Great, we have set up our model repository. During training, we will automatically push the training logs and model weights to the repo.
Next, let's add a symbolic link to the run_summarization_flax.py
.
export MODEL_DIR="./bart-base-xsum"
ln -s ~/transformers/examples/flax/summarization/run_summarization_flax.py run_summarization_flax.py
Train the model
Next we can run the example script to train the model:
python run_summarization_flax.py \
--output_dir ${MODEL_DIR} \
--model_name_or_path facebook/bart-base \
--tokenizer_name facebook/bart-base \
--dataset_name="xsum" \
--do_train --do_eval --do_predict --predict_with_generate \
--num_train_epochs 6 \
--learning_rate 5e-5 --warmup_steps 0 \
--per_device_train_batch_size 64 \
--per_device_eval_batch_size 64 \
--overwrite_output_dir \
--max_source_length 512 --max_target_length 64 \
--push_to_hub
This should finish in 37min, with validation loss and ROUGE2 score of 1.7785 and 17.01 respectively after 6 epochs. training statistics can be accessed on tfhub.de.
Note that here we used default
generate
arguments, using arguments specific forxsum
dataset should give better ROUGE scores.