Wav2Vec2 Pretraining (#11306)

* Working quantizer forward

* Working quantizer forward

* Clean up unused model parts, test reproducibility

* Working quantizer forward

* Clean up unused model parts, test reproducibility

* Remove custom outputs from the shared ones

* correct conversion

* correct bug

* add first pretrain script

* save intermediate

* static shapes

* save intermediate

* finish first pretrain script version

* more refactor

* remove wanddb

* refactor more

* improve test

* correct perplexity compute bug

* finish model implementation

* add to docs

* finish docs

* finish pretraining script

* finish pretraining script

* remove wandb

* finish PR for merge

* finish config

* finish

* make deepspeed work

* Apply suggestions from code review

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* apply suggestions

* fix flaky test

Co-authored-by: patrickvonplaten <patrick.v.platen@gmail.com>
Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
Anton Lozhkov 2021-06-09 19:40:56 +02:00 committed by GitHub
parent b1a8aa94f0
commit d472bd7b18
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 1190 additions and 74 deletions

View File

@ -79,3 +79,10 @@ Wav2Vec2ForCTC
.. autoclass:: transformers.Wav2Vec2ForCTC
:members: forward
Wav2Vec2ForPreTraining
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.Wav2Vec2ForPreTraining
:members: forward

View File

@ -184,3 +184,35 @@ run_asr.py \
--preprocessing_num_workers=1 --group_by_length --freeze_feature_extractor --verbose_logging \
--deepspeed ds_config_wav2vec2_zero3.json
```
### Pretraining Wav2Vec2
The `run_pretrain.py` script allows one to pretrain a Wav2Vec2 model from scratch using Wav2Vec2's contrastive loss objective (see official [paper](https://arxiv.org/abs/2006.11477) for more information).
It is recommended to pre-train Wav2Vec2 with Trainer + Deepspeed (please refer to [this guide](https://huggingface.co/transformers/master/main_classes/deepspeed.html#deepspeed-trainer-integration) for more information).
Here is an example of how you can use DeepSpeed ZeRO-2 to pretrain a small Wav2Vec2 model:
```
PYTHONPATH=../../../src deepspeed --num_gpus 2 run_pretrain.py \
--output_dir="./wav2vec2-base-libri-100h" \
--num_train_epochs="3" \
--per_device_train_batch_size="32" \
--per_device_eval_batch_size="32" \
--gradient_accumulation_steps="2" \
--save_total_limit="3" \
--save_steps="500" \
--logging_steps="10" \
--learning_rate="5e-4" \
--weight_decay="0.01" \
--warmup_steps="3000" \
--model_name_or_path="patrickvonplaten/wav2vec2-base-libri-100h" \
--dataset_name="librispeech_asr" \
--dataset_config_name="clean" \
--train_split_name="train.100" \
--preprocessing_num_workers="4" \
--max_duration_in_seconds="10.0" \
--group_by_length \
--verbose_logging \
--fp16 \
--deepspeed ds_config_wav2vec2_zero2.json \
```

View File

@ -0,0 +1,370 @@
#!/usr/bin/env python3
import logging
import sys
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Union
import torch
import torch.nn as nn
from datasets import DatasetDict, load_dataset
from packaging import version
import librosa
from transformers import (
HfArgumentParser,
Trainer,
TrainingArguments,
Wav2Vec2Config,
Wav2Vec2FeatureExtractor,
Wav2Vec2ForPreTraining,
is_apex_available,
trainer_utils,
)
from transformers.models.wav2vec2.modeling_wav2vec2 import _compute_mask_indices
if is_apex_available():
from apex import amp
if version.parse(torch.__version__) >= version.parse("1.6"):
_is_native_amp_available = True
from torch.cuda.amp import autocast
logger = logging.getLogger(__name__)
@dataclass
class ModelArguments:
"""
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
"""
model_name_or_path: str = field(
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
)
cache_dir: Optional[str] = field(
default=None,
metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"},
)
freeze_feature_extractor: Optional[bool] = field(
default=True, metadata={"help": "Whether to freeze the feature extractor layers of the model."}
)
gradient_checkpointing: Optional[bool] = field(
default=False, metadata={"help": "Whether to freeze the feature extractor layers of the model."}
)
verbose_logging: Optional[bool] = field(
default=False,
metadata={"help": "Whether to log verbose messages or not."},
)
max_gumbel_temperature: Optional[float] = field(
default=2.0, metadata={"help": "Maximum temperature for gumbel softmax."}
)
min_gumbel_temperature: Optional[float] = field(
default=0.5, metadata={"help": "Minimum temperature for gumbel softmax."}
)
gumbel_temperature_decay: Optional[float] = field(
default=0.999995, metadata={"help": "Decay of gumbel temperature during training."}
)
def configure_logger(model_args: ModelArguments, training_args: TrainingArguments):
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
handlers=[logging.StreamHandler(sys.stdout)],
)
logging_level = logging.WARNING
if model_args.verbose_logging:
logging_level = logging.DEBUG
elif trainer_utils.is_main_process(training_args.local_rank):
logging_level = logging.INFO
logger.setLevel(logging_level)
@dataclass
class DataTrainingArguments:
"""
Arguments pertaining to what data we are going to input our model for training and eval.
Using `HfArgumentParser` we can turn this class
into argparse arguments to be able to specify them on
the command line.
"""
dataset_name: str = field(
default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
)
dataset_config_name: Optional[str] = field(
default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
)
train_split_name: Optional[str] = field(
default="train",
metadata={
"help": "The name of the training data set split to use (via the datasets library). Defaults to 'train'"
},
)
validation_split_name: Optional[str] = field(
default="validation",
metadata={
"help": "The name of the validation data set split to use (via the datasets library). Defaults to 'validation'"
},
)
speech_file_column: Optional[str] = field(
default="file",
metadata={"help": "Column in the dataset that contains speech file path. Defaults to 'file'"},
)
overwrite_cache: bool = field(
default=False, metadata={"help": "Overwrite the cached preprocessed datasets or not."}
)
preprocessing_num_workers: Optional[int] = field(
default=None,
metadata={"help": "The number of processes to use for the preprocessing."},
)
max_duration_in_seconds: Optional[float] = field(
default=20.0, metadata={"help": "Filter audio files that are longer than `max_duration_in_seconds` seconds"}
)
@dataclass
class DataCollatorForWav2Vec2Pretraining:
"""
Data collator that will dynamically pad the inputs received and prepare masked indices
for self-supervised pretraining.
Args:
model (:class:`~transformers.Wav2Vec2ForPreTraining`):
The Wav2Vec2 model used for pretraining. The data collator needs to have access
to config and ``_get_feat_extract_output_lengths`` function for correct padding.
feature_extractor (:class:`~transformers.Wav2Vec2FeatureExtractor`):
The processor used for proccessing the data.
padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
among:
* :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
sequence if provided).
* :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
maximum acceptable input length for the model if that argument is not provided.
* :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
different lengths).
max_length (:obj:`int`, `optional`):
Maximum length of the ``input_values`` of the returned list and optionally padding length (see above).
pad_to_multiple_of (:obj:`int`, `optional`):
If set will pad the sequence to a multiple of the provided value.
This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
7.5 (Volta).
"""
model: Wav2Vec2ForPreTraining
feature_extractor: Wav2Vec2FeatureExtractor
padding: Union[bool, str] = "longest"
pad_to_multiple_of: Optional[int] = None
max_length: Optional[int] = None
def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
# reformat list to dict and set to pytorch format
batch = self.feature_extractor.pad(
features,
max_length=self.max_length,
padding=self.padding,
pad_to_multiple_of=self.pad_to_multiple_of,
return_tensors="pt",
)
mask_indices_seq_length = self.model._get_feat_extract_output_lengths(batch["input_values"].shape[-1])
# sample randomly masked indices
batch["mask_time_indices"] = _compute_mask_indices(
(batch["input_values"].shape[0], mask_indices_seq_length),
self.model.config.mask_time_prob,
self.model.config.mask_time_length,
device=batch["input_values"].device,
min_masks=2,
)
return batch
class Wav2Vec2PreTrainer(Trainer):
"""
Subclassed :class:`~transformers.Trainer` for Wav2Vec2-like pretraining. Trainer can decay gumbel softmax temperature during training.
"""
def __init__(self, *args, max_gumbel_temp=1, min_gumbel_temp=0, gumbel_temp_decay=1.0, **kwargs):
super().__init__(*args, **kwargs)
self.num_update_step = 0
self.max_gumbel_temp = max_gumbel_temp
self.min_gumbel_temp = min_gumbel_temp
self.gumbel_temp_decay = gumbel_temp_decay
def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor:
"""
Perform a training step on a batch of inputs.
Subclass and override to inject custom behavior.
Args:
model (:obj:`nn.Module`):
The model to train.
inputs (:obj:`Dict[str, Union[torch.Tensor, Any]]`):
The inputs and targets of the model.
The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
argument :obj:`labels`. Check your model's documentation for all accepted arguments.
Return:
:obj:`torch.Tensor`: The tensor with training loss on this batch.
"""
model.train()
inputs = self._prepare_inputs(inputs)
if self.use_amp:
with autocast():
loss = self.compute_loss(model, inputs)
else:
loss = self.compute_loss(model, inputs)
if self.args.n_gpu > 1 or self.deepspeed:
if model.module.config.ctc_loss_reduction == "mean":
loss = loss.mean()
elif model.module.config.ctc_loss_reduction == "sum":
loss = loss.sum() / (inputs["mask_time_indices"]).sum()
else:
raise ValueError(f"{model.config.ctc_loss_reduction} is not valid. Choose one of ['mean', 'sum']")
if self.args.gradient_accumulation_steps > 1:
loss = loss / self.args.gradient_accumulation_steps
if self.use_amp:
self.scaler.scale(loss).backward()
elif self.use_apex:
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
scaled_loss.backward()
elif self.deepspeed:
self.deepspeed.backward(loss)
else:
loss.backward()
self.num_update_step += 1
# make sure gumbel softmax temperature is decayed
if self.args.n_gpu > 1 or self.deepspeed:
model.module.set_gumbel_temperature(
max(self.max_gumbel_temp * self.gumbel_temp_decay ** self.num_update_step, self.min_gumbel_temp)
)
else:
model.set_gumbel_temperature(
max(self.max_gumbel_temp * self.gumbel_temp_decay ** self.num_update_step, self.min_gumbel_temp)
)
return loss.detach()
def main():
# See all possible arguments in src/transformers/training_args.py
# or by passing the --help flag to this script.
# We now keep distinct sets of args, for a cleaner separation of concerns.
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
configure_logger(model_args, training_args)
# Downloading and loading a dataset from the hub.
datasets = load_dataset(data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir)
if "validation" not in datasets.keys():
# make sure only "validation" and "train" keys remain"
datasets = DatasetDict()
datasets["validation"] = load_dataset(
data_args.dataset_name,
data_args.dataset_config_name,
split=f"{data_args.train_split_name}[:{data_args.validation_split_percentage}%]",
cache_dir=model_args.cache_dir,
)
datasets["train"] = load_dataset(
data_args.dataset_name,
data_args.dataset_config_name,
split=f"{data_args.train_split_name}[{data_args.validation_split_percentage}%:]",
cache_dir=model_args.cache_dir,
)
else:
# make sure only "validation" and "train" keys remain"
datasets = DatasetDict()
datasets["validation"] = load_dataset(
data_args.dataset_name,
data_args.dataset_config_name,
split="validation",
cache_dir=model_args.cache_dir,
)
datasets["train"] = load_dataset(
data_args.dataset_name,
data_args.dataset_config_name,
split=f"{data_args.train_split_name}",
cache_dir=model_args.cache_dir,
)
# only normalized-inputs-training is supported
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
model_args.model_name_or_path, cache_dir=model_args.cache_dir, do_normalize=True
)
def prepare_dataset(batch):
# check that all files have the correct sampling rate
batch["speech"], _ = librosa.load(batch[data_args.speech_file_column], sr=feature_extractor.sampling_rate)
return batch
# load audio files into numpy arrays
vectorized_datasets = datasets.map(
prepare_dataset, num_proc=data_args.preprocessing_num_workers, remove_columns=datasets["train"].column_names
)
# filter audio files that are too long
vectorized_datasets = vectorized_datasets.filter(
lambda data: len(data["speech"]) < int(data_args.max_duration_in_seconds * feature_extractor.sampling_rate)
)
def normalize(batch):
return feature_extractor(batch["speech"], sampling_rate=feature_extractor.sampling_rate)
# normalize and transform to `BatchFeatures`
vectorized_datasets = vectorized_datasets.map(
normalize,
batched=True,
num_proc=data_args.preprocessing_num_workers,
load_from_cache_file=not data_args.overwrite_cache,
remove_columns=vectorized_datasets["train"].column_names,
)
# pretraining is only supported for "newer" stable layer norm architecture
# apply_spec_augment has to be True, mask_feature_prob has to be 0.0
config = Wav2Vec2Config.from_pretrained(
model_args.model_name_or_path,
cache_dir=model_args.cache_dir,
gradient_checkpointing=model_args.gradient_checkpointing,
)
if not config.do_stable_layer_norm or config.feat_extract_norm != "layer":
raise ValueError(
"PreTraining is only supported for ``config.do_stable_layer_norm=True`` and ``config.feat_extract_norm='layer'"
)
model = Wav2Vec2ForPreTraining(config)
data_collator = DataCollatorForWav2Vec2Pretraining(model=model, feature_extractor=feature_extractor)
trainer = Wav2Vec2PreTrainer(
model=model,
data_collator=data_collator,
args=training_args,
train_dataset=vectorized_datasets["train"],
eval_dataset=vectorized_datasets["validation"],
tokenizer=feature_extractor,
max_gumbel_temp=model_args.max_gumbel_temperature,
min_gumbel_temp=model_args.min_gumbel_temperature,
gumbel_temp_decay=model_args.gumbel_temperature_decay,
)
trainer.train()
if __name__ == "__main__":
main()

View File

@ -1046,6 +1046,7 @@ if is_torch_available():
"WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST",
"Wav2Vec2ForCTC",
"Wav2Vec2ForMaskedLM",
"Wav2Vec2ForPreTraining",
"Wav2Vec2Model",
"Wav2Vec2PreTrainedModel",
]
@ -2411,6 +2412,7 @@ if TYPE_CHECKING:
WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST,
Wav2Vec2ForCTC,
Wav2Vec2ForMaskedLM,
Wav2Vec2ForPreTraining,
Wav2Vec2Model,
Wav2Vec2PreTrainedModel,
)

View File

@ -269,7 +269,7 @@ from ..tapas.modeling_tapas import (
from ..transfo_xl.modeling_transfo_xl import TransfoXLForSequenceClassification, TransfoXLLMHeadModel, TransfoXLModel
from ..visual_bert.modeling_visual_bert import VisualBertForPreTraining, VisualBertModel
from ..vit.modeling_vit import ViTForImageClassification, ViTModel
from ..wav2vec2.modeling_wav2vec2 import Wav2Vec2ForMaskedLM, Wav2Vec2Model
from ..wav2vec2.modeling_wav2vec2 import Wav2Vec2ForMaskedLM, Wav2Vec2ForPreTraining, Wav2Vec2Model
from ..xlm.modeling_xlm import (
XLMForMultipleChoice,
XLMForQuestionAnsweringSimple,
@ -463,6 +463,7 @@ MODEL_FOR_PRETRAINING_MAPPING = OrderedDict(
(IBertConfig, IBertForMaskedLM),
(DebertaConfig, DebertaForMaskedLM),
(DebertaV2Config, DebertaV2ForMaskedLM),
(Wav2Vec2Config, Wav2Vec2ForPreTraining),
]
)

View File

@ -32,6 +32,7 @@ if is_torch_available():
"WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST",
"Wav2Vec2ForCTC",
"Wav2Vec2ForMaskedLM",
"Wav2Vec2ForPreTraining",
"Wav2Vec2Model",
"Wav2Vec2PreTrainedModel",
]
@ -48,6 +49,7 @@ if TYPE_CHECKING:
WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST,
Wav2Vec2ForCTC,
Wav2Vec2ForMaskedLM,
Wav2Vec2ForPreTraining,
Wav2Vec2Model,
Wav2Vec2PreTrainedModel,
)

View File

@ -71,6 +71,8 @@ class Wav2Vec2Config(PretrainedConfig):
feat_extract_activation (:obj:`str, `optional`, defaults to :obj:`"gelu"`):
The non-linear activation function (function or string) in the 1D convolutional layers of the feature
extractor. If string, :obj:`"gelu"`, :obj:`"relu"`, :obj:`"selu"` and :obj:`"gelu_new"` are supported.
feat_quantizer_dropout (obj:`float`, `optional`, defaults to 0.0):
The dropout probabilitiy for quantized feature extractor states.
conv_dim (:obj:`Tuple[int]`, `optional`, defaults to :obj:`(512, 512, 512, 512, 512, 512, 512)`):
A tuple of integers defining the number of input and output channels of each 1D convolutional layer in the
feature extractor. The length of `conv_dim` defines the number of 1D convolutional layers.
@ -108,6 +110,22 @@ class Wav2Vec2Config(PretrainedConfig):
masked along the time axis. This is only relevant if ``apply_spec_augment is True``.
mask_feature_length (:obj:`int`, `optional`, defaults to 10):
Length of vector span along the feature axis.
num_codevectors_per_group (:obj:`int`, `optional`, defaults to 320):
Number of entries in each quantization codebook (group).
num_codevector_groups (:obj:`int`, `optional`, defaults to 2):
Number of codevector groups for product codevector quantization.
contrastive_logits_temperature (:obj:`float`, `optional`, defaults to 0.1):
The temperature `kappa` in the contrastive loss.
feat_quantizer_dropout (:obj:`float`, `optional`, defaults to 0.0):
The dropout probabilitiy for the output of the feature extractor that's used by the quantizer.
num_negatives (:obj:`int`, `optional`, defaults to 100):
Number of negative samples for the contrastive loss.
codevector_dim (:obj:`int`, `optional`, defaults to 256):
Dimensionality of the quantized feature vectors.
proj_codevector_dim (:obj:`int`, `optional`, defaults to 256):
Dimensionality of the final projection of both the quantized and the transformer features.
diversity_loss_weight (:obj:`int`, `optional`, defaults to 0.1):
The weight of the codebook diversity loss component.
ctc_loss_reduction (:obj:`str`, `optional`, defaults to :obj:`"sum"`):
Specifies the reduction to apply to the output of ``torch.nn.CTCLoss``. Only relevant when training an
instance of :class:`~transformers.Wav2Vec2ForCTC`.
@ -145,6 +163,7 @@ class Wav2Vec2Config(PretrainedConfig):
activation_dropout=0.1,
attention_dropout=0.1,
feat_proj_dropout=0.1,
feat_quantizer_dropout=0.0,
final_dropout=0.1,
layerdrop=0.1,
initializer_range=0.02,
@ -163,6 +182,13 @@ class Wav2Vec2Config(PretrainedConfig):
mask_time_length=10,
mask_feature_prob=0.0,
mask_feature_length=10,
num_codevectors_per_group=320,
num_codevector_groups=2,
contrastive_logits_temperature=0.1,
num_negatives=100,
codevector_dim=256,
proj_codevector_dim=256,
diversity_loss_weight=0.1,
ctc_loss_reduction="sum",
ctc_zero_infinity=False,
gradient_checkpointing=False,
@ -217,6 +243,16 @@ class Wav2Vec2Config(PretrainedConfig):
self.mask_feature_prob = mask_feature_prob
self.mask_feature_length = mask_feature_length
# parameters for pretraining with codevector quantized representations
self.num_codevectors_per_group = num_codevectors_per_group
self.num_codevector_groups = num_codevector_groups
self.contrastive_logits_temperature = contrastive_logits_temperature
self.feat_quantizer_dropout = feat_quantizer_dropout
self.num_negatives = num_negatives
self.codevector_dim = codevector_dim
self.proj_codevector_dim = proj_codevector_dim
self.diversity_loss_weight = diversity_loss_weight
# ctc loss
self.ctc_loss_reduction = ctc_loss_reduction
self.ctc_zero_infinity = ctc_zero_infinity

