Wav2Vec2 Pretraining (#11306)

* Working quantizer forward

* Working quantizer forward

* Clean up unused model parts, test reproducibility

* Working quantizer forward

* Clean up unused model parts, test reproducibility

* Remove custom outputs from the shared ones

* correct conversion

* correct bug

* add first pretrain script

* save intermediate

* static shapes

* save intermediate

* finish first pretrain script version

* more refactor

* remove wanddb

* refactor more

* improve test

* correct perplexity compute bug

* finish model implementation

* add to docs

* finish docs

* finish pretraining script

* finish pretraining script

* remove wandb

* finish PR for merge

* finish config

* finish

* make deepspeed work

* Apply suggestions from code review

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* apply suggestions

* fix flaky test

Co-authored-by: patrickvonplaten <patrick.v.platen@gmail.com>
Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
Anton Lozhkov 2021-06-09 19:40:56 +02:00 committed by GitHub
parent b1a8aa94f0
commit d472bd7b18
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 1190 additions and 74 deletions

View File

@ -79,3 +79,10 @@ Wav2Vec2ForCTC
.. autoclass:: transformers.Wav2Vec2ForCTC .. autoclass:: transformers.Wav2Vec2ForCTC
:members: forward :members: forward
Wav2Vec2ForPreTraining
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.Wav2Vec2ForPreTraining
:members: forward

View File

@ -184,3 +184,35 @@ run_asr.py \
--preprocessing_num_workers=1 --group_by_length --freeze_feature_extractor --verbose_logging \ --preprocessing_num_workers=1 --group_by_length --freeze_feature_extractor --verbose_logging \
--deepspeed ds_config_wav2vec2_zero3.json --deepspeed ds_config_wav2vec2_zero3.json
``` ```
### Pretraining Wav2Vec2
The `run_pretrain.py` script allows one to pretrain a Wav2Vec2 model from scratch using Wav2Vec2's contrastive loss objective (see official [paper](https://arxiv.org/abs/2006.11477) for more information).
It is recommended to pre-train Wav2Vec2 with Trainer + Deepspeed (please refer to [this guide](https://huggingface.co/transformers/master/main_classes/deepspeed.html#deepspeed-trainer-integration) for more information).
Here is an example of how you can use DeepSpeed ZeRO-2 to pretrain a small Wav2Vec2 model:
```
PYTHONPATH=../../../src deepspeed --num_gpus 2 run_pretrain.py \
--output_dir="./wav2vec2-base-libri-100h" \
--num_train_epochs="3" \
--per_device_train_batch_size="32" \
--per_device_eval_batch_size="32" \
--gradient_accumulation_steps="2" \
--save_total_limit="3" \
--save_steps="500" \
--logging_steps="10" \
--learning_rate="5e-4" \
--weight_decay="0.01" \
--warmup_steps="3000" \
--model_name_or_path="patrickvonplaten/wav2vec2-base-libri-100h" \
--dataset_name="librispeech_asr" \
--dataset_config_name="clean" \
--train_split_name="train.100" \
--preprocessing_num_workers="4" \
--max_duration_in_seconds="10.0" \
--group_by_length \
--verbose_logging \
--fp16 \
--deepspeed ds_config_wav2vec2_zero2.json \
```

View File

@ -0,0 +1,370 @@
#!/usr/bin/env python3
import logging
import sys
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Union
import torch
import torch.nn as nn
from datasets import DatasetDict, load_dataset
from packaging import version
import librosa
from transformers import (
HfArgumentParser,
Trainer,
TrainingArguments,
Wav2Vec2Config,
Wav2Vec2FeatureExtractor,
Wav2Vec2ForPreTraining,
is_apex_available,
trainer_utils,
)
from transformers.models.wav2vec2.modeling_wav2vec2 import _compute_mask_indices
if is_apex_available():
from apex import amp
if version.parse(torch.__version__) >= version.parse("1.6"):
_is_native_amp_available = True
from torch.cuda.amp import autocast
logger = logging.getLogger(__name__)
@dataclass
class ModelArguments:
"""
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
"""
model_name_or_path: str = field(
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
)
cache_dir: Optional[str] = field(
default=None,
metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"},
)
freeze_feature_extractor: Optional[bool] = field(
default=True, metadata={"help": "Whether to freeze the feature extractor layers of the model."}
)
gradient_checkpointing: Optional[bool] = field(
default=False, metadata={"help": "Whether to freeze the feature extractor layers of the model."}
)
verbose_logging: Optional[bool] = field(
default=False,
metadata={"help": "Whether to log verbose messages or not."},
)
max_gumbel_temperature: Optional[float] = field(
default=2.0, metadata={"help": "Maximum temperature for gumbel softmax."}
)
min_gumbel_temperature: Optional[float] = field(
default=0.5, metadata={"help": "Minimum temperature for gumbel softmax."}
)
gumbel_temperature_decay: Optional[float] = field(
default=0.999995, metadata={"help": "Decay of gumbel temperature during training."}
)
def configure_logger(model_args: ModelArguments, training_args: TrainingArguments):
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
handlers=[logging.StreamHandler(sys.stdout)],
)
logging_level = logging.WARNING
if model_args.verbose_logging:
logging_level = logging.DEBUG
elif trainer_utils.is_main_process(training_args.local_rank):
logging_level = logging.INFO
logger.setLevel(logging_level)
@dataclass
class DataTrainingArguments:
"""
Arguments pertaining to what data we are going to input our model for training and eval.
Using `HfArgumentParser` we can turn this class
into argparse arguments to be able to specify them on
the command line.
"""
dataset_name: str = field(
default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
)
dataset_config_name: Optional[str] = field(
default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
)
train_split_name: Optional[str] = field(
default="train",
metadata={
"help": "The name of the training data set split to use (via the datasets library). Defaults to 'train'"
},
)
validation_split_name: Optional[str] = field(
default="validation",
metadata={
"help": "The name of the validation data set split to use (via the datasets library). Defaults to 'validation'"
},
)
speech_file_column: Optional[str] = field(
default="file",
metadata={"help": "Column in the dataset that contains speech file path. Defaults to 'file'"},
)
overwrite_cache: bool = field(
default=False, metadata={"help": "Overwrite the cached preprocessed datasets or not."}
)
preprocessing_num_workers: Optional[int] = field(
default=None,
metadata={"help": "The number of processes to use for the preprocessing."},
)
max_duration_in_seconds: Optional[float] = field(
default=20.0, metadata={"help": "Filter audio files that are longer than `max_duration_in_seconds` seconds"}
)
@dataclass
class DataCollatorForWav2Vec2Pretraining:
"""
Data collator that will dynamically pad the inputs received and prepare masked indices
for self-supervised pretraining.
Args:
model (:class:`~transformers.Wav2Vec2ForPreTraining`):
The Wav2Vec2 model used for pretraining. The data collator needs to have access
to config and ``_get_feat_extract_output_lengths`` function for correct padding.
feature_extractor (:class:`~transformers.Wav2Vec2FeatureExtractor`):
The processor used for proccessing the data.
padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
among:
* :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
sequence if provided).
* :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
maximum acceptable input length for the model if that argument is not provided.
* :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
different lengths).
max_length (:obj:`int`, `optional`):
Maximum length of the ``input_values`` of the returned list and optionally padding length (see above).
pad_to_multiple_of (:obj:`int`, `optional`):
If set will pad the sequence to a multiple of the provided value.
This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
7.5 (Volta).
"""
model: Wav2Vec2ForPreTraining
feature_extractor: Wav2Vec2FeatureExtractor
padding: Union[bool, str] = "longest"
pad_to_multiple_of: Optional[int] = None
max_length: Optional[int] = None
def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
# reformat list to dict and set to pytorch format
batch = self.feature_extractor.pad(
features,
max_length=self.max_length,
padding=self.padding,
pad_to_multiple_of=self.pad_to_multiple_of,
return_tensors="pt",
)
mask_indices_seq_length = self.model._get_feat_extract_output_lengths(batch["input_values"].shape[-1])
# sample randomly masked indices
batch["mask_time_indices"] = _compute_mask_indices(
(batch["input_values"].shape[0], mask_indices_seq_length),
self.model.config.mask_time_prob,
self.model.config.mask_time_length,
device=batch["input_values"].device,
min_masks=2,
)
return batch
class Wav2Vec2PreTrainer(Trainer):
"""
Subclassed :class:`~transformers.Trainer` for Wav2Vec2-like pretraining. Trainer can decay gumbel softmax temperature during training.
"""
def __init__(self, *args, max_gumbel_temp=1, min_gumbel_temp=0, gumbel_temp_decay=1.0, **kwargs):
super().__init__(*args, **kwargs)
self.num_update_step = 0
self.max_gumbel_temp = max_gumbel_temp
self.min_gumbel_temp = min_gumbel_temp
self.gumbel_temp_decay = gumbel_temp_decay
def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor:
"""
Perform a training step on a batch of inputs.
Subclass and override to inject custom behavior.
Args:
model (:obj:`nn.Module`):
The model to train.
inputs (:obj:`Dict[str, Union[torch.Tensor, Any]]`):
The inputs and targets of the model.
The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
argument :obj:`labels`. Check your model's documentation for all accepted arguments.
Return:
:obj:`torch.Tensor`: The tensor with training loss on this batch.
"""
model.train()
inputs = self._prepare_inputs(inputs)
if self.use_amp:
with autocast():
loss = self.compute_loss(model, inputs)
else:
loss = self.compute_loss(model, inputs)
if self.args.n_gpu > 1 or self.deepspeed:
if model.module.config.ctc_loss_reduction == "mean":
loss = loss.mean()
elif model.module.config.ctc_loss_reduction == "sum":
loss = loss.sum() / (inputs["mask_time_indices"]).sum()
else:
raise ValueError(f"{model.config.ctc_loss_reduction} is not valid. Choose one of ['mean', 'sum']")
if self.args.gradient_accumulation_steps > 1:
loss = loss / self.args.gradient_accumulation_steps
if self.use_amp:
self.scaler.scale(loss).backward()
elif self.use_apex:
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
scaled_loss.backward()
elif self.deepspeed:
self.deepspeed.backward(loss)
else:
loss.backward()
self.num_update_step += 1
# make sure gumbel softmax temperature is decayed
if self.args.n_gpu > 1 or self.deepspeed:
model.module.set_gumbel_temperature(
max(self.max_gumbel_temp * self.gumbel_temp_decay ** self.num_update_step, self.min_gumbel_temp)
)
else:
model.set_gumbel_temperature(
max(self.max_gumbel_temp * self.gumbel_temp_decay ** self.num_update_step, self.min_gumbel_temp)
)
return loss.detach()
def main():
# See all possible arguments in src/transformers/training_args.py
# or by passing the --help flag to this script.
# We now keep distinct sets of args, for a cleaner separation of concerns.
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
configure_logger(model_args, training_args)
# Downloading and loading a dataset from the hub.
datasets = load_dataset(data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir)
if "validation" not in datasets.keys():
# make sure only "validation" and "train" keys remain"
datasets = DatasetDict()
datasets["validation"] = load_dataset(
data_args.dataset_name,
data_args.dataset_config_name,
split=f"{data_args.train_split_name}[:{data_args.validation_split_percentage}%]",
cache_dir=model_args.cache_dir,
)
datasets["train"] = load_dataset(
data_args.dataset_name,
data_args.dataset_config_name,
split=f"{data_args.train_split_name}[{data_args.validation_split_percentage}%:]",
cache_dir=model_args.cache_dir,
)
else:
# make sure only "validation" and "train" keys remain"
datasets = DatasetDict()
datasets["validation"] = load_dataset(
data_args.dataset_name,
data_args.dataset_config_name,
split="validation",
cache_dir=model_args.cache_dir,
)
datasets["train"] = load_dataset(
data_args.dataset_name,
data_args.dataset_config_name,
split=f"{data_args.train_split_name}",
cache_dir=model_args.cache_dir,
)
# only normalized-inputs-training is supported
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
model_args.model_name_or_path, cache_dir=model_args.cache_dir, do_normalize=True
)
def prepare_dataset(batch):
# check that all files have the correct sampling rate
batch["speech"], _ = librosa.load(batch[data_args.speech_file_column], sr=feature_extractor.sampling_rate)
return batch
# load audio files into numpy arrays
vectorized_datasets = datasets.map(
prepare_dataset, num_proc=data_args.preprocessing_num_workers, remove_columns=datasets["train"].column_names
)
# filter audio files that are too long
vectorized_datasets = vectorized_datasets.filter(
lambda data: len(data["speech"]) < int(data_args.max_duration_in_seconds * feature_extractor.sampling_rate)
)
def normalize(batch):
return feature_extractor(batch["speech"], sampling_rate=feature_extractor.sampling_rate)
# normalize and transform to `BatchFeatures`
vectorized_datasets = vectorized_datasets.map(
normalize,
batched=True,
num_proc=data_args.preprocessing_num_workers,
load_from_cache_file=not data_args.overwrite_cache,
remove_columns=vectorized_datasets["train"].column_names,
)
# pretraining is only supported for "newer" stable layer norm architecture
# apply_spec_augment has to be True, mask_feature_prob has to be 0.0
config = Wav2Vec2Config.from_pretrained(
model_args.model_name_or_path,
cache_dir=model_args.cache_dir,
gradient_checkpointing=model_args.gradient_checkpointing,
)
if not config.do_stable_layer_norm or config.feat_extract_norm != "layer":
raise ValueError(
"PreTraining is only supported for ``config.do_stable_layer_norm=True`` and ``config.feat_extract_norm='layer'"
)
model = Wav2Vec2ForPreTraining(config)
data_collator = DataCollatorForWav2Vec2Pretraining(model=model, feature_extractor=feature_extractor)
trainer = Wav2Vec2PreTrainer(
model=model,
data_collator=data_collator,
args=training_args,
train_dataset=vectorized_datasets["train"],
eval_dataset=vectorized_datasets["validation"],
tokenizer=feature_extractor,
max_gumbel_temp=model_args.max_gumbel_temperature,
min_gumbel_temp=model_args.min_gumbel_temperature,
gumbel_temp_decay=model_args.gumbel_temperature_decay,
)
trainer.train()
if __name__ == "__main__":
main()

View File

@ -1046,6 +1046,7 @@ if is_torch_available():
"WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST", "WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST",
"Wav2Vec2ForCTC", "Wav2Vec2ForCTC",
"Wav2Vec2ForMaskedLM", "Wav2Vec2ForMaskedLM",
"Wav2Vec2ForPreTraining",
"Wav2Vec2Model", "Wav2Vec2Model",
"Wav2Vec2PreTrainedModel", "Wav2Vec2PreTrainedModel",
] ]
@ -2411,6 +2412,7 @@ if TYPE_CHECKING:
WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST, WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST,
Wav2Vec2ForCTC, Wav2Vec2ForCTC,
Wav2Vec2ForMaskedLM, Wav2Vec2ForMaskedLM,
Wav2Vec2ForPreTraining,
Wav2Vec2Model, Wav2Vec2Model,
Wav2Vec2PreTrainedModel, Wav2Vec2PreTrainedModel,
) )

