mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-04 21:30:07 +06:00
Add MMS CTC Fine-Tuning (#24281)
* Add mms ctc fine tuning * make style * More fixes that are needed * make fix-copies * make draft for README * add new file * move to new file * make style * make style * add quick test * make style * make style
This commit is contained in:
parent
0c3fdccf2f
commit
1609a436ec
@ -26,6 +26,10 @@ limitations under the License.
|
|||||||
- [Librispeech](#librispeech-ctc)
|
- [Librispeech](#librispeech-ctc)
|
||||||
- [Common Voice](#common-voice-ctc)
|
- [Common Voice](#common-voice-ctc)
|
||||||
- [Multilingual Librispeech](#multilingual-librispeech-ctc)
|
- [Multilingual Librispeech](#multilingual-librispeech-ctc)
|
||||||
|
- [Automatic Speech Recognition with CTC and Adapter Layers](#connectionist-temporal-classification-with-adapters)
|
||||||
|
- [Massive Multilingual Speech (MMS)](#mms-model)
|
||||||
|
- [Examples](#examples-ctc-adapter)
|
||||||
|
- [Common Voice](#common-voice-ctc-adapter)
|
||||||
- [Automatic Speech Recognition with Sequence-to-Sequence](#sequence-to-sequence)
|
- [Automatic Speech Recognition with Sequence-to-Sequence](#sequence-to-sequence)
|
||||||
- [Whisper Model](#whisper-model)
|
- [Whisper Model](#whisper-model)
|
||||||
- [Speech-Encoder-Decoder Model](#warm-started-speech-encoder-decoder-model)
|
- [Speech-Encoder-Decoder Model](#warm-started-speech-encoder-decoder-model)
|
||||||
@ -243,6 +247,111 @@ they can serve as a baseline to improve upon.
|
|||||||
| [Multilingual Librispeech](https://huggingface.co/datasets/multilingual_librispeech)| `"german"` | [facebook/wav2vec2-large-xlsr-53](https://huggingface.co/facebook/wav2vec2-large-xlsr-53) | 0.13 | - | 1 GPU Titan 24 GB RAM | 15h04 | [here](https://huggingface.co/patrickvonplaten/wav2vec2-xlsr-53-300m-mls-german-ft) | [run.sh](https://huggingface.co/patrickvonplaten/wav2vec2-xlsr-53-300m-mls-german-ft/blob/main/run.sh) |
|
| [Multilingual Librispeech](https://huggingface.co/datasets/multilingual_librispeech)| `"german"` | [facebook/wav2vec2-large-xlsr-53](https://huggingface.co/facebook/wav2vec2-large-xlsr-53) | 0.13 | - | 1 GPU Titan 24 GB RAM | 15h04 | [here](https://huggingface.co/patrickvonplaten/wav2vec2-xlsr-53-300m-mls-german-ft) | [run.sh](https://huggingface.co/patrickvonplaten/wav2vec2-xlsr-53-300m-mls-german-ft/blob/main/run.sh) |
|
||||||
| [Multilingual Librispeech](https://huggingface.co/datasets/multilingual_librispeech)| `"german"` | [facebook/wav2vec2-xls-r-300m](https://huggingface.co/facebook/wav2vec2-xls-r-300m) | 0.15 | - | 1 GPU Titan 24 GB RAM | 15h04 | [here](https://huggingface.co/patrickvonplaten/wav2vec2-300m-mls-german-ft) | [run.sh](https://huggingface.co/patrickvonplaten/wav2vec2-300m-mls-german-ft/blob/main/run.sh) |
|
| [Multilingual Librispeech](https://huggingface.co/datasets/multilingual_librispeech)| `"german"` | [facebook/wav2vec2-xls-r-300m](https://huggingface.co/facebook/wav2vec2-xls-r-300m) | 0.15 | - | 1 GPU Titan 24 GB RAM | 15h04 | [here](https://huggingface.co/patrickvonplaten/wav2vec2-300m-mls-german-ft) | [run.sh](https://huggingface.co/patrickvonplaten/wav2vec2-300m-mls-german-ft/blob/main/run.sh) |
|
||||||
|
|
||||||
|
## Connectionist Temporal Classification With Adapters
|
||||||
|
|
||||||
|
The script [`run_speech_recognition_ctc_adapter.py`](https://github.com/huggingface/transformers/blob/main/examples/pytorch/speech-recognition/run_speech_recognition_ctc_adapter.py) can be used to fine-tune adapter layers for [Wav2Vec2-like models like MMS](https://huggingface.co/docs/transformers/main/en/model_doc/mms) for automatic speech recognition.
|
||||||
|
|
||||||
|
### MMS Model
|
||||||
|
|
||||||
|
The [Massive Multilingual Speech (MMS) model](https://huggingface.co/facebook/mms-1b-all) has been pre-trained and fine-tuned
|
||||||
|
on 1000+ languages. The model makes use of adapter attention layers to fine-tune only a small part
|
||||||
|
of the model on a specific language. The model already comes with fine-tuned adapter layers for 1000+ languages and
|
||||||
|
can be used for inference for 1000+ languages out of the box.
|
||||||
|
|
||||||
|
However, for improved performance or more specific use cases one can re-initialize the adapter weights, freeze all
|
||||||
|
other weights and fine-tune them on a specific dataset as shown in the [example below](#examples-ctc-adapter).
|
||||||
|
|
||||||
|
Note that the adapter weights include low dimensional linear layers for every attention block as well as the final language
|
||||||
|
model head layers.
|
||||||
|
|
||||||
|
### Examples CTC Adapter
|
||||||
|
|
||||||
|
In the following we will look at how one can fine-tune adapter weights for any of the
|
||||||
|
[MMS CTC checkpoints](https://huggingface.co/models?pipeline_tag=automatic-speech-recognition&other=mms&sort=downloads) in less than 1 hour.
|
||||||
|
|
||||||
|
#### Common Voice CTC Adapter
|
||||||
|
|
||||||
|
As in the examples [above](#examples-ctc), we fine-tune on Common Voice's 6 dataset in Turkish as an example.
|
||||||
|
Contrary to [`run_speech_recognition_ctc.py`](https://github.com/huggingface/transformers/blob/main/examples/pytorch/speech-recognition/run_speech_recognition_ctc.py) before there is a `--target_language` which has to be defined to state for which
|
||||||
|
language or concept the adapter layers shall be trained. The adapter weights will then
|
||||||
|
accordingly be called `adapter.{<target_language}.safetensors`.
|
||||||
|
|
||||||
|
Let's run an example script. Make sure to be logged in so that your model can be directly uploaded to the Hub.
|
||||||
|
```
|
||||||
|
huggingface-cli login
|
||||||
|
```
|
||||||
|
|
||||||
|
Now, let's run an example and upload it to the Hub under `wav2vec2-common_voice-tr-mms-demo`.
|
||||||
|
|
||||||
|
```sh
|
||||||
|
python run_speech_recognition_ctc.py \
|
||||||
|
--dataset_name="common_voice" \
|
||||||
|
--model_name_or_path="facebook/mms-1b-all" \
|
||||||
|
--dataset_config_name="tr" \
|
||||||
|
--output_dir="./wav2vec2-common_voice-tr-mms-demo" \
|
||||||
|
--num_train_epochs="4" \
|
||||||
|
--per_device_train_batch_size="32" \
|
||||||
|
--learning_rate="1e-3" \
|
||||||
|
--warmup_steps="100" \
|
||||||
|
--evaluation_strategy="steps" \
|
||||||
|
--text_column_name="sentence" \
|
||||||
|
--length_column_name="input_length" \
|
||||||
|
--save_steps="200" \
|
||||||
|
--eval_steps="100" \
|
||||||
|
--save_total_limit="3" \
|
||||||
|
--target_language="tur" \
|
||||||
|
--gradient_checkpointing \
|
||||||
|
--chars_to_ignore , ? . ! - \; \: \" “ % ‘ ” <20> \
|
||||||
|
--fp16 \
|
||||||
|
--group_by_length \
|
||||||
|
--do_train --do_eval \
|
||||||
|
--push_to_hub
|
||||||
|
```
|
||||||
|
|
||||||
|
This should take less than 10 minutes on most GPUs and you should very quickly get word error rates
|
||||||
|
below 27%.
|
||||||
|
|
||||||
|
For an example run, you can have a look at [`patrickvonplaten/wav2vec2-common_voice-tr-mms-demo`](https://huggingface.co/patrickvonplaten/wav2vec2-common_voice-tr-mms-demo).
|
||||||
|
|
||||||
|
|
||||||
|
If you'd like to train another adapter model with the same base model, you can simply re-use the same `--output_dir`,
|
||||||
|
but make sure to pass the `--output_dir` folder also to `--tokenizer_name_or_path` so that the vocabulary is not
|
||||||
|
overwritten but **extended**. Assuming you would like to train adapter weights on Swedish in addition to Turkish and save
|
||||||
|
the adapter weights in the same model repo, you can run:
|
||||||
|
|
||||||
|
```sh
|
||||||
|
python run_speech_recognition_ctc.py \
|
||||||
|
--dataset_name="common_voice" \
|
||||||
|
--model_name_or_path="facebook/mms-1b-all" \
|
||||||
|
--dataset_config_name="sw" \
|
||||||
|
--output_dir="./wav2vec2-common_voice-tr-mms-demo" \
|
||||||
|
--tokenizer_name_or_path="./wav2vec2-common_voice-tr-mms-demo" \
|
||||||
|
--num_train_epochs="4" \
|
||||||
|
--per_device_train_batch_size="32" \
|
||||||
|
--learning_rate="1e-3" \
|
||||||
|
--warmup_steps="100" \
|
||||||
|
--evaluation_strategy="steps" \
|
||||||
|
--text_column_name="sentence" \
|
||||||
|
--length_column_name="input_length" \
|
||||||
|
--save_steps="200" \
|
||||||
|
--eval_steps="100" \
|
||||||
|
--save_total_limit="3" \
|
||||||
|
--target_language="swe" \
|
||||||
|
--gradient_checkpointing \
|
||||||
|
--chars_to_ignore , ? . ! - \; \: \" “ % ‘ ” <20> \
|
||||||
|
--fp16 \
|
||||||
|
--group_by_length \
|
||||||
|
--do_train --do_eval \
|
||||||
|
--push_to_hub
|
||||||
|
```
|
||||||
|
|
||||||
|
Now you should have both `adapter.tur.safetensors` and `adapter.swe.safetensors` in the model repo
|
||||||
|
and you can load the respective language with:
|
||||||
|
```py
|
||||||
|
model.load_adapter("tur") # or "swe"
|
||||||
|
```
|
||||||
|
respectively.
|
||||||
|
|
||||||
## Sequence to Sequence
|
## Sequence to Sequence
|
||||||
|
|
||||||
The script [`run_speech_recognition_seq2seq.py`](https://github.com/huggingface/transformers/blob/main/examples/pytorch/speech-recognition/run_speech_recognition_seq2seq.py) can be used to fine-tune any [Speech Sequence-to-Sequence Model](https://huggingface.co/docs/transformers/main/en/model_doc/auto#transformers.AutoModelForSpeechSeq2Seq) for automatic speech
|
The script [`run_speech_recognition_seq2seq.py`](https://github.com/huggingface/transformers/blob/main/examples/pytorch/speech-recognition/run_speech_recognition_seq2seq.py) can be used to fine-tune any [Speech Sequence-to-Sequence Model](https://huggingface.co/docs/transformers/main/en/model_doc/auto#transformers.AutoModelForSpeechSeq2Seq) for automatic speech
|
||||||
|
799
examples/pytorch/speech-recognition/run_speech_recognition_ctc_adapter.py
Executable file
799
examples/pytorch/speech-recognition/run_speech_recognition_ctc_adapter.py
Executable file
@ -0,0 +1,799 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
""" Fine-tuning a 🤗 Transformers CTC adapter model for automatic speech recognition"""
|
||||||
|
|
||||||
|
import functools
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import sys
|
||||||
|
import warnings
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Dict, List, Optional, Union
|
||||||
|
|
||||||
|
import datasets
|
||||||
|
import evaluate
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from datasets import DatasetDict, load_dataset
|
||||||
|
from safetensors.torch import save_file as safe_save_file
|
||||||
|
|
||||||
|
import transformers
|
||||||
|
from transformers import (
|
||||||
|
AutoConfig,
|
||||||
|
AutoFeatureExtractor,
|
||||||
|
AutoModelForCTC,
|
||||||
|
AutoProcessor,
|
||||||
|
AutoTokenizer,
|
||||||
|
HfArgumentParser,
|
||||||
|
Trainer,
|
||||||
|
TrainingArguments,
|
||||||
|
Wav2Vec2Processor,
|
||||||
|
set_seed,
|
||||||
|
)
|
||||||
|
from transformers.models.wav2vec2.modeling_wav2vec2 import WAV2VEC2_ADAPTER_SAFE_FILE
|
||||||
|
from transformers.trainer_utils import get_last_checkpoint, is_main_process
|
||||||
|
from transformers.utils import check_min_version, send_example_telemetry
|
||||||
|
from transformers.utils.versions import require_version
|
||||||
|
|
||||||
|
|
||||||
|
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||||
|
check_min_version("4.31.0.dev0")
|
||||||
|
|
||||||
|
require_version("datasets>=1.18.0", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt")
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def list_field(default=None, metadata=None):
|
||||||
|
return field(default_factory=lambda: default, metadata=metadata)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ModelArguments:
|
||||||
|
"""
|
||||||
|
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
|
||||||
|
"""
|
||||||
|
|
||||||
|
model_name_or_path: str = field(
|
||||||
|
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
|
||||||
|
)
|
||||||
|
tokenizer_name_or_path: Optional[str] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "Path to pretrained tokenizer or tokenizer identifier from huggingface.co/models"},
|
||||||
|
)
|
||||||
|
cache_dir: Optional[str] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"},
|
||||||
|
)
|
||||||
|
final_dropout: float = field(
|
||||||
|
default=0.0,
|
||||||
|
metadata={"help": "The dropout probability for the final projection layer."},
|
||||||
|
)
|
||||||
|
mask_time_prob: float = field(
|
||||||
|
default=0.05,
|
||||||
|
metadata={
|
||||||
|
"help": (
|
||||||
|
"Probability of each feature vector along the time axis to be chosen as the start of the vector"
|
||||||
|
"span to be masked. Approximately ``mask_time_prob * sequence_length // mask_time_length`` feature"
|
||||||
|
"vectors will be masked along the time axis."
|
||||||
|
)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
mask_time_length: int = field(
|
||||||
|
default=10,
|
||||||
|
metadata={"help": "Length of vector span to mask along the time axis."},
|
||||||
|
)
|
||||||
|
mask_feature_prob: float = field(
|
||||||
|
default=0.0,
|
||||||
|
metadata={
|
||||||
|
"help": (
|
||||||
|
"Probability of each feature vector along the feature axis to be chosen as the start of the vectorspan"
|
||||||
|
" to be masked. Approximately ``mask_feature_prob * sequence_length // mask_feature_length`` feature"
|
||||||
|
" bins will be masked along the time axis."
|
||||||
|
)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
mask_feature_length: int = field(
|
||||||
|
default=10,
|
||||||
|
metadata={"help": "Length of vector span to mask along the feature axis."},
|
||||||
|
)
|
||||||
|
layerdrop: float = field(default=0.0, metadata={"help": "The LayerDrop probability."})
|
||||||
|
ctc_loss_reduction: Optional[str] = field(
|
||||||
|
default="mean", metadata={"help": "The way the ctc loss should be reduced. Should be one of 'mean' or 'sum'."}
|
||||||
|
)
|
||||||
|
adapter_attn_dim: int = field(
|
||||||
|
default=16,
|
||||||
|
metadata={
|
||||||
|
"help": "The hidden dimension of the adapter layers that will be randomly initialized and trained. The higher the dimension, the more capacity is given to the adapter weights. Note that only the adapter weights are fine-tuned."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DataTrainingArguments:
|
||||||
|
"""
|
||||||
|
Arguments pertaining to what data we are going to input our model for training and eval.
|
||||||
|
|
||||||
|
Using `HfArgumentParser` we can turn this class
|
||||||
|
into argparse arguments to be able to specify them on
|
||||||
|
the command line.
|
||||||
|
"""
|
||||||
|
|
||||||
|
dataset_name: str = field(
|
||||||
|
metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
|
||||||
|
)
|
||||||
|
target_language: Optional[str] = field(
|
||||||
|
metadata={
|
||||||
|
"help": (
|
||||||
|
"The target language on which the adapter attention layers"
|
||||||
|
" should be trained on in ISO 693-3 code, e.g. `tur` for Turkish"
|
||||||
|
" Wav2Vec2's MMS ISO codes can be looked up here: https://dl.fbaipublicfiles.com/mms/misc/language_coverage_mms.html"
|
||||||
|
" If you are not training the adapter layers on a language, simply choose"
|
||||||
|
" another accronym that fits your data."
|
||||||
|
)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
dataset_config_name: str = field(
|
||||||
|
default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
|
||||||
|
)
|
||||||
|
train_split_name: str = field(
|
||||||
|
default="train+validation",
|
||||||
|
metadata={
|
||||||
|
"help": (
|
||||||
|
"The name of the training data set split to use (via the datasets library). Defaults to "
|
||||||
|
"'train+validation'"
|
||||||
|
)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
eval_split_name: str = field(
|
||||||
|
default="test",
|
||||||
|
metadata={
|
||||||
|
"help": "The name of the evaluation data set split to use (via the datasets library). Defaults to 'test'"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
audio_column_name: str = field(
|
||||||
|
default="audio",
|
||||||
|
metadata={"help": "The name of the dataset column containing the audio data. Defaults to 'audio'"},
|
||||||
|
)
|
||||||
|
text_column_name: str = field(
|
||||||
|
default="text",
|
||||||
|
metadata={"help": "The name of the dataset column containing the text data. Defaults to 'text'"},
|
||||||
|
)
|
||||||
|
overwrite_cache: bool = field(
|
||||||
|
default=False, metadata={"help": "Overwrite the cached preprocessed datasets or not."}
|
||||||
|
)
|
||||||
|
preprocessing_num_workers: Optional[int] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "The number of processes to use for the preprocessing."},
|
||||||
|
)
|
||||||
|
max_train_samples: Optional[int] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={
|
||||||
|
"help": (
|
||||||
|
"For debugging purposes or quicker training, truncate the number of training examples to this "
|
||||||
|
"value if set."
|
||||||
|
)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
max_eval_samples: Optional[int] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={
|
||||||
|
"help": (
|
||||||
|
"For debugging purposes or quicker training, truncate the number of validation examples to this "
|
||||||
|
"value if set."
|
||||||
|
)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
chars_to_ignore: Optional[List[str]] = list_field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "A list of characters to remove from the transcripts."},
|
||||||
|
)
|
||||||
|
eval_metrics: List[str] = list_field(
|
||||||
|
default=["wer"],
|
||||||
|
metadata={"help": "A list of metrics the model should be evaluated on. E.g. `'wer cer'`"},
|
||||||
|
)
|
||||||
|
max_duration_in_seconds: float = field(
|
||||||
|
default=20.0,
|
||||||
|
metadata={
|
||||||
|
"help": (
|
||||||
|
"Filter audio files that are longer than `max_duration_in_seconds` seconds to"
|
||||||
|
" 'max_duration_in_seconds`"
|
||||||
|
)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
min_duration_in_seconds: float = field(
|
||||||
|
default=0.0, metadata={"help": "Filter audio files that are shorter than `min_duration_in_seconds` seconds"}
|
||||||
|
)
|
||||||
|
preprocessing_only: bool = field(
|
||||||
|
default=False,
|
||||||
|
metadata={
|
||||||
|
"help": (
|
||||||
|
"Whether to only do data preprocessing and skip training. This is especially useful when data"
|
||||||
|
" preprocessing errors out in distributed training due to timeout. In this case, one should run the"
|
||||||
|
" preprocessing in a non-distributed setup with `preprocessing_only=True` so that the cached datasets"
|
||||||
|
" can consequently be loaded in distributed training"
|
||||||
|
)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
use_auth_token: bool = field(
|
||||||
|
default=False,
|
||||||
|
metadata={
|
||||||
|
"help": (
|
||||||
|
"If :obj:`True`, will use the token generated when running"
|
||||||
|
":obj:`huggingface-cli login` as HTTP bearer authorization for remote files."
|
||||||
|
)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
unk_token: str = field(
|
||||||
|
default="[UNK]",
|
||||||
|
metadata={"help": "The unk token for the tokenizer"},
|
||||||
|
)
|
||||||
|
pad_token: str = field(
|
||||||
|
default="[PAD]",
|
||||||
|
metadata={"help": "The padding token for the tokenizer"},
|
||||||
|
)
|
||||||
|
word_delimiter_token: str = field(
|
||||||
|
default="|",
|
||||||
|
metadata={"help": "The word delimiter token for the tokenizer"},
|
||||||
|
)
|
||||||
|
overwrite_lang_vocab: bool = field(
|
||||||
|
default=False,
|
||||||
|
metadata={"help": ("If :obj:`True`, will overwrite existing `target_language` vocabulary of tokenizer.")},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DataCollatorCTCWithPadding:
|
||||||
|
"""
|
||||||
|
Data collator that will dynamically pad the inputs received.
|
||||||
|
Args:
|
||||||
|
processor (:class:`~transformers.AutoProcessor`)
|
||||||
|
The processor used for proccessing the data.
|
||||||
|
padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
|
||||||
|
Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
|
||||||
|
among:
|
||||||
|
* :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
|
||||||
|
sequence if provided).
|
||||||
|
* :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
|
||||||
|
maximum acceptable input length for the model if that argument is not provided.
|
||||||
|
* :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
|
||||||
|
different lengths).
|
||||||
|
max_length (:obj:`int`, `optional`):
|
||||||
|
Maximum length of the ``input_values`` of the returned list and optionally padding length (see above).
|
||||||
|
max_length_labels (:obj:`int`, `optional`):
|
||||||
|
Maximum length of the ``labels`` returned list and optionally padding length (see above).
|
||||||
|
pad_to_multiple_of (:obj:`int`, `optional`):
|
||||||
|
If set will pad the sequence to a multiple of the provided value.
|
||||||
|
This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
|
||||||
|
7.5 (Volta).
|
||||||
|
"""
|
||||||
|
|
||||||
|
processor: AutoProcessor
|
||||||
|
padding: Union[bool, str] = "longest"
|
||||||
|
pad_to_multiple_of: Optional[int] = None
|
||||||
|
pad_to_multiple_of_labels: Optional[int] = None
|
||||||
|
|
||||||
|
def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
|
||||||
|
# split inputs and labels since they have to be of different lenghts and need
|
||||||
|
# different padding methods
|
||||||
|
input_features = [{"input_values": feature["input_values"]} for feature in features]
|
||||||
|
label_features = [{"input_ids": feature["labels"]} for feature in features]
|
||||||
|
|
||||||
|
batch = self.processor.pad(
|
||||||
|
input_features,
|
||||||
|
padding=self.padding,
|
||||||
|
pad_to_multiple_of=self.pad_to_multiple_of,
|
||||||
|
return_tensors="pt",
|
||||||
|
)
|
||||||
|
|
||||||
|
labels_batch = self.processor.pad(
|
||||||
|
labels=label_features,
|
||||||
|
padding=self.padding,
|
||||||
|
pad_to_multiple_of=self.pad_to_multiple_of_labels,
|
||||||
|
return_tensors="pt",
|
||||||
|
)
|
||||||
|
|
||||||
|
# replace padding with -100 to ignore loss correctly
|
||||||
|
labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
|
||||||
|
|
||||||
|
batch["labels"] = labels
|
||||||
|
if "attention_mask" in batch:
|
||||||
|
batch["attention_mask"] = batch["attention_mask"].to(torch.long)
|
||||||
|
|
||||||
|
return batch
|
||||||
|
|
||||||
|
|
||||||
|
def create_vocabulary_from_data(
|
||||||
|
datasets: DatasetDict,
|
||||||
|
word_delimiter_token: Optional[str] = None,
|
||||||
|
unk_token: Optional[str] = None,
|
||||||
|
pad_token: Optional[str] = None,
|
||||||
|
):
|
||||||
|
# Given training and test labels create vocabulary
|
||||||
|
def extract_all_chars(batch):
|
||||||
|
all_text = " ".join(batch["target_text"])
|
||||||
|
vocab = list(set(all_text))
|
||||||
|
return {"vocab": [vocab], "all_text": [all_text]}
|
||||||
|
|
||||||
|
vocabs = datasets.map(
|
||||||
|
extract_all_chars,
|
||||||
|
batched=True,
|
||||||
|
batch_size=-1,
|
||||||
|
keep_in_memory=True,
|
||||||
|
remove_columns=datasets["train"].column_names,
|
||||||
|
)
|
||||||
|
|
||||||
|
# take union of all unique characters in each dataset
|
||||||
|
vocab_set = functools.reduce(
|
||||||
|
lambda vocab_1, vocab_2: set(vocab_1["vocab"][0]) | set(vocab_2["vocab"][0]), vocabs.values()
|
||||||
|
)
|
||||||
|
|
||||||
|
vocab_dict = {v: k for k, v in enumerate(sorted(vocab_set))}
|
||||||
|
|
||||||
|
# replace white space with delimiter token
|
||||||
|
if word_delimiter_token is not None:
|
||||||
|
vocab_dict[word_delimiter_token] = vocab_dict[" "]
|
||||||
|
del vocab_dict[" "]
|
||||||
|
|
||||||
|
# add unk and pad token
|
||||||
|
if unk_token is not None:
|
||||||
|
vocab_dict[unk_token] = len(vocab_dict)
|
||||||
|
|
||||||
|
if pad_token is not None:
|
||||||
|
vocab_dict[pad_token] = len(vocab_dict)
|
||||||
|
|
||||||
|
return vocab_dict
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# See all possible arguments in src/transformers/training_args.py
|
||||||
|
# or by passing the --help flag to this script.
|
||||||
|
# We now keep distinct sets of args, for a cleaner separation of concerns.
|
||||||
|
|
||||||
|
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
|
||||||
|
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
|
||||||
|
# If we pass only one argument to the script and it's the path to a json file,
|
||||||
|
# let's parse it to get our arguments.
|
||||||
|
model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
|
||||||
|
else:
|
||||||
|
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
||||||
|
|
||||||
|
# Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
|
||||||
|
# information sent is the one passed as arguments along with your Python/PyTorch versions.
|
||||||
|
send_example_telemetry("run_speech_recognition_ctc_adapter", model_args, data_args)
|
||||||
|
|
||||||
|
# Detecting last checkpoint.
|
||||||
|
last_checkpoint = None
|
||||||
|
if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
|
||||||
|
last_checkpoint = get_last_checkpoint(training_args.output_dir)
|
||||||
|
if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"Output directory ({training_args.output_dir}) already exists and is not empty. "
|
||||||
|
"Use --overwrite_output_dir to overcome."
|
||||||
|
)
|
||||||
|
elif last_checkpoint is not None:
|
||||||
|
logger.info(
|
||||||
|
f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
|
||||||
|
"the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Setup logging
|
||||||
|
logging.basicConfig(
|
||||||
|
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||||
|
datefmt="%m/%d/%Y %H:%M:%S",
|
||||||
|
handlers=[logging.StreamHandler(sys.stdout)],
|
||||||
|
)
|
||||||
|
logger.setLevel(logging.INFO if is_main_process(training_args.local_rank) else logging.WARN)
|
||||||
|
|
||||||
|
# Log on each process the small summary:
|
||||||
|
logger.warning(
|
||||||
|
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
|
||||||
|
f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
|
||||||
|
)
|
||||||
|
# Set the verbosity to info of the Transformers logger (on main process only):
|
||||||
|
if is_main_process(training_args.local_rank):
|
||||||
|
transformers.utils.logging.set_verbosity_info()
|
||||||
|
logger.info("Training/evaluation parameters %s", training_args)
|
||||||
|
|
||||||
|
# Set seed before initializing model.
|
||||||
|
set_seed(training_args.seed)
|
||||||
|
|
||||||
|
# 1. First, let's load the dataset
|
||||||
|
raw_datasets = DatasetDict()
|
||||||
|
|
||||||
|
if training_args.do_train:
|
||||||
|
raw_datasets["train"] = load_dataset(
|
||||||
|
data_args.dataset_name,
|
||||||
|
data_args.dataset_config_name,
|
||||||
|
split=data_args.train_split_name,
|
||||||
|
use_auth_token=data_args.use_auth_token,
|
||||||
|
)
|
||||||
|
|
||||||
|
if data_args.audio_column_name not in raw_datasets["train"].column_names:
|
||||||
|
raise ValueError(
|
||||||
|
f"--audio_column_name '{data_args.audio_column_name}' not found in dataset '{data_args.dataset_name}'."
|
||||||
|
" Make sure to set `--audio_column_name` to the correct audio column - one of"
|
||||||
|
f" {', '.join(raw_datasets['train'].column_names)}."
|
||||||
|
)
|
||||||
|
|
||||||
|
if data_args.text_column_name not in raw_datasets["train"].column_names:
|
||||||
|
raise ValueError(
|
||||||
|
f"--text_column_name {data_args.text_column_name} not found in dataset '{data_args.dataset_name}'. "
|
||||||
|
"Make sure to set `--text_column_name` to the correct text column - one of "
|
||||||
|
f"{', '.join(raw_datasets['train'].column_names)}."
|
||||||
|
)
|
||||||
|
|
||||||
|
if data_args.max_train_samples is not None:
|
||||||
|
raw_datasets["train"] = raw_datasets["train"].select(range(data_args.max_train_samples))
|
||||||
|
|
||||||
|
if training_args.do_eval:
|
||||||
|
raw_datasets["eval"] = load_dataset(
|
||||||
|
data_args.dataset_name,
|
||||||
|
data_args.dataset_config_name,
|
||||||
|
split=data_args.eval_split_name,
|
||||||
|
use_auth_token=data_args.use_auth_token,
|
||||||
|
)
|
||||||
|
|
||||||
|
if data_args.max_eval_samples is not None:
|
||||||
|
raw_datasets["eval"] = raw_datasets["eval"].select(range(data_args.max_eval_samples))
|
||||||
|
|
||||||
|
# 2. We remove some special characters from the datasets
|
||||||
|
# that make training complicated and do not help in transcribing the speech
|
||||||
|
# E.g. characters, such as `,` and `.` do not really have an acoustic characteristic
|
||||||
|
# that could be easily picked up by the model
|
||||||
|
chars_to_ignore_regex = (
|
||||||
|
f'[{"".join(data_args.chars_to_ignore)}]' if data_args.chars_to_ignore is not None else None
|
||||||
|
)
|
||||||
|
text_column_name = data_args.text_column_name
|
||||||
|
|
||||||
|
def remove_special_characters(batch):
|
||||||
|
if chars_to_ignore_regex is not None:
|
||||||
|
batch["target_text"] = re.sub(chars_to_ignore_regex, "", batch[text_column_name]).lower() + " "
|
||||||
|
else:
|
||||||
|
batch["target_text"] = batch[text_column_name].lower() + " "
|
||||||
|
return batch
|
||||||
|
|
||||||
|
with training_args.main_process_first(desc="dataset map special characters removal"):
|
||||||
|
raw_datasets = raw_datasets.map(
|
||||||
|
remove_special_characters,
|
||||||
|
remove_columns=[text_column_name],
|
||||||
|
desc="remove special characters from datasets",
|
||||||
|
)
|
||||||
|
|
||||||
|
# save special tokens for tokenizer
|
||||||
|
word_delimiter_token = data_args.word_delimiter_token
|
||||||
|
unk_token = data_args.unk_token
|
||||||
|
pad_token = data_args.pad_token
|
||||||
|
|
||||||
|
# 3. Next, let's load the config as we might need it to create
|
||||||
|
# the tokenizer
|
||||||
|
# load config
|
||||||
|
config = AutoConfig.from_pretrained(
|
||||||
|
model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_auth_token=data_args.use_auth_token
|
||||||
|
)
|
||||||
|
|
||||||
|
# 4. Next, if no tokenizer file is defined,
|
||||||
|
# we create the vocabulary of the model by extracting all unique characters from
|
||||||
|
# the training and evaluation datasets
|
||||||
|
# We need to make sure that only first rank saves vocabulary
|
||||||
|
# make sure all processes wait until vocab is created
|
||||||
|
tokenizer_name_or_path = model_args.tokenizer_name_or_path
|
||||||
|
tokenizer_kwargs = {}
|
||||||
|
|
||||||
|
vocab_dict = {}
|
||||||
|
if tokenizer_name_or_path is not None:
|
||||||
|
# load vocabulary of other adapter languages so that new language can be appended
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path, use_auth_token=data_args.use_auth_token)
|
||||||
|
vocab_dict = tokenizer.vocab.copy()
|
||||||
|
if tokenizer.target_lang is None:
|
||||||
|
raise ValueError("Make sure to load a multi-lingual tokenizer with a set target language.")
|
||||||
|
|
||||||
|
if data_args.target_language in tokenizer.vocab and not data_args.overwrite_lang_vocab:
|
||||||
|
logger.info(
|
||||||
|
"Adapter language already exists."
|
||||||
|
" Skipping vocabulary creating. If you want to create a new vocabulary"
|
||||||
|
f" for {data_args.target_language} make sure to add '--overwrite_lang_vocab'"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
tokenizer_name_or_path = None
|
||||||
|
|
||||||
|
if tokenizer_name_or_path is None:
|
||||||
|
# save vocab in training output dir
|
||||||
|
tokenizer_name_or_path = training_args.output_dir
|
||||||
|
|
||||||
|
vocab_file = os.path.join(tokenizer_name_or_path, "vocab.json")
|
||||||
|
|
||||||
|
with training_args.main_process_first():
|
||||||
|
if training_args.overwrite_output_dir and os.path.isfile(vocab_file):
|
||||||
|
try:
|
||||||
|
os.remove(vocab_file)
|
||||||
|
except OSError:
|
||||||
|
# in shared file-systems it might be the case that
|
||||||
|
# two processes try to delete the vocab file at the some time
|
||||||
|
pass
|
||||||
|
|
||||||
|
with training_args.main_process_first(desc="dataset map vocabulary creation"):
|
||||||
|
if not os.path.isfile(vocab_file):
|
||||||
|
os.makedirs(tokenizer_name_or_path, exist_ok=True)
|
||||||
|
lang_dict = create_vocabulary_from_data(
|
||||||
|
raw_datasets,
|
||||||
|
word_delimiter_token=word_delimiter_token,
|
||||||
|
unk_token=unk_token,
|
||||||
|
pad_token=pad_token,
|
||||||
|
)
|
||||||
|
|
||||||
|
# if we doing adapter language training, save
|
||||||
|
# vocab with adpter language
|
||||||
|
if data_args.target_language is not None:
|
||||||
|
vocab_dict[data_args.target_language] = lang_dict
|
||||||
|
|
||||||
|
# save vocab dict to be loaded into tokenizer
|
||||||
|
with open(vocab_file, "w") as file:
|
||||||
|
json.dump(vocab_dict, file)
|
||||||
|
|
||||||
|
# if tokenizer has just been created
|
||||||
|
# it is defined by `tokenizer_class` if present in config else by `model_type`
|
||||||
|
tokenizer_kwargs = {
|
||||||
|
"config": config if config.tokenizer_class is not None else None,
|
||||||
|
"tokenizer_type": config.model_type if config.tokenizer_class is None else None,
|
||||||
|
"unk_token": unk_token,
|
||||||
|
"pad_token": pad_token,
|
||||||
|
"word_delimiter_token": word_delimiter_token,
|
||||||
|
"target_lang": data_args.target_language,
|
||||||
|
}
|
||||||
|
|
||||||
|
# 5. Now we can instantiate the feature extractor, tokenizer and model
|
||||||
|
# Note for distributed training, the .from_pretrained methods guarantee that only
|
||||||
|
# one local process can concurrently download model & vocab.
|
||||||
|
|
||||||
|
# load feature_extractor and tokenizer
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
|
tokenizer_name_or_path,
|
||||||
|
use_auth_token=data_args.use_auth_token,
|
||||||
|
**tokenizer_kwargs,
|
||||||
|
)
|
||||||
|
feature_extractor = AutoFeatureExtractor.from_pretrained(
|
||||||
|
model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_auth_token=data_args.use_auth_token
|
||||||
|
)
|
||||||
|
|
||||||
|
# adapt config
|
||||||
|
config.update(
|
||||||
|
{
|
||||||
|
"final_dropout": model_args.final_dropout,
|
||||||
|
"mask_time_prob": model_args.mask_time_prob,
|
||||||
|
"mask_time_length": model_args.mask_time_length,
|
||||||
|
"mask_feature_prob": model_args.mask_feature_prob,
|
||||||
|
"mask_feature_length": model_args.mask_feature_length,
|
||||||
|
"gradient_checkpointing": training_args.gradient_checkpointing,
|
||||||
|
"layerdrop": model_args.layerdrop,
|
||||||
|
"ctc_loss_reduction": model_args.ctc_loss_reduction,
|
||||||
|
"pad_token_id": tokenizer.pad_token_id,
|
||||||
|
"vocab_size": len(tokenizer),
|
||||||
|
"adapter_attn_dim": model_args.adapter_attn_dim,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# create model
|
||||||
|
model = AutoModelForCTC.from_pretrained(
|
||||||
|
model_args.model_name_or_path,
|
||||||
|
cache_dir=model_args.cache_dir,
|
||||||
|
config=config,
|
||||||
|
use_auth_token=data_args.use_auth_token,
|
||||||
|
ignore_mismatched_sizes=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# if attn adapter is defined, freeze all non-adapter weights
|
||||||
|
if model.config.adapter_attn_dim is not None:
|
||||||
|
model.init_adapter_layers()
|
||||||
|
# first we freeze the whole base model
|
||||||
|
model.freeze_base_model()
|
||||||
|
|
||||||
|
# next we unfreeze all adapter layers
|
||||||
|
adapter_weights = model._get_adapters()
|
||||||
|
for param in adapter_weights.values():
|
||||||
|
param.requires_grad = True
|
||||||
|
|
||||||
|
# 6. Now we preprocess the datasets including loading the audio, resampling and normalization
|
||||||
|
# Thankfully, `datasets` takes care of automatically loading and resampling the audio,
|
||||||
|
# so that we just need to set the correct target sampling rate and normalize the input
|
||||||
|
# via the `feature_extractor`
|
||||||
|
|
||||||
|
# make sure that dataset decodes audio with correct sampling rate
|
||||||
|
dataset_sampling_rate = next(iter(raw_datasets.values())).features[data_args.audio_column_name].sampling_rate
|
||||||
|
if dataset_sampling_rate != feature_extractor.sampling_rate:
|
||||||
|
raw_datasets = raw_datasets.cast_column(
|
||||||
|
data_args.audio_column_name, datasets.features.Audio(sampling_rate=feature_extractor.sampling_rate)
|
||||||
|
)
|
||||||
|
|
||||||
|
# derive max & min input length for sample rate & max duration
|
||||||
|
max_input_length = data_args.max_duration_in_seconds * feature_extractor.sampling_rate
|
||||||
|
min_input_length = data_args.min_duration_in_seconds * feature_extractor.sampling_rate
|
||||||
|
audio_column_name = data_args.audio_column_name
|
||||||
|
num_workers = data_args.preprocessing_num_workers
|
||||||
|
|
||||||
|
# Preprocessing the datasets.
|
||||||
|
# We need to read the audio files as arrays and tokenize the targets.
|
||||||
|
def prepare_dataset(batch):
|
||||||
|
# load audio
|
||||||
|
sample = batch[audio_column_name]
|
||||||
|
|
||||||
|
inputs = feature_extractor(sample["array"], sampling_rate=sample["sampling_rate"])
|
||||||
|
batch["input_values"] = inputs.input_values[0]
|
||||||
|
batch["input_length"] = len(batch["input_values"])
|
||||||
|
|
||||||
|
# encode targets
|
||||||
|
batch["labels"] = tokenizer(batch["target_text"]).input_ids
|
||||||
|
return batch
|
||||||
|
|
||||||
|
with training_args.main_process_first(desc="dataset map preprocessing"):
|
||||||
|
vectorized_datasets = raw_datasets.map(
|
||||||
|
prepare_dataset,
|
||||||
|
remove_columns=next(iter(raw_datasets.values())).column_names,
|
||||||
|
num_proc=num_workers,
|
||||||
|
desc="preprocess datasets",
|
||||||
|
)
|
||||||
|
|
||||||
|
def is_audio_in_length_range(length):
|
||||||
|
return length > min_input_length and length < max_input_length
|
||||||
|
|
||||||
|
# filter data that is shorter than min_input_length
|
||||||
|
vectorized_datasets = vectorized_datasets.filter(
|
||||||
|
is_audio_in_length_range,
|
||||||
|
num_proc=num_workers,
|
||||||
|
input_columns=["input_length"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# 7. Next, we can prepare the training.
|
||||||
|
# Let's use word error rate (WER) as our evaluation metric,
|
||||||
|
# instantiate a data collator and the trainer
|
||||||
|
|
||||||
|
# Define evaluation metrics during training, *i.e.* word error rate, character error rate
|
||||||
|
eval_metrics = {metric: evaluate.load(metric) for metric in data_args.eval_metrics}
|
||||||
|
|
||||||
|
# for large datasets it is advised to run the preprocessing on a
|
||||||
|
# single machine first with ``args.preprocessing_only`` since there will mostly likely
|
||||||
|
# be a timeout when running the script in distributed mode.
|
||||||
|
# In a second step ``args.preprocessing_only`` can then be set to `False` to load the
|
||||||
|
# cached dataset
|
||||||
|
if data_args.preprocessing_only:
|
||||||
|
logger.info(f"Data preprocessing finished. Files cached at {vectorized_datasets.cache_files}")
|
||||||
|
return
|
||||||
|
|
||||||
|
def compute_metrics(pred):
|
||||||
|
pred_logits = pred.predictions
|
||||||
|
pred_ids = np.argmax(pred_logits, axis=-1)
|
||||||
|
|
||||||
|
pred.label_ids[pred.label_ids == -100] = tokenizer.pad_token_id
|
||||||
|
|
||||||
|
pred_str = tokenizer.batch_decode(pred_ids)
|
||||||
|
# we do not want to group tokens when computing the metrics
|
||||||
|
label_str = tokenizer.batch_decode(pred.label_ids, group_tokens=False)
|
||||||
|
|
||||||
|
metrics = {k: v.compute(predictions=pred_str, references=label_str) for k, v in eval_metrics.items()}
|
||||||
|
|
||||||
|
return metrics
|
||||||
|
|
||||||
|
# Now save everything to be able to create a single processor later
|
||||||
|
# make sure all processes wait until data is saved
|
||||||
|
with training_args.main_process_first():
|
||||||
|
# only the main process saves them
|
||||||
|
if is_main_process(training_args.local_rank):
|
||||||
|
# save feature extractor, tokenizer and config
|
||||||
|
feature_extractor.save_pretrained(training_args.output_dir)
|
||||||
|
tokenizer.save_pretrained(training_args.output_dir)
|
||||||
|
config.save_pretrained(training_args.output_dir)
|
||||||
|
|
||||||
|
try:
|
||||||
|
processor = AutoProcessor.from_pretrained(training_args.output_dir)
|
||||||
|
except (OSError, KeyError):
|
||||||
|
warnings.warn(
|
||||||
|
"Loading a processor from a feature extractor config that does not"
|
||||||
|
" include a `processor_class` attribute is deprecated and will be removed in v5. Please add the following "
|
||||||
|
" attribute to your `preprocessor_config.json` file to suppress this warning: "
|
||||||
|
" `'processor_class': 'Wav2Vec2Processor'`",
|
||||||
|
FutureWarning,
|
||||||
|
)
|
||||||
|
processor = Wav2Vec2Processor.from_pretrained(training_args.output_dir)
|
||||||
|
|
||||||
|
# Instantiate custom data collator
|
||||||
|
data_collator = DataCollatorCTCWithPadding(processor=processor)
|
||||||
|
|
||||||
|
# Initialize Trainer
|
||||||
|
trainer = Trainer(
|
||||||
|
model=model,
|
||||||
|
data_collator=data_collator,
|
||||||
|
args=training_args,
|
||||||
|
compute_metrics=compute_metrics,
|
||||||
|
train_dataset=vectorized_datasets["train"] if training_args.do_train else None,
|
||||||
|
eval_dataset=vectorized_datasets["eval"] if training_args.do_eval else None,
|
||||||
|
tokenizer=processor,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 8. Finally, we can start training
|
||||||
|
|
||||||
|
# Training
|
||||||
|
if training_args.do_train:
|
||||||
|
# use last checkpoint if exist
|
||||||
|
if last_checkpoint is not None:
|
||||||
|
checkpoint = last_checkpoint
|
||||||
|
elif os.path.isdir(model_args.model_name_or_path):
|
||||||
|
checkpoint = model_args.model_name_or_path
|
||||||
|
else:
|
||||||
|
checkpoint = None
|
||||||
|
|
||||||
|
train_result = trainer.train(resume_from_checkpoint=checkpoint)
|
||||||
|
trainer.save_model()
|
||||||
|
|
||||||
|
metrics = train_result.metrics
|
||||||
|
max_train_samples = (
|
||||||
|
data_args.max_train_samples
|
||||||
|
if data_args.max_train_samples is not None
|
||||||
|
else len(vectorized_datasets["train"])
|
||||||
|
)
|
||||||
|
metrics["train_samples"] = min(max_train_samples, len(vectorized_datasets["train"]))
|
||||||
|
|
||||||
|
trainer.log_metrics("train", metrics)
|
||||||
|
trainer.save_metrics("train", metrics)
|
||||||
|
trainer.save_state()
|
||||||
|
|
||||||
|
# Evaluation
|
||||||
|
results = {}
|
||||||
|
if training_args.do_eval:
|
||||||
|
logger.info("*** Evaluate ***")
|
||||||
|
metrics = trainer.evaluate()
|
||||||
|
max_eval_samples = (
|
||||||
|
data_args.max_eval_samples if data_args.max_eval_samples is not None else len(vectorized_datasets["eval"])
|
||||||
|
)
|
||||||
|
metrics["eval_samples"] = min(max_eval_samples, len(vectorized_datasets["eval"]))
|
||||||
|
|
||||||
|
trainer.log_metrics("eval", metrics)
|
||||||
|
trainer.save_metrics("eval", metrics)
|
||||||
|
|
||||||
|
# Write model card and (optionally) push to hub
|
||||||
|
config_name = data_args.dataset_config_name if data_args.dataset_config_name is not None else "na"
|
||||||
|
kwargs = {
|
||||||
|
"finetuned_from": model_args.model_name_or_path,
|
||||||
|
"tasks": "automatic-speech-recognition",
|
||||||
|
"tags": ["automatic-speech-recognition", data_args.dataset_name, "mms"],
|
||||||
|
"dataset_args": (
|
||||||
|
f"Config: {config_name}, Training split: {data_args.train_split_name}, Eval split:"
|
||||||
|
f" {data_args.eval_split_name}"
|
||||||
|
),
|
||||||
|
"dataset": f"{data_args.dataset_name.upper()} - {config_name.upper()}",
|
||||||
|
}
|
||||||
|
if "common_voice" in data_args.dataset_name:
|
||||||
|
kwargs["language"] = config_name
|
||||||
|
|
||||||
|
# make sure that adapter weights are saved seperately
|
||||||
|
adapter_file = WAV2VEC2_ADAPTER_SAFE_FILE.format(data_args.target_language)
|
||||||
|
adapter_file = os.path.join(training_args.output_dir, adapter_file)
|
||||||
|
logger.info(f"Saving adapter weights under {adapter_file}...")
|
||||||
|
safe_save_file(model._get_adapters(), adapter_file, metadata={"format": "pt"})
|
||||||
|
|
||||||
|
if training_args.push_to_hub:
|
||||||
|
trainer.push_to_hub(**kwargs)
|
||||||
|
else:
|
||||||
|
trainer.create_model_card(**kwargs)
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
@ -63,6 +63,7 @@ if SRC_DIRS is not None:
|
|||||||
import run_semantic_segmentation
|
import run_semantic_segmentation
|
||||||
import run_seq2seq_qa as run_squad_seq2seq
|
import run_seq2seq_qa as run_squad_seq2seq
|
||||||
import run_speech_recognition_ctc
|
import run_speech_recognition_ctc
|
||||||
|
import run_speech_recognition_ctc_adapter
|
||||||
import run_speech_recognition_seq2seq
|
import run_speech_recognition_seq2seq
|
||||||
import run_summarization
|
import run_summarization
|
||||||
import run_swag
|
import run_swag
|
||||||
@ -446,6 +447,38 @@ class ExamplesTests(TestCasePlus):
|
|||||||
result = get_results(tmp_dir)
|
result = get_results(tmp_dir)
|
||||||
self.assertLess(result["eval_loss"], result["train_loss"])
|
self.assertLess(result["eval_loss"], result["train_loss"])
|
||||||
|
|
||||||
|
def test_run_speech_recognition_ctc_adapter(self):
|
||||||
|
tmp_dir = self.get_auto_remove_tmp_dir()
|
||||||
|
testargs = f"""
|
||||||
|
run_speech_recognition_ctc_adapter.py
|
||||||
|
--output_dir {tmp_dir}
|
||||||
|
--model_name_or_path hf-internal-testing/tiny-random-wav2vec2
|
||||||
|
--dataset_name hf-internal-testing/librispeech_asr_dummy
|
||||||
|
--dataset_config_name clean
|
||||||
|
--train_split_name validation
|
||||||
|
--eval_split_name validation
|
||||||
|
--do_train
|
||||||
|
--do_eval
|
||||||
|
--learning_rate 1e-4
|
||||||
|
--per_device_train_batch_size 2
|
||||||
|
--per_device_eval_batch_size 1
|
||||||
|
--remove_unused_columns False
|
||||||
|
--overwrite_output_dir True
|
||||||
|
--preprocessing_num_workers 16
|
||||||
|
--max_steps 10
|
||||||
|
--target_language tur
|
||||||
|
--seed 42
|
||||||
|
""".split()
|
||||||
|
|
||||||
|
if is_cuda_and_apex_available():
|
||||||
|
testargs.append("--fp16")
|
||||||
|
|
||||||
|
with patch.object(sys, "argv", testargs):
|
||||||
|
run_speech_recognition_ctc_adapter.main()
|
||||||
|
result = get_results(tmp_dir)
|
||||||
|
self.assertTrue(os.path.isfile(os.path.join(tmp_dir, "./adapter.tur.safetensors")))
|
||||||
|
self.assertLess(result["eval_loss"], result["train_loss"])
|
||||||
|
|
||||||
def test_run_speech_recognition_seq2seq(self):
|
def test_run_speech_recognition_seq2seq(self):
|
||||||
tmp_dir = self.get_auto_remove_tmp_dir()
|
tmp_dir = self.get_auto_remove_tmp_dir()
|
||||||
testargs = f"""
|
testargs = f"""
|
||||||
|
@ -1181,6 +1181,14 @@ class HubertForCTC(HubertPreTrainedModel):
|
|||||||
"""
|
"""
|
||||||
self.hubert.feature_extractor._freeze_parameters()
|
self.hubert.feature_extractor._freeze_parameters()
|
||||||
|
|
||||||
|
def freeze_base_model(self):
|
||||||
|
"""
|
||||||
|
Calling this function will disable the gradient computation for the base model so that its parameters will not
|
||||||
|
be updated during training. Only the classification head will be updated.
|
||||||
|
"""
|
||||||
|
for param in self.hubert.parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
|
||||||
@add_start_docstrings_to_model_forward(HUBERT_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_model_forward(HUBERT_INPUTS_DOCSTRING)
|
||||||
@add_code_sample_docstrings(
|
@add_code_sample_docstrings(
|
||||||
checkpoint=_CHECKPOINT_FOR_DOC,
|
checkpoint=_CHECKPOINT_FOR_DOC,
|
||||||
|
@ -1016,6 +1016,14 @@ class SEWForCTC(SEWPreTrainedModel):
|
|||||||
"""
|
"""
|
||||||
self.sew.feature_extractor._freeze_parameters()
|
self.sew.feature_extractor._freeze_parameters()
|
||||||
|
|
||||||
|
def freeze_base_model(self):
|
||||||
|
"""
|
||||||
|
Calling this function will disable the gradient computation for the base model so that its parameters will not
|
||||||
|
be updated during training. Only the classification head will be updated.
|
||||||
|
"""
|
||||||
|
for param in self.sew.parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
|
||||||
@add_start_docstrings_to_model_forward(SEW_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_model_forward(SEW_INPUTS_DOCSTRING)
|
||||||
@add_code_sample_docstrings(
|
@add_code_sample_docstrings(
|
||||||
checkpoint=_CHECKPOINT_FOR_DOC,
|
checkpoint=_CHECKPOINT_FOR_DOC,
|
||||||
|
@ -1556,6 +1556,14 @@ class SEWDForCTC(SEWDPreTrainedModel):
|
|||||||
"""
|
"""
|
||||||
self.sew_d.feature_extractor._freeze_parameters()
|
self.sew_d.feature_extractor._freeze_parameters()
|
||||||
|
|
||||||
|
def freeze_base_model(self):
|
||||||
|
"""
|
||||||
|
Calling this function will disable the gradient computation for the base model so that its parameters will not
|
||||||
|
be updated during training. Only the classification head will be updated.
|
||||||
|
"""
|
||||||
|
for param in self.sew_d.parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
|
||||||
@add_start_docstrings_to_model_forward(SEWD_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_model_forward(SEWD_INPUTS_DOCSTRING)
|
||||||
@add_code_sample_docstrings(
|
@add_code_sample_docstrings(
|
||||||
checkpoint=_CHECKPOINT_FOR_DOC,
|
checkpoint=_CHECKPOINT_FOR_DOC,
|
||||||
|
@ -1425,6 +1425,14 @@ class UniSpeechForCTC(UniSpeechPreTrainedModel):
|
|||||||
"""
|
"""
|
||||||
self.unispeech.feature_extractor._freeze_parameters()
|
self.unispeech.feature_extractor._freeze_parameters()
|
||||||
|
|
||||||
|
def freeze_base_model(self):
|
||||||
|
"""
|
||||||
|
Calling this function will disable the gradient computation for the base model so that its parameters will not
|
||||||
|
be updated during training. Only the classification head will be updated.
|
||||||
|
"""
|
||||||
|
for param in self.unispeech.parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
|
||||||
@add_start_docstrings_to_model_forward(UNISPEECH_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_model_forward(UNISPEECH_INPUTS_DOCSTRING)
|
||||||
@add_code_sample_docstrings(
|
@add_code_sample_docstrings(
|
||||||
checkpoint=_CHECKPOINT_FOR_DOC,
|
checkpoint=_CHECKPOINT_FOR_DOC,
|
||||||
|
@ -1432,6 +1432,14 @@ class UniSpeechSatForCTC(UniSpeechSatPreTrainedModel):
|
|||||||
"""
|
"""
|
||||||
self.unispeech_sat.feature_extractor._freeze_parameters()
|
self.unispeech_sat.feature_extractor._freeze_parameters()
|
||||||
|
|
||||||
|
def freeze_base_model(self):
|
||||||
|
"""
|
||||||
|
Calling this function will disable the gradient computation for the base model so that its parameters will not
|
||||||
|
be updated during training. Only the classification head will be updated.
|
||||||
|
"""
|
||||||
|
for param in self.unispeech_sat.parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
|
||||||
@add_start_docstrings_to_model_forward(UNISPEECH_SAT_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_model_forward(UNISPEECH_SAT_INPUTS_DOCSTRING)
|
||||||
@add_code_sample_docstrings(
|
@add_code_sample_docstrings(
|
||||||
checkpoint=_CHECKPOINT_FOR_DOC,
|
checkpoint=_CHECKPOINT_FOR_DOC,
|
||||||
|
@ -1194,6 +1194,19 @@ class Wav2Vec2PreTrainedModel(PreTrainedModel):
|
|||||||
|
|
||||||
return adapter_weights
|
return adapter_weights
|
||||||
|
|
||||||
|
def init_adapter_layers(self):
|
||||||
|
"""
|
||||||
|
(Re-)initialize attention adapter layers and lm head for adapter-only fine-tuning
|
||||||
|
"""
|
||||||
|
# init attention adapters
|
||||||
|
for module in self.modules():
|
||||||
|
if isinstance(module, Wav2Vec2AttnAdapterLayer):
|
||||||
|
self._init_weights(module)
|
||||||
|
|
||||||
|
# init lm head
|
||||||
|
if isinstance(self, Wav2Vec2ForCTC):
|
||||||
|
self._init_weights(self.lm_head)
|
||||||
|
|
||||||
def load_adapter(self, target_lang: str, **kwargs):
|
def load_adapter(self, target_lang: str, **kwargs):
|
||||||
r"""
|
r"""
|
||||||
Load a language adapter model from a pre-trained adapter model.
|
Load a language adapter model from a pre-trained adapter model.
|
||||||
@ -1888,6 +1901,14 @@ class Wav2Vec2ForCTC(Wav2Vec2PreTrainedModel):
|
|||||||
"""
|
"""
|
||||||
self.wav2vec2.feature_extractor._freeze_parameters()
|
self.wav2vec2.feature_extractor._freeze_parameters()
|
||||||
|
|
||||||
|
def freeze_base_model(self):
|
||||||
|
"""
|
||||||
|
Calling this function will disable the gradient computation for the base model so that its parameters will not
|
||||||
|
be updated during training. Only the classification head will be updated.
|
||||||
|
"""
|
||||||
|
for param in self.wav2vec2.parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
|
||||||
@add_start_docstrings_to_model_forward(WAV_2_VEC_2_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_model_forward(WAV_2_VEC_2_INPUTS_DOCSTRING)
|
||||||
@add_code_sample_docstrings(
|
@add_code_sample_docstrings(
|
||||||
checkpoint=_CHECKPOINT_FOR_DOC,
|
checkpoint=_CHECKPOINT_FOR_DOC,
|
||||||
|
@ -1319,6 +1319,14 @@ class WavLMForCTC(WavLMPreTrainedModel):
|
|||||||
"""
|
"""
|
||||||
self.wavlm.feature_extractor._freeze_parameters()
|
self.wavlm.feature_extractor._freeze_parameters()
|
||||||
|
|
||||||
|
def freeze_base_model(self):
|
||||||
|
"""
|
||||||
|
Calling this function will disable the gradient computation for the base model so that its parameters will not
|
||||||
|
be updated during training. Only the classification head will be updated.
|
||||||
|
"""
|
||||||
|
for param in self.wavlm.parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
|
||||||
@add_start_docstrings_to_model_forward(WAVLM_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_model_forward(WAVLM_INPUTS_DOCSTRING)
|
||||||
@add_code_sample_docstrings(
|
@add_code_sample_docstrings(
|
||||||
checkpoint=_CHECKPOINT_FOR_DOC,
|
checkpoint=_CHECKPOINT_FOR_DOC,
|
||||||
|
Loading…
Reference in New Issue
Block a user