mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
wav2vec2: support datasets other than LibriSpeech (#10581)
* wav2vec2: support datasets other than LibriSpeech * Formatting run_asr.py to pass code quality test * bundled orthography options and added verbose logs * fixing a typo in timit fine-tuning script * update comment for clarity * resize_lm_head and load custom vocab from file * adding a max_duration_in_seconds filter * do not assign `duration_filter` lambda, use a def * log untransliterated text as well * fix base model for arabic * fix duration filter when target_sr is not set * drop duration_in_seconds when unneeded * script for wav2vec2-large-lv60-timit-asr * fix for "tha" in arabic corpus (huggingface#10581) * adding more options to work with common_voice * PR feedback (huggingface#10581) * small README change
This commit is contained in:
parent
0b98ca368f
commit
af8afdc88d
@ -1,8 +1,129 @@
|
||||
## Fine-tuning Wav2Vec2
|
||||
|
||||
The `run_training.py` script allows one to finetune pretrained Wav2Vec2 models that can be found [here](https://huggingface.co/models?search=facebook/wav2vec2).
|
||||
The `run_asr.py` script allows one to fine-tune pretrained Wav2Vec2 models that can be found [here](https://huggingface.co/models?search=facebook/wav2vec2).
|
||||
|
||||
This finetuning script can also be run as a google colab [TODO: here]( ).
|
||||
|
||||
The script is actively maintained by [Patrick von Platen](https://github.com/patrickvonplaten).
|
||||
The script is actively maintained by [Patrick von Platen](https://github.com/patrickvonplaten).
|
||||
Feel free to ask a question on the [Forum](https://discuss.huggingface.co/) or post an issue on [GitHub](https://github.com/huggingface/transformers/issues/new/choose) and adding `@patrickvonplaten` as a tag.
|
||||
|
||||
### Fine-Tuning with TIMIT
|
||||
Let's take a look at the [script](./finetune_base_timit_asr.sh) used to fine-tune [wav2vec2-base](https://huggingface.co/facebook/wav2vec2-base)
|
||||
with the [TIMIT dataset](https://huggingface.co/datasets/timit_asr):
|
||||
|
||||
```bash
|
||||
#!/usr/bin/env bash
|
||||
python run_asr.py \
|
||||
--output_dir="./wav2vec2-base-timit-asr" \
|
||||
--num_train_epochs="30" \
|
||||
--per_device_train_batch_size="20" \
|
||||
--per_device_eval_batch_size="20" \
|
||||
--evaluation_strategy="steps" \
|
||||
--save_steps="500" \
|
||||
--eval_steps="100" \
|
||||
--logging_steps="50" \
|
||||
--learning_rate="5e-4" \
|
||||
--warmup_steps="3000" \
|
||||
--model_name_or_path="facebook/wav2vec2-base" \
|
||||
--fp16 \
|
||||
--dataset_name="timit_asr" \
|
||||
--train_split_name="train" \
|
||||
--validation_split_name="test" \
|
||||
--orthography="timit" \
|
||||
--preprocessing_num_workers="$(nproc)" \
|
||||
--group_by_length \
|
||||
--freeze_feature_extractor \
|
||||
--verbose_logging \
|
||||
```
|
||||
|
||||
The resulting model and inference examples can be found [here](https://huggingface.co/elgeish/wav2vec2-base-timit-asr).
|
||||
Some of the arguments above may look unfamiliar, let's break down what's going on:
|
||||
|
||||
`--orthography="timit"` applies certain text preprocessing rules, for tokenization and normalization, to clean up the dataset.
|
||||
In this case, we use the following instance of `Orthography`:
|
||||
|
||||
```python
|
||||
Orthography(
|
||||
do_lower_case=True,
|
||||
# break compounds like "quarter-century-old" and replace pauses "--"
|
||||
translation_table=str.maketrans({"-": " "}),
|
||||
)
|
||||
```
|
||||
|
||||
The instance above is used as follows:
|
||||
* creates a tokenizer with `do_lower_case=True` (ignores casing for input and lowercases output when decoding)
|
||||
* replaces `"-"` with `" "` to break compounds like `"quarter-century-old"` and to clean up suspended hyphens
|
||||
* cleans up consecutive whitespaces (replaces them with a single space: `" "`)
|
||||
* removes characters not in vocabulary (lacking respective sound units)
|
||||
|
||||
`--verbose_logging` logs text preprocessing updates and when evaluating, using the validation split every `eval_steps`,
|
||||
logs references and predictions.
|
||||
|
||||
### Fine-Tuning with Arabic Speech Corpus
|
||||
|
||||
Other datasets, like the [Arabic Speech Corpus dataset](https://huggingface.co/datasets/arabic_speech_corpus),
|
||||
require more work! Let's take a look at the [script](./finetune_large_xlsr_53_arabic_speech_corpus.sh)
|
||||
used to fine-tune [wav2vec2-large-xlsr-53](https://huggingface.co/elgeish/wav2vec2-large-xlsr-53-arabic):
|
||||
|
||||
```bash
|
||||
#!/usr/bin/env bash
|
||||
python run_asr.py \
|
||||
--output_dir="./wav2vec2-large-xlsr-53-arabic-speech-corpus" \
|
||||
--num_train_epochs="50" \
|
||||
--per_device_train_batch_size="1" \
|
||||
--per_device_eval_batch_size="1" \
|
||||
--gradient_accumulation_steps="8" \
|
||||
--evaluation_strategy="steps" \
|
||||
--save_steps="500" \
|
||||
--eval_steps="100" \
|
||||
--logging_steps="50" \
|
||||
--learning_rate="5e-4" \
|
||||
--warmup_steps="3000" \
|
||||
--model_name_or_path="elgeish/wav2vec2-large-xlsr-53-arabic" \
|
||||
--fp16 \
|
||||
--dataset_name="arabic_speech_corpus" \
|
||||
--train_split_name="train" \
|
||||
--validation_split_name="test" \
|
||||
--max_duration_in_seconds="15" \
|
||||
--orthography="buckwalter" \
|
||||
--preprocessing_num_workers="$(nproc)" \
|
||||
--group_by_length \
|
||||
--freeze_feature_extractor \
|
||||
--target_feature_extractor_sampling_rate \
|
||||
--verbose_logging \
|
||||
```
|
||||
|
||||
First, let's understand how this dataset represents Arabic text; it uses a format called
|
||||
[Buckwalter transliteration](https://en.wikipedia.org/wiki/Buckwalter_transliteration).
|
||||
We use the [lang-trans](https://github.com/kariminf/lang-trans) package to convert back to Arabic when logging.
|
||||
The Buckwalter format only includes ASCII characters, some of which are non-alpha (e.g., `">"` maps to `"أ"`).
|
||||
|
||||
`--orthography="buckwalter"` applies certain text preprocessing rules, for tokenization and normalization, to clean up the dataset. In this case, we use the following instance of `Orthography`:
|
||||
|
||||
```python
|
||||
Orthography(
|
||||
vocab_file=pathlib.Path(__file__).parent.joinpath("vocab/buckwalter.json"),
|
||||
word_delimiter_token="/", # "|" is Arabic letter alef with madda above
|
||||
words_to_remove={"sil"}, # fixing "sil" in arabic_speech_corpus dataset
|
||||
untransliterator=arabic.buckwalter.untransliterate,
|
||||
translation_table=str.maketrans(translation_table = {
|
||||
"-": " ", # sometimes used to represent pauses
|
||||
"^": "v", # fixing "tha" in arabic_speech_corpus dataset
|
||||
}),
|
||||
)
|
||||
```
|
||||
|
||||
The instance above is used as follows:
|
||||
* creates a tokenizer with Buckwalter vocabulary and `word_delimiter_token="/"`
|
||||
* replaces `"-"` with `" "` to clean up hyphens and fixes the orthography for `"ث"`
|
||||
* removes words used as indicators (in this case, `"sil"` is used for silence)
|
||||
* cleans up consecutive whitespaces (replaces them with a single space: `" "`)
|
||||
* removes characters not in vocabulary (lacking respective sound units)
|
||||
|
||||
`--verbose_logging` logs text preprocessing updates and when evaluating, using the validation split every `eval_steps`,
|
||||
logs references and predictions. Using the Buckwalter format, text is also logged in Arabic abjad.
|
||||
|
||||
`--target_feature_extractor_sampling_rate` resamples audio to target feature extractor's sampling rate (16kHz).
|
||||
|
||||
`--max_duration_in_seconds="15"` filters out examples whose audio is longer than the specified limit,
|
||||
which helps with capping GPU memory usage.
|
||||
|
22
examples/research_projects/wav2vec2/finetune_base_timit_asr.sh
Executable file
22
examples/research_projects/wav2vec2/finetune_base_timit_asr.sh
Executable file
@ -0,0 +1,22 @@
|
||||
#!/usr/bin/env bash
|
||||
python run_asr.py \
|
||||
--output_dir="./wav2vec2-base-timit-asr" \
|
||||
--num_train_epochs="30" \
|
||||
--per_device_train_batch_size="20" \
|
||||
--per_device_eval_batch_size="20" \
|
||||
--evaluation_strategy="steps" \
|
||||
--save_steps="500" \
|
||||
--eval_steps="100" \
|
||||
--logging_steps="50" \
|
||||
--learning_rate="5e-4" \
|
||||
--warmup_steps="3000" \
|
||||
--model_name_or_path="facebook/wav2vec2-base" \
|
||||
--fp16 \
|
||||
--dataset_name="timit_asr" \
|
||||
--train_split_name="train" \
|
||||
--validation_split_name="test" \
|
||||
--orthography="timit" \
|
||||
--preprocessing_num_workers="$(nproc)" \
|
||||
--group_by_length \
|
||||
--freeze_feature_extractor \
|
||||
--verbose_logging \
|
23
examples/research_projects/wav2vec2/finetune_large_lv60_timit_asr.sh
Executable file
23
examples/research_projects/wav2vec2/finetune_large_lv60_timit_asr.sh
Executable file
@ -0,0 +1,23 @@
|
||||
#!/usr/bin/env bash
|
||||
python run_asr.py \
|
||||
--output_dir="./wav2vec2-large-lv60-timit-asr" \
|
||||
--num_train_epochs="30" \
|
||||
--per_device_train_batch_size="2" \
|
||||
--per_device_eval_batch_size="2" \
|
||||
--gradient_accumulation_steps="4" \
|
||||
--evaluation_strategy="steps" \
|
||||
--save_steps="500" \
|
||||
--eval_steps="100" \
|
||||
--logging_steps="50" \
|
||||
--learning_rate="5e-4" \
|
||||
--warmup_steps="3000" \
|
||||
--model_name_or_path="facebook/wav2vec2-large-lv60" \
|
||||
--fp16 \
|
||||
--dataset_name="timit_asr" \
|
||||
--train_split_name="train" \
|
||||
--validation_split_name="test" \
|
||||
--orthography="timit" \
|
||||
--preprocessing_num_workers="$(nproc)" \
|
||||
--group_by_length \
|
||||
--freeze_feature_extractor \
|
||||
--verbose_logging \
|
@ -0,0 +1,25 @@
|
||||
#!/usr/bin/env bash
|
||||
python run_asr.py \
|
||||
--output_dir="./wav2vec2-large-xlsr-53-arabic-speech-corpus" \
|
||||
--num_train_epochs="50" \
|
||||
--per_device_train_batch_size="1" \
|
||||
--per_device_eval_batch_size="1" \
|
||||
--gradient_accumulation_steps="8" \
|
||||
--evaluation_strategy="steps" \
|
||||
--save_steps="500" \
|
||||
--eval_steps="100" \
|
||||
--logging_steps="50" \
|
||||
--learning_rate="5e-4" \
|
||||
--warmup_steps="3000" \
|
||||
--model_name_or_path="elgeish/wav2vec2-large-xlsr-53-arabic" \
|
||||
--fp16 \
|
||||
--dataset_name="arabic_speech_corpus" \
|
||||
--train_split_name="train" \
|
||||
--validation_split_name="test" \
|
||||
--max_duration_in_seconds="15" \
|
||||
--orthography="buckwalter" \
|
||||
--preprocessing_num_workers="$(nproc)" \
|
||||
--group_by_length \
|
||||
--freeze_feature_extractor \
|
||||
--target_feature_extractor_sampling_rate \
|
||||
--verbose_logging \
|
@ -1,4 +1,6 @@
|
||||
transformers
|
||||
datasets
|
||||
torch >= 1.5.0
|
||||
jiwer
|
||||
torch>=1.5.0
|
||||
jiwer==2.2.0
|
||||
lang-trans==0.6.0
|
||||
librosa==0.8.0
|
||||
|
@ -1,6 +1,10 @@
|
||||
#!/usr/bin/env python3
|
||||
import logging
|
||||
import pathlib
|
||||
import re
|
||||
import sys
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from typing import Any, Callable, Dict, List, Optional, Set, Union
|
||||
|
||||
import datasets
|
||||
import numpy as np
|
||||
@ -8,26 +12,32 @@ import torch
|
||||
import torch.nn as nn
|
||||
from packaging import version
|
||||
|
||||
import soundfile as sf
|
||||
import librosa
|
||||
from lang_trans import arabic
|
||||
from transformers import (
|
||||
HfArgumentParser,
|
||||
Trainer,
|
||||
TrainingArguments,
|
||||
Wav2Vec2CTCTokenizer,
|
||||
Wav2Vec2FeatureExtractor,
|
||||
Wav2Vec2ForCTC,
|
||||
Wav2Vec2Processor,
|
||||
is_apex_available,
|
||||
trainer_utils,
|
||||
)
|
||||
|
||||
|
||||
if is_apex_available():
|
||||
from apex import amp
|
||||
|
||||
|
||||
if version.parse(torch.__version__) >= version.parse("1.6"):
|
||||
_is_native_amp_available = True
|
||||
from torch.cuda.amp import autocast
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArguments:
|
||||
"""
|
||||
@ -44,6 +54,27 @@ class ModelArguments:
|
||||
freeze_feature_extractor: Optional[bool] = field(
|
||||
default=True, metadata={"help": "Whether to freeze the feature extractor layers of the model."}
|
||||
)
|
||||
gradient_checkpointing: Optional[bool] = field(
|
||||
default=False, metadata={"help": "Whether to freeze the feature extractor layers of the model."}
|
||||
)
|
||||
verbose_logging: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether to log verbose messages or not."},
|
||||
)
|
||||
|
||||
|
||||
def configure_logger(model_args: ModelArguments, training_args: TrainingArguments):
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
datefmt="%m/%d/%Y %H:%M:%S",
|
||||
handlers=[logging.StreamHandler(sys.stdout)],
|
||||
)
|
||||
logging_level = logging.WARNING
|
||||
if model_args.verbose_logging:
|
||||
logging_level = logging.DEBUG
|
||||
elif trainer_utils.is_main_process(training_args.local_rank):
|
||||
logging_level = logging.INFO
|
||||
logger.setLevel(logging_level)
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -68,6 +99,34 @@ class DataTrainingArguments:
|
||||
"help": "The name of the training data set split to use (via the datasets library). Defaults to 'train'"
|
||||
},
|
||||
)
|
||||
validation_split_name: Optional[str] = field(
|
||||
default="validation",
|
||||
metadata={
|
||||
"help": "The name of the validation data set split to use (via the datasets library). Defaults to 'validation'"
|
||||
},
|
||||
)
|
||||
target_text_column: Optional[str] = field(
|
||||
default="text",
|
||||
metadata={"help": "Column in the dataset that contains label (target text). Defaults to 'text'"},
|
||||
)
|
||||
speech_file_column: Optional[str] = field(
|
||||
default="file",
|
||||
metadata={"help": "Column in the dataset that contains speech file path. Defaults to 'file'"},
|
||||
)
|
||||
target_feature_extractor_sampling_rate: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={"help": "Resample loaded audio to target feature extractor's sampling rate or not."},
|
||||
)
|
||||
max_duration_in_seconds: Optional[float] = field(
|
||||
default=None,
|
||||
metadata={"help": "Filters out examples longer than specified. Defaults to no filtering."},
|
||||
)
|
||||
orthography: Optional[str] = field(
|
||||
default="librispeech",
|
||||
metadata={
|
||||
"help": "Orthography used for normalization and tokenization: 'librispeech' (default), 'timit', or 'buckwalter'."
|
||||
},
|
||||
)
|
||||
overwrite_cache: bool = field(
|
||||
default=False, metadata={"help": "Overwrite the cached preprocessed datasets or not."}
|
||||
)
|
||||
@ -77,6 +136,88 @@ class DataTrainingArguments:
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Orthography:
|
||||
"""
|
||||
Orthography scheme used for text normalization and tokenization.
|
||||
|
||||
Args:
|
||||
do_lower_case (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether or not to accept lowercase input and lowercase the output when decoding.
|
||||
vocab_file (:obj:`str`, `optional`, defaults to :obj:`None`):
|
||||
File containing the vocabulary.
|
||||
word_delimiter_token (:obj:`str`, `optional`, defaults to :obj:`"|"`):
|
||||
The token used for delimiting words; it needs to be in the vocabulary.
|
||||
translation_table (:obj:`Dict[str, str]`, `optional`, defaults to :obj:`{}`):
|
||||
Table to use with `str.translate()` when preprocessing text (e.g., "-" -> " ").
|
||||
words_to_remove (:obj:`Set[str]`, `optional`, defaults to :obj:`set()`):
|
||||
Words to remove when preprocessing text (e.g., "sil").
|
||||
untransliterator (:obj:`Callable[[str], str]`, `optional`, defaults to :obj:`None`):
|
||||
Function that untransliterates text back into native writing system.
|
||||
"""
|
||||
|
||||
do_lower_case: bool = False
|
||||
vocab_file: Optional[str] = None
|
||||
word_delimiter_token: Optional[str] = "|"
|
||||
translation_table: Optional[Dict[str, str]] = field(default_factory=dict)
|
||||
words_to_remove: Optional[Set[str]] = field(default_factory=set)
|
||||
untransliterator: Optional[Callable[[str], str]] = None
|
||||
|
||||
@classmethod
|
||||
def from_name(cls, name: str):
|
||||
if name == "librispeech":
|
||||
return cls()
|
||||
if name == "timit":
|
||||
return cls(
|
||||
do_lower_case=True,
|
||||
# break compounds like "quarter-century-old" and replace pauses "--"
|
||||
translation_table=str.maketrans({"-": " "}),
|
||||
)
|
||||
if name == "buckwalter":
|
||||
translation_table = {
|
||||
"-": " ", # sometimes used to represent pauses
|
||||
"^": "v", # fixing "tha" in arabic_speech_corpus dataset
|
||||
}
|
||||
return cls(
|
||||
vocab_file=pathlib.Path(__file__).parent.joinpath("vocab/buckwalter.json"),
|
||||
word_delimiter_token="/", # "|" is Arabic letter alef with madda above
|
||||
translation_table=str.maketrans(translation_table),
|
||||
words_to_remove={"sil"}, # fixing "sil" in arabic_speech_corpus dataset
|
||||
untransliterator=arabic.buckwalter.untransliterate,
|
||||
)
|
||||
raise ValueError(f"Unsupported orthography: '{name}'.")
|
||||
|
||||
def preprocess_for_training(self, text: str) -> str:
|
||||
# TODO(elgeish) return a pipeline (e.g., from jiwer) instead? Or rely on branch predictor as is
|
||||
if len(self.translation_table) > 0:
|
||||
text = text.translate(self.translation_table)
|
||||
if len(self.words_to_remove) == 0:
|
||||
text = " ".join(text.split()) # clean up whitespaces
|
||||
else:
|
||||
text = " ".join(w for w in text.split() if w not in self.words_to_remove) # and clean up whilespaces
|
||||
return text
|
||||
|
||||
def create_processor(self, model_args: ModelArguments) -> Wav2Vec2Processor:
|
||||
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
|
||||
model_args.model_name_or_path, cache_dir=model_args.cache_dir
|
||||
)
|
||||
if self.vocab_file:
|
||||
tokenizer = Wav2Vec2CTCTokenizer(
|
||||
self.vocab_file,
|
||||
cache_dir=model_args.cache_dir,
|
||||
do_lower_case=self.do_lower_case,
|
||||
word_delimiter_token=self.word_delimiter_token,
|
||||
)
|
||||
else:
|
||||
tokenizer = Wav2Vec2CTCTokenizer.from_pretrained(
|
||||
model_args.model_name_or_path,
|
||||
cache_dir=model_args.cache_dir,
|
||||
do_lower_case=self.do_lower_case,
|
||||
word_delimiter_token=self.word_delimiter_token,
|
||||
)
|
||||
return Wav2Vec2Processor(feature_extractor, tokenizer)
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataCollatorCTCWithPadding:
|
||||
"""
|
||||
@ -201,25 +342,72 @@ def main():
|
||||
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
|
||||
|
||||
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
||||
configure_logger(model_args, training_args)
|
||||
|
||||
model = Wav2Vec2ForCTC.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir)
|
||||
processor = Wav2Vec2Processor.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir)
|
||||
orthography = Orthography.from_name(data_args.orthography.lower())
|
||||
processor = orthography.create_processor(model_args)
|
||||
model = Wav2Vec2ForCTC.from_pretrained(
|
||||
model_args.model_name_or_path,
|
||||
cache_dir=model_args.cache_dir,
|
||||
gradient_checkpointing=model_args.gradient_checkpointing,
|
||||
vocab_size=len(processor.tokenizer),
|
||||
)
|
||||
|
||||
train_dataset = datasets.load_dataset(
|
||||
data_args.dataset_name, data_args.dataset_config_name, split=data_args.train_split_name
|
||||
)
|
||||
val_dataset = datasets.load_dataset(data_args.dataset_name, data_args.dataset_config_name, split="validation")
|
||||
val_dataset = datasets.load_dataset(
|
||||
data_args.dataset_name, data_args.dataset_config_name, split=data_args.validation_split_name
|
||||
)
|
||||
|
||||
wer_metric = datasets.load_metric("wer")
|
||||
target_sr = processor.feature_extractor.sampling_rate if data_args.target_feature_extractor_sampling_rate else None
|
||||
vocabulary_chars_str = "".join(t for t in processor.tokenizer.get_vocab().keys() if len(t) == 1)
|
||||
vocabulary_text_cleaner = re.compile( # remove characters not in vocabulary
|
||||
f"[^\s{re.escape(vocabulary_chars_str)}]", # allow space in addition to chars in vocabulary
|
||||
flags=re.IGNORECASE if processor.tokenizer.do_lower_case else 0,
|
||||
)
|
||||
text_updates = []
|
||||
|
||||
def map_to_array(batch):
|
||||
speech_array, sampling_rate = sf.read(batch["file"])
|
||||
batch["speech"] = speech_array
|
||||
batch["sampling_rate"] = sampling_rate
|
||||
return batch
|
||||
def prepare_example(example): # TODO(elgeish) make use of multiprocessing?
|
||||
example["speech"], example["sampling_rate"] = librosa.load(example[data_args.speech_file_column], sr=target_sr)
|
||||
if data_args.max_duration_in_seconds is not None:
|
||||
example["duration_in_seconds"] = len(example["speech"]) / example["sampling_rate"]
|
||||
# Normalize and clean up text; order matters!
|
||||
updated_text = orthography.preprocess_for_training(example[data_args.target_text_column])
|
||||
updated_text = vocabulary_text_cleaner.sub("", updated_text)
|
||||
if updated_text != example[data_args.target_text_column]:
|
||||
text_updates.append((example[data_args.target_text_column], updated_text))
|
||||
example[data_args.target_text_column] = updated_text
|
||||
return example
|
||||
|
||||
train_dataset = train_dataset.map(map_to_array, remove_columns=["file"])
|
||||
val_dataset = val_dataset.map(map_to_array, remove_columns=["file"])
|
||||
train_dataset = train_dataset.map(prepare_example, remove_columns=[data_args.speech_file_column])
|
||||
val_dataset = val_dataset.map(prepare_example, remove_columns=[data_args.speech_file_column])
|
||||
|
||||
if data_args.max_duration_in_seconds is not None:
|
||||
|
||||
def filter_by_max_duration(example):
|
||||
return example["duration_in_seconds"] <= data_args.max_duration_in_seconds
|
||||
|
||||
old_train_size = len(train_dataset)
|
||||
old_val_size = len(val_dataset)
|
||||
train_dataset = train_dataset.filter(filter_by_max_duration, remove_columns=["duration_in_seconds"])
|
||||
val_dataset = val_dataset.filter(filter_by_max_duration, remove_columns=["duration_in_seconds"])
|
||||
if len(train_dataset) > old_train_size:
|
||||
logger.warning(
|
||||
f"Filtered out {len(train_dataset) - old_train_size} train example(s) longer than {data_args.max_duration_in_seconds} second(s)."
|
||||
)
|
||||
if len(val_dataset) > old_val_size:
|
||||
logger.warning(
|
||||
f"Filtered out {len(val_dataset) - old_val_size} validation example(s) longer than {data_args.max_duration_in_seconds} second(s)."
|
||||
)
|
||||
logger.info(f"Split sizes: {len(train_dataset)} train and {len(val_dataset)} validation.")
|
||||
|
||||
logger.warning(f"Updated {len(text_updates)} transcript(s) using '{data_args.orthography}' orthography rules.")
|
||||
if logger.isEnabledFor(logging.DEBUG):
|
||||
for original_text, updated_text in text_updates:
|
||||
logger.debug(f'Updated text: "{original_text}" -> "{updated_text}"')
|
||||
text_updates = None
|
||||
|
||||
def prepare_dataset(batch):
|
||||
# check that all files have the correct sampling rate
|
||||
@ -229,7 +417,7 @@ def main():
|
||||
|
||||
batch["input_values"] = processor(batch["speech"], sampling_rate=batch["sampling_rate"][0]).input_values
|
||||
with processor.as_target_processor():
|
||||
batch["labels"] = processor(batch["text"]).input_ids
|
||||
batch["labels"] = processor(batch[data_args.target_text_column]).input_ids
|
||||
return batch
|
||||
|
||||
train_dataset = train_dataset.map(
|
||||
@ -256,6 +444,13 @@ def main():
|
||||
pred_str = processor.batch_decode(pred_ids)
|
||||
# we do not want to group tokens when computing the metrics
|
||||
label_str = processor.batch_decode(pred.label_ids, group_tokens=False)
|
||||
if logger.isEnabledFor(logging.DEBUG):
|
||||
for reference, predicted in zip(label_str, pred_str):
|
||||
logger.debug(f'reference: "{reference}"')
|
||||
logger.debug(f'predicted: "{predicted}"')
|
||||
if orthography.untransliterator is not None:
|
||||
logger.debug(f'reference (untransliterated): "{orthography.untransliterator(reference)}"')
|
||||
logger.debug(f'predicted (untransliterated): "{orthography.untransliterator(predicted)}"')
|
||||
|
||||
wer = wer_metric.compute(predictions=pred_str, references=label_str)
|
||||
|
||||
|
58
examples/research_projects/wav2vec2/vocab/buckwalter.json
Normal file
58
examples/research_projects/wav2vec2/vocab/buckwalter.json
Normal file
@ -0,0 +1,58 @@
|
||||
{
|
||||
"<pad>": 0,
|
||||
"<s>": 1,
|
||||
"</s>": 2,
|
||||
"<unk>": 3,
|
||||
"/": 4,
|
||||
"'": 5,
|
||||
"|": 6,
|
||||
">": 7,
|
||||
"&": 8,
|
||||
"<": 9,
|
||||
"}": 10,
|
||||
"A": 11,
|
||||
"b": 12,
|
||||
"p": 13,
|
||||
"t": 14,
|
||||
"v": 15,
|
||||
"j": 16,
|
||||
"H": 17,
|
||||
"x": 18,
|
||||
"d": 19,
|
||||
"*": 20,
|
||||
"r": 21,
|
||||
"z": 22,
|
||||
"s": 23,
|
||||
"$": 24,
|
||||
"S": 25,
|
||||
"D": 26,
|
||||
"T": 27,
|
||||
"Z": 28,
|
||||
"E": 29,
|
||||
"g": 30,
|
||||
"_": 31,
|
||||
"f": 32,
|
||||
"q": 33,
|
||||
"k": 34,
|
||||
"l": 35,
|
||||
"m": 36,
|
||||
"n": 37,
|
||||
"h": 38,
|
||||
"w": 39,
|
||||
"Y": 40,
|
||||
"y": 41,
|
||||
"F": 42,
|
||||
"N": 43,
|
||||
"K": 44,
|
||||
"a": 45,
|
||||
"u": 46,
|
||||
"i": 47,
|
||||
"~": 48,
|
||||
"o": 49,
|
||||
"`": 50,
|
||||
"{": 51,
|
||||
"P": 52,
|
||||
"J": 53,
|
||||
"V": 54,
|
||||
"G": 55
|
||||
}
|
@ -145,7 +145,7 @@ class Wav2Vec2CTCTokenizer(PreTrainedTokenizer):
|
||||
@property
|
||||
def word_delimiter_token(self) -> str:
|
||||
"""
|
||||
:obj:`str`: Padding token. Log an error if used while not having been set.
|
||||
:obj:`str`: Word delimiter token. Log an error if used while not having been set.
|
||||
"""
|
||||
if self._word_delimiter_token is None and self.verbose:
|
||||
logger.error("Using word_delimiter_token, but it is not set yet.")
|
||||
|
Loading…
Reference in New Issue
Block a user