diff --git a/docs/source/model_doc/wav2vec2.rst b/docs/source/model_doc/wav2vec2.rst index cd0b6e0cc78..bcfc3f26e4b 100644 --- a/docs/source/model_doc/wav2vec2.rst +++ b/docs/source/model_doc/wav2vec2.rst @@ -79,3 +79,10 @@ Wav2Vec2ForCTC .. autoclass:: transformers.Wav2Vec2ForCTC :members: forward + + +Wav2Vec2ForPreTraining +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.Wav2Vec2ForPreTraining + :members: forward diff --git a/examples/research_projects/wav2vec2/README.md b/examples/research_projects/wav2vec2/README.md index 39bbda38118..238ee6f629a 100644 --- a/examples/research_projects/wav2vec2/README.md +++ b/examples/research_projects/wav2vec2/README.md @@ -184,3 +184,35 @@ run_asr.py \ --preprocessing_num_workers=1 --group_by_length --freeze_feature_extractor --verbose_logging \ --deepspeed ds_config_wav2vec2_zero3.json ``` + +### Pretraining Wav2Vec2 + +The `run_pretrain.py` script allows one to pretrain a Wav2Vec2 model from scratch using Wav2Vec2's contrastive loss objective (see official [paper](https://arxiv.org/abs/2006.11477) for more information). +It is recommended to pre-train Wav2Vec2 with Trainer + Deepspeed (please refer to [this guide](https://huggingface.co/transformers/master/main_classes/deepspeed.html#deepspeed-trainer-integration) for more information). + +Here is an example of how you can use DeepSpeed ZeRO-2 to pretrain a small Wav2Vec2 model: + +``` +PYTHONPATH=../../../src deepspeed --num_gpus 2 run_pretrain.py \ +--output_dir="./wav2vec2-base-libri-100h" \ +--num_train_epochs="3" \ +--per_device_train_batch_size="32" \ +--per_device_eval_batch_size="32" \ +--gradient_accumulation_steps="2" \ +--save_total_limit="3" \ +--save_steps="500" \ +--logging_steps="10" \ +--learning_rate="5e-4" \ +--weight_decay="0.01" \ +--warmup_steps="3000" \ +--model_name_or_path="patrickvonplaten/wav2vec2-base-libri-100h" \ +--dataset_name="librispeech_asr" \ +--dataset_config_name="clean" \ +--train_split_name="train.100" \ +--preprocessing_num_workers="4" \ +--max_duration_in_seconds="10.0" \ +--group_by_length \ +--verbose_logging \ +--fp16 \ +--deepspeed ds_config_wav2vec2_zero2.json \ +``` diff --git a/examples/research_projects/wav2vec2/run_pretrain.py b/examples/research_projects/wav2vec2/run_pretrain.py new file mode 100755 index 00000000000..a34fa404a71 --- /dev/null +++ b/examples/research_projects/wav2vec2/run_pretrain.py @@ -0,0 +1,370 @@ +#!/usr/bin/env python3 +import logging +import sys +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional, Union + +import torch +import torch.nn as nn +from datasets import DatasetDict, load_dataset +from packaging import version + +import librosa +from transformers import ( + HfArgumentParser, + Trainer, + TrainingArguments, + Wav2Vec2Config, + Wav2Vec2FeatureExtractor, + Wav2Vec2ForPreTraining, + is_apex_available, + trainer_utils, +) +from transformers.models.wav2vec2.modeling_wav2vec2 import _compute_mask_indices + + +if is_apex_available(): + from apex import amp + +if version.parse(torch.__version__) >= version.parse("1.6"): + _is_native_amp_available = True + from torch.cuda.amp import autocast + + +logger = logging.getLogger(__name__) + + +@dataclass +class ModelArguments: + """ + Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. + """ + + model_name_or_path: str = field( + metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"} + ) + cache_dir: Optional[str] = field( + default=None, + metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"}, + ) + freeze_feature_extractor: Optional[bool] = field( + default=True, metadata={"help": "Whether to freeze the feature extractor layers of the model."} + ) + gradient_checkpointing: Optional[bool] = field( + default=False, metadata={"help": "Whether to freeze the feature extractor layers of the model."} + ) + verbose_logging: Optional[bool] = field( + default=False, + metadata={"help": "Whether to log verbose messages or not."}, + ) + max_gumbel_temperature: Optional[float] = field( + default=2.0, metadata={"help": "Maximum temperature for gumbel softmax."} + ) + min_gumbel_temperature: Optional[float] = field( + default=0.5, metadata={"help": "Minimum temperature for gumbel softmax."} + ) + gumbel_temperature_decay: Optional[float] = field( + default=0.999995, metadata={"help": "Decay of gumbel temperature during training."} + ) + + +def configure_logger(model_args: ModelArguments, training_args: TrainingArguments): + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + handlers=[logging.StreamHandler(sys.stdout)], + ) + logging_level = logging.WARNING + if model_args.verbose_logging: + logging_level = logging.DEBUG + elif trainer_utils.is_main_process(training_args.local_rank): + logging_level = logging.INFO + logger.setLevel(logging_level) + + +@dataclass +class DataTrainingArguments: + """ + Arguments pertaining to what data we are going to input our model for training and eval. + + Using `HfArgumentParser` we can turn this class + into argparse arguments to be able to specify them on + the command line. + """ + + dataset_name: str = field( + default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."} + ) + dataset_config_name: Optional[str] = field( + default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."} + ) + train_split_name: Optional[str] = field( + default="train", + metadata={ + "help": "The name of the training data set split to use (via the datasets library). Defaults to 'train'" + }, + ) + validation_split_name: Optional[str] = field( + default="validation", + metadata={ + "help": "The name of the validation data set split to use (via the datasets library). Defaults to 'validation'" + }, + ) + speech_file_column: Optional[str] = field( + default="file", + metadata={"help": "Column in the dataset that contains speech file path. Defaults to 'file'"}, + ) + overwrite_cache: bool = field( + default=False, metadata={"help": "Overwrite the cached preprocessed datasets or not."} + ) + preprocessing_num_workers: Optional[int] = field( + default=None, + metadata={"help": "The number of processes to use for the preprocessing."}, + ) + max_duration_in_seconds: Optional[float] = field( + default=20.0, metadata={"help": "Filter audio files that are longer than `max_duration_in_seconds` seconds"} + ) + + +@dataclass +class DataCollatorForWav2Vec2Pretraining: + """ + Data collator that will dynamically pad the inputs received and prepare masked indices + for self-supervised pretraining. + + Args: + model (:class:`~transformers.Wav2Vec2ForPreTraining`): + The Wav2Vec2 model used for pretraining. The data collator needs to have access + to config and ``_get_feat_extract_output_lengths`` function for correct padding. + feature_extractor (:class:`~transformers.Wav2Vec2FeatureExtractor`): + The processor used for proccessing the data. + padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`): + Select a strategy to pad the returned sequences (according to the model's padding side and padding index) + among: + * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single + sequence if provided). + * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the + maximum acceptable input length for the model if that argument is not provided. + * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of + different lengths). + max_length (:obj:`int`, `optional`): + Maximum length of the ``input_values`` of the returned list and optionally padding length (see above). + pad_to_multiple_of (:obj:`int`, `optional`): + If set will pad the sequence to a multiple of the provided value. + This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >= + 7.5 (Volta). + """ + + model: Wav2Vec2ForPreTraining + feature_extractor: Wav2Vec2FeatureExtractor + padding: Union[bool, str] = "longest" + pad_to_multiple_of: Optional[int] = None + max_length: Optional[int] = None + + def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]: + # reformat list to dict and set to pytorch format + batch = self.feature_extractor.pad( + features, + max_length=self.max_length, + padding=self.padding, + pad_to_multiple_of=self.pad_to_multiple_of, + return_tensors="pt", + ) + mask_indices_seq_length = self.model._get_feat_extract_output_lengths(batch["input_values"].shape[-1]) + + # sample randomly masked indices + batch["mask_time_indices"] = _compute_mask_indices( + (batch["input_values"].shape[0], mask_indices_seq_length), + self.model.config.mask_time_prob, + self.model.config.mask_time_length, + device=batch["input_values"].device, + min_masks=2, + ) + + return batch + + +class Wav2Vec2PreTrainer(Trainer): + """ + Subclassed :class:`~transformers.Trainer` for Wav2Vec2-like pretraining. Trainer can decay gumbel softmax temperature during training. + """ + + def __init__(self, *args, max_gumbel_temp=1, min_gumbel_temp=0, gumbel_temp_decay=1.0, **kwargs): + super().__init__(*args, **kwargs) + self.num_update_step = 0 + self.max_gumbel_temp = max_gumbel_temp + self.min_gumbel_temp = min_gumbel_temp + self.gumbel_temp_decay = gumbel_temp_decay + + def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor: + """ + Perform a training step on a batch of inputs. + + Subclass and override to inject custom behavior. + + Args: + model (:obj:`nn.Module`): + The model to train. + inputs (:obj:`Dict[str, Union[torch.Tensor, Any]]`): + The inputs and targets of the model. + + The dictionary will be unpacked before being fed to the model. Most models expect the targets under the + argument :obj:`labels`. Check your model's documentation for all accepted arguments. + + Return: + :obj:`torch.Tensor`: The tensor with training loss on this batch. + """ + + model.train() + inputs = self._prepare_inputs(inputs) + + if self.use_amp: + with autocast(): + loss = self.compute_loss(model, inputs) + else: + loss = self.compute_loss(model, inputs) + + if self.args.n_gpu > 1 or self.deepspeed: + if model.module.config.ctc_loss_reduction == "mean": + loss = loss.mean() + elif model.module.config.ctc_loss_reduction == "sum": + loss = loss.sum() / (inputs["mask_time_indices"]).sum() + else: + raise ValueError(f"{model.config.ctc_loss_reduction} is not valid. Choose one of ['mean', 'sum']") + + if self.args.gradient_accumulation_steps > 1: + loss = loss / self.args.gradient_accumulation_steps + + if self.use_amp: + self.scaler.scale(loss).backward() + elif self.use_apex: + with amp.scale_loss(loss, self.optimizer) as scaled_loss: + scaled_loss.backward() + elif self.deepspeed: + self.deepspeed.backward(loss) + else: + loss.backward() + + self.num_update_step += 1 + # make sure gumbel softmax temperature is decayed + if self.args.n_gpu > 1 or self.deepspeed: + model.module.set_gumbel_temperature( + max(self.max_gumbel_temp * self.gumbel_temp_decay ** self.num_update_step, self.min_gumbel_temp) + ) + else: + model.set_gumbel_temperature( + max(self.max_gumbel_temp * self.gumbel_temp_decay ** self.num_update_step, self.min_gumbel_temp) + ) + + return loss.detach() + + +def main(): + # See all possible arguments in src/transformers/training_args.py + # or by passing the --help flag to this script. + # We now keep distinct sets of args, for a cleaner separation of concerns. + + parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments)) + + model_args, data_args, training_args = parser.parse_args_into_dataclasses() + configure_logger(model_args, training_args) + + # Downloading and loading a dataset from the hub. + datasets = load_dataset(data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir) + + if "validation" not in datasets.keys(): + # make sure only "validation" and "train" keys remain" + datasets = DatasetDict() + datasets["validation"] = load_dataset( + data_args.dataset_name, + data_args.dataset_config_name, + split=f"{data_args.train_split_name}[:{data_args.validation_split_percentage}%]", + cache_dir=model_args.cache_dir, + ) + datasets["train"] = load_dataset( + data_args.dataset_name, + data_args.dataset_config_name, + split=f"{data_args.train_split_name}[{data_args.validation_split_percentage}%:]", + cache_dir=model_args.cache_dir, + ) + else: + # make sure only "validation" and "train" keys remain" + datasets = DatasetDict() + datasets["validation"] = load_dataset( + data_args.dataset_name, + data_args.dataset_config_name, + split="validation", + cache_dir=model_args.cache_dir, + ) + datasets["train"] = load_dataset( + data_args.dataset_name, + data_args.dataset_config_name, + split=f"{data_args.train_split_name}", + cache_dir=model_args.cache_dir, + ) + + # only normalized-inputs-training is supported + feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained( + model_args.model_name_or_path, cache_dir=model_args.cache_dir, do_normalize=True + ) + + def prepare_dataset(batch): + # check that all files have the correct sampling rate + batch["speech"], _ = librosa.load(batch[data_args.speech_file_column], sr=feature_extractor.sampling_rate) + return batch + + # load audio files into numpy arrays + vectorized_datasets = datasets.map( + prepare_dataset, num_proc=data_args.preprocessing_num_workers, remove_columns=datasets["train"].column_names + ) + + # filter audio files that are too long + vectorized_datasets = vectorized_datasets.filter( + lambda data: len(data["speech"]) < int(data_args.max_duration_in_seconds * feature_extractor.sampling_rate) + ) + + def normalize(batch): + return feature_extractor(batch["speech"], sampling_rate=feature_extractor.sampling_rate) + + # normalize and transform to `BatchFeatures` + vectorized_datasets = vectorized_datasets.map( + normalize, + batched=True, + num_proc=data_args.preprocessing_num_workers, + load_from_cache_file=not data_args.overwrite_cache, + remove_columns=vectorized_datasets["train"].column_names, + ) + + # pretraining is only supported for "newer" stable layer norm architecture + # apply_spec_augment has to be True, mask_feature_prob has to be 0.0 + config = Wav2Vec2Config.from_pretrained( + model_args.model_name_or_path, + cache_dir=model_args.cache_dir, + gradient_checkpointing=model_args.gradient_checkpointing, + ) + + if not config.do_stable_layer_norm or config.feat_extract_norm != "layer": + raise ValueError( + "PreTraining is only supported for ``config.do_stable_layer_norm=True`` and ``config.feat_extract_norm='layer'" + ) + + model = Wav2Vec2ForPreTraining(config) + + data_collator = DataCollatorForWav2Vec2Pretraining(model=model, feature_extractor=feature_extractor) + + trainer = Wav2Vec2PreTrainer( + model=model, + data_collator=data_collator, + args=training_args, + train_dataset=vectorized_datasets["train"], + eval_dataset=vectorized_datasets["validation"], + tokenizer=feature_extractor, + max_gumbel_temp=model_args.max_gumbel_temperature, + min_gumbel_temp=model_args.min_gumbel_temperature, + gumbel_temp_decay=model_args.gumbel_temperature_decay, + ) + trainer.train() + + +if __name__ == "__main__": + main() diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 386d64892cc..387e8b938ad 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -1046,6 +1046,7 @@ if is_torch_available(): "WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST", "Wav2Vec2ForCTC", "Wav2Vec2ForMaskedLM", + "Wav2Vec2ForPreTraining", "Wav2Vec2Model", "Wav2Vec2PreTrainedModel", ] @@ -2411,6 +2412,7 @@ if TYPE_CHECKING: WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST, Wav2Vec2ForCTC, Wav2Vec2ForMaskedLM, + Wav2Vec2ForPreTraining, Wav2Vec2Model, Wav2Vec2PreTrainedModel, ) diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 8b144b83c71..ce8b3592df3 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -269,7 +269,7 @@ from ..tapas.modeling_tapas import ( from ..transfo_xl.modeling_transfo_xl import TransfoXLForSequenceClassification, TransfoXLLMHeadModel, TransfoXLModel from ..visual_bert.modeling_visual_bert import VisualBertForPreTraining, VisualBertModel from ..vit.modeling_vit import ViTForImageClassification, ViTModel -from ..wav2vec2.modeling_wav2vec2 import Wav2Vec2ForMaskedLM, Wav2Vec2Model +from ..wav2vec2.modeling_wav2vec2 import Wav2Vec2ForMaskedLM, Wav2Vec2ForPreTraining, Wav2Vec2Model from ..xlm.modeling_xlm import ( XLMForMultipleChoice, XLMForQuestionAnsweringSimple, @@ -463,6 +463,7 @@ MODEL_FOR_PRETRAINING_MAPPING = OrderedDict( (IBertConfig, IBertForMaskedLM), (DebertaConfig, DebertaForMaskedLM), (DebertaV2Config, DebertaV2ForMaskedLM), + (Wav2Vec2Config, Wav2Vec2ForPreTraining), ] ) diff --git a/src/transformers/models/wav2vec2/__init__.py b/src/transformers/models/wav2vec2/__init__.py index 183f85b82d3..b9de364d51f 100644 --- a/src/transformers/models/wav2vec2/__init__.py +++ b/src/transformers/models/wav2vec2/__init__.py @@ -32,6 +32,7 @@ if is_torch_available(): "WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST", "Wav2Vec2ForCTC", "Wav2Vec2ForMaskedLM", + "Wav2Vec2ForPreTraining", "Wav2Vec2Model", "Wav2Vec2PreTrainedModel", ] @@ -48,6 +49,7 @@ if TYPE_CHECKING: WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST, Wav2Vec2ForCTC, Wav2Vec2ForMaskedLM, + Wav2Vec2ForPreTraining, Wav2Vec2Model, Wav2Vec2PreTrainedModel, ) diff --git a/src/transformers/models/wav2vec2/configuration_wav2vec2.py b/src/transformers/models/wav2vec2/configuration_wav2vec2.py index 33b0e9584c9..88200133d54 100644 --- a/src/transformers/models/wav2vec2/configuration_wav2vec2.py +++ b/src/transformers/models/wav2vec2/configuration_wav2vec2.py @@ -71,6 +71,8 @@ class Wav2Vec2Config(PretrainedConfig): feat_extract_activation (:obj:`str, `optional`, defaults to :obj:`"gelu"`): The non-linear activation function (function or string) in the 1D convolutional layers of the feature extractor. If string, :obj:`"gelu"`, :obj:`"relu"`, :obj:`"selu"` and :obj:`"gelu_new"` are supported. + feat_quantizer_dropout (obj:`float`, `optional`, defaults to 0.0): + The dropout probabilitiy for quantized feature extractor states. conv_dim (:obj:`Tuple[int]`, `optional`, defaults to :obj:`(512, 512, 512, 512, 512, 512, 512)`): A tuple of integers defining the number of input and output channels of each 1D convolutional layer in the feature extractor. The length of `conv_dim` defines the number of 1D convolutional layers. @@ -108,6 +110,22 @@ class Wav2Vec2Config(PretrainedConfig): masked along the time axis. This is only relevant if ``apply_spec_augment is True``. mask_feature_length (:obj:`int`, `optional`, defaults to 10): Length of vector span along the feature axis. + num_codevectors_per_group (:obj:`int`, `optional`, defaults to 320): + Number of entries in each quantization codebook (group). + num_codevector_groups (:obj:`int`, `optional`, defaults to 2): + Number of codevector groups for product codevector quantization. + contrastive_logits_temperature (:obj:`float`, `optional`, defaults to 0.1): + The temperature `kappa` in the contrastive loss. + feat_quantizer_dropout (:obj:`float`, `optional`, defaults to 0.0): + The dropout probabilitiy for the output of the feature extractor that's used by the quantizer. + num_negatives (:obj:`int`, `optional`, defaults to 100): + Number of negative samples for the contrastive loss. + codevector_dim (:obj:`int`, `optional`, defaults to 256): + Dimensionality of the quantized feature vectors. + proj_codevector_dim (:obj:`int`, `optional`, defaults to 256): + Dimensionality of the final projection of both the quantized and the transformer features. + diversity_loss_weight (:obj:`int`, `optional`, defaults to 0.1): + The weight of the codebook diversity loss component. ctc_loss_reduction (:obj:`str`, `optional`, defaults to :obj:`"sum"`): Specifies the reduction to apply to the output of ``torch.nn.CTCLoss``. Only relevant when training an instance of :class:`~transformers.Wav2Vec2ForCTC`. @@ -145,6 +163,7 @@ class Wav2Vec2Config(PretrainedConfig): activation_dropout=0.1, attention_dropout=0.1, feat_proj_dropout=0.1, + feat_quantizer_dropout=0.0, final_dropout=0.1, layerdrop=0.1, initializer_range=0.02, @@ -163,6 +182,13 @@ class Wav2Vec2Config(PretrainedConfig): mask_time_length=10, mask_feature_prob=0.0, mask_feature_length=10, + num_codevectors_per_group=320, + num_codevector_groups=2, + contrastive_logits_temperature=0.1, + num_negatives=100, + codevector_dim=256, + proj_codevector_dim=256, + diversity_loss_weight=0.1, ctc_loss_reduction="sum", ctc_zero_infinity=False, gradient_checkpointing=False, @@ -217,6 +243,16 @@ class Wav2Vec2Config(PretrainedConfig): self.mask_feature_prob = mask_feature_prob self.mask_feature_length = mask_feature_length + # parameters for pretraining with codevector quantized representations + self.num_codevectors_per_group = num_codevectors_per_group + self.num_codevector_groups = num_codevector_groups + self.contrastive_logits_temperature = contrastive_logits_temperature + self.feat_quantizer_dropout = feat_quantizer_dropout + self.num_negatives = num_negatives + self.codevector_dim = codevector_dim + self.proj_codevector_dim = proj_codevector_dim + self.diversity_loss_weight = diversity_loss_weight + # ctc loss self.ctc_loss_reduction = ctc_loss_reduction self.ctc_zero_infinity = ctc_zero_infinity diff --git a/src/transformers/models/wav2vec2/convert_wav2vec2_original_pytorch_checkpoint_to_pytorch.py b/src/transformers/models/wav2vec2/convert_wav2vec2_original_pytorch_checkpoint_to_pytorch.py index 2ba66c70be8..f27d7168a94 100644 --- a/src/transformers/models/wav2vec2/convert_wav2vec2_original_pytorch_checkpoint_to_pytorch.py +++ b/src/transformers/models/wav2vec2/convert_wav2vec2_original_pytorch_checkpoint_to_pytorch.py @@ -28,7 +28,7 @@ from transformers import ( Wav2Vec2CTCTokenizer, Wav2Vec2FeatureExtractor, Wav2Vec2ForCTC, - Wav2Vec2Model, + Wav2Vec2ForPreTraining, Wav2Vec2Processor, logging, ) @@ -50,9 +50,20 @@ MAPPING = { "final_layer_norm": "encoder.layers.*.final_layer_norm", "encoder.layer_norm": "encoder.layer_norm", "w2v_model.layer_norm": "feature_projection.layer_norm", + "quantizer.weight_proj": "quantizer.weight_proj", + "quantizer.vars": "quantizer.codevectors", + "project_q": "project_q", + "final_proj": "project_hid", "w2v_encoder.proj": "lm_head", "mask_emb": "masked_spec_embed", } +TOP_LEVEL_KEYS = [ + "lm_head", + "quantizer.weight_proj", + "quantizer.codevectors", + "project_q", + "project_hid", +] def set_recursively(hf_pointer, key, value, full_name, weight_type): @@ -82,11 +93,11 @@ def set_recursively(hf_pointer, key, value, full_name, weight_type): logger.info(f"{key + '.' + weight_type if weight_type is not None else ''} was initialized from {full_name}.") -def recursively_load_weights(fairseq_model, hf_model, is_finetuned): +def recursively_load_weights(fairseq_model, hf_model, is_headless): unused_weights = [] fairseq_dict = fairseq_model.state_dict() - feature_extractor = hf_model.wav2vec2.feature_extractor if is_finetuned else hf_model.feature_extractor + feature_extractor = hf_model.wav2vec2.feature_extractor for name, value in fairseq_dict.items(): is_used = False @@ -101,9 +112,8 @@ def recursively_load_weights(fairseq_model, hf_model, is_finetuned): is_used = True else: for key, mapped_key in MAPPING.items(): - mapped_key = "wav2vec2." + mapped_key if (is_finetuned and mapped_key != "lm_head") else mapped_key - - if key in name or (key.split("w2v_model.")[-1] == name.split(".")[0] and not is_finetuned): + mapped_key = "wav2vec2." + mapped_key if mapped_key not in TOP_LEVEL_KEYS else mapped_key + if key in name or key.split("w2v_model.")[-1] == name.split(".")[0]: is_used = True if "*" in mapped_key: layer_index = name.split(key)[0].split(".")[-2] @@ -112,10 +122,11 @@ def recursively_load_weights(fairseq_model, hf_model, is_finetuned): weight_type = "weight_g" elif "weight_v" in name: weight_type = "weight_v" - elif "weight" in name: - weight_type = "weight" elif "bias" in name: weight_type = "bias" + elif "weight" in name: + # TODO: don't match quantizer.weight_proj + weight_type = "weight" else: weight_type = None set_recursively(hf_model, mapped_key, value, name, weight_type) @@ -213,7 +224,7 @@ def convert_wav2vec2_checkpoint( hf_wav2vec = Wav2Vec2ForCTC(config) else: - hf_wav2vec = Wav2Vec2Model(config) + hf_wav2vec = Wav2Vec2ForPreTraining(config) if is_finetuned: model, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task( @@ -224,7 +235,7 @@ def convert_wav2vec2_checkpoint( model = model[0].eval() - recursively_load_weights(model, hf_wav2vec, is_finetuned) + recursively_load_weights(model, hf_wav2vec, not is_finetuned) hf_wav2vec.save_pretrained(pytorch_dump_folder_path) diff --git a/src/transformers/models/wav2vec2/modeling_wav2vec2.py b/src/transformers/models/wav2vec2/modeling_wav2vec2.py index edaad028523..5039595a29c 100755 --- a/src/transformers/models/wav2vec2/modeling_wav2vec2.py +++ b/src/transformers/models/wav2vec2/modeling_wav2vec2.py @@ -15,7 +15,8 @@ """ PyTorch Wav2Vec2 model. """ import warnings -from typing import Optional, Tuple +from dataclasses import dataclass +from typing import Optional, Tuple, Union import numpy as np import torch @@ -26,7 +27,12 @@ from torch import nn from transformers.deepspeed import is_deepspeed_zero3_enabled from ...activations import ACT2FN -from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings +from ...file_utils import ( + ModelOutput, + add_start_docstrings, + add_start_docstrings_to_model_forward, + replace_return_docstrings, +) from ...modeling_outputs import BaseModelOutput, CausalLMOutput, MaskedLMOutput from ...modeling_utils import PreTrainedModel from ...utils import logging @@ -46,6 +52,71 @@ WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST = [ ] +@dataclass +class Wav2Vec2BaseModelOutput(ModelOutput): + """ + Output type of :class:`~transformers.Wav2Vec2BaseModelOutput`, with potential hidden states and attentions. + + Args: + last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + extract_features (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, conv_dim[-1])`): + Sequence of extracted feature vectors of the last convolutional layer of the model. + hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): + Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) + of shape :obj:`(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): + Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads, + sequence_length, sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + last_hidden_state: torch.FloatTensor = None + extract_features: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + + +@dataclass +class Wav2Vec2ForPreTrainingOutput(ModelOutput): + """ + Output type of :class:`~transformers.Wav2Vec2ForPreTrainingOutput`, with potential hidden states and attentions. + + Args: + loss (`optional`, returned when model is in train mode, ``torch.FloatTensor`` of shape :obj:`(1,)`): + Total loss as the sum of the contrastive loss (L_m) and the diversity loss (L_d) as stated in the `official + paper `__ . (classification) loss. + projected_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.proj_codevector_dim)`): + Hidden-states of the model projected to `config.proj_codevector_dim` that can be used to predict the masked + projected quantized states. + projected_quantized_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.proj_codevector_dim)`): + Quantized extracted feature vectors projected to `config.proj_codevector_dim` representing the positive + target vectors for contrastive loss. + hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): + Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) + of shape :obj:`(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): + Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads, + sequence_length, sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + loss: Optional[torch.FloatTensor] = None + projected_states: torch.FloatTensor = None + projected_quantized_states: torch.FloatTensor = None + codevector_perplexity: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + + def _compute_mask_indices( shape: Tuple[int, int], mask_prob: float, @@ -271,10 +342,11 @@ class Wav2Vec2FeatureProjection(nn.Module): self.dropout = nn.Dropout(config.feat_proj_dropout) def forward(self, hidden_states): - hidden_states = self.layer_norm(hidden_states) - hidden_states = self.projection(hidden_states) + # non-projected hidden states are needed for quantization + norm_hidden_states = self.layer_norm(hidden_states) + hidden_states = self.projection(norm_hidden_states) hidden_states = self.dropout(hidden_states) - return hidden_states + return hidden_states, norm_hidden_states # Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->Wav2Vec2 @@ -685,6 +757,86 @@ class Wav2Vec2EncoderStableLayerNorm(nn.Module): ) +class Wav2Vec2GumbelVectorQuantizer(nn.Module): + """ + Vector quantization using gumbel softmax. See `CATEGORICAL REPARAMETERIZATION WITH GUMBEL-SOFTMAX + `__ for more information. + """ + + def __init__(self, config): + super().__init__() + self.num_groups = config.num_codevector_groups + self.num_vars = config.num_codevectors_per_group + + assert ( + config.codevector_dim % self.num_groups == 0 + ), f"`config.codevector_dim {config.codevector_dim} must be divisible by `config.num_codevector_groups` {self.num_groups} for concatenation" + + # storage for codebook variables (codewords) + self.codevectors = nn.Parameter( + torch.FloatTensor(1, self.num_groups * self.num_vars, config.codevector_dim // self.num_groups) + ) + self.weight_proj = nn.Linear(config.conv_dim[-1], self.num_groups * self.num_vars) + + # can be decayed for training + self.temperature = 1 + + def set_temperature(self, temperature: int): + self.temperature = temperature + + @staticmethod + def _compute_perplexity(probs, mask=None): + if mask is not None: + mask_extended = mask.flatten()[:, None, None].expand(probs.shape) + probs = torch.where(mask_extended, probs, torch.zeros_like(probs)) + marginal_probs = probs.sum(dim=0) / mask.sum() + else: + marginal_probs = probs.mean(dim=0) + + perplexity = torch.exp(-torch.sum(marginal_probs * torch.log(marginal_probs + 1e-7), dim=-1)).sum() + return perplexity + + def forward(self, hidden_states, mask_time_indices=None): + batch_size, sequence_length, hidden_size = hidden_states.shape + + # project to codevector dim + hidden_states = self.weight_proj(hidden_states) + hidden_states = hidden_states.view(batch_size * sequence_length * self.num_groups, -1) + + if self.training: + # sample code vector probs via gumbel in differentiateable way + codevector_probs = F.gumbel_softmax(hidden_states.float(), tau=self.temperature, hard=True).type_as( + hidden_states + ) + + # compute perplexity + codevector_soft_dist = torch.softmax( + hidden_states.view(batch_size * sequence_length, self.num_groups, -1).float(), dim=-1 + ) + perplexity = self._compute_perplexity(codevector_soft_dist, mask_time_indices) + else: + # take argmax in non-differentiable way + # comptute hard codevector distribution (one hot) + codevector_idx = hidden_states.argmax(dim=-1) + codevector_probs = hidden_states.new_zeros(*hidden_states.shape).scatter_( + -1, codevector_idx.view(-1, 1), 1.0 + ) + codevector_probs = codevector_probs.view(batch_size * sequence_length, self.num_groups, -1) + + perplexity = self._compute_perplexity(codevector_probs, mask_time_indices) + + codevector_probs = codevector_probs.view(batch_size * sequence_length, -1) + # use probs to retrieve codevectors + codevectors_per_group = codevector_probs.unsqueeze(-1) * self.codevectors + codevectors = ( + codevectors_per_group.view(batch_size * sequence_length, self.num_groups, self.num_vars, -1) + .sum(-2) + .view(batch_size, sequence_length, -1) + ) + + return codevectors, perplexity + + class Wav2Vec2PreTrainedModel(PreTrainedModel): """ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained @@ -697,7 +849,12 @@ class Wav2Vec2PreTrainedModel(PreTrainedModel): def _init_weights(self, module): """Initialize the weights""" - if isinstance(module, nn.Linear): + # gumbel softmax requires special init + if isinstance(module, Wav2Vec2GumbelVectorQuantizer): + module.weight_proj.weight.data.normal_(mean=0.0, std=1) + module.weight_proj.bias.data.zero_() + nn.init.uniform_(module.codevectors) + elif isinstance(module, nn.Linear): # Slightly different from the TF version which uses truncated_normal for initialization # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) @@ -720,7 +877,7 @@ class Wav2Vec2PreTrainedModel(PreTrainedModel): if isinstance(module, (nn.Linear, nn.Conv1d)) and module.bias is not None: module.bias.data.zero_() - def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor): + def _get_feat_extract_output_lengths(self, input_lengths: Union[torch.LongTensor, int]): """ Computes the output length of the convolutional layers """ @@ -733,7 +890,7 @@ class Wav2Vec2PreTrainedModel(PreTrainedModel): for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride): input_lengths = _conv_out_length(input_lengths, kernel_size, stride) - return input_lengths.to(torch.long) + return input_lengths WAV_2_VEC_2_START_DOCSTRING = r""" @@ -797,7 +954,7 @@ WAV_2_VEC_2_INPUTS_DOCSTRING = r""" WAV_2_VEC_2_START_DOCSTRING, ) class Wav2Vec2Model(Wav2Vec2PreTrainedModel): - def __init__(self, config): + def __init__(self, config: Wav2Vec2Config): super().__init__(config) self.config = config self.feature_extractor = Wav2Vec2FeatureExtractor(config) @@ -812,12 +969,53 @@ class Wav2Vec2Model(Wav2Vec2PreTrainedModel): self.init_weights() + def _mask_hidden_states( + self, hidden_states: torch.FloatTensor, mask_time_indices: Optional[torch.FloatTensor] = None + ): + """ + Masks extracted features along time axis and/or along feature axis according to `SpecAugment + `__ . + """ + + # `config.apply_spec_augment` can set masking to False + if not getattr(self.config, "apply_spec_augment", True): + return hidden_states + + if mask_time_indices is not None: + # apply SpecAugment along time axis with given mask_time_indices + hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype) + elif self.config.mask_time_prob > 0 and self.training: + # generate indices & apply SpecAugment along time axis + batch_size, sequence_length, hidden_size = hidden_states.size() + + mask_time_indices = _compute_mask_indices( + (batch_size, sequence_length), + mask_prob=self.config.mask_time_prob, + mask_length=self.config.mask_time_length, + device=hidden_states.device, + min_masks=2, + ) + hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype) + + if self.config.mask_feature_prob > 0 and self.training: + # generate indices & apply SpecAugment along feature axis + mask_feature_indices = _compute_mask_indices( + (batch_size, hidden_size), + mask_prob=self.config.mask_feature_prob, + mask_length=self.config.mask_feature_length, + device=hidden_states.device, + ) + hidden_states[mask_feature_indices[:, None].expand(-1, sequence_length, -1)] = 0 + + return hidden_states + @add_start_docstrings_to_model_forward(WAV_2_VEC_2_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC) def forward( self, input_values, attention_mask=None, + mask_time_indices=None, output_attentions=None, output_hidden_states=None, return_dict=None, @@ -852,49 +1050,30 @@ class Wav2Vec2Model(Wav2Vec2PreTrainedModel): ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict - hidden_states = self.feature_extractor(input_values) - hidden_states = hidden_states.transpose(1, 2) + extract_features = self.feature_extractor(input_values) + extract_features = extract_features.transpose(1, 2) if attention_mask is not None: # compute real output lengths according to convolution formula - output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)) + output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long) attention_mask = torch.zeros( - hidden_states.shape[:2], dtype=hidden_states.dtype, device=hidden_states.device + extract_features.shape[:2], dtype=extract_features.dtype, device=extract_features.device ) # these two operations makes sure that all values # before the output lengths indices are attended to attention_mask[ - (torch.arange(attention_mask.shape[0], device=hidden_states.device), output_lengths - 1) + (torch.arange(attention_mask.shape[0], device=extract_features.device), output_lengths - 1) ] = 1 attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool() - hidden_states = self.feature_projection(hidden_states) + hidden_states, extract_features = self.feature_projection(extract_features) - if self.config.apply_spec_augment and self.training: - batch_size, sequence_length, hidden_size = hidden_states.size() + if mask_time_indices is not None: # apply SpecAugment along time axis with given indices + hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype) - # apply SpecAugment along time axis - if self.config.mask_time_prob > 0: - mask_time_indices = _compute_mask_indices( - (batch_size, sequence_length), - mask_prob=self.config.mask_time_prob, - mask_length=self.config.mask_time_length, - device=hidden_states.device, - min_masks=2, - ) - hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype) - - # apply SpecAugment along feature axis - if self.config.mask_feature_prob > 0: - mask_feature_indices = _compute_mask_indices( - (batch_size, hidden_size), - mask_prob=self.config.mask_feature_prob, - mask_length=self.config.mask_feature_length, - device=hidden_states.device, - ) - hidden_states[mask_feature_indices[:, None].expand(-1, sequence_length, -1)] = 0 + hidden_states = self._mask_hidden_states(hidden_states) encoder_outputs = self.encoder( hidden_states, @@ -907,15 +1086,240 @@ class Wav2Vec2Model(Wav2Vec2PreTrainedModel): hidden_states = encoder_outputs[0] if not return_dict: - return (hidden_states,) + encoder_outputs[1:] + return (hidden_states, extract_features) + encoder_outputs[1:] - return BaseModelOutput( + return Wav2Vec2BaseModelOutput( last_hidden_state=hidden_states, + extract_features=extract_features, hidden_states=encoder_outputs.hidden_states, attentions=encoder_outputs.attentions, ) +@add_start_docstrings("""Wav2Vec2 Model with a quantizer and `VQ` head on top. """, WAV_2_VEC_2_START_DOCSTRING) +class Wav2Vec2ForPreTraining(Wav2Vec2PreTrainedModel): + def __init__(self, config: Wav2Vec2Config): + super().__init__(config) + self.wav2vec2 = Wav2Vec2Model(config) + self.dropout_features = nn.Dropout(config.feat_quantizer_dropout) + + self.quantizer = Wav2Vec2GumbelVectorQuantizer(config) + self.project_q = nn.Linear(config.codevector_dim, config.proj_codevector_dim) + self.project_hid = nn.Linear(config.hidden_size, config.proj_codevector_dim) + + self.init_weights() + + def set_gumbel_temperature(self, temperature: int): + """ + Set the Gumbel softmax temperature to a given value. Only necessary for training + """ + return self.quantizer.set_temperature(temperature) + + def freeze_feature_extractor(self): + """ + Calling this function will disable the gradient computation for the feature extractor so that its parameters + will not be updated during training. + """ + self.wav2vec2.feature_extractor._freeze_parameters() + + @staticmethod + def _sample_negatives(features: torch.FloatTensor, num_negatives: int): + """ + Sample `num_negatives` vectors from feature vectors. + """ + batch_size, sequence_length, hidden_size = features.shape + if sequence_length <= 1: + raise ValueError( + f"`features should have `sequence_length` > 1, but are of shape (batch_size, sequence_length, hidden_size) = ({batch_size, sequence_length, hidden_size})." + ) + + features = features.view(-1, hidden_size) # BTC => (BxT)C + + with torch.no_grad(): + # get `num_negatives` random vector indices from the same utterance + sampled_negative_indices = torch.randint( + low=0, + high=sequence_length - 1, + size=(batch_size, num_negatives * sequence_length), + device=features.device, + ) + + # generate indices of the positive vectors themselves, repeat them `num_negatives` times + feature_indices = ( + torch.arange(sequence_length, device=features.device)[:, None] + .expand(sequence_length, num_negatives) + .flatten() + ) + + # avoid sampling the same positive vector, but keep the distribution uniform + sampled_negative_indices[sampled_negative_indices >= feature_indices] += 1 + + # correct for batch size + for batch_idx in range(1, batch_size): + sampled_negative_indices[batch_idx] += batch_idx * sequence_length + + # take negative vectors from sampled indices + sampled_negatives = features[sampled_negative_indices.view(-1)] + sampled_negatives = sampled_negatives.view(batch_size, sequence_length, num_negatives, hidden_size).permute( + 2, 0, 1, 3 + ) + + return sampled_negatives + + @staticmethod + def compute_contrastive_logits( + target_features: torch.FloatTensor, + negative_features: torch.FloatTensor, + predicted_features: torch.FloatTensor, + temperature: int = 1, + ): + """ + Compute logits for contrastive loss based using cosine similarity as the distance measure between + `[positive_feature, negative_features]` and `[predicted_features]`. Additionally, temperature can be applied. + """ + target_features = torch.cat([target_features, negative_features], dim=0) + + logits = torch.cosine_similarity(predicted_features.float(), target_features.float(), dim=-1).type_as( + target_features + ) + + # apply temperature + logits = logits / temperature + return logits + + @add_start_docstrings_to_model_forward(WAV_2_VEC_2_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_values, + attention_mask=None, + mask_time_indices=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + mask_time_indices (:obj:`torch.BoolTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Indices to mask extracted features for contrastive loss. When in training mode, model learns to predict + masked extracted features in `config.proj_codevector_dim` space. + + Returns: + + Example:: + + >>> import torch + >>> from transformers import Wav2Vec2FeatureExtractor, Wav2Vec2ForPreTraining + >>> from transformers.models.wav2vec2.modeling_wav2vec2 import _compute_mask_indices + >>> from datasets import load_dataset + >>> import soundfile as sf + + >>> feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("patrickvonplaten/wav2vec2-base") + >>> model = Wav2Vec2ForPreTraining.from_pretrained("patrickvonplaten/wav2vec2-base") + + + >>> def map_to_array(batch): + ... speech, _ = sf.read(batch["file"]) + ... batch["speech"] = speech + ... return batch + + + >>> ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation") + >>> ds = ds.map(map_to_array) + + >>> input_values = feature_extractor(ds["speech"][0], return_tensors="pt").input_values # Batch size 1 + + >>> # compute masked indices + >>> batch_size, raw_sequence_length = input_values.shape + >>> sequence_length = model._get_feat_extract_output_lengths(raw_sequence_length) + >>> mask_time_indices = _compute_mask_indices((batch_size, sequence_length), mask_prob=0.2, mask_length=2, device=model.device) + + >>> with torch.no_grad(): + ... outputs = model(input_values, mask_time_indices=mask_time_indices) + + >>> # compute cosine similarity between predicted (=projected_states) and target (=projected_quantized_states) + >>> cosine_sim = torch.cosine_similarity( + ... outputs.projected_states, outputs.projected_quantized_states, dim=-1 + ... ) + + >>> # show that cosine similarity is much higher than random + >>> assert cosine_sim[mask_time_indices].mean() > 0.5 + + >>> # for contrastive loss training model should be put into train mode + >>> model.train() + >>> loss = model(input_values, mask_time_indices=mask_time_indices).loss + """ + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if mask_time_indices is not None: + mask_time_indices = mask_time_indices.to(torch.bool) + + outputs = self.wav2vec2( + input_values, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + mask_time_indices=mask_time_indices, + return_dict=return_dict, + ) + + # 1. project all transformed features (including masked) to final vq dim + transformer_features = self.project_hid(outputs[0]) + + # 2. quantize all (unmasked) extracted features and project to final vq dim + extract_features = self.dropout_features(outputs[1]) + quantized_features, codevector_perplexity = self.quantizer(extract_features, mask_time_indices) + quantized_features = self.project_q(quantized_features) + + loss = None + if self.training: + # for training, we sample negatives + # 3. sample K negatives (distractors) quantized states for contrastive loss + negative_quantized_features = self._sample_negatives(quantized_features, self.config.num_negatives) + + # 4. compute logits, corresponding to `logs = sim(c_t, [q_t, \sim{q}_t]) / \kappa` + # of equation (3) in https://arxiv.org/pdf/2006.11477.pdf + logits = self.compute_contrastive_logits( + quantized_features[None, :], + negative_quantized_features, + transformer_features, + self.config.contrastive_logits_temperature, + ) + + # 5. if a negative vector is identical to the positive (i.e. when codebook utilization is low), + # its cosine similarity will be masked + neg_is_pos = (quantized_features == negative_quantized_features).all(-1) + if neg_is_pos.any(): + logits[1:][neg_is_pos] = float("-inf") + + # 6. compute contrastive loss \mathbf{L}_m = cross_entropy(logs) = + # -log(exp(sim(c_t, q_t)/\kappa) / \sum_{\sim{q}} exp(sim(c_t, \sim{q})/\kappa)) + preds = logits.transpose(0, 2).reshape(-1, logits.size(0)) + target = ((1 - mask_time_indices.long()) * -100).transpose(0, 1).flatten() + contrastive_loss = F.cross_entropy(preds.float(), target, reduction="sum") + + # 7. compute diversity loss: \mathbf{L}_d + num_codevectors = self.config.num_codevectors_per_group * self.config.num_codevector_groups + diversity_loss = (num_codevectors - codevector_perplexity) / num_codevectors + + # 8. \mathbf{L} = \mathbf{L}_m + \alpha * \mathbf{L}_d + loss = contrastive_loss + self.config.diversity_loss_weight * diversity_loss + + if not return_dict: + if loss is not None: + return (loss, transformer_features, quantized_features, codevector_perplexity) + outputs[2:] + return (transformer_features, quantized_features, codevector_perplexity) + outputs[2:] + + return Wav2Vec2ForPreTrainingOutput( + loss=loss, + projected_states=transformer_features, + projected_quantized_states=quantized_features, + codevector_perplexity=codevector_perplexity, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + @add_start_docstrings("""Wav2Vec2 Model with a `language modeling` head on top. """, WAV_2_VEC_2_START_DOCSTRING) class Wav2Vec2ForMaskedLM(Wav2Vec2PreTrainedModel): def __init__(self, config): @@ -986,7 +1390,7 @@ class Wav2Vec2ForMaskedLM(Wav2Vec2PreTrainedModel): logits = self.lm_head(hidden_states) if not return_dict: - output = (logits,) + outputs[1:] + output = (logits,) + outputs[2:] return output return MaskedLMOutput(logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions) @@ -1089,7 +1493,7 @@ class Wav2Vec2ForCTC(Wav2Vec2PreTrainedModel): attention_mask = ( attention_mask if attention_mask is not None else torch.ones_like(input_values, dtype=torch.long) ) - input_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)) + input_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long) # assuming that padded tokens are filled with -100 # when not being attended to @@ -1112,7 +1516,7 @@ class Wav2Vec2ForCTC(Wav2Vec2PreTrainedModel): ) if not return_dict: - output = (logits,) + outputs[1:] + output = (logits,) + outputs[2:] return ((loss,) + output) if loss is not None else output return CausalLMOutput( diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index 036a0a1c5ac..f3b8e813488 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -2969,6 +2969,11 @@ class Wav2Vec2ForMaskedLM: requires_backends(self, ["torch"]) +class Wav2Vec2ForPreTraining: + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class Wav2Vec2Model: def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) diff --git a/tests/test_modeling_wav2vec2.py b/tests/test_modeling_wav2vec2.py index c43515df0d7..0934967dc28 100644 --- a/tests/test_modeling_wav2vec2.py +++ b/tests/test_modeling_wav2vec2.py @@ -29,8 +29,16 @@ from .test_modeling_common import ModelTesterMixin, _config_zero_init if is_torch_available(): import torch - from transformers import Wav2Vec2Config, Wav2Vec2ForCTC, Wav2Vec2ForMaskedLM, Wav2Vec2Model, Wav2Vec2Processor - from transformers.models.wav2vec2.modeling_wav2vec2 import _compute_mask_indices + from transformers import ( + Wav2Vec2Config, + Wav2Vec2FeatureExtractor, + Wav2Vec2ForCTC, + Wav2Vec2ForMaskedLM, + Wav2Vec2ForPreTraining, + Wav2Vec2Model, + Wav2Vec2Processor, + ) + from transformers.models.wav2vec2.modeling_wav2vec2 import Wav2Vec2GumbelVectorQuantizer, _compute_mask_indices class Wav2Vec2ModelTester: @@ -219,13 +227,7 @@ class Wav2Vec2ModelTester: @require_torch class Wav2Vec2ModelTest(ModelTesterMixin, unittest.TestCase): all_model_classes = ( - ( - Wav2Vec2ForCTC, - Wav2Vec2Model, - Wav2Vec2ForMaskedLM, - ) - if is_torch_available() - else () + (Wav2Vec2ForCTC, Wav2Vec2Model, Wav2Vec2ForMaskedLM, Wav2Vec2ForPreTraining) if is_torch_available() else () ) test_pruning = False test_headmasking = False @@ -316,8 +318,14 @@ class Wav2Vec2ModelTest(ModelTesterMixin, unittest.TestCase): for model_class in self.all_model_classes: model = model_class(config=configs_no_init) for name, param in model.named_parameters(): + uniform_init_parms = [ + "conv.weight", + "masked_spec_embed", + "codevectors", + "quantizer.weight_proj.weight", + ] if param.requires_grad: - if "conv.weight" in name or "masked_spec_embed" in name: + if any([x in name for x in uniform_init_parms]): self.assertTrue( -1.0 <= ((param.data.mean() * 1e9).round() / 1e9).item() <= 1.0, msg=f"Parameter {name} of model {model_class} seems not properly initialized", @@ -333,10 +341,14 @@ class Wav2Vec2ModelTest(ModelTesterMixin, unittest.TestCase): def _mock_init_weights(self, module): if hasattr(module, "weight") and module.weight is not None: module.weight.data.fill_(3) - if hasattr(module, "weight_g") and module.weight is not None: + if hasattr(module, "weight_g") and module.weight_g is not None: module.weight_g.data.fill_(3) + if hasattr(module, "weight_v") and module.weight_v is not None: + module.weight_v.data.fill_(3) if hasattr(module, "bias") and module.bias is not None: module.bias.data.fill_(3) + if hasattr(module, "codevectors") and module.codevectors is not None: + module.codevectors.data.fill_(3) @slow def test_model_from_pretrained(self): @@ -346,7 +358,9 @@ class Wav2Vec2ModelTest(ModelTesterMixin, unittest.TestCase): @require_torch class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase): - all_model_classes = (Wav2Vec2ForCTC, Wav2Vec2Model, Wav2Vec2ForMaskedLM) if is_torch_available() else () + all_model_classes = ( + (Wav2Vec2ForCTC, Wav2Vec2Model, Wav2Vec2ForMaskedLM, Wav2Vec2ForPreTraining) if is_torch_available() else () + ) test_pruning = False test_headmasking = False test_torchscript = False @@ -442,8 +456,14 @@ class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase): for model_class in self.all_model_classes: model = model_class(config=configs_no_init) for name, param in model.named_parameters(): + uniform_init_parms = [ + "conv.weight", + "masked_spec_embed", + "codevectors", + "quantizer.weight_proj.weight", + ] if param.requires_grad: - if "conv.weight" in name or "masked_spec_embed" in name: + if any([x in name for x in uniform_init_parms]): self.assertTrue( -1.0 <= ((param.data.mean() * 1e9).round() / 1e9).item() <= 1.0, msg=f"Parameter {name} of model {model_class} seems not properly initialized", @@ -459,10 +479,47 @@ class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase): def _mock_init_weights(self, module): if hasattr(module, "weight") and module.weight is not None: module.weight.data.fill_(3) - if hasattr(module, "weight_g") and module.weight is not None: + if hasattr(module, "weight_g") and module.weight_g is not None: module.weight_g.data.fill_(3) + if hasattr(module, "weight_v") and module.weight_v is not None: + module.weight_v.data.fill_(3) if hasattr(module, "bias") and module.bias is not None: module.bias.data.fill_(3) + if hasattr(module, "codevectors") and module.codevectors is not None: + module.codevectors.data.fill_(3) + + def test_model_for_pretraining(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + model = Wav2Vec2ForPreTraining(config).to(torch_device) + + features_shape = ( + inputs_dict["input_values"].shape[0], + model._get_feat_extract_output_lengths(torch.tensor(inputs_dict["input_values"].shape[1])), + ) + + mask_time_indices = _compute_mask_indices( + features_shape, + model.config.mask_time_prob, + model.config.mask_time_length, + device=inputs_dict["input_values"].device, + min_masks=2, + ).to(torch_device) + + loss = model( + inputs_dict["input_values"], + attention_mask=inputs_dict["attention_mask"], + mask_time_indices=mask_time_indices, + ).loss + + mask_time_indices[:, : mask_time_indices.shape[-1] // 2] = True + loss_more_masked = model( + inputs_dict["input_values"], + attention_mask=inputs_dict["attention_mask"], + mask_time_indices=mask_time_indices, + ).loss + + # loss_more_masked has to be bigger or equal loss since more masked inputs have to be predicted + self.assertTrue(loss.detach().item() <= loss_more_masked.detach().item()) @slow def test_model_from_pretrained(self): @@ -484,24 +541,56 @@ class Wav2Vec2UtilsTest(unittest.TestCase): def test_compute_mask_indices_overlap(self): batch_size = 4 - sequence_length = 60 + sequence_length = 80 mask_prob = 0.5 mask_length = 4 mask = _compute_mask_indices((batch_size, sequence_length), mask_prob, mask_length, torch_device) - # because of overlap there is a range of possible masks + # because of overlap mask don't have to add up exactly to `mask_prob * sequence_length`, but have to be smaller or equal for batch_sum in mask.sum(axis=-1): - self.assertIn( - int(batch_sum), - list(range(int(mask_prob // mask_length * sequence_length), int(mask_prob * sequence_length))), - ) + self.assertTrue(int(batch_sum) <= mask_prob * sequence_length) + + def test_compute_perplexity(self): + probs = torch.arange(100, device=torch_device).reshape(2, 5, 10) / 100 + + ppl = Wav2Vec2GumbelVectorQuantizer._compute_perplexity(probs) + self.assertTrue(abs(ppl.item() - 141.4291) < 1e-3) + + # mask half of the input + mask = torch.ones((2,), device=torch_device, dtype=torch.bool) + mask[0] = 0 + + ppl = Wav2Vec2GumbelVectorQuantizer._compute_perplexity(probs, mask) + self.assertTrue(abs(ppl.item() - 58.6757) < 1e-3) + + def test_sample_negatives(self): + batch_size = 2 + sequence_length = 10 + hidden_size = 4 + num_negatives = 3 + + features = (torch.arange(sequence_length * hidden_size, device=torch_device) // hidden_size).view( + sequence_length, hidden_size + ) # each value in vector consits of same value + features = features[None, :].expand(batch_size, sequence_length, hidden_size).contiguous() + + negatives = Wav2Vec2ForPreTraining._sample_negatives(features, num_negatives) + + self.assertTrue(negatives.shape == (num_negatives, batch_size, sequence_length, hidden_size)) + + # make sure no negatively sampled vector is actually a positive one + for negative in negatives: + self.assertTrue(((negative - features) == 0).sum() == 0.0) + + # make sure that full vectors are sampled and not values of vectors => this means that `unique()` yields a single value for `hidden_size` dim + self.assertTrue(negatives.unique(dim=-1).shape, (num_negatives, batch_size, sequence_length, 1)) @require_torch -@slow @require_datasets @require_soundfile +@slow class Wav2Vec2ModelIntegrationTest(unittest.TestCase): def _load_datasamples(self, num_samples): from datasets import load_dataset @@ -586,3 +675,160 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase): "his instant panic was followed by a small sharp blow high on his chest", ] self.assertListEqual(predicted_trans, EXPECTED_TRANSCRIPTIONS) + + def test_inference_integration(self): + model = Wav2Vec2ForPreTraining.from_pretrained("patrickvonplaten/wav2vec2-base") + model.to(torch_device) + feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained( + "patrickvonplaten/wav2vec2-base", return_attention_mask=True + ) + input_speech = self._load_datasamples(2) + + inputs_dict = feature_extractor(input_speech, return_tensors="pt", padding=True) + + features_shape = ( + inputs_dict["input_values"].shape[0], + model._get_feat_extract_output_lengths(torch.tensor(inputs_dict["input_values"].shape[1])), + ) + + torch.manual_seed(0) + mask_time_indices = _compute_mask_indices( + features_shape, + model.config.mask_time_prob, + model.config.mask_time_length, + device=inputs_dict["input_values"].device, + min_masks=2, + ).to(torch_device) + + with torch.no_grad(): + outputs = model( + inputs_dict.input_values.to(torch_device), + attention_mask=inputs_dict.attention_mask.to(torch_device), + mask_time_indices=mask_time_indices, + ) + + # compute cosine similarity + cosine_sim = torch.cosine_similarity(outputs.projected_states, outputs.projected_quantized_states, dim=-1) + + # retrieve cosine sim of masked features + cosine_sim_masked = cosine_sim[mask_time_indices] + + # fmt: off + expected_cosine_sim_masked = torch.tensor( + [0.7458, 0.7188, 0.6418, 0.3729, 0.3741, 0.3694, 0.3110, 0.2257, 0.4403, 0.5415, 0.3950, 0.3701, 0.8831, 0.8613, 0.5229, 0.6696, 0.7206, 0.7877, 0.6758, 0.8746, 0.6596, 0.6282, 0.6178, 0.5839, 0.5926, 0.6651, 0.4635, 0.6332, 0.6572, 0.8776, 0.4999, 0.7001, 0.7257, 0.5098, 0.6229, 0.4566, 0.5261, 0.6363, 0.5371, 0.6997], + device=torch_device, + ) + # fmt: on + + self.assertTrue(torch.allclose(cosine_sim_masked, expected_cosine_sim_masked, atol=1e-3)) + + def test_inference_pretrained(self): + model = Wav2Vec2ForPreTraining.from_pretrained("patrickvonplaten/wav2vec2-base") + model.to(torch_device) + feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained( + "patrickvonplaten/wav2vec2-base", return_attention_mask=True + ) + input_speech = self._load_datasamples(2) + + inputs_dict = feature_extractor(input_speech, return_tensors="pt", padding=True) + + features_shape = ( + inputs_dict["input_values"].shape[0], + model._get_feat_extract_output_lengths(torch.tensor(inputs_dict["input_values"].shape[1])), + ) + + torch.manual_seed(0) + mask_time_indices = _compute_mask_indices( + features_shape, + model.config.mask_time_prob, + model.config.mask_time_length, + device=inputs_dict["input_values"].device, + min_masks=2, + ).to(torch_device) + + with torch.no_grad(): + outputs = model( + inputs_dict.input_values.to(torch_device), + attention_mask=inputs_dict.attention_mask.to(torch_device), + mask_time_indices=mask_time_indices, + ) + + # compute cosine similarity + cosine_sim = torch.cosine_similarity(outputs.projected_states, outputs.projected_quantized_states, dim=-1) + + # retrieve cosine sim of masked features + cosine_sim_masked = cosine_sim[mask_time_indices] + + # ... now compare to randomly initialized model + + config = Wav2Vec2Config.from_pretrained("patrickvonplaten/wav2vec2-base") + model_rand = Wav2Vec2ForPreTraining(config).to(torch_device).eval() + + with torch.no_grad(): + outputs_rand = model_rand( + inputs_dict.input_values.to(torch_device), + attention_mask=inputs_dict.attention_mask.to(torch_device), + mask_time_indices=mask_time_indices, + ) + + # compute cosine similarity + cosine_sim_rand = torch.cosine_similarity( + outputs_rand.projected_states, outputs_rand.projected_quantized_states, dim=-1 + ) + + # retrieve cosine sim of masked features + cosine_sim_masked_rand = cosine_sim_rand[mask_time_indices] + + # a pretrained wav2vec2 model has learned to predict the quantized latent states + # => the cosine similarity between quantized states and predicted states > 0.5 + # a random wav2vec2 model has not learned to predict the quantized latent states + # => the cosine similarity between quantized states and predicted states is very likely < 0.1 + self.assertTrue(cosine_sim_masked.mean().item() - 5 * cosine_sim_masked_rand.mean().item() > 0) + + def test_loss_pretraining(self): + model = Wav2Vec2ForPreTraining.from_pretrained( + "patrickvonplaten/wav2vec2-base", + attention_dropout=0.0, + feat_proj_dropout=0.0, + hidden_dropout=0.0, + layerdrop=0.0, + ) + model.to(torch_device).train() + + feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained( + "patrickvonplaten/wav2vec2-base", return_attention_mask=True + ) + input_speech = self._load_datasamples(2) + + inputs_dict = feature_extractor(input_speech, return_tensors="pt", padding=True) + + features_shape = ( + inputs_dict["input_values"].shape[0], + model._get_feat_extract_output_lengths(inputs_dict["input_values"].shape[1]), + ) + + torch.manual_seed(0) + mask_time_indices = _compute_mask_indices( + features_shape, + model.config.mask_time_prob, + model.config.mask_time_length, + device=inputs_dict["input_values"].device, + min_masks=2, + ).to(torch_device) + + with torch.no_grad(): + outputs = model( + inputs_dict.input_values.to(torch_device), + attention_mask=inputs_dict.attention_mask.to(torch_device), + mask_time_indices=mask_time_indices, + ) + + # check diversity loss + num_codevectors = model.config.num_codevectors_per_group * model.config.num_codevector_groups + diversity_loss = (num_codevectors - outputs.codevector_perplexity) / num_codevectors + self.assertTrue(abs(diversity_loss.item() - 0.8859) < 1e-3) + + # check overall loss (contrastive loss + diversity loss) + expected_loss = 62.5170 if model.device.type == "cpu" else 50.3612 + + self.assertTrue(abs(outputs.loss.item() - expected_loss) < 1e-3)