View File

@ -28,7 +28,7 @@ from transformers import (
Wav2Vec2CTCTokenizer,
Wav2Vec2FeatureExtractor,
Wav2Vec2ForCTC,
Wav2Vec2Model,
Wav2Vec2ForPreTraining,
Wav2Vec2Processor,
logging,
)
@ -50,9 +50,20 @@ MAPPING = {
"final_layer_norm": "encoder.layers.*.final_layer_norm",
"encoder.layer_norm": "encoder.layer_norm",
"w2v_model.layer_norm": "feature_projection.layer_norm",
"quantizer.weight_proj": "quantizer.weight_proj",
"quantizer.vars": "quantizer.codevectors",
"project_q": "project_q",
"final_proj": "project_hid",
"w2v_encoder.proj": "lm_head",
"mask_emb": "masked_spec_embed",
}
TOP_LEVEL_KEYS = [
"lm_head",
"quantizer.weight_proj",
"quantizer.codevectors",
"project_q",
"project_hid",
]
def set_recursively(hf_pointer, key, value, full_name, weight_type):
@ -82,11 +93,11 @@ def set_recursively(hf_pointer, key, value, full_name, weight_type):
logger.info(f"{key + '.' + weight_type if weight_type is not None else ''} was initialized from {full_name}.")
def recursively_load_weights(fairseq_model, hf_model, is_finetuned):
def recursively_load_weights(fairseq_model, hf_model, is_headless):
unused_weights = []
fairseq_dict = fairseq_model.state_dict()
feature_extractor = hf_model.wav2vec2.feature_extractor if is_finetuned else hf_model.feature_extractor
feature_extractor = hf_model.wav2vec2.feature_extractor
for name, value in fairseq_dict.items():
is_used = False
@ -101,9 +112,8 @@ def recursively_load_weights(fairseq_model, hf_model, is_finetuned):
is_used = True
else:
for key, mapped_key in MAPPING.items():
mapped_key = "wav2vec2." + mapped_key if (is_finetuned and mapped_key != "lm_head") else mapped_key
if key in name or (key.split("w2v_model.")[-1] == name.split(".")[0] and not is_finetuned):
mapped_key = "wav2vec2." + mapped_key if mapped_key not in TOP_LEVEL_KEYS else mapped_key
if key in name or key.split("w2v_model.")[-1] == name.split(".")[0]:
is_used = True
if "*" in mapped_key:
layer_index = name.split(key)[0].split(".")[-2]
@ -112,10 +122,11 @@ def recursively_load_weights(fairseq_model, hf_model, is_finetuned):
weight_type = "weight_g"
elif "weight_v" in name:
weight_type = "weight_v"
elif "weight" in name:
weight_type = "weight"
elif "bias" in name:
weight_type = "bias"
elif "weight" in name:
# TODO: don't match quantizer.weight_proj
weight_type = "weight"
else:
weight_type = None
set_recursively(hf_model, mapped_key, value, name, weight_type)
@ -213,7 +224,7 @@ def convert_wav2vec2_checkpoint(
hf_wav2vec = Wav2Vec2ForCTC(config)
else:
hf_wav2vec = Wav2Vec2Model(config)
hf_wav2vec = Wav2Vec2ForPreTraining(config)
if is_finetuned:
model, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task(
@ -224,7 +235,7 @@ def convert_wav2vec2_checkpoint(
model = model[0].eval()
recursively_load_weights(model, hf_wav2vec, is_finetuned)
recursively_load_weights(model, hf_wav2vec, not is_finetuned)
hf_wav2vec.save_pretrained(pytorch_dump_folder_path)

View File

@ -15,7 +15,8 @@
""" PyTorch Wav2Vec2 model. """
import warnings
from typing import Optional, Tuple
from dataclasses import dataclass
from typing import Optional, Tuple, Union
import numpy as np
import torch
@ -26,7 +27,12 @@ from torch import nn
from transformers.deepspeed import is_deepspeed_zero3_enabled
from ...activations import ACT2FN
from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings
from ...file_utils import (
ModelOutput,
add_start_docstrings,
add_start_docstrings_to_model_forward,
replace_return_docstrings,
)
from ...modeling_outputs import BaseModelOutput, CausalLMOutput, MaskedLMOutput
from ...modeling_utils import PreTrainedModel
from ...utils import logging
@ -46,6 +52,71 @@ WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST = [
]
@dataclass
class Wav2Vec2BaseModelOutput(ModelOutput):
"""
Output type of :class:`~transformers.Wav2Vec2BaseModelOutput`, with potential hidden states and attentions.
Args:
last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
extract_features (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, conv_dim[-1])`):
Sequence of extracted feature vectors of the last convolutional layer of the model.
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
last_hidden_state: torch.FloatTensor = None
extract_features: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
@dataclass
class Wav2Vec2ForPreTrainingOutput(ModelOutput):
"""
Output type of :class:`~transformers.Wav2Vec2ForPreTrainingOutput`, with potential hidden states and attentions.
Args:
loss (`optional`, returned when model is in train mode, ``torch.FloatTensor`` of shape :obj:`(1,)`):
Total loss as the sum of the contrastive loss (L_m) and the diversity loss (L_d) as stated in the `official
paper <https://arxiv.org/pdf/2006.11477.pdf>`__ . (classification) loss.
projected_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.proj_codevector_dim)`):
Hidden-states of the model projected to `config.proj_codevector_dim` that can be used to predict the masked
projected quantized states.
projected_quantized_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.proj_codevector_dim)`):
Quantized extracted feature vectors projected to `config.proj_codevector_dim` representing the positive
target vectors for contrastive loss.
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
loss: Optional[torch.FloatTensor] = None
projected_states: torch.FloatTensor = None
projected_quantized_states: torch.FloatTensor = None
codevector_perplexity: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
def _compute_mask_indices(
shape: Tuple[int, int],
mask_prob: float,
@ -271,10 +342,11 @@ class Wav2Vec2FeatureProjection(nn.Module):
self.dropout = nn.Dropout(config.feat_proj_dropout)
def forward(self, hidden_states):
hidden_states = self.layer_norm(hidden_states)
hidden_states = self.projection(hidden_states)
# non-projected hidden states are needed for quantization
norm_hidden_states = self.layer_norm(hidden_states)
hidden_states = self.projection(norm_hidden_states)
hidden_states = self.dropout(hidden_states)
return hidden_states
return hidden_states, norm_hidden_states
# Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->Wav2Vec2
@ -685,6 +757,86 @@ class Wav2Vec2EncoderStableLayerNorm(nn.Module):
)
class Wav2Vec2GumbelVectorQuantizer(nn.Module):
"""
Vector quantization using gumbel softmax. See `CATEGORICAL REPARAMETERIZATION WITH GUMBEL-SOFTMAX
<https://arxiv.org/pdf/1611.01144.pdf>`__ for more information.
"""
def __init__(self, config):
super().__init__()
self.num_groups = config.num_codevector_groups
self.num_vars = config.num_codevectors_per_group
assert (
config.codevector_dim % self.num_groups == 0
), f"`config.codevector_dim {config.codevector_dim} must be divisible by `config.num_codevector_groups` {self.num_groups} for concatenation"
# storage for codebook variables (codewords)
self.codevectors = nn.Parameter(
torch.FloatTensor(1, self.num_groups * self.num_vars, config.codevector_dim // self.num_groups)
)
self.weight_proj = nn.Linear(config.conv_dim[-1], self.num_groups * self.num_vars)
# can be decayed for training
self.temperature = 1
def set_temperature(self, temperature: int):
self.temperature = temperature
@staticmethod
def _compute_perplexity(probs, mask=None):
if mask is not None:
mask_extended = mask.flatten()[:, None, None].expand(probs.shape)
probs = torch.where(mask_extended, probs, torch.zeros_like(probs))
marginal_probs = probs.sum(dim=0) / mask.sum()
else:
marginal_probs = probs.mean(dim=0)
perplexity = torch.exp(-torch.sum(marginal_probs * torch.log(marginal_probs + 1e-7), dim=-1)).sum()
return perplexity
def forward(self, hidden_states, mask_time_indices=None):
batch_size, sequence_length, hidden_size = hidden_states.shape
# project to codevector dim
hidden_states = self.weight_proj(hidden_states)
hidden_states = hidden_states.view(batch_size * sequence_length * self.num_groups, -1)
if self.training:
# sample code vector probs via gumbel in differentiateable way
codevector_probs = F.gumbel_softmax(hidden_states.float(), tau=self.temperature, hard=True).type_as(
hidden_states
)
# compute perplexity
codevector_soft_dist = torch.softmax(
hidden_states.view(batch_size * sequence_length, self.num_groups, -1).float(), dim=-1
)
perplexity = self._compute_perplexity(codevector_soft_dist, mask_time_indices)
else:
# take argmax in non-differentiable way
# comptute hard codevector distribution (one hot)
codevector_idx = hidden_states.argmax(dim=-1)
codevector_probs = hidden_states.new_zeros(*hidden_states.shape).scatter_(
-1, codevector_idx.view(-1, 1), 1.0
)
codevector_probs = codevector_probs.view(batch_size * sequence_length, self.num_groups, -1)
perplexity = self._compute_perplexity(codevector_probs, mask_time_indices)
codevector_probs = codevector_probs.view(batch_size * sequence_length, -1)
# use probs to retrieve codevectors
codevectors_per_group = codevector_probs.unsqueeze(-1) * self.codevectors
codevectors = (
codevectors_per_group.view(batch_size * sequence_length, self.num_groups, self.num_vars, -1)
.sum(-2)
.view(batch_size, sequence_length, -1)
)
return codevectors, perplexity
class Wav2Vec2PreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
@ -697,7 +849,12 @@ class Wav2Vec2PreTrainedModel(PreTrainedModel):
def _init_weights(self, module):
"""Initialize the weights"""
if isinstance(module, nn.Linear):
# gumbel softmax requires special init
if isinstance(module, Wav2Vec2GumbelVectorQuantizer):
module.weight_proj.weight.data.normal_(mean=0.0, std=1)
module.weight_proj.bias.data.zero_()
nn.init.uniform_(module.codevectors)
elif isinstance(module, nn.Linear):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
@ -720,7 +877,7 @@ class Wav2Vec2PreTrainedModel(PreTrainedModel):
if isinstance(module, (nn.Linear, nn.Conv1d)) and module.bias is not None:
module.bias.data.zero_()
def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor):
def _get_feat_extract_output_lengths(self, input_lengths: Union[torch.LongTensor, int]):
"""
Computes the output length of the convolutional layers
"""
@ -733,7 +890,7 @@ class Wav2Vec2PreTrainedModel(PreTrainedModel):
for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride):
input_lengths = _conv_out_length(input_lengths, kernel_size, stride)
return input_lengths.to(torch.long)
return input_lengths
WAV_2_VEC_2_START_DOCSTRING = r"""
@ -797,7 +954,7 @@ WAV_2_VEC_2_INPUTS_DOCSTRING = r"""
WAV_2_VEC_2_START_DOCSTRING,
)
class Wav2Vec2Model(Wav2Vec2PreTrainedModel):
def __init__(self, config):
def __init__(self, config: Wav2Vec2Config):
super().__init__(config)
self.config = config
self.feature_extractor = Wav2Vec2FeatureExtractor(config)
@ -812,12 +969,53 @@ class Wav2Vec2Model(Wav2Vec2PreTrainedModel):
self.init_weights()
def _mask_hidden_states(
self, hidden_states: torch.FloatTensor, mask_time_indices: Optional[torch.FloatTensor] = None
):
"""
Masks extracted features along time axis and/or along feature axis according to `SpecAugment
<https://arxiv.org/abs/1904.08779>`__ .
"""
# `config.apply_spec_augment` can set masking to False
if not getattr(self.config, "apply_spec_augment", True):
return hidden_states
if mask_time_indices is not None:
# apply SpecAugment along time axis with given mask_time_indices
hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
elif self.config.mask_time_prob > 0 and self.training:
# generate indices & apply SpecAugment along time axis
batch_size, sequence_length, hidden_size = hidden_states.size()
mask_time_indices = _compute_mask_indices(
(batch_size, sequence_length),
mask_prob=self.config.mask_time_prob,
mask_length=self.config.mask_time_length,
device=hidden_states.device,
min_masks=2,
)
hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
if self.config.mask_feature_prob > 0 and self.training:
# generate indices & apply SpecAugment along feature axis
mask_feature_indices = _compute_mask_indices(
(batch_size, hidden_size),
mask_prob=self.config.mask_feature_prob,
mask_length=self.config.mask_feature_length,
device=hidden_states.device,
)
hidden_states[mask_feature_indices[:, None].expand(-1, sequence_length, -1)] = 0
return hidden_states
@add_start_docstrings_to_model_forward(WAV_2_VEC_2_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_values,
attention_mask=None,
mask_time_indices=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
@ -852,49 +1050,30 @@ class Wav2Vec2Model(Wav2Vec2PreTrainedModel):
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
hidden_states = self.feature_extractor(input_values)
hidden_states = hidden_states.transpose(1, 2)
extract_features = self.feature_extractor(input_values)
extract_features = extract_features.transpose(1, 2)
if attention_mask is not None:
# compute real output lengths according to convolution formula
output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1))
output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long)
attention_mask = torch.zeros(
hidden_states.shape[:2], dtype=hidden_states.dtype, device=hidden_states.device
extract_features.shape[:2], dtype=extract_features.dtype, device=extract_features.device
)
# these two operations makes sure that all values
# before the output lengths indices are attended to
attention_mask[
(torch.arange(attention_mask.shape[0], device=hidden_states.device), output_lengths - 1)
(torch.arange(attention_mask.shape[0], device=extract_features.device), output_lengths - 1)
] = 1
attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool()
hidden_states = self.feature_projection(hidden_states)
hidden_states, extract_features = self.feature_projection(extract_features)
if self.config.apply_spec_augment and self.training:
batch_size, sequence_length, hidden_size = hidden_states.size()
if mask_time_indices is not None: # apply SpecAugment along time axis with given indices
hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
# apply SpecAugment along time axis
if self.config.mask_time_prob > 0:
mask_time_indices = _compute_mask_indices(
(batch_size, sequence_length),
mask_prob=self.config.mask_time_prob,
mask_length=self.config.mask_time_length,
device=hidden_states.device,
min_masks=2,
)
hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
# apply SpecAugment along feature axis
if self.config.mask_feature_prob > 0:
mask_feature_indices = _compute_mask_indices(
(batch_size, hidden_size),
mask_prob=self.config.mask_feature_prob,
mask_length=self.config.mask_feature_length,
device=hidden_states.device,
)
hidden_states[mask_feature_indices[:, None].expand(-1, sequence_length, -1)] = 0
hidden_states = self._mask_hidden_states(hidden_states)
encoder_outputs = self.encoder(
hidden_states,
@ -907,15 +1086,240 @@ class Wav2Vec2Model(Wav2Vec2PreTrainedModel):
hidden_states = encoder_outputs[0]
if not return_dict:
return (hidden_states,) + encoder_outputs[1:]
return (hidden_states, extract_features) + encoder_outputs[1:]
return BaseModelOutput(
return Wav2Vec2BaseModelOutput(
last_hidden_state=hidden_states,
extract_features=extract_features,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)
@add_start_docstrings("""Wav2Vec2 Model with a quantizer and `VQ` head on top. """, WAV_2_VEC_2_START_DOCSTRING)
class Wav2Vec2ForPreTraining(Wav2Vec2PreTrainedModel):
def __init__(self, config: Wav2Vec2Config):
super().__init__(config)
self.wav2vec2 = Wav2Vec2Model(config)
self.dropout_features = nn.Dropout(config.feat_quantizer_dropout)
self.quantizer = Wav2Vec2GumbelVectorQuantizer(config)
self.project_q = nn.Linear(config.codevector_dim, config.proj_codevector_dim)
self.project_hid = nn.Linear(config.hidden_size, config.proj_codevector_dim)
self.init_weights()
def set_gumbel_temperature(self, temperature: int):
"""
Set the Gumbel softmax temperature to a given value. Only necessary for training
"""
return self.quantizer.set_temperature(temperature)
def freeze_feature_extractor(self):
"""
Calling this function will disable the gradient computation for the feature extractor so that its parameters
will not be updated during training.
"""
self.wav2vec2.feature_extractor._freeze_parameters()
@staticmethod
def _sample_negatives(features: torch.FloatTensor, num_negatives: int):
"""
Sample `num_negatives` vectors from feature vectors.
"""
batch_size, sequence_length, hidden_size = features.shape
if sequence_length <= 1:
raise ValueError(
f"`features should have `sequence_length` > 1, but are of shape (batch_size, sequence_length, hidden_size) = ({batch_size, sequence_length, hidden_size})."
)
features = features.view(-1, hidden_size) # BTC => (BxT)C
with torch.no_grad():
# get `num_negatives` random vector indices from the same utterance
sampled_negative_indices = torch.randint(
low=0,
high=sequence_length - 1,
size=(batch_size, num_negatives * sequence_length),
device=features.device,
)
# generate indices of the positive vectors themselves, repeat them `num_negatives` times
feature_indices = (
torch.arange(sequence_length, device=features.device)[:, None]
.expand(sequence_length, num_negatives)
.flatten()
)
# avoid sampling the same positive vector, but keep the distribution uniform
sampled_negative_indices[sampled_negative_indices >= feature_indices] += 1
# correct for batch size
for batch_idx in range(1, batch_size):
sampled_negative_indices[batch_idx] += batch_idx * sequence_length
# take negative vectors from sampled indices
sampled_negatives = features[sampled_negative_indices.view(-1)]
sampled_negatives = sampled_negatives.view(batch_size, sequence_length, num_negatives, hidden_size).permute(
2, 0, 1, 3
)
return sampled_negatives
@staticmethod
def compute_contrastive_logits(
target_features: torch.FloatTensor,
negative_features: torch.FloatTensor,
predicted_features: torch.FloatTensor,
temperature: int = 1,
):
"""
Compute logits for contrastive loss based using cosine similarity as the distance measure between
`[positive_feature, negative_features]` and `[predicted_features]`. Additionally, temperature can be applied.
"""
target_features = torch.cat([target_features, negative_features], dim=0)
logits = torch.cosine_similarity(predicted_features.float(), target_features.float(), dim=-1).type_as(
target_features
)
# apply temperature
logits = logits / temperature
return logits
@add_start_docstrings_to_model_forward(WAV_2_VEC_2_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_values,
attention_mask=None,
mask_time_indices=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
r"""
mask_time_indices (:obj:`torch.BoolTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Indices to mask extracted features for contrastive loss. When in training mode, model learns to predict
masked extracted features in `config.proj_codevector_dim` space.
Returns:
Example::
>>> import torch
>>> from transformers import Wav2Vec2FeatureExtractor, Wav2Vec2ForPreTraining
>>> from transformers.models.wav2vec2.modeling_wav2vec2 import _compute_mask_indices
>>> from datasets import load_dataset
>>> import soundfile as sf
>>> feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("patrickvonplaten/wav2vec2-base")
>>> model = Wav2Vec2ForPreTraining.from_pretrained("patrickvonplaten/wav2vec2-base")
>>> def map_to_array(batch):
... speech, _ = sf.read(batch["file"])
... batch["speech"] = speech
... return batch
>>> ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation")
>>> ds = ds.map(map_to_array)
>>> input_values = feature_extractor(ds["speech"][0], return_tensors="pt").input_values # Batch size 1
>>> # compute masked indices
>>> batch_size, raw_sequence_length = input_values.shape
>>> sequence_length = model._get_feat_extract_output_lengths(raw_sequence_length)
>>> mask_time_indices = _compute_mask_indices((batch_size, sequence_length), mask_prob=0.2, mask_length=2, device=model.device)
>>> with torch.no_grad():
... outputs = model(input_values, mask_time_indices=mask_time_indices)
>>> # compute cosine similarity between predicted (=projected_states) and target (=projected_quantized_states)
>>> cosine_sim = torch.cosine_similarity(
... outputs.projected_states, outputs.projected_quantized_states, dim=-1
... )
>>> # show that cosine similarity is much higher than random
>>> assert cosine_sim[mask_time_indices].mean() > 0.5
>>> # for contrastive loss training model should be put into train mode
>>> model.train()
>>> loss = model(input_values, mask_time_indices=mask_time_indices).loss
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if mask_time_indices is not None:
mask_time_indices = mask_time_indices.to(torch.bool)
outputs = self.wav2vec2(
input_values,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
mask_time_indices=mask_time_indices,
return_dict=return_dict,
)
# 1. project all transformed features (including masked) to final vq dim
transformer_features = self.project_hid(outputs[0])
# 2. quantize all (unmasked) extracted features and project to final vq dim
extract_features = self.dropout_features(outputs[1])
quantized_features, codevector_perplexity = self.quantizer(extract_features, mask_time_indices)
quantized_features = self.project_q(quantized_features)
loss = None
if self.training:
# for training, we sample negatives
# 3. sample K negatives (distractors) quantized states for contrastive loss
negative_quantized_features = self._sample_negatives(quantized_features, self.config.num_negatives)
# 4. compute logits, corresponding to `logs = sim(c_t, [q_t, \sim{q}_t]) / \kappa`
# of equation (3) in https://arxiv.org/pdf/2006.11477.pdf
logits = self.compute_contrastive_logits(
quantized_features[None, :],
negative_quantized_features,
transformer_features,
self.config.contrastive_logits_temperature,
)
# 5. if a negative vector is identical to the positive (i.e. when codebook utilization is low),
# its cosine similarity will be masked
neg_is_pos = (quantized_features == negative_quantized_features).all(-1)
if neg_is_pos.any():
logits[1:][neg_is_pos] = float("-inf")
# 6. compute contrastive loss \mathbf{L}_m = cross_entropy(logs) =
# -log(exp(sim(c_t, q_t)/\kappa) / \sum_{\sim{q}} exp(sim(c_t, \sim{q})/\kappa))
preds = logits.transpose(0, 2).reshape(-1, logits.size(0))
target = ((1 - mask_time_indices.long()) * -100).transpose(0, 1).flatten()
contrastive_loss = F.cross_entropy(preds.float(), target, reduction="sum")
# 7. compute diversity loss: \mathbf{L}_d
num_codevectors = self.config.num_codevectors_per_group * self.config.num_codevector_groups
diversity_loss = (num_codevectors - codevector_perplexity) / num_codevectors
# 8. \mathbf{L} = \mathbf{L}_m + \alpha * \mathbf{L}_d
loss = contrastive_loss + self.config.diversity_loss_weight * diversity_loss
if not return_dict:
if loss is not None:
return (loss, transformer_features, quantized_features, codevector_perplexity) + outputs[2:]
return (transformer_features, quantized_features, codevector_perplexity) + outputs[2:]
return Wav2Vec2ForPreTrainingOutput(
loss=loss,
projected_states=transformer_features,
projected_quantized_states=quantized_features,
codevector_perplexity=codevector_perplexity,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
@add_start_docstrings("""Wav2Vec2 Model with a `language modeling` head on top. """, WAV_2_VEC_2_START_DOCSTRING)
class Wav2Vec2ForMaskedLM(Wav2Vec2PreTrainedModel):
def __init__(self, config):
@ -986,7 +1390,7 @@ class Wav2Vec2ForMaskedLM(Wav2Vec2PreTrainedModel):
logits = self.lm_head(hidden_states)
if not return_dict:
output = (logits,) + outputs[1:]
output = (logits,) + outputs[2:]
return output
return MaskedLMOutput(logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions)
@ -1089,7 +1493,7 @@ class Wav2Vec2ForCTC(Wav2Vec2PreTrainedModel):
attention_mask = (
attention_mask if attention_mask is not None else torch.ones_like(input_values, dtype=torch.long)
)
input_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1))
input_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long)
# assuming that padded tokens are filled with -100
# when not being attended to
@ -1112,7 +1516,7 @@ class Wav2Vec2ForCTC(Wav2Vec2PreTrainedModel):
)
if not return_dict:
output = (logits,) + outputs[1:]
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return CausalLMOutput(

View File

@ -2969,6 +2969,11 @@ class Wav2Vec2ForMaskedLM:
requires_backends(self, ["torch"])
class Wav2Vec2ForPreTraining:
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class Wav2Vec2Model:
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])

View File

@ -29,8 +29,16 @@ from .test_modeling_common import ModelTesterMixin, _config_zero_init
if is_torch_available():
import torch
from transformers import Wav2Vec2Config, Wav2Vec2ForCTC, Wav2Vec2ForMaskedLM, Wav2Vec2Model, Wav2Vec2Processor
from transformers.models.wav2vec2.modeling_wav2vec2 import _compute_mask_indices
from transformers import (
Wav2Vec2Config,
Wav2Vec2FeatureExtractor,
Wav2Vec2ForCTC,
Wav2Vec2ForMaskedLM,
Wav2Vec2ForPreTraining,
Wav2Vec2Model,
Wav2Vec2Processor,
)
from transformers.models.wav2vec2.modeling_wav2vec2 import Wav2Vec2GumbelVectorQuantizer, _compute_mask_indices
class Wav2Vec2ModelTester:
@ -219,13 +227,7 @@ class Wav2Vec2ModelTester:
@require_torch
class Wav2Vec2ModelTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = (
(
Wav2Vec2ForCTC,
Wav2Vec2Model,
Wav2Vec2ForMaskedLM,
)
if is_torch_available()
else ()
(Wav2Vec2ForCTC, Wav2Vec2Model, Wav2Vec2ForMaskedLM, Wav2Vec2ForPreTraining) if is_torch_available() else ()
)
test_pruning = False
test_headmasking = False
@ -316,8 +318,14 @@ class Wav2Vec2ModelTest(ModelTesterMixin, unittest.TestCase):
for model_class in self.all_model_classes:
model = model_class(config=configs_no_init)
for name, param in model.named_parameters():
uniform_init_parms = [
"conv.weight",
"masked_spec_embed",
"codevectors",
"quantizer.weight_proj.weight",
]
if param.requires_grad:
if "conv.weight" in name or "masked_spec_embed" in name:
if any([x in name for x in uniform_init_parms]):
self.assertTrue(
-1.0 <= ((param.data.mean() * 1e9).round() / 1e9).item() <= 1.0,
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
@ -333,10 +341,14 @@ class Wav2Vec2ModelTest(ModelTesterMixin, unittest.TestCase):
def _mock_init_weights(self, module):
if hasattr(module, "weight") and module.weight is not None:
module.weight.data.fill_(3)
if hasattr(module, "weight_g") and module.weight is not None:
if hasattr(module, "weight_g") and module.weight_g is not None:
module.weight_g.data.fill_(3)
if hasattr(module, "weight_v") and module.weight_v is not None:
module.weight_v.data.fill_(3)
if hasattr(module, "bias") and module.bias is not None:
module.bias.data.fill_(3)
if hasattr(module, "codevectors") and module.codevectors is not None:
module.codevectors.data.fill_(3)
@slow
def test_model_from_pretrained(self):
@ -346,7 +358,9 @@ class Wav2Vec2ModelTest(ModelTesterMixin, unittest.TestCase):
@require_torch
class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = (Wav2Vec2ForCTC, Wav2Vec2Model, Wav2Vec2ForMaskedLM) if is_torch_available() else ()
all_model_classes = (
(Wav2Vec2ForCTC, Wav2Vec2Model, Wav2Vec2ForMaskedLM, Wav2Vec2ForPreTraining) if is_torch_available() else ()
)
test_pruning = False
test_headmasking = False
test_torchscript = False
@ -442,8 +456,14 @@ class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase):
for model_class in self.all_model_classes:
model = model_class(config=configs_no_init)
for name, param in model.named_parameters():
uniform_init_parms = [
"conv.weight",
"masked_spec_embed",
"codevectors",
"quantizer.weight_proj.weight",
]
if param.requires_grad:
if "conv.weight" in name or "masked_spec_embed" in name:
if any([x in name for x in uniform_init_parms]):
self.assertTrue(
-1.0 <= ((param.data.mean() * 1e9).round() / 1e9).item() <= 1.0,
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
@ -459,10 +479,47 @@ class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase):
def _mock_init_weights(self, module):
if hasattr(module, "weight") and module.weight is not None:
module.weight.data.fill_(3)
if hasattr(module, "weight_g") and module.weight is not None:
if hasattr(module, "weight_g") and module.weight_g is not None:
module.weight_g.data.fill_(3)
if hasattr(module, "weight_v") and module.weight_v is not None:
module.weight_v.data.fill_(3)
if hasattr(module, "bias") and module.bias is not None:
module.bias.data.fill_(3)
if hasattr(module, "codevectors") and module.codevectors is not None:
module.codevectors.data.fill_(3)
def test_model_for_pretraining(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = Wav2Vec2ForPreTraining(config).to(torch_device)
features_shape = (
inputs_dict["input_values"].shape[0],
model._get_feat_extract_output_lengths(torch.tensor(inputs_dict["input_values"].shape[1])),
)
mask_time_indices = _compute_mask_indices(
features_shape,
model.config.mask_time_prob,
model.config.mask_time_length,
device=inputs_dict["input_values"].device,
min_masks=2,
).to(torch_device)
loss = model(
inputs_dict["input_values"],
attention_mask=inputs_dict["attention_mask"],
mask_time_indices=mask_time_indices,
).loss
mask_time_indices[:, : mask_time_indices.shape[-1] // 2] = True
loss_more_masked = model(
inputs_dict["input_values"],
attention_mask=inputs_dict["attention_mask"],
mask_time_indices=mask_time_indices,
).loss
# loss_more_masked has to be bigger or equal loss since more masked inputs have to be predicted
self.assertTrue(loss.detach().item() <= loss_more_masked.detach().item())
@slow
def test_model_from_pretrained(self):
@ -484,24 +541,56 @@ class Wav2Vec2UtilsTest(unittest.TestCase):
def test_compute_mask_indices_overlap(self):
batch_size = 4
sequence_length = 60
sequence_length = 80
mask_prob = 0.5
mask_length = 4
mask = _compute_mask_indices((batch_size, sequence_length), mask_prob, mask_length, torch_device)
# because of overlap there is a range of possible masks
# because of overlap mask don't have to add up exactly to `mask_prob * sequence_length`, but have to be smaller or equal
for batch_sum in mask.sum(axis=-1):
self.assertIn(
int(batch_sum),
list(range(int(mask_prob // mask_length * sequence_length), int(mask_prob * sequence_length))),
)
self.assertTrue(int(batch_sum) <= mask_prob * sequence_length)
def test_compute_perplexity(self):
probs = torch.arange(100, device=torch_device).reshape(2, 5, 10) / 100
ppl = Wav2Vec2GumbelVectorQuantizer._compute_perplexity(probs)
self.assertTrue(abs(ppl.item() - 141.4291) < 1e-3)
# mask half of the input
mask = torch.ones((2,), device=torch_device, dtype=torch.bool)
mask[0] = 0
ppl = Wav2Vec2GumbelVectorQuantizer._compute_perplexity(probs, mask)
self.assertTrue(abs(ppl.item() - 58.6757) < 1e-3)
def test_sample_negatives(self):
batch_size = 2
sequence_length = 10
hidden_size = 4
num_negatives = 3
features = (torch.arange(sequence_length * hidden_size, device=torch_device) // hidden_size).view(
sequence_length, hidden_size
) # each value in vector consits of same value
features = features[None, :].expand(batch_size, sequence_length, hidden_size).contiguous()
negatives = Wav2Vec2ForPreTraining._sample_negatives(features, num_negatives)
self.assertTrue(negatives.shape == (num_negatives, batch_size, sequence_length, hidden_size))
# make sure no negatively sampled vector is actually a positive one
for negative in negatives:
self.assertTrue(((negative - features) == 0).sum() == 0.0)
# make sure that full vectors are sampled and not values of vectors => this means that `unique()` yields a single value for `hidden_size` dim
self.assertTrue(negatives.unique(dim=-1).shape, (num_negatives, batch_size, sequence_length, 1))
@require_torch
@slow
@require_datasets
@require_soundfile
@slow
class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
def _load_datasamples(self, num_samples):
from datasets import load_dataset
@ -586,3 +675,160 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
"his instant panic was followed by a small sharp blow high on his chest",
]
self.assertListEqual(predicted_trans, EXPECTED_TRANSCRIPTIONS)
def test_inference_integration(self):
model = Wav2Vec2ForPreTraining.from_pretrained("patrickvonplaten/wav2vec2-base")
model.to(torch_device)
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
"patrickvonplaten/wav2vec2-base", return_attention_mask=True
)
input_speech = self._load_datasamples(2)
inputs_dict = feature_extractor(input_speech, return_tensors="pt", padding=True)
features_shape = (
inputs_dict["input_values"].shape[0],
model._get_feat_extract_output_lengths(torch.tensor(inputs_dict["input_values"].shape[1])),
)
torch.manual_seed(0)
mask_time_indices = _compute_mask_indices(
features_shape,
model.config.mask_time_prob,
model.config.mask_time_length,
device=inputs_dict["input_values"].device,
min_masks=2,
).to(torch_device)
with torch.no_grad():
outputs = model(
inputs_dict.input_values.to(torch_device),
attention_mask=inputs_dict.attention_mask.to(torch_device),
mask_time_indices=mask_time_indices,
)
# compute cosine similarity
cosine_sim = torch.cosine_similarity(outputs.projected_states, outputs.projected_quantized_states, dim=-1)
# retrieve cosine sim of masked features
cosine_sim_masked = cosine_sim[mask_time_indices]
# fmt: off
expected_cosine_sim_masked = torch.tensor(
[0.7458, 0.7188, 0.6418, 0.3729, 0.3741, 0.3694, 0.3110, 0.2257, 0.4403, 0.5415, 0.3950, 0.3701, 0.8831, 0.8613, 0.5229, 0.6696, 0.7206, 0.7877, 0.6758, 0.8746, 0.6596, 0.6282, 0.6178, 0.5839, 0.5926, 0.6651, 0.4635, 0.6332, 0.6572, 0.8776, 0.4999, 0.7001, 0.7257, 0.5098, 0.6229, 0.4566, 0.5261, 0.6363, 0.5371, 0.6997],
device=torch_device,
)
# fmt: on
self.assertTrue(torch.allclose(cosine_sim_masked, expected_cosine_sim_masked, atol=1e-3))
def test_inference_pretrained(self):
model = Wav2Vec2ForPreTraining.from_pretrained("patrickvonplaten/wav2vec2-base")
model.to(torch_device)
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
"patrickvonplaten/wav2vec2-base", return_attention_mask=True
)
input_speech = self._load_datasamples(2)
inputs_dict = feature_extractor(input_speech, return_tensors="pt", padding=True)
features_shape = (
inputs_dict["input_values"].shape[0],
model._get_feat_extract_output_lengths(torch.tensor(inputs_dict["input_values"].shape[1])),
)
torch.manual_seed(0)
mask_time_indices = _compute_mask_indices(
features_shape,
model.config.mask_time_prob,
model.config.mask_time_length,
device=inputs_dict["input_values"].device,
min_masks=2,
).to(torch_device)
with torch.no_grad():
outputs = model(
inputs_dict.input_values.to(torch_device),
attention_mask=inputs_dict.attention_mask.to(torch_device),
mask_time_indices=mask_time_indices,
)
# compute cosine similarity
cosine_sim = torch.cosine_similarity(outputs.projected_states, outputs.projected_quantized_states, dim=-1)
# retrieve cosine sim of masked features
cosine_sim_masked = cosine_sim[mask_time_indices]
# ... now compare to randomly initialized model
config = Wav2Vec2Config.from_pretrained("patrickvonplaten/wav2vec2-base")
model_rand = Wav2Vec2ForPreTraining(config).to(torch_device).eval()
with torch.no_grad():
outputs_rand = model_rand(
inputs_dict.input_values.to(torch_device),
attention_mask=inputs_dict.attention_mask.to(torch_device),
mask_time_indices=mask_time_indices,
)
# compute cosine similarity
cosine_sim_rand = torch.cosine_similarity(
outputs_rand.projected_states, outputs_rand.projected_quantized_states, dim=-1
)
# retrieve cosine sim of masked features
cosine_sim_masked_rand = cosine_sim_rand[mask_time_indices]
# a pretrained wav2vec2 model has learned to predict the quantized latent states
# => the cosine similarity between quantized states and predicted states > 0.5
# a random wav2vec2 model has not learned to predict the quantized latent states
# => the cosine similarity between quantized states and predicted states is very likely < 0.1
self.assertTrue(cosine_sim_masked.mean().item() - 5 * cosine_sim_masked_rand.mean().item() > 0)
def test_loss_pretraining(self):
model = Wav2Vec2ForPreTraining.from_pretrained(
"patrickvonplaten/wav2vec2-base",
attention_dropout=0.0,
feat_proj_dropout=0.0,
hidden_dropout=0.0,
layerdrop=0.0,
)
model.to(torch_device).train()
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
"patrickvonplaten/wav2vec2-base", return_attention_mask=True
)
input_speech = self._load_datasamples(2)
inputs_dict = feature_extractor(input_speech, return_tensors="pt", padding=True)
features_shape = (
inputs_dict["input_values"].shape[0],
model._get_feat_extract_output_lengths(inputs_dict["input_values"].shape[1]),
)
torch.manual_seed(0)
mask_time_indices = _compute_mask_indices(
features_shape,
model.config.mask_time_prob,
model.config.mask_time_length,
device=inputs_dict["input_values"].device,
min_masks=2,
).to(torch_device)
with torch.no_grad():
outputs = model(
inputs_dict.input_values.to(torch_device),
attention_mask=inputs_dict.attention_mask.to(torch_device),
mask_time_indices=mask_time_indices,
)
# check diversity loss
num_codevectors = model.config.num_codevectors_per_group * model.config.num_codevector_groups
diversity_loss = (num_codevectors - outputs.codevector_perplexity) / num_codevectors
self.assertTrue(abs(diversity_loss.item() - 0.8859) < 1e-3)
# check overall loss (contrastive loss + diversity loss)
expected_loss = 62.5170 if model.device.type == "cpu" else 50.3612
self.assertTrue(abs(outputs.loss.item() - expected_loss) < 1e-3)