wav2vec2: support datasets other than LibriSpeech (#10581)

* wav2vec2: support datasets other than LibriSpeech

* Formatting run_asr.py to pass code quality test

* bundled orthography options and added verbose logs

* fixing a typo in timit fine-tuning script

* update comment for clarity

* resize_lm_head and load custom vocab from file

* adding a max_duration_in_seconds filter

* do not assign `duration_filter` lambda, use a def

* log untransliterated text as well

* fix base model for arabic

* fix duration filter when target_sr is not set

* drop duration_in_seconds when unneeded

* script for wav2vec2-large-lv60-timit-asr

* fix for "tha" in arabic corpus (huggingface#10581)

* adding more options to work with common_voice

* PR feedback (huggingface#10581)

* small README change
This commit is contained in:
Mohamed El-Geish 2021-03-18 00:20:26 -07:00 committed by GitHub
parent 0b98ca368f
commit af8afdc88d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 465 additions and 19 deletions

View File

@ -1,8 +1,129 @@
## Fine-tuning Wav2Vec2
The `run_training.py` script allows one to finetune pretrained Wav2Vec2 models that can be found [here](https://huggingface.co/models?search=facebook/wav2vec2).
The `run_asr.py` script allows one to fine-tune pretrained Wav2Vec2 models that can be found [here](https://huggingface.co/models?search=facebook/wav2vec2).
This finetuning script can also be run as a google colab [TODO: here]( ).
The script is actively maintained by [Patrick von Platen](https://github.com/patrickvonplaten).
The script is actively maintained by [Patrick von Platen](https://github.com/patrickvonplaten).
Feel free to ask a question on the [Forum](https://discuss.huggingface.co/) or post an issue on [GitHub](https://github.com/huggingface/transformers/issues/new/choose) and adding `@patrickvonplaten` as a tag.
### Fine-Tuning with TIMIT
Let's take a look at the [script](./finetune_base_timit_asr.sh) used to fine-tune [wav2vec2-base](https://huggingface.co/facebook/wav2vec2-base)
with the [TIMIT dataset](https://huggingface.co/datasets/timit_asr):
```bash
#!/usr/bin/env bash
python run_asr.py \
--output_dir="./wav2vec2-base-timit-asr" \
--num_train_epochs="30" \
--per_device_train_batch_size="20" \
--per_device_eval_batch_size="20" \
--evaluation_strategy="steps" \
--save_steps="500" \
--eval_steps="100" \
--logging_steps="50" \
--learning_rate="5e-4" \
--warmup_steps="3000" \
--model_name_or_path="facebook/wav2vec2-base" \
--fp16 \
--dataset_name="timit_asr" \
--train_split_name="train" \
--validation_split_name="test" \
--orthography="timit" \
--preprocessing_num_workers="$(nproc)" \
--group_by_length \
--freeze_feature_extractor \
--verbose_logging \
```
The resulting model and inference examples can be found [here](https://huggingface.co/elgeish/wav2vec2-base-timit-asr).
Some of the arguments above may look unfamiliar, let's break down what's going on:
`--orthography="timit"` applies certain text preprocessing rules, for tokenization and normalization, to clean up the dataset.
In this case, we use the following instance of `Orthography`:
```python
Orthography(
do_lower_case=True,
# break compounds like "quarter-century-old" and replace pauses "--"
translation_table=str.maketrans({"-": " "}),
)
```
The instance above is used as follows:
* creates a tokenizer with `do_lower_case=True` (ignores casing for input and lowercases output when decoding)
* replaces `"-"` with `" "` to break compounds like `"quarter-century-old"` and to clean up suspended hyphens
* cleans up consecutive whitespaces (replaces them with a single space: `" "`)
* removes characters not in vocabulary (lacking respective sound units)
`--verbose_logging` logs text preprocessing updates and when evaluating, using the validation split every `eval_steps`,
logs references and predictions.
### Fine-Tuning with Arabic Speech Corpus
Other datasets, like the [Arabic Speech Corpus dataset](https://huggingface.co/datasets/arabic_speech_corpus),
require more work! Let's take a look at the [script](./finetune_large_xlsr_53_arabic_speech_corpus.sh)
used to fine-tune [wav2vec2-large-xlsr-53](https://huggingface.co/elgeish/wav2vec2-large-xlsr-53-arabic):
```bash
#!/usr/bin/env bash
python run_asr.py \
--output_dir="./wav2vec2-large-xlsr-53-arabic-speech-corpus" \
--num_train_epochs="50" \
--per_device_train_batch_size="1" \
--per_device_eval_batch_size="1" \
--gradient_accumulation_steps="8" \
--evaluation_strategy="steps" \
--save_steps="500" \
--eval_steps="100" \
--logging_steps="50" \
--learning_rate="5e-4" \
--warmup_steps="3000" \
--model_name_or_path="elgeish/wav2vec2-large-xlsr-53-arabic" \
--fp16 \
--dataset_name="arabic_speech_corpus" \
--train_split_name="train" \
--validation_split_name="test" \
--max_duration_in_seconds="15" \
--orthography="buckwalter" \
--preprocessing_num_workers="$(nproc)" \
--group_by_length \
--freeze_feature_extractor \
--target_feature_extractor_sampling_rate \
--verbose_logging \
```
First, let's understand how this dataset represents Arabic text; it uses a format called
[Buckwalter transliteration](https://en.wikipedia.org/wiki/Buckwalter_transliteration).
We use the [lang-trans](https://github.com/kariminf/lang-trans) package to convert back to Arabic when logging.
The Buckwalter format only includes ASCII characters, some of which are non-alpha (e.g., `">"` maps to `"أ"`).
`--orthography="buckwalter"` applies certain text preprocessing rules, for tokenization and normalization, to clean up the dataset. In this case, we use the following instance of `Orthography`:
```python
Orthography(
vocab_file=pathlib.Path(__file__).parent.joinpath("vocab/buckwalter.json"),
word_delimiter_token="/", # "|" is Arabic letter alef with madda above
words_to_remove={"sil"}, # fixing "sil" in arabic_speech_corpus dataset
untransliterator=arabic.buckwalter.untransliterate,
translation_table=str.maketrans(translation_table = {
"-": " ", # sometimes used to represent pauses
"^": "v", # fixing "tha" in arabic_speech_corpus dataset
}),
)
```
The instance above is used as follows:
* creates a tokenizer with Buckwalter vocabulary and `word_delimiter_token="/"`
* replaces `"-"` with `" "` to clean up hyphens and fixes the orthography for `"ث"`
* removes words used as indicators (in this case, `"sil"` is used for silence)
* cleans up consecutive whitespaces (replaces them with a single space: `" "`)
* removes characters not in vocabulary (lacking respective sound units)
`--verbose_logging` logs text preprocessing updates and when evaluating, using the validation split every `eval_steps`,
logs references and predictions. Using the Buckwalter format, text is also logged in Arabic abjad.
`--target_feature_extractor_sampling_rate` resamples audio to target feature extractor's sampling rate (16kHz).
`--max_duration_in_seconds="15"` filters out examples whose audio is longer than the specified limit,
which helps with capping GPU memory usage.

View File

@ -0,0 +1,22 @@
#!/usr/bin/env bash
python run_asr.py \
--output_dir="./wav2vec2-base-timit-asr" \
--num_train_epochs="30" \
--per_device_train_batch_size="20" \
--per_device_eval_batch_size="20" \
--evaluation_strategy="steps" \
--save_steps="500" \
--eval_steps="100" \
--logging_steps="50" \
--learning_rate="5e-4" \
--warmup_steps="3000" \
--model_name_or_path="facebook/wav2vec2-base" \
--fp16 \
--dataset_name="timit_asr" \
--train_split_name="train" \
--validation_split_name="test" \
--orthography="timit" \
--preprocessing_num_workers="$(nproc)" \
--group_by_length \
--freeze_feature_extractor \
--verbose_logging \

View File

@ -0,0 +1,23 @@
#!/usr/bin/env bash
python run_asr.py \
--output_dir="./wav2vec2-large-lv60-timit-asr" \
--num_train_epochs="30" \
--per_device_train_batch_size="2" \
--per_device_eval_batch_size="2" \
--gradient_accumulation_steps="4" \
--evaluation_strategy="steps" \
--save_steps="500" \
--eval_steps="100" \
--logging_steps="50" \
--learning_rate="5e-4" \
--warmup_steps="3000" \
--model_name_or_path="facebook/wav2vec2-large-lv60" \
--fp16 \
--dataset_name="timit_asr" \
--train_split_name="train" \
--validation_split_name="test" \
--orthography="timit" \
--preprocessing_num_workers="$(nproc)" \
--group_by_length \
--freeze_feature_extractor \
--verbose_logging \

View File

@ -0,0 +1,25 @@
#!/usr/bin/env bash
python run_asr.py \
--output_dir="./wav2vec2-large-xlsr-53-arabic-speech-corpus" \
--num_train_epochs="50" \
--per_device_train_batch_size="1" \
--per_device_eval_batch_size="1" \
--gradient_accumulation_steps="8" \
--evaluation_strategy="steps" \
--save_steps="500" \
--eval_steps="100" \
--logging_steps="50" \
--learning_rate="5e-4" \
--warmup_steps="3000" \
--model_name_or_path="elgeish/wav2vec2-large-xlsr-53-arabic" \
--fp16 \
--dataset_name="arabic_speech_corpus" \
--train_split_name="train" \
--validation_split_name="test" \
--max_duration_in_seconds="15" \
--orthography="buckwalter" \
--preprocessing_num_workers="$(nproc)" \
--group_by_length \
--freeze_feature_extractor \
--target_feature_extractor_sampling_rate \
--verbose_logging \

View File

@ -1,4 +1,6 @@
transformers
datasets
torch >= 1.5.0
jiwer
torch>=1.5.0
jiwer==2.2.0
lang-trans==0.6.0
librosa==0.8.0

View File

@ -1,6 +1,10 @@
#!/usr/bin/env python3
import logging
import pathlib
import re
import sys
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Union
from typing import Any, Callable, Dict, List, Optional, Set, Union
import datasets
import numpy as np
@ -8,26 +12,32 @@ import torch
import torch.nn as nn
from packaging import version
import soundfile as sf
import librosa
from lang_trans import arabic
from transformers import (
HfArgumentParser,
Trainer,
TrainingArguments,
Wav2Vec2CTCTokenizer,
Wav2Vec2FeatureExtractor,
Wav2Vec2ForCTC,
Wav2Vec2Processor,
is_apex_available,
trainer_utils,
)
if is_apex_available():
from apex import amp
if version.parse(torch.__version__) >= version.parse("1.6"):
_is_native_amp_available = True
from torch.cuda.amp import autocast
logger = logging.getLogger(__name__)
@dataclass
class ModelArguments:
"""
@ -44,6 +54,27 @@ class ModelArguments:
freeze_feature_extractor: Optional[bool] = field(
default=True, metadata={"help": "Whether to freeze the feature extractor layers of the model."}
)
gradient_checkpointing: Optional[bool] = field(
default=False, metadata={"help": "Whether to freeze the feature extractor layers of the model."}
)
verbose_logging: Optional[bool] = field(
default=False,
metadata={"help": "Whether to log verbose messages or not."},
)
def configure_logger(model_args: ModelArguments, training_args: TrainingArguments):
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
handlers=[logging.StreamHandler(sys.stdout)],
)
logging_level = logging.WARNING
if model_args.verbose_logging:
logging_level = logging.DEBUG
elif trainer_utils.is_main_process(training_args.local_rank):
logging_level = logging.INFO
logger.setLevel(logging_level)
@dataclass
@ -68,6 +99,34 @@ class DataTrainingArguments:
"help": "The name of the training data set split to use (via the datasets library). Defaults to 'train'"
},
)
validation_split_name: Optional[str] = field(
default="validation",
metadata={
"help": "The name of the validation data set split to use (via the datasets library). Defaults to 'validation'"
},
)
target_text_column: Optional[str] = field(
default="text",
metadata={"help": "Column in the dataset that contains label (target text). Defaults to 'text'"},
)
speech_file_column: Optional[str] = field(
default="file",
metadata={"help": "Column in the dataset that contains speech file path. Defaults to 'file'"},
)
target_feature_extractor_sampling_rate: Optional[bool] = field(
default=False,
metadata={"help": "Resample loaded audio to target feature extractor's sampling rate or not."},
)
max_duration_in_seconds: Optional[float] = field(
default=None,
metadata={"help": "Filters out examples longer than specified. Defaults to no filtering."},
)
orthography: Optional[str] = field(
default="librispeech",
metadata={
"help": "Orthography used for normalization and tokenization: 'librispeech' (default), 'timit', or 'buckwalter'."
},
)
overwrite_cache: bool = field(
default=False, metadata={"help": "Overwrite the cached preprocessed datasets or not."}
)
@ -77,6 +136,88 @@ class DataTrainingArguments:
)
@dataclass
class Orthography:
"""
Orthography scheme used for text normalization and tokenization.
Args:
do_lower_case (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to accept lowercase input and lowercase the output when decoding.
vocab_file (:obj:`str`, `optional`, defaults to :obj:`None`):
File containing the vocabulary.
word_delimiter_token (:obj:`str`, `optional`, defaults to :obj:`"|"`):
The token used for delimiting words; it needs to be in the vocabulary.
translation_table (:obj:`Dict[str, str]`, `optional`, defaults to :obj:`{}`):
Table to use with `str.translate()` when preprocessing text (e.g., "-" -> " ").
words_to_remove (:obj:`Set[str]`, `optional`, defaults to :obj:`set()`):
Words to remove when preprocessing text (e.g., "sil").
untransliterator (:obj:`Callable[[str], str]`, `optional`, defaults to :obj:`None`):
Function that untransliterates text back into native writing system.
"""
do_lower_case: bool = False
vocab_file: Optional[str] = None
word_delimiter_token: Optional[str] = "|"
translation_table: Optional[Dict[str, str]] = field(default_factory=dict)
words_to_remove: Optional[Set[str]] = field(default_factory=set)
untransliterator: Optional[Callable[[str], str]] = None
@classmethod
def from_name(cls, name: str):
if name == "librispeech":
return cls()
if name == "timit":
return cls(
do_lower_case=True,
# break compounds like "quarter-century-old" and replace pauses "--"
translation_table=str.maketrans({"-": " "}),
)
if name == "buckwalter":
translation_table = {
"-": " ", # sometimes used to represent pauses
"^": "v", # fixing "tha" in arabic_speech_corpus dataset
}
return cls(
vocab_file=pathlib.Path(__file__).parent.joinpath("vocab/buckwalter.json"),
word_delimiter_token="/", # "|" is Arabic letter alef with madda above
translation_table=str.maketrans(translation_table),
words_to_remove={"sil"}, # fixing "sil" in arabic_speech_corpus dataset
untransliterator=arabic.buckwalter.untransliterate,
)
raise ValueError(f"Unsupported orthography: '{name}'.")
def preprocess_for_training(self, text: str) -> str:
# TODO(elgeish) return a pipeline (e.g., from jiwer) instead? Or rely on branch predictor as is
if len(self.translation_table) > 0:
text = text.translate(self.translation_table)
if len(self.words_to_remove) == 0:
text = " ".join(text.split()) # clean up whitespaces
else:
text = " ".join(w for w in text.split() if w not in self.words_to_remove) # and clean up whilespaces
return text
def create_processor(self, model_args: ModelArguments) -> Wav2Vec2Processor:
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
model_args.model_name_or_path, cache_dir=model_args.cache_dir
)
if self.vocab_file:
tokenizer = Wav2Vec2CTCTokenizer(
self.vocab_file,
cache_dir=model_args.cache_dir,
do_lower_case=self.do_lower_case,
word_delimiter_token=self.word_delimiter_token,
)
else:
tokenizer = Wav2Vec2CTCTokenizer.from_pretrained(
model_args.model_name_or_path,
cache_dir=model_args.cache_dir,
do_lower_case=self.do_lower_case,
word_delimiter_token=self.word_delimiter_token,
)
return Wav2Vec2Processor(feature_extractor, tokenizer)
@dataclass
class DataCollatorCTCWithPadding:
"""
@ -201,25 +342,72 @@ def main():
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
configure_logger(model_args, training_args)
model = Wav2Vec2ForCTC.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir)
processor = Wav2Vec2Processor.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir)
orthography = Orthography.from_name(data_args.orthography.lower())
processor = orthography.create_processor(model_args)
model = Wav2Vec2ForCTC.from_pretrained(
model_args.model_name_or_path,
cache_dir=model_args.cache_dir,
gradient_checkpointing=model_args.gradient_checkpointing,
vocab_size=len(processor.tokenizer),
)
train_dataset = datasets.load_dataset(
data_args.dataset_name, data_args.dataset_config_name, split=data_args.train_split_name
)
val_dataset = datasets.load_dataset(data_args.dataset_name, data_args.dataset_config_name, split="validation")
val_dataset = datasets.load_dataset(
data_args.dataset_name, data_args.dataset_config_name, split=data_args.validation_split_name
)
wer_metric = datasets.load_metric("wer")
target_sr = processor.feature_extractor.sampling_rate if data_args.target_feature_extractor_sampling_rate else None
vocabulary_chars_str = "".join(t for t in processor.tokenizer.get_vocab().keys() if len(t) == 1)
vocabulary_text_cleaner = re.compile( # remove characters not in vocabulary
f"[^\s{re.escape(vocabulary_chars_str)}]", # allow space in addition to chars in vocabulary
flags=re.IGNORECASE if processor.tokenizer.do_lower_case else 0,
)
text_updates = []
def map_to_array(batch):
speech_array, sampling_rate = sf.read(batch["file"])
batch["speech"] = speech_array
batch["sampling_rate"] = sampling_rate
return batch
def prepare_example(example): # TODO(elgeish) make use of multiprocessing?
example["speech"], example["sampling_rate"] = librosa.load(example[data_args.speech_file_column], sr=target_sr)
if data_args.max_duration_in_seconds is not None:
example["duration_in_seconds"] = len(example["speech"]) / example["sampling_rate"]
# Normalize and clean up text; order matters!
updated_text = orthography.preprocess_for_training(example[data_args.target_text_column])
updated_text = vocabulary_text_cleaner.sub("", updated_text)
if updated_text != example[data_args.target_text_column]:
text_updates.append((example[data_args.target_text_column], updated_text))
example[data_args.target_text_column] = updated_text
return example
train_dataset = train_dataset.map(map_to_array, remove_columns=["file"])
val_dataset = val_dataset.map(map_to_array, remove_columns=["file"])
train_dataset = train_dataset.map(prepare_example, remove_columns=[data_args.speech_file_column])
val_dataset = val_dataset.map(prepare_example, remove_columns=[data_args.speech_file_column])
if data_args.max_duration_in_seconds is not None:
def filter_by_max_duration(example):
return example["duration_in_seconds"] <= data_args.max_duration_in_seconds
old_train_size = len(train_dataset)
old_val_size = len(val_dataset)
train_dataset = train_dataset.filter(filter_by_max_duration, remove_columns=["duration_in_seconds"])
val_dataset = val_dataset.filter(filter_by_max_duration, remove_columns=["duration_in_seconds"])
if len(train_dataset) > old_train_size:
logger.warning(
f"Filtered out {len(train_dataset) - old_train_size} train example(s) longer than {data_args.max_duration_in_seconds} second(s)."
)
if len(val_dataset) > old_val_size:
logger.warning(
f"Filtered out {len(val_dataset) - old_val_size} validation example(s) longer than {data_args.max_duration_in_seconds} second(s)."
)
logger.info(f"Split sizes: {len(train_dataset)} train and {len(val_dataset)} validation.")
logger.warning(f"Updated {len(text_updates)} transcript(s) using '{data_args.orthography}' orthography rules.")
if logger.isEnabledFor(logging.DEBUG):
for original_text, updated_text in text_updates:
logger.debug(f'Updated text: "{original_text}" -> "{updated_text}"')
text_updates = None
def prepare_dataset(batch):
# check that all files have the correct sampling rate
@ -229,7 +417,7 @@ def main():
batch["input_values"] = processor(batch["speech"], sampling_rate=batch["sampling_rate"][0]).input_values
with processor.as_target_processor():
batch["labels"] = processor(batch["text"]).input_ids
batch["labels"] = processor(batch[data_args.target_text_column]).input_ids
return batch
train_dataset = train_dataset.map(
@ -256,6 +444,13 @@ def main():
pred_str = processor.batch_decode(pred_ids)
# we do not want to group tokens when computing the metrics
label_str = processor.batch_decode(pred.label_ids, group_tokens=False)
if logger.isEnabledFor(logging.DEBUG):
for reference, predicted in zip(label_str, pred_str):
logger.debug(f'reference: "{reference}"')
logger.debug(f'predicted: "{predicted}"')
if orthography.untransliterator is not None:
logger.debug(f'reference (untransliterated): "{orthography.untransliterator(reference)}"')
logger.debug(f'predicted (untransliterated): "{orthography.untransliterator(predicted)}"')
wer = wer_metric.compute(predictions=pred_str, references=label_str)

View File

@ -0,0 +1,58 @@
{
"<pad>": 0,
"<s>": 1,
"</s>": 2,
"<unk>": 3,
"/": 4,
"'": 5,
"|": 6,
">": 7,
"&": 8,
"<": 9,
"}": 10,
"A": 11,
"b": 12,
"p": 13,
"t": 14,
"v": 15,
"j": 16,
"H": 17,
"x": 18,
"d": 19,
"*": 20,
"r": 21,
"z": 22,
"s": 23,
"$": 24,
"S": 25,
"D": 26,
"T": 27,
"Z": 28,
"E": 29,
"g": 30,
"_": 31,
"f": 32,
"q": 33,
"k": 34,
"l": 35,
"m": 36,
"n": 37,
"h": 38,
"w": 39,
"Y": 40,
"y": 41,
"F": 42,
"N": 43,
"K": 44,
"a": 45,
"u": 46,
"i": 47,
"~": 48,
"o": 49,
"`": 50,
"{": 51,
"P": 52,
"J": 53,
"V": 54,
"G": 55
}

View File

@ -145,7 +145,7 @@ class Wav2Vec2CTCTokenizer(PreTrainedTokenizer):
@property
def word_delimiter_token(self) -> str:
"""
:obj:`str`: Padding token. Log an error if used while not having been set.
:obj:`str`: Word delimiter token. Log an error if used while not having been set.
"""
if self._word_delimiter_token is None and self.verbose:
logger.error("Using word_delimiter_token, but it is not set yet.")