View File

@ -269,7 +269,7 @@ from ..tapas.modeling_tapas import (
from ..transfo_xl.modeling_transfo_xl import TransfoXLForSequenceClassification, TransfoXLLMHeadModel, TransfoXLModel from ..transfo_xl.modeling_transfo_xl import TransfoXLForSequenceClassification, TransfoXLLMHeadModel, TransfoXLModel
from ..visual_bert.modeling_visual_bert import VisualBertForPreTraining, VisualBertModel from ..visual_bert.modeling_visual_bert import VisualBertForPreTraining, VisualBertModel
from ..vit.modeling_vit import ViTForImageClassification, ViTModel from ..vit.modeling_vit import ViTForImageClassification, ViTModel
from ..wav2vec2.modeling_wav2vec2 import Wav2Vec2ForMaskedLM, Wav2Vec2Model from ..wav2vec2.modeling_wav2vec2 import Wav2Vec2ForMaskedLM, Wav2Vec2ForPreTraining, Wav2Vec2Model
from ..xlm.modeling_xlm import ( from ..xlm.modeling_xlm import (
XLMForMultipleChoice, XLMForMultipleChoice,
XLMForQuestionAnsweringSimple, XLMForQuestionAnsweringSimple,
@ -463,6 +463,7 @@ MODEL_FOR_PRETRAINING_MAPPING = OrderedDict(
(IBertConfig, IBertForMaskedLM), (IBertConfig, IBertForMaskedLM),
(DebertaConfig, DebertaForMaskedLM), (DebertaConfig, DebertaForMaskedLM),
(DebertaV2Config, DebertaV2ForMaskedLM), (DebertaV2Config, DebertaV2ForMaskedLM),
(Wav2Vec2Config, Wav2Vec2ForPreTraining),
] ]
) )

View File

