mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-30 17:52:35 +06:00
Wav2Vec2 Pretraining (#11306)
* Working quantizer forward * Working quantizer forward * Clean up unused model parts, test reproducibility * Working quantizer forward * Clean up unused model parts, test reproducibility * Remove custom outputs from the shared ones * correct conversion * correct bug * add first pretrain script * save intermediate * static shapes * save intermediate * finish first pretrain script version * more refactor * remove wanddb * refactor more * improve test * correct perplexity compute bug * finish model implementation * add to docs * finish docs * finish pretraining script * finish pretraining script * remove wandb * finish PR for merge * finish config * finish * make deepspeed work * Apply suggestions from code review Co-authored-by: Lysandre Debut <lysandre@huggingface.co> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * apply suggestions * fix flaky test Co-authored-by: patrickvonplaten <patrick.v.platen@gmail.com> Co-authored-by: Lysandre Debut <lysandre@huggingface.co> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
parent
b1a8aa94f0
commit
d472bd7b18
@ -79,3 +79,10 @@ Wav2Vec2ForCTC
|
||||
|
||||
.. autoclass:: transformers.Wav2Vec2ForCTC
|
||||
:members: forward
|
||||
|
||||
|
||||
Wav2Vec2ForPreTraining
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.Wav2Vec2ForPreTraining
|
||||
:members: forward
|
||||
|
@ -184,3 +184,35 @@ run_asr.py \
|
||||
--preprocessing_num_workers=1 --group_by_length --freeze_feature_extractor --verbose_logging \
|
||||
--deepspeed ds_config_wav2vec2_zero3.json
|
||||
```
|
||||
|
||||
### Pretraining Wav2Vec2
|
||||
|
||||
The `run_pretrain.py` script allows one to pretrain a Wav2Vec2 model from scratch using Wav2Vec2's contrastive loss objective (see official [paper](https://arxiv.org/abs/2006.11477) for more information).
|
||||
It is recommended to pre-train Wav2Vec2 with Trainer + Deepspeed (please refer to [this guide](https://huggingface.co/transformers/master/main_classes/deepspeed.html#deepspeed-trainer-integration) for more information).
|
||||
|
||||
Here is an example of how you can use DeepSpeed ZeRO-2 to pretrain a small Wav2Vec2 model:
|
||||
|
||||
```
|
||||
PYTHONPATH=../../../src deepspeed --num_gpus 2 run_pretrain.py \
|
||||
--output_dir="./wav2vec2-base-libri-100h" \
|
||||
--num_train_epochs="3" \
|
||||
--per_device_train_batch_size="32" \
|
||||
--per_device_eval_batch_size="32" \
|
||||
--gradient_accumulation_steps="2" \
|
||||
--save_total_limit="3" \
|
||||
--save_steps="500" \
|
||||
--logging_steps="10" \
|
||||
--learning_rate="5e-4" \
|
||||
--weight_decay="0.01" \
|
||||
--warmup_steps="3000" \
|
||||
--model_name_or_path="patrickvonplaten/wav2vec2-base-libri-100h" \
|
||||
--dataset_name="librispeech_asr" \
|
||||
--dataset_config_name="clean" \
|
||||
--train_split_name="train.100" \
|
||||
--preprocessing_num_workers="4" \
|
||||
--max_duration_in_seconds="10.0" \
|
||||
--group_by_length \
|
||||
--verbose_logging \
|
||||
--fp16 \
|
||||
--deepspeed ds_config_wav2vec2_zero2.json \
|
||||
```
|
||||
|
370
examples/research_projects/wav2vec2/run_pretrain.py
Executable file
370
examples/research_projects/wav2vec2/run_pretrain.py
Executable file
@ -0,0 +1,370 @@
|
||||
#!/usr/bin/env python3
|
||||
import logging
|
||||
import sys
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from datasets import DatasetDict, load_dataset
|
||||
from packaging import version
|
||||
|
||||
import librosa
|
||||
from transformers import (
|
||||
HfArgumentParser,
|
||||
Trainer,
|
||||
TrainingArguments,
|
||||
Wav2Vec2Config,
|
||||
Wav2Vec2FeatureExtractor,
|
||||
Wav2Vec2ForPreTraining,
|
||||
is_apex_available,
|
||||
trainer_utils,
|
||||
)
|
||||
from transformers.models.wav2vec2.modeling_wav2vec2 import _compute_mask_indices
|
||||
|
||||
|
||||
if is_apex_available():
|
||||
from apex import amp
|
||||
|
||||
if version.parse(torch.__version__) >= version.parse("1.6"):
|
||||
_is_native_amp_available = True
|
||||
from torch.cuda.amp import autocast
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArguments:
|
||||
"""
|
||||
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
|
||||
"""
|
||||
|
||||
model_name_or_path: str = field(
|
||||
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
|
||||
)
|
||||
cache_dir: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"},
|
||||
)
|
||||
freeze_feature_extractor: Optional[bool] = field(
|
||||
default=True, metadata={"help": "Whether to freeze the feature extractor layers of the model."}
|
||||
)
|
||||
gradient_checkpointing: Optional[bool] = field(
|
||||
default=False, metadata={"help": "Whether to freeze the feature extractor layers of the model."}
|
||||
)
|
||||
verbose_logging: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether to log verbose messages or not."},
|
||||
)
|
||||
max_gumbel_temperature: Optional[float] = field(
|
||||
default=2.0, metadata={"help": "Maximum temperature for gumbel softmax."}
|
||||
)
|
||||
min_gumbel_temperature: Optional[float] = field(
|
||||
default=0.5, metadata={"help": "Minimum temperature for gumbel softmax."}
|
||||
)
|
||||
gumbel_temperature_decay: Optional[float] = field(
|
||||
default=0.999995, metadata={"help": "Decay of gumbel temperature during training."}
|
||||
)
|
||||
|
||||
|
||||
def configure_logger(model_args: ModelArguments, training_args: TrainingArguments):
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
datefmt="%m/%d/%Y %H:%M:%S",
|
||||
handlers=[logging.StreamHandler(sys.stdout)],
|
||||
)
|
||||
logging_level = logging.WARNING
|
||||
if model_args.verbose_logging:
|
||||
logging_level = logging.DEBUG
|
||||
elif trainer_utils.is_main_process(training_args.local_rank):
|
||||
logging_level = logging.INFO
|
||||
logger.setLevel(logging_level)
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataTrainingArguments:
|
||||
"""
|
||||
Arguments pertaining to what data we are going to input our model for training and eval.
|
||||
|
||||
Using `HfArgumentParser` we can turn this class
|
||||
into argparse arguments to be able to specify them on
|
||||
the command line.
|
||||
"""
|
||||
|
||||
dataset_name: str = field(
|
||||
default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
|
||||
)
|
||||
dataset_config_name: Optional[str] = field(
|
||||
default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
|
||||
)
|
||||
train_split_name: Optional[str] = field(
|
||||
default="train",
|
||||
metadata={
|
||||
"help": "The name of the training data set split to use (via the datasets library). Defaults to 'train'"
|
||||
},
|
||||
)
|
||||
validation_split_name: Optional[str] = field(
|
||||
default="validation",
|
||||
metadata={
|
||||
"help": "The name of the validation data set split to use (via the datasets library). Defaults to 'validation'"
|
||||
},
|
||||
)
|
||||
speech_file_column: Optional[str] = field(
|
||||
default="file",
|
||||
metadata={"help": "Column in the dataset that contains speech file path. Defaults to 'file'"},
|
||||
)
|
||||
overwrite_cache: bool = field(
|
||||
default=False, metadata={"help": "Overwrite the cached preprocessed datasets or not."}
|
||||
)
|
||||
preprocessing_num_workers: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "The number of processes to use for the preprocessing."},
|
||||
)
|
||||
max_duration_in_seconds: Optional[float] = field(
|
||||
default=20.0, metadata={"help": "Filter audio files that are longer than `max_duration_in_seconds` seconds"}
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataCollatorForWav2Vec2Pretraining:
|
||||
"""
|
||||
Data collator that will dynamically pad the inputs received and prepare masked indices
|
||||
for self-supervised pretraining.
|
||||
|
||||
Args:
|
||||
model (:class:`~transformers.Wav2Vec2ForPreTraining`):
|
||||
The Wav2Vec2 model used for pretraining. The data collator needs to have access
|
||||
to config and ``_get_feat_extract_output_lengths`` function for correct padding.
|
||||
feature_extractor (:class:`~transformers.Wav2Vec2FeatureExtractor`):
|
||||
The processor used for proccessing the data.
|
||||
padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
|
||||
Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
|
||||
among:
|
||||
* :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
|
||||
sequence if provided).
|
||||
* :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
|
||||
maximum acceptable input length for the model if that argument is not provided.
|
||||
* :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
|
||||
different lengths).
|
||||
max_length (:obj:`int`, `optional`):
|
||||
Maximum length of the ``input_values`` of the returned list and optionally padding length (see above).
|
||||
pad_to_multiple_of (:obj:`int`, `optional`):
|
||||
If set will pad the sequence to a multiple of the provided value.
|
||||
This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
|
||||
7.5 (Volta).
|
||||
"""
|
||||
|
||||
model: Wav2Vec2ForPreTraining
|
||||
feature_extractor: Wav2Vec2FeatureExtractor
|
||||
padding: Union[bool, str] = "longest"
|
||||
pad_to_multiple_of: Optional[int] = None
|
||||
max_length: Optional[int] = None
|
||||
|
||||
def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
|
||||
# reformat list to dict and set to pytorch format
|
||||
batch = self.feature_extractor.pad(
|
||||
features,
|
||||
max_length=self.max_length,
|
||||
padding=self.padding,
|
||||
pad_to_multiple_of=self.pad_to_multiple_of,
|
||||
return_tensors="pt",
|
||||
)
|
||||
mask_indices_seq_length = self.model._get_feat_extract_output_lengths(batch["input_values"].shape[-1])
|
||||
|
||||
# sample randomly masked indices
|
||||
batch["mask_time_indices"] = _compute_mask_indices(
|
||||
(batch["input_values"].shape[0], mask_indices_seq_length),
|
||||
self.model.config.mask_time_prob,
|
||||
self.model.config.mask_time_length,
|
||||
device=batch["input_values"].device,
|
||||
min_masks=2,
|
||||
)
|
||||
|
||||
return batch
|
||||
|
||||
|
||||
class Wav2Vec2PreTrainer(Trainer):
|
||||
"""
|
||||
Subclassed :class:`~transformers.Trainer` for Wav2Vec2-like pretraining. Trainer can decay gumbel softmax temperature during training.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, max_gumbel_temp=1, min_gumbel_temp=0, gumbel_temp_decay=1.0, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.num_update_step = 0
|
||||
self.max_gumbel_temp = max_gumbel_temp
|
||||
self.min_gumbel_temp = min_gumbel_temp
|
||||
self.gumbel_temp_decay = gumbel_temp_decay
|
||||
|
||||
def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor:
|
||||
"""
|
||||
Perform a training step on a batch of inputs.
|
||||
|
||||
Subclass and override to inject custom behavior.
|
||||
|
||||
Args:
|
||||
model (:obj:`nn.Module`):
|
||||
The model to train.
|
||||
inputs (:obj:`Dict[str, Union[torch.Tensor, Any]]`):
|
||||
The inputs and targets of the model.
|
||||
|
||||
The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
|
||||
argument :obj:`labels`. Check your model's documentation for all accepted arguments.
|
||||
|
||||
Return:
|
||||
:obj:`torch.Tensor`: The tensor with training loss on this batch.
|
||||
"""
|
||||
|
||||
model.train()
|
||||
inputs = self._prepare_inputs(inputs)
|
||||
|
||||
if self.use_amp:
|
||||
with autocast():
|
||||
loss = self.compute_loss(model, inputs)
|
||||
else:
|
||||
loss = self.compute_loss(model, inputs)
|
||||
|
||||
if self.args.n_gpu > 1 or self.deepspeed:
|
||||
if model.module.config.ctc_loss_reduction == "mean":
|
||||
loss = loss.mean()
|
||||
elif model.module.config.ctc_loss_reduction == "sum":
|
||||
loss = loss.sum() / (inputs["mask_time_indices"]).sum()
|
||||
else:
|
||||
raise ValueError(f"{model.config.ctc_loss_reduction} is not valid. Choose one of ['mean', 'sum']")
|
||||
|
||||
if self.args.gradient_accumulation_steps > 1:
|
||||
loss = loss / self.args.gradient_accumulation_steps
|
||||
|
||||
if self.use_amp:
|
||||
self.scaler.scale(loss).backward()
|
||||
elif self.use_apex:
|
||||
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
|
||||
scaled_loss.backward()
|
||||
elif self.deepspeed:
|
||||
self.deepspeed.backward(loss)
|
||||
else:
|
||||
loss.backward()
|
||||
|
||||
self.num_update_step += 1
|
||||
# make sure gumbel softmax temperature is decayed
|
||||
if self.args.n_gpu > 1 or self.deepspeed:
|
||||
model.module.set_gumbel_temperature(
|
||||
max(self.max_gumbel_temp * self.gumbel_temp_decay ** self.num_update_step, self.min_gumbel_temp)
|
||||
)
|
||||
else:
|
||||
model.set_gumbel_temperature(
|
||||
max(self.max_gumbel_temp * self.gumbel_temp_decay ** self.num_update_step, self.min_gumbel_temp)
|
||||
)
|
||||
|
||||
return loss.detach()
|
||||
|
||||
|
||||
def main():
|
||||
# See all possible arguments in src/transformers/training_args.py
|
||||
# or by passing the --help flag to this script.
|
||||
# We now keep distinct sets of args, for a cleaner separation of concerns.
|
||||
|
||||
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
|
||||
|
||||
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
||||
configure_logger(model_args, training_args)
|
||||
|
||||
# Downloading and loading a dataset from the hub.
|
||||
datasets = load_dataset(data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir)
|
||||
|
||||
if "validation" not in datasets.keys():
|
||||
# make sure only "validation" and "train" keys remain"
|
||||
datasets = DatasetDict()
|
||||
datasets["validation"] = load_dataset(
|
||||
data_args.dataset_name,
|
||||
data_args.dataset_config_name,
|
||||
split=f"{data_args.train_split_name}[:{data_args.validation_split_percentage}%]",
|
||||
cache_dir=model_args.cache_dir,
|
||||
)
|
||||
datasets["train"] = load_dataset(
|
||||
data_args.dataset_name,
|
||||
data_args.dataset_config_name,
|
||||
split=f"{data_args.train_split_name}[{data_args.validation_split_percentage}%:]",
|
||||
cache_dir=model_args.cache_dir,
|
||||
)
|
||||
else:
|
||||
# make sure only "validation" and "train" keys remain"
|
||||
datasets = DatasetDict()
|
||||
datasets["validation"] = load_dataset(
|
||||
data_args.dataset_name,
|
||||
data_args.dataset_config_name,
|
||||
split="validation",
|
||||
cache_dir=model_args.cache_dir,
|
||||
)
|
||||
datasets["train"] = load_dataset(
|
||||
data_args.dataset_name,
|
||||
data_args.dataset_config_name,
|
||||
split=f"{data_args.train_split_name}",
|
||||
cache_dir=model_args.cache_dir,
|
||||
)
|
||||
|
||||
# only normalized-inputs-training is supported
|
||||
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
|
||||
model_args.model_name_or_path, cache_dir=model_args.cache_dir, do_normalize=True
|
||||
)
|
||||
|
||||
def prepare_dataset(batch):
|
||||
# check that all files have the correct sampling rate
|
||||
batch["speech"], _ = librosa.load(batch[data_args.speech_file_column], sr=feature_extractor.sampling_rate)
|
||||
return batch
|
||||
|
||||
# load audio files into numpy arrays
|
||||
vectorized_datasets = datasets.map(
|
||||
prepare_dataset, num_proc=data_args.preprocessing_num_workers, remove_columns=datasets["train"].column_names
|
||||
)
|
||||
|
||||
# filter audio files that are too long
|
||||
vectorized_datasets = vectorized_datasets.filter(
|
||||
lambda data: len(data["speech"]) < int(data_args.max_duration_in_seconds * feature_extractor.sampling_rate)
|
||||
)
|
||||
|
||||
def normalize(batch):
|
||||
return feature_extractor(batch["speech"], sampling_rate=feature_extractor.sampling_rate)
|
||||
|
||||
# normalize and transform to `BatchFeatures`
|
||||
vectorized_datasets = vectorized_datasets.map(
|
||||
normalize,
|
||||
batched=True,
|
||||
num_proc=data_args.preprocessing_num_workers,
|
||||
load_from_cache_file=not data_args.overwrite_cache,
|
||||
remove_columns=vectorized_datasets["train"].column_names,
|
||||
)
|
||||
|
||||
# pretraining is only supported for "newer" stable layer norm architecture
|
||||
# apply_spec_augment has to be True, mask_feature_prob has to be 0.0
|
||||
config = Wav2Vec2Config.from_pretrained(
|
||||
model_args.model_name_or_path,
|
||||
cache_dir=model_args.cache_dir,
|
||||
gradient_checkpointing=model_args.gradient_checkpointing,
|
||||
)
|
||||
|
||||
if not config.do_stable_layer_norm or config.feat_extract_norm != "layer":
|
||||
raise ValueError(
|
||||
"PreTraining is only supported for ``config.do_stable_layer_norm=True`` and ``config.feat_extract_norm='layer'"
|
||||
)
|
||||
|
||||
model = Wav2Vec2ForPreTraining(config)
|
||||
|
||||
data_collator = DataCollatorForWav2Vec2Pretraining(model=model, feature_extractor=feature_extractor)
|
||||
|
||||
trainer = Wav2Vec2PreTrainer(
|
||||
model=model,
|
||||
data_collator=data_collator,
|
||||
args=training_args,
|
||||
train_dataset=vectorized_datasets["train"],
|
||||
eval_dataset=vectorized_datasets["validation"],
|
||||
tokenizer=feature_extractor,
|
||||
max_gumbel_temp=model_args.max_gumbel_temperature,
|
||||
min_gumbel_temp=model_args.min_gumbel_temperature,
|
||||
gumbel_temp_decay=model_args.gumbel_temperature_decay,
|
||||
)
|
||||
trainer.train()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -1046,6 +1046,7 @@ if is_torch_available():
|
||||
"WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
"Wav2Vec2ForCTC",
|
||||
"Wav2Vec2ForMaskedLM",
|
||||
"Wav2Vec2ForPreTraining",
|
||||
"Wav2Vec2Model",
|
||||
"Wav2Vec2PreTrainedModel",
|
||||
]
|
||||
@ -2411,6 +2412,7 @@ if TYPE_CHECKING:
|
||||
WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
Wav2Vec2ForCTC,
|
||||
Wav2Vec2ForMaskedLM,
|
||||
Wav2Vec2ForPreTraining,
|
||||
Wav2Vec2Model,
|
||||
Wav2Vec2PreTrainedModel,
|
||||
)
|
||||
|
@ -269,7 +269,7 @@ from ..tapas.modeling_tapas import (
|
||||
from ..transfo_xl.modeling_transfo_xl import TransfoXLForSequenceClassification, TransfoXLLMHeadModel, TransfoXLModel
|
||||
from ..visual_bert.modeling_visual_bert import VisualBertForPreTraining, VisualBertModel
|
||||
from ..vit.modeling_vit import ViTForImageClassification, ViTModel
|
||||
from ..wav2vec2.modeling_wav2vec2 import Wav2Vec2ForMaskedLM, Wav2Vec2Model
|
||||
from ..wav2vec2.modeling_wav2vec2 import Wav2Vec2ForMaskedLM, Wav2Vec2ForPreTraining, Wav2Vec2Model
|
||||
from ..xlm.modeling_xlm import (
|
||||
XLMForMultipleChoice,
|
||||
XLMForQuestionAnsweringSimple,
|
||||
@ -463,6 +463,7 @@ MODEL_FOR_PRETRAINING_MAPPING = OrderedDict(
|
||||
(IBertConfig, IBertForMaskedLM),
|
||||
(DebertaConfig, DebertaForMaskedLM),
|
||||
(DebertaV2Config, DebertaV2ForMaskedLM),
|
||||
(Wav2Vec2Config, Wav2Vec2ForPreTraining),
|
||||
]
|
||||
)
|
||||
|
||||
|
@ -32,6 +32,7 @@ if is_torch_available():
|
||||
"WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
"Wav2Vec2ForCTC",
|
||||
"Wav2Vec2ForMaskedLM",
|
||||
"Wav2Vec2ForPreTraining",
|
||||
"Wav2Vec2Model",
|
||||
"Wav2Vec2PreTrainedModel",
|
||||
]
|
||||
@ -48,6 +49,7 @@ if TYPE_CHECKING:
|
||||
WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
Wav2Vec2ForCTC,
|
||||
Wav2Vec2ForMaskedLM,
|
||||
Wav2Vec2ForPreTraining,
|
||||
Wav2Vec2Model,
|
||||
Wav2Vec2PreTrainedModel,
|
||||
)
|
||||
|
@ -71,6 +71,8 @@ class Wav2Vec2Config(PretrainedConfig):
|
||||
feat_extract_activation (:obj:`str, `optional`, defaults to :obj:`"gelu"`):
|
||||
The non-linear activation function (function or string) in the 1D convolutional layers of the feature
|
||||
extractor. If string, :obj:`"gelu"`, :obj:`"relu"`, :obj:`"selu"` and :obj:`"gelu_new"` are supported.
|
||||
feat_quantizer_dropout (obj:`float`, `optional`, defaults to 0.0):
|
||||
The dropout probabilitiy for quantized feature extractor states.
|
||||
conv_dim (:obj:`Tuple[int]`, `optional`, defaults to :obj:`(512, 512, 512, 512, 512, 512, 512)`):
|
||||
A tuple of integers defining the number of input and output channels of each 1D convolutional layer in the
|
||||
feature extractor. The length of `conv_dim` defines the number of 1D convolutional layers.
|
||||
@ -108,6 +110,22 @@ class Wav2Vec2Config(PretrainedConfig):
|
||||
masked along the time axis. This is only relevant if ``apply_spec_augment is True``.
|
||||
mask_feature_length (:obj:`int`, `optional`, defaults to 10):
|
||||
Length of vector span along the feature axis.
|
||||
num_codevectors_per_group (:obj:`int`, `optional`, defaults to 320):
|
||||
Number of entries in each quantization codebook (group).
|
||||
num_codevector_groups (:obj:`int`, `optional`, defaults to 2):
|
||||
Number of codevector groups for product codevector quantization.
|
||||
contrastive_logits_temperature (:obj:`float`, `optional`, defaults to 0.1):
|
||||
The temperature `kappa` in the contrastive loss.
|
||||
feat_quantizer_dropout (:obj:`float`, `optional`, defaults to 0.0):
|
||||
The dropout probabilitiy for the output of the feature extractor that's used by the quantizer.
|
||||
num_negatives (:obj:`int`, `optional`, defaults to 100):
|
||||
Number of negative samples for the contrastive loss.
|
||||
codevector_dim (:obj:`int`, `optional`, defaults to 256):
|
||||
Dimensionality of the quantized feature vectors.
|
||||
proj_codevector_dim (:obj:`int`, `optional`, defaults to 256):
|
||||
Dimensionality of the final projection of both the quantized and the transformer features.
|
||||
diversity_loss_weight (:obj:`int`, `optional`, defaults to 0.1):
|
||||
The weight of the codebook diversity loss component.
|
||||
ctc_loss_reduction (:obj:`str`, `optional`, defaults to :obj:`"sum"`):
|
||||
Specifies the reduction to apply to the output of ``torch.nn.CTCLoss``. Only relevant when training an
|
||||
instance of :class:`~transformers.Wav2Vec2ForCTC`.
|
||||
@ -145,6 +163,7 @@ class Wav2Vec2Config(PretrainedConfig):
|
||||
activation_dropout=0.1,
|
||||
attention_dropout=0.1,
|
||||
feat_proj_dropout=0.1,
|
||||
feat_quantizer_dropout=0.0,
|
||||
final_dropout=0.1,
|
||||
layerdrop=0.1,
|
||||
initializer_range=0.02,
|
||||
@ -163,6 +182,13 @@ class Wav2Vec2Config(PretrainedConfig):
|
||||
mask_time_length=10,
|
||||
mask_feature_prob=0.0,
|
||||
mask_feature_length=10,
|
||||
num_codevectors_per_group=320,
|
||||
num_codevector_groups=2,
|
||||
contrastive_logits_temperature=0.1,
|
||||
num_negatives=100,
|
||||
codevector_dim=256,
|
||||
proj_codevector_dim=256,
|
||||
diversity_loss_weight=0.1,
|
||||
ctc_loss_reduction="sum",
|
||||
ctc_zero_infinity=False,
|
||||
gradient_checkpointing=False,
|
||||
@ -217,6 +243,16 @@ class Wav2Vec2Config(PretrainedConfig):
|
||||
self.mask_feature_prob = mask_feature_prob
|
||||
self.mask_feature_length = mask_feature_length
|
||||
|
||||
# parameters for pretraining with codevector quantized representations
|
||||
self.num_codevectors_per_group = num_codevectors_per_group
|
||||
self.num_codevector_groups = num_codevector_groups
|
||||
self.contrastive_logits_temperature = contrastive_logits_temperature
|
||||
self.feat_quantizer_dropout = feat_quantizer_dropout
|
||||
self.num_negatives = num_negatives
|
||||
self.codevector_dim = codevector_dim
|
||||
self.proj_codevector_dim = proj_codevector_dim
|
||||
self.diversity_loss_weight = diversity_loss_weight
|
||||
|
||||
# ctc loss
|
||||
self.ctc_loss_reduction = ctc_loss_reduction
|
||||
self.ctc_zero_infinity = ctc_zero_infinity
|
||||
|
@ -28,7 +28,7 @@ from transformers import (
|
||||
Wav2Vec2CTCTokenizer,
|
||||
Wav2Vec2FeatureExtractor,
|
||||
Wav2Vec2ForCTC,
|
||||
Wav2Vec2Model,
|
||||
Wav2Vec2ForPreTraining,
|
||||
Wav2Vec2Processor,
|
||||
logging,
|
||||
)
|
||||
@ -50,9 +50,20 @@ MAPPING = {
|
||||
"final_layer_norm": "encoder.layers.*.final_layer_norm",
|
||||
"encoder.layer_norm": "encoder.layer_norm",
|
||||
"w2v_model.layer_norm": "feature_projection.layer_norm",
|
||||
"quantizer.weight_proj": "quantizer.weight_proj",
|
||||
"quantizer.vars": "quantizer.codevectors",
|
||||
"project_q": "project_q",
|
||||
"final_proj": "project_hid",
|
||||
"w2v_encoder.proj": "lm_head",
|
||||
"mask_emb": "masked_spec_embed",
|
||||
}
|
||||
TOP_LEVEL_KEYS = [
|
||||
"lm_head",
|
||||
"quantizer.weight_proj",
|
||||
"quantizer.codevectors",
|
||||
"project_q",
|
||||
"project_hid",
|
||||
]
|
||||
|
||||
|
||||
def set_recursively(hf_pointer, key, value, full_name, weight_type):
|
||||
@ -82,11 +93,11 @@ def set_recursively(hf_pointer, key, value, full_name, weight_type):
|
||||
logger.info(f"{key + '.' + weight_type if weight_type is not None else ''} was initialized from {full_name}.")
|
||||
|
||||
|
||||
def recursively_load_weights(fairseq_model, hf_model, is_finetuned):
|
||||
def recursively_load_weights(fairseq_model, hf_model, is_headless):
|
||||
unused_weights = []
|
||||
fairseq_dict = fairseq_model.state_dict()
|
||||
|
||||
feature_extractor = hf_model.wav2vec2.feature_extractor if is_finetuned else hf_model.feature_extractor
|
||||
feature_extractor = hf_model.wav2vec2.feature_extractor
|
||||
|
||||
for name, value in fairseq_dict.items():
|
||||
is_used = False
|
||||
@ -101,9 +112,8 @@ def recursively_load_weights(fairseq_model, hf_model, is_finetuned):
|
||||
is_used = True
|
||||
else:
|
||||
for key, mapped_key in MAPPING.items():
|
||||
mapped_key = "wav2vec2." + mapped_key if (is_finetuned and mapped_key != "lm_head") else mapped_key
|
||||
|
||||
if key in name or (key.split("w2v_model.")[-1] == name.split(".")[0] and not is_finetuned):
|
||||
mapped_key = "wav2vec2." + mapped_key if mapped_key not in TOP_LEVEL_KEYS else mapped_key
|
||||
if key in name or key.split("w2v_model.")[-1] == name.split(".")[0]:
|
||||
is_used = True
|
||||
if "*" in mapped_key:
|
||||
layer_index = name.split(key)[0].split(".")[-2]
|
||||
@ -112,10 +122,11 @@ def recursively_load_weights(fairseq_model, hf_model, is_finetuned):
|
||||
weight_type = "weight_g"
|
||||
elif "weight_v" in name:
|
||||
weight_type = "weight_v"
|
||||
elif "weight" in name:
|
||||
weight_type = "weight"
|
||||
elif "bias" in name:
|
||||
weight_type = "bias"
|
||||
elif "weight" in name:
|
||||
# TODO: don't match quantizer.weight_proj
|
||||
weight_type = "weight"
|
||||
else:
|
||||
weight_type = None
|
||||
set_recursively(hf_model, mapped_key, value, name, weight_type)
|
||||
@ -213,7 +224,7 @@ def convert_wav2vec2_checkpoint(
|
||||
|
||||
hf_wav2vec = Wav2Vec2ForCTC(config)
|
||||
else:
|
||||
hf_wav2vec = Wav2Vec2Model(config)
|
||||
hf_wav2vec = Wav2Vec2ForPreTraining(config)
|
||||
|
||||
if is_finetuned:
|
||||
model, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task(
|
||||
@ -224,7 +235,7 @@ def convert_wav2vec2_checkpoint(
|
||||
|
||||
model = model[0].eval()
|
||||
|
||||
recursively_load_weights(model, hf_wav2vec, is_finetuned)
|
||||
recursively_load_weights(model, hf_wav2vec, not is_finetuned)
|
||||
|
||||
hf_wav2vec.save_pretrained(pytorch_dump_folder_path)
|
||||
|
||||
|
@ -15,7 +15,8 @@
|
||||
""" PyTorch Wav2Vec2 model. """
|
||||
|
||||
import warnings
|
||||
from typing import Optional, Tuple
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -26,7 +27,12 @@ from torch import nn
|
||||
from transformers.deepspeed import is_deepspeed_zero3_enabled
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings
|
||||
from ...file_utils import (
|
||||
ModelOutput,
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
from ...modeling_outputs import BaseModelOutput, CausalLMOutput, MaskedLMOutput
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...utils import logging
|
||||
@ -46,6 +52,71 @@ WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
class Wav2Vec2BaseModelOutput(ModelOutput):
|
||||
"""
|
||||
Output type of :class:`~transformers.Wav2Vec2BaseModelOutput`, with potential hidden states and attentions.
|
||||
|
||||
Args:
|
||||
last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
|
||||
Sequence of hidden-states at the output of the last layer of the model.
|
||||
extract_features (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, conv_dim[-1])`):
|
||||
Sequence of extracted feature vectors of the last convolutional layer of the model.
|
||||
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
||||
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
|
||||
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
||||
|
||||
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
||||
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
|
||||
sequence_length, sequence_length)`.
|
||||
|
||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
||||
heads.
|
||||
"""
|
||||
|
||||
last_hidden_state: torch.FloatTensor = None
|
||||
extract_features: torch.FloatTensor = None
|
||||
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
||||
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class Wav2Vec2ForPreTrainingOutput(ModelOutput):
|
||||
"""
|
||||
Output type of :class:`~transformers.Wav2Vec2ForPreTrainingOutput`, with potential hidden states and attentions.
|
||||
|
||||
Args:
|
||||
loss (`optional`, returned when model is in train mode, ``torch.FloatTensor`` of shape :obj:`(1,)`):
|
||||
Total loss as the sum of the contrastive loss (L_m) and the diversity loss (L_d) as stated in the `official
|
||||
paper <https://arxiv.org/pdf/2006.11477.pdf>`__ . (classification) loss.
|
||||
projected_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.proj_codevector_dim)`):
|
||||
Hidden-states of the model projected to `config.proj_codevector_dim` that can be used to predict the masked
|
||||
projected quantized states.
|
||||
projected_quantized_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.proj_codevector_dim)`):
|
||||
Quantized extracted feature vectors projected to `config.proj_codevector_dim` representing the positive
|
||||
target vectors for contrastive loss.
|
||||
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
||||
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
|
||||
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
||||
|
||||
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
||||
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
|
||||
sequence_length, sequence_length)`.
|
||||
|
||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
||||
heads.
|
||||
"""
|
||||
|
||||
loss: Optional[torch.FloatTensor] = None
|
||||
projected_states: torch.FloatTensor = None
|
||||
projected_quantized_states: torch.FloatTensor = None
|
||||
codevector_perplexity: torch.FloatTensor = None
|
||||
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
||||
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||
|
||||
|
||||
def _compute_mask_indices(
|
||||
shape: Tuple[int, int],
|
||||
mask_prob: float,
|
||||
@ -271,10 +342,11 @@ class Wav2Vec2FeatureProjection(nn.Module):
|
||||
self.dropout = nn.Dropout(config.feat_proj_dropout)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = self.layer_norm(hidden_states)
|
||||
hidden_states = self.projection(hidden_states)
|
||||
# non-projected hidden states are needed for quantization
|
||||
norm_hidden_states = self.layer_norm(hidden_states)
|
||||
hidden_states = self.projection(norm_hidden_states)
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
return hidden_states
|
||||
return hidden_states, norm_hidden_states
|
||||
|
||||
|
||||
# Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->Wav2Vec2
|
||||
@ -685,6 +757,86 @@ class Wav2Vec2EncoderStableLayerNorm(nn.Module):
|
||||
)
|
||||
|
||||
|
||||
class Wav2Vec2GumbelVectorQuantizer(nn.Module):
|
||||
"""
|
||||
Vector quantization using gumbel softmax. See `CATEGORICAL REPARAMETERIZATION WITH GUMBEL-SOFTMAX
|
||||
<https://arxiv.org/pdf/1611.01144.pdf>`__ for more information.
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.num_groups = config.num_codevector_groups
|
||||
self.num_vars = config.num_codevectors_per_group
|
||||
|
||||
assert (
|
||||
config.codevector_dim % self.num_groups == 0
|
||||
), f"`config.codevector_dim {config.codevector_dim} must be divisible by `config.num_codevector_groups` {self.num_groups} for concatenation"
|
||||
|
||||
# storage for codebook variables (codewords)
|
||||
self.codevectors = nn.Parameter(
|
||||
torch.FloatTensor(1, self.num_groups * self.num_vars, config.codevector_dim // self.num_groups)
|
||||
)
|
||||
self.weight_proj = nn.Linear(config.conv_dim[-1], self.num_groups * self.num_vars)
|
||||
|
||||
# can be decayed for training
|
||||
self.temperature = 1
|
||||
|
||||
def set_temperature(self, temperature: int):
|
||||
self.temperature = temperature
|
||||
|
||||
@staticmethod
|
||||
def _compute_perplexity(probs, mask=None):
|
||||
if mask is not None:
|
||||
mask_extended = mask.flatten()[:, None, None].expand(probs.shape)
|
||||
probs = torch.where(mask_extended, probs, torch.zeros_like(probs))
|
||||
marginal_probs = probs.sum(dim=0) / mask.sum()
|
||||
else:
|
||||
marginal_probs = probs.mean(dim=0)
|
||||
|
||||
perplexity = torch.exp(-torch.sum(marginal_probs * torch.log(marginal_probs + 1e-7), dim=-1)).sum()
|
||||
return perplexity
|
||||
|
||||
def forward(self, hidden_states, mask_time_indices=None):
|
||||
batch_size, sequence_length, hidden_size = hidden_states.shape
|
||||
|
||||
# project to codevector dim
|
||||
hidden_states = self.weight_proj(hidden_states)
|
||||
hidden_states = hidden_states.view(batch_size * sequence_length * self.num_groups, -1)
|
||||
|
||||
if self.training:
|
||||
# sample code vector probs via gumbel in differentiateable way
|
||||
codevector_probs = F.gumbel_softmax(hidden_states.float(), tau=self.temperature, hard=True).type_as(
|
||||
hidden_states
|
||||
)
|
||||
|
||||
# compute perplexity
|
||||
codevector_soft_dist = torch.softmax(
|
||||
hidden_states.view(batch_size * sequence_length, self.num_groups, -1).float(), dim=-1
|
||||
)
|
||||
perplexity = self._compute_perplexity(codevector_soft_dist, mask_time_indices)
|
||||
else:
|
||||
# take argmax in non-differentiable way
|
||||
# comptute hard codevector distribution (one hot)
|
||||
codevector_idx = hidden_states.argmax(dim=-1)
|
||||
codevector_probs = hidden_states.new_zeros(*hidden_states.shape).scatter_(
|
||||
-1, codevector_idx.view(-1, 1), 1.0
|
||||
)
|
||||
codevector_probs = codevector_probs.view(batch_size * sequence_length, self.num_groups, -1)
|
||||
|
||||
perplexity = self._compute_perplexity(codevector_probs, mask_time_indices)
|
||||
|
||||
codevector_probs = codevector_probs.view(batch_size * sequence_length, -1)
|
||||
# use probs to retrieve codevectors
|
||||
codevectors_per_group = codevector_probs.unsqueeze(-1) * self.codevectors
|
||||
codevectors = (
|
||||
codevectors_per_group.view(batch_size * sequence_length, self.num_groups, self.num_vars, -1)
|
||||
.sum(-2)
|
||||
.view(batch_size, sequence_length, -1)
|
||||
)
|
||||
|
||||
return codevectors, perplexity
|
||||
|
||||
|
||||
class Wav2Vec2PreTrainedModel(PreTrainedModel):
|
||||
"""
|
||||
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
||||
@ -697,7 +849,12 @@ class Wav2Vec2PreTrainedModel(PreTrainedModel):
|
||||
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights"""
|
||||
if isinstance(module, nn.Linear):
|
||||
# gumbel softmax requires special init
|
||||
if isinstance(module, Wav2Vec2GumbelVectorQuantizer):
|
||||
module.weight_proj.weight.data.normal_(mean=0.0, std=1)
|
||||
module.weight_proj.bias.data.zero_()
|
||||
nn.init.uniform_(module.codevectors)
|
||||
elif isinstance(module, nn.Linear):
|
||||
# Slightly different from the TF version which uses truncated_normal for initialization
|
||||
# cf https://github.com/pytorch/pytorch/pull/5617
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
@ -720,7 +877,7 @@ class Wav2Vec2PreTrainedModel(PreTrainedModel):
|
||||
if isinstance(module, (nn.Linear, nn.Conv1d)) and module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
|
||||
def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor):
|
||||
def _get_feat_extract_output_lengths(self, input_lengths: Union[torch.LongTensor, int]):
|
||||
"""
|
||||
Computes the output length of the convolutional layers
|
||||
"""
|
||||
@ -733,7 +890,7 @@ class Wav2Vec2PreTrainedModel(PreTrainedModel):
|
||||
for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride):
|
||||
input_lengths = _conv_out_length(input_lengths, kernel_size, stride)
|
||||
|
||||
return input_lengths.to(torch.long)
|
||||
return input_lengths
|
||||
|
||||
|
||||
WAV_2_VEC_2_START_DOCSTRING = r"""
|
||||
@ -797,7 +954,7 @@ WAV_2_VEC_2_INPUTS_DOCSTRING = r"""
|
||||
WAV_2_VEC_2_START_DOCSTRING,
|
||||
)
|
||||
class Wav2Vec2Model(Wav2Vec2PreTrainedModel):
|
||||
def __init__(self, config):
|
||||
def __init__(self, config: Wav2Vec2Config):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
self.feature_extractor = Wav2Vec2FeatureExtractor(config)
|
||||
@ -812,12 +969,53 @@ class Wav2Vec2Model(Wav2Vec2PreTrainedModel):
|
||||
|
||||
self.init_weights()
|
||||
|
||||
def _mask_hidden_states(
|
||||
self, hidden_states: torch.FloatTensor, mask_time_indices: Optional[torch.FloatTensor] = None
|
||||
):
|
||||
"""
|
||||
Masks extracted features along time axis and/or along feature axis according to `SpecAugment
|
||||
<https://arxiv.org/abs/1904.08779>`__ .
|
||||
"""
|
||||
|
||||
# `config.apply_spec_augment` can set masking to False
|
||||
if not getattr(self.config, "apply_spec_augment", True):
|
||||
return hidden_states
|
||||
|
||||
if mask_time_indices is not None:
|
||||
# apply SpecAugment along time axis with given mask_time_indices
|
||||
hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
|
||||
elif self.config.mask_time_prob > 0 and self.training:
|
||||
# generate indices & apply SpecAugment along time axis
|
||||
batch_size, sequence_length, hidden_size = hidden_states.size()
|
||||
|
||||
mask_time_indices = _compute_mask_indices(
|
||||
(batch_size, sequence_length),
|
||||
mask_prob=self.config.mask_time_prob,
|
||||
mask_length=self.config.mask_time_length,
|
||||
device=hidden_states.device,
|
||||
min_masks=2,
|
||||
)
|
||||
hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
|
||||
|
||||
if self.config.mask_feature_prob > 0 and self.training:
|
||||
# generate indices & apply SpecAugment along feature axis
|
||||
mask_feature_indices = _compute_mask_indices(
|
||||
(batch_size, hidden_size),
|
||||
mask_prob=self.config.mask_feature_prob,
|
||||
mask_length=self.config.mask_feature_length,
|
||||
device=hidden_states.device,
|
||||
)
|
||||
hidden_states[mask_feature_indices[:, None].expand(-1, sequence_length, -1)] = 0
|
||||
|
||||
return hidden_states
|
||||
|
||||
@add_start_docstrings_to_model_forward(WAV_2_VEC_2_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC)
|
||||
def forward(
|
||||
self,
|
||||
input_values,
|
||||
attention_mask=None,
|
||||
mask_time_indices=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
@ -852,49 +1050,30 @@ class Wav2Vec2Model(Wav2Vec2PreTrainedModel):
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
hidden_states = self.feature_extractor(input_values)
|
||||
hidden_states = hidden_states.transpose(1, 2)
|
||||
extract_features = self.feature_extractor(input_values)
|
||||
extract_features = extract_features.transpose(1, 2)
|
||||
|
||||
if attention_mask is not None:
|
||||
# compute real output lengths according to convolution formula
|
||||
output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1))
|
||||
output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long)
|
||||
|
||||
attention_mask = torch.zeros(
|
||||
hidden_states.shape[:2], dtype=hidden_states.dtype, device=hidden_states.device
|
||||
extract_features.shape[:2], dtype=extract_features.dtype, device=extract_features.device
|
||||
)
|
||||
|
||||
# these two operations makes sure that all values
|
||||
# before the output lengths indices are attended to
|
||||
attention_mask[
|
||||
(torch.arange(attention_mask.shape[0], device=hidden_states.device), output_lengths - 1)
|
||||
(torch.arange(attention_mask.shape[0], device=extract_features.device), output_lengths - 1)
|
||||
] = 1
|
||||
attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool()
|
||||
|
||||
hidden_states = self.feature_projection(hidden_states)
|
||||
hidden_states, extract_features = self.feature_projection(extract_features)
|
||||
|
||||
if self.config.apply_spec_augment and self.training:
|
||||
batch_size, sequence_length, hidden_size = hidden_states.size()
|
||||
if mask_time_indices is not None: # apply SpecAugment along time axis with given indices
|
||||
hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
|
||||
|
||||
# apply SpecAugment along time axis
|
||||
if self.config.mask_time_prob > 0:
|
||||
mask_time_indices = _compute_mask_indices(
|
||||
(batch_size, sequence_length),
|
||||
mask_prob=self.config.mask_time_prob,
|
||||
mask_length=self.config.mask_time_length,
|
||||
device=hidden_states.device,
|
||||
min_masks=2,
|
||||
)
|
||||
hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
|
||||
|
||||
# apply SpecAugment along feature axis
|
||||
if self.config.mask_feature_prob > 0:
|
||||
mask_feature_indices = _compute_mask_indices(
|
||||
(batch_size, hidden_size),
|
||||
mask_prob=self.config.mask_feature_prob,
|
||||
mask_length=self.config.mask_feature_length,
|
||||
device=hidden_states.device,
|
||||
)
|
||||
hidden_states[mask_feature_indices[:, None].expand(-1, sequence_length, -1)] = 0
|
||||
hidden_states = self._mask_hidden_states(hidden_states)
|
||||
|
||||
encoder_outputs = self.encoder(
|
||||
hidden_states,
|
||||
@ -907,15 +1086,240 @@ class Wav2Vec2Model(Wav2Vec2PreTrainedModel):
|
||||
hidden_states = encoder_outputs[0]
|
||||
|
||||
if not return_dict:
|
||||
return (hidden_states,) + encoder_outputs[1:]
|
||||
return (hidden_states, extract_features) + encoder_outputs[1:]
|
||||
|
||||
return BaseModelOutput(
|
||||
return Wav2Vec2BaseModelOutput(
|
||||
last_hidden_state=hidden_states,
|
||||
extract_features=extract_features,
|
||||
hidden_states=encoder_outputs.hidden_states,
|
||||
attentions=encoder_outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
@add_start_docstrings("""Wav2Vec2 Model with a quantizer and `VQ` head on top. """, WAV_2_VEC_2_START_DOCSTRING)
|
||||
class Wav2Vec2ForPreTraining(Wav2Vec2PreTrainedModel):
|
||||
def __init__(self, config: Wav2Vec2Config):
|
||||
super().__init__(config)
|
||||
self.wav2vec2 = Wav2Vec2Model(config)
|
||||
self.dropout_features = nn.Dropout(config.feat_quantizer_dropout)
|
||||
|
||||
self.quantizer = Wav2Vec2GumbelVectorQuantizer(config)
|
||||
self.project_q = nn.Linear(config.codevector_dim, config.proj_codevector_dim)
|
||||
self.project_hid = nn.Linear(config.hidden_size, config.proj_codevector_dim)
|
||||
|
||||
self.init_weights()
|
||||
|
||||
def set_gumbel_temperature(self, temperature: int):
|
||||
"""
|
||||
Set the Gumbel softmax temperature to a given value. Only necessary for training
|
||||
"""
|
||||
return self.quantizer.set_temperature(temperature)
|
||||
|
||||
def freeze_feature_extractor(self):
|
||||
"""
|
||||
Calling this function will disable the gradient computation for the feature extractor so that its parameters
|
||||
will not be updated during training.
|
||||
"""
|
||||
self.wav2vec2.feature_extractor._freeze_parameters()
|
||||
|
||||
@staticmethod
|
||||
def _sample_negatives(features: torch.FloatTensor, num_negatives: int):
|
||||
"""
|
||||
Sample `num_negatives` vectors from feature vectors.
|
||||
"""
|
||||
batch_size, sequence_length, hidden_size = features.shape
|
||||
if sequence_length <= 1:
|
||||
raise ValueError(
|
||||
f"`features should have `sequence_length` > 1, but are of shape (batch_size, sequence_length, hidden_size) = ({batch_size, sequence_length, hidden_size})."
|
||||
)
|
||||
|
||||
features = features.view(-1, hidden_size) # BTC => (BxT)C
|
||||
|
||||
with torch.no_grad():
|
||||
# get `num_negatives` random vector indices from the same utterance
|
||||
sampled_negative_indices = torch.randint(
|
||||
low=0,
|
||||
high=sequence_length - 1,
|
||||
size=(batch_size, num_negatives * sequence_length),
|
||||
device=features.device,
|
||||
)
|
||||
|
||||
# generate indices of the positive vectors themselves, repeat them `num_negatives` times
|
||||
feature_indices = (
|
||||
torch.arange(sequence_length, device=features.device)[:, None]
|
||||
.expand(sequence_length, num_negatives)
|
||||
.flatten()
|
||||
)
|
||||
|
||||
# avoid sampling the same positive vector, but keep the distribution uniform
|
||||
sampled_negative_indices[sampled_negative_indices >= feature_indices] += 1
|
||||
|
||||
# correct for batch size
|
||||
for batch_idx in range(1, batch_size):
|
||||
sampled_negative_indices[batch_idx] += batch_idx * sequence_length
|
||||
|
||||
# take negative vectors from sampled indices
|
||||
sampled_negatives = features[sampled_negative_indices.view(-1)]
|
||||
sampled_negatives = sampled_negatives.view(batch_size, sequence_length, num_negatives, hidden_size).permute(
|
||||
2, 0, 1, 3
|
||||
)
|
||||
|
||||
return sampled_negatives
|
||||
|
||||
@staticmethod
|
||||
def compute_contrastive_logits(
|
||||
target_features: torch.FloatTensor,
|
||||
negative_features: torch.FloatTensor,
|
||||
predicted_features: torch.FloatTensor,
|
||||
temperature: int = 1,
|
||||
):
|
||||
"""
|
||||
Compute logits for contrastive loss based using cosine similarity as the distance measure between
|
||||
`[positive_feature, negative_features]` and `[predicted_features]`. Additionally, temperature can be applied.
|
||||
"""
|
||||
target_features = torch.cat([target_features, negative_features], dim=0)
|
||||
|
||||
logits = torch.cosine_similarity(predicted_features.float(), target_features.float(), dim=-1).type_as(
|
||||
target_features
|
||||
)
|
||||
|
||||
# apply temperature
|
||||
logits = logits / temperature
|
||||
return logits
|
||||
|
||||
@add_start_docstrings_to_model_forward(WAV_2_VEC_2_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC)
|
||||
def forward(
|
||||
self,
|
||||
input_values,
|
||||
attention_mask=None,
|
||||
mask_time_indices=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
):
|
||||
r"""
|
||||
mask_time_indices (:obj:`torch.BoolTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
||||
Indices to mask extracted features for contrastive loss. When in training mode, model learns to predict
|
||||
masked extracted features in `config.proj_codevector_dim` space.
|
||||
|
||||
Returns:
|
||||
|
||||
Example::
|
||||
|
||||
>>> import torch
|
||||
>>> from transformers import Wav2Vec2FeatureExtractor, Wav2Vec2ForPreTraining
|
||||
>>> from transformers.models.wav2vec2.modeling_wav2vec2 import _compute_mask_indices
|
||||
>>> from datasets import load_dataset
|
||||
>>> import soundfile as sf
|
||||
|
||||
>>> feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("patrickvonplaten/wav2vec2-base")
|
||||
>>> model = Wav2Vec2ForPreTraining.from_pretrained("patrickvonplaten/wav2vec2-base")
|
||||
|
||||
|
||||
>>> def map_to_array(batch):
|
||||
... speech, _ = sf.read(batch["file"])
|
||||
... batch["speech"] = speech
|
||||
... return batch
|
||||
|
||||
|
||||
>>> ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation")
|
||||
>>> ds = ds.map(map_to_array)
|
||||
|
||||
>>> input_values = feature_extractor(ds["speech"][0], return_tensors="pt").input_values # Batch size 1
|
||||
|
||||
>>> # compute masked indices
|
||||
>>> batch_size, raw_sequence_length = input_values.shape
|
||||
>>> sequence_length = model._get_feat_extract_output_lengths(raw_sequence_length)
|
||||
>>> mask_time_indices = _compute_mask_indices((batch_size, sequence_length), mask_prob=0.2, mask_length=2, device=model.device)
|
||||
|
||||
>>> with torch.no_grad():
|
||||
... outputs = model(input_values, mask_time_indices=mask_time_indices)
|
||||
|
||||
>>> # compute cosine similarity between predicted (=projected_states) and target (=projected_quantized_states)
|
||||
>>> cosine_sim = torch.cosine_similarity(
|
||||
... outputs.projected_states, outputs.projected_quantized_states, dim=-1
|
||||
... )
|
||||
|
||||
>>> # show that cosine similarity is much higher than random
|
||||
>>> assert cosine_sim[mask_time_indices].mean() > 0.5
|
||||
|
||||
>>> # for contrastive loss training model should be put into train mode
|
||||
>>> model.train()
|
||||
>>> loss = model(input_values, mask_time_indices=mask_time_indices).loss
|
||||
"""
|
||||
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
if mask_time_indices is not None:
|
||||
mask_time_indices = mask_time_indices.to(torch.bool)
|
||||
|
||||
outputs = self.wav2vec2(
|
||||
input_values,
|
||||
attention_mask=attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
mask_time_indices=mask_time_indices,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
# 1. project all transformed features (including masked) to final vq dim
|
||||
transformer_features = self.project_hid(outputs[0])
|
||||
|
||||
# 2. quantize all (unmasked) extracted features and project to final vq dim
|
||||
extract_features = self.dropout_features(outputs[1])
|
||||
quantized_features, codevector_perplexity = self.quantizer(extract_features, mask_time_indices)
|
||||
quantized_features = self.project_q(quantized_features)
|
||||
|
||||
loss = None
|
||||
if self.training:
|
||||
# for training, we sample negatives
|
||||
# 3. sample K negatives (distractors) quantized states for contrastive loss
|
||||
negative_quantized_features = self._sample_negatives(quantized_features, self.config.num_negatives)
|
||||
|
||||
# 4. compute logits, corresponding to `logs = sim(c_t, [q_t, \sim{q}_t]) / \kappa`
|
||||
# of equation (3) in https://arxiv.org/pdf/2006.11477.pdf
|
||||
logits = self.compute_contrastive_logits(
|
||||
quantized_features[None, :],
|
||||
negative_quantized_features,
|
||||
transformer_features,
|
||||
self.config.contrastive_logits_temperature,
|
||||
)
|
||||
|
||||
# 5. if a negative vector is identical to the positive (i.e. when codebook utilization is low),
|
||||
# its cosine similarity will be masked
|
||||
neg_is_pos = (quantized_features == negative_quantized_features).all(-1)
|
||||
if neg_is_pos.any():
|
||||
logits[1:][neg_is_pos] = float("-inf")
|
||||
|
||||
# 6. compute contrastive loss \mathbf{L}_m = cross_entropy(logs) =
|
||||
# -log(exp(sim(c_t, q_t)/\kappa) / \sum_{\sim{q}} exp(sim(c_t, \sim{q})/\kappa))
|
||||
preds = logits.transpose(0, 2).reshape(-1, logits.size(0))
|
||||
target = ((1 - mask_time_indices.long()) * -100).transpose(0, 1).flatten()
|
||||
contrastive_loss = F.cross_entropy(preds.float(), target, reduction="sum")
|
||||
|
||||
# 7. compute diversity loss: \mathbf{L}_d
|
||||
num_codevectors = self.config.num_codevectors_per_group * self.config.num_codevector_groups
|
||||
diversity_loss = (num_codevectors - codevector_perplexity) / num_codevectors
|
||||
|
||||
# 8. \mathbf{L} = \mathbf{L}_m + \alpha * \mathbf{L}_d
|
||||
loss = contrastive_loss + self.config.diversity_loss_weight * diversity_loss
|
||||
|
||||
if not return_dict:
|
||||
if loss is not None:
|
||||
return (loss, transformer_features, quantized_features, codevector_perplexity) + outputs[2:]
|
||||
return (transformer_features, quantized_features, codevector_perplexity) + outputs[2:]
|
||||
|
||||
return Wav2Vec2ForPreTrainingOutput(
|
||||
loss=loss,
|
||||
projected_states=transformer_features,
|
||||
projected_quantized_states=quantized_features,
|
||||
codevector_perplexity=codevector_perplexity,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
@add_start_docstrings("""Wav2Vec2 Model with a `language modeling` head on top. """, WAV_2_VEC_2_START_DOCSTRING)
|
||||
class Wav2Vec2ForMaskedLM(Wav2Vec2PreTrainedModel):
|
||||
def __init__(self, config):
|
||||
@ -986,7 +1390,7 @@ class Wav2Vec2ForMaskedLM(Wav2Vec2PreTrainedModel):
|
||||
logits = self.lm_head(hidden_states)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
output = (logits,) + outputs[2:]
|
||||
return output
|
||||
|
||||
return MaskedLMOutput(logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions)
|
||||
@ -1089,7 +1493,7 @@ class Wav2Vec2ForCTC(Wav2Vec2PreTrainedModel):
|
||||
attention_mask = (
|
||||
attention_mask if attention_mask is not None else torch.ones_like(input_values, dtype=torch.long)
|
||||
)
|
||||
input_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1))
|
||||
input_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long)
|
||||
|
||||
# assuming that padded tokens are filled with -100
|
||||
# when not being attended to
|
||||
@ -1112,7 +1516,7 @@ class Wav2Vec2ForCTC(Wav2Vec2PreTrainedModel):
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
output = (logits,) + outputs[2:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return CausalLMOutput(
|
||||
|
@ -2969,6 +2969,11 @@ class Wav2Vec2ForMaskedLM:
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class Wav2Vec2ForPreTraining:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class Wav2Vec2Model:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
@ -29,8 +29,16 @@ from .test_modeling_common import ModelTesterMixin, _config_zero_init
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers import Wav2Vec2Config, Wav2Vec2ForCTC, Wav2Vec2ForMaskedLM, Wav2Vec2Model, Wav2Vec2Processor
|
||||
from transformers.models.wav2vec2.modeling_wav2vec2 import _compute_mask_indices
|
||||
from transformers import (
|
||||
Wav2Vec2Config,
|
||||
Wav2Vec2FeatureExtractor,
|
||||
Wav2Vec2ForCTC,
|
||||
Wav2Vec2ForMaskedLM,
|
||||
Wav2Vec2ForPreTraining,
|
||||
Wav2Vec2Model,
|
||||
Wav2Vec2Processor,
|
||||
)
|
||||
from transformers.models.wav2vec2.modeling_wav2vec2 import Wav2Vec2GumbelVectorQuantizer, _compute_mask_indices
|
||||
|
||||
|
||||
class Wav2Vec2ModelTester:
|
||||
@ -219,13 +227,7 @@ class Wav2Vec2ModelTester:
|
||||
@require_torch
|
||||
class Wav2Vec2ModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (
|
||||
(
|
||||
Wav2Vec2ForCTC,
|
||||
Wav2Vec2Model,
|
||||
Wav2Vec2ForMaskedLM,
|
||||
)
|
||||
if is_torch_available()
|
||||
else ()
|
||||
(Wav2Vec2ForCTC, Wav2Vec2Model, Wav2Vec2ForMaskedLM, Wav2Vec2ForPreTraining) if is_torch_available() else ()
|
||||
)
|
||||
test_pruning = False
|
||||
test_headmasking = False
|
||||
@ -316,8 +318,14 @@ class Wav2Vec2ModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config=configs_no_init)
|
||||
for name, param in model.named_parameters():
|
||||
uniform_init_parms = [
|
||||
"conv.weight",
|
||||
"masked_spec_embed",
|
||||
"codevectors",
|
||||
"quantizer.weight_proj.weight",
|
||||
]
|
||||
if param.requires_grad:
|
||||
if "conv.weight" in name or "masked_spec_embed" in name:
|
||||
if any([x in name for x in uniform_init_parms]):
|
||||
self.assertTrue(
|
||||
-1.0 <= ((param.data.mean() * 1e9).round() / 1e9).item() <= 1.0,
|
||||
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
|
||||
@ -333,10 +341,14 @@ class Wav2Vec2ModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
def _mock_init_weights(self, module):
|
||||
if hasattr(module, "weight") and module.weight is not None:
|
||||
module.weight.data.fill_(3)
|
||||
if hasattr(module, "weight_g") and module.weight is not None:
|
||||
if hasattr(module, "weight_g") and module.weight_g is not None:
|
||||
module.weight_g.data.fill_(3)
|
||||
if hasattr(module, "weight_v") and module.weight_v is not None:
|
||||
module.weight_v.data.fill_(3)
|
||||
if hasattr(module, "bias") and module.bias is not None:
|
||||
module.bias.data.fill_(3)
|
||||
if hasattr(module, "codevectors") and module.codevectors is not None:
|
||||
module.codevectors.data.fill_(3)
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
@ -346,7 +358,9 @@ class Wav2Vec2ModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
|
||||
@require_torch
|
||||
class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (Wav2Vec2ForCTC, Wav2Vec2Model, Wav2Vec2ForMaskedLM) if is_torch_available() else ()
|
||||
all_model_classes = (
|
||||
(Wav2Vec2ForCTC, Wav2Vec2Model, Wav2Vec2ForMaskedLM, Wav2Vec2ForPreTraining) if is_torch_available() else ()
|
||||
)
|
||||
test_pruning = False
|
||||
test_headmasking = False
|
||||
test_torchscript = False
|
||||
@ -442,8 +456,14 @@ class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config=configs_no_init)
|
||||
for name, param in model.named_parameters():
|
||||
uniform_init_parms = [
|
||||
"conv.weight",
|
||||
"masked_spec_embed",
|
||||
"codevectors",
|
||||
"quantizer.weight_proj.weight",
|
||||
]
|
||||
if param.requires_grad:
|
||||
if "conv.weight" in name or "masked_spec_embed" in name:
|
||||
if any([x in name for x in uniform_init_parms]):
|
||||
self.assertTrue(
|
||||
-1.0 <= ((param.data.mean() * 1e9).round() / 1e9).item() <= 1.0,
|
||||
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
|
||||
@ -459,10 +479,47 @@ class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
def _mock_init_weights(self, module):
|
||||
if hasattr(module, "weight") and module.weight is not None:
|
||||
module.weight.data.fill_(3)
|
||||
if hasattr(module, "weight_g") and module.weight is not None:
|
||||
if hasattr(module, "weight_g") and module.weight_g is not None:
|
||||
module.weight_g.data.fill_(3)
|
||||
if hasattr(module, "weight_v") and module.weight_v is not None:
|
||||
module.weight_v.data.fill_(3)
|
||||
if hasattr(module, "bias") and module.bias is not None:
|
||||
module.bias.data.fill_(3)
|
||||
if hasattr(module, "codevectors") and module.codevectors is not None:
|
||||
module.codevectors.data.fill_(3)
|
||||
|
||||
def test_model_for_pretraining(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
model = Wav2Vec2ForPreTraining(config).to(torch_device)
|
||||
|
||||
features_shape = (
|
||||
inputs_dict["input_values"].shape[0],
|
||||
model._get_feat_extract_output_lengths(torch.tensor(inputs_dict["input_values"].shape[1])),
|
||||
)
|
||||
|
||||
mask_time_indices = _compute_mask_indices(
|
||||
features_shape,
|
||||
model.config.mask_time_prob,
|
||||
model.config.mask_time_length,
|
||||
device=inputs_dict["input_values"].device,
|
||||
min_masks=2,
|
||||
).to(torch_device)
|
||||
|
||||
loss = model(
|
||||
inputs_dict["input_values"],
|
||||
attention_mask=inputs_dict["attention_mask"],
|
||||
mask_time_indices=mask_time_indices,
|
||||
).loss
|
||||
|
||||
mask_time_indices[:, : mask_time_indices.shape[-1] // 2] = True
|
||||
loss_more_masked = model(
|
||||
inputs_dict["input_values"],
|
||||
attention_mask=inputs_dict["attention_mask"],
|
||||
mask_time_indices=mask_time_indices,
|
||||
).loss
|
||||
|
||||
# loss_more_masked has to be bigger or equal loss since more masked inputs have to be predicted
|
||||
self.assertTrue(loss.detach().item() <= loss_more_masked.detach().item())
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
@ -484,24 +541,56 @@ class Wav2Vec2UtilsTest(unittest.TestCase):
|
||||
|
||||
def test_compute_mask_indices_overlap(self):
|
||||
batch_size = 4
|
||||
sequence_length = 60
|
||||
sequence_length = 80
|
||||
mask_prob = 0.5
|
||||
mask_length = 4
|
||||
|
||||
mask = _compute_mask_indices((batch_size, sequence_length), mask_prob, mask_length, torch_device)
|
||||
|
||||
# because of overlap there is a range of possible masks
|
||||
# because of overlap mask don't have to add up exactly to `mask_prob * sequence_length`, but have to be smaller or equal
|
||||
for batch_sum in mask.sum(axis=-1):
|
||||
self.assertIn(
|
||||
int(batch_sum),
|
||||
list(range(int(mask_prob // mask_length * sequence_length), int(mask_prob * sequence_length))),
|
||||
)
|
||||
self.assertTrue(int(batch_sum) <= mask_prob * sequence_length)
|
||||
|
||||
def test_compute_perplexity(self):
|
||||
probs = torch.arange(100, device=torch_device).reshape(2, 5, 10) / 100
|
||||
|
||||
ppl = Wav2Vec2GumbelVectorQuantizer._compute_perplexity(probs)
|
||||
self.assertTrue(abs(ppl.item() - 141.4291) < 1e-3)
|
||||
|
||||
# mask half of the input
|
||||
mask = torch.ones((2,), device=torch_device, dtype=torch.bool)
|
||||
mask[0] = 0
|
||||
|
||||
ppl = Wav2Vec2GumbelVectorQuantizer._compute_perplexity(probs, mask)
|
||||
self.assertTrue(abs(ppl.item() - 58.6757) < 1e-3)
|
||||
|
||||
def test_sample_negatives(self):
|
||||
batch_size = 2
|
||||
sequence_length = 10
|
||||
hidden_size = 4
|
||||
num_negatives = 3
|
||||
|
||||
features = (torch.arange(sequence_length * hidden_size, device=torch_device) // hidden_size).view(
|
||||
sequence_length, hidden_size
|
||||
) # each value in vector consits of same value
|
||||
features = features[None, :].expand(batch_size, sequence_length, hidden_size).contiguous()
|
||||
|
||||
negatives = Wav2Vec2ForPreTraining._sample_negatives(features, num_negatives)
|
||||
|
||||
self.assertTrue(negatives.shape == (num_negatives, batch_size, sequence_length, hidden_size))
|
||||
|
||||
# make sure no negatively sampled vector is actually a positive one
|
||||
for negative in negatives:
|
||||
self.assertTrue(((negative - features) == 0).sum() == 0.0)
|
||||
|
||||
# make sure that full vectors are sampled and not values of vectors => this means that `unique()` yields a single value for `hidden_size` dim
|
||||
self.assertTrue(negatives.unique(dim=-1).shape, (num_negatives, batch_size, sequence_length, 1))
|
||||
|
||||
|
||||
@require_torch
|
||||
@slow
|
||||
@require_datasets
|
||||
@require_soundfile
|
||||
@slow
|
||||
class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
|
||||
def _load_datasamples(self, num_samples):
|
||||
from datasets import load_dataset
|
||||
@ -586,3 +675,160 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
|
||||
"his instant panic was followed by a small sharp blow high on his chest",
|
||||
]
|
||||
self.assertListEqual(predicted_trans, EXPECTED_TRANSCRIPTIONS)
|
||||
|
||||
def test_inference_integration(self):
|
||||
model = Wav2Vec2ForPreTraining.from_pretrained("patrickvonplaten/wav2vec2-base")
|
||||
model.to(torch_device)
|
||||
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
|
||||
"patrickvonplaten/wav2vec2-base", return_attention_mask=True
|
||||
)
|
||||
input_speech = self._load_datasamples(2)
|
||||
|
||||
inputs_dict = feature_extractor(input_speech, return_tensors="pt", padding=True)
|
||||
|
||||
features_shape = (
|
||||
inputs_dict["input_values"].shape[0],
|
||||
model._get_feat_extract_output_lengths(torch.tensor(inputs_dict["input_values"].shape[1])),
|
||||
)
|
||||
|
||||
torch.manual_seed(0)
|
||||
mask_time_indices = _compute_mask_indices(
|
||||
features_shape,
|
||||
model.config.mask_time_prob,
|
||||
model.config.mask_time_length,
|
||||
device=inputs_dict["input_values"].device,
|
||||
min_masks=2,
|
||||
).to(torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model(
|
||||
inputs_dict.input_values.to(torch_device),
|
||||
attention_mask=inputs_dict.attention_mask.to(torch_device),
|
||||
mask_time_indices=mask_time_indices,
|
||||
)
|
||||
|
||||
# compute cosine similarity
|
||||
cosine_sim = torch.cosine_similarity(outputs.projected_states, outputs.projected_quantized_states, dim=-1)
|
||||
|
||||
# retrieve cosine sim of masked features
|
||||
cosine_sim_masked = cosine_sim[mask_time_indices]
|
||||
|
||||
# fmt: off
|
||||
expected_cosine_sim_masked = torch.tensor(
|
||||
[0.7458, 0.7188, 0.6418, 0.3729, 0.3741, 0.3694, 0.3110, 0.2257, 0.4403, 0.5415, 0.3950, 0.3701, 0.8831, 0.8613, 0.5229, 0.6696, 0.7206, 0.7877, 0.6758, 0.8746, 0.6596, 0.6282, 0.6178, 0.5839, 0.5926, 0.6651, 0.4635, 0.6332, 0.6572, 0.8776, 0.4999, 0.7001, 0.7257, 0.5098, 0.6229, 0.4566, 0.5261, 0.6363, 0.5371, 0.6997],
|
||||
device=torch_device,
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
self.assertTrue(torch.allclose(cosine_sim_masked, expected_cosine_sim_masked, atol=1e-3))
|
||||
|
||||
def test_inference_pretrained(self):
|
||||
model = Wav2Vec2ForPreTraining.from_pretrained("patrickvonplaten/wav2vec2-base")
|
||||
model.to(torch_device)
|
||||
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
|
||||
"patrickvonplaten/wav2vec2-base", return_attention_mask=True
|
||||
)
|
||||
input_speech = self._load_datasamples(2)
|
||||
|
||||
inputs_dict = feature_extractor(input_speech, return_tensors="pt", padding=True)
|
||||
|
||||
features_shape = (
|
||||
inputs_dict["input_values"].shape[0],
|
||||
model._get_feat_extract_output_lengths(torch.tensor(inputs_dict["input_values"].shape[1])),
|
||||
)
|
||||
|
||||
torch.manual_seed(0)
|
||||
mask_time_indices = _compute_mask_indices(
|
||||
features_shape,
|
||||
model.config.mask_time_prob,
|
||||
model.config.mask_time_length,
|
||||
device=inputs_dict["input_values"].device,
|
||||
min_masks=2,
|
||||
).to(torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model(
|
||||
inputs_dict.input_values.to(torch_device),
|
||||
attention_mask=inputs_dict.attention_mask.to(torch_device),
|
||||
mask_time_indices=mask_time_indices,
|
||||
)
|
||||
|
||||
# compute cosine similarity
|
||||
cosine_sim = torch.cosine_similarity(outputs.projected_states, outputs.projected_quantized_states, dim=-1)
|
||||
|
||||
# retrieve cosine sim of masked features
|
||||
cosine_sim_masked = cosine_sim[mask_time_indices]
|
||||
|
||||
# ... now compare to randomly initialized model
|
||||
|
||||
config = Wav2Vec2Config.from_pretrained("patrickvonplaten/wav2vec2-base")
|
||||
model_rand = Wav2Vec2ForPreTraining(config).to(torch_device).eval()
|
||||
|
||||
with torch.no_grad():
|
||||
outputs_rand = model_rand(
|
||||
inputs_dict.input_values.to(torch_device),
|
||||
attention_mask=inputs_dict.attention_mask.to(torch_device),
|
||||
mask_time_indices=mask_time_indices,
|
||||
)
|
||||
|
||||
# compute cosine similarity
|
||||
cosine_sim_rand = torch.cosine_similarity(
|
||||
outputs_rand.projected_states, outputs_rand.projected_quantized_states, dim=-1
|
||||
)
|
||||
|
||||
# retrieve cosine sim of masked features
|
||||
cosine_sim_masked_rand = cosine_sim_rand[mask_time_indices]
|
||||
|
||||
# a pretrained wav2vec2 model has learned to predict the quantized latent states
|
||||
# => the cosine similarity between quantized states and predicted states > 0.5
|
||||
# a random wav2vec2 model has not learned to predict the quantized latent states
|
||||
# => the cosine similarity between quantized states and predicted states is very likely < 0.1
|
||||
self.assertTrue(cosine_sim_masked.mean().item() - 5 * cosine_sim_masked_rand.mean().item() > 0)
|
||||
|
||||
def test_loss_pretraining(self):
|
||||
model = Wav2Vec2ForPreTraining.from_pretrained(
|
||||
"patrickvonplaten/wav2vec2-base",
|
||||
attention_dropout=0.0,
|
||||
feat_proj_dropout=0.0,
|
||||
hidden_dropout=0.0,
|
||||
layerdrop=0.0,
|
||||
)
|
||||
model.to(torch_device).train()
|
||||
|
||||
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
|
||||
"patrickvonplaten/wav2vec2-base", return_attention_mask=True
|
||||
)
|
||||
input_speech = self._load_datasamples(2)
|
||||
|
||||
inputs_dict = feature_extractor(input_speech, return_tensors="pt", padding=True)
|
||||
|
||||
features_shape = (
|
||||
inputs_dict["input_values"].shape[0],
|
||||
model._get_feat_extract_output_lengths(inputs_dict["input_values"].shape[1]),
|
||||
)
|
||||
|
||||
torch.manual_seed(0)
|
||||
mask_time_indices = _compute_mask_indices(
|
||||
features_shape,
|
||||
model.config.mask_time_prob,
|
||||
model.config.mask_time_length,
|
||||
device=inputs_dict["input_values"].device,
|
||||
min_masks=2,
|
||||
).to(torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model(
|
||||
inputs_dict.input_values.to(torch_device),
|
||||
attention_mask=inputs_dict.attention_mask.to(torch_device),
|
||||
mask_time_indices=mask_time_indices,
|
||||
)
|
||||
|
||||
# check diversity loss
|
||||
num_codevectors = model.config.num_codevectors_per_group * model.config.num_codevector_groups
|
||||
diversity_loss = (num_codevectors - outputs.codevector_perplexity) / num_codevectors
|
||||
self.assertTrue(abs(diversity_loss.item() - 0.8859) < 1e-3)
|
||||
|
||||
# check overall loss (contrastive loss + diversity loss)
|
||||
expected_loss = 62.5170 if model.device.type == "cpu" else 50.3612
|
||||
|
||||
self.assertTrue(abs(outputs.loss.item() - expected_loss) < 1e-3)
|
||||
|
Loading…
Reference in New Issue
Block a user