mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-17 03:28:22 +06:00

* wav2vec2: support datasets other than LibriSpeech * Formatting run_asr.py to pass code quality test * bundled orthography options and added verbose logs * fixing a typo in timit fine-tuning script * update comment for clarity * resize_lm_head and load custom vocab from file * adding a max_duration_in_seconds filter * do not assign `duration_filter` lambda, use a def * log untransliterated text as well * fix base model for arabic * fix duration filter when target_sr is not set * drop duration_in_seconds when unneeded * script for wav2vec2-large-lv60-timit-asr * fix for "tha" in arabic corpus (huggingface#10581) * adding more options to work with common_voice * PR feedback (huggingface#10581) * small README change
130 lines
5.6 KiB
Markdown
130 lines
5.6 KiB
Markdown
## Fine-tuning Wav2Vec2
|
|
|
|
The `run_asr.py` script allows one to fine-tune pretrained Wav2Vec2 models that can be found [here](https://huggingface.co/models?search=facebook/wav2vec2).
|
|
|
|
This finetuning script can also be run as a google colab [TODO: here]( ).
|
|
|
|
The script is actively maintained by [Patrick von Platen](https://github.com/patrickvonplaten).
|
|
Feel free to ask a question on the [Forum](https://discuss.huggingface.co/) or post an issue on [GitHub](https://github.com/huggingface/transformers/issues/new/choose) and adding `@patrickvonplaten` as a tag.
|
|
|
|
### Fine-Tuning with TIMIT
|
|
Let's take a look at the [script](./finetune_base_timit_asr.sh) used to fine-tune [wav2vec2-base](https://huggingface.co/facebook/wav2vec2-base)
|
|
with the [TIMIT dataset](https://huggingface.co/datasets/timit_asr):
|
|
|
|
```bash
|
|
#!/usr/bin/env bash
|
|
python run_asr.py \
|
|
--output_dir="./wav2vec2-base-timit-asr" \
|
|
--num_train_epochs="30" \
|
|
--per_device_train_batch_size="20" \
|
|
--per_device_eval_batch_size="20" \
|
|
--evaluation_strategy="steps" \
|
|
--save_steps="500" \
|
|
--eval_steps="100" \
|
|
--logging_steps="50" \
|
|
--learning_rate="5e-4" \
|
|
--warmup_steps="3000" \
|
|
--model_name_or_path="facebook/wav2vec2-base" \
|
|
--fp16 \
|
|
--dataset_name="timit_asr" \
|
|
--train_split_name="train" \
|
|
--validation_split_name="test" \
|
|
--orthography="timit" \
|
|
--preprocessing_num_workers="$(nproc)" \
|
|
--group_by_length \
|
|
--freeze_feature_extractor \
|
|
--verbose_logging \
|
|
```
|
|
|
|
The resulting model and inference examples can be found [here](https://huggingface.co/elgeish/wav2vec2-base-timit-asr).
|
|
Some of the arguments above may look unfamiliar, let's break down what's going on:
|
|
|
|
`--orthography="timit"` applies certain text preprocessing rules, for tokenization and normalization, to clean up the dataset.
|
|
In this case, we use the following instance of `Orthography`:
|
|
|
|
```python
|
|
Orthography(
|
|
do_lower_case=True,
|
|
# break compounds like "quarter-century-old" and replace pauses "--"
|
|
translation_table=str.maketrans({"-": " "}),
|
|
)
|
|
```
|
|
|
|
The instance above is used as follows:
|
|
* creates a tokenizer with `do_lower_case=True` (ignores casing for input and lowercases output when decoding)
|
|
* replaces `"-"` with `" "` to break compounds like `"quarter-century-old"` and to clean up suspended hyphens
|
|
* cleans up consecutive whitespaces (replaces them with a single space: `" "`)
|
|
* removes characters not in vocabulary (lacking respective sound units)
|
|
|
|
`--verbose_logging` logs text preprocessing updates and when evaluating, using the validation split every `eval_steps`,
|
|
logs references and predictions.
|
|
|
|
### Fine-Tuning with Arabic Speech Corpus
|
|
|
|
Other datasets, like the [Arabic Speech Corpus dataset](https://huggingface.co/datasets/arabic_speech_corpus),
|
|
require more work! Let's take a look at the [script](./finetune_large_xlsr_53_arabic_speech_corpus.sh)
|
|
used to fine-tune [wav2vec2-large-xlsr-53](https://huggingface.co/elgeish/wav2vec2-large-xlsr-53-arabic):
|
|
|
|
```bash
|
|
#!/usr/bin/env bash
|
|
python run_asr.py \
|
|
--output_dir="./wav2vec2-large-xlsr-53-arabic-speech-corpus" \
|
|
--num_train_epochs="50" \
|
|
--per_device_train_batch_size="1" \
|
|
--per_device_eval_batch_size="1" \
|
|
--gradient_accumulation_steps="8" \
|
|
--evaluation_strategy="steps" \
|
|
--save_steps="500" \
|
|
--eval_steps="100" \
|
|
--logging_steps="50" \
|
|
--learning_rate="5e-4" \
|
|
--warmup_steps="3000" \
|
|
--model_name_or_path="elgeish/wav2vec2-large-xlsr-53-arabic" \
|
|
--fp16 \
|
|
--dataset_name="arabic_speech_corpus" \
|
|
--train_split_name="train" \
|
|
--validation_split_name="test" \
|
|
--max_duration_in_seconds="15" \
|
|
--orthography="buckwalter" \
|
|
--preprocessing_num_workers="$(nproc)" \
|
|
--group_by_length \
|
|
--freeze_feature_extractor \
|
|
--target_feature_extractor_sampling_rate \
|
|
--verbose_logging \
|
|
```
|
|
|
|
First, let's understand how this dataset represents Arabic text; it uses a format called
|
|
[Buckwalter transliteration](https://en.wikipedia.org/wiki/Buckwalter_transliteration).
|
|
We use the [lang-trans](https://github.com/kariminf/lang-trans) package to convert back to Arabic when logging.
|
|
The Buckwalter format only includes ASCII characters, some of which are non-alpha (e.g., `">"` maps to `"أ"`).
|
|
|
|
`--orthography="buckwalter"` applies certain text preprocessing rules, for tokenization and normalization, to clean up the dataset. In this case, we use the following instance of `Orthography`:
|
|
|
|
```python
|
|
Orthography(
|
|
vocab_file=pathlib.Path(__file__).parent.joinpath("vocab/buckwalter.json"),
|
|
word_delimiter_token="/", # "|" is Arabic letter alef with madda above
|
|
words_to_remove={"sil"}, # fixing "sil" in arabic_speech_corpus dataset
|
|
untransliterator=arabic.buckwalter.untransliterate,
|
|
translation_table=str.maketrans(translation_table = {
|
|
"-": " ", # sometimes used to represent pauses
|
|
"^": "v", # fixing "tha" in arabic_speech_corpus dataset
|
|
}),
|
|
)
|
|
```
|
|
|
|
The instance above is used as follows:
|
|
* creates a tokenizer with Buckwalter vocabulary and `word_delimiter_token="/"`
|
|
* replaces `"-"` with `" "` to clean up hyphens and fixes the orthography for `"ث"`
|
|
* removes words used as indicators (in this case, `"sil"` is used for silence)
|
|
* cleans up consecutive whitespaces (replaces them with a single space: `" "`)
|
|
* removes characters not in vocabulary (lacking respective sound units)
|
|
|
|
`--verbose_logging` logs text preprocessing updates and when evaluating, using the validation split every `eval_steps`,
|
|
logs references and predictions. Using the Buckwalter format, text is also logged in Arabic abjad.
|
|
|
|
`--target_feature_extractor_sampling_rate` resamples audio to target feature extractor's sampling rate (16kHz).
|
|
|
|
`--max_duration_in_seconds="15"` filters out examples whose audio is longer than the specified limit,
|
|
which helps with capping GPU memory usage.
|