transformers/examples/flax/summarization
Alex Hedges 95091e1582
Set cache_dir for evaluate.load() in example scripts (#28422)
While using `run_clm.py`,[^1] I noticed that some files were being added
to my global cache, not the local cache. I set the `cache_dir` parameter
for the one call to `evaluate.load()`, which partially solved the
problem. I figured that while I was fixing the one script upstream, I
might as well fix the problem in all other example scripts that I could.

There are still some files being added to my global cache, but this
appears to be a bug in `evaluate` itself. This commit at least moves
some of the files into the local cache, which is better than before.

To create this PR, I made the following regex-based transformation:
`evaluate\.load\((.*?)\)` -> `evaluate\.load\($1,
cache_dir=model_args.cache_dir\)`. After using that, I manually fixed
all modified files with `ruff` serving as useful guidance. During the
process, I removed one existing usage of the `cache_dir` parameter in a
script that did not have a corresponding `--cache-dir` argument
declared.

[^1]: I specifically used `pytorch/language-modeling/run_clm.py` from
v4.34.1 of the library. For the original code, see the following URL:
acc394c4f5/examples/pytorch/language-modeling/run_clm.py.
2024-01-11 15:38:44 +01:00
..
README.md Broken links fixed related to datasets docs (#27569) 2023-11-17 13:44:09 -08:00
requirements.txt Fix ROUGE add example check and update README (#18398) 2022-08-01 11:14:49 -04:00
run_summarization_flax.py Set cache_dir for evaluate.load() in example scripts (#28422) 2024-01-11 15:38:44 +01:00

Summarization (Seq2Seq model) training examples

The following example showcases how to finetune a sequence-to-sequence model for summarization using the JAX/Flax backend.

JAX/Flax allows you to trace pure functions and compile them into efficient, fused accelerator code on both GPU and TPU. Models written in JAX/Flax are immutable and updated in a purely functional way which enables simple and efficient model parallelism.

run_summarization_flax.py is a lightweight example of how to download and preprocess a dataset from the 🤗 Datasets library or use your own files (jsonlines or csv), then fine-tune one of the architectures above on it.

For custom datasets in jsonlines format please see: https://huggingface.co/docs/datasets/loading_datasets#json-files and you also will find examples of these below.

Train the model

Next we can run the example script to train the model:

python run_summarization_flax.py \
	--output_dir ./bart-base-xsum \
	--model_name_or_path facebook/bart-base \
	--tokenizer_name facebook/bart-base \
	--dataset_name="xsum" \
	--do_train --do_eval --do_predict --predict_with_generate \
	--num_train_epochs 6 \
	--learning_rate 5e-5 --warmup_steps 0 \
	--per_device_train_batch_size 64 \
	--per_device_eval_batch_size 64 \
	--overwrite_output_dir \
	--max_source_length 512 --max_target_length 64 \
	--push_to_hub

This should finish in 37min, with validation loss and ROUGE2 score of 1.7785 and 17.01 respectively after 6 epochs. training statistics can be accessed on tfhub.de.

Note that here we used default generate arguments, using arguments specific for xsum dataset should give better ROUGE scores.