@ -32,6 +32,7 @@ if is_torch_available():
"WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST", "WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST",
"Wav2Vec2ForCTC", "Wav2Vec2ForCTC",
"Wav2Vec2ForMaskedLM", "Wav2Vec2ForMaskedLM",
"Wav2Vec2ForPreTraining",
"Wav2Vec2Model", "Wav2Vec2Model",
"Wav2Vec2PreTrainedModel", "Wav2Vec2PreTrainedModel",
] ]
@ -48,6 +49,7 @@ if TYPE_CHECKING:
WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST, WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST,
Wav2Vec2ForCTC, Wav2Vec2ForCTC,
Wav2Vec2ForMaskedLM, Wav2Vec2ForMaskedLM,
Wav2Vec2ForPreTraining,
Wav2Vec2Model, Wav2Vec2Model,
Wav2Vec2PreTrainedModel, Wav2Vec2PreTrainedModel,
) )

View File

@ -71,6 +71,8 @@ class Wav2Vec2Config(PretrainedConfig):
feat_extract_activation (:obj:`str, `optional`, defaults to :obj:`"gelu"`): feat_extract_activation (:obj:`str, `optional`, defaults to :obj:`"gelu"`):
The non-linear activation function (function or string) in the 1D convolutional layers of the feature The non-linear activation function (function or string) in the 1D convolutional layers of the feature
extractor. If string, :obj:`"gelu"`, :obj:`"relu"`, :obj:`"selu"` and :obj:`"gelu_new"` are supported. extractor. If string, :obj:`"gelu"`, :obj:`"relu"`, :obj:`"selu"` and :obj:`"gelu_new"` are supported.
feat_quantizer_dropout (obj:`float`, `optional`, defaults to 0.0):
The dropout probabilitiy for quantized feature extractor states.
conv_dim (:obj:`Tuple[int]`, `optional`, defaults to :obj:`(512, 512, 512, 512, 512, 512, 512)`): conv_dim (:obj:`Tuple[int]`, `optional`, defaults to :obj:`(512, 512, 512, 512, 512, 512, 512)`):
A tuple of integers defining the number of input and output channels of each 1D convolutional layer in the A tuple of integers defining the number of input and output channels of each 1D convolutional layer in the
feature extractor. The length of `conv_dim` defines the number of 1D convolutional layers. feature extractor. The length of `conv_dim` defines the number of 1D convolutional layers.
@ -108,6 +110,22 @@ class Wav2Vec2Config(PretrainedConfig):
masked along the time axis. This is only relevant if ``apply_spec_augment is True``. masked along the time axis. This is only relevant if ``apply_spec_augment is True``.
mask_feature_length (:obj:`int`, `optional`, defaults to 10): mask_feature_length (:obj:`int`, `optional`, defaults to 10):
Length of vector span along the feature axis. Length of vector span along the feature axis.
num_codevectors_per_group (:obj:`int`, `optional`, defaults to 320):
Number of entries in each quantization codebook (group).
num_codevector_groups (:obj:`int`, `optional`, defaults to 2):
Number of codevector groups for product codevector quantization.
contrastive_logits_temperature (:obj:`float`, `optional`, defaults to 0.1):
The temperature `kappa` in the contrastive loss.
feat_quantizer_dropout (:obj:`float`, `optional`, defaults to 0.0):
The dropout probabilitiy for the output of the feature extractor that's used by the quantizer.
num_negatives (:obj:`int`, `optional`, defaults to 100):
Number of negative samples for the contrastive loss.
codevector_dim (:obj:`int`, `optional`, defaults to 256):
Dimensionality of the quantized feature vectors.
proj_codevector_dim (:obj:`int`, `optional`, defaults to 256):
Dimensionality of the final projection of both the quantized and the transformer features.
diversity_loss_weight (:obj:`int`, `optional`, defaults to 0.1):
The weight of the codebook diversity loss component.
ctc_loss_reduction (:obj:`str`, `optional`, defaults to :obj:`"sum"`): ctc_loss_reduction (:obj:`str`, `optional`, defaults to :obj:`"sum"`):
Specifies the reduction to apply to the output of ``torch.nn.CTCLoss``. Only relevant when training an Specifies the reduction to apply to the output of ``torch.nn.CTCLoss``. Only relevant when training an
instance of :class:`~transformers.Wav2Vec2ForCTC`. instance of :class:`~transformers.Wav2Vec2ForCTC`.
@ -145,6 +163,7 @@ class Wav2Vec2Config(PretrainedConfig):
activation_dropout=0.1, activation_dropout=0.1,
attention_dropout=0.1, attention_dropout=0.1,
feat_proj_dropout=0.1, feat_proj_dropout=0.1,
feat_quantizer_dropout=0.0,
final_dropout=0.1, final_dropout=0.1,
layerdrop=0.1, layerdrop=0.1,
initializer_range=0.02, initializer_range=0.02,
@ -163,6 +182,13 @@ class Wav2Vec2Config(PretrainedConfig):
mask_time_length=10, mask_time_length=10,
mask_feature_prob=0.0, mask_feature_prob=0.0,
mask_feature_length=10, mask_feature_length=10,
num_codevectors_per_group=320,
num_codevector_groups=2,
contrastive_logits_temperature=0.1,
num_negatives=100,
codevector_dim=256,
proj_codevector_dim=256,
diversity_loss_weight=0.1,
ctc_loss_reduction="sum", ctc_loss_reduction="sum",
ctc_zero_infinity=False, ctc_zero_infinity=False,
gradient_checkpointing=False, gradient_checkpointing=False,
@ -217,6 +243,16 @@ class Wav2Vec2Config(PretrainedConfig):
self.mask_feature_prob = mask_feature_prob self.mask_feature_prob = mask_feature_prob
self.mask_feature_length = mask_feature_length self.mask_feature_length = mask_feature_length
# parameters for pretraining with codevector quantized representations
self.num_codevectors_per_group = num_codevectors_per_group
self.num_codevector_groups = num_codevector_groups
self.contrastive_logits_temperature = contrastive_logits_temperature
self.feat_quantizer_dropout = feat_quantizer_dropout
self.num_negatives = num_negatives
self.codevector_dim = codevector_dim
self.proj_codevector_dim = proj_codevector_dim
self.diversity_loss_weight = diversity_loss_weight
# ctc loss # ctc loss
self.ctc_loss_reduction = ctc_loss_reduction self.ctc_loss_reduction = ctc_loss_reduction
self.ctc_zero_infinity = ctc_zero_infinity self.ctc_zero_infinity = ctc_zero_infinity

View File

