transformers/docs/source/en/tasks/asr.mdx
Sylvain Gugger b9a768b3ff
Enable doc in Spanish (#16518)
* Reorganize doc for multilingual support

* Fix style

* Style

* Toc trees

* Adapt templates
2022-04-04 10:25:46 -04:00

218 lines
8.2 KiB
Plaintext

<!--Copyright 2022 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# Automatic speech recognition
<Youtube id="TksaY_FDgnk"/>
Automatic speech recognition (ASR) converts a speech signal to text. It is an example of a sequence-to-sequence task, going from a sequence of audio inputs to textual outputs. Voice assistants like Siri and Alexa utilize ASR models to assist users.
This guide will show you how to fine-tune [Wav2Vec2](https://huggingface.co/facebook/wav2vec2-base) on the [TIMIT](https://huggingface.co/datasets/timit_asr) dataset to transcribe audio to text.
<Tip>
See the automatic speech recognition [task page](https://huggingface.co/tasks/automatic-speech-recognition) for more information about its associated models, datasets, and metrics.
</Tip>
## Load TIMIT dataset
Load the TIMIT dataset from the 🤗 Datasets library:
```py
>>> from datasets import load_dataset
>>> timit = load_dataset("timit_asr")
```
Then take a look at an example:
```py
>>> timit
DatasetDict({
train: Dataset({
features: ['file', 'audio', 'text', 'phonetic_detail', 'word_detail', 'dialect_region', 'sentence_type', 'speaker_id', 'id'],
num_rows: 4620
})
test: Dataset({
features: ['file', 'audio', 'text', 'phonetic_detail', 'word_detail', 'dialect_region', 'sentence_type', 'speaker_id', 'id'],
num_rows: 1680
})
})
```
While the dataset contains a lot of helpful information, like `dialect_region` and `sentence_type`, you will focus on the `audio` and `text` fields in this guide. Remove the other columns:
```py
>>> timit = timit.remove_columns(
... ["phonetic_detail", "word_detail", "dialect_region", "id", "sentence_type", "speaker_id"]
... )
```
Take a look at the example again:
```py
>>> timit["train"][0]
{'audio': {'array': array([-2.1362305e-04, 6.1035156e-05, 3.0517578e-05, ...,
-3.0517578e-05, -9.1552734e-05, -6.1035156e-05], dtype=float32),
'path': '/root/.cache/huggingface/datasets/downloads/extracted/404950a46da14eac65eb4e2a8317b1372fb3971d980d91d5d5b221275b1fd7e0/data/TRAIN/DR4/MMDM0/SI681.WAV',
'sampling_rate': 16000},
'file': '/root/.cache/huggingface/datasets/downloads/extracted/404950a46da14eac65eb4e2a8317b1372fb3971d980d91d5d5b221275b1fd7e0/data/TRAIN/DR4/MMDM0/SI681.WAV',
'text': 'Would such an act of refusal be useful?'}
```
The `audio` column contains a 1-dimensional `array` of the speech signal that must be called to load and resample the audio file.
## Preprocess
Load the Wav2Vec2 processor to process the audio signal and transcribed text:
```py
>>> from transformers import AutoProcessor
>>> processor = AutoProcessor.from_pretrained("facebook/wav2vec2-base")
```
The preprocessing function needs to:
1. Call the `audio` column to load and resample the audio file.
2. Extract the `input_values` from the audio file.
3. Typically, when you call the processor, you call the feature extractor. Since you also want to tokenize text, instruct the processor to call the tokenizer instead with a context manager.
```py
>>> def prepare_dataset(batch):
... audio = batch["audio"]
... batch["input_values"] = processor(audio["array"], sampling_rate=audio["sampling_rate"]).input_values[0]
... batch["input_length"] = len(batch["input_values"])
... with processor.as_target_processor():
... batch["labels"] = processor(batch["text"]).input_ids
... return batch
```
Use 🤗 Datasets [`map`](https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map) function to apply the preprocessing function over the entire dataset. You can speed up the map function by increasing the number of processes with `num_proc`. Remove the columns you don't need:
```py
>>> timit = timit.map(prepare_dataset, remove_columns=timit.column_names["train"], num_proc=4)
```
🤗 Transformers doesn't have a data collator for automatic speech recognition, so you will need to create one. You can adapt the [`DataCollatorWithPadding`] to create a batch of examples for automatic speech recognition. It will also dynamically pad your text and labels to the length of the longest element in its batch, so they are a uniform length. While it is possible to pad your text in the `tokenizer` function by setting `padding=True`, dynamic padding is more efficient.
Unlike other data collators, this specific data collator needs to apply a different padding method to `input_values` and `labels`. You can apply a different padding method with a context manager:
```py
>>> import torch
>>> from dataclasses import dataclass, field
>>> from typing import Any, Dict, List, Optional, Union
>>> @dataclass
... class DataCollatorCTCWithPadding:
... processor: AutoProcessor
... padding: Union[bool, str] = True
... def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
... # split inputs and labels since they have to be of different lengths and need
... # different padding methods
... input_features = [{"input_values": feature["input_values"]} for feature in features]
... label_features = [{"input_ids": feature["labels"]} for feature in features]
... batch = self.processor.pad(
... input_features,
... padding=self.padding,
... return_tensors="pt",
... )
... with self.processor.as_target_processor():
... labels_batch = self.processor.pad(
... label_features,
... padding=self.padding,
... return_tensors="pt",
... )
... # replace padding with -100 to ignore loss correctly
... labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
... batch["labels"] = labels
... return batch
```
Create a batch of examples and dynamically pad them with `DataCollatorForCTCWithPadding`:
```py
>>> data_collator = DataCollatorCTCWithPadding(processor=processor, padding=True)
```
## Train
<frameworkcontent>
<pt>
Load Wav2Vec2 with [`AutoModelForCTC`]. For `ctc_loss_reduction`, it is often better to use the average instead of the default summation:
```py
>>> from transformers import AutoModelForCTC, TrainingArguments, Trainer
>>> model = AutoModelForCTC.from_pretrained(
... "facebook/wav2vec-base",
... ctc_loss_reduction="mean",
... pad_token_id=processor.tokenizer.pad_token_id,
... )
```
<Tip>
If you aren't familiar with fine-tuning a model with the [`Trainer`], take a look at the basic tutorial [here](../training#finetune-with-trainer)!
</Tip>
At this point, only three steps remain:
1. Define your training hyperparameters in [`TrainingArguments`].
2. Pass the training arguments to [`Trainer`] along with the model, datasets, tokenizer, and data collator.
3. Call [`~Trainer.train`] to fine-tune your model.
```py
>>> training_args = TrainingArguments(
... output_dir="./results",
... group_by_length=True,
... per_device_train_batch_size=16,
... evaluation_strategy="steps",
... num_train_epochs=3,
... fp16=True,
... gradient_checkpointing=True,
... learning_rate=1e-4,
... weight_decay=0.005,
... save_total_limit=2,
... )
>>> trainer = Trainer(
... model=model,
... args=training_args,
... train_dataset=timit["train"],
... eval_dataset=timit["test"],
... tokenizer=processor.feature_extractor,
... data_collator=data_collator,
... )
>>> trainer.train()
```
</pt>
</frameworkcontent>
<Tip>
For a more in-depth example of how to fine-tune a model for automatic speech recognition, take a look at this blog [post](https://huggingface.co/blog/fine-tune-wav2vec2-english) for English ASR and this [post](https://huggingface.co/blog/fine-tune-xlsr-wav2vec2) for multilingual ASR.
</Tip>