diff --git a/setup.py b/setup.py index 6993962bf9f..d3271b5e54a 100644 --- a/setup.py +++ b/setup.py @@ -307,7 +307,12 @@ extras["hub-kernels"] = deps_list("kernels") extras["integrations"] = extras["hub-kernels"] + extras["optuna"] + extras["ray"] + extras["sigopt"] extras["serving"] = deps_list("pydantic", "uvicorn", "fastapi", "starlette") -extras["audio"] = deps_list("librosa", "pyctcdecode", "phonemizer", "kenlm@git+https://github.com/ydshieh/kenlm@78f664fb3dafe1468d868d71faf19534530698d5") +extras["audio"] = deps_list( + "librosa", + "pyctcdecode", + "phonemizer", + "kenlm@git+https://github.com/ydshieh/kenlm@78f664fb3dafe1468d868d71faf19534530698d5", +) # `pip install ".[speech]"` is deprecated and `pip install ".[torch-speech]"` should be used instead extras["speech"] = deps_list("torchaudio") + extras["audio"] extras["torch-speech"] = deps_list("torchaudio") + extras["audio"] diff --git a/src/transformers/models/cohere2/modular_cohere2.py b/src/transformers/models/cohere2/modular_cohere2.py index 7c73f3e185d..3d1bdaeca94 100644 --- a/src/transformers/models/cohere2/modular_cohere2.py +++ b/src/transformers/models/cohere2/modular_cohere2.py @@ -14,7 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from functools import partial -from typing import Callable, Optional, Tuple, Union +from typing import Callable, Optional, Tuple import torch import torch.nn as nn @@ -465,7 +465,7 @@ class Cohere2Model(Gemma2Model): cache_position: Optional[torch.LongTensor] = None, last_cache_position: Optional[int] = None, **flash_attn_kwargs: Unpack[FlashAttentionKwargs], - ) -> Union[Tuple, BaseModelOutputWithPast]: + ) -> BaseModelOutputWithPast: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states diff --git a/src/transformers/models/data2vec/modeling_data2vec_audio.py b/src/transformers/models/data2vec/modeling_data2vec_audio.py index f3ee5e11fc6..7f799e81297 100755 --- a/src/transformers/models/data2vec/modeling_data2vec_audio.py +++ b/src/transformers/models/data2vec/modeling_data2vec_audio.py @@ -1,26 +1,15 @@ -# coding=utf-8 -# Copyright 2021 The Fairseq Authors and the HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""PyTorch Data2VecAudio model.""" - +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/data2vec/modular_data2vec_audio.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_data2vec_audio.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 import math import warnings from typing import Optional, Tuple, Union import numpy as np import torch -import torch.utils.checkpoint from torch import nn from torch.nn import CrossEntropyLoss @@ -50,141 +39,14 @@ from .configuration_data2vec_audio import Data2VecAudioConfig if is_flash_attn_available(): from ...modeling_flash_attention_utils import _flash_attention_forward + logger = logging.get_logger(__name__) - -_HIDDEN_STATES_START_POSITION = 2 - -# General docstring -_CONFIG_FOR_DOC = "Data2VecAudioConfig" - # Base docstring _CHECKPOINT_FOR_DOC = "facebook/data2vec-audio-base-960h" -_EXPECTED_OUTPUT_SHAPE = [1, 292, 768] -# CTC docstring -_CTC_EXPECTED_OUTPUT = "'MISTER QUILTER IS THE APOSTLE OF THE MIDDLE CLASSES AND WE ARE GLAD TO WELCOME HIS GOSPEL'" -_CTC_EXPECTED_LOSS = 66.95 - - -# Copied from transformers.models.wav2vec2.modeling_wav2vec2._compute_mask_indices -def _compute_mask_indices( - shape: Tuple[int, int], - mask_prob: float, - mask_length: int, - attention_mask: Optional[torch.LongTensor] = None, - min_masks: int = 0, -) -> np.ndarray: - """ - Computes random mask spans for a given shape. Used to implement [SpecAugment: A Simple Data Augmentation Method for - ASR](https://arxiv.org/abs/1904.08779). Note that this method is not optimized to run on TPU and should be run on - CPU as part of the preprocessing during training. - - Args: - shape: The shape for which to compute masks. This should be of a tuple of size 2 where - the first element is the batch size and the second element is the length of the axis to span. - mask_prob: The percentage of the whole axis (between 0 and 1) which will be masked. The number of - independently generated mask spans of length `mask_length` is computed by - `mask_prob*shape[1]/mask_length`. Note that due to overlaps, `mask_prob` is an upper bound and the - actual percentage will be smaller. - mask_length: size of the mask - min_masks: minimum number of masked spans - attention_mask: A (right-padded) attention mask which independently shortens the feature axis of - each batch dimension. - """ - batch_size, sequence_length = shape - - if mask_length < 1: - raise ValueError("`mask_length` has to be bigger than 0.") - - if mask_length > sequence_length: - raise ValueError( - f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length}" - f" and `sequence_length`: {sequence_length}`" - ) - - # epsilon is used for probabilistic rounding - epsilon = np.random.rand(1).item() - - def compute_num_masked_span(input_length): - """Given input length, compute how many spans should be masked""" - num_masked_span = int(mask_prob * input_length / mask_length + epsilon) - num_masked_span = max(num_masked_span, min_masks) - - # make sure num masked span <= sequence_length - if num_masked_span * mask_length > sequence_length: - num_masked_span = sequence_length // mask_length - - # make sure num_masked span is also <= input_length - (mask_length - 1) - if input_length - (mask_length - 1) < num_masked_span: - num_masked_span = max(input_length - (mask_length - 1), 0) - - return num_masked_span - - # compute number of masked spans in batch - input_lengths = ( - attention_mask.detach().sum(-1).tolist() - if attention_mask is not None - else [sequence_length for _ in range(batch_size)] - ) - - # SpecAugment mask to fill - spec_aug_mask = np.zeros((batch_size, sequence_length), dtype=bool) - spec_aug_mask_idxs = [] - - max_num_masked_span = compute_num_masked_span(sequence_length) - - if max_num_masked_span == 0: - return spec_aug_mask - - for input_length in input_lengths: - # compute num of masked spans for this input - num_masked_span = compute_num_masked_span(input_length) - - # get random indices to mask - spec_aug_mask_idx = np.random.choice( - np.arange(input_length - (mask_length - 1)), num_masked_span, replace=False - ) - - # pick first sampled index that will serve as a dummy index to pad vector - # to ensure same dimension for all batches due to probabilistic rounding - # Picking first sample just pads those vectors twice. - if len(spec_aug_mask_idx) == 0: - # this case can only happen if `input_length` is strictly smaller then - # `sequence_length` in which case the last token has to be a padding - # token which we can use as a dummy mask id - dummy_mask_idx = sequence_length - 1 - else: - dummy_mask_idx = spec_aug_mask_idx[0] - - spec_aug_mask_idx = np.concatenate( - [spec_aug_mask_idx, np.ones(max_num_masked_span - num_masked_span, dtype=np.int32) * dummy_mask_idx] - ) - spec_aug_mask_idxs.append(spec_aug_mask_idx) - - spec_aug_mask_idxs = np.array(spec_aug_mask_idxs) - - # expand masked indices to masked spans - spec_aug_mask_idxs = np.broadcast_to( - spec_aug_mask_idxs[:, :, None], (batch_size, max_num_masked_span, mask_length) - ) - spec_aug_mask_idxs = spec_aug_mask_idxs.reshape(batch_size, max_num_masked_span * mask_length) - - # add offset to the starting indexes so that indexes now create a span - offsets = np.arange(mask_length)[None, None, :] - offsets = np.broadcast_to(offsets, (batch_size, max_num_masked_span, mask_length)).reshape( - batch_size, max_num_masked_span * mask_length - ) - spec_aug_mask_idxs = spec_aug_mask_idxs + offsets - - # ensure that we cannot have indices larger than sequence_length - if spec_aug_mask_idxs.max() > sequence_length - 1: - spec_aug_mask_idxs[spec_aug_mask_idxs > sequence_length - 1] = sequence_length - 1 - - # scatter indices to mask - np.put_along_axis(spec_aug_mask, spec_aug_mask_idxs, 1, -1) - - return spec_aug_mask +# General docstring +_CONFIG_FOR_DOC = "Data2VecAudioConfig" class Data2VecAudioConvLayer(nn.Module): @@ -214,7 +76,6 @@ class Data2VecAudioConvLayer(nn.Module): return hidden_states -# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2SamePadLayer with Wav2Vec2->Data2VecAudio class Data2VecAudioPadLayer(nn.Module): def __init__(self, num_conv_pos_embeddings): super().__init__() @@ -279,13 +140,11 @@ class Data2VecAudioFeatureEncoder(nn.Module): self.gradient_checkpointing = False self._requires_grad = True - # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2FeatureEncoder._freeze_parameters def _freeze_parameters(self): for param in self.parameters(): param.requires_grad = False self._requires_grad = False - # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2FeatureEncoder.forward def forward(self, input_values): hidden_states = input_values[:, None] @@ -305,7 +164,6 @@ class Data2VecAudioFeatureEncoder(nn.Module): return hidden_states -# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2FeatureProjection with Wav2Vec2->Data2VecAudio class Data2VecAudioFeatureProjection(nn.Module): def __init__(self, config): super().__init__() @@ -321,7 +179,6 @@ class Data2VecAudioFeatureProjection(nn.Module): return hidden_states, norm_hidden_states -# Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->Data2VecAudio class Data2VecAudioAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -480,7 +337,6 @@ class Data2VecAudioAttention(nn.Module): return attn_output, attn_weights_reshaped, past_key_value -# Copied from transformers.models.bart.modeling_bart.BartFlashAttention2 with Bart->Data2VecAudio class Data2VecAudioFlashAttention2(Data2VecAudioAttention): """ Data2VecAudio flash attention module. This module inherits from `Data2VecAudioAttention` as the weights of the module stays @@ -608,7 +464,6 @@ class Data2VecAudioFlashAttention2(Data2VecAudioAttention): class Data2VecAudioSdpaAttention(Data2VecAudioAttention): - # Copied from transformers.models.bart.modeling_bart.BartSdpaAttention.forward with Bart->Data2VecAudio def forward( self, hidden_states: torch.Tensor, @@ -714,14 +569,6 @@ class Data2VecAudioSdpaAttention(Data2VecAudioAttention): return attn_output, None, past_key_value -DATA2VEC2AUDIO_ATTENTION_CLASSES = { - "eager": Data2VecAudioAttention, - "sdpa": Data2VecAudioSdpaAttention, - "flash_attention_2": Data2VecAudioFlashAttention2, -} - - -# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2FeedForward with Wav2Vec2->Data2VecAudio class Data2VecAudioFeedForward(nn.Module): def __init__(self, config): super().__init__() @@ -746,11 +593,17 @@ class Data2VecAudioFeedForward(nn.Module): return hidden_states -# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2EncoderLayer with Wav2Vec2->Data2VecAudio, WAV2VEC2->DATA2VEC2AUDIO +DATA2VEC_AUDIO_ATTENTION_CLASSES = { + "eager": Data2VecAudioAttention, + "sdpa": Data2VecAudioSdpaAttention, + "flash_attention_2": Data2VecAudioFlashAttention2, +} + + class Data2VecAudioEncoderLayer(nn.Module): def __init__(self, config): super().__init__() - self.attention = DATA2VEC2AUDIO_ATTENTION_CLASSES[config._attn_implementation]( + self.attention = DATA2VEC_AUDIO_ATTENTION_CLASSES[config._attn_implementation]( embed_dim=config.hidden_size, num_heads=config.num_attention_heads, dropout=config.attention_dropout, @@ -782,7 +635,6 @@ class Data2VecAudioEncoderLayer(nn.Module): return outputs -# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Encoder with Wav2Vec2->Data2VecAudio class Data2VecAudioEncoder(nn.Module): def __init__(self, config): super().__init__() @@ -868,7 +720,24 @@ class Data2VecAudioEncoder(nn.Module): ) -# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Adapter with Wav2Vec2->Data2VecAudio +class Data2VecAudioAdapterLayer(nn.Module): + def __init__(self, config): + super().__init__() + self.conv = nn.Conv1d( + config.output_hidden_size, + 2 * config.output_hidden_size, + config.adapter_kernel_size, + stride=config.adapter_stride, + padding=1, + ) + + def forward(self, hidden_states): + hidden_states = self.conv(hidden_states) + hidden_states = nn.functional.glu(hidden_states, dim=1) + + return hidden_states + + class Data2VecAudioAdapter(nn.Module): def __init__(self, config): super().__init__() @@ -900,25 +769,6 @@ class Data2VecAudioAdapter(nn.Module): return hidden_states -# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2AdapterLayer with Wav2Vec2->Data2VecAudio -class Data2VecAudioAdapterLayer(nn.Module): - def __init__(self, config): - super().__init__() - self.conv = nn.Conv1d( - config.output_hidden_size, - 2 * config.output_hidden_size, - config.adapter_kernel_size, - stride=config.adapter_stride, - padding=1, - ) - - def forward(self, hidden_states): - hidden_states = self.conv(hidden_states) - hidden_states = nn.functional.glu(hidden_states, dim=1) - - return hidden_states - - class Data2VecAudioPreTrainedModel(PreTrainedModel): """ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained @@ -957,7 +807,6 @@ class Data2VecAudioPreTrainedModel(PreTrainedModel): k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0])) nn.init.uniform_(module.bias, a=-k, b=k) - # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2PreTrainedModel._get_feat_extract_output_lengths with def _get_feat_extract_output_lengths( self, input_lengths: Union[torch.LongTensor, int], add_adapter: Optional[bool] = None ): @@ -981,7 +830,6 @@ class Data2VecAudioPreTrainedModel(PreTrainedModel): return input_lengths - # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2PreTrainedModel._get_feature_vector_attention_mask def _get_feature_vector_attention_mask( self, feature_vector_length: int, attention_mask: torch.LongTensor, add_adapter=None ): @@ -1003,6 +851,128 @@ class Data2VecAudioPreTrainedModel(PreTrainedModel): return attention_mask +def _compute_mask_indices( + shape: Tuple[int, int], + mask_prob: float, + mask_length: int, + attention_mask: Optional[torch.LongTensor] = None, + min_masks: int = 0, +) -> np.ndarray: + """ + Computes random mask spans for a given shape. Used to implement [SpecAugment: A Simple Data Augmentation Method for + ASR](https://arxiv.org/abs/1904.08779). Note that this method is not optimized to run on TPU and should be run on + CPU as part of the preprocessing during training. + + Args: + shape: The shape for which to compute masks. This should be of a tuple of size 2 where + the first element is the batch size and the second element is the length of the axis to span. + mask_prob: The percentage of the whole axis (between 0 and 1) which will be masked. The number of + independently generated mask spans of length `mask_length` is computed by + `mask_prob*shape[1]/mask_length`. Note that due to overlaps, `mask_prob` is an upper bound and the + actual percentage will be smaller. + mask_length: size of the mask + min_masks: minimum number of masked spans + attention_mask: A (right-padded) attention mask which independently shortens the feature axis of + each batch dimension. + """ + batch_size, sequence_length = shape + + if mask_length < 1: + raise ValueError("`mask_length` has to be bigger than 0.") + + if mask_length > sequence_length: + raise ValueError( + f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length}" + f" and `sequence_length`: {sequence_length}`" + ) + + # epsilon is used for probabilistic rounding + epsilon = np.random.rand(1).item() + + def compute_num_masked_span(input_length): + """Given input length, compute how many spans should be masked""" + num_masked_span = int(mask_prob * input_length / mask_length + epsilon) + num_masked_span = max(num_masked_span, min_masks) + + # make sure num masked span <= sequence_length + if num_masked_span * mask_length > sequence_length: + num_masked_span = sequence_length // mask_length + + # make sure num_masked span is also <= input_length - (mask_length - 1) + if input_length - (mask_length - 1) < num_masked_span: + num_masked_span = max(input_length - (mask_length - 1), 0) + + return num_masked_span + + # compute number of masked spans in batch + input_lengths = ( + attention_mask.detach().sum(-1).tolist() + if attention_mask is not None + else [sequence_length for _ in range(batch_size)] + ) + + # SpecAugment mask to fill + spec_aug_mask = np.zeros((batch_size, sequence_length), dtype=bool) + spec_aug_mask_idxs = [] + + max_num_masked_span = compute_num_masked_span(sequence_length) + + if max_num_masked_span == 0: + return spec_aug_mask + + for input_length in input_lengths: + # compute num of masked spans for this input + num_masked_span = compute_num_masked_span(input_length) + + # get random indices to mask + spec_aug_mask_idx = np.random.choice( + np.arange(input_length - (mask_length - 1)), num_masked_span, replace=False + ) + + # pick first sampled index that will serve as a dummy index to pad vector + # to ensure same dimension for all batches due to probabilistic rounding + # Picking first sample just pads those vectors twice. + if len(spec_aug_mask_idx) == 0: + # this case can only happen if `input_length` is strictly smaller then + # `sequence_length` in which case the last token has to be a padding + # token which we can use as a dummy mask id + dummy_mask_idx = sequence_length - 1 + else: + dummy_mask_idx = spec_aug_mask_idx[0] + + spec_aug_mask_idx = np.concatenate( + [spec_aug_mask_idx, np.ones(max_num_masked_span - num_masked_span, dtype=np.int32) * dummy_mask_idx] + ) + spec_aug_mask_idxs.append(spec_aug_mask_idx) + + spec_aug_mask_idxs = np.array(spec_aug_mask_idxs) + + # expand masked indices to masked spans + spec_aug_mask_idxs = np.broadcast_to( + spec_aug_mask_idxs[:, :, None], (batch_size, max_num_masked_span, mask_length) + ) + spec_aug_mask_idxs = spec_aug_mask_idxs.reshape(batch_size, max_num_masked_span * mask_length) + + # add offset to the starting indexes so that indexes now create a span + offsets = np.arange(mask_length)[None, None, :] + offsets = np.broadcast_to(offsets, (batch_size, max_num_masked_span, mask_length)).reshape( + batch_size, max_num_masked_span * mask_length + ) + spec_aug_mask_idxs = spec_aug_mask_idxs + offsets + + # ensure that we cannot have indices larger than sequence_length + if spec_aug_mask_idxs.max() > sequence_length - 1: + spec_aug_mask_idxs[spec_aug_mask_idxs > sequence_length - 1] = sequence_length - 1 + + # scatter indices to mask + np.put_along_axis(spec_aug_mask, spec_aug_mask_idxs, 1, -1) + + return spec_aug_mask + + +_EXPECTED_OUTPUT_SHAPE = [1, 292, 768] + + DATA2VEC_AUDIO_START_DOCSTRING = r""" Data2VecAudio was proposed in [data2vec: A General Framework for Self-supervised Learning in Speech, Vision and Language](https://arxiv.org/pdf/2202.03555) by Alexei Baevski, Wei-Ning Hsu, Qiantong Xu, Arun Babu, Jiatao Gu and @@ -1021,7 +991,6 @@ DATA2VEC_AUDIO_START_DOCSTRING = r""" configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. """ - DATA2VEC_AUDIO_INPUTS_DOCSTRING = r""" Args: input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`): @@ -1058,6 +1027,8 @@ DATA2VEC_AUDIO_INPUTS_DOCSTRING = r""" Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. """ +Data2VecAudioBaseModelOutput = Wav2Vec2BaseModelOutput + @add_start_docstrings( "The bare Data2VecAudio Model transformer outputting raw hidden-states without any specific head on top.", @@ -1137,7 +1108,7 @@ class Data2VecAudioModel(Data2VecAudioPreTrainedModel): @add_start_docstrings_to_model_forward(DATA2VEC_AUDIO_INPUTS_DOCSTRING) @add_code_sample_docstrings( checkpoint=_CHECKPOINT_FOR_DOC, - output_type=Wav2Vec2BaseModelOutput, + output_type=Data2VecAudioBaseModelOutput, config_class=_CONFIG_FOR_DOC, modality="audio", expected_output=_EXPECTED_OUTPUT_SHAPE, @@ -1150,7 +1121,7 @@ class Data2VecAudioModel(Data2VecAudioPreTrainedModel): output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, - ) -> Union[Tuple, Wav2Vec2BaseModelOutput]: + ) -> Union[Tuple, Data2VecAudioBaseModelOutput]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states @@ -1187,7 +1158,7 @@ class Data2VecAudioModel(Data2VecAudioPreTrainedModel): if not return_dict: return (hidden_states, extract_features) + encoder_outputs[1:] - return Wav2Vec2BaseModelOutput( + return Data2VecAudioBaseModelOutput( last_hidden_state=hidden_states, extract_features=extract_features, hidden_states=encoder_outputs.hidden_states, @@ -1195,6 +1166,13 @@ class Data2VecAudioModel(Data2VecAudioPreTrainedModel): ) +_HIDDEN_STATES_START_POSITION = 2 + +# CTC docstring +_CTC_EXPECTED_OUTPUT = "'MISTER QUILTER IS THE APOSTLE OF THE MIDDLE CLASSES AND WE ARE GLAD TO WELCOME HIS GOSPEL'" +_CTC_EXPECTED_LOSS = 66.95 + + @add_start_docstrings( """Data2VecAudio Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC).""", DATA2VEC_AUDIO_START_DOCSTRING, @@ -1248,7 +1226,6 @@ class Data2VecAudioForCTC(Data2VecAudioPreTrainedModel): expected_output=_CTC_EXPECTED_OUTPUT, expected_loss=_CTC_EXPECTED_LOSS, ) - # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForCTC.forward with wav2vec2->data2vec_audio def forward( self, input_values: Optional[torch.Tensor], @@ -1379,7 +1356,6 @@ class Data2VecAudioForSequenceClassification(Data2VecAudioPreTrainedModel): config_class=_CONFIG_FOR_DOC, modality="audio", ) - # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.forward with wav2vec2->data2vec_audio def forward( self, input_values: Optional[torch.Tensor], @@ -1455,8 +1431,7 @@ class Data2VecAudioForAudioFrameClassification(Data2VecAudioPreTrainedModel): if hasattr(config, "add_adapter") and config.add_adapter: raise ValueError( - "Audio frame classification does not support the use of Data2VecAudio adapters" - " (config.add_adapter=True)" + "Audio frame classification does not support the use of Data2VecAudio adapters (config.add_adapter=True)" ) self.data2vec_audio = Data2VecAudioModel(config) num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings @@ -1501,7 +1476,6 @@ class Data2VecAudioForAudioFrameClassification(Data2VecAudioPreTrainedModel): config_class=_CONFIG_FOR_DOC, modality="audio", ) - # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForAudioFrameClassification.forward with wav2vec2->data2vec_audio def forward( self, input_values: Optional[torch.Tensor], @@ -1556,7 +1530,6 @@ class Data2VecAudioForAudioFrameClassification(Data2VecAudioPreTrainedModel): ) -# Copied from transformers.models.wav2vec2.modeling_wav2vec2.AMSoftmaxLoss class AMSoftmaxLoss(nn.Module): def __init__(self, input_dim, num_labels, scale=30.0, margin=0.4): super(AMSoftmaxLoss, self).__init__() @@ -1580,7 +1553,6 @@ class AMSoftmaxLoss(nn.Module): return loss -# Copied from transformers.models.wav2vec2.modeling_wav2vec2.TDNNLayer class TDNNLayer(nn.Module): def __init__(self, config, layer_id=0): super().__init__() @@ -1596,6 +1568,7 @@ class TDNNLayer(nn.Module): if is_peft_available(): from peft.tuners.lora import LoraLayer + if is_peft_available(): if isinstance(self.kernel, LoraLayer): warnings.warn( "Detected LoRA on TDNNLayer. LoRA weights won't be applied due to optimization. " @@ -1687,7 +1660,6 @@ class Data2VecAudioForXVector(Data2VecAudioPreTrainedModel): config_class=_CONFIG_FOR_DOC, modality="audio", ) - # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForXVector.forward with wav2vec2->data2vec_audio def forward( self, input_values: Optional[torch.Tensor], diff --git a/src/transformers/models/data2vec/modular_data2vec_audio.py b/src/transformers/models/data2vec/modular_data2vec_audio.py new file mode 100644 index 00000000000..052f22a960f --- /dev/null +++ b/src/transformers/models/data2vec/modular_data2vec_audio.py @@ -0,0 +1,400 @@ +import math + +import torch +from torch import nn + +from ...activations import ACT2FN +from ...modeling_outputs import ( + CausalLMOutput, + SequenceClassifierOutput, + TokenClassifierOutput, + Wav2Vec2BaseModelOutput, + XVectorOutput, +) +from ...modeling_utils import PreTrainedModel +from ...utils import ( + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, +) +from ..wav2vec2.modeling_wav2vec2 import ( + Wav2Vec2Adapter, + Wav2Vec2Encoder, + Wav2Vec2FeatureEncoder, + Wav2Vec2FeatureProjection, + Wav2Vec2ForAudioFrameClassification, + Wav2Vec2ForCTC, + Wav2Vec2ForSequenceClassification, + Wav2Vec2ForXVector, + Wav2Vec2Model, + Wav2Vec2PreTrainedModel, + Wav2Vec2SamePadLayer, +) +from .configuration_data2vec_audio import Data2VecAudioConfig + + +_HIDDEN_STATES_START_POSITION = 2 + +# General docstring +_CONFIG_FOR_DOC = "Data2VecAudioConfig" + +# Base docstring +_CHECKPOINT_FOR_DOC = "facebook/data2vec-audio-base-960h" +_EXPECTED_OUTPUT_SHAPE = [1, 292, 768] + +# CTC docstring +_CTC_EXPECTED_OUTPUT = "'MISTER QUILTER IS THE APOSTLE OF THE MIDDLE CLASSES AND WE ARE GLAD TO WELCOME HIS GOSPEL'" +_CTC_EXPECTED_LOSS = 66.95 + + +class Data2VecAudioConvLayer(nn.Module): + def __init__(self, config, layer_id=0): + super().__init__() + self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1 + self.out_conv_dim = config.conv_dim[layer_id] + + self.conv = nn.Conv1d( + self.in_conv_dim, + self.out_conv_dim, + kernel_size=config.conv_kernel[layer_id], + stride=config.conv_stride[layer_id], + bias=config.conv_bias, + ) + self.layer_norm = nn.LayerNorm(self.out_conv_dim, elementwise_affine=True) + self.activation = ACT2FN[config.feat_extract_activation] + + def forward(self, hidden_states): + hidden_states = self.conv(hidden_states) + + hidden_states = hidden_states.transpose(-2, -1) + hidden_states = self.layer_norm(hidden_states) + hidden_states = hidden_states.transpose(-2, -1) + + hidden_states = self.activation(hidden_states) + return hidden_states + + +class Data2VecAudioPadLayer(Wav2Vec2SamePadLayer): + pass + + +class Data2VecAudioPositionalConvLayer(nn.Module): + def __init__(self, config): + super().__init__() + self.conv = nn.Conv1d( + config.hidden_size, + config.hidden_size, + kernel_size=config.conv_pos_kernel_size, + padding=config.conv_pos_kernel_size // 2, + groups=config.num_conv_pos_embedding_groups, + ) + + self.padding = Data2VecAudioPadLayer(config.conv_pos_kernel_size) + self.activation = ACT2FN[config.feat_extract_activation] + # no learnable parameters + self.layer_norm = nn.LayerNorm(config.hidden_size, elementwise_affine=False) + + def forward(self, hidden_states): + hidden_states = self.conv(hidden_states) + hidden_states = self.padding(hidden_states) + + hidden_states = hidden_states.transpose(1, 2) + hidden_states = self.layer_norm(hidden_states) + hidden_states = hidden_states.transpose(1, 2) + hidden_states = self.activation(hidden_states) + return hidden_states + + +class Data2VecAudioPositionalConvEmbedding(nn.Module): + def __init__(self, config): + super().__init__() + self.layers = nn.ModuleList( + [Data2VecAudioPositionalConvLayer(config) for _ in range(config.num_conv_pos_embeddings)] + ) + + def forward(self, hidden_states): + hidden_states = hidden_states.transpose(1, 2) + for layer in self.layers: + hidden_states = layer(hidden_states) + hidden_states = hidden_states.transpose(1, 2) + return hidden_states + + +class Data2VecAudioFeatureEncoder(Wav2Vec2FeatureEncoder, nn.Module): + def __init__(self, config): + nn.Module.__init__() + self.conv_layers = nn.ModuleList( + [Data2VecAudioConvLayer(config, layer_id=i) for i in range(config.num_feat_extract_layers)] + ) + self.gradient_checkpointing = False + self._requires_grad = True + + +class Data2VecAudioFeatureProjection(Wav2Vec2FeatureProjection): + pass + + +class Data2VecAudioEncoder(Wav2Vec2Encoder): + pass + + +class Data2VecAudioAdapter(Wav2Vec2Adapter): + pass + + +class Data2VecAudioPreTrainedModel(PreTrainedModel, Wav2Vec2PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = Data2VecAudioConfig + base_model_prefix = "data2vec_audio" + main_input_name = "input_values" + supports_gradient_checkpointing = True + _supports_flash_attn_2 = True + _supports_sdpa = True + + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, Data2VecAudioFeatureProjection): + k = math.sqrt(1 / module.projection.in_features) + nn.init.uniform_(module.projection.weight, a=-k, b=k) + nn.init.uniform_(module.projection.bias, a=-k, b=k) + elif isinstance(module, Data2VecAudioPositionalConvLayer): + nn.init.constant_(module.conv.bias, 0) + elif isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)): + if module.bias is not None: + module.bias.data.zero_() + if module.weight is not None: + module.weight.data.fill_(1.0) + elif isinstance(module, nn.Conv1d): + nn.init.kaiming_normal_(module.weight) + + if module.bias is not None: + k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0])) + nn.init.uniform_(module.bias, a=-k, b=k) + + def _get_adapters(self): + raise AttributeError("Not needed for Data2VecAudio") + + def init_adapter_layers(self): + raise AttributeError("Not needed for Data2VecAudio") + + def load_adapter(self): + raise AttributeError("Not needed for Data2VecAudio") + + +DATA2VEC_AUDIO_START_DOCSTRING = r""" + Data2VecAudio was proposed in [data2vec: A General Framework for Self-supervised Learning in Speech, Vision and + Language](https://arxiv.org/pdf/2202.03555) by Alexei Baevski, Wei-Ning Hsu, Qiantong Xu, Arun Babu, Jiatao Gu and + Michael Auli. + + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving etc.). + + This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use + it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and + behavior. + + Parameters: + config ([`Data2VecAudioConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +DATA2VEC_AUDIO_INPUTS_DOCSTRING = r""" + Args: + input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`): + Float values of input raw speech waveform. Values can be obtained by loading a *.flac* or *.wav* audio file + into an array of type *List[float]* or a *numpy.ndarray*, *e.g.* via the soundfile library (*pip install + soundfile*). To prepare the array into *input_values*, the [`AutoProcessor`] should be used for padding and + conversion into a tensor of type *torch.FloatTensor*. See [`Wav2Vec2Processor.__call__`] for details. + attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing convolution and attention on padding token indices. Mask values selected in `[0, + 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + + + `attention_mask` should be passed if the corresponding processor has `config.return_attention_mask == + True`, which is the case for all pre-trained Data2Vec Audio models. Be aware that that even with + `attention_mask`, zero-padded inputs will have slightly different outputs compared to non-padded inputs + because there are more than one convolutional layer in the positional encodings. For a more detailed + explanation, see [here](https://github.com/huggingface/transformers/issues/25621#issuecomment-1713759349). + + + + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + +Data2VecAudioBaseModelOutput = Wav2Vec2BaseModelOutput + + +@add_start_docstrings( + "The bare Data2VecAudio Model transformer outputting raw hidden-states without any specific head on top.", + DATA2VEC_AUDIO_START_DOCSTRING, +) +class Data2VecAudioModel(Data2VecAudioPreTrainedModel, Wav2Vec2Model): + def __init__(self, config: Data2VecAudioConfig): + Data2VecAudioPreTrainedModel.__init__(config) + self.config = config + self.feature_extractor = Data2VecAudioFeatureEncoder(config) + self.feature_projection = Data2VecAudioFeatureProjection(config) + + # model only needs masking vector if mask prob is > 0.0 + if config.mask_time_prob > 0.0 or config.mask_feature_prob > 0.0: + self.masked_spec_embed = nn.Parameter(torch.Tensor(config.hidden_size).uniform_()) + + self.encoder = Data2VecAudioEncoder(config) + + self.adapter = Data2VecAudioAdapter(config) if config.add_adapter else None + + # Initialize weights and apply final processing + self.post_init() + + def freeze_feature_extractor(self): + raise AttributeError("Not needed for Data2VecAudio") + + def freeze_feature_encoder(self): + """ + Calling this function will disable the gradient computation for the feature encoder so that its parameter will + not be updated during training. + """ + self.feature_extractor._freeze_parameters() + + @add_start_docstrings_to_model_forward(DATA2VEC_AUDIO_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=Data2VecAudioBaseModelOutput, + config_class=_CONFIG_FOR_DOC, + modality="audio", + expected_output=_EXPECTED_OUTPUT_SHAPE, + ) + def forward(self, **super_kwargs): + return super().forward(**super_kwargs) + + +@add_start_docstrings( + """Data2VecAudio Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC).""", + DATA2VEC_AUDIO_START_DOCSTRING, +) +class Data2VecAudioForCTC(Data2VecAudioPreTrainedModel, Wav2Vec2ForCTC): + def __init__(self, config): + Data2VecAudioPreTrainedModel.__init__(config) + + self.data2vec_audio = Data2VecAudioModel(config) + self.dropout = nn.Dropout(config.final_dropout) + + if config.vocab_size is None: + raise ValueError( + f"You are trying to instantiate {self.__class__} with a configuration that " + "does not define the vocabulary size of the language model head. Please " + "instantiate the model as follows: `Data2VecAudioForCTC.from_pretrained(..., vocab_size=vocab_size)`. " + "or define `vocab_size` of your model's configuration." + ) + output_hidden_size = ( + config.output_hidden_size if hasattr(config, "add_adapter") and config.add_adapter else config.hidden_size + ) + self.lm_head = nn.Linear(output_hidden_size, config.vocab_size) + + # Initialize weights and apply final processing + self.post_init() + + def freeze_base_model(self): + raise AttributeError("Not needed for Data2VecAudio") + + def tie_weights(self): + raise AttributeError("Not needed for Data2VecAudio") + + @add_start_docstrings_to_model_forward(DATA2VEC_AUDIO_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=CausalLMOutput, + config_class=_CONFIG_FOR_DOC, + expected_output=_CTC_EXPECTED_OUTPUT, + expected_loss=_CTC_EXPECTED_LOSS, + ) + def forward(self, **super_kwargs): + return super().forward(**super_kwargs) + + +@add_start_docstrings( + """ + Data2VecAudio Model with a sequence classification head on top (a linear layer over the pooled output) for tasks + like SUPERB Keyword Spotting. + """, + DATA2VEC_AUDIO_START_DOCSTRING, +) +class Data2VecAudioForSequenceClassification(Wav2Vec2ForSequenceClassification): + @add_start_docstrings_to_model_forward(DATA2VEC_AUDIO_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=SequenceClassifierOutput, + config_class=_CONFIG_FOR_DOC, + modality="audio", + ) + def forward(self, **super_kwargs): + return super().forward(**super_kwargs) + + +@add_start_docstrings( + """ + Data2VecAudio Model with a frame classification head on top for tasks like Speaker Diarization. + """, + DATA2VEC_AUDIO_START_DOCSTRING, +) +class Data2VecAudioForAudioFrameClassification(Wav2Vec2ForAudioFrameClassification): + @add_start_docstrings_to_model_forward(DATA2VEC_AUDIO_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TokenClassifierOutput, + config_class=_CONFIG_FOR_DOC, + modality="audio", + ) + def forward(self, **super_kwargs): + return super().forward(**super_kwargs) + + +@add_start_docstrings( + """ + Data2VecAudio Model with an XVector feature extraction head on top for tasks like Speaker Verification. + """, + DATA2VEC_AUDIO_START_DOCSTRING, +) +class Data2VecAudioForXVector(Wav2Vec2ForXVector): + @add_start_docstrings_to_model_forward(DATA2VEC_AUDIO_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=XVectorOutput, + config_class=_CONFIG_FOR_DOC, + modality="audio", + ) + def forward(self, **super_kwargs): + return super().forward(**super_kwargs) + + +__all__ = [ + "Data2VecAudioForAudioFrameClassification", + "Data2VecAudioForCTC", + "Data2VecAudioForSequenceClassification", + "Data2VecAudioForXVector", + "Data2VecAudioModel", + "Data2VecAudioPreTrainedModel", +] diff --git a/src/transformers/models/gemma/configuration_gemma.py b/src/transformers/models/gemma/configuration_gemma.py index 2aeb2005805..e372817bf71 100644 --- a/src/transformers/models/gemma/configuration_gemma.py +++ b/src/transformers/models/gemma/configuration_gemma.py @@ -19,7 +19,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - from ...configuration_utils import PretrainedConfig diff --git a/src/transformers/models/gemma/modular_gemma.py b/src/transformers/models/gemma/modular_gemma.py index 735c508f6d7..fa6af70ecfd 100644 --- a/src/transformers/models/gemma/modular_gemma.py +++ b/src/transformers/models/gemma/modular_gemma.py @@ -13,7 +13,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union import sentencepiece as spm import torch @@ -379,7 +379,7 @@ class GemmaModel(LlamaModel): output_hidden_states: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs, # NOOP kwarg for now - ) -> Union[Tuple, BaseModelOutputWithPast]: + ) -> BaseModelOutputWithPast: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states diff --git a/src/transformers/models/gemma2/modular_gemma2.py b/src/transformers/models/gemma2/modular_gemma2.py index d197d978e8d..384f3e08023 100644 --- a/src/transformers/models/gemma2/modular_gemma2.py +++ b/src/transformers/models/gemma2/modular_gemma2.py @@ -415,7 +415,7 @@ class Gemma2Model(GemmaModel): cache_position: Optional[torch.LongTensor] = None, last_cache_position: Optional[int] = None, **flash_attn_kwargs: Unpack[FlashAttentionKwargs], - ) -> Union[Tuple, BaseModelOutputWithPast]: + ) -> BaseModelOutputWithPast: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states @@ -588,7 +588,7 @@ class Gemma2ForCausalLM(GemmaForCausalLM): cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, **loss_kwargs, - ) -> Union[Tuple, CausalLMOutputWithPast]: + ) -> CausalLMOutputWithPast: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., diff --git a/src/transformers/models/gemma3/modeling_gemma3.py b/src/transformers/models/gemma3/modeling_gemma3.py index 6dad88c1bc1..6364f890219 100644 --- a/src/transformers/models/gemma3/modeling_gemma3.py +++ b/src/transformers/models/gemma3/modeling_gemma3.py @@ -326,7 +326,7 @@ class Gemma3Attention(nn.Module): past_key_value: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[FlashAttentionKwargs], - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: input_shape = hidden_states.shape[:-1] hidden_shape = (*input_shape, -1, self.head_dim) @@ -1203,7 +1203,7 @@ class Gemma3ForConditionalGeneration(Gemma3PreTrainedModel, GenerationMixin): return causal_mask - def get_image_features(self, pixel_values: torch.Tensor): + def get_image_features(self, pixel_values: torch.Tensor) -> torch.Tensor: """ Projects the last hidden state from the vision model into language model space. diff --git a/src/transformers/models/gemma3/modular_gemma3.py b/src/transformers/models/gemma3/modular_gemma3.py index c4de8d928d5..3f7292f13a0 100644 --- a/src/transformers/models/gemma3/modular_gemma3.py +++ b/src/transformers/models/gemma3/modular_gemma3.py @@ -597,7 +597,7 @@ class Gemma3TextModel(Gemma2Model): cache_position: Optional[torch.LongTensor] = None, last_cache_position: Optional[int] = None, **flash_attn_kwargs: Unpack[FlashAttentionKwargs], - ) -> Union[tuple, BaseModelOutputWithPast]: + ) -> BaseModelOutputWithPast: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states diff --git a/src/transformers/models/got_ocr2/modeling_got_ocr2.py b/src/transformers/models/got_ocr2/modeling_got_ocr2.py index 1278af4faab..c474ef36900 100644 --- a/src/transformers/models/got_ocr2/modeling_got_ocr2.py +++ b/src/transformers/models/got_ocr2/modeling_got_ocr2.py @@ -753,7 +753,7 @@ class GotOcr2ForConditionalGeneration(GotOcr2PreTrainedModel, GenerationMixin): output_hidden_states: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, - ) -> Union[Tuple, GotOcr2CausalLMOutputWithPast]: + ) -> GotOcr2CausalLMOutputWithPast: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., diff --git a/src/transformers/models/got_ocr2/modular_got_ocr2.py b/src/transformers/models/got_ocr2/modular_got_ocr2.py index 36d2db007b6..4b7e7f1adb5 100644 --- a/src/transformers/models/got_ocr2/modular_got_ocr2.py +++ b/src/transformers/models/got_ocr2/modular_got_ocr2.py @@ -397,7 +397,7 @@ class GotOcr2ForConditionalGeneration(LlavaForConditionalGeneration): output_hidden_states: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, - ) -> LlavaCausalLMOutputWithPast: + ) -> GotOcr2CausalLMOutputWithPast: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., diff --git a/src/transformers/models/gpt_neox/modular_gpt_neox.py b/src/transformers/models/gpt_neox/modular_gpt_neox.py index 1399f1f18f1..58b697c3f10 100644 --- a/src/transformers/models/gpt_neox/modular_gpt_neox.py +++ b/src/transformers/models/gpt_neox/modular_gpt_neox.py @@ -343,7 +343,7 @@ class GPTNeoXModel(LlamaModel, nn.Module): output_hidden_states: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, **flash_attn_kwargs: Unpack[FlashAttentionKwargs], - ) -> Union[Tuple, BaseModelOutputWithPast]: + ) -> BaseModelOutputWithPast: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states diff --git a/src/transformers/models/granite/modular_granite.py b/src/transformers/models/granite/modular_granite.py index 494ab5f1825..25929dbb337 100644 --- a/src/transformers/models/granite/modular_granite.py +++ b/src/transformers/models/granite/modular_granite.py @@ -132,7 +132,7 @@ class GraniteModel(LlamaModel): output_hidden_states: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, **flash_attn_kwargs: Unpack[FlashAttentionKwargs], - ) -> Union[Tuple, BaseModelOutputWithPast]: + ) -> BaseModelOutputWithPast: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states @@ -244,7 +244,7 @@ class GraniteForCausalLM(LlamaForCausalLM): cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, **kwargs: Unpack[KwargsForCausalLM], - ) -> Union[Tuple, CausalLMOutputWithPast]: + ) -> CausalLMOutputWithPast: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states diff --git a/src/transformers/models/hubert/modeling_hubert.py b/src/transformers/models/hubert/modeling_hubert.py index 7445802e624..ae03cea1c13 100755 --- a/src/transformers/models/hubert/modeling_hubert.py +++ b/src/transformers/models/hubert/modeling_hubert.py @@ -1,26 +1,15 @@ -# coding=utf-8 -# Copyright 2021 The Fairseq Authors and the HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""PyTorch Hubert model.""" - +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/hubert/modular_hubert.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_hubert.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 import warnings from typing import Optional, Tuple, Union import numpy as np import torch -import torch.utils.checkpoint -from torch import nn +import torch.nn as nn from torch.nn import CrossEntropyLoss from ...activations import ACT2FN @@ -45,219 +34,12 @@ if is_flash_attn_available(): logger = logging.get_logger(__name__) -_HIDDEN_STATES_START_POSITION = 1 +# Base docstring +_CHECKPOINT_FOR_DOC = "facebook/hubert-large-ls960-ft" # General docstring _CONFIG_FOR_DOC = "HubertConfig" -# Base docstring -_CHECKPOINT_FOR_DOC = "facebook/hubert-large-ls960-ft" -_EXPECTED_OUTPUT_SHAPE = [1, 292, 768] - -# CTC docstring -_CTC_EXPECTED_OUTPUT = "'MISTER QUILTER IS THE APOSTLE OF THE MIDDLE CLASSES AND WE ARE GLAD TO WELCOME HIS GOSPEL'" -_CTC_EXPECTED_LOSS = 22.68 - -# Audio class docstring -_SEQ_CLASS_CHECKPOINT = "superb/hubert-base-superb-ks" -_SEQ_CLASS_EXPECTED_OUTPUT = "'_unknown_'" -_SEQ_CLASS_EXPECTED_LOSS = 8.53 - - -# Copied from transformers.models.wav2vec2.modeling_wav2vec2._compute_mask_indices -def _compute_mask_indices( - shape: Tuple[int, int], - mask_prob: float, - mask_length: int, - attention_mask: Optional[torch.LongTensor] = None, - min_masks: int = 0, -) -> np.ndarray: - """ - Computes random mask spans for a given shape. Used to implement [SpecAugment: A Simple Data Augmentation Method for - ASR](https://arxiv.org/abs/1904.08779). Note that this method is not optimized to run on TPU and should be run on - CPU as part of the preprocessing during training. - - Args: - shape: The shape for which to compute masks. This should be of a tuple of size 2 where - the first element is the batch size and the second element is the length of the axis to span. - mask_prob: The percentage of the whole axis (between 0 and 1) which will be masked. The number of - independently generated mask spans of length `mask_length` is computed by - `mask_prob*shape[1]/mask_length`. Note that due to overlaps, `mask_prob` is an upper bound and the - actual percentage will be smaller. - mask_length: size of the mask - min_masks: minimum number of masked spans - attention_mask: A (right-padded) attention mask which independently shortens the feature axis of - each batch dimension. - """ - batch_size, sequence_length = shape - - if mask_length < 1: - raise ValueError("`mask_length` has to be bigger than 0.") - - if mask_length > sequence_length: - raise ValueError( - f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length}" - f" and `sequence_length`: {sequence_length}`" - ) - - # epsilon is used for probabilistic rounding - epsilon = np.random.rand(1).item() - - def compute_num_masked_span(input_length): - """Given input length, compute how many spans should be masked""" - num_masked_span = int(mask_prob * input_length / mask_length + epsilon) - num_masked_span = max(num_masked_span, min_masks) - - # make sure num masked span <= sequence_length - if num_masked_span * mask_length > sequence_length: - num_masked_span = sequence_length // mask_length - - # make sure num_masked span is also <= input_length - (mask_length - 1) - if input_length - (mask_length - 1) < num_masked_span: - num_masked_span = max(input_length - (mask_length - 1), 0) - - return num_masked_span - - # compute number of masked spans in batch - input_lengths = ( - attention_mask.detach().sum(-1).tolist() - if attention_mask is not None - else [sequence_length for _ in range(batch_size)] - ) - - # SpecAugment mask to fill - spec_aug_mask = np.zeros((batch_size, sequence_length), dtype=bool) - spec_aug_mask_idxs = [] - - max_num_masked_span = compute_num_masked_span(sequence_length) - - if max_num_masked_span == 0: - return spec_aug_mask - - for input_length in input_lengths: - # compute num of masked spans for this input - num_masked_span = compute_num_masked_span(input_length) - - # get random indices to mask - spec_aug_mask_idx = np.random.choice( - np.arange(input_length - (mask_length - 1)), num_masked_span, replace=False - ) - - # pick first sampled index that will serve as a dummy index to pad vector - # to ensure same dimension for all batches due to probabilistic rounding - # Picking first sample just pads those vectors twice. - if len(spec_aug_mask_idx) == 0: - # this case can only happen if `input_length` is strictly smaller then - # `sequence_length` in which case the last token has to be a padding - # token which we can use as a dummy mask id - dummy_mask_idx = sequence_length - 1 - else: - dummy_mask_idx = spec_aug_mask_idx[0] - - spec_aug_mask_idx = np.concatenate( - [spec_aug_mask_idx, np.ones(max_num_masked_span - num_masked_span, dtype=np.int32) * dummy_mask_idx] - ) - spec_aug_mask_idxs.append(spec_aug_mask_idx) - - spec_aug_mask_idxs = np.array(spec_aug_mask_idxs) - - # expand masked indices to masked spans - spec_aug_mask_idxs = np.broadcast_to( - spec_aug_mask_idxs[:, :, None], (batch_size, max_num_masked_span, mask_length) - ) - spec_aug_mask_idxs = spec_aug_mask_idxs.reshape(batch_size, max_num_masked_span * mask_length) - - # add offset to the starting indexes so that indexes now create a span - offsets = np.arange(mask_length)[None, None, :] - offsets = np.broadcast_to(offsets, (batch_size, max_num_masked_span, mask_length)).reshape( - batch_size, max_num_masked_span * mask_length - ) - spec_aug_mask_idxs = spec_aug_mask_idxs + offsets - - # ensure that we cannot have indices larger than sequence_length - if spec_aug_mask_idxs.max() > sequence_length - 1: - spec_aug_mask_idxs[spec_aug_mask_idxs > sequence_length - 1] = sequence_length - 1 - - # scatter indices to mask - np.put_along_axis(spec_aug_mask, spec_aug_mask_idxs, 1, -1) - - return spec_aug_mask - - -# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2NoLayerNormConvLayer with Wav2Vec2->Hubert -class HubertNoLayerNormConvLayer(nn.Module): - def __init__(self, config, layer_id=0): - super().__init__() - self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1 - self.out_conv_dim = config.conv_dim[layer_id] - - self.conv = nn.Conv1d( - self.in_conv_dim, - self.out_conv_dim, - kernel_size=config.conv_kernel[layer_id], - stride=config.conv_stride[layer_id], - bias=config.conv_bias, - ) - self.activation = ACT2FN[config.feat_extract_activation] - - def forward(self, hidden_states): - hidden_states = self.conv(hidden_states) - hidden_states = self.activation(hidden_states) - return hidden_states - - -# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2LayerNormConvLayer with Wav2Vec2->Hubert -class HubertLayerNormConvLayer(nn.Module): - def __init__(self, config, layer_id=0): - super().__init__() - self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1 - self.out_conv_dim = config.conv_dim[layer_id] - - self.conv = nn.Conv1d( - self.in_conv_dim, - self.out_conv_dim, - kernel_size=config.conv_kernel[layer_id], - stride=config.conv_stride[layer_id], - bias=config.conv_bias, - ) - self.layer_norm = nn.LayerNorm(self.out_conv_dim, elementwise_affine=True) - self.activation = ACT2FN[config.feat_extract_activation] - - def forward(self, hidden_states): - hidden_states = self.conv(hidden_states) - - hidden_states = hidden_states.transpose(-2, -1) - hidden_states = self.layer_norm(hidden_states) - hidden_states = hidden_states.transpose(-2, -1) - - hidden_states = self.activation(hidden_states) - return hidden_states - - -# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2GroupNormConvLayer with Wav2Vec2->Hubert -class HubertGroupNormConvLayer(nn.Module): - def __init__(self, config, layer_id=0): - super().__init__() - self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1 - self.out_conv_dim = config.conv_dim[layer_id] - - self.conv = nn.Conv1d( - self.in_conv_dim, - self.out_conv_dim, - kernel_size=config.conv_kernel[layer_id], - stride=config.conv_stride[layer_id], - bias=config.conv_bias, - ) - self.activation = ACT2FN[config.feat_extract_activation] - - self.layer_norm = nn.GroupNorm(num_groups=self.out_conv_dim, num_channels=self.out_conv_dim, affine=True) - - def forward(self, hidden_states): - hidden_states = self.conv(hidden_states) - hidden_states = self.layer_norm(hidden_states) - hidden_states = self.activation(hidden_states) - return hidden_states - class HubertPositionalConvEmbedding(nn.Module): def __init__(self, config): @@ -309,7 +91,6 @@ class HubertPositionalConvEmbedding(nn.Module): return hidden_states -# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2SamePadLayer with Wav2Vec2->Hubert class HubertSamePadLayer(nn.Module): def __init__(self, num_conv_pos_embeddings): super().__init__() @@ -321,7 +102,78 @@ class HubertSamePadLayer(nn.Module): return hidden_states -# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2FeatureEncoder with Wav2Vec2->Hubert +class HubertNoLayerNormConvLayer(nn.Module): + def __init__(self, config, layer_id=0): + super().__init__() + self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1 + self.out_conv_dim = config.conv_dim[layer_id] + + self.conv = nn.Conv1d( + self.in_conv_dim, + self.out_conv_dim, + kernel_size=config.conv_kernel[layer_id], + stride=config.conv_stride[layer_id], + bias=config.conv_bias, + ) + self.activation = ACT2FN[config.feat_extract_activation] + + def forward(self, hidden_states): + hidden_states = self.conv(hidden_states) + hidden_states = self.activation(hidden_states) + return hidden_states + + +class HubertLayerNormConvLayer(nn.Module): + def __init__(self, config, layer_id=0): + super().__init__() + self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1 + self.out_conv_dim = config.conv_dim[layer_id] + + self.conv = nn.Conv1d( + self.in_conv_dim, + self.out_conv_dim, + kernel_size=config.conv_kernel[layer_id], + stride=config.conv_stride[layer_id], + bias=config.conv_bias, + ) + self.layer_norm = nn.LayerNorm(self.out_conv_dim, elementwise_affine=True) + self.activation = ACT2FN[config.feat_extract_activation] + + def forward(self, hidden_states): + hidden_states = self.conv(hidden_states) + + hidden_states = hidden_states.transpose(-2, -1) + hidden_states = self.layer_norm(hidden_states) + hidden_states = hidden_states.transpose(-2, -1) + + hidden_states = self.activation(hidden_states) + return hidden_states + + +class HubertGroupNormConvLayer(nn.Module): + def __init__(self, config, layer_id=0): + super().__init__() + self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1 + self.out_conv_dim = config.conv_dim[layer_id] + + self.conv = nn.Conv1d( + self.in_conv_dim, + self.out_conv_dim, + kernel_size=config.conv_kernel[layer_id], + stride=config.conv_stride[layer_id], + bias=config.conv_bias, + ) + self.activation = ACT2FN[config.feat_extract_activation] + + self.layer_norm = nn.GroupNorm(num_groups=self.out_conv_dim, num_channels=self.out_conv_dim, affine=True) + + def forward(self, hidden_states): + hidden_states = self.conv(hidden_states) + hidden_states = self.layer_norm(hidden_states) + hidden_states = self.activation(hidden_states) + return hidden_states + + class HubertFeatureEncoder(nn.Module): """Construct the features from raw audio waveform""" @@ -366,17 +218,6 @@ class HubertFeatureEncoder(nn.Module): return hidden_states -class HubertFeatureExtractor(HubertFeatureEncoder): - def __init__(self, config): - super().__init__(config) - warnings.warn( - f"The class `{self.__class__.__name__}` has been depreciated " - "and will be removed in Transformers v5. " - f"Use `{self.__class__.__bases__[0].__name__}` instead.", - FutureWarning, - ) - - class HubertFeatureProjection(nn.Module): def __init__(self, config): super().__init__() @@ -395,7 +236,6 @@ class HubertFeatureProjection(nn.Module): return hidden_states -# Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->Hubert class HubertAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -554,7 +394,6 @@ class HubertAttention(nn.Module): return attn_output, attn_weights_reshaped, past_key_value -# Copied from transformers.models.bart.modeling_bart.BartFlashAttention2 with Bart->Hubert class HubertFlashAttention2(HubertAttention): """ Hubert flash attention module. This module inherits from `HubertAttention` as the weights of the module stays @@ -682,7 +521,6 @@ class HubertFlashAttention2(HubertAttention): class HubertSdpaAttention(HubertAttention): - # Copied from transformers.models.bart.modeling_bart.BartSdpaAttention.forward with Bart->Hubert def forward( self, hidden_states: torch.Tensor, @@ -788,14 +626,6 @@ class HubertSdpaAttention(HubertAttention): return attn_output, None, past_key_value -HUBERT_ATTENTION_CLASSES = { - "eager": HubertAttention, - "sdpa": HubertSdpaAttention, - "flash_attention_2": HubertFlashAttention2, -} - - -# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2FeedForward with Wav2Vec2->Hubert class HubertFeedForward(nn.Module): def __init__(self, config): super().__init__() @@ -820,7 +650,13 @@ class HubertFeedForward(nn.Module): return hidden_states -# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2EncoderLayer with Wav2Vec2->Hubert, WAV2VEC2->HUBERT +HUBERT_ATTENTION_CLASSES = { + "eager": HubertAttention, + "sdpa": HubertSdpaAttention, + "flash_attention_2": HubertFlashAttention2, +} + + class HubertEncoderLayer(nn.Module): def __init__(self, config): super().__init__() @@ -856,79 +692,6 @@ class HubertEncoderLayer(nn.Module): return outputs -# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2AttnAdapterLayer with Wav2Vec2->Hubert -class HubertAttnAdapterLayer(nn.Module): - def __init__(self, config): - """ - Implements adapter modules directly with 3D tensor weight as parameters and without using ModuleList to speed - up training throughput. - """ - super().__init__() - self.input_dim = config.adapter_attn_dim - self.hidden_dim = config.hidden_size - - self.norm = nn.LayerNorm(self.hidden_dim) - self.linear_1 = nn.Linear(self.hidden_dim, self.input_dim) - self.act_fn = nn.ReLU() - self.linear_2 = nn.Linear(self.input_dim, self.hidden_dim) - - def forward(self, hidden_states: torch.FloatTensor): - hidden_states = self.norm(hidden_states) - - hidden_states = self.linear_1(hidden_states) - hidden_states = self.act_fn(hidden_states) - hidden_states = self.linear_2(hidden_states) - - return hidden_states - - -# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2EncoderLayerStableLayerNorm with Wav2Vec2->Hubert, WAV2VEC2->HUBERT -class HubertEncoderLayerStableLayerNorm(nn.Module): - def __init__(self, config): - super().__init__() - self.attention = HUBERT_ATTENTION_CLASSES[config._attn_implementation]( - embed_dim=config.hidden_size, - num_heads=config.num_attention_heads, - dropout=config.attention_dropout, - is_decoder=False, - ) - self.dropout = nn.Dropout(config.hidden_dropout) - self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - self.feed_forward = HubertFeedForward(config) - self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - - if getattr(config, "adapter_attn_dim", None) is not None: - self.adapter_layer = HubertAttnAdapterLayer(config) - else: - self.adapter_layer = None - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - output_attentions: bool = False, - ): - attn_residual = hidden_states - hidden_states = self.layer_norm(hidden_states) - hidden_states, attn_weights, _ = self.attention( - hidden_states, attention_mask=attention_mask, output_attentions=output_attentions - ) - hidden_states = self.dropout(hidden_states) - hidden_states = attn_residual + hidden_states - hidden_states = hidden_states + self.feed_forward(self.final_layer_norm(hidden_states)) - - if self.adapter_layer is not None: - hidden_states = hidden_states + self.adapter_layer(hidden_states) - - outputs = (hidden_states,) - - if output_attentions: - outputs += (attn_weights,) - - return outputs - - -# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Encoder with Wav2Vec2->Hubert class HubertEncoder(nn.Module): def __init__(self, config): super().__init__() @@ -1014,7 +777,76 @@ class HubertEncoder(nn.Module): ) -# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2EncoderStableLayerNorm with Wav2Vec2->Hubert +class HubertAttnAdapterLayer(nn.Module): + def __init__(self, config): + """ + Implements adapter modules directly with 3D tensor weight as parameters and without using ModuleList to speed + up training throughput. + """ + super().__init__() + self.input_dim = config.adapter_attn_dim + self.hidden_dim = config.hidden_size + + self.norm = nn.LayerNorm(self.hidden_dim) + self.linear_1 = nn.Linear(self.hidden_dim, self.input_dim) + self.act_fn = nn.ReLU() + self.linear_2 = nn.Linear(self.input_dim, self.hidden_dim) + + def forward(self, hidden_states: torch.FloatTensor): + hidden_states = self.norm(hidden_states) + + hidden_states = self.linear_1(hidden_states) + hidden_states = self.act_fn(hidden_states) + hidden_states = self.linear_2(hidden_states) + + return hidden_states + + +class HubertEncoderLayerStableLayerNorm(nn.Module): + def __init__(self, config): + super().__init__() + self.attention = HUBERT_ATTENTION_CLASSES[config._attn_implementation]( + embed_dim=config.hidden_size, + num_heads=config.num_attention_heads, + dropout=config.attention_dropout, + is_decoder=False, + ) + self.dropout = nn.Dropout(config.hidden_dropout) + self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.feed_forward = HubertFeedForward(config) + self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + if getattr(config, "adapter_attn_dim", None) is not None: + self.adapter_layer = HubertAttnAdapterLayer(config) + else: + self.adapter_layer = None + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ): + attn_residual = hidden_states + hidden_states = self.layer_norm(hidden_states) + hidden_states, attn_weights, _ = self.attention( + hidden_states, attention_mask=attention_mask, output_attentions=output_attentions + ) + hidden_states = self.dropout(hidden_states) + hidden_states = attn_residual + hidden_states + hidden_states = hidden_states + self.feed_forward(self.final_layer_norm(hidden_states)) + + if self.adapter_layer is not None: + hidden_states = hidden_states + self.adapter_layer(hidden_states) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_weights,) + + return outputs + + class HubertEncoderStableLayerNorm(nn.Module): def __init__(self, config): super().__init__() @@ -1170,6 +1002,125 @@ class HubertPreTrainedModel(PreTrainedModel): return attention_mask +def _compute_mask_indices( + shape: Tuple[int, int], + mask_prob: float, + mask_length: int, + attention_mask: Optional[torch.LongTensor] = None, + min_masks: int = 0, +) -> np.ndarray: + """ + Computes random mask spans for a given shape. Used to implement [SpecAugment: A Simple Data Augmentation Method for + ASR](https://arxiv.org/abs/1904.08779). Note that this method is not optimized to run on TPU and should be run on + CPU as part of the preprocessing during training. + + Args: + shape: The shape for which to compute masks. This should be of a tuple of size 2 where + the first element is the batch size and the second element is the length of the axis to span. + mask_prob: The percentage of the whole axis (between 0 and 1) which will be masked. The number of + independently generated mask spans of length `mask_length` is computed by + `mask_prob*shape[1]/mask_length`. Note that due to overlaps, `mask_prob` is an upper bound and the + actual percentage will be smaller. + mask_length: size of the mask + min_masks: minimum number of masked spans + attention_mask: A (right-padded) attention mask which independently shortens the feature axis of + each batch dimension. + """ + batch_size, sequence_length = shape + + if mask_length < 1: + raise ValueError("`mask_length` has to be bigger than 0.") + + if mask_length > sequence_length: + raise ValueError( + f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length}" + f" and `sequence_length`: {sequence_length}`" + ) + + # epsilon is used for probabilistic rounding + epsilon = np.random.rand(1).item() + + def compute_num_masked_span(input_length): + """Given input length, compute how many spans should be masked""" + num_masked_span = int(mask_prob * input_length / mask_length + epsilon) + num_masked_span = max(num_masked_span, min_masks) + + # make sure num masked span <= sequence_length + if num_masked_span * mask_length > sequence_length: + num_masked_span = sequence_length // mask_length + + # make sure num_masked span is also <= input_length - (mask_length - 1) + if input_length - (mask_length - 1) < num_masked_span: + num_masked_span = max(input_length - (mask_length - 1), 0) + + return num_masked_span + + # compute number of masked spans in batch + input_lengths = ( + attention_mask.detach().sum(-1).tolist() + if attention_mask is not None + else [sequence_length for _ in range(batch_size)] + ) + + # SpecAugment mask to fill + spec_aug_mask = np.zeros((batch_size, sequence_length), dtype=bool) + spec_aug_mask_idxs = [] + + max_num_masked_span = compute_num_masked_span(sequence_length) + + if max_num_masked_span == 0: + return spec_aug_mask + + for input_length in input_lengths: + # compute num of masked spans for this input + num_masked_span = compute_num_masked_span(input_length) + + # get random indices to mask + spec_aug_mask_idx = np.random.choice( + np.arange(input_length - (mask_length - 1)), num_masked_span, replace=False + ) + + # pick first sampled index that will serve as a dummy index to pad vector + # to ensure same dimension for all batches due to probabilistic rounding + # Picking first sample just pads those vectors twice. + if len(spec_aug_mask_idx) == 0: + # this case can only happen if `input_length` is strictly smaller then + # `sequence_length` in which case the last token has to be a padding + # token which we can use as a dummy mask id + dummy_mask_idx = sequence_length - 1 + else: + dummy_mask_idx = spec_aug_mask_idx[0] + + spec_aug_mask_idx = np.concatenate( + [spec_aug_mask_idx, np.ones(max_num_masked_span - num_masked_span, dtype=np.int32) * dummy_mask_idx] + ) + spec_aug_mask_idxs.append(spec_aug_mask_idx) + + spec_aug_mask_idxs = np.array(spec_aug_mask_idxs) + + # expand masked indices to masked spans + spec_aug_mask_idxs = np.broadcast_to( + spec_aug_mask_idxs[:, :, None], (batch_size, max_num_masked_span, mask_length) + ) + spec_aug_mask_idxs = spec_aug_mask_idxs.reshape(batch_size, max_num_masked_span * mask_length) + + # add offset to the starting indexes so that indexes now create a span + offsets = np.arange(mask_length)[None, None, :] + offsets = np.broadcast_to(offsets, (batch_size, max_num_masked_span, mask_length)).reshape( + batch_size, max_num_masked_span * mask_length + ) + spec_aug_mask_idxs = spec_aug_mask_idxs + offsets + + # ensure that we cannot have indices larger than sequence_length + if spec_aug_mask_idxs.max() > sequence_length - 1: + spec_aug_mask_idxs[spec_aug_mask_idxs > sequence_length - 1] = sequence_length - 1 + + # scatter indices to mask + np.put_along_axis(spec_aug_mask, spec_aug_mask_idxs, 1, -1) + + return spec_aug_mask + + HUBERT_START_DOCSTRING = r""" Hubert was proposed in [HuBERT: Self-Supervised Speech Representation Learning by Masked Prediction of Hidden Units](https://arxiv.org/abs/2106.07447) by Wei-Ning Hsu, Benjamin Bolte, Yao-Hung Hubert Tsai, Kushal Lakhotia, @@ -1238,6 +1189,7 @@ class HubertModel(HubertPreTrainedModel): self.feature_extractor = HubertFeatureEncoder(config) self.feature_projection = HubertFeatureProjection(config) + # model only needs masking vector if mask prob is > 0.0 if config.mask_time_prob > 0.0 or config.mask_feature_prob > 0.0: self.masked_spec_embed = nn.Parameter(torch.Tensor(config.hidden_size).uniform_()) @@ -1249,7 +1201,6 @@ class HubertModel(HubertPreTrainedModel): # Initialize weights and apply final processing self.post_init() - # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Model._mask_hidden_states def _mask_hidden_states( self, hidden_states: torch.FloatTensor, @@ -1370,11 +1321,17 @@ class HubertModel(HubertPreTrainedModel): ) +_HIDDEN_STATES_START_POSITION = 1 + + +_CTC_EXPECTED_OUTPUT = "'MISTER QUILTER IS THE APOSTLE OF THE MIDDLE CLASSES AND WE ARE GLAD TO WELCOME HIS GOSPEL'" +_CTC_EXPECTED_LOSS = 22.68 + + @add_start_docstrings( """Hubert Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC).""", HUBERT_START_DOCSTRING, ) -# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForCTC with Wav2Vec2->Hubert, wav2vec2->hubert, WAV_2_VEC_2->HUBERT class HubertForCTC(HubertPreTrainedModel): def __init__(self, config, target_lang: Optional[str] = None): super().__init__(config) @@ -1526,6 +1483,11 @@ class HubertForCTC(HubertPreTrainedModel): ) +_SEQ_CLASS_CHECKPOINT = "superb/hubert-base-superb-ks" +_SEQ_CLASS_EXPECTED_OUTPUT = "'_unknown_'" +_SEQ_CLASS_EXPECTED_LOSS = 8.53 + + @add_start_docstrings( """ Hubert Model with a sequence classification head on top (a linear layer over the pooled output) for tasks like @@ -1533,7 +1495,6 @@ class HubertForCTC(HubertPreTrainedModel): """, HUBERT_START_DOCSTRING, ) -# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification with Wav2Vec2->Hubert, wav2vec2->hubert, WAV_2_VEC_2->HUBERT class HubertForSequenceClassification(HubertPreTrainedModel): def __init__(self, config): super().__init__(config) diff --git a/src/transformers/models/hubert/modular_hubert.py b/src/transformers/models/hubert/modular_hubert.py new file mode 100644 index 00000000000..a42bb5e598b --- /dev/null +++ b/src/transformers/models/hubert/modular_hubert.py @@ -0,0 +1,400 @@ +from typing import Optional, Tuple, Union + +import torch +import torch.nn as nn + +from ...activations import ACT2FN +from ...integrations.deepspeed import is_deepspeed_zero3_enabled +from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput +from ...modeling_utils import PreTrainedModel +from ...utils import ( + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + replace_return_docstrings, +) +from ..wav2vec2.modeling_wav2vec2 import ( + Wav2Vec2Encoder, + Wav2Vec2EncoderStableLayerNorm, + Wav2Vec2FeatureEncoder, + Wav2Vec2ForCTC, + Wav2Vec2ForSequenceClassification, + Wav2Vec2Model, + Wav2Vec2SamePadLayer, +) +from .configuration_hubert import HubertConfig + + +_HIDDEN_STATES_START_POSITION = 1 + +# General docstring +_CONFIG_FOR_DOC = "HubertConfig" + +# Base docstring +_CHECKPOINT_FOR_DOC = "facebook/hubert-large-ls960-ft" +_EXPECTED_OUTPUT_SHAPE = [1, 292, 768] + + +_CTC_EXPECTED_OUTPUT = "'MISTER QUILTER IS THE APOSTLE OF THE MIDDLE CLASSES AND WE ARE GLAD TO WELCOME HIS GOSPEL'" +_CTC_EXPECTED_LOSS = 22.68 + + +_SEQ_CLASS_CHECKPOINT = "superb/hubert-base-superb-ks" +_SEQ_CLASS_EXPECTED_OUTPUT = "'_unknown_'" +_SEQ_CLASS_EXPECTED_LOSS = 8.53 + + +class HubertPositionalConvEmbedding(nn.Module): + def __init__(self, config): + super().__init__() + self.conv = nn.Conv1d( + config.hidden_size, + config.hidden_size, + kernel_size=config.num_conv_pos_embeddings, + padding=config.num_conv_pos_embeddings // 2, + groups=config.num_conv_pos_embedding_groups, + ) + + self.batch_norm = None + if config.conv_pos_batch_norm: + self.batch_norm = nn.BatchNorm1d(config.hidden_size) + else: + weight_norm = nn.utils.weight_norm + if hasattr(nn.utils.parametrizations, "weight_norm"): + weight_norm = nn.utils.parametrizations.weight_norm + + if is_deepspeed_zero3_enabled(): + import deepspeed + + with deepspeed.zero.GatheredParameters(self.conv.weight, modifier_rank=0): + self.conv = weight_norm(self.conv, name="weight", dim=2) + if hasattr(self.conv, "parametrizations"): + weight_g = self.conv.parametrizations.weight.original0 + weight_v = self.conv.parametrizations.weight.original1 + else: + weight_g = self.conv.weight_g + weight_v = self.conv.weight_v + deepspeed.zero.register_external_parameter(self, weight_v) + deepspeed.zero.register_external_parameter(self, weight_g) + else: + self.conv = weight_norm(self.conv, name="weight", dim=2) + + self.padding = HubertSamePadLayer(config.num_conv_pos_embeddings) + self.activation = ACT2FN[config.feat_extract_activation] + + def forward(self, hidden_states): + hidden_states = hidden_states.transpose(1, 2) + if self.batch_norm is not None: + hidden_states = self.batch_norm(hidden_states) + hidden_states = self.conv(hidden_states) + hidden_states = self.padding(hidden_states) + hidden_states = self.activation(hidden_states) + + hidden_states = hidden_states.transpose(1, 2) + return hidden_states + + +class HubertSamePadLayer(Wav2Vec2SamePadLayer): + pass + + +class HubertFeatureEncoder(Wav2Vec2FeatureEncoder): + pass + + +class HubertFeatureProjection(nn.Module): + def __init__(self, config): + super().__init__() + self.feat_proj_layer_norm = config.feat_proj_layer_norm + if self.feat_proj_layer_norm: + self.layer_norm = nn.LayerNorm(config.conv_dim[-1], eps=config.layer_norm_eps) + self.projection = nn.Linear(config.conv_dim[-1], config.hidden_size) + self.dropout = nn.Dropout(config.feat_proj_dropout) + + def forward(self, hidden_states): + # non-projected hidden states are needed for quantization + if self.feat_proj_layer_norm: + hidden_states = self.layer_norm(hidden_states) + hidden_states = self.projection(hidden_states) + hidden_states = self.dropout(hidden_states) + return hidden_states + + +class HubertEncoder(Wav2Vec2Encoder): + pass + + +class HubertEncoderStableLayerNorm(Wav2Vec2EncoderStableLayerNorm): + pass + + +class HubertPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = HubertConfig + base_model_prefix = "hubert" + main_input_name = "input_values" + supports_gradient_checkpointing = True + _supports_flash_attn_2 = True + _supports_sdpa = True + + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, nn.Linear): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + elif isinstance(module, nn.Conv1d): + if is_deepspeed_zero3_enabled(): + import deepspeed + + if hasattr(module, "weight_v") and hasattr(module, "weight_g"): + with deepspeed.zero.GatheredParameters([module.weight_v, module.weight_g], modifier_rank=0): + nn.init.kaiming_normal_(module.weight.data) + else: + with deepspeed.zero.GatheredParameters(module.weight, modifier_rank=0): + nn.init.kaiming_normal_(module.weight.data) + else: + nn.init.kaiming_normal_(module.weight.data) + + if isinstance(module, (nn.Linear, nn.Conv1d)) and module.bias is not None: + module.bias.data.zero_() + + def _get_feat_extract_output_lengths(self, input_lengths: Union[torch.LongTensor, int]): + """ + Computes the output length of the convolutional layers + """ + + def _conv_out_length(input_length, kernel_size, stride): + # 1D convolutional layer output length formula taken + # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html + return torch.div(input_length - kernel_size, stride, rounding_mode="floor") + 1 + + for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride): + input_lengths = _conv_out_length(input_lengths, kernel_size, stride) + + return input_lengths + + def _get_feature_vector_attention_mask(self, feature_vector_length: int, attention_mask: torch.LongTensor): + output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long) + batch_size = attention_mask.shape[0] + + attention_mask = torch.zeros( + (batch_size, feature_vector_length), dtype=attention_mask.dtype, device=attention_mask.device + ) + # these two operations makes sure that all values before the output lengths idxs are attended to + attention_mask[(torch.arange(attention_mask.shape[0], device=attention_mask.device), output_lengths - 1)] = 1 + attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool() + return attention_mask + + +HUBERT_START_DOCSTRING = r""" + Hubert was proposed in [HuBERT: Self-Supervised Speech Representation Learning by Masked Prediction of Hidden + Units](https://arxiv.org/abs/2106.07447) by Wei-Ning Hsu, Benjamin Bolte, Yao-Hung Hubert Tsai, Kushal Lakhotia, + Ruslan Salakhutdinov, Abdelrahman Mohamed. + + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving etc.). + + This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use + it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and + behavior. + + Parameters: + config ([`HubertConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +HUBERT_INPUTS_DOCSTRING = r""" + Args: + input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`): + Float values of input raw speech waveform. Values can be obtained by loading a `.flac` or `.wav` audio file + into an array of type `List[float]` or a `numpy.ndarray`, *e.g.* via the soundfile library (`pip install + soundfile`). To prepare the array into `input_values`, the [`AutoProcessor`] should be used for padding and + conversion into a tensor of type `torch.FloatTensor`. See [`Wav2Vec2Processor.__call__`] for details. + attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing convolution and attention on padding token indices. Mask values selected in `[0, + 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + + + `attention_mask` should only be passed if the corresponding processor has `config.return_attention_mask == + True`. For all models whose processor has `config.return_attention_mask == False`, such as + [hubert-base](https://huggingface.co/facebook/hubert-base-ls960), `attention_mask` should **not** be passed + to avoid degraded performance when doing batched inference. For such models `input_values` should simply be + padded with 0 and passed without `attention_mask`. Be aware that these models also yield slightly different + results depending on whether `input_values` is padded or not. + + + + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare Hubert Model transformer outputting raw hidden-states without any specific head on top.", + HUBERT_START_DOCSTRING, +) +class HubertModel(Wav2Vec2Model, HubertPreTrainedModel): + def __init__(self, config: HubertConfig): + super().__init__(config) + self.config = config + self.feature_extractor = HubertFeatureEncoder(config) + self.feature_projection = HubertFeatureProjection(config) + + if config.mask_time_prob > 0.0 or config.mask_feature_prob > 0.0: + self.masked_spec_embed = nn.Parameter(torch.Tensor(config.hidden_size).uniform_()) + + if config.do_stable_layer_norm: + self.encoder = HubertEncoderStableLayerNorm(config) + else: + self.encoder = HubertEncoder(config) + + # Initialize weights and apply final processing + self.post_init() + + del self.adapter + + def freeze_feature_extractor(self): + raise AttributeError("Not needed for Hubert") + + def freeze_feature_encoder(self): + raise AttributeError("Not needed for Hubert") + + @add_start_docstrings_to_model_forward(HUBERT_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_values: Optional[torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + mask_time_indices: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutput]: + """ + + Returns: + + Example: + + ```python + >>> from transformers import AutoProcessor, HubertModel + >>> from datasets import load_dataset + >>> import soundfile as sf + + >>> processor = AutoProcessor.from_pretrained("facebook/hubert-large-ls960-ft") + >>> model = HubertModel.from_pretrained("facebook/hubert-large-ls960-ft") + + + >>> def map_to_array(batch): + ... speech, _ = sf.read(batch["file"]) + ... batch["speech"] = speech + ... return batch + + + >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") + >>> ds = ds.map(map_to_array) + + >>> input_values = processor(ds["speech"][0], return_tensors="pt").input_values # Batch size 1 + >>> hidden_states = model(input_values).last_hidden_state + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + extract_features = self.feature_extractor(input_values) + extract_features = extract_features.transpose(1, 2) + + if attention_mask is not None: + # compute reduced attention_mask corresponding to feature vectors + attention_mask = self._get_feature_vector_attention_mask(extract_features.shape[1], attention_mask) + + hidden_states = self.feature_projection(extract_features) + hidden_states = self._mask_hidden_states(hidden_states, mask_time_indices=mask_time_indices) + + encoder_outputs = self.encoder( + hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = encoder_outputs[0] + + if not return_dict: + return (hidden_states,) + encoder_outputs[1:] + + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +@add_start_docstrings( + """Hubert Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC).""", + HUBERT_START_DOCSTRING, +) +class HubertForCTC(Wav2Vec2ForCTC): + pass + + @add_start_docstrings_to_model_forward(HUBERT_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=CausalLMOutput, + config_class=_CONFIG_FOR_DOC, + expected_output=_CTC_EXPECTED_OUTPUT, + expected_loss=_CTC_EXPECTED_LOSS, + ) + def forward(self, **super_kwargs): + super().forward(**super_kwargs) + + +@add_start_docstrings( + """ + Hubert Model with a sequence classification head on top (a linear layer over the pooled output) for tasks like + SUPERB Keyword Spotting. + """, + HUBERT_START_DOCSTRING, +) +class HubertForSequenceClassification(Wav2Vec2ForSequenceClassification): + pass + + @add_start_docstrings_to_model_forward(HUBERT_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_SEQ_CLASS_CHECKPOINT, + output_type=SequenceClassifierOutput, + config_class=_CONFIG_FOR_DOC, + modality="audio", + expected_output=_SEQ_CLASS_EXPECTED_OUTPUT, + expected_loss=_SEQ_CLASS_EXPECTED_LOSS, + ) + def forward(self, **super_kwargs): + super().forward(**super_kwargs) + + +__all__ = ["HubertForCTC", "HubertForSequenceClassification", "HubertModel", "HubertPreTrainedModel"] diff --git a/src/transformers/models/mistral/modular_mistral.py b/src/transformers/models/mistral/modular_mistral.py index 4548e69a738..6026f5a7e07 100644 --- a/src/transformers/models/mistral/modular_mistral.py +++ b/src/transformers/models/mistral/modular_mistral.py @@ -303,7 +303,7 @@ class MistralForQuestionAnswering(LlamaForQuestionAnswering): output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, **kwargs, - ) -> Union[Tuple, QuestionAnsweringModelOutput]: + ) -> QuestionAnsweringModelOutput: r""" start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for position (index) of the start of the labelled span for computing the token classification loss. diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index 013f04ab36b..bc3f09a778b 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -620,7 +620,7 @@ class MixtralModel(MixtralPreTrainedModel): output_router_logits: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, **flash_attn_kwargs: Unpack[FlashAttentionKwargs], - ) -> BaseModelOutputWithPast: + ) -> MoeModelOutputWithPast: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_router_logits = ( output_router_logits if output_router_logits is not None else self.config.output_router_logits @@ -1013,7 +1013,7 @@ class MixtralForCausalLM(MixtralPreTrainedModel, GenerationMixin): cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, **kwargs: Unpack[KwargsForCausalLM], - ) -> CausalLMOutputWithPast: + ) -> MoeCausalLMOutputWithPast: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., diff --git a/src/transformers/models/moonshine/modeling_moonshine.py b/src/transformers/models/moonshine/modeling_moonshine.py index 09c48edbf24..bc46c9ab0c1 100644 --- a/src/transformers/models/moonshine/modeling_moonshine.py +++ b/src/transformers/models/moonshine/modeling_moonshine.py @@ -853,7 +853,7 @@ class MoonshineDecoder(MoonshinePreTrainedModel): encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, **flash_attn_kwargs: Unpack[FlashAttentionKwargs], - ) -> BaseModelOutputWithPast: + ) -> Union[Tuple, BaseModelOutputWithPast]: """ Args: encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*): @@ -1415,7 +1415,7 @@ class MoonshineModel(MoonshinePreTrainedModel): output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - ) -> Union[Tuple[torch.Tensor], Seq2SeqModelOutput]: + ) -> Seq2SeqModelOutput: r""" Returns: diff --git a/src/transformers/models/phi/modular_phi.py b/src/transformers/models/phi/modular_phi.py index cb9b1583382..f0eb31058c9 100644 --- a/src/transformers/models/phi/modular_phi.py +++ b/src/transformers/models/phi/modular_phi.py @@ -1,5 +1,5 @@ from functools import partial -from typing import Callable, Optional, Tuple, Union +from typing import Callable, Optional, Tuple import torch import torch.nn as nn @@ -191,7 +191,7 @@ class PhiModel(LlamaModel): output_hidden_states: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, **flash_attn_kwargs: Unpack[FlashAttentionKwargs], - ) -> Union[Tuple, BaseModelOutputWithPast]: + ) -> BaseModelOutputWithPast: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states diff --git a/src/transformers/models/qwen3/modular_qwen3.py b/src/transformers/models/qwen3/modular_qwen3.py index 58c02036b6e..8e0b86da407 100644 --- a/src/transformers/models/qwen3/modular_qwen3.py +++ b/src/transformers/models/qwen3/modular_qwen3.py @@ -14,7 +14,7 @@ # limitations under the License. """PyTorch Qwen3 model.""" -from typing import Callable, Optional, Tuple, Union +from typing import Callable, Optional, Tuple import torch import torch.utils.checkpoint @@ -146,7 +146,7 @@ class Qwen3ForCausalLM(LlamaForCausalLM): def forward( self, **super_kwargs: Unpack[KwargsForCausalLM], - ) -> Union[Tuple, CausalLMOutputWithPast]: + ) -> CausalLMOutputWithPast: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., diff --git a/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py b/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py index b9ffae15a21..d239f79ec7d 100644 --- a/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py +++ b/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py @@ -633,7 +633,7 @@ class Qwen3MoeModel(Qwen3MoePreTrainedModel): output_router_logits: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, **flash_attn_kwargs: Unpack[FlashAttentionKwargs], - ) -> BaseModelOutputWithPast: + ) -> MoeModelOutputWithPast: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_router_logits = ( output_router_logits if output_router_logits is not None else self.config.output_router_logits @@ -1026,7 +1026,7 @@ class Qwen3MoeForCausalLM(Qwen3MoePreTrainedModel, GenerationMixin): cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, **kwargs: Unpack[KwargsForCausalLM], - ) -> CausalLMOutputWithPast: + ) -> MoeCausalLMOutputWithPast: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., diff --git a/src/transformers/models/qwen3_moe/modular_qwen3_moe.py b/src/transformers/models/qwen3_moe/modular_qwen3_moe.py index d8a5cc54f24..385f338c78e 100644 --- a/src/transformers/models/qwen3_moe/modular_qwen3_moe.py +++ b/src/transformers/models/qwen3_moe/modular_qwen3_moe.py @@ -257,7 +257,7 @@ class Qwen3MoeForCausalLM(MixtralForCausalLM): cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, **kwargs: Unpack[KwargsForCausalLM], - ) -> Union[Tuple, MoeCausalLMOutputWithPast]: + ) -> MoeCausalLMOutputWithPast: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., diff --git a/src/transformers/models/sew/modeling_sew.py b/src/transformers/models/sew/modeling_sew.py index d5f773299c9..572c07e3c9d 100644 --- a/src/transformers/models/sew/modeling_sew.py +++ b/src/transformers/models/sew/modeling_sew.py @@ -1218,7 +1218,7 @@ class SEWModel(SEWPreTrainedModel): """SEW Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC).""", SEW_START_DOCSTRING, ) -# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForCTC with Wav2Vec2->SEW, wav2vec2->sew, WAV_2_VEC_2->SEW +# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForCTC with Wav2Vec2->SEW, wav2vec2->sew, WAV2VEC2->SEW class SEWForCTC(SEWPreTrainedModel): def __init__(self, config, target_lang: Optional[str] = None): super().__init__(config) @@ -1377,7 +1377,7 @@ class SEWForCTC(SEWPreTrainedModel): """, SEW_START_DOCSTRING, ) -# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification with Wav2Vec2->SEW, wav2vec2->sew, WAV_2_VEC_2->SEW +# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification with Wav2Vec2->SEW, wav2vec2->sew, WAV2VEC2->SEW class SEWForSequenceClassification(SEWPreTrainedModel): def __init__(self, config): super().__init__(config) diff --git a/src/transformers/models/sew_d/modeling_sew_d.py b/src/transformers/models/sew_d/modeling_sew_d.py index 9f49a46a05f..96c587fbb2e 100644 --- a/src/transformers/models/sew_d/modeling_sew_d.py +++ b/src/transformers/models/sew_d/modeling_sew_d.py @@ -1468,7 +1468,7 @@ class SEWDModel(SEWDPreTrainedModel): """SEW-D Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC).""", SEWD_START_DOCSTRING, ) -# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForCTC with Wav2Vec2->SEWD, wav2vec2->sew_d, WAV_2_VEC_2->SEWD +# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForCTC with Wav2Vec2->SEWD, wav2vec2->sew_d, WAV2VEC2->SEWD class SEWDForCTC(SEWDPreTrainedModel): def __init__(self, config, target_lang: Optional[str] = None): super().__init__(config) @@ -1627,7 +1627,7 @@ class SEWDForCTC(SEWDPreTrainedModel): """, SEWD_START_DOCSTRING, ) -# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification with Wav2Vec2->SEWD, wav2vec2->sew_d, WAV_2_VEC_2->SEWD +# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification with Wav2Vec2->SEWD, wav2vec2->sew_d, WAV2VEC2->SEWD class SEWDForSequenceClassification(SEWDPreTrainedModel): def __init__(self, config): super().__init__(config) diff --git a/src/transformers/models/siglip2/modeling_siglip2.py b/src/transformers/models/siglip2/modeling_siglip2.py index b4787439497..612e149d544 100644 --- a/src/transformers/models/siglip2/modeling_siglip2.py +++ b/src/transformers/models/siglip2/modeling_siglip2.py @@ -999,7 +999,7 @@ class Siglip2MultiheadAttentionPoolingHead(nn.Module): self.mlp = Siglip2MLP(config) self.num_heads = config.num_attention_heads - def forward(self, hidden_state: torch.Tensor, attention_mask: Optional[torch.Tensor] = None): + def forward(self, hidden_state: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: batch_size = hidden_state.shape[0] probe = self.probe.repeat(batch_size, 1, 1) diff --git a/src/transformers/models/siglip2/modular_siglip2.py b/src/transformers/models/siglip2/modular_siglip2.py index 92e106bc59b..23df3b0413d 100644 --- a/src/transformers/models/siglip2/modular_siglip2.py +++ b/src/transformers/models/siglip2/modular_siglip2.py @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Tuple, Union +from typing import Optional import torch import torch.nn as nn @@ -243,7 +243,7 @@ class Siglip2VisionTransformer(SiglipVisionTransformer): spatial_shapes: torch.LongTensor, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - ) -> Union[Tuple, BaseModelOutputWithPooling]: + ) -> BaseModelOutputWithPooling: r""" Returns: diff --git a/src/transformers/models/smolvlm/configuration_smolvlm.py b/src/transformers/models/smolvlm/configuration_smolvlm.py index f4a42f348de..cd854415683 100644 --- a/src/transformers/models/smolvlm/configuration_smolvlm.py +++ b/src/transformers/models/smolvlm/configuration_smolvlm.py @@ -19,7 +19,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - from ...configuration_utils import PretrainedConfig from ...utils import logging from ..auto import CONFIG_MAPPING, AutoConfig diff --git a/src/transformers/models/unispeech/modeling_unispeech.py b/src/transformers/models/unispeech/modeling_unispeech.py index 2ee2e054aeb..74608797ab6 100755 --- a/src/transformers/models/unispeech/modeling_unispeech.py +++ b/src/transformers/models/unispeech/modeling_unispeech.py @@ -1,19 +1,9 @@ -# coding=utf-8 -# Copyright 2021 The Fairseq Authors and the HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""PyTorch UniSpeech model.""" - +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/unispeech/modular_unispeech.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_unispeech.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 import math import warnings from dataclasses import dataclass @@ -21,18 +11,22 @@ from typing import Optional, Tuple, Union import numpy as np import torch -import torch.utils.checkpoint -from torch import nn +import torch.nn as nn from torch.nn import CrossEntropyLoss from ...activations import ACT2FN from ...integrations.deepspeed import is_deepspeed_zero3_enabled from ...integrations.fsdp import is_fsdp_managed_module from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available -from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput, Wav2Vec2BaseModelOutput +from ...modeling_outputs import ( + BaseModelOutput, + CausalLMOutput, + ModelOutput, + SequenceClassifierOutput, + Wav2Vec2BaseModelOutput, +) from ...modeling_utils import PreTrainedModel from ...utils import ( - ModelOutput, add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, @@ -48,20 +42,12 @@ if is_flash_attn_available(): logger = logging.get_logger(__name__) - -_HIDDEN_STATES_START_POSITION = 2 +# Base docstring +_CHECKPOINT_FOR_DOC = "patrickvonplaten/unispeech-large-1500h-cv-timit" # General docstring _CONFIG_FOR_DOC = "UniSpeechConfig" -# Base docstring -_CHECKPOINT_FOR_DOC = "patrickvonplaten/unispeech-large-1500h-cv-timit" -_EXPECTED_OUTPUT_SHAPE = [1, 292, 1024] - -# CTC docstring -_CTC_EXPECTED_OUTPUT = "'mister quilter is the apposl of the midle classes and weare glad to welcom his gosepl'" -_CTC_EXPECTED_LOSS = 17.17 - @dataclass class UniSpeechForPreTrainingOutput(ModelOutput): @@ -99,202 +85,17 @@ class UniSpeechForPreTrainingOutput(ModelOutput): attentions: Optional[Tuple[torch.FloatTensor]] = None -# Copied from transformers.models.wav2vec2.modeling_wav2vec2._compute_mask_indices -def _compute_mask_indices( - shape: Tuple[int, int], - mask_prob: float, - mask_length: int, - attention_mask: Optional[torch.LongTensor] = None, - min_masks: int = 0, -) -> np.ndarray: - """ - Computes random mask spans for a given shape. Used to implement [SpecAugment: A Simple Data Augmentation Method for - ASR](https://arxiv.org/abs/1904.08779). Note that this method is not optimized to run on TPU and should be run on - CPU as part of the preprocessing during training. - - Args: - shape: The shape for which to compute masks. This should be of a tuple of size 2 where - the first element is the batch size and the second element is the length of the axis to span. - mask_prob: The percentage of the whole axis (between 0 and 1) which will be masked. The number of - independently generated mask spans of length `mask_length` is computed by - `mask_prob*shape[1]/mask_length`. Note that due to overlaps, `mask_prob` is an upper bound and the - actual percentage will be smaller. - mask_length: size of the mask - min_masks: minimum number of masked spans - attention_mask: A (right-padded) attention mask which independently shortens the feature axis of - each batch dimension. - """ - batch_size, sequence_length = shape - - if mask_length < 1: - raise ValueError("`mask_length` has to be bigger than 0.") - - if mask_length > sequence_length: - raise ValueError( - f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length}" - f" and `sequence_length`: {sequence_length}`" - ) - - # epsilon is used for probabilistic rounding - epsilon = np.random.rand(1).item() - - def compute_num_masked_span(input_length): - """Given input length, compute how many spans should be masked""" - num_masked_span = int(mask_prob * input_length / mask_length + epsilon) - num_masked_span = max(num_masked_span, min_masks) - - # make sure num masked span <= sequence_length - if num_masked_span * mask_length > sequence_length: - num_masked_span = sequence_length // mask_length - - # make sure num_masked span is also <= input_length - (mask_length - 1) - if input_length - (mask_length - 1) < num_masked_span: - num_masked_span = max(input_length - (mask_length - 1), 0) - - return num_masked_span - - # compute number of masked spans in batch - input_lengths = ( - attention_mask.detach().sum(-1).tolist() - if attention_mask is not None - else [sequence_length for _ in range(batch_size)] - ) - - # SpecAugment mask to fill - spec_aug_mask = np.zeros((batch_size, sequence_length), dtype=bool) - spec_aug_mask_idxs = [] - - max_num_masked_span = compute_num_masked_span(sequence_length) - - if max_num_masked_span == 0: - return spec_aug_mask - - for input_length in input_lengths: - # compute num of masked spans for this input - num_masked_span = compute_num_masked_span(input_length) - - # get random indices to mask - spec_aug_mask_idx = np.random.choice( - np.arange(input_length - (mask_length - 1)), num_masked_span, replace=False - ) - - # pick first sampled index that will serve as a dummy index to pad vector - # to ensure same dimension for all batches due to probabilistic rounding - # Picking first sample just pads those vectors twice. - if len(spec_aug_mask_idx) == 0: - # this case can only happen if `input_length` is strictly smaller then - # `sequence_length` in which case the last token has to be a padding - # token which we can use as a dummy mask id - dummy_mask_idx = sequence_length - 1 - else: - dummy_mask_idx = spec_aug_mask_idx[0] - - spec_aug_mask_idx = np.concatenate( - [spec_aug_mask_idx, np.ones(max_num_masked_span - num_masked_span, dtype=np.int32) * dummy_mask_idx] - ) - spec_aug_mask_idxs.append(spec_aug_mask_idx) - - spec_aug_mask_idxs = np.array(spec_aug_mask_idxs) - - # expand masked indices to masked spans - spec_aug_mask_idxs = np.broadcast_to( - spec_aug_mask_idxs[:, :, None], (batch_size, max_num_masked_span, mask_length) - ) - spec_aug_mask_idxs = spec_aug_mask_idxs.reshape(batch_size, max_num_masked_span * mask_length) - - # add offset to the starting indexes so that indexes now create a span - offsets = np.arange(mask_length)[None, None, :] - offsets = np.broadcast_to(offsets, (batch_size, max_num_masked_span, mask_length)).reshape( - batch_size, max_num_masked_span * mask_length - ) - spec_aug_mask_idxs = spec_aug_mask_idxs + offsets - - # ensure that we cannot have indices larger than sequence_length - if spec_aug_mask_idxs.max() > sequence_length - 1: - spec_aug_mask_idxs[spec_aug_mask_idxs > sequence_length - 1] = sequence_length - 1 - - # scatter indices to mask - np.put_along_axis(spec_aug_mask, spec_aug_mask_idxs, 1, -1) - - return spec_aug_mask - - -# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2NoLayerNormConvLayer with Wav2Vec2->UniSpeech -class UniSpeechNoLayerNormConvLayer(nn.Module): - def __init__(self, config, layer_id=0): +class UniSpeechSamePadLayer(nn.Module): + def __init__(self, num_conv_pos_embeddings): super().__init__() - self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1 - self.out_conv_dim = config.conv_dim[layer_id] - - self.conv = nn.Conv1d( - self.in_conv_dim, - self.out_conv_dim, - kernel_size=config.conv_kernel[layer_id], - stride=config.conv_stride[layer_id], - bias=config.conv_bias, - ) - self.activation = ACT2FN[config.feat_extract_activation] + self.num_pad_remove = 1 if num_conv_pos_embeddings % 2 == 0 else 0 def forward(self, hidden_states): - hidden_states = self.conv(hidden_states) - hidden_states = self.activation(hidden_states) + if self.num_pad_remove > 0: + hidden_states = hidden_states[:, :, : -self.num_pad_remove] return hidden_states -# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2LayerNormConvLayer with Wav2Vec2->UniSpeech -class UniSpeechLayerNormConvLayer(nn.Module): - def __init__(self, config, layer_id=0): - super().__init__() - self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1 - self.out_conv_dim = config.conv_dim[layer_id] - - self.conv = nn.Conv1d( - self.in_conv_dim, - self.out_conv_dim, - kernel_size=config.conv_kernel[layer_id], - stride=config.conv_stride[layer_id], - bias=config.conv_bias, - ) - self.layer_norm = nn.LayerNorm(self.out_conv_dim, elementwise_affine=True) - self.activation = ACT2FN[config.feat_extract_activation] - - def forward(self, hidden_states): - hidden_states = self.conv(hidden_states) - - hidden_states = hidden_states.transpose(-2, -1) - hidden_states = self.layer_norm(hidden_states) - hidden_states = hidden_states.transpose(-2, -1) - - hidden_states = self.activation(hidden_states) - return hidden_states - - -# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2GroupNormConvLayer with Wav2Vec2->UniSpeech -class UniSpeechGroupNormConvLayer(nn.Module): - def __init__(self, config, layer_id=0): - super().__init__() - self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1 - self.out_conv_dim = config.conv_dim[layer_id] - - self.conv = nn.Conv1d( - self.in_conv_dim, - self.out_conv_dim, - kernel_size=config.conv_kernel[layer_id], - stride=config.conv_stride[layer_id], - bias=config.conv_bias, - ) - self.activation = ACT2FN[config.feat_extract_activation] - - self.layer_norm = nn.GroupNorm(num_groups=self.out_conv_dim, num_channels=self.out_conv_dim, affine=True) - - def forward(self, hidden_states): - hidden_states = self.conv(hidden_states) - hidden_states = self.layer_norm(hidden_states) - hidden_states = self.activation(hidden_states) - return hidden_states - - -# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2PositionalConvEmbedding with Wav2Vec2->UniSpeech class UniSpeechPositionalConvEmbedding(nn.Module): def __init__(self, config): super().__init__() @@ -340,19 +141,78 @@ class UniSpeechPositionalConvEmbedding(nn.Module): return hidden_states -# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2SamePadLayer with Wav2Vec2->UniSpeech -class UniSpeechSamePadLayer(nn.Module): - def __init__(self, num_conv_pos_embeddings): +class UniSpeechNoLayerNormConvLayer(nn.Module): + def __init__(self, config, layer_id=0): super().__init__() - self.num_pad_remove = 1 if num_conv_pos_embeddings % 2 == 0 else 0 + self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1 + self.out_conv_dim = config.conv_dim[layer_id] + + self.conv = nn.Conv1d( + self.in_conv_dim, + self.out_conv_dim, + kernel_size=config.conv_kernel[layer_id], + stride=config.conv_stride[layer_id], + bias=config.conv_bias, + ) + self.activation = ACT2FN[config.feat_extract_activation] def forward(self, hidden_states): - if self.num_pad_remove > 0: - hidden_states = hidden_states[:, :, : -self.num_pad_remove] + hidden_states = self.conv(hidden_states) + hidden_states = self.activation(hidden_states) + return hidden_states + + +class UniSpeechLayerNormConvLayer(nn.Module): + def __init__(self, config, layer_id=0): + super().__init__() + self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1 + self.out_conv_dim = config.conv_dim[layer_id] + + self.conv = nn.Conv1d( + self.in_conv_dim, + self.out_conv_dim, + kernel_size=config.conv_kernel[layer_id], + stride=config.conv_stride[layer_id], + bias=config.conv_bias, + ) + self.layer_norm = nn.LayerNorm(self.out_conv_dim, elementwise_affine=True) + self.activation = ACT2FN[config.feat_extract_activation] + + def forward(self, hidden_states): + hidden_states = self.conv(hidden_states) + + hidden_states = hidden_states.transpose(-2, -1) + hidden_states = self.layer_norm(hidden_states) + hidden_states = hidden_states.transpose(-2, -1) + + hidden_states = self.activation(hidden_states) + return hidden_states + + +class UniSpeechGroupNormConvLayer(nn.Module): + def __init__(self, config, layer_id=0): + super().__init__() + self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1 + self.out_conv_dim = config.conv_dim[layer_id] + + self.conv = nn.Conv1d( + self.in_conv_dim, + self.out_conv_dim, + kernel_size=config.conv_kernel[layer_id], + stride=config.conv_stride[layer_id], + bias=config.conv_bias, + ) + self.activation = ACT2FN[config.feat_extract_activation] + + self.layer_norm = nn.GroupNorm(num_groups=self.out_conv_dim, num_channels=self.out_conv_dim, affine=True) + + def forward(self, hidden_states): + hidden_states = self.conv(hidden_states) + hidden_states = self.layer_norm(hidden_states) + hidden_states = self.activation(hidden_states) return hidden_states -# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2FeatureEncoder with Wav2Vec2->UniSpeech class UniSpeechFeatureEncoder(nn.Module): """Construct the features from raw audio waveform""" @@ -400,18 +260,6 @@ class UniSpeechFeatureEncoder(nn.Module): return hidden_states -class UniSpeechFeatureExtractor(UniSpeechFeatureEncoder): - def __init__(self, config): - super().__init__(config) - warnings.warn( - f"The class `{self.__class__.__name__}` has been depreciated " - "and will be removed in Transformers v5. " - f"Use `{self.__class__.__bases__[0].__name__}` instead.", - FutureWarning, - ) - - -# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2FeatureProjection with Wav2Vec2->UniSpeech class UniSpeechFeatureProjection(nn.Module): def __init__(self, config): super().__init__() @@ -427,7 +275,6 @@ class UniSpeechFeatureProjection(nn.Module): return hidden_states, norm_hidden_states -# Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->UniSpeech class UniSpeechAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -586,7 +433,6 @@ class UniSpeechAttention(nn.Module): return attn_output, attn_weights_reshaped, past_key_value -# Copied from transformers.models.bart.modeling_bart.BartFlashAttention2 with Bart->UniSpeech class UniSpeechFlashAttention2(UniSpeechAttention): """ UniSpeech flash attention module. This module inherits from `UniSpeechAttention` as the weights of the module stays @@ -714,7 +560,6 @@ class UniSpeechFlashAttention2(UniSpeechAttention): class UniSpeechSdpaAttention(UniSpeechAttention): - # Copied from transformers.models.bart.modeling_bart.BartSdpaAttention.forward with Bart->UniSpeech def forward( self, hidden_states: torch.Tensor, @@ -820,14 +665,6 @@ class UniSpeechSdpaAttention(UniSpeechAttention): return attn_output, None, past_key_value -UNISPEECH_ATTENTION_CLASSES = { - "eager": UniSpeechAttention, - "sdpa": UniSpeechSdpaAttention, - "flash_attention_2": UniSpeechFlashAttention2, -} - - -# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2FeedForward with Wav2Vec2->UniSpeech class UniSpeechFeedForward(nn.Module): def __init__(self, config): super().__init__() @@ -852,7 +689,13 @@ class UniSpeechFeedForward(nn.Module): return hidden_states -# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2EncoderLayer with Wav2Vec2->UniSpeech, WAV2VEC2->UNISPEECH +UNISPEECH_ATTENTION_CLASSES = { + "eager": UniSpeechAttention, + "sdpa": UniSpeechSdpaAttention, + "flash_attention_2": UniSpeechFlashAttention2, +} + + class UniSpeechEncoderLayer(nn.Module): def __init__(self, config): super().__init__() @@ -888,79 +731,6 @@ class UniSpeechEncoderLayer(nn.Module): return outputs -# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2AttnAdapterLayer with Wav2Vec2->UniSpeech -class UniSpeechAttnAdapterLayer(nn.Module): - def __init__(self, config): - """ - Implements adapter modules directly with 3D tensor weight as parameters and without using ModuleList to speed - up training throughput. - """ - super().__init__() - self.input_dim = config.adapter_attn_dim - self.hidden_dim = config.hidden_size - - self.norm = nn.LayerNorm(self.hidden_dim) - self.linear_1 = nn.Linear(self.hidden_dim, self.input_dim) - self.act_fn = nn.ReLU() - self.linear_2 = nn.Linear(self.input_dim, self.hidden_dim) - - def forward(self, hidden_states: torch.FloatTensor): - hidden_states = self.norm(hidden_states) - - hidden_states = self.linear_1(hidden_states) - hidden_states = self.act_fn(hidden_states) - hidden_states = self.linear_2(hidden_states) - - return hidden_states - - -# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2EncoderLayerStableLayerNorm with Wav2Vec2->UniSpeech, WAV2VEC2->UNISPEECH -class UniSpeechEncoderLayerStableLayerNorm(nn.Module): - def __init__(self, config): - super().__init__() - self.attention = UNISPEECH_ATTENTION_CLASSES[config._attn_implementation]( - embed_dim=config.hidden_size, - num_heads=config.num_attention_heads, - dropout=config.attention_dropout, - is_decoder=False, - ) - self.dropout = nn.Dropout(config.hidden_dropout) - self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - self.feed_forward = UniSpeechFeedForward(config) - self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - - if getattr(config, "adapter_attn_dim", None) is not None: - self.adapter_layer = UniSpeechAttnAdapterLayer(config) - else: - self.adapter_layer = None - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - output_attentions: bool = False, - ): - attn_residual = hidden_states - hidden_states = self.layer_norm(hidden_states) - hidden_states, attn_weights, _ = self.attention( - hidden_states, attention_mask=attention_mask, output_attentions=output_attentions - ) - hidden_states = self.dropout(hidden_states) - hidden_states = attn_residual + hidden_states - hidden_states = hidden_states + self.feed_forward(self.final_layer_norm(hidden_states)) - - if self.adapter_layer is not None: - hidden_states = hidden_states + self.adapter_layer(hidden_states) - - outputs = (hidden_states,) - - if output_attentions: - outputs += (attn_weights,) - - return outputs - - -# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Encoder with Wav2Vec2->UniSpeech class UniSpeechEncoder(nn.Module): def __init__(self, config): super().__init__() @@ -1046,7 +816,76 @@ class UniSpeechEncoder(nn.Module): ) -# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2EncoderStableLayerNorm with Wav2Vec2->UniSpeech +class UniSpeechAttnAdapterLayer(nn.Module): + def __init__(self, config): + """ + Implements adapter modules directly with 3D tensor weight as parameters and without using ModuleList to speed + up training throughput. + """ + super().__init__() + self.input_dim = config.adapter_attn_dim + self.hidden_dim = config.hidden_size + + self.norm = nn.LayerNorm(self.hidden_dim) + self.linear_1 = nn.Linear(self.hidden_dim, self.input_dim) + self.act_fn = nn.ReLU() + self.linear_2 = nn.Linear(self.input_dim, self.hidden_dim) + + def forward(self, hidden_states: torch.FloatTensor): + hidden_states = self.norm(hidden_states) + + hidden_states = self.linear_1(hidden_states) + hidden_states = self.act_fn(hidden_states) + hidden_states = self.linear_2(hidden_states) + + return hidden_states + + +class UniSpeechEncoderLayerStableLayerNorm(nn.Module): + def __init__(self, config): + super().__init__() + self.attention = UNISPEECH_ATTENTION_CLASSES[config._attn_implementation]( + embed_dim=config.hidden_size, + num_heads=config.num_attention_heads, + dropout=config.attention_dropout, + is_decoder=False, + ) + self.dropout = nn.Dropout(config.hidden_dropout) + self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.feed_forward = UniSpeechFeedForward(config) + self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + if getattr(config, "adapter_attn_dim", None) is not None: + self.adapter_layer = UniSpeechAttnAdapterLayer(config) + else: + self.adapter_layer = None + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ): + attn_residual = hidden_states + hidden_states = self.layer_norm(hidden_states) + hidden_states, attn_weights, _ = self.attention( + hidden_states, attention_mask=attention_mask, output_attentions=output_attentions + ) + hidden_states = self.dropout(hidden_states) + hidden_states = attn_residual + hidden_states + hidden_states = hidden_states + self.feed_forward(self.final_layer_norm(hidden_states)) + + if self.adapter_layer is not None: + hidden_states = hidden_states + self.adapter_layer(hidden_states) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_weights,) + + return outputs + + class UniSpeechEncoderStableLayerNorm(nn.Module): def __init__(self, config): super().__init__() @@ -1138,7 +977,7 @@ class UniSpeechEncoderStableLayerNorm(nn.Module): class UniSpeechGumbelVectorQuantizer(nn.Module): """ - Vector quantization using gumbel softmax. See [CATEGORICAL REPARAMETERIZATION WITH + Vector quantization using gumbel softmax. See `[CATEGORICAL REPARAMETERIZATION WITH GUMBEL-SOFTMAX](https://arxiv.org/pdf/1611.01144.pdf) for more information. """ @@ -1149,8 +988,8 @@ class UniSpeechGumbelVectorQuantizer(nn.Module): if config.codevector_dim % self.num_groups != 0: raise ValueError( - f"`config.codevector_dim {config.codevector_dim} must be divisible by `config.num_codevector_groups`" - f" {self.num_groups} for concatenation" + f"`config.codevector_dim {config.codevector_dim} must be divisible " + f"by `config.num_codevector_groups` {self.num_groups} for concatenation" ) # storage for codebook variables (codewords) @@ -1283,6 +1122,128 @@ class UniSpeechPreTrainedModel(PreTrainedModel): return attention_mask +def _compute_mask_indices( + shape: Tuple[int, int], + mask_prob: float, + mask_length: int, + attention_mask: Optional[torch.LongTensor] = None, + min_masks: int = 0, +) -> np.ndarray: + """ + Computes random mask spans for a given shape. Used to implement [SpecAugment: A Simple Data Augmentation Method for + ASR](https://arxiv.org/abs/1904.08779). Note that this method is not optimized to run on TPU and should be run on + CPU as part of the preprocessing during training. + + Args: + shape: The shape for which to compute masks. This should be of a tuple of size 2 where + the first element is the batch size and the second element is the length of the axis to span. + mask_prob: The percentage of the whole axis (between 0 and 1) which will be masked. The number of + independently generated mask spans of length `mask_length` is computed by + `mask_prob*shape[1]/mask_length`. Note that due to overlaps, `mask_prob` is an upper bound and the + actual percentage will be smaller. + mask_length: size of the mask + min_masks: minimum number of masked spans + attention_mask: A (right-padded) attention mask which independently shortens the feature axis of + each batch dimension. + """ + batch_size, sequence_length = shape + + if mask_length < 1: + raise ValueError("`mask_length` has to be bigger than 0.") + + if mask_length > sequence_length: + raise ValueError( + f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length}" + f" and `sequence_length`: {sequence_length}`" + ) + + # epsilon is used for probabilistic rounding + epsilon = np.random.rand(1).item() + + def compute_num_masked_span(input_length): + """Given input length, compute how many spans should be masked""" + num_masked_span = int(mask_prob * input_length / mask_length + epsilon) + num_masked_span = max(num_masked_span, min_masks) + + # make sure num masked span <= sequence_length + if num_masked_span * mask_length > sequence_length: + num_masked_span = sequence_length // mask_length + + # make sure num_masked span is also <= input_length - (mask_length - 1) + if input_length - (mask_length - 1) < num_masked_span: + num_masked_span = max(input_length - (mask_length - 1), 0) + + return num_masked_span + + # compute number of masked spans in batch + input_lengths = ( + attention_mask.detach().sum(-1).tolist() + if attention_mask is not None + else [sequence_length for _ in range(batch_size)] + ) + + # SpecAugment mask to fill + spec_aug_mask = np.zeros((batch_size, sequence_length), dtype=bool) + spec_aug_mask_idxs = [] + + max_num_masked_span = compute_num_masked_span(sequence_length) + + if max_num_masked_span == 0: + return spec_aug_mask + + for input_length in input_lengths: + # compute num of masked spans for this input + num_masked_span = compute_num_masked_span(input_length) + + # get random indices to mask + spec_aug_mask_idx = np.random.choice( + np.arange(input_length - (mask_length - 1)), num_masked_span, replace=False + ) + + # pick first sampled index that will serve as a dummy index to pad vector + # to ensure same dimension for all batches due to probabilistic rounding + # Picking first sample just pads those vectors twice. + if len(spec_aug_mask_idx) == 0: + # this case can only happen if `input_length` is strictly smaller then + # `sequence_length` in which case the last token has to be a padding + # token which we can use as a dummy mask id + dummy_mask_idx = sequence_length - 1 + else: + dummy_mask_idx = spec_aug_mask_idx[0] + + spec_aug_mask_idx = np.concatenate( + [spec_aug_mask_idx, np.ones(max_num_masked_span - num_masked_span, dtype=np.int32) * dummy_mask_idx] + ) + spec_aug_mask_idxs.append(spec_aug_mask_idx) + + spec_aug_mask_idxs = np.array(spec_aug_mask_idxs) + + # expand masked indices to masked spans + spec_aug_mask_idxs = np.broadcast_to( + spec_aug_mask_idxs[:, :, None], (batch_size, max_num_masked_span, mask_length) + ) + spec_aug_mask_idxs = spec_aug_mask_idxs.reshape(batch_size, max_num_masked_span * mask_length) + + # add offset to the starting indexes so that indexes now create a span + offsets = np.arange(mask_length)[None, None, :] + offsets = np.broadcast_to(offsets, (batch_size, max_num_masked_span, mask_length)).reshape( + batch_size, max_num_masked_span * mask_length + ) + spec_aug_mask_idxs = spec_aug_mask_idxs + offsets + + # ensure that we cannot have indices larger than sequence_length + if spec_aug_mask_idxs.max() > sequence_length - 1: + spec_aug_mask_idxs[spec_aug_mask_idxs > sequence_length - 1] = sequence_length - 1 + + # scatter indices to mask + np.put_along_axis(spec_aug_mask, spec_aug_mask_idxs, 1, -1) + + return spec_aug_mask + + +_EXPECTED_OUTPUT_SHAPE = [1, 292, 1024] + + UNISPEECH_START_DOCSTRING = r""" UniSpeech was proposed in [UniSpeech: Unified Speech Representation Learning with Labeled and Unlabeled Data](https://arxiv.org/abs/2101.07597) by Chengyi Wang, Yu Wu, Yao Qian, Kenichi Kumatani, Shujie Liu, Furu Wei, @@ -1301,7 +1262,6 @@ UNISPEECH_START_DOCSTRING = r""" configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. """ - UNISPEECH_INPUTS_DOCSTRING = r""" Args: input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`): @@ -1339,6 +1299,9 @@ UNISPEECH_INPUTS_DOCSTRING = r""" """ +UniSpeechBaseModelOutput = Wav2Vec2BaseModelOutput + + @add_start_docstrings( "The bare UniSpeech Model transformer outputting raw hidden-states without any specific head on top.", UNISPEECH_START_DOCSTRING, @@ -1361,7 +1324,6 @@ class UniSpeechModel(UniSpeechPreTrainedModel): # Initialize weights and apply final processing self.post_init() - # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Model._mask_hidden_states def _mask_hidden_states( self, hidden_states: torch.FloatTensor, @@ -1411,7 +1373,7 @@ class UniSpeechModel(UniSpeechPreTrainedModel): @add_start_docstrings_to_model_forward(UNISPEECH_INPUTS_DOCSTRING) @add_code_sample_docstrings( checkpoint=_CHECKPOINT_FOR_DOC, - output_type=Wav2Vec2BaseModelOutput, + output_type=UniSpeechBaseModelOutput, config_class=_CONFIG_FOR_DOC, modality="audio", expected_output=_EXPECTED_OUTPUT_SHAPE, @@ -1424,7 +1386,7 @@ class UniSpeechModel(UniSpeechPreTrainedModel): output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, - ) -> Union[Tuple, Wav2Vec2BaseModelOutput]: + ) -> Union[Tuple, UniSpeechBaseModelOutput]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states @@ -1456,7 +1418,7 @@ class UniSpeechModel(UniSpeechPreTrainedModel): if not return_dict: return (hidden_states, extract_features) + encoder_outputs[1:] - return Wav2Vec2BaseModelOutput( + return UniSpeechBaseModelOutput( last_hidden_state=hidden_states, extract_features=extract_features, hidden_states=encoder_outputs.hidden_states, @@ -1610,6 +1572,13 @@ class UniSpeechForPreTraining(UniSpeechPreTrainedModel): ) +_HIDDEN_STATES_START_POSITION = 2 + +# CTC docstring +_CTC_EXPECTED_OUTPUT = "'mister quilter is the apposl of the midle classes and weare glad to welcom his gosepl'" +_CTC_EXPECTED_LOSS = 17.17 + + @add_start_docstrings( """UniSpeech Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC).""", UNISPEECH_START_DOCSTRING, @@ -1620,7 +1589,6 @@ class UniSpeechForPreTraining(UniSpeechPreTrainedModel): by default. """, ) -# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForCTC with Wav2Vec2->UniSpeech, wav2vec2->unispeech, WAV_2_VEC_2->UNISPEECH class UniSpeechForCTC(UniSpeechPreTrainedModel): def __init__(self, config, target_lang: Optional[str] = None): super().__init__(config) @@ -1797,7 +1765,6 @@ class UniSpeechForSequenceClassification(UniSpeechPreTrainedModel): # Initialize weights and apply final processing self.post_init() - # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.freeze_feature_extractor def freeze_feature_extractor(self): """ Calling this function will disable the gradient computation for the feature encoder so that its parameters will @@ -1810,7 +1777,6 @@ class UniSpeechForSequenceClassification(UniSpeechPreTrainedModel): ) self.freeze_feature_encoder() - # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.freeze_feature_encoder with wav2vec2->unispeech def freeze_feature_encoder(self): """ Calling this function will disable the gradient computation for the feature encoder so that its parameter will @@ -1818,7 +1784,6 @@ class UniSpeechForSequenceClassification(UniSpeechPreTrainedModel): """ self.unispeech.feature_extractor._freeze_parameters() - # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.freeze_base_model with wav2vec2->unispeech def freeze_base_model(self): """ Calling this function will disable the gradient computation for the base model so that its parameters will not @@ -1834,7 +1799,6 @@ class UniSpeechForSequenceClassification(UniSpeechPreTrainedModel): config_class=_CONFIG_FOR_DOC, modality="audio", ) - # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.forward with Wav2Vec2->UniSpeech, wav2vec2->unispeech def forward( self, input_values: Optional[torch.Tensor], diff --git a/src/transformers/models/unispeech/modular_unispeech.py b/src/transformers/models/unispeech/modular_unispeech.py new file mode 100644 index 00000000000..1096bc559b4 --- /dev/null +++ b/src/transformers/models/unispeech/modular_unispeech.py @@ -0,0 +1,563 @@ +import math +import warnings +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import torch +import torch.nn as nn + +from ...modeling_outputs import CausalLMOutput, ModelOutput, SequenceClassifierOutput, Wav2Vec2BaseModelOutput +from ...modeling_utils import PreTrainedModel +from ...utils import ( + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, +) +from ..wav2vec2.modeling_wav2vec2 import ( + Wav2Vec2Encoder, + Wav2Vec2EncoderStableLayerNorm, + Wav2Vec2FeatureEncoder, + Wav2Vec2FeatureProjection, + Wav2Vec2ForCTC, + Wav2Vec2ForSequenceClassification, + Wav2Vec2GumbelVectorQuantizer, + Wav2Vec2Model, + Wav2Vec2PositionalConvEmbedding, +) +from .configuration_unispeech import UniSpeechConfig + + +logger = logging.get_logger(__name__) + + +_HIDDEN_STATES_START_POSITION = 2 + +# General docstring +_CONFIG_FOR_DOC = "UniSpeechConfig" + +# Base docstring +_CHECKPOINT_FOR_DOC = "patrickvonplaten/unispeech-large-1500h-cv-timit" +_EXPECTED_OUTPUT_SHAPE = [1, 292, 1024] + +# CTC docstring +_CTC_EXPECTED_OUTPUT = "'mister quilter is the apposl of the midle classes and weare glad to welcom his gosepl'" +_CTC_EXPECTED_LOSS = 17.17 + + +@dataclass +class UniSpeechForPreTrainingOutput(ModelOutput): + """ + Output type of [`UniSpeechForPreTrainingOutput`], with potential hidden states and attentions. + + Args: + loss (*optional*, returned when model is in train mode, `torch.FloatTensor` of shape `(1,)`): + Total loss as the sum of the contrastive loss (L_m) and the diversity loss (L_d) as stated in the [official + paper](https://arxiv.org/pdf/2006.11477.pdf) . (classification) loss. + projected_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.proj_codevector_dim)`): + Hidden-states of the model projected to *config.proj_codevector_dim* that can be used to predict the masked + projected quantized states. + projected_quantized_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.proj_codevector_dim)`): + Quantized extracted feature vectors projected to *config.proj_codevector_dim* representing the positive + target vectors for contrastive loss. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + loss: Optional[torch.FloatTensor] = None + projected_states: Optional[torch.FloatTensor] = None + projected_quantized_states: Optional[torch.FloatTensor] = None + codevector_perplexity: Optional[torch.FloatTensor] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + + +class UniSpeechPositionalConvEmbedding(Wav2Vec2PositionalConvEmbedding): + pass + + +class UniSpeechFeatureEncoder(Wav2Vec2FeatureEncoder): + pass + + +class UniSpeechFeatureProjection(Wav2Vec2FeatureProjection): + pass + + +class UniSpeechEncoder(Wav2Vec2Encoder): + pass + + +class UniSpeechEncoderStableLayerNorm(Wav2Vec2EncoderStableLayerNorm): + pass + + +class UniSpeechGumbelVectorQuantizer(Wav2Vec2GumbelVectorQuantizer): + @staticmethod + def _compute_perplexity(probs): + marginal_probs = probs.mean(dim=0) + perplexity = torch.exp(-torch.sum(marginal_probs * torch.log(marginal_probs + 1e-7), dim=-1)).sum() + return perplexity + + def forward(self, hidden_states): + batch_size, sequence_length, hidden_size = hidden_states.shape + + # project to codevector dim + hidden_states = self.weight_proj(hidden_states) + hidden_states = hidden_states.view(batch_size * sequence_length * self.num_groups, -1) + + if self.training: + # sample code vector probs via gumbel in differentiateable way + codevector_probs = nn.functional.gumbel_softmax( + hidden_states.float(), tau=self.temperature, hard=True + ).type_as(hidden_states) + + # compute perplexity + codevector_soft_dist = torch.softmax( + hidden_states.view(batch_size * sequence_length, self.num_groups, -1).float(), dim=-1 + ) + perplexity = self._compute_perplexity(codevector_soft_dist) + else: + # take argmax in non-differentiable way + # comptute hard codevector distribution (one hot) + codevector_idx = hidden_states.argmax(dim=-1) + codevector_probs = hidden_states.new_zeros(*hidden_states.shape).scatter_( + -1, codevector_idx.view(-1, 1), 1.0 + ) + codevector_probs = codevector_probs.view(batch_size * sequence_length, self.num_groups, -1) + + perplexity = self._compute_perplexity(codevector_probs) + + codevector_probs = codevector_probs.view(batch_size * sequence_length, -1) + # use probs to retrieve codevectors + codevectors_per_group = codevector_probs.unsqueeze(-1) * self.codevectors + codevectors = codevectors_per_group.view(batch_size * sequence_length, self.num_groups, self.num_vars, -1) + codevectors = codevectors.sum(-2).view(batch_size, sequence_length, -1) + + return codevectors, perplexity + + +class UniSpeechPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = UniSpeechConfig + base_model_prefix = "unispeech" + main_input_name = "input_values" + supports_gradient_checkpointing = True + _supports_flash_attn_2 = True + _supports_sdpa = True + + def _init_weights(self, module): + """Initialize the weights""" + # gumbel softmax requires special init + if isinstance(module, UniSpeechGumbelVectorQuantizer): + module.weight_proj.weight.data.normal_(mean=0.0, std=1) + module.weight_proj.bias.data.zero_() + nn.init.uniform_(module.codevectors) + elif isinstance(module, UniSpeechPositionalConvEmbedding): + nn.init.normal_( + module.conv.weight, + mean=0, + std=2 * math.sqrt(1 / (module.conv.kernel_size[0] * module.conv.in_channels)), + ) + nn.init.constant_(module.conv.bias, 0) + elif isinstance(module, UniSpeechFeatureProjection): + k = math.sqrt(1 / module.projection.in_features) + nn.init.uniform_(module.projection.weight, a=-k, b=k) + nn.init.uniform_(module.projection.bias, a=-k, b=k) + elif isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + elif isinstance(module, nn.Conv1d): + nn.init.kaiming_normal_(module.weight) + + if module.bias is not None: + k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0])) + nn.init.uniform_(module.bias, a=-k, b=k) + + def _get_feat_extract_output_lengths(self, input_lengths: Union[torch.LongTensor, int]): + """ + Computes the output length of the convolutional layers + """ + + def _conv_out_length(input_length, kernel_size, stride): + # 1D convolutional layer output length formula taken + # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html + return torch.div(input_length - kernel_size, stride, rounding_mode="floor") + 1 + + for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride): + input_lengths = _conv_out_length(input_lengths, kernel_size, stride) + + return input_lengths + + def _get_feature_vector_attention_mask(self, feature_vector_length: int, attention_mask: torch.LongTensor): + # Effectively attention_mask.sum(-1), but not inplace to be able to run + # on inference mode. + non_padded_lengths = attention_mask.cumsum(dim=-1)[:, -1] + output_lengths = self._get_feat_extract_output_lengths(non_padded_lengths).to(torch.long) + batch_size = attention_mask.shape[0] + + attention_mask = torch.zeros( + (batch_size, feature_vector_length), dtype=attention_mask.dtype, device=attention_mask.device + ) + # these two operations makes sure that all values before the output lengths idxs are attended to + attention_mask[(torch.arange(attention_mask.shape[0], device=attention_mask.device), output_lengths - 1)] = 1 + attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool() + return attention_mask + + +UNISPEECH_START_DOCSTRING = r""" + UniSpeech was proposed in [UniSpeech: Unified Speech Representation Learning with Labeled and Unlabeled + Data](https://arxiv.org/abs/2101.07597) by Chengyi Wang, Yu Wu, Yao Qian, Kenichi Kumatani, Shujie Liu, Furu Wei, + Michael Zeng, Xuedong Huang. + + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving etc.). + + This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use + it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and + behavior. + + Parameters: + config ([`UniSpeechConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +UNISPEECH_INPUTS_DOCSTRING = r""" + Args: + input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`): + Float values of input raw speech waveform. Values can be obtained by loading a `.flac` or `.wav` audio file + into an array of type `List[float]` or a `numpy.ndarray`, *e.g.* via the soundfile library (`pip install + soundfile`). To prepare the array into `input_values`, the [`AutoProcessor`] should be used for padding and + conversion into a tensor of type `torch.FloatTensor`. See [`Wav2Vec2Processor.__call__`] for details. + attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing convolution and attention on padding token indices. Mask values selected in `[0, + 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + + + `attention_mask` should only be passed if the corresponding processor has `config.return_attention_mask == + True`. For all models whose processor has `config.return_attention_mask == False`, `attention_mask` should + **not** be passed to avoid degraded performance when doing batched inference. For such models + `input_values` should simply be padded with 0 and passed without `attention_mask`. Be aware that these + models also yield slightly different results depending on whether `input_values` is padded or not. + + + + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +UniSpeechBaseModelOutput = Wav2Vec2BaseModelOutput + + +@add_start_docstrings( + "The bare UniSpeech Model transformer outputting raw hidden-states without any specific head on top.", + UNISPEECH_START_DOCSTRING, +) +class UniSpeechModel(UniSpeechPreTrainedModel, Wav2Vec2Model): + def __init__(self, config: UniSpeechConfig): + UniSpeechPreTrainedModel.__init__(config) + self.config = config + self.feature_extractor = UniSpeechFeatureEncoder(config) + self.feature_projection = UniSpeechFeatureProjection(config) + + if config.mask_time_prob > 0.0 or config.mask_feature_prob > 0.0: + self.masked_spec_embed = nn.Parameter(torch.Tensor(config.hidden_size).uniform_()) + + if config.do_stable_layer_norm: + self.encoder = UniSpeechEncoderStableLayerNorm(config) + else: + self.encoder = UniSpeechEncoder(config) + + # Initialize weights and apply final processing + self.post_init() + + def freeze_feature_extractor(self): + raise AttributeError("Not needed for UniSpeech") + + def freeze_feature_encoder(self): + raise AttributeError("Not needed for UniSpeech") + + @add_start_docstrings_to_model_forward(UNISPEECH_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=UniSpeechBaseModelOutput, + config_class=_CONFIG_FOR_DOC, + modality="audio", + expected_output=_EXPECTED_OUTPUT_SHAPE, + ) + def forward( + self, + input_values: Optional[torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + mask_time_indices: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, UniSpeechBaseModelOutput]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + extract_features = self.feature_extractor(input_values) + extract_features = extract_features.transpose(1, 2) + + if attention_mask is not None: + # compute reduced attention_mask corresponding to feature vectors + attention_mask = self._get_feature_vector_attention_mask(extract_features.shape[1], attention_mask) + + hidden_states, extract_features = self.feature_projection(extract_features) + hidden_states = self._mask_hidden_states( + hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask + ) + + encoder_outputs = self.encoder( + hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = encoder_outputs[0] + + if not return_dict: + return (hidden_states, extract_features) + encoder_outputs[1:] + + return UniSpeechBaseModelOutput( + last_hidden_state=hidden_states, + extract_features=extract_features, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +@add_start_docstrings( + """UniSpeech Model with a vector-quantization module and ctc loss for pre-training.""", UNISPEECH_START_DOCSTRING +) +class UniSpeechForPreTraining(UniSpeechPreTrainedModel): + def __init__(self, config: UniSpeechConfig): + super().__init__(config) + self.unispeech = UniSpeechModel(config) + self.dropout_features = nn.Dropout(config.feat_quantizer_dropout) + + self.quantizer = UniSpeechGumbelVectorQuantizer(config) + self.project_q = nn.Linear(config.codevector_dim, config.proj_codevector_dim) + self.project_hid = nn.Linear(config.proj_codevector_dim, config.hidden_size) + + self.ctc_proj = nn.Linear(config.hidden_size, config.num_ctc_classes) + self.dropout = nn.Dropout(config.final_dropout) + + # Initialize weights and apply final processing + self.post_init() + + def set_gumbel_temperature(self, temperature: int): + """ + Set the Gumbel softmax temperature to a given value. Only necessary for training + """ + self.quantizer.temperature = temperature + + def freeze_feature_extractor(self): + """ + Calling this function will disable the gradient computation for the feature encoder so that its parameters will + not be updated during training. + """ + warnings.warn( + "The method `freeze_feature_extractor` is deprecated and will be removed in Transformers v5. " + "Please use the equivalent `freeze_feature_encoder` method instead.", + FutureWarning, + ) + self.freeze_feature_encoder() + + def freeze_feature_encoder(self): + """ + Calling this function will disable the gradient computation for the feature encoder so that its parameter will + not be updated during training. + """ + self.unispeech.feature_extractor._freeze_parameters() + + @staticmethod + def compute_contrastive_logits( + target_features: torch.FloatTensor, + negative_features: torch.FloatTensor, + predicted_features: torch.FloatTensor, + temperature: int = 1, + ): + """ + Compute logits for contrastive loss based using cosine similarity as the distance measure between + `[positive_feature, negative_features]` and `[predicted_features]`. Additionally, temperature can be applied. + """ + target_features = torch.cat([target_features, negative_features], dim=0) + + logits = torch.cosine_similarity(predicted_features.float(), target_features.float(), dim=-1) + logits = logits.type_as(target_features) + + # apply temperature + logits = logits / temperature + return logits + + @add_start_docstrings_to_model_forward(UNISPEECH_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=UniSpeechForPreTrainingOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_values: Optional[torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, UniSpeechForPreTrainingOutput]: + r""" + mask_time_indices (`torch.BoolTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices to mask extracted features for contrastive loss. When in training mode, model learns to predict + masked extracted features in *config.proj_codevector_dim* space. + sampled_negative_indices (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_negatives)`, *optional*): + Indices indicating which quantized target vectors are used as negative sampled vectors in contrastive loss. + Required input for pre-training. + + Returns: + + Example: + + ```python + >>> import torch + >>> from transformers import AutoFeatureExtractor, UniSpeechForPreTraining + + >>> feature_extractor = AutoFeatureExtractor.from_pretrained("microsoft/unispeech-large-1500h-cv") + >>> model = UniSpeechForPreTraining.from_pretrained("microsoft/unispeech-large-1500h-cv") + >>> # TODO: Add full pretraining example + ```""" + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.unispeech( + input_values, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + transformer_features = outputs[0] + + # quantize all (unmasked) extracted features and project to final vq dim + extract_features = self.dropout_features(outputs[1]) + quantized_features, codevector_perplexity = self.quantizer(extract_features) + + # project quantized features twice + quantized_features = self.project_q(quantized_features.to(self.project_q.weight.dtype)) + quantized_features = self.project_hid(quantized_features) + + prob_replace_matrix = torch.empty(transformer_features.size(0), transformer_features.size(1)).fill_( + self.config.replace_prob + ) + prob_replace_matrix = prob_replace_matrix.transpose(0, 1) + sampled_replace_matrix = torch.bernoulli(prob_replace_matrix).bool().to(transformer_features.device) + sampled_replace_matrix = sampled_replace_matrix.transpose(0, 1) + sampled_replace_matrix = sampled_replace_matrix.unsqueeze(-1) + logits = transformer_features.masked_fill(sampled_replace_matrix, 0.0) + ( + quantized_features.masked_fill(~sampled_replace_matrix, 0.0) + ) + + # project to ctc units + logits = self.dropout(logits) + logits = self.ctc_proj(logits) + + # TODO(PVP) - add negative sampling & loss computation + loss = None + if not return_dict: + if loss is not None: + return (loss, transformer_features, quantized_features, codevector_perplexity) + outputs[2:] + return (transformer_features, quantized_features, codevector_perplexity) + outputs[2:] + + return UniSpeechForPreTrainingOutput( + loss=loss, + projected_states=transformer_features, + projected_quantized_states=quantized_features, + codevector_perplexity=codevector_perplexity, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """UniSpeech Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC).""", + UNISPEECH_START_DOCSTRING, + """ + target_lang (`str`, *optional*): + Language id of adapter weights. Adapter weights are stored in the format adapter..safetensors or + adapter..bin. Only relevant when using an instance of [`UniSpeechForCTC`] with adapters. Uses 'eng' + by default. + """, +) +class UniSpeechForCTC(Wav2Vec2ForCTC): + @add_start_docstrings_to_model_forward(UNISPEECH_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=CausalLMOutput, + config_class=_CONFIG_FOR_DOC, + expected_output=_CTC_EXPECTED_OUTPUT, + expected_loss=_CTC_EXPECTED_LOSS, + ) + def forward(self, **super_kwargs): + super().forward(**super_kwargs) + + +@add_start_docstrings( + """ + UniSpeech Model with a sequence classification head on top (a linear layer over the pooled output) for tasks like + SUPERB Keyword Spotting. + """, + UNISPEECH_START_DOCSTRING, +) +class UniSpeechForSequenceClassification(Wav2Vec2ForSequenceClassification): + @add_start_docstrings_to_model_forward(UNISPEECH_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=SequenceClassifierOutput, + config_class=_CONFIG_FOR_DOC, + modality="audio", + ) + def forward(self, **super_kwargs): + super().forward(**super_kwargs) + + +__all__ = [ + "UniSpeechForCTC", + "UniSpeechForPreTraining", + "UniSpeechForSequenceClassification", + "UniSpeechModel", + "UniSpeechPreTrainedModel", +] diff --git a/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py b/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py index 0eb035bac17..b1f8c4c3466 100755 --- a/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py +++ b/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py @@ -1,19 +1,9 @@ -# coding=utf-8 -# Copyright 2021 The Fairseq Authors and the HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""PyTorch UniSpeechSat model.""" - +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/unispeech_sat/modular_unispeech_sat.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_unispeech_sat.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 import math import warnings from dataclasses import dataclass @@ -21,8 +11,7 @@ from typing import Optional, Tuple, Union import numpy as np import torch -import torch.utils.checkpoint -from torch import nn +import torch.nn as nn from torch.nn import CrossEntropyLoss from ...activations import ACT2FN @@ -32,6 +21,7 @@ from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, from ...modeling_outputs import ( BaseModelOutput, CausalLMOutput, + ModelOutput, SequenceClassifierOutput, TokenClassifierOutput, Wav2Vec2BaseModelOutput, @@ -39,7 +29,6 @@ from ...modeling_outputs import ( ) from ...modeling_utils import PreTrainedModel from ...utils import ( - ModelOutput, add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, @@ -56,28 +45,12 @@ if is_flash_attn_available(): logger = logging.get_logger(__name__) - -_HIDDEN_STATES_START_POSITION = 2 +# Base docstring +_CHECKPOINT_FOR_DOC = "microsoft/unispeech-sat-base-100h-libri-ft" # General docstring _CONFIG_FOR_DOC = "UniSpeechSatConfig" -# Base docstring -_CHECKPOINT_FOR_DOC = "microsoft/unispeech-sat-base-100h-libri-ft" -_EXPECTED_OUTPUT_SHAPE = [1, 292, 768] - -# CTC docstring -_CTC_EXPECTED_OUTPUT = "'MISTER QUILDER IS THE APOSTLE OF THE MIDDLE CLASSES AND WE ARE GLAD TO WELCOME HIS GOSPEL'" -_CTC_EXPECTED_LOSS = 39.88 - -# Frame class docstring -_FRAME_CLASS_CHECKPOINT = "microsoft/unispeech-sat-base-plus-sd" -_FRAME_EXPECTED_OUTPUT = [0, 0] - -# Speaker Verification docstring -_XVECTOR_CHECKPOINT = "microsoft/unispeech-sat-base-plus-sv" -_XVECTOR_EXPECTED_OUTPUT = 0.97 - @dataclass class UniSpeechSatForPreTrainingOutput(ModelOutput): @@ -116,202 +89,17 @@ class UniSpeechSatForPreTrainingOutput(ModelOutput): attentions: Optional[Tuple[torch.FloatTensor]] = None -# Copied from transformers.models.wav2vec2.modeling_wav2vec2._compute_mask_indices -def _compute_mask_indices( - shape: Tuple[int, int], - mask_prob: float, - mask_length: int, - attention_mask: Optional[torch.LongTensor] = None, - min_masks: int = 0, -) -> np.ndarray: - """ - Computes random mask spans for a given shape. Used to implement [SpecAugment: A Simple Data Augmentation Method for - ASR](https://arxiv.org/abs/1904.08779). Note that this method is not optimized to run on TPU and should be run on - CPU as part of the preprocessing during training. - - Args: - shape: The shape for which to compute masks. This should be of a tuple of size 2 where - the first element is the batch size and the second element is the length of the axis to span. - mask_prob: The percentage of the whole axis (between 0 and 1) which will be masked. The number of - independently generated mask spans of length `mask_length` is computed by - `mask_prob*shape[1]/mask_length`. Note that due to overlaps, `mask_prob` is an upper bound and the - actual percentage will be smaller. - mask_length: size of the mask - min_masks: minimum number of masked spans - attention_mask: A (right-padded) attention mask which independently shortens the feature axis of - each batch dimension. - """ - batch_size, sequence_length = shape - - if mask_length < 1: - raise ValueError("`mask_length` has to be bigger than 0.") - - if mask_length > sequence_length: - raise ValueError( - f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length}" - f" and `sequence_length`: {sequence_length}`" - ) - - # epsilon is used for probabilistic rounding - epsilon = np.random.rand(1).item() - - def compute_num_masked_span(input_length): - """Given input length, compute how many spans should be masked""" - num_masked_span = int(mask_prob * input_length / mask_length + epsilon) - num_masked_span = max(num_masked_span, min_masks) - - # make sure num masked span <= sequence_length - if num_masked_span * mask_length > sequence_length: - num_masked_span = sequence_length // mask_length - - # make sure num_masked span is also <= input_length - (mask_length - 1) - if input_length - (mask_length - 1) < num_masked_span: - num_masked_span = max(input_length - (mask_length - 1), 0) - - return num_masked_span - - # compute number of masked spans in batch - input_lengths = ( - attention_mask.detach().sum(-1).tolist() - if attention_mask is not None - else [sequence_length for _ in range(batch_size)] - ) - - # SpecAugment mask to fill - spec_aug_mask = np.zeros((batch_size, sequence_length), dtype=bool) - spec_aug_mask_idxs = [] - - max_num_masked_span = compute_num_masked_span(sequence_length) - - if max_num_masked_span == 0: - return spec_aug_mask - - for input_length in input_lengths: - # compute num of masked spans for this input - num_masked_span = compute_num_masked_span(input_length) - - # get random indices to mask - spec_aug_mask_idx = np.random.choice( - np.arange(input_length - (mask_length - 1)), num_masked_span, replace=False - ) - - # pick first sampled index that will serve as a dummy index to pad vector - # to ensure same dimension for all batches due to probabilistic rounding - # Picking first sample just pads those vectors twice. - if len(spec_aug_mask_idx) == 0: - # this case can only happen if `input_length` is strictly smaller then - # `sequence_length` in which case the last token has to be a padding - # token which we can use as a dummy mask id - dummy_mask_idx = sequence_length - 1 - else: - dummy_mask_idx = spec_aug_mask_idx[0] - - spec_aug_mask_idx = np.concatenate( - [spec_aug_mask_idx, np.ones(max_num_masked_span - num_masked_span, dtype=np.int32) * dummy_mask_idx] - ) - spec_aug_mask_idxs.append(spec_aug_mask_idx) - - spec_aug_mask_idxs = np.array(spec_aug_mask_idxs) - - # expand masked indices to masked spans - spec_aug_mask_idxs = np.broadcast_to( - spec_aug_mask_idxs[:, :, None], (batch_size, max_num_masked_span, mask_length) - ) - spec_aug_mask_idxs = spec_aug_mask_idxs.reshape(batch_size, max_num_masked_span * mask_length) - - # add offset to the starting indexes so that indexes now create a span - offsets = np.arange(mask_length)[None, None, :] - offsets = np.broadcast_to(offsets, (batch_size, max_num_masked_span, mask_length)).reshape( - batch_size, max_num_masked_span * mask_length - ) - spec_aug_mask_idxs = spec_aug_mask_idxs + offsets - - # ensure that we cannot have indices larger than sequence_length - if spec_aug_mask_idxs.max() > sequence_length - 1: - spec_aug_mask_idxs[spec_aug_mask_idxs > sequence_length - 1] = sequence_length - 1 - - # scatter indices to mask - np.put_along_axis(spec_aug_mask, spec_aug_mask_idxs, 1, -1) - - return spec_aug_mask - - -# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2NoLayerNormConvLayer with Wav2Vec2->UniSpeechSat -class UniSpeechSatNoLayerNormConvLayer(nn.Module): - def __init__(self, config, layer_id=0): +class UniSpeechSatSamePadLayer(nn.Module): + def __init__(self, num_conv_pos_embeddings): super().__init__() - self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1 - self.out_conv_dim = config.conv_dim[layer_id] - - self.conv = nn.Conv1d( - self.in_conv_dim, - self.out_conv_dim, - kernel_size=config.conv_kernel[layer_id], - stride=config.conv_stride[layer_id], - bias=config.conv_bias, - ) - self.activation = ACT2FN[config.feat_extract_activation] + self.num_pad_remove = 1 if num_conv_pos_embeddings % 2 == 0 else 0 def forward(self, hidden_states): - hidden_states = self.conv(hidden_states) - hidden_states = self.activation(hidden_states) + if self.num_pad_remove > 0: + hidden_states = hidden_states[:, :, : -self.num_pad_remove] return hidden_states -# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2LayerNormConvLayer with Wav2Vec2->UniSpeechSat -class UniSpeechSatLayerNormConvLayer(nn.Module): - def __init__(self, config, layer_id=0): - super().__init__() - self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1 - self.out_conv_dim = config.conv_dim[layer_id] - - self.conv = nn.Conv1d( - self.in_conv_dim, - self.out_conv_dim, - kernel_size=config.conv_kernel[layer_id], - stride=config.conv_stride[layer_id], - bias=config.conv_bias, - ) - self.layer_norm = nn.LayerNorm(self.out_conv_dim, elementwise_affine=True) - self.activation = ACT2FN[config.feat_extract_activation] - - def forward(self, hidden_states): - hidden_states = self.conv(hidden_states) - - hidden_states = hidden_states.transpose(-2, -1) - hidden_states = self.layer_norm(hidden_states) - hidden_states = hidden_states.transpose(-2, -1) - - hidden_states = self.activation(hidden_states) - return hidden_states - - -# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2GroupNormConvLayer with Wav2Vec2->UniSpeechSat -class UniSpeechSatGroupNormConvLayer(nn.Module): - def __init__(self, config, layer_id=0): - super().__init__() - self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1 - self.out_conv_dim = config.conv_dim[layer_id] - - self.conv = nn.Conv1d( - self.in_conv_dim, - self.out_conv_dim, - kernel_size=config.conv_kernel[layer_id], - stride=config.conv_stride[layer_id], - bias=config.conv_bias, - ) - self.activation = ACT2FN[config.feat_extract_activation] - - self.layer_norm = nn.GroupNorm(num_groups=self.out_conv_dim, num_channels=self.out_conv_dim, affine=True) - - def forward(self, hidden_states): - hidden_states = self.conv(hidden_states) - hidden_states = self.layer_norm(hidden_states) - hidden_states = self.activation(hidden_states) - return hidden_states - - -# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2PositionalConvEmbedding with Wav2Vec2->UniSpeechSat class UniSpeechSatPositionalConvEmbedding(nn.Module): def __init__(self, config): super().__init__() @@ -357,19 +145,78 @@ class UniSpeechSatPositionalConvEmbedding(nn.Module): return hidden_states -# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2SamePadLayer with Wav2Vec2->UniSpeechSat -class UniSpeechSatSamePadLayer(nn.Module): - def __init__(self, num_conv_pos_embeddings): +class UniSpeechSatNoLayerNormConvLayer(nn.Module): + def __init__(self, config, layer_id=0): super().__init__() - self.num_pad_remove = 1 if num_conv_pos_embeddings % 2 == 0 else 0 + self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1 + self.out_conv_dim = config.conv_dim[layer_id] + + self.conv = nn.Conv1d( + self.in_conv_dim, + self.out_conv_dim, + kernel_size=config.conv_kernel[layer_id], + stride=config.conv_stride[layer_id], + bias=config.conv_bias, + ) + self.activation = ACT2FN[config.feat_extract_activation] def forward(self, hidden_states): - if self.num_pad_remove > 0: - hidden_states = hidden_states[:, :, : -self.num_pad_remove] + hidden_states = self.conv(hidden_states) + hidden_states = self.activation(hidden_states) + return hidden_states + + +class UniSpeechSatLayerNormConvLayer(nn.Module): + def __init__(self, config, layer_id=0): + super().__init__() + self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1 + self.out_conv_dim = config.conv_dim[layer_id] + + self.conv = nn.Conv1d( + self.in_conv_dim, + self.out_conv_dim, + kernel_size=config.conv_kernel[layer_id], + stride=config.conv_stride[layer_id], + bias=config.conv_bias, + ) + self.layer_norm = nn.LayerNorm(self.out_conv_dim, elementwise_affine=True) + self.activation = ACT2FN[config.feat_extract_activation] + + def forward(self, hidden_states): + hidden_states = self.conv(hidden_states) + + hidden_states = hidden_states.transpose(-2, -1) + hidden_states = self.layer_norm(hidden_states) + hidden_states = hidden_states.transpose(-2, -1) + + hidden_states = self.activation(hidden_states) + return hidden_states + + +class UniSpeechSatGroupNormConvLayer(nn.Module): + def __init__(self, config, layer_id=0): + super().__init__() + self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1 + self.out_conv_dim = config.conv_dim[layer_id] + + self.conv = nn.Conv1d( + self.in_conv_dim, + self.out_conv_dim, + kernel_size=config.conv_kernel[layer_id], + stride=config.conv_stride[layer_id], + bias=config.conv_bias, + ) + self.activation = ACT2FN[config.feat_extract_activation] + + self.layer_norm = nn.GroupNorm(num_groups=self.out_conv_dim, num_channels=self.out_conv_dim, affine=True) + + def forward(self, hidden_states): + hidden_states = self.conv(hidden_states) + hidden_states = self.layer_norm(hidden_states) + hidden_states = self.activation(hidden_states) return hidden_states -# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2FeatureEncoder with Wav2Vec2->UniSpeechSat class UniSpeechSatFeatureEncoder(nn.Module): """Construct the features from raw audio waveform""" @@ -417,18 +264,6 @@ class UniSpeechSatFeatureEncoder(nn.Module): return hidden_states -class UniSpeechSatFeatureExtractor(UniSpeechSatFeatureEncoder): - def __init__(self, config): - super().__init__(config) - warnings.warn( - f"The class `{self.__class__.__name__}` has been depreciated " - "and will be removed in Transformers v5. " - f"Use `{self.__class__.__bases__[0].__name__}` instead.", - FutureWarning, - ) - - -# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2FeatureProjection with Wav2Vec2->UniSpeechSat class UniSpeechSatFeatureProjection(nn.Module): def __init__(self, config): super().__init__() @@ -444,7 +279,6 @@ class UniSpeechSatFeatureProjection(nn.Module): return hidden_states, norm_hidden_states -# Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->UniSpeechSat class UniSpeechSatAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -603,7 +437,6 @@ class UniSpeechSatAttention(nn.Module): return attn_output, attn_weights_reshaped, past_key_value -# Copied from transformers.models.bart.modeling_bart.BartFlashAttention2 with Bart->UniSpeechSat class UniSpeechSatFlashAttention2(UniSpeechSatAttention): """ UniSpeechSat flash attention module. This module inherits from `UniSpeechSatAttention` as the weights of the module stays @@ -731,7 +564,6 @@ class UniSpeechSatFlashAttention2(UniSpeechSatAttention): class UniSpeechSatSdpaAttention(UniSpeechSatAttention): - # Copied from transformers.models.bart.modeling_bart.BartSdpaAttention.forward with Bart->UniSpeechSat def forward( self, hidden_states: torch.Tensor, @@ -837,14 +669,6 @@ class UniSpeechSatSdpaAttention(UniSpeechSatAttention): return attn_output, None, past_key_value -UNISPEECHSAT_ATTENTION_CLASSES = { - "eager": UniSpeechSatAttention, - "sdpa": UniSpeechSatSdpaAttention, - "flash_attention_2": UniSpeechSatFlashAttention2, -} - - -# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2FeedForward with Wav2Vec2->UniSpeechSat class UniSpeechSatFeedForward(nn.Module): def __init__(self, config): super().__init__() @@ -869,11 +693,17 @@ class UniSpeechSatFeedForward(nn.Module): return hidden_states -# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2EncoderLayer with Wav2Vec2->UniSpeechSat, WAV2VEC2->UNISPEECHSAT +UNISPEECH_SAT_ATTENTION_CLASSES = { + "eager": UniSpeechSatAttention, + "sdpa": UniSpeechSatSdpaAttention, + "flash_attention_2": UniSpeechSatFlashAttention2, +} + + class UniSpeechSatEncoderLayer(nn.Module): def __init__(self, config): super().__init__() - self.attention = UNISPEECHSAT_ATTENTION_CLASSES[config._attn_implementation]( + self.attention = UNISPEECH_SAT_ATTENTION_CLASSES[config._attn_implementation]( embed_dim=config.hidden_size, num_heads=config.num_attention_heads, dropout=config.attention_dropout, @@ -905,79 +735,6 @@ class UniSpeechSatEncoderLayer(nn.Module): return outputs -# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2AttnAdapterLayer with Wav2Vec2->UniSpeechSat -class UniSpeechSatAttnAdapterLayer(nn.Module): - def __init__(self, config): - """ - Implements adapter modules directly with 3D tensor weight as parameters and without using ModuleList to speed - up training throughput. - """ - super().__init__() - self.input_dim = config.adapter_attn_dim - self.hidden_dim = config.hidden_size - - self.norm = nn.LayerNorm(self.hidden_dim) - self.linear_1 = nn.Linear(self.hidden_dim, self.input_dim) - self.act_fn = nn.ReLU() - self.linear_2 = nn.Linear(self.input_dim, self.hidden_dim) - - def forward(self, hidden_states: torch.FloatTensor): - hidden_states = self.norm(hidden_states) - - hidden_states = self.linear_1(hidden_states) - hidden_states = self.act_fn(hidden_states) - hidden_states = self.linear_2(hidden_states) - - return hidden_states - - -# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2EncoderLayerStableLayerNorm with Wav2Vec2->UniSpeechSat, WAV2VEC2->UNISPEECHSAT -class UniSpeechSatEncoderLayerStableLayerNorm(nn.Module): - def __init__(self, config): - super().__init__() - self.attention = UNISPEECHSAT_ATTENTION_CLASSES[config._attn_implementation]( - embed_dim=config.hidden_size, - num_heads=config.num_attention_heads, - dropout=config.attention_dropout, - is_decoder=False, - ) - self.dropout = nn.Dropout(config.hidden_dropout) - self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - self.feed_forward = UniSpeechSatFeedForward(config) - self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - - if getattr(config, "adapter_attn_dim", None) is not None: - self.adapter_layer = UniSpeechSatAttnAdapterLayer(config) - else: - self.adapter_layer = None - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - output_attentions: bool = False, - ): - attn_residual = hidden_states - hidden_states = self.layer_norm(hidden_states) - hidden_states, attn_weights, _ = self.attention( - hidden_states, attention_mask=attention_mask, output_attentions=output_attentions - ) - hidden_states = self.dropout(hidden_states) - hidden_states = attn_residual + hidden_states - hidden_states = hidden_states + self.feed_forward(self.final_layer_norm(hidden_states)) - - if self.adapter_layer is not None: - hidden_states = hidden_states + self.adapter_layer(hidden_states) - - outputs = (hidden_states,) - - if output_attentions: - outputs += (attn_weights,) - - return outputs - - -# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Encoder with Wav2Vec2->UniSpeechSat class UniSpeechSatEncoder(nn.Module): def __init__(self, config): super().__init__() @@ -1063,7 +820,76 @@ class UniSpeechSatEncoder(nn.Module): ) -# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2EncoderStableLayerNorm with Wav2Vec2->UniSpeechSat +class UniSpeechSatAttnAdapterLayer(nn.Module): + def __init__(self, config): + """ + Implements adapter modules directly with 3D tensor weight as parameters and without using ModuleList to speed + up training throughput. + """ + super().__init__() + self.input_dim = config.adapter_attn_dim + self.hidden_dim = config.hidden_size + + self.norm = nn.LayerNorm(self.hidden_dim) + self.linear_1 = nn.Linear(self.hidden_dim, self.input_dim) + self.act_fn = nn.ReLU() + self.linear_2 = nn.Linear(self.input_dim, self.hidden_dim) + + def forward(self, hidden_states: torch.FloatTensor): + hidden_states = self.norm(hidden_states) + + hidden_states = self.linear_1(hidden_states) + hidden_states = self.act_fn(hidden_states) + hidden_states = self.linear_2(hidden_states) + + return hidden_states + + +class UniSpeechSatEncoderLayerStableLayerNorm(nn.Module): + def __init__(self, config): + super().__init__() + self.attention = UNISPEECH_SAT_ATTENTION_CLASSES[config._attn_implementation]( + embed_dim=config.hidden_size, + num_heads=config.num_attention_heads, + dropout=config.attention_dropout, + is_decoder=False, + ) + self.dropout = nn.Dropout(config.hidden_dropout) + self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.feed_forward = UniSpeechSatFeedForward(config) + self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + if getattr(config, "adapter_attn_dim", None) is not None: + self.adapter_layer = UniSpeechSatAttnAdapterLayer(config) + else: + self.adapter_layer = None + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ): + attn_residual = hidden_states + hidden_states = self.layer_norm(hidden_states) + hidden_states, attn_weights, _ = self.attention( + hidden_states, attention_mask=attention_mask, output_attentions=output_attentions + ) + hidden_states = self.dropout(hidden_states) + hidden_states = attn_residual + hidden_states + hidden_states = hidden_states + self.feed_forward(self.final_layer_norm(hidden_states)) + + if self.adapter_layer is not None: + hidden_states = hidden_states + self.adapter_layer(hidden_states) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_weights,) + + return outputs + + class UniSpeechSatEncoderStableLayerNorm(nn.Module): def __init__(self, config): super().__init__() @@ -1155,7 +981,7 @@ class UniSpeechSatEncoderStableLayerNorm(nn.Module): class UniSpeechSatGumbelVectorQuantizer(nn.Module): """ - Vector quantization using gumbel softmax. See [CATEGORICAL REPARAMETERIZATION WITH + Vector quantization using gumbel softmax. See `[CATEGORICAL REPARAMETERIZATION WITH GUMBEL-SOFTMAX](https://arxiv.org/pdf/1611.01144.pdf) for more information. """ @@ -1166,8 +992,8 @@ class UniSpeechSatGumbelVectorQuantizer(nn.Module): if config.codevector_dim % self.num_groups != 0: raise ValueError( - f"`config.codevector_dim {config.codevector_dim} must be divisible by `config.num_codevector_groups`" - f" {self.num_groups} for concatenation" + f"`config.codevector_dim {config.codevector_dim} must be divisible " + f"by `config.num_codevector_groups` {self.num_groups} for concatenation" ) # storage for codebook variables (codewords) @@ -1300,6 +1126,128 @@ class UniSpeechSatPreTrainedModel(PreTrainedModel): return attention_mask +def _compute_mask_indices( + shape: Tuple[int, int], + mask_prob: float, + mask_length: int, + attention_mask: Optional[torch.LongTensor] = None, + min_masks: int = 0, +) -> np.ndarray: + """ + Computes random mask spans for a given shape. Used to implement [SpecAugment: A Simple Data Augmentation Method for + ASR](https://arxiv.org/abs/1904.08779). Note that this method is not optimized to run on TPU and should be run on + CPU as part of the preprocessing during training. + + Args: + shape: The shape for which to compute masks. This should be of a tuple of size 2 where + the first element is the batch size and the second element is the length of the axis to span. + mask_prob: The percentage of the whole axis (between 0 and 1) which will be masked. The number of + independently generated mask spans of length `mask_length` is computed by + `mask_prob*shape[1]/mask_length`. Note that due to overlaps, `mask_prob` is an upper bound and the + actual percentage will be smaller. + mask_length: size of the mask + min_masks: minimum number of masked spans + attention_mask: A (right-padded) attention mask which independently shortens the feature axis of + each batch dimension. + """ + batch_size, sequence_length = shape + + if mask_length < 1: + raise ValueError("`mask_length` has to be bigger than 0.") + + if mask_length > sequence_length: + raise ValueError( + f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length}" + f" and `sequence_length`: {sequence_length}`" + ) + + # epsilon is used for probabilistic rounding + epsilon = np.random.rand(1).item() + + def compute_num_masked_span(input_length): + """Given input length, compute how many spans should be masked""" + num_masked_span = int(mask_prob * input_length / mask_length + epsilon) + num_masked_span = max(num_masked_span, min_masks) + + # make sure num masked span <= sequence_length + if num_masked_span * mask_length > sequence_length: + num_masked_span = sequence_length // mask_length + + # make sure num_masked span is also <= input_length - (mask_length - 1) + if input_length - (mask_length - 1) < num_masked_span: + num_masked_span = max(input_length - (mask_length - 1), 0) + + return num_masked_span + + # compute number of masked spans in batch + input_lengths = ( + attention_mask.detach().sum(-1).tolist() + if attention_mask is not None + else [sequence_length for _ in range(batch_size)] + ) + + # SpecAugment mask to fill + spec_aug_mask = np.zeros((batch_size, sequence_length), dtype=bool) + spec_aug_mask_idxs = [] + + max_num_masked_span = compute_num_masked_span(sequence_length) + + if max_num_masked_span == 0: + return spec_aug_mask + + for input_length in input_lengths: + # compute num of masked spans for this input + num_masked_span = compute_num_masked_span(input_length) + + # get random indices to mask + spec_aug_mask_idx = np.random.choice( + np.arange(input_length - (mask_length - 1)), num_masked_span, replace=False + ) + + # pick first sampled index that will serve as a dummy index to pad vector + # to ensure same dimension for all batches due to probabilistic rounding + # Picking first sample just pads those vectors twice. + if len(spec_aug_mask_idx) == 0: + # this case can only happen if `input_length` is strictly smaller then + # `sequence_length` in which case the last token has to be a padding + # token which we can use as a dummy mask id + dummy_mask_idx = sequence_length - 1 + else: + dummy_mask_idx = spec_aug_mask_idx[0] + + spec_aug_mask_idx = np.concatenate( + [spec_aug_mask_idx, np.ones(max_num_masked_span - num_masked_span, dtype=np.int32) * dummy_mask_idx] + ) + spec_aug_mask_idxs.append(spec_aug_mask_idx) + + spec_aug_mask_idxs = np.array(spec_aug_mask_idxs) + + # expand masked indices to masked spans + spec_aug_mask_idxs = np.broadcast_to( + spec_aug_mask_idxs[:, :, None], (batch_size, max_num_masked_span, mask_length) + ) + spec_aug_mask_idxs = spec_aug_mask_idxs.reshape(batch_size, max_num_masked_span * mask_length) + + # add offset to the starting indexes so that indexes now create a span + offsets = np.arange(mask_length)[None, None, :] + offsets = np.broadcast_to(offsets, (batch_size, max_num_masked_span, mask_length)).reshape( + batch_size, max_num_masked_span * mask_length + ) + spec_aug_mask_idxs = spec_aug_mask_idxs + offsets + + # ensure that we cannot have indices larger than sequence_length + if spec_aug_mask_idxs.max() > sequence_length - 1: + spec_aug_mask_idxs[spec_aug_mask_idxs > sequence_length - 1] = sequence_length - 1 + + # scatter indices to mask + np.put_along_axis(spec_aug_mask, spec_aug_mask_idxs, 1, -1) + + return spec_aug_mask + + +_EXPECTED_OUTPUT_SHAPE = [1, 292, 768] + + UNISPEECH_SAT_START_DOCSTRING = r""" UniSpeechSat was proposed in [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations](https://arxiv.org/abs/2006.11477) by Alexei Baevski, Henry Zhou, Abdelrahman Mohamed, Michael @@ -1318,7 +1266,6 @@ UNISPEECH_SAT_START_DOCSTRING = r""" configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. """ - UNISPEECH_SAT_INPUTS_DOCSTRING = r""" Args: input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`): @@ -1338,12 +1285,10 @@ UNISPEECH_SAT_INPUTS_DOCSTRING = r""" `attention_mask` should only be passed if the corresponding processor has `config.return_attention_mask == - True`. For all models whose processor has `config.return_attention_mask == False`, such as - [microsoft/unispeech-sat-base-100h-libri-ft](https://huggingface.co/microsoft/unispeech-sat-base-100h-libri-ft), - `attention_mask` should **not** be passed to avoid degraded performance when doing batched inference. For - such models `input_values` should simply be padded with 0 and passed without `attention_mask`. Be aware - that these models also yield slightly different results depending on whether `input_values` is padded or - not. + True`. For all models whose processor has `config.return_attention_mask == False`, `attention_mask` should + **not** be passed to avoid degraded performance when doing batched inference. For such models + `input_values` should simply be padded with 0 and passed without `attention_mask`. Be aware that these + models also yield slightly different results depending on whether `input_values` is padded or not. @@ -1357,6 +1302,8 @@ UNISPEECH_SAT_INPUTS_DOCSTRING = r""" Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. """ +UniSpeechSatBaseModelOutput = Wav2Vec2BaseModelOutput + @add_start_docstrings( "The bare UniSpeechSat Model transformer outputting raw hidden-states without any specific head on top.", @@ -1379,7 +1326,6 @@ class UniSpeechSatModel(UniSpeechSatPreTrainedModel): # Initialize weights and apply final processing self.post_init() - # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Model._mask_hidden_states def _mask_hidden_states( self, hidden_states: torch.FloatTensor, @@ -1429,7 +1375,7 @@ class UniSpeechSatModel(UniSpeechSatPreTrainedModel): @add_start_docstrings_to_model_forward(UNISPEECH_SAT_INPUTS_DOCSTRING) @add_code_sample_docstrings( checkpoint=_CHECKPOINT_FOR_DOC, - output_type=Wav2Vec2BaseModelOutput, + output_type=UniSpeechSatBaseModelOutput, config_class=_CONFIG_FOR_DOC, modality="audio", expected_output=_EXPECTED_OUTPUT_SHAPE, @@ -1442,7 +1388,7 @@ class UniSpeechSatModel(UniSpeechSatPreTrainedModel): output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, - ) -> Union[Tuple, Wav2Vec2BaseModelOutput]: + ) -> Union[Tuple, UniSpeechSatBaseModelOutput]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states @@ -1474,7 +1420,7 @@ class UniSpeechSatModel(UniSpeechSatPreTrainedModel): if not return_dict: return (hidden_states, extract_features) + encoder_outputs[1:] - return Wav2Vec2BaseModelOutput( + return UniSpeechSatBaseModelOutput( last_hidden_state=hidden_states, extract_features=extract_features, hidden_states=encoder_outputs.hidden_states, @@ -1482,7 +1428,10 @@ class UniSpeechSatModel(UniSpeechSatPreTrainedModel): ) -@add_start_docstrings("""UniSpeechSat Model with a quantizer and `VQ` head on top.""", UNISPEECH_SAT_START_DOCSTRING) +@add_start_docstrings( + """UniSpeechSat Model with a vector-quantization module and ctc loss for pre-training.""", + UNISPEECH_SAT_START_DOCSTRING, +) class UniSpeechSatForPreTraining(UniSpeechSatPreTrainedModel): def __init__(self, config: UniSpeechSatConfig): super().__init__(config) @@ -1529,7 +1478,7 @@ class UniSpeechSatForPreTraining(UniSpeechSatPreTrainedModel): Calling this function will disable the gradient computation for the feature encoder so that its parameter will not be updated during training. """ - self.wav2vec2.feature_extractor._freeze_parameters() + self.unispeech_sat.feature_extractor._freeze_parameters() @staticmethod def compute_contrastive_logits( @@ -1594,16 +1543,6 @@ class UniSpeechSatForPreTraining(UniSpeechSatPreTrainedModel): logits = extract_features loss = quantized_features = codevector_perplexity = None - # layer normalization (has no effect when `config.do_stable_layer_norm == False`) - # extract_features = self.layer_norm_for_extract(extract_features) - # quantized_features, codevector_perplexity = self.quantizer(extract_features) - # - # project quantized features twice - # quantized_features = self.project_q(quantized_features) - # quantized_features = self.project_hid(quantized_features) - # - # loss = None - # logits = quantized_features if not return_dict: if loss is not None: return (loss, logits, transformer_features, quantized_features, codevector_perplexity) + outputs[2:] @@ -1620,6 +1559,13 @@ class UniSpeechSatForPreTraining(UniSpeechSatPreTrainedModel): ) +_HIDDEN_STATES_START_POSITION = 2 + +# CTC docstring +_CTC_EXPECTED_OUTPUT = "'MISTER QUILDER IS THE APOSTLE OF THE MIDDLE CLASSES AND WE ARE GLAD TO WELCOME HIS GOSPEL'" +_CTC_EXPECTED_LOSS = 39.88 + + @add_start_docstrings( """UniSpeechSat Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC).""", UNISPEECH_SAT_START_DOCSTRING, @@ -1630,7 +1576,6 @@ class UniSpeechSatForPreTraining(UniSpeechSatPreTrainedModel): 'eng' by default. """, ) -# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForCTC with Wav2Vec2->UniSpeechSat, wav2vec2->unispeech_sat, WAV_2_VEC_2->UNISPEECH_SAT class UniSpeechSatForCTC(UniSpeechSatPreTrainedModel): def __init__(self, config, target_lang: Optional[str] = None): super().__init__(config) @@ -1784,8 +1729,8 @@ class UniSpeechSatForCTC(UniSpeechSatPreTrainedModel): @add_start_docstrings( """ - UniSpeechSat Model with a sequence classification head on top (a linear layer over the pooled output) for tasks - like SUPERB Keyword Spotting. + UniSpeechSat Model with a sequence classification head on top (a linear layer over the pooled output) for tasks like + SUPERB Keyword Spotting. """, UNISPEECH_SAT_START_DOCSTRING, ) @@ -1807,7 +1752,6 @@ class UniSpeechSatForSequenceClassification(UniSpeechSatPreTrainedModel): # Initialize weights and apply final processing self.post_init() - # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.freeze_feature_extractor def freeze_feature_extractor(self): """ Calling this function will disable the gradient computation for the feature encoder so that its parameters will @@ -1820,7 +1764,6 @@ class UniSpeechSatForSequenceClassification(UniSpeechSatPreTrainedModel): ) self.freeze_feature_encoder() - # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.freeze_feature_encoder with wav2vec2->unispeech_sat def freeze_feature_encoder(self): """ Calling this function will disable the gradient computation for the feature encoder so that its parameter will @@ -1828,7 +1771,6 @@ class UniSpeechSatForSequenceClassification(UniSpeechSatPreTrainedModel): """ self.unispeech_sat.feature_extractor._freeze_parameters() - # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.freeze_base_model with wav2vec2->unispeech_sat def freeze_base_model(self): """ Calling this function will disable the gradient computation for the base model so that its parameters will not @@ -1844,7 +1786,6 @@ class UniSpeechSatForSequenceClassification(UniSpeechSatPreTrainedModel): config_class=_CONFIG_FOR_DOC, modality="audio", ) - # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.forward with Wav2Vec2->UniSpeechSat, wav2vec2->unispeech_sat def forward( self, input_values: Optional[torch.Tensor], @@ -1908,13 +1849,17 @@ class UniSpeechSatForSequenceClassification(UniSpeechSatPreTrainedModel): ) +# Frame class docstring +_FRAME_CLASS_CHECKPOINT = "microsoft/unispeech-sat-base-plus-sd" +_FRAME_EXPECTED_OUTPUT = [0, 0] + + @add_start_docstrings( """ - UniSpeech-SAT Model with a frame classification head on top for tasks like Speaker Diarization. + UniSpeechSat Model with a frame classification head on top for tasks like Speaker Diarization. """, UNISPEECH_SAT_START_DOCSTRING, ) -# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForAudioFrameClassification with Wav2Vec2->UniSpeechSat, wav2vec2->unispeech_sat, WAV_2_VEC_2->UNISPEECH_SAT class UniSpeechSatForAudioFrameClassification(UniSpeechSatPreTrainedModel): def __init__(self, config): super().__init__(config) @@ -2021,7 +1966,6 @@ class UniSpeechSatForAudioFrameClassification(UniSpeechSatPreTrainedModel): ) -# Copied from transformers.models.wav2vec2.modeling_wav2vec2.AMSoftmaxLoss class AMSoftmaxLoss(nn.Module): def __init__(self, input_dim, num_labels, scale=30.0, margin=0.4): super(AMSoftmaxLoss, self).__init__() @@ -2045,7 +1989,6 @@ class AMSoftmaxLoss(nn.Module): return loss -# Copied from transformers.models.wav2vec2.modeling_wav2vec2.TDNNLayer class TDNNLayer(nn.Module): def __init__(self, config, layer_id=0): super().__init__() @@ -2061,6 +2004,7 @@ class TDNNLayer(nn.Module): if is_peft_available(): from peft.tuners.lora import LoraLayer + if is_peft_available(): if isinstance(self.kernel, LoraLayer): warnings.warn( "Detected LoRA on TDNNLayer. LoRA weights won't be applied due to optimization. " @@ -2077,13 +2021,17 @@ class TDNNLayer(nn.Module): return hidden_states +# Speaker Verification docstring +_XVECTOR_CHECKPOINT = "microsoft/unispeech-sat-base-plus-sv" +_XVECTOR_EXPECTED_OUTPUT = 0.97 + + @add_start_docstrings( """ - UniSpeech-SAT Model with an XVector feature extraction head on top for tasks like Speaker Verification. + UniSpeechSat Model with an XVector feature extraction head on top for tasks like Speaker Verification. """, UNISPEECH_SAT_START_DOCSTRING, ) -# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForXVector with Wav2Vec2->UniSpeechSat, wav2vec2->unispeech_sat, WAV_2_VEC_2->UNISPEECH_SAT class UniSpeechSatForXVector(UniSpeechSatPreTrainedModel): def __init__(self, config): super().__init__(config) diff --git a/src/transformers/models/unispeech_sat/modular_unispeech_sat.py b/src/transformers/models/unispeech_sat/modular_unispeech_sat.py new file mode 100644 index 00000000000..44e566068ef --- /dev/null +++ b/src/transformers/models/unispeech_sat/modular_unispeech_sat.py @@ -0,0 +1,610 @@ +import math +import warnings +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import torch +import torch.nn as nn + +from ...modeling_outputs import ( + CausalLMOutput, + ModelOutput, + SequenceClassifierOutput, + TokenClassifierOutput, + Wav2Vec2BaseModelOutput, + XVectorOutput, +) +from ...modeling_utils import PreTrainedModel +from ...utils import ( + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, +) +from ..wav2vec2.modeling_wav2vec2 import ( + Wav2Vec2Encoder, + Wav2Vec2EncoderStableLayerNorm, + Wav2Vec2FeatureEncoder, + Wav2Vec2FeatureProjection, + Wav2Vec2ForAudioFrameClassification, + Wav2Vec2ForCTC, + Wav2Vec2ForSequenceClassification, + Wav2Vec2ForXVector, + Wav2Vec2GumbelVectorQuantizer, + Wav2Vec2Model, + Wav2Vec2PositionalConvEmbedding, +) +from .configuration_unispeech_sat import UniSpeechSatConfig + + +logger = logging.get_logger(__name__) + + +_HIDDEN_STATES_START_POSITION = 2 + +# General docstring +_CONFIG_FOR_DOC = "UniSpeechSatConfig" + +# Base docstring +_CHECKPOINT_FOR_DOC = "microsoft/unispeech-sat-base-100h-libri-ft" +_EXPECTED_OUTPUT_SHAPE = [1, 292, 768] + +# CTC docstring +_CTC_EXPECTED_OUTPUT = "'MISTER QUILDER IS THE APOSTLE OF THE MIDDLE CLASSES AND WE ARE GLAD TO WELCOME HIS GOSPEL'" +_CTC_EXPECTED_LOSS = 39.88 + +# Frame class docstring +_FRAME_CLASS_CHECKPOINT = "microsoft/unispeech-sat-base-plus-sd" +_FRAME_EXPECTED_OUTPUT = [0, 0] + +# Speaker Verification docstring +_XVECTOR_CHECKPOINT = "microsoft/unispeech-sat-base-plus-sv" +_XVECTOR_EXPECTED_OUTPUT = 0.97 + + +@dataclass +class UniSpeechSatForPreTrainingOutput(ModelOutput): + """ + Output type of [`UniSpeechSatForPreTrainingOutput`], with potential hidden states and attentions. + + Args: + loss (*optional*, returned when model is in train mode, `torch.FloatTensor` of shape `(1,)`): + Total loss as the sum of the contrastive loss (L_m) and the diversity loss (L_d) as stated in the [official + paper](https://arxiv.org/pdf/2006.11477.pdf) . (classification) loss. + projected_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.proj_codevector_dim)`): + Hidden-states of the model projected to *config.proj_codevector_dim* that can be used to predict the masked + projected quantized states. + projected_quantized_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.proj_codevector_dim)`): + Quantized extracted feature vectors projected to *config.proj_codevector_dim* representing the positive + target vectors for contrastive loss. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + loss: Optional[torch.FloatTensor] = None + logits: Optional[torch.FloatTensor] = None + projected_states: Optional[torch.FloatTensor] = None + projected_quantized_states: Optional[torch.FloatTensor] = None + codevector_perplexity: Optional[torch.FloatTensor] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + + +class UniSpeechSatPositionalConvEmbedding(Wav2Vec2PositionalConvEmbedding): + pass + + +class UniSpeechSatFeatureEncoder(Wav2Vec2FeatureEncoder): + pass + + +class UniSpeechSatFeatureProjection(Wav2Vec2FeatureProjection): + pass + + +class UniSpeechSatEncoder(Wav2Vec2Encoder): + pass + + +class UniSpeechSatEncoderStableLayerNorm(Wav2Vec2EncoderStableLayerNorm): + pass + + +class UniSpeechSatGumbelVectorQuantizer(Wav2Vec2GumbelVectorQuantizer): + def __init__(self, config): + super().__init__() + self.weight_proj = nn.Linear(config.hidden_size, self.num_groups * self.num_vars) + + @staticmethod + def _compute_perplexity(probs, mask=None): + marginal_probs = probs.mean(dim=0) + perplexity = torch.exp(-torch.sum(marginal_probs * torch.log(marginal_probs + 1e-7), dim=-1)).sum() + return perplexity + + def forward(self, hidden_states): + batch_size, sequence_length, hidden_size = hidden_states.shape + + # project to codevector dim + hidden_states = self.weight_proj(hidden_states) + hidden_states = hidden_states.view(batch_size * sequence_length * self.num_groups, -1) + + if self.training: + # sample code vector probs via gumbel in differentiateable way + codevector_probs = nn.functional.gumbel_softmax( + hidden_states.float(), tau=self.temperature, hard=True + ).type_as(hidden_states) + + # compute perplexity + codevector_soft_dist = torch.softmax( + hidden_states.view(batch_size * sequence_length, self.num_groups, -1).float(), dim=-1 + ) + perplexity = self._compute_perplexity(codevector_soft_dist) + else: + # take argmax in non-differentiable way + # comptute hard codevector distribution (one hot) + codevector_idx = hidden_states.argmax(dim=-1) + codevector_probs = hidden_states.new_zeros(*hidden_states.shape).scatter_( + -1, codevector_idx.view(-1, 1), 1.0 + ) + codevector_probs = codevector_probs.view(batch_size * sequence_length, self.num_groups, -1) + + perplexity = self._compute_perplexity(codevector_probs) + + codevector_probs = codevector_probs.view(batch_size * sequence_length, -1) + # use probs to retrieve codevectors + codevectors_per_group = codevector_probs.unsqueeze(-1) * self.codevectors + codevectors = codevectors_per_group.view(batch_size * sequence_length, self.num_groups, self.num_vars, -1) + codevectors = codevectors.sum(-2).view(batch_size, sequence_length, -1) + + return codevectors, perplexity + + +class UniSpeechSatPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = UniSpeechSatConfig + base_model_prefix = "unispeech_sat" + main_input_name = "input_values" + supports_gradient_checkpointing = True + _supports_flash_attn_2 = True + _supports_sdpa = True + + def _init_weights(self, module): + """Initialize the weights""" + # gumbel softmax requires special init + if isinstance(module, UniSpeechSatGumbelVectorQuantizer): + module.weight_proj.weight.data.normal_(mean=0.0, std=1) + module.weight_proj.bias.data.zero_() + nn.init.uniform_(module.codevectors) + elif isinstance(module, UniSpeechSatPositionalConvEmbedding): + nn.init.normal_( + module.conv.weight, + mean=0, + std=2 * math.sqrt(1 / (module.conv.kernel_size[0] * module.conv.in_channels)), + ) + nn.init.constant_(module.conv.bias, 0) + elif isinstance(module, UniSpeechSatFeatureProjection): + k = math.sqrt(1 / module.projection.in_features) + nn.init.uniform_(module.projection.weight, a=-k, b=k) + nn.init.uniform_(module.projection.bias, a=-k, b=k) + elif isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + elif isinstance(module, nn.Conv1d): + nn.init.kaiming_normal_(module.weight) + + if module.bias is not None: + k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0])) + nn.init.uniform_(module.bias, a=-k, b=k) + + def _get_feat_extract_output_lengths(self, input_lengths: Union[torch.LongTensor, int]): + """ + Computes the output length of the convolutional layers + """ + + def _conv_out_length(input_length, kernel_size, stride): + # 1D convolutional layer output length formula taken + # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html + return torch.div(input_length - kernel_size, stride, rounding_mode="floor") + 1 + + for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride): + input_lengths = _conv_out_length(input_lengths, kernel_size, stride) + + return input_lengths + + def _get_feature_vector_attention_mask(self, feature_vector_length: int, attention_mask: torch.LongTensor): + # Effectively attention_mask.sum(-1), but not inplace to be able to run + # on inference mode. + non_padded_lengths = attention_mask.cumsum(dim=-1)[:, -1] + output_lengths = self._get_feat_extract_output_lengths(non_padded_lengths).to(torch.long) + batch_size = attention_mask.shape[0] + + attention_mask = torch.zeros( + (batch_size, feature_vector_length), dtype=attention_mask.dtype, device=attention_mask.device + ) + # these two operations makes sure that all values before the output lengths idxs are attended to + attention_mask[(torch.arange(attention_mask.shape[0], device=attention_mask.device), output_lengths - 1)] = 1 + attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool() + return attention_mask + + +UNISPEECH_SAT_START_DOCSTRING = r""" + UniSpeechSat was proposed in [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech + Representations](https://arxiv.org/abs/2006.11477) by Alexei Baevski, Henry Zhou, Abdelrahman Mohamed, Michael + Auli. + + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving etc.). + + This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use + it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and + behavior. + + Parameters: + config ([`UniSpeechSatConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +UNISPEECH_SAT_INPUTS_DOCSTRING = r""" + Args: + input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`): + Float values of input raw speech waveform. Values can be obtained by loading a `.flac` or `.wav` audio file + into an array of type `List[float]` or a `numpy.ndarray`, *e.g.* via the soundfile library (`pip install + soundfile`). To prepare the array into `input_values`, the [`AutoProcessor`] should be used for padding and + conversion into a tensor of type `torch.FloatTensor`. See [`Wav2Vec2Processor.__call__`] for details. + attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing convolution and attention on padding token indices. Mask values selected in `[0, + 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + + + `attention_mask` should only be passed if the corresponding processor has `config.return_attention_mask == + True`. For all models whose processor has `config.return_attention_mask == False`, `attention_mask` should + **not** be passed to avoid degraded performance when doing batched inference. For such models + `input_values` should simply be padded with 0 and passed without `attention_mask`. Be aware that these + models also yield slightly different results depending on whether `input_values` is padded or not. + + + + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + +UniSpeechSatBaseModelOutput = Wav2Vec2BaseModelOutput + + +@add_start_docstrings( + "The bare UniSpeechSat Model transformer outputting raw hidden-states without any specific head on top.", + UNISPEECH_SAT_START_DOCSTRING, +) +class UniSpeechSatModel(UniSpeechSatPreTrainedModel, Wav2Vec2Model): + def __init__(self, config: UniSpeechSatConfig): + UniSpeechSatPreTrainedModel.__init__(config) + self.config = config + self.feature_extractor = UniSpeechSatFeatureEncoder(config) + self.feature_projection = UniSpeechSatFeatureProjection(config) + + self.masked_spec_embed = nn.Parameter(torch.Tensor(config.hidden_size).uniform_()) + + if config.do_stable_layer_norm: + self.encoder = UniSpeechSatEncoderStableLayerNorm(config) + else: + self.encoder = UniSpeechSatEncoder(config) + + # Initialize weights and apply final processing + self.post_init() + + def freeze_feature_extractor(self): + raise AttributeError("Not needed for UniSpeechSat") + + def freeze_feature_encoder(self): + raise AttributeError("Not needed for UniSpeechSat") + + @add_start_docstrings_to_model_forward(UNISPEECH_SAT_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=UniSpeechSatBaseModelOutput, + config_class=_CONFIG_FOR_DOC, + modality="audio", + expected_output=_EXPECTED_OUTPUT_SHAPE, + ) + def forward( + self, + input_values: Optional[torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + mask_time_indices: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, UniSpeechSatBaseModelOutput]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + extract_features = self.feature_extractor(input_values) + extract_features = extract_features.transpose(1, 2) + + if attention_mask is not None: + # compute reduced attention_mask corresponding to feature vectors + attention_mask = self._get_feature_vector_attention_mask(extract_features.shape[1], attention_mask) + + hidden_states, extract_features = self.feature_projection(extract_features) + hidden_states = self._mask_hidden_states( + hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask + ) + + encoder_outputs = self.encoder( + hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = encoder_outputs[0] + + if not return_dict: + return (hidden_states, extract_features) + encoder_outputs[1:] + + return UniSpeechSatBaseModelOutput( + last_hidden_state=hidden_states, + extract_features=extract_features, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +@add_start_docstrings( + """UniSpeechSat Model with a vector-quantization module and ctc loss for pre-training.""", + UNISPEECH_SAT_START_DOCSTRING, +) +class UniSpeechSatForPreTraining(UniSpeechSatPreTrainedModel): + def __init__(self, config: UniSpeechSatConfig): + super().__init__(config) + self.unispeech_sat = UniSpeechSatModel(config) + self.dropout_features = nn.Dropout(config.feat_quantizer_dropout) + + self.quantizer = UniSpeechSatGumbelVectorQuantizer(config) + self.project_q = nn.Linear(config.codevector_dim, config.proj_codevector_dim) + self.project_hid = nn.Linear(config.hidden_size, config.proj_codevector_dim) + + self.dropout = nn.Dropout(config.final_dropout) + + self.speaker_proj = nn.Linear(config.hidden_size, config.codevector_dim) + self.label_embeddings_concat = nn.Parameter(torch.FloatTensor(config.num_clusters, config.codevector_dim)) + self.label_embeddings_concat.data.zero_() + + self.layer_norm_for_extract = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + if self.config.do_stable_layer_norm: + self.layer_norm_for_extract.requires_grad = False + + # Initialize weights and apply final processing + self.post_init() + + def set_gumbel_temperature(self, temperature: int): + """ + Set the Gumbel softmax temperature to a given value. Only necessary for training + """ + self.quantizer.temperature = temperature + + def freeze_feature_extractor(self): + """ + Calling this function will disable the gradient computation for the feature encoder so that its parameters will + not be updated during training. + """ + warnings.warn( + "The method `freeze_feature_extractor` is deprecated and will be removed in Transformers v5. " + "Please use the equivalent `freeze_feature_encoder` method instead.", + FutureWarning, + ) + self.freeze_feature_encoder() + + def freeze_feature_encoder(self): + """ + Calling this function will disable the gradient computation for the feature encoder so that its parameter will + not be updated during training. + """ + self.unispeech_sat.feature_extractor._freeze_parameters() + + @staticmethod + def compute_contrastive_logits( + target_features: torch.FloatTensor, + negative_features: torch.FloatTensor, + predicted_features: torch.FloatTensor, + temperature: int = 1, + ): + """ + Compute logits for contrastive loss based using cosine similarity as the distance measure between + `[positive_feature, negative_features]` and `[predicted_features]`. Additionally, temperature can be applied. + """ + target_features = torch.cat([target_features, negative_features], dim=0) + + logits = torch.cosine_similarity(predicted_features.float(), target_features.float(), dim=-1) + logits = logits.type_as(target_features) + + # apply temperature + logits = logits / temperature + return logits + + @add_start_docstrings_to_model_forward(UNISPEECH_SAT_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=UniSpeechSatForPreTrainingOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_values: Optional[torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, UniSpeechSatForPreTrainingOutput]: + r""" + Returns: + + Example: + + ```python + >>> import torch + >>> from transformers import AutoFeatureExtractor, UniSpeechSatForPreTraining + >>> from transformers.models.unispeech_sat.modeling_unispeech_sat import _compute_mask_indices + + >>> feature_extractor = AutoFeatureExtractor.from_pretrained("microsoft/unispeech-sat-base") + >>> model = UniSpeechSatForPreTraining.from_pretrained("microsoft/unispeech-sat-base") + >>> # TODO: Add full pretraining example + ```""" + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.unispeech_sat( + input_values, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + transformer_features = outputs[0] + + # quantize all (unmasked) extracted features and project to final vq dim + extract_features = self.dropout_features(outputs[1]) + + # TODO(PVP) - add pretraining logic and add to tests + logits = extract_features + loss = quantized_features = codevector_perplexity = None + + if not return_dict: + if loss is not None: + return (loss, logits, transformer_features, quantized_features, codevector_perplexity) + outputs[2:] + return (logits, transformer_features, quantized_features, codevector_perplexity) + outputs[2:] + + return UniSpeechSatForPreTrainingOutput( + loss=loss, + logits=logits, + projected_states=transformer_features, + projected_quantized_states=quantized_features, + codevector_perplexity=codevector_perplexity, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """UniSpeechSat Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC).""", + UNISPEECH_SAT_START_DOCSTRING, + """ + target_lang (`str`, *optional*): + Language id of adapter weights. Adapter weights are stored in the format adapter..safetensors or + adapter..bin. Only relevant when using an instance of [`UniSpeechSatForCTC`] with adapters. Uses + 'eng' by default. + """, +) +class UniSpeechSatForCTC(Wav2Vec2ForCTC): + @add_start_docstrings_to_model_forward(UNISPEECH_SAT_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=CausalLMOutput, + config_class=_CONFIG_FOR_DOC, + expected_output=_CTC_EXPECTED_OUTPUT, + expected_loss=_CTC_EXPECTED_LOSS, + ) + def forward(self, **super_kwargs): + return super().forward(**super_kwargs) + + +@add_start_docstrings( + """ + UniSpeechSat Model with a sequence classification head on top (a linear layer over the pooled output) for tasks like + SUPERB Keyword Spotting. + """, + UNISPEECH_SAT_START_DOCSTRING, +) +class UniSpeechSatForSequenceClassification(Wav2Vec2ForSequenceClassification): + @add_start_docstrings_to_model_forward(UNISPEECH_SAT_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=SequenceClassifierOutput, + config_class=_CONFIG_FOR_DOC, + modality="audio", + ) + def forward(self, **super_kwargs): + super().forward(**super_kwargs) + + +@add_start_docstrings( + """ + UniSpeechSat Model with a frame classification head on top for tasks like Speaker Diarization. + """, + UNISPEECH_SAT_START_DOCSTRING, +) +class UniSpeechSatForAudioFrameClassification(Wav2Vec2ForAudioFrameClassification): + @add_start_docstrings_to_model_forward(UNISPEECH_SAT_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_FRAME_CLASS_CHECKPOINT, + output_type=TokenClassifierOutput, + config_class=_CONFIG_FOR_DOC, + modality="audio", + expected_output=_FRAME_EXPECTED_OUTPUT, + ) + def forward(self, **super_kwargs): + super().forward(**super_kwargs) + + +@add_start_docstrings( + """ + UniSpeechSat Model with an XVector feature extraction head on top for tasks like Speaker Verification. + """, + UNISPEECH_SAT_START_DOCSTRING, +) +class UniSpeechSatForXVector(Wav2Vec2ForXVector): + pass + + @add_start_docstrings_to_model_forward(UNISPEECH_SAT_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_XVECTOR_CHECKPOINT, + output_type=XVectorOutput, + config_class=_CONFIG_FOR_DOC, + modality="audio", + expected_output=_XVECTOR_EXPECTED_OUTPUT, + ) + def forward(self, **super_kwargs): + super().forward(**super_kwargs) + + +__all__ = [ + "UniSpeechSatForAudioFrameClassification", + "UniSpeechSatForCTC", + "UniSpeechSatForPreTraining", + "UniSpeechSatForSequenceClassification", + "UniSpeechSatForXVector", + "UniSpeechSatModel", + "UniSpeechSatPreTrainedModel", +] diff --git a/src/transformers/models/wav2vec2/modeling_flax_wav2vec2.py b/src/transformers/models/wav2vec2/modeling_flax_wav2vec2.py index 55d34b84ef5..54707620501 100644 --- a/src/transformers/models/wav2vec2/modeling_flax_wav2vec2.py +++ b/src/transformers/models/wav2vec2/modeling_flax_wav2vec2.py @@ -212,7 +212,7 @@ def _sample_negative_indices(features_shape: Tuple, num_negatives: int, attentio return sampled_negative_indices -WAV_2_VEC_2_START_DOCSTRING = r""" +WAV2VEC2_START_DOCSTRING = r""" Wav2Vec2 was proposed in [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations](https://arxiv.org/abs/2006.11477) by Alexei Baevski, Henry Zhou, Abdelrahman Mohamed, Michael Auli. @@ -251,7 +251,7 @@ WAV_2_VEC_2_START_DOCSTRING = r""" """ -WAV_2_VEC_2_INPUTS_DOCSTRING = r""" +WAV2VEC2_INPUTS_DOCSTRING = r""" Args: input_values (`jnp.ndarray` of shape `(batch_size, sequence_length)`): Float values of input raw speech waveform. Values can be obtained by loading a `.flac` or `.wav` audio file @@ -885,7 +885,7 @@ class FlaxWav2Vec2PreTrainedModel(FlaxPreTrainedModel): else: return random_params - @add_start_docstrings_to_model_forward(WAV_2_VEC_2_INPUTS_DOCSTRING) + @add_start_docstrings_to_model_forward(WAV2VEC2_INPUTS_DOCSTRING) def __call__( self, input_values, @@ -1050,7 +1050,7 @@ class FlaxWav2Vec2Module(nn.Module): @add_start_docstrings( "The bare Wav2Vec2 Model transformer outputting raw hidden-states without any specific head on top.", - WAV_2_VEC_2_START_DOCSTRING, + WAV2VEC2_START_DOCSTRING, ) class FlaxWav2Vec2Model(FlaxWav2Vec2PreTrainedModel): module_class = FlaxWav2Vec2Module @@ -1088,7 +1088,7 @@ FLAX_WAV2VEC2_MODEL_DOCSTRING = """ overwrite_call_docstring( FlaxWav2Vec2Model, - WAV_2_VEC_2_INPUTS_DOCSTRING + FLAX_WAV2VEC2_MODEL_DOCSTRING, + WAV2VEC2_INPUTS_DOCSTRING + FLAX_WAV2VEC2_MODEL_DOCSTRING, ) append_replace_return_docstrings( FlaxWav2Vec2Model, output_type=FlaxWav2Vec2BaseModelOutput, config_class=Wav2Vec2Config @@ -1168,7 +1168,7 @@ class FlaxWav2Vec2ForCTCModule(nn.Module): @add_start_docstrings( "Wav2Vec2 Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC).", - WAV_2_VEC_2_START_DOCSTRING, + WAV2VEC2_START_DOCSTRING, ) class FlaxWav2Vec2ForCTC(FlaxWav2Vec2PreTrainedModel): module_class = FlaxWav2Vec2ForCTCModule @@ -1211,7 +1211,7 @@ FLAX_WAV2VEC2_FOR_CTC_DOCSTRING = """ overwrite_call_docstring( FlaxWav2Vec2ForCTC, - WAV_2_VEC_2_INPUTS_DOCSTRING + FLAX_WAV2VEC2_FOR_CTC_DOCSTRING, + WAV2VEC2_INPUTS_DOCSTRING + FLAX_WAV2VEC2_FOR_CTC_DOCSTRING, ) append_replace_return_docstrings(FlaxWav2Vec2ForCTC, output_type=FlaxCausalLMOutput, config_class=Wav2Vec2Config) @@ -1315,11 +1315,11 @@ class FlaxWav2Vec2ForPreTrainingModule(nn.Module): return input_lengths -@add_start_docstrings("""Wav2Vec2 Model with a quantizer and `VQ` head on top.""", WAV_2_VEC_2_START_DOCSTRING) +@add_start_docstrings("""Wav2Vec2 Model with a quantizer and `VQ` head on top.""", WAV2VEC2_START_DOCSTRING) class FlaxWav2Vec2ForPreTraining(FlaxWav2Vec2PreTrainedModel): module_class = FlaxWav2Vec2ForPreTrainingModule - @add_start_docstrings_to_model_forward(WAV_2_VEC_2_INPUTS_DOCSTRING) + @add_start_docstrings_to_model_forward(WAV2VEC2_INPUTS_DOCSTRING) # overwrite since has `gumbel_temperature` input def __call__( self, @@ -1418,7 +1418,7 @@ FLAX_WAV2VEC2_FOR_PRETRAINING_DOCSTRING = """ overwrite_call_docstring( FlaxWav2Vec2ForPreTraining, - WAV_2_VEC_2_INPUTS_DOCSTRING + FLAX_WAV2VEC2_FOR_PRETRAINING_DOCSTRING, + WAV2VEC2_INPUTS_DOCSTRING + FLAX_WAV2VEC2_FOR_PRETRAINING_DOCSTRING, ) append_replace_return_docstrings( FlaxWav2Vec2ForPreTraining, output_type=FlaxWav2Vec2ForPreTrainingOutput, config_class=Wav2Vec2Config diff --git a/src/transformers/models/wav2vec2/modeling_tf_wav2vec2.py b/src/transformers/models/wav2vec2/modeling_tf_wav2vec2.py index ad923bef80b..c385c192a98 100644 --- a/src/transformers/models/wav2vec2/modeling_tf_wav2vec2.py +++ b/src/transformers/models/wav2vec2/modeling_tf_wav2vec2.py @@ -1397,7 +1397,7 @@ class TFWav2Vec2PreTrainedModel(TFPreTrainedModel): return attention_mask -WAV_2_VEC_2_START_DOCSTRING = r""" +WAV2VEC2_START_DOCSTRING = r""" This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads @@ -1439,7 +1439,7 @@ WAV_2_VEC_2_START_DOCSTRING = r""" configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. """ -WAV_2_VEC_2_INPUTS_DOCSTRING = r""" +WAV2VEC2_INPUTS_DOCSTRING = r""" Args: input_values (`np.ndarray`, `tf.Tensor`, `List[tf.Tensor]` `Dict[str, tf.Tensor]` or `Dict[str, np.ndarray]` and each example must have the shape `({0})`): Indices of input sequence tokens in the vocabulary. @@ -1497,7 +1497,7 @@ WAV_2_VEC_2_INPUTS_DOCSTRING = r""" @add_start_docstrings( "The bare TFWav2Vec2 Model transformer outputing raw hidden-states without any specific head on top.", - WAV_2_VEC_2_START_DOCSTRING, + WAV2VEC2_START_DOCSTRING, ) class TFWav2Vec2Model(TFWav2Vec2PreTrainedModel): def __init__(self, config: Wav2Vec2Config, *inputs, **kwargs): @@ -1505,7 +1505,7 @@ class TFWav2Vec2Model(TFWav2Vec2PreTrainedModel): self.config = config self.wav2vec2 = TFWav2Vec2MainLayer(config, name="wav2vec2") - @add_start_docstrings_to_model_forward(WAV_2_VEC_2_INPUTS_DOCSTRING) + @add_start_docstrings_to_model_forward(WAV2VEC2_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=TFBaseModelOutput, config_class=_CONFIG_FOR_DOC) @unpack_inputs def call( @@ -1579,7 +1579,7 @@ class TFWav2Vec2Model(TFWav2Vec2PreTrainedModel): @add_start_docstrings( """TFWav2Vec2 Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC).""", - WAV_2_VEC_2_START_DOCSTRING, + WAV2VEC2_START_DOCSTRING, ) class TFWav2Vec2ForCTC(TFWav2Vec2PreTrainedModel): def __init__(self, config: Wav2Vec2Config, *inputs, **kwargs): @@ -1612,7 +1612,7 @@ class TFWav2Vec2ForCTC(TFWav2Vec2PreTrainedModel): self.wav2vec2.feature_extractor.trainable = False @unpack_inputs - @add_start_docstrings_to_model_forward(WAV_2_VEC_2_INPUTS_DOCSTRING) + @add_start_docstrings_to_model_forward(WAV2VEC2_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=TFCausalLMOutput, config_class=_CONFIG_FOR_DOC) def call( self, diff --git a/src/transformers/models/wav2vec2/modeling_wav2vec2.py b/src/transformers/models/wav2vec2/modeling_wav2vec2.py index 93f77372538..2ac0e21486e 100755 --- a/src/transformers/models/wav2vec2/modeling_wav2vec2.py +++ b/src/transformers/models/wav2vec2/modeling_wav2vec2.py @@ -63,6 +63,7 @@ if is_safetensors_available(): if is_flash_attn_available(): from ...modeling_flash_attention_utils import _flash_attention_forward + logger = logging.get_logger(__name__) @@ -1633,7 +1634,7 @@ class Wav2Vec2PreTrainedModel(PreTrainedModel): self.target_lang = target_lang -WAV_2_VEC_2_START_DOCSTRING = r""" +WAV2VEC2_START_DOCSTRING = r""" Wav2Vec2 was proposed in [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations](https://arxiv.org/abs/2006.11477) by Alexei Baevski, Henry Zhou, Abdelrahman Mohamed, Michael Auli. @@ -1652,7 +1653,7 @@ WAV_2_VEC_2_START_DOCSTRING = r""" """ -WAV_2_VEC_2_INPUTS_DOCSTRING = r""" +WAV2VEC2_INPUTS_DOCSTRING = r""" Args: input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`): Float values of input raw speech waveform. Values can be obtained by loading a `.flac` or `.wav` audio file @@ -1692,7 +1693,7 @@ WAV_2_VEC_2_INPUTS_DOCSTRING = r""" @add_start_docstrings( "The bare Wav2Vec2 Model transformer outputting raw hidden-states without any specific head on top.", - WAV_2_VEC_2_START_DOCSTRING, + WAV2VEC2_START_DOCSTRING, ) class Wav2Vec2Model(Wav2Vec2PreTrainedModel): def __init__(self, config: Wav2Vec2Config): @@ -1780,7 +1781,7 @@ class Wav2Vec2Model(Wav2Vec2PreTrainedModel): return hidden_states - @add_start_docstrings_to_model_forward(WAV_2_VEC_2_INPUTS_DOCSTRING) + @add_start_docstrings_to_model_forward(WAV2VEC2_INPUTS_DOCSTRING) @add_code_sample_docstrings( checkpoint=_CHECKPOINT_FOR_DOC, output_type=Wav2Vec2BaseModelOutput, @@ -1841,7 +1842,7 @@ class Wav2Vec2Model(Wav2Vec2PreTrainedModel): ) -@add_start_docstrings("""Wav2Vec2 Model with a quantizer and `VQ` head on top.""", WAV_2_VEC_2_START_DOCSTRING) +@add_start_docstrings("""Wav2Vec2 Model with a quantizer and `VQ` head on top.""", WAV2VEC2_START_DOCSTRING) class Wav2Vec2ForPreTraining(Wav2Vec2PreTrainedModel): def __init__(self, config: Wav2Vec2Config): super().__init__(config) @@ -1902,7 +1903,7 @@ class Wav2Vec2ForPreTraining(Wav2Vec2PreTrainedModel): logits = logits / temperature return logits - @add_start_docstrings_to_model_forward(WAV_2_VEC_2_INPUTS_DOCSTRING) + @add_start_docstrings_to_model_forward(WAV2VEC2_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=Wav2Vec2ForPreTrainingOutput, config_class=_CONFIG_FOR_DOC) def forward( self, @@ -2065,7 +2066,7 @@ class Wav2Vec2ForPreTraining(Wav2Vec2PreTrainedModel): ) -@add_start_docstrings("""Wav2Vec2 Model with a `language modeling` head on top.""", WAV_2_VEC_2_START_DOCSTRING) +@add_start_docstrings("""Wav2Vec2 Model with a `language modeling` head on top.""", WAV2VEC2_START_DOCSTRING) class Wav2Vec2ForMaskedLM(Wav2Vec2PreTrainedModel): def __init__(self, config): super().__init__(config) @@ -2081,7 +2082,7 @@ class Wav2Vec2ForMaskedLM(Wav2Vec2PreTrainedModel): # Initialize weights and apply final processing self.post_init() - @add_start_docstrings_to_model_forward(WAV_2_VEC_2_INPUTS_DOCSTRING) + @add_start_docstrings_to_model_forward(WAV2VEC2_INPUTS_DOCSTRING) def forward( self, input_values: torch.FloatTensor, @@ -2113,7 +2114,7 @@ class Wav2Vec2ForMaskedLM(Wav2Vec2PreTrainedModel): @add_start_docstrings( """Wav2Vec2 Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC).""", - WAV_2_VEC_2_START_DOCSTRING, + WAV2VEC2_START_DOCSTRING, """ target_lang (`str`, *optional*): Language id of adapter weights. Adapter weights are stored in the format adapter..safetensors or @@ -2193,7 +2194,7 @@ class Wav2Vec2ForCTC(Wav2Vec2PreTrainedModel): for param in self.wav2vec2.parameters(): param.requires_grad = False - @add_start_docstrings_to_model_forward(WAV_2_VEC_2_INPUTS_DOCSTRING) + @add_start_docstrings_to_model_forward(WAV2VEC2_INPUTS_DOCSTRING) @add_code_sample_docstrings( checkpoint=_CHECKPOINT_FOR_DOC, output_type=CausalLMOutput, @@ -2277,7 +2278,7 @@ class Wav2Vec2ForCTC(Wav2Vec2PreTrainedModel): Wav2Vec2 Model with a sequence classification head on top (a linear layer over the pooled output) for tasks like SUPERB Keyword Spotting. """, - WAV_2_VEC_2_START_DOCSTRING, + WAV2VEC2_START_DOCSTRING, ) class Wav2Vec2ForSequenceClassification(Wav2Vec2PreTrainedModel): def __init__(self, config): @@ -2324,7 +2325,7 @@ class Wav2Vec2ForSequenceClassification(Wav2Vec2PreTrainedModel): for param in self.wav2vec2.parameters(): param.requires_grad = False - @add_start_docstrings_to_model_forward(WAV_2_VEC_2_INPUTS_DOCSTRING) + @add_start_docstrings_to_model_forward(WAV2VEC2_INPUTS_DOCSTRING) @add_code_sample_docstrings( checkpoint=_SEQ_CLASS_CHECKPOINT, output_type=SequenceClassifierOutput, @@ -2400,7 +2401,7 @@ class Wav2Vec2ForSequenceClassification(Wav2Vec2PreTrainedModel): """ Wav2Vec2 Model with a frame classification head on top for tasks like Speaker Diarization. """, - WAV_2_VEC_2_START_DOCSTRING, + WAV2VEC2_START_DOCSTRING, ) class Wav2Vec2ForAudioFrameClassification(Wav2Vec2PreTrainedModel): def __init__(self, config): @@ -2446,7 +2447,7 @@ class Wav2Vec2ForAudioFrameClassification(Wav2Vec2PreTrainedModel): for param in self.wav2vec2.parameters(): param.requires_grad = False - @add_start_docstrings_to_model_forward(WAV_2_VEC_2_INPUTS_DOCSTRING) + @add_start_docstrings_to_model_forward(WAV2VEC2_INPUTS_DOCSTRING) @add_code_sample_docstrings( checkpoint=_FRAME_CLASS_CHECKPOINT, output_type=TokenClassifierOutput, @@ -2546,6 +2547,7 @@ class TDNNLayer(nn.Module): if is_peft_available(): from peft.tuners.lora import LoraLayer + if is_peft_available(): if isinstance(self.kernel, LoraLayer): warnings.warn( "Detected LoRA on TDNNLayer. LoRA weights won't be applied due to optimization. " @@ -2566,7 +2568,7 @@ class TDNNLayer(nn.Module): """ Wav2Vec2 Model with an XVector feature extraction head on top for tasks like Speaker Verification. """, - WAV_2_VEC_2_START_DOCSTRING, + WAV2VEC2_START_DOCSTRING, ) class Wav2Vec2ForXVector(Wav2Vec2PreTrainedModel): def __init__(self, config): @@ -2630,7 +2632,7 @@ class Wav2Vec2ForXVector(Wav2Vec2PreTrainedModel): return input_lengths - @add_start_docstrings_to_model_forward(WAV_2_VEC_2_INPUTS_DOCSTRING) + @add_start_docstrings_to_model_forward(WAV2VEC2_INPUTS_DOCSTRING) @add_code_sample_docstrings( checkpoint=_XVECTOR_CHECKPOINT, output_type=XVectorOutput, diff --git a/src/transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py b/src/transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py index 0fb591542f4..86e9b65b3d5 100644 --- a/src/transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +++ b/src/transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py @@ -1,26 +1,15 @@ -# coding=utf-8 -# Copyright 2024 The Seamless Authors and the HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""PyTorch Wav2Vec2-BERT model.""" - +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_wav2vec2_bert.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 import math import warnings from typing import Optional, Tuple, Union import numpy as np import torch -import torch.utils.checkpoint from torch import nn from torch.nn import CrossEntropyLoss @@ -42,213 +31,14 @@ from ...utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, is_peft_available, - logging, ) from .configuration_wav2vec2_bert import Wav2Vec2BertConfig -logger = logging.get_logger(__name__) - - -_HIDDEN_STATES_START_POSITION = 2 - # General docstring _CONFIG_FOR_DOC = "Wav2Vec2BertConfig" -# Base docstring -_BASE_CHECKPOINT_FOR_DOC = "facebook/w2v-bert-2.0" -_PRETRAINED_CHECKPOINT_FOR_DOC = "hf-audio/wav2vec2-bert-CV16-en" -_EXPECTED_OUTPUT_SHAPE = [1, 146, 1024] -# CTC docstring -_CTC_EXPECTED_OUTPUT = "'mr quilter is the apostle of the middle classes and we are glad to welcome his gospel'" -_CTC_EXPECTED_LOSS = 17.04 - - -# Copied from transformers.models.seamless_m4t_v2.modeling_seamless_m4t_v2._compute_new_attention_mask -def _compute_new_attention_mask(hidden_states: torch.Tensor, seq_lens: torch.Tensor): - """ - Computes an attention mask of the form `(batch, seq_len)` with an attention for each element in the batch that - stops at the corresponding element in `seq_lens`. - Args: - hidden_states (`torch.FloatTensor` of shape `(batch, seq_len, *)`): - The sequences to mask, where `*` is any number of sequence-specific dimensions including none. - seq_lens (`torch.Tensor` of shape `(batch)`: - Each element represents the length of the sequence at the same index in `hidden_states` - Returns: - `torch.FloatTensor`: The float attention mask of shape `(batch, seq_len)` - """ - batch_size, mask_seq_len = hidden_states.shape[:2] - - indices = torch.arange(mask_seq_len, device=seq_lens.device).expand(batch_size, -1) - - bool_mask = indices >= seq_lens.unsqueeze(1).expand(-1, mask_seq_len) - - mask = hidden_states.new_ones((batch_size, mask_seq_len)) - - mask = mask.masked_fill(bool_mask, 0) - - return mask - - -# Copied from transformers.models.wav2vec2.modeling_wav2vec2._compute_mask_indices -def _compute_mask_indices( - shape: Tuple[int, int], - mask_prob: float, - mask_length: int, - attention_mask: Optional[torch.LongTensor] = None, - min_masks: int = 0, -) -> np.ndarray: - """ - Computes random mask spans for a given shape. Used to implement [SpecAugment: A Simple Data Augmentation Method for - ASR](https://arxiv.org/abs/1904.08779). Note that this method is not optimized to run on TPU and should be run on - CPU as part of the preprocessing during training. - - Args: - shape: The shape for which to compute masks. This should be of a tuple of size 2 where - the first element is the batch size and the second element is the length of the axis to span. - mask_prob: The percentage of the whole axis (between 0 and 1) which will be masked. The number of - independently generated mask spans of length `mask_length` is computed by - `mask_prob*shape[1]/mask_length`. Note that due to overlaps, `mask_prob` is an upper bound and the - actual percentage will be smaller. - mask_length: size of the mask - min_masks: minimum number of masked spans - attention_mask: A (right-padded) attention mask which independently shortens the feature axis of - each batch dimension. - """ - batch_size, sequence_length = shape - - if mask_length < 1: - raise ValueError("`mask_length` has to be bigger than 0.") - - if mask_length > sequence_length: - raise ValueError( - f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length}" - f" and `sequence_length`: {sequence_length}`" - ) - - # epsilon is used for probabilistic rounding - epsilon = np.random.rand(1).item() - - def compute_num_masked_span(input_length): - """Given input length, compute how many spans should be masked""" - num_masked_span = int(mask_prob * input_length / mask_length + epsilon) - num_masked_span = max(num_masked_span, min_masks) - - # make sure num masked span <= sequence_length - if num_masked_span * mask_length > sequence_length: - num_masked_span = sequence_length // mask_length - - # make sure num_masked span is also <= input_length - (mask_length - 1) - if input_length - (mask_length - 1) < num_masked_span: - num_masked_span = max(input_length - (mask_length - 1), 0) - - return num_masked_span - - # compute number of masked spans in batch - input_lengths = ( - attention_mask.detach().sum(-1).tolist() - if attention_mask is not None - else [sequence_length for _ in range(batch_size)] - ) - - # SpecAugment mask to fill - spec_aug_mask = np.zeros((batch_size, sequence_length), dtype=bool) - spec_aug_mask_idxs = [] - - max_num_masked_span = compute_num_masked_span(sequence_length) - - if max_num_masked_span == 0: - return spec_aug_mask - - for input_length in input_lengths: - # compute num of masked spans for this input - num_masked_span = compute_num_masked_span(input_length) - - # get random indices to mask - spec_aug_mask_idx = np.random.choice( - np.arange(input_length - (mask_length - 1)), num_masked_span, replace=False - ) - - # pick first sampled index that will serve as a dummy index to pad vector - # to ensure same dimension for all batches due to probabilistic rounding - # Picking first sample just pads those vectors twice. - if len(spec_aug_mask_idx) == 0: - # this case can only happen if `input_length` is strictly smaller then - # `sequence_length` in which case the last token has to be a padding - # token which we can use as a dummy mask id - dummy_mask_idx = sequence_length - 1 - else: - dummy_mask_idx = spec_aug_mask_idx[0] - - spec_aug_mask_idx = np.concatenate( - [spec_aug_mask_idx, np.ones(max_num_masked_span - num_masked_span, dtype=np.int32) * dummy_mask_idx] - ) - spec_aug_mask_idxs.append(spec_aug_mask_idx) - - spec_aug_mask_idxs = np.array(spec_aug_mask_idxs) - - # expand masked indices to masked spans - spec_aug_mask_idxs = np.broadcast_to( - spec_aug_mask_idxs[:, :, None], (batch_size, max_num_masked_span, mask_length) - ) - spec_aug_mask_idxs = spec_aug_mask_idxs.reshape(batch_size, max_num_masked_span * mask_length) - - # add offset to the starting indexes so that indexes now create a span - offsets = np.arange(mask_length)[None, None, :] - offsets = np.broadcast_to(offsets, (batch_size, max_num_masked_span, mask_length)).reshape( - batch_size, max_num_masked_span * mask_length - ) - spec_aug_mask_idxs = spec_aug_mask_idxs + offsets - - # ensure that we cannot have indices larger than sequence_length - if spec_aug_mask_idxs.max() > sequence_length - 1: - spec_aug_mask_idxs[spec_aug_mask_idxs > sequence_length - 1] = sequence_length - 1 - - # scatter indices to mask - np.put_along_axis(spec_aug_mask, spec_aug_mask_idxs, 1, -1) - - return spec_aug_mask - - -# Copied from transformers.models.wav2vec2.modeling_wav2vec2._sample_negative_indices -def _sample_negative_indices( - features_shape: Tuple, num_negatives: int, mask_time_indices: Optional[np.ndarray] = None -): - """ - Sample `num_negatives` vectors from feature vectors. - """ - batch_size, sequence_length = features_shape - - # generate indices of the positive vectors themselves, repeat them `num_negatives` times - sequence_length_range = np.arange(sequence_length) - - # get `num_negatives` random vector indices from the same utterance - sampled_negative_indices = np.zeros(shape=(batch_size, sequence_length, num_negatives), dtype=np.int32) - - mask_time_indices = ( - mask_time_indices.astype(bool) if mask_time_indices is not None else np.ones(features_shape, dtype=bool) - ) - - for batch_idx in range(batch_size): - high = mask_time_indices[batch_idx].sum() - 1 - mapped_masked_indices = sequence_length_range[mask_time_indices[batch_idx]] - - feature_indices = np.broadcast_to(np.arange(high + 1)[:, None], (high + 1, num_negatives)) - sampled_indices = np.random.randint(0, high, size=(high + 1, num_negatives)) - # avoid sampling the same positive vector, but keep the distribution uniform - sampled_indices[sampled_indices >= feature_indices] += 1 - - # remap to actual indices - sampled_negative_indices[batch_idx][mask_time_indices[batch_idx]] = mapped_masked_indices[sampled_indices] - - # correct for batch size - sampled_negative_indices[batch_idx] += batch_idx * sequence_length - - return sampled_negative_indices - - -# Copied from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer.Wav2Vec2ConformerRotaryPositionalEmbedding with Wav2Vec2Conformer->Wav2Vec2Bert class Wav2Vec2BertRotaryPositionalEmbedding(nn.Module): """Rotary positional embedding Reference : https://blog.eleuther.ai/rotary-embeddings/ Paper: https://arxiv.org/pdf/2104.09864.pdf @@ -284,7 +74,6 @@ class Wav2Vec2BertRotaryPositionalEmbedding(nn.Module): return self.cached_rotary_positional_embedding -# Copied from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer.Wav2Vec2ConformerRelPositionalEmbedding with Wav2Vec2Conformer->Wav2Vec2Bert class Wav2Vec2BertRelPositionalEmbedding(nn.Module): """Relative positional encoding module.""" @@ -363,7 +152,6 @@ class Wav2Vec2BertFeedForward(nn.Module): self.output_dense = nn.Linear(config.intermediate_size, hidden_size) self.output_dropout = nn.Dropout(config.hidden_dropout) - # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2FeedForward.forward def forward(self, hidden_states): hidden_states = self.intermediate_dense(hidden_states) hidden_states = self.intermediate_act_fn(hidden_states) @@ -556,7 +344,6 @@ class Wav2Vec2BertSelfAttention(nn.Module): return hidden_states, probs - # Copied from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer.Wav2Vec2ConformerSelfAttention._apply_rotary_embedding def _apply_rotary_embedding(self, hidden_states, relative_position_embeddings): batch_size, sequence_length, hidden_size = hidden_states.size() hidden_states = hidden_states.view(batch_size, sequence_length, self.num_heads, self.head_size) @@ -576,7 +363,6 @@ class Wav2Vec2BertSelfAttention(nn.Module): return hidden_states - # Copied from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer.Wav2Vec2ConformerSelfAttention._apply_relative_embeddings def _apply_relative_embeddings(self, query, key, relative_position_embeddings): # 1. project positional embeddings # => (batch, head, 2*time1-1, d_k) @@ -823,6 +609,32 @@ class Wav2Vec2BertAdapter(nn.Module): return hidden_states +# Copied from transformers.models.seamless_m4t_v2.modeling_seamless_m4t_v2._compute_new_attention_mask +def _compute_new_attention_mask(hidden_states: torch.Tensor, seq_lens: torch.Tensor): + """ + Computes an attention mask of the form `(batch, seq_len)` with an attention for each element in the batch that + stops at the corresponding element in `seq_lens`. + Args: + hidden_states (`torch.FloatTensor` of shape `(batch, seq_len, *)`): + The sequences to mask, where `*` is any number of sequence-specific dimensions including none. + seq_lens (`torch.Tensor` of shape `(batch)`: + Each element represents the length of the sequence at the same index in `hidden_states` + Returns: + `torch.FloatTensor`: The float attention mask of shape `(batch, seq_len)` + """ + batch_size, mask_seq_len = hidden_states.shape[:2] + + indices = torch.arange(mask_seq_len, device=seq_lens.device).expand(batch_size, -1) + + bool_mask = indices >= seq_lens.unsqueeze(1).expand(-1, mask_seq_len) + + mask = hidden_states.new_ones((batch_size, mask_seq_len)) + + mask = mask.masked_fill(bool_mask, 0) + + return mask + + class Wav2Vec2BertAdapterLayer(nn.Module): def __init__(self, config): super().__init__() @@ -911,7 +723,6 @@ class Wav2Vec2BertAdapterLayer(nn.Module): return hidden_states -# Copied from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer.Wav2Vec2ConformerPreTrainedModel with Wav2Vec2Conformer->Wav2Vec2Bert,wav2vec2_conformer->wav2vec2_bert, input_values->input_features class Wav2Vec2BertPreTrainedModel(PreTrainedModel): """ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained @@ -995,6 +806,129 @@ class Wav2Vec2BertPreTrainedModel(PreTrainedModel): return attention_mask +def _compute_mask_indices( + shape: Tuple[int, int], + mask_prob: float, + mask_length: int, + attention_mask: Optional[torch.LongTensor] = None, + min_masks: int = 0, +) -> np.ndarray: + """ + Computes random mask spans for a given shape. Used to implement [SpecAugment: A Simple Data Augmentation Method for + ASR](https://arxiv.org/abs/1904.08779). Note that this method is not optimized to run on TPU and should be run on + CPU as part of the preprocessing during training. + + Args: + shape: The shape for which to compute masks. This should be of a tuple of size 2 where + the first element is the batch size and the second element is the length of the axis to span. + mask_prob: The percentage of the whole axis (between 0 and 1) which will be masked. The number of + independently generated mask spans of length `mask_length` is computed by + `mask_prob*shape[1]/mask_length`. Note that due to overlaps, `mask_prob` is an upper bound and the + actual percentage will be smaller. + mask_length: size of the mask + min_masks: minimum number of masked spans + attention_mask: A (right-padded) attention mask which independently shortens the feature axis of + each batch dimension. + """ + batch_size, sequence_length = shape + + if mask_length < 1: + raise ValueError("`mask_length` has to be bigger than 0.") + + if mask_length > sequence_length: + raise ValueError( + f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length}" + f" and `sequence_length`: {sequence_length}`" + ) + + # epsilon is used for probabilistic rounding + epsilon = np.random.rand(1).item() + + def compute_num_masked_span(input_length): + """Given input length, compute how many spans should be masked""" + num_masked_span = int(mask_prob * input_length / mask_length + epsilon) + num_masked_span = max(num_masked_span, min_masks) + + # make sure num masked span <= sequence_length + if num_masked_span * mask_length > sequence_length: + num_masked_span = sequence_length // mask_length + + # make sure num_masked span is also <= input_length - (mask_length - 1) + if input_length - (mask_length - 1) < num_masked_span: + num_masked_span = max(input_length - (mask_length - 1), 0) + + return num_masked_span + + # compute number of masked spans in batch + input_lengths = ( + attention_mask.detach().sum(-1).tolist() + if attention_mask is not None + else [sequence_length for _ in range(batch_size)] + ) + + # SpecAugment mask to fill + spec_aug_mask = np.zeros((batch_size, sequence_length), dtype=bool) + spec_aug_mask_idxs = [] + + max_num_masked_span = compute_num_masked_span(sequence_length) + + if max_num_masked_span == 0: + return spec_aug_mask + + for input_length in input_lengths: + # compute num of masked spans for this input + num_masked_span = compute_num_masked_span(input_length) + + # get random indices to mask + spec_aug_mask_idx = np.random.choice( + np.arange(input_length - (mask_length - 1)), num_masked_span, replace=False + ) + + # pick first sampled index that will serve as a dummy index to pad vector + # to ensure same dimension for all batches due to probabilistic rounding + # Picking first sample just pads those vectors twice. + if len(spec_aug_mask_idx) == 0: + # this case can only happen if `input_length` is strictly smaller then + # `sequence_length` in which case the last token has to be a padding + # token which we can use as a dummy mask id + dummy_mask_idx = sequence_length - 1 + else: + dummy_mask_idx = spec_aug_mask_idx[0] + + spec_aug_mask_idx = np.concatenate( + [spec_aug_mask_idx, np.ones(max_num_masked_span - num_masked_span, dtype=np.int32) * dummy_mask_idx] + ) + spec_aug_mask_idxs.append(spec_aug_mask_idx) + + spec_aug_mask_idxs = np.array(spec_aug_mask_idxs) + + # expand masked indices to masked spans + spec_aug_mask_idxs = np.broadcast_to( + spec_aug_mask_idxs[:, :, None], (batch_size, max_num_masked_span, mask_length) + ) + spec_aug_mask_idxs = spec_aug_mask_idxs.reshape(batch_size, max_num_masked_span * mask_length) + + # add offset to the starting indexes so that indexes now create a span + offsets = np.arange(mask_length)[None, None, :] + offsets = np.broadcast_to(offsets, (batch_size, max_num_masked_span, mask_length)).reshape( + batch_size, max_num_masked_span * mask_length + ) + spec_aug_mask_idxs = spec_aug_mask_idxs + offsets + + # ensure that we cannot have indices larger than sequence_length + if spec_aug_mask_idxs.max() > sequence_length - 1: + spec_aug_mask_idxs[spec_aug_mask_idxs > sequence_length - 1] = sequence_length - 1 + + # scatter indices to mask + np.put_along_axis(spec_aug_mask, spec_aug_mask_idxs, 1, -1) + + return spec_aug_mask + + +_PRETRAINED_CHECKPOINT_FOR_DOC = "hf-audio/wav2vec2-bert-CV16-en" +_EXPECTED_OUTPUT_SHAPE = [1, 146, 1024] + + WAV2VEC2_BERT_START_DOCSTRING = r""" Wav2Vec2Bert was proposed in [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations](https://arxiv.org/abs/2006.11477) by Alexei Baevski, Henry Zhou, Abdelrahman Mohamed, Michael @@ -1003,8 +937,9 @@ WAV2VEC2_BERT_START_DOCSTRING = r""" This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving etc.). - This model is a PyTorch [nn.Module](https://pytorch.org/docs/stable/nn.html#nn.Module) sub-class. Use it as a - regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior. + This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use + it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and + behavior. Parameters: config ([`Wav2Vec2BertConfig`]): Model configuration class with all the parameters of the model. @@ -1012,7 +947,6 @@ WAV2VEC2_BERT_START_DOCSTRING = r""" configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. """ - WAV2VEC2_BERT_INPUTS_DOCSTRING = r""" Args: input_features (`torch.FloatTensor` of shape `(batch_size, sequence_length)`): @@ -1039,6 +973,9 @@ WAV2VEC2_BERT_INPUTS_DOCSTRING = r""" """ +Wav2Vec2BertBaseModelOutput = Wav2Vec2BaseModelOutput + + @add_start_docstrings( "The bare Wav2Vec2Bert Model transformer outputting raw hidden-states without any specific head on top.", WAV2VEC2_BERT_START_DOCSTRING, @@ -1064,7 +1001,6 @@ class Wav2Vec2BertModel(Wav2Vec2BertPreTrainedModel): # Initialize weights and apply final processing self.post_init() - # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Model._mask_hidden_states def _mask_hidden_states( self, hidden_states: torch.FloatTensor, @@ -1114,7 +1050,7 @@ class Wav2Vec2BertModel(Wav2Vec2BertPreTrainedModel): @add_start_docstrings_to_model_forward(WAV2VEC2_BERT_INPUTS_DOCSTRING) @add_code_sample_docstrings( checkpoint=_PRETRAINED_CHECKPOINT_FOR_DOC, - output_type=Wav2Vec2BaseModelOutput, + output_type=Wav2Vec2BertBaseModelOutput, config_class=_CONFIG_FOR_DOC, modality="audio", expected_output=_EXPECTED_OUTPUT_SHAPE, @@ -1127,7 +1063,7 @@ class Wav2Vec2BertModel(Wav2Vec2BertPreTrainedModel): output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, - ) -> Union[Tuple, Wav2Vec2BaseModelOutput]: + ) -> Union[Tuple, Wav2Vec2BertBaseModelOutput]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states @@ -1159,7 +1095,7 @@ class Wav2Vec2BertModel(Wav2Vec2BertPreTrainedModel): if not return_dict: return (hidden_states, extract_features) + encoder_outputs[1:] - return Wav2Vec2BaseModelOutput( + return Wav2Vec2BertBaseModelOutput( last_hidden_state=hidden_states, extract_features=extract_features, hidden_states=encoder_outputs.hidden_states, @@ -1167,12 +1103,18 @@ class Wav2Vec2BertModel(Wav2Vec2BertPreTrainedModel): ) +_HIDDEN_STATES_START_POSITION = 2 + +# CTC docstring +_CTC_EXPECTED_OUTPUT = "'mr quilter is the apostle of the middle classes and we are glad to welcome his gospel'" +_CTC_EXPECTED_LOSS = 17.04 + + @add_start_docstrings( """Wav2Vec2Bert Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC).""", WAV2VEC2_BERT_START_DOCSTRING, ) class Wav2Vec2BertForCTC(Wav2Vec2BertPreTrainedModel): - # Copied from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer.Wav2Vec2ConformerForCTC.__init__ with Wav2Vec2Conformer->Wav2Vec2Bert,WAV2VEC2_CONFORMER->WAV2VEC2_BERT,wav2vec2_conformer->wav2vec2_bert def __init__(self, config, target_lang: Optional[str] = None): super().__init__(config) @@ -1277,6 +1219,10 @@ class Wav2Vec2BertForCTC(Wav2Vec2BertPreTrainedModel): ) +# Base docstring +_BASE_CHECKPOINT_FOR_DOC = "facebook/w2v-bert-2.0" + + @add_start_docstrings( """ Wav2Vec2Bert Model with a sequence classification head on top (a linear layer over the pooled output) for @@ -1285,7 +1231,6 @@ class Wav2Vec2BertForCTC(Wav2Vec2BertPreTrainedModel): WAV2VEC2_BERT_START_DOCSTRING, ) class Wav2Vec2BertForSequenceClassification(Wav2Vec2BertPreTrainedModel): - # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.__init__ with Wav2Vec2->Wav2Vec2Bert,wav2vec2->wav2vec2_bert def __init__(self, config): super().__init__(config) @@ -1318,7 +1263,6 @@ class Wav2Vec2BertForSequenceClassification(Wav2Vec2BertPreTrainedModel): config_class=_CONFIG_FOR_DOC, modality="audio", ) - # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.forward with Wav2Vec2->Wav2Vec2Bert,wav2vec2->wav2vec2_bert,WAV_2_VEC_2->WAV2VEC2_BERT, input_values->input_features def forward( self, input_features: Optional[torch.Tensor], @@ -1389,7 +1333,6 @@ class Wav2Vec2BertForSequenceClassification(Wav2Vec2BertPreTrainedModel): WAV2VEC2_BERT_START_DOCSTRING, ) class Wav2Vec2BertForAudioFrameClassification(Wav2Vec2BertPreTrainedModel): - # Copied from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer.Wav2Vec2ConformerForAudioFrameClassification.__init__ with Wav2Vec2Conformer->Wav2Vec2Bert,WAV2VEC2_CONFORMER->WAV2VEC2_BERT,wav2vec2_conformer->wav2vec2_bert def __init__(self, config): super().__init__(config) @@ -1406,7 +1349,6 @@ class Wav2Vec2BertForAudioFrameClassification(Wav2Vec2BertPreTrainedModel): self.init_weights() - # Copied from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer.Wav2Vec2ConformerForAudioFrameClassification.freeze_base_model with wav2vec2_conformer->wav2vec2_bert def freeze_base_model(self): """ Calling this function will disable the gradient computation for the base model so that its parameters will not @@ -1422,7 +1364,6 @@ class Wav2Vec2BertForAudioFrameClassification(Wav2Vec2BertPreTrainedModel): config_class=_CONFIG_FOR_DOC, modality="audio", ) - # Copied from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer.Wav2Vec2ConformerForAudioFrameClassification.forward with wav2vec2_conformer->wav2vec2_bert, input_values->input_features def forward( self, input_features: Optional[torch.Tensor], @@ -1477,7 +1418,6 @@ class Wav2Vec2BertForAudioFrameClassification(Wav2Vec2BertPreTrainedModel): ) -# Copied from transformers.models.wav2vec2.modeling_wav2vec2.AMSoftmaxLoss class AMSoftmaxLoss(nn.Module): def __init__(self, input_dim, num_labels, scale=30.0, margin=0.4): super(AMSoftmaxLoss, self).__init__() @@ -1501,7 +1441,6 @@ class AMSoftmaxLoss(nn.Module): return loss -# Copied from transformers.models.wav2vec2.modeling_wav2vec2.TDNNLayer class TDNNLayer(nn.Module): def __init__(self, config, layer_id=0): super().__init__() @@ -1517,6 +1456,7 @@ class TDNNLayer(nn.Module): if is_peft_available(): from peft.tuners.lora import LoraLayer + if is_peft_available(): if isinstance(self.kernel, LoraLayer): warnings.warn( "Detected LoRA on TDNNLayer. LoRA weights won't be applied due to optimization. " @@ -1540,7 +1480,6 @@ class TDNNLayer(nn.Module): WAV2VEC2_BERT_START_DOCSTRING, ) class Wav2Vec2BertForXVector(Wav2Vec2BertPreTrainedModel): - # Copied from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer.Wav2Vec2ConformerForXVector.__init__ with Wav2Vec2Conformer->Wav2Vec2Bert,WAV2VEC2_CONFORMER->WAV2VEC2_BERT,wav2vec2_conformer->wav2vec2_bert def __init__(self, config): super().__init__(config) @@ -1560,7 +1499,6 @@ class Wav2Vec2BertForXVector(Wav2Vec2BertPreTrainedModel): self.init_weights() - # Copied from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer.Wav2Vec2ConformerForXVector.freeze_base_model with wav2vec2_conformer->wav2vec2_bert def freeze_base_model(self): """ Calling this function will disable the gradient computation for the base model so that its parameters will not @@ -1569,7 +1507,6 @@ class Wav2Vec2BertForXVector(Wav2Vec2BertPreTrainedModel): for param in self.wav2vec2_bert.parameters(): param.requires_grad = False - # Copied from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer.Wav2Vec2ConformerForXVector._get_tdnn_output_lengths def _get_tdnn_output_lengths(self, input_lengths: Union[torch.LongTensor, int]): """ Computes the output length of the TDNN layers @@ -1592,7 +1529,6 @@ class Wav2Vec2BertForXVector(Wav2Vec2BertPreTrainedModel): config_class=_CONFIG_FOR_DOC, modality="audio", ) - # Copied from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer.Wav2Vec2ConformerForXVector.forward with wav2vec2_conformer->wav2vec2_bert, input_values->input_features def forward( self, input_features: Optional[torch.Tensor], diff --git a/src/transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py b/src/transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py new file mode 100644 index 00000000000..b8dc95c6754 --- /dev/null +++ b/src/transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py @@ -0,0 +1,1169 @@ +import math +from typing import Optional, Tuple, Union + +import torch +from torch import nn +from torch.nn import CrossEntropyLoss + +from ...activations import ACT2FN +from ...integrations.deepspeed import is_deepspeed_zero3_enabled +from ...integrations.fsdp import is_fsdp_managed_module +from ...modeling_attn_mask_utils import _prepare_4d_attention_mask +from ...modeling_outputs import ( + BaseModelOutput, + CausalLMOutput, + SequenceClassifierOutput, + TokenClassifierOutput, + Wav2Vec2BaseModelOutput, + XVectorOutput, +) +from ...modeling_utils import PreTrainedModel +from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging +from ..wav2vec2.modeling_wav2vec2 import Wav2Vec2FeedForward, Wav2Vec2ForSequenceClassification, Wav2Vec2Model +from ..wav2vec2_conformer.modeling_wav2vec2_conformer import ( + Wav2Vec2ConformerForAudioFrameClassification, + Wav2Vec2ConformerForCTC, + Wav2Vec2ConformerForXVector, + Wav2Vec2ConformerRelPositionalEmbedding, + Wav2Vec2ConformerRotaryPositionalEmbedding, + Wav2Vec2ConformerSelfAttention, +) +from .configuration_wav2vec2_bert import Wav2Vec2BertConfig + + +logger = logging.get_logger(__name__) + + +_HIDDEN_STATES_START_POSITION = 2 + +# General docstring +_CONFIG_FOR_DOC = "Wav2Vec2BertConfig" + +# Base docstring +_BASE_CHECKPOINT_FOR_DOC = "facebook/w2v-bert-2.0" +_PRETRAINED_CHECKPOINT_FOR_DOC = "hf-audio/wav2vec2-bert-CV16-en" +_EXPECTED_OUTPUT_SHAPE = [1, 146, 1024] + +# CTC docstring +_CTC_EXPECTED_OUTPUT = "'mr quilter is the apostle of the middle classes and we are glad to welcome his gospel'" +_CTC_EXPECTED_LOSS = 17.04 + + +# Copied from transformers.models.seamless_m4t_v2.modeling_seamless_m4t_v2._compute_new_attention_mask +def _compute_new_attention_mask(hidden_states: torch.Tensor, seq_lens: torch.Tensor): + """ + Computes an attention mask of the form `(batch, seq_len)` with an attention for each element in the batch that + stops at the corresponding element in `seq_lens`. + Args: + hidden_states (`torch.FloatTensor` of shape `(batch, seq_len, *)`): + The sequences to mask, where `*` is any number of sequence-specific dimensions including none. + seq_lens (`torch.Tensor` of shape `(batch)`: + Each element represents the length of the sequence at the same index in `hidden_states` + Returns: + `torch.FloatTensor`: The float attention mask of shape `(batch, seq_len)` + """ + batch_size, mask_seq_len = hidden_states.shape[:2] + + indices = torch.arange(mask_seq_len, device=seq_lens.device).expand(batch_size, -1) + + bool_mask = indices >= seq_lens.unsqueeze(1).expand(-1, mask_seq_len) + + mask = hidden_states.new_ones((batch_size, mask_seq_len)) + + mask = mask.masked_fill(bool_mask, 0) + + return mask + + +class Wav2Vec2BertRotaryPositionalEmbedding(Wav2Vec2ConformerRotaryPositionalEmbedding, nn.Module): + def __init__(self, config): + nn.Module.__init__() + dim = config.hidden_size // config.num_attention_heads + base = config.rotary_embedding_base + + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim)) + # Ignore copy + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.cached_sequence_length = None + self.cached_rotary_positional_embedding = None + + +class Wav2Vec2BertRelPositionalEmbedding(Wav2Vec2ConformerRelPositionalEmbedding): + pass + + +class Wav2Vec2BertFeatureProjection(nn.Module): + def __init__(self, config): + super().__init__() + self.layer_norm = nn.LayerNorm(config.feature_projection_input_dim, eps=config.layer_norm_eps) + self.projection = nn.Linear(config.feature_projection_input_dim, config.hidden_size) + self.dropout = nn.Dropout(config.feat_proj_dropout) + + def forward(self, hidden_states): + # non-projected hidden states are needed for quantization + norm_hidden_states = self.layer_norm(hidden_states) + hidden_states = self.projection(norm_hidden_states) + hidden_states = self.dropout(hidden_states) + return hidden_states, norm_hidden_states + + +class Wav2Vec2BertFeedForward(Wav2Vec2FeedForward, nn.Module): + def __init__(self, config, act_fn=None, hidden_size=None): + nn.Module.__init__() + act_fn = act_fn if act_fn is not None else config.hidden_act + hidden_size = hidden_size if hidden_size is not None else config.hidden_size + self.intermediate_dropout = nn.Dropout(config.activation_dropout) + + self.intermediate_dense = nn.Linear(hidden_size, config.intermediate_size) + self.intermediate_act_fn = ACT2FN[act_fn] if isinstance(act_fn, str) else act_fn + + self.output_dense = nn.Linear(config.intermediate_size, hidden_size) + self.output_dropout = nn.Dropout(config.hidden_dropout) + + +class Wav2Vec2BertConvolutionModule(nn.Module): + """Convolution block used in the conformer block""" + + def __init__(self, config): + super().__init__() + if (config.conv_depthwise_kernel_size - 1) % 2 == 1: + raise ValueError("`config.conv_depthwise_kernel_size` should be a odd number for 'SAME' padding") + self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.pointwise_conv1 = nn.Conv1d( + config.hidden_size, + 2 * config.hidden_size, + kernel_size=1, + stride=1, + padding=0, + bias=False, + ) + self.glu = nn.GLU(dim=1) + self.depthwise_conv = nn.Conv1d( + config.hidden_size, + config.hidden_size, + config.conv_depthwise_kernel_size, + stride=1, + padding=0, + groups=config.hidden_size, + bias=False, + ) + + self.depthwise_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.activation = ACT2FN[config.hidden_act] + self.pointwise_conv2 = nn.Conv1d( + config.hidden_size, + config.hidden_size, + kernel_size=1, + stride=1, + padding=0, + bias=False, + ) + self.dropout = nn.Dropout(config.conformer_conv_dropout) + + def forward(self, hidden_states, attention_mask=None): + hidden_states = self.layer_norm(hidden_states) + + # Ensure that we do not leak padded positions in depthwise convolution if attention mask is passed. + # Put 0 where necessary + if attention_mask is not None: + hidden_states = hidden_states.masked_fill(~attention_mask.bool().unsqueeze(-1), 0.0) + + # exchange the temporal dimension and the feature dimension + hidden_states = hidden_states.transpose(1, 2) + + # GLU mechanism + # => (batch, 2*channel, dim) + hidden_states = self.pointwise_conv1(hidden_states) + # => (batch, channel, dim) + hidden_states = self.glu(hidden_states) + + # Pad the sequence entirely on the left because of causal convolution. + hidden_states = torch.nn.functional.pad(hidden_states, (self.depthwise_conv.kernel_size[0] - 1, 0)) + + # 1D Depthwise Conv + hidden_states = self.depthwise_conv(hidden_states) + + hidden_states = self.depthwise_layer_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + hidden_states = self.activation(hidden_states) + + hidden_states = self.pointwise_conv2(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = hidden_states.transpose(1, 2) + return hidden_states + + +class Wav2Vec2BertSelfAttention(Wav2Vec2ConformerSelfAttention, nn.Module): + """Construct an Wav2Vec2BertSelfAttention object. + Can be enhanced with rotary or relative position embeddings. + """ + + def __init__(self, config, is_adapter_attention=False): + nn.Module.__init__() + hidden_size = config.hidden_size if not is_adapter_attention else config.output_hidden_size + + self.head_size = hidden_size // config.num_attention_heads + self.num_heads = config.num_attention_heads + self.position_embeddings_type = config.position_embeddings_type if not is_adapter_attention else None + + self.linear_q = nn.Linear(hidden_size, hidden_size) + self.linear_k = nn.Linear(hidden_size, hidden_size) + self.linear_v = nn.Linear(hidden_size, hidden_size) + self.linear_out = nn.Linear(hidden_size, hidden_size) + + self.dropout = nn.Dropout(p=config.attention_dropout) + + if self.position_embeddings_type == "relative": + # linear transformation for positional encoding + self.linear_pos = nn.Linear(hidden_size, hidden_size, bias=False) + # these two learnable bias are used in matrix c and matrix d + # as described in https://arxiv.org/abs/1901.02860 Section 3.3 + self.pos_bias_u = nn.Parameter(torch.zeros(self.num_heads, self.head_size)) + self.pos_bias_v = nn.Parameter(torch.zeros(self.num_heads, self.head_size)) + + if self.position_embeddings_type == "relative_key": + self.left_max_position_embeddings = config.left_max_position_embeddings + self.right_max_position_embeddings = config.right_max_position_embeddings + num_positions = self.left_max_position_embeddings + self.right_max_position_embeddings + 1 + self.distance_embedding = nn.Embedding(num_positions, self.head_size) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + relative_position_embeddings: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + # self-attention mechanism + batch_size, sequence_length, hidden_size = hidden_states.size() + + # make sure query/key states can be != value states + query_key_states = hidden_states + value_states = hidden_states + + if self.position_embeddings_type == "rotary": + if relative_position_embeddings is None: + raise ValueError( + "`relative_position_embeddings` has to be defined when `self.position_embeddings_type == 'rotary'" + ) + query_key_states = self._apply_rotary_embedding(query_key_states, relative_position_embeddings) + + # project query_key_states and value_states + query = self.linear_q(query_key_states).view(batch_size, -1, self.num_heads, self.head_size) + key = self.linear_k(query_key_states).view(batch_size, -1, self.num_heads, self.head_size) + value = self.linear_v(value_states).view(batch_size, -1, self.num_heads, self.head_size) + + # => (batch, head, time1, d_k) + query = query.transpose(1, 2) + key = key.transpose(1, 2) + value = value.transpose(1, 2) + + if self.position_embeddings_type == "relative": + if relative_position_embeddings is None: + raise ValueError( + "`relative_position_embeddings` has to be defined when `self.position_embeddings_type ==" + " 'relative'" + ) + # apply relative_position_embeddings to qk scores + # as proposed in Transformer_XL: https://arxiv.org/abs/1901.02860 + scores = self._apply_relative_embeddings( + query=query, key=key, relative_position_embeddings=relative_position_embeddings + ) + else: + scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.head_size) + + if self.position_embeddings_type == "relative_key": + query_length, key_length = query.shape[2], key.shape[2] + + position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1) + position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1) + distance = position_ids_r - position_ids_l + distance = torch.clamp(distance, -self.left_max_position_embeddings, self.right_max_position_embeddings) + + positional_embedding = self.distance_embedding(distance + self.left_max_position_embeddings) + positional_embedding = positional_embedding.to(dtype=query.dtype) # fp16 compatibility + + relative_position_attn_weights = torch.einsum("bhld,lrd->bhlr", query, positional_embedding) + scores = scores + (relative_position_attn_weights / math.sqrt(self.head_size)) + + # apply attention_mask if necessary + if attention_mask is not None: + scores = scores + attention_mask + + # => (batch, head, time1, time2) + probs = torch.softmax(scores, dim=-1) + probs = self.dropout(probs) + + # => (batch, head, time1, d_k) + hidden_states = torch.matmul(probs, value) + + # => (batch, time1, hidden_size) + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_size) + hidden_states = self.linear_out(hidden_states) + + return hidden_states, probs + + +class Wav2Vec2BertEncoderLayer(nn.Module): + """Conformer block based on https://arxiv.org/abs/2005.08100.""" + + def __init__(self, config): + super().__init__() + embed_dim = config.hidden_size + dropout = config.attention_dropout + + # Feed-forward 1 + self.ffn1_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) + self.ffn1 = Wav2Vec2BertFeedForward(config) + + # Self-Attention + self.self_attn_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) + self.self_attn_dropout = nn.Dropout(dropout) + self.self_attn = Wav2Vec2BertSelfAttention(config) + + # Conformer Convolution + self.conv_module = Wav2Vec2BertConvolutionModule(config) + + # Feed-forward 2 + self.ffn2_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) + self.ffn2 = Wav2Vec2BertFeedForward(config) + self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) + + def forward( + self, + hidden_states, + attention_mask: Optional[torch.Tensor] = None, + relative_position_embeddings: Optional[torch.Tensor] = None, + output_attentions: bool = False, + conv_attention_mask: Optional[torch.Tensor] = None, + ): + hidden_states = hidden_states + + # 1. Feed-Forward 1 layer + residual = hidden_states + hidden_states = self.ffn1_layer_norm(hidden_states) + hidden_states = self.ffn1(hidden_states) + hidden_states = hidden_states * 0.5 + residual + residual = hidden_states + + # 2. Self-Attention layer + hidden_states = self.self_attn_layer_norm(hidden_states) + hidden_states, attn_weigts = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + relative_position_embeddings=relative_position_embeddings, + output_attentions=output_attentions, + ) + hidden_states = self.self_attn_dropout(hidden_states) + hidden_states = hidden_states + residual + + # 3. Convolutional Layer + residual = hidden_states + hidden_states = self.conv_module(hidden_states, attention_mask=conv_attention_mask) + hidden_states = residual + hidden_states + + # 4. Feed-Forward 2 Layer + residual = hidden_states + hidden_states = self.ffn2_layer_norm(hidden_states) + hidden_states = self.ffn2(hidden_states) + hidden_states = hidden_states * 0.5 + residual + hidden_states = self.final_layer_norm(hidden_states) + + return hidden_states, attn_weigts + + +class Wav2Vec2BertEncoder(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + + if config.position_embeddings_type == "relative": + self.embed_positions = Wav2Vec2BertRelPositionalEmbedding(config) + elif config.position_embeddings_type == "rotary": + self.embed_positions = Wav2Vec2BertRotaryPositionalEmbedding(config) + else: + self.embed_positions = None + + self.dropout = nn.Dropout(config.hidden_dropout) + self.layers = nn.ModuleList([Wav2Vec2BertEncoderLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False + + def forward( + self, + hidden_states, + attention_mask=None, + output_attentions=False, + output_hidden_states=False, + return_dict=True, + ): + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + + conv_attention_mask = attention_mask + if attention_mask is not None: + # make sure padded tokens output 0 + hidden_states = hidden_states.masked_fill(~attention_mask.bool().unsqueeze(-1), 0.0) + + # extend attention_mask + attention_mask = 1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype) + attention_mask = attention_mask * torch.finfo(hidden_states.dtype).min + attention_mask = attention_mask.expand( + attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1] + ) + + hidden_states = self.dropout(hidden_states) + + if self.embed_positions is not None: + relative_position_embeddings = self.embed_positions(hidden_states) + else: + relative_position_embeddings = None + + synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self) + + for i, layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + dropout_probability = torch.rand([]) + + skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False + if not skip_the_layer or synced_gpus: + # under fsdp or deepspeed zero3 all gpus must run in sync + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + layer.__call__, + hidden_states, + attention_mask, + relative_position_embeddings, + output_attentions, + conv_attention_mask, + ) + else: + layer_outputs = layer( + hidden_states, + attention_mask=attention_mask, + relative_position_embeddings=relative_position_embeddings, + output_attentions=output_attentions, + conv_attention_mask=conv_attention_mask, + ) + hidden_states = layer_outputs[0] + + if skip_the_layer: + layer_outputs = (None, None) + + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + +class Wav2Vec2BertAdapter(nn.Module): + def __init__(self, config): + super().__init__() + # feature dim might need to be down-projected + if config.output_hidden_size != config.hidden_size: + self.proj = nn.Linear(config.hidden_size, config.output_hidden_size) + self.proj_layer_norm = nn.LayerNorm(config.output_hidden_size, eps=config.layer_norm_eps) + else: + self.proj = self.proj_layer_norm = None + self.layers = nn.ModuleList(Wav2Vec2BertAdapterLayer(config) for _ in range(config.num_adapter_layers)) + self.layerdrop = config.layerdrop + + self.kernel_size = config.adapter_kernel_size + self.stride = config.adapter_stride + + def _compute_sub_sample_lengths_from_attention_mask(self, seq_lens): + if seq_lens is None: + return seq_lens + pad = self.kernel_size // 2 + seq_lens = ((seq_lens + 2 * pad - self.kernel_size) / self.stride) + 1 + return seq_lens.floor() + + def forward(self, hidden_states, attention_mask=None): + # down project hidden_states if necessary + if self.proj is not None and self.proj_layer_norm is not None: + hidden_states = self.proj(hidden_states) + hidden_states = self.proj_layer_norm(hidden_states) + + sub_sampled_lengths = None + if attention_mask is not None: + sub_sampled_lengths = (attention_mask.size(1) - (1 - attention_mask.int()).sum(1)).to(hidden_states.device) + + for layer in self.layers: + layerdrop_prob = torch.rand([]) + sub_sampled_lengths = self._compute_sub_sample_lengths_from_attention_mask(sub_sampled_lengths) + if not self.training or (layerdrop_prob > self.layerdrop): + hidden_states = layer( + hidden_states, attention_mask=attention_mask, sub_sampled_lengths=sub_sampled_lengths + ) + + return hidden_states + + +class Wav2Vec2BertAdapterLayer(nn.Module): + def __init__(self, config): + super().__init__() + embed_dim = config.output_hidden_size + dropout = config.conformer_conv_dropout + + self.kernel_size = config.adapter_kernel_size + self.stride = config.adapter_stride + + # 1. residual convolution + self.residual_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) + self.residual_conv = nn.Conv1d( + embed_dim, + 2 * embed_dim, + self.kernel_size, + stride=self.stride, + padding=self.stride // 2, + ) + self.activation = nn.GLU(dim=1) + + # Self-Attention + self.self_attn_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) + self.self_attn_conv = nn.Conv1d( + embed_dim, + 2 * embed_dim, + self.kernel_size, + stride=self.stride, + padding=self.stride // 2, + ) + self.self_attn = Wav2Vec2BertSelfAttention(config, is_adapter_attention=True) + self.self_attn_dropout = nn.Dropout(dropout) + + # Feed-forward + self.ffn_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) + self.ffn = Wav2Vec2BertFeedForward(config, act_fn=config.adapter_act, hidden_size=embed_dim) + + def forward( + self, + hidden_states, + attention_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + sub_sampled_lengths: Optional[torch.Tensor] = None, + ): + residual = self.residual_layer_norm(hidden_states) + + # Apply pooling to the residual to match the sequence length of the + # multi-head attention output. + # (batch, seq_len, feature_dim) -> (batch, feature_dim, seq_len) + residual = residual.transpose(1, 2) + residual = self.residual_conv(residual) + residual = self.activation(residual) + # (batch, feature_dim, seq_len) -> (batch, seq_len, feature_dim) + residual = residual.transpose(1, 2) + + hidden_states = self.self_attn_layer_norm(hidden_states) + # Apply pooling before feeding to the multihead-attention layer. + # (batch, seq_len, feature_dim) -> (batch, feature_dim, seq_len) + hidden_states = hidden_states.transpose(1, 2) + hidden_states = self.self_attn_conv(hidden_states) + hidden_states = self.activation(hidden_states) + # (batch, feature_dim, seq_len) -> (batch, seq_len, feature_dim) + hidden_states = hidden_states.transpose(1, 2) + + if attention_mask is not None: + attention_mask = _compute_new_attention_mask(hidden_states=hidden_states, seq_lens=sub_sampled_lengths) + attention_mask = _prepare_4d_attention_mask( + attention_mask, + hidden_states.dtype, + ) + + # The rest of the computation is identical to a vanilla Transformer + # encoder layer. + hidden_states, attn_weigths = self.self_attn( + hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + ) + hidden_states = self.self_attn_dropout(hidden_states) + hidden_states = hidden_states + residual + + residual = hidden_states + + hidden_states = self.ffn_layer_norm(hidden_states) + hidden_states = self.ffn(hidden_states) + residual + + return hidden_states + + +class Wav2Vec2BertPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = Wav2Vec2BertConfig + base_model_prefix = "wav2vec2_bert" + main_input_name = "input_features" + supports_gradient_checkpointing = True + + # Ignore copy + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, Wav2Vec2BertSelfAttention): + if hasattr(module, "pos_bias_u"): + nn.init.xavier_uniform_(module.pos_bias_u) + if hasattr(module, "pos_bias_v"): + nn.init.xavier_uniform_(module.pos_bias_v) + elif isinstance(module, Wav2Vec2BertFeatureProjection): + k = math.sqrt(1 / module.projection.in_features) + nn.init.uniform_(module.projection.weight, a=-k, b=k) + nn.init.uniform_(module.projection.bias, a=-k, b=k) + elif isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + elif isinstance(module, nn.Conv1d): + nn.init.kaiming_normal_(module.weight) + + if module.bias is not None: + k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0])) + nn.init.uniform_(module.bias, a=-k, b=k) + + # Ignore copy + def _get_feat_extract_output_lengths( + self, input_lengths: Union[torch.LongTensor, int], add_adapter: Optional[bool] = None + ): + """ + Computes the output length of the convolutional layers + """ + + add_adapter = self.config.add_adapter if add_adapter is None else add_adapter + + def _conv_out_length(input_length, kernel_size, stride, padding): + # 1D convolutional layer output length formula taken + # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html + return torch.div(input_length + 2 * padding - kernel_size, stride, rounding_mode="floor") + 1 + + if add_adapter: + padding = self.config.adapter_kernel_size // 2 + for _ in range(self.config.num_adapter_layers): + input_lengths = _conv_out_length( + input_lengths, self.config.adapter_kernel_size, self.config.adapter_stride, padding + ) + + return input_lengths + + def _get_feature_vector_attention_mask( + self, feature_vector_length: int, attention_mask: torch.LongTensor, add_adapter=None + ): + # Effectively attention_mask.sum(-1), but not inplace to be able to run + # on inference mode. + non_padded_lengths = attention_mask.cumsum(dim=-1)[:, -1] + + output_lengths = self._get_feat_extract_output_lengths(non_padded_lengths, add_adapter=add_adapter) + output_lengths = output_lengths.to(torch.long) + + batch_size = attention_mask.shape[0] + + attention_mask = torch.zeros( + (batch_size, feature_vector_length), dtype=attention_mask.dtype, device=attention_mask.device + ) + # these two operations makes sure that all values before the output lengths idxs are attended to + attention_mask[(torch.arange(attention_mask.shape[0], device=attention_mask.device), output_lengths - 1)] = 1 + attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool() + return attention_mask + + +WAV2VEC2_BERT_START_DOCSTRING = None + +WAV2VEC2_BERT_INPUTS_DOCSTRING = r""" + Args: + input_features (`torch.FloatTensor` of shape `(batch_size, sequence_length)`): + Float values of input raw speech waveform. Values can be obtained by loading a `.flac` or `.wav` audio file + into an array of type `List[float]` or a `numpy.ndarray`, *e.g.* via the soundfile library (`pip install + soundfile`). To prepare the array into `input_features`, the [`AutoProcessor`] should be used for padding and + conversion into a tensor of type `torch.FloatTensor`. See [`Wav2Vec2BertProcessor.__call__`] for details. + attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing convolution and attention on padding token indices. Mask values selected in `[0, + 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +Wav2Vec2BertBaseModelOutput = Wav2Vec2BaseModelOutput + + +@add_start_docstrings( + "The bare Wav2Vec2Bert Model transformer outputting raw hidden-states without any specific head on top.", + WAV2VEC2_BERT_START_DOCSTRING, +) +class Wav2Vec2BertModel(Wav2Vec2Model, Wav2Vec2BertPreTrainedModel): + def __init__(self, config: Wav2Vec2BertConfig): + Wav2Vec2BertPreTrainedModel.__init__(config) + self.config = config + self.feature_projection = Wav2Vec2BertFeatureProjection(config) + + # model only needs masking vector if mask prob is > 0.0 + if config.mask_time_prob > 0.0 or config.mask_feature_prob > 0.0: + self.masked_spec_embed = nn.Parameter(torch.Tensor(config.hidden_size).uniform_()) + + self.encoder = Wav2Vec2BertEncoder(config) + + self.adapter = Wav2Vec2BertAdapter(config) if config.add_adapter else None + + self.intermediate_ffn = None + if config.use_intermediate_ffn_before_adapter: + self.intermediate_ffn = Wav2Vec2BertFeedForward(config, act_fn="relu") + + # Initialize weights and apply final processing + self.post_init() + + def freeze_feature_extractor(self): + raise AttributeError("Not needed for Wav2Vec2Bert") + + def freeze_feature_encoder(self): + raise AttributeError("Not needed for Wav2Vec2Bert") + + @add_start_docstrings_to_model_forward(WAV2VEC2_BERT_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_PRETRAINED_CHECKPOINT_FOR_DOC, + output_type=Wav2Vec2BertBaseModelOutput, + config_class=_CONFIG_FOR_DOC, + modality="audio", + expected_output=_EXPECTED_OUTPUT_SHAPE, + ) + def forward( + self, + input_features: Optional[torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + mask_time_indices: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, Wav2Vec2BertBaseModelOutput]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + hidden_states, extract_features = self.feature_projection(input_features) + hidden_states = self._mask_hidden_states( + hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask + ) + + encoder_outputs = self.encoder( + hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = encoder_outputs[0] + + if self.intermediate_ffn: + expanded_hidden_states = self.intermediate_ffn(hidden_states) + hidden_states = hidden_states + 0.5 * expanded_hidden_states + + if self.adapter is not None: + hidden_states = self.adapter(hidden_states, attention_mask=attention_mask) + + if not return_dict: + return (hidden_states, extract_features) + encoder_outputs[1:] + + return Wav2Vec2BertBaseModelOutput( + last_hidden_state=hidden_states, + extract_features=extract_features, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +@add_start_docstrings( + """Wav2Vec2Bert Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC).""", + WAV2VEC2_BERT_START_DOCSTRING, +) +class Wav2Vec2BertForCTC(Wav2Vec2ConformerForCTC): + def __init__(self, config, target_lang: Optional[str] = None): + super().__init__(config) + + def freeze_feature_encoder(self): + raise AttributeError("Not needed for Wav2Vec2Bert") + + @add_start_docstrings_to_model_forward(WAV2VEC2_BERT_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_PRETRAINED_CHECKPOINT_FOR_DOC, + output_type=CausalLMOutput, + config_class=_CONFIG_FOR_DOC, + expected_output=_CTC_EXPECTED_OUTPUT, + expected_loss=_CTC_EXPECTED_LOSS, + ) + def forward( + self, + input_features: Optional[torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + labels: Optional[torch.Tensor] = None, + ) -> Union[Tuple, CausalLMOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, target_length)`, *optional*): + Labels for connectionist temporal classification. Note that `target_length` has to be smaller or equal to + the sequence length of the output logits. Indices are selected in `[-100, 0, ..., config.vocab_size - 1]`. + All labels set to `-100` are ignored (masked), the loss is only computed for labels in `[0, ..., + config.vocab_size - 1]`. + """ + if labels is not None and labels.max() >= self.config.vocab_size: + raise ValueError(f"Label values must be <= vocab_size: {self.config.vocab_size}") + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.wav2vec2_bert( + input_features, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + hidden_states = self.dropout(hidden_states) + + logits = self.lm_head(hidden_states) + + loss = None + if labels is not None: + # retrieve loss input_lengths from attention_mask + attention_mask = ( + attention_mask + if attention_mask is not None + else torch.ones(input_features.shape[:2], device=input_features.device, dtype=torch.long) + ) + input_lengths = self._get_feat_extract_output_lengths(attention_mask.sum([-1])).to(torch.long) + + # assuming that padded tokens are filled with -100 + # when not being attended to + labels_mask = labels >= 0 + target_lengths = labels_mask.sum(-1) + flattened_targets = labels.masked_select(labels_mask) + + # ctc_loss doesn't support fp16 + log_probs = nn.functional.log_softmax(logits, dim=-1, dtype=torch.float32).transpose(0, 1) + + with torch.backends.cudnn.flags(enabled=False): + loss = nn.functional.ctc_loss( + log_probs, + flattened_targets, + input_lengths, + target_lengths, + blank=self.config.pad_token_id, + reduction=self.config.ctc_loss_reduction, + zero_infinity=self.config.ctc_zero_infinity, + ) + + if not return_dict: + output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:] + return ((loss,) + output) if loss is not None else output + + return CausalLMOutput( + loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions + ) + + +@add_start_docstrings( + """ + Wav2Vec2Bert Model with a sequence classification head on top (a linear layer over the pooled output) for + tasks like SUPERB Keyword Spotting. + """, + WAV2VEC2_BERT_START_DOCSTRING, +) +class Wav2Vec2BertForSequenceClassification(Wav2Vec2ForSequenceClassification): + def __init__(self, config): + super().__init__(config) + + def freeze_feature_extractor(self): + raise AttributeError("Not needed for Wav2Vec2Bert") + + def freeze_feature_encoder(self): + raise AttributeError("Not needed for Wav2Vec2Bert") + + def freeze_base_model(self): + """ + Calling this function will disable the gradient computation for the base model so that its parameters will not + be updated during training. Only the classification head will be updated. + """ + for param in self.wav2vec2_bert.parameters(): + param.requires_grad = False + + @add_start_docstrings_to_model_forward(WAV2VEC2_BERT_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_BASE_CHECKPOINT_FOR_DOC, + output_type=SequenceClassifierOutput, + config_class=_CONFIG_FOR_DOC, + modality="audio", + ) + # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.forward with Wav2Vec2->Wav2Vec2Bert,wav2vec2->wav2vec2_bert,WAV_2_VEC_2->WAV2VEC2_BERT, input_values->input_features + def forward( + self, + input_features: Optional[torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + labels: Optional[torch.Tensor] = None, + ) -> Union[Tuple, SequenceClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states + + outputs = self.wav2vec2_bert( + input_features, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + if self.config.use_weighted_layer_sum: + hidden_states = outputs[_HIDDEN_STATES_START_POSITION] + hidden_states = torch.stack(hidden_states, dim=1) + norm_weights = nn.functional.softmax(self.layer_weights, dim=-1) + hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1) + else: + hidden_states = outputs[0] + + hidden_states = self.projector(hidden_states) + if attention_mask is None: + pooled_output = hidden_states.mean(dim=1) + else: + padding_mask = self._get_feature_vector_attention_mask(hidden_states.shape[1], attention_mask) + expand_padding_mask = padding_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2]) + hidden_states[~expand_padding_mask] = 0.0 + pooled_output = hidden_states.sum(dim=1) / padding_mask.sum(dim=1).view(-1, 1) + + logits = self.classifier(pooled_output) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + Wav2Vec2Bert Model with a frame classification head on top for tasks like Speaker Diarization. + """, + WAV2VEC2_BERT_START_DOCSTRING, +) +class Wav2Vec2BertForAudioFrameClassification(Wav2Vec2ConformerForAudioFrameClassification): + def __init__(self, config): + super().__init__(config) + + def freeze_feature_encoder(self): + raise AttributeError("Not needed for Wav2Vec2Bert") + + @add_start_docstrings_to_model_forward(WAV2VEC2_BERT_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_BASE_CHECKPOINT_FOR_DOC, + output_type=TokenClassifierOutput, + config_class=_CONFIG_FOR_DOC, + modality="audio", + ) + # Copied from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer.Wav2Vec2ConformerForAudioFrameClassification.forward with wav2vec2_conformer->wav2vec2_bert, input_values->input_features + def forward( + self, + input_features: Optional[torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, TokenClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states + + outputs = self.wav2vec2_bert( + input_features, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + if self.config.use_weighted_layer_sum: + hidden_states = outputs[_HIDDEN_STATES_START_POSITION] + hidden_states = torch.stack(hidden_states, dim=1) + norm_weights = nn.functional.softmax(self.layer_weights, dim=-1) + hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1) + else: + hidden_states = outputs[0] + + logits = self.classifier(hidden_states) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), torch.argmax(labels.view(-1, self.num_labels), axis=1)) + + if not return_dict: + output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:] + return output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + Wav2Vec2Bert Model with an XVector feature extraction head on top for tasks like Speaker Verification. + """, + WAV2VEC2_BERT_START_DOCSTRING, +) +class Wav2Vec2BertForXVector(Wav2Vec2ConformerForXVector): + def __init__(self, config): + super().__init__(config) + + def freeze_feature_encoder(self): + raise AttributeError("Not needed for Wav2Vec2Bert") + + @add_start_docstrings_to_model_forward(WAV2VEC2_BERT_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_BASE_CHECKPOINT_FOR_DOC, + output_type=XVectorOutput, + config_class=_CONFIG_FOR_DOC, + modality="audio", + ) + # Copied from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer.Wav2Vec2ConformerForXVector.forward with wav2vec2_conformer->wav2vec2_bert, input_values->input_features + def forward( + self, + input_features: Optional[torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + labels: Optional[torch.Tensor] = None, + ) -> Union[Tuple, XVectorOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states + + outputs = self.wav2vec2_bert( + input_features, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + if self.config.use_weighted_layer_sum: + hidden_states = outputs[_HIDDEN_STATES_START_POSITION] + hidden_states = torch.stack(hidden_states, dim=1) + norm_weights = nn.functional.softmax(self.layer_weights, dim=-1) + hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1) + else: + hidden_states = outputs[0] + + hidden_states = self.projector(hidden_states) + + for tdnn_layer in self.tdnn: + hidden_states = tdnn_layer(hidden_states) + + # Statistic Pooling + if attention_mask is None: + mean_features = hidden_states.mean(dim=1) + std_features = hidden_states.std(dim=1) + else: + feat_extract_output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(dim=1)) + tdnn_output_lengths = self._get_tdnn_output_lengths(feat_extract_output_lengths) + mean_features = [] + std_features = [] + for i, length in enumerate(tdnn_output_lengths): + mean_features.append(hidden_states[i, :length].mean(dim=0)) + std_features.append(hidden_states[i, :length].std(dim=0)) + mean_features = torch.stack(mean_features) + std_features = torch.stack(std_features) + statistic_pooling = torch.cat([mean_features, std_features], dim=-1) + + output_embeddings = self.feature_extractor(statistic_pooling) + logits = self.classifier(output_embeddings) + + loss = None + if labels is not None: + loss = self.objective(logits, labels) + + if not return_dict: + output = (logits, output_embeddings) + outputs[_HIDDEN_STATES_START_POSITION:] + return ((loss,) + output) if loss is not None else output + + return XVectorOutput( + loss=loss, + logits=logits, + embeddings=output_embeddings, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +__all__ = [ + "Wav2Vec2BertForAudioFrameClassification", + "Wav2Vec2BertForCTC", + "Wav2Vec2BertForSequenceClassification", + "Wav2Vec2BertForXVector", + "Wav2Vec2BertModel", + "Wav2Vec2BertPreTrainedModel", +] diff --git a/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py b/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py index 587ae1c8249..bd94e44b616 100644 --- a/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +++ b/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py @@ -1,19 +1,9 @@ -# coding=utf-8 -# Copyright 2022 The Fairseq Authors and the HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""PyTorch Wav2Vec2-Conformer model.""" - +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/wav2vec2_conformer/modular_wav2vec2_conformer.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_wav2vec2_conformer.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 import math import warnings from dataclasses import dataclass @@ -21,7 +11,6 @@ from typing import Optional, Tuple, Union import numpy as np import torch -import torch.utils.checkpoint from torch import nn from torch.nn import CrossEntropyLoss @@ -43,31 +32,19 @@ from ...utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, is_peft_available, - logging, replace_return_docstrings, ) from .configuration_wav2vec2_conformer import Wav2Vec2ConformerConfig -logger = logging.get_logger(__name__) - - -_HIDDEN_STATES_START_POSITION = 2 +# Base docstring +_CHECKPOINT_FOR_DOC = "facebook/wav2vec2-conformer-rope-large-960h-ft" # General docstring _CONFIG_FOR_DOC = "Wav2Vec2ConformerConfig" -# Base docstring -_CHECKPOINT_FOR_DOC = "facebook/wav2vec2-conformer-rope-large-960h-ft" -_EXPECTED_OUTPUT_SHAPE = [1, 292, 1024] - -# CTC docstring -_CTC_EXPECTED_OUTPUT = "'MISTER QUILTER IS THE APOSTLE OF THE MIDDLE CLASSES AND WE ARE GLAD TO WELCOME HIS GOSPEL'" -_CTC_EXPECTED_LOSS = 64.21 - @dataclass -# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForPreTrainingOutput with Wav2Vec2->Wav2Vec2Conformer class Wav2Vec2ConformerForPreTrainingOutput(ModelOutput): """ Output type of [`Wav2Vec2ConformerForPreTraining`], with potential hidden states and attentions. @@ -109,239 +86,17 @@ class Wav2Vec2ConformerForPreTrainingOutput(ModelOutput): diversity_loss: Optional[torch.FloatTensor] = None -# Copied from transformers.models.wav2vec2.modeling_wav2vec2._compute_mask_indices -def _compute_mask_indices( - shape: Tuple[int, int], - mask_prob: float, - mask_length: int, - attention_mask: Optional[torch.LongTensor] = None, - min_masks: int = 0, -) -> np.ndarray: - """ - Computes random mask spans for a given shape. Used to implement [SpecAugment: A Simple Data Augmentation Method for - ASR](https://arxiv.org/abs/1904.08779). Note that this method is not optimized to run on TPU and should be run on - CPU as part of the preprocessing during training. - - Args: - shape: The shape for which to compute masks. This should be of a tuple of size 2 where - the first element is the batch size and the second element is the length of the axis to span. - mask_prob: The percentage of the whole axis (between 0 and 1) which will be masked. The number of - independently generated mask spans of length `mask_length` is computed by - `mask_prob*shape[1]/mask_length`. Note that due to overlaps, `mask_prob` is an upper bound and the - actual percentage will be smaller. - mask_length: size of the mask - min_masks: minimum number of masked spans - attention_mask: A (right-padded) attention mask which independently shortens the feature axis of - each batch dimension. - """ - batch_size, sequence_length = shape - - if mask_length < 1: - raise ValueError("`mask_length` has to be bigger than 0.") - - if mask_length > sequence_length: - raise ValueError( - f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length}" - f" and `sequence_length`: {sequence_length}`" - ) - - # epsilon is used for probabilistic rounding - epsilon = np.random.rand(1).item() - - def compute_num_masked_span(input_length): - """Given input length, compute how many spans should be masked""" - num_masked_span = int(mask_prob * input_length / mask_length + epsilon) - num_masked_span = max(num_masked_span, min_masks) - - # make sure num masked span <= sequence_length - if num_masked_span * mask_length > sequence_length: - num_masked_span = sequence_length // mask_length - - # make sure num_masked span is also <= input_length - (mask_length - 1) - if input_length - (mask_length - 1) < num_masked_span: - num_masked_span = max(input_length - (mask_length - 1), 0) - - return num_masked_span - - # compute number of masked spans in batch - input_lengths = ( - attention_mask.detach().sum(-1).tolist() - if attention_mask is not None - else [sequence_length for _ in range(batch_size)] - ) - - # SpecAugment mask to fill - spec_aug_mask = np.zeros((batch_size, sequence_length), dtype=bool) - spec_aug_mask_idxs = [] - - max_num_masked_span = compute_num_masked_span(sequence_length) - - if max_num_masked_span == 0: - return spec_aug_mask - - for input_length in input_lengths: - # compute num of masked spans for this input - num_masked_span = compute_num_masked_span(input_length) - - # get random indices to mask - spec_aug_mask_idx = np.random.choice( - np.arange(input_length - (mask_length - 1)), num_masked_span, replace=False - ) - - # pick first sampled index that will serve as a dummy index to pad vector - # to ensure same dimension for all batches due to probabilistic rounding - # Picking first sample just pads those vectors twice. - if len(spec_aug_mask_idx) == 0: - # this case can only happen if `input_length` is strictly smaller then - # `sequence_length` in which case the last token has to be a padding - # token which we can use as a dummy mask id - dummy_mask_idx = sequence_length - 1 - else: - dummy_mask_idx = spec_aug_mask_idx[0] - - spec_aug_mask_idx = np.concatenate( - [spec_aug_mask_idx, np.ones(max_num_masked_span - num_masked_span, dtype=np.int32) * dummy_mask_idx] - ) - spec_aug_mask_idxs.append(spec_aug_mask_idx) - - spec_aug_mask_idxs = np.array(spec_aug_mask_idxs) - - # expand masked indices to masked spans - spec_aug_mask_idxs = np.broadcast_to( - spec_aug_mask_idxs[:, :, None], (batch_size, max_num_masked_span, mask_length) - ) - spec_aug_mask_idxs = spec_aug_mask_idxs.reshape(batch_size, max_num_masked_span * mask_length) - - # add offset to the starting indexes so that indexes now create a span - offsets = np.arange(mask_length)[None, None, :] - offsets = np.broadcast_to(offsets, (batch_size, max_num_masked_span, mask_length)).reshape( - batch_size, max_num_masked_span * mask_length - ) - spec_aug_mask_idxs = spec_aug_mask_idxs + offsets - - # ensure that we cannot have indices larger than sequence_length - if spec_aug_mask_idxs.max() > sequence_length - 1: - spec_aug_mask_idxs[spec_aug_mask_idxs > sequence_length - 1] = sequence_length - 1 - - # scatter indices to mask - np.put_along_axis(spec_aug_mask, spec_aug_mask_idxs, 1, -1) - - return spec_aug_mask - - -# Copied from transformers.models.wav2vec2.modeling_wav2vec2._sample_negative_indices -def _sample_negative_indices( - features_shape: Tuple, num_negatives: int, mask_time_indices: Optional[np.ndarray] = None -): - """ - Sample `num_negatives` vectors from feature vectors. - """ - batch_size, sequence_length = features_shape - - # generate indices of the positive vectors themselves, repeat them `num_negatives` times - sequence_length_range = np.arange(sequence_length) - - # get `num_negatives` random vector indices from the same utterance - sampled_negative_indices = np.zeros(shape=(batch_size, sequence_length, num_negatives), dtype=np.int32) - - mask_time_indices = ( - mask_time_indices.astype(bool) if mask_time_indices is not None else np.ones(features_shape, dtype=bool) - ) - - for batch_idx in range(batch_size): - high = mask_time_indices[batch_idx].sum() - 1 - mapped_masked_indices = sequence_length_range[mask_time_indices[batch_idx]] - - feature_indices = np.broadcast_to(np.arange(high + 1)[:, None], (high + 1, num_negatives)) - sampled_indices = np.random.randint(0, high, size=(high + 1, num_negatives)) - # avoid sampling the same positive vector, but keep the distribution uniform - sampled_indices[sampled_indices >= feature_indices] += 1 - - # remap to actual indices - sampled_negative_indices[batch_idx][mask_time_indices[batch_idx]] = mapped_masked_indices[sampled_indices] - - # correct for batch size - sampled_negative_indices[batch_idx] += batch_idx * sequence_length - - return sampled_negative_indices - - -# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2NoLayerNormConvLayer with Wav2Vec2->Wav2Vec2Conformer -class Wav2Vec2ConformerNoLayerNormConvLayer(nn.Module): - def __init__(self, config, layer_id=0): +class Wav2Vec2ConformerSamePadLayer(nn.Module): + def __init__(self, num_conv_pos_embeddings): super().__init__() - self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1 - self.out_conv_dim = config.conv_dim[layer_id] - - self.conv = nn.Conv1d( - self.in_conv_dim, - self.out_conv_dim, - kernel_size=config.conv_kernel[layer_id], - stride=config.conv_stride[layer_id], - bias=config.conv_bias, - ) - self.activation = ACT2FN[config.feat_extract_activation] + self.num_pad_remove = 1 if num_conv_pos_embeddings % 2 == 0 else 0 def forward(self, hidden_states): - hidden_states = self.conv(hidden_states) - hidden_states = self.activation(hidden_states) + if self.num_pad_remove > 0: + hidden_states = hidden_states[:, :, : -self.num_pad_remove] return hidden_states -# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2LayerNormConvLayer with Wav2Vec2->Wav2Vec2Conformer -class Wav2Vec2ConformerLayerNormConvLayer(nn.Module): - def __init__(self, config, layer_id=0): - super().__init__() - self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1 - self.out_conv_dim = config.conv_dim[layer_id] - - self.conv = nn.Conv1d( - self.in_conv_dim, - self.out_conv_dim, - kernel_size=config.conv_kernel[layer_id], - stride=config.conv_stride[layer_id], - bias=config.conv_bias, - ) - self.layer_norm = nn.LayerNorm(self.out_conv_dim, elementwise_affine=True) - self.activation = ACT2FN[config.feat_extract_activation] - - def forward(self, hidden_states): - hidden_states = self.conv(hidden_states) - - hidden_states = hidden_states.transpose(-2, -1) - hidden_states = self.layer_norm(hidden_states) - hidden_states = hidden_states.transpose(-2, -1) - - hidden_states = self.activation(hidden_states) - return hidden_states - - -# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2GroupNormConvLayer with Wav2Vec2->Wav2Vec2Conformer -class Wav2Vec2ConformerGroupNormConvLayer(nn.Module): - def __init__(self, config, layer_id=0): - super().__init__() - self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1 - self.out_conv_dim = config.conv_dim[layer_id] - - self.conv = nn.Conv1d( - self.in_conv_dim, - self.out_conv_dim, - kernel_size=config.conv_kernel[layer_id], - stride=config.conv_stride[layer_id], - bias=config.conv_bias, - ) - self.activation = ACT2FN[config.feat_extract_activation] - - self.layer_norm = nn.GroupNorm(num_groups=self.out_conv_dim, num_channels=self.out_conv_dim, affine=True) - - def forward(self, hidden_states): - hidden_states = self.conv(hidden_states) - hidden_states = self.layer_norm(hidden_states) - hidden_states = self.activation(hidden_states) - return hidden_states - - -# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2PositionalConvEmbedding with Wav2Vec2->Wav2Vec2Conformer class Wav2Vec2ConformerPositionalConvEmbedding(nn.Module): def __init__(self, config): super().__init__() @@ -471,19 +226,78 @@ class Wav2Vec2ConformerRelPositionalEmbedding(nn.Module): return relative_position_embeddings -# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2SamePadLayer with Wav2Vec2->Wav2Vec2Conformer -class Wav2Vec2ConformerSamePadLayer(nn.Module): - def __init__(self, num_conv_pos_embeddings): +class Wav2Vec2ConformerNoLayerNormConvLayer(nn.Module): + def __init__(self, config, layer_id=0): super().__init__() - self.num_pad_remove = 1 if num_conv_pos_embeddings % 2 == 0 else 0 + self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1 + self.out_conv_dim = config.conv_dim[layer_id] + + self.conv = nn.Conv1d( + self.in_conv_dim, + self.out_conv_dim, + kernel_size=config.conv_kernel[layer_id], + stride=config.conv_stride[layer_id], + bias=config.conv_bias, + ) + self.activation = ACT2FN[config.feat_extract_activation] def forward(self, hidden_states): - if self.num_pad_remove > 0: - hidden_states = hidden_states[:, :, : -self.num_pad_remove] + hidden_states = self.conv(hidden_states) + hidden_states = self.activation(hidden_states) + return hidden_states + + +class Wav2Vec2ConformerLayerNormConvLayer(nn.Module): + def __init__(self, config, layer_id=0): + super().__init__() + self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1 + self.out_conv_dim = config.conv_dim[layer_id] + + self.conv = nn.Conv1d( + self.in_conv_dim, + self.out_conv_dim, + kernel_size=config.conv_kernel[layer_id], + stride=config.conv_stride[layer_id], + bias=config.conv_bias, + ) + self.layer_norm = nn.LayerNorm(self.out_conv_dim, elementwise_affine=True) + self.activation = ACT2FN[config.feat_extract_activation] + + def forward(self, hidden_states): + hidden_states = self.conv(hidden_states) + + hidden_states = hidden_states.transpose(-2, -1) + hidden_states = self.layer_norm(hidden_states) + hidden_states = hidden_states.transpose(-2, -1) + + hidden_states = self.activation(hidden_states) + return hidden_states + + +class Wav2Vec2ConformerGroupNormConvLayer(nn.Module): + def __init__(self, config, layer_id=0): + super().__init__() + self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1 + self.out_conv_dim = config.conv_dim[layer_id] + + self.conv = nn.Conv1d( + self.in_conv_dim, + self.out_conv_dim, + kernel_size=config.conv_kernel[layer_id], + stride=config.conv_stride[layer_id], + bias=config.conv_bias, + ) + self.activation = ACT2FN[config.feat_extract_activation] + + self.layer_norm = nn.GroupNorm(num_groups=self.out_conv_dim, num_channels=self.out_conv_dim, affine=True) + + def forward(self, hidden_states): + hidden_states = self.conv(hidden_states) + hidden_states = self.layer_norm(hidden_states) + hidden_states = self.activation(hidden_states) return hidden_states -# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2FeatureEncoder with Wav2Vec2->Wav2Vec2Conformer class Wav2Vec2ConformerFeatureEncoder(nn.Module): """Construct the features from raw audio waveform""" @@ -531,7 +345,6 @@ class Wav2Vec2ConformerFeatureEncoder(nn.Module): return hidden_states -# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2FeatureProjection with Wav2Vec2->Wav2Vec2Conformer class Wav2Vec2ConformerFeatureProjection(nn.Module): def __init__(self, config): super().__init__() @@ -547,7 +360,6 @@ class Wav2Vec2ConformerFeatureProjection(nn.Module): return hidden_states, norm_hidden_states -# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2FeedForward with Wav2Vec2->Wav2Vec2Conformer class Wav2Vec2ConformerFeedForward(nn.Module): def __init__(self, config): super().__init__() @@ -943,7 +755,6 @@ class Wav2Vec2ConformerEncoder(nn.Module): ) -# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2GumbelVectorQuantizer with Wav2Vec2->Wav2Vec2Conformer class Wav2Vec2ConformerGumbelVectorQuantizer(nn.Module): """ Vector quantization using gumbel softmax. See `[CATEGORICAL REPARAMETERIZATION WITH @@ -1020,7 +831,6 @@ class Wav2Vec2ConformerGumbelVectorQuantizer(nn.Module): return codevectors, perplexity -# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Adapter with Wav2Vec2->Wav2Vec2Conformer class Wav2Vec2ConformerAdapter(nn.Module): def __init__(self, config): super().__init__() @@ -1052,7 +862,6 @@ class Wav2Vec2ConformerAdapter(nn.Module): return hidden_states -# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2AdapterLayer with Wav2Vec2->Wav2Vec2Conformer class Wav2Vec2ConformerAdapterLayer(nn.Module): def __init__(self, config): super().__init__() @@ -1170,6 +979,128 @@ class Wav2Vec2ConformerPreTrainedModel(PreTrainedModel): return attention_mask +def _compute_mask_indices( + shape: Tuple[int, int], + mask_prob: float, + mask_length: int, + attention_mask: Optional[torch.LongTensor] = None, + min_masks: int = 0, +) -> np.ndarray: + """ + Computes random mask spans for a given shape. Used to implement [SpecAugment: A Simple Data Augmentation Method for + ASR](https://arxiv.org/abs/1904.08779). Note that this method is not optimized to run on TPU and should be run on + CPU as part of the preprocessing during training. + + Args: + shape: The shape for which to compute masks. This should be of a tuple of size 2 where + the first element is the batch size and the second element is the length of the axis to span. + mask_prob: The percentage of the whole axis (between 0 and 1) which will be masked. The number of + independently generated mask spans of length `mask_length` is computed by + `mask_prob*shape[1]/mask_length`. Note that due to overlaps, `mask_prob` is an upper bound and the + actual percentage will be smaller. + mask_length: size of the mask + min_masks: minimum number of masked spans + attention_mask: A (right-padded) attention mask which independently shortens the feature axis of + each batch dimension. + """ + batch_size, sequence_length = shape + + if mask_length < 1: + raise ValueError("`mask_length` has to be bigger than 0.") + + if mask_length > sequence_length: + raise ValueError( + f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length}" + f" and `sequence_length`: {sequence_length}`" + ) + + # epsilon is used for probabilistic rounding + epsilon = np.random.rand(1).item() + + def compute_num_masked_span(input_length): + """Given input length, compute how many spans should be masked""" + num_masked_span = int(mask_prob * input_length / mask_length + epsilon) + num_masked_span = max(num_masked_span, min_masks) + + # make sure num masked span <= sequence_length + if num_masked_span * mask_length > sequence_length: + num_masked_span = sequence_length // mask_length + + # make sure num_masked span is also <= input_length - (mask_length - 1) + if input_length - (mask_length - 1) < num_masked_span: + num_masked_span = max(input_length - (mask_length - 1), 0) + + return num_masked_span + + # compute number of masked spans in batch + input_lengths = ( + attention_mask.detach().sum(-1).tolist() + if attention_mask is not None + else [sequence_length for _ in range(batch_size)] + ) + + # SpecAugment mask to fill + spec_aug_mask = np.zeros((batch_size, sequence_length), dtype=bool) + spec_aug_mask_idxs = [] + + max_num_masked_span = compute_num_masked_span(sequence_length) + + if max_num_masked_span == 0: + return spec_aug_mask + + for input_length in input_lengths: + # compute num of masked spans for this input + num_masked_span = compute_num_masked_span(input_length) + + # get random indices to mask + spec_aug_mask_idx = np.random.choice( + np.arange(input_length - (mask_length - 1)), num_masked_span, replace=False + ) + + # pick first sampled index that will serve as a dummy index to pad vector + # to ensure same dimension for all batches due to probabilistic rounding + # Picking first sample just pads those vectors twice. + if len(spec_aug_mask_idx) == 0: + # this case can only happen if `input_length` is strictly smaller then + # `sequence_length` in which case the last token has to be a padding + # token which we can use as a dummy mask id + dummy_mask_idx = sequence_length - 1 + else: + dummy_mask_idx = spec_aug_mask_idx[0] + + spec_aug_mask_idx = np.concatenate( + [spec_aug_mask_idx, np.ones(max_num_masked_span - num_masked_span, dtype=np.int32) * dummy_mask_idx] + ) + spec_aug_mask_idxs.append(spec_aug_mask_idx) + + spec_aug_mask_idxs = np.array(spec_aug_mask_idxs) + + # expand masked indices to masked spans + spec_aug_mask_idxs = np.broadcast_to( + spec_aug_mask_idxs[:, :, None], (batch_size, max_num_masked_span, mask_length) + ) + spec_aug_mask_idxs = spec_aug_mask_idxs.reshape(batch_size, max_num_masked_span * mask_length) + + # add offset to the starting indexes so that indexes now create a span + offsets = np.arange(mask_length)[None, None, :] + offsets = np.broadcast_to(offsets, (batch_size, max_num_masked_span, mask_length)).reshape( + batch_size, max_num_masked_span * mask_length + ) + spec_aug_mask_idxs = spec_aug_mask_idxs + offsets + + # ensure that we cannot have indices larger than sequence_length + if spec_aug_mask_idxs.max() > sequence_length - 1: + spec_aug_mask_idxs[spec_aug_mask_idxs > sequence_length - 1] = sequence_length - 1 + + # scatter indices to mask + np.put_along_axis(spec_aug_mask, spec_aug_mask_idxs, 1, -1) + + return spec_aug_mask + + +_EXPECTED_OUTPUT_SHAPE = [1, 292, 1024] + + WAV2VEC2_CONFORMER_START_DOCSTRING = r""" Wav2Vec2Conformer was proposed in [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations](https://arxiv.org/abs/2006.11477) by Alexei Baevski, Henry Zhou, Abdelrahman Mohamed, Michael @@ -1178,8 +1109,9 @@ WAV2VEC2_CONFORMER_START_DOCSTRING = r""" This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving etc.). - This model is a PyTorch [nn.Module](https://pytorch.org/docs/stable/nn.html#nn.Module) sub-class. Use it as a - regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior. + This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use + it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and + behavior. Parameters: config ([`Wav2Vec2ConformerConfig`]): Model configuration class with all the parameters of the model. @@ -1187,7 +1119,6 @@ WAV2VEC2_CONFORMER_START_DOCSTRING = r""" configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. """ - WAV2VEC2_CONFORMER_INPUTS_DOCSTRING = r""" Args: input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`): @@ -1226,6 +1157,8 @@ WAV2VEC2_CONFORMER_INPUTS_DOCSTRING = r""" Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. """ +Wav2Vec2ConformerBaseModelOutput = Wav2Vec2BaseModelOutput + @add_start_docstrings( "The bare Wav2Vec2Conformer Model transformer outputting raw hidden-states without any specific head on top.", @@ -1249,7 +1182,6 @@ class Wav2Vec2ConformerModel(Wav2Vec2ConformerPreTrainedModel): # Initialize weights and apply final processing self.post_init() - # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Model.freeze_feature_encoder def freeze_feature_encoder(self): """ Calling this function will disable the gradient computation for the feature encoder so that its parameter will @@ -1257,7 +1189,6 @@ class Wav2Vec2ConformerModel(Wav2Vec2ConformerPreTrainedModel): """ self.feature_extractor._freeze_parameters() - # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Model._mask_hidden_states def _mask_hidden_states( self, hidden_states: torch.FloatTensor, @@ -1307,12 +1238,11 @@ class Wav2Vec2ConformerModel(Wav2Vec2ConformerPreTrainedModel): @add_start_docstrings_to_model_forward(WAV2VEC2_CONFORMER_INPUTS_DOCSTRING) @add_code_sample_docstrings( checkpoint=_CHECKPOINT_FOR_DOC, - output_type=Wav2Vec2BaseModelOutput, + output_type=Wav2Vec2ConformerBaseModelOutput, config_class=_CONFIG_FOR_DOC, modality="audio", expected_output=_EXPECTED_OUTPUT_SHAPE, ) - # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Model.forward with wav2vec2->wav2vec2_conformer def forward( self, input_values: Optional[torch.Tensor], @@ -1321,7 +1251,7 @@ class Wav2Vec2ConformerModel(Wav2Vec2ConformerPreTrainedModel): output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, - ) -> Union[Tuple, Wav2Vec2BaseModelOutput]: + ) -> Union[Tuple, Wav2Vec2ConformerBaseModelOutput]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states @@ -1358,7 +1288,7 @@ class Wav2Vec2ConformerModel(Wav2Vec2ConformerPreTrainedModel): if not return_dict: return (hidden_states, extract_features) + encoder_outputs[1:] - return Wav2Vec2BaseModelOutput( + return Wav2Vec2ConformerBaseModelOutput( last_hidden_state=hidden_states, extract_features=extract_features, hidden_states=encoder_outputs.hidden_states, @@ -1370,7 +1300,6 @@ class Wav2Vec2ConformerModel(Wav2Vec2ConformerPreTrainedModel): """Wav2Vec2Conformer Model with a quantizer and `VQ` head on top.""", WAV2VEC2_CONFORMER_START_DOCSTRING ) class Wav2Vec2ConformerForPreTraining(Wav2Vec2ConformerPreTrainedModel): - # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForPreTraining.__init__ with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer def __init__(self, config: Wav2Vec2ConformerConfig): super().__init__(config) self.wav2vec2_conformer = Wav2Vec2ConformerModel(config) @@ -1384,14 +1313,12 @@ class Wav2Vec2ConformerForPreTraining(Wav2Vec2ConformerPreTrainedModel): # Initialize weights and apply final processing self.post_init() - # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForPreTraining.set_gumbel_temperature def set_gumbel_temperature(self, temperature: int): """ Set the Gumbel softmax temperature to a given value. Only necessary for training """ self.quantizer.temperature = temperature - # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForPreTraining.freeze_feature_encoder with wav2vec2->wav2vec2_conformer def freeze_feature_encoder(self): """ Calling this function will disable the gradient computation for the feature encoder so that its parameter will @@ -1400,7 +1327,6 @@ class Wav2Vec2ConformerForPreTraining(Wav2Vec2ConformerPreTrainedModel): self.wav2vec2_conformer.feature_extractor._freeze_parameters() @staticmethod - # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForPreTraining.compute_contrastive_logits def compute_contrastive_logits( target_features: torch.FloatTensor, negative_features: torch.FloatTensor, @@ -1423,7 +1349,6 @@ class Wav2Vec2ConformerForPreTraining(Wav2Vec2ConformerPreTrainedModel): @add_start_docstrings_to_model_forward(WAV2VEC2_CONFORMER_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=Wav2Vec2ConformerForPreTrainingOutput, config_class=_CONFIG_FOR_DOC) - # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForPreTraining.forward with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer,wav2vec2_conformer-base->wav2vec2-conformer-rel-pos-large def forward( self, input_values: Optional[torch.Tensor], @@ -1452,8 +1377,8 @@ class Wav2Vec2ConformerForPreTraining(Wav2Vec2ConformerPreTrainedModel): >>> from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer import _compute_mask_indices, _sample_negative_indices >>> from datasets import load_dataset - >>> feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-conformer-rel-pos-large") - >>> model = Wav2Vec2ConformerForPreTraining.from_pretrained("facebook/wav2vec2-conformer-rel-pos-large") + >>> feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2_conformer-base") + >>> model = Wav2Vec2ConformerForPreTraining.from_pretrained("facebook/wav2vec2_conformer-base") >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") >>> input_values = feature_extractor(ds[0]["audio"]["array"], return_tensors="pt").input_values # Batch size 1 @@ -1585,12 +1510,18 @@ class Wav2Vec2ConformerForPreTraining(Wav2Vec2ConformerPreTrainedModel): ) +_HIDDEN_STATES_START_POSITION = 2 + +# CTC docstring +_CTC_EXPECTED_OUTPUT = "'MISTER QUILTER IS THE APOSTLE OF THE MIDDLE CLASSES AND WE ARE GLAD TO WELCOME HIS GOSPEL'" +_CTC_EXPECTED_LOSS = 64.21 + + @add_start_docstrings( """Wav2Vec2Conformer Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC).""", WAV2VEC2_CONFORMER_START_DOCSTRING, ) class Wav2Vec2ConformerForCTC(Wav2Vec2ConformerPreTrainedModel): - # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForCTC.__init__ with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer def __init__(self, config, target_lang: Optional[str] = None): super().__init__(config) @@ -1614,7 +1545,6 @@ class Wav2Vec2ConformerForCTC(Wav2Vec2ConformerPreTrainedModel): # Initialize weights and apply final processing self.post_init() - # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForCTC.freeze_feature_encoder with wav2vec2->wav2vec2_conformer def freeze_feature_encoder(self): """ Calling this function will disable the gradient computation for the feature encoder so that its parameter will @@ -1630,7 +1560,6 @@ class Wav2Vec2ConformerForCTC(Wav2Vec2ConformerPreTrainedModel): expected_output=_CTC_EXPECTED_OUTPUT, expected_loss=_CTC_EXPECTED_LOSS, ) - # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForCTC.forward with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer def forward( self, input_values: Optional[torch.Tensor], @@ -1710,7 +1639,6 @@ class Wav2Vec2ConformerForCTC(Wav2Vec2ConformerPreTrainedModel): WAV2VEC2_CONFORMER_START_DOCSTRING, ) class Wav2Vec2ConformerForSequenceClassification(Wav2Vec2ConformerPreTrainedModel): - # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.__init__ with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer def __init__(self, config): super().__init__(config) @@ -1728,7 +1656,6 @@ class Wav2Vec2ConformerForSequenceClassification(Wav2Vec2ConformerPreTrainedMode # Initialize weights and apply final processing self.post_init() - # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.freeze_feature_encoder with wav2vec2->wav2vec2_conformer def freeze_feature_encoder(self): """ Calling this function will disable the gradient computation for the feature encoder so that its parameter will @@ -1751,7 +1678,6 @@ class Wav2Vec2ConformerForSequenceClassification(Wav2Vec2ConformerPreTrainedMode config_class=_CONFIG_FOR_DOC, modality="audio", ) - # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.forward with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer,WAV_2_VEC_2->WAV2VEC2_CONFORMER def forward( self, input_values: Optional[torch.Tensor], @@ -1822,7 +1748,6 @@ class Wav2Vec2ConformerForSequenceClassification(Wav2Vec2ConformerPreTrainedMode WAV2VEC2_CONFORMER_START_DOCSTRING, ) class Wav2Vec2ConformerForAudioFrameClassification(Wav2Vec2ConformerPreTrainedModel): - # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForAudioFrameClassification.__init__ with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer,WAV_2_VEC_2->WAV2VEC2_CONFORMER def __init__(self, config): super().__init__(config) @@ -1839,7 +1764,6 @@ class Wav2Vec2ConformerForAudioFrameClassification(Wav2Vec2ConformerPreTrainedMo self.init_weights() - # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForAudioFrameClassification.freeze_feature_encoder with wav2vec2->wav2vec2_conformer def freeze_feature_encoder(self): """ Calling this function will disable the gradient computation for the feature encoder so that its parameter will @@ -1847,7 +1771,6 @@ class Wav2Vec2ConformerForAudioFrameClassification(Wav2Vec2ConformerPreTrainedMo """ self.wav2vec2_conformer.feature_extractor._freeze_parameters() - # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForAudioFrameClassification.freeze_base_model with wav2vec2->wav2vec2_conformer def freeze_base_model(self): """ Calling this function will disable the gradient computation for the base model so that its parameters will not @@ -1863,7 +1786,6 @@ class Wav2Vec2ConformerForAudioFrameClassification(Wav2Vec2ConformerPreTrainedMo config_class=_CONFIG_FOR_DOC, modality="audio", ) - # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForAudioFrameClassification.forward with wav2vec2->wav2vec2_conformer def forward( self, input_values: Optional[torch.Tensor], @@ -1918,7 +1840,6 @@ class Wav2Vec2ConformerForAudioFrameClassification(Wav2Vec2ConformerPreTrainedMo ) -# Copied from transformers.models.wav2vec2.modeling_wav2vec2.AMSoftmaxLoss class AMSoftmaxLoss(nn.Module): def __init__(self, input_dim, num_labels, scale=30.0, margin=0.4): super(AMSoftmaxLoss, self).__init__() @@ -1942,7 +1863,6 @@ class AMSoftmaxLoss(nn.Module): return loss -# Copied from transformers.models.wav2vec2.modeling_wav2vec2.TDNNLayer class TDNNLayer(nn.Module): def __init__(self, config, layer_id=0): super().__init__() @@ -1958,6 +1878,7 @@ class TDNNLayer(nn.Module): if is_peft_available(): from peft.tuners.lora import LoraLayer + if is_peft_available(): if isinstance(self.kernel, LoraLayer): warnings.warn( "Detected LoRA on TDNNLayer. LoRA weights won't be applied due to optimization. " @@ -2000,7 +1921,6 @@ class Wav2Vec2ConformerForXVector(Wav2Vec2ConformerPreTrainedModel): self.init_weights() - # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForXVector.freeze_feature_encoder with wav2vec2->wav2vec2_conformer def freeze_feature_encoder(self): """ Calling this function will disable the gradient computation for the feature encoder so that its parameter will @@ -2008,7 +1928,6 @@ class Wav2Vec2ConformerForXVector(Wav2Vec2ConformerPreTrainedModel): """ self.wav2vec2_conformer.feature_extractor._freeze_parameters() - # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForXVector.freeze_base_model with wav2vec2->wav2vec2_conformer def freeze_base_model(self): """ Calling this function will disable the gradient computation for the base model so that its parameters will not @@ -2017,7 +1936,6 @@ class Wav2Vec2ConformerForXVector(Wav2Vec2ConformerPreTrainedModel): for param in self.wav2vec2_conformer.parameters(): param.requires_grad = False - # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForXVector._get_tdnn_output_lengths with wav2vec2->wav2vec2_conformer def _get_tdnn_output_lengths(self, input_lengths: Union[torch.LongTensor, int]): """ Computes the output length of the TDNN layers @@ -2040,7 +1958,6 @@ class Wav2Vec2ConformerForXVector(Wav2Vec2ConformerPreTrainedModel): config_class=_CONFIG_FOR_DOC, modality="audio", ) - # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForXVector.forward with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer,WAV_2_VEC_2->WAV2VEC2_CONFORMER def forward( self, input_values: Optional[torch.Tensor], diff --git a/src/transformers/models/wav2vec2_conformer/modular_wav2vec2_conformer.py b/src/transformers/models/wav2vec2_conformer/modular_wav2vec2_conformer.py new file mode 100644 index 00000000000..c2d101385fa --- /dev/null +++ b/src/transformers/models/wav2vec2_conformer/modular_wav2vec2_conformer.py @@ -0,0 +1,892 @@ +import math +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import torch +from torch import nn + +from ...activations import ACT2FN +from ...integrations.deepspeed import is_deepspeed_zero3_enabled +from ...integrations.fsdp import is_fsdp_managed_module +from ...modeling_outputs import ( + BaseModelOutput, + CausalLMOutput, + SequenceClassifierOutput, + TokenClassifierOutput, + Wav2Vec2BaseModelOutput, + XVectorOutput, +) +from ...modeling_utils import PreTrainedModel +from ...utils import ( + ModelOutput, + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, +) +from ..wav2vec2.modeling_wav2vec2 import ( + Wav2Vec2Adapter, + Wav2Vec2AdapterLayer, + Wav2Vec2FeatureEncoder, + Wav2Vec2FeatureProjection, + Wav2Vec2FeedForward, + Wav2Vec2ForAudioFrameClassification, + Wav2Vec2ForCTC, + Wav2Vec2ForPreTraining, + Wav2Vec2ForSequenceClassification, + Wav2Vec2ForXVector, + Wav2Vec2GumbelVectorQuantizer, + Wav2Vec2Model, + Wav2Vec2PositionalConvEmbedding, +) +from .configuration_wav2vec2_conformer import Wav2Vec2ConformerConfig + + +logger = logging.get_logger(__name__) + +_HIDDEN_STATES_START_POSITION = 2 + +# General docstring +_CONFIG_FOR_DOC = "Wav2Vec2ConformerConfig" + +# Base docstring +_CHECKPOINT_FOR_DOC = "facebook/wav2vec2-conformer-rope-large-960h-ft" +_EXPECTED_OUTPUT_SHAPE = [1, 292, 1024] + +# CTC docstring +_CTC_EXPECTED_OUTPUT = "'MISTER QUILTER IS THE APOSTLE OF THE MIDDLE CLASSES AND WE ARE GLAD TO WELCOME HIS GOSPEL'" +_CTC_EXPECTED_LOSS = 64.21 + + +@dataclass +class Wav2Vec2ConformerForPreTrainingOutput(ModelOutput): + """ + Output type of [`Wav2Vec2ConformerForPreTraining`], with potential hidden states and attentions. + + Args: + loss (*optional*, returned when `sample_negative_indices` are passed, `torch.FloatTensor` of shape `(1,)`): + Total loss as the sum of the contrastive loss (L_m) and the diversity loss (L_d) as stated in the [official + paper](https://arxiv.org/pdf/2006.11477.pdf) . (classification) loss. + projected_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.proj_codevector_dim)`): + Hidden-states of the model projected to *config.proj_codevector_dim* that can be used to predict the masked + projected quantized states. + projected_quantized_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.proj_codevector_dim)`): + Quantized extracted feature vectors projected to *config.proj_codevector_dim* representing the positive + target vectors for contrastive loss. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + contrastive_loss (*optional*, returned when `sample_negative_indices` are passed, `torch.FloatTensor` of shape `(1,)`): + The contrastive loss (L_m) as stated in the [official paper](https://arxiv.org/pdf/2006.11477.pdf) . + diversity_loss (*optional*, returned when `sample_negative_indices` are passed, `torch.FloatTensor` of shape `(1,)`): + The diversity loss (L_d) as stated in the [official paper](https://arxiv.org/pdf/2006.11477.pdf) . + """ + + loss: Optional[torch.FloatTensor] = None + projected_states: Optional[torch.FloatTensor] = None + projected_quantized_states: Optional[torch.FloatTensor] = None + codevector_perplexity: Optional[torch.FloatTensor] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + contrastive_loss: Optional[torch.FloatTensor] = None + diversity_loss: Optional[torch.FloatTensor] = None + + +class Wav2Vec2ConformerPositionalConvEmbedding(Wav2Vec2PositionalConvEmbedding): + pass + + +class Wav2Vec2ConformerRotaryPositionalEmbedding(nn.Module): + """Rotary positional embedding + Reference : https://blog.eleuther.ai/rotary-embeddings/ Paper: https://arxiv.org/pdf/2104.09864.pdf + """ + + def __init__(self, config): + super().__init__() + dim = config.hidden_size // config.num_attention_heads + base = config.rotary_embedding_base + + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim)) + self.register_buffer("inv_freq", inv_freq) + self.cached_sequence_length = None + self.cached_rotary_positional_embedding = None + + def forward(self, hidden_states): + sequence_length = hidden_states.shape[1] + + if sequence_length == self.cached_sequence_length and self.cached_rotary_positional_embedding is not None: + return self.cached_rotary_positional_embedding + + self.cached_sequence_length = sequence_length + # Embeddings are computed in the dtype of the inv_freq constant + time_stamps = torch.arange(sequence_length).type_as(self.inv_freq) + freqs = torch.einsum("i,j->ij", time_stamps, self.inv_freq) + embeddings = torch.cat((freqs, freqs), dim=-1) + + cos_embeddings = embeddings.cos()[:, None, None, :] + sin_embeddings = embeddings.sin()[:, None, None, :] + # Computed embeddings are cast to the dtype of the hidden state inputs + self.cached_rotary_positional_embedding = torch.stack([cos_embeddings, sin_embeddings]).type_as(hidden_states) + return self.cached_rotary_positional_embedding + + +class Wav2Vec2ConformerRelPositionalEmbedding(nn.Module): + """Relative positional encoding module.""" + + def __init__(self, config): + super().__init__() + self.max_len = config.max_source_positions + self.d_model = config.hidden_size + self.pe = None + self.extend_pe(torch.tensor(0.0).expand(1, self.max_len)) + + def extend_pe(self, x): + # Reset the positional encodings + if self.pe is not None: + # self.pe contains both positive and negative parts + # the length of self.pe is 2 * input_len - 1 + if self.pe.size(1) >= x.size(1) * 2 - 1: + if self.pe.dtype != x.dtype or self.pe.device != x.device: + self.pe = self.pe.to(dtype=x.dtype, device=x.device) + return + # Suppose `i` is the position of query vector and `j` is the + # position of key vector. We use positive relative positions when keys + # are to the left (i>j) and negative relative positions otherwise (i (batch, 2*channel, dim) + hidden_states = self.pointwise_conv1(hidden_states) + # => (batch, channel, dim) + hidden_states = self.glu(hidden_states) + + # 1D Depthwise Conv + hidden_states = self.depthwise_conv(hidden_states) + hidden_states = self.batch_norm(hidden_states) + hidden_states = self.activation(hidden_states) + + hidden_states = self.pointwise_conv2(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = hidden_states.transpose(1, 2) + return hidden_states + + +class Wav2Vec2ConformerSelfAttention(nn.Module): + """Construct an Wav2Vec2ConformerSelfAttention object. + Can be enhanced with rotary or relative position embeddings. + """ + + def __init__(self, config): + super().__init__() + + self.head_size = config.hidden_size // config.num_attention_heads + self.num_heads = config.num_attention_heads + self.position_embeddings_type = config.position_embeddings_type + + self.linear_q = nn.Linear(config.hidden_size, config.hidden_size) + self.linear_k = nn.Linear(config.hidden_size, config.hidden_size) + self.linear_v = nn.Linear(config.hidden_size, config.hidden_size) + self.linear_out = nn.Linear(config.hidden_size, config.hidden_size) + + self.dropout = nn.Dropout(p=config.attention_dropout) + + if self.position_embeddings_type == "relative": + # linear transformation for positional encoding + self.linear_pos = nn.Linear(config.hidden_size, config.hidden_size, bias=False) + # these two learnable bias are used in matrix c and matrix d + # as described in https://arxiv.org/abs/1901.02860 Section 3.3 + self.pos_bias_u = nn.Parameter(torch.zeros(self.num_heads, self.head_size)) + self.pos_bias_v = nn.Parameter(torch.zeros(self.num_heads, self.head_size)) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + relative_position_embeddings: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + # self-attention mechanism + batch_size, sequence_length, hidden_size = hidden_states.size() + + # make sure query/key states can be != value states + query_key_states = hidden_states + value_states = hidden_states + + if self.position_embeddings_type == "rotary": + if relative_position_embeddings is None: + raise ValueError( + "`relative_position_embeddings` has to be defined when `self.position_embeddings_type == 'rotary'" + ) + query_key_states = self._apply_rotary_embedding(query_key_states, relative_position_embeddings) + + # project query_key_states and value_states + query = self.linear_q(query_key_states).view(batch_size, -1, self.num_heads, self.head_size) + key = self.linear_k(query_key_states).view(batch_size, -1, self.num_heads, self.head_size) + value = self.linear_v(value_states).view(batch_size, -1, self.num_heads, self.head_size) + + # => (batch, head, time1, d_k) + query = query.transpose(1, 2) + key = key.transpose(1, 2) + value = value.transpose(1, 2) + + if self.position_embeddings_type == "relative": + if relative_position_embeddings is None: + raise ValueError( + "`relative_position_embeddings` has to be defined when `self.position_embeddings_type ==" + " 'relative'" + ) + # apply relative_position_embeddings to qk scores + # as proposed in Transformer_XL: https://arxiv.org/abs/1901.02860 + scores = self._apply_relative_embeddings( + query=query, key=key, relative_position_embeddings=relative_position_embeddings + ) + else: + scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.head_size) + + # apply attention_mask if necessary + if attention_mask is not None: + scores = scores + attention_mask + + # => (batch, head, time1, time2) + probs = torch.softmax(scores, dim=-1) + probs = self.dropout(probs) + + # => (batch, head, time1, d_k) + hidden_states = torch.matmul(probs, value) + + # => (batch, time1, hidden_size) + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_size) + hidden_states = self.linear_out(hidden_states) + + return hidden_states, probs + + def _apply_rotary_embedding(self, hidden_states, relative_position_embeddings): + batch_size, sequence_length, hidden_size = hidden_states.size() + hidden_states = hidden_states.view(batch_size, sequence_length, self.num_heads, self.head_size) + + cos = relative_position_embeddings[0, :sequence_length, ...] + sin = relative_position_embeddings[1, :sequence_length, ...] + + # rotate hidden_states with rotary embeddings + hidden_states = hidden_states.transpose(0, 1) + rotated_states_begin = hidden_states[..., : self.head_size // 2] + rotated_states_end = hidden_states[..., self.head_size // 2 :] + rotated_states = torch.cat((-rotated_states_end, rotated_states_begin), dim=rotated_states_begin.ndim - 1) + hidden_states = (hidden_states * cos) + (rotated_states * sin) + hidden_states = hidden_states.transpose(0, 1) + + hidden_states = hidden_states.view(batch_size, sequence_length, self.num_heads * self.head_size) + + return hidden_states + + def _apply_relative_embeddings(self, query, key, relative_position_embeddings): + # 1. project positional embeddings + # => (batch, head, 2*time1-1, d_k) + proj_relative_position_embeddings = self.linear_pos(relative_position_embeddings) + proj_relative_position_embeddings = proj_relative_position_embeddings.view( + relative_position_embeddings.size(0), -1, self.num_heads, self.head_size + ) + proj_relative_position_embeddings = proj_relative_position_embeddings.transpose(1, 2) + proj_relative_position_embeddings = proj_relative_position_embeddings.transpose(2, 3) + + # 2. Add bias to query + # => (batch, head, time1, d_k) + query = query.transpose(1, 2) + q_with_bias_u = (query + self.pos_bias_u).transpose(1, 2) + q_with_bias_v = (query + self.pos_bias_v).transpose(1, 2) + + # 3. attention score: first compute matrix a and matrix c + # as described in https://arxiv.org/abs/1901.02860 Section 3.3 + # => (batch, head, time1, time2) + scores_ac = torch.matmul(q_with_bias_u, key.transpose(-2, -1)) + + # 4. then compute matrix b and matrix d + # => (batch, head, time1, 2*time1-1) + scores_bd = torch.matmul(q_with_bias_v, proj_relative_position_embeddings) + + # 5. shift matrix b and matrix d + zero_pad = torch.zeros((*scores_bd.size()[:3], 1), device=scores_bd.device, dtype=scores_bd.dtype) + scores_bd_padded = torch.cat([zero_pad, scores_bd], dim=-1) + scores_bd_padded_shape = scores_bd.size()[:2] + (scores_bd.shape[3] + 1, scores_bd.shape[2]) + scores_bd_padded = scores_bd_padded.view(*scores_bd_padded_shape) + scores_bd = scores_bd_padded[:, :, 1:].view_as(scores_bd) + scores_bd = scores_bd[:, :, :, : scores_bd.size(-1) // 2 + 1] + + # 6. sum matrices + # => (batch, head, time1, time2) + scores = (scores_ac + scores_bd) / math.sqrt(self.head_size) + + return scores + + +class Wav2Vec2ConformerEncoderLayer(nn.Module): + """Conformer block based on https://arxiv.org/abs/2005.08100.""" + + def __init__(self, config): + super().__init__() + embed_dim = config.hidden_size + dropout = config.attention_dropout + + # Feed-forward 1 + self.ffn1_layer_norm = nn.LayerNorm(embed_dim) + self.ffn1 = Wav2Vec2ConformerFeedForward(config) + + # Self-Attention + self.self_attn_layer_norm = nn.LayerNorm(embed_dim) + self.self_attn_dropout = nn.Dropout(dropout) + self.self_attn = Wav2Vec2ConformerSelfAttention(config) + + # Conformer Convolution + self.conv_module = Wav2Vec2ConformerConvolutionModule(config) + + # Feed-forward 2 + self.ffn2_layer_norm = nn.LayerNorm(embed_dim) + self.ffn2 = Wav2Vec2ConformerFeedForward(config) + self.final_layer_norm = nn.LayerNorm(embed_dim) + + def forward( + self, + hidden_states, + attention_mask: Optional[torch.Tensor] = None, + relative_position_embeddings: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ): + hidden_states = hidden_states + + # 1. Feed-Forward 1 layer + residual = hidden_states + hidden_states = self.ffn1_layer_norm(hidden_states) + hidden_states = self.ffn1(hidden_states) + hidden_states = hidden_states * 0.5 + residual + residual = hidden_states + + # 2. Self-Attention layer + hidden_states = self.self_attn_layer_norm(hidden_states) + hidden_states, attn_weigts = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + relative_position_embeddings=relative_position_embeddings, + output_attentions=output_attentions, + ) + hidden_states = self.self_attn_dropout(hidden_states) + hidden_states = hidden_states + residual + + # 3. Convolutional Layer + residual = hidden_states + hidden_states = self.conv_module(hidden_states) + hidden_states = residual + hidden_states + + # 4. Feed-Forward 2 Layer + residual = hidden_states + hidden_states = self.ffn2_layer_norm(hidden_states) + hidden_states = self.ffn2(hidden_states) + hidden_states = hidden_states * 0.5 + residual + hidden_states = self.final_layer_norm(hidden_states) + + return hidden_states, attn_weigts + + +class Wav2Vec2ConformerEncoder(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + + if config.position_embeddings_type == "relative": + self.embed_positions = Wav2Vec2ConformerRelPositionalEmbedding(config) + elif config.position_embeddings_type == "rotary": + self.embed_positions = Wav2Vec2ConformerRotaryPositionalEmbedding(config) + else: + self.embed_positions = None + + self.pos_conv_embed = Wav2Vec2ConformerPositionalConvEmbedding(config) + self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout) + self.layers = nn.ModuleList([Wav2Vec2ConformerEncoderLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False + + def forward( + self, + hidden_states, + attention_mask=None, + output_attentions=False, + output_hidden_states=False, + return_dict=True, + ): + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + + if attention_mask is not None: + # make sure padded tokens output 0 + expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2]) + hidden_states[~expand_attention_mask] = 0.0 + + # extend attention_mask + attention_mask = 1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype) + attention_mask = attention_mask * torch.finfo(hidden_states.dtype).min + attention_mask = attention_mask.expand( + attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1] + ) + + hidden_states = self.dropout(hidden_states) + + if self.embed_positions is not None: + relative_position_embeddings = self.embed_positions(hidden_states) + else: + relative_position_embeddings = None + + synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self) + + for i, layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + dropout_probability = torch.rand([]) + + skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False + if not skip_the_layer or synced_gpus: + # under fsdp or deepspeed zero3 all gpus must run in sync + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + layer.__call__, + hidden_states, + attention_mask, + relative_position_embeddings, + output_attentions, + ) + else: + layer_outputs = layer( + hidden_states, + attention_mask=attention_mask, + relative_position_embeddings=relative_position_embeddings, + output_attentions=output_attentions, + ) + hidden_states = layer_outputs[0] + + if skip_the_layer: + layer_outputs = (None, None) + + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + + hidden_states = self.layer_norm(hidden_states) + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + +class Wav2Vec2ConformerGumbelVectorQuantizer(Wav2Vec2GumbelVectorQuantizer): + pass + + +class Wav2Vec2ConformerAdapter(Wav2Vec2Adapter): + pass + + +class Wav2Vec2ConformerAdapterLayer(Wav2Vec2AdapterLayer): + pass + + +class Wav2Vec2ConformerPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = Wav2Vec2ConformerConfig + base_model_prefix = "wav2vec2_conformer" + main_input_name = "input_values" + supports_gradient_checkpointing = True + + def _init_weights(self, module): + """Initialize the weights""" + # Wav2Vec2ForPreTraining last 2 linear layers need standard Linear init. + if isinstance(module, Wav2Vec2ConformerForPreTraining): + module.project_hid.reset_parameters() + module.project_q.reset_parameters() + module.project_hid._is_hf_initialized = True + module.project_q._is_hf_initialized = True + # gumbel softmax requires special init + elif isinstance(module, Wav2Vec2ConformerGumbelVectorQuantizer): + module.weight_proj.weight.data.normal_(mean=0.0, std=1) + module.weight_proj.bias.data.zero_() + nn.init.uniform_(module.codevectors) + elif isinstance(module, Wav2Vec2ConformerSelfAttention): + if hasattr(module, "pos_bias_u"): + nn.init.xavier_uniform_(module.pos_bias_u) + if hasattr(module, "pos_bias_v"): + nn.init.xavier_uniform_(module.pos_bias_v) + elif isinstance(module, Wav2Vec2ConformerPositionalConvEmbedding): + nn.init.normal_( + module.conv.weight, + mean=0, + std=2 * math.sqrt(1 / (module.conv.kernel_size[0] * module.conv.in_channels)), + ) + nn.init.constant_(module.conv.bias, 0) + elif isinstance(module, Wav2Vec2ConformerFeatureProjection): + k = math.sqrt(1 / module.projection.in_features) + nn.init.uniform_(module.projection.weight, a=-k, b=k) + nn.init.uniform_(module.projection.bias, a=-k, b=k) + elif isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + elif isinstance(module, nn.Conv1d): + nn.init.kaiming_normal_(module.weight) + + if module.bias is not None: + k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0])) + nn.init.uniform_(module.bias, a=-k, b=k) + + def _get_feat_extract_output_lengths( + self, input_lengths: Union[torch.LongTensor, int], add_adapter: Optional[bool] = None + ): + """ + Computes the output length of the convolutional layers + """ + + add_adapter = self.config.add_adapter if add_adapter is None else add_adapter + + def _conv_out_length(input_length, kernel_size, stride): + # 1D convolutional layer output length formula taken + # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html + return torch.div(input_length - kernel_size, stride, rounding_mode="floor") + 1 + + for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride): + input_lengths = _conv_out_length(input_lengths, kernel_size, stride) + + if add_adapter: + for _ in range(self.config.num_adapter_layers): + input_lengths = _conv_out_length(input_lengths, 1, self.config.adapter_stride) + + return input_lengths + + def _get_feature_vector_attention_mask( + self, feature_vector_length: int, attention_mask: torch.LongTensor, add_adapter=None + ): + # Effectively attention_mask.sum(-1), but not inplace to be able to run + # on inference mode. + non_padded_lengths = attention_mask.cumsum(dim=-1)[:, -1] + + output_lengths = self._get_feat_extract_output_lengths(non_padded_lengths, add_adapter=add_adapter) + output_lengths = output_lengths.to(torch.long) + + batch_size = attention_mask.shape[0] + + attention_mask = torch.zeros( + (batch_size, feature_vector_length), dtype=attention_mask.dtype, device=attention_mask.device + ) + # these two operations makes sure that all values before the output lengths idxs are attended to + attention_mask[(torch.arange(attention_mask.shape[0], device=attention_mask.device), output_lengths - 1)] = 1 + attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool() + return attention_mask + + +WAV2VEC2_CONFORMER_START_DOCSTRING = None # will be automatically redefined + +WAV2VEC2_CONFORMER_INPUTS_DOCSTRING = r""" + Args: + input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`): + Float values of input raw speech waveform. Values can be obtained by loading a `.flac` or `.wav` audio file + into an array of type `List[float]` or a `numpy.ndarray`, *e.g.* via the soundfile library (`pip install + soundfile`). To prepare the array into `input_values`, the [`AutoProcessor`] should be used for padding and + conversion into a tensor of type `torch.FloatTensor`. See [`Wav2Vec2Processor.__call__`] for details. + attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing convolution and attention on padding token indices. Mask values selected in `[0, + 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + + + `attention_mask` should only be passed if the corresponding processor has `config.return_attention_mask == + True`. For all models whose processor has `config.return_attention_mask == False`, such as + [wav2vec2-conformer-rel-pos-large](https://huggingface.co/facebook/wav2vec2-conformer-rel-pos-large), + `attention_mask` should **not** be passed to avoid degraded performance when doing batched inference. For + such models `input_values` should simply be padded with 0 and passed without `attention_mask`. Be aware + that these models also yield slightly different results depending on whether `input_values` is padded or + not. + + + + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + +Wav2Vec2ConformerBaseModelOutput = Wav2Vec2BaseModelOutput + + +@add_start_docstrings( + "The bare Wav2Vec2Conformer Model transformer outputting raw hidden-states without any specific head on top.", + WAV2VEC2_CONFORMER_START_DOCSTRING, +) +class Wav2Vec2ConformerModel(Wav2Vec2ConformerPreTrainedModel, Wav2Vec2Model): + def __init__(self, config: Wav2Vec2ConformerConfig): + Wav2Vec2ConformerPreTrainedModel.__init__(config) + self.config = config + self.feature_extractor = Wav2Vec2ConformerFeatureEncoder(config) + self.feature_projection = Wav2Vec2ConformerFeatureProjection(config) + + # model only needs masking vector if mask prob is > 0.0 + if config.mask_time_prob > 0.0 or config.mask_feature_prob > 0.0: + self.masked_spec_embed = nn.Parameter(torch.Tensor(config.hidden_size).uniform_()) + + self.encoder = Wav2Vec2ConformerEncoder(config) + + self.adapter = Wav2Vec2ConformerAdapter(config) if config.add_adapter else None + + # Initialize weights and apply final processing + self.post_init() + + def freeze_feature_extractor(self): + raise AttributeError("Not needed for Wav2Vec2Conformer") + + @add_start_docstrings_to_model_forward(WAV2VEC2_CONFORMER_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=Wav2Vec2ConformerBaseModelOutput, + config_class=_CONFIG_FOR_DOC, + modality="audio", + expected_output=_EXPECTED_OUTPUT_SHAPE, + ) + def forward(self, **super_kwargs): + return super().forward(**super_kwargs) + + +@add_start_docstrings( + """Wav2Vec2Conformer Model with a quantizer and `VQ` head on top.""", WAV2VEC2_CONFORMER_START_DOCSTRING +) +class Wav2Vec2ConformerForPreTraining(Wav2Vec2ForPreTraining): + def __init__(self, config: Wav2Vec2ConformerConfig): + super().__init__(config) + + def freeze_feature_extractor(self): + raise AttributeError("Not needed for Wav2Vec2Conformer") + + @add_start_docstrings_to_model_forward(WAV2VEC2_CONFORMER_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=Wav2Vec2ConformerForPreTrainingOutput, config_class=_CONFIG_FOR_DOC) + def forward(self, **super_kwargs) -> Union[Tuple, Wav2Vec2ConformerForPreTrainingOutput]: + return super().forward(**super_kwargs) + + +@add_start_docstrings( + """Wav2Vec2Conformer Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC).""", + WAV2VEC2_CONFORMER_START_DOCSTRING, +) +class Wav2Vec2ConformerForCTC(Wav2Vec2ForCTC): + def __init__(self, config, target_lang: Optional[str] = None): + super().__init__(config) + + def tie_weights(self): + raise AttributeError("Not needed for Wav2Vec2Conformer") + + def freeze_feature_extractor(self): + raise AttributeError("Not needed for Wav2Vec2Conformer") + + def freeze_base_model(self): + raise AttributeError("Not needed for Wav2Vec2Conformer") + + @add_start_docstrings_to_model_forward(WAV2VEC2_CONFORMER_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=CausalLMOutput, + config_class=_CONFIG_FOR_DOC, + expected_output=_CTC_EXPECTED_OUTPUT, + expected_loss=_CTC_EXPECTED_LOSS, + ) + def forward(self, **super_kwargs): + return super().forward(**super_kwargs) + + +@add_start_docstrings( + """ + Wav2Vec2Conformer Model with a sequence classification head on top (a linear layer over the pooled output) for + tasks like SUPERB Keyword Spotting. + """, + WAV2VEC2_CONFORMER_START_DOCSTRING, +) +class Wav2Vec2ConformerForSequenceClassification(Wav2Vec2ForSequenceClassification): + def __init__(self, config): + super().__init__(config) + + def freeze_feature_extractor(self): + raise AttributeError("Not needed for Wav2Vec2Conformer") + + @add_start_docstrings_to_model_forward(WAV2VEC2_CONFORMER_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=SequenceClassifierOutput, + config_class=_CONFIG_FOR_DOC, + modality="audio", + ) + def forward(self, **super_kwargs): + return super().forward(**super_kwargs) + + +@add_start_docstrings( + """ + Wav2Vec2Conformer Model with a frame classification head on top for tasks like Speaker Diarization. + """, + WAV2VEC2_CONFORMER_START_DOCSTRING, +) +class Wav2Vec2ConformerForAudioFrameClassification(Wav2Vec2ForAudioFrameClassification): + def __init__(self, config): + super().__init__(config) + + def freeze_feature_extractor(self): + raise AttributeError("Not needed for Wav2Vec2Conformer") + + @add_start_docstrings_to_model_forward(WAV2VEC2_CONFORMER_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TokenClassifierOutput, + config_class=_CONFIG_FOR_DOC, + modality="audio", + ) + def forward(self, **super_kwargs): + return super().forward(**super_kwargs) + + +@add_start_docstrings( + """ + Wav2Vec2Conformer Model with an XVector feature extraction head on top for tasks like Speaker Verification. + """, + WAV2VEC2_CONFORMER_START_DOCSTRING, +) +class Wav2Vec2ConformerForXVector(Wav2Vec2ForXVector): + def __init__(self, config): + super().__init__(config) + + def freeze_feature_extractor(self): + raise AttributeError("Not needed for Wav2Vec2Conformer") + + @add_start_docstrings_to_model_forward(WAV2VEC2_CONFORMER_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=XVectorOutput, + config_class=_CONFIG_FOR_DOC, + modality="audio", + ) + def forward(self, **super_kwargs): + return super().forward(**super_kwargs) + + +__all__ = [ + "Wav2Vec2ConformerForAudioFrameClassification", + "Wav2Vec2ConformerForCTC", + "Wav2Vec2ConformerForPreTraining", + "Wav2Vec2ConformerForSequenceClassification", + "Wav2Vec2ConformerForXVector", + "Wav2Vec2ConformerModel", + "Wav2Vec2ConformerPreTrainedModel", +] diff --git a/src/transformers/models/wavlm/modeling_wavlm.py b/src/transformers/models/wavlm/modeling_wavlm.py index 270d48f8378..1c3c09d1a70 100755 --- a/src/transformers/models/wavlm/modeling_wavlm.py +++ b/src/transformers/models/wavlm/modeling_wavlm.py @@ -1,28 +1,17 @@ -# coding=utf-8 -# Copyright 2021 The Fairseq Authors, Microsoft Research, and The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""PyTorch WavLM model.""" - +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/wavlm/modular_wavlm.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_wavlm.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 import math import warnings from typing import Optional, Tuple, Union import numpy as np import torch +import torch.nn as nn import torch.nn.functional as F -import torch.utils.checkpoint -from torch import nn from torch.nn import CrossEntropyLoss from ...activations import ACT2FN @@ -49,225 +38,22 @@ from .configuration_wavlm import WavLMConfig logger = logging.get_logger(__name__) +_CHECKPOINT_FOR_DOC = "patrickvonplaten/wavlm-libri-clean-100h-base-plus" -_HIDDEN_STATES_START_POSITION = 2 - -# General docstring _CONFIG_FOR_DOC = "WavLMConfig" -# Base docstring -_CHECKPOINT_FOR_DOC = "patrickvonplaten/wavlm-libri-clean-100h-base-plus" -_EXPECTED_OUTPUT_SHAPE = [1, 292, 768] -# CTC docstring -_CTC_EXPECTED_OUTPUT = "'mister quilter is the aposle of the middle classes and we are glad to welcome his gospel'" -_CTC_EXPECTED_LOSS = 12.51 - -# Frame class docstring -_FRAME_CLASS_CHECKPOINT = "microsoft/wavlm-base-plus-sd" -_FRAME_EXPECTED_OUTPUT = [0, 0] - -# Speaker Verification docstring -_XVECTOR_CHECKPOINT = "microsoft/wavlm-base-plus-sv" -_XVECTOR_EXPECTED_OUTPUT = 0.97 - - -# Copied from transformers.models.wav2vec2.modeling_wav2vec2._compute_mask_indices -def _compute_mask_indices( - shape: Tuple[int, int], - mask_prob: float, - mask_length: int, - attention_mask: Optional[torch.LongTensor] = None, - min_masks: int = 0, -) -> np.ndarray: - """ - Computes random mask spans for a given shape. Used to implement [SpecAugment: A Simple Data Augmentation Method for - ASR](https://arxiv.org/abs/1904.08779). Note that this method is not optimized to run on TPU and should be run on - CPU as part of the preprocessing during training. - - Args: - shape: The shape for which to compute masks. This should be of a tuple of size 2 where - the first element is the batch size and the second element is the length of the axis to span. - mask_prob: The percentage of the whole axis (between 0 and 1) which will be masked. The number of - independently generated mask spans of length `mask_length` is computed by - `mask_prob*shape[1]/mask_length`. Note that due to overlaps, `mask_prob` is an upper bound and the - actual percentage will be smaller. - mask_length: size of the mask - min_masks: minimum number of masked spans - attention_mask: A (right-padded) attention mask which independently shortens the feature axis of - each batch dimension. - """ - batch_size, sequence_length = shape - - if mask_length < 1: - raise ValueError("`mask_length` has to be bigger than 0.") - - if mask_length > sequence_length: - raise ValueError( - f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length}" - f" and `sequence_length`: {sequence_length}`" - ) - - # epsilon is used for probabilistic rounding - epsilon = np.random.rand(1).item() - - def compute_num_masked_span(input_length): - """Given input length, compute how many spans should be masked""" - num_masked_span = int(mask_prob * input_length / mask_length + epsilon) - num_masked_span = max(num_masked_span, min_masks) - - # make sure num masked span <= sequence_length - if num_masked_span * mask_length > sequence_length: - num_masked_span = sequence_length // mask_length - - # make sure num_masked span is also <= input_length - (mask_length - 1) - if input_length - (mask_length - 1) < num_masked_span: - num_masked_span = max(input_length - (mask_length - 1), 0) - - return num_masked_span - - # compute number of masked spans in batch - input_lengths = ( - attention_mask.detach().sum(-1).tolist() - if attention_mask is not None - else [sequence_length for _ in range(batch_size)] - ) - - # SpecAugment mask to fill - spec_aug_mask = np.zeros((batch_size, sequence_length), dtype=bool) - spec_aug_mask_idxs = [] - - max_num_masked_span = compute_num_masked_span(sequence_length) - - if max_num_masked_span == 0: - return spec_aug_mask - - for input_length in input_lengths: - # compute num of masked spans for this input - num_masked_span = compute_num_masked_span(input_length) - - # get random indices to mask - spec_aug_mask_idx = np.random.choice( - np.arange(input_length - (mask_length - 1)), num_masked_span, replace=False - ) - - # pick first sampled index that will serve as a dummy index to pad vector - # to ensure same dimension for all batches due to probabilistic rounding - # Picking first sample just pads those vectors twice. - if len(spec_aug_mask_idx) == 0: - # this case can only happen if `input_length` is strictly smaller then - # `sequence_length` in which case the last token has to be a padding - # token which we can use as a dummy mask id - dummy_mask_idx = sequence_length - 1 - else: - dummy_mask_idx = spec_aug_mask_idx[0] - - spec_aug_mask_idx = np.concatenate( - [spec_aug_mask_idx, np.ones(max_num_masked_span - num_masked_span, dtype=np.int32) * dummy_mask_idx] - ) - spec_aug_mask_idxs.append(spec_aug_mask_idx) - - spec_aug_mask_idxs = np.array(spec_aug_mask_idxs) - - # expand masked indices to masked spans - spec_aug_mask_idxs = np.broadcast_to( - spec_aug_mask_idxs[:, :, None], (batch_size, max_num_masked_span, mask_length) - ) - spec_aug_mask_idxs = spec_aug_mask_idxs.reshape(batch_size, max_num_masked_span * mask_length) - - # add offset to the starting indexes so that indexes now create a span - offsets = np.arange(mask_length)[None, None, :] - offsets = np.broadcast_to(offsets, (batch_size, max_num_masked_span, mask_length)).reshape( - batch_size, max_num_masked_span * mask_length - ) - spec_aug_mask_idxs = spec_aug_mask_idxs + offsets - - # ensure that we cannot have indices larger than sequence_length - if spec_aug_mask_idxs.max() > sequence_length - 1: - spec_aug_mask_idxs[spec_aug_mask_idxs > sequence_length - 1] = sequence_length - 1 - - # scatter indices to mask - np.put_along_axis(spec_aug_mask, spec_aug_mask_idxs, 1, -1) - - return spec_aug_mask - - -# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2NoLayerNormConvLayer with Wav2Vec2->WavLM -class WavLMNoLayerNormConvLayer(nn.Module): - def __init__(self, config, layer_id=0): +class WavLMSamePadLayer(nn.Module): + def __init__(self, num_conv_pos_embeddings): super().__init__() - self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1 - self.out_conv_dim = config.conv_dim[layer_id] - - self.conv = nn.Conv1d( - self.in_conv_dim, - self.out_conv_dim, - kernel_size=config.conv_kernel[layer_id], - stride=config.conv_stride[layer_id], - bias=config.conv_bias, - ) - self.activation = ACT2FN[config.feat_extract_activation] + self.num_pad_remove = 1 if num_conv_pos_embeddings % 2 == 0 else 0 def forward(self, hidden_states): - hidden_states = self.conv(hidden_states) - hidden_states = self.activation(hidden_states) + if self.num_pad_remove > 0: + hidden_states = hidden_states[:, :, : -self.num_pad_remove] return hidden_states -# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2LayerNormConvLayer with Wav2Vec2->WavLM -class WavLMLayerNormConvLayer(nn.Module): - def __init__(self, config, layer_id=0): - super().__init__() - self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1 - self.out_conv_dim = config.conv_dim[layer_id] - - self.conv = nn.Conv1d( - self.in_conv_dim, - self.out_conv_dim, - kernel_size=config.conv_kernel[layer_id], - stride=config.conv_stride[layer_id], - bias=config.conv_bias, - ) - self.layer_norm = nn.LayerNorm(self.out_conv_dim, elementwise_affine=True) - self.activation = ACT2FN[config.feat_extract_activation] - - def forward(self, hidden_states): - hidden_states = self.conv(hidden_states) - - hidden_states = hidden_states.transpose(-2, -1) - hidden_states = self.layer_norm(hidden_states) - hidden_states = hidden_states.transpose(-2, -1) - - hidden_states = self.activation(hidden_states) - return hidden_states - - -# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2GroupNormConvLayer with Wav2Vec2->WavLM -class WavLMGroupNormConvLayer(nn.Module): - def __init__(self, config, layer_id=0): - super().__init__() - self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1 - self.out_conv_dim = config.conv_dim[layer_id] - - self.conv = nn.Conv1d( - self.in_conv_dim, - self.out_conv_dim, - kernel_size=config.conv_kernel[layer_id], - stride=config.conv_stride[layer_id], - bias=config.conv_bias, - ) - self.activation = ACT2FN[config.feat_extract_activation] - - self.layer_norm = nn.GroupNorm(num_groups=self.out_conv_dim, num_channels=self.out_conv_dim, affine=True) - - def forward(self, hidden_states): - hidden_states = self.conv(hidden_states) - hidden_states = self.layer_norm(hidden_states) - hidden_states = self.activation(hidden_states) - return hidden_states - - -# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2PositionalConvEmbedding with Wav2Vec2->WavLM class WavLMPositionalConvEmbedding(nn.Module): def __init__(self, config): super().__init__() @@ -313,75 +99,6 @@ class WavLMPositionalConvEmbedding(nn.Module): return hidden_states -# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2SamePadLayer with Wav2Vec2->WavLM -class WavLMSamePadLayer(nn.Module): - def __init__(self, num_conv_pos_embeddings): - super().__init__() - self.num_pad_remove = 1 if num_conv_pos_embeddings % 2 == 0 else 0 - - def forward(self, hidden_states): - if self.num_pad_remove > 0: - hidden_states = hidden_states[:, :, : -self.num_pad_remove] - return hidden_states - - -# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2FeatureEncoder with Wav2Vec2->WavLM -class WavLMFeatureEncoder(nn.Module): - """Construct the features from raw audio waveform""" - - def __init__(self, config): - super().__init__() - - if config.feat_extract_norm == "group": - conv_layers = [WavLMGroupNormConvLayer(config, layer_id=0)] + [ - WavLMNoLayerNormConvLayer(config, layer_id=i + 1) for i in range(config.num_feat_extract_layers - 1) - ] - elif config.feat_extract_norm == "layer": - conv_layers = [WavLMLayerNormConvLayer(config, layer_id=i) for i in range(config.num_feat_extract_layers)] - else: - raise ValueError( - f"`config.feat_extract_norm` is {config.feat_extract_norm}, but has to be one of ['group', 'layer']" - ) - self.conv_layers = nn.ModuleList(conv_layers) - self.gradient_checkpointing = False - self._requires_grad = True - - def _freeze_parameters(self): - for param in self.parameters(): - param.requires_grad = False - self._requires_grad = False - - def forward(self, input_values): - hidden_states = input_values[:, None] - - # make sure hidden_states require grad for gradient_checkpointing - if self._requires_grad and self.training: - hidden_states.requires_grad = True - - for conv_layer in self.conv_layers: - if self._requires_grad and self.gradient_checkpointing and self.training: - hidden_states = self._gradient_checkpointing_func( - conv_layer.__call__, - hidden_states, - ) - else: - hidden_states = conv_layer(hidden_states) - - return hidden_states - - -class WavLMFeatureExtractor(WavLMFeatureEncoder): - def __init__(self, config): - super().__init__(config) - warnings.warn( - f"The class `{self.__class__.__name__}` has been depreciated " - "and will be removed in Transformers v5. " - f"Use `{self.__class__.__bases__[0].__name__}` instead.", - FutureWarning, - ) - - -# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2FeatureProjection with Wav2Vec2->WavLM class WavLMFeatureProjection(nn.Module): def __init__(self, config): super().__init__() @@ -563,7 +280,6 @@ class WavLMAttention(nn.Module): return relative_buckets -# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2FeedForward with Wav2Vec2->WavLM class WavLMFeedForward(nn.Module): def __init__(self, config): super().__init__() @@ -903,57 +619,6 @@ class WavLMGumbelVectorQuantizer(nn.Module): return codevectors, perplexity -# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Adapter with Wav2Vec2->WavLM -class WavLMAdapter(nn.Module): - def __init__(self, config): - super().__init__() - - # feature dim might need to be down-projected - if config.output_hidden_size != config.hidden_size: - self.proj = nn.Linear(config.hidden_size, config.output_hidden_size) - self.proj_layer_norm = nn.LayerNorm(config.output_hidden_size) - else: - self.proj = self.proj_layer_norm = None - - self.layers = nn.ModuleList(WavLMAdapterLayer(config) for _ in range(config.num_adapter_layers)) - self.layerdrop = config.layerdrop - - def forward(self, hidden_states): - # down project hidden_states if necessary - if self.proj is not None and self.proj_layer_norm is not None: - hidden_states = self.proj(hidden_states) - hidden_states = self.proj_layer_norm(hidden_states) - - hidden_states = hidden_states.transpose(1, 2) - - for layer in self.layers: - layerdrop_prob = np.random.random() - if not self.training or (layerdrop_prob > self.layerdrop): - hidden_states = layer(hidden_states) - - hidden_states = hidden_states.transpose(1, 2) - return hidden_states - - -# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2AdapterLayer with Wav2Vec2->WavLM -class WavLMAdapterLayer(nn.Module): - def __init__(self, config): - super().__init__() - self.conv = nn.Conv1d( - config.output_hidden_size, - 2 * config.output_hidden_size, - config.adapter_kernel_size, - stride=config.adapter_stride, - padding=1, - ) - - def forward(self, hidden_states): - hidden_states = self.conv(hidden_states) - hidden_states = nn.functional.glu(hidden_states, dim=1) - - return hidden_states - - class WavLMPreTrainedModel(PreTrainedModel): """ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained @@ -964,6 +629,8 @@ class WavLMPreTrainedModel(PreTrainedModel): base_model_prefix = "wavlm" main_input_name = "input_values" supports_gradient_checkpointing = True + _supports_flash_attn_2 = False + _supports_sdpa = False def _init_weights(self, module): """Initialize the weights""" @@ -1042,6 +709,293 @@ class WavLMPreTrainedModel(PreTrainedModel): return attention_mask +class WavLMNoLayerNormConvLayer(nn.Module): + def __init__(self, config, layer_id=0): + super().__init__() + self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1 + self.out_conv_dim = config.conv_dim[layer_id] + + self.conv = nn.Conv1d( + self.in_conv_dim, + self.out_conv_dim, + kernel_size=config.conv_kernel[layer_id], + stride=config.conv_stride[layer_id], + bias=config.conv_bias, + ) + self.activation = ACT2FN[config.feat_extract_activation] + + def forward(self, hidden_states): + hidden_states = self.conv(hidden_states) + hidden_states = self.activation(hidden_states) + return hidden_states + + +class WavLMLayerNormConvLayer(nn.Module): + def __init__(self, config, layer_id=0): + super().__init__() + self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1 + self.out_conv_dim = config.conv_dim[layer_id] + + self.conv = nn.Conv1d( + self.in_conv_dim, + self.out_conv_dim, + kernel_size=config.conv_kernel[layer_id], + stride=config.conv_stride[layer_id], + bias=config.conv_bias, + ) + self.layer_norm = nn.LayerNorm(self.out_conv_dim, elementwise_affine=True) + self.activation = ACT2FN[config.feat_extract_activation] + + def forward(self, hidden_states): + hidden_states = self.conv(hidden_states) + + hidden_states = hidden_states.transpose(-2, -1) + hidden_states = self.layer_norm(hidden_states) + hidden_states = hidden_states.transpose(-2, -1) + + hidden_states = self.activation(hidden_states) + return hidden_states + + +class WavLMGroupNormConvLayer(nn.Module): + def __init__(self, config, layer_id=0): + super().__init__() + self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1 + self.out_conv_dim = config.conv_dim[layer_id] + + self.conv = nn.Conv1d( + self.in_conv_dim, + self.out_conv_dim, + kernel_size=config.conv_kernel[layer_id], + stride=config.conv_stride[layer_id], + bias=config.conv_bias, + ) + self.activation = ACT2FN[config.feat_extract_activation] + + self.layer_norm = nn.GroupNorm(num_groups=self.out_conv_dim, num_channels=self.out_conv_dim, affine=True) + + def forward(self, hidden_states): + hidden_states = self.conv(hidden_states) + hidden_states = self.layer_norm(hidden_states) + hidden_states = self.activation(hidden_states) + return hidden_states + + +class WavLMFeatureEncoder(nn.Module): + """Construct the features from raw audio waveform""" + + def __init__(self, config): + super().__init__() + + if config.feat_extract_norm == "group": + conv_layers = [WavLMGroupNormConvLayer(config, layer_id=0)] + [ + WavLMNoLayerNormConvLayer(config, layer_id=i + 1) for i in range(config.num_feat_extract_layers - 1) + ] + elif config.feat_extract_norm == "layer": + conv_layers = [WavLMLayerNormConvLayer(config, layer_id=i) for i in range(config.num_feat_extract_layers)] + else: + raise ValueError( + f"`config.feat_extract_norm` is {config.feat_extract_norm}, but has to be one of ['group', 'layer']" + ) + self.conv_layers = nn.ModuleList(conv_layers) + self.gradient_checkpointing = False + self._requires_grad = True + + def _freeze_parameters(self): + for param in self.parameters(): + param.requires_grad = False + self._requires_grad = False + + def forward(self, input_values): + hidden_states = input_values[:, None] + + # make sure hidden_states require grad for gradient_checkpointing + if self._requires_grad and self.training: + hidden_states.requires_grad = True + + for conv_layer in self.conv_layers: + if self._requires_grad and self.gradient_checkpointing and self.training: + hidden_states = self._gradient_checkpointing_func( + conv_layer.__call__, + hidden_states, + ) + else: + hidden_states = conv_layer(hidden_states) + + return hidden_states + + +class WavLMAdapterLayer(nn.Module): + def __init__(self, config): + super().__init__() + self.conv = nn.Conv1d( + config.output_hidden_size, + 2 * config.output_hidden_size, + config.adapter_kernel_size, + stride=config.adapter_stride, + padding=1, + ) + + def forward(self, hidden_states): + hidden_states = self.conv(hidden_states) + hidden_states = nn.functional.glu(hidden_states, dim=1) + + return hidden_states + + +class WavLMAdapter(nn.Module): + def __init__(self, config): + super().__init__() + + # feature dim might need to be down-projected + if config.output_hidden_size != config.hidden_size: + self.proj = nn.Linear(config.hidden_size, config.output_hidden_size) + self.proj_layer_norm = nn.LayerNorm(config.output_hidden_size) + else: + self.proj = self.proj_layer_norm = None + + self.layers = nn.ModuleList(WavLMAdapterLayer(config) for _ in range(config.num_adapter_layers)) + self.layerdrop = config.layerdrop + + def forward(self, hidden_states): + # down project hidden_states if necessary + if self.proj is not None and self.proj_layer_norm is not None: + hidden_states = self.proj(hidden_states) + hidden_states = self.proj_layer_norm(hidden_states) + + hidden_states = hidden_states.transpose(1, 2) + + for layer in self.layers: + layerdrop_prob = np.random.random() + if not self.training or (layerdrop_prob > self.layerdrop): + hidden_states = layer(hidden_states) + + hidden_states = hidden_states.transpose(1, 2) + return hidden_states + + +def _compute_mask_indices( + shape: Tuple[int, int], + mask_prob: float, + mask_length: int, + attention_mask: Optional[torch.LongTensor] = None, + min_masks: int = 0, +) -> np.ndarray: + """ + Computes random mask spans for a given shape. Used to implement [SpecAugment: A Simple Data Augmentation Method for + ASR](https://arxiv.org/abs/1904.08779). Note that this method is not optimized to run on TPU and should be run on + CPU as part of the preprocessing during training. + + Args: + shape: The shape for which to compute masks. This should be of a tuple of size 2 where + the first element is the batch size and the second element is the length of the axis to span. + mask_prob: The percentage of the whole axis (between 0 and 1) which will be masked. The number of + independently generated mask spans of length `mask_length` is computed by + `mask_prob*shape[1]/mask_length`. Note that due to overlaps, `mask_prob` is an upper bound and the + actual percentage will be smaller. + mask_length: size of the mask + min_masks: minimum number of masked spans + attention_mask: A (right-padded) attention mask which independently shortens the feature axis of + each batch dimension. + """ + batch_size, sequence_length = shape + + if mask_length < 1: + raise ValueError("`mask_length` has to be bigger than 0.") + + if mask_length > sequence_length: + raise ValueError( + f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length}" + f" and `sequence_length`: {sequence_length}`" + ) + + # epsilon is used for probabilistic rounding + epsilon = np.random.rand(1).item() + + def compute_num_masked_span(input_length): + """Given input length, compute how many spans should be masked""" + num_masked_span = int(mask_prob * input_length / mask_length + epsilon) + num_masked_span = max(num_masked_span, min_masks) + + # make sure num masked span <= sequence_length + if num_masked_span * mask_length > sequence_length: + num_masked_span = sequence_length // mask_length + + # make sure num_masked span is also <= input_length - (mask_length - 1) + if input_length - (mask_length - 1) < num_masked_span: + num_masked_span = max(input_length - (mask_length - 1), 0) + + return num_masked_span + + # compute number of masked spans in batch + input_lengths = ( + attention_mask.detach().sum(-1).tolist() + if attention_mask is not None + else [sequence_length for _ in range(batch_size)] + ) + + # SpecAugment mask to fill + spec_aug_mask = np.zeros((batch_size, sequence_length), dtype=bool) + spec_aug_mask_idxs = [] + + max_num_masked_span = compute_num_masked_span(sequence_length) + + if max_num_masked_span == 0: + return spec_aug_mask + + for input_length in input_lengths: + # compute num of masked spans for this input + num_masked_span = compute_num_masked_span(input_length) + + # get random indices to mask + spec_aug_mask_idx = np.random.choice( + np.arange(input_length - (mask_length - 1)), num_masked_span, replace=False + ) + + # pick first sampled index that will serve as a dummy index to pad vector + # to ensure same dimension for all batches due to probabilistic rounding + # Picking first sample just pads those vectors twice. + if len(spec_aug_mask_idx) == 0: + # this case can only happen if `input_length` is strictly smaller then + # `sequence_length` in which case the last token has to be a padding + # token which we can use as a dummy mask id + dummy_mask_idx = sequence_length - 1 + else: + dummy_mask_idx = spec_aug_mask_idx[0] + + spec_aug_mask_idx = np.concatenate( + [spec_aug_mask_idx, np.ones(max_num_masked_span - num_masked_span, dtype=np.int32) * dummy_mask_idx] + ) + spec_aug_mask_idxs.append(spec_aug_mask_idx) + + spec_aug_mask_idxs = np.array(spec_aug_mask_idxs) + + # expand masked indices to masked spans + spec_aug_mask_idxs = np.broadcast_to( + spec_aug_mask_idxs[:, :, None], (batch_size, max_num_masked_span, mask_length) + ) + spec_aug_mask_idxs = spec_aug_mask_idxs.reshape(batch_size, max_num_masked_span * mask_length) + + # add offset to the starting indexes so that indexes now create a span + offsets = np.arange(mask_length)[None, None, :] + offsets = np.broadcast_to(offsets, (batch_size, max_num_masked_span, mask_length)).reshape( + batch_size, max_num_masked_span * mask_length + ) + spec_aug_mask_idxs = spec_aug_mask_idxs + offsets + + # ensure that we cannot have indices larger than sequence_length + if spec_aug_mask_idxs.max() > sequence_length - 1: + spec_aug_mask_idxs[spec_aug_mask_idxs > sequence_length - 1] = sequence_length - 1 + + # scatter indices to mask + np.put_along_axis(spec_aug_mask, spec_aug_mask_idxs, 1, -1) + + return spec_aug_mask + + +_EXPECTED_OUTPUT_SHAPE = [1, 292, 768] + + WAVLM_START_DOCSTRING = r""" WavLM was proposed in [WavLM: Unified Speech Representation Learning with Labeled and Unlabeled Data](https://arxiv.org/abs/2110.13900) by Sanyuan Chen, Chengyi Wang, Zhengyang Chen, Yu Wu, Shujie Liu, Zhuo @@ -1061,7 +1015,6 @@ WAVLM_START_DOCSTRING = r""" configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. """ - WAVLM_INPUTS_DOCSTRING = r""" Args: input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`): @@ -1098,12 +1051,13 @@ WAVLM_INPUTS_DOCSTRING = r""" Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. """ +WavLMBaseModelOutput = Wav2Vec2BaseModelOutput + @add_start_docstrings( "The bare WavLM Model transformer outputting raw hidden-states without any specific head on top.", WAVLM_START_DOCSTRING, ) -# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Model with Wav2Vec2->WavLM, wav2vec2->wavlm, WAV_2_VEC_2->WAVLM, WavLMBaseModelOutput->Wav2Vec2BaseModelOutput class WavLMModel(WavLMPreTrainedModel): def __init__(self, config: WavLMConfig): super().__init__(config) @@ -1193,7 +1147,7 @@ class WavLMModel(WavLMPreTrainedModel): @add_start_docstrings_to_model_forward(WAVLM_INPUTS_DOCSTRING) @add_code_sample_docstrings( checkpoint=_CHECKPOINT_FOR_DOC, - output_type=Wav2Vec2BaseModelOutput, + output_type=WavLMBaseModelOutput, config_class=_CONFIG_FOR_DOC, modality="audio", expected_output=_EXPECTED_OUTPUT_SHAPE, @@ -1206,7 +1160,7 @@ class WavLMModel(WavLMPreTrainedModel): output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, - ) -> Union[Tuple, Wav2Vec2BaseModelOutput]: + ) -> Union[Tuple, WavLMBaseModelOutput]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states @@ -1243,7 +1197,7 @@ class WavLMModel(WavLMPreTrainedModel): if not return_dict: return (hidden_states, extract_features) + encoder_outputs[1:] - return Wav2Vec2BaseModelOutput( + return WavLMBaseModelOutput( last_hidden_state=hidden_states, extract_features=extract_features, hidden_states=encoder_outputs.hidden_states, @@ -1251,11 +1205,16 @@ class WavLMModel(WavLMPreTrainedModel): ) +_HIDDEN_STATES_START_POSITION = 2 + +_CTC_EXPECTED_OUTPUT = "'mister quilter is the aposle of the middle classes and we are glad to welcome his gospel'" +_CTC_EXPECTED_LOSS = 12.51 + + @add_start_docstrings( """WavLM Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC).""", WAVLM_START_DOCSTRING, ) -# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForCTC with Wav2Vec2->WavLM, wav2vec2->wavlm, WAV_2_VEC_2->WAVLM class WavLMForCTC(WavLMPreTrainedModel): def __init__(self, config, target_lang: Optional[str] = None): super().__init__(config) @@ -1432,7 +1391,6 @@ class WavLMForSequenceClassification(WavLMPreTrainedModel): # Initialize weights and apply final processing self.post_init() - # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.freeze_feature_extractor def freeze_feature_extractor(self): """ Calling this function will disable the gradient computation for the feature encoder so that its parameters will @@ -1445,7 +1403,6 @@ class WavLMForSequenceClassification(WavLMPreTrainedModel): ) self.freeze_feature_encoder() - # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.freeze_feature_encoder with wav2vec2->wavlm def freeze_feature_encoder(self): """ Calling this function will disable the gradient computation for the feature encoder so that its parameter will @@ -1453,7 +1410,6 @@ class WavLMForSequenceClassification(WavLMPreTrainedModel): """ self.wavlm.feature_extractor._freeze_parameters() - # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.freeze_base_model with wav2vec2->wavlm def freeze_base_model(self): """ Calling this function will disable the gradient computation for the base model so that its parameters will not @@ -1469,7 +1425,6 @@ class WavLMForSequenceClassification(WavLMPreTrainedModel): config_class=_CONFIG_FOR_DOC, modality="audio", ) - # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.forward with Wav2Vec2->WavLM, wav2vec2->wavlm def forward( self, input_values: Optional[torch.Tensor], @@ -1533,13 +1488,16 @@ class WavLMForSequenceClassification(WavLMPreTrainedModel): ) +_FRAME_CLASS_CHECKPOINT = "microsoft/wavlm-base-plus-sd" +_FRAME_EXPECTED_OUTPUT = [0, 0] + + @add_start_docstrings( """ WavLM Model with a frame classification head on top for tasks like Speaker Diarization. """, WAVLM_START_DOCSTRING, ) -# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForAudioFrameClassification with Wav2Vec2->WavLM, wav2vec2->wavlm, WAV_2_VEC_2->WAVLM class WavLMForAudioFrameClassification(WavLMPreTrainedModel): def __init__(self, config): super().__init__(config) @@ -1646,7 +1604,6 @@ class WavLMForAudioFrameClassification(WavLMPreTrainedModel): ) -# Copied from transformers.models.wav2vec2.modeling_wav2vec2.AMSoftmaxLoss class AMSoftmaxLoss(nn.Module): def __init__(self, input_dim, num_labels, scale=30.0, margin=0.4): super(AMSoftmaxLoss, self).__init__() @@ -1670,7 +1627,6 @@ class AMSoftmaxLoss(nn.Module): return loss -# Copied from transformers.models.wav2vec2.modeling_wav2vec2.TDNNLayer class TDNNLayer(nn.Module): def __init__(self, config, layer_id=0): super().__init__() @@ -1686,6 +1642,7 @@ class TDNNLayer(nn.Module): if is_peft_available(): from peft.tuners.lora import LoraLayer + if is_peft_available(): if isinstance(self.kernel, LoraLayer): warnings.warn( "Detected LoRA on TDNNLayer. LoRA weights won't be applied due to optimization. " @@ -1702,13 +1659,16 @@ class TDNNLayer(nn.Module): return hidden_states +_XVECTOR_CHECKPOINT = "microsoft/wavlm-base-plus-sv" +_XVECTOR_EXPECTED_OUTPUT = 0.97 + + @add_start_docstrings( """ WavLM Model with an XVector feature extraction head on top for tasks like Speaker Verification. """, WAVLM_START_DOCSTRING, ) -# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForXVector with Wav2Vec2->WavLM, wav2vec2->wavlm, WAV_2_VEC_2->WAVLM class WavLMForXVector(WavLMPreTrainedModel): def __init__(self, config): super().__init__(config) diff --git a/src/transformers/models/wavlm/modular_wavlm.py b/src/transformers/models/wavlm/modular_wavlm.py new file mode 100644 index 00000000000..9ae9170fec5 --- /dev/null +++ b/src/transformers/models/wavlm/modular_wavlm.py @@ -0,0 +1,758 @@ +import math +from typing import Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ...integrations.deepspeed import is_deepspeed_zero3_enabled +from ...integrations.fsdp import is_fsdp_managed_module +from ...modeling_outputs import ( + BaseModelOutput, + CausalLMOutput, + SequenceClassifierOutput, + TokenClassifierOutput, + Wav2Vec2BaseModelOutput, + XVectorOutput, +) +from ...modeling_utils import PreTrainedModel +from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging +from ..wav2vec2.modeling_wav2vec2 import ( + Wav2Vec2FeatureProjection, + Wav2Vec2FeedForward, + Wav2Vec2ForAudioFrameClassification, + Wav2Vec2ForCTC, + Wav2Vec2ForSequenceClassification, + Wav2Vec2ForXVector, + Wav2Vec2Model, + Wav2Vec2PositionalConvEmbedding, + Wav2Vec2PreTrainedModel, +) +from .configuration_wavlm import WavLMConfig + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "WavLMConfig" + +_CHECKPOINT_FOR_DOC = "patrickvonplaten/wavlm-libri-clean-100h-base-plus" +_EXPECTED_OUTPUT_SHAPE = [1, 292, 768] + +_CTC_EXPECTED_OUTPUT = "'mister quilter is the aposle of the middle classes and we are glad to welcome his gospel'" +_CTC_EXPECTED_LOSS = 12.51 + +_FRAME_CLASS_CHECKPOINT = "microsoft/wavlm-base-plus-sd" +_FRAME_EXPECTED_OUTPUT = [0, 0] + +_XVECTOR_CHECKPOINT = "microsoft/wavlm-base-plus-sv" +_XVECTOR_EXPECTED_OUTPUT = 0.97 + + +class WavLMPositionalConvEmbedding(Wav2Vec2PositionalConvEmbedding): + pass + + +class WavLMFeatureProjection(Wav2Vec2FeatureProjection): + pass + + +class WavLMAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__( + self, + embed_dim: int, + num_heads: int, + dropout: float = 0.0, + num_buckets: int = 320, + max_distance: int = 800, + has_relative_position_bias: bool = True, + ): + super().__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = embed_dim // num_heads + + if (self.head_dim * num_heads) != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" + f" and `num_heads`: {num_heads})." + ) + self.scaling = self.head_dim**-0.5 + + self.k_proj = nn.Linear(embed_dim, embed_dim) + self.v_proj = nn.Linear(embed_dim, embed_dim) + self.q_proj = nn.Linear(embed_dim, embed_dim) + self.out_proj = nn.Linear(embed_dim, embed_dim) + + self.num_buckets = num_buckets + self.max_distance = max_distance + + self.gru_rel_pos_const = nn.Parameter(torch.ones(1, self.num_heads, 1, 1)) + self.gru_rel_pos_linear = nn.Linear(self.head_dim, 8) + + if has_relative_position_bias: + self.rel_attn_embed = nn.Embedding(self.num_buckets, self.num_heads) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_bias: Optional[torch.Tensor] = None, + output_attentions: bool = False, + index=0, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Attention layer with relative attention""" + bsz, tgt_len, _ = hidden_states.size() + + # first pass of attention layer creates position bias + if position_bias is None: + position_bias = self.compute_bias(tgt_len, tgt_len) + position_bias = ( + position_bias.unsqueeze(0).repeat(bsz, 1, 1, 1).view(bsz * self.num_heads, tgt_len, tgt_len) + ) + + # Compute relative position bias: + # 1) get reshape hidden_states + gated_hidden_states = hidden_states.view(hidden_states.shape[:-1] + (self.num_heads, -1)) + gated_hidden_states = gated_hidden_states.permute(0, 2, 1, 3) + + # 2) project hidden states + relative_position_proj = self.gru_rel_pos_linear(gated_hidden_states) + relative_position_proj = relative_position_proj.view(gated_hidden_states.shape[:-1] + (2, 4)).sum(-1) + + # 3) compute gate for position bias from projected hidden states + gate_a, gate_b = torch.sigmoid(relative_position_proj).chunk(2, dim=-1) + gate_output = gate_a * (gate_b * self.gru_rel_pos_const - 1.0) + 2.0 + + # 4) apply gate to position bias to compute gated position_bias + gated_position_bias = gate_output.view(bsz * self.num_heads, -1, 1) * position_bias + gated_position_bias = gated_position_bias.view((-1, tgt_len, tgt_len)) + + attn_output, attn_weights = self.torch_multi_head_self_attention( + hidden_states, attention_mask, gated_position_bias, output_attentions + ) + + return attn_output, attn_weights, position_bias + + def torch_multi_head_self_attention( + self, + hidden_states: torch.FloatTensor, + attention_mask: Union[torch.LongTensor, torch.BoolTensor], + gated_position_bias: torch.FloatTensor, + output_attentions: bool, + ) -> (torch.FloatTensor, torch.FloatTensor): + """simple wrapper around torch's multi_head_attention_forward function""" + # self-attention assumes q = k = v + query = key = value = hidden_states.transpose(0, 1) + key_padding_mask = attention_mask.ne(1) if attention_mask is not None else None + + # disable bias and add_zero_attn + bias_k = bias_v = None + add_zero_attn = False + + # PyTorch 1.3.0 has F.multi_head_attention_forward defined + # so no problem with backwards compatibility + attn_output, attn_weights = F.multi_head_attention_forward( + query, + key, + value, + self.embed_dim, + self.num_heads, + torch.empty([0]), + torch.cat((self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)), + bias_k, + bias_v, + add_zero_attn, + self.dropout, + self.out_proj.weight, + self.out_proj.bias, + self.training, + key_padding_mask, + output_attentions, + gated_position_bias, + use_separate_proj_weight=True, + q_proj_weight=self.q_proj.weight, + k_proj_weight=self.k_proj.weight, + v_proj_weight=self.v_proj.weight, + ) + + # [Seq_Len, Batch Size, ...] -> [Batch Size, Seq_Len, ...] + attn_output = attn_output.transpose(0, 1) + + if attn_weights is not None: + # IMPORTANT: Attention weights are averaged weights + # here which should not be the case. This is an open issue + # on PyTorch: https://github.com/pytorch/pytorch/issues/32590 + attn_weights = attn_weights[:, None].broadcast_to( + attn_weights.shape[:1] + (self.num_heads,) + attn_weights.shape[1:] + ) + + return attn_output, attn_weights + + def compute_bias(self, query_length: int, key_length: int) -> torch.FloatTensor: + context_position = torch.arange(query_length, dtype=torch.long)[:, None] + memory_position = torch.arange(key_length, dtype=torch.long)[None, :] + relative_position = memory_position - context_position + relative_position_bucket = self._relative_positions_bucket(relative_position) + relative_position_bucket = relative_position_bucket.to(self.rel_attn_embed.weight.device) + values = self.rel_attn_embed(relative_position_bucket) + values = values.permute([2, 0, 1]) + return values + + def _relative_positions_bucket(self, relative_positions: torch.FloatTensor) -> torch.FloatTensor: + num_buckets = self.num_buckets // 2 + + relative_buckets = (relative_positions > 0).to(torch.long) * num_buckets + relative_positions = torch.abs(relative_positions) + + max_exact = num_buckets // 2 + is_small = relative_positions < max_exact + + relative_positions_if_large = torch.log(relative_positions.float() / max_exact) + relative_positions_if_large = relative_positions_if_large / math.log(self.max_distance / max_exact) + relative_positions_if_large = relative_positions_if_large * (num_buckets - max_exact) + relative_position_if_large = (max_exact + relative_positions_if_large).to(torch.long) + relative_position_if_large = torch.min( + relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1) + ) + + relative_buckets += torch.where(is_small, relative_positions, relative_position_if_large) + return relative_buckets + + +class WavLMFeedForward(Wav2Vec2FeedForward): + pass + + +class WavLMEncoderLayer(nn.Module): + def __init__(self, config: WavLMConfig, has_relative_position_bias: bool = True): + super().__init__() + self.attention = WavLMAttention( + embed_dim=config.hidden_size, + num_heads=config.num_attention_heads, + dropout=config.attention_dropout, + num_buckets=config.num_buckets, + max_distance=config.max_bucket_distance, + has_relative_position_bias=has_relative_position_bias, + ) + self.dropout = nn.Dropout(config.hidden_dropout) + self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.feed_forward = WavLMFeedForward(config) + self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + def forward(self, hidden_states, attention_mask=None, position_bias=None, output_attentions=False, index=0): + attn_residual = hidden_states + hidden_states, attn_weights, position_bias = self.attention( + hidden_states, + attention_mask=attention_mask, + position_bias=position_bias, + output_attentions=output_attentions, + index=index, + ) + hidden_states = self.dropout(hidden_states) + hidden_states = attn_residual + hidden_states + + hidden_states = self.layer_norm(hidden_states) + + hidden_states = hidden_states + self.feed_forward(hidden_states) + hidden_states = self.final_layer_norm(hidden_states) + + outputs = (hidden_states, position_bias) + + if output_attentions: + outputs += (attn_weights,) + + return outputs + + +class WavLMEncoderLayerStableLayerNorm(nn.Module): + def __init__(self, config: WavLMConfig, has_relative_position_bias: bool = True): + super().__init__() + self.attention = WavLMAttention( + embed_dim=config.hidden_size, + num_heads=config.num_attention_heads, + dropout=config.attention_dropout, + num_buckets=config.num_buckets, + max_distance=config.max_bucket_distance, + has_relative_position_bias=has_relative_position_bias, + ) + self.dropout = nn.Dropout(config.hidden_dropout) + self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.feed_forward = WavLMFeedForward(config) + self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + def forward(self, hidden_states, attention_mask=None, position_bias=None, output_attentions=False): + attn_residual = hidden_states + hidden_states = self.layer_norm(hidden_states) + hidden_states, attn_weights, position_bias = self.attention( + hidden_states, + attention_mask=attention_mask, + position_bias=position_bias, + output_attentions=output_attentions, + ) + hidden_states = self.dropout(hidden_states) + hidden_states = attn_residual + hidden_states + hidden_states = hidden_states + self.feed_forward(self.final_layer_norm(hidden_states)) + + outputs = (hidden_states, position_bias) + + if output_attentions: + outputs += (attn_weights,) + + return outputs + + +class WavLMEncoder(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.pos_conv_embed = WavLMPositionalConvEmbedding(config) + self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout) + self.layers = nn.ModuleList( + [WavLMEncoderLayer(config, has_relative_position_bias=(i == 0)) for i in range(config.num_hidden_layers)] + ) + self.gradient_checkpointing = False + + def forward( + self, + hidden_states, + attention_mask=None, + output_attentions=False, + output_hidden_states=False, + return_dict=True, + ): + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + + if attention_mask is not None: + # make sure padded tokens output 0 + expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2]) + hidden_states[~expand_attention_mask] = 0 + + position_embeddings = self.pos_conv_embed(hidden_states) + hidden_states = hidden_states + position_embeddings + hidden_states = self.layer_norm(hidden_states) + hidden_states = self.dropout(hidden_states) + + synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self) + position_bias = None + + for i, layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + dropout_probability = torch.rand([]) + + skip_the_layer = self.training and i > 0 and (dropout_probability < self.config.layerdrop) + if not skip_the_layer or synced_gpus: + # under fsdp or deepspeed zero3 all gpus must run in sync + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + layer.__call__, + hidden_states, + attention_mask, + position_bias, + output_attentions, + ) + else: + layer_outputs = layer( + hidden_states, + attention_mask=attention_mask, + position_bias=position_bias, + output_attentions=output_attentions, + index=i, + ) + + hidden_states, position_bias = layer_outputs[:2] + + if skip_the_layer: + layer_outputs = (None, None, None) + + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[2],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + +class WavLMEncoderStableLayerNorm(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.pos_conv_embed = WavLMPositionalConvEmbedding(config) + self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout) + self.layers = nn.ModuleList( + [ + WavLMEncoderLayerStableLayerNorm(config, has_relative_position_bias=(i == 0)) + for i in range(config.num_hidden_layers) + ] + ) + self.gradient_checkpointing = False + + def forward( + self, + hidden_states, + attention_mask=None, + output_attentions=False, + output_hidden_states=False, + return_dict=True, + ): + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + + if attention_mask is not None: + # make sure padded tokens are not attended to + expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2]) + hidden_states[~expand_attention_mask] = 0 + + position_embeddings = self.pos_conv_embed(hidden_states) + hidden_states = hidden_states + position_embeddings + hidden_states = self.dropout(hidden_states) + + synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self) + position_bias = None + + for i, layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + dropout_probability = torch.rand([]) + + skip_the_layer = self.training and i > 0 and (dropout_probability < self.config.layerdrop) + if not skip_the_layer or synced_gpus: + # under fsdp or deepspeed zero3 all gpus must run in sync + # XXX: could optimize this like synced_gpus in generate_utils but not sure if it's worth the code complication + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + layer.__call__, + hidden_states, + attention_mask, + position_bias, + output_attentions, + ) + else: + layer_outputs = layer( + hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + position_bias=position_bias, + ) + hidden_states, position_bias = layer_outputs[:2] + + if skip_the_layer: + layer_outputs = (None, None, None) + + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[2],) + + hidden_states = self.layer_norm(hidden_states) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_self_attentions + ) + + +class WavLMGumbelVectorQuantizer(nn.Module): + """ + Vector quantization using gumbel softmax. See [CATEGORICAL REPARAMETERIZATION WITH + GUMBEL-SOFTMAX](https://arxiv.org/pdf/1611.01144.pdf) for more information. + """ + + def __init__(self, config): + super().__init__() + self.num_groups = config.num_codevector_groups + self.num_vars = config.num_codevectors_per_group + + if config.codevector_dim % self.num_groups != 0: + raise ValueError( + f"`config.codevector_dim {config.codevector_dim} must be divisible" + f" by `config.num_codevector_groups` {self.num_groups} " + "for concatenation." + ) + + # storage for codebook variables (codewords) + self.codevectors = nn.Parameter( + torch.FloatTensor(1, self.num_groups * self.num_vars, config.codevector_dim // self.num_groups) + ) + self.weight_proj = nn.Linear(config.conv_dim[-1], self.num_groups * self.num_vars) + + # can be decayed for training + self.temperature = 2 + + @staticmethod + def _compute_perplexity(probs): + marginal_probs = probs.mean(dim=0) + perplexity = torch.exp(-torch.sum(marginal_probs * torch.log(marginal_probs + 1e-7), dim=-1)).sum() + return perplexity + + def forward(self, hidden_states): + batch_size, sequence_length, hidden_size = hidden_states.shape + + # project to codevector dim + hidden_states = self.weight_proj(hidden_states) + hidden_states = hidden_states.view(batch_size * sequence_length * self.num_groups, -1) + + if self.training: + # sample code vector probs via gumbel in differentiateable way + codevector_probs = nn.functional.gumbel_softmax(hidden_states.float(), tau=self.temperature, hard=True) + codevector_probs = codevector_probs.type_as(hidden_states) + + # compute perplexity + codevector_soft_dist = torch.softmax( + hidden_states.view(batch_size * sequence_length, self.num_groups, -1).float(), dim=-1 + ) + perplexity = self._compute_perplexity(codevector_soft_dist) + else: + # take argmax in non-differentiable way + # comptute hard codevector distribution (one hot) + codevector_idx = hidden_states.argmax(dim=-1) + codevector_probs = hidden_states.new_zeros(*hidden_states.shape).scatter_( + -1, codevector_idx.view(-1, 1), 1.0 + ) + codevector_probs = codevector_probs.view(batch_size * sequence_length, self.num_groups, -1) + + perplexity = self._compute_perplexity(codevector_probs) + + codevector_probs = codevector_probs.view(batch_size * sequence_length, -1) + # use probs to retrieve codevectors + codevectors_per_group = codevector_probs.unsqueeze(-1) * self.codevectors + codevectors = codevectors_per_group.view(batch_size * sequence_length, self.num_groups, self.num_vars, -1) + codevectors = codevectors.sum(-2).view(batch_size, sequence_length, -1) + + return codevectors, perplexity + + +class WavLMPreTrainedModel(PreTrainedModel, Wav2Vec2PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = WavLMConfig + base_model_prefix = "wavlm" + main_input_name = "input_values" + supports_gradient_checkpointing = True + _supports_flash_attn_2 = False + _supports_sdpa = False + + def _init_weights(self, module): + """Initialize the weights""" + # gumbel softmax requires special init + if isinstance(module, WavLMGumbelVectorQuantizer): + module.weight_proj.weight.data.normal_(mean=0.0, std=1) + module.weight_proj.bias.data.zero_() + nn.init.uniform_(module.codevectors) + elif isinstance(module, WavLMPositionalConvEmbedding): + nn.init.normal_( + module.conv.weight, + mean=0, + std=2 * math.sqrt(1 / (module.conv.kernel_size[0] * module.conv.in_channels)), + ) + nn.init.constant_(module.conv.bias, 0) + elif isinstance(module, WavLMFeatureProjection): + k = math.sqrt(1 / module.projection.in_features) + nn.init.uniform_(module.projection.weight, a=-k, b=k) + nn.init.uniform_(module.projection.bias, a=-k, b=k) + elif isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + elif isinstance(module, nn.Conv1d): + nn.init.kaiming_normal_(module.weight) + + if module.bias is not None: + k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0])) + nn.init.uniform_(module.bias, a=-k, b=k) + + def _get_adapters(self): + raise AttributeError("Not needed for WavLM") + + def init_adapter_layers(self): + raise AttributeError("Not needed for WavLM") + + def load_adapter(self): + raise AttributeError("Not needed for WavLM") + + +WAVLM_START_DOCSTRING = r""" + WavLM was proposed in [WavLM: Unified Speech Representation Learning with Labeled and Unlabeled + Data](https://arxiv.org/abs/2110.13900) by Sanyuan Chen, Chengyi Wang, Zhengyang Chen, Yu Wu, Shujie Liu, Zhuo + Chen, Jinyu Li, Naoyuki Kanda, Takuya Yoshioka, Xiong Xiao, Jian Wu, Long Zhou, Shuo Ren, Yanmin Qian, Yao Qian, + Jian Wu, Michael Zeng, Xiangzhan Yu, Furu Wei. + + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving etc.). + + This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use + it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and + behavior. + + Parameters: + config ([`WavLMConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +WAVLM_INPUTS_DOCSTRING = r""" + Args: + input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`): + Float values of input raw speech waveform. Values can be obtained by loading a `.flac` or `.wav` audio file + into an array of type `List[float]` or a `numpy.ndarray`, *e.g.* via the soundfile library (`pip install + soundfile`). To prepare the array into `input_values`, the [`AutoProcessor`] should be used for padding and + conversion into a tensor of type `torch.FloatTensor`. See [`Wav2Vec2Processor.__call__`] for details. + attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing convolution and attention on padding token indices. Mask values selected in `[0, + 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + + + `attention_mask` should only be passed if the corresponding processor has `config.return_attention_mask == + True`. For all models whose processor has `config.return_attention_mask == False`, `attention_mask` should + **not** be passed to avoid degraded performance when doing batched inference. For such models + `input_values` should simply be padded with 0 and passed without `attention_mask`. Be aware that these + models also yield slightly different results depending on whether `input_values` is padded or not. + + + + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + +WavLMBaseModelOutput = Wav2Vec2BaseModelOutput + + +@add_start_docstrings( + "The bare WavLM Model transformer outputting raw hidden-states without any specific head on top.", + WAVLM_START_DOCSTRING, +) +class WavLMModel(Wav2Vec2Model): + @add_start_docstrings_to_model_forward(WAVLM_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=WavLMBaseModelOutput, + config_class=_CONFIG_FOR_DOC, + modality="audio", + expected_output=_EXPECTED_OUTPUT_SHAPE, + ) + def forward(self, **super_kwargs): + return super().forward(**super_kwargs) + + +@add_start_docstrings( + """WavLM Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC).""", + WAVLM_START_DOCSTRING, +) +class WavLMForCTC(Wav2Vec2ForCTC): + @add_start_docstrings_to_model_forward(WAVLM_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=CausalLMOutput, + config_class=_CONFIG_FOR_DOC, + expected_output=_CTC_EXPECTED_OUTPUT, + expected_loss=_CTC_EXPECTED_LOSS, + ) + def forward(self, **super_kwargs): + super().forward(**super_kwargs) + + +@add_start_docstrings( + """ + WavLM Model with a sequence classification head on top (a linear layer over the pooled output) for tasks like + SUPERB Keyword Spotting. + """, + WAVLM_START_DOCSTRING, +) +class WavLMForSequenceClassification(Wav2Vec2ForSequenceClassification): + @add_start_docstrings_to_model_forward(WAVLM_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=SequenceClassifierOutput, + config_class=_CONFIG_FOR_DOC, + modality="audio", + ) + def forward(self, **super_kwargs): + super().forward(**super_kwargs) + + +@add_start_docstrings( + """ + WavLM Model with a frame classification head on top for tasks like Speaker Diarization. + """, + WAVLM_START_DOCSTRING, +) +class WavLMForAudioFrameClassification(Wav2Vec2ForAudioFrameClassification): + @add_start_docstrings_to_model_forward(WAVLM_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_FRAME_CLASS_CHECKPOINT, + output_type=TokenClassifierOutput, + config_class=_CONFIG_FOR_DOC, + modality="audio", + expected_output=_FRAME_EXPECTED_OUTPUT, + ) + def forward(self, **super_kwargs): + super().forward(**super_kwargs) + + +@add_start_docstrings( + """ + WavLM Model with an XVector feature extraction head on top for tasks like Speaker Verification. + """, + WAVLM_START_DOCSTRING, +) +class WavLMForXVector(Wav2Vec2ForXVector): + pass + + @add_start_docstrings_to_model_forward(WAVLM_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_XVECTOR_CHECKPOINT, + output_type=XVectorOutput, + config_class=_CONFIG_FOR_DOC, + modality="audio", + expected_output=_XVECTOR_EXPECTED_OUTPUT, + ) + def forward(self, **super_kwargs): + super().forward(**super_kwargs) + + +__all__ = [ + "WavLMForAudioFrameClassification", + "WavLMForCTC", + "WavLMForSequenceClassification", + "WavLMForXVector", + "WavLMModel", + "WavLMPreTrainedModel", +] diff --git a/tests/models/wav2vec2_bert/test_modeling_wav2vec2_bert.py b/tests/models/wav2vec2_bert/test_modeling_wav2vec2_bert.py index 4cce51d7ace..7f1e1b92242 100644 --- a/tests/models/wav2vec2_bert/test_modeling_wav2vec2_bert.py +++ b/tests/models/wav2vec2_bert/test_modeling_wav2vec2_bert.py @@ -50,9 +50,9 @@ if is_torch_available(): Wav2Vec2BertForXVector, Wav2Vec2BertModel, ) + from transformers.models.wav2vec2.modeling_wav2vec2 import _sample_negative_indices from transformers.models.wav2vec2_bert.modeling_wav2vec2_bert import ( _compute_mask_indices, - _sample_negative_indices, ) diff --git a/tests/models/wav2vec2_conformer/test_modeling_wav2vec2_conformer.py b/tests/models/wav2vec2_conformer/test_modeling_wav2vec2_conformer.py index f276b13d7be..6a884ba36ba 100644 --- a/tests/models/wav2vec2_conformer/test_modeling_wav2vec2_conformer.py +++ b/tests/models/wav2vec2_conformer/test_modeling_wav2vec2_conformer.py @@ -55,10 +55,10 @@ if is_torch_available(): Wav2Vec2FeatureExtractor, Wav2Vec2Processor, ) + from transformers.models.wav2vec2.modeling_wav2vec2 import _sample_negative_indices from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer import ( Wav2Vec2ConformerGumbelVectorQuantizer, _compute_mask_indices, - _sample_negative_indices, ) diff --git a/utils/modular_model_converter.py b/utils/modular_model_converter.py index 55e71c4cc91..0962056270d 100644 --- a/utils/modular_model_converter.py +++ b/utils/modular_model_converter.py @@ -79,8 +79,11 @@ def preserve_case_replace(text, patterns: dict, default_name: str): def get_cased_name(lowercase_name: str) -> str: """From a model name in lowercase in the format `my_model`, return the cased name in the format `MyModel`.""" + alt_lowercase_name = lowercase_name.replace("_", "-") if lowercase_name in CONFIG_MAPPING_NAMES: return CONFIG_MAPPING_NAMES[lowercase_name].replace("Config", "") + elif alt_lowercase_name in CONFIG_MAPPING_NAMES: + return CONFIG_MAPPING_NAMES[alt_lowercase_name].replace("Config", "") else: return "".join(x.title() for x in lowercase_name.split("_")) @@ -106,6 +109,8 @@ class ReplaceNameTransformer(m.MatcherDecoratableTransformer): def __init__(self, old_name: str, new_name: str, original_new_model_name: str = "", only_doc: bool = False): super().__init__() + old_name = old_name.replace("-", "_") + new_name = new_name.replace("-", "_") self.old_name = old_name self.new_name = new_name self.cased_new_name = get_cased_name(self.new_name) @@ -535,7 +540,7 @@ def find_all_dependencies( # Top-level variables that match the following patterns will always use the value in the `modular_xxx.py` file -ASSIGNMENTS_REGEX_TO_KEEP = [r"_CHECKPOINT", r"_EXPECTED", r"_FOR_DOC"] +ASSIGNMENTS_REGEX_TO_KEEP = [r"_CHECKPOINT", r"_EXPECTED", r"_FOR_DOC", r"_HIDDEN_STATES_START_POSITION"] # Top-level variables that match the following patterns will use the value in the `modular_xxx.py` file only if they are not None ASSIGNMENTS_REGEX_TO_KEEP_IF_NOT_NONE = [r"_DOCSTRING"] @@ -616,6 +621,7 @@ class ModuleMapper(CSTVisitor, ABC): self.object_dependency_mapping = defaultdict(set) # immediate function/assignment dependency mapping (i.e. dependencies immediately in the function/assignment definition) self.assignments: Dict[str, cst.SimpleStatementLine] = {} # mapping of global assignments names to Nodes self.current_function = None # this keeps track of the current module-scope function + self.current_class = None # this keeps track of the current module-scope class self.current_assignment = None # this keeps track of the current module-scope assignment # this keeps track of objects imported from modeling files (`from .configuration import Config`) -> `Config` should not be a dependency self.objects_imported_from_modeling = set() @@ -672,7 +678,7 @@ class ModuleMapper(CSTVisitor, ABC): def visit_If(self, node): # If we are inside a function, do not add the import to the list of imports - if self.current_function is None: + if self.current_function is None and self.current_class is None: for stmt in node.body.body: if m.matches(stmt, m.SimpleStatementLine(body=[m.ImportFrom() | m.Import()])): self.imports.append(node) @@ -680,6 +686,10 @@ class ModuleMapper(CSTVisitor, ABC): def visit_ClassDef(self, node: ClassDef) -> None: """Record class nodes to create their dependencies at the end.""" self.classes[node.name.value] = node + self.current_class = node.name.value + + def leave_ClassDef(self, node): + self.current_class = None def visit_Name(self, node: cst.Call): """This is used to create a mapping from module-scope functions and assignments to objects used inside them.""" @@ -1024,11 +1034,20 @@ def replace_class_node( new_decorators = ( updated_methods[name].decorators if len(updated_methods[name].decorators) > 0 else func.decorators ) + + # Keep return annotation in `modular_xxx.py` if any, else original return annotation + new_return_annotation = updated_methods[name].returns if updated_methods[name].returns else func.returns + if not re.match( r"\ndef .*\(.*\):\n raise.*Error\(.*", mapper.python_module.code_for_node(updated_methods[name]), ): - func = func.with_changes(body=updated_methods[name].body, params=new_params, decorators=new_decorators) + func = func.with_changes( + body=updated_methods[name].body, + params=new_params, + decorators=new_decorators, + returns=new_return_annotation, + ) else: continue @@ -1136,7 +1155,7 @@ def append_new_import_node( import_node = node.body[0] names_to_keep = [] for name in import_node.names: - name_value = name.evaluated_name + name_value = name.evaluated_alias or name.evaluated_name if name_value not in unused_imports and name_value not in added_names: names_to_keep.append(name.with_changes(comma=cst.MaybeSentinel.DEFAULT)) added_names.add(name_value)