transformers/examples/flax/speech-recognition
Albert Villanova del Moral a14b055b65
Pass datasets trust_remote_code (#31406)
* Pass datasets trust_remote_code

* Pass trust_remote_code in more tests

* Add trust_remote_dataset_code arg to some tests

* Revert "Temporarily pin datasets upper version to fix CI"

This reverts commit b7672826ca.

* Pass trust_remote_code in librispeech_asr_dummy docstrings

* Revert "Pin datasets<2.20.0 for examples"

This reverts commit 833fc17a3e.

* Pass trust_remote_code to all examples

* Revert "Add trust_remote_dataset_code arg to some tests" to research_projects

* Pass trust_remote_code to tests

* Pass trust_remote_code to docstrings

* Fix flax examples tests requirements

* Pass trust_remote_dataset_code arg to tests

* Replace trust_remote_dataset_code with trust_remote_code in one example

* Fix duplicate trust_remote_code

* Replace args.trust_remote_dataset_code with args.trust_remote_code

* Replace trust_remote_dataset_code with trust_remote_code in parser

* Replace trust_remote_dataset_code with trust_remote_code in dataclasses

* Replace trust_remote_dataset_code with trust_remote_code arg
2024-06-17 17:29:13 +01:00
..
README.md [Flax Examples] Seq2Seq ASR Fine-Tuning Script (#21764) 2023-09-29 16:42:58 +01:00
requirements.txt [Flax Examples] Seq2Seq ASR Fine-Tuning Script (#21764) 2023-09-29 16:42:58 +01:00
run_flax_speech_recognition_seq2seq.py Pass datasets trust_remote_code (#31406) 2024-06-17 17:29:13 +01:00

Automatic Speech Recognition - Flax Examples

Sequence to Sequence

The script run_flax_speech_recognition_seq2seq.py can be used to fine-tune any Flax Speech Sequence-to-Sequence Model for automatic speech recognition on one of the official speech recognition datasets or a custom dataset. This includes the Whisper model from OpenAI, or a warm-started Speech-Encoder-Decoder Model, an example for which is included below.

Whisper Model

We can load all components of the Whisper model directly from the pretrained checkpoint, including the pretrained model weights, feature extractor and tokenizer. We simply have to specify the id of fine-tuning dataset and the necessary training hyperparameters.

The following example shows how to fine-tune the Whisper small checkpoint on the Hindi subset of the Common Voice 13 dataset. Note that before running this script you must accept the dataset's terms of use and register your Hugging Face Hub token on your device by running huggingface-hub login.

python run_flax_speech_recognition_seq2seq.py \
	--model_name_or_path="openai/whisper-small" \
	--dataset_name="mozilla-foundation/common_voice_13_0" \
	--dataset_config_name="hi" \
	--language="hindi" \
	--train_split_name="train+validation" \
	--eval_split_name="test" \
	--output_dir="./whisper-small-hi-flax" \
	--per_device_train_batch_size="16" \
	--per_device_eval_batch_size="16" \
	--num_train_epochs="10" \
	--learning_rate="1e-4" \
	--warmup_steps="500" \
	--logging_steps="25" \
	--generation_max_length="40" \
	--preprocessing_num_workers="32" \
	--dataloader_num_workers="32" \
	--max_duration_in_seconds="30" \
	--text_column_name="sentence" \
	--overwrite_output_dir \
	--do_train \
	--do_eval \
	--predict_with_generate \
	--push_to_hub \
	--use_auth_token

On a TPU v4-8, training should take approximately 25 minutes, with a final cross-entropy loss of 0.02 and word error rate of 34%. See the checkpoint sanchit-gandhi/whisper-small-hi-flax for an example training run.