diff --git a/setup.py b/setup.py
index 6993962bf9f..d3271b5e54a 100644
--- a/setup.py
+++ b/setup.py
@@ -307,7 +307,12 @@ extras["hub-kernels"] = deps_list("kernels")
extras["integrations"] = extras["hub-kernels"] + extras["optuna"] + extras["ray"] + extras["sigopt"]
extras["serving"] = deps_list("pydantic", "uvicorn", "fastapi", "starlette")
-extras["audio"] = deps_list("librosa", "pyctcdecode", "phonemizer", "kenlm@git+https://github.com/ydshieh/kenlm@78f664fb3dafe1468d868d71faf19534530698d5")
+extras["audio"] = deps_list(
+ "librosa",
+ "pyctcdecode",
+ "phonemizer",
+ "kenlm@git+https://github.com/ydshieh/kenlm@78f664fb3dafe1468d868d71faf19534530698d5",
+)
# `pip install ".[speech]"` is deprecated and `pip install ".[torch-speech]"` should be used instead
extras["speech"] = deps_list("torchaudio") + extras["audio"]
extras["torch-speech"] = deps_list("torchaudio") + extras["audio"]
diff --git a/src/transformers/models/cohere2/modular_cohere2.py b/src/transformers/models/cohere2/modular_cohere2.py
index 7c73f3e185d..3d1bdaeca94 100644
--- a/src/transformers/models/cohere2/modular_cohere2.py
+++ b/src/transformers/models/cohere2/modular_cohere2.py
@@ -14,7 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
-from typing import Callable, Optional, Tuple, Union
+from typing import Callable, Optional, Tuple
import torch
import torch.nn as nn
@@ -465,7 +465,7 @@ class Cohere2Model(Gemma2Model):
cache_position: Optional[torch.LongTensor] = None,
last_cache_position: Optional[int] = None,
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
- ) -> Union[Tuple, BaseModelOutputWithPast]:
+ ) -> BaseModelOutputWithPast:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
diff --git a/src/transformers/models/data2vec/modeling_data2vec_audio.py b/src/transformers/models/data2vec/modeling_data2vec_audio.py
index f3ee5e11fc6..7f799e81297 100755
--- a/src/transformers/models/data2vec/modeling_data2vec_audio.py
+++ b/src/transformers/models/data2vec/modeling_data2vec_audio.py
@@ -1,26 +1,15 @@
-# coding=utf-8
-# Copyright 2021 The Fairseq Authors and the HuggingFace Inc. team. All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-"""PyTorch Data2VecAudio model."""
-
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/data2vec/modular_data2vec_audio.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_data2vec_audio.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
import math
import warnings
from typing import Optional, Tuple, Union
import numpy as np
import torch
-import torch.utils.checkpoint
from torch import nn
from torch.nn import CrossEntropyLoss
@@ -50,141 +39,14 @@ from .configuration_data2vec_audio import Data2VecAudioConfig
if is_flash_attn_available():
from ...modeling_flash_attention_utils import _flash_attention_forward
+
logger = logging.get_logger(__name__)
-
-_HIDDEN_STATES_START_POSITION = 2
-
-# General docstring
-_CONFIG_FOR_DOC = "Data2VecAudioConfig"
-
# Base docstring
_CHECKPOINT_FOR_DOC = "facebook/data2vec-audio-base-960h"
-_EXPECTED_OUTPUT_SHAPE = [1, 292, 768]
-# CTC docstring
-_CTC_EXPECTED_OUTPUT = "'MISTER QUILTER IS THE APOSTLE OF THE MIDDLE CLASSES AND WE ARE GLAD TO WELCOME HIS GOSPEL'"
-_CTC_EXPECTED_LOSS = 66.95
-
-
-# Copied from transformers.models.wav2vec2.modeling_wav2vec2._compute_mask_indices
-def _compute_mask_indices(
- shape: Tuple[int, int],
- mask_prob: float,
- mask_length: int,
- attention_mask: Optional[torch.LongTensor] = None,
- min_masks: int = 0,
-) -> np.ndarray:
- """
- Computes random mask spans for a given shape. Used to implement [SpecAugment: A Simple Data Augmentation Method for
- ASR](https://arxiv.org/abs/1904.08779). Note that this method is not optimized to run on TPU and should be run on
- CPU as part of the preprocessing during training.
-
- Args:
- shape: The shape for which to compute masks. This should be of a tuple of size 2 where
- the first element is the batch size and the second element is the length of the axis to span.
- mask_prob: The percentage of the whole axis (between 0 and 1) which will be masked. The number of
- independently generated mask spans of length `mask_length` is computed by
- `mask_prob*shape[1]/mask_length`. Note that due to overlaps, `mask_prob` is an upper bound and the
- actual percentage will be smaller.
- mask_length: size of the mask
- min_masks: minimum number of masked spans
- attention_mask: A (right-padded) attention mask which independently shortens the feature axis of
- each batch dimension.
- """
- batch_size, sequence_length = shape
-
- if mask_length < 1:
- raise ValueError("`mask_length` has to be bigger than 0.")
-
- if mask_length > sequence_length:
- raise ValueError(
- f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length}"
- f" and `sequence_length`: {sequence_length}`"
- )
-
- # epsilon is used for probabilistic rounding
- epsilon = np.random.rand(1).item()
-
- def compute_num_masked_span(input_length):
- """Given input length, compute how many spans should be masked"""
- num_masked_span = int(mask_prob * input_length / mask_length + epsilon)
- num_masked_span = max(num_masked_span, min_masks)
-
- # make sure num masked span <= sequence_length
- if num_masked_span * mask_length > sequence_length:
- num_masked_span = sequence_length // mask_length
-
- # make sure num_masked span is also <= input_length - (mask_length - 1)
- if input_length - (mask_length - 1) < num_masked_span:
- num_masked_span = max(input_length - (mask_length - 1), 0)
-
- return num_masked_span
-
- # compute number of masked spans in batch
- input_lengths = (
- attention_mask.detach().sum(-1).tolist()
- if attention_mask is not None
- else [sequence_length for _ in range(batch_size)]
- )
-
- # SpecAugment mask to fill
- spec_aug_mask = np.zeros((batch_size, sequence_length), dtype=bool)
- spec_aug_mask_idxs = []
-
- max_num_masked_span = compute_num_masked_span(sequence_length)
-
- if max_num_masked_span == 0:
- return spec_aug_mask
-
- for input_length in input_lengths:
- # compute num of masked spans for this input
- num_masked_span = compute_num_masked_span(input_length)
-
- # get random indices to mask
- spec_aug_mask_idx = np.random.choice(
- np.arange(input_length - (mask_length - 1)), num_masked_span, replace=False
- )
-
- # pick first sampled index that will serve as a dummy index to pad vector
- # to ensure same dimension for all batches due to probabilistic rounding
- # Picking first sample just pads those vectors twice.
- if len(spec_aug_mask_idx) == 0:
- # this case can only happen if `input_length` is strictly smaller then
- # `sequence_length` in which case the last token has to be a padding
- # token which we can use as a dummy mask id
- dummy_mask_idx = sequence_length - 1
- else:
- dummy_mask_idx = spec_aug_mask_idx[0]
-
- spec_aug_mask_idx = np.concatenate(
- [spec_aug_mask_idx, np.ones(max_num_masked_span - num_masked_span, dtype=np.int32) * dummy_mask_idx]
- )
- spec_aug_mask_idxs.append(spec_aug_mask_idx)
-
- spec_aug_mask_idxs = np.array(spec_aug_mask_idxs)
-
- # expand masked indices to masked spans
- spec_aug_mask_idxs = np.broadcast_to(
- spec_aug_mask_idxs[:, :, None], (batch_size, max_num_masked_span, mask_length)
- )
- spec_aug_mask_idxs = spec_aug_mask_idxs.reshape(batch_size, max_num_masked_span * mask_length)
-
- # add offset to the starting indexes so that indexes now create a span
- offsets = np.arange(mask_length)[None, None, :]
- offsets = np.broadcast_to(offsets, (batch_size, max_num_masked_span, mask_length)).reshape(
- batch_size, max_num_masked_span * mask_length
- )
- spec_aug_mask_idxs = spec_aug_mask_idxs + offsets
-
- # ensure that we cannot have indices larger than sequence_length
- if spec_aug_mask_idxs.max() > sequence_length - 1:
- spec_aug_mask_idxs[spec_aug_mask_idxs > sequence_length - 1] = sequence_length - 1
-
- # scatter indices to mask
- np.put_along_axis(spec_aug_mask, spec_aug_mask_idxs, 1, -1)
-
- return spec_aug_mask
+# General docstring
+_CONFIG_FOR_DOC = "Data2VecAudioConfig"
class Data2VecAudioConvLayer(nn.Module):
@@ -214,7 +76,6 @@ class Data2VecAudioConvLayer(nn.Module):
return hidden_states
-# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2SamePadLayer with Wav2Vec2->Data2VecAudio
class Data2VecAudioPadLayer(nn.Module):
def __init__(self, num_conv_pos_embeddings):
super().__init__()
@@ -279,13 +140,11 @@ class Data2VecAudioFeatureEncoder(nn.Module):
self.gradient_checkpointing = False
self._requires_grad = True
- # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2FeatureEncoder._freeze_parameters
def _freeze_parameters(self):
for param in self.parameters():
param.requires_grad = False
self._requires_grad = False
- # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2FeatureEncoder.forward
def forward(self, input_values):
hidden_states = input_values[:, None]
@@ -305,7 +164,6 @@ class Data2VecAudioFeatureEncoder(nn.Module):
return hidden_states
-# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2FeatureProjection with Wav2Vec2->Data2VecAudio
class Data2VecAudioFeatureProjection(nn.Module):
def __init__(self, config):
super().__init__()
@@ -321,7 +179,6 @@ class Data2VecAudioFeatureProjection(nn.Module):
return hidden_states, norm_hidden_states
-# Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->Data2VecAudio
class Data2VecAudioAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
@@ -480,7 +337,6 @@ class Data2VecAudioAttention(nn.Module):
return attn_output, attn_weights_reshaped, past_key_value
-# Copied from transformers.models.bart.modeling_bart.BartFlashAttention2 with Bart->Data2VecAudio
class Data2VecAudioFlashAttention2(Data2VecAudioAttention):
"""
Data2VecAudio flash attention module. This module inherits from `Data2VecAudioAttention` as the weights of the module stays
@@ -608,7 +464,6 @@ class Data2VecAudioFlashAttention2(Data2VecAudioAttention):
class Data2VecAudioSdpaAttention(Data2VecAudioAttention):
- # Copied from transformers.models.bart.modeling_bart.BartSdpaAttention.forward with Bart->Data2VecAudio
def forward(
self,
hidden_states: torch.Tensor,
@@ -714,14 +569,6 @@ class Data2VecAudioSdpaAttention(Data2VecAudioAttention):
return attn_output, None, past_key_value
-DATA2VEC2AUDIO_ATTENTION_CLASSES = {
- "eager": Data2VecAudioAttention,
- "sdpa": Data2VecAudioSdpaAttention,
- "flash_attention_2": Data2VecAudioFlashAttention2,
-}
-
-
-# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2FeedForward with Wav2Vec2->Data2VecAudio
class Data2VecAudioFeedForward(nn.Module):
def __init__(self, config):
super().__init__()
@@ -746,11 +593,17 @@ class Data2VecAudioFeedForward(nn.Module):
return hidden_states
-# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2EncoderLayer with Wav2Vec2->Data2VecAudio, WAV2VEC2->DATA2VEC2AUDIO
+DATA2VEC_AUDIO_ATTENTION_CLASSES = {
+ "eager": Data2VecAudioAttention,
+ "sdpa": Data2VecAudioSdpaAttention,
+ "flash_attention_2": Data2VecAudioFlashAttention2,
+}
+
+
class Data2VecAudioEncoderLayer(nn.Module):
def __init__(self, config):
super().__init__()
- self.attention = DATA2VEC2AUDIO_ATTENTION_CLASSES[config._attn_implementation](
+ self.attention = DATA2VEC_AUDIO_ATTENTION_CLASSES[config._attn_implementation](
embed_dim=config.hidden_size,
num_heads=config.num_attention_heads,
dropout=config.attention_dropout,
@@ -782,7 +635,6 @@ class Data2VecAudioEncoderLayer(nn.Module):
return outputs
-# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Encoder with Wav2Vec2->Data2VecAudio
class Data2VecAudioEncoder(nn.Module):
def __init__(self, config):
super().__init__()
@@ -868,7 +720,24 @@ class Data2VecAudioEncoder(nn.Module):
)
-# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Adapter with Wav2Vec2->Data2VecAudio
+class Data2VecAudioAdapterLayer(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.conv = nn.Conv1d(
+ config.output_hidden_size,
+ 2 * config.output_hidden_size,
+ config.adapter_kernel_size,
+ stride=config.adapter_stride,
+ padding=1,
+ )
+
+ def forward(self, hidden_states):
+ hidden_states = self.conv(hidden_states)
+ hidden_states = nn.functional.glu(hidden_states, dim=1)
+
+ return hidden_states
+
+
class Data2VecAudioAdapter(nn.Module):
def __init__(self, config):
super().__init__()
@@ -900,25 +769,6 @@ class Data2VecAudioAdapter(nn.Module):
return hidden_states
-# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2AdapterLayer with Wav2Vec2->Data2VecAudio
-class Data2VecAudioAdapterLayer(nn.Module):
- def __init__(self, config):
- super().__init__()
- self.conv = nn.Conv1d(
- config.output_hidden_size,
- 2 * config.output_hidden_size,
- config.adapter_kernel_size,
- stride=config.adapter_stride,
- padding=1,
- )
-
- def forward(self, hidden_states):
- hidden_states = self.conv(hidden_states)
- hidden_states = nn.functional.glu(hidden_states, dim=1)
-
- return hidden_states
-
-
class Data2VecAudioPreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
@@ -957,7 +807,6 @@ class Data2VecAudioPreTrainedModel(PreTrainedModel):
k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0]))
nn.init.uniform_(module.bias, a=-k, b=k)
- # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2PreTrainedModel._get_feat_extract_output_lengths with
def _get_feat_extract_output_lengths(
self, input_lengths: Union[torch.LongTensor, int], add_adapter: Optional[bool] = None
):
@@ -981,7 +830,6 @@ class Data2VecAudioPreTrainedModel(PreTrainedModel):
return input_lengths
- # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2PreTrainedModel._get_feature_vector_attention_mask
def _get_feature_vector_attention_mask(
self, feature_vector_length: int, attention_mask: torch.LongTensor, add_adapter=None
):
@@ -1003,6 +851,128 @@ class Data2VecAudioPreTrainedModel(PreTrainedModel):
return attention_mask
+def _compute_mask_indices(
+ shape: Tuple[int, int],
+ mask_prob: float,
+ mask_length: int,
+ attention_mask: Optional[torch.LongTensor] = None,
+ min_masks: int = 0,
+) -> np.ndarray:
+ """
+ Computes random mask spans for a given shape. Used to implement [SpecAugment: A Simple Data Augmentation Method for
+ ASR](https://arxiv.org/abs/1904.08779). Note that this method is not optimized to run on TPU and should be run on
+ CPU as part of the preprocessing during training.
+
+ Args:
+ shape: The shape for which to compute masks. This should be of a tuple of size 2 where
+ the first element is the batch size and the second element is the length of the axis to span.
+ mask_prob: The percentage of the whole axis (between 0 and 1) which will be masked. The number of
+ independently generated mask spans of length `mask_length` is computed by
+ `mask_prob*shape[1]/mask_length`. Note that due to overlaps, `mask_prob` is an upper bound and the
+ actual percentage will be smaller.
+ mask_length: size of the mask
+ min_masks: minimum number of masked spans
+ attention_mask: A (right-padded) attention mask which independently shortens the feature axis of
+ each batch dimension.
+ """
+ batch_size, sequence_length = shape
+
+ if mask_length < 1:
+ raise ValueError("`mask_length` has to be bigger than 0.")
+
+ if mask_length > sequence_length:
+ raise ValueError(
+ f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length}"
+ f" and `sequence_length`: {sequence_length}`"
+ )
+
+ # epsilon is used for probabilistic rounding
+ epsilon = np.random.rand(1).item()
+
+ def compute_num_masked_span(input_length):
+ """Given input length, compute how many spans should be masked"""
+ num_masked_span = int(mask_prob * input_length / mask_length + epsilon)
+ num_masked_span = max(num_masked_span, min_masks)
+
+ # make sure num masked span <= sequence_length
+ if num_masked_span * mask_length > sequence_length:
+ num_masked_span = sequence_length // mask_length
+
+ # make sure num_masked span is also <= input_length - (mask_length - 1)
+ if input_length - (mask_length - 1) < num_masked_span:
+ num_masked_span = max(input_length - (mask_length - 1), 0)
+
+ return num_masked_span
+
+ # compute number of masked spans in batch
+ input_lengths = (
+ attention_mask.detach().sum(-1).tolist()
+ if attention_mask is not None
+ else [sequence_length for _ in range(batch_size)]
+ )
+
+ # SpecAugment mask to fill
+ spec_aug_mask = np.zeros((batch_size, sequence_length), dtype=bool)
+ spec_aug_mask_idxs = []
+
+ max_num_masked_span = compute_num_masked_span(sequence_length)
+
+ if max_num_masked_span == 0:
+ return spec_aug_mask
+
+ for input_length in input_lengths:
+ # compute num of masked spans for this input
+ num_masked_span = compute_num_masked_span(input_length)
+
+ # get random indices to mask
+ spec_aug_mask_idx = np.random.choice(
+ np.arange(input_length - (mask_length - 1)), num_masked_span, replace=False
+ )
+
+ # pick first sampled index that will serve as a dummy index to pad vector
+ # to ensure same dimension for all batches due to probabilistic rounding
+ # Picking first sample just pads those vectors twice.
+ if len(spec_aug_mask_idx) == 0:
+ # this case can only happen if `input_length` is strictly smaller then
+ # `sequence_length` in which case the last token has to be a padding
+ # token which we can use as a dummy mask id
+ dummy_mask_idx = sequence_length - 1
+ else:
+ dummy_mask_idx = spec_aug_mask_idx[0]
+
+ spec_aug_mask_idx = np.concatenate(
+ [spec_aug_mask_idx, np.ones(max_num_masked_span - num_masked_span, dtype=np.int32) * dummy_mask_idx]
+ )
+ spec_aug_mask_idxs.append(spec_aug_mask_idx)
+
+ spec_aug_mask_idxs = np.array(spec_aug_mask_idxs)
+
+ # expand masked indices to masked spans
+ spec_aug_mask_idxs = np.broadcast_to(
+ spec_aug_mask_idxs[:, :, None], (batch_size, max_num_masked_span, mask_length)
+ )
+ spec_aug_mask_idxs = spec_aug_mask_idxs.reshape(batch_size, max_num_masked_span * mask_length)
+
+ # add offset to the starting indexes so that indexes now create a span
+ offsets = np.arange(mask_length)[None, None, :]
+ offsets = np.broadcast_to(offsets, (batch_size, max_num_masked_span, mask_length)).reshape(
+ batch_size, max_num_masked_span * mask_length
+ )
+ spec_aug_mask_idxs = spec_aug_mask_idxs + offsets
+
+ # ensure that we cannot have indices larger than sequence_length
+ if spec_aug_mask_idxs.max() > sequence_length - 1:
+ spec_aug_mask_idxs[spec_aug_mask_idxs > sequence_length - 1] = sequence_length - 1
+
+ # scatter indices to mask
+ np.put_along_axis(spec_aug_mask, spec_aug_mask_idxs, 1, -1)
+
+ return spec_aug_mask
+
+
+_EXPECTED_OUTPUT_SHAPE = [1, 292, 768]
+
+
DATA2VEC_AUDIO_START_DOCSTRING = r"""
Data2VecAudio was proposed in [data2vec: A General Framework for Self-supervised Learning in Speech, Vision and
Language](https://arxiv.org/pdf/2202.03555) by Alexei Baevski, Wei-Ning Hsu, Qiantong Xu, Arun Babu, Jiatao Gu and
@@ -1021,7 +991,6 @@ DATA2VEC_AUDIO_START_DOCSTRING = r"""
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
-
DATA2VEC_AUDIO_INPUTS_DOCSTRING = r"""
Args:
input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
@@ -1058,6 +1027,8 @@ DATA2VEC_AUDIO_INPUTS_DOCSTRING = r"""
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
+Data2VecAudioBaseModelOutput = Wav2Vec2BaseModelOutput
+
@add_start_docstrings(
"The bare Data2VecAudio Model transformer outputting raw hidden-states without any specific head on top.",
@@ -1137,7 +1108,7 @@ class Data2VecAudioModel(Data2VecAudioPreTrainedModel):
@add_start_docstrings_to_model_forward(DATA2VEC_AUDIO_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
- output_type=Wav2Vec2BaseModelOutput,
+ output_type=Data2VecAudioBaseModelOutput,
config_class=_CONFIG_FOR_DOC,
modality="audio",
expected_output=_EXPECTED_OUTPUT_SHAPE,
@@ -1150,7 +1121,7 @@ class Data2VecAudioModel(Data2VecAudioPreTrainedModel):
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
- ) -> Union[Tuple, Wav2Vec2BaseModelOutput]:
+ ) -> Union[Tuple, Data2VecAudioBaseModelOutput]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
@@ -1187,7 +1158,7 @@ class Data2VecAudioModel(Data2VecAudioPreTrainedModel):
if not return_dict:
return (hidden_states, extract_features) + encoder_outputs[1:]
- return Wav2Vec2BaseModelOutput(
+ return Data2VecAudioBaseModelOutput(
last_hidden_state=hidden_states,
extract_features=extract_features,
hidden_states=encoder_outputs.hidden_states,
@@ -1195,6 +1166,13 @@ class Data2VecAudioModel(Data2VecAudioPreTrainedModel):
)
+_HIDDEN_STATES_START_POSITION = 2
+
+# CTC docstring
+_CTC_EXPECTED_OUTPUT = "'MISTER QUILTER IS THE APOSTLE OF THE MIDDLE CLASSES AND WE ARE GLAD TO WELCOME HIS GOSPEL'"
+_CTC_EXPECTED_LOSS = 66.95
+
+
@add_start_docstrings(
"""Data2VecAudio Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC).""",
DATA2VEC_AUDIO_START_DOCSTRING,
@@ -1248,7 +1226,6 @@ class Data2VecAudioForCTC(Data2VecAudioPreTrainedModel):
expected_output=_CTC_EXPECTED_OUTPUT,
expected_loss=_CTC_EXPECTED_LOSS,
)
- # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForCTC.forward with wav2vec2->data2vec_audio
def forward(
self,
input_values: Optional[torch.Tensor],
@@ -1379,7 +1356,6 @@ class Data2VecAudioForSequenceClassification(Data2VecAudioPreTrainedModel):
config_class=_CONFIG_FOR_DOC,
modality="audio",
)
- # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.forward with wav2vec2->data2vec_audio
def forward(
self,
input_values: Optional[torch.Tensor],
@@ -1455,8 +1431,7 @@ class Data2VecAudioForAudioFrameClassification(Data2VecAudioPreTrainedModel):
if hasattr(config, "add_adapter") and config.add_adapter:
raise ValueError(
- "Audio frame classification does not support the use of Data2VecAudio adapters"
- " (config.add_adapter=True)"
+ "Audio frame classification does not support the use of Data2VecAudio adapters (config.add_adapter=True)"
)
self.data2vec_audio = Data2VecAudioModel(config)
num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings
@@ -1501,7 +1476,6 @@ class Data2VecAudioForAudioFrameClassification(Data2VecAudioPreTrainedModel):
config_class=_CONFIG_FOR_DOC,
modality="audio",
)
- # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForAudioFrameClassification.forward with wav2vec2->data2vec_audio
def forward(
self,
input_values: Optional[torch.Tensor],
@@ -1556,7 +1530,6 @@ class Data2VecAudioForAudioFrameClassification(Data2VecAudioPreTrainedModel):
)
-# Copied from transformers.models.wav2vec2.modeling_wav2vec2.AMSoftmaxLoss
class AMSoftmaxLoss(nn.Module):
def __init__(self, input_dim, num_labels, scale=30.0, margin=0.4):
super(AMSoftmaxLoss, self).__init__()
@@ -1580,7 +1553,6 @@ class AMSoftmaxLoss(nn.Module):
return loss
-# Copied from transformers.models.wav2vec2.modeling_wav2vec2.TDNNLayer
class TDNNLayer(nn.Module):
def __init__(self, config, layer_id=0):
super().__init__()
@@ -1596,6 +1568,7 @@ class TDNNLayer(nn.Module):
if is_peft_available():
from peft.tuners.lora import LoraLayer
+ if is_peft_available():
if isinstance(self.kernel, LoraLayer):
warnings.warn(
"Detected LoRA on TDNNLayer. LoRA weights won't be applied due to optimization. "
@@ -1687,7 +1660,6 @@ class Data2VecAudioForXVector(Data2VecAudioPreTrainedModel):
config_class=_CONFIG_FOR_DOC,
modality="audio",
)
- # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForXVector.forward with wav2vec2->data2vec_audio
def forward(
self,
input_values: Optional[torch.Tensor],
diff --git a/src/transformers/models/data2vec/modular_data2vec_audio.py b/src/transformers/models/data2vec/modular_data2vec_audio.py
new file mode 100644
index 00000000000..052f22a960f
--- /dev/null
+++ b/src/transformers/models/data2vec/modular_data2vec_audio.py
@@ -0,0 +1,400 @@
+import math
+
+import torch
+from torch import nn
+
+from ...activations import ACT2FN
+from ...modeling_outputs import (
+ CausalLMOutput,
+ SequenceClassifierOutput,
+ TokenClassifierOutput,
+ Wav2Vec2BaseModelOutput,
+ XVectorOutput,
+)
+from ...modeling_utils import PreTrainedModel
+from ...utils import (
+ add_code_sample_docstrings,
+ add_start_docstrings,
+ add_start_docstrings_to_model_forward,
+)
+from ..wav2vec2.modeling_wav2vec2 import (
+ Wav2Vec2Adapter,
+ Wav2Vec2Encoder,
+ Wav2Vec2FeatureEncoder,
+ Wav2Vec2FeatureProjection,
+ Wav2Vec2ForAudioFrameClassification,
+ Wav2Vec2ForCTC,
+ Wav2Vec2ForSequenceClassification,
+ Wav2Vec2ForXVector,
+ Wav2Vec2Model,
+ Wav2Vec2PreTrainedModel,
+ Wav2Vec2SamePadLayer,
+)
+from .configuration_data2vec_audio import Data2VecAudioConfig
+
+
+_HIDDEN_STATES_START_POSITION = 2
+
+# General docstring
+_CONFIG_FOR_DOC = "Data2VecAudioConfig"
+
+# Base docstring
+_CHECKPOINT_FOR_DOC = "facebook/data2vec-audio-base-960h"
+_EXPECTED_OUTPUT_SHAPE = [1, 292, 768]
+
+# CTC docstring
+_CTC_EXPECTED_OUTPUT = "'MISTER QUILTER IS THE APOSTLE OF THE MIDDLE CLASSES AND WE ARE GLAD TO WELCOME HIS GOSPEL'"
+_CTC_EXPECTED_LOSS = 66.95
+
+
+class Data2VecAudioConvLayer(nn.Module):
+ def __init__(self, config, layer_id=0):
+ super().__init__()
+ self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
+ self.out_conv_dim = config.conv_dim[layer_id]
+
+ self.conv = nn.Conv1d(
+ self.in_conv_dim,
+ self.out_conv_dim,
+ kernel_size=config.conv_kernel[layer_id],
+ stride=config.conv_stride[layer_id],
+ bias=config.conv_bias,
+ )
+ self.layer_norm = nn.LayerNorm(self.out_conv_dim, elementwise_affine=True)
+ self.activation = ACT2FN[config.feat_extract_activation]
+
+ def forward(self, hidden_states):
+ hidden_states = self.conv(hidden_states)
+
+ hidden_states = hidden_states.transpose(-2, -1)
+ hidden_states = self.layer_norm(hidden_states)
+ hidden_states = hidden_states.transpose(-2, -1)
+
+ hidden_states = self.activation(hidden_states)
+ return hidden_states
+
+
+class Data2VecAudioPadLayer(Wav2Vec2SamePadLayer):
+ pass
+
+
+class Data2VecAudioPositionalConvLayer(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.conv = nn.Conv1d(
+ config.hidden_size,
+ config.hidden_size,
+ kernel_size=config.conv_pos_kernel_size,
+ padding=config.conv_pos_kernel_size // 2,
+ groups=config.num_conv_pos_embedding_groups,
+ )
+
+ self.padding = Data2VecAudioPadLayer(config.conv_pos_kernel_size)
+ self.activation = ACT2FN[config.feat_extract_activation]
+ # no learnable parameters
+ self.layer_norm = nn.LayerNorm(config.hidden_size, elementwise_affine=False)
+
+ def forward(self, hidden_states):
+ hidden_states = self.conv(hidden_states)
+ hidden_states = self.padding(hidden_states)
+
+ hidden_states = hidden_states.transpose(1, 2)
+ hidden_states = self.layer_norm(hidden_states)
+ hidden_states = hidden_states.transpose(1, 2)
+ hidden_states = self.activation(hidden_states)
+ return hidden_states
+
+
+class Data2VecAudioPositionalConvEmbedding(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.layers = nn.ModuleList(
+ [Data2VecAudioPositionalConvLayer(config) for _ in range(config.num_conv_pos_embeddings)]
+ )
+
+ def forward(self, hidden_states):
+ hidden_states = hidden_states.transpose(1, 2)
+ for layer in self.layers:
+ hidden_states = layer(hidden_states)
+ hidden_states = hidden_states.transpose(1, 2)
+ return hidden_states
+
+
+class Data2VecAudioFeatureEncoder(Wav2Vec2FeatureEncoder, nn.Module):
+ def __init__(self, config):
+ nn.Module.__init__()
+ self.conv_layers = nn.ModuleList(
+ [Data2VecAudioConvLayer(config, layer_id=i) for i in range(config.num_feat_extract_layers)]
+ )
+ self.gradient_checkpointing = False
+ self._requires_grad = True
+
+
+class Data2VecAudioFeatureProjection(Wav2Vec2FeatureProjection):
+ pass
+
+
+class Data2VecAudioEncoder(Wav2Vec2Encoder):
+ pass
+
+
+class Data2VecAudioAdapter(Wav2Vec2Adapter):
+ pass
+
+
+class Data2VecAudioPreTrainedModel(PreTrainedModel, Wav2Vec2PreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = Data2VecAudioConfig
+ base_model_prefix = "data2vec_audio"
+ main_input_name = "input_values"
+ supports_gradient_checkpointing = True
+ _supports_flash_attn_2 = True
+ _supports_sdpa = True
+
+ def _init_weights(self, module):
+ """Initialize the weights"""
+ if isinstance(module, Data2VecAudioFeatureProjection):
+ k = math.sqrt(1 / module.projection.in_features)
+ nn.init.uniform_(module.projection.weight, a=-k, b=k)
+ nn.init.uniform_(module.projection.bias, a=-k, b=k)
+ elif isinstance(module, Data2VecAudioPositionalConvLayer):
+ nn.init.constant_(module.conv.bias, 0)
+ elif isinstance(module, nn.Linear):
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)):
+ if module.bias is not None:
+ module.bias.data.zero_()
+ if module.weight is not None:
+ module.weight.data.fill_(1.0)
+ elif isinstance(module, nn.Conv1d):
+ nn.init.kaiming_normal_(module.weight)
+
+ if module.bias is not None:
+ k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0]))
+ nn.init.uniform_(module.bias, a=-k, b=k)
+
+ def _get_adapters(self):
+ raise AttributeError("Not needed for Data2VecAudio")
+
+ def init_adapter_layers(self):
+ raise AttributeError("Not needed for Data2VecAudio")
+
+ def load_adapter(self):
+ raise AttributeError("Not needed for Data2VecAudio")
+
+
+DATA2VEC_AUDIO_START_DOCSTRING = r"""
+ Data2VecAudio was proposed in [data2vec: A General Framework for Self-supervised Learning in Speech, Vision and
+ Language](https://arxiv.org/pdf/2202.03555) by Alexei Baevski, Wei-Ning Hsu, Qiantong Xu, Arun Babu, Jiatao Gu and
+ Michael Auli.
+
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
+ library implements for all its model (such as downloading or saving etc.).
+
+ This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use
+ it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
+ behavior.
+
+ Parameters:
+ config ([`Data2VecAudioConfig`]): Model configuration class with all the parameters of the model.
+ Initializing with a config file does not load the weights associated with the model, only the
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+DATA2VEC_AUDIO_INPUTS_DOCSTRING = r"""
+ Args:
+ input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
+ Float values of input raw speech waveform. Values can be obtained by loading a *.flac* or *.wav* audio file
+ into an array of type *List[float]* or a *numpy.ndarray*, *e.g.* via the soundfile library (*pip install
+ soundfile*). To prepare the array into *input_values*, the [`AutoProcessor`] should be used for padding and
+ conversion into a tensor of type *torch.FloatTensor*. See [`Wav2Vec2Processor.__call__`] for details.
+ attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing convolution and attention on padding token indices. Mask values selected in `[0,
+ 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+
+
+
+ `attention_mask` should be passed if the corresponding processor has `config.return_attention_mask ==
+ True`, which is the case for all pre-trained Data2Vec Audio models. Be aware that that even with
+ `attention_mask`, zero-padded inputs will have slightly different outputs compared to non-padded inputs
+ because there are more than one convolutional layer in the positional encodings. For a more detailed
+ explanation, see [here](https://github.com/huggingface/transformers/issues/25621#issuecomment-1713759349).
+
+
+
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+
+Data2VecAudioBaseModelOutput = Wav2Vec2BaseModelOutput
+
+
+@add_start_docstrings(
+ "The bare Data2VecAudio Model transformer outputting raw hidden-states without any specific head on top.",
+ DATA2VEC_AUDIO_START_DOCSTRING,
+)
+class Data2VecAudioModel(Data2VecAudioPreTrainedModel, Wav2Vec2Model):
+ def __init__(self, config: Data2VecAudioConfig):
+ Data2VecAudioPreTrainedModel.__init__(config)
+ self.config = config
+ self.feature_extractor = Data2VecAudioFeatureEncoder(config)
+ self.feature_projection = Data2VecAudioFeatureProjection(config)
+
+ # model only needs masking vector if mask prob is > 0.0
+ if config.mask_time_prob > 0.0 or config.mask_feature_prob > 0.0:
+ self.masked_spec_embed = nn.Parameter(torch.Tensor(config.hidden_size).uniform_())
+
+ self.encoder = Data2VecAudioEncoder(config)
+
+ self.adapter = Data2VecAudioAdapter(config) if config.add_adapter else None
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def freeze_feature_extractor(self):
+ raise AttributeError("Not needed for Data2VecAudio")
+
+ def freeze_feature_encoder(self):
+ """
+ Calling this function will disable the gradient computation for the feature encoder so that its parameter will
+ not be updated during training.
+ """
+ self.feature_extractor._freeze_parameters()
+
+ @add_start_docstrings_to_model_forward(DATA2VEC_AUDIO_INPUTS_DOCSTRING)
+ @add_code_sample_docstrings(
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=Data2VecAudioBaseModelOutput,
+ config_class=_CONFIG_FOR_DOC,
+ modality="audio",
+ expected_output=_EXPECTED_OUTPUT_SHAPE,
+ )
+ def forward(self, **super_kwargs):
+ return super().forward(**super_kwargs)
+
+
+@add_start_docstrings(
+ """Data2VecAudio Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC).""",
+ DATA2VEC_AUDIO_START_DOCSTRING,
+)
+class Data2VecAudioForCTC(Data2VecAudioPreTrainedModel, Wav2Vec2ForCTC):
+ def __init__(self, config):
+ Data2VecAudioPreTrainedModel.__init__(config)
+
+ self.data2vec_audio = Data2VecAudioModel(config)
+ self.dropout = nn.Dropout(config.final_dropout)
+
+ if config.vocab_size is None:
+ raise ValueError(
+ f"You are trying to instantiate {self.__class__} with a configuration that "
+ "does not define the vocabulary size of the language model head. Please "
+ "instantiate the model as follows: `Data2VecAudioForCTC.from_pretrained(..., vocab_size=vocab_size)`. "
+ "or define `vocab_size` of your model's configuration."
+ )
+ output_hidden_size = (
+ config.output_hidden_size if hasattr(config, "add_adapter") and config.add_adapter else config.hidden_size
+ )
+ self.lm_head = nn.Linear(output_hidden_size, config.vocab_size)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def freeze_base_model(self):
+ raise AttributeError("Not needed for Data2VecAudio")
+
+ def tie_weights(self):
+ raise AttributeError("Not needed for Data2VecAudio")
+
+ @add_start_docstrings_to_model_forward(DATA2VEC_AUDIO_INPUTS_DOCSTRING)
+ @add_code_sample_docstrings(
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=CausalLMOutput,
+ config_class=_CONFIG_FOR_DOC,
+ expected_output=_CTC_EXPECTED_OUTPUT,
+ expected_loss=_CTC_EXPECTED_LOSS,
+ )
+ def forward(self, **super_kwargs):
+ return super().forward(**super_kwargs)
+
+
+@add_start_docstrings(
+ """
+ Data2VecAudio Model with a sequence classification head on top (a linear layer over the pooled output) for tasks
+ like SUPERB Keyword Spotting.
+ """,
+ DATA2VEC_AUDIO_START_DOCSTRING,
+)
+class Data2VecAudioForSequenceClassification(Wav2Vec2ForSequenceClassification):
+ @add_start_docstrings_to_model_forward(DATA2VEC_AUDIO_INPUTS_DOCSTRING)
+ @add_code_sample_docstrings(
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=SequenceClassifierOutput,
+ config_class=_CONFIG_FOR_DOC,
+ modality="audio",
+ )
+ def forward(self, **super_kwargs):
+ return super().forward(**super_kwargs)
+
+
+@add_start_docstrings(
+ """
+ Data2VecAudio Model with a frame classification head on top for tasks like Speaker Diarization.
+ """,
+ DATA2VEC_AUDIO_START_DOCSTRING,
+)
+class Data2VecAudioForAudioFrameClassification(Wav2Vec2ForAudioFrameClassification):
+ @add_start_docstrings_to_model_forward(DATA2VEC_AUDIO_INPUTS_DOCSTRING)
+ @add_code_sample_docstrings(
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=TokenClassifierOutput,
+ config_class=_CONFIG_FOR_DOC,
+ modality="audio",
+ )
+ def forward(self, **super_kwargs):
+ return super().forward(**super_kwargs)
+
+
+@add_start_docstrings(
+ """
+ Data2VecAudio Model with an XVector feature extraction head on top for tasks like Speaker Verification.
+ """,
+ DATA2VEC_AUDIO_START_DOCSTRING,
+)
+class Data2VecAudioForXVector(Wav2Vec2ForXVector):
+ @add_start_docstrings_to_model_forward(DATA2VEC_AUDIO_INPUTS_DOCSTRING)
+ @add_code_sample_docstrings(
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=XVectorOutput,
+ config_class=_CONFIG_FOR_DOC,
+ modality="audio",
+ )
+ def forward(self, **super_kwargs):
+ return super().forward(**super_kwargs)
+
+
+__all__ = [
+ "Data2VecAudioForAudioFrameClassification",
+ "Data2VecAudioForCTC",
+ "Data2VecAudioForSequenceClassification",
+ "Data2VecAudioForXVector",
+ "Data2VecAudioModel",
+ "Data2VecAudioPreTrainedModel",
+]
diff --git a/src/transformers/models/gemma/configuration_gemma.py b/src/transformers/models/gemma/configuration_gemma.py
index 2aeb2005805..e372817bf71 100644
--- a/src/transformers/models/gemma/configuration_gemma.py
+++ b/src/transformers/models/gemma/configuration_gemma.py
@@ -19,7 +19,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
from ...configuration_utils import PretrainedConfig
diff --git a/src/transformers/models/gemma/modular_gemma.py b/src/transformers/models/gemma/modular_gemma.py
index 735c508f6d7..fa6af70ecfd 100644
--- a/src/transformers/models/gemma/modular_gemma.py
+++ b/src/transformers/models/gemma/modular_gemma.py
@@ -13,7 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
import sentencepiece as spm
import torch
@@ -379,7 +379,7 @@ class GemmaModel(LlamaModel):
output_hidden_states: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs, # NOOP kwarg for now
- ) -> Union[Tuple, BaseModelOutputWithPast]:
+ ) -> BaseModelOutputWithPast:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
diff --git a/src/transformers/models/gemma2/modular_gemma2.py b/src/transformers/models/gemma2/modular_gemma2.py
index d197d978e8d..384f3e08023 100644
--- a/src/transformers/models/gemma2/modular_gemma2.py
+++ b/src/transformers/models/gemma2/modular_gemma2.py
@@ -415,7 +415,7 @@ class Gemma2Model(GemmaModel):
cache_position: Optional[torch.LongTensor] = None,
last_cache_position: Optional[int] = None,
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
- ) -> Union[Tuple, BaseModelOutputWithPast]:
+ ) -> BaseModelOutputWithPast:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
@@ -588,7 +588,7 @@ class Gemma2ForCausalLM(GemmaForCausalLM):
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**loss_kwargs,
- ) -> Union[Tuple, CausalLMOutputWithPast]:
+ ) -> CausalLMOutputWithPast:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
diff --git a/src/transformers/models/gemma3/modeling_gemma3.py b/src/transformers/models/gemma3/modeling_gemma3.py
index 6dad88c1bc1..6364f890219 100644
--- a/src/transformers/models/gemma3/modeling_gemma3.py
+++ b/src/transformers/models/gemma3/modeling_gemma3.py
@@ -326,7 +326,7 @@ class Gemma3Attention(nn.Module):
past_key_value: Optional[Cache] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[FlashAttentionKwargs],
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim)
@@ -1203,7 +1203,7 @@ class Gemma3ForConditionalGeneration(Gemma3PreTrainedModel, GenerationMixin):
return causal_mask
- def get_image_features(self, pixel_values: torch.Tensor):
+ def get_image_features(self, pixel_values: torch.Tensor) -> torch.Tensor:
"""
Projects the last hidden state from the vision model into language model space.
diff --git a/src/transformers/models/gemma3/modular_gemma3.py b/src/transformers/models/gemma3/modular_gemma3.py
index c4de8d928d5..3f7292f13a0 100644
--- a/src/transformers/models/gemma3/modular_gemma3.py
+++ b/src/transformers/models/gemma3/modular_gemma3.py
@@ -597,7 +597,7 @@ class Gemma3TextModel(Gemma2Model):
cache_position: Optional[torch.LongTensor] = None,
last_cache_position: Optional[int] = None,
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
- ) -> Union[tuple, BaseModelOutputWithPast]:
+ ) -> BaseModelOutputWithPast:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
diff --git a/src/transformers/models/got_ocr2/modeling_got_ocr2.py b/src/transformers/models/got_ocr2/modeling_got_ocr2.py
index 1278af4faab..c474ef36900 100644
--- a/src/transformers/models/got_ocr2/modeling_got_ocr2.py
+++ b/src/transformers/models/got_ocr2/modeling_got_ocr2.py
@@ -753,7 +753,7 @@ class GotOcr2ForConditionalGeneration(GotOcr2PreTrainedModel, GenerationMixin):
output_hidden_states: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
- ) -> Union[Tuple, GotOcr2CausalLMOutputWithPast]:
+ ) -> GotOcr2CausalLMOutputWithPast:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
diff --git a/src/transformers/models/got_ocr2/modular_got_ocr2.py b/src/transformers/models/got_ocr2/modular_got_ocr2.py
index 36d2db007b6..4b7e7f1adb5 100644
--- a/src/transformers/models/got_ocr2/modular_got_ocr2.py
+++ b/src/transformers/models/got_ocr2/modular_got_ocr2.py
@@ -397,7 +397,7 @@ class GotOcr2ForConditionalGeneration(LlavaForConditionalGeneration):
output_hidden_states: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
- ) -> LlavaCausalLMOutputWithPast:
+ ) -> GotOcr2CausalLMOutputWithPast:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
diff --git a/src/transformers/models/gpt_neox/modular_gpt_neox.py b/src/transformers/models/gpt_neox/modular_gpt_neox.py
index 1399f1f18f1..58b697c3f10 100644
--- a/src/transformers/models/gpt_neox/modular_gpt_neox.py
+++ b/src/transformers/models/gpt_neox/modular_gpt_neox.py
@@ -343,7 +343,7 @@ class GPTNeoXModel(LlamaModel, nn.Module):
output_hidden_states: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
- ) -> Union[Tuple, BaseModelOutputWithPast]:
+ ) -> BaseModelOutputWithPast:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
diff --git a/src/transformers/models/granite/modular_granite.py b/src/transformers/models/granite/modular_granite.py
index 494ab5f1825..25929dbb337 100644
--- a/src/transformers/models/granite/modular_granite.py
+++ b/src/transformers/models/granite/modular_granite.py
@@ -132,7 +132,7 @@ class GraniteModel(LlamaModel):
output_hidden_states: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
- ) -> Union[Tuple, BaseModelOutputWithPast]:
+ ) -> BaseModelOutputWithPast:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
@@ -244,7 +244,7 @@ class GraniteForCausalLM(LlamaForCausalLM):
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**kwargs: Unpack[KwargsForCausalLM],
- ) -> Union[Tuple, CausalLMOutputWithPast]:
+ ) -> CausalLMOutputWithPast:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
diff --git a/src/transformers/models/hubert/modeling_hubert.py b/src/transformers/models/hubert/modeling_hubert.py
index 7445802e624..ae03cea1c13 100755
--- a/src/transformers/models/hubert/modeling_hubert.py
+++ b/src/transformers/models/hubert/modeling_hubert.py
@@ -1,26 +1,15 @@
-# coding=utf-8
-# Copyright 2021 The Fairseq Authors and the HuggingFace Inc. team. All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-"""PyTorch Hubert model."""
-
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/hubert/modular_hubert.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_hubert.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
import warnings
from typing import Optional, Tuple, Union
import numpy as np
import torch
-import torch.utils.checkpoint
-from torch import nn
+import torch.nn as nn
from torch.nn import CrossEntropyLoss
from ...activations import ACT2FN
@@ -45,219 +34,12 @@ if is_flash_attn_available():
logger = logging.get_logger(__name__)
-_HIDDEN_STATES_START_POSITION = 1
+# Base docstring
+_CHECKPOINT_FOR_DOC = "facebook/hubert-large-ls960-ft"
# General docstring
_CONFIG_FOR_DOC = "HubertConfig"
-# Base docstring
-_CHECKPOINT_FOR_DOC = "facebook/hubert-large-ls960-ft"
-_EXPECTED_OUTPUT_SHAPE = [1, 292, 768]
-
-# CTC docstring
-_CTC_EXPECTED_OUTPUT = "'MISTER QUILTER IS THE APOSTLE OF THE MIDDLE CLASSES AND WE ARE GLAD TO WELCOME HIS GOSPEL'"
-_CTC_EXPECTED_LOSS = 22.68
-
-# Audio class docstring
-_SEQ_CLASS_CHECKPOINT = "superb/hubert-base-superb-ks"
-_SEQ_CLASS_EXPECTED_OUTPUT = "'_unknown_'"
-_SEQ_CLASS_EXPECTED_LOSS = 8.53
-
-
-# Copied from transformers.models.wav2vec2.modeling_wav2vec2._compute_mask_indices
-def _compute_mask_indices(
- shape: Tuple[int, int],
- mask_prob: float,
- mask_length: int,
- attention_mask: Optional[torch.LongTensor] = None,
- min_masks: int = 0,
-) -> np.ndarray:
- """
- Computes random mask spans for a given shape. Used to implement [SpecAugment: A Simple Data Augmentation Method for
- ASR](https://arxiv.org/abs/1904.08779). Note that this method is not optimized to run on TPU and should be run on
- CPU as part of the preprocessing during training.
-
- Args:
- shape: The shape for which to compute masks. This should be of a tuple of size 2 where
- the first element is the batch size and the second element is the length of the axis to span.
- mask_prob: The percentage of the whole axis (between 0 and 1) which will be masked. The number of
- independently generated mask spans of length `mask_length` is computed by
- `mask_prob*shape[1]/mask_length`. Note that due to overlaps, `mask_prob` is an upper bound and the
- actual percentage will be smaller.
- mask_length: size of the mask
- min_masks: minimum number of masked spans
- attention_mask: A (right-padded) attention mask which independently shortens the feature axis of
- each batch dimension.
- """
- batch_size, sequence_length = shape
-
- if mask_length < 1:
- raise ValueError("`mask_length` has to be bigger than 0.")
-
- if mask_length > sequence_length:
- raise ValueError(
- f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length}"
- f" and `sequence_length`: {sequence_length}`"
- )
-
- # epsilon is used for probabilistic rounding
- epsilon = np.random.rand(1).item()
-
- def compute_num_masked_span(input_length):
- """Given input length, compute how many spans should be masked"""
- num_masked_span = int(mask_prob * input_length / mask_length + epsilon)
- num_masked_span = max(num_masked_span, min_masks)
-
- # make sure num masked span <= sequence_length
- if num_masked_span * mask_length > sequence_length:
- num_masked_span = sequence_length // mask_length
-
- # make sure num_masked span is also <= input_length - (mask_length - 1)
- if input_length - (mask_length - 1) < num_masked_span:
- num_masked_span = max(input_length - (mask_length - 1), 0)
-
- return num_masked_span
-
- # compute number of masked spans in batch
- input_lengths = (
- attention_mask.detach().sum(-1).tolist()
- if attention_mask is not None
- else [sequence_length for _ in range(batch_size)]
- )
-
- # SpecAugment mask to fill
- spec_aug_mask = np.zeros((batch_size, sequence_length), dtype=bool)
- spec_aug_mask_idxs = []
-
- max_num_masked_span = compute_num_masked_span(sequence_length)
-
- if max_num_masked_span == 0:
- return spec_aug_mask
-
- for input_length in input_lengths:
- # compute num of masked spans for this input
- num_masked_span = compute_num_masked_span(input_length)
-
- # get random indices to mask
- spec_aug_mask_idx = np.random.choice(
- np.arange(input_length - (mask_length - 1)), num_masked_span, replace=False
- )
-
- # pick first sampled index that will serve as a dummy index to pad vector
- # to ensure same dimension for all batches due to probabilistic rounding
- # Picking first sample just pads those vectors twice.
- if len(spec_aug_mask_idx) == 0:
- # this case can only happen if `input_length` is strictly smaller then
- # `sequence_length` in which case the last token has to be a padding
- # token which we can use as a dummy mask id
- dummy_mask_idx = sequence_length - 1
- else:
- dummy_mask_idx = spec_aug_mask_idx[0]
-
- spec_aug_mask_idx = np.concatenate(
- [spec_aug_mask_idx, np.ones(max_num_masked_span - num_masked_span, dtype=np.int32) * dummy_mask_idx]
- )
- spec_aug_mask_idxs.append(spec_aug_mask_idx)
-
- spec_aug_mask_idxs = np.array(spec_aug_mask_idxs)
-
- # expand masked indices to masked spans
- spec_aug_mask_idxs = np.broadcast_to(
- spec_aug_mask_idxs[:, :, None], (batch_size, max_num_masked_span, mask_length)
- )
- spec_aug_mask_idxs = spec_aug_mask_idxs.reshape(batch_size, max_num_masked_span * mask_length)
-
- # add offset to the starting indexes so that indexes now create a span
- offsets = np.arange(mask_length)[None, None, :]
- offsets = np.broadcast_to(offsets, (batch_size, max_num_masked_span, mask_length)).reshape(
- batch_size, max_num_masked_span * mask_length
- )
- spec_aug_mask_idxs = spec_aug_mask_idxs + offsets
-
- # ensure that we cannot have indices larger than sequence_length
- if spec_aug_mask_idxs.max() > sequence_length - 1:
- spec_aug_mask_idxs[spec_aug_mask_idxs > sequence_length - 1] = sequence_length - 1
-
- # scatter indices to mask
- np.put_along_axis(spec_aug_mask, spec_aug_mask_idxs, 1, -1)
-
- return spec_aug_mask
-
-
-# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2NoLayerNormConvLayer with Wav2Vec2->Hubert
-class HubertNoLayerNormConvLayer(nn.Module):
- def __init__(self, config, layer_id=0):
- super().__init__()
- self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
- self.out_conv_dim = config.conv_dim[layer_id]
-
- self.conv = nn.Conv1d(
- self.in_conv_dim,
- self.out_conv_dim,
- kernel_size=config.conv_kernel[layer_id],
- stride=config.conv_stride[layer_id],
- bias=config.conv_bias,
- )
- self.activation = ACT2FN[config.feat_extract_activation]
-
- def forward(self, hidden_states):
- hidden_states = self.conv(hidden_states)
- hidden_states = self.activation(hidden_states)
- return hidden_states
-
-
-# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2LayerNormConvLayer with Wav2Vec2->Hubert
-class HubertLayerNormConvLayer(nn.Module):
- def __init__(self, config, layer_id=0):
- super().__init__()
- self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
- self.out_conv_dim = config.conv_dim[layer_id]
-
- self.conv = nn.Conv1d(
- self.in_conv_dim,
- self.out_conv_dim,
- kernel_size=config.conv_kernel[layer_id],
- stride=config.conv_stride[layer_id],
- bias=config.conv_bias,
- )
- self.layer_norm = nn.LayerNorm(self.out_conv_dim, elementwise_affine=True)
- self.activation = ACT2FN[config.feat_extract_activation]
-
- def forward(self, hidden_states):
- hidden_states = self.conv(hidden_states)
-
- hidden_states = hidden_states.transpose(-2, -1)
- hidden_states = self.layer_norm(hidden_states)
- hidden_states = hidden_states.transpose(-2, -1)
-
- hidden_states = self.activation(hidden_states)
- return hidden_states
-
-
-# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2GroupNormConvLayer with Wav2Vec2->Hubert
-class HubertGroupNormConvLayer(nn.Module):
- def __init__(self, config, layer_id=0):
- super().__init__()
- self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
- self.out_conv_dim = config.conv_dim[layer_id]
-
- self.conv = nn.Conv1d(
- self.in_conv_dim,
- self.out_conv_dim,
- kernel_size=config.conv_kernel[layer_id],
- stride=config.conv_stride[layer_id],
- bias=config.conv_bias,
- )
- self.activation = ACT2FN[config.feat_extract_activation]
-
- self.layer_norm = nn.GroupNorm(num_groups=self.out_conv_dim, num_channels=self.out_conv_dim, affine=True)
-
- def forward(self, hidden_states):
- hidden_states = self.conv(hidden_states)
- hidden_states = self.layer_norm(hidden_states)
- hidden_states = self.activation(hidden_states)
- return hidden_states
-
class HubertPositionalConvEmbedding(nn.Module):
def __init__(self, config):
@@ -309,7 +91,6 @@ class HubertPositionalConvEmbedding(nn.Module):
return hidden_states
-# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2SamePadLayer with Wav2Vec2->Hubert
class HubertSamePadLayer(nn.Module):
def __init__(self, num_conv_pos_embeddings):
super().__init__()
@@ -321,7 +102,78 @@ class HubertSamePadLayer(nn.Module):
return hidden_states
-# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2FeatureEncoder with Wav2Vec2->Hubert
+class HubertNoLayerNormConvLayer(nn.Module):
+ def __init__(self, config, layer_id=0):
+ super().__init__()
+ self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
+ self.out_conv_dim = config.conv_dim[layer_id]
+
+ self.conv = nn.Conv1d(
+ self.in_conv_dim,
+ self.out_conv_dim,
+ kernel_size=config.conv_kernel[layer_id],
+ stride=config.conv_stride[layer_id],
+ bias=config.conv_bias,
+ )
+ self.activation = ACT2FN[config.feat_extract_activation]
+
+ def forward(self, hidden_states):
+ hidden_states = self.conv(hidden_states)
+ hidden_states = self.activation(hidden_states)
+ return hidden_states
+
+
+class HubertLayerNormConvLayer(nn.Module):
+ def __init__(self, config, layer_id=0):
+ super().__init__()
+ self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
+ self.out_conv_dim = config.conv_dim[layer_id]
+
+ self.conv = nn.Conv1d(
+ self.in_conv_dim,
+ self.out_conv_dim,
+ kernel_size=config.conv_kernel[layer_id],
+ stride=config.conv_stride[layer_id],
+ bias=config.conv_bias,
+ )
+ self.layer_norm = nn.LayerNorm(self.out_conv_dim, elementwise_affine=True)
+ self.activation = ACT2FN[config.feat_extract_activation]
+
+ def forward(self, hidden_states):
+ hidden_states = self.conv(hidden_states)
+
+ hidden_states = hidden_states.transpose(-2, -1)
+ hidden_states = self.layer_norm(hidden_states)
+ hidden_states = hidden_states.transpose(-2, -1)
+
+ hidden_states = self.activation(hidden_states)
+ return hidden_states
+
+
+class HubertGroupNormConvLayer(nn.Module):
+ def __init__(self, config, layer_id=0):
+ super().__init__()
+ self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
+ self.out_conv_dim = config.conv_dim[layer_id]
+
+ self.conv = nn.Conv1d(
+ self.in_conv_dim,
+ self.out_conv_dim,
+ kernel_size=config.conv_kernel[layer_id],
+ stride=config.conv_stride[layer_id],
+ bias=config.conv_bias,
+ )
+ self.activation = ACT2FN[config.feat_extract_activation]
+
+ self.layer_norm = nn.GroupNorm(num_groups=self.out_conv_dim, num_channels=self.out_conv_dim, affine=True)
+
+ def forward(self, hidden_states):
+ hidden_states = self.conv(hidden_states)
+ hidden_states = self.layer_norm(hidden_states)
+ hidden_states = self.activation(hidden_states)
+ return hidden_states
+
+
class HubertFeatureEncoder(nn.Module):
"""Construct the features from raw audio waveform"""
@@ -366,17 +218,6 @@ class HubertFeatureEncoder(nn.Module):
return hidden_states
-class HubertFeatureExtractor(HubertFeatureEncoder):
- def __init__(self, config):
- super().__init__(config)
- warnings.warn(
- f"The class `{self.__class__.__name__}` has been depreciated "
- "and will be removed in Transformers v5. "
- f"Use `{self.__class__.__bases__[0].__name__}` instead.",
- FutureWarning,
- )
-
-
class HubertFeatureProjection(nn.Module):
def __init__(self, config):
super().__init__()
@@ -395,7 +236,6 @@ class HubertFeatureProjection(nn.Module):
return hidden_states
-# Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->Hubert
class HubertAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
@@ -554,7 +394,6 @@ class HubertAttention(nn.Module):
return attn_output, attn_weights_reshaped, past_key_value
-# Copied from transformers.models.bart.modeling_bart.BartFlashAttention2 with Bart->Hubert
class HubertFlashAttention2(HubertAttention):
"""
Hubert flash attention module. This module inherits from `HubertAttention` as the weights of the module stays
@@ -682,7 +521,6 @@ class HubertFlashAttention2(HubertAttention):
class HubertSdpaAttention(HubertAttention):
- # Copied from transformers.models.bart.modeling_bart.BartSdpaAttention.forward with Bart->Hubert
def forward(
self,
hidden_states: torch.Tensor,
@@ -788,14 +626,6 @@ class HubertSdpaAttention(HubertAttention):
return attn_output, None, past_key_value
-HUBERT_ATTENTION_CLASSES = {
- "eager": HubertAttention,
- "sdpa": HubertSdpaAttention,
- "flash_attention_2": HubertFlashAttention2,
-}
-
-
-# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2FeedForward with Wav2Vec2->Hubert
class HubertFeedForward(nn.Module):
def __init__(self, config):
super().__init__()
@@ -820,7 +650,13 @@ class HubertFeedForward(nn.Module):
return hidden_states
-# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2EncoderLayer with Wav2Vec2->Hubert, WAV2VEC2->HUBERT
+HUBERT_ATTENTION_CLASSES = {
+ "eager": HubertAttention,
+ "sdpa": HubertSdpaAttention,
+ "flash_attention_2": HubertFlashAttention2,
+}
+
+
class HubertEncoderLayer(nn.Module):
def __init__(self, config):
super().__init__()
@@ -856,79 +692,6 @@ class HubertEncoderLayer(nn.Module):
return outputs
-# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2AttnAdapterLayer with Wav2Vec2->Hubert
-class HubertAttnAdapterLayer(nn.Module):
- def __init__(self, config):
- """
- Implements adapter modules directly with 3D tensor weight as parameters and without using ModuleList to speed
- up training throughput.
- """
- super().__init__()
- self.input_dim = config.adapter_attn_dim
- self.hidden_dim = config.hidden_size
-
- self.norm = nn.LayerNorm(self.hidden_dim)
- self.linear_1 = nn.Linear(self.hidden_dim, self.input_dim)
- self.act_fn = nn.ReLU()
- self.linear_2 = nn.Linear(self.input_dim, self.hidden_dim)
-
- def forward(self, hidden_states: torch.FloatTensor):
- hidden_states = self.norm(hidden_states)
-
- hidden_states = self.linear_1(hidden_states)
- hidden_states = self.act_fn(hidden_states)
- hidden_states = self.linear_2(hidden_states)
-
- return hidden_states
-
-
-# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2EncoderLayerStableLayerNorm with Wav2Vec2->Hubert, WAV2VEC2->HUBERT
-class HubertEncoderLayerStableLayerNorm(nn.Module):
- def __init__(self, config):
- super().__init__()
- self.attention = HUBERT_ATTENTION_CLASSES[config._attn_implementation](
- embed_dim=config.hidden_size,
- num_heads=config.num_attention_heads,
- dropout=config.attention_dropout,
- is_decoder=False,
- )
- self.dropout = nn.Dropout(config.hidden_dropout)
- self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
- self.feed_forward = HubertFeedForward(config)
- self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
-
- if getattr(config, "adapter_attn_dim", None) is not None:
- self.adapter_layer = HubertAttnAdapterLayer(config)
- else:
- self.adapter_layer = None
-
- def forward(
- self,
- hidden_states: torch.Tensor,
- attention_mask: Optional[torch.Tensor] = None,
- output_attentions: bool = False,
- ):
- attn_residual = hidden_states
- hidden_states = self.layer_norm(hidden_states)
- hidden_states, attn_weights, _ = self.attention(
- hidden_states, attention_mask=attention_mask, output_attentions=output_attentions
- )
- hidden_states = self.dropout(hidden_states)
- hidden_states = attn_residual + hidden_states
- hidden_states = hidden_states + self.feed_forward(self.final_layer_norm(hidden_states))
-
- if self.adapter_layer is not None:
- hidden_states = hidden_states + self.adapter_layer(hidden_states)
-
- outputs = (hidden_states,)
-
- if output_attentions:
- outputs += (attn_weights,)
-
- return outputs
-
-
-# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Encoder with Wav2Vec2->Hubert
class HubertEncoder(nn.Module):
def __init__(self, config):
super().__init__()
@@ -1014,7 +777,76 @@ class HubertEncoder(nn.Module):
)
-# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2EncoderStableLayerNorm with Wav2Vec2->Hubert
+class HubertAttnAdapterLayer(nn.Module):
+ def __init__(self, config):
+ """
+ Implements adapter modules directly with 3D tensor weight as parameters and without using ModuleList to speed
+ up training throughput.
+ """
+ super().__init__()
+ self.input_dim = config.adapter_attn_dim
+ self.hidden_dim = config.hidden_size
+
+ self.norm = nn.LayerNorm(self.hidden_dim)
+ self.linear_1 = nn.Linear(self.hidden_dim, self.input_dim)
+ self.act_fn = nn.ReLU()
+ self.linear_2 = nn.Linear(self.input_dim, self.hidden_dim)
+
+ def forward(self, hidden_states: torch.FloatTensor):
+ hidden_states = self.norm(hidden_states)
+
+ hidden_states = self.linear_1(hidden_states)
+ hidden_states = self.act_fn(hidden_states)
+ hidden_states = self.linear_2(hidden_states)
+
+ return hidden_states
+
+
+class HubertEncoderLayerStableLayerNorm(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.attention = HUBERT_ATTENTION_CLASSES[config._attn_implementation](
+ embed_dim=config.hidden_size,
+ num_heads=config.num_attention_heads,
+ dropout=config.attention_dropout,
+ is_decoder=False,
+ )
+ self.dropout = nn.Dropout(config.hidden_dropout)
+ self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.feed_forward = HubertFeedForward(config)
+ self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+
+ if getattr(config, "adapter_attn_dim", None) is not None:
+ self.adapter_layer = HubertAttnAdapterLayer(config)
+ else:
+ self.adapter_layer = None
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ output_attentions: bool = False,
+ ):
+ attn_residual = hidden_states
+ hidden_states = self.layer_norm(hidden_states)
+ hidden_states, attn_weights, _ = self.attention(
+ hidden_states, attention_mask=attention_mask, output_attentions=output_attentions
+ )
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = attn_residual + hidden_states
+ hidden_states = hidden_states + self.feed_forward(self.final_layer_norm(hidden_states))
+
+ if self.adapter_layer is not None:
+ hidden_states = hidden_states + self.adapter_layer(hidden_states)
+
+ outputs = (hidden_states,)
+
+ if output_attentions:
+ outputs += (attn_weights,)
+
+ return outputs
+
+
class HubertEncoderStableLayerNorm(nn.Module):
def __init__(self, config):
super().__init__()
@@ -1170,6 +1002,125 @@ class HubertPreTrainedModel(PreTrainedModel):
return attention_mask
+def _compute_mask_indices(
+ shape: Tuple[int, int],
+ mask_prob: float,
+ mask_length: int,
+ attention_mask: Optional[torch.LongTensor] = None,
+ min_masks: int = 0,
+) -> np.ndarray:
+ """
+ Computes random mask spans for a given shape. Used to implement [SpecAugment: A Simple Data Augmentation Method for
+ ASR](https://arxiv.org/abs/1904.08779). Note that this method is not optimized to run on TPU and should be run on
+ CPU as part of the preprocessing during training.
+
+ Args:
+ shape: The shape for which to compute masks. This should be of a tuple of size 2 where
+ the first element is the batch size and the second element is the length of the axis to span.
+ mask_prob: The percentage of the whole axis (between 0 and 1) which will be masked. The number of
+ independently generated mask spans of length `mask_length` is computed by
+ `mask_prob*shape[1]/mask_length`. Note that due to overlaps, `mask_prob` is an upper bound and the
+ actual percentage will be smaller.
+ mask_length: size of the mask
+ min_masks: minimum number of masked spans
+ attention_mask: A (right-padded) attention mask which independently shortens the feature axis of
+ each batch dimension.
+ """
+ batch_size, sequence_length = shape
+
+ if mask_length < 1:
+ raise ValueError("`mask_length` has to be bigger than 0.")
+
+ if mask_length > sequence_length:
+ raise ValueError(
+ f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length}"
+ f" and `sequence_length`: {sequence_length}`"
+ )
+
+ # epsilon is used for probabilistic rounding
+ epsilon = np.random.rand(1).item()
+
+ def compute_num_masked_span(input_length):
+ """Given input length, compute how many spans should be masked"""
+ num_masked_span = int(mask_prob * input_length / mask_length + epsilon)
+ num_masked_span = max(num_masked_span, min_masks)
+
+ # make sure num masked span <= sequence_length
+ if num_masked_span * mask_length > sequence_length:
+ num_masked_span = sequence_length // mask_length
+
+ # make sure num_masked span is also <= input_length - (mask_length - 1)
+ if input_length - (mask_length - 1) < num_masked_span:
+ num_masked_span = max(input_length - (mask_length - 1), 0)
+
+ return num_masked_span
+
+ # compute number of masked spans in batch
+ input_lengths = (
+ attention_mask.detach().sum(-1).tolist()
+ if attention_mask is not None
+ else [sequence_length for _ in range(batch_size)]
+ )
+
+ # SpecAugment mask to fill
+ spec_aug_mask = np.zeros((batch_size, sequence_length), dtype=bool)
+ spec_aug_mask_idxs = []
+
+ max_num_masked_span = compute_num_masked_span(sequence_length)
+
+ if max_num_masked_span == 0:
+ return spec_aug_mask
+
+ for input_length in input_lengths:
+ # compute num of masked spans for this input
+ num_masked_span = compute_num_masked_span(input_length)
+
+ # get random indices to mask
+ spec_aug_mask_idx = np.random.choice(
+ np.arange(input_length - (mask_length - 1)), num_masked_span, replace=False
+ )
+
+ # pick first sampled index that will serve as a dummy index to pad vector
+ # to ensure same dimension for all batches due to probabilistic rounding
+ # Picking first sample just pads those vectors twice.
+ if len(spec_aug_mask_idx) == 0:
+ # this case can only happen if `input_length` is strictly smaller then
+ # `sequence_length` in which case the last token has to be a padding
+ # token which we can use as a dummy mask id
+ dummy_mask_idx = sequence_length - 1
+ else:
+ dummy_mask_idx = spec_aug_mask_idx[0]
+
+ spec_aug_mask_idx = np.concatenate(
+ [spec_aug_mask_idx, np.ones(max_num_masked_span - num_masked_span, dtype=np.int32) * dummy_mask_idx]
+ )
+ spec_aug_mask_idxs.append(spec_aug_mask_idx)
+
+ spec_aug_mask_idxs = np.array(spec_aug_mask_idxs)
+
+ # expand masked indices to masked spans
+ spec_aug_mask_idxs = np.broadcast_to(
+ spec_aug_mask_idxs[:, :, None], (batch_size, max_num_masked_span, mask_length)
+ )
+ spec_aug_mask_idxs = spec_aug_mask_idxs.reshape(batch_size, max_num_masked_span * mask_length)
+
+ # add offset to the starting indexes so that indexes now create a span
+ offsets = np.arange(mask_length)[None, None, :]
+ offsets = np.broadcast_to(offsets, (batch_size, max_num_masked_span, mask_length)).reshape(
+ batch_size, max_num_masked_span * mask_length
+ )
+ spec_aug_mask_idxs = spec_aug_mask_idxs + offsets
+
+ # ensure that we cannot have indices larger than sequence_length
+ if spec_aug_mask_idxs.max() > sequence_length - 1:
+ spec_aug_mask_idxs[spec_aug_mask_idxs > sequence_length - 1] = sequence_length - 1
+
+ # scatter indices to mask
+ np.put_along_axis(spec_aug_mask, spec_aug_mask_idxs, 1, -1)
+
+ return spec_aug_mask
+
+
HUBERT_START_DOCSTRING = r"""
Hubert was proposed in [HuBERT: Self-Supervised Speech Representation Learning by Masked Prediction of Hidden
Units](https://arxiv.org/abs/2106.07447) by Wei-Ning Hsu, Benjamin Bolte, Yao-Hung Hubert Tsai, Kushal Lakhotia,
@@ -1238,6 +1189,7 @@ class HubertModel(HubertPreTrainedModel):
self.feature_extractor = HubertFeatureEncoder(config)
self.feature_projection = HubertFeatureProjection(config)
+ # model only needs masking vector if mask prob is > 0.0
if config.mask_time_prob > 0.0 or config.mask_feature_prob > 0.0:
self.masked_spec_embed = nn.Parameter(torch.Tensor(config.hidden_size).uniform_())
@@ -1249,7 +1201,6 @@ class HubertModel(HubertPreTrainedModel):
# Initialize weights and apply final processing
self.post_init()
- # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Model._mask_hidden_states
def _mask_hidden_states(
self,
hidden_states: torch.FloatTensor,
@@ -1370,11 +1321,17 @@ class HubertModel(HubertPreTrainedModel):
)
+_HIDDEN_STATES_START_POSITION = 1
+
+
+_CTC_EXPECTED_OUTPUT = "'MISTER QUILTER IS THE APOSTLE OF THE MIDDLE CLASSES AND WE ARE GLAD TO WELCOME HIS GOSPEL'"
+_CTC_EXPECTED_LOSS = 22.68
+
+
@add_start_docstrings(
"""Hubert Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC).""",
HUBERT_START_DOCSTRING,
)
-# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForCTC with Wav2Vec2->Hubert, wav2vec2->hubert, WAV_2_VEC_2->HUBERT
class HubertForCTC(HubertPreTrainedModel):
def __init__(self, config, target_lang: Optional[str] = None):
super().__init__(config)
@@ -1526,6 +1483,11 @@ class HubertForCTC(HubertPreTrainedModel):
)
+_SEQ_CLASS_CHECKPOINT = "superb/hubert-base-superb-ks"
+_SEQ_CLASS_EXPECTED_OUTPUT = "'_unknown_'"
+_SEQ_CLASS_EXPECTED_LOSS = 8.53
+
+
@add_start_docstrings(
"""
Hubert Model with a sequence classification head on top (a linear layer over the pooled output) for tasks like
@@ -1533,7 +1495,6 @@ class HubertForCTC(HubertPreTrainedModel):
""",
HUBERT_START_DOCSTRING,
)
-# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification with Wav2Vec2->Hubert, wav2vec2->hubert, WAV_2_VEC_2->HUBERT
class HubertForSequenceClassification(HubertPreTrainedModel):
def __init__(self, config):
super().__init__(config)
diff --git a/src/transformers/models/hubert/modular_hubert.py b/src/transformers/models/hubert/modular_hubert.py
new file mode 100644
index 00000000000..a42bb5e598b
--- /dev/null
+++ b/src/transformers/models/hubert/modular_hubert.py
@@ -0,0 +1,400 @@
+from typing import Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+
+from ...activations import ACT2FN
+from ...integrations.deepspeed import is_deepspeed_zero3_enabled
+from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput
+from ...modeling_utils import PreTrainedModel
+from ...utils import (
+ add_code_sample_docstrings,
+ add_start_docstrings,
+ add_start_docstrings_to_model_forward,
+ replace_return_docstrings,
+)
+from ..wav2vec2.modeling_wav2vec2 import (
+ Wav2Vec2Encoder,
+ Wav2Vec2EncoderStableLayerNorm,
+ Wav2Vec2FeatureEncoder,
+ Wav2Vec2ForCTC,
+ Wav2Vec2ForSequenceClassification,
+ Wav2Vec2Model,
+ Wav2Vec2SamePadLayer,
+)
+from .configuration_hubert import HubertConfig
+
+
+_HIDDEN_STATES_START_POSITION = 1
+
+# General docstring
+_CONFIG_FOR_DOC = "HubertConfig"
+
+# Base docstring
+_CHECKPOINT_FOR_DOC = "facebook/hubert-large-ls960-ft"
+_EXPECTED_OUTPUT_SHAPE = [1, 292, 768]
+
+
+_CTC_EXPECTED_OUTPUT = "'MISTER QUILTER IS THE APOSTLE OF THE MIDDLE CLASSES AND WE ARE GLAD TO WELCOME HIS GOSPEL'"
+_CTC_EXPECTED_LOSS = 22.68
+
+
+_SEQ_CLASS_CHECKPOINT = "superb/hubert-base-superb-ks"
+_SEQ_CLASS_EXPECTED_OUTPUT = "'_unknown_'"
+_SEQ_CLASS_EXPECTED_LOSS = 8.53
+
+
+class HubertPositionalConvEmbedding(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.conv = nn.Conv1d(
+ config.hidden_size,
+ config.hidden_size,
+ kernel_size=config.num_conv_pos_embeddings,
+ padding=config.num_conv_pos_embeddings // 2,
+ groups=config.num_conv_pos_embedding_groups,
+ )
+
+ self.batch_norm = None
+ if config.conv_pos_batch_norm:
+ self.batch_norm = nn.BatchNorm1d(config.hidden_size)
+ else:
+ weight_norm = nn.utils.weight_norm
+ if hasattr(nn.utils.parametrizations, "weight_norm"):
+ weight_norm = nn.utils.parametrizations.weight_norm
+
+ if is_deepspeed_zero3_enabled():
+ import deepspeed
+
+ with deepspeed.zero.GatheredParameters(self.conv.weight, modifier_rank=0):
+ self.conv = weight_norm(self.conv, name="weight", dim=2)
+ if hasattr(self.conv, "parametrizations"):
+ weight_g = self.conv.parametrizations.weight.original0
+ weight_v = self.conv.parametrizations.weight.original1
+ else:
+ weight_g = self.conv.weight_g
+ weight_v = self.conv.weight_v
+ deepspeed.zero.register_external_parameter(self, weight_v)
+ deepspeed.zero.register_external_parameter(self, weight_g)
+ else:
+ self.conv = weight_norm(self.conv, name="weight", dim=2)
+
+ self.padding = HubertSamePadLayer(config.num_conv_pos_embeddings)
+ self.activation = ACT2FN[config.feat_extract_activation]
+
+ def forward(self, hidden_states):
+ hidden_states = hidden_states.transpose(1, 2)
+ if self.batch_norm is not None:
+ hidden_states = self.batch_norm(hidden_states)
+ hidden_states = self.conv(hidden_states)
+ hidden_states = self.padding(hidden_states)
+ hidden_states = self.activation(hidden_states)
+
+ hidden_states = hidden_states.transpose(1, 2)
+ return hidden_states
+
+
+class HubertSamePadLayer(Wav2Vec2SamePadLayer):
+ pass
+
+
+class HubertFeatureEncoder(Wav2Vec2FeatureEncoder):
+ pass
+
+
+class HubertFeatureProjection(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.feat_proj_layer_norm = config.feat_proj_layer_norm
+ if self.feat_proj_layer_norm:
+ self.layer_norm = nn.LayerNorm(config.conv_dim[-1], eps=config.layer_norm_eps)
+ self.projection = nn.Linear(config.conv_dim[-1], config.hidden_size)
+ self.dropout = nn.Dropout(config.feat_proj_dropout)
+
+ def forward(self, hidden_states):
+ # non-projected hidden states are needed for quantization
+ if self.feat_proj_layer_norm:
+ hidden_states = self.layer_norm(hidden_states)
+ hidden_states = self.projection(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ return hidden_states
+
+
+class HubertEncoder(Wav2Vec2Encoder):
+ pass
+
+
+class HubertEncoderStableLayerNorm(Wav2Vec2EncoderStableLayerNorm):
+ pass
+
+
+class HubertPreTrainedModel(PreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = HubertConfig
+ base_model_prefix = "hubert"
+ main_input_name = "input_values"
+ supports_gradient_checkpointing = True
+ _supports_flash_attn_2 = True
+ _supports_sdpa = True
+
+ def _init_weights(self, module):
+ """Initialize the weights"""
+ if isinstance(module, nn.Linear):
+ # Slightly different from the TF version which uses truncated_normal for initialization
+ # cf https://github.com/pytorch/pytorch/pull/5617
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+ elif isinstance(module, nn.Conv1d):
+ if is_deepspeed_zero3_enabled():
+ import deepspeed
+
+ if hasattr(module, "weight_v") and hasattr(module, "weight_g"):
+ with deepspeed.zero.GatheredParameters([module.weight_v, module.weight_g], modifier_rank=0):
+ nn.init.kaiming_normal_(module.weight.data)
+ else:
+ with deepspeed.zero.GatheredParameters(module.weight, modifier_rank=0):
+ nn.init.kaiming_normal_(module.weight.data)
+ else:
+ nn.init.kaiming_normal_(module.weight.data)
+
+ if isinstance(module, (nn.Linear, nn.Conv1d)) and module.bias is not None:
+ module.bias.data.zero_()
+
+ def _get_feat_extract_output_lengths(self, input_lengths: Union[torch.LongTensor, int]):
+ """
+ Computes the output length of the convolutional layers
+ """
+
+ def _conv_out_length(input_length, kernel_size, stride):
+ # 1D convolutional layer output length formula taken
+ # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
+ return torch.div(input_length - kernel_size, stride, rounding_mode="floor") + 1
+
+ for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride):
+ input_lengths = _conv_out_length(input_lengths, kernel_size, stride)
+
+ return input_lengths
+
+ def _get_feature_vector_attention_mask(self, feature_vector_length: int, attention_mask: torch.LongTensor):
+ output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long)
+ batch_size = attention_mask.shape[0]
+
+ attention_mask = torch.zeros(
+ (batch_size, feature_vector_length), dtype=attention_mask.dtype, device=attention_mask.device
+ )
+ # these two operations makes sure that all values before the output lengths idxs are attended to
+ attention_mask[(torch.arange(attention_mask.shape[0], device=attention_mask.device), output_lengths - 1)] = 1
+ attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool()
+ return attention_mask
+
+
+HUBERT_START_DOCSTRING = r"""
+ Hubert was proposed in [HuBERT: Self-Supervised Speech Representation Learning by Masked Prediction of Hidden
+ Units](https://arxiv.org/abs/2106.07447) by Wei-Ning Hsu, Benjamin Bolte, Yao-Hung Hubert Tsai, Kushal Lakhotia,
+ Ruslan Salakhutdinov, Abdelrahman Mohamed.
+
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
+ library implements for all its model (such as downloading or saving etc.).
+
+ This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use
+ it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
+ behavior.
+
+ Parameters:
+ config ([`HubertConfig`]): Model configuration class with all the parameters of the model.
+ Initializing with a config file does not load the weights associated with the model, only the
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+
+HUBERT_INPUTS_DOCSTRING = r"""
+ Args:
+ input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
+ Float values of input raw speech waveform. Values can be obtained by loading a `.flac` or `.wav` audio file
+ into an array of type `List[float]` or a `numpy.ndarray`, *e.g.* via the soundfile library (`pip install
+ soundfile`). To prepare the array into `input_values`, the [`AutoProcessor`] should be used for padding and
+ conversion into a tensor of type `torch.FloatTensor`. See [`Wav2Vec2Processor.__call__`] for details.
+ attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing convolution and attention on padding token indices. Mask values selected in `[0,
+ 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+
+
+
+ `attention_mask` should only be passed if the corresponding processor has `config.return_attention_mask ==
+ True`. For all models whose processor has `config.return_attention_mask == False`, such as
+ [hubert-base](https://huggingface.co/facebook/hubert-base-ls960), `attention_mask` should **not** be passed
+ to avoid degraded performance when doing batched inference. For such models `input_values` should simply be
+ padded with 0 and passed without `attention_mask`. Be aware that these models also yield slightly different
+ results depending on whether `input_values` is padded or not.
+
+
+
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+
+
+@add_start_docstrings(
+ "The bare Hubert Model transformer outputting raw hidden-states without any specific head on top.",
+ HUBERT_START_DOCSTRING,
+)
+class HubertModel(Wav2Vec2Model, HubertPreTrainedModel):
+ def __init__(self, config: HubertConfig):
+ super().__init__(config)
+ self.config = config
+ self.feature_extractor = HubertFeatureEncoder(config)
+ self.feature_projection = HubertFeatureProjection(config)
+
+ if config.mask_time_prob > 0.0 or config.mask_feature_prob > 0.0:
+ self.masked_spec_embed = nn.Parameter(torch.Tensor(config.hidden_size).uniform_())
+
+ if config.do_stable_layer_norm:
+ self.encoder = HubertEncoderStableLayerNorm(config)
+ else:
+ self.encoder = HubertEncoder(config)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ del self.adapter
+
+ def freeze_feature_extractor(self):
+ raise AttributeError("Not needed for Hubert")
+
+ def freeze_feature_encoder(self):
+ raise AttributeError("Not needed for Hubert")
+
+ @add_start_docstrings_to_model_forward(HUBERT_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC)
+ def forward(
+ self,
+ input_values: Optional[torch.Tensor],
+ attention_mask: Optional[torch.Tensor] = None,
+ mask_time_indices: Optional[torch.FloatTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, BaseModelOutput]:
+ """
+
+ Returns:
+
+ Example:
+
+ ```python
+ >>> from transformers import AutoProcessor, HubertModel
+ >>> from datasets import load_dataset
+ >>> import soundfile as sf
+
+ >>> processor = AutoProcessor.from_pretrained("facebook/hubert-large-ls960-ft")
+ >>> model = HubertModel.from_pretrained("facebook/hubert-large-ls960-ft")
+
+
+ >>> def map_to_array(batch):
+ ... speech, _ = sf.read(batch["file"])
+ ... batch["speech"] = speech
+ ... return batch
+
+
+ >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
+ >>> ds = ds.map(map_to_array)
+
+ >>> input_values = processor(ds["speech"][0], return_tensors="pt").input_values # Batch size 1
+ >>> hidden_states = model(input_values).last_hidden_state
+ ```"""
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ extract_features = self.feature_extractor(input_values)
+ extract_features = extract_features.transpose(1, 2)
+
+ if attention_mask is not None:
+ # compute reduced attention_mask corresponding to feature vectors
+ attention_mask = self._get_feature_vector_attention_mask(extract_features.shape[1], attention_mask)
+
+ hidden_states = self.feature_projection(extract_features)
+ hidden_states = self._mask_hidden_states(hidden_states, mask_time_indices=mask_time_indices)
+
+ encoder_outputs = self.encoder(
+ hidden_states,
+ attention_mask=attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ hidden_states = encoder_outputs[0]
+
+ if not return_dict:
+ return (hidden_states,) + encoder_outputs[1:]
+
+ return BaseModelOutput(
+ last_hidden_state=hidden_states,
+ hidden_states=encoder_outputs.hidden_states,
+ attentions=encoder_outputs.attentions,
+ )
+
+
+@add_start_docstrings(
+ """Hubert Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC).""",
+ HUBERT_START_DOCSTRING,
+)
+class HubertForCTC(Wav2Vec2ForCTC):
+ pass
+
+ @add_start_docstrings_to_model_forward(HUBERT_INPUTS_DOCSTRING)
+ @add_code_sample_docstrings(
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=CausalLMOutput,
+ config_class=_CONFIG_FOR_DOC,
+ expected_output=_CTC_EXPECTED_OUTPUT,
+ expected_loss=_CTC_EXPECTED_LOSS,
+ )
+ def forward(self, **super_kwargs):
+ super().forward(**super_kwargs)
+
+
+@add_start_docstrings(
+ """
+ Hubert Model with a sequence classification head on top (a linear layer over the pooled output) for tasks like
+ SUPERB Keyword Spotting.
+ """,
+ HUBERT_START_DOCSTRING,
+)
+class HubertForSequenceClassification(Wav2Vec2ForSequenceClassification):
+ pass
+
+ @add_start_docstrings_to_model_forward(HUBERT_INPUTS_DOCSTRING)
+ @add_code_sample_docstrings(
+ checkpoint=_SEQ_CLASS_CHECKPOINT,
+ output_type=SequenceClassifierOutput,
+ config_class=_CONFIG_FOR_DOC,
+ modality="audio",
+ expected_output=_SEQ_CLASS_EXPECTED_OUTPUT,
+ expected_loss=_SEQ_CLASS_EXPECTED_LOSS,
+ )
+ def forward(self, **super_kwargs):
+ super().forward(**super_kwargs)
+
+
+__all__ = ["HubertForCTC", "HubertForSequenceClassification", "HubertModel", "HubertPreTrainedModel"]
diff --git a/src/transformers/models/mistral/modular_mistral.py b/src/transformers/models/mistral/modular_mistral.py
index 4548e69a738..6026f5a7e07 100644
--- a/src/transformers/models/mistral/modular_mistral.py
+++ b/src/transformers/models/mistral/modular_mistral.py
@@ -303,7 +303,7 @@ class MistralForQuestionAnswering(LlamaForQuestionAnswering):
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
**kwargs,
- ) -> Union[Tuple, QuestionAnsweringModelOutput]:
+ ) -> QuestionAnsweringModelOutput:
r"""
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for position (index) of the start of the labelled span for computing the token classification loss.
diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py
index 013f04ab36b..bc3f09a778b 100644
--- a/src/transformers/models/mixtral/modeling_mixtral.py
+++ b/src/transformers/models/mixtral/modeling_mixtral.py
@@ -620,7 +620,7 @@ class MixtralModel(MixtralPreTrainedModel):
output_router_logits: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
- ) -> BaseModelOutputWithPast:
+ ) -> MoeModelOutputWithPast:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_router_logits = (
output_router_logits if output_router_logits is not None else self.config.output_router_logits
@@ -1013,7 +1013,7 @@ class MixtralForCausalLM(MixtralPreTrainedModel, GenerationMixin):
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**kwargs: Unpack[KwargsForCausalLM],
- ) -> CausalLMOutputWithPast:
+ ) -> MoeCausalLMOutputWithPast:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
diff --git a/src/transformers/models/moonshine/modeling_moonshine.py b/src/transformers/models/moonshine/modeling_moonshine.py
index 09c48edbf24..bc46c9ab0c1 100644
--- a/src/transformers/models/moonshine/modeling_moonshine.py
+++ b/src/transformers/models/moonshine/modeling_moonshine.py
@@ -853,7 +853,7 @@ class MoonshineDecoder(MoonshinePreTrainedModel):
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
- ) -> BaseModelOutputWithPast:
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
"""
Args:
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*):
@@ -1415,7 +1415,7 @@ class MoonshineModel(MoonshinePreTrainedModel):
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
- ) -> Union[Tuple[torch.Tensor], Seq2SeqModelOutput]:
+ ) -> Seq2SeqModelOutput:
r"""
Returns:
diff --git a/src/transformers/models/phi/modular_phi.py b/src/transformers/models/phi/modular_phi.py
index cb9b1583382..f0eb31058c9 100644
--- a/src/transformers/models/phi/modular_phi.py
+++ b/src/transformers/models/phi/modular_phi.py
@@ -1,5 +1,5 @@
from functools import partial
-from typing import Callable, Optional, Tuple, Union
+from typing import Callable, Optional, Tuple
import torch
import torch.nn as nn
@@ -191,7 +191,7 @@ class PhiModel(LlamaModel):
output_hidden_states: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
- ) -> Union[Tuple, BaseModelOutputWithPast]:
+ ) -> BaseModelOutputWithPast:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
diff --git a/src/transformers/models/qwen3/modular_qwen3.py b/src/transformers/models/qwen3/modular_qwen3.py
index 58c02036b6e..8e0b86da407 100644
--- a/src/transformers/models/qwen3/modular_qwen3.py
+++ b/src/transformers/models/qwen3/modular_qwen3.py
@@ -14,7 +14,7 @@
# limitations under the License.
"""PyTorch Qwen3 model."""
-from typing import Callable, Optional, Tuple, Union
+from typing import Callable, Optional, Tuple
import torch
import torch.utils.checkpoint
@@ -146,7 +146,7 @@ class Qwen3ForCausalLM(LlamaForCausalLM):
def forward(
self,
**super_kwargs: Unpack[KwargsForCausalLM],
- ) -> Union[Tuple, CausalLMOutputWithPast]:
+ ) -> CausalLMOutputWithPast:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
diff --git a/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py b/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py
index b9ffae15a21..d239f79ec7d 100644
--- a/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py
+++ b/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py
@@ -633,7 +633,7 @@ class Qwen3MoeModel(Qwen3MoePreTrainedModel):
output_router_logits: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
- ) -> BaseModelOutputWithPast:
+ ) -> MoeModelOutputWithPast:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_router_logits = (
output_router_logits if output_router_logits is not None else self.config.output_router_logits
@@ -1026,7 +1026,7 @@ class Qwen3MoeForCausalLM(Qwen3MoePreTrainedModel, GenerationMixin):
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**kwargs: Unpack[KwargsForCausalLM],
- ) -> CausalLMOutputWithPast:
+ ) -> MoeCausalLMOutputWithPast:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
diff --git a/src/transformers/models/qwen3_moe/modular_qwen3_moe.py b/src/transformers/models/qwen3_moe/modular_qwen3_moe.py
index d8a5cc54f24..385f338c78e 100644
--- a/src/transformers/models/qwen3_moe/modular_qwen3_moe.py
+++ b/src/transformers/models/qwen3_moe/modular_qwen3_moe.py
@@ -257,7 +257,7 @@ class Qwen3MoeForCausalLM(MixtralForCausalLM):
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**kwargs: Unpack[KwargsForCausalLM],
- ) -> Union[Tuple, MoeCausalLMOutputWithPast]:
+ ) -> MoeCausalLMOutputWithPast:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
diff --git a/src/transformers/models/sew/modeling_sew.py b/src/transformers/models/sew/modeling_sew.py
index d5f773299c9..572c07e3c9d 100644
--- a/src/transformers/models/sew/modeling_sew.py
+++ b/src/transformers/models/sew/modeling_sew.py
@@ -1218,7 +1218,7 @@ class SEWModel(SEWPreTrainedModel):
"""SEW Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC).""",
SEW_START_DOCSTRING,
)
-# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForCTC with Wav2Vec2->SEW, wav2vec2->sew, WAV_2_VEC_2->SEW
+# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForCTC with Wav2Vec2->SEW, wav2vec2->sew, WAV2VEC2->SEW
class SEWForCTC(SEWPreTrainedModel):
def __init__(self, config, target_lang: Optional[str] = None):
super().__init__(config)
@@ -1377,7 +1377,7 @@ class SEWForCTC(SEWPreTrainedModel):
""",
SEW_START_DOCSTRING,
)
-# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification with Wav2Vec2->SEW, wav2vec2->sew, WAV_2_VEC_2->SEW
+# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification with Wav2Vec2->SEW, wav2vec2->sew, WAV2VEC2->SEW
class SEWForSequenceClassification(SEWPreTrainedModel):
def __init__(self, config):
super().__init__(config)
diff --git a/src/transformers/models/sew_d/modeling_sew_d.py b/src/transformers/models/sew_d/modeling_sew_d.py
index 9f49a46a05f..96c587fbb2e 100644
--- a/src/transformers/models/sew_d/modeling_sew_d.py
+++ b/src/transformers/models/sew_d/modeling_sew_d.py
@@ -1468,7 +1468,7 @@ class SEWDModel(SEWDPreTrainedModel):
"""SEW-D Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC).""",
SEWD_START_DOCSTRING,
)
-# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForCTC with Wav2Vec2->SEWD, wav2vec2->sew_d, WAV_2_VEC_2->SEWD
+# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForCTC with Wav2Vec2->SEWD, wav2vec2->sew_d, WAV2VEC2->SEWD
class SEWDForCTC(SEWDPreTrainedModel):
def __init__(self, config, target_lang: Optional[str] = None):
super().__init__(config)
@@ -1627,7 +1627,7 @@ class SEWDForCTC(SEWDPreTrainedModel):
""",
SEWD_START_DOCSTRING,
)
-# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification with Wav2Vec2->SEWD, wav2vec2->sew_d, WAV_2_VEC_2->SEWD
+# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification with Wav2Vec2->SEWD, wav2vec2->sew_d, WAV2VEC2->SEWD
class SEWDForSequenceClassification(SEWDPreTrainedModel):
def __init__(self, config):
super().__init__(config)
diff --git a/src/transformers/models/siglip2/modeling_siglip2.py b/src/transformers/models/siglip2/modeling_siglip2.py
index b4787439497..612e149d544 100644
--- a/src/transformers/models/siglip2/modeling_siglip2.py
+++ b/src/transformers/models/siglip2/modeling_siglip2.py
@@ -999,7 +999,7 @@ class Siglip2MultiheadAttentionPoolingHead(nn.Module):
self.mlp = Siglip2MLP(config)
self.num_heads = config.num_attention_heads
- def forward(self, hidden_state: torch.Tensor, attention_mask: Optional[torch.Tensor] = None):
+ def forward(self, hidden_state: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
batch_size = hidden_state.shape[0]
probe = self.probe.repeat(batch_size, 1, 1)
diff --git a/src/transformers/models/siglip2/modular_siglip2.py b/src/transformers/models/siglip2/modular_siglip2.py
index 92e106bc59b..23df3b0413d 100644
--- a/src/transformers/models/siglip2/modular_siglip2.py
+++ b/src/transformers/models/siglip2/modular_siglip2.py
@@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Optional, Tuple, Union
+from typing import Optional
import torch
import torch.nn as nn
@@ -243,7 +243,7 @@ class Siglip2VisionTransformer(SiglipVisionTransformer):
spatial_shapes: torch.LongTensor,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
- ) -> Union[Tuple, BaseModelOutputWithPooling]:
+ ) -> BaseModelOutputWithPooling:
r"""
Returns:
diff --git a/src/transformers/models/smolvlm/configuration_smolvlm.py b/src/transformers/models/smolvlm/configuration_smolvlm.py
index f4a42f348de..cd854415683 100644
--- a/src/transformers/models/smolvlm/configuration_smolvlm.py
+++ b/src/transformers/models/smolvlm/configuration_smolvlm.py
@@ -19,7 +19,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
from ...configuration_utils import PretrainedConfig
from ...utils import logging
from ..auto import CONFIG_MAPPING, AutoConfig
diff --git a/src/transformers/models/unispeech/modeling_unispeech.py b/src/transformers/models/unispeech/modeling_unispeech.py
index 2ee2e054aeb..74608797ab6 100755
--- a/src/transformers/models/unispeech/modeling_unispeech.py
+++ b/src/transformers/models/unispeech/modeling_unispeech.py
@@ -1,19 +1,9 @@
-# coding=utf-8
-# Copyright 2021 The Fairseq Authors and the HuggingFace Inc. team. All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-"""PyTorch UniSpeech model."""
-
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/unispeech/modular_unispeech.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_unispeech.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
import math
import warnings
from dataclasses import dataclass
@@ -21,18 +11,22 @@ from typing import Optional, Tuple, Union
import numpy as np
import torch
-import torch.utils.checkpoint
-from torch import nn
+import torch.nn as nn
from torch.nn import CrossEntropyLoss
from ...activations import ACT2FN
from ...integrations.deepspeed import is_deepspeed_zero3_enabled
from ...integrations.fsdp import is_fsdp_managed_module
from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available
-from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput, Wav2Vec2BaseModelOutput
+from ...modeling_outputs import (
+ BaseModelOutput,
+ CausalLMOutput,
+ ModelOutput,
+ SequenceClassifierOutput,
+ Wav2Vec2BaseModelOutput,
+)
from ...modeling_utils import PreTrainedModel
from ...utils import (
- ModelOutput,
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
@@ -48,20 +42,12 @@ if is_flash_attn_available():
logger = logging.get_logger(__name__)
-
-_HIDDEN_STATES_START_POSITION = 2
+# Base docstring
+_CHECKPOINT_FOR_DOC = "patrickvonplaten/unispeech-large-1500h-cv-timit"
# General docstring
_CONFIG_FOR_DOC = "UniSpeechConfig"
-# Base docstring
-_CHECKPOINT_FOR_DOC = "patrickvonplaten/unispeech-large-1500h-cv-timit"
-_EXPECTED_OUTPUT_SHAPE = [1, 292, 1024]
-
-# CTC docstring
-_CTC_EXPECTED_OUTPUT = "'mister quilter is the apposl of the midle classes and weare glad to welcom his gosepl'"
-_CTC_EXPECTED_LOSS = 17.17
-
@dataclass
class UniSpeechForPreTrainingOutput(ModelOutput):
@@ -99,202 +85,17 @@ class UniSpeechForPreTrainingOutput(ModelOutput):
attentions: Optional[Tuple[torch.FloatTensor]] = None
-# Copied from transformers.models.wav2vec2.modeling_wav2vec2._compute_mask_indices
-def _compute_mask_indices(
- shape: Tuple[int, int],
- mask_prob: float,
- mask_length: int,
- attention_mask: Optional[torch.LongTensor] = None,
- min_masks: int = 0,
-) -> np.ndarray:
- """
- Computes random mask spans for a given shape. Used to implement [SpecAugment: A Simple Data Augmentation Method for
- ASR](https://arxiv.org/abs/1904.08779). Note that this method is not optimized to run on TPU and should be run on
- CPU as part of the preprocessing during training.
-
- Args:
- shape: The shape for which to compute masks. This should be of a tuple of size 2 where
- the first element is the batch size and the second element is the length of the axis to span.
- mask_prob: The percentage of the whole axis (between 0 and 1) which will be masked. The number of
- independently generated mask spans of length `mask_length` is computed by
- `mask_prob*shape[1]/mask_length`. Note that due to overlaps, `mask_prob` is an upper bound and the
- actual percentage will be smaller.
- mask_length: size of the mask
- min_masks: minimum number of masked spans
- attention_mask: A (right-padded) attention mask which independently shortens the feature axis of
- each batch dimension.
- """
- batch_size, sequence_length = shape
-
- if mask_length < 1:
- raise ValueError("`mask_length` has to be bigger than 0.")
-
- if mask_length > sequence_length:
- raise ValueError(
- f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length}"
- f" and `sequence_length`: {sequence_length}`"
- )
-
- # epsilon is used for probabilistic rounding
- epsilon = np.random.rand(1).item()
-
- def compute_num_masked_span(input_length):
- """Given input length, compute how many spans should be masked"""
- num_masked_span = int(mask_prob * input_length / mask_length + epsilon)
- num_masked_span = max(num_masked_span, min_masks)
-
- # make sure num masked span <= sequence_length
- if num_masked_span * mask_length > sequence_length:
- num_masked_span = sequence_length // mask_length
-
- # make sure num_masked span is also <= input_length - (mask_length - 1)
- if input_length - (mask_length - 1) < num_masked_span:
- num_masked_span = max(input_length - (mask_length - 1), 0)
-
- return num_masked_span
-
- # compute number of masked spans in batch
- input_lengths = (
- attention_mask.detach().sum(-1).tolist()
- if attention_mask is not None
- else [sequence_length for _ in range(batch_size)]
- )
-
- # SpecAugment mask to fill
- spec_aug_mask = np.zeros((batch_size, sequence_length), dtype=bool)
- spec_aug_mask_idxs = []
-
- max_num_masked_span = compute_num_masked_span(sequence_length)
-
- if max_num_masked_span == 0:
- return spec_aug_mask
-
- for input_length in input_lengths:
- # compute num of masked spans for this input
- num_masked_span = compute_num_masked_span(input_length)
-
- # get random indices to mask
- spec_aug_mask_idx = np.random.choice(
- np.arange(input_length - (mask_length - 1)), num_masked_span, replace=False
- )
-
- # pick first sampled index that will serve as a dummy index to pad vector
- # to ensure same dimension for all batches due to probabilistic rounding
- # Picking first sample just pads those vectors twice.
- if len(spec_aug_mask_idx) == 0:
- # this case can only happen if `input_length` is strictly smaller then
- # `sequence_length` in which case the last token has to be a padding
- # token which we can use as a dummy mask id
- dummy_mask_idx = sequence_length - 1
- else:
- dummy_mask_idx = spec_aug_mask_idx[0]
-
- spec_aug_mask_idx = np.concatenate(
- [spec_aug_mask_idx, np.ones(max_num_masked_span - num_masked_span, dtype=np.int32) * dummy_mask_idx]
- )
- spec_aug_mask_idxs.append(spec_aug_mask_idx)
-
- spec_aug_mask_idxs = np.array(spec_aug_mask_idxs)
-
- # expand masked indices to masked spans
- spec_aug_mask_idxs = np.broadcast_to(
- spec_aug_mask_idxs[:, :, None], (batch_size, max_num_masked_span, mask_length)
- )
- spec_aug_mask_idxs = spec_aug_mask_idxs.reshape(batch_size, max_num_masked_span * mask_length)
-
- # add offset to the starting indexes so that indexes now create a span
- offsets = np.arange(mask_length)[None, None, :]
- offsets = np.broadcast_to(offsets, (batch_size, max_num_masked_span, mask_length)).reshape(
- batch_size, max_num_masked_span * mask_length
- )
- spec_aug_mask_idxs = spec_aug_mask_idxs + offsets
-
- # ensure that we cannot have indices larger than sequence_length
- if spec_aug_mask_idxs.max() > sequence_length - 1:
- spec_aug_mask_idxs[spec_aug_mask_idxs > sequence_length - 1] = sequence_length - 1
-
- # scatter indices to mask
- np.put_along_axis(spec_aug_mask, spec_aug_mask_idxs, 1, -1)
-
- return spec_aug_mask
-
-
-# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2NoLayerNormConvLayer with Wav2Vec2->UniSpeech
-class UniSpeechNoLayerNormConvLayer(nn.Module):
- def __init__(self, config, layer_id=0):
+class UniSpeechSamePadLayer(nn.Module):
+ def __init__(self, num_conv_pos_embeddings):
super().__init__()
- self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
- self.out_conv_dim = config.conv_dim[layer_id]
-
- self.conv = nn.Conv1d(
- self.in_conv_dim,
- self.out_conv_dim,
- kernel_size=config.conv_kernel[layer_id],
- stride=config.conv_stride[layer_id],
- bias=config.conv_bias,
- )
- self.activation = ACT2FN[config.feat_extract_activation]
+ self.num_pad_remove = 1 if num_conv_pos_embeddings % 2 == 0 else 0
def forward(self, hidden_states):
- hidden_states = self.conv(hidden_states)
- hidden_states = self.activation(hidden_states)
+ if self.num_pad_remove > 0:
+ hidden_states = hidden_states[:, :, : -self.num_pad_remove]
return hidden_states
-# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2LayerNormConvLayer with Wav2Vec2->UniSpeech
-class UniSpeechLayerNormConvLayer(nn.Module):
- def __init__(self, config, layer_id=0):
- super().__init__()
- self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
- self.out_conv_dim = config.conv_dim[layer_id]
-
- self.conv = nn.Conv1d(
- self.in_conv_dim,
- self.out_conv_dim,
- kernel_size=config.conv_kernel[layer_id],
- stride=config.conv_stride[layer_id],
- bias=config.conv_bias,
- )
- self.layer_norm = nn.LayerNorm(self.out_conv_dim, elementwise_affine=True)
- self.activation = ACT2FN[config.feat_extract_activation]
-
- def forward(self, hidden_states):
- hidden_states = self.conv(hidden_states)
-
- hidden_states = hidden_states.transpose(-2, -1)
- hidden_states = self.layer_norm(hidden_states)
- hidden_states = hidden_states.transpose(-2, -1)
-
- hidden_states = self.activation(hidden_states)
- return hidden_states
-
-
-# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2GroupNormConvLayer with Wav2Vec2->UniSpeech
-class UniSpeechGroupNormConvLayer(nn.Module):
- def __init__(self, config, layer_id=0):
- super().__init__()
- self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
- self.out_conv_dim = config.conv_dim[layer_id]
-
- self.conv = nn.Conv1d(
- self.in_conv_dim,
- self.out_conv_dim,
- kernel_size=config.conv_kernel[layer_id],
- stride=config.conv_stride[layer_id],
- bias=config.conv_bias,
- )
- self.activation = ACT2FN[config.feat_extract_activation]
-
- self.layer_norm = nn.GroupNorm(num_groups=self.out_conv_dim, num_channels=self.out_conv_dim, affine=True)
-
- def forward(self, hidden_states):
- hidden_states = self.conv(hidden_states)
- hidden_states = self.layer_norm(hidden_states)
- hidden_states = self.activation(hidden_states)
- return hidden_states
-
-
-# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2PositionalConvEmbedding with Wav2Vec2->UniSpeech
class UniSpeechPositionalConvEmbedding(nn.Module):
def __init__(self, config):
super().__init__()
@@ -340,19 +141,78 @@ class UniSpeechPositionalConvEmbedding(nn.Module):
return hidden_states
-# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2SamePadLayer with Wav2Vec2->UniSpeech
-class UniSpeechSamePadLayer(nn.Module):
- def __init__(self, num_conv_pos_embeddings):
+class UniSpeechNoLayerNormConvLayer(nn.Module):
+ def __init__(self, config, layer_id=0):
super().__init__()
- self.num_pad_remove = 1 if num_conv_pos_embeddings % 2 == 0 else 0
+ self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
+ self.out_conv_dim = config.conv_dim[layer_id]
+
+ self.conv = nn.Conv1d(
+ self.in_conv_dim,
+ self.out_conv_dim,
+ kernel_size=config.conv_kernel[layer_id],
+ stride=config.conv_stride[layer_id],
+ bias=config.conv_bias,
+ )
+ self.activation = ACT2FN[config.feat_extract_activation]
def forward(self, hidden_states):
- if self.num_pad_remove > 0:
- hidden_states = hidden_states[:, :, : -self.num_pad_remove]
+ hidden_states = self.conv(hidden_states)
+ hidden_states = self.activation(hidden_states)
+ return hidden_states
+
+
+class UniSpeechLayerNormConvLayer(nn.Module):
+ def __init__(self, config, layer_id=0):
+ super().__init__()
+ self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
+ self.out_conv_dim = config.conv_dim[layer_id]
+
+ self.conv = nn.Conv1d(
+ self.in_conv_dim,
+ self.out_conv_dim,
+ kernel_size=config.conv_kernel[layer_id],
+ stride=config.conv_stride[layer_id],
+ bias=config.conv_bias,
+ )
+ self.layer_norm = nn.LayerNorm(self.out_conv_dim, elementwise_affine=True)
+ self.activation = ACT2FN[config.feat_extract_activation]
+
+ def forward(self, hidden_states):
+ hidden_states = self.conv(hidden_states)
+
+ hidden_states = hidden_states.transpose(-2, -1)
+ hidden_states = self.layer_norm(hidden_states)
+ hidden_states = hidden_states.transpose(-2, -1)
+
+ hidden_states = self.activation(hidden_states)
+ return hidden_states
+
+
+class UniSpeechGroupNormConvLayer(nn.Module):
+ def __init__(self, config, layer_id=0):
+ super().__init__()
+ self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
+ self.out_conv_dim = config.conv_dim[layer_id]
+
+ self.conv = nn.Conv1d(
+ self.in_conv_dim,
+ self.out_conv_dim,
+ kernel_size=config.conv_kernel[layer_id],
+ stride=config.conv_stride[layer_id],
+ bias=config.conv_bias,
+ )
+ self.activation = ACT2FN[config.feat_extract_activation]
+
+ self.layer_norm = nn.GroupNorm(num_groups=self.out_conv_dim, num_channels=self.out_conv_dim, affine=True)
+
+ def forward(self, hidden_states):
+ hidden_states = self.conv(hidden_states)
+ hidden_states = self.layer_norm(hidden_states)
+ hidden_states = self.activation(hidden_states)
return hidden_states
-# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2FeatureEncoder with Wav2Vec2->UniSpeech
class UniSpeechFeatureEncoder(nn.Module):
"""Construct the features from raw audio waveform"""
@@ -400,18 +260,6 @@ class UniSpeechFeatureEncoder(nn.Module):
return hidden_states
-class UniSpeechFeatureExtractor(UniSpeechFeatureEncoder):
- def __init__(self, config):
- super().__init__(config)
- warnings.warn(
- f"The class `{self.__class__.__name__}` has been depreciated "
- "and will be removed in Transformers v5. "
- f"Use `{self.__class__.__bases__[0].__name__}` instead.",
- FutureWarning,
- )
-
-
-# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2FeatureProjection with Wav2Vec2->UniSpeech
class UniSpeechFeatureProjection(nn.Module):
def __init__(self, config):
super().__init__()
@@ -427,7 +275,6 @@ class UniSpeechFeatureProjection(nn.Module):
return hidden_states, norm_hidden_states
-# Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->UniSpeech
class UniSpeechAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
@@ -586,7 +433,6 @@ class UniSpeechAttention(nn.Module):
return attn_output, attn_weights_reshaped, past_key_value
-# Copied from transformers.models.bart.modeling_bart.BartFlashAttention2 with Bart->UniSpeech
class UniSpeechFlashAttention2(UniSpeechAttention):
"""
UniSpeech flash attention module. This module inherits from `UniSpeechAttention` as the weights of the module stays
@@ -714,7 +560,6 @@ class UniSpeechFlashAttention2(UniSpeechAttention):
class UniSpeechSdpaAttention(UniSpeechAttention):
- # Copied from transformers.models.bart.modeling_bart.BartSdpaAttention.forward with Bart->UniSpeech
def forward(
self,
hidden_states: torch.Tensor,
@@ -820,14 +665,6 @@ class UniSpeechSdpaAttention(UniSpeechAttention):
return attn_output, None, past_key_value
-UNISPEECH_ATTENTION_CLASSES = {
- "eager": UniSpeechAttention,
- "sdpa": UniSpeechSdpaAttention,
- "flash_attention_2": UniSpeechFlashAttention2,
-}
-
-
-# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2FeedForward with Wav2Vec2->UniSpeech
class UniSpeechFeedForward(nn.Module):
def __init__(self, config):
super().__init__()
@@ -852,7 +689,13 @@ class UniSpeechFeedForward(nn.Module):
return hidden_states
-# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2EncoderLayer with Wav2Vec2->UniSpeech, WAV2VEC2->UNISPEECH
+UNISPEECH_ATTENTION_CLASSES = {
+ "eager": UniSpeechAttention,
+ "sdpa": UniSpeechSdpaAttention,
+ "flash_attention_2": UniSpeechFlashAttention2,
+}
+
+
class UniSpeechEncoderLayer(nn.Module):
def __init__(self, config):
super().__init__()
@@ -888,79 +731,6 @@ class UniSpeechEncoderLayer(nn.Module):
return outputs
-# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2AttnAdapterLayer with Wav2Vec2->UniSpeech
-class UniSpeechAttnAdapterLayer(nn.Module):
- def __init__(self, config):
- """
- Implements adapter modules directly with 3D tensor weight as parameters and without using ModuleList to speed
- up training throughput.
- """
- super().__init__()
- self.input_dim = config.adapter_attn_dim
- self.hidden_dim = config.hidden_size
-
- self.norm = nn.LayerNorm(self.hidden_dim)
- self.linear_1 = nn.Linear(self.hidden_dim, self.input_dim)
- self.act_fn = nn.ReLU()
- self.linear_2 = nn.Linear(self.input_dim, self.hidden_dim)
-
- def forward(self, hidden_states: torch.FloatTensor):
- hidden_states = self.norm(hidden_states)
-
- hidden_states = self.linear_1(hidden_states)
- hidden_states = self.act_fn(hidden_states)
- hidden_states = self.linear_2(hidden_states)
-
- return hidden_states
-
-
-# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2EncoderLayerStableLayerNorm with Wav2Vec2->UniSpeech, WAV2VEC2->UNISPEECH
-class UniSpeechEncoderLayerStableLayerNorm(nn.Module):
- def __init__(self, config):
- super().__init__()
- self.attention = UNISPEECH_ATTENTION_CLASSES[config._attn_implementation](
- embed_dim=config.hidden_size,
- num_heads=config.num_attention_heads,
- dropout=config.attention_dropout,
- is_decoder=False,
- )
- self.dropout = nn.Dropout(config.hidden_dropout)
- self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
- self.feed_forward = UniSpeechFeedForward(config)
- self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
-
- if getattr(config, "adapter_attn_dim", None) is not None:
- self.adapter_layer = UniSpeechAttnAdapterLayer(config)
- else:
- self.adapter_layer = None
-
- def forward(
- self,
- hidden_states: torch.Tensor,
- attention_mask: Optional[torch.Tensor] = None,
- output_attentions: bool = False,
- ):
- attn_residual = hidden_states
- hidden_states = self.layer_norm(hidden_states)
- hidden_states, attn_weights, _ = self.attention(
- hidden_states, attention_mask=attention_mask, output_attentions=output_attentions
- )
- hidden_states = self.dropout(hidden_states)
- hidden_states = attn_residual + hidden_states
- hidden_states = hidden_states + self.feed_forward(self.final_layer_norm(hidden_states))
-
- if self.adapter_layer is not None:
- hidden_states = hidden_states + self.adapter_layer(hidden_states)
-
- outputs = (hidden_states,)
-
- if output_attentions:
- outputs += (attn_weights,)
-
- return outputs
-
-
-# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Encoder with Wav2Vec2->UniSpeech
class UniSpeechEncoder(nn.Module):
def __init__(self, config):
super().__init__()
@@ -1046,7 +816,76 @@ class UniSpeechEncoder(nn.Module):
)
-# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2EncoderStableLayerNorm with Wav2Vec2->UniSpeech
+class UniSpeechAttnAdapterLayer(nn.Module):
+ def __init__(self, config):
+ """
+ Implements adapter modules directly with 3D tensor weight as parameters and without using ModuleList to speed
+ up training throughput.
+ """
+ super().__init__()
+ self.input_dim = config.adapter_attn_dim
+ self.hidden_dim = config.hidden_size
+
+ self.norm = nn.LayerNorm(self.hidden_dim)
+ self.linear_1 = nn.Linear(self.hidden_dim, self.input_dim)
+ self.act_fn = nn.ReLU()
+ self.linear_2 = nn.Linear(self.input_dim, self.hidden_dim)
+
+ def forward(self, hidden_states: torch.FloatTensor):
+ hidden_states = self.norm(hidden_states)
+
+ hidden_states = self.linear_1(hidden_states)
+ hidden_states = self.act_fn(hidden_states)
+ hidden_states = self.linear_2(hidden_states)
+
+ return hidden_states
+
+
+class UniSpeechEncoderLayerStableLayerNorm(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.attention = UNISPEECH_ATTENTION_CLASSES[config._attn_implementation](
+ embed_dim=config.hidden_size,
+ num_heads=config.num_attention_heads,
+ dropout=config.attention_dropout,
+ is_decoder=False,
+ )
+ self.dropout = nn.Dropout(config.hidden_dropout)
+ self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.feed_forward = UniSpeechFeedForward(config)
+ self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+
+ if getattr(config, "adapter_attn_dim", None) is not None:
+ self.adapter_layer = UniSpeechAttnAdapterLayer(config)
+ else:
+ self.adapter_layer = None
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ output_attentions: bool = False,
+ ):
+ attn_residual = hidden_states
+ hidden_states = self.layer_norm(hidden_states)
+ hidden_states, attn_weights, _ = self.attention(
+ hidden_states, attention_mask=attention_mask, output_attentions=output_attentions
+ )
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = attn_residual + hidden_states
+ hidden_states = hidden_states + self.feed_forward(self.final_layer_norm(hidden_states))
+
+ if self.adapter_layer is not None:
+ hidden_states = hidden_states + self.adapter_layer(hidden_states)
+
+ outputs = (hidden_states,)
+
+ if output_attentions:
+ outputs += (attn_weights,)
+
+ return outputs
+
+
class UniSpeechEncoderStableLayerNorm(nn.Module):
def __init__(self, config):
super().__init__()
@@ -1138,7 +977,7 @@ class UniSpeechEncoderStableLayerNorm(nn.Module):
class UniSpeechGumbelVectorQuantizer(nn.Module):
"""
- Vector quantization using gumbel softmax. See [CATEGORICAL REPARAMETERIZATION WITH
+ Vector quantization using gumbel softmax. See `[CATEGORICAL REPARAMETERIZATION WITH
GUMBEL-SOFTMAX](https://arxiv.org/pdf/1611.01144.pdf) for more information.
"""
@@ -1149,8 +988,8 @@ class UniSpeechGumbelVectorQuantizer(nn.Module):
if config.codevector_dim % self.num_groups != 0:
raise ValueError(
- f"`config.codevector_dim {config.codevector_dim} must be divisible by `config.num_codevector_groups`"
- f" {self.num_groups} for concatenation"
+ f"`config.codevector_dim {config.codevector_dim} must be divisible "
+ f"by `config.num_codevector_groups` {self.num_groups} for concatenation"
)
# storage for codebook variables (codewords)
@@ -1283,6 +1122,128 @@ class UniSpeechPreTrainedModel(PreTrainedModel):
return attention_mask
+def _compute_mask_indices(
+ shape: Tuple[int, int],
+ mask_prob: float,
+ mask_length: int,
+ attention_mask: Optional[torch.LongTensor] = None,
+ min_masks: int = 0,
+) -> np.ndarray:
+ """
+ Computes random mask spans for a given shape. Used to implement [SpecAugment: A Simple Data Augmentation Method for
+ ASR](https://arxiv.org/abs/1904.08779). Note that this method is not optimized to run on TPU and should be run on
+ CPU as part of the preprocessing during training.
+
+ Args:
+ shape: The shape for which to compute masks. This should be of a tuple of size 2 where
+ the first element is the batch size and the second element is the length of the axis to span.
+ mask_prob: The percentage of the whole axis (between 0 and 1) which will be masked. The number of
+ independently generated mask spans of length `mask_length` is computed by
+ `mask_prob*shape[1]/mask_length`. Note that due to overlaps, `mask_prob` is an upper bound and the
+ actual percentage will be smaller.
+ mask_length: size of the mask
+ min_masks: minimum number of masked spans
+ attention_mask: A (right-padded) attention mask which independently shortens the feature axis of
+ each batch dimension.
+ """
+ batch_size, sequence_length = shape
+
+ if mask_length < 1:
+ raise ValueError("`mask_length` has to be bigger than 0.")
+
+ if mask_length > sequence_length:
+ raise ValueError(
+ f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length}"
+ f" and `sequence_length`: {sequence_length}`"
+ )
+
+ # epsilon is used for probabilistic rounding
+ epsilon = np.random.rand(1).item()
+
+ def compute_num_masked_span(input_length):
+ """Given input length, compute how many spans should be masked"""
+ num_masked_span = int(mask_prob * input_length / mask_length + epsilon)
+ num_masked_span = max(num_masked_span, min_masks)
+
+ # make sure num masked span <= sequence_length
+ if num_masked_span * mask_length > sequence_length:
+ num_masked_span = sequence_length // mask_length
+
+ # make sure num_masked span is also <= input_length - (mask_length - 1)
+ if input_length - (mask_length - 1) < num_masked_span:
+ num_masked_span = max(input_length - (mask_length - 1), 0)
+
+ return num_masked_span
+
+ # compute number of masked spans in batch
+ input_lengths = (
+ attention_mask.detach().sum(-1).tolist()
+ if attention_mask is not None
+ else [sequence_length for _ in range(batch_size)]
+ )
+
+ # SpecAugment mask to fill
+ spec_aug_mask = np.zeros((batch_size, sequence_length), dtype=bool)
+ spec_aug_mask_idxs = []
+
+ max_num_masked_span = compute_num_masked_span(sequence_length)
+
+ if max_num_masked_span == 0:
+ return spec_aug_mask
+
+ for input_length in input_lengths:
+ # compute num of masked spans for this input
+ num_masked_span = compute_num_masked_span(input_length)
+
+ # get random indices to mask
+ spec_aug_mask_idx = np.random.choice(
+ np.arange(input_length - (mask_length - 1)), num_masked_span, replace=False
+ )
+
+ # pick first sampled index that will serve as a dummy index to pad vector
+ # to ensure same dimension for all batches due to probabilistic rounding
+ # Picking first sample just pads those vectors twice.
+ if len(spec_aug_mask_idx) == 0:
+ # this case can only happen if `input_length` is strictly smaller then
+ # `sequence_length` in which case the last token has to be a padding
+ # token which we can use as a dummy mask id
+ dummy_mask_idx = sequence_length - 1
+ else:
+ dummy_mask_idx = spec_aug_mask_idx[0]
+
+ spec_aug_mask_idx = np.concatenate(
+ [spec_aug_mask_idx, np.ones(max_num_masked_span - num_masked_span, dtype=np.int32) * dummy_mask_idx]
+ )
+ spec_aug_mask_idxs.append(spec_aug_mask_idx)
+
+ spec_aug_mask_idxs = np.array(spec_aug_mask_idxs)
+
+ # expand masked indices to masked spans
+ spec_aug_mask_idxs = np.broadcast_to(
+ spec_aug_mask_idxs[:, :, None], (batch_size, max_num_masked_span, mask_length)
+ )
+ spec_aug_mask_idxs = spec_aug_mask_idxs.reshape(batch_size, max_num_masked_span * mask_length)
+
+ # add offset to the starting indexes so that indexes now create a span
+ offsets = np.arange(mask_length)[None, None, :]
+ offsets = np.broadcast_to(offsets, (batch_size, max_num_masked_span, mask_length)).reshape(
+ batch_size, max_num_masked_span * mask_length
+ )
+ spec_aug_mask_idxs = spec_aug_mask_idxs + offsets
+
+ # ensure that we cannot have indices larger than sequence_length
+ if spec_aug_mask_idxs.max() > sequence_length - 1:
+ spec_aug_mask_idxs[spec_aug_mask_idxs > sequence_length - 1] = sequence_length - 1
+
+ # scatter indices to mask
+ np.put_along_axis(spec_aug_mask, spec_aug_mask_idxs, 1, -1)
+
+ return spec_aug_mask
+
+
+_EXPECTED_OUTPUT_SHAPE = [1, 292, 1024]
+
+
UNISPEECH_START_DOCSTRING = r"""
UniSpeech was proposed in [UniSpeech: Unified Speech Representation Learning with Labeled and Unlabeled
Data](https://arxiv.org/abs/2101.07597) by Chengyi Wang, Yu Wu, Yao Qian, Kenichi Kumatani, Shujie Liu, Furu Wei,
@@ -1301,7 +1262,6 @@ UNISPEECH_START_DOCSTRING = r"""
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
-
UNISPEECH_INPUTS_DOCSTRING = r"""
Args:
input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
@@ -1339,6 +1299,9 @@ UNISPEECH_INPUTS_DOCSTRING = r"""
"""
+UniSpeechBaseModelOutput = Wav2Vec2BaseModelOutput
+
+
@add_start_docstrings(
"The bare UniSpeech Model transformer outputting raw hidden-states without any specific head on top.",
UNISPEECH_START_DOCSTRING,
@@ -1361,7 +1324,6 @@ class UniSpeechModel(UniSpeechPreTrainedModel):
# Initialize weights and apply final processing
self.post_init()
- # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Model._mask_hidden_states
def _mask_hidden_states(
self,
hidden_states: torch.FloatTensor,
@@ -1411,7 +1373,7 @@ class UniSpeechModel(UniSpeechPreTrainedModel):
@add_start_docstrings_to_model_forward(UNISPEECH_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
- output_type=Wav2Vec2BaseModelOutput,
+ output_type=UniSpeechBaseModelOutput,
config_class=_CONFIG_FOR_DOC,
modality="audio",
expected_output=_EXPECTED_OUTPUT_SHAPE,
@@ -1424,7 +1386,7 @@ class UniSpeechModel(UniSpeechPreTrainedModel):
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
- ) -> Union[Tuple, Wav2Vec2BaseModelOutput]:
+ ) -> Union[Tuple, UniSpeechBaseModelOutput]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
@@ -1456,7 +1418,7 @@ class UniSpeechModel(UniSpeechPreTrainedModel):
if not return_dict:
return (hidden_states, extract_features) + encoder_outputs[1:]
- return Wav2Vec2BaseModelOutput(
+ return UniSpeechBaseModelOutput(
last_hidden_state=hidden_states,
extract_features=extract_features,
hidden_states=encoder_outputs.hidden_states,
@@ -1610,6 +1572,13 @@ class UniSpeechForPreTraining(UniSpeechPreTrainedModel):
)
+_HIDDEN_STATES_START_POSITION = 2
+
+# CTC docstring
+_CTC_EXPECTED_OUTPUT = "'mister quilter is the apposl of the midle classes and weare glad to welcom his gosepl'"
+_CTC_EXPECTED_LOSS = 17.17
+
+
@add_start_docstrings(
"""UniSpeech Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC).""",
UNISPEECH_START_DOCSTRING,
@@ -1620,7 +1589,6 @@ class UniSpeechForPreTraining(UniSpeechPreTrainedModel):
by default.
""",
)
-# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForCTC with Wav2Vec2->UniSpeech, wav2vec2->unispeech, WAV_2_VEC_2->UNISPEECH
class UniSpeechForCTC(UniSpeechPreTrainedModel):
def __init__(self, config, target_lang: Optional[str] = None):
super().__init__(config)
@@ -1797,7 +1765,6 @@ class UniSpeechForSequenceClassification(UniSpeechPreTrainedModel):
# Initialize weights and apply final processing
self.post_init()
- # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.freeze_feature_extractor
def freeze_feature_extractor(self):
"""
Calling this function will disable the gradient computation for the feature encoder so that its parameters will
@@ -1810,7 +1777,6 @@ class UniSpeechForSequenceClassification(UniSpeechPreTrainedModel):
)
self.freeze_feature_encoder()
- # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.freeze_feature_encoder with wav2vec2->unispeech
def freeze_feature_encoder(self):
"""
Calling this function will disable the gradient computation for the feature encoder so that its parameter will
@@ -1818,7 +1784,6 @@ class UniSpeechForSequenceClassification(UniSpeechPreTrainedModel):
"""
self.unispeech.feature_extractor._freeze_parameters()
- # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.freeze_base_model with wav2vec2->unispeech
def freeze_base_model(self):
"""
Calling this function will disable the gradient computation for the base model so that its parameters will not
@@ -1834,7 +1799,6 @@ class UniSpeechForSequenceClassification(UniSpeechPreTrainedModel):
config_class=_CONFIG_FOR_DOC,
modality="audio",
)
- # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.forward with Wav2Vec2->UniSpeech, wav2vec2->unispeech
def forward(
self,
input_values: Optional[torch.Tensor],
diff --git a/src/transformers/models/unispeech/modular_unispeech.py b/src/transformers/models/unispeech/modular_unispeech.py
new file mode 100644
index 00000000000..1096bc559b4
--- /dev/null
+++ b/src/transformers/models/unispeech/modular_unispeech.py
@@ -0,0 +1,563 @@
+import math
+import warnings
+from dataclasses import dataclass
+from typing import Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+
+from ...modeling_outputs import CausalLMOutput, ModelOutput, SequenceClassifierOutput, Wav2Vec2BaseModelOutput
+from ...modeling_utils import PreTrainedModel
+from ...utils import (
+ add_code_sample_docstrings,
+ add_start_docstrings,
+ add_start_docstrings_to_model_forward,
+ logging,
+ replace_return_docstrings,
+)
+from ..wav2vec2.modeling_wav2vec2 import (
+ Wav2Vec2Encoder,
+ Wav2Vec2EncoderStableLayerNorm,
+ Wav2Vec2FeatureEncoder,
+ Wav2Vec2FeatureProjection,
+ Wav2Vec2ForCTC,
+ Wav2Vec2ForSequenceClassification,
+ Wav2Vec2GumbelVectorQuantizer,
+ Wav2Vec2Model,
+ Wav2Vec2PositionalConvEmbedding,
+)
+from .configuration_unispeech import UniSpeechConfig
+
+
+logger = logging.get_logger(__name__)
+
+
+_HIDDEN_STATES_START_POSITION = 2
+
+# General docstring
+_CONFIG_FOR_DOC = "UniSpeechConfig"
+
+# Base docstring
+_CHECKPOINT_FOR_DOC = "patrickvonplaten/unispeech-large-1500h-cv-timit"
+_EXPECTED_OUTPUT_SHAPE = [1, 292, 1024]
+
+# CTC docstring
+_CTC_EXPECTED_OUTPUT = "'mister quilter is the apposl of the midle classes and weare glad to welcom his gosepl'"
+_CTC_EXPECTED_LOSS = 17.17
+
+
+@dataclass
+class UniSpeechForPreTrainingOutput(ModelOutput):
+ """
+ Output type of [`UniSpeechForPreTrainingOutput`], with potential hidden states and attentions.
+
+ Args:
+ loss (*optional*, returned when model is in train mode, `torch.FloatTensor` of shape `(1,)`):
+ Total loss as the sum of the contrastive loss (L_m) and the diversity loss (L_d) as stated in the [official
+ paper](https://arxiv.org/pdf/2006.11477.pdf) . (classification) loss.
+ projected_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.proj_codevector_dim)`):
+ Hidden-states of the model projected to *config.proj_codevector_dim* that can be used to predict the masked
+ projected quantized states.
+ projected_quantized_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.proj_codevector_dim)`):
+ Quantized extracted feature vectors projected to *config.proj_codevector_dim* representing the positive
+ target vectors for contrastive loss.
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
+ shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ projected_states: Optional[torch.FloatTensor] = None
+ projected_quantized_states: Optional[torch.FloatTensor] = None
+ codevector_perplexity: Optional[torch.FloatTensor] = None
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
+
+
+class UniSpeechPositionalConvEmbedding(Wav2Vec2PositionalConvEmbedding):
+ pass
+
+
+class UniSpeechFeatureEncoder(Wav2Vec2FeatureEncoder):
+ pass
+
+
+class UniSpeechFeatureProjection(Wav2Vec2FeatureProjection):
+ pass
+
+
+class UniSpeechEncoder(Wav2Vec2Encoder):
+ pass
+
+
+class UniSpeechEncoderStableLayerNorm(Wav2Vec2EncoderStableLayerNorm):
+ pass
+
+
+class UniSpeechGumbelVectorQuantizer(Wav2Vec2GumbelVectorQuantizer):
+ @staticmethod
+ def _compute_perplexity(probs):
+ marginal_probs = probs.mean(dim=0)
+ perplexity = torch.exp(-torch.sum(marginal_probs * torch.log(marginal_probs + 1e-7), dim=-1)).sum()
+ return perplexity
+
+ def forward(self, hidden_states):
+ batch_size, sequence_length, hidden_size = hidden_states.shape
+
+ # project to codevector dim
+ hidden_states = self.weight_proj(hidden_states)
+ hidden_states = hidden_states.view(batch_size * sequence_length * self.num_groups, -1)
+
+ if self.training:
+ # sample code vector probs via gumbel in differentiateable way
+ codevector_probs = nn.functional.gumbel_softmax(
+ hidden_states.float(), tau=self.temperature, hard=True
+ ).type_as(hidden_states)
+
+ # compute perplexity
+ codevector_soft_dist = torch.softmax(
+ hidden_states.view(batch_size * sequence_length, self.num_groups, -1).float(), dim=-1
+ )
+ perplexity = self._compute_perplexity(codevector_soft_dist)
+ else:
+ # take argmax in non-differentiable way
+ # comptute hard codevector distribution (one hot)
+ codevector_idx = hidden_states.argmax(dim=-1)
+ codevector_probs = hidden_states.new_zeros(*hidden_states.shape).scatter_(
+ -1, codevector_idx.view(-1, 1), 1.0
+ )
+ codevector_probs = codevector_probs.view(batch_size * sequence_length, self.num_groups, -1)
+
+ perplexity = self._compute_perplexity(codevector_probs)
+
+ codevector_probs = codevector_probs.view(batch_size * sequence_length, -1)
+ # use probs to retrieve codevectors
+ codevectors_per_group = codevector_probs.unsqueeze(-1) * self.codevectors
+ codevectors = codevectors_per_group.view(batch_size * sequence_length, self.num_groups, self.num_vars, -1)
+ codevectors = codevectors.sum(-2).view(batch_size, sequence_length, -1)
+
+ return codevectors, perplexity
+
+
+class UniSpeechPreTrainedModel(PreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = UniSpeechConfig
+ base_model_prefix = "unispeech"
+ main_input_name = "input_values"
+ supports_gradient_checkpointing = True
+ _supports_flash_attn_2 = True
+ _supports_sdpa = True
+
+ def _init_weights(self, module):
+ """Initialize the weights"""
+ # gumbel softmax requires special init
+ if isinstance(module, UniSpeechGumbelVectorQuantizer):
+ module.weight_proj.weight.data.normal_(mean=0.0, std=1)
+ module.weight_proj.bias.data.zero_()
+ nn.init.uniform_(module.codevectors)
+ elif isinstance(module, UniSpeechPositionalConvEmbedding):
+ nn.init.normal_(
+ module.conv.weight,
+ mean=0,
+ std=2 * math.sqrt(1 / (module.conv.kernel_size[0] * module.conv.in_channels)),
+ )
+ nn.init.constant_(module.conv.bias, 0)
+ elif isinstance(module, UniSpeechFeatureProjection):
+ k = math.sqrt(1 / module.projection.in_features)
+ nn.init.uniform_(module.projection.weight, a=-k, b=k)
+ nn.init.uniform_(module.projection.bias, a=-k, b=k)
+ elif isinstance(module, nn.Linear):
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+ elif isinstance(module, nn.Conv1d):
+ nn.init.kaiming_normal_(module.weight)
+
+ if module.bias is not None:
+ k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0]))
+ nn.init.uniform_(module.bias, a=-k, b=k)
+
+ def _get_feat_extract_output_lengths(self, input_lengths: Union[torch.LongTensor, int]):
+ """
+ Computes the output length of the convolutional layers
+ """
+
+ def _conv_out_length(input_length, kernel_size, stride):
+ # 1D convolutional layer output length formula taken
+ # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
+ return torch.div(input_length - kernel_size, stride, rounding_mode="floor") + 1
+
+ for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride):
+ input_lengths = _conv_out_length(input_lengths, kernel_size, stride)
+
+ return input_lengths
+
+ def _get_feature_vector_attention_mask(self, feature_vector_length: int, attention_mask: torch.LongTensor):
+ # Effectively attention_mask.sum(-1), but not inplace to be able to run
+ # on inference mode.
+ non_padded_lengths = attention_mask.cumsum(dim=-1)[:, -1]
+ output_lengths = self._get_feat_extract_output_lengths(non_padded_lengths).to(torch.long)
+ batch_size = attention_mask.shape[0]
+
+ attention_mask = torch.zeros(
+ (batch_size, feature_vector_length), dtype=attention_mask.dtype, device=attention_mask.device
+ )
+ # these two operations makes sure that all values before the output lengths idxs are attended to
+ attention_mask[(torch.arange(attention_mask.shape[0], device=attention_mask.device), output_lengths - 1)] = 1
+ attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool()
+ return attention_mask
+
+
+UNISPEECH_START_DOCSTRING = r"""
+ UniSpeech was proposed in [UniSpeech: Unified Speech Representation Learning with Labeled and Unlabeled
+ Data](https://arxiv.org/abs/2101.07597) by Chengyi Wang, Yu Wu, Yao Qian, Kenichi Kumatani, Shujie Liu, Furu Wei,
+ Michael Zeng, Xuedong Huang.
+
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
+ library implements for all its model (such as downloading or saving etc.).
+
+ This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use
+ it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
+ behavior.
+
+ Parameters:
+ config ([`UniSpeechConfig`]): Model configuration class with all the parameters of the model.
+ Initializing with a config file does not load the weights associated with the model, only the
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+UNISPEECH_INPUTS_DOCSTRING = r"""
+ Args:
+ input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
+ Float values of input raw speech waveform. Values can be obtained by loading a `.flac` or `.wav` audio file
+ into an array of type `List[float]` or a `numpy.ndarray`, *e.g.* via the soundfile library (`pip install
+ soundfile`). To prepare the array into `input_values`, the [`AutoProcessor`] should be used for padding and
+ conversion into a tensor of type `torch.FloatTensor`. See [`Wav2Vec2Processor.__call__`] for details.
+ attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing convolution and attention on padding token indices. Mask values selected in `[0,
+ 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+
+
+
+ `attention_mask` should only be passed if the corresponding processor has `config.return_attention_mask ==
+ True`. For all models whose processor has `config.return_attention_mask == False`, `attention_mask` should
+ **not** be passed to avoid degraded performance when doing batched inference. For such models
+ `input_values` should simply be padded with 0 and passed without `attention_mask`. Be aware that these
+ models also yield slightly different results depending on whether `input_values` is padded or not.
+
+
+
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+
+
+UniSpeechBaseModelOutput = Wav2Vec2BaseModelOutput
+
+
+@add_start_docstrings(
+ "The bare UniSpeech Model transformer outputting raw hidden-states without any specific head on top.",
+ UNISPEECH_START_DOCSTRING,
+)
+class UniSpeechModel(UniSpeechPreTrainedModel, Wav2Vec2Model):
+ def __init__(self, config: UniSpeechConfig):
+ UniSpeechPreTrainedModel.__init__(config)
+ self.config = config
+ self.feature_extractor = UniSpeechFeatureEncoder(config)
+ self.feature_projection = UniSpeechFeatureProjection(config)
+
+ if config.mask_time_prob > 0.0 or config.mask_feature_prob > 0.0:
+ self.masked_spec_embed = nn.Parameter(torch.Tensor(config.hidden_size).uniform_())
+
+ if config.do_stable_layer_norm:
+ self.encoder = UniSpeechEncoderStableLayerNorm(config)
+ else:
+ self.encoder = UniSpeechEncoder(config)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def freeze_feature_extractor(self):
+ raise AttributeError("Not needed for UniSpeech")
+
+ def freeze_feature_encoder(self):
+ raise AttributeError("Not needed for UniSpeech")
+
+ @add_start_docstrings_to_model_forward(UNISPEECH_INPUTS_DOCSTRING)
+ @add_code_sample_docstrings(
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=UniSpeechBaseModelOutput,
+ config_class=_CONFIG_FOR_DOC,
+ modality="audio",
+ expected_output=_EXPECTED_OUTPUT_SHAPE,
+ )
+ def forward(
+ self,
+ input_values: Optional[torch.Tensor],
+ attention_mask: Optional[torch.Tensor] = None,
+ mask_time_indices: Optional[torch.FloatTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, UniSpeechBaseModelOutput]:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ extract_features = self.feature_extractor(input_values)
+ extract_features = extract_features.transpose(1, 2)
+
+ if attention_mask is not None:
+ # compute reduced attention_mask corresponding to feature vectors
+ attention_mask = self._get_feature_vector_attention_mask(extract_features.shape[1], attention_mask)
+
+ hidden_states, extract_features = self.feature_projection(extract_features)
+ hidden_states = self._mask_hidden_states(
+ hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask
+ )
+
+ encoder_outputs = self.encoder(
+ hidden_states,
+ attention_mask=attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ hidden_states = encoder_outputs[0]
+
+ if not return_dict:
+ return (hidden_states, extract_features) + encoder_outputs[1:]
+
+ return UniSpeechBaseModelOutput(
+ last_hidden_state=hidden_states,
+ extract_features=extract_features,
+ hidden_states=encoder_outputs.hidden_states,
+ attentions=encoder_outputs.attentions,
+ )
+
+
+@add_start_docstrings(
+ """UniSpeech Model with a vector-quantization module and ctc loss for pre-training.""", UNISPEECH_START_DOCSTRING
+)
+class UniSpeechForPreTraining(UniSpeechPreTrainedModel):
+ def __init__(self, config: UniSpeechConfig):
+ super().__init__(config)
+ self.unispeech = UniSpeechModel(config)
+ self.dropout_features = nn.Dropout(config.feat_quantizer_dropout)
+
+ self.quantizer = UniSpeechGumbelVectorQuantizer(config)
+ self.project_q = nn.Linear(config.codevector_dim, config.proj_codevector_dim)
+ self.project_hid = nn.Linear(config.proj_codevector_dim, config.hidden_size)
+
+ self.ctc_proj = nn.Linear(config.hidden_size, config.num_ctc_classes)
+ self.dropout = nn.Dropout(config.final_dropout)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def set_gumbel_temperature(self, temperature: int):
+ """
+ Set the Gumbel softmax temperature to a given value. Only necessary for training
+ """
+ self.quantizer.temperature = temperature
+
+ def freeze_feature_extractor(self):
+ """
+ Calling this function will disable the gradient computation for the feature encoder so that its parameters will
+ not be updated during training.
+ """
+ warnings.warn(
+ "The method `freeze_feature_extractor` is deprecated and will be removed in Transformers v5. "
+ "Please use the equivalent `freeze_feature_encoder` method instead.",
+ FutureWarning,
+ )
+ self.freeze_feature_encoder()
+
+ def freeze_feature_encoder(self):
+ """
+ Calling this function will disable the gradient computation for the feature encoder so that its parameter will
+ not be updated during training.
+ """
+ self.unispeech.feature_extractor._freeze_parameters()
+
+ @staticmethod
+ def compute_contrastive_logits(
+ target_features: torch.FloatTensor,
+ negative_features: torch.FloatTensor,
+ predicted_features: torch.FloatTensor,
+ temperature: int = 1,
+ ):
+ """
+ Compute logits for contrastive loss based using cosine similarity as the distance measure between
+ `[positive_feature, negative_features]` and `[predicted_features]`. Additionally, temperature can be applied.
+ """
+ target_features = torch.cat([target_features, negative_features], dim=0)
+
+ logits = torch.cosine_similarity(predicted_features.float(), target_features.float(), dim=-1)
+ logits = logits.type_as(target_features)
+
+ # apply temperature
+ logits = logits / temperature
+ return logits
+
+ @add_start_docstrings_to_model_forward(UNISPEECH_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=UniSpeechForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
+ def forward(
+ self,
+ input_values: Optional[torch.Tensor],
+ attention_mask: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, UniSpeechForPreTrainingOutput]:
+ r"""
+ mask_time_indices (`torch.BoolTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Indices to mask extracted features for contrastive loss. When in training mode, model learns to predict
+ masked extracted features in *config.proj_codevector_dim* space.
+ sampled_negative_indices (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_negatives)`, *optional*):
+ Indices indicating which quantized target vectors are used as negative sampled vectors in contrastive loss.
+ Required input for pre-training.
+
+ Returns:
+
+ Example:
+
+ ```python
+ >>> import torch
+ >>> from transformers import AutoFeatureExtractor, UniSpeechForPreTraining
+
+ >>> feature_extractor = AutoFeatureExtractor.from_pretrained("microsoft/unispeech-large-1500h-cv")
+ >>> model = UniSpeechForPreTraining.from_pretrained("microsoft/unispeech-large-1500h-cv")
+ >>> # TODO: Add full pretraining example
+ ```"""
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.unispeech(
+ input_values,
+ attention_mask=attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ transformer_features = outputs[0]
+
+ # quantize all (unmasked) extracted features and project to final vq dim
+ extract_features = self.dropout_features(outputs[1])
+ quantized_features, codevector_perplexity = self.quantizer(extract_features)
+
+ # project quantized features twice
+ quantized_features = self.project_q(quantized_features.to(self.project_q.weight.dtype))
+ quantized_features = self.project_hid(quantized_features)
+
+ prob_replace_matrix = torch.empty(transformer_features.size(0), transformer_features.size(1)).fill_(
+ self.config.replace_prob
+ )
+ prob_replace_matrix = prob_replace_matrix.transpose(0, 1)
+ sampled_replace_matrix = torch.bernoulli(prob_replace_matrix).bool().to(transformer_features.device)
+ sampled_replace_matrix = sampled_replace_matrix.transpose(0, 1)
+ sampled_replace_matrix = sampled_replace_matrix.unsqueeze(-1)
+ logits = transformer_features.masked_fill(sampled_replace_matrix, 0.0) + (
+ quantized_features.masked_fill(~sampled_replace_matrix, 0.0)
+ )
+
+ # project to ctc units
+ logits = self.dropout(logits)
+ logits = self.ctc_proj(logits)
+
+ # TODO(PVP) - add negative sampling & loss computation
+ loss = None
+ if not return_dict:
+ if loss is not None:
+ return (loss, transformer_features, quantized_features, codevector_perplexity) + outputs[2:]
+ return (transformer_features, quantized_features, codevector_perplexity) + outputs[2:]
+
+ return UniSpeechForPreTrainingOutput(
+ loss=loss,
+ projected_states=transformer_features,
+ projected_quantized_states=quantized_features,
+ codevector_perplexity=codevector_perplexity,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+@add_start_docstrings(
+ """UniSpeech Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC).""",
+ UNISPEECH_START_DOCSTRING,
+ """
+ target_lang (`str`, *optional*):
+ Language id of adapter weights. Adapter weights are stored in the format adapter..safetensors or
+ adapter..bin. Only relevant when using an instance of [`UniSpeechForCTC`] with adapters. Uses 'eng'
+ by default.
+ """,
+)
+class UniSpeechForCTC(Wav2Vec2ForCTC):
+ @add_start_docstrings_to_model_forward(UNISPEECH_INPUTS_DOCSTRING)
+ @add_code_sample_docstrings(
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=CausalLMOutput,
+ config_class=_CONFIG_FOR_DOC,
+ expected_output=_CTC_EXPECTED_OUTPUT,
+ expected_loss=_CTC_EXPECTED_LOSS,
+ )
+ def forward(self, **super_kwargs):
+ super().forward(**super_kwargs)
+
+
+@add_start_docstrings(
+ """
+ UniSpeech Model with a sequence classification head on top (a linear layer over the pooled output) for tasks like
+ SUPERB Keyword Spotting.
+ """,
+ UNISPEECH_START_DOCSTRING,
+)
+class UniSpeechForSequenceClassification(Wav2Vec2ForSequenceClassification):
+ @add_start_docstrings_to_model_forward(UNISPEECH_INPUTS_DOCSTRING)
+ @add_code_sample_docstrings(
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=SequenceClassifierOutput,
+ config_class=_CONFIG_FOR_DOC,
+ modality="audio",
+ )
+ def forward(self, **super_kwargs):
+ super().forward(**super_kwargs)
+
+
+__all__ = [
+ "UniSpeechForCTC",
+ "UniSpeechForPreTraining",
+ "UniSpeechForSequenceClassification",
+ "UniSpeechModel",
+ "UniSpeechPreTrainedModel",
+]
diff --git a/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py b/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py
index 0eb035bac17..b1f8c4c3466 100755
--- a/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py
+++ b/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py
@@ -1,19 +1,9 @@
-# coding=utf-8
-# Copyright 2021 The Fairseq Authors and the HuggingFace Inc. team. All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-"""PyTorch UniSpeechSat model."""
-
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/unispeech_sat/modular_unispeech_sat.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_unispeech_sat.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
import math
import warnings
from dataclasses import dataclass
@@ -21,8 +11,7 @@ from typing import Optional, Tuple, Union
import numpy as np
import torch
-import torch.utils.checkpoint
-from torch import nn
+import torch.nn as nn
from torch.nn import CrossEntropyLoss
from ...activations import ACT2FN
@@ -32,6 +21,7 @@ from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask,
from ...modeling_outputs import (
BaseModelOutput,
CausalLMOutput,
+ ModelOutput,
SequenceClassifierOutput,
TokenClassifierOutput,
Wav2Vec2BaseModelOutput,
@@ -39,7 +29,6 @@ from ...modeling_outputs import (
)
from ...modeling_utils import PreTrainedModel
from ...utils import (
- ModelOutput,
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
@@ -56,28 +45,12 @@ if is_flash_attn_available():
logger = logging.get_logger(__name__)
-
-_HIDDEN_STATES_START_POSITION = 2
+# Base docstring
+_CHECKPOINT_FOR_DOC = "microsoft/unispeech-sat-base-100h-libri-ft"
# General docstring
_CONFIG_FOR_DOC = "UniSpeechSatConfig"
-# Base docstring
-_CHECKPOINT_FOR_DOC = "microsoft/unispeech-sat-base-100h-libri-ft"
-_EXPECTED_OUTPUT_SHAPE = [1, 292, 768]
-
-# CTC docstring
-_CTC_EXPECTED_OUTPUT = "'MISTER QUILDER IS THE APOSTLE OF THE MIDDLE CLASSES AND WE ARE GLAD TO WELCOME HIS GOSPEL'"
-_CTC_EXPECTED_LOSS = 39.88
-
-# Frame class docstring
-_FRAME_CLASS_CHECKPOINT = "microsoft/unispeech-sat-base-plus-sd"
-_FRAME_EXPECTED_OUTPUT = [0, 0]
-
-# Speaker Verification docstring
-_XVECTOR_CHECKPOINT = "microsoft/unispeech-sat-base-plus-sv"
-_XVECTOR_EXPECTED_OUTPUT = 0.97
-
@dataclass
class UniSpeechSatForPreTrainingOutput(ModelOutput):
@@ -116,202 +89,17 @@ class UniSpeechSatForPreTrainingOutput(ModelOutput):
attentions: Optional[Tuple[torch.FloatTensor]] = None
-# Copied from transformers.models.wav2vec2.modeling_wav2vec2._compute_mask_indices
-def _compute_mask_indices(
- shape: Tuple[int, int],
- mask_prob: float,
- mask_length: int,
- attention_mask: Optional[torch.LongTensor] = None,
- min_masks: int = 0,
-) -> np.ndarray:
- """
- Computes random mask spans for a given shape. Used to implement [SpecAugment: A Simple Data Augmentation Method for
- ASR](https://arxiv.org/abs/1904.08779). Note that this method is not optimized to run on TPU and should be run on
- CPU as part of the preprocessing during training.
-
- Args:
- shape: The shape for which to compute masks. This should be of a tuple of size 2 where
- the first element is the batch size and the second element is the length of the axis to span.
- mask_prob: The percentage of the whole axis (between 0 and 1) which will be masked. The number of
- independently generated mask spans of length `mask_length` is computed by
- `mask_prob*shape[1]/mask_length`. Note that due to overlaps, `mask_prob` is an upper bound and the
- actual percentage will be smaller.
- mask_length: size of the mask
- min_masks: minimum number of masked spans
- attention_mask: A (right-padded) attention mask which independently shortens the feature axis of
- each batch dimension.
- """
- batch_size, sequence_length = shape
-
- if mask_length < 1:
- raise ValueError("`mask_length` has to be bigger than 0.")
-
- if mask_length > sequence_length:
- raise ValueError(
- f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length}"
- f" and `sequence_length`: {sequence_length}`"
- )
-
- # epsilon is used for probabilistic rounding
- epsilon = np.random.rand(1).item()
-
- def compute_num_masked_span(input_length):
- """Given input length, compute how many spans should be masked"""
- num_masked_span = int(mask_prob * input_length / mask_length + epsilon)
- num_masked_span = max(num_masked_span, min_masks)
-
- # make sure num masked span <= sequence_length
- if num_masked_span * mask_length > sequence_length:
- num_masked_span = sequence_length // mask_length
-
- # make sure num_masked span is also <= input_length - (mask_length - 1)
- if input_length - (mask_length - 1) < num_masked_span:
- num_masked_span = max(input_length - (mask_length - 1), 0)
-
- return num_masked_span
-
- # compute number of masked spans in batch
- input_lengths = (
- attention_mask.detach().sum(-1).tolist()
- if attention_mask is not None
- else [sequence_length for _ in range(batch_size)]
- )
-
- # SpecAugment mask to fill
- spec_aug_mask = np.zeros((batch_size, sequence_length), dtype=bool)
- spec_aug_mask_idxs = []
-
- max_num_masked_span = compute_num_masked_span(sequence_length)
-
- if max_num_masked_span == 0:
- return spec_aug_mask
-
- for input_length in input_lengths:
- # compute num of masked spans for this input
- num_masked_span = compute_num_masked_span(input_length)
-
- # get random indices to mask
- spec_aug_mask_idx = np.random.choice(
- np.arange(input_length - (mask_length - 1)), num_masked_span, replace=False
- )
-
- # pick first sampled index that will serve as a dummy index to pad vector
- # to ensure same dimension for all batches due to probabilistic rounding
- # Picking first sample just pads those vectors twice.
- if len(spec_aug_mask_idx) == 0:
- # this case can only happen if `input_length` is strictly smaller then
- # `sequence_length` in which case the last token has to be a padding
- # token which we can use as a dummy mask id
- dummy_mask_idx = sequence_length - 1
- else:
- dummy_mask_idx = spec_aug_mask_idx[0]
-
- spec_aug_mask_idx = np.concatenate(
- [spec_aug_mask_idx, np.ones(max_num_masked_span - num_masked_span, dtype=np.int32) * dummy_mask_idx]
- )
- spec_aug_mask_idxs.append(spec_aug_mask_idx)
-
- spec_aug_mask_idxs = np.array(spec_aug_mask_idxs)
-
- # expand masked indices to masked spans
- spec_aug_mask_idxs = np.broadcast_to(
- spec_aug_mask_idxs[:, :, None], (batch_size, max_num_masked_span, mask_length)
- )
- spec_aug_mask_idxs = spec_aug_mask_idxs.reshape(batch_size, max_num_masked_span * mask_length)
-
- # add offset to the starting indexes so that indexes now create a span
- offsets = np.arange(mask_length)[None, None, :]
- offsets = np.broadcast_to(offsets, (batch_size, max_num_masked_span, mask_length)).reshape(
- batch_size, max_num_masked_span * mask_length
- )
- spec_aug_mask_idxs = spec_aug_mask_idxs + offsets
-
- # ensure that we cannot have indices larger than sequence_length
- if spec_aug_mask_idxs.max() > sequence_length - 1:
- spec_aug_mask_idxs[spec_aug_mask_idxs > sequence_length - 1] = sequence_length - 1
-
- # scatter indices to mask
- np.put_along_axis(spec_aug_mask, spec_aug_mask_idxs, 1, -1)
-
- return spec_aug_mask
-
-
-# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2NoLayerNormConvLayer with Wav2Vec2->UniSpeechSat
-class UniSpeechSatNoLayerNormConvLayer(nn.Module):
- def __init__(self, config, layer_id=0):
+class UniSpeechSatSamePadLayer(nn.Module):
+ def __init__(self, num_conv_pos_embeddings):
super().__init__()
- self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
- self.out_conv_dim = config.conv_dim[layer_id]
-
- self.conv = nn.Conv1d(
- self.in_conv_dim,
- self.out_conv_dim,
- kernel_size=config.conv_kernel[layer_id],
- stride=config.conv_stride[layer_id],
- bias=config.conv_bias,
- )
- self.activation = ACT2FN[config.feat_extract_activation]
+ self.num_pad_remove = 1 if num_conv_pos_embeddings % 2 == 0 else 0
def forward(self, hidden_states):
- hidden_states = self.conv(hidden_states)
- hidden_states = self.activation(hidden_states)
+ if self.num_pad_remove > 0:
+ hidden_states = hidden_states[:, :, : -self.num_pad_remove]
return hidden_states
-# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2LayerNormConvLayer with Wav2Vec2->UniSpeechSat
-class UniSpeechSatLayerNormConvLayer(nn.Module):
- def __init__(self, config, layer_id=0):
- super().__init__()
- self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
- self.out_conv_dim = config.conv_dim[layer_id]
-
- self.conv = nn.Conv1d(
- self.in_conv_dim,
- self.out_conv_dim,
- kernel_size=config.conv_kernel[layer_id],
- stride=config.conv_stride[layer_id],
- bias=config.conv_bias,
- )
- self.layer_norm = nn.LayerNorm(self.out_conv_dim, elementwise_affine=True)
- self.activation = ACT2FN[config.feat_extract_activation]
-
- def forward(self, hidden_states):
- hidden_states = self.conv(hidden_states)
-
- hidden_states = hidden_states.transpose(-2, -1)
- hidden_states = self.layer_norm(hidden_states)
- hidden_states = hidden_states.transpose(-2, -1)
-
- hidden_states = self.activation(hidden_states)
- return hidden_states
-
-
-# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2GroupNormConvLayer with Wav2Vec2->UniSpeechSat
-class UniSpeechSatGroupNormConvLayer(nn.Module):
- def __init__(self, config, layer_id=0):
- super().__init__()
- self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
- self.out_conv_dim = config.conv_dim[layer_id]
-
- self.conv = nn.Conv1d(
- self.in_conv_dim,
- self.out_conv_dim,
- kernel_size=config.conv_kernel[layer_id],
- stride=config.conv_stride[layer_id],
- bias=config.conv_bias,
- )
- self.activation = ACT2FN[config.feat_extract_activation]
-
- self.layer_norm = nn.GroupNorm(num_groups=self.out_conv_dim, num_channels=self.out_conv_dim, affine=True)
-
- def forward(self, hidden_states):
- hidden_states = self.conv(hidden_states)
- hidden_states = self.layer_norm(hidden_states)
- hidden_states = self.activation(hidden_states)
- return hidden_states
-
-
-# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2PositionalConvEmbedding with Wav2Vec2->UniSpeechSat
class UniSpeechSatPositionalConvEmbedding(nn.Module):
def __init__(self, config):
super().__init__()
@@ -357,19 +145,78 @@ class UniSpeechSatPositionalConvEmbedding(nn.Module):
return hidden_states
-# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2SamePadLayer with Wav2Vec2->UniSpeechSat
-class UniSpeechSatSamePadLayer(nn.Module):
- def __init__(self, num_conv_pos_embeddings):
+class UniSpeechSatNoLayerNormConvLayer(nn.Module):
+ def __init__(self, config, layer_id=0):
super().__init__()
- self.num_pad_remove = 1 if num_conv_pos_embeddings % 2 == 0 else 0
+ self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
+ self.out_conv_dim = config.conv_dim[layer_id]
+
+ self.conv = nn.Conv1d(
+ self.in_conv_dim,
+ self.out_conv_dim,
+ kernel_size=config.conv_kernel[layer_id],
+ stride=config.conv_stride[layer_id],
+ bias=config.conv_bias,
+ )
+ self.activation = ACT2FN[config.feat_extract_activation]
def forward(self, hidden_states):
- if self.num_pad_remove > 0:
- hidden_states = hidden_states[:, :, : -self.num_pad_remove]
+ hidden_states = self.conv(hidden_states)
+ hidden_states = self.activation(hidden_states)
+ return hidden_states
+
+
+class UniSpeechSatLayerNormConvLayer(nn.Module):
+ def __init__(self, config, layer_id=0):
+ super().__init__()
+ self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
+ self.out_conv_dim = config.conv_dim[layer_id]
+
+ self.conv = nn.Conv1d(
+ self.in_conv_dim,
+ self.out_conv_dim,
+ kernel_size=config.conv_kernel[layer_id],
+ stride=config.conv_stride[layer_id],
+ bias=config.conv_bias,
+ )
+ self.layer_norm = nn.LayerNorm(self.out_conv_dim, elementwise_affine=True)
+ self.activation = ACT2FN[config.feat_extract_activation]
+
+ def forward(self, hidden_states):
+ hidden_states = self.conv(hidden_states)
+
+ hidden_states = hidden_states.transpose(-2, -1)
+ hidden_states = self.layer_norm(hidden_states)
+ hidden_states = hidden_states.transpose(-2, -1)
+
+ hidden_states = self.activation(hidden_states)
+ return hidden_states
+
+
+class UniSpeechSatGroupNormConvLayer(nn.Module):
+ def __init__(self, config, layer_id=0):
+ super().__init__()
+ self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
+ self.out_conv_dim = config.conv_dim[layer_id]
+
+ self.conv = nn.Conv1d(
+ self.in_conv_dim,
+ self.out_conv_dim,
+ kernel_size=config.conv_kernel[layer_id],
+ stride=config.conv_stride[layer_id],
+ bias=config.conv_bias,
+ )
+ self.activation = ACT2FN[config.feat_extract_activation]
+
+ self.layer_norm = nn.GroupNorm(num_groups=self.out_conv_dim, num_channels=self.out_conv_dim, affine=True)
+
+ def forward(self, hidden_states):
+ hidden_states = self.conv(hidden_states)
+ hidden_states = self.layer_norm(hidden_states)
+ hidden_states = self.activation(hidden_states)
return hidden_states
-# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2FeatureEncoder with Wav2Vec2->UniSpeechSat
class UniSpeechSatFeatureEncoder(nn.Module):
"""Construct the features from raw audio waveform"""
@@ -417,18 +264,6 @@ class UniSpeechSatFeatureEncoder(nn.Module):
return hidden_states
-class UniSpeechSatFeatureExtractor(UniSpeechSatFeatureEncoder):
- def __init__(self, config):
- super().__init__(config)
- warnings.warn(
- f"The class `{self.__class__.__name__}` has been depreciated "
- "and will be removed in Transformers v5. "
- f"Use `{self.__class__.__bases__[0].__name__}` instead.",
- FutureWarning,
- )
-
-
-# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2FeatureProjection with Wav2Vec2->UniSpeechSat
class UniSpeechSatFeatureProjection(nn.Module):
def __init__(self, config):
super().__init__()
@@ -444,7 +279,6 @@ class UniSpeechSatFeatureProjection(nn.Module):
return hidden_states, norm_hidden_states
-# Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->UniSpeechSat
class UniSpeechSatAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
@@ -603,7 +437,6 @@ class UniSpeechSatAttention(nn.Module):
return attn_output, attn_weights_reshaped, past_key_value
-# Copied from transformers.models.bart.modeling_bart.BartFlashAttention2 with Bart->UniSpeechSat
class UniSpeechSatFlashAttention2(UniSpeechSatAttention):
"""
UniSpeechSat flash attention module. This module inherits from `UniSpeechSatAttention` as the weights of the module stays
@@ -731,7 +564,6 @@ class UniSpeechSatFlashAttention2(UniSpeechSatAttention):
class UniSpeechSatSdpaAttention(UniSpeechSatAttention):
- # Copied from transformers.models.bart.modeling_bart.BartSdpaAttention.forward with Bart->UniSpeechSat
def forward(
self,
hidden_states: torch.Tensor,
@@ -837,14 +669,6 @@ class UniSpeechSatSdpaAttention(UniSpeechSatAttention):
return attn_output, None, past_key_value
-UNISPEECHSAT_ATTENTION_CLASSES = {
- "eager": UniSpeechSatAttention,
- "sdpa": UniSpeechSatSdpaAttention,
- "flash_attention_2": UniSpeechSatFlashAttention2,
-}
-
-
-# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2FeedForward with Wav2Vec2->UniSpeechSat
class UniSpeechSatFeedForward(nn.Module):
def __init__(self, config):
super().__init__()
@@ -869,11 +693,17 @@ class UniSpeechSatFeedForward(nn.Module):
return hidden_states
-# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2EncoderLayer with Wav2Vec2->UniSpeechSat, WAV2VEC2->UNISPEECHSAT
+UNISPEECH_SAT_ATTENTION_CLASSES = {
+ "eager": UniSpeechSatAttention,
+ "sdpa": UniSpeechSatSdpaAttention,
+ "flash_attention_2": UniSpeechSatFlashAttention2,
+}
+
+
class UniSpeechSatEncoderLayer(nn.Module):
def __init__(self, config):
super().__init__()
- self.attention = UNISPEECHSAT_ATTENTION_CLASSES[config._attn_implementation](
+ self.attention = UNISPEECH_SAT_ATTENTION_CLASSES[config._attn_implementation](
embed_dim=config.hidden_size,
num_heads=config.num_attention_heads,
dropout=config.attention_dropout,
@@ -905,79 +735,6 @@ class UniSpeechSatEncoderLayer(nn.Module):
return outputs
-# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2AttnAdapterLayer with Wav2Vec2->UniSpeechSat
-class UniSpeechSatAttnAdapterLayer(nn.Module):
- def __init__(self, config):
- """
- Implements adapter modules directly with 3D tensor weight as parameters and without using ModuleList to speed
- up training throughput.
- """
- super().__init__()
- self.input_dim = config.adapter_attn_dim
- self.hidden_dim = config.hidden_size
-
- self.norm = nn.LayerNorm(self.hidden_dim)
- self.linear_1 = nn.Linear(self.hidden_dim, self.input_dim)
- self.act_fn = nn.ReLU()
- self.linear_2 = nn.Linear(self.input_dim, self.hidden_dim)
-
- def forward(self, hidden_states: torch.FloatTensor):
- hidden_states = self.norm(hidden_states)
-
- hidden_states = self.linear_1(hidden_states)
- hidden_states = self.act_fn(hidden_states)
- hidden_states = self.linear_2(hidden_states)
-
- return hidden_states
-
-
-# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2EncoderLayerStableLayerNorm with Wav2Vec2->UniSpeechSat, WAV2VEC2->UNISPEECHSAT
-class UniSpeechSatEncoderLayerStableLayerNorm(nn.Module):
- def __init__(self, config):
- super().__init__()
- self.attention = UNISPEECHSAT_ATTENTION_CLASSES[config._attn_implementation](
- embed_dim=config.hidden_size,
- num_heads=config.num_attention_heads,
- dropout=config.attention_dropout,
- is_decoder=False,
- )
- self.dropout = nn.Dropout(config.hidden_dropout)
- self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
- self.feed_forward = UniSpeechSatFeedForward(config)
- self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
-
- if getattr(config, "adapter_attn_dim", None) is not None:
- self.adapter_layer = UniSpeechSatAttnAdapterLayer(config)
- else:
- self.adapter_layer = None
-
- def forward(
- self,
- hidden_states: torch.Tensor,
- attention_mask: Optional[torch.Tensor] = None,
- output_attentions: bool = False,
- ):
- attn_residual = hidden_states
- hidden_states = self.layer_norm(hidden_states)
- hidden_states, attn_weights, _ = self.attention(
- hidden_states, attention_mask=attention_mask, output_attentions=output_attentions
- )
- hidden_states = self.dropout(hidden_states)
- hidden_states = attn_residual + hidden_states
- hidden_states = hidden_states + self.feed_forward(self.final_layer_norm(hidden_states))
-
- if self.adapter_layer is not None:
- hidden_states = hidden_states + self.adapter_layer(hidden_states)
-
- outputs = (hidden_states,)
-
- if output_attentions:
- outputs += (attn_weights,)
-
- return outputs
-
-
-# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Encoder with Wav2Vec2->UniSpeechSat
class UniSpeechSatEncoder(nn.Module):
def __init__(self, config):
super().__init__()
@@ -1063,7 +820,76 @@ class UniSpeechSatEncoder(nn.Module):
)
-# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2EncoderStableLayerNorm with Wav2Vec2->UniSpeechSat
+class UniSpeechSatAttnAdapterLayer(nn.Module):
+ def __init__(self, config):
+ """
+ Implements adapter modules directly with 3D tensor weight as parameters and without using ModuleList to speed
+ up training throughput.
+ """
+ super().__init__()
+ self.input_dim = config.adapter_attn_dim
+ self.hidden_dim = config.hidden_size
+
+ self.norm = nn.LayerNorm(self.hidden_dim)
+ self.linear_1 = nn.Linear(self.hidden_dim, self.input_dim)
+ self.act_fn = nn.ReLU()
+ self.linear_2 = nn.Linear(self.input_dim, self.hidden_dim)
+
+ def forward(self, hidden_states: torch.FloatTensor):
+ hidden_states = self.norm(hidden_states)
+
+ hidden_states = self.linear_1(hidden_states)
+ hidden_states = self.act_fn(hidden_states)
+ hidden_states = self.linear_2(hidden_states)
+
+ return hidden_states
+
+
+class UniSpeechSatEncoderLayerStableLayerNorm(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.attention = UNISPEECH_SAT_ATTENTION_CLASSES[config._attn_implementation](
+ embed_dim=config.hidden_size,
+ num_heads=config.num_attention_heads,
+ dropout=config.attention_dropout,
+ is_decoder=False,
+ )
+ self.dropout = nn.Dropout(config.hidden_dropout)
+ self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.feed_forward = UniSpeechSatFeedForward(config)
+ self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+
+ if getattr(config, "adapter_attn_dim", None) is not None:
+ self.adapter_layer = UniSpeechSatAttnAdapterLayer(config)
+ else:
+ self.adapter_layer = None
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ output_attentions: bool = False,
+ ):
+ attn_residual = hidden_states
+ hidden_states = self.layer_norm(hidden_states)
+ hidden_states, attn_weights, _ = self.attention(
+ hidden_states, attention_mask=attention_mask, output_attentions=output_attentions
+ )
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = attn_residual + hidden_states
+ hidden_states = hidden_states + self.feed_forward(self.final_layer_norm(hidden_states))
+
+ if self.adapter_layer is not None:
+ hidden_states = hidden_states + self.adapter_layer(hidden_states)
+
+ outputs = (hidden_states,)
+
+ if output_attentions:
+ outputs += (attn_weights,)
+
+ return outputs
+
+
class UniSpeechSatEncoderStableLayerNorm(nn.Module):
def __init__(self, config):
super().__init__()
@@ -1155,7 +981,7 @@ class UniSpeechSatEncoderStableLayerNorm(nn.Module):
class UniSpeechSatGumbelVectorQuantizer(nn.Module):
"""
- Vector quantization using gumbel softmax. See [CATEGORICAL REPARAMETERIZATION WITH
+ Vector quantization using gumbel softmax. See `[CATEGORICAL REPARAMETERIZATION WITH
GUMBEL-SOFTMAX](https://arxiv.org/pdf/1611.01144.pdf) for more information.
"""
@@ -1166,8 +992,8 @@ class UniSpeechSatGumbelVectorQuantizer(nn.Module):
if config.codevector_dim % self.num_groups != 0:
raise ValueError(
- f"`config.codevector_dim {config.codevector_dim} must be divisible by `config.num_codevector_groups`"
- f" {self.num_groups} for concatenation"
+ f"`config.codevector_dim {config.codevector_dim} must be divisible "
+ f"by `config.num_codevector_groups` {self.num_groups} for concatenation"
)
# storage for codebook variables (codewords)
@@ -1300,6 +1126,128 @@ class UniSpeechSatPreTrainedModel(PreTrainedModel):
return attention_mask
+def _compute_mask_indices(
+ shape: Tuple[int, int],
+ mask_prob: float,
+ mask_length: int,
+ attention_mask: Optional[torch.LongTensor] = None,
+ min_masks: int = 0,
+) -> np.ndarray:
+ """
+ Computes random mask spans for a given shape. Used to implement [SpecAugment: A Simple Data Augmentation Method for
+ ASR](https://arxiv.org/abs/1904.08779). Note that this method is not optimized to run on TPU and should be run on
+ CPU as part of the preprocessing during training.
+
+ Args:
+ shape: The shape for which to compute masks. This should be of a tuple of size 2 where
+ the first element is the batch size and the second element is the length of the axis to span.
+ mask_prob: The percentage of the whole axis (between 0 and 1) which will be masked. The number of
+ independently generated mask spans of length `mask_length` is computed by
+ `mask_prob*shape[1]/mask_length`. Note that due to overlaps, `mask_prob` is an upper bound and the
+ actual percentage will be smaller.
+ mask_length: size of the mask
+ min_masks: minimum number of masked spans
+ attention_mask: A (right-padded) attention mask which independently shortens the feature axis of
+ each batch dimension.
+ """
+ batch_size, sequence_length = shape
+
+ if mask_length < 1:
+ raise ValueError("`mask_length` has to be bigger than 0.")
+
+ if mask_length > sequence_length:
+ raise ValueError(
+ f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length}"
+ f" and `sequence_length`: {sequence_length}`"
+ )
+
+ # epsilon is used for probabilistic rounding
+ epsilon = np.random.rand(1).item()
+
+ def compute_num_masked_span(input_length):
+ """Given input length, compute how many spans should be masked"""
+ num_masked_span = int(mask_prob * input_length / mask_length + epsilon)
+ num_masked_span = max(num_masked_span, min_masks)
+
+ # make sure num masked span <= sequence_length
+ if num_masked_span * mask_length > sequence_length:
+ num_masked_span = sequence_length // mask_length
+
+ # make sure num_masked span is also <= input_length - (mask_length - 1)
+ if input_length - (mask_length - 1) < num_masked_span:
+ num_masked_span = max(input_length - (mask_length - 1), 0)
+
+ return num_masked_span
+
+ # compute number of masked spans in batch
+ input_lengths = (
+ attention_mask.detach().sum(-1).tolist()
+ if attention_mask is not None
+ else [sequence_length for _ in range(batch_size)]
+ )
+
+ # SpecAugment mask to fill
+ spec_aug_mask = np.zeros((batch_size, sequence_length), dtype=bool)
+ spec_aug_mask_idxs = []
+
+ max_num_masked_span = compute_num_masked_span(sequence_length)
+
+ if max_num_masked_span == 0:
+ return spec_aug_mask
+
+ for input_length in input_lengths:
+ # compute num of masked spans for this input
+ num_masked_span = compute_num_masked_span(input_length)
+
+ # get random indices to mask
+ spec_aug_mask_idx = np.random.choice(
+ np.arange(input_length - (mask_length - 1)), num_masked_span, replace=False
+ )
+
+ # pick first sampled index that will serve as a dummy index to pad vector
+ # to ensure same dimension for all batches due to probabilistic rounding
+ # Picking first sample just pads those vectors twice.
+ if len(spec_aug_mask_idx) == 0:
+ # this case can only happen if `input_length` is strictly smaller then
+ # `sequence_length` in which case the last token has to be a padding
+ # token which we can use as a dummy mask id
+ dummy_mask_idx = sequence_length - 1
+ else:
+ dummy_mask_idx = spec_aug_mask_idx[0]
+
+ spec_aug_mask_idx = np.concatenate(
+ [spec_aug_mask_idx, np.ones(max_num_masked_span - num_masked_span, dtype=np.int32) * dummy_mask_idx]
+ )
+ spec_aug_mask_idxs.append(spec_aug_mask_idx)
+
+ spec_aug_mask_idxs = np.array(spec_aug_mask_idxs)
+
+ # expand masked indices to masked spans
+ spec_aug_mask_idxs = np.broadcast_to(
+ spec_aug_mask_idxs[:, :, None], (batch_size, max_num_masked_span, mask_length)
+ )
+ spec_aug_mask_idxs = spec_aug_mask_idxs.reshape(batch_size, max_num_masked_span * mask_length)
+
+ # add offset to the starting indexes so that indexes now create a span
+ offsets = np.arange(mask_length)[None, None, :]
+ offsets = np.broadcast_to(offsets, (batch_size, max_num_masked_span, mask_length)).reshape(
+ batch_size, max_num_masked_span * mask_length
+ )
+ spec_aug_mask_idxs = spec_aug_mask_idxs + offsets
+
+ # ensure that we cannot have indices larger than sequence_length
+ if spec_aug_mask_idxs.max() > sequence_length - 1:
+ spec_aug_mask_idxs[spec_aug_mask_idxs > sequence_length - 1] = sequence_length - 1
+
+ # scatter indices to mask
+ np.put_along_axis(spec_aug_mask, spec_aug_mask_idxs, 1, -1)
+
+ return spec_aug_mask
+
+
+_EXPECTED_OUTPUT_SHAPE = [1, 292, 768]
+
+
UNISPEECH_SAT_START_DOCSTRING = r"""
UniSpeechSat was proposed in [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech
Representations](https://arxiv.org/abs/2006.11477) by Alexei Baevski, Henry Zhou, Abdelrahman Mohamed, Michael
@@ -1318,7 +1266,6 @@ UNISPEECH_SAT_START_DOCSTRING = r"""
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
-
UNISPEECH_SAT_INPUTS_DOCSTRING = r"""
Args:
input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
@@ -1338,12 +1285,10 @@ UNISPEECH_SAT_INPUTS_DOCSTRING = r"""
`attention_mask` should only be passed if the corresponding processor has `config.return_attention_mask ==
- True`. For all models whose processor has `config.return_attention_mask == False`, such as
- [microsoft/unispeech-sat-base-100h-libri-ft](https://huggingface.co/microsoft/unispeech-sat-base-100h-libri-ft),
- `attention_mask` should **not** be passed to avoid degraded performance when doing batched inference. For
- such models `input_values` should simply be padded with 0 and passed without `attention_mask`. Be aware
- that these models also yield slightly different results depending on whether `input_values` is padded or
- not.
+ True`. For all models whose processor has `config.return_attention_mask == False`, `attention_mask` should
+ **not** be passed to avoid degraded performance when doing batched inference. For such models
+ `input_values` should simply be padded with 0 and passed without `attention_mask`. Be aware that these
+ models also yield slightly different results depending on whether `input_values` is padded or not.
@@ -1357,6 +1302,8 @@ UNISPEECH_SAT_INPUTS_DOCSTRING = r"""
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
+UniSpeechSatBaseModelOutput = Wav2Vec2BaseModelOutput
+
@add_start_docstrings(
"The bare UniSpeechSat Model transformer outputting raw hidden-states without any specific head on top.",
@@ -1379,7 +1326,6 @@ class UniSpeechSatModel(UniSpeechSatPreTrainedModel):
# Initialize weights and apply final processing
self.post_init()
- # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Model._mask_hidden_states
def _mask_hidden_states(
self,
hidden_states: torch.FloatTensor,
@@ -1429,7 +1375,7 @@ class UniSpeechSatModel(UniSpeechSatPreTrainedModel):
@add_start_docstrings_to_model_forward(UNISPEECH_SAT_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
- output_type=Wav2Vec2BaseModelOutput,
+ output_type=UniSpeechSatBaseModelOutput,
config_class=_CONFIG_FOR_DOC,
modality="audio",
expected_output=_EXPECTED_OUTPUT_SHAPE,
@@ -1442,7 +1388,7 @@ class UniSpeechSatModel(UniSpeechSatPreTrainedModel):
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
- ) -> Union[Tuple, Wav2Vec2BaseModelOutput]:
+ ) -> Union[Tuple, UniSpeechSatBaseModelOutput]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
@@ -1474,7 +1420,7 @@ class UniSpeechSatModel(UniSpeechSatPreTrainedModel):
if not return_dict:
return (hidden_states, extract_features) + encoder_outputs[1:]
- return Wav2Vec2BaseModelOutput(
+ return UniSpeechSatBaseModelOutput(
last_hidden_state=hidden_states,
extract_features=extract_features,
hidden_states=encoder_outputs.hidden_states,
@@ -1482,7 +1428,10 @@ class UniSpeechSatModel(UniSpeechSatPreTrainedModel):
)
-@add_start_docstrings("""UniSpeechSat Model with a quantizer and `VQ` head on top.""", UNISPEECH_SAT_START_DOCSTRING)
+@add_start_docstrings(
+ """UniSpeechSat Model with a vector-quantization module and ctc loss for pre-training.""",
+ UNISPEECH_SAT_START_DOCSTRING,
+)
class UniSpeechSatForPreTraining(UniSpeechSatPreTrainedModel):
def __init__(self, config: UniSpeechSatConfig):
super().__init__(config)
@@ -1529,7 +1478,7 @@ class UniSpeechSatForPreTraining(UniSpeechSatPreTrainedModel):
Calling this function will disable the gradient computation for the feature encoder so that its parameter will
not be updated during training.
"""
- self.wav2vec2.feature_extractor._freeze_parameters()
+ self.unispeech_sat.feature_extractor._freeze_parameters()
@staticmethod
def compute_contrastive_logits(
@@ -1594,16 +1543,6 @@ class UniSpeechSatForPreTraining(UniSpeechSatPreTrainedModel):
logits = extract_features
loss = quantized_features = codevector_perplexity = None
- # layer normalization (has no effect when `config.do_stable_layer_norm == False`)
- # extract_features = self.layer_norm_for_extract(extract_features)
- # quantized_features, codevector_perplexity = self.quantizer(extract_features)
- #
- # project quantized features twice
- # quantized_features = self.project_q(quantized_features)
- # quantized_features = self.project_hid(quantized_features)
- #
- # loss = None
- # logits = quantized_features
if not return_dict:
if loss is not None:
return (loss, logits, transformer_features, quantized_features, codevector_perplexity) + outputs[2:]
@@ -1620,6 +1559,13 @@ class UniSpeechSatForPreTraining(UniSpeechSatPreTrainedModel):
)
+_HIDDEN_STATES_START_POSITION = 2
+
+# CTC docstring
+_CTC_EXPECTED_OUTPUT = "'MISTER QUILDER IS THE APOSTLE OF THE MIDDLE CLASSES AND WE ARE GLAD TO WELCOME HIS GOSPEL'"
+_CTC_EXPECTED_LOSS = 39.88
+
+
@add_start_docstrings(
"""UniSpeechSat Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC).""",
UNISPEECH_SAT_START_DOCSTRING,
@@ -1630,7 +1576,6 @@ class UniSpeechSatForPreTraining(UniSpeechSatPreTrainedModel):
'eng' by default.
""",
)
-# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForCTC with Wav2Vec2->UniSpeechSat, wav2vec2->unispeech_sat, WAV_2_VEC_2->UNISPEECH_SAT
class UniSpeechSatForCTC(UniSpeechSatPreTrainedModel):
def __init__(self, config, target_lang: Optional[str] = None):
super().__init__(config)
@@ -1784,8 +1729,8 @@ class UniSpeechSatForCTC(UniSpeechSatPreTrainedModel):
@add_start_docstrings(
"""
- UniSpeechSat Model with a sequence classification head on top (a linear layer over the pooled output) for tasks
- like SUPERB Keyword Spotting.
+ UniSpeechSat Model with a sequence classification head on top (a linear layer over the pooled output) for tasks like
+ SUPERB Keyword Spotting.
""",
UNISPEECH_SAT_START_DOCSTRING,
)
@@ -1807,7 +1752,6 @@ class UniSpeechSatForSequenceClassification(UniSpeechSatPreTrainedModel):
# Initialize weights and apply final processing
self.post_init()
- # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.freeze_feature_extractor
def freeze_feature_extractor(self):
"""
Calling this function will disable the gradient computation for the feature encoder so that its parameters will
@@ -1820,7 +1764,6 @@ class UniSpeechSatForSequenceClassification(UniSpeechSatPreTrainedModel):
)
self.freeze_feature_encoder()
- # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.freeze_feature_encoder with wav2vec2->unispeech_sat
def freeze_feature_encoder(self):
"""
Calling this function will disable the gradient computation for the feature encoder so that its parameter will
@@ -1828,7 +1771,6 @@ class UniSpeechSatForSequenceClassification(UniSpeechSatPreTrainedModel):
"""
self.unispeech_sat.feature_extractor._freeze_parameters()
- # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.freeze_base_model with wav2vec2->unispeech_sat
def freeze_base_model(self):
"""
Calling this function will disable the gradient computation for the base model so that its parameters will not
@@ -1844,7 +1786,6 @@ class UniSpeechSatForSequenceClassification(UniSpeechSatPreTrainedModel):
config_class=_CONFIG_FOR_DOC,
modality="audio",
)
- # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.forward with Wav2Vec2->UniSpeechSat, wav2vec2->unispeech_sat
def forward(
self,
input_values: Optional[torch.Tensor],
@@ -1908,13 +1849,17 @@ class UniSpeechSatForSequenceClassification(UniSpeechSatPreTrainedModel):
)
+# Frame class docstring
+_FRAME_CLASS_CHECKPOINT = "microsoft/unispeech-sat-base-plus-sd"
+_FRAME_EXPECTED_OUTPUT = [0, 0]
+
+
@add_start_docstrings(
"""
- UniSpeech-SAT Model with a frame classification head on top for tasks like Speaker Diarization.
+ UniSpeechSat Model with a frame classification head on top for tasks like Speaker Diarization.
""",
UNISPEECH_SAT_START_DOCSTRING,
)
-# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForAudioFrameClassification with Wav2Vec2->UniSpeechSat, wav2vec2->unispeech_sat, WAV_2_VEC_2->UNISPEECH_SAT
class UniSpeechSatForAudioFrameClassification(UniSpeechSatPreTrainedModel):
def __init__(self, config):
super().__init__(config)
@@ -2021,7 +1966,6 @@ class UniSpeechSatForAudioFrameClassification(UniSpeechSatPreTrainedModel):
)
-# Copied from transformers.models.wav2vec2.modeling_wav2vec2.AMSoftmaxLoss
class AMSoftmaxLoss(nn.Module):
def __init__(self, input_dim, num_labels, scale=30.0, margin=0.4):
super(AMSoftmaxLoss, self).__init__()
@@ -2045,7 +1989,6 @@ class AMSoftmaxLoss(nn.Module):
return loss
-# Copied from transformers.models.wav2vec2.modeling_wav2vec2.TDNNLayer
class TDNNLayer(nn.Module):
def __init__(self, config, layer_id=0):
super().__init__()
@@ -2061,6 +2004,7 @@ class TDNNLayer(nn.Module):
if is_peft_available():
from peft.tuners.lora import LoraLayer
+ if is_peft_available():
if isinstance(self.kernel, LoraLayer):
warnings.warn(
"Detected LoRA on TDNNLayer. LoRA weights won't be applied due to optimization. "
@@ -2077,13 +2021,17 @@ class TDNNLayer(nn.Module):
return hidden_states
+# Speaker Verification docstring
+_XVECTOR_CHECKPOINT = "microsoft/unispeech-sat-base-plus-sv"
+_XVECTOR_EXPECTED_OUTPUT = 0.97
+
+
@add_start_docstrings(
"""
- UniSpeech-SAT Model with an XVector feature extraction head on top for tasks like Speaker Verification.
+ UniSpeechSat Model with an XVector feature extraction head on top for tasks like Speaker Verification.
""",
UNISPEECH_SAT_START_DOCSTRING,
)
-# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForXVector with Wav2Vec2->UniSpeechSat, wav2vec2->unispeech_sat, WAV_2_VEC_2->UNISPEECH_SAT
class UniSpeechSatForXVector(UniSpeechSatPreTrainedModel):
def __init__(self, config):
super().__init__(config)
diff --git a/src/transformers/models/unispeech_sat/modular_unispeech_sat.py b/src/transformers/models/unispeech_sat/modular_unispeech_sat.py
new file mode 100644
index 00000000000..44e566068ef
--- /dev/null
+++ b/src/transformers/models/unispeech_sat/modular_unispeech_sat.py
@@ -0,0 +1,610 @@
+import math
+import warnings
+from dataclasses import dataclass
+from typing import Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+
+from ...modeling_outputs import (
+ CausalLMOutput,
+ ModelOutput,
+ SequenceClassifierOutput,
+ TokenClassifierOutput,
+ Wav2Vec2BaseModelOutput,
+ XVectorOutput,
+)
+from ...modeling_utils import PreTrainedModel
+from ...utils import (
+ add_code_sample_docstrings,
+ add_start_docstrings,
+ add_start_docstrings_to_model_forward,
+ logging,
+ replace_return_docstrings,
+)
+from ..wav2vec2.modeling_wav2vec2 import (
+ Wav2Vec2Encoder,
+ Wav2Vec2EncoderStableLayerNorm,
+ Wav2Vec2FeatureEncoder,
+ Wav2Vec2FeatureProjection,
+ Wav2Vec2ForAudioFrameClassification,
+ Wav2Vec2ForCTC,
+ Wav2Vec2ForSequenceClassification,
+ Wav2Vec2ForXVector,
+ Wav2Vec2GumbelVectorQuantizer,
+ Wav2Vec2Model,
+ Wav2Vec2PositionalConvEmbedding,
+)
+from .configuration_unispeech_sat import UniSpeechSatConfig
+
+
+logger = logging.get_logger(__name__)
+
+
+_HIDDEN_STATES_START_POSITION = 2
+
+# General docstring
+_CONFIG_FOR_DOC = "UniSpeechSatConfig"
+
+# Base docstring
+_CHECKPOINT_FOR_DOC = "microsoft/unispeech-sat-base-100h-libri-ft"
+_EXPECTED_OUTPUT_SHAPE = [1, 292, 768]
+
+# CTC docstring
+_CTC_EXPECTED_OUTPUT = "'MISTER QUILDER IS THE APOSTLE OF THE MIDDLE CLASSES AND WE ARE GLAD TO WELCOME HIS GOSPEL'"
+_CTC_EXPECTED_LOSS = 39.88
+
+# Frame class docstring
+_FRAME_CLASS_CHECKPOINT = "microsoft/unispeech-sat-base-plus-sd"
+_FRAME_EXPECTED_OUTPUT = [0, 0]
+
+# Speaker Verification docstring
+_XVECTOR_CHECKPOINT = "microsoft/unispeech-sat-base-plus-sv"
+_XVECTOR_EXPECTED_OUTPUT = 0.97
+
+
+@dataclass
+class UniSpeechSatForPreTrainingOutput(ModelOutput):
+ """
+ Output type of [`UniSpeechSatForPreTrainingOutput`], with potential hidden states and attentions.
+
+ Args:
+ loss (*optional*, returned when model is in train mode, `torch.FloatTensor` of shape `(1,)`):
+ Total loss as the sum of the contrastive loss (L_m) and the diversity loss (L_d) as stated in the [official
+ paper](https://arxiv.org/pdf/2006.11477.pdf) . (classification) loss.
+ projected_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.proj_codevector_dim)`):
+ Hidden-states of the model projected to *config.proj_codevector_dim* that can be used to predict the masked
+ projected quantized states.
+ projected_quantized_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.proj_codevector_dim)`):
+ Quantized extracted feature vectors projected to *config.proj_codevector_dim* representing the positive
+ target vectors for contrastive loss.
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
+ shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ logits: Optional[torch.FloatTensor] = None
+ projected_states: Optional[torch.FloatTensor] = None
+ projected_quantized_states: Optional[torch.FloatTensor] = None
+ codevector_perplexity: Optional[torch.FloatTensor] = None
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
+
+
+class UniSpeechSatPositionalConvEmbedding(Wav2Vec2PositionalConvEmbedding):
+ pass
+
+
+class UniSpeechSatFeatureEncoder(Wav2Vec2FeatureEncoder):
+ pass
+
+
+class UniSpeechSatFeatureProjection(Wav2Vec2FeatureProjection):
+ pass
+
+
+class UniSpeechSatEncoder(Wav2Vec2Encoder):
+ pass
+
+
+class UniSpeechSatEncoderStableLayerNorm(Wav2Vec2EncoderStableLayerNorm):
+ pass
+
+
+class UniSpeechSatGumbelVectorQuantizer(Wav2Vec2GumbelVectorQuantizer):
+ def __init__(self, config):
+ super().__init__()
+ self.weight_proj = nn.Linear(config.hidden_size, self.num_groups * self.num_vars)
+
+ @staticmethod
+ def _compute_perplexity(probs, mask=None):
+ marginal_probs = probs.mean(dim=0)
+ perplexity = torch.exp(-torch.sum(marginal_probs * torch.log(marginal_probs + 1e-7), dim=-1)).sum()
+ return perplexity
+
+ def forward(self, hidden_states):
+ batch_size, sequence_length, hidden_size = hidden_states.shape
+
+ # project to codevector dim
+ hidden_states = self.weight_proj(hidden_states)
+ hidden_states = hidden_states.view(batch_size * sequence_length * self.num_groups, -1)
+
+ if self.training:
+ # sample code vector probs via gumbel in differentiateable way
+ codevector_probs = nn.functional.gumbel_softmax(
+ hidden_states.float(), tau=self.temperature, hard=True
+ ).type_as(hidden_states)
+
+ # compute perplexity
+ codevector_soft_dist = torch.softmax(
+ hidden_states.view(batch_size * sequence_length, self.num_groups, -1).float(), dim=-1
+ )
+ perplexity = self._compute_perplexity(codevector_soft_dist)
+ else:
+ # take argmax in non-differentiable way
+ # comptute hard codevector distribution (one hot)
+ codevector_idx = hidden_states.argmax(dim=-1)
+ codevector_probs = hidden_states.new_zeros(*hidden_states.shape).scatter_(
+ -1, codevector_idx.view(-1, 1), 1.0
+ )
+ codevector_probs = codevector_probs.view(batch_size * sequence_length, self.num_groups, -1)
+
+ perplexity = self._compute_perplexity(codevector_probs)
+
+ codevector_probs = codevector_probs.view(batch_size * sequence_length, -1)
+ # use probs to retrieve codevectors
+ codevectors_per_group = codevector_probs.unsqueeze(-1) * self.codevectors
+ codevectors = codevectors_per_group.view(batch_size * sequence_length, self.num_groups, self.num_vars, -1)
+ codevectors = codevectors.sum(-2).view(batch_size, sequence_length, -1)
+
+ return codevectors, perplexity
+
+
+class UniSpeechSatPreTrainedModel(PreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = UniSpeechSatConfig
+ base_model_prefix = "unispeech_sat"
+ main_input_name = "input_values"
+ supports_gradient_checkpointing = True
+ _supports_flash_attn_2 = True
+ _supports_sdpa = True
+
+ def _init_weights(self, module):
+ """Initialize the weights"""
+ # gumbel softmax requires special init
+ if isinstance(module, UniSpeechSatGumbelVectorQuantizer):
+ module.weight_proj.weight.data.normal_(mean=0.0, std=1)
+ module.weight_proj.bias.data.zero_()
+ nn.init.uniform_(module.codevectors)
+ elif isinstance(module, UniSpeechSatPositionalConvEmbedding):
+ nn.init.normal_(
+ module.conv.weight,
+ mean=0,
+ std=2 * math.sqrt(1 / (module.conv.kernel_size[0] * module.conv.in_channels)),
+ )
+ nn.init.constant_(module.conv.bias, 0)
+ elif isinstance(module, UniSpeechSatFeatureProjection):
+ k = math.sqrt(1 / module.projection.in_features)
+ nn.init.uniform_(module.projection.weight, a=-k, b=k)
+ nn.init.uniform_(module.projection.bias, a=-k, b=k)
+ elif isinstance(module, nn.Linear):
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+ elif isinstance(module, nn.Conv1d):
+ nn.init.kaiming_normal_(module.weight)
+
+ if module.bias is not None:
+ k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0]))
+ nn.init.uniform_(module.bias, a=-k, b=k)
+
+ def _get_feat_extract_output_lengths(self, input_lengths: Union[torch.LongTensor, int]):
+ """
+ Computes the output length of the convolutional layers
+ """
+
+ def _conv_out_length(input_length, kernel_size, stride):
+ # 1D convolutional layer output length formula taken
+ # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
+ return torch.div(input_length - kernel_size, stride, rounding_mode="floor") + 1
+
+ for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride):
+ input_lengths = _conv_out_length(input_lengths, kernel_size, stride)
+
+ return input_lengths
+
+ def _get_feature_vector_attention_mask(self, feature_vector_length: int, attention_mask: torch.LongTensor):
+ # Effectively attention_mask.sum(-1), but not inplace to be able to run
+ # on inference mode.
+ non_padded_lengths = attention_mask.cumsum(dim=-1)[:, -1]
+ output_lengths = self._get_feat_extract_output_lengths(non_padded_lengths).to(torch.long)
+ batch_size = attention_mask.shape[0]
+
+ attention_mask = torch.zeros(
+ (batch_size, feature_vector_length), dtype=attention_mask.dtype, device=attention_mask.device
+ )
+ # these two operations makes sure that all values before the output lengths idxs are attended to
+ attention_mask[(torch.arange(attention_mask.shape[0], device=attention_mask.device), output_lengths - 1)] = 1
+ attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool()
+ return attention_mask
+
+
+UNISPEECH_SAT_START_DOCSTRING = r"""
+ UniSpeechSat was proposed in [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech
+ Representations](https://arxiv.org/abs/2006.11477) by Alexei Baevski, Henry Zhou, Abdelrahman Mohamed, Michael
+ Auli.
+
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
+ library implements for all its model (such as downloading or saving etc.).
+
+ This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use
+ it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
+ behavior.
+
+ Parameters:
+ config ([`UniSpeechSatConfig`]): Model configuration class with all the parameters of the model.
+ Initializing with a config file does not load the weights associated with the model, only the
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+UNISPEECH_SAT_INPUTS_DOCSTRING = r"""
+ Args:
+ input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
+ Float values of input raw speech waveform. Values can be obtained by loading a `.flac` or `.wav` audio file
+ into an array of type `List[float]` or a `numpy.ndarray`, *e.g.* via the soundfile library (`pip install
+ soundfile`). To prepare the array into `input_values`, the [`AutoProcessor`] should be used for padding and
+ conversion into a tensor of type `torch.FloatTensor`. See [`Wav2Vec2Processor.__call__`] for details.
+ attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing convolution and attention on padding token indices. Mask values selected in `[0,
+ 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+
+
+
+ `attention_mask` should only be passed if the corresponding processor has `config.return_attention_mask ==
+ True`. For all models whose processor has `config.return_attention_mask == False`, `attention_mask` should
+ **not** be passed to avoid degraded performance when doing batched inference. For such models
+ `input_values` should simply be padded with 0 and passed without `attention_mask`. Be aware that these
+ models also yield slightly different results depending on whether `input_values` is padded or not.
+
+
+
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+
+UniSpeechSatBaseModelOutput = Wav2Vec2BaseModelOutput
+
+
+@add_start_docstrings(
+ "The bare UniSpeechSat Model transformer outputting raw hidden-states without any specific head on top.",
+ UNISPEECH_SAT_START_DOCSTRING,
+)
+class UniSpeechSatModel(UniSpeechSatPreTrainedModel, Wav2Vec2Model):
+ def __init__(self, config: UniSpeechSatConfig):
+ UniSpeechSatPreTrainedModel.__init__(config)
+ self.config = config
+ self.feature_extractor = UniSpeechSatFeatureEncoder(config)
+ self.feature_projection = UniSpeechSatFeatureProjection(config)
+
+ self.masked_spec_embed = nn.Parameter(torch.Tensor(config.hidden_size).uniform_())
+
+ if config.do_stable_layer_norm:
+ self.encoder = UniSpeechSatEncoderStableLayerNorm(config)
+ else:
+ self.encoder = UniSpeechSatEncoder(config)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def freeze_feature_extractor(self):
+ raise AttributeError("Not needed for UniSpeechSat")
+
+ def freeze_feature_encoder(self):
+ raise AttributeError("Not needed for UniSpeechSat")
+
+ @add_start_docstrings_to_model_forward(UNISPEECH_SAT_INPUTS_DOCSTRING)
+ @add_code_sample_docstrings(
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=UniSpeechSatBaseModelOutput,
+ config_class=_CONFIG_FOR_DOC,
+ modality="audio",
+ expected_output=_EXPECTED_OUTPUT_SHAPE,
+ )
+ def forward(
+ self,
+ input_values: Optional[torch.Tensor],
+ attention_mask: Optional[torch.Tensor] = None,
+ mask_time_indices: Optional[torch.FloatTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, UniSpeechSatBaseModelOutput]:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ extract_features = self.feature_extractor(input_values)
+ extract_features = extract_features.transpose(1, 2)
+
+ if attention_mask is not None:
+ # compute reduced attention_mask corresponding to feature vectors
+ attention_mask = self._get_feature_vector_attention_mask(extract_features.shape[1], attention_mask)
+
+ hidden_states, extract_features = self.feature_projection(extract_features)
+ hidden_states = self._mask_hidden_states(
+ hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask
+ )
+
+ encoder_outputs = self.encoder(
+ hidden_states,
+ attention_mask=attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ hidden_states = encoder_outputs[0]
+
+ if not return_dict:
+ return (hidden_states, extract_features) + encoder_outputs[1:]
+
+ return UniSpeechSatBaseModelOutput(
+ last_hidden_state=hidden_states,
+ extract_features=extract_features,
+ hidden_states=encoder_outputs.hidden_states,
+ attentions=encoder_outputs.attentions,
+ )
+
+
+@add_start_docstrings(
+ """UniSpeechSat Model with a vector-quantization module and ctc loss for pre-training.""",
+ UNISPEECH_SAT_START_DOCSTRING,
+)
+class UniSpeechSatForPreTraining(UniSpeechSatPreTrainedModel):
+ def __init__(self, config: UniSpeechSatConfig):
+ super().__init__(config)
+ self.unispeech_sat = UniSpeechSatModel(config)
+ self.dropout_features = nn.Dropout(config.feat_quantizer_dropout)
+
+ self.quantizer = UniSpeechSatGumbelVectorQuantizer(config)
+ self.project_q = nn.Linear(config.codevector_dim, config.proj_codevector_dim)
+ self.project_hid = nn.Linear(config.hidden_size, config.proj_codevector_dim)
+
+ self.dropout = nn.Dropout(config.final_dropout)
+
+ self.speaker_proj = nn.Linear(config.hidden_size, config.codevector_dim)
+ self.label_embeddings_concat = nn.Parameter(torch.FloatTensor(config.num_clusters, config.codevector_dim))
+ self.label_embeddings_concat.data.zero_()
+
+ self.layer_norm_for_extract = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ if self.config.do_stable_layer_norm:
+ self.layer_norm_for_extract.requires_grad = False
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def set_gumbel_temperature(self, temperature: int):
+ """
+ Set the Gumbel softmax temperature to a given value. Only necessary for training
+ """
+ self.quantizer.temperature = temperature
+
+ def freeze_feature_extractor(self):
+ """
+ Calling this function will disable the gradient computation for the feature encoder so that its parameters will
+ not be updated during training.
+ """
+ warnings.warn(
+ "The method `freeze_feature_extractor` is deprecated and will be removed in Transformers v5. "
+ "Please use the equivalent `freeze_feature_encoder` method instead.",
+ FutureWarning,
+ )
+ self.freeze_feature_encoder()
+
+ def freeze_feature_encoder(self):
+ """
+ Calling this function will disable the gradient computation for the feature encoder so that its parameter will
+ not be updated during training.
+ """
+ self.unispeech_sat.feature_extractor._freeze_parameters()
+
+ @staticmethod
+ def compute_contrastive_logits(
+ target_features: torch.FloatTensor,
+ negative_features: torch.FloatTensor,
+ predicted_features: torch.FloatTensor,
+ temperature: int = 1,
+ ):
+ """
+ Compute logits for contrastive loss based using cosine similarity as the distance measure between
+ `[positive_feature, negative_features]` and `[predicted_features]`. Additionally, temperature can be applied.
+ """
+ target_features = torch.cat([target_features, negative_features], dim=0)
+
+ logits = torch.cosine_similarity(predicted_features.float(), target_features.float(), dim=-1)
+ logits = logits.type_as(target_features)
+
+ # apply temperature
+ logits = logits / temperature
+ return logits
+
+ @add_start_docstrings_to_model_forward(UNISPEECH_SAT_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=UniSpeechSatForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
+ def forward(
+ self,
+ input_values: Optional[torch.Tensor],
+ attention_mask: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, UniSpeechSatForPreTrainingOutput]:
+ r"""
+ Returns:
+
+ Example:
+
+ ```python
+ >>> import torch
+ >>> from transformers import AutoFeatureExtractor, UniSpeechSatForPreTraining
+ >>> from transformers.models.unispeech_sat.modeling_unispeech_sat import _compute_mask_indices
+
+ >>> feature_extractor = AutoFeatureExtractor.from_pretrained("microsoft/unispeech-sat-base")
+ >>> model = UniSpeechSatForPreTraining.from_pretrained("microsoft/unispeech-sat-base")
+ >>> # TODO: Add full pretraining example
+ ```"""
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.unispeech_sat(
+ input_values,
+ attention_mask=attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ transformer_features = outputs[0]
+
+ # quantize all (unmasked) extracted features and project to final vq dim
+ extract_features = self.dropout_features(outputs[1])
+
+ # TODO(PVP) - add pretraining logic and add to tests
+ logits = extract_features
+ loss = quantized_features = codevector_perplexity = None
+
+ if not return_dict:
+ if loss is not None:
+ return (loss, logits, transformer_features, quantized_features, codevector_perplexity) + outputs[2:]
+ return (logits, transformer_features, quantized_features, codevector_perplexity) + outputs[2:]
+
+ return UniSpeechSatForPreTrainingOutput(
+ loss=loss,
+ logits=logits,
+ projected_states=transformer_features,
+ projected_quantized_states=quantized_features,
+ codevector_perplexity=codevector_perplexity,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+@add_start_docstrings(
+ """UniSpeechSat Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC).""",
+ UNISPEECH_SAT_START_DOCSTRING,
+ """
+ target_lang (`str`, *optional*):
+ Language id of adapter weights. Adapter weights are stored in the format adapter..safetensors or
+ adapter..bin. Only relevant when using an instance of [`UniSpeechSatForCTC`] with adapters. Uses
+ 'eng' by default.
+ """,
+)
+class UniSpeechSatForCTC(Wav2Vec2ForCTC):
+ @add_start_docstrings_to_model_forward(UNISPEECH_SAT_INPUTS_DOCSTRING)
+ @add_code_sample_docstrings(
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=CausalLMOutput,
+ config_class=_CONFIG_FOR_DOC,
+ expected_output=_CTC_EXPECTED_OUTPUT,
+ expected_loss=_CTC_EXPECTED_LOSS,
+ )
+ def forward(self, **super_kwargs):
+ return super().forward(**super_kwargs)
+
+
+@add_start_docstrings(
+ """
+ UniSpeechSat Model with a sequence classification head on top (a linear layer over the pooled output) for tasks like
+ SUPERB Keyword Spotting.
+ """,
+ UNISPEECH_SAT_START_DOCSTRING,
+)
+class UniSpeechSatForSequenceClassification(Wav2Vec2ForSequenceClassification):
+ @add_start_docstrings_to_model_forward(UNISPEECH_SAT_INPUTS_DOCSTRING)
+ @add_code_sample_docstrings(
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=SequenceClassifierOutput,
+ config_class=_CONFIG_FOR_DOC,
+ modality="audio",
+ )
+ def forward(self, **super_kwargs):
+ super().forward(**super_kwargs)
+
+
+@add_start_docstrings(
+ """
+ UniSpeechSat Model with a frame classification head on top for tasks like Speaker Diarization.
+ """,
+ UNISPEECH_SAT_START_DOCSTRING,
+)
+class UniSpeechSatForAudioFrameClassification(Wav2Vec2ForAudioFrameClassification):
+ @add_start_docstrings_to_model_forward(UNISPEECH_SAT_INPUTS_DOCSTRING)
+ @add_code_sample_docstrings(
+ checkpoint=_FRAME_CLASS_CHECKPOINT,
+ output_type=TokenClassifierOutput,
+ config_class=_CONFIG_FOR_DOC,
+ modality="audio",
+ expected_output=_FRAME_EXPECTED_OUTPUT,
+ )
+ def forward(self, **super_kwargs):
+ super().forward(**super_kwargs)
+
+
+@add_start_docstrings(
+ """
+ UniSpeechSat Model with an XVector feature extraction head on top for tasks like Speaker Verification.
+ """,
+ UNISPEECH_SAT_START_DOCSTRING,
+)
+class UniSpeechSatForXVector(Wav2Vec2ForXVector):
+ pass
+
+ @add_start_docstrings_to_model_forward(UNISPEECH_SAT_INPUTS_DOCSTRING)
+ @add_code_sample_docstrings(
+ checkpoint=_XVECTOR_CHECKPOINT,
+ output_type=XVectorOutput,
+ config_class=_CONFIG_FOR_DOC,
+ modality="audio",
+ expected_output=_XVECTOR_EXPECTED_OUTPUT,
+ )
+ def forward(self, **super_kwargs):
+ super().forward(**super_kwargs)
+
+
+__all__ = [
+ "UniSpeechSatForAudioFrameClassification",
+ "UniSpeechSatForCTC",
+ "UniSpeechSatForPreTraining",
+ "UniSpeechSatForSequenceClassification",
+ "UniSpeechSatForXVector",
+ "UniSpeechSatModel",
+ "UniSpeechSatPreTrainedModel",
+]
diff --git a/src/transformers/models/wav2vec2/modeling_flax_wav2vec2.py b/src/transformers/models/wav2vec2/modeling_flax_wav2vec2.py
index 55d34b84ef5..54707620501 100644
--- a/src/transformers/models/wav2vec2/modeling_flax_wav2vec2.py
+++ b/src/transformers/models/wav2vec2/modeling_flax_wav2vec2.py
@@ -212,7 +212,7 @@ def _sample_negative_indices(features_shape: Tuple, num_negatives: int, attentio
return sampled_negative_indices
-WAV_2_VEC_2_START_DOCSTRING = r"""
+WAV2VEC2_START_DOCSTRING = r"""
Wav2Vec2 was proposed in [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech
Representations](https://arxiv.org/abs/2006.11477) by Alexei Baevski, Henry Zhou, Abdelrahman Mohamed, Michael
Auli.
@@ -251,7 +251,7 @@ WAV_2_VEC_2_START_DOCSTRING = r"""
"""
-WAV_2_VEC_2_INPUTS_DOCSTRING = r"""
+WAV2VEC2_INPUTS_DOCSTRING = r"""
Args:
input_values (`jnp.ndarray` of shape `(batch_size, sequence_length)`):
Float values of input raw speech waveform. Values can be obtained by loading a `.flac` or `.wav` audio file
@@ -885,7 +885,7 @@ class FlaxWav2Vec2PreTrainedModel(FlaxPreTrainedModel):
else:
return random_params
- @add_start_docstrings_to_model_forward(WAV_2_VEC_2_INPUTS_DOCSTRING)
+ @add_start_docstrings_to_model_forward(WAV2VEC2_INPUTS_DOCSTRING)
def __call__(
self,
input_values,
@@ -1050,7 +1050,7 @@ class FlaxWav2Vec2Module(nn.Module):
@add_start_docstrings(
"The bare Wav2Vec2 Model transformer outputting raw hidden-states without any specific head on top.",
- WAV_2_VEC_2_START_DOCSTRING,
+ WAV2VEC2_START_DOCSTRING,
)
class FlaxWav2Vec2Model(FlaxWav2Vec2PreTrainedModel):
module_class = FlaxWav2Vec2Module
@@ -1088,7 +1088,7 @@ FLAX_WAV2VEC2_MODEL_DOCSTRING = """
overwrite_call_docstring(
FlaxWav2Vec2Model,
- WAV_2_VEC_2_INPUTS_DOCSTRING + FLAX_WAV2VEC2_MODEL_DOCSTRING,
+ WAV2VEC2_INPUTS_DOCSTRING + FLAX_WAV2VEC2_MODEL_DOCSTRING,
)
append_replace_return_docstrings(
FlaxWav2Vec2Model, output_type=FlaxWav2Vec2BaseModelOutput, config_class=Wav2Vec2Config
@@ -1168,7 +1168,7 @@ class FlaxWav2Vec2ForCTCModule(nn.Module):
@add_start_docstrings(
"Wav2Vec2 Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC).",
- WAV_2_VEC_2_START_DOCSTRING,
+ WAV2VEC2_START_DOCSTRING,
)
class FlaxWav2Vec2ForCTC(FlaxWav2Vec2PreTrainedModel):
module_class = FlaxWav2Vec2ForCTCModule
@@ -1211,7 +1211,7 @@ FLAX_WAV2VEC2_FOR_CTC_DOCSTRING = """
overwrite_call_docstring(
FlaxWav2Vec2ForCTC,
- WAV_2_VEC_2_INPUTS_DOCSTRING + FLAX_WAV2VEC2_FOR_CTC_DOCSTRING,
+ WAV2VEC2_INPUTS_DOCSTRING + FLAX_WAV2VEC2_FOR_CTC_DOCSTRING,
)
append_replace_return_docstrings(FlaxWav2Vec2ForCTC, output_type=FlaxCausalLMOutput, config_class=Wav2Vec2Config)
@@ -1315,11 +1315,11 @@ class FlaxWav2Vec2ForPreTrainingModule(nn.Module):
return input_lengths
-@add_start_docstrings("""Wav2Vec2 Model with a quantizer and `VQ` head on top.""", WAV_2_VEC_2_START_DOCSTRING)
+@add_start_docstrings("""Wav2Vec2 Model with a quantizer and `VQ` head on top.""", WAV2VEC2_START_DOCSTRING)
class FlaxWav2Vec2ForPreTraining(FlaxWav2Vec2PreTrainedModel):
module_class = FlaxWav2Vec2ForPreTrainingModule
- @add_start_docstrings_to_model_forward(WAV_2_VEC_2_INPUTS_DOCSTRING)
+ @add_start_docstrings_to_model_forward(WAV2VEC2_INPUTS_DOCSTRING)
# overwrite since has `gumbel_temperature` input
def __call__(
self,
@@ -1418,7 +1418,7 @@ FLAX_WAV2VEC2_FOR_PRETRAINING_DOCSTRING = """
overwrite_call_docstring(
FlaxWav2Vec2ForPreTraining,
- WAV_2_VEC_2_INPUTS_DOCSTRING + FLAX_WAV2VEC2_FOR_PRETRAINING_DOCSTRING,
+ WAV2VEC2_INPUTS_DOCSTRING + FLAX_WAV2VEC2_FOR_PRETRAINING_DOCSTRING,
)
append_replace_return_docstrings(
FlaxWav2Vec2ForPreTraining, output_type=FlaxWav2Vec2ForPreTrainingOutput, config_class=Wav2Vec2Config
diff --git a/src/transformers/models/wav2vec2/modeling_tf_wav2vec2.py b/src/transformers/models/wav2vec2/modeling_tf_wav2vec2.py
index ad923bef80b..c385c192a98 100644
--- a/src/transformers/models/wav2vec2/modeling_tf_wav2vec2.py
+++ b/src/transformers/models/wav2vec2/modeling_tf_wav2vec2.py
@@ -1397,7 +1397,7 @@ class TFWav2Vec2PreTrainedModel(TFPreTrainedModel):
return attention_mask
-WAV_2_VEC_2_START_DOCSTRING = r"""
+WAV2VEC2_START_DOCSTRING = r"""
This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
@@ -1439,7 +1439,7 @@ WAV_2_VEC_2_START_DOCSTRING = r"""
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
-WAV_2_VEC_2_INPUTS_DOCSTRING = r"""
+WAV2VEC2_INPUTS_DOCSTRING = r"""
Args:
input_values (`np.ndarray`, `tf.Tensor`, `List[tf.Tensor]` `Dict[str, tf.Tensor]` or `Dict[str, np.ndarray]` and each example must have the shape `({0})`):
Indices of input sequence tokens in the vocabulary.
@@ -1497,7 +1497,7 @@ WAV_2_VEC_2_INPUTS_DOCSTRING = r"""
@add_start_docstrings(
"The bare TFWav2Vec2 Model transformer outputing raw hidden-states without any specific head on top.",
- WAV_2_VEC_2_START_DOCSTRING,
+ WAV2VEC2_START_DOCSTRING,
)
class TFWav2Vec2Model(TFWav2Vec2PreTrainedModel):
def __init__(self, config: Wav2Vec2Config, *inputs, **kwargs):
@@ -1505,7 +1505,7 @@ class TFWav2Vec2Model(TFWav2Vec2PreTrainedModel):
self.config = config
self.wav2vec2 = TFWav2Vec2MainLayer(config, name="wav2vec2")
- @add_start_docstrings_to_model_forward(WAV_2_VEC_2_INPUTS_DOCSTRING)
+ @add_start_docstrings_to_model_forward(WAV2VEC2_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=TFBaseModelOutput, config_class=_CONFIG_FOR_DOC)
@unpack_inputs
def call(
@@ -1579,7 +1579,7 @@ class TFWav2Vec2Model(TFWav2Vec2PreTrainedModel):
@add_start_docstrings(
"""TFWav2Vec2 Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC).""",
- WAV_2_VEC_2_START_DOCSTRING,
+ WAV2VEC2_START_DOCSTRING,
)
class TFWav2Vec2ForCTC(TFWav2Vec2PreTrainedModel):
def __init__(self, config: Wav2Vec2Config, *inputs, **kwargs):
@@ -1612,7 +1612,7 @@ class TFWav2Vec2ForCTC(TFWav2Vec2PreTrainedModel):
self.wav2vec2.feature_extractor.trainable = False
@unpack_inputs
- @add_start_docstrings_to_model_forward(WAV_2_VEC_2_INPUTS_DOCSTRING)
+ @add_start_docstrings_to_model_forward(WAV2VEC2_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=TFCausalLMOutput, config_class=_CONFIG_FOR_DOC)
def call(
self,
diff --git a/src/transformers/models/wav2vec2/modeling_wav2vec2.py b/src/transformers/models/wav2vec2/modeling_wav2vec2.py
index 93f77372538..2ac0e21486e 100755
--- a/src/transformers/models/wav2vec2/modeling_wav2vec2.py
+++ b/src/transformers/models/wav2vec2/modeling_wav2vec2.py
@@ -63,6 +63,7 @@ if is_safetensors_available():
if is_flash_attn_available():
from ...modeling_flash_attention_utils import _flash_attention_forward
+
logger = logging.get_logger(__name__)
@@ -1633,7 +1634,7 @@ class Wav2Vec2PreTrainedModel(PreTrainedModel):
self.target_lang = target_lang
-WAV_2_VEC_2_START_DOCSTRING = r"""
+WAV2VEC2_START_DOCSTRING = r"""
Wav2Vec2 was proposed in [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech
Representations](https://arxiv.org/abs/2006.11477) by Alexei Baevski, Henry Zhou, Abdelrahman Mohamed, Michael
Auli.
@@ -1652,7 +1653,7 @@ WAV_2_VEC_2_START_DOCSTRING = r"""
"""
-WAV_2_VEC_2_INPUTS_DOCSTRING = r"""
+WAV2VEC2_INPUTS_DOCSTRING = r"""
Args:
input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
Float values of input raw speech waveform. Values can be obtained by loading a `.flac` or `.wav` audio file
@@ -1692,7 +1693,7 @@ WAV_2_VEC_2_INPUTS_DOCSTRING = r"""
@add_start_docstrings(
"The bare Wav2Vec2 Model transformer outputting raw hidden-states without any specific head on top.",
- WAV_2_VEC_2_START_DOCSTRING,
+ WAV2VEC2_START_DOCSTRING,
)
class Wav2Vec2Model(Wav2Vec2PreTrainedModel):
def __init__(self, config: Wav2Vec2Config):
@@ -1780,7 +1781,7 @@ class Wav2Vec2Model(Wav2Vec2PreTrainedModel):
return hidden_states
- @add_start_docstrings_to_model_forward(WAV_2_VEC_2_INPUTS_DOCSTRING)
+ @add_start_docstrings_to_model_forward(WAV2VEC2_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=Wav2Vec2BaseModelOutput,
@@ -1841,7 +1842,7 @@ class Wav2Vec2Model(Wav2Vec2PreTrainedModel):
)
-@add_start_docstrings("""Wav2Vec2 Model with a quantizer and `VQ` head on top.""", WAV_2_VEC_2_START_DOCSTRING)
+@add_start_docstrings("""Wav2Vec2 Model with a quantizer and `VQ` head on top.""", WAV2VEC2_START_DOCSTRING)
class Wav2Vec2ForPreTraining(Wav2Vec2PreTrainedModel):
def __init__(self, config: Wav2Vec2Config):
super().__init__(config)
@@ -1902,7 +1903,7 @@ class Wav2Vec2ForPreTraining(Wav2Vec2PreTrainedModel):
logits = logits / temperature
return logits
- @add_start_docstrings_to_model_forward(WAV_2_VEC_2_INPUTS_DOCSTRING)
+ @add_start_docstrings_to_model_forward(WAV2VEC2_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=Wav2Vec2ForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
@@ -2065,7 +2066,7 @@ class Wav2Vec2ForPreTraining(Wav2Vec2PreTrainedModel):
)
-@add_start_docstrings("""Wav2Vec2 Model with a `language modeling` head on top.""", WAV_2_VEC_2_START_DOCSTRING)
+@add_start_docstrings("""Wav2Vec2 Model with a `language modeling` head on top.""", WAV2VEC2_START_DOCSTRING)
class Wav2Vec2ForMaskedLM(Wav2Vec2PreTrainedModel):
def __init__(self, config):
super().__init__(config)
@@ -2081,7 +2082,7 @@ class Wav2Vec2ForMaskedLM(Wav2Vec2PreTrainedModel):
# Initialize weights and apply final processing
self.post_init()
- @add_start_docstrings_to_model_forward(WAV_2_VEC_2_INPUTS_DOCSTRING)
+ @add_start_docstrings_to_model_forward(WAV2VEC2_INPUTS_DOCSTRING)
def forward(
self,
input_values: torch.FloatTensor,
@@ -2113,7 +2114,7 @@ class Wav2Vec2ForMaskedLM(Wav2Vec2PreTrainedModel):
@add_start_docstrings(
"""Wav2Vec2 Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC).""",
- WAV_2_VEC_2_START_DOCSTRING,
+ WAV2VEC2_START_DOCSTRING,
"""
target_lang (`str`, *optional*):
Language id of adapter weights. Adapter weights are stored in the format adapter..safetensors or
@@ -2193,7 +2194,7 @@ class Wav2Vec2ForCTC(Wav2Vec2PreTrainedModel):
for param in self.wav2vec2.parameters():
param.requires_grad = False
- @add_start_docstrings_to_model_forward(WAV_2_VEC_2_INPUTS_DOCSTRING)
+ @add_start_docstrings_to_model_forward(WAV2VEC2_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=CausalLMOutput,
@@ -2277,7 +2278,7 @@ class Wav2Vec2ForCTC(Wav2Vec2PreTrainedModel):
Wav2Vec2 Model with a sequence classification head on top (a linear layer over the pooled output) for tasks like
SUPERB Keyword Spotting.
""",
- WAV_2_VEC_2_START_DOCSTRING,
+ WAV2VEC2_START_DOCSTRING,
)
class Wav2Vec2ForSequenceClassification(Wav2Vec2PreTrainedModel):
def __init__(self, config):
@@ -2324,7 +2325,7 @@ class Wav2Vec2ForSequenceClassification(Wav2Vec2PreTrainedModel):
for param in self.wav2vec2.parameters():
param.requires_grad = False
- @add_start_docstrings_to_model_forward(WAV_2_VEC_2_INPUTS_DOCSTRING)
+ @add_start_docstrings_to_model_forward(WAV2VEC2_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint=_SEQ_CLASS_CHECKPOINT,
output_type=SequenceClassifierOutput,
@@ -2400,7 +2401,7 @@ class Wav2Vec2ForSequenceClassification(Wav2Vec2PreTrainedModel):
"""
Wav2Vec2 Model with a frame classification head on top for tasks like Speaker Diarization.
""",
- WAV_2_VEC_2_START_DOCSTRING,
+ WAV2VEC2_START_DOCSTRING,
)
class Wav2Vec2ForAudioFrameClassification(Wav2Vec2PreTrainedModel):
def __init__(self, config):
@@ -2446,7 +2447,7 @@ class Wav2Vec2ForAudioFrameClassification(Wav2Vec2PreTrainedModel):
for param in self.wav2vec2.parameters():
param.requires_grad = False
- @add_start_docstrings_to_model_forward(WAV_2_VEC_2_INPUTS_DOCSTRING)
+ @add_start_docstrings_to_model_forward(WAV2VEC2_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint=_FRAME_CLASS_CHECKPOINT,
output_type=TokenClassifierOutput,
@@ -2546,6 +2547,7 @@ class TDNNLayer(nn.Module):
if is_peft_available():
from peft.tuners.lora import LoraLayer
+ if is_peft_available():
if isinstance(self.kernel, LoraLayer):
warnings.warn(
"Detected LoRA on TDNNLayer. LoRA weights won't be applied due to optimization. "
@@ -2566,7 +2568,7 @@ class TDNNLayer(nn.Module):
"""
Wav2Vec2 Model with an XVector feature extraction head on top for tasks like Speaker Verification.
""",
- WAV_2_VEC_2_START_DOCSTRING,
+ WAV2VEC2_START_DOCSTRING,
)
class Wav2Vec2ForXVector(Wav2Vec2PreTrainedModel):
def __init__(self, config):
@@ -2630,7 +2632,7 @@ class Wav2Vec2ForXVector(Wav2Vec2PreTrainedModel):
return input_lengths
- @add_start_docstrings_to_model_forward(WAV_2_VEC_2_INPUTS_DOCSTRING)
+ @add_start_docstrings_to_model_forward(WAV2VEC2_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint=_XVECTOR_CHECKPOINT,
output_type=XVectorOutput,
diff --git a/src/transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py b/src/transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py
index 0fb591542f4..86e9b65b3d5 100644
--- a/src/transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py
+++ b/src/transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py
@@ -1,26 +1,15 @@
-# coding=utf-8
-# Copyright 2024 The Seamless Authors and the HuggingFace Inc. team. All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-"""PyTorch Wav2Vec2-BERT model."""
-
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_wav2vec2_bert.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
import math
import warnings
from typing import Optional, Tuple, Union
import numpy as np
import torch
-import torch.utils.checkpoint
from torch import nn
from torch.nn import CrossEntropyLoss
@@ -42,213 +31,14 @@ from ...utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_peft_available,
- logging,
)
from .configuration_wav2vec2_bert import Wav2Vec2BertConfig
-logger = logging.get_logger(__name__)
-
-
-_HIDDEN_STATES_START_POSITION = 2
-
# General docstring
_CONFIG_FOR_DOC = "Wav2Vec2BertConfig"
-# Base docstring
-_BASE_CHECKPOINT_FOR_DOC = "facebook/w2v-bert-2.0"
-_PRETRAINED_CHECKPOINT_FOR_DOC = "hf-audio/wav2vec2-bert-CV16-en"
-_EXPECTED_OUTPUT_SHAPE = [1, 146, 1024]
-# CTC docstring
-_CTC_EXPECTED_OUTPUT = "'mr quilter is the apostle of the middle classes and we are glad to welcome his gospel'"
-_CTC_EXPECTED_LOSS = 17.04
-
-
-# Copied from transformers.models.seamless_m4t_v2.modeling_seamless_m4t_v2._compute_new_attention_mask
-def _compute_new_attention_mask(hidden_states: torch.Tensor, seq_lens: torch.Tensor):
- """
- Computes an attention mask of the form `(batch, seq_len)` with an attention for each element in the batch that
- stops at the corresponding element in `seq_lens`.
- Args:
- hidden_states (`torch.FloatTensor` of shape `(batch, seq_len, *)`):
- The sequences to mask, where `*` is any number of sequence-specific dimensions including none.
- seq_lens (`torch.Tensor` of shape `(batch)`:
- Each element represents the length of the sequence at the same index in `hidden_states`
- Returns:
- `torch.FloatTensor`: The float attention mask of shape `(batch, seq_len)`
- """
- batch_size, mask_seq_len = hidden_states.shape[:2]
-
- indices = torch.arange(mask_seq_len, device=seq_lens.device).expand(batch_size, -1)
-
- bool_mask = indices >= seq_lens.unsqueeze(1).expand(-1, mask_seq_len)
-
- mask = hidden_states.new_ones((batch_size, mask_seq_len))
-
- mask = mask.masked_fill(bool_mask, 0)
-
- return mask
-
-
-# Copied from transformers.models.wav2vec2.modeling_wav2vec2._compute_mask_indices
-def _compute_mask_indices(
- shape: Tuple[int, int],
- mask_prob: float,
- mask_length: int,
- attention_mask: Optional[torch.LongTensor] = None,
- min_masks: int = 0,
-) -> np.ndarray:
- """
- Computes random mask spans for a given shape. Used to implement [SpecAugment: A Simple Data Augmentation Method for
- ASR](https://arxiv.org/abs/1904.08779). Note that this method is not optimized to run on TPU and should be run on
- CPU as part of the preprocessing during training.
-
- Args:
- shape: The shape for which to compute masks. This should be of a tuple of size 2 where
- the first element is the batch size and the second element is the length of the axis to span.
- mask_prob: The percentage of the whole axis (between 0 and 1) which will be masked. The number of
- independently generated mask spans of length `mask_length` is computed by
- `mask_prob*shape[1]/mask_length`. Note that due to overlaps, `mask_prob` is an upper bound and the
- actual percentage will be smaller.
- mask_length: size of the mask
- min_masks: minimum number of masked spans
- attention_mask: A (right-padded) attention mask which independently shortens the feature axis of
- each batch dimension.
- """
- batch_size, sequence_length = shape
-
- if mask_length < 1:
- raise ValueError("`mask_length` has to be bigger than 0.")
-
- if mask_length > sequence_length:
- raise ValueError(
- f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length}"
- f" and `sequence_length`: {sequence_length}`"
- )
-
- # epsilon is used for probabilistic rounding
- epsilon = np.random.rand(1).item()
-
- def compute_num_masked_span(input_length):
- """Given input length, compute how many spans should be masked"""
- num_masked_span = int(mask_prob * input_length / mask_length + epsilon)
- num_masked_span = max(num_masked_span, min_masks)
-
- # make sure num masked span <= sequence_length
- if num_masked_span * mask_length > sequence_length:
- num_masked_span = sequence_length // mask_length
-
- # make sure num_masked span is also <= input_length - (mask_length - 1)
- if input_length - (mask_length - 1) < num_masked_span:
- num_masked_span = max(input_length - (mask_length - 1), 0)
-
- return num_masked_span
-
- # compute number of masked spans in batch
- input_lengths = (
- attention_mask.detach().sum(-1).tolist()
- if attention_mask is not None
- else [sequence_length for _ in range(batch_size)]
- )
-
- # SpecAugment mask to fill
- spec_aug_mask = np.zeros((batch_size, sequence_length), dtype=bool)
- spec_aug_mask_idxs = []
-
- max_num_masked_span = compute_num_masked_span(sequence_length)
-
- if max_num_masked_span == 0:
- return spec_aug_mask
-
- for input_length in input_lengths:
- # compute num of masked spans for this input
- num_masked_span = compute_num_masked_span(input_length)
-
- # get random indices to mask
- spec_aug_mask_idx = np.random.choice(
- np.arange(input_length - (mask_length - 1)), num_masked_span, replace=False
- )
-
- # pick first sampled index that will serve as a dummy index to pad vector
- # to ensure same dimension for all batches due to probabilistic rounding
- # Picking first sample just pads those vectors twice.
- if len(spec_aug_mask_idx) == 0:
- # this case can only happen if `input_length` is strictly smaller then
- # `sequence_length` in which case the last token has to be a padding
- # token which we can use as a dummy mask id
- dummy_mask_idx = sequence_length - 1
- else:
- dummy_mask_idx = spec_aug_mask_idx[0]
-
- spec_aug_mask_idx = np.concatenate(
- [spec_aug_mask_idx, np.ones(max_num_masked_span - num_masked_span, dtype=np.int32) * dummy_mask_idx]
- )
- spec_aug_mask_idxs.append(spec_aug_mask_idx)
-
- spec_aug_mask_idxs = np.array(spec_aug_mask_idxs)
-
- # expand masked indices to masked spans
- spec_aug_mask_idxs = np.broadcast_to(
- spec_aug_mask_idxs[:, :, None], (batch_size, max_num_masked_span, mask_length)
- )
- spec_aug_mask_idxs = spec_aug_mask_idxs.reshape(batch_size, max_num_masked_span * mask_length)
-
- # add offset to the starting indexes so that indexes now create a span
- offsets = np.arange(mask_length)[None, None, :]
- offsets = np.broadcast_to(offsets, (batch_size, max_num_masked_span, mask_length)).reshape(
- batch_size, max_num_masked_span * mask_length
- )
- spec_aug_mask_idxs = spec_aug_mask_idxs + offsets
-
- # ensure that we cannot have indices larger than sequence_length
- if spec_aug_mask_idxs.max() > sequence_length - 1:
- spec_aug_mask_idxs[spec_aug_mask_idxs > sequence_length - 1] = sequence_length - 1
-
- # scatter indices to mask
- np.put_along_axis(spec_aug_mask, spec_aug_mask_idxs, 1, -1)
-
- return spec_aug_mask
-
-
-# Copied from transformers.models.wav2vec2.modeling_wav2vec2._sample_negative_indices
-def _sample_negative_indices(
- features_shape: Tuple, num_negatives: int, mask_time_indices: Optional[np.ndarray] = None
-):
- """
- Sample `num_negatives` vectors from feature vectors.
- """
- batch_size, sequence_length = features_shape
-
- # generate indices of the positive vectors themselves, repeat them `num_negatives` times
- sequence_length_range = np.arange(sequence_length)
-
- # get `num_negatives` random vector indices from the same utterance
- sampled_negative_indices = np.zeros(shape=(batch_size, sequence_length, num_negatives), dtype=np.int32)
-
- mask_time_indices = (
- mask_time_indices.astype(bool) if mask_time_indices is not None else np.ones(features_shape, dtype=bool)
- )
-
- for batch_idx in range(batch_size):
- high = mask_time_indices[batch_idx].sum() - 1
- mapped_masked_indices = sequence_length_range[mask_time_indices[batch_idx]]
-
- feature_indices = np.broadcast_to(np.arange(high + 1)[:, None], (high + 1, num_negatives))
- sampled_indices = np.random.randint(0, high, size=(high + 1, num_negatives))
- # avoid sampling the same positive vector, but keep the distribution uniform
- sampled_indices[sampled_indices >= feature_indices] += 1
-
- # remap to actual indices
- sampled_negative_indices[batch_idx][mask_time_indices[batch_idx]] = mapped_masked_indices[sampled_indices]
-
- # correct for batch size
- sampled_negative_indices[batch_idx] += batch_idx * sequence_length
-
- return sampled_negative_indices
-
-
-# Copied from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer.Wav2Vec2ConformerRotaryPositionalEmbedding with Wav2Vec2Conformer->Wav2Vec2Bert
class Wav2Vec2BertRotaryPositionalEmbedding(nn.Module):
"""Rotary positional embedding
Reference : https://blog.eleuther.ai/rotary-embeddings/ Paper: https://arxiv.org/pdf/2104.09864.pdf
@@ -284,7 +74,6 @@ class Wav2Vec2BertRotaryPositionalEmbedding(nn.Module):
return self.cached_rotary_positional_embedding
-# Copied from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer.Wav2Vec2ConformerRelPositionalEmbedding with Wav2Vec2Conformer->Wav2Vec2Bert
class Wav2Vec2BertRelPositionalEmbedding(nn.Module):
"""Relative positional encoding module."""
@@ -363,7 +152,6 @@ class Wav2Vec2BertFeedForward(nn.Module):
self.output_dense = nn.Linear(config.intermediate_size, hidden_size)
self.output_dropout = nn.Dropout(config.hidden_dropout)
- # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2FeedForward.forward
def forward(self, hidden_states):
hidden_states = self.intermediate_dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states)
@@ -556,7 +344,6 @@ class Wav2Vec2BertSelfAttention(nn.Module):
return hidden_states, probs
- # Copied from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer.Wav2Vec2ConformerSelfAttention._apply_rotary_embedding
def _apply_rotary_embedding(self, hidden_states, relative_position_embeddings):
batch_size, sequence_length, hidden_size = hidden_states.size()
hidden_states = hidden_states.view(batch_size, sequence_length, self.num_heads, self.head_size)
@@ -576,7 +363,6 @@ class Wav2Vec2BertSelfAttention(nn.Module):
return hidden_states
- # Copied from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer.Wav2Vec2ConformerSelfAttention._apply_relative_embeddings
def _apply_relative_embeddings(self, query, key, relative_position_embeddings):
# 1. project positional embeddings
# => (batch, head, 2*time1-1, d_k)
@@ -823,6 +609,32 @@ class Wav2Vec2BertAdapter(nn.Module):
return hidden_states
+# Copied from transformers.models.seamless_m4t_v2.modeling_seamless_m4t_v2._compute_new_attention_mask
+def _compute_new_attention_mask(hidden_states: torch.Tensor, seq_lens: torch.Tensor):
+ """
+ Computes an attention mask of the form `(batch, seq_len)` with an attention for each element in the batch that
+ stops at the corresponding element in `seq_lens`.
+ Args:
+ hidden_states (`torch.FloatTensor` of shape `(batch, seq_len, *)`):
+ The sequences to mask, where `*` is any number of sequence-specific dimensions including none.
+ seq_lens (`torch.Tensor` of shape `(batch)`:
+ Each element represents the length of the sequence at the same index in `hidden_states`
+ Returns:
+ `torch.FloatTensor`: The float attention mask of shape `(batch, seq_len)`
+ """
+ batch_size, mask_seq_len = hidden_states.shape[:2]
+
+ indices = torch.arange(mask_seq_len, device=seq_lens.device).expand(batch_size, -1)
+
+ bool_mask = indices >= seq_lens.unsqueeze(1).expand(-1, mask_seq_len)
+
+ mask = hidden_states.new_ones((batch_size, mask_seq_len))
+
+ mask = mask.masked_fill(bool_mask, 0)
+
+ return mask
+
+
class Wav2Vec2BertAdapterLayer(nn.Module):
def __init__(self, config):
super().__init__()
@@ -911,7 +723,6 @@ class Wav2Vec2BertAdapterLayer(nn.Module):
return hidden_states
-# Copied from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer.Wav2Vec2ConformerPreTrainedModel with Wav2Vec2Conformer->Wav2Vec2Bert,wav2vec2_conformer->wav2vec2_bert, input_values->input_features
class Wav2Vec2BertPreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
@@ -995,6 +806,129 @@ class Wav2Vec2BertPreTrainedModel(PreTrainedModel):
return attention_mask
+def _compute_mask_indices(
+ shape: Tuple[int, int],
+ mask_prob: float,
+ mask_length: int,
+ attention_mask: Optional[torch.LongTensor] = None,
+ min_masks: int = 0,
+) -> np.ndarray:
+ """
+ Computes random mask spans for a given shape. Used to implement [SpecAugment: A Simple Data Augmentation Method for
+ ASR](https://arxiv.org/abs/1904.08779). Note that this method is not optimized to run on TPU and should be run on
+ CPU as part of the preprocessing during training.
+
+ Args:
+ shape: The shape for which to compute masks. This should be of a tuple of size 2 where
+ the first element is the batch size and the second element is the length of the axis to span.
+ mask_prob: The percentage of the whole axis (between 0 and 1) which will be masked. The number of
+ independently generated mask spans of length `mask_length` is computed by
+ `mask_prob*shape[1]/mask_length`. Note that due to overlaps, `mask_prob` is an upper bound and the
+ actual percentage will be smaller.
+ mask_length: size of the mask
+ min_masks: minimum number of masked spans
+ attention_mask: A (right-padded) attention mask which independently shortens the feature axis of
+ each batch dimension.
+ """
+ batch_size, sequence_length = shape
+
+ if mask_length < 1:
+ raise ValueError("`mask_length` has to be bigger than 0.")
+
+ if mask_length > sequence_length:
+ raise ValueError(
+ f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length}"
+ f" and `sequence_length`: {sequence_length}`"
+ )
+
+ # epsilon is used for probabilistic rounding
+ epsilon = np.random.rand(1).item()
+
+ def compute_num_masked_span(input_length):
+ """Given input length, compute how many spans should be masked"""
+ num_masked_span = int(mask_prob * input_length / mask_length + epsilon)
+ num_masked_span = max(num_masked_span, min_masks)
+
+ # make sure num masked span <= sequence_length
+ if num_masked_span * mask_length > sequence_length:
+ num_masked_span = sequence_length // mask_length
+
+ # make sure num_masked span is also <= input_length - (mask_length - 1)
+ if input_length - (mask_length - 1) < num_masked_span:
+ num_masked_span = max(input_length - (mask_length - 1), 0)
+
+ return num_masked_span
+
+ # compute number of masked spans in batch
+ input_lengths = (
+ attention_mask.detach().sum(-1).tolist()
+ if attention_mask is not None
+ else [sequence_length for _ in range(batch_size)]
+ )
+
+ # SpecAugment mask to fill
+ spec_aug_mask = np.zeros((batch_size, sequence_length), dtype=bool)
+ spec_aug_mask_idxs = []
+
+ max_num_masked_span = compute_num_masked_span(sequence_length)
+
+ if max_num_masked_span == 0:
+ return spec_aug_mask
+
+ for input_length in input_lengths:
+ # compute num of masked spans for this input
+ num_masked_span = compute_num_masked_span(input_length)
+
+ # get random indices to mask
+ spec_aug_mask_idx = np.random.choice(
+ np.arange(input_length - (mask_length - 1)), num_masked_span, replace=False
+ )
+
+ # pick first sampled index that will serve as a dummy index to pad vector
+ # to ensure same dimension for all batches due to probabilistic rounding
+ # Picking first sample just pads those vectors twice.
+ if len(spec_aug_mask_idx) == 0:
+ # this case can only happen if `input_length` is strictly smaller then
+ # `sequence_length` in which case the last token has to be a padding
+ # token which we can use as a dummy mask id
+ dummy_mask_idx = sequence_length - 1
+ else:
+ dummy_mask_idx = spec_aug_mask_idx[0]
+
+ spec_aug_mask_idx = np.concatenate(
+ [spec_aug_mask_idx, np.ones(max_num_masked_span - num_masked_span, dtype=np.int32) * dummy_mask_idx]
+ )
+ spec_aug_mask_idxs.append(spec_aug_mask_idx)
+
+ spec_aug_mask_idxs = np.array(spec_aug_mask_idxs)
+
+ # expand masked indices to masked spans
+ spec_aug_mask_idxs = np.broadcast_to(
+ spec_aug_mask_idxs[:, :, None], (batch_size, max_num_masked_span, mask_length)
+ )
+ spec_aug_mask_idxs = spec_aug_mask_idxs.reshape(batch_size, max_num_masked_span * mask_length)
+
+ # add offset to the starting indexes so that indexes now create a span
+ offsets = np.arange(mask_length)[None, None, :]
+ offsets = np.broadcast_to(offsets, (batch_size, max_num_masked_span, mask_length)).reshape(
+ batch_size, max_num_masked_span * mask_length
+ )
+ spec_aug_mask_idxs = spec_aug_mask_idxs + offsets
+
+ # ensure that we cannot have indices larger than sequence_length
+ if spec_aug_mask_idxs.max() > sequence_length - 1:
+ spec_aug_mask_idxs[spec_aug_mask_idxs > sequence_length - 1] = sequence_length - 1
+
+ # scatter indices to mask
+ np.put_along_axis(spec_aug_mask, spec_aug_mask_idxs, 1, -1)
+
+ return spec_aug_mask
+
+
+_PRETRAINED_CHECKPOINT_FOR_DOC = "hf-audio/wav2vec2-bert-CV16-en"
+_EXPECTED_OUTPUT_SHAPE = [1, 146, 1024]
+
+
WAV2VEC2_BERT_START_DOCSTRING = r"""
Wav2Vec2Bert was proposed in [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech
Representations](https://arxiv.org/abs/2006.11477) by Alexei Baevski, Henry Zhou, Abdelrahman Mohamed, Michael
@@ -1003,8 +937,9 @@ WAV2VEC2_BERT_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving etc.).
- This model is a PyTorch [nn.Module](https://pytorch.org/docs/stable/nn.html#nn.Module) sub-class. Use it as a
- regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior.
+ This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use
+ it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
+ behavior.
Parameters:
config ([`Wav2Vec2BertConfig`]): Model configuration class with all the parameters of the model.
@@ -1012,7 +947,6 @@ WAV2VEC2_BERT_START_DOCSTRING = r"""
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
-
WAV2VEC2_BERT_INPUTS_DOCSTRING = r"""
Args:
input_features (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
@@ -1039,6 +973,9 @@ WAV2VEC2_BERT_INPUTS_DOCSTRING = r"""
"""
+Wav2Vec2BertBaseModelOutput = Wav2Vec2BaseModelOutput
+
+
@add_start_docstrings(
"The bare Wav2Vec2Bert Model transformer outputting raw hidden-states without any specific head on top.",
WAV2VEC2_BERT_START_DOCSTRING,
@@ -1064,7 +1001,6 @@ class Wav2Vec2BertModel(Wav2Vec2BertPreTrainedModel):
# Initialize weights and apply final processing
self.post_init()
- # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Model._mask_hidden_states
def _mask_hidden_states(
self,
hidden_states: torch.FloatTensor,
@@ -1114,7 +1050,7 @@ class Wav2Vec2BertModel(Wav2Vec2BertPreTrainedModel):
@add_start_docstrings_to_model_forward(WAV2VEC2_BERT_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint=_PRETRAINED_CHECKPOINT_FOR_DOC,
- output_type=Wav2Vec2BaseModelOutput,
+ output_type=Wav2Vec2BertBaseModelOutput,
config_class=_CONFIG_FOR_DOC,
modality="audio",
expected_output=_EXPECTED_OUTPUT_SHAPE,
@@ -1127,7 +1063,7 @@ class Wav2Vec2BertModel(Wav2Vec2BertPreTrainedModel):
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
- ) -> Union[Tuple, Wav2Vec2BaseModelOutput]:
+ ) -> Union[Tuple, Wav2Vec2BertBaseModelOutput]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
@@ -1159,7 +1095,7 @@ class Wav2Vec2BertModel(Wav2Vec2BertPreTrainedModel):
if not return_dict:
return (hidden_states, extract_features) + encoder_outputs[1:]
- return Wav2Vec2BaseModelOutput(
+ return Wav2Vec2BertBaseModelOutput(
last_hidden_state=hidden_states,
extract_features=extract_features,
hidden_states=encoder_outputs.hidden_states,
@@ -1167,12 +1103,18 @@ class Wav2Vec2BertModel(Wav2Vec2BertPreTrainedModel):
)
+_HIDDEN_STATES_START_POSITION = 2
+
+# CTC docstring
+_CTC_EXPECTED_OUTPUT = "'mr quilter is the apostle of the middle classes and we are glad to welcome his gospel'"
+_CTC_EXPECTED_LOSS = 17.04
+
+
@add_start_docstrings(
"""Wav2Vec2Bert Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC).""",
WAV2VEC2_BERT_START_DOCSTRING,
)
class Wav2Vec2BertForCTC(Wav2Vec2BertPreTrainedModel):
- # Copied from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer.Wav2Vec2ConformerForCTC.__init__ with Wav2Vec2Conformer->Wav2Vec2Bert,WAV2VEC2_CONFORMER->WAV2VEC2_BERT,wav2vec2_conformer->wav2vec2_bert
def __init__(self, config, target_lang: Optional[str] = None):
super().__init__(config)
@@ -1277,6 +1219,10 @@ class Wav2Vec2BertForCTC(Wav2Vec2BertPreTrainedModel):
)
+# Base docstring
+_BASE_CHECKPOINT_FOR_DOC = "facebook/w2v-bert-2.0"
+
+
@add_start_docstrings(
"""
Wav2Vec2Bert Model with a sequence classification head on top (a linear layer over the pooled output) for
@@ -1285,7 +1231,6 @@ class Wav2Vec2BertForCTC(Wav2Vec2BertPreTrainedModel):
WAV2VEC2_BERT_START_DOCSTRING,
)
class Wav2Vec2BertForSequenceClassification(Wav2Vec2BertPreTrainedModel):
- # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.__init__ with Wav2Vec2->Wav2Vec2Bert,wav2vec2->wav2vec2_bert
def __init__(self, config):
super().__init__(config)
@@ -1318,7 +1263,6 @@ class Wav2Vec2BertForSequenceClassification(Wav2Vec2BertPreTrainedModel):
config_class=_CONFIG_FOR_DOC,
modality="audio",
)
- # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.forward with Wav2Vec2->Wav2Vec2Bert,wav2vec2->wav2vec2_bert,WAV_2_VEC_2->WAV2VEC2_BERT, input_values->input_features
def forward(
self,
input_features: Optional[torch.Tensor],
@@ -1389,7 +1333,6 @@ class Wav2Vec2BertForSequenceClassification(Wav2Vec2BertPreTrainedModel):
WAV2VEC2_BERT_START_DOCSTRING,
)
class Wav2Vec2BertForAudioFrameClassification(Wav2Vec2BertPreTrainedModel):
- # Copied from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer.Wav2Vec2ConformerForAudioFrameClassification.__init__ with Wav2Vec2Conformer->Wav2Vec2Bert,WAV2VEC2_CONFORMER->WAV2VEC2_BERT,wav2vec2_conformer->wav2vec2_bert
def __init__(self, config):
super().__init__(config)
@@ -1406,7 +1349,6 @@ class Wav2Vec2BertForAudioFrameClassification(Wav2Vec2BertPreTrainedModel):
self.init_weights()
- # Copied from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer.Wav2Vec2ConformerForAudioFrameClassification.freeze_base_model with wav2vec2_conformer->wav2vec2_bert
def freeze_base_model(self):
"""
Calling this function will disable the gradient computation for the base model so that its parameters will not
@@ -1422,7 +1364,6 @@ class Wav2Vec2BertForAudioFrameClassification(Wav2Vec2BertPreTrainedModel):
config_class=_CONFIG_FOR_DOC,
modality="audio",
)
- # Copied from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer.Wav2Vec2ConformerForAudioFrameClassification.forward with wav2vec2_conformer->wav2vec2_bert, input_values->input_features
def forward(
self,
input_features: Optional[torch.Tensor],
@@ -1477,7 +1418,6 @@ class Wav2Vec2BertForAudioFrameClassification(Wav2Vec2BertPreTrainedModel):
)
-# Copied from transformers.models.wav2vec2.modeling_wav2vec2.AMSoftmaxLoss
class AMSoftmaxLoss(nn.Module):
def __init__(self, input_dim, num_labels, scale=30.0, margin=0.4):
super(AMSoftmaxLoss, self).__init__()
@@ -1501,7 +1441,6 @@ class AMSoftmaxLoss(nn.Module):
return loss
-# Copied from transformers.models.wav2vec2.modeling_wav2vec2.TDNNLayer
class TDNNLayer(nn.Module):
def __init__(self, config, layer_id=0):
super().__init__()
@@ -1517,6 +1456,7 @@ class TDNNLayer(nn.Module):
if is_peft_available():
from peft.tuners.lora import LoraLayer
+ if is_peft_available():
if isinstance(self.kernel, LoraLayer):
warnings.warn(
"Detected LoRA on TDNNLayer. LoRA weights won't be applied due to optimization. "
@@ -1540,7 +1480,6 @@ class TDNNLayer(nn.Module):
WAV2VEC2_BERT_START_DOCSTRING,
)
class Wav2Vec2BertForXVector(Wav2Vec2BertPreTrainedModel):
- # Copied from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer.Wav2Vec2ConformerForXVector.__init__ with Wav2Vec2Conformer->Wav2Vec2Bert,WAV2VEC2_CONFORMER->WAV2VEC2_BERT,wav2vec2_conformer->wav2vec2_bert
def __init__(self, config):
super().__init__(config)
@@ -1560,7 +1499,6 @@ class Wav2Vec2BertForXVector(Wav2Vec2BertPreTrainedModel):
self.init_weights()
- # Copied from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer.Wav2Vec2ConformerForXVector.freeze_base_model with wav2vec2_conformer->wav2vec2_bert
def freeze_base_model(self):
"""
Calling this function will disable the gradient computation for the base model so that its parameters will not
@@ -1569,7 +1507,6 @@ class Wav2Vec2BertForXVector(Wav2Vec2BertPreTrainedModel):
for param in self.wav2vec2_bert.parameters():
param.requires_grad = False
- # Copied from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer.Wav2Vec2ConformerForXVector._get_tdnn_output_lengths
def _get_tdnn_output_lengths(self, input_lengths: Union[torch.LongTensor, int]):
"""
Computes the output length of the TDNN layers
@@ -1592,7 +1529,6 @@ class Wav2Vec2BertForXVector(Wav2Vec2BertPreTrainedModel):
config_class=_CONFIG_FOR_DOC,
modality="audio",
)
- # Copied from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer.Wav2Vec2ConformerForXVector.forward with wav2vec2_conformer->wav2vec2_bert, input_values->input_features
def forward(
self,
input_features: Optional[torch.Tensor],
diff --git a/src/transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py b/src/transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py
new file mode 100644
index 00000000000..b8dc95c6754
--- /dev/null
+++ b/src/transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py
@@ -0,0 +1,1169 @@
+import math
+from typing import Optional, Tuple, Union
+
+import torch
+from torch import nn
+from torch.nn import CrossEntropyLoss
+
+from ...activations import ACT2FN
+from ...integrations.deepspeed import is_deepspeed_zero3_enabled
+from ...integrations.fsdp import is_fsdp_managed_module
+from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
+from ...modeling_outputs import (
+ BaseModelOutput,
+ CausalLMOutput,
+ SequenceClassifierOutput,
+ TokenClassifierOutput,
+ Wav2Vec2BaseModelOutput,
+ XVectorOutput,
+)
+from ...modeling_utils import PreTrainedModel
+from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
+from ..wav2vec2.modeling_wav2vec2 import Wav2Vec2FeedForward, Wav2Vec2ForSequenceClassification, Wav2Vec2Model
+from ..wav2vec2_conformer.modeling_wav2vec2_conformer import (
+ Wav2Vec2ConformerForAudioFrameClassification,
+ Wav2Vec2ConformerForCTC,
+ Wav2Vec2ConformerForXVector,
+ Wav2Vec2ConformerRelPositionalEmbedding,
+ Wav2Vec2ConformerRotaryPositionalEmbedding,
+ Wav2Vec2ConformerSelfAttention,
+)
+from .configuration_wav2vec2_bert import Wav2Vec2BertConfig
+
+
+logger = logging.get_logger(__name__)
+
+
+_HIDDEN_STATES_START_POSITION = 2
+
+# General docstring
+_CONFIG_FOR_DOC = "Wav2Vec2BertConfig"
+
+# Base docstring
+_BASE_CHECKPOINT_FOR_DOC = "facebook/w2v-bert-2.0"
+_PRETRAINED_CHECKPOINT_FOR_DOC = "hf-audio/wav2vec2-bert-CV16-en"
+_EXPECTED_OUTPUT_SHAPE = [1, 146, 1024]
+
+# CTC docstring
+_CTC_EXPECTED_OUTPUT = "'mr quilter is the apostle of the middle classes and we are glad to welcome his gospel'"
+_CTC_EXPECTED_LOSS = 17.04
+
+
+# Copied from transformers.models.seamless_m4t_v2.modeling_seamless_m4t_v2._compute_new_attention_mask
+def _compute_new_attention_mask(hidden_states: torch.Tensor, seq_lens: torch.Tensor):
+ """
+ Computes an attention mask of the form `(batch, seq_len)` with an attention for each element in the batch that
+ stops at the corresponding element in `seq_lens`.
+ Args:
+ hidden_states (`torch.FloatTensor` of shape `(batch, seq_len, *)`):
+ The sequences to mask, where `*` is any number of sequence-specific dimensions including none.
+ seq_lens (`torch.Tensor` of shape `(batch)`:
+ Each element represents the length of the sequence at the same index in `hidden_states`
+ Returns:
+ `torch.FloatTensor`: The float attention mask of shape `(batch, seq_len)`
+ """
+ batch_size, mask_seq_len = hidden_states.shape[:2]
+
+ indices = torch.arange(mask_seq_len, device=seq_lens.device).expand(batch_size, -1)
+
+ bool_mask = indices >= seq_lens.unsqueeze(1).expand(-1, mask_seq_len)
+
+ mask = hidden_states.new_ones((batch_size, mask_seq_len))
+
+ mask = mask.masked_fill(bool_mask, 0)
+
+ return mask
+
+
+class Wav2Vec2BertRotaryPositionalEmbedding(Wav2Vec2ConformerRotaryPositionalEmbedding, nn.Module):
+ def __init__(self, config):
+ nn.Module.__init__()
+ dim = config.hidden_size // config.num_attention_heads
+ base = config.rotary_embedding_base
+
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim))
+ # Ignore copy
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
+ self.cached_sequence_length = None
+ self.cached_rotary_positional_embedding = None
+
+
+class Wav2Vec2BertRelPositionalEmbedding(Wav2Vec2ConformerRelPositionalEmbedding):
+ pass
+
+
+class Wav2Vec2BertFeatureProjection(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.layer_norm = nn.LayerNorm(config.feature_projection_input_dim, eps=config.layer_norm_eps)
+ self.projection = nn.Linear(config.feature_projection_input_dim, config.hidden_size)
+ self.dropout = nn.Dropout(config.feat_proj_dropout)
+
+ def forward(self, hidden_states):
+ # non-projected hidden states are needed for quantization
+ norm_hidden_states = self.layer_norm(hidden_states)
+ hidden_states = self.projection(norm_hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ return hidden_states, norm_hidden_states
+
+
+class Wav2Vec2BertFeedForward(Wav2Vec2FeedForward, nn.Module):
+ def __init__(self, config, act_fn=None, hidden_size=None):
+ nn.Module.__init__()
+ act_fn = act_fn if act_fn is not None else config.hidden_act
+ hidden_size = hidden_size if hidden_size is not None else config.hidden_size
+ self.intermediate_dropout = nn.Dropout(config.activation_dropout)
+
+ self.intermediate_dense = nn.Linear(hidden_size, config.intermediate_size)
+ self.intermediate_act_fn = ACT2FN[act_fn] if isinstance(act_fn, str) else act_fn
+
+ self.output_dense = nn.Linear(config.intermediate_size, hidden_size)
+ self.output_dropout = nn.Dropout(config.hidden_dropout)
+
+
+class Wav2Vec2BertConvolutionModule(nn.Module):
+ """Convolution block used in the conformer block"""
+
+ def __init__(self, config):
+ super().__init__()
+ if (config.conv_depthwise_kernel_size - 1) % 2 == 1:
+ raise ValueError("`config.conv_depthwise_kernel_size` should be a odd number for 'SAME' padding")
+ self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.pointwise_conv1 = nn.Conv1d(
+ config.hidden_size,
+ 2 * config.hidden_size,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ bias=False,
+ )
+ self.glu = nn.GLU(dim=1)
+ self.depthwise_conv = nn.Conv1d(
+ config.hidden_size,
+ config.hidden_size,
+ config.conv_depthwise_kernel_size,
+ stride=1,
+ padding=0,
+ groups=config.hidden_size,
+ bias=False,
+ )
+
+ self.depthwise_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.activation = ACT2FN[config.hidden_act]
+ self.pointwise_conv2 = nn.Conv1d(
+ config.hidden_size,
+ config.hidden_size,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ bias=False,
+ )
+ self.dropout = nn.Dropout(config.conformer_conv_dropout)
+
+ def forward(self, hidden_states, attention_mask=None):
+ hidden_states = self.layer_norm(hidden_states)
+
+ # Ensure that we do not leak padded positions in depthwise convolution if attention mask is passed.
+ # Put 0 where necessary
+ if attention_mask is not None:
+ hidden_states = hidden_states.masked_fill(~attention_mask.bool().unsqueeze(-1), 0.0)
+
+ # exchange the temporal dimension and the feature dimension
+ hidden_states = hidden_states.transpose(1, 2)
+
+ # GLU mechanism
+ # => (batch, 2*channel, dim)
+ hidden_states = self.pointwise_conv1(hidden_states)
+ # => (batch, channel, dim)
+ hidden_states = self.glu(hidden_states)
+
+ # Pad the sequence entirely on the left because of causal convolution.
+ hidden_states = torch.nn.functional.pad(hidden_states, (self.depthwise_conv.kernel_size[0] - 1, 0))
+
+ # 1D Depthwise Conv
+ hidden_states = self.depthwise_conv(hidden_states)
+
+ hidden_states = self.depthwise_layer_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+
+ hidden_states = self.activation(hidden_states)
+
+ hidden_states = self.pointwise_conv2(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = hidden_states.transpose(1, 2)
+ return hidden_states
+
+
+class Wav2Vec2BertSelfAttention(Wav2Vec2ConformerSelfAttention, nn.Module):
+ """Construct an Wav2Vec2BertSelfAttention object.
+ Can be enhanced with rotary or relative position embeddings.
+ """
+
+ def __init__(self, config, is_adapter_attention=False):
+ nn.Module.__init__()
+ hidden_size = config.hidden_size if not is_adapter_attention else config.output_hidden_size
+
+ self.head_size = hidden_size // config.num_attention_heads
+ self.num_heads = config.num_attention_heads
+ self.position_embeddings_type = config.position_embeddings_type if not is_adapter_attention else None
+
+ self.linear_q = nn.Linear(hidden_size, hidden_size)
+ self.linear_k = nn.Linear(hidden_size, hidden_size)
+ self.linear_v = nn.Linear(hidden_size, hidden_size)
+ self.linear_out = nn.Linear(hidden_size, hidden_size)
+
+ self.dropout = nn.Dropout(p=config.attention_dropout)
+
+ if self.position_embeddings_type == "relative":
+ # linear transformation for positional encoding
+ self.linear_pos = nn.Linear(hidden_size, hidden_size, bias=False)
+ # these two learnable bias are used in matrix c and matrix d
+ # as described in https://arxiv.org/abs/1901.02860 Section 3.3
+ self.pos_bias_u = nn.Parameter(torch.zeros(self.num_heads, self.head_size))
+ self.pos_bias_v = nn.Parameter(torch.zeros(self.num_heads, self.head_size))
+
+ if self.position_embeddings_type == "relative_key":
+ self.left_max_position_embeddings = config.left_max_position_embeddings
+ self.right_max_position_embeddings = config.right_max_position_embeddings
+ num_positions = self.left_max_position_embeddings + self.right_max_position_embeddings + 1
+ self.distance_embedding = nn.Embedding(num_positions, self.head_size)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ relative_position_embeddings: Optional[torch.Tensor] = None,
+ output_attentions: bool = False,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ # self-attention mechanism
+ batch_size, sequence_length, hidden_size = hidden_states.size()
+
+ # make sure query/key states can be != value states
+ query_key_states = hidden_states
+ value_states = hidden_states
+
+ if self.position_embeddings_type == "rotary":
+ if relative_position_embeddings is None:
+ raise ValueError(
+ "`relative_position_embeddings` has to be defined when `self.position_embeddings_type == 'rotary'"
+ )
+ query_key_states = self._apply_rotary_embedding(query_key_states, relative_position_embeddings)
+
+ # project query_key_states and value_states
+ query = self.linear_q(query_key_states).view(batch_size, -1, self.num_heads, self.head_size)
+ key = self.linear_k(query_key_states).view(batch_size, -1, self.num_heads, self.head_size)
+ value = self.linear_v(value_states).view(batch_size, -1, self.num_heads, self.head_size)
+
+ # => (batch, head, time1, d_k)
+ query = query.transpose(1, 2)
+ key = key.transpose(1, 2)
+ value = value.transpose(1, 2)
+
+ if self.position_embeddings_type == "relative":
+ if relative_position_embeddings is None:
+ raise ValueError(
+ "`relative_position_embeddings` has to be defined when `self.position_embeddings_type =="
+ " 'relative'"
+ )
+ # apply relative_position_embeddings to qk scores
+ # as proposed in Transformer_XL: https://arxiv.org/abs/1901.02860
+ scores = self._apply_relative_embeddings(
+ query=query, key=key, relative_position_embeddings=relative_position_embeddings
+ )
+ else:
+ scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.head_size)
+
+ if self.position_embeddings_type == "relative_key":
+ query_length, key_length = query.shape[2], key.shape[2]
+
+ position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
+ position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
+ distance = position_ids_r - position_ids_l
+ distance = torch.clamp(distance, -self.left_max_position_embeddings, self.right_max_position_embeddings)
+
+ positional_embedding = self.distance_embedding(distance + self.left_max_position_embeddings)
+ positional_embedding = positional_embedding.to(dtype=query.dtype) # fp16 compatibility
+
+ relative_position_attn_weights = torch.einsum("bhld,lrd->bhlr", query, positional_embedding)
+ scores = scores + (relative_position_attn_weights / math.sqrt(self.head_size))
+
+ # apply attention_mask if necessary
+ if attention_mask is not None:
+ scores = scores + attention_mask
+
+ # => (batch, head, time1, time2)
+ probs = torch.softmax(scores, dim=-1)
+ probs = self.dropout(probs)
+
+ # => (batch, head, time1, d_k)
+ hidden_states = torch.matmul(probs, value)
+
+ # => (batch, time1, hidden_size)
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_size)
+ hidden_states = self.linear_out(hidden_states)
+
+ return hidden_states, probs
+
+
+class Wav2Vec2BertEncoderLayer(nn.Module):
+ """Conformer block based on https://arxiv.org/abs/2005.08100."""
+
+ def __init__(self, config):
+ super().__init__()
+ embed_dim = config.hidden_size
+ dropout = config.attention_dropout
+
+ # Feed-forward 1
+ self.ffn1_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
+ self.ffn1 = Wav2Vec2BertFeedForward(config)
+
+ # Self-Attention
+ self.self_attn_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
+ self.self_attn_dropout = nn.Dropout(dropout)
+ self.self_attn = Wav2Vec2BertSelfAttention(config)
+
+ # Conformer Convolution
+ self.conv_module = Wav2Vec2BertConvolutionModule(config)
+
+ # Feed-forward 2
+ self.ffn2_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
+ self.ffn2 = Wav2Vec2BertFeedForward(config)
+ self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask: Optional[torch.Tensor] = None,
+ relative_position_embeddings: Optional[torch.Tensor] = None,
+ output_attentions: bool = False,
+ conv_attention_mask: Optional[torch.Tensor] = None,
+ ):
+ hidden_states = hidden_states
+
+ # 1. Feed-Forward 1 layer
+ residual = hidden_states
+ hidden_states = self.ffn1_layer_norm(hidden_states)
+ hidden_states = self.ffn1(hidden_states)
+ hidden_states = hidden_states * 0.5 + residual
+ residual = hidden_states
+
+ # 2. Self-Attention layer
+ hidden_states = self.self_attn_layer_norm(hidden_states)
+ hidden_states, attn_weigts = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ relative_position_embeddings=relative_position_embeddings,
+ output_attentions=output_attentions,
+ )
+ hidden_states = self.self_attn_dropout(hidden_states)
+ hidden_states = hidden_states + residual
+
+ # 3. Convolutional Layer
+ residual = hidden_states
+ hidden_states = self.conv_module(hidden_states, attention_mask=conv_attention_mask)
+ hidden_states = residual + hidden_states
+
+ # 4. Feed-Forward 2 Layer
+ residual = hidden_states
+ hidden_states = self.ffn2_layer_norm(hidden_states)
+ hidden_states = self.ffn2(hidden_states)
+ hidden_states = hidden_states * 0.5 + residual
+ hidden_states = self.final_layer_norm(hidden_states)
+
+ return hidden_states, attn_weigts
+
+
+class Wav2Vec2BertEncoder(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+
+ if config.position_embeddings_type == "relative":
+ self.embed_positions = Wav2Vec2BertRelPositionalEmbedding(config)
+ elif config.position_embeddings_type == "rotary":
+ self.embed_positions = Wav2Vec2BertRotaryPositionalEmbedding(config)
+ else:
+ self.embed_positions = None
+
+ self.dropout = nn.Dropout(config.hidden_dropout)
+ self.layers = nn.ModuleList([Wav2Vec2BertEncoderLayer(config) for _ in range(config.num_hidden_layers)])
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask=None,
+ output_attentions=False,
+ output_hidden_states=False,
+ return_dict=True,
+ ):
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attentions = () if output_attentions else None
+
+ conv_attention_mask = attention_mask
+ if attention_mask is not None:
+ # make sure padded tokens output 0
+ hidden_states = hidden_states.masked_fill(~attention_mask.bool().unsqueeze(-1), 0.0)
+
+ # extend attention_mask
+ attention_mask = 1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)
+ attention_mask = attention_mask * torch.finfo(hidden_states.dtype).min
+ attention_mask = attention_mask.expand(
+ attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1]
+ )
+
+ hidden_states = self.dropout(hidden_states)
+
+ if self.embed_positions is not None:
+ relative_position_embeddings = self.embed_positions(hidden_states)
+ else:
+ relative_position_embeddings = None
+
+ synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self)
+
+ for i, layer in enumerate(self.layers):
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
+ dropout_probability = torch.rand([])
+
+ skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False
+ if not skip_the_layer or synced_gpus:
+ # under fsdp or deepspeed zero3 all gpus must run in sync
+ if self.gradient_checkpointing and self.training:
+ layer_outputs = self._gradient_checkpointing_func(
+ layer.__call__,
+ hidden_states,
+ attention_mask,
+ relative_position_embeddings,
+ output_attentions,
+ conv_attention_mask,
+ )
+ else:
+ layer_outputs = layer(
+ hidden_states,
+ attention_mask=attention_mask,
+ relative_position_embeddings=relative_position_embeddings,
+ output_attentions=output_attentions,
+ conv_attention_mask=conv_attention_mask,
+ )
+ hidden_states = layer_outputs[0]
+
+ if skip_the_layer:
+ layer_outputs = (None, None)
+
+ if output_attentions:
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
+
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
+ return BaseModelOutput(
+ last_hidden_state=hidden_states,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attentions,
+ )
+
+
+class Wav2Vec2BertAdapter(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ # feature dim might need to be down-projected
+ if config.output_hidden_size != config.hidden_size:
+ self.proj = nn.Linear(config.hidden_size, config.output_hidden_size)
+ self.proj_layer_norm = nn.LayerNorm(config.output_hidden_size, eps=config.layer_norm_eps)
+ else:
+ self.proj = self.proj_layer_norm = None
+ self.layers = nn.ModuleList(Wav2Vec2BertAdapterLayer(config) for _ in range(config.num_adapter_layers))
+ self.layerdrop = config.layerdrop
+
+ self.kernel_size = config.adapter_kernel_size
+ self.stride = config.adapter_stride
+
+ def _compute_sub_sample_lengths_from_attention_mask(self, seq_lens):
+ if seq_lens is None:
+ return seq_lens
+ pad = self.kernel_size // 2
+ seq_lens = ((seq_lens + 2 * pad - self.kernel_size) / self.stride) + 1
+ return seq_lens.floor()
+
+ def forward(self, hidden_states, attention_mask=None):
+ # down project hidden_states if necessary
+ if self.proj is not None and self.proj_layer_norm is not None:
+ hidden_states = self.proj(hidden_states)
+ hidden_states = self.proj_layer_norm(hidden_states)
+
+ sub_sampled_lengths = None
+ if attention_mask is not None:
+ sub_sampled_lengths = (attention_mask.size(1) - (1 - attention_mask.int()).sum(1)).to(hidden_states.device)
+
+ for layer in self.layers:
+ layerdrop_prob = torch.rand([])
+ sub_sampled_lengths = self._compute_sub_sample_lengths_from_attention_mask(sub_sampled_lengths)
+ if not self.training or (layerdrop_prob > self.layerdrop):
+ hidden_states = layer(
+ hidden_states, attention_mask=attention_mask, sub_sampled_lengths=sub_sampled_lengths
+ )
+
+ return hidden_states
+
+
+class Wav2Vec2BertAdapterLayer(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ embed_dim = config.output_hidden_size
+ dropout = config.conformer_conv_dropout
+
+ self.kernel_size = config.adapter_kernel_size
+ self.stride = config.adapter_stride
+
+ # 1. residual convolution
+ self.residual_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
+ self.residual_conv = nn.Conv1d(
+ embed_dim,
+ 2 * embed_dim,
+ self.kernel_size,
+ stride=self.stride,
+ padding=self.stride // 2,
+ )
+ self.activation = nn.GLU(dim=1)
+
+ # Self-Attention
+ self.self_attn_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
+ self.self_attn_conv = nn.Conv1d(
+ embed_dim,
+ 2 * embed_dim,
+ self.kernel_size,
+ stride=self.stride,
+ padding=self.stride // 2,
+ )
+ self.self_attn = Wav2Vec2BertSelfAttention(config, is_adapter_attention=True)
+ self.self_attn_dropout = nn.Dropout(dropout)
+
+ # Feed-forward
+ self.ffn_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
+ self.ffn = Wav2Vec2BertFeedForward(config, act_fn=config.adapter_act, hidden_size=embed_dim)
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask: Optional[torch.Tensor] = None,
+ output_attentions: bool = False,
+ sub_sampled_lengths: Optional[torch.Tensor] = None,
+ ):
+ residual = self.residual_layer_norm(hidden_states)
+
+ # Apply pooling to the residual to match the sequence length of the
+ # multi-head attention output.
+ # (batch, seq_len, feature_dim) -> (batch, feature_dim, seq_len)
+ residual = residual.transpose(1, 2)
+ residual = self.residual_conv(residual)
+ residual = self.activation(residual)
+ # (batch, feature_dim, seq_len) -> (batch, seq_len, feature_dim)
+ residual = residual.transpose(1, 2)
+
+ hidden_states = self.self_attn_layer_norm(hidden_states)
+ # Apply pooling before feeding to the multihead-attention layer.
+ # (batch, seq_len, feature_dim) -> (batch, feature_dim, seq_len)
+ hidden_states = hidden_states.transpose(1, 2)
+ hidden_states = self.self_attn_conv(hidden_states)
+ hidden_states = self.activation(hidden_states)
+ # (batch, feature_dim, seq_len) -> (batch, seq_len, feature_dim)
+ hidden_states = hidden_states.transpose(1, 2)
+
+ if attention_mask is not None:
+ attention_mask = _compute_new_attention_mask(hidden_states=hidden_states, seq_lens=sub_sampled_lengths)
+ attention_mask = _prepare_4d_attention_mask(
+ attention_mask,
+ hidden_states.dtype,
+ )
+
+ # The rest of the computation is identical to a vanilla Transformer
+ # encoder layer.
+ hidden_states, attn_weigths = self.self_attn(
+ hidden_states,
+ attention_mask=attention_mask,
+ output_attentions=output_attentions,
+ )
+ hidden_states = self.self_attn_dropout(hidden_states)
+ hidden_states = hidden_states + residual
+
+ residual = hidden_states
+
+ hidden_states = self.ffn_layer_norm(hidden_states)
+ hidden_states = self.ffn(hidden_states) + residual
+
+ return hidden_states
+
+
+class Wav2Vec2BertPreTrainedModel(PreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = Wav2Vec2BertConfig
+ base_model_prefix = "wav2vec2_bert"
+ main_input_name = "input_features"
+ supports_gradient_checkpointing = True
+
+ # Ignore copy
+ def _init_weights(self, module):
+ """Initialize the weights"""
+ if isinstance(module, Wav2Vec2BertSelfAttention):
+ if hasattr(module, "pos_bias_u"):
+ nn.init.xavier_uniform_(module.pos_bias_u)
+ if hasattr(module, "pos_bias_v"):
+ nn.init.xavier_uniform_(module.pos_bias_v)
+ elif isinstance(module, Wav2Vec2BertFeatureProjection):
+ k = math.sqrt(1 / module.projection.in_features)
+ nn.init.uniform_(module.projection.weight, a=-k, b=k)
+ nn.init.uniform_(module.projection.bias, a=-k, b=k)
+ elif isinstance(module, nn.Linear):
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+ elif isinstance(module, nn.Conv1d):
+ nn.init.kaiming_normal_(module.weight)
+
+ if module.bias is not None:
+ k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0]))
+ nn.init.uniform_(module.bias, a=-k, b=k)
+
+ # Ignore copy
+ def _get_feat_extract_output_lengths(
+ self, input_lengths: Union[torch.LongTensor, int], add_adapter: Optional[bool] = None
+ ):
+ """
+ Computes the output length of the convolutional layers
+ """
+
+ add_adapter = self.config.add_adapter if add_adapter is None else add_adapter
+
+ def _conv_out_length(input_length, kernel_size, stride, padding):
+ # 1D convolutional layer output length formula taken
+ # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
+ return torch.div(input_length + 2 * padding - kernel_size, stride, rounding_mode="floor") + 1
+
+ if add_adapter:
+ padding = self.config.adapter_kernel_size // 2
+ for _ in range(self.config.num_adapter_layers):
+ input_lengths = _conv_out_length(
+ input_lengths, self.config.adapter_kernel_size, self.config.adapter_stride, padding
+ )
+
+ return input_lengths
+
+ def _get_feature_vector_attention_mask(
+ self, feature_vector_length: int, attention_mask: torch.LongTensor, add_adapter=None
+ ):
+ # Effectively attention_mask.sum(-1), but not inplace to be able to run
+ # on inference mode.
+ non_padded_lengths = attention_mask.cumsum(dim=-1)[:, -1]
+
+ output_lengths = self._get_feat_extract_output_lengths(non_padded_lengths, add_adapter=add_adapter)
+ output_lengths = output_lengths.to(torch.long)
+
+ batch_size = attention_mask.shape[0]
+
+ attention_mask = torch.zeros(
+ (batch_size, feature_vector_length), dtype=attention_mask.dtype, device=attention_mask.device
+ )
+ # these two operations makes sure that all values before the output lengths idxs are attended to
+ attention_mask[(torch.arange(attention_mask.shape[0], device=attention_mask.device), output_lengths - 1)] = 1
+ attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool()
+ return attention_mask
+
+
+WAV2VEC2_BERT_START_DOCSTRING = None
+
+WAV2VEC2_BERT_INPUTS_DOCSTRING = r"""
+ Args:
+ input_features (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
+ Float values of input raw speech waveform. Values can be obtained by loading a `.flac` or `.wav` audio file
+ into an array of type `List[float]` or a `numpy.ndarray`, *e.g.* via the soundfile library (`pip install
+ soundfile`). To prepare the array into `input_features`, the [`AutoProcessor`] should be used for padding and
+ conversion into a tensor of type `torch.FloatTensor`. See [`Wav2Vec2BertProcessor.__call__`] for details.
+ attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing convolution and attention on padding token indices. Mask values selected in `[0,
+ 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+
+
+Wav2Vec2BertBaseModelOutput = Wav2Vec2BaseModelOutput
+
+
+@add_start_docstrings(
+ "The bare Wav2Vec2Bert Model transformer outputting raw hidden-states without any specific head on top.",
+ WAV2VEC2_BERT_START_DOCSTRING,
+)
+class Wav2Vec2BertModel(Wav2Vec2Model, Wav2Vec2BertPreTrainedModel):
+ def __init__(self, config: Wav2Vec2BertConfig):
+ Wav2Vec2BertPreTrainedModel.__init__(config)
+ self.config = config
+ self.feature_projection = Wav2Vec2BertFeatureProjection(config)
+
+ # model only needs masking vector if mask prob is > 0.0
+ if config.mask_time_prob > 0.0 or config.mask_feature_prob > 0.0:
+ self.masked_spec_embed = nn.Parameter(torch.Tensor(config.hidden_size).uniform_())
+
+ self.encoder = Wav2Vec2BertEncoder(config)
+
+ self.adapter = Wav2Vec2BertAdapter(config) if config.add_adapter else None
+
+ self.intermediate_ffn = None
+ if config.use_intermediate_ffn_before_adapter:
+ self.intermediate_ffn = Wav2Vec2BertFeedForward(config, act_fn="relu")
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def freeze_feature_extractor(self):
+ raise AttributeError("Not needed for Wav2Vec2Bert")
+
+ def freeze_feature_encoder(self):
+ raise AttributeError("Not needed for Wav2Vec2Bert")
+
+ @add_start_docstrings_to_model_forward(WAV2VEC2_BERT_INPUTS_DOCSTRING)
+ @add_code_sample_docstrings(
+ checkpoint=_PRETRAINED_CHECKPOINT_FOR_DOC,
+ output_type=Wav2Vec2BertBaseModelOutput,
+ config_class=_CONFIG_FOR_DOC,
+ modality="audio",
+ expected_output=_EXPECTED_OUTPUT_SHAPE,
+ )
+ def forward(
+ self,
+ input_features: Optional[torch.Tensor],
+ attention_mask: Optional[torch.Tensor] = None,
+ mask_time_indices: Optional[torch.FloatTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, Wav2Vec2BertBaseModelOutput]:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ hidden_states, extract_features = self.feature_projection(input_features)
+ hidden_states = self._mask_hidden_states(
+ hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask
+ )
+
+ encoder_outputs = self.encoder(
+ hidden_states,
+ attention_mask=attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ hidden_states = encoder_outputs[0]
+
+ if self.intermediate_ffn:
+ expanded_hidden_states = self.intermediate_ffn(hidden_states)
+ hidden_states = hidden_states + 0.5 * expanded_hidden_states
+
+ if self.adapter is not None:
+ hidden_states = self.adapter(hidden_states, attention_mask=attention_mask)
+
+ if not return_dict:
+ return (hidden_states, extract_features) + encoder_outputs[1:]
+
+ return Wav2Vec2BertBaseModelOutput(
+ last_hidden_state=hidden_states,
+ extract_features=extract_features,
+ hidden_states=encoder_outputs.hidden_states,
+ attentions=encoder_outputs.attentions,
+ )
+
+
+@add_start_docstrings(
+ """Wav2Vec2Bert Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC).""",
+ WAV2VEC2_BERT_START_DOCSTRING,
+)
+class Wav2Vec2BertForCTC(Wav2Vec2ConformerForCTC):
+ def __init__(self, config, target_lang: Optional[str] = None):
+ super().__init__(config)
+
+ def freeze_feature_encoder(self):
+ raise AttributeError("Not needed for Wav2Vec2Bert")
+
+ @add_start_docstrings_to_model_forward(WAV2VEC2_BERT_INPUTS_DOCSTRING)
+ @add_code_sample_docstrings(
+ checkpoint=_PRETRAINED_CHECKPOINT_FOR_DOC,
+ output_type=CausalLMOutput,
+ config_class=_CONFIG_FOR_DOC,
+ expected_output=_CTC_EXPECTED_OUTPUT,
+ expected_loss=_CTC_EXPECTED_LOSS,
+ )
+ def forward(
+ self,
+ input_features: Optional[torch.Tensor],
+ attention_mask: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ labels: Optional[torch.Tensor] = None,
+ ) -> Union[Tuple, CausalLMOutput]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, target_length)`, *optional*):
+ Labels for connectionist temporal classification. Note that `target_length` has to be smaller or equal to
+ the sequence length of the output logits. Indices are selected in `[-100, 0, ..., config.vocab_size - 1]`.
+ All labels set to `-100` are ignored (masked), the loss is only computed for labels in `[0, ...,
+ config.vocab_size - 1]`.
+ """
+ if labels is not None and labels.max() >= self.config.vocab_size:
+ raise ValueError(f"Label values must be <= vocab_size: {self.config.vocab_size}")
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.wav2vec2_bert(
+ input_features,
+ attention_mask=attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ hidden_states = outputs[0]
+ hidden_states = self.dropout(hidden_states)
+
+ logits = self.lm_head(hidden_states)
+
+ loss = None
+ if labels is not None:
+ # retrieve loss input_lengths from attention_mask
+ attention_mask = (
+ attention_mask
+ if attention_mask is not None
+ else torch.ones(input_features.shape[:2], device=input_features.device, dtype=torch.long)
+ )
+ input_lengths = self._get_feat_extract_output_lengths(attention_mask.sum([-1])).to(torch.long)
+
+ # assuming that padded tokens are filled with -100
+ # when not being attended to
+ labels_mask = labels >= 0
+ target_lengths = labels_mask.sum(-1)
+ flattened_targets = labels.masked_select(labels_mask)
+
+ # ctc_loss doesn't support fp16
+ log_probs = nn.functional.log_softmax(logits, dim=-1, dtype=torch.float32).transpose(0, 1)
+
+ with torch.backends.cudnn.flags(enabled=False):
+ loss = nn.functional.ctc_loss(
+ log_probs,
+ flattened_targets,
+ input_lengths,
+ target_lengths,
+ blank=self.config.pad_token_id,
+ reduction=self.config.ctc_loss_reduction,
+ zero_infinity=self.config.ctc_zero_infinity,
+ )
+
+ if not return_dict:
+ output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
+ return ((loss,) + output) if loss is not None else output
+
+ return CausalLMOutput(
+ loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions
+ )
+
+
+@add_start_docstrings(
+ """
+ Wav2Vec2Bert Model with a sequence classification head on top (a linear layer over the pooled output) for
+ tasks like SUPERB Keyword Spotting.
+ """,
+ WAV2VEC2_BERT_START_DOCSTRING,
+)
+class Wav2Vec2BertForSequenceClassification(Wav2Vec2ForSequenceClassification):
+ def __init__(self, config):
+ super().__init__(config)
+
+ def freeze_feature_extractor(self):
+ raise AttributeError("Not needed for Wav2Vec2Bert")
+
+ def freeze_feature_encoder(self):
+ raise AttributeError("Not needed for Wav2Vec2Bert")
+
+ def freeze_base_model(self):
+ """
+ Calling this function will disable the gradient computation for the base model so that its parameters will not
+ be updated during training. Only the classification head will be updated.
+ """
+ for param in self.wav2vec2_bert.parameters():
+ param.requires_grad = False
+
+ @add_start_docstrings_to_model_forward(WAV2VEC2_BERT_INPUTS_DOCSTRING)
+ @add_code_sample_docstrings(
+ checkpoint=_BASE_CHECKPOINT_FOR_DOC,
+ output_type=SequenceClassifierOutput,
+ config_class=_CONFIG_FOR_DOC,
+ modality="audio",
+ )
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.forward with Wav2Vec2->Wav2Vec2Bert,wav2vec2->wav2vec2_bert,WAV_2_VEC_2->WAV2VEC2_BERT, input_values->input_features
+ def forward(
+ self,
+ input_features: Optional[torch.Tensor],
+ attention_mask: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ labels: Optional[torch.Tensor] = None,
+ ) -> Union[Tuple, SequenceClassifierOutput]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ """
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states
+
+ outputs = self.wav2vec2_bert(
+ input_features,
+ attention_mask=attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ if self.config.use_weighted_layer_sum:
+ hidden_states = outputs[_HIDDEN_STATES_START_POSITION]
+ hidden_states = torch.stack(hidden_states, dim=1)
+ norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
+ hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
+ else:
+ hidden_states = outputs[0]
+
+ hidden_states = self.projector(hidden_states)
+ if attention_mask is None:
+ pooled_output = hidden_states.mean(dim=1)
+ else:
+ padding_mask = self._get_feature_vector_attention_mask(hidden_states.shape[1], attention_mask)
+ expand_padding_mask = padding_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
+ hidden_states[~expand_padding_mask] = 0.0
+ pooled_output = hidden_states.sum(dim=1) / padding_mask.sum(dim=1).view(-1, 1)
+
+ logits = self.classifier(pooled_output)
+
+ loss = None
+ if labels is not None:
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
+
+ if not return_dict:
+ output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
+ return ((loss,) + output) if loss is not None else output
+
+ return SequenceClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+@add_start_docstrings(
+ """
+ Wav2Vec2Bert Model with a frame classification head on top for tasks like Speaker Diarization.
+ """,
+ WAV2VEC2_BERT_START_DOCSTRING,
+)
+class Wav2Vec2BertForAudioFrameClassification(Wav2Vec2ConformerForAudioFrameClassification):
+ def __init__(self, config):
+ super().__init__(config)
+
+ def freeze_feature_encoder(self):
+ raise AttributeError("Not needed for Wav2Vec2Bert")
+
+ @add_start_docstrings_to_model_forward(WAV2VEC2_BERT_INPUTS_DOCSTRING)
+ @add_code_sample_docstrings(
+ checkpoint=_BASE_CHECKPOINT_FOR_DOC,
+ output_type=TokenClassifierOutput,
+ config_class=_CONFIG_FOR_DOC,
+ modality="audio",
+ )
+ # Copied from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer.Wav2Vec2ConformerForAudioFrameClassification.forward with wav2vec2_conformer->wav2vec2_bert, input_values->input_features
+ def forward(
+ self,
+ input_features: Optional[torch.Tensor],
+ attention_mask: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, TokenClassifierOutput]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ """
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states
+
+ outputs = self.wav2vec2_bert(
+ input_features,
+ attention_mask=attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ if self.config.use_weighted_layer_sum:
+ hidden_states = outputs[_HIDDEN_STATES_START_POSITION]
+ hidden_states = torch.stack(hidden_states, dim=1)
+ norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
+ hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
+ else:
+ hidden_states = outputs[0]
+
+ logits = self.classifier(hidden_states)
+
+ loss = None
+ if labels is not None:
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(logits.view(-1, self.num_labels), torch.argmax(labels.view(-1, self.num_labels), axis=1))
+
+ if not return_dict:
+ output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
+ return output
+
+ return TokenClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+@add_start_docstrings(
+ """
+ Wav2Vec2Bert Model with an XVector feature extraction head on top for tasks like Speaker Verification.
+ """,
+ WAV2VEC2_BERT_START_DOCSTRING,
+)
+class Wav2Vec2BertForXVector(Wav2Vec2ConformerForXVector):
+ def __init__(self, config):
+ super().__init__(config)
+
+ def freeze_feature_encoder(self):
+ raise AttributeError("Not needed for Wav2Vec2Bert")
+
+ @add_start_docstrings_to_model_forward(WAV2VEC2_BERT_INPUTS_DOCSTRING)
+ @add_code_sample_docstrings(
+ checkpoint=_BASE_CHECKPOINT_FOR_DOC,
+ output_type=XVectorOutput,
+ config_class=_CONFIG_FOR_DOC,
+ modality="audio",
+ )
+ # Copied from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer.Wav2Vec2ConformerForXVector.forward with wav2vec2_conformer->wav2vec2_bert, input_values->input_features
+ def forward(
+ self,
+ input_features: Optional[torch.Tensor],
+ attention_mask: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ labels: Optional[torch.Tensor] = None,
+ ) -> Union[Tuple, XVectorOutput]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ """
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states
+
+ outputs = self.wav2vec2_bert(
+ input_features,
+ attention_mask=attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ if self.config.use_weighted_layer_sum:
+ hidden_states = outputs[_HIDDEN_STATES_START_POSITION]
+ hidden_states = torch.stack(hidden_states, dim=1)
+ norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
+ hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
+ else:
+ hidden_states = outputs[0]
+
+ hidden_states = self.projector(hidden_states)
+
+ for tdnn_layer in self.tdnn:
+ hidden_states = tdnn_layer(hidden_states)
+
+ # Statistic Pooling
+ if attention_mask is None:
+ mean_features = hidden_states.mean(dim=1)
+ std_features = hidden_states.std(dim=1)
+ else:
+ feat_extract_output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(dim=1))
+ tdnn_output_lengths = self._get_tdnn_output_lengths(feat_extract_output_lengths)
+ mean_features = []
+ std_features = []
+ for i, length in enumerate(tdnn_output_lengths):
+ mean_features.append(hidden_states[i, :length].mean(dim=0))
+ std_features.append(hidden_states[i, :length].std(dim=0))
+ mean_features = torch.stack(mean_features)
+ std_features = torch.stack(std_features)
+ statistic_pooling = torch.cat([mean_features, std_features], dim=-1)
+
+ output_embeddings = self.feature_extractor(statistic_pooling)
+ logits = self.classifier(output_embeddings)
+
+ loss = None
+ if labels is not None:
+ loss = self.objective(logits, labels)
+
+ if not return_dict:
+ output = (logits, output_embeddings) + outputs[_HIDDEN_STATES_START_POSITION:]
+ return ((loss,) + output) if loss is not None else output
+
+ return XVectorOutput(
+ loss=loss,
+ logits=logits,
+ embeddings=output_embeddings,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+__all__ = [
+ "Wav2Vec2BertForAudioFrameClassification",
+ "Wav2Vec2BertForCTC",
+ "Wav2Vec2BertForSequenceClassification",
+ "Wav2Vec2BertForXVector",
+ "Wav2Vec2BertModel",
+ "Wav2Vec2BertPreTrainedModel",
+]
diff --git a/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py b/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py
index 587ae1c8249..bd94e44b616 100644
--- a/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py
+++ b/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py
@@ -1,19 +1,9 @@
-# coding=utf-8
-# Copyright 2022 The Fairseq Authors and the HuggingFace Inc. team. All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-"""PyTorch Wav2Vec2-Conformer model."""
-
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/wav2vec2_conformer/modular_wav2vec2_conformer.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_wav2vec2_conformer.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
import math
import warnings
from dataclasses import dataclass
@@ -21,7 +11,6 @@ from typing import Optional, Tuple, Union
import numpy as np
import torch
-import torch.utils.checkpoint
from torch import nn
from torch.nn import CrossEntropyLoss
@@ -43,31 +32,19 @@ from ...utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_peft_available,
- logging,
replace_return_docstrings,
)
from .configuration_wav2vec2_conformer import Wav2Vec2ConformerConfig
-logger = logging.get_logger(__name__)
-
-
-_HIDDEN_STATES_START_POSITION = 2
+# Base docstring
+_CHECKPOINT_FOR_DOC = "facebook/wav2vec2-conformer-rope-large-960h-ft"
# General docstring
_CONFIG_FOR_DOC = "Wav2Vec2ConformerConfig"
-# Base docstring
-_CHECKPOINT_FOR_DOC = "facebook/wav2vec2-conformer-rope-large-960h-ft"
-_EXPECTED_OUTPUT_SHAPE = [1, 292, 1024]
-
-# CTC docstring
-_CTC_EXPECTED_OUTPUT = "'MISTER QUILTER IS THE APOSTLE OF THE MIDDLE CLASSES AND WE ARE GLAD TO WELCOME HIS GOSPEL'"
-_CTC_EXPECTED_LOSS = 64.21
-
@dataclass
-# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForPreTrainingOutput with Wav2Vec2->Wav2Vec2Conformer
class Wav2Vec2ConformerForPreTrainingOutput(ModelOutput):
"""
Output type of [`Wav2Vec2ConformerForPreTraining`], with potential hidden states and attentions.
@@ -109,239 +86,17 @@ class Wav2Vec2ConformerForPreTrainingOutput(ModelOutput):
diversity_loss: Optional[torch.FloatTensor] = None
-# Copied from transformers.models.wav2vec2.modeling_wav2vec2._compute_mask_indices
-def _compute_mask_indices(
- shape: Tuple[int, int],
- mask_prob: float,
- mask_length: int,
- attention_mask: Optional[torch.LongTensor] = None,
- min_masks: int = 0,
-) -> np.ndarray:
- """
- Computes random mask spans for a given shape. Used to implement [SpecAugment: A Simple Data Augmentation Method for
- ASR](https://arxiv.org/abs/1904.08779). Note that this method is not optimized to run on TPU and should be run on
- CPU as part of the preprocessing during training.
-
- Args:
- shape: The shape for which to compute masks. This should be of a tuple of size 2 where
- the first element is the batch size and the second element is the length of the axis to span.
- mask_prob: The percentage of the whole axis (between 0 and 1) which will be masked. The number of
- independently generated mask spans of length `mask_length` is computed by
- `mask_prob*shape[1]/mask_length`. Note that due to overlaps, `mask_prob` is an upper bound and the
- actual percentage will be smaller.
- mask_length: size of the mask
- min_masks: minimum number of masked spans
- attention_mask: A (right-padded) attention mask which independently shortens the feature axis of
- each batch dimension.
- """
- batch_size, sequence_length = shape
-
- if mask_length < 1:
- raise ValueError("`mask_length` has to be bigger than 0.")
-
- if mask_length > sequence_length:
- raise ValueError(
- f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length}"
- f" and `sequence_length`: {sequence_length}`"
- )
-
- # epsilon is used for probabilistic rounding
- epsilon = np.random.rand(1).item()
-
- def compute_num_masked_span(input_length):
- """Given input length, compute how many spans should be masked"""
- num_masked_span = int(mask_prob * input_length / mask_length + epsilon)
- num_masked_span = max(num_masked_span, min_masks)
-
- # make sure num masked span <= sequence_length
- if num_masked_span * mask_length > sequence_length:
- num_masked_span = sequence_length // mask_length
-
- # make sure num_masked span is also <= input_length - (mask_length - 1)
- if input_length - (mask_length - 1) < num_masked_span:
- num_masked_span = max(input_length - (mask_length - 1), 0)
-
- return num_masked_span
-
- # compute number of masked spans in batch
- input_lengths = (
- attention_mask.detach().sum(-1).tolist()
- if attention_mask is not None
- else [sequence_length for _ in range(batch_size)]
- )
-
- # SpecAugment mask to fill
- spec_aug_mask = np.zeros((batch_size, sequence_length), dtype=bool)
- spec_aug_mask_idxs = []
-
- max_num_masked_span = compute_num_masked_span(sequence_length)
-
- if max_num_masked_span == 0:
- return spec_aug_mask
-
- for input_length in input_lengths:
- # compute num of masked spans for this input
- num_masked_span = compute_num_masked_span(input_length)
-
- # get random indices to mask
- spec_aug_mask_idx = np.random.choice(
- np.arange(input_length - (mask_length - 1)), num_masked_span, replace=False
- )
-
- # pick first sampled index that will serve as a dummy index to pad vector
- # to ensure same dimension for all batches due to probabilistic rounding
- # Picking first sample just pads those vectors twice.
- if len(spec_aug_mask_idx) == 0:
- # this case can only happen if `input_length` is strictly smaller then
- # `sequence_length` in which case the last token has to be a padding
- # token which we can use as a dummy mask id
- dummy_mask_idx = sequence_length - 1
- else:
- dummy_mask_idx = spec_aug_mask_idx[0]
-
- spec_aug_mask_idx = np.concatenate(
- [spec_aug_mask_idx, np.ones(max_num_masked_span - num_masked_span, dtype=np.int32) * dummy_mask_idx]
- )
- spec_aug_mask_idxs.append(spec_aug_mask_idx)
-
- spec_aug_mask_idxs = np.array(spec_aug_mask_idxs)
-
- # expand masked indices to masked spans
- spec_aug_mask_idxs = np.broadcast_to(
- spec_aug_mask_idxs[:, :, None], (batch_size, max_num_masked_span, mask_length)
- )
- spec_aug_mask_idxs = spec_aug_mask_idxs.reshape(batch_size, max_num_masked_span * mask_length)
-
- # add offset to the starting indexes so that indexes now create a span
- offsets = np.arange(mask_length)[None, None, :]
- offsets = np.broadcast_to(offsets, (batch_size, max_num_masked_span, mask_length)).reshape(
- batch_size, max_num_masked_span * mask_length
- )
- spec_aug_mask_idxs = spec_aug_mask_idxs + offsets
-
- # ensure that we cannot have indices larger than sequence_length
- if spec_aug_mask_idxs.max() > sequence_length - 1:
- spec_aug_mask_idxs[spec_aug_mask_idxs > sequence_length - 1] = sequence_length - 1
-
- # scatter indices to mask
- np.put_along_axis(spec_aug_mask, spec_aug_mask_idxs, 1, -1)
-
- return spec_aug_mask
-
-
-# Copied from transformers.models.wav2vec2.modeling_wav2vec2._sample_negative_indices
-def _sample_negative_indices(
- features_shape: Tuple, num_negatives: int, mask_time_indices: Optional[np.ndarray] = None
-):
- """
- Sample `num_negatives` vectors from feature vectors.
- """
- batch_size, sequence_length = features_shape
-
- # generate indices of the positive vectors themselves, repeat them `num_negatives` times
- sequence_length_range = np.arange(sequence_length)
-
- # get `num_negatives` random vector indices from the same utterance
- sampled_negative_indices = np.zeros(shape=(batch_size, sequence_length, num_negatives), dtype=np.int32)
-
- mask_time_indices = (
- mask_time_indices.astype(bool) if mask_time_indices is not None else np.ones(features_shape, dtype=bool)
- )
-
- for batch_idx in range(batch_size):
- high = mask_time_indices[batch_idx].sum() - 1
- mapped_masked_indices = sequence_length_range[mask_time_indices[batch_idx]]
-
- feature_indices = np.broadcast_to(np.arange(high + 1)[:, None], (high + 1, num_negatives))
- sampled_indices = np.random.randint(0, high, size=(high + 1, num_negatives))
- # avoid sampling the same positive vector, but keep the distribution uniform
- sampled_indices[sampled_indices >= feature_indices] += 1
-
- # remap to actual indices
- sampled_negative_indices[batch_idx][mask_time_indices[batch_idx]] = mapped_masked_indices[sampled_indices]
-
- # correct for batch size
- sampled_negative_indices[batch_idx] += batch_idx * sequence_length
-
- return sampled_negative_indices
-
-
-# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2NoLayerNormConvLayer with Wav2Vec2->Wav2Vec2Conformer
-class Wav2Vec2ConformerNoLayerNormConvLayer(nn.Module):
- def __init__(self, config, layer_id=0):
+class Wav2Vec2ConformerSamePadLayer(nn.Module):
+ def __init__(self, num_conv_pos_embeddings):
super().__init__()
- self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
- self.out_conv_dim = config.conv_dim[layer_id]
-
- self.conv = nn.Conv1d(
- self.in_conv_dim,
- self.out_conv_dim,
- kernel_size=config.conv_kernel[layer_id],
- stride=config.conv_stride[layer_id],
- bias=config.conv_bias,
- )
- self.activation = ACT2FN[config.feat_extract_activation]
+ self.num_pad_remove = 1 if num_conv_pos_embeddings % 2 == 0 else 0
def forward(self, hidden_states):
- hidden_states = self.conv(hidden_states)
- hidden_states = self.activation(hidden_states)
+ if self.num_pad_remove > 0:
+ hidden_states = hidden_states[:, :, : -self.num_pad_remove]
return hidden_states
-# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2LayerNormConvLayer with Wav2Vec2->Wav2Vec2Conformer
-class Wav2Vec2ConformerLayerNormConvLayer(nn.Module):
- def __init__(self, config, layer_id=0):
- super().__init__()
- self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
- self.out_conv_dim = config.conv_dim[layer_id]
-
- self.conv = nn.Conv1d(
- self.in_conv_dim,
- self.out_conv_dim,
- kernel_size=config.conv_kernel[layer_id],
- stride=config.conv_stride[layer_id],
- bias=config.conv_bias,
- )
- self.layer_norm = nn.LayerNorm(self.out_conv_dim, elementwise_affine=True)
- self.activation = ACT2FN[config.feat_extract_activation]
-
- def forward(self, hidden_states):
- hidden_states = self.conv(hidden_states)
-
- hidden_states = hidden_states.transpose(-2, -1)
- hidden_states = self.layer_norm(hidden_states)
- hidden_states = hidden_states.transpose(-2, -1)
-
- hidden_states = self.activation(hidden_states)
- return hidden_states
-
-
-# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2GroupNormConvLayer with Wav2Vec2->Wav2Vec2Conformer
-class Wav2Vec2ConformerGroupNormConvLayer(nn.Module):
- def __init__(self, config, layer_id=0):
- super().__init__()
- self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
- self.out_conv_dim = config.conv_dim[layer_id]
-
- self.conv = nn.Conv1d(
- self.in_conv_dim,
- self.out_conv_dim,
- kernel_size=config.conv_kernel[layer_id],
- stride=config.conv_stride[layer_id],
- bias=config.conv_bias,
- )
- self.activation = ACT2FN[config.feat_extract_activation]
-
- self.layer_norm = nn.GroupNorm(num_groups=self.out_conv_dim, num_channels=self.out_conv_dim, affine=True)
-
- def forward(self, hidden_states):
- hidden_states = self.conv(hidden_states)
- hidden_states = self.layer_norm(hidden_states)
- hidden_states = self.activation(hidden_states)
- return hidden_states
-
-
-# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2PositionalConvEmbedding with Wav2Vec2->Wav2Vec2Conformer
class Wav2Vec2ConformerPositionalConvEmbedding(nn.Module):
def __init__(self, config):
super().__init__()
@@ -471,19 +226,78 @@ class Wav2Vec2ConformerRelPositionalEmbedding(nn.Module):
return relative_position_embeddings
-# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2SamePadLayer with Wav2Vec2->Wav2Vec2Conformer
-class Wav2Vec2ConformerSamePadLayer(nn.Module):
- def __init__(self, num_conv_pos_embeddings):
+class Wav2Vec2ConformerNoLayerNormConvLayer(nn.Module):
+ def __init__(self, config, layer_id=0):
super().__init__()
- self.num_pad_remove = 1 if num_conv_pos_embeddings % 2 == 0 else 0
+ self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
+ self.out_conv_dim = config.conv_dim[layer_id]
+
+ self.conv = nn.Conv1d(
+ self.in_conv_dim,
+ self.out_conv_dim,
+ kernel_size=config.conv_kernel[layer_id],
+ stride=config.conv_stride[layer_id],
+ bias=config.conv_bias,
+ )
+ self.activation = ACT2FN[config.feat_extract_activation]
def forward(self, hidden_states):
- if self.num_pad_remove > 0:
- hidden_states = hidden_states[:, :, : -self.num_pad_remove]
+ hidden_states = self.conv(hidden_states)
+ hidden_states = self.activation(hidden_states)
+ return hidden_states
+
+
+class Wav2Vec2ConformerLayerNormConvLayer(nn.Module):
+ def __init__(self, config, layer_id=0):
+ super().__init__()
+ self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
+ self.out_conv_dim = config.conv_dim[layer_id]
+
+ self.conv = nn.Conv1d(
+ self.in_conv_dim,
+ self.out_conv_dim,
+ kernel_size=config.conv_kernel[layer_id],
+ stride=config.conv_stride[layer_id],
+ bias=config.conv_bias,
+ )
+ self.layer_norm = nn.LayerNorm(self.out_conv_dim, elementwise_affine=True)
+ self.activation = ACT2FN[config.feat_extract_activation]
+
+ def forward(self, hidden_states):
+ hidden_states = self.conv(hidden_states)
+
+ hidden_states = hidden_states.transpose(-2, -1)
+ hidden_states = self.layer_norm(hidden_states)
+ hidden_states = hidden_states.transpose(-2, -1)
+
+ hidden_states = self.activation(hidden_states)
+ return hidden_states
+
+
+class Wav2Vec2ConformerGroupNormConvLayer(nn.Module):
+ def __init__(self, config, layer_id=0):
+ super().__init__()
+ self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
+ self.out_conv_dim = config.conv_dim[layer_id]
+
+ self.conv = nn.Conv1d(
+ self.in_conv_dim,
+ self.out_conv_dim,
+ kernel_size=config.conv_kernel[layer_id],
+ stride=config.conv_stride[layer_id],
+ bias=config.conv_bias,
+ )
+ self.activation = ACT2FN[config.feat_extract_activation]
+
+ self.layer_norm = nn.GroupNorm(num_groups=self.out_conv_dim, num_channels=self.out_conv_dim, affine=True)
+
+ def forward(self, hidden_states):
+ hidden_states = self.conv(hidden_states)
+ hidden_states = self.layer_norm(hidden_states)
+ hidden_states = self.activation(hidden_states)
return hidden_states
-# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2FeatureEncoder with Wav2Vec2->Wav2Vec2Conformer
class Wav2Vec2ConformerFeatureEncoder(nn.Module):
"""Construct the features from raw audio waveform"""
@@ -531,7 +345,6 @@ class Wav2Vec2ConformerFeatureEncoder(nn.Module):
return hidden_states
-# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2FeatureProjection with Wav2Vec2->Wav2Vec2Conformer
class Wav2Vec2ConformerFeatureProjection(nn.Module):
def __init__(self, config):
super().__init__()
@@ -547,7 +360,6 @@ class Wav2Vec2ConformerFeatureProjection(nn.Module):
return hidden_states, norm_hidden_states
-# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2FeedForward with Wav2Vec2->Wav2Vec2Conformer
class Wav2Vec2ConformerFeedForward(nn.Module):
def __init__(self, config):
super().__init__()
@@ -943,7 +755,6 @@ class Wav2Vec2ConformerEncoder(nn.Module):
)
-# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2GumbelVectorQuantizer with Wav2Vec2->Wav2Vec2Conformer
class Wav2Vec2ConformerGumbelVectorQuantizer(nn.Module):
"""
Vector quantization using gumbel softmax. See `[CATEGORICAL REPARAMETERIZATION WITH
@@ -1020,7 +831,6 @@ class Wav2Vec2ConformerGumbelVectorQuantizer(nn.Module):
return codevectors, perplexity
-# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Adapter with Wav2Vec2->Wav2Vec2Conformer
class Wav2Vec2ConformerAdapter(nn.Module):
def __init__(self, config):
super().__init__()
@@ -1052,7 +862,6 @@ class Wav2Vec2ConformerAdapter(nn.Module):
return hidden_states
-# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2AdapterLayer with Wav2Vec2->Wav2Vec2Conformer
class Wav2Vec2ConformerAdapterLayer(nn.Module):
def __init__(self, config):
super().__init__()
@@ -1170,6 +979,128 @@ class Wav2Vec2ConformerPreTrainedModel(PreTrainedModel):
return attention_mask
+def _compute_mask_indices(
+ shape: Tuple[int, int],
+ mask_prob: float,
+ mask_length: int,
+ attention_mask: Optional[torch.LongTensor] = None,
+ min_masks: int = 0,
+) -> np.ndarray:
+ """
+ Computes random mask spans for a given shape. Used to implement [SpecAugment: A Simple Data Augmentation Method for
+ ASR](https://arxiv.org/abs/1904.08779). Note that this method is not optimized to run on TPU and should be run on
+ CPU as part of the preprocessing during training.
+
+ Args:
+ shape: The shape for which to compute masks. This should be of a tuple of size 2 where
+ the first element is the batch size and the second element is the length of the axis to span.
+ mask_prob: The percentage of the whole axis (between 0 and 1) which will be masked. The number of
+ independently generated mask spans of length `mask_length` is computed by
+ `mask_prob*shape[1]/mask_length`. Note that due to overlaps, `mask_prob` is an upper bound and the
+ actual percentage will be smaller.
+ mask_length: size of the mask
+ min_masks: minimum number of masked spans
+ attention_mask: A (right-padded) attention mask which independently shortens the feature axis of
+ each batch dimension.
+ """
+ batch_size, sequence_length = shape
+
+ if mask_length < 1:
+ raise ValueError("`mask_length` has to be bigger than 0.")
+
+ if mask_length > sequence_length:
+ raise ValueError(
+ f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length}"
+ f" and `sequence_length`: {sequence_length}`"
+ )
+
+ # epsilon is used for probabilistic rounding
+ epsilon = np.random.rand(1).item()
+
+ def compute_num_masked_span(input_length):
+ """Given input length, compute how many spans should be masked"""
+ num_masked_span = int(mask_prob * input_length / mask_length + epsilon)
+ num_masked_span = max(num_masked_span, min_masks)
+
+ # make sure num masked span <= sequence_length
+ if num_masked_span * mask_length > sequence_length:
+ num_masked_span = sequence_length // mask_length
+
+ # make sure num_masked span is also <= input_length - (mask_length - 1)
+ if input_length - (mask_length - 1) < num_masked_span:
+ num_masked_span = max(input_length - (mask_length - 1), 0)
+
+ return num_masked_span
+
+ # compute number of masked spans in batch
+ input_lengths = (
+ attention_mask.detach().sum(-1).tolist()
+ if attention_mask is not None
+ else [sequence_length for _ in range(batch_size)]
+ )
+
+ # SpecAugment mask to fill
+ spec_aug_mask = np.zeros((batch_size, sequence_length), dtype=bool)
+ spec_aug_mask_idxs = []
+
+ max_num_masked_span = compute_num_masked_span(sequence_length)
+
+ if max_num_masked_span == 0:
+ return spec_aug_mask
+
+ for input_length in input_lengths:
+ # compute num of masked spans for this input
+ num_masked_span = compute_num_masked_span(input_length)
+
+ # get random indices to mask
+ spec_aug_mask_idx = np.random.choice(
+ np.arange(input_length - (mask_length - 1)), num_masked_span, replace=False
+ )
+
+ # pick first sampled index that will serve as a dummy index to pad vector
+ # to ensure same dimension for all batches due to probabilistic rounding
+ # Picking first sample just pads those vectors twice.
+ if len(spec_aug_mask_idx) == 0:
+ # this case can only happen if `input_length` is strictly smaller then
+ # `sequence_length` in which case the last token has to be a padding
+ # token which we can use as a dummy mask id
+ dummy_mask_idx = sequence_length - 1
+ else:
+ dummy_mask_idx = spec_aug_mask_idx[0]
+
+ spec_aug_mask_idx = np.concatenate(
+ [spec_aug_mask_idx, np.ones(max_num_masked_span - num_masked_span, dtype=np.int32) * dummy_mask_idx]
+ )
+ spec_aug_mask_idxs.append(spec_aug_mask_idx)
+
+ spec_aug_mask_idxs = np.array(spec_aug_mask_idxs)
+
+ # expand masked indices to masked spans
+ spec_aug_mask_idxs = np.broadcast_to(
+ spec_aug_mask_idxs[:, :, None], (batch_size, max_num_masked_span, mask_length)
+ )
+ spec_aug_mask_idxs = spec_aug_mask_idxs.reshape(batch_size, max_num_masked_span * mask_length)
+
+ # add offset to the starting indexes so that indexes now create a span
+ offsets = np.arange(mask_length)[None, None, :]
+ offsets = np.broadcast_to(offsets, (batch_size, max_num_masked_span, mask_length)).reshape(
+ batch_size, max_num_masked_span * mask_length
+ )
+ spec_aug_mask_idxs = spec_aug_mask_idxs + offsets
+
+ # ensure that we cannot have indices larger than sequence_length
+ if spec_aug_mask_idxs.max() > sequence_length - 1:
+ spec_aug_mask_idxs[spec_aug_mask_idxs > sequence_length - 1] = sequence_length - 1
+
+ # scatter indices to mask
+ np.put_along_axis(spec_aug_mask, spec_aug_mask_idxs, 1, -1)
+
+ return spec_aug_mask
+
+
+_EXPECTED_OUTPUT_SHAPE = [1, 292, 1024]
+
+
WAV2VEC2_CONFORMER_START_DOCSTRING = r"""
Wav2Vec2Conformer was proposed in [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech
Representations](https://arxiv.org/abs/2006.11477) by Alexei Baevski, Henry Zhou, Abdelrahman Mohamed, Michael
@@ -1178,8 +1109,9 @@ WAV2VEC2_CONFORMER_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving etc.).
- This model is a PyTorch [nn.Module](https://pytorch.org/docs/stable/nn.html#nn.Module) sub-class. Use it as a
- regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior.
+ This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use
+ it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
+ behavior.
Parameters:
config ([`Wav2Vec2ConformerConfig`]): Model configuration class with all the parameters of the model.
@@ -1187,7 +1119,6 @@ WAV2VEC2_CONFORMER_START_DOCSTRING = r"""
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
-
WAV2VEC2_CONFORMER_INPUTS_DOCSTRING = r"""
Args:
input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
@@ -1226,6 +1157,8 @@ WAV2VEC2_CONFORMER_INPUTS_DOCSTRING = r"""
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
+Wav2Vec2ConformerBaseModelOutput = Wav2Vec2BaseModelOutput
+
@add_start_docstrings(
"The bare Wav2Vec2Conformer Model transformer outputting raw hidden-states without any specific head on top.",
@@ -1249,7 +1182,6 @@ class Wav2Vec2ConformerModel(Wav2Vec2ConformerPreTrainedModel):
# Initialize weights and apply final processing
self.post_init()
- # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Model.freeze_feature_encoder
def freeze_feature_encoder(self):
"""
Calling this function will disable the gradient computation for the feature encoder so that its parameter will
@@ -1257,7 +1189,6 @@ class Wav2Vec2ConformerModel(Wav2Vec2ConformerPreTrainedModel):
"""
self.feature_extractor._freeze_parameters()
- # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Model._mask_hidden_states
def _mask_hidden_states(
self,
hidden_states: torch.FloatTensor,
@@ -1307,12 +1238,11 @@ class Wav2Vec2ConformerModel(Wav2Vec2ConformerPreTrainedModel):
@add_start_docstrings_to_model_forward(WAV2VEC2_CONFORMER_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
- output_type=Wav2Vec2BaseModelOutput,
+ output_type=Wav2Vec2ConformerBaseModelOutput,
config_class=_CONFIG_FOR_DOC,
modality="audio",
expected_output=_EXPECTED_OUTPUT_SHAPE,
)
- # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Model.forward with wav2vec2->wav2vec2_conformer
def forward(
self,
input_values: Optional[torch.Tensor],
@@ -1321,7 +1251,7 @@ class Wav2Vec2ConformerModel(Wav2Vec2ConformerPreTrainedModel):
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
- ) -> Union[Tuple, Wav2Vec2BaseModelOutput]:
+ ) -> Union[Tuple, Wav2Vec2ConformerBaseModelOutput]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
@@ -1358,7 +1288,7 @@ class Wav2Vec2ConformerModel(Wav2Vec2ConformerPreTrainedModel):
if not return_dict:
return (hidden_states, extract_features) + encoder_outputs[1:]
- return Wav2Vec2BaseModelOutput(
+ return Wav2Vec2ConformerBaseModelOutput(
last_hidden_state=hidden_states,
extract_features=extract_features,
hidden_states=encoder_outputs.hidden_states,
@@ -1370,7 +1300,6 @@ class Wav2Vec2ConformerModel(Wav2Vec2ConformerPreTrainedModel):
"""Wav2Vec2Conformer Model with a quantizer and `VQ` head on top.""", WAV2VEC2_CONFORMER_START_DOCSTRING
)
class Wav2Vec2ConformerForPreTraining(Wav2Vec2ConformerPreTrainedModel):
- # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForPreTraining.__init__ with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer
def __init__(self, config: Wav2Vec2ConformerConfig):
super().__init__(config)
self.wav2vec2_conformer = Wav2Vec2ConformerModel(config)
@@ -1384,14 +1313,12 @@ class Wav2Vec2ConformerForPreTraining(Wav2Vec2ConformerPreTrainedModel):
# Initialize weights and apply final processing
self.post_init()
- # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForPreTraining.set_gumbel_temperature
def set_gumbel_temperature(self, temperature: int):
"""
Set the Gumbel softmax temperature to a given value. Only necessary for training
"""
self.quantizer.temperature = temperature
- # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForPreTraining.freeze_feature_encoder with wav2vec2->wav2vec2_conformer
def freeze_feature_encoder(self):
"""
Calling this function will disable the gradient computation for the feature encoder so that its parameter will
@@ -1400,7 +1327,6 @@ class Wav2Vec2ConformerForPreTraining(Wav2Vec2ConformerPreTrainedModel):
self.wav2vec2_conformer.feature_extractor._freeze_parameters()
@staticmethod
- # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForPreTraining.compute_contrastive_logits
def compute_contrastive_logits(
target_features: torch.FloatTensor,
negative_features: torch.FloatTensor,
@@ -1423,7 +1349,6 @@ class Wav2Vec2ConformerForPreTraining(Wav2Vec2ConformerPreTrainedModel):
@add_start_docstrings_to_model_forward(WAV2VEC2_CONFORMER_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=Wav2Vec2ConformerForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
- # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForPreTraining.forward with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer,wav2vec2_conformer-base->wav2vec2-conformer-rel-pos-large
def forward(
self,
input_values: Optional[torch.Tensor],
@@ -1452,8 +1377,8 @@ class Wav2Vec2ConformerForPreTraining(Wav2Vec2ConformerPreTrainedModel):
>>> from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer import _compute_mask_indices, _sample_negative_indices
>>> from datasets import load_dataset
- >>> feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-conformer-rel-pos-large")
- >>> model = Wav2Vec2ConformerForPreTraining.from_pretrained("facebook/wav2vec2-conformer-rel-pos-large")
+ >>> feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2_conformer-base")
+ >>> model = Wav2Vec2ConformerForPreTraining.from_pretrained("facebook/wav2vec2_conformer-base")
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
>>> input_values = feature_extractor(ds[0]["audio"]["array"], return_tensors="pt").input_values # Batch size 1
@@ -1585,12 +1510,18 @@ class Wav2Vec2ConformerForPreTraining(Wav2Vec2ConformerPreTrainedModel):
)
+_HIDDEN_STATES_START_POSITION = 2
+
+# CTC docstring
+_CTC_EXPECTED_OUTPUT = "'MISTER QUILTER IS THE APOSTLE OF THE MIDDLE CLASSES AND WE ARE GLAD TO WELCOME HIS GOSPEL'"
+_CTC_EXPECTED_LOSS = 64.21
+
+
@add_start_docstrings(
"""Wav2Vec2Conformer Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC).""",
WAV2VEC2_CONFORMER_START_DOCSTRING,
)
class Wav2Vec2ConformerForCTC(Wav2Vec2ConformerPreTrainedModel):
- # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForCTC.__init__ with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer
def __init__(self, config, target_lang: Optional[str] = None):
super().__init__(config)
@@ -1614,7 +1545,6 @@ class Wav2Vec2ConformerForCTC(Wav2Vec2ConformerPreTrainedModel):
# Initialize weights and apply final processing
self.post_init()
- # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForCTC.freeze_feature_encoder with wav2vec2->wav2vec2_conformer
def freeze_feature_encoder(self):
"""
Calling this function will disable the gradient computation for the feature encoder so that its parameter will
@@ -1630,7 +1560,6 @@ class Wav2Vec2ConformerForCTC(Wav2Vec2ConformerPreTrainedModel):
expected_output=_CTC_EXPECTED_OUTPUT,
expected_loss=_CTC_EXPECTED_LOSS,
)
- # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForCTC.forward with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer
def forward(
self,
input_values: Optional[torch.Tensor],
@@ -1710,7 +1639,6 @@ class Wav2Vec2ConformerForCTC(Wav2Vec2ConformerPreTrainedModel):
WAV2VEC2_CONFORMER_START_DOCSTRING,
)
class Wav2Vec2ConformerForSequenceClassification(Wav2Vec2ConformerPreTrainedModel):
- # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.__init__ with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer
def __init__(self, config):
super().__init__(config)
@@ -1728,7 +1656,6 @@ class Wav2Vec2ConformerForSequenceClassification(Wav2Vec2ConformerPreTrainedMode
# Initialize weights and apply final processing
self.post_init()
- # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.freeze_feature_encoder with wav2vec2->wav2vec2_conformer
def freeze_feature_encoder(self):
"""
Calling this function will disable the gradient computation for the feature encoder so that its parameter will
@@ -1751,7 +1678,6 @@ class Wav2Vec2ConformerForSequenceClassification(Wav2Vec2ConformerPreTrainedMode
config_class=_CONFIG_FOR_DOC,
modality="audio",
)
- # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.forward with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer,WAV_2_VEC_2->WAV2VEC2_CONFORMER
def forward(
self,
input_values: Optional[torch.Tensor],
@@ -1822,7 +1748,6 @@ class Wav2Vec2ConformerForSequenceClassification(Wav2Vec2ConformerPreTrainedMode
WAV2VEC2_CONFORMER_START_DOCSTRING,
)
class Wav2Vec2ConformerForAudioFrameClassification(Wav2Vec2ConformerPreTrainedModel):
- # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForAudioFrameClassification.__init__ with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer,WAV_2_VEC_2->WAV2VEC2_CONFORMER
def __init__(self, config):
super().__init__(config)
@@ -1839,7 +1764,6 @@ class Wav2Vec2ConformerForAudioFrameClassification(Wav2Vec2ConformerPreTrainedMo
self.init_weights()
- # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForAudioFrameClassification.freeze_feature_encoder with wav2vec2->wav2vec2_conformer
def freeze_feature_encoder(self):
"""
Calling this function will disable the gradient computation for the feature encoder so that its parameter will
@@ -1847,7 +1771,6 @@ class Wav2Vec2ConformerForAudioFrameClassification(Wav2Vec2ConformerPreTrainedMo
"""
self.wav2vec2_conformer.feature_extractor._freeze_parameters()
- # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForAudioFrameClassification.freeze_base_model with wav2vec2->wav2vec2_conformer
def freeze_base_model(self):
"""
Calling this function will disable the gradient computation for the base model so that its parameters will not
@@ -1863,7 +1786,6 @@ class Wav2Vec2ConformerForAudioFrameClassification(Wav2Vec2ConformerPreTrainedMo
config_class=_CONFIG_FOR_DOC,
modality="audio",
)
- # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForAudioFrameClassification.forward with wav2vec2->wav2vec2_conformer
def forward(
self,
input_values: Optional[torch.Tensor],
@@ -1918,7 +1840,6 @@ class Wav2Vec2ConformerForAudioFrameClassification(Wav2Vec2ConformerPreTrainedMo
)
-# Copied from transformers.models.wav2vec2.modeling_wav2vec2.AMSoftmaxLoss
class AMSoftmaxLoss(nn.Module):
def __init__(self, input_dim, num_labels, scale=30.0, margin=0.4):
super(AMSoftmaxLoss, self).__init__()
@@ -1942,7 +1863,6 @@ class AMSoftmaxLoss(nn.Module):
return loss
-# Copied from transformers.models.wav2vec2.modeling_wav2vec2.TDNNLayer
class TDNNLayer(nn.Module):
def __init__(self, config, layer_id=0):
super().__init__()
@@ -1958,6 +1878,7 @@ class TDNNLayer(nn.Module):
if is_peft_available():
from peft.tuners.lora import LoraLayer
+ if is_peft_available():
if isinstance(self.kernel, LoraLayer):
warnings.warn(
"Detected LoRA on TDNNLayer. LoRA weights won't be applied due to optimization. "
@@ -2000,7 +1921,6 @@ class Wav2Vec2ConformerForXVector(Wav2Vec2ConformerPreTrainedModel):
self.init_weights()
- # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForXVector.freeze_feature_encoder with wav2vec2->wav2vec2_conformer
def freeze_feature_encoder(self):
"""
Calling this function will disable the gradient computation for the feature encoder so that its parameter will
@@ -2008,7 +1928,6 @@ class Wav2Vec2ConformerForXVector(Wav2Vec2ConformerPreTrainedModel):
"""
self.wav2vec2_conformer.feature_extractor._freeze_parameters()
- # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForXVector.freeze_base_model with wav2vec2->wav2vec2_conformer
def freeze_base_model(self):
"""
Calling this function will disable the gradient computation for the base model so that its parameters will not
@@ -2017,7 +1936,6 @@ class Wav2Vec2ConformerForXVector(Wav2Vec2ConformerPreTrainedModel):
for param in self.wav2vec2_conformer.parameters():
param.requires_grad = False
- # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForXVector._get_tdnn_output_lengths with wav2vec2->wav2vec2_conformer
def _get_tdnn_output_lengths(self, input_lengths: Union[torch.LongTensor, int]):
"""
Computes the output length of the TDNN layers
@@ -2040,7 +1958,6 @@ class Wav2Vec2ConformerForXVector(Wav2Vec2ConformerPreTrainedModel):
config_class=_CONFIG_FOR_DOC,
modality="audio",
)
- # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForXVector.forward with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer,WAV_2_VEC_2->WAV2VEC2_CONFORMER
def forward(
self,
input_values: Optional[torch.Tensor],
diff --git a/src/transformers/models/wav2vec2_conformer/modular_wav2vec2_conformer.py b/src/transformers/models/wav2vec2_conformer/modular_wav2vec2_conformer.py
new file mode 100644
index 00000000000..c2d101385fa
--- /dev/null
+++ b/src/transformers/models/wav2vec2_conformer/modular_wav2vec2_conformer.py
@@ -0,0 +1,892 @@
+import math
+from dataclasses import dataclass
+from typing import Optional, Tuple, Union
+
+import torch
+from torch import nn
+
+from ...activations import ACT2FN
+from ...integrations.deepspeed import is_deepspeed_zero3_enabled
+from ...integrations.fsdp import is_fsdp_managed_module
+from ...modeling_outputs import (
+ BaseModelOutput,
+ CausalLMOutput,
+ SequenceClassifierOutput,
+ TokenClassifierOutput,
+ Wav2Vec2BaseModelOutput,
+ XVectorOutput,
+)
+from ...modeling_utils import PreTrainedModel
+from ...utils import (
+ ModelOutput,
+ add_code_sample_docstrings,
+ add_start_docstrings,
+ add_start_docstrings_to_model_forward,
+ logging,
+ replace_return_docstrings,
+)
+from ..wav2vec2.modeling_wav2vec2 import (
+ Wav2Vec2Adapter,
+ Wav2Vec2AdapterLayer,
+ Wav2Vec2FeatureEncoder,
+ Wav2Vec2FeatureProjection,
+ Wav2Vec2FeedForward,
+ Wav2Vec2ForAudioFrameClassification,
+ Wav2Vec2ForCTC,
+ Wav2Vec2ForPreTraining,
+ Wav2Vec2ForSequenceClassification,
+ Wav2Vec2ForXVector,
+ Wav2Vec2GumbelVectorQuantizer,
+ Wav2Vec2Model,
+ Wav2Vec2PositionalConvEmbedding,
+)
+from .configuration_wav2vec2_conformer import Wav2Vec2ConformerConfig
+
+
+logger = logging.get_logger(__name__)
+
+_HIDDEN_STATES_START_POSITION = 2
+
+# General docstring
+_CONFIG_FOR_DOC = "Wav2Vec2ConformerConfig"
+
+# Base docstring
+_CHECKPOINT_FOR_DOC = "facebook/wav2vec2-conformer-rope-large-960h-ft"
+_EXPECTED_OUTPUT_SHAPE = [1, 292, 1024]
+
+# CTC docstring
+_CTC_EXPECTED_OUTPUT = "'MISTER QUILTER IS THE APOSTLE OF THE MIDDLE CLASSES AND WE ARE GLAD TO WELCOME HIS GOSPEL'"
+_CTC_EXPECTED_LOSS = 64.21
+
+
+@dataclass
+class Wav2Vec2ConformerForPreTrainingOutput(ModelOutput):
+ """
+ Output type of [`Wav2Vec2ConformerForPreTraining`], with potential hidden states and attentions.
+
+ Args:
+ loss (*optional*, returned when `sample_negative_indices` are passed, `torch.FloatTensor` of shape `(1,)`):
+ Total loss as the sum of the contrastive loss (L_m) and the diversity loss (L_d) as stated in the [official
+ paper](https://arxiv.org/pdf/2006.11477.pdf) . (classification) loss.
+ projected_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.proj_codevector_dim)`):
+ Hidden-states of the model projected to *config.proj_codevector_dim* that can be used to predict the masked
+ projected quantized states.
+ projected_quantized_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.proj_codevector_dim)`):
+ Quantized extracted feature vectors projected to *config.proj_codevector_dim* representing the positive
+ target vectors for contrastive loss.
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
+ shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ contrastive_loss (*optional*, returned when `sample_negative_indices` are passed, `torch.FloatTensor` of shape `(1,)`):
+ The contrastive loss (L_m) as stated in the [official paper](https://arxiv.org/pdf/2006.11477.pdf) .
+ diversity_loss (*optional*, returned when `sample_negative_indices` are passed, `torch.FloatTensor` of shape `(1,)`):
+ The diversity loss (L_d) as stated in the [official paper](https://arxiv.org/pdf/2006.11477.pdf) .
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ projected_states: Optional[torch.FloatTensor] = None
+ projected_quantized_states: Optional[torch.FloatTensor] = None
+ codevector_perplexity: Optional[torch.FloatTensor] = None
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
+ contrastive_loss: Optional[torch.FloatTensor] = None
+ diversity_loss: Optional[torch.FloatTensor] = None
+
+
+class Wav2Vec2ConformerPositionalConvEmbedding(Wav2Vec2PositionalConvEmbedding):
+ pass
+
+
+class Wav2Vec2ConformerRotaryPositionalEmbedding(nn.Module):
+ """Rotary positional embedding
+ Reference : https://blog.eleuther.ai/rotary-embeddings/ Paper: https://arxiv.org/pdf/2104.09864.pdf
+ """
+
+ def __init__(self, config):
+ super().__init__()
+ dim = config.hidden_size // config.num_attention_heads
+ base = config.rotary_embedding_base
+
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim))
+ self.register_buffer("inv_freq", inv_freq)
+ self.cached_sequence_length = None
+ self.cached_rotary_positional_embedding = None
+
+ def forward(self, hidden_states):
+ sequence_length = hidden_states.shape[1]
+
+ if sequence_length == self.cached_sequence_length and self.cached_rotary_positional_embedding is not None:
+ return self.cached_rotary_positional_embedding
+
+ self.cached_sequence_length = sequence_length
+ # Embeddings are computed in the dtype of the inv_freq constant
+ time_stamps = torch.arange(sequence_length).type_as(self.inv_freq)
+ freqs = torch.einsum("i,j->ij", time_stamps, self.inv_freq)
+ embeddings = torch.cat((freqs, freqs), dim=-1)
+
+ cos_embeddings = embeddings.cos()[:, None, None, :]
+ sin_embeddings = embeddings.sin()[:, None, None, :]
+ # Computed embeddings are cast to the dtype of the hidden state inputs
+ self.cached_rotary_positional_embedding = torch.stack([cos_embeddings, sin_embeddings]).type_as(hidden_states)
+ return self.cached_rotary_positional_embedding
+
+
+class Wav2Vec2ConformerRelPositionalEmbedding(nn.Module):
+ """Relative positional encoding module."""
+
+ def __init__(self, config):
+ super().__init__()
+ self.max_len = config.max_source_positions
+ self.d_model = config.hidden_size
+ self.pe = None
+ self.extend_pe(torch.tensor(0.0).expand(1, self.max_len))
+
+ def extend_pe(self, x):
+ # Reset the positional encodings
+ if self.pe is not None:
+ # self.pe contains both positive and negative parts
+ # the length of self.pe is 2 * input_len - 1
+ if self.pe.size(1) >= x.size(1) * 2 - 1:
+ if self.pe.dtype != x.dtype or self.pe.device != x.device:
+ self.pe = self.pe.to(dtype=x.dtype, device=x.device)
+ return
+ # Suppose `i` is the position of query vector and `j` is the
+ # position of key vector. We use positive relative positions when keys
+ # are to the left (i>j) and negative relative positions otherwise (i (batch, 2*channel, dim)
+ hidden_states = self.pointwise_conv1(hidden_states)
+ # => (batch, channel, dim)
+ hidden_states = self.glu(hidden_states)
+
+ # 1D Depthwise Conv
+ hidden_states = self.depthwise_conv(hidden_states)
+ hidden_states = self.batch_norm(hidden_states)
+ hidden_states = self.activation(hidden_states)
+
+ hidden_states = self.pointwise_conv2(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = hidden_states.transpose(1, 2)
+ return hidden_states
+
+
+class Wav2Vec2ConformerSelfAttention(nn.Module):
+ """Construct an Wav2Vec2ConformerSelfAttention object.
+ Can be enhanced with rotary or relative position embeddings.
+ """
+
+ def __init__(self, config):
+ super().__init__()
+
+ self.head_size = config.hidden_size // config.num_attention_heads
+ self.num_heads = config.num_attention_heads
+ self.position_embeddings_type = config.position_embeddings_type
+
+ self.linear_q = nn.Linear(config.hidden_size, config.hidden_size)
+ self.linear_k = nn.Linear(config.hidden_size, config.hidden_size)
+ self.linear_v = nn.Linear(config.hidden_size, config.hidden_size)
+ self.linear_out = nn.Linear(config.hidden_size, config.hidden_size)
+
+ self.dropout = nn.Dropout(p=config.attention_dropout)
+
+ if self.position_embeddings_type == "relative":
+ # linear transformation for positional encoding
+ self.linear_pos = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
+ # these two learnable bias are used in matrix c and matrix d
+ # as described in https://arxiv.org/abs/1901.02860 Section 3.3
+ self.pos_bias_u = nn.Parameter(torch.zeros(self.num_heads, self.head_size))
+ self.pos_bias_v = nn.Parameter(torch.zeros(self.num_heads, self.head_size))
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ relative_position_embeddings: Optional[torch.Tensor] = None,
+ output_attentions: bool = False,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ # self-attention mechanism
+ batch_size, sequence_length, hidden_size = hidden_states.size()
+
+ # make sure query/key states can be != value states
+ query_key_states = hidden_states
+ value_states = hidden_states
+
+ if self.position_embeddings_type == "rotary":
+ if relative_position_embeddings is None:
+ raise ValueError(
+ "`relative_position_embeddings` has to be defined when `self.position_embeddings_type == 'rotary'"
+ )
+ query_key_states = self._apply_rotary_embedding(query_key_states, relative_position_embeddings)
+
+ # project query_key_states and value_states
+ query = self.linear_q(query_key_states).view(batch_size, -1, self.num_heads, self.head_size)
+ key = self.linear_k(query_key_states).view(batch_size, -1, self.num_heads, self.head_size)
+ value = self.linear_v(value_states).view(batch_size, -1, self.num_heads, self.head_size)
+
+ # => (batch, head, time1, d_k)
+ query = query.transpose(1, 2)
+ key = key.transpose(1, 2)
+ value = value.transpose(1, 2)
+
+ if self.position_embeddings_type == "relative":
+ if relative_position_embeddings is None:
+ raise ValueError(
+ "`relative_position_embeddings` has to be defined when `self.position_embeddings_type =="
+ " 'relative'"
+ )
+ # apply relative_position_embeddings to qk scores
+ # as proposed in Transformer_XL: https://arxiv.org/abs/1901.02860
+ scores = self._apply_relative_embeddings(
+ query=query, key=key, relative_position_embeddings=relative_position_embeddings
+ )
+ else:
+ scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.head_size)
+
+ # apply attention_mask if necessary
+ if attention_mask is not None:
+ scores = scores + attention_mask
+
+ # => (batch, head, time1, time2)
+ probs = torch.softmax(scores, dim=-1)
+ probs = self.dropout(probs)
+
+ # => (batch, head, time1, d_k)
+ hidden_states = torch.matmul(probs, value)
+
+ # => (batch, time1, hidden_size)
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_size)
+ hidden_states = self.linear_out(hidden_states)
+
+ return hidden_states, probs
+
+ def _apply_rotary_embedding(self, hidden_states, relative_position_embeddings):
+ batch_size, sequence_length, hidden_size = hidden_states.size()
+ hidden_states = hidden_states.view(batch_size, sequence_length, self.num_heads, self.head_size)
+
+ cos = relative_position_embeddings[0, :sequence_length, ...]
+ sin = relative_position_embeddings[1, :sequence_length, ...]
+
+ # rotate hidden_states with rotary embeddings
+ hidden_states = hidden_states.transpose(0, 1)
+ rotated_states_begin = hidden_states[..., : self.head_size // 2]
+ rotated_states_end = hidden_states[..., self.head_size // 2 :]
+ rotated_states = torch.cat((-rotated_states_end, rotated_states_begin), dim=rotated_states_begin.ndim - 1)
+ hidden_states = (hidden_states * cos) + (rotated_states * sin)
+ hidden_states = hidden_states.transpose(0, 1)
+
+ hidden_states = hidden_states.view(batch_size, sequence_length, self.num_heads * self.head_size)
+
+ return hidden_states
+
+ def _apply_relative_embeddings(self, query, key, relative_position_embeddings):
+ # 1. project positional embeddings
+ # => (batch, head, 2*time1-1, d_k)
+ proj_relative_position_embeddings = self.linear_pos(relative_position_embeddings)
+ proj_relative_position_embeddings = proj_relative_position_embeddings.view(
+ relative_position_embeddings.size(0), -1, self.num_heads, self.head_size
+ )
+ proj_relative_position_embeddings = proj_relative_position_embeddings.transpose(1, 2)
+ proj_relative_position_embeddings = proj_relative_position_embeddings.transpose(2, 3)
+
+ # 2. Add bias to query
+ # => (batch, head, time1, d_k)
+ query = query.transpose(1, 2)
+ q_with_bias_u = (query + self.pos_bias_u).transpose(1, 2)
+ q_with_bias_v = (query + self.pos_bias_v).transpose(1, 2)
+
+ # 3. attention score: first compute matrix a and matrix c
+ # as described in https://arxiv.org/abs/1901.02860 Section 3.3
+ # => (batch, head, time1, time2)
+ scores_ac = torch.matmul(q_with_bias_u, key.transpose(-2, -1))
+
+ # 4. then compute matrix b and matrix d
+ # => (batch, head, time1, 2*time1-1)
+ scores_bd = torch.matmul(q_with_bias_v, proj_relative_position_embeddings)
+
+ # 5. shift matrix b and matrix d
+ zero_pad = torch.zeros((*scores_bd.size()[:3], 1), device=scores_bd.device, dtype=scores_bd.dtype)
+ scores_bd_padded = torch.cat([zero_pad, scores_bd], dim=-1)
+ scores_bd_padded_shape = scores_bd.size()[:2] + (scores_bd.shape[3] + 1, scores_bd.shape[2])
+ scores_bd_padded = scores_bd_padded.view(*scores_bd_padded_shape)
+ scores_bd = scores_bd_padded[:, :, 1:].view_as(scores_bd)
+ scores_bd = scores_bd[:, :, :, : scores_bd.size(-1) // 2 + 1]
+
+ # 6. sum matrices
+ # => (batch, head, time1, time2)
+ scores = (scores_ac + scores_bd) / math.sqrt(self.head_size)
+
+ return scores
+
+
+class Wav2Vec2ConformerEncoderLayer(nn.Module):
+ """Conformer block based on https://arxiv.org/abs/2005.08100."""
+
+ def __init__(self, config):
+ super().__init__()
+ embed_dim = config.hidden_size
+ dropout = config.attention_dropout
+
+ # Feed-forward 1
+ self.ffn1_layer_norm = nn.LayerNorm(embed_dim)
+ self.ffn1 = Wav2Vec2ConformerFeedForward(config)
+
+ # Self-Attention
+ self.self_attn_layer_norm = nn.LayerNorm(embed_dim)
+ self.self_attn_dropout = nn.Dropout(dropout)
+ self.self_attn = Wav2Vec2ConformerSelfAttention(config)
+
+ # Conformer Convolution
+ self.conv_module = Wav2Vec2ConformerConvolutionModule(config)
+
+ # Feed-forward 2
+ self.ffn2_layer_norm = nn.LayerNorm(embed_dim)
+ self.ffn2 = Wav2Vec2ConformerFeedForward(config)
+ self.final_layer_norm = nn.LayerNorm(embed_dim)
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask: Optional[torch.Tensor] = None,
+ relative_position_embeddings: Optional[torch.Tensor] = None,
+ output_attentions: bool = False,
+ ):
+ hidden_states = hidden_states
+
+ # 1. Feed-Forward 1 layer
+ residual = hidden_states
+ hidden_states = self.ffn1_layer_norm(hidden_states)
+ hidden_states = self.ffn1(hidden_states)
+ hidden_states = hidden_states * 0.5 + residual
+ residual = hidden_states
+
+ # 2. Self-Attention layer
+ hidden_states = self.self_attn_layer_norm(hidden_states)
+ hidden_states, attn_weigts = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ relative_position_embeddings=relative_position_embeddings,
+ output_attentions=output_attentions,
+ )
+ hidden_states = self.self_attn_dropout(hidden_states)
+ hidden_states = hidden_states + residual
+
+ # 3. Convolutional Layer
+ residual = hidden_states
+ hidden_states = self.conv_module(hidden_states)
+ hidden_states = residual + hidden_states
+
+ # 4. Feed-Forward 2 Layer
+ residual = hidden_states
+ hidden_states = self.ffn2_layer_norm(hidden_states)
+ hidden_states = self.ffn2(hidden_states)
+ hidden_states = hidden_states * 0.5 + residual
+ hidden_states = self.final_layer_norm(hidden_states)
+
+ return hidden_states, attn_weigts
+
+
+class Wav2Vec2ConformerEncoder(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+
+ if config.position_embeddings_type == "relative":
+ self.embed_positions = Wav2Vec2ConformerRelPositionalEmbedding(config)
+ elif config.position_embeddings_type == "rotary":
+ self.embed_positions = Wav2Vec2ConformerRotaryPositionalEmbedding(config)
+ else:
+ self.embed_positions = None
+
+ self.pos_conv_embed = Wav2Vec2ConformerPositionalConvEmbedding(config)
+ self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.dropout = nn.Dropout(config.hidden_dropout)
+ self.layers = nn.ModuleList([Wav2Vec2ConformerEncoderLayer(config) for _ in range(config.num_hidden_layers)])
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask=None,
+ output_attentions=False,
+ output_hidden_states=False,
+ return_dict=True,
+ ):
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attentions = () if output_attentions else None
+
+ if attention_mask is not None:
+ # make sure padded tokens output 0
+ expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
+ hidden_states[~expand_attention_mask] = 0.0
+
+ # extend attention_mask
+ attention_mask = 1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)
+ attention_mask = attention_mask * torch.finfo(hidden_states.dtype).min
+ attention_mask = attention_mask.expand(
+ attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1]
+ )
+
+ hidden_states = self.dropout(hidden_states)
+
+ if self.embed_positions is not None:
+ relative_position_embeddings = self.embed_positions(hidden_states)
+ else:
+ relative_position_embeddings = None
+
+ synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self)
+
+ for i, layer in enumerate(self.layers):
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
+ dropout_probability = torch.rand([])
+
+ skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False
+ if not skip_the_layer or synced_gpus:
+ # under fsdp or deepspeed zero3 all gpus must run in sync
+ if self.gradient_checkpointing and self.training:
+ layer_outputs = self._gradient_checkpointing_func(
+ layer.__call__,
+ hidden_states,
+ attention_mask,
+ relative_position_embeddings,
+ output_attentions,
+ )
+ else:
+ layer_outputs = layer(
+ hidden_states,
+ attention_mask=attention_mask,
+ relative_position_embeddings=relative_position_embeddings,
+ output_attentions=output_attentions,
+ )
+ hidden_states = layer_outputs[0]
+
+ if skip_the_layer:
+ layer_outputs = (None, None)
+
+ if output_attentions:
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
+
+ hidden_states = self.layer_norm(hidden_states)
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
+ return BaseModelOutput(
+ last_hidden_state=hidden_states,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attentions,
+ )
+
+
+class Wav2Vec2ConformerGumbelVectorQuantizer(Wav2Vec2GumbelVectorQuantizer):
+ pass
+
+
+class Wav2Vec2ConformerAdapter(Wav2Vec2Adapter):
+ pass
+
+
+class Wav2Vec2ConformerAdapterLayer(Wav2Vec2AdapterLayer):
+ pass
+
+
+class Wav2Vec2ConformerPreTrainedModel(PreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = Wav2Vec2ConformerConfig
+ base_model_prefix = "wav2vec2_conformer"
+ main_input_name = "input_values"
+ supports_gradient_checkpointing = True
+
+ def _init_weights(self, module):
+ """Initialize the weights"""
+ # Wav2Vec2ForPreTraining last 2 linear layers need standard Linear init.
+ if isinstance(module, Wav2Vec2ConformerForPreTraining):
+ module.project_hid.reset_parameters()
+ module.project_q.reset_parameters()
+ module.project_hid._is_hf_initialized = True
+ module.project_q._is_hf_initialized = True
+ # gumbel softmax requires special init
+ elif isinstance(module, Wav2Vec2ConformerGumbelVectorQuantizer):
+ module.weight_proj.weight.data.normal_(mean=0.0, std=1)
+ module.weight_proj.bias.data.zero_()
+ nn.init.uniform_(module.codevectors)
+ elif isinstance(module, Wav2Vec2ConformerSelfAttention):
+ if hasattr(module, "pos_bias_u"):
+ nn.init.xavier_uniform_(module.pos_bias_u)
+ if hasattr(module, "pos_bias_v"):
+ nn.init.xavier_uniform_(module.pos_bias_v)
+ elif isinstance(module, Wav2Vec2ConformerPositionalConvEmbedding):
+ nn.init.normal_(
+ module.conv.weight,
+ mean=0,
+ std=2 * math.sqrt(1 / (module.conv.kernel_size[0] * module.conv.in_channels)),
+ )
+ nn.init.constant_(module.conv.bias, 0)
+ elif isinstance(module, Wav2Vec2ConformerFeatureProjection):
+ k = math.sqrt(1 / module.projection.in_features)
+ nn.init.uniform_(module.projection.weight, a=-k, b=k)
+ nn.init.uniform_(module.projection.bias, a=-k, b=k)
+ elif isinstance(module, nn.Linear):
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+ elif isinstance(module, nn.Conv1d):
+ nn.init.kaiming_normal_(module.weight)
+
+ if module.bias is not None:
+ k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0]))
+ nn.init.uniform_(module.bias, a=-k, b=k)
+
+ def _get_feat_extract_output_lengths(
+ self, input_lengths: Union[torch.LongTensor, int], add_adapter: Optional[bool] = None
+ ):
+ """
+ Computes the output length of the convolutional layers
+ """
+
+ add_adapter = self.config.add_adapter if add_adapter is None else add_adapter
+
+ def _conv_out_length(input_length, kernel_size, stride):
+ # 1D convolutional layer output length formula taken
+ # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
+ return torch.div(input_length - kernel_size, stride, rounding_mode="floor") + 1
+
+ for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride):
+ input_lengths = _conv_out_length(input_lengths, kernel_size, stride)
+
+ if add_adapter:
+ for _ in range(self.config.num_adapter_layers):
+ input_lengths = _conv_out_length(input_lengths, 1, self.config.adapter_stride)
+
+ return input_lengths
+
+ def _get_feature_vector_attention_mask(
+ self, feature_vector_length: int, attention_mask: torch.LongTensor, add_adapter=None
+ ):
+ # Effectively attention_mask.sum(-1), but not inplace to be able to run
+ # on inference mode.
+ non_padded_lengths = attention_mask.cumsum(dim=-1)[:, -1]
+
+ output_lengths = self._get_feat_extract_output_lengths(non_padded_lengths, add_adapter=add_adapter)
+ output_lengths = output_lengths.to(torch.long)
+
+ batch_size = attention_mask.shape[0]
+
+ attention_mask = torch.zeros(
+ (batch_size, feature_vector_length), dtype=attention_mask.dtype, device=attention_mask.device
+ )
+ # these two operations makes sure that all values before the output lengths idxs are attended to
+ attention_mask[(torch.arange(attention_mask.shape[0], device=attention_mask.device), output_lengths - 1)] = 1
+ attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool()
+ return attention_mask
+
+
+WAV2VEC2_CONFORMER_START_DOCSTRING = None # will be automatically redefined
+
+WAV2VEC2_CONFORMER_INPUTS_DOCSTRING = r"""
+ Args:
+ input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
+ Float values of input raw speech waveform. Values can be obtained by loading a `.flac` or `.wav` audio file
+ into an array of type `List[float]` or a `numpy.ndarray`, *e.g.* via the soundfile library (`pip install
+ soundfile`). To prepare the array into `input_values`, the [`AutoProcessor`] should be used for padding and
+ conversion into a tensor of type `torch.FloatTensor`. See [`Wav2Vec2Processor.__call__`] for details.
+ attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing convolution and attention on padding token indices. Mask values selected in `[0,
+ 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+
+
+
+ `attention_mask` should only be passed if the corresponding processor has `config.return_attention_mask ==
+ True`. For all models whose processor has `config.return_attention_mask == False`, such as
+ [wav2vec2-conformer-rel-pos-large](https://huggingface.co/facebook/wav2vec2-conformer-rel-pos-large),
+ `attention_mask` should **not** be passed to avoid degraded performance when doing batched inference. For
+ such models `input_values` should simply be padded with 0 and passed without `attention_mask`. Be aware
+ that these models also yield slightly different results depending on whether `input_values` is padded or
+ not.
+
+
+
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+
+Wav2Vec2ConformerBaseModelOutput = Wav2Vec2BaseModelOutput
+
+
+@add_start_docstrings(
+ "The bare Wav2Vec2Conformer Model transformer outputting raw hidden-states without any specific head on top.",
+ WAV2VEC2_CONFORMER_START_DOCSTRING,
+)
+class Wav2Vec2ConformerModel(Wav2Vec2ConformerPreTrainedModel, Wav2Vec2Model):
+ def __init__(self, config: Wav2Vec2ConformerConfig):
+ Wav2Vec2ConformerPreTrainedModel.__init__(config)
+ self.config = config
+ self.feature_extractor = Wav2Vec2ConformerFeatureEncoder(config)
+ self.feature_projection = Wav2Vec2ConformerFeatureProjection(config)
+
+ # model only needs masking vector if mask prob is > 0.0
+ if config.mask_time_prob > 0.0 or config.mask_feature_prob > 0.0:
+ self.masked_spec_embed = nn.Parameter(torch.Tensor(config.hidden_size).uniform_())
+
+ self.encoder = Wav2Vec2ConformerEncoder(config)
+
+ self.adapter = Wav2Vec2ConformerAdapter(config) if config.add_adapter else None
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def freeze_feature_extractor(self):
+ raise AttributeError("Not needed for Wav2Vec2Conformer")
+
+ @add_start_docstrings_to_model_forward(WAV2VEC2_CONFORMER_INPUTS_DOCSTRING)
+ @add_code_sample_docstrings(
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=Wav2Vec2ConformerBaseModelOutput,
+ config_class=_CONFIG_FOR_DOC,
+ modality="audio",
+ expected_output=_EXPECTED_OUTPUT_SHAPE,
+ )
+ def forward(self, **super_kwargs):
+ return super().forward(**super_kwargs)
+
+
+@add_start_docstrings(
+ """Wav2Vec2Conformer Model with a quantizer and `VQ` head on top.""", WAV2VEC2_CONFORMER_START_DOCSTRING
+)
+class Wav2Vec2ConformerForPreTraining(Wav2Vec2ForPreTraining):
+ def __init__(self, config: Wav2Vec2ConformerConfig):
+ super().__init__(config)
+
+ def freeze_feature_extractor(self):
+ raise AttributeError("Not needed for Wav2Vec2Conformer")
+
+ @add_start_docstrings_to_model_forward(WAV2VEC2_CONFORMER_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=Wav2Vec2ConformerForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
+ def forward(self, **super_kwargs) -> Union[Tuple, Wav2Vec2ConformerForPreTrainingOutput]:
+ return super().forward(**super_kwargs)
+
+
+@add_start_docstrings(
+ """Wav2Vec2Conformer Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC).""",
+ WAV2VEC2_CONFORMER_START_DOCSTRING,
+)
+class Wav2Vec2ConformerForCTC(Wav2Vec2ForCTC):
+ def __init__(self, config, target_lang: Optional[str] = None):
+ super().__init__(config)
+
+ def tie_weights(self):
+ raise AttributeError("Not needed for Wav2Vec2Conformer")
+
+ def freeze_feature_extractor(self):
+ raise AttributeError("Not needed for Wav2Vec2Conformer")
+
+ def freeze_base_model(self):
+ raise AttributeError("Not needed for Wav2Vec2Conformer")
+
+ @add_start_docstrings_to_model_forward(WAV2VEC2_CONFORMER_INPUTS_DOCSTRING)
+ @add_code_sample_docstrings(
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=CausalLMOutput,
+ config_class=_CONFIG_FOR_DOC,
+ expected_output=_CTC_EXPECTED_OUTPUT,
+ expected_loss=_CTC_EXPECTED_LOSS,
+ )
+ def forward(self, **super_kwargs):
+ return super().forward(**super_kwargs)
+
+
+@add_start_docstrings(
+ """
+ Wav2Vec2Conformer Model with a sequence classification head on top (a linear layer over the pooled output) for
+ tasks like SUPERB Keyword Spotting.
+ """,
+ WAV2VEC2_CONFORMER_START_DOCSTRING,
+)
+class Wav2Vec2ConformerForSequenceClassification(Wav2Vec2ForSequenceClassification):
+ def __init__(self, config):
+ super().__init__(config)
+
+ def freeze_feature_extractor(self):
+ raise AttributeError("Not needed for Wav2Vec2Conformer")
+
+ @add_start_docstrings_to_model_forward(WAV2VEC2_CONFORMER_INPUTS_DOCSTRING)
+ @add_code_sample_docstrings(
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=SequenceClassifierOutput,
+ config_class=_CONFIG_FOR_DOC,
+ modality="audio",
+ )
+ def forward(self, **super_kwargs):
+ return super().forward(**super_kwargs)
+
+
+@add_start_docstrings(
+ """
+ Wav2Vec2Conformer Model with a frame classification head on top for tasks like Speaker Diarization.
+ """,
+ WAV2VEC2_CONFORMER_START_DOCSTRING,
+)
+class Wav2Vec2ConformerForAudioFrameClassification(Wav2Vec2ForAudioFrameClassification):
+ def __init__(self, config):
+ super().__init__(config)
+
+ def freeze_feature_extractor(self):
+ raise AttributeError("Not needed for Wav2Vec2Conformer")
+
+ @add_start_docstrings_to_model_forward(WAV2VEC2_CONFORMER_INPUTS_DOCSTRING)
+ @add_code_sample_docstrings(
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=TokenClassifierOutput,
+ config_class=_CONFIG_FOR_DOC,
+ modality="audio",
+ )
+ def forward(self, **super_kwargs):
+ return super().forward(**super_kwargs)
+
+
+@add_start_docstrings(
+ """
+ Wav2Vec2Conformer Model with an XVector feature extraction head on top for tasks like Speaker Verification.
+ """,
+ WAV2VEC2_CONFORMER_START_DOCSTRING,
+)
+class Wav2Vec2ConformerForXVector(Wav2Vec2ForXVector):
+ def __init__(self, config):
+ super().__init__(config)
+
+ def freeze_feature_extractor(self):
+ raise AttributeError("Not needed for Wav2Vec2Conformer")
+
+ @add_start_docstrings_to_model_forward(WAV2VEC2_CONFORMER_INPUTS_DOCSTRING)
+ @add_code_sample_docstrings(
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=XVectorOutput,
+ config_class=_CONFIG_FOR_DOC,
+ modality="audio",
+ )
+ def forward(self, **super_kwargs):
+ return super().forward(**super_kwargs)
+
+
+__all__ = [
+ "Wav2Vec2ConformerForAudioFrameClassification",
+ "Wav2Vec2ConformerForCTC",
+ "Wav2Vec2ConformerForPreTraining",
+ "Wav2Vec2ConformerForSequenceClassification",
+ "Wav2Vec2ConformerForXVector",
+ "Wav2Vec2ConformerModel",
+ "Wav2Vec2ConformerPreTrainedModel",
+]
diff --git a/src/transformers/models/wavlm/modeling_wavlm.py b/src/transformers/models/wavlm/modeling_wavlm.py
index 270d48f8378..1c3c09d1a70 100755
--- a/src/transformers/models/wavlm/modeling_wavlm.py
+++ b/src/transformers/models/wavlm/modeling_wavlm.py
@@ -1,28 +1,17 @@
-# coding=utf-8
-# Copyright 2021 The Fairseq Authors, Microsoft Research, and The HuggingFace Inc. team. All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-"""PyTorch WavLM model."""
-
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/wavlm/modular_wavlm.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_wavlm.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
import math
import warnings
from typing import Optional, Tuple, Union
import numpy as np
import torch
+import torch.nn as nn
import torch.nn.functional as F
-import torch.utils.checkpoint
-from torch import nn
from torch.nn import CrossEntropyLoss
from ...activations import ACT2FN
@@ -49,225 +38,22 @@ from .configuration_wavlm import WavLMConfig
logger = logging.get_logger(__name__)
+_CHECKPOINT_FOR_DOC = "patrickvonplaten/wavlm-libri-clean-100h-base-plus"
-_HIDDEN_STATES_START_POSITION = 2
-
-# General docstring
_CONFIG_FOR_DOC = "WavLMConfig"
-# Base docstring
-_CHECKPOINT_FOR_DOC = "patrickvonplaten/wavlm-libri-clean-100h-base-plus"
-_EXPECTED_OUTPUT_SHAPE = [1, 292, 768]
-# CTC docstring
-_CTC_EXPECTED_OUTPUT = "'mister quilter is the aposle of the middle classes and we are glad to welcome his gospel'"
-_CTC_EXPECTED_LOSS = 12.51
-
-# Frame class docstring
-_FRAME_CLASS_CHECKPOINT = "microsoft/wavlm-base-plus-sd"
-_FRAME_EXPECTED_OUTPUT = [0, 0]
-
-# Speaker Verification docstring
-_XVECTOR_CHECKPOINT = "microsoft/wavlm-base-plus-sv"
-_XVECTOR_EXPECTED_OUTPUT = 0.97
-
-
-# Copied from transformers.models.wav2vec2.modeling_wav2vec2._compute_mask_indices
-def _compute_mask_indices(
- shape: Tuple[int, int],
- mask_prob: float,
- mask_length: int,
- attention_mask: Optional[torch.LongTensor] = None,
- min_masks: int = 0,
-) -> np.ndarray:
- """
- Computes random mask spans for a given shape. Used to implement [SpecAugment: A Simple Data Augmentation Method for
- ASR](https://arxiv.org/abs/1904.08779). Note that this method is not optimized to run on TPU and should be run on
- CPU as part of the preprocessing during training.
-
- Args:
- shape: The shape for which to compute masks. This should be of a tuple of size 2 where
- the first element is the batch size and the second element is the length of the axis to span.
- mask_prob: The percentage of the whole axis (between 0 and 1) which will be masked. The number of
- independently generated mask spans of length `mask_length` is computed by
- `mask_prob*shape[1]/mask_length`. Note that due to overlaps, `mask_prob` is an upper bound and the
- actual percentage will be smaller.
- mask_length: size of the mask
- min_masks: minimum number of masked spans
- attention_mask: A (right-padded) attention mask which independently shortens the feature axis of
- each batch dimension.
- """
- batch_size, sequence_length = shape
-
- if mask_length < 1:
- raise ValueError("`mask_length` has to be bigger than 0.")
-
- if mask_length > sequence_length:
- raise ValueError(
- f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length}"
- f" and `sequence_length`: {sequence_length}`"
- )
-
- # epsilon is used for probabilistic rounding
- epsilon = np.random.rand(1).item()
-
- def compute_num_masked_span(input_length):
- """Given input length, compute how many spans should be masked"""
- num_masked_span = int(mask_prob * input_length / mask_length + epsilon)
- num_masked_span = max(num_masked_span, min_masks)
-
- # make sure num masked span <= sequence_length
- if num_masked_span * mask_length > sequence_length:
- num_masked_span = sequence_length // mask_length
-
- # make sure num_masked span is also <= input_length - (mask_length - 1)
- if input_length - (mask_length - 1) < num_masked_span:
- num_masked_span = max(input_length - (mask_length - 1), 0)
-
- return num_masked_span
-
- # compute number of masked spans in batch
- input_lengths = (
- attention_mask.detach().sum(-1).tolist()
- if attention_mask is not None
- else [sequence_length for _ in range(batch_size)]
- )
-
- # SpecAugment mask to fill
- spec_aug_mask = np.zeros((batch_size, sequence_length), dtype=bool)
- spec_aug_mask_idxs = []
-
- max_num_masked_span = compute_num_masked_span(sequence_length)
-
- if max_num_masked_span == 0:
- return spec_aug_mask
-
- for input_length in input_lengths:
- # compute num of masked spans for this input
- num_masked_span = compute_num_masked_span(input_length)
-
- # get random indices to mask
- spec_aug_mask_idx = np.random.choice(
- np.arange(input_length - (mask_length - 1)), num_masked_span, replace=False
- )
-
- # pick first sampled index that will serve as a dummy index to pad vector
- # to ensure same dimension for all batches due to probabilistic rounding
- # Picking first sample just pads those vectors twice.
- if len(spec_aug_mask_idx) == 0:
- # this case can only happen if `input_length` is strictly smaller then
- # `sequence_length` in which case the last token has to be a padding
- # token which we can use as a dummy mask id
- dummy_mask_idx = sequence_length - 1
- else:
- dummy_mask_idx = spec_aug_mask_idx[0]
-
- spec_aug_mask_idx = np.concatenate(
- [spec_aug_mask_idx, np.ones(max_num_masked_span - num_masked_span, dtype=np.int32) * dummy_mask_idx]
- )
- spec_aug_mask_idxs.append(spec_aug_mask_idx)
-
- spec_aug_mask_idxs = np.array(spec_aug_mask_idxs)
-
- # expand masked indices to masked spans
- spec_aug_mask_idxs = np.broadcast_to(
- spec_aug_mask_idxs[:, :, None], (batch_size, max_num_masked_span, mask_length)
- )
- spec_aug_mask_idxs = spec_aug_mask_idxs.reshape(batch_size, max_num_masked_span * mask_length)
-
- # add offset to the starting indexes so that indexes now create a span
- offsets = np.arange(mask_length)[None, None, :]
- offsets = np.broadcast_to(offsets, (batch_size, max_num_masked_span, mask_length)).reshape(
- batch_size, max_num_masked_span * mask_length
- )
- spec_aug_mask_idxs = spec_aug_mask_idxs + offsets
-
- # ensure that we cannot have indices larger than sequence_length
- if spec_aug_mask_idxs.max() > sequence_length - 1:
- spec_aug_mask_idxs[spec_aug_mask_idxs > sequence_length - 1] = sequence_length - 1
-
- # scatter indices to mask
- np.put_along_axis(spec_aug_mask, spec_aug_mask_idxs, 1, -1)
-
- return spec_aug_mask
-
-
-# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2NoLayerNormConvLayer with Wav2Vec2->WavLM
-class WavLMNoLayerNormConvLayer(nn.Module):
- def __init__(self, config, layer_id=0):
+class WavLMSamePadLayer(nn.Module):
+ def __init__(self, num_conv_pos_embeddings):
super().__init__()
- self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
- self.out_conv_dim = config.conv_dim[layer_id]
-
- self.conv = nn.Conv1d(
- self.in_conv_dim,
- self.out_conv_dim,
- kernel_size=config.conv_kernel[layer_id],
- stride=config.conv_stride[layer_id],
- bias=config.conv_bias,
- )
- self.activation = ACT2FN[config.feat_extract_activation]
+ self.num_pad_remove = 1 if num_conv_pos_embeddings % 2 == 0 else 0
def forward(self, hidden_states):
- hidden_states = self.conv(hidden_states)
- hidden_states = self.activation(hidden_states)
+ if self.num_pad_remove > 0:
+ hidden_states = hidden_states[:, :, : -self.num_pad_remove]
return hidden_states
-# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2LayerNormConvLayer with Wav2Vec2->WavLM
-class WavLMLayerNormConvLayer(nn.Module):
- def __init__(self, config, layer_id=0):
- super().__init__()
- self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
- self.out_conv_dim = config.conv_dim[layer_id]
-
- self.conv = nn.Conv1d(
- self.in_conv_dim,
- self.out_conv_dim,
- kernel_size=config.conv_kernel[layer_id],
- stride=config.conv_stride[layer_id],
- bias=config.conv_bias,
- )
- self.layer_norm = nn.LayerNorm(self.out_conv_dim, elementwise_affine=True)
- self.activation = ACT2FN[config.feat_extract_activation]
-
- def forward(self, hidden_states):
- hidden_states = self.conv(hidden_states)
-
- hidden_states = hidden_states.transpose(-2, -1)
- hidden_states = self.layer_norm(hidden_states)
- hidden_states = hidden_states.transpose(-2, -1)
-
- hidden_states = self.activation(hidden_states)
- return hidden_states
-
-
-# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2GroupNormConvLayer with Wav2Vec2->WavLM
-class WavLMGroupNormConvLayer(nn.Module):
- def __init__(self, config, layer_id=0):
- super().__init__()
- self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
- self.out_conv_dim = config.conv_dim[layer_id]
-
- self.conv = nn.Conv1d(
- self.in_conv_dim,
- self.out_conv_dim,
- kernel_size=config.conv_kernel[layer_id],
- stride=config.conv_stride[layer_id],
- bias=config.conv_bias,
- )
- self.activation = ACT2FN[config.feat_extract_activation]
-
- self.layer_norm = nn.GroupNorm(num_groups=self.out_conv_dim, num_channels=self.out_conv_dim, affine=True)
-
- def forward(self, hidden_states):
- hidden_states = self.conv(hidden_states)
- hidden_states = self.layer_norm(hidden_states)
- hidden_states = self.activation(hidden_states)
- return hidden_states
-
-
-# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2PositionalConvEmbedding with Wav2Vec2->WavLM
class WavLMPositionalConvEmbedding(nn.Module):
def __init__(self, config):
super().__init__()
@@ -313,75 +99,6 @@ class WavLMPositionalConvEmbedding(nn.Module):
return hidden_states
-# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2SamePadLayer with Wav2Vec2->WavLM
-class WavLMSamePadLayer(nn.Module):
- def __init__(self, num_conv_pos_embeddings):
- super().__init__()
- self.num_pad_remove = 1 if num_conv_pos_embeddings % 2 == 0 else 0
-
- def forward(self, hidden_states):
- if self.num_pad_remove > 0:
- hidden_states = hidden_states[:, :, : -self.num_pad_remove]
- return hidden_states
-
-
-# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2FeatureEncoder with Wav2Vec2->WavLM
-class WavLMFeatureEncoder(nn.Module):
- """Construct the features from raw audio waveform"""
-
- def __init__(self, config):
- super().__init__()
-
- if config.feat_extract_norm == "group":
- conv_layers = [WavLMGroupNormConvLayer(config, layer_id=0)] + [
- WavLMNoLayerNormConvLayer(config, layer_id=i + 1) for i in range(config.num_feat_extract_layers - 1)
- ]
- elif config.feat_extract_norm == "layer":
- conv_layers = [WavLMLayerNormConvLayer(config, layer_id=i) for i in range(config.num_feat_extract_layers)]
- else:
- raise ValueError(
- f"`config.feat_extract_norm` is {config.feat_extract_norm}, but has to be one of ['group', 'layer']"
- )
- self.conv_layers = nn.ModuleList(conv_layers)
- self.gradient_checkpointing = False
- self._requires_grad = True
-
- def _freeze_parameters(self):
- for param in self.parameters():
- param.requires_grad = False
- self._requires_grad = False
-
- def forward(self, input_values):
- hidden_states = input_values[:, None]
-
- # make sure hidden_states require grad for gradient_checkpointing
- if self._requires_grad and self.training:
- hidden_states.requires_grad = True
-
- for conv_layer in self.conv_layers:
- if self._requires_grad and self.gradient_checkpointing and self.training:
- hidden_states = self._gradient_checkpointing_func(
- conv_layer.__call__,
- hidden_states,
- )
- else:
- hidden_states = conv_layer(hidden_states)
-
- return hidden_states
-
-
-class WavLMFeatureExtractor(WavLMFeatureEncoder):
- def __init__(self, config):
- super().__init__(config)
- warnings.warn(
- f"The class `{self.__class__.__name__}` has been depreciated "
- "and will be removed in Transformers v5. "
- f"Use `{self.__class__.__bases__[0].__name__}` instead.",
- FutureWarning,
- )
-
-
-# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2FeatureProjection with Wav2Vec2->WavLM
class WavLMFeatureProjection(nn.Module):
def __init__(self, config):
super().__init__()
@@ -563,7 +280,6 @@ class WavLMAttention(nn.Module):
return relative_buckets
-# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2FeedForward with Wav2Vec2->WavLM
class WavLMFeedForward(nn.Module):
def __init__(self, config):
super().__init__()
@@ -903,57 +619,6 @@ class WavLMGumbelVectorQuantizer(nn.Module):
return codevectors, perplexity
-# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Adapter with Wav2Vec2->WavLM
-class WavLMAdapter(nn.Module):
- def __init__(self, config):
- super().__init__()
-
- # feature dim might need to be down-projected
- if config.output_hidden_size != config.hidden_size:
- self.proj = nn.Linear(config.hidden_size, config.output_hidden_size)
- self.proj_layer_norm = nn.LayerNorm(config.output_hidden_size)
- else:
- self.proj = self.proj_layer_norm = None
-
- self.layers = nn.ModuleList(WavLMAdapterLayer(config) for _ in range(config.num_adapter_layers))
- self.layerdrop = config.layerdrop
-
- def forward(self, hidden_states):
- # down project hidden_states if necessary
- if self.proj is not None and self.proj_layer_norm is not None:
- hidden_states = self.proj(hidden_states)
- hidden_states = self.proj_layer_norm(hidden_states)
-
- hidden_states = hidden_states.transpose(1, 2)
-
- for layer in self.layers:
- layerdrop_prob = np.random.random()
- if not self.training or (layerdrop_prob > self.layerdrop):
- hidden_states = layer(hidden_states)
-
- hidden_states = hidden_states.transpose(1, 2)
- return hidden_states
-
-
-# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2AdapterLayer with Wav2Vec2->WavLM
-class WavLMAdapterLayer(nn.Module):
- def __init__(self, config):
- super().__init__()
- self.conv = nn.Conv1d(
- config.output_hidden_size,
- 2 * config.output_hidden_size,
- config.adapter_kernel_size,
- stride=config.adapter_stride,
- padding=1,
- )
-
- def forward(self, hidden_states):
- hidden_states = self.conv(hidden_states)
- hidden_states = nn.functional.glu(hidden_states, dim=1)
-
- return hidden_states
-
-
class WavLMPreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
@@ -964,6 +629,8 @@ class WavLMPreTrainedModel(PreTrainedModel):
base_model_prefix = "wavlm"
main_input_name = "input_values"
supports_gradient_checkpointing = True
+ _supports_flash_attn_2 = False
+ _supports_sdpa = False
def _init_weights(self, module):
"""Initialize the weights"""
@@ -1042,6 +709,293 @@ class WavLMPreTrainedModel(PreTrainedModel):
return attention_mask
+class WavLMNoLayerNormConvLayer(nn.Module):
+ def __init__(self, config, layer_id=0):
+ super().__init__()
+ self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
+ self.out_conv_dim = config.conv_dim[layer_id]
+
+ self.conv = nn.Conv1d(
+ self.in_conv_dim,
+ self.out_conv_dim,
+ kernel_size=config.conv_kernel[layer_id],
+ stride=config.conv_stride[layer_id],
+ bias=config.conv_bias,
+ )
+ self.activation = ACT2FN[config.feat_extract_activation]
+
+ def forward(self, hidden_states):
+ hidden_states = self.conv(hidden_states)
+ hidden_states = self.activation(hidden_states)
+ return hidden_states
+
+
+class WavLMLayerNormConvLayer(nn.Module):
+ def __init__(self, config, layer_id=0):
+ super().__init__()
+ self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
+ self.out_conv_dim = config.conv_dim[layer_id]
+
+ self.conv = nn.Conv1d(
+ self.in_conv_dim,
+ self.out_conv_dim,
+ kernel_size=config.conv_kernel[layer_id],
+ stride=config.conv_stride[layer_id],
+ bias=config.conv_bias,
+ )
+ self.layer_norm = nn.LayerNorm(self.out_conv_dim, elementwise_affine=True)
+ self.activation = ACT2FN[config.feat_extract_activation]
+
+ def forward(self, hidden_states):
+ hidden_states = self.conv(hidden_states)
+
+ hidden_states = hidden_states.transpose(-2, -1)
+ hidden_states = self.layer_norm(hidden_states)
+ hidden_states = hidden_states.transpose(-2, -1)
+
+ hidden_states = self.activation(hidden_states)
+ return hidden_states
+
+
+class WavLMGroupNormConvLayer(nn.Module):
+ def __init__(self, config, layer_id=0):
+ super().__init__()
+ self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
+ self.out_conv_dim = config.conv_dim[layer_id]
+
+ self.conv = nn.Conv1d(
+ self.in_conv_dim,
+ self.out_conv_dim,
+ kernel_size=config.conv_kernel[layer_id],
+ stride=config.conv_stride[layer_id],
+ bias=config.conv_bias,
+ )
+ self.activation = ACT2FN[config.feat_extract_activation]
+
+ self.layer_norm = nn.GroupNorm(num_groups=self.out_conv_dim, num_channels=self.out_conv_dim, affine=True)
+
+ def forward(self, hidden_states):
+ hidden_states = self.conv(hidden_states)
+ hidden_states = self.layer_norm(hidden_states)
+ hidden_states = self.activation(hidden_states)
+ return hidden_states
+
+
+class WavLMFeatureEncoder(nn.Module):
+ """Construct the features from raw audio waveform"""
+
+ def __init__(self, config):
+ super().__init__()
+
+ if config.feat_extract_norm == "group":
+ conv_layers = [WavLMGroupNormConvLayer(config, layer_id=0)] + [
+ WavLMNoLayerNormConvLayer(config, layer_id=i + 1) for i in range(config.num_feat_extract_layers - 1)
+ ]
+ elif config.feat_extract_norm == "layer":
+ conv_layers = [WavLMLayerNormConvLayer(config, layer_id=i) for i in range(config.num_feat_extract_layers)]
+ else:
+ raise ValueError(
+ f"`config.feat_extract_norm` is {config.feat_extract_norm}, but has to be one of ['group', 'layer']"
+ )
+ self.conv_layers = nn.ModuleList(conv_layers)
+ self.gradient_checkpointing = False
+ self._requires_grad = True
+
+ def _freeze_parameters(self):
+ for param in self.parameters():
+ param.requires_grad = False
+ self._requires_grad = False
+
+ def forward(self, input_values):
+ hidden_states = input_values[:, None]
+
+ # make sure hidden_states require grad for gradient_checkpointing
+ if self._requires_grad and self.training:
+ hidden_states.requires_grad = True
+
+ for conv_layer in self.conv_layers:
+ if self._requires_grad and self.gradient_checkpointing and self.training:
+ hidden_states = self._gradient_checkpointing_func(
+ conv_layer.__call__,
+ hidden_states,
+ )
+ else:
+ hidden_states = conv_layer(hidden_states)
+
+ return hidden_states
+
+
+class WavLMAdapterLayer(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.conv = nn.Conv1d(
+ config.output_hidden_size,
+ 2 * config.output_hidden_size,
+ config.adapter_kernel_size,
+ stride=config.adapter_stride,
+ padding=1,
+ )
+
+ def forward(self, hidden_states):
+ hidden_states = self.conv(hidden_states)
+ hidden_states = nn.functional.glu(hidden_states, dim=1)
+
+ return hidden_states
+
+
+class WavLMAdapter(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+
+ # feature dim might need to be down-projected
+ if config.output_hidden_size != config.hidden_size:
+ self.proj = nn.Linear(config.hidden_size, config.output_hidden_size)
+ self.proj_layer_norm = nn.LayerNorm(config.output_hidden_size)
+ else:
+ self.proj = self.proj_layer_norm = None
+
+ self.layers = nn.ModuleList(WavLMAdapterLayer(config) for _ in range(config.num_adapter_layers))
+ self.layerdrop = config.layerdrop
+
+ def forward(self, hidden_states):
+ # down project hidden_states if necessary
+ if self.proj is not None and self.proj_layer_norm is not None:
+ hidden_states = self.proj(hidden_states)
+ hidden_states = self.proj_layer_norm(hidden_states)
+
+ hidden_states = hidden_states.transpose(1, 2)
+
+ for layer in self.layers:
+ layerdrop_prob = np.random.random()
+ if not self.training or (layerdrop_prob > self.layerdrop):
+ hidden_states = layer(hidden_states)
+
+ hidden_states = hidden_states.transpose(1, 2)
+ return hidden_states
+
+
+def _compute_mask_indices(
+ shape: Tuple[int, int],
+ mask_prob: float,
+ mask_length: int,
+ attention_mask: Optional[torch.LongTensor] = None,
+ min_masks: int = 0,
+) -> np.ndarray:
+ """
+ Computes random mask spans for a given shape. Used to implement [SpecAugment: A Simple Data Augmentation Method for
+ ASR](https://arxiv.org/abs/1904.08779). Note that this method is not optimized to run on TPU and should be run on
+ CPU as part of the preprocessing during training.
+
+ Args:
+ shape: The shape for which to compute masks. This should be of a tuple of size 2 where
+ the first element is the batch size and the second element is the length of the axis to span.
+ mask_prob: The percentage of the whole axis (between 0 and 1) which will be masked. The number of
+ independently generated mask spans of length `mask_length` is computed by
+ `mask_prob*shape[1]/mask_length`. Note that due to overlaps, `mask_prob` is an upper bound and the
+ actual percentage will be smaller.
+ mask_length: size of the mask
+ min_masks: minimum number of masked spans
+ attention_mask: A (right-padded) attention mask which independently shortens the feature axis of
+ each batch dimension.
+ """
+ batch_size, sequence_length = shape
+
+ if mask_length < 1:
+ raise ValueError("`mask_length` has to be bigger than 0.")
+
+ if mask_length > sequence_length:
+ raise ValueError(
+ f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length}"
+ f" and `sequence_length`: {sequence_length}`"
+ )
+
+ # epsilon is used for probabilistic rounding
+ epsilon = np.random.rand(1).item()
+
+ def compute_num_masked_span(input_length):
+ """Given input length, compute how many spans should be masked"""
+ num_masked_span = int(mask_prob * input_length / mask_length + epsilon)
+ num_masked_span = max(num_masked_span, min_masks)
+
+ # make sure num masked span <= sequence_length
+ if num_masked_span * mask_length > sequence_length:
+ num_masked_span = sequence_length // mask_length
+
+ # make sure num_masked span is also <= input_length - (mask_length - 1)
+ if input_length - (mask_length - 1) < num_masked_span:
+ num_masked_span = max(input_length - (mask_length - 1), 0)
+
+ return num_masked_span
+
+ # compute number of masked spans in batch
+ input_lengths = (
+ attention_mask.detach().sum(-1).tolist()
+ if attention_mask is not None
+ else [sequence_length for _ in range(batch_size)]
+ )
+
+ # SpecAugment mask to fill
+ spec_aug_mask = np.zeros((batch_size, sequence_length), dtype=bool)
+ spec_aug_mask_idxs = []
+
+ max_num_masked_span = compute_num_masked_span(sequence_length)
+
+ if max_num_masked_span == 0:
+ return spec_aug_mask
+
+ for input_length in input_lengths:
+ # compute num of masked spans for this input
+ num_masked_span = compute_num_masked_span(input_length)
+
+ # get random indices to mask
+ spec_aug_mask_idx = np.random.choice(
+ np.arange(input_length - (mask_length - 1)), num_masked_span, replace=False
+ )
+
+ # pick first sampled index that will serve as a dummy index to pad vector
+ # to ensure same dimension for all batches due to probabilistic rounding
+ # Picking first sample just pads those vectors twice.
+ if len(spec_aug_mask_idx) == 0:
+ # this case can only happen if `input_length` is strictly smaller then
+ # `sequence_length` in which case the last token has to be a padding
+ # token which we can use as a dummy mask id
+ dummy_mask_idx = sequence_length - 1
+ else:
+ dummy_mask_idx = spec_aug_mask_idx[0]
+
+ spec_aug_mask_idx = np.concatenate(
+ [spec_aug_mask_idx, np.ones(max_num_masked_span - num_masked_span, dtype=np.int32) * dummy_mask_idx]
+ )
+ spec_aug_mask_idxs.append(spec_aug_mask_idx)
+
+ spec_aug_mask_idxs = np.array(spec_aug_mask_idxs)
+
+ # expand masked indices to masked spans
+ spec_aug_mask_idxs = np.broadcast_to(
+ spec_aug_mask_idxs[:, :, None], (batch_size, max_num_masked_span, mask_length)
+ )
+ spec_aug_mask_idxs = spec_aug_mask_idxs.reshape(batch_size, max_num_masked_span * mask_length)
+
+ # add offset to the starting indexes so that indexes now create a span
+ offsets = np.arange(mask_length)[None, None, :]
+ offsets = np.broadcast_to(offsets, (batch_size, max_num_masked_span, mask_length)).reshape(
+ batch_size, max_num_masked_span * mask_length
+ )
+ spec_aug_mask_idxs = spec_aug_mask_idxs + offsets
+
+ # ensure that we cannot have indices larger than sequence_length
+ if spec_aug_mask_idxs.max() > sequence_length - 1:
+ spec_aug_mask_idxs[spec_aug_mask_idxs > sequence_length - 1] = sequence_length - 1
+
+ # scatter indices to mask
+ np.put_along_axis(spec_aug_mask, spec_aug_mask_idxs, 1, -1)
+
+ return spec_aug_mask
+
+
+_EXPECTED_OUTPUT_SHAPE = [1, 292, 768]
+
+
WAVLM_START_DOCSTRING = r"""
WavLM was proposed in [WavLM: Unified Speech Representation Learning with Labeled and Unlabeled
Data](https://arxiv.org/abs/2110.13900) by Sanyuan Chen, Chengyi Wang, Zhengyang Chen, Yu Wu, Shujie Liu, Zhuo
@@ -1061,7 +1015,6 @@ WAVLM_START_DOCSTRING = r"""
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
-
WAVLM_INPUTS_DOCSTRING = r"""
Args:
input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
@@ -1098,12 +1051,13 @@ WAVLM_INPUTS_DOCSTRING = r"""
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
+WavLMBaseModelOutput = Wav2Vec2BaseModelOutput
+
@add_start_docstrings(
"The bare WavLM Model transformer outputting raw hidden-states without any specific head on top.",
WAVLM_START_DOCSTRING,
)
-# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Model with Wav2Vec2->WavLM, wav2vec2->wavlm, WAV_2_VEC_2->WAVLM, WavLMBaseModelOutput->Wav2Vec2BaseModelOutput
class WavLMModel(WavLMPreTrainedModel):
def __init__(self, config: WavLMConfig):
super().__init__(config)
@@ -1193,7 +1147,7 @@ class WavLMModel(WavLMPreTrainedModel):
@add_start_docstrings_to_model_forward(WAVLM_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
- output_type=Wav2Vec2BaseModelOutput,
+ output_type=WavLMBaseModelOutput,
config_class=_CONFIG_FOR_DOC,
modality="audio",
expected_output=_EXPECTED_OUTPUT_SHAPE,
@@ -1206,7 +1160,7 @@ class WavLMModel(WavLMPreTrainedModel):
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
- ) -> Union[Tuple, Wav2Vec2BaseModelOutput]:
+ ) -> Union[Tuple, WavLMBaseModelOutput]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
@@ -1243,7 +1197,7 @@ class WavLMModel(WavLMPreTrainedModel):
if not return_dict:
return (hidden_states, extract_features) + encoder_outputs[1:]
- return Wav2Vec2BaseModelOutput(
+ return WavLMBaseModelOutput(
last_hidden_state=hidden_states,
extract_features=extract_features,
hidden_states=encoder_outputs.hidden_states,
@@ -1251,11 +1205,16 @@ class WavLMModel(WavLMPreTrainedModel):
)
+_HIDDEN_STATES_START_POSITION = 2
+
+_CTC_EXPECTED_OUTPUT = "'mister quilter is the aposle of the middle classes and we are glad to welcome his gospel'"
+_CTC_EXPECTED_LOSS = 12.51
+
+
@add_start_docstrings(
"""WavLM Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC).""",
WAVLM_START_DOCSTRING,
)
-# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForCTC with Wav2Vec2->WavLM, wav2vec2->wavlm, WAV_2_VEC_2->WAVLM
class WavLMForCTC(WavLMPreTrainedModel):
def __init__(self, config, target_lang: Optional[str] = None):
super().__init__(config)
@@ -1432,7 +1391,6 @@ class WavLMForSequenceClassification(WavLMPreTrainedModel):
# Initialize weights and apply final processing
self.post_init()
- # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.freeze_feature_extractor
def freeze_feature_extractor(self):
"""
Calling this function will disable the gradient computation for the feature encoder so that its parameters will
@@ -1445,7 +1403,6 @@ class WavLMForSequenceClassification(WavLMPreTrainedModel):
)
self.freeze_feature_encoder()
- # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.freeze_feature_encoder with wav2vec2->wavlm
def freeze_feature_encoder(self):
"""
Calling this function will disable the gradient computation for the feature encoder so that its parameter will
@@ -1453,7 +1410,6 @@ class WavLMForSequenceClassification(WavLMPreTrainedModel):
"""
self.wavlm.feature_extractor._freeze_parameters()
- # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.freeze_base_model with wav2vec2->wavlm
def freeze_base_model(self):
"""
Calling this function will disable the gradient computation for the base model so that its parameters will not
@@ -1469,7 +1425,6 @@ class WavLMForSequenceClassification(WavLMPreTrainedModel):
config_class=_CONFIG_FOR_DOC,
modality="audio",
)
- # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.forward with Wav2Vec2->WavLM, wav2vec2->wavlm
def forward(
self,
input_values: Optional[torch.Tensor],
@@ -1533,13 +1488,16 @@ class WavLMForSequenceClassification(WavLMPreTrainedModel):
)
+_FRAME_CLASS_CHECKPOINT = "microsoft/wavlm-base-plus-sd"
+_FRAME_EXPECTED_OUTPUT = [0, 0]
+
+
@add_start_docstrings(
"""
WavLM Model with a frame classification head on top for tasks like Speaker Diarization.
""",
WAVLM_START_DOCSTRING,
)
-# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForAudioFrameClassification with Wav2Vec2->WavLM, wav2vec2->wavlm, WAV_2_VEC_2->WAVLM
class WavLMForAudioFrameClassification(WavLMPreTrainedModel):
def __init__(self, config):
super().__init__(config)
@@ -1646,7 +1604,6 @@ class WavLMForAudioFrameClassification(WavLMPreTrainedModel):
)
-# Copied from transformers.models.wav2vec2.modeling_wav2vec2.AMSoftmaxLoss
class AMSoftmaxLoss(nn.Module):
def __init__(self, input_dim, num_labels, scale=30.0, margin=0.4):
super(AMSoftmaxLoss, self).__init__()
@@ -1670,7 +1627,6 @@ class AMSoftmaxLoss(nn.Module):
return loss
-# Copied from transformers.models.wav2vec2.modeling_wav2vec2.TDNNLayer
class TDNNLayer(nn.Module):
def __init__(self, config, layer_id=0):
super().__init__()
@@ -1686,6 +1642,7 @@ class TDNNLayer(nn.Module):
if is_peft_available():
from peft.tuners.lora import LoraLayer
+ if is_peft_available():
if isinstance(self.kernel, LoraLayer):
warnings.warn(
"Detected LoRA on TDNNLayer. LoRA weights won't be applied due to optimization. "
@@ -1702,13 +1659,16 @@ class TDNNLayer(nn.Module):
return hidden_states
+_XVECTOR_CHECKPOINT = "microsoft/wavlm-base-plus-sv"
+_XVECTOR_EXPECTED_OUTPUT = 0.97
+
+
@add_start_docstrings(
"""
WavLM Model with an XVector feature extraction head on top for tasks like Speaker Verification.
""",
WAVLM_START_DOCSTRING,
)
-# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForXVector with Wav2Vec2->WavLM, wav2vec2->wavlm, WAV_2_VEC_2->WAVLM
class WavLMForXVector(WavLMPreTrainedModel):
def __init__(self, config):
super().__init__(config)
diff --git a/src/transformers/models/wavlm/modular_wavlm.py b/src/transformers/models/wavlm/modular_wavlm.py
new file mode 100644
index 00000000000..9ae9170fec5
--- /dev/null
+++ b/src/transformers/models/wavlm/modular_wavlm.py
@@ -0,0 +1,758 @@
+import math
+from typing import Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from ...integrations.deepspeed import is_deepspeed_zero3_enabled
+from ...integrations.fsdp import is_fsdp_managed_module
+from ...modeling_outputs import (
+ BaseModelOutput,
+ CausalLMOutput,
+ SequenceClassifierOutput,
+ TokenClassifierOutput,
+ Wav2Vec2BaseModelOutput,
+ XVectorOutput,
+)
+from ...modeling_utils import PreTrainedModel
+from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
+from ..wav2vec2.modeling_wav2vec2 import (
+ Wav2Vec2FeatureProjection,
+ Wav2Vec2FeedForward,
+ Wav2Vec2ForAudioFrameClassification,
+ Wav2Vec2ForCTC,
+ Wav2Vec2ForSequenceClassification,
+ Wav2Vec2ForXVector,
+ Wav2Vec2Model,
+ Wav2Vec2PositionalConvEmbedding,
+ Wav2Vec2PreTrainedModel,
+)
+from .configuration_wavlm import WavLMConfig
+
+
+logger = logging.get_logger(__name__)
+
+_CONFIG_FOR_DOC = "WavLMConfig"
+
+_CHECKPOINT_FOR_DOC = "patrickvonplaten/wavlm-libri-clean-100h-base-plus"
+_EXPECTED_OUTPUT_SHAPE = [1, 292, 768]
+
+_CTC_EXPECTED_OUTPUT = "'mister quilter is the aposle of the middle classes and we are glad to welcome his gospel'"
+_CTC_EXPECTED_LOSS = 12.51
+
+_FRAME_CLASS_CHECKPOINT = "microsoft/wavlm-base-plus-sd"
+_FRAME_EXPECTED_OUTPUT = [0, 0]
+
+_XVECTOR_CHECKPOINT = "microsoft/wavlm-base-plus-sv"
+_XVECTOR_EXPECTED_OUTPUT = 0.97
+
+
+class WavLMPositionalConvEmbedding(Wav2Vec2PositionalConvEmbedding):
+ pass
+
+
+class WavLMFeatureProjection(Wav2Vec2FeatureProjection):
+ pass
+
+
+class WavLMAttention(nn.Module):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ def __init__(
+ self,
+ embed_dim: int,
+ num_heads: int,
+ dropout: float = 0.0,
+ num_buckets: int = 320,
+ max_distance: int = 800,
+ has_relative_position_bias: bool = True,
+ ):
+ super().__init__()
+ self.embed_dim = embed_dim
+ self.num_heads = num_heads
+ self.dropout = dropout
+ self.head_dim = embed_dim // num_heads
+
+ if (self.head_dim * num_heads) != self.embed_dim:
+ raise ValueError(
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
+ f" and `num_heads`: {num_heads})."
+ )
+ self.scaling = self.head_dim**-0.5
+
+ self.k_proj = nn.Linear(embed_dim, embed_dim)
+ self.v_proj = nn.Linear(embed_dim, embed_dim)
+ self.q_proj = nn.Linear(embed_dim, embed_dim)
+ self.out_proj = nn.Linear(embed_dim, embed_dim)
+
+ self.num_buckets = num_buckets
+ self.max_distance = max_distance
+
+ self.gru_rel_pos_const = nn.Parameter(torch.ones(1, self.num_heads, 1, 1))
+ self.gru_rel_pos_linear = nn.Linear(self.head_dim, 8)
+
+ if has_relative_position_bias:
+ self.rel_attn_embed = nn.Embedding(self.num_buckets, self.num_heads)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_bias: Optional[torch.Tensor] = None,
+ output_attentions: bool = False,
+ index=0,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ """Attention layer with relative attention"""
+ bsz, tgt_len, _ = hidden_states.size()
+
+ # first pass of attention layer creates position bias
+ if position_bias is None:
+ position_bias = self.compute_bias(tgt_len, tgt_len)
+ position_bias = (
+ position_bias.unsqueeze(0).repeat(bsz, 1, 1, 1).view(bsz * self.num_heads, tgt_len, tgt_len)
+ )
+
+ # Compute relative position bias:
+ # 1) get reshape hidden_states
+ gated_hidden_states = hidden_states.view(hidden_states.shape[:-1] + (self.num_heads, -1))
+ gated_hidden_states = gated_hidden_states.permute(0, 2, 1, 3)
+
+ # 2) project hidden states
+ relative_position_proj = self.gru_rel_pos_linear(gated_hidden_states)
+ relative_position_proj = relative_position_proj.view(gated_hidden_states.shape[:-1] + (2, 4)).sum(-1)
+
+ # 3) compute gate for position bias from projected hidden states
+ gate_a, gate_b = torch.sigmoid(relative_position_proj).chunk(2, dim=-1)
+ gate_output = gate_a * (gate_b * self.gru_rel_pos_const - 1.0) + 2.0
+
+ # 4) apply gate to position bias to compute gated position_bias
+ gated_position_bias = gate_output.view(bsz * self.num_heads, -1, 1) * position_bias
+ gated_position_bias = gated_position_bias.view((-1, tgt_len, tgt_len))
+
+ attn_output, attn_weights = self.torch_multi_head_self_attention(
+ hidden_states, attention_mask, gated_position_bias, output_attentions
+ )
+
+ return attn_output, attn_weights, position_bias
+
+ def torch_multi_head_self_attention(
+ self,
+ hidden_states: torch.FloatTensor,
+ attention_mask: Union[torch.LongTensor, torch.BoolTensor],
+ gated_position_bias: torch.FloatTensor,
+ output_attentions: bool,
+ ) -> (torch.FloatTensor, torch.FloatTensor):
+ """simple wrapper around torch's multi_head_attention_forward function"""
+ # self-attention assumes q = k = v
+ query = key = value = hidden_states.transpose(0, 1)
+ key_padding_mask = attention_mask.ne(1) if attention_mask is not None else None
+
+ # disable bias and add_zero_attn
+ bias_k = bias_v = None
+ add_zero_attn = False
+
+ # PyTorch 1.3.0 has F.multi_head_attention_forward defined
+ # so no problem with backwards compatibility
+ attn_output, attn_weights = F.multi_head_attention_forward(
+ query,
+ key,
+ value,
+ self.embed_dim,
+ self.num_heads,
+ torch.empty([0]),
+ torch.cat((self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)),
+ bias_k,
+ bias_v,
+ add_zero_attn,
+ self.dropout,
+ self.out_proj.weight,
+ self.out_proj.bias,
+ self.training,
+ key_padding_mask,
+ output_attentions,
+ gated_position_bias,
+ use_separate_proj_weight=True,
+ q_proj_weight=self.q_proj.weight,
+ k_proj_weight=self.k_proj.weight,
+ v_proj_weight=self.v_proj.weight,
+ )
+
+ # [Seq_Len, Batch Size, ...] -> [Batch Size, Seq_Len, ...]
+ attn_output = attn_output.transpose(0, 1)
+
+ if attn_weights is not None:
+ # IMPORTANT: Attention weights are averaged weights
+ # here which should not be the case. This is an open issue
+ # on PyTorch: https://github.com/pytorch/pytorch/issues/32590
+ attn_weights = attn_weights[:, None].broadcast_to(
+ attn_weights.shape[:1] + (self.num_heads,) + attn_weights.shape[1:]
+ )
+
+ return attn_output, attn_weights
+
+ def compute_bias(self, query_length: int, key_length: int) -> torch.FloatTensor:
+ context_position = torch.arange(query_length, dtype=torch.long)[:, None]
+ memory_position = torch.arange(key_length, dtype=torch.long)[None, :]
+ relative_position = memory_position - context_position
+ relative_position_bucket = self._relative_positions_bucket(relative_position)
+ relative_position_bucket = relative_position_bucket.to(self.rel_attn_embed.weight.device)
+ values = self.rel_attn_embed(relative_position_bucket)
+ values = values.permute([2, 0, 1])
+ return values
+
+ def _relative_positions_bucket(self, relative_positions: torch.FloatTensor) -> torch.FloatTensor:
+ num_buckets = self.num_buckets // 2
+
+ relative_buckets = (relative_positions > 0).to(torch.long) * num_buckets
+ relative_positions = torch.abs(relative_positions)
+
+ max_exact = num_buckets // 2
+ is_small = relative_positions < max_exact
+
+ relative_positions_if_large = torch.log(relative_positions.float() / max_exact)
+ relative_positions_if_large = relative_positions_if_large / math.log(self.max_distance / max_exact)
+ relative_positions_if_large = relative_positions_if_large * (num_buckets - max_exact)
+ relative_position_if_large = (max_exact + relative_positions_if_large).to(torch.long)
+ relative_position_if_large = torch.min(
+ relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1)
+ )
+
+ relative_buckets += torch.where(is_small, relative_positions, relative_position_if_large)
+ return relative_buckets
+
+
+class WavLMFeedForward(Wav2Vec2FeedForward):
+ pass
+
+
+class WavLMEncoderLayer(nn.Module):
+ def __init__(self, config: WavLMConfig, has_relative_position_bias: bool = True):
+ super().__init__()
+ self.attention = WavLMAttention(
+ embed_dim=config.hidden_size,
+ num_heads=config.num_attention_heads,
+ dropout=config.attention_dropout,
+ num_buckets=config.num_buckets,
+ max_distance=config.max_bucket_distance,
+ has_relative_position_bias=has_relative_position_bias,
+ )
+ self.dropout = nn.Dropout(config.hidden_dropout)
+ self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.feed_forward = WavLMFeedForward(config)
+ self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+
+ def forward(self, hidden_states, attention_mask=None, position_bias=None, output_attentions=False, index=0):
+ attn_residual = hidden_states
+ hidden_states, attn_weights, position_bias = self.attention(
+ hidden_states,
+ attention_mask=attention_mask,
+ position_bias=position_bias,
+ output_attentions=output_attentions,
+ index=index,
+ )
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = attn_residual + hidden_states
+
+ hidden_states = self.layer_norm(hidden_states)
+
+ hidden_states = hidden_states + self.feed_forward(hidden_states)
+ hidden_states = self.final_layer_norm(hidden_states)
+
+ outputs = (hidden_states, position_bias)
+
+ if output_attentions:
+ outputs += (attn_weights,)
+
+ return outputs
+
+
+class WavLMEncoderLayerStableLayerNorm(nn.Module):
+ def __init__(self, config: WavLMConfig, has_relative_position_bias: bool = True):
+ super().__init__()
+ self.attention = WavLMAttention(
+ embed_dim=config.hidden_size,
+ num_heads=config.num_attention_heads,
+ dropout=config.attention_dropout,
+ num_buckets=config.num_buckets,
+ max_distance=config.max_bucket_distance,
+ has_relative_position_bias=has_relative_position_bias,
+ )
+ self.dropout = nn.Dropout(config.hidden_dropout)
+ self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.feed_forward = WavLMFeedForward(config)
+ self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+
+ def forward(self, hidden_states, attention_mask=None, position_bias=None, output_attentions=False):
+ attn_residual = hidden_states
+ hidden_states = self.layer_norm(hidden_states)
+ hidden_states, attn_weights, position_bias = self.attention(
+ hidden_states,
+ attention_mask=attention_mask,
+ position_bias=position_bias,
+ output_attentions=output_attentions,
+ )
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = attn_residual + hidden_states
+ hidden_states = hidden_states + self.feed_forward(self.final_layer_norm(hidden_states))
+
+ outputs = (hidden_states, position_bias)
+
+ if output_attentions:
+ outputs += (attn_weights,)
+
+ return outputs
+
+
+class WavLMEncoder(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.pos_conv_embed = WavLMPositionalConvEmbedding(config)
+ self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.dropout = nn.Dropout(config.hidden_dropout)
+ self.layers = nn.ModuleList(
+ [WavLMEncoderLayer(config, has_relative_position_bias=(i == 0)) for i in range(config.num_hidden_layers)]
+ )
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask=None,
+ output_attentions=False,
+ output_hidden_states=False,
+ return_dict=True,
+ ):
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attentions = () if output_attentions else None
+
+ if attention_mask is not None:
+ # make sure padded tokens output 0
+ expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
+ hidden_states[~expand_attention_mask] = 0
+
+ position_embeddings = self.pos_conv_embed(hidden_states)
+ hidden_states = hidden_states + position_embeddings
+ hidden_states = self.layer_norm(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+
+ synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self)
+ position_bias = None
+
+ for i, layer in enumerate(self.layers):
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
+ dropout_probability = torch.rand([])
+
+ skip_the_layer = self.training and i > 0 and (dropout_probability < self.config.layerdrop)
+ if not skip_the_layer or synced_gpus:
+ # under fsdp or deepspeed zero3 all gpus must run in sync
+ if self.gradient_checkpointing and self.training:
+ layer_outputs = self._gradient_checkpointing_func(
+ layer.__call__,
+ hidden_states,
+ attention_mask,
+ position_bias,
+ output_attentions,
+ )
+ else:
+ layer_outputs = layer(
+ hidden_states,
+ attention_mask=attention_mask,
+ position_bias=position_bias,
+ output_attentions=output_attentions,
+ index=i,
+ )
+
+ hidden_states, position_bias = layer_outputs[:2]
+
+ if skip_the_layer:
+ layer_outputs = (None, None, None)
+
+ if output_attentions:
+ all_self_attentions = all_self_attentions + (layer_outputs[2],)
+
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
+ return BaseModelOutput(
+ last_hidden_state=hidden_states,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attentions,
+ )
+
+
+class WavLMEncoderStableLayerNorm(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.pos_conv_embed = WavLMPositionalConvEmbedding(config)
+ self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.dropout = nn.Dropout(config.hidden_dropout)
+ self.layers = nn.ModuleList(
+ [
+ WavLMEncoderLayerStableLayerNorm(config, has_relative_position_bias=(i == 0))
+ for i in range(config.num_hidden_layers)
+ ]
+ )
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask=None,
+ output_attentions=False,
+ output_hidden_states=False,
+ return_dict=True,
+ ):
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attentions = () if output_attentions else None
+
+ if attention_mask is not None:
+ # make sure padded tokens are not attended to
+ expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
+ hidden_states[~expand_attention_mask] = 0
+
+ position_embeddings = self.pos_conv_embed(hidden_states)
+ hidden_states = hidden_states + position_embeddings
+ hidden_states = self.dropout(hidden_states)
+
+ synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self)
+ position_bias = None
+
+ for i, layer in enumerate(self.layers):
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
+ dropout_probability = torch.rand([])
+
+ skip_the_layer = self.training and i > 0 and (dropout_probability < self.config.layerdrop)
+ if not skip_the_layer or synced_gpus:
+ # under fsdp or deepspeed zero3 all gpus must run in sync
+ # XXX: could optimize this like synced_gpus in generate_utils but not sure if it's worth the code complication
+ if self.gradient_checkpointing and self.training:
+ layer_outputs = self._gradient_checkpointing_func(
+ layer.__call__,
+ hidden_states,
+ attention_mask,
+ position_bias,
+ output_attentions,
+ )
+ else:
+ layer_outputs = layer(
+ hidden_states,
+ attention_mask=attention_mask,
+ output_attentions=output_attentions,
+ position_bias=position_bias,
+ )
+ hidden_states, position_bias = layer_outputs[:2]
+
+ if skip_the_layer:
+ layer_outputs = (None, None, None)
+
+ if output_attentions:
+ all_self_attentions = all_self_attentions + (layer_outputs[2],)
+
+ hidden_states = self.layer_norm(hidden_states)
+
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
+ return BaseModelOutput(
+ last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_self_attentions
+ )
+
+
+class WavLMGumbelVectorQuantizer(nn.Module):
+ """
+ Vector quantization using gumbel softmax. See [CATEGORICAL REPARAMETERIZATION WITH
+ GUMBEL-SOFTMAX](https://arxiv.org/pdf/1611.01144.pdf) for more information.
+ """
+
+ def __init__(self, config):
+ super().__init__()
+ self.num_groups = config.num_codevector_groups
+ self.num_vars = config.num_codevectors_per_group
+
+ if config.codevector_dim % self.num_groups != 0:
+ raise ValueError(
+ f"`config.codevector_dim {config.codevector_dim} must be divisible"
+ f" by `config.num_codevector_groups` {self.num_groups} "
+ "for concatenation."
+ )
+
+ # storage for codebook variables (codewords)
+ self.codevectors = nn.Parameter(
+ torch.FloatTensor(1, self.num_groups * self.num_vars, config.codevector_dim // self.num_groups)
+ )
+ self.weight_proj = nn.Linear(config.conv_dim[-1], self.num_groups * self.num_vars)
+
+ # can be decayed for training
+ self.temperature = 2
+
+ @staticmethod
+ def _compute_perplexity(probs):
+ marginal_probs = probs.mean(dim=0)
+ perplexity = torch.exp(-torch.sum(marginal_probs * torch.log(marginal_probs + 1e-7), dim=-1)).sum()
+ return perplexity
+
+ def forward(self, hidden_states):
+ batch_size, sequence_length, hidden_size = hidden_states.shape
+
+ # project to codevector dim
+ hidden_states = self.weight_proj(hidden_states)
+ hidden_states = hidden_states.view(batch_size * sequence_length * self.num_groups, -1)
+
+ if self.training:
+ # sample code vector probs via gumbel in differentiateable way
+ codevector_probs = nn.functional.gumbel_softmax(hidden_states.float(), tau=self.temperature, hard=True)
+ codevector_probs = codevector_probs.type_as(hidden_states)
+
+ # compute perplexity
+ codevector_soft_dist = torch.softmax(
+ hidden_states.view(batch_size * sequence_length, self.num_groups, -1).float(), dim=-1
+ )
+ perplexity = self._compute_perplexity(codevector_soft_dist)
+ else:
+ # take argmax in non-differentiable way
+ # comptute hard codevector distribution (one hot)
+ codevector_idx = hidden_states.argmax(dim=-1)
+ codevector_probs = hidden_states.new_zeros(*hidden_states.shape).scatter_(
+ -1, codevector_idx.view(-1, 1), 1.0
+ )
+ codevector_probs = codevector_probs.view(batch_size * sequence_length, self.num_groups, -1)
+
+ perplexity = self._compute_perplexity(codevector_probs)
+
+ codevector_probs = codevector_probs.view(batch_size * sequence_length, -1)
+ # use probs to retrieve codevectors
+ codevectors_per_group = codevector_probs.unsqueeze(-1) * self.codevectors
+ codevectors = codevectors_per_group.view(batch_size * sequence_length, self.num_groups, self.num_vars, -1)
+ codevectors = codevectors.sum(-2).view(batch_size, sequence_length, -1)
+
+ return codevectors, perplexity
+
+
+class WavLMPreTrainedModel(PreTrainedModel, Wav2Vec2PreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = WavLMConfig
+ base_model_prefix = "wavlm"
+ main_input_name = "input_values"
+ supports_gradient_checkpointing = True
+ _supports_flash_attn_2 = False
+ _supports_sdpa = False
+
+ def _init_weights(self, module):
+ """Initialize the weights"""
+ # gumbel softmax requires special init
+ if isinstance(module, WavLMGumbelVectorQuantizer):
+ module.weight_proj.weight.data.normal_(mean=0.0, std=1)
+ module.weight_proj.bias.data.zero_()
+ nn.init.uniform_(module.codevectors)
+ elif isinstance(module, WavLMPositionalConvEmbedding):
+ nn.init.normal_(
+ module.conv.weight,
+ mean=0,
+ std=2 * math.sqrt(1 / (module.conv.kernel_size[0] * module.conv.in_channels)),
+ )
+ nn.init.constant_(module.conv.bias, 0)
+ elif isinstance(module, WavLMFeatureProjection):
+ k = math.sqrt(1 / module.projection.in_features)
+ nn.init.uniform_(module.projection.weight, a=-k, b=k)
+ nn.init.uniform_(module.projection.bias, a=-k, b=k)
+ elif isinstance(module, nn.Linear):
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+ elif isinstance(module, nn.Conv1d):
+ nn.init.kaiming_normal_(module.weight)
+
+ if module.bias is not None:
+ k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0]))
+ nn.init.uniform_(module.bias, a=-k, b=k)
+
+ def _get_adapters(self):
+ raise AttributeError("Not needed for WavLM")
+
+ def init_adapter_layers(self):
+ raise AttributeError("Not needed for WavLM")
+
+ def load_adapter(self):
+ raise AttributeError("Not needed for WavLM")
+
+
+WAVLM_START_DOCSTRING = r"""
+ WavLM was proposed in [WavLM: Unified Speech Representation Learning with Labeled and Unlabeled
+ Data](https://arxiv.org/abs/2110.13900) by Sanyuan Chen, Chengyi Wang, Zhengyang Chen, Yu Wu, Shujie Liu, Zhuo
+ Chen, Jinyu Li, Naoyuki Kanda, Takuya Yoshioka, Xiong Xiao, Jian Wu, Long Zhou, Shuo Ren, Yanmin Qian, Yao Qian,
+ Jian Wu, Michael Zeng, Xiangzhan Yu, Furu Wei.
+
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
+ library implements for all its model (such as downloading or saving etc.).
+
+ This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use
+ it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
+ behavior.
+
+ Parameters:
+ config ([`WavLMConfig`]): Model configuration class with all the parameters of the model.
+ Initializing with a config file does not load the weights associated with the model, only the
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+WAVLM_INPUTS_DOCSTRING = r"""
+ Args:
+ input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
+ Float values of input raw speech waveform. Values can be obtained by loading a `.flac` or `.wav` audio file
+ into an array of type `List[float]` or a `numpy.ndarray`, *e.g.* via the soundfile library (`pip install
+ soundfile`). To prepare the array into `input_values`, the [`AutoProcessor`] should be used for padding and
+ conversion into a tensor of type `torch.FloatTensor`. See [`Wav2Vec2Processor.__call__`] for details.
+ attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing convolution and attention on padding token indices. Mask values selected in `[0,
+ 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+
+
+
+ `attention_mask` should only be passed if the corresponding processor has `config.return_attention_mask ==
+ True`. For all models whose processor has `config.return_attention_mask == False`, `attention_mask` should
+ **not** be passed to avoid degraded performance when doing batched inference. For such models
+ `input_values` should simply be padded with 0 and passed without `attention_mask`. Be aware that these
+ models also yield slightly different results depending on whether `input_values` is padded or not.
+
+
+
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+
+WavLMBaseModelOutput = Wav2Vec2BaseModelOutput
+
+
+@add_start_docstrings(
+ "The bare WavLM Model transformer outputting raw hidden-states without any specific head on top.",
+ WAVLM_START_DOCSTRING,
+)
+class WavLMModel(Wav2Vec2Model):
+ @add_start_docstrings_to_model_forward(WAVLM_INPUTS_DOCSTRING)
+ @add_code_sample_docstrings(
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=WavLMBaseModelOutput,
+ config_class=_CONFIG_FOR_DOC,
+ modality="audio",
+ expected_output=_EXPECTED_OUTPUT_SHAPE,
+ )
+ def forward(self, **super_kwargs):
+ return super().forward(**super_kwargs)
+
+
+@add_start_docstrings(
+ """WavLM Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC).""",
+ WAVLM_START_DOCSTRING,
+)
+class WavLMForCTC(Wav2Vec2ForCTC):
+ @add_start_docstrings_to_model_forward(WAVLM_INPUTS_DOCSTRING)
+ @add_code_sample_docstrings(
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=CausalLMOutput,
+ config_class=_CONFIG_FOR_DOC,
+ expected_output=_CTC_EXPECTED_OUTPUT,
+ expected_loss=_CTC_EXPECTED_LOSS,
+ )
+ def forward(self, **super_kwargs):
+ super().forward(**super_kwargs)
+
+
+@add_start_docstrings(
+ """
+ WavLM Model with a sequence classification head on top (a linear layer over the pooled output) for tasks like
+ SUPERB Keyword Spotting.
+ """,
+ WAVLM_START_DOCSTRING,
+)
+class WavLMForSequenceClassification(Wav2Vec2ForSequenceClassification):
+ @add_start_docstrings_to_model_forward(WAVLM_INPUTS_DOCSTRING)
+ @add_code_sample_docstrings(
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=SequenceClassifierOutput,
+ config_class=_CONFIG_FOR_DOC,
+ modality="audio",
+ )
+ def forward(self, **super_kwargs):
+ super().forward(**super_kwargs)
+
+
+@add_start_docstrings(
+ """
+ WavLM Model with a frame classification head on top for tasks like Speaker Diarization.
+ """,
+ WAVLM_START_DOCSTRING,
+)
+class WavLMForAudioFrameClassification(Wav2Vec2ForAudioFrameClassification):
+ @add_start_docstrings_to_model_forward(WAVLM_INPUTS_DOCSTRING)
+ @add_code_sample_docstrings(
+ checkpoint=_FRAME_CLASS_CHECKPOINT,
+ output_type=TokenClassifierOutput,
+ config_class=_CONFIG_FOR_DOC,
+ modality="audio",
+ expected_output=_FRAME_EXPECTED_OUTPUT,
+ )
+ def forward(self, **super_kwargs):
+ super().forward(**super_kwargs)
+
+
+@add_start_docstrings(
+ """
+ WavLM Model with an XVector feature extraction head on top for tasks like Speaker Verification.
+ """,
+ WAVLM_START_DOCSTRING,
+)
+class WavLMForXVector(Wav2Vec2ForXVector):
+ pass
+
+ @add_start_docstrings_to_model_forward(WAVLM_INPUTS_DOCSTRING)
+ @add_code_sample_docstrings(
+ checkpoint=_XVECTOR_CHECKPOINT,
+ output_type=XVectorOutput,
+ config_class=_CONFIG_FOR_DOC,
+ modality="audio",
+ expected_output=_XVECTOR_EXPECTED_OUTPUT,
+ )
+ def forward(self, **super_kwargs):
+ super().forward(**super_kwargs)
+
+
+__all__ = [
+ "WavLMForAudioFrameClassification",
+ "WavLMForCTC",
+ "WavLMForSequenceClassification",
+ "WavLMForXVector",
+ "WavLMModel",
+ "WavLMPreTrainedModel",
+]
diff --git a/tests/models/wav2vec2_bert/test_modeling_wav2vec2_bert.py b/tests/models/wav2vec2_bert/test_modeling_wav2vec2_bert.py
index 4cce51d7ace..7f1e1b92242 100644
--- a/tests/models/wav2vec2_bert/test_modeling_wav2vec2_bert.py
+++ b/tests/models/wav2vec2_bert/test_modeling_wav2vec2_bert.py
@@ -50,9 +50,9 @@ if is_torch_available():
Wav2Vec2BertForXVector,
Wav2Vec2BertModel,
)
+ from transformers.models.wav2vec2.modeling_wav2vec2 import _sample_negative_indices
from transformers.models.wav2vec2_bert.modeling_wav2vec2_bert import (
_compute_mask_indices,
- _sample_negative_indices,
)
diff --git a/tests/models/wav2vec2_conformer/test_modeling_wav2vec2_conformer.py b/tests/models/wav2vec2_conformer/test_modeling_wav2vec2_conformer.py
index f276b13d7be..6a884ba36ba 100644
--- a/tests/models/wav2vec2_conformer/test_modeling_wav2vec2_conformer.py
+++ b/tests/models/wav2vec2_conformer/test_modeling_wav2vec2_conformer.py
@@ -55,10 +55,10 @@ if is_torch_available():
Wav2Vec2FeatureExtractor,
Wav2Vec2Processor,
)
+ from transformers.models.wav2vec2.modeling_wav2vec2 import _sample_negative_indices
from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer import (
Wav2Vec2ConformerGumbelVectorQuantizer,
_compute_mask_indices,
- _sample_negative_indices,
)
diff --git a/utils/modular_model_converter.py b/utils/modular_model_converter.py
index 55e71c4cc91..0962056270d 100644
--- a/utils/modular_model_converter.py
+++ b/utils/modular_model_converter.py
@@ -79,8 +79,11 @@ def preserve_case_replace(text, patterns: dict, default_name: str):
def get_cased_name(lowercase_name: str) -> str:
"""From a model name in lowercase in the format `my_model`, return the cased name in the format `MyModel`."""
+ alt_lowercase_name = lowercase_name.replace("_", "-")
if lowercase_name in CONFIG_MAPPING_NAMES:
return CONFIG_MAPPING_NAMES[lowercase_name].replace("Config", "")
+ elif alt_lowercase_name in CONFIG_MAPPING_NAMES:
+ return CONFIG_MAPPING_NAMES[alt_lowercase_name].replace("Config", "")
else:
return "".join(x.title() for x in lowercase_name.split("_"))
@@ -106,6 +109,8 @@ class ReplaceNameTransformer(m.MatcherDecoratableTransformer):
def __init__(self, old_name: str, new_name: str, original_new_model_name: str = "", only_doc: bool = False):
super().__init__()
+ old_name = old_name.replace("-", "_")
+ new_name = new_name.replace("-", "_")
self.old_name = old_name
self.new_name = new_name
self.cased_new_name = get_cased_name(self.new_name)
@@ -535,7 +540,7 @@ def find_all_dependencies(
# Top-level variables that match the following patterns will always use the value in the `modular_xxx.py` file
-ASSIGNMENTS_REGEX_TO_KEEP = [r"_CHECKPOINT", r"_EXPECTED", r"_FOR_DOC"]
+ASSIGNMENTS_REGEX_TO_KEEP = [r"_CHECKPOINT", r"_EXPECTED", r"_FOR_DOC", r"_HIDDEN_STATES_START_POSITION"]
# Top-level variables that match the following patterns will use the value in the `modular_xxx.py` file only if they are not None
ASSIGNMENTS_REGEX_TO_KEEP_IF_NOT_NONE = [r"_DOCSTRING"]
@@ -616,6 +621,7 @@ class ModuleMapper(CSTVisitor, ABC):
self.object_dependency_mapping = defaultdict(set) # immediate function/assignment dependency mapping (i.e. dependencies immediately in the function/assignment definition)
self.assignments: Dict[str, cst.SimpleStatementLine] = {} # mapping of global assignments names to Nodes
self.current_function = None # this keeps track of the current module-scope function
+ self.current_class = None # this keeps track of the current module-scope class
self.current_assignment = None # this keeps track of the current module-scope assignment
# this keeps track of objects imported from modeling files (`from .configuration import Config`) -> `Config` should not be a dependency
self.objects_imported_from_modeling = set()
@@ -672,7 +678,7 @@ class ModuleMapper(CSTVisitor, ABC):
def visit_If(self, node):
# If we are inside a function, do not add the import to the list of imports
- if self.current_function is None:
+ if self.current_function is None and self.current_class is None:
for stmt in node.body.body:
if m.matches(stmt, m.SimpleStatementLine(body=[m.ImportFrom() | m.Import()])):
self.imports.append(node)
@@ -680,6 +686,10 @@ class ModuleMapper(CSTVisitor, ABC):
def visit_ClassDef(self, node: ClassDef) -> None:
"""Record class nodes to create their dependencies at the end."""
self.classes[node.name.value] = node
+ self.current_class = node.name.value
+
+ def leave_ClassDef(self, node):
+ self.current_class = None
def visit_Name(self, node: cst.Call):
"""This is used to create a mapping from module-scope functions and assignments to objects used inside them."""
@@ -1024,11 +1034,20 @@ def replace_class_node(
new_decorators = (
updated_methods[name].decorators if len(updated_methods[name].decorators) > 0 else func.decorators
)
+
+ # Keep return annotation in `modular_xxx.py` if any, else original return annotation
+ new_return_annotation = updated_methods[name].returns if updated_methods[name].returns else func.returns
+
if not re.match(
r"\ndef .*\(.*\):\n raise.*Error\(.*",
mapper.python_module.code_for_node(updated_methods[name]),
):
- func = func.with_changes(body=updated_methods[name].body, params=new_params, decorators=new_decorators)
+ func = func.with_changes(
+ body=updated_methods[name].body,
+ params=new_params,
+ decorators=new_decorators,
+ returns=new_return_annotation,
+ )
else:
continue
@@ -1136,7 +1155,7 @@ def append_new_import_node(
import_node = node.body[0]
names_to_keep = []
for name in import_node.names:
- name_value = name.evaluated_name
+ name_value = name.evaluated_alias or name.evaluated_name
if name_value not in unused_imports and name_value not in added_names:
names_to_keep.append(name.with_changes(comma=cst.MaybeSentinel.DEFAULT))
added_names.add(name_value)