@ -28,7 +28,7 @@ from transformers import (
Wav2Vec2CTCTokenizer, Wav2Vec2CTCTokenizer,
Wav2Vec2FeatureExtractor, Wav2Vec2FeatureExtractor,
Wav2Vec2ForCTC, Wav2Vec2ForCTC,
Wav2Vec2Model, Wav2Vec2ForPreTraining,
Wav2Vec2Processor, Wav2Vec2Processor,
logging, logging,
) )
@ -50,9 +50,20 @@ MAPPING = {
"final_layer_norm": "encoder.layers.*.final_layer_norm", "final_layer_norm": "encoder.layers.*.final_layer_norm",
"encoder.layer_norm": "encoder.layer_norm", "encoder.layer_norm": "encoder.layer_norm",
"w2v_model.layer_norm": "feature_projection.layer_norm", "w2v_model.layer_norm": "feature_projection.layer_norm",
"quantizer.weight_proj": "quantizer.weight_proj",
"quantizer.vars": "quantizer.codevectors",
"project_q": "project_q",
"final_proj": "project_hid",
"w2v_encoder.proj": "lm_head", "w2v_encoder.proj": "lm_head",
"mask_emb": "masked_spec_embed", "mask_emb": "masked_spec_embed",
} }
TOP_LEVEL_KEYS = [
"lm_head",
"quantizer.weight_proj",
"quantizer.codevectors",
"project_q",
"project_hid",
]
def set_recursively(hf_pointer, key, value, full_name, weight_type): def set_recursively(hf_pointer, key, value, full_name, weight_type):
@ -82,11 +93,11 @@ def set_recursively(hf_pointer, key, value, full_name, weight_type):
logger.info(f"{key + '.' + weight_type if weight_type is not None else ''} was initialized from {full_name}.") logger.info(f"{key + '.' + weight_type if weight_type is not None else ''} was initialized from {full_name}.")
def recursively_load_weights(fairseq_model, hf_model, is_finetuned): def recursively_load_weights(fairseq_model, hf_model, is_headless):
unused_weights = [] unused_weights = []
fairseq_dict = fairseq_model.state_dict() fairseq_dict = fairseq_model.state_dict()
feature_extractor = hf_model.wav2vec2.feature_extractor if is_finetuned else hf_model.feature_extractor feature_extractor = hf_model.wav2vec2.feature_extractor
for name, value in fairseq_dict.items(): for name, value in fairseq_dict.items():
is_used = False is_used = False
@ -101,9 +112,8 @@ def recursively_load_weights(fairseq_model, hf_model, is_finetuned):
is_used = True is_used = True
else: else:
for key, mapped_key in MAPPING.items(): for key, mapped_key in MAPPING.items():
mapped_key = "wav2vec2." + mapped_key if (is_finetuned and mapped_key != "lm_head") else mapped_key mapped_key = "wav2vec2." + mapped_key if mapped_key not in TOP_LEVEL_KEYS else mapped_key
if key in name or key.split("w2v_model.")[-1] == name.split(".")[0]:
if key in name or (key.split("w2v_model.")[-1] == name.split(".")[0] and not is_finetuned):
is_used = True is_used = True
if "*" in mapped_key: if "*" in mapped_key:
layer_index = name.split(key)[0].split(".")[-2] layer_index = name.split(key)[0].split(".")[-2]
@ -112,10 +122,11 @@ def recursively_load_weights(fairseq_model, hf_model, is_finetuned):
weight_type = "weight_g" weight_type = "weight_g"
elif "weight_v" in name: elif "weight_v" in name:
weight_type = "weight_v" weight_type = "weight_v"
elif "weight" in name:
weight_type = "weight"
elif "bias" in name: elif "bias" in name:
weight_type = "bias" weight_type = "bias"
elif "weight" in name:
# TODO: don't match quantizer.weight_proj
weight_type = "weight"
else: else:
weight_type = None weight_type = None
set_recursively(hf_model, mapped_key, value, name, weight_type) set_recursively(hf_model, mapped_key, value, name, weight_type)
@ -213,7 +224,7 @@ def convert_wav2vec2_checkpoint(
hf_wav2vec = Wav2Vec2ForCTC(config) hf_wav2vec = Wav2Vec2ForCTC(config)
else: else:
hf_wav2vec = Wav2Vec2Model(config) hf_wav2vec = Wav2Vec2ForPreTraining(config)
if is_finetuned: if is_finetuned:
model, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task( model, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task(
@ -224,7 +235,7 @@ def convert_wav2vec2_checkpoint(
model = model[0].eval() model = model[0].eval()
recursively_load_weights(model, hf_wav2vec, is_finetuned) recursively_load_weights(model, hf_wav2vec, not is_finetuned)
hf_wav2vec.save_pretrained(pytorch_dump_folder_path) hf_wav2vec.save_pretrained(pytorch_dump_folder_path)

View File

@ -15,7 +15,8 @@
""" PyTorch Wav2Vec2 model. """ """ PyTorch Wav2Vec2 model. """
import warnings import warnings
from typing import Optional, Tuple from dataclasses import dataclass
from typing import Optional, Tuple, Union
import numpy as np import numpy as np
import torch import torch
@ -26,7 +27,12 @@ from torch import nn
from transformers.deepspeed import is_deepspeed_zero3_enabled from transformers.deepspeed import is_deepspeed_zero3_enabled
from ...activations import ACT2FN from ...activations import ACT2FN
from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings from ...file_utils import (
ModelOutput,
add_start_docstrings,
add_start_docstrings_to_model_forward,
replace_return_docstrings,
)
from ...modeling_outputs import BaseModelOutput, CausalLMOutput, MaskedLMOutput from ...modeling_outputs import BaseModelOutput, CausalLMOutput, MaskedLMOutput
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...utils import logging from ...utils import logging
@ -46,6 +52,71 @@ WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST = [
] ]
@dataclass
class Wav2Vec2BaseModelOutput(ModelOutput):
"""
Output type of :class:`~transformers.Wav2Vec2BaseModelOutput`, with potential hidden states and attentions.
Args:
last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
extract_features (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, conv_dim[-1])`):
Sequence of extracted feature vectors of the last convolutional layer of the model.
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
last_hidden_state: torch.FloatTensor = None
extract_features: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
@dataclass
class Wav2Vec2ForPreTrainingOutput(ModelOutput):
"""
Output type of :class:`~transformers.Wav2Vec2ForPreTrainingOutput`, with potential hidden states and attentions.
Args:
loss (`optional`, returned when model is in train mode, ``torch.FloatTensor`` of shape :obj:`(1,)`):
Total loss as the sum of the contrastive loss (L_m) and the diversity loss (L_d) as stated in the `official
paper <https://arxiv.org/pdf/2006.11477.pdf>`__ . (classification) loss.
projected_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.proj_codevector_dim)`):
Hidden-states of the model projected to `config.proj_codevector_dim` that can be used to predict the masked
projected quantized states.
projected_quantized_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.proj_codevector_dim)`):
Quantized extracted feature vectors projected to `config.proj_codevector_dim` representing the positive
target vectors for contrastive loss.
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
loss: Optional[torch.FloatTensor] = None
projected_states: torch.FloatTensor = None
projected_quantized_states: torch.FloatTensor = None
codevector_perplexity: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
def _compute_mask_indices( def _compute_mask_indices(
shape: Tuple[int, int], shape: Tuple[int, int],
mask_prob: float, mask_prob: float,
@ -271,10 +342,11 @@ class Wav2Vec2FeatureProjection(nn.Module):
self.dropout = nn.Dropout(config.feat_proj_dropout) self.dropout = nn.Dropout(config.feat_proj_dropout)
def forward(self, hidden_states): def forward(self, hidden_states):
hidden_states = self.layer_norm(hidden_states) # non-projected hidden states are needed for quantization
hidden_states = self.projection(hidden_states) norm_hidden_states = self.layer_norm(hidden_states)
hidden_states = self.projection(norm_hidden_states)
hidden_states = self.dropout(hidden_states) hidden_states = self.dropout(hidden_states)
return hidden_states return hidden_states, norm_hidden_states
# Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->Wav2Vec2 # Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->Wav2Vec2
@ -685,6 +757,86 @@ class Wav2Vec2EncoderStableLayerNorm(nn.Module):
) )
class Wav2Vec2GumbelVectorQuantizer(nn.Module):
"""
Vector quantization using gumbel softmax. See `CATEGORICAL REPARAMETERIZATION WITH GUMBEL-SOFTMAX
<https://arxiv.org/pdf/1611.01144.pdf>`__ for more information.
"""
def __init__(self, config):
super().__init__()
self.num_groups = config.num_codevector_groups
self.num_vars = config.num_codevectors_per_group
assert (
config.codevector_dim % self.num_groups == 0
), f"`config.codevector_dim {config.codevector_dim} must be divisible by `config.num_codevector_groups` {self.num_groups} for concatenation"
# storage for codebook variables (codewords)
self.codevectors = nn.Parameter(
torch.FloatTensor(1, self.num_groups * self.num_vars, config.codevector_dim // self.num_groups)
)
self.weight_proj = nn.Linear(config.conv_dim[-1], self.num_groups * self.num_vars)
# can be decayed for training
self.temperature = 1
def set_temperature(self, temperature: int):
self.temperature = temperature
@staticmethod
def _compute_perplexity(probs, mask=None):
if mask is not None:
mask_extended = mask.flatten()[:, None, None].expand(probs.shape)
probs = torch.where(mask_extended, probs, torch.zeros_like(probs))
marginal_probs = probs.sum(dim=0) / mask.sum()
else:
marginal_probs = probs.mean(dim=0)
perplexity = torch.exp(-torch.sum(marginal_probs * torch.log(marginal_probs + 1e-7), dim=-1)).sum()
return perplexity
def forward(self, hidden_states, mask_time_indices=None):
batch_size, sequence_length, hidden_size = hidden_states.shape
# project to codevector dim
hidden_states = self.weight_proj(hidden_states)
hidden_states = hidden_states.view(batch_size * sequence_length * self.num_groups, -1)
if self.training:
# sample code vector probs via gumbel in differentiateable way
codevector_probs = F.gumbel_softmax(hidden_states.float(), tau=self.temperature, hard=True).type_as(
hidden_states
)
# compute perplexity
codevector_soft_dist = torch.softmax(
hidden_states.view(batch_size * sequence_length, self.num_groups, -1).float(), dim=-1
)
perplexity = self._compute_perplexity(codevector_soft_dist, mask_time_indices)
else:
# take argmax in non-differentiable way
# comptute hard codevector distribution (one hot)
codevector_idx = hidden_states.argmax(dim=-1)
codevector_probs = hidden_states.new_zeros(*hidden_states.shape).scatter_(
-1, codevector_idx.view(-1, 1), 1.0
)
codevector_probs = codevector_probs.view(batch_size * sequence_length, self.num_groups, -1)
perplexity = self._compute_perplexity(codevector_probs, mask_time_indices)
codevector_probs = codevector_probs.view(batch_size * sequence_length, -1)
# use probs to retrieve codevectors
codevectors_per_group = codevector_probs.unsqueeze(-1) * self.codevectors
codevectors = (
codevectors_per_group.view(batch_size * sequence_length, self.num_groups, self.num_vars, -1)
.sum(-2)
.view(batch_size, sequence_length, -1)
)
return codevectors, perplexity
class Wav2Vec2PreTrainedModel(PreTrainedModel): class Wav2Vec2PreTrainedModel(PreTrainedModel):
""" """
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
@ -697,7 +849,12 @@ class Wav2Vec2PreTrainedModel(PreTrainedModel):
def _init_weights(self, module): def _init_weights(self, module):
"""Initialize the weights""" """Initialize the weights"""
if isinstance(module, nn.Linear): # gumbel softmax requires special init
if isinstance(module, Wav2Vec2GumbelVectorQuantizer):
module.weight_proj.weight.data.normal_(mean=0.0, std=1)
module.weight_proj.bias.data.zero_()
nn.init.uniform_(module.codevectors)
elif isinstance(module, nn.Linear):
# Slightly different from the TF version which uses truncated_normal for initialization # Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617 # cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
@ -720,7 +877,7 @@ class Wav2Vec2PreTrainedModel(PreTrainedModel):
if isinstance(module, (nn.Linear, nn.Conv1d)) and module.bias is not None: if isinstance(module, (nn.Linear, nn.Conv1d)) and module.bias is not None:
module.bias.data.zero_() module.bias.data.zero_()
def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor): def _get_feat_extract_output_lengths(self, input_lengths: Union[torch.LongTensor, int]):
""" """
Computes the output length of the convolutional layers Computes the output length of the convolutional layers
""" """
@ -733,7 +890,7 @@ class Wav2Vec2PreTrainedModel(PreTrainedModel):
for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride): for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride):
input_lengths = _conv_out_length(input_lengths, kernel_size, stride) input_lengths = _conv_out_length(input_lengths, kernel_size, stride)
return input_lengths.to(torch.long) return input_lengths
WAV_2_VEC_2_START_DOCSTRING = r""" WAV_2_VEC_2_START_DOCSTRING = r"""
@ -797,7 +954,7 @@ WAV_2_VEC_2_INPUTS_DOCSTRING = r"""
WAV_2_VEC_2_START_DOCSTRING, WAV_2_VEC_2_START_DOCSTRING,
) )
class Wav2Vec2Model(Wav2Vec2PreTrainedModel): class Wav2Vec2Model(Wav2Vec2PreTrainedModel):
def __init__(self, config): def __init__(self, config: Wav2Vec2Config):
super().__init__(config) super().__init__(config)
self.config = config self.config = config
self.feature_extractor = Wav2Vec2FeatureExtractor(config) self.feature_extractor = Wav2Vec2FeatureExtractor(config)
@ -812,12 +969,53 @@ class Wav2Vec2Model(Wav2Vec2PreTrainedModel):
self.init_weights() self.init_weights()
def _mask_hidden_states(
self, hidden_states: torch.FloatTensor, mask_time_indices: Optional[torch.FloatTensor] = None
):
"""
Masks extracted features along time axis and/or along feature axis according to `SpecAugment
<https://arxiv.org/abs/1904.08779>`__ .
"""
# `config.apply_spec_augment` can set masking to False
if not getattr(self.config, "apply_spec_augment", True):
return hidden_states
if mask_time_indices is not None:
# apply SpecAugment along time axis with given mask_time_indices
hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
elif self.config.mask_time_prob > 0 and self.training:
# generate indices & apply SpecAugment along time axis
batch_size, sequence_length, hidden_size = hidden_states.size()
mask_time_indices = _compute_mask_indices(
(batch_size, sequence_length),
mask_prob=self.config.mask_time_prob,
mask_length=self.config.mask_time_length,
device=hidden_states.device,
min_masks=2,
)
hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
if self.config.mask_feature_prob > 0 and self.training:
# generate indices & apply SpecAugment along feature axis
mask_feature_indices = _compute_mask_indices(
(batch_size, hidden_size),
mask_prob=self.config.mask_feature_prob,
mask_length=self.config.mask_feature_length,
device=hidden_states.device,
)
hidden_states[mask_feature_indices[:, None].expand(-1, sequence_length, -1)] = 0
return hidden_states
@add_start_docstrings_to_model_forward(WAV_2_VEC_2_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(WAV_2_VEC_2_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC)
def forward( def forward(
self, self,
input_values, input_values,
attention_mask=None, attention_mask=None,
mask_time_indices=None,
output_attentions=None, output_attentions=None,
output_hidden_states=None, output_hidden_states=None,
return_dict=None, return_dict=None,
@ -852,49 +1050,30 @@ class Wav2Vec2Model(Wav2Vec2PreTrainedModel):
) )
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = return_dict if return_dict is not None else self.config.use_return_dict
hidden_states = self.feature_extractor(input_values) extract_features = self.feature_extractor(input_values)
hidden_states = hidden_states.transpose(1, 2) extract_features = extract_features.transpose(1, 2)
if attention_mask is not None: if attention_mask is not None:
# compute real output lengths according to convolution formula # compute real output lengths according to convolution formula
output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)) output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long)
attention_mask = torch.zeros( attention_mask = torch.zeros(
hidden_states.shape[:2], dtype=hidden_states.dtype, device=hidden_states.device extract_features.shape[:2], dtype=extract_features.dtype, device=extract_features.device
) )
# these two operations makes sure that all values # these two operations makes sure that all values
# before the output lengths indices are attended to # before the output lengths indices are attended to
attention_mask[ attention_mask[
(torch.arange(attention_mask.shape[0], device=hidden_states.device), output_lengths - 1) (torch.arange(attention_mask.shape[0], device=extract_features.device), output_lengths - 1)
] = 1 ] = 1
attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool() attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool()
hidden_states = self.feature_projection(hidden_states) hidden_states, extract_features = self.feature_projection(extract_features)
if self.config.apply_spec_augment and self.training: if mask_time_indices is not None: # apply SpecAugment along time axis with given indices
batch_size, sequence_length, hidden_size = hidden_states.size() hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
# apply SpecAugment along time axis hidden_states = self._mask_hidden_states(hidden_states)
if self.config.mask_time_prob > 0:
mask_time_indices = _compute_mask_indices(
(batch_size, sequence_length),
mask_prob=self.config.mask_time_prob,
mask_length=self.config.mask_time_length,
device=hidden_states.device,
min_masks=2,
)
hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
# apply SpecAugment along feature axis
if self.config.mask_feature_prob > 0:
mask_feature_indices = _compute_mask_indices(
(batch_size, hidden_size),
mask_prob=self.config.mask_feature_prob,
mask_length=self.config.mask_feature_length,
device=hidden_states.device,
)
hidden_states[mask_feature_indices[:, None].expand(-1, sequence_length, -1)] = 0
encoder_outputs = self.encoder( encoder_outputs = self.encoder(
hidden_states, hidden_states,
@ -907,15 +1086,240 @@ class Wav2Vec2Model(Wav2Vec2PreTrainedModel):
hidden_states = encoder_outputs[0] hidden_states = encoder_outputs[0]
if not return_dict: if not return_dict:
return (hidden_states,) + encoder_outputs[1:] return (hidden_states, extract_features) + encoder_outputs[1:]
return BaseModelOutput( return Wav2Vec2BaseModelOutput(
last_hidden_state=hidden_states, last_hidden_state=hidden_states,
extract_features=extract_features,
hidden_states=encoder_outputs.hidden_states, hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions, attentions=encoder_outputs.attentions,
) )
@add_start_docstrings("""Wav2Vec2 Model with a quantizer and `VQ` head on top. """, WAV_2_VEC_2_START_DOCSTRING)
class Wav2Vec2ForPreTraining(Wav2Vec2PreTrainedModel):
def __init__(self, config: Wav2Vec2Config):
super().__init__(config)
self.wav2vec2 = Wav2Vec2Model(config)
self.dropout_features = nn.Dropout(config.feat_quantizer_dropout)
self.quantizer = Wav2Vec2GumbelVectorQuantizer(config)
self.project_q = nn.Linear(config.codevector_dim, config.proj_codevector_dim)
self.project_hid = nn.Linear(config.hidden_size, config.proj_codevector_dim)
self.init_weights()
def set_gumbel_temperature(self, temperature: int):
"""
Set the Gumbel softmax temperature to a given value. Only necessary for training
"""
return self.quantizer.set_temperature(temperature)
def freeze_feature_extractor(self):
"""
Calling this function will disable the gradient computation for the feature extractor so that its parameters
will not be updated during training.
"""
self.wav2vec2.feature_extractor._freeze_parameters()
@staticmethod
def _sample_negatives(features: torch.FloatTensor, num_negatives: int):
"""
Sample `num_negatives` vectors from feature vectors.
"""
batch_size, sequence_length, hidden_size = features.shape
if sequence_length <= 1:
raise ValueError(
f"`features should have `sequence_length` > 1, but are of shape (batch_size, sequence_length, hidden_size) = ({batch_size, sequence_length, hidden_size})."
)
features = features.view(-1, hidden_size) # BTC => (BxT)C
with torch.no_grad():
# get `num_negatives` random vector indices from the same utterance
sampled_negative_indices = torch.randint(
low=0,
high=sequence_length - 1,
size=(batch_size, num_negatives * sequence_length),
device=features.device,
)
# generate indices of the positive vectors themselves, repeat them `num_negatives` times
feature_indices = (
torch.arange(sequence_length, device=features.device)[:, None]
.expand(sequence_length, num_negatives)
.flatten()
)
# avoid sampling the same positive vector, but keep the distribution uniform
sampled_negative_indices[sampled_negative_indices >= feature_indices] += 1
# correct for batch size
for batch_idx in range(1, batch_size):
sampled_negative_indices[batch_idx] += batch_idx * sequence_length
# take negative vectors from sampled indices
sampled_negatives = features[sampled_negative_indices.view(-1)]
sampled_negatives = sampled_negatives.view(batch_size, sequence_length, num_negatives, hidden_size).permute(
2, 0, 1, 3
)
return sampled_negatives
@staticmethod
def compute_contrastive_logits(
target_features: torch.FloatTensor,
negative_features: torch.FloatTensor,
predicted_features: torch.FloatTensor,
temperature: int = 1,
):
"""
Compute logits for contrastive loss based using cosine similarity as the distance measure between
`[positive_feature, negative_features]` and `[predicted_features]`. Additionally, temperature can be applied.
"""
target_features = torch.cat([target_features, negative_features], dim=0)
logits = torch.cosine_similarity(predicted_features.float(), target_features.float(), dim=-1).type_as(
target_features
)
# apply temperature
logits = logits / temperature
return logits
@add_start_docstrings_to_model_forward(WAV_2_VEC_2_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_values,
attention_mask=None,
mask_time_indices=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
r"""
mask_time_indices (:obj:`torch.BoolTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Indices to mask extracted features for contrastive loss. When in training mode, model learns to predict
masked extracted features in `config.proj_codevector_dim` space.
Returns:
Example::
>>> import torch
>>> from transformers import Wav2Vec2FeatureExtractor, Wav2Vec2ForPreTraining
>>> from transformers.models.wav2vec2.modeling_wav2vec2 import _compute_mask_indices
>>> from datasets import load_dataset
>>> import soundfile as sf
>>> feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("patrickvonplaten/wav2vec2-base")
>>> model = Wav2Vec2ForPreTraining.from_pretrained("patrickvonplaten/wav2vec2-base")
>>> def map_to_array(batch):
... speech, _ = sf.read(batch["file"])
... batch["speech"] = speech
... return batch
>>> ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation")
>>> ds = ds.map(map_to_array)
>>> input_values = feature_extractor(ds["speech"][0], return_tensors="pt").input_values # Batch size 1
>>> # compute masked indices
>>> batch_size, raw_sequence_length = input_values.shape
>>> sequence_length = model._get_feat_extract_output_lengths(raw_sequence_length)
>>> mask_time_indices = _compute_mask_indices((batch_size, sequence_length), mask_prob=0.2, mask_length=2, device=model.device)
>>> with torch.no_grad():
... outputs = model(input_values, mask_time_indices=mask_time_indices)
>>> # compute cosine similarity between predicted (=projected_states) and target (=projected_quantized_states)
>>> cosine_sim = torch.cosine_similarity(
... outputs.projected_states, outputs.projected_quantized_states, dim=-1
... )
>>> # show that cosine similarity is much higher than random
>>> assert cosine_sim[mask_time_indices].mean() > 0.5
>>> # for contrastive loss training model should be put into train mode
>>> model.train()
>>> loss = model(input_values, mask_time_indices=mask_time_indices).loss
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if mask_time_indices is not None:
mask_time_indices = mask_time_indices.to(torch.bool)
outputs = self.wav2vec2(
input_values,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
mask_time_indices=mask_time_indices,
return_dict=return_dict,
)
# 1. project all transformed features (including masked) to final vq dim
transformer_features = self.project_hid(outputs[0])
# 2. quantize all (unmasked) extracted features and project to final vq dim
extract_features = self.dropout_features(outputs[1])
quantized_features, codevector_perplexity = self.quantizer(extract_features, mask_time_indices)
quantized_features = self.project_q(quantized_features)
loss = None
if self.training:
# for training, we sample negatives
# 3. sample K negatives (distractors) quantized states for contrastive loss
negative_quantized_features = self._sample_negatives(quantized_features, self.config.num_negatives)
# 4. compute logits, corresponding to `logs = sim(c_t, [q_t, \sim{q}_t]) / \kappa`
# of equation (3) in https://arxiv.org/pdf/2006.11477.pdf
logits = self.compute_contrastive_logits(
quantized_features[None, :],
negative_quantized_features,
transformer_features,
self.config.contrastive_logits_temperature,
)
# 5. if a negative vector is identical to the positive (i.e. when codebook utilization is low),
# its cosine similarity will be masked
neg_is_pos = (quantized_features == negative_quantized_features).all(-1)
if neg_is_pos.any():
logits[1:][neg_is_pos] = float("-inf")
# 6. compute contrastive loss \mathbf{L}_m = cross_entropy(logs) =
# -log(exp(sim(c_t, q_t)/\kappa) / \sum_{\sim{q}} exp(sim(c_t, \sim{q})/\kappa))
preds = logits.transpose(0, 2).reshape(-1, logits.size(0))
target = ((1 - mask_time_indices.long()) * -100).transpose(0, 1).flatten()
contrastive_loss = F.cross_entropy(preds.float(), target, reduction="sum")
# 7. compute diversity loss: \mathbf{L}_d
num_codevectors = self.config.num_codevectors_per_group * self.config.num_codevector_groups
diversity_loss = (num_codevectors - codevector_perplexity) / num_codevectors
# 8. \mathbf{L} = \mathbf{L}_m + \alpha * \mathbf{L}_d
loss = contrastive_loss + self.config.diversity_loss_weight * diversity_loss
if not return_dict:
if loss is not None:
return (loss, transformer_features, quantized_features, codevector_perplexity) + outputs[2:]
return (transformer_features, quantized_features, codevector_perplexity) + outputs[2:]
return Wav2Vec2ForPreTrainingOutput(
loss=loss,
projected_states=transformer_features,
projected_quantized_states=quantized_features,
codevector_perplexity=codevector_perplexity,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
@add_start_docstrings("""Wav2Vec2 Model with a `language modeling` head on top. """, WAV_2_VEC_2_START_DOCSTRING) @add_start_docstrings("""Wav2Vec2 Model with a `language modeling` head on top. """, WAV_2_VEC_2_START_DOCSTRING)
class Wav2Vec2ForMaskedLM(Wav2Vec2PreTrainedModel): class Wav2Vec2ForMaskedLM(Wav2Vec2PreTrainedModel):
def __init__(self, config): def __init__(self, config):
@ -986,7 +1390,7 @@ class Wav2Vec2ForMaskedLM(Wav2Vec2PreTrainedModel):
logits = self.lm_head(hidden_states) logits = self.lm_head(hidden_states)
if not return_dict: if not return_dict:
output = (logits,) + outputs[1:] output = (logits,) + outputs[2:]
return output return output
return MaskedLMOutput(logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions) return MaskedLMOutput(logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions)
@ -1089,7 +1493,7 @@ class Wav2Vec2ForCTC(Wav2Vec2PreTrainedModel):
attention_mask = ( attention_mask = (
attention_mask if attention_mask is not None else torch.ones_like(input_values, dtype=torch.long) attention_mask if attention_mask is not None else torch.ones_like(input_values, dtype=torch.long)
) )
input_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)) input_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long)
# assuming that padded tokens are filled with -100 # assuming that padded tokens are filled with -100
# when not being attended to # when not being attended to
@ -1112,7 +1516,7 @@ class Wav2Vec2ForCTC(Wav2Vec2PreTrainedModel):
) )
if not return_dict: if not return_dict:
output = (logits,) + outputs[1:] output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
return CausalLMOutput( return CausalLMOutput(

View File

@ -2969,6 +2969,11 @@ class Wav2Vec2ForMaskedLM:
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
class Wav2Vec2ForPreTraining:
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class Wav2Vec2Model: class Wav2Vec2Model:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])

View File

@ -29,8 +29,16 @@ from .test_modeling_common import ModelTesterMixin, _config_zero_init
if is_torch_available(): if is_torch_available():
import torch import torch
from transformers import Wav2Vec2Config, Wav2Vec2ForCTC, Wav2Vec2ForMaskedLM, Wav2Vec2Model, Wav2Vec2Processor from transformers import (
from transformers.models.wav2vec2.modeling_wav2vec2 import _compute_mask_indices Wav2Vec2Config,
Wav2Vec2FeatureExtractor,
Wav2Vec2ForCTC,
Wav2Vec2ForMaskedLM,
Wav2Vec2ForPreTraining,
Wav2Vec2Model,
Wav2Vec2Processor,
)
from transformers.models.wav2vec2.modeling_wav2vec2 import Wav2Vec2GumbelVectorQuantizer, _compute_mask_indices
class Wav2Vec2ModelTester: class Wav2Vec2ModelTester:
@ -219,13 +227,7 @@ class Wav2Vec2ModelTester:
@require_torch @require_torch
class Wav2Vec2ModelTest(ModelTesterMixin, unittest.TestCase): class Wav2Vec2ModelTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = ( all_model_classes = (
( (Wav2Vec2ForCTC, Wav2Vec2Model, Wav2Vec2ForMaskedLM, Wav2Vec2ForPreTraining) if is_torch_available() else ()
Wav2Vec2ForCTC,
Wav2Vec2Model,
Wav2Vec2ForMaskedLM,
)
if is_torch_available()
else ()
) )
test_pruning = False test_pruning = False
test_headmasking = False test_headmasking = False
@ -316,8 +318,14 @@ class Wav2Vec2ModelTest(ModelTesterMixin, unittest.TestCase):
for model_class in self.all_model_classes: for model_class in self.all_model_classes:
model = model_class(config=configs_no_init) model = model_class(config=configs_no_init)
for name, param in model.named_parameters(): for name, param in model.named_parameters():
uniform_init_parms = [
"conv.weight",
"masked_spec_embed",
"codevectors",
"quantizer.weight_proj.weight",
]
if param.requires_grad: if param.requires_grad:
if "conv.weight" in name or "masked_spec_embed" in name: if any([x in name for x in uniform_init_parms]):
self.assertTrue( self.assertTrue(
-1.0 <= ((param.data.mean() * 1e9).round() / 1e9).item() <= 1.0, -1.0 <= ((param.data.mean() * 1e9).round() / 1e9).item() <= 1.0,
msg=f"Parameter {name} of model {model_class} seems not properly initialized", msg=f"Parameter {name} of model {model_class} seems not properly initialized",
@ -333,10 +341,14 @@ class Wav2Vec2ModelTest(ModelTesterMixin, unittest.TestCase):
def _mock_init_weights(self, module): def _mock_init_weights(self, module):
if hasattr(module, "weight") and module.weight is not None: if hasattr(module, "weight") and module.weight is not None:
module.weight.data.fill_(3) module.weight.data.fill_(3)
if hasattr(module, "weight_g") and module.weight is not None: if hasattr(module, "weight_g") and module.weight_g is not None:
module.weight_g.data.fill_(3) module.weight_g.data.fill_(3)
if hasattr(module, "weight_v") and module.weight_v is not None:
module.weight_v.data.fill_(3)
if hasattr(module, "bias") and module.bias is not None: if hasattr(module, "bias") and module.bias is not None:
module.bias.data.fill_(3) module.bias.data.fill_(3)
if hasattr(module, "codevectors") and module.codevectors is not None:
module.codevectors.data.fill_(3)
@slow @slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
@ -346,7 +358,9 @@ class Wav2Vec2ModelTest(ModelTesterMixin, unittest.TestCase):
@require_torch @require_torch
class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase): class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = (Wav2Vec2ForCTC, Wav2Vec2Model, Wav2Vec2ForMaskedLM) if is_torch_available() else () all_model_classes = (
(Wav2Vec2ForCTC, Wav2Vec2Model, Wav2Vec2ForMaskedLM, Wav2Vec2ForPreTraining) if is_torch_available() else ()
)
test_pruning = False test_pruning = False
test_headmasking = False test_headmasking = False
test_torchscript = False test_torchscript = False
@ -442,8 +456,14 @@ class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase):
for model_class in self.all_model_classes: for model_class in self.all_model_classes:
model = model_class(config=configs_no_init) model = model_class(config=configs_no_init)
for name, param in model.named_parameters(): for name, param in model.named_parameters():
uniform_init_parms = [
"conv.weight",
"masked_spec_embed",
"codevectors",
"quantizer.weight_proj.weight",
]
if param.requires_grad: if param.requires_grad:
if "conv.weight" in name or "masked_spec_embed" in name: if any([x in name for x in uniform_init_parms]):
self.assertTrue( self.assertTrue(
-1.0 <= ((param.data.mean() * 1e9).round() / 1e9).item() <= 1.0, -1.0 <= ((param.data.mean() * 1e9).round() / 1e9).item() <= 1.0,
msg=f"Parameter {name} of model {model_class} seems not properly initialized", msg=f"Parameter {name} of model {model_class} seems not properly initialized",
@ -459,10 +479,47 @@ class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase):
def _mock_init_weights(self, module): def _mock_init_weights(self, module):
if hasattr(module, "weight") and module.weight is not None: if hasattr(module, "weight") and module.weight is not None:
module.weight.data.fill_(3) module.weight.data.fill_(3)
if hasattr(module, "weight_g") and module.weight is not None: if hasattr(module, "weight_g") and module.weight_g is not None:
module.weight_g.data.fill_(3) module.weight_g.data.fill_(3)
if hasattr(module, "weight_v") and module.weight_v is not None:
module.weight_v.data.fill_(3)
if hasattr(module, "bias") and module.bias is not None: if hasattr(module, "bias") and module.bias is not None:
module.bias.data.fill_(3) module.bias.data.fill_(3)
if hasattr(module, "codevectors") and module.codevectors is not None:
module.codevectors.data.fill_(3)
def test_model_for_pretraining(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = Wav2Vec2ForPreTraining(config).to(torch_device)
features_shape = (
inputs_dict["input_values"].shape[0],
model._get_feat_extract_output_lengths(torch.tensor(inputs_dict["input_values"].shape[1])),
)
mask_time_indices = _compute_mask_indices(
features_shape,
model.config.mask_time_prob,
model.config.mask_time_length,
device=inputs_dict["input_values"].device,
min_masks=2,
).to(torch_device)
loss = model(
inputs_dict["input_values"],
attention_mask=inputs_dict["attention_mask"],
mask_time_indices=mask_time_indices,
).loss
mask_time_indices[:, : mask_time_indices.shape[-1] // 2] = True
loss_more_masked = model(
inputs_dict["input_values"],
attention_mask=inputs_dict["attention_mask"],
mask_time_indices=mask_time_indices,
).loss
# loss_more_masked has to be bigger or equal loss since more masked inputs have to be predicted
self.assertTrue(loss.detach().item() <= loss_more_masked.detach().item())
@slow @slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
@ -484,24 +541,56 @@ class Wav2Vec2UtilsTest(unittest.TestCase):
def test_compute_mask_indices_overlap(self): def test_compute_mask_indices_overlap(self):
batch_size = 4 batch_size = 4
sequence_length = 60 sequence_length = 80
mask_prob = 0.5 mask_prob = 0.5
mask_length = 4 mask_length = 4
mask = _compute_mask_indices((batch_size, sequence_length), mask_prob, mask_length, torch_device) mask = _compute_mask_indices((batch_size, sequence_length), mask_prob, mask_length, torch_device)
# because of overlap there is a range of possible masks # because of overlap mask don't have to add up exactly to `mask_prob * sequence_length`, but have to be smaller or equal
for batch_sum in mask.sum(axis=-1): for batch_sum in mask.sum(axis=-1):
self.assertIn( self.assertTrue(int(batch_sum) <= mask_prob * sequence_length)
int(batch_sum),
list(range(int(mask_prob // mask_length * sequence_length), int(mask_prob * sequence_length))), def test_compute_perplexity(self):
) probs = torch.arange(100, device=torch_device).reshape(2, 5, 10) / 100
ppl = Wav2Vec2GumbelVectorQuantizer._compute_perplexity(probs)
self.assertTrue(abs(ppl.item() - 141.4291) < 1e-3)
# mask half of the input
mask = torch.ones((2,), device=torch_device, dtype=torch.bool)
mask[0] = 0
ppl = Wav2Vec2GumbelVectorQuantizer._compute_perplexity(probs, mask)
self.assertTrue(abs(ppl.item() - 58.6757) < 1e-3)
def test_sample_negatives(self):
batch_size = 2
sequence_length = 10
hidden_size = 4
num_negatives = 3
features = (torch.arange(sequence_length * hidden_size, device=torch_device) // hidden_size).view(
sequence_length, hidden_size
) # each value in vector consits of same value
features = features[None, :].expand(batch_size, sequence_length, hidden_size).contiguous()
negatives = Wav2Vec2ForPreTraining._sample_negatives(features, num_negatives)
self.assertTrue(negatives.shape == (num_negatives, batch_size, sequence_length, hidden_size))
# make sure no negatively sampled vector is actually a positive one
for negative in negatives:
self.assertTrue(((negative - features) == 0).sum() == 0.0)
# make sure that full vectors are sampled and not values of vectors => this means that `unique()` yields a single value for `hidden_size` dim
self.assertTrue(negatives.unique(dim=-1).shape, (num_negatives, batch_size, sequence_length, 1))
@require_torch @require_torch
@slow
@require_datasets @require_datasets
@require_soundfile @require_soundfile
@slow
class Wav2Vec2ModelIntegrationTest(unittest.TestCase): class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
def _load_datasamples(self, num_samples): def _load_datasamples(self, num_samples):
from datasets import load_dataset from datasets import load_dataset
@ -586,3 +675,160 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
"his instant panic was followed by a small sharp blow high on his chest", "his instant panic was followed by a small sharp blow high on his chest",
] ]
self.assertListEqual(predicted_trans, EXPECTED_TRANSCRIPTIONS) self.assertListEqual(predicted_trans, EXPECTED_TRANSCRIPTIONS)
def test_inference_integration(self):
model = Wav2Vec2ForPreTraining.from_pretrained("patrickvonplaten/wav2vec2-base")
model.to(torch_device)
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
"patrickvonplaten/wav2vec2-base", return_attention_mask=True
)
input_speech = self._load_datasamples(2)
inputs_dict = feature_extractor(input_speech, return_tensors="pt", padding=True)
features_shape = (
inputs_dict["input_values"].shape[0],
model._get_feat_extract_output_lengths(torch.tensor(inputs_dict["input_values"].shape[1])),
)
torch.manual_seed(0)
mask_time_indices = _compute_mask_indices(
features_shape,
model.config.mask_time_prob,
model.config.mask_time_length,
device=inputs_dict["input_values"].device,
min_masks=2,
).to(torch_device)
with torch.no_grad():
outputs = model(
inputs_dict.input_values.to(torch_device),
attention_mask=inputs_dict.attention_mask.to(torch_device),
mask_time_indices=mask_time_indices,
)
# compute cosine similarity
cosine_sim = torch.cosine_similarity(outputs.projected_states, outputs.projected_quantized_states, dim=-1)
# retrieve cosine sim of masked features
cosine_sim_masked = cosine_sim[mask_time_indices]
# fmt: off
expected_cosine_sim_masked = torch.tensor(
[0.7458, 0.7188, 0.6418, 0.3729, 0.3741, 0.3694, 0.3110, 0.2257, 0.4403, 0.5415, 0.3950, 0.3701, 0.8831, 0.8613, 0.5229, 0.6696, 0.7206, 0.7877, 0.6758, 0.8746, 0.6596, 0.6282, 0.6178, 0.5839, 0.5926, 0.6651, 0.4635, 0.6332, 0.6572, 0.8776, 0.4999, 0.7001, 0.7257, 0.5098, 0.6229, 0.4566, 0.5261, 0.6363, 0.5371, 0.6997],
device=torch_device,
)
# fmt: on
self.assertTrue(torch.allclose(cosine_sim_masked, expected_cosine_sim_masked, atol=1e-3))
def test_inference_pretrained(self):
model = Wav2Vec2ForPreTraining.from_pretrained("patrickvonplaten/wav2vec2-base")
model.to(torch_device)
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
"patrickvonplaten/wav2vec2-base", return_attention_mask=True
)
input_speech = self._load_datasamples(2)
inputs_dict = feature_extractor(input_speech, return_tensors="pt", padding=True)
features_shape = (
inputs_dict["input_values"].shape[0],
model._get_feat_extract_output_lengths(torch.tensor(inputs_dict["input_values"].shape[1])),
)
torch.manual_seed(0)
mask_time_indices = _compute_mask_indices(
features_shape,
model.config.mask_time_prob,
model.config.mask_time_length,
device=inputs_dict["input_values"].device,
min_masks=2,
).to(torch_device)
with torch.no_grad():
outputs = model(
inputs_dict.input_values.to(torch_device),
attention_mask=inputs_dict.attention_mask.to(torch_device),
mask_time_indices=mask_time_indices,
)
# compute cosine similarity
cosine_sim = torch.cosine_similarity(outputs.projected_states, outputs.projected_quantized_states, dim=-1)
# retrieve cosine sim of masked features
cosine_sim_masked = cosine_sim[mask_time_indices]
# ... now compare to randomly initialized model
config = Wav2Vec2Config.from_pretrained("patrickvonplaten/wav2vec2-base")
model_rand = Wav2Vec2ForPreTraining(config).to(torch_device).eval()
with torch.no_grad():
outputs_rand = model_rand(
inputs_dict.input_values.to(torch_device),
attention_mask=inputs_dict.attention_mask.to(torch_device),
mask_time_indices=mask_time_indices,
)
# compute cosine similarity
cosine_sim_rand = torch.cosine_similarity(
outputs_rand.projected_states, outputs_rand.projected_quantized_states, dim=-1
)
# retrieve cosine sim of masked features
cosine_sim_masked_rand = cosine_sim_rand[mask_time_indices]
# a pretrained wav2vec2 model has learned to predict the quantized latent states
# => the cosine similarity between quantized states and predicted states > 0.5
# a random wav2vec2 model has not learned to predict the quantized latent states
# => the cosine similarity between quantized states and predicted states is very likely < 0.1
self.assertTrue(cosine_sim_masked.mean().item() - 5 * cosine_sim_masked_rand.mean().item() > 0)
def test_loss_pretraining(self):
model = Wav2Vec2ForPreTraining.from_pretrained(
"patrickvonplaten/wav2vec2-base",
attention_dropout=0.0,
feat_proj_dropout=0.0,
hidden_dropout=0.0,
layerdrop=0.0,
)
model.to(torch_device).train()
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
"patrickvonplaten/wav2vec2-base", return_attention_mask=True
)
input_speech = self._load_datasamples(2)
inputs_dict = feature_extractor(input_speech, return_tensors="pt", padding=True)
features_shape = (
inputs_dict["input_values"].shape[0],
model._get_feat_extract_output_lengths(inputs_dict["input_values"].shape[1]),
)
torch.manual_seed(0)
mask_time_indices = _compute_mask_indices(
features_shape,
model.config.mask_time_prob,
model.config.mask_time_length,
device=inputs_dict["input_values"].device,
min_masks=2,
).to(torch_device)
with torch.no_grad():
outputs = model(
inputs_dict.input_values.to(torch_device),
attention_mask=inputs_dict.attention_mask.to(torch_device),
mask_time_indices=mask_time_indices,
)
# check diversity loss
num_codevectors = model.config.num_codevectors_per_group * model.config.num_codevector_groups
diversity_loss = (num_codevectors - outputs.codevector_perplexity) / num_codevectors
self.assertTrue(abs(diversity_loss.item() - 0.8859) < 1e-3)
# check overall loss (contrastive loss + diversity loss)
expected_loss = 62.5170 if model.device.type == "cpu" else 50.3612
self.assertTrue(abs(outputs.loss.item() - expected_loss) < 1e-3)