Introduce modular files for speech models (#35902)

* WAV_2_VEC_2 to WAV2VEC2

* added modular files for hubert, wavlm, wav2vec2_bert, data2vec_audio

* remove unnessary definitions in modulars

* added modular files for UniSpeech, UniSpeechSat, Wav2Vec2Conformer

* docstring fix for UniSpeechForCTC

* removed unneccessary re-definition of modular classes

* reverted lazy imports change on modular_model_converter, type-alias for Wav2Vec2BaseModelOutput

* top-level import of deepspeed in seamless_m4t, speecht5

* avoid tracking imports inside classes, relocate lazy deepspeed, peft imports in their original locations

* convert modular

* tiny modular typing fixes

* some more modular fixes

* make style

---------

Co-authored-by: eustlb <94853470+eustlb@users.noreply.github.com>
Co-authored-by: Eustache Le Bihan <eulebihan@gmail.com>
This commit is contained in:
Nikos Antoniou 2025-04-04 12:46:27 +03:00 committed by GitHub
parent d130cd0e16
commit f74d7da836
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
43 changed files with 6690 additions and 2216 deletions

View File

@ -307,7 +307,12 @@ extras["hub-kernels"] = deps_list("kernels")
extras["integrations"] = extras["hub-kernels"] + extras["optuna"] + extras["ray"] + extras["sigopt"]
extras["serving"] = deps_list("pydantic", "uvicorn", "fastapi", "starlette")
extras["audio"] = deps_list("librosa", "pyctcdecode", "phonemizer", "kenlm@git+https://github.com/ydshieh/kenlm@78f664fb3dafe1468d868d71faf19534530698d5")
extras["audio"] = deps_list(
"librosa",
"pyctcdecode",
"phonemizer",
"kenlm@git+https://github.com/ydshieh/kenlm@78f664fb3dafe1468d868d71faf19534530698d5",
)
# `pip install ".[speech]"` is deprecated and `pip install ".[torch-speech]"` should be used instead
extras["speech"] = deps_list("torchaudio") + extras["audio"]
extras["torch-speech"] = deps_list("torchaudio") + extras["audio"]

View File

@ -14,7 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
from typing import Callable, Optional, Tuple, Union
from typing import Callable, Optional, Tuple
import torch
import torch.nn as nn
@ -465,7 +465,7 @@ class Cohere2Model(Gemma2Model):
cache_position: Optional[torch.LongTensor] = None,
last_cache_position: Optional[int] = None,
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
) -> Union[Tuple, BaseModelOutputWithPast]:
) -> BaseModelOutputWithPast:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states

View File

@ -1,26 +1,15 @@
# coding=utf-8
# Copyright 2021 The Fairseq Authors and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch Data2VecAudio model."""
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# This file was automatically generated from src/transformers/models/data2vec/modular_data2vec_audio.py.
# Do NOT edit this file manually as any edits will be overwritten by the generation of
# the file from the modular. If any change should be done, please apply the change to the
# modular_data2vec_audio.py file directly. One of our CI enforces this.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
import math
import warnings
from typing import Optional, Tuple, Union
import numpy as np
import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import CrossEntropyLoss
@ -50,141 +39,14 @@ from .configuration_data2vec_audio import Data2VecAudioConfig
if is_flash_attn_available():
from ...modeling_flash_attention_utils import _flash_attention_forward
logger = logging.get_logger(__name__)
_HIDDEN_STATES_START_POSITION = 2
# General docstring
_CONFIG_FOR_DOC = "Data2VecAudioConfig"
# Base docstring
_CHECKPOINT_FOR_DOC = "facebook/data2vec-audio-base-960h"
_EXPECTED_OUTPUT_SHAPE = [1, 292, 768]
# CTC docstring
_CTC_EXPECTED_OUTPUT = "'MISTER QUILTER IS THE APOSTLE OF THE MIDDLE CLASSES AND WE ARE GLAD TO WELCOME HIS GOSPEL'"
_CTC_EXPECTED_LOSS = 66.95
# Copied from transformers.models.wav2vec2.modeling_wav2vec2._compute_mask_indices
def _compute_mask_indices(
shape: Tuple[int, int],
mask_prob: float,
mask_length: int,
attention_mask: Optional[torch.LongTensor] = None,
min_masks: int = 0,
) -> np.ndarray:
"""
Computes random mask spans for a given shape. Used to implement [SpecAugment: A Simple Data Augmentation Method for
ASR](https://arxiv.org/abs/1904.08779). Note that this method is not optimized to run on TPU and should be run on
CPU as part of the preprocessing during training.
Args:
shape: The shape for which to compute masks. This should be of a tuple of size 2 where
the first element is the batch size and the second element is the length of the axis to span.
mask_prob: The percentage of the whole axis (between 0 and 1) which will be masked. The number of
independently generated mask spans of length `mask_length` is computed by
`mask_prob*shape[1]/mask_length`. Note that due to overlaps, `mask_prob` is an upper bound and the
actual percentage will be smaller.
mask_length: size of the mask
min_masks: minimum number of masked spans
attention_mask: A (right-padded) attention mask which independently shortens the feature axis of
each batch dimension.
"""
batch_size, sequence_length = shape
if mask_length < 1:
raise ValueError("`mask_length` has to be bigger than 0.")
if mask_length > sequence_length:
raise ValueError(
f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length}"
f" and `sequence_length`: {sequence_length}`"
)
# epsilon is used for probabilistic rounding
epsilon = np.random.rand(1).item()
def compute_num_masked_span(input_length):
"""Given input length, compute how many spans should be masked"""
num_masked_span = int(mask_prob * input_length / mask_length + epsilon)
num_masked_span = max(num_masked_span, min_masks)
# make sure num masked span <= sequence_length
if num_masked_span * mask_length > sequence_length:
num_masked_span = sequence_length // mask_length
# make sure num_masked span is also <= input_length - (mask_length - 1)
if input_length - (mask_length - 1) < num_masked_span:
num_masked_span = max(input_length - (mask_length - 1), 0)
return num_masked_span
# compute number of masked spans in batch
input_lengths = (
attention_mask.detach().sum(-1).tolist()
if attention_mask is not None
else [sequence_length for _ in range(batch_size)]
)
# SpecAugment mask to fill
spec_aug_mask = np.zeros((batch_size, sequence_length), dtype=bool)
spec_aug_mask_idxs = []
max_num_masked_span = compute_num_masked_span(sequence_length)
if max_num_masked_span == 0:
return spec_aug_mask
for input_length in input_lengths:
# compute num of masked spans for this input
num_masked_span = compute_num_masked_span(input_length)
# get random indices to mask
spec_aug_mask_idx = np.random.choice(
np.arange(input_length - (mask_length - 1)), num_masked_span, replace=False
)
# pick first sampled index that will serve as a dummy index to pad vector
# to ensure same dimension for all batches due to probabilistic rounding
# Picking first sample just pads those vectors twice.
if len(spec_aug_mask_idx) == 0:
# this case can only happen if `input_length` is strictly smaller then
# `sequence_length` in which case the last token has to be a padding
# token which we can use as a dummy mask id
dummy_mask_idx = sequence_length - 1
else:
dummy_mask_idx = spec_aug_mask_idx[0]
spec_aug_mask_idx = np.concatenate(
[spec_aug_mask_idx, np.ones(max_num_masked_span - num_masked_span, dtype=np.int32) * dummy_mask_idx]
)
spec_aug_mask_idxs.append(spec_aug_mask_idx)
spec_aug_mask_idxs = np.array(spec_aug_mask_idxs)
# expand masked indices to masked spans
spec_aug_mask_idxs = np.broadcast_to(
spec_aug_mask_idxs[:, :, None], (batch_size, max_num_masked_span, mask_length)
)
spec_aug_mask_idxs = spec_aug_mask_idxs.reshape(batch_size, max_num_masked_span * mask_length)
# add offset to the starting indexes so that indexes now create a span
offsets = np.arange(mask_length)[None, None, :]
offsets = np.broadcast_to(offsets, (batch_size, max_num_masked_span, mask_length)).reshape(
batch_size, max_num_masked_span * mask_length
)
spec_aug_mask_idxs = spec_aug_mask_idxs + offsets
# ensure that we cannot have indices larger than sequence_length
if spec_aug_mask_idxs.max() > sequence_length - 1:
spec_aug_mask_idxs[spec_aug_mask_idxs > sequence_length - 1] = sequence_length - 1
# scatter indices to mask
np.put_along_axis(spec_aug_mask, spec_aug_mask_idxs, 1, -1)
return spec_aug_mask
# General docstring
_CONFIG_FOR_DOC = "Data2VecAudioConfig"
class Data2VecAudioConvLayer(nn.Module):
@ -214,7 +76,6 @@ class Data2VecAudioConvLayer(nn.Module):
return hidden_states
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2SamePadLayer with Wav2Vec2->Data2VecAudio
class Data2VecAudioPadLayer(nn.Module):
def __init__(self, num_conv_pos_embeddings):
super().__init__()
@ -279,13 +140,11 @@ class Data2VecAudioFeatureEncoder(nn.Module):
self.gradient_checkpointing = False
self._requires_grad = True
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2FeatureEncoder._freeze_parameters
def _freeze_parameters(self):
for param in self.parameters():
param.requires_grad = False
self._requires_grad = False
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2FeatureEncoder.forward
def forward(self, input_values):
hidden_states = input_values[:, None]
@ -305,7 +164,6 @@ class Data2VecAudioFeatureEncoder(nn.Module):
return hidden_states
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2FeatureProjection with Wav2Vec2->Data2VecAudio
class Data2VecAudioFeatureProjection(nn.Module):
def __init__(self, config):
super().__init__()
@ -321,7 +179,6 @@ class Data2VecAudioFeatureProjection(nn.Module):
return hidden_states, norm_hidden_states
# Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->Data2VecAudio
class Data2VecAudioAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
@ -480,7 +337,6 @@ class Data2VecAudioAttention(nn.Module):
return attn_output, attn_weights_reshaped, past_key_value
# Copied from transformers.models.bart.modeling_bart.BartFlashAttention2 with Bart->Data2VecAudio
class Data2VecAudioFlashAttention2(Data2VecAudioAttention):
"""
Data2VecAudio flash attention module. This module inherits from `Data2VecAudioAttention` as the weights of the module stays
@ -608,7 +464,6 @@ class Data2VecAudioFlashAttention2(Data2VecAudioAttention):
class Data2VecAudioSdpaAttention(Data2VecAudioAttention):
# Copied from transformers.models.bart.modeling_bart.BartSdpaAttention.forward with Bart->Data2VecAudio
def forward(
self,
hidden_states: torch.Tensor,
@ -714,14 +569,6 @@ class Data2VecAudioSdpaAttention(Data2VecAudioAttention):
return attn_output, None, past_key_value
DATA2VEC2AUDIO_ATTENTION_CLASSES = {
"eager": Data2VecAudioAttention,
"sdpa": Data2VecAudioSdpaAttention,
"flash_attention_2": Data2VecAudioFlashAttention2,
}
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2FeedForward with Wav2Vec2->Data2VecAudio
class Data2VecAudioFeedForward(nn.Module):
def __init__(self, config):
super().__init__()
@ -746,11 +593,17 @@ class Data2VecAudioFeedForward(nn.Module):
return hidden_states
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2EncoderLayer with Wav2Vec2->Data2VecAudio, WAV2VEC2->DATA2VEC2AUDIO
DATA2VEC_AUDIO_ATTENTION_CLASSES = {
"eager": Data2VecAudioAttention,
"sdpa": Data2VecAudioSdpaAttention,
"flash_attention_2": Data2VecAudioFlashAttention2,
}
class Data2VecAudioEncoderLayer(nn.Module):
def __init__(self, config):
super().__init__()
self.attention = DATA2VEC2AUDIO_ATTENTION_CLASSES[config._attn_implementation](
self.attention = DATA2VEC_AUDIO_ATTENTION_CLASSES[config._attn_implementation](
embed_dim=config.hidden_size,
num_heads=config.num_attention_heads,
dropout=config.attention_dropout,
@ -782,7 +635,6 @@ class Data2VecAudioEncoderLayer(nn.Module):
return outputs
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Encoder with Wav2Vec2->Data2VecAudio
class Data2VecAudioEncoder(nn.Module):
def __init__(self, config):
super().__init__()
@ -868,7 +720,24 @@ class Data2VecAudioEncoder(nn.Module):
)
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Adapter with Wav2Vec2->Data2VecAudio
class Data2VecAudioAdapterLayer(nn.Module):
def __init__(self, config):
super().__init__()
self.conv = nn.Conv1d(
config.output_hidden_size,
2 * config.output_hidden_size,
config.adapter_kernel_size,
stride=config.adapter_stride,
padding=1,
)
def forward(self, hidden_states):
hidden_states = self.conv(hidden_states)
hidden_states = nn.functional.glu(hidden_states, dim=1)
return hidden_states
class Data2VecAudioAdapter(nn.Module):
def __init__(self, config):
super().__init__()
@ -900,25 +769,6 @@ class Data2VecAudioAdapter(nn.Module):
return hidden_states
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2AdapterLayer with Wav2Vec2->Data2VecAudio
class Data2VecAudioAdapterLayer(nn.Module):
def __init__(self, config):
super().__init__()
self.conv = nn.Conv1d(
config.output_hidden_size,
2 * config.output_hidden_size,
config.adapter_kernel_size,
stride=config.adapter_stride,
padding=1,
)
def forward(self, hidden_states):
hidden_states = self.conv(hidden_states)
hidden_states = nn.functional.glu(hidden_states, dim=1)
return hidden_states
class Data2VecAudioPreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
@ -957,7 +807,6 @@ class Data2VecAudioPreTrainedModel(PreTrainedModel):
k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0]))
nn.init.uniform_(module.bias, a=-k, b=k)
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2PreTrainedModel._get_feat_extract_output_lengths with
def _get_feat_extract_output_lengths(
self, input_lengths: Union[torch.LongTensor, int], add_adapter: Optional[bool] = None
):
@ -981,7 +830,6 @@ class Data2VecAudioPreTrainedModel(PreTrainedModel):
return input_lengths
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2PreTrainedModel._get_feature_vector_attention_mask
def _get_feature_vector_attention_mask(
self, feature_vector_length: int, attention_mask: torch.LongTensor, add_adapter=None
):
@ -1003,6 +851,128 @@ class Data2VecAudioPreTrainedModel(PreTrainedModel):
return attention_mask
def _compute_mask_indices(
shape: Tuple[int, int],
mask_prob: float,
mask_length: int,
attention_mask: Optional[torch.LongTensor] = None,
min_masks: int = 0,
) -> np.ndarray:
"""
Computes random mask spans for a given shape. Used to implement [SpecAugment: A Simple Data Augmentation Method for
ASR](https://arxiv.org/abs/1904.08779). Note that this method is not optimized to run on TPU and should be run on
CPU as part of the preprocessing during training.
Args:
shape: The shape for which to compute masks. This should be of a tuple of size 2 where
the first element is the batch size and the second element is the length of the axis to span.
mask_prob: The percentage of the whole axis (between 0 and 1) which will be masked. The number of
independently generated mask spans of length `mask_length` is computed by
`mask_prob*shape[1]/mask_length`. Note that due to overlaps, `mask_prob` is an upper bound and the
actual percentage will be smaller.
mask_length: size of the mask
min_masks: minimum number of masked spans
attention_mask: A (right-padded) attention mask which independently shortens the feature axis of
each batch dimension.
"""
batch_size, sequence_length = shape
if mask_length < 1:
raise ValueError("`mask_length` has to be bigger than 0.")
if mask_length > sequence_length:
raise ValueError(
f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length}"
f" and `sequence_length`: {sequence_length}`"
)
# epsilon is used for probabilistic rounding
epsilon = np.random.rand(1).item()
def compute_num_masked_span(input_length):
"""Given input length, compute how many spans should be masked"""
num_masked_span = int(mask_prob * input_length / mask_length + epsilon)
num_masked_span = max(num_masked_span, min_masks)
# make sure num masked span <= sequence_length
if num_masked_span * mask_length > sequence_length:
num_masked_span = sequence_length // mask_length
# make sure num_masked span is also <= input_length - (mask_length - 1)
if input_length - (mask_length - 1) < num_masked_span:
num_masked_span = max(input_length - (mask_length - 1), 0)
return num_masked_span
# compute number of masked spans in batch
input_lengths = (
attention_mask.detach().sum(-1).tolist()
if attention_mask is not None
else [sequence_length for _ in range(batch_size)]
)
# SpecAugment mask to fill
spec_aug_mask = np.zeros((batch_size, sequence_length), dtype=bool)
spec_aug_mask_idxs = []
max_num_masked_span = compute_num_masked_span(sequence_length)
if max_num_masked_span == 0:
return spec_aug_mask
for input_length in input_lengths:
# compute num of masked spans for this input
num_masked_span = compute_num_masked_span(input_length)
# get random indices to mask
spec_aug_mask_idx = np.random.choice(
np.arange(input_length - (mask_length - 1)), num_masked_span, replace=False
)
# pick first sampled index that will serve as a dummy index to pad vector
# to ensure same dimension for all batches due to probabilistic rounding
# Picking first sample just pads those vectors twice.
if len(spec_aug_mask_idx) == 0:
# this case can only happen if `input_length` is strictly smaller then
# `sequence_length` in which case the last token has to be a padding
# token which we can use as a dummy mask id
dummy_mask_idx = sequence_length - 1
else:
dummy_mask_idx = spec_aug_mask_idx[0]
spec_aug_mask_idx = np.concatenate(
[spec_aug_mask_idx, np.ones(max_num_masked_span - num_masked_span, dtype=np.int32) * dummy_mask_idx]
)
spec_aug_mask_idxs.append(spec_aug_mask_idx)
spec_aug_mask_idxs = np.array(spec_aug_mask_idxs)
# expand masked indices to masked spans
spec_aug_mask_idxs = np.broadcast_to(
spec_aug_mask_idxs[:, :, None], (batch_size, max_num_masked_span, mask_length)
)
spec_aug_mask_idxs = spec_aug_mask_idxs.reshape(batch_size, max_num_masked_span * mask_length)
# add offset to the starting indexes so that indexes now create a span
offsets = np.arange(mask_length)[None, None, :]
offsets = np.broadcast_to(offsets, (batch_size, max_num_masked_span, mask_length)).reshape(
batch_size, max_num_masked_span * mask_length
)
spec_aug_mask_idxs = spec_aug_mask_idxs + offsets
# ensure that we cannot have indices larger than sequence_length
if spec_aug_mask_idxs.max() > sequence_length - 1:
spec_aug_mask_idxs[spec_aug_mask_idxs > sequence_length - 1] = sequence_length - 1
# scatter indices to mask
np.put_along_axis(spec_aug_mask, spec_aug_mask_idxs, 1, -1)
return spec_aug_mask
_EXPECTED_OUTPUT_SHAPE = [1, 292, 768]
DATA2VEC_AUDIO_START_DOCSTRING = r"""
Data2VecAudio was proposed in [data2vec: A General Framework for Self-supervised Learning in Speech, Vision and
Language](https://arxiv.org/pdf/2202.03555) by Alexei Baevski, Wei-Ning Hsu, Qiantong Xu, Arun Babu, Jiatao Gu and
@ -1021,7 +991,6 @@ DATA2VEC_AUDIO_START_DOCSTRING = r"""
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
DATA2VEC_AUDIO_INPUTS_DOCSTRING = r"""
Args:
input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
@ -1058,6 +1027,8 @@ DATA2VEC_AUDIO_INPUTS_DOCSTRING = r"""
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
Data2VecAudioBaseModelOutput = Wav2Vec2BaseModelOutput
@add_start_docstrings(
"The bare Data2VecAudio Model transformer outputting raw hidden-states without any specific head on top.",
@ -1137,7 +1108,7 @@ class Data2VecAudioModel(Data2VecAudioPreTrainedModel):
@add_start_docstrings_to_model_forward(DATA2VEC_AUDIO_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=Wav2Vec2BaseModelOutput,
output_type=Data2VecAudioBaseModelOutput,
config_class=_CONFIG_FOR_DOC,
modality="audio",
expected_output=_EXPECTED_OUTPUT_SHAPE,
@ -1150,7 +1121,7 @@ class Data2VecAudioModel(Data2VecAudioPreTrainedModel):
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, Wav2Vec2BaseModelOutput]:
) -> Union[Tuple, Data2VecAudioBaseModelOutput]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
@ -1187,7 +1158,7 @@ class Data2VecAudioModel(Data2VecAudioPreTrainedModel):
if not return_dict:
return (hidden_states, extract_features) + encoder_outputs[1:]
return Wav2Vec2BaseModelOutput(
return Data2VecAudioBaseModelOutput(
last_hidden_state=hidden_states,
extract_features=extract_features,
hidden_states=encoder_outputs.hidden_states,
@ -1195,6 +1166,13 @@ class Data2VecAudioModel(Data2VecAudioPreTrainedModel):
)
_HIDDEN_STATES_START_POSITION = 2
# CTC docstring
_CTC_EXPECTED_OUTPUT = "'MISTER QUILTER IS THE APOSTLE OF THE MIDDLE CLASSES AND WE ARE GLAD TO WELCOME HIS GOSPEL'"
_CTC_EXPECTED_LOSS = 66.95
@add_start_docstrings(
"""Data2VecAudio Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC).""",
DATA2VEC_AUDIO_START_DOCSTRING,
@ -1248,7 +1226,6 @@ class Data2VecAudioForCTC(Data2VecAudioPreTrainedModel):
expected_output=_CTC_EXPECTED_OUTPUT,
expected_loss=_CTC_EXPECTED_LOSS,
)
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForCTC.forward with wav2vec2->data2vec_audio
def forward(
self,
input_values: Optional[torch.Tensor],
@ -1379,7 +1356,6 @@ class Data2VecAudioForSequenceClassification(Data2VecAudioPreTrainedModel):
config_class=_CONFIG_FOR_DOC,
modality="audio",
)
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.forward with wav2vec2->data2vec_audio
def forward(
self,
input_values: Optional[torch.Tensor],
@ -1455,8 +1431,7 @@ class Data2VecAudioForAudioFrameClassification(Data2VecAudioPreTrainedModel):
if hasattr(config, "add_adapter") and config.add_adapter:
raise ValueError(
"Audio frame classification does not support the use of Data2VecAudio adapters"
" (config.add_adapter=True)"
"Audio frame classification does not support the use of Data2VecAudio adapters (config.add_adapter=True)"
)
self.data2vec_audio = Data2VecAudioModel(config)
num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings
@ -1501,7 +1476,6 @@ class Data2VecAudioForAudioFrameClassification(Data2VecAudioPreTrainedModel):
config_class=_CONFIG_FOR_DOC,
modality="audio",
)
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForAudioFrameClassification.forward with wav2vec2->data2vec_audio
def forward(
self,
input_values: Optional[torch.Tensor],
@ -1556,7 +1530,6 @@ class Data2VecAudioForAudioFrameClassification(Data2VecAudioPreTrainedModel):
)
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.AMSoftmaxLoss
class AMSoftmaxLoss(nn.Module):
def __init__(self, input_dim, num_labels, scale=30.0, margin=0.4):
super(AMSoftmaxLoss, self).__init__()
@ -1580,7 +1553,6 @@ class AMSoftmaxLoss(nn.Module):
return loss
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.TDNNLayer
class TDNNLayer(nn.Module):
def __init__(self, config, layer_id=0):
super().__init__()
@ -1596,6 +1568,7 @@ class TDNNLayer(nn.Module):
if is_peft_available():
from peft.tuners.lora import LoraLayer
if is_peft_available():
if isinstance(self.kernel, LoraLayer):
warnings.warn(
"Detected LoRA on TDNNLayer. LoRA weights won't be applied due to optimization. "
@ -1687,7 +1660,6 @@ class Data2VecAudioForXVector(Data2VecAudioPreTrainedModel):
config_class=_CONFIG_FOR_DOC,
modality="audio",
)
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForXVector.forward with wav2vec2->data2vec_audio
def forward(
self,
input_values: Optional[torch.Tensor],

View File

@ -0,0 +1,400 @@
import math
import torch
from torch import nn
from ...activations import ACT2FN
from ...modeling_outputs import (
CausalLMOutput,
SequenceClassifierOutput,
TokenClassifierOutput,
Wav2Vec2BaseModelOutput,
XVectorOutput,
)
from ...modeling_utils import PreTrainedModel
from ...utils import (
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
)
from ..wav2vec2.modeling_wav2vec2 import (
Wav2Vec2Adapter,
Wav2Vec2Encoder,
Wav2Vec2FeatureEncoder,
Wav2Vec2FeatureProjection,
Wav2Vec2ForAudioFrameClassification,
Wav2Vec2ForCTC,
Wav2Vec2ForSequenceClassification,
Wav2Vec2ForXVector,
Wav2Vec2Model,
Wav2Vec2PreTrainedModel,
Wav2Vec2SamePadLayer,
)
from .configuration_data2vec_audio import Data2VecAudioConfig
_HIDDEN_STATES_START_POSITION = 2
# General docstring
_CONFIG_FOR_DOC = "Data2VecAudioConfig"
# Base docstring
_CHECKPOINT_FOR_DOC = "facebook/data2vec-audio-base-960h"
_EXPECTED_OUTPUT_SHAPE = [1, 292, 768]
# CTC docstring
_CTC_EXPECTED_OUTPUT = "'MISTER QUILTER IS THE APOSTLE OF THE MIDDLE CLASSES AND WE ARE GLAD TO WELCOME HIS GOSPEL'"
_CTC_EXPECTED_LOSS = 66.95
class Data2VecAudioConvLayer(nn.Module):
def __init__(self, config, layer_id=0):
super().__init__()
self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
self.out_conv_dim = config.conv_dim[layer_id]
self.conv = nn.Conv1d(
self.in_conv_dim,
self.out_conv_dim,
kernel_size=config.conv_kernel[layer_id],
stride=config.conv_stride[layer_id],
bias=config.conv_bias,
)
self.layer_norm = nn.LayerNorm(self.out_conv_dim, elementwise_affine=True)
self.activation = ACT2FN[config.feat_extract_activation]
def forward(self, hidden_states):
hidden_states = self.conv(hidden_states)
hidden_states = hidden_states.transpose(-2, -1)
hidden_states = self.layer_norm(hidden_states)
hidden_states = hidden_states.transpose(-2, -1)
hidden_states = self.activation(hidden_states)
return hidden_states
class Data2VecAudioPadLayer(Wav2Vec2SamePadLayer):
pass
class Data2VecAudioPositionalConvLayer(nn.Module):
def __init__(self, config):
super().__init__()
self.conv = nn.Conv1d(
config.hidden_size,
config.hidden_size,
kernel_size=config.conv_pos_kernel_size,
padding=config.conv_pos_kernel_size // 2,
groups=config.num_conv_pos_embedding_groups,
)
self.padding = Data2VecAudioPadLayer(config.conv_pos_kernel_size)
self.activation = ACT2FN[config.feat_extract_activation]
# no learnable parameters
self.layer_norm = nn.LayerNorm(config.hidden_size, elementwise_affine=False)
def forward(self, hidden_states):
hidden_states = self.conv(hidden_states)
hidden_states = self.padding(hidden_states)
hidden_states = hidden_states.transpose(1, 2)
hidden_states = self.layer_norm(hidden_states)
hidden_states = hidden_states.transpose(1, 2)
hidden_states = self.activation(hidden_states)
return hidden_states
class Data2VecAudioPositionalConvEmbedding(nn.Module):
def __init__(self, config):
super().__init__()
self.layers = nn.ModuleList(
[Data2VecAudioPositionalConvLayer(config) for _ in range(config.num_conv_pos_embeddings)]
)
def forward(self, hidden_states):
hidden_states = hidden_states.transpose(1, 2)
for layer in self.layers:
hidden_states = layer(hidden_states)
hidden_states = hidden_states.transpose(1, 2)
return hidden_states
class Data2VecAudioFeatureEncoder(Wav2Vec2FeatureEncoder, nn.Module):
def __init__(self, config):
nn.Module.__init__()
self.conv_layers = nn.ModuleList(
[Data2VecAudioConvLayer(config, layer_id=i) for i in range(config.num_feat_extract_layers)]
)
self.gradient_checkpointing = False
self._requires_grad = True
class Data2VecAudioFeatureProjection(Wav2Vec2FeatureProjection):
pass
class Data2VecAudioEncoder(Wav2Vec2Encoder):
pass
class Data2VecAudioAdapter(Wav2Vec2Adapter):
pass
class Data2VecAudioPreTrainedModel(PreTrainedModel, Wav2Vec2PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = Data2VecAudioConfig
base_model_prefix = "data2vec_audio"
main_input_name = "input_values"
supports_gradient_checkpointing = True
_supports_flash_attn_2 = True
_supports_sdpa = True
def _init_weights(self, module):
"""Initialize the weights"""
if isinstance(module, Data2VecAudioFeatureProjection):
k = math.sqrt(1 / module.projection.in_features)
nn.init.uniform_(module.projection.weight, a=-k, b=k)
nn.init.uniform_(module.projection.bias, a=-k, b=k)
elif isinstance(module, Data2VecAudioPositionalConvLayer):
nn.init.constant_(module.conv.bias, 0)
elif isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)):
if module.bias is not None:
module.bias.data.zero_()
if module.weight is not None:
module.weight.data.fill_(1.0)
elif isinstance(module, nn.Conv1d):
nn.init.kaiming_normal_(module.weight)
if module.bias is not None:
k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0]))
nn.init.uniform_(module.bias, a=-k, b=k)
def _get_adapters(self):
raise AttributeError("Not needed for Data2VecAudio")
def init_adapter_layers(self):
raise AttributeError("Not needed for Data2VecAudio")
def load_adapter(self):
raise AttributeError("Not needed for Data2VecAudio")
DATA2VEC_AUDIO_START_DOCSTRING = r"""
Data2VecAudio was proposed in [data2vec: A General Framework for Self-supervised Learning in Speech, Vision and
Language](https://arxiv.org/pdf/2202.03555) by Alexei Baevski, Wei-Ning Hsu, Qiantong Xu, Arun Babu, Jiatao Gu and
Michael Auli.
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving etc.).
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use
it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
behavior.
Parameters:
config ([`Data2VecAudioConfig`]): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
DATA2VEC_AUDIO_INPUTS_DOCSTRING = r"""
Args:
input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
Float values of input raw speech waveform. Values can be obtained by loading a *.flac* or *.wav* audio file
into an array of type *List[float]* or a *numpy.ndarray*, *e.g.* via the soundfile library (*pip install
soundfile*). To prepare the array into *input_values*, the [`AutoProcessor`] should be used for padding and
conversion into a tensor of type *torch.FloatTensor*. See [`Wav2Vec2Processor.__call__`] for details.
attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing convolution and attention on padding token indices. Mask values selected in `[0,
1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
<Tip warning={true}>
`attention_mask` should be passed if the corresponding processor has `config.return_attention_mask ==
True`, which is the case for all pre-trained Data2Vec Audio models. Be aware that that even with
`attention_mask`, zero-padded inputs will have slightly different outputs compared to non-padded inputs
because there are more than one convolutional layer in the positional encodings. For a more detailed
explanation, see [here](https://github.com/huggingface/transformers/issues/25621#issuecomment-1713759349).
</Tip>
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
Data2VecAudioBaseModelOutput = Wav2Vec2BaseModelOutput
@add_start_docstrings(
"The bare Data2VecAudio Model transformer outputting raw hidden-states without any specific head on top.",
DATA2VEC_AUDIO_START_DOCSTRING,
)
class Data2VecAudioModel(Data2VecAudioPreTrainedModel, Wav2Vec2Model):
def __init__(self, config: Data2VecAudioConfig):
Data2VecAudioPreTrainedModel.__init__(config)
self.config = config
self.feature_extractor = Data2VecAudioFeatureEncoder(config)
self.feature_projection = Data2VecAudioFeatureProjection(config)
# model only needs masking vector if mask prob is > 0.0
if config.mask_time_prob > 0.0 or config.mask_feature_prob > 0.0:
self.masked_spec_embed = nn.Parameter(torch.Tensor(config.hidden_size).uniform_())
self.encoder = Data2VecAudioEncoder(config)
self.adapter = Data2VecAudioAdapter(config) if config.add_adapter else None
# Initialize weights and apply final processing
self.post_init()
def freeze_feature_extractor(self):
raise AttributeError("Not needed for Data2VecAudio")
def freeze_feature_encoder(self):
"""
Calling this function will disable the gradient computation for the feature encoder so that its parameter will
not be updated during training.
"""
self.feature_extractor._freeze_parameters()
@add_start_docstrings_to_model_forward(DATA2VEC_AUDIO_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=Data2VecAudioBaseModelOutput,
config_class=_CONFIG_FOR_DOC,
modality="audio",
expected_output=_EXPECTED_OUTPUT_SHAPE,
)
def forward(self, **super_kwargs):
return super().forward(**super_kwargs)
@add_start_docstrings(
"""Data2VecAudio Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC).""",
DATA2VEC_AUDIO_START_DOCSTRING,
)
class Data2VecAudioForCTC(Data2VecAudioPreTrainedModel, Wav2Vec2ForCTC):
def __init__(self, config):
Data2VecAudioPreTrainedModel.__init__(config)
self.data2vec_audio = Data2VecAudioModel(config)
self.dropout = nn.Dropout(config.final_dropout)
if config.vocab_size is None:
raise ValueError(
f"You are trying to instantiate {self.__class__} with a configuration that "
"does not define the vocabulary size of the language model head. Please "
"instantiate the model as follows: `Data2VecAudioForCTC.from_pretrained(..., vocab_size=vocab_size)`. "
"or define `vocab_size` of your model's configuration."
)
output_hidden_size = (
config.output_hidden_size if hasattr(config, "add_adapter") and config.add_adapter else config.hidden_size
)
self.lm_head = nn.Linear(output_hidden_size, config.vocab_size)
# Initialize weights and apply final processing
self.post_init()
def freeze_base_model(self):
raise AttributeError("Not needed for Data2VecAudio")
def tie_weights(self):
raise AttributeError("Not needed for Data2VecAudio")
@add_start_docstrings_to_model_forward(DATA2VEC_AUDIO_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=CausalLMOutput,
config_class=_CONFIG_FOR_DOC,
expected_output=_CTC_EXPECTED_OUTPUT,
expected_loss=_CTC_EXPECTED_LOSS,
)
def forward(self, **super_kwargs):
return super().forward(**super_kwargs)
@add_start_docstrings(
"""
Data2VecAudio Model with a sequence classification head on top (a linear layer over the pooled output) for tasks
like SUPERB Keyword Spotting.
""",
DATA2VEC_AUDIO_START_DOCSTRING,
)
class Data2VecAudioForSequenceClassification(Wav2Vec2ForSequenceClassification):
@add_start_docstrings_to_model_forward(DATA2VEC_AUDIO_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=SequenceClassifierOutput,
config_class=_CONFIG_FOR_DOC,
modality="audio",
)
def forward(self, **super_kwargs):
return super().forward(**super_kwargs)
@add_start_docstrings(
"""
Data2VecAudio Model with a frame classification head on top for tasks like Speaker Diarization.
""",
DATA2VEC_AUDIO_START_DOCSTRING,
)
class Data2VecAudioForAudioFrameClassification(Wav2Vec2ForAudioFrameClassification):
@add_start_docstrings_to_model_forward(DATA2VEC_AUDIO_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=TokenClassifierOutput,
config_class=_CONFIG_FOR_DOC,
modality="audio",
)
def forward(self, **super_kwargs):
return super().forward(**super_kwargs)
@add_start_docstrings(
"""
Data2VecAudio Model with an XVector feature extraction head on top for tasks like Speaker Verification.
""",
DATA2VEC_AUDIO_START_DOCSTRING,
)
class Data2VecAudioForXVector(Wav2Vec2ForXVector):
@add_start_docstrings_to_model_forward(DATA2VEC_AUDIO_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=XVectorOutput,
config_class=_CONFIG_FOR_DOC,
modality="audio",
)
def forward(self, **super_kwargs):
return super().forward(**super_kwargs)
__all__ = [
"Data2VecAudioForAudioFrameClassification",
"Data2VecAudioForCTC",
"Data2VecAudioForSequenceClassification",
"Data2VecAudioForXVector",
"Data2VecAudioModel",
"Data2VecAudioPreTrainedModel",
]

View File

@ -19,7 +19,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ...configuration_utils import PretrainedConfig

View File

@ -13,7 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
import sentencepiece as spm
import torch
@ -379,7 +379,7 @@ class GemmaModel(LlamaModel):
output_hidden_states: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs, # NOOP kwarg for now
) -> Union[Tuple, BaseModelOutputWithPast]:
) -> BaseModelOutputWithPast:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states

View File

@ -415,7 +415,7 @@ class Gemma2Model(GemmaModel):
cache_position: Optional[torch.LongTensor] = None,
last_cache_position: Optional[int] = None,
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
) -> Union[Tuple, BaseModelOutputWithPast]:
) -> BaseModelOutputWithPast:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
@ -588,7 +588,7 @@ class Gemma2ForCausalLM(GemmaForCausalLM):
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**loss_kwargs,
) -> Union[Tuple, CausalLMOutputWithPast]:
) -> CausalLMOutputWithPast:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,

View File

@ -326,7 +326,7 @@ class Gemma3Attention(nn.Module):
past_key_value: Optional[Cache] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[FlashAttentionKwargs],
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim)
@ -1203,7 +1203,7 @@ class Gemma3ForConditionalGeneration(Gemma3PreTrainedModel, GenerationMixin):
return causal_mask
def get_image_features(self, pixel_values: torch.Tensor):
def get_image_features(self, pixel_values: torch.Tensor) -> torch.Tensor:
"""
Projects the last hidden state from the vision model into language model space.

View File

@ -597,7 +597,7 @@ class Gemma3TextModel(Gemma2Model):
cache_position: Optional[torch.LongTensor] = None,
last_cache_position: Optional[int] = None,
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
) -> Union[tuple, BaseModelOutputWithPast]:
) -> BaseModelOutputWithPast:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states

View File

@ -753,7 +753,7 @@ class GotOcr2ForConditionalGeneration(GotOcr2PreTrainedModel, GenerationMixin):
output_hidden_states: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
) -> Union[Tuple, GotOcr2CausalLMOutputWithPast]:
) -> GotOcr2CausalLMOutputWithPast:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,

View File

@ -397,7 +397,7 @@ class GotOcr2ForConditionalGeneration(LlavaForConditionalGeneration):
output_hidden_states: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
) -> LlavaCausalLMOutputWithPast:
) -> GotOcr2CausalLMOutputWithPast:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,

View File

@ -343,7 +343,7 @@ class GPTNeoXModel(LlamaModel, nn.Module):
output_hidden_states: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
) -> Union[Tuple, BaseModelOutputWithPast]:
) -> BaseModelOutputWithPast:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states

View File

@ -132,7 +132,7 @@ class GraniteModel(LlamaModel):
output_hidden_states: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
) -> Union[Tuple, BaseModelOutputWithPast]:
) -> BaseModelOutputWithPast:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
@ -244,7 +244,7 @@ class GraniteForCausalLM(LlamaForCausalLM):
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**kwargs: Unpack[KwargsForCausalLM],
) -> Union[Tuple, CausalLMOutputWithPast]:
) -> CausalLMOutputWithPast:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states

View File

@ -1,26 +1,15 @@
# coding=utf-8
# Copyright 2021 The Fairseq Authors and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch Hubert model."""
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# This file was automatically generated from src/transformers/models/hubert/modular_hubert.py.
# Do NOT edit this file manually as any edits will be overwritten by the generation of
# the file from the modular. If any change should be done, please apply the change to the
# modular_hubert.py file directly. One of our CI enforces this.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
import warnings
from typing import Optional, Tuple, Union
import numpy as np
import torch
import torch.utils.checkpoint
from torch import nn
import torch.nn as nn
from torch.nn import CrossEntropyLoss
from ...activations import ACT2FN
@ -45,219 +34,12 @@ if is_flash_attn_available():
logger = logging.get_logger(__name__)
_HIDDEN_STATES_START_POSITION = 1
# Base docstring
_CHECKPOINT_FOR_DOC = "facebook/hubert-large-ls960-ft"
# General docstring
_CONFIG_FOR_DOC = "HubertConfig"
# Base docstring
_CHECKPOINT_FOR_DOC = "facebook/hubert-large-ls960-ft"
_EXPECTED_OUTPUT_SHAPE = [1, 292, 768]
# CTC docstring
_CTC_EXPECTED_OUTPUT = "'MISTER QUILTER IS THE APOSTLE OF THE MIDDLE CLASSES AND WE ARE GLAD TO WELCOME HIS GOSPEL'"
_CTC_EXPECTED_LOSS = 22.68
# Audio class docstring
_SEQ_CLASS_CHECKPOINT = "superb/hubert-base-superb-ks"
_SEQ_CLASS_EXPECTED_OUTPUT = "'_unknown_'"
_SEQ_CLASS_EXPECTED_LOSS = 8.53
# Copied from transformers.models.wav2vec2.modeling_wav2vec2._compute_mask_indices
def _compute_mask_indices(
shape: Tuple[int, int],
mask_prob: float,
mask_length: int,
attention_mask: Optional[torch.LongTensor] = None,
min_masks: int = 0,
) -> np.ndarray:
"""
Computes random mask spans for a given shape. Used to implement [SpecAugment: A Simple Data Augmentation Method for
ASR](https://arxiv.org/abs/1904.08779). Note that this method is not optimized to run on TPU and should be run on
CPU as part of the preprocessing during training.
Args:
shape: The shape for which to compute masks. This should be of a tuple of size 2 where
the first element is the batch size and the second element is the length of the axis to span.
mask_prob: The percentage of the whole axis (between 0 and 1) which will be masked. The number of
independently generated mask spans of length `mask_length` is computed by
`mask_prob*shape[1]/mask_length`. Note that due to overlaps, `mask_prob` is an upper bound and the
actual percentage will be smaller.
mask_length: size of the mask
min_masks: minimum number of masked spans
attention_mask: A (right-padded) attention mask which independently shortens the feature axis of
each batch dimension.
"""
batch_size, sequence_length = shape
if mask_length < 1:
raise ValueError("`mask_length` has to be bigger than 0.")
if mask_length > sequence_length:
raise ValueError(
f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length}"
f" and `sequence_length`: {sequence_length}`"
)
# epsilon is used for probabilistic rounding
epsilon = np.random.rand(1).item()
def compute_num_masked_span(input_length):
"""Given input length, compute how many spans should be masked"""
num_masked_span = int(mask_prob * input_length / mask_length + epsilon)
num_masked_span = max(num_masked_span, min_masks)
# make sure num masked span <= sequence_length
if num_masked_span * mask_length > sequence_length:
num_masked_span = sequence_length // mask_length
# make sure num_masked span is also <= input_length - (mask_length - 1)
if input_length - (mask_length - 1) < num_masked_span:
num_masked_span = max(input_length - (mask_length - 1), 0)
return num_masked_span
# compute number of masked spans in batch
input_lengths = (
attention_mask.detach().sum(-1).tolist()
if attention_mask is not None
else [sequence_length for _ in range(batch_size)]
)
# SpecAugment mask to fill
spec_aug_mask = np.zeros((batch_size, sequence_length), dtype=bool)
spec_aug_mask_idxs = []
max_num_masked_span = compute_num_masked_span(sequence_length)
if max_num_masked_span == 0:
return spec_aug_mask
for input_length in input_lengths:
# compute num of masked spans for this input
num_masked_span = compute_num_masked_span(input_length)
# get random indices to mask
spec_aug_mask_idx = np.random.choice(
np.arange(input_length - (mask_length - 1)), num_masked_span, replace=False
)
# pick first sampled index that will serve as a dummy index to pad vector
# to ensure same dimension for all batches due to probabilistic rounding
# Picking first sample just pads those vectors twice.
if len(spec_aug_mask_idx) == 0:
# this case can only happen if `input_length` is strictly smaller then
# `sequence_length` in which case the last token has to be a padding
# token which we can use as a dummy mask id
dummy_mask_idx = sequence_length - 1
else:
dummy_mask_idx = spec_aug_mask_idx[0]
spec_aug_mask_idx = np.concatenate(
[spec_aug_mask_idx, np.ones(max_num_masked_span - num_masked_span, dtype=np.int32) * dummy_mask_idx]
)
spec_aug_mask_idxs.append(spec_aug_mask_idx)
spec_aug_mask_idxs = np.array(spec_aug_mask_idxs)
# expand masked indices to masked spans
spec_aug_mask_idxs = np.broadcast_to(
spec_aug_mask_idxs[:, :, None], (batch_size, max_num_masked_span, mask_length)
)
spec_aug_mask_idxs = spec_aug_mask_idxs.reshape(batch_size, max_num_masked_span * mask_length)
# add offset to the starting indexes so that indexes now create a span
offsets = np.arange(mask_length)[None, None, :]
offsets = np.broadcast_to(offsets, (batch_size, max_num_masked_span, mask_length)).reshape(
batch_size, max_num_masked_span * mask_length
)
spec_aug_mask_idxs = spec_aug_mask_idxs + offsets
# ensure that we cannot have indices larger than sequence_length
if spec_aug_mask_idxs.max() > sequence_length - 1:
spec_aug_mask_idxs[spec_aug_mask_idxs > sequence_length - 1] = sequence_length - 1
# scatter indices to mask
np.put_along_axis(spec_aug_mask, spec_aug_mask_idxs, 1, -1)
return spec_aug_mask
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2NoLayerNormConvLayer with Wav2Vec2->Hubert
class HubertNoLayerNormConvLayer(nn.Module):
def __init__(self, config, layer_id=0):
super().__init__()
self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
self.out_conv_dim = config.conv_dim[layer_id]
self.conv = nn.Conv1d(
self.in_conv_dim,
self.out_conv_dim,
kernel_size=config.conv_kernel[layer_id],
stride=config.conv_stride[layer_id],
bias=config.conv_bias,
)
self.activation = ACT2FN[config.feat_extract_activation]
def forward(self, hidden_states):
hidden_states = self.conv(hidden_states)
hidden_states = self.activation(hidden_states)
return hidden_states
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2LayerNormConvLayer with Wav2Vec2->Hubert
class HubertLayerNormConvLayer(nn.Module):
def __init__(self, config, layer_id=0):
super().__init__()
self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
self.out_conv_dim = config.conv_dim[layer_id]
self.conv = nn.Conv1d(
self.in_conv_dim,
self.out_conv_dim,
kernel_size=config.conv_kernel[layer_id],
stride=config.conv_stride[layer_id],
bias=config.conv_bias,
)
self.layer_norm = nn.LayerNorm(self.out_conv_dim, elementwise_affine=True)
self.activation = ACT2FN[config.feat_extract_activation]
def forward(self, hidden_states):
hidden_states = self.conv(hidden_states)
hidden_states = hidden_states.transpose(-2, -1)
hidden_states = self.layer_norm(hidden_states)
hidden_states = hidden_states.transpose(-2, -1)
hidden_states = self.activation(hidden_states)
return hidden_states
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2GroupNormConvLayer with Wav2Vec2->Hubert
class HubertGroupNormConvLayer(nn.Module):
def __init__(self, config, layer_id=0):
super().__init__()
self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
self.out_conv_dim = config.conv_dim[layer_id]
self.conv = nn.Conv1d(
self.in_conv_dim,
self.out_conv_dim,
kernel_size=config.conv_kernel[layer_id],
stride=config.conv_stride[layer_id],
bias=config.conv_bias,
)
self.activation = ACT2FN[config.feat_extract_activation]
self.layer_norm = nn.GroupNorm(num_groups=self.out_conv_dim, num_channels=self.out_conv_dim, affine=True)
def forward(self, hidden_states):
hidden_states = self.conv(hidden_states)
hidden_states = self.layer_norm(hidden_states)
hidden_states = self.activation(hidden_states)
return hidden_states
class HubertPositionalConvEmbedding(nn.Module):
def __init__(self, config):
@ -309,7 +91,6 @@ class HubertPositionalConvEmbedding(nn.Module):
return hidden_states
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2SamePadLayer with Wav2Vec2->Hubert
class HubertSamePadLayer(nn.Module):
def __init__(self, num_conv_pos_embeddings):
super().__init__()
@ -321,7 +102,78 @@ class HubertSamePadLayer(nn.Module):
return hidden_states
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2FeatureEncoder with Wav2Vec2->Hubert
class HubertNoLayerNormConvLayer(nn.Module):
def __init__(self, config, layer_id=0):
super().__init__()
self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
self.out_conv_dim = config.conv_dim[layer_id]
self.conv = nn.Conv1d(
self.in_conv_dim,
self.out_conv_dim,
kernel_size=config.conv_kernel[layer_id],
stride=config.conv_stride[layer_id],
bias=config.conv_bias,
)
self.activation = ACT2FN[config.feat_extract_activation]
def forward(self, hidden_states):
hidden_states = self.conv(hidden_states)
hidden_states = self.activation(hidden_states)
return hidden_states
class HubertLayerNormConvLayer(nn.Module):
def __init__(self, config, layer_id=0):
super().__init__()
self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
self.out_conv_dim = config.conv_dim[layer_id]
self.conv = nn.Conv1d(
self.in_conv_dim,
self.out_conv_dim,
kernel_size=config.conv_kernel[layer_id],
stride=config.conv_stride[layer_id],
bias=config.conv_bias,
)
self.layer_norm = nn.LayerNorm(self.out_conv_dim, elementwise_affine=True)
self.activation = ACT2FN[config.feat_extract_activation]
def forward(self, hidden_states):
hidden_states = self.conv(hidden_states)
hidden_states = hidden_states.transpose(-2, -1)
hidden_states = self.layer_norm(hidden_states)
hidden_states = hidden_states.transpose(-2, -1)
hidden_states = self.activation(hidden_states)
return hidden_states
class HubertGroupNormConvLayer(nn.Module):
def __init__(self, config, layer_id=0):
super().__init__()
self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
self.out_conv_dim = config.conv_dim[layer_id]
self.conv = nn.Conv1d(
self.in_conv_dim,
self.out_conv_dim,
kernel_size=config.conv_kernel[layer_id],
stride=config.conv_stride[layer_id],
bias=config.conv_bias,
)
self.activation = ACT2FN[config.feat_extract_activation]
self.layer_norm = nn.GroupNorm(num_groups=self.out_conv_dim, num_channels=self.out_conv_dim, affine=True)
def forward(self, hidden_states):
hidden_states = self.conv(hidden_states)
hidden_states = self.layer_norm(hidden_states)
hidden_states = self.activation(hidden_states)
return hidden_states
class HubertFeatureEncoder(nn.Module):
"""Construct the features from raw audio waveform"""
@ -366,17 +218,6 @@ class HubertFeatureEncoder(nn.Module):
return hidden_states
class HubertFeatureExtractor(HubertFeatureEncoder):
def __init__(self, config):
super().__init__(config)
warnings.warn(
f"The class `{self.__class__.__name__}` has been depreciated "
"and will be removed in Transformers v5. "
f"Use `{self.__class__.__bases__[0].__name__}` instead.",
FutureWarning,
)
class HubertFeatureProjection(nn.Module):
def __init__(self, config):
super().__init__()
@ -395,7 +236,6 @@ class HubertFeatureProjection(nn.Module):
return hidden_states
# Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->Hubert
class HubertAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
@ -554,7 +394,6 @@ class HubertAttention(nn.Module):
return attn_output, attn_weights_reshaped, past_key_value
# Copied from transformers.models.bart.modeling_bart.BartFlashAttention2 with Bart->Hubert
class HubertFlashAttention2(HubertAttention):
"""
Hubert flash attention module. This module inherits from `HubertAttention` as the weights of the module stays
@ -682,7 +521,6 @@ class HubertFlashAttention2(HubertAttention):
class HubertSdpaAttention(HubertAttention):
# Copied from transformers.models.bart.modeling_bart.BartSdpaAttention.forward with Bart->Hubert
def forward(
self,
hidden_states: torch.Tensor,
@ -788,14 +626,6 @@ class HubertSdpaAttention(HubertAttention):
return attn_output, None, past_key_value
HUBERT_ATTENTION_CLASSES = {
"eager": HubertAttention,
"sdpa": HubertSdpaAttention,
"flash_attention_2": HubertFlashAttention2,
}
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2FeedForward with Wav2Vec2->Hubert
class HubertFeedForward(nn.Module):
def __init__(self, config):
super().__init__()
@ -820,7 +650,13 @@ class HubertFeedForward(nn.Module):
return hidden_states
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2EncoderLayer with Wav2Vec2->Hubert, WAV2VEC2->HUBERT
HUBERT_ATTENTION_CLASSES = {
"eager": HubertAttention,
"sdpa": HubertSdpaAttention,
"flash_attention_2": HubertFlashAttention2,
}
class HubertEncoderLayer(nn.Module):
def __init__(self, config):
super().__init__()
@ -856,79 +692,6 @@ class HubertEncoderLayer(nn.Module):
return outputs
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2AttnAdapterLayer with Wav2Vec2->Hubert
class HubertAttnAdapterLayer(nn.Module):
def __init__(self, config):
"""
Implements adapter modules directly with 3D tensor weight as parameters and without using ModuleList to speed
up training throughput.
"""
super().__init__()
self.input_dim = config.adapter_attn_dim
self.hidden_dim = config.hidden_size
self.norm = nn.LayerNorm(self.hidden_dim)
self.linear_1 = nn.Linear(self.hidden_dim, self.input_dim)
self.act_fn = nn.ReLU()
self.linear_2 = nn.Linear(self.input_dim, self.hidden_dim)
def forward(self, hidden_states: torch.FloatTensor):
hidden_states = self.norm(hidden_states)
hidden_states = self.linear_1(hidden_states)
hidden_states = self.act_fn(hidden_states)
hidden_states = self.linear_2(hidden_states)
return hidden_states
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2EncoderLayerStableLayerNorm with Wav2Vec2->Hubert, WAV2VEC2->HUBERT
class HubertEncoderLayerStableLayerNorm(nn.Module):
def __init__(self, config):
super().__init__()
self.attention = HUBERT_ATTENTION_CLASSES[config._attn_implementation](
embed_dim=config.hidden_size,
num_heads=config.num_attention_heads,
dropout=config.attention_dropout,
is_decoder=False,
)
self.dropout = nn.Dropout(config.hidden_dropout)
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.feed_forward = HubertFeedForward(config)
self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
if getattr(config, "adapter_attn_dim", None) is not None:
self.adapter_layer = HubertAttnAdapterLayer(config)
else:
self.adapter_layer = None
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
):
attn_residual = hidden_states
hidden_states = self.layer_norm(hidden_states)
hidden_states, attn_weights, _ = self.attention(
hidden_states, attention_mask=attention_mask, output_attentions=output_attentions
)
hidden_states = self.dropout(hidden_states)
hidden_states = attn_residual + hidden_states
hidden_states = hidden_states + self.feed_forward(self.final_layer_norm(hidden_states))
if self.adapter_layer is not None:
hidden_states = hidden_states + self.adapter_layer(hidden_states)
outputs = (hidden_states,)
if output_attentions:
outputs += (attn_weights,)
return outputs
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Encoder with Wav2Vec2->Hubert
class HubertEncoder(nn.Module):
def __init__(self, config):
super().__init__()
@ -1014,7 +777,76 @@ class HubertEncoder(nn.Module):
)
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2EncoderStableLayerNorm with Wav2Vec2->Hubert
class HubertAttnAdapterLayer(nn.Module):
def __init__(self, config):
"""
Implements adapter modules directly with 3D tensor weight as parameters and without using ModuleList to speed
up training throughput.
"""
super().__init__()
self.input_dim = config.adapter_attn_dim
self.hidden_dim = config.hidden_size
self.norm = nn.LayerNorm(self.hidden_dim)
self.linear_1 = nn.Linear(self.hidden_dim, self.input_dim)
self.act_fn = nn.ReLU()
self.linear_2 = nn.Linear(self.input_dim, self.hidden_dim)
def forward(self, hidden_states: torch.FloatTensor):
hidden_states = self.norm(hidden_states)
hidden_states = self.linear_1(hidden_states)
hidden_states = self.act_fn(hidden_states)
hidden_states = self.linear_2(hidden_states)
return hidden_states
class HubertEncoderLayerStableLayerNorm(nn.Module):
def __init__(self, config):
super().__init__()
self.attention = HUBERT_ATTENTION_CLASSES[config._attn_implementation](
embed_dim=config.hidden_size,
num_heads=config.num_attention_heads,
dropout=config.attention_dropout,
is_decoder=False,
)
self.dropout = nn.Dropout(config.hidden_dropout)
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.feed_forward = HubertFeedForward(config)
self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
if getattr(config, "adapter_attn_dim", None) is not None:
self.adapter_layer = HubertAttnAdapterLayer(config)
else:
self.adapter_layer = None
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
):
attn_residual = hidden_states
hidden_states = self.layer_norm(hidden_states)
hidden_states, attn_weights, _ = self.attention(
hidden_states, attention_mask=attention_mask, output_attentions=output_attentions
)
hidden_states = self.dropout(hidden_states)
hidden_states = attn_residual + hidden_states
hidden_states = hidden_states + self.feed_forward(self.final_layer_norm(hidden_states))
if self.adapter_layer is not None:
hidden_states = hidden_states + self.adapter_layer(hidden_states)
outputs = (hidden_states,)
if output_attentions:
outputs += (attn_weights,)
return outputs
class HubertEncoderStableLayerNorm(nn.Module):
def __init__(self, config):
super().__init__()
@ -1170,6 +1002,125 @@ class HubertPreTrainedModel(PreTrainedModel):
return attention_mask
def _compute_mask_indices(
shape: Tuple[int, int],
mask_prob: float,
mask_length: int,
attention_mask: Optional[torch.LongTensor] = None,
min_masks: int = 0,
) -> np.ndarray:
"""
Computes random mask spans for a given shape. Used to implement [SpecAugment: A Simple Data Augmentation Method for
ASR](https://arxiv.org/abs/1904.08779). Note that this method is not optimized to run on TPU and should be run on
CPU as part of the preprocessing during training.
Args:
shape: The shape for which to compute masks. This should be of a tuple of size 2 where
the first element is the batch size and the second element is the length of the axis to span.
mask_prob: The percentage of the whole axis (between 0 and 1) which will be masked. The number of
independently generated mask spans of length `mask_length` is computed by
`mask_prob*shape[1]/mask_length`. Note that due to overlaps, `mask_prob` is an upper bound and the
actual percentage will be smaller.
mask_length: size of the mask
min_masks: minimum number of masked spans
attention_mask: A (right-padded) attention mask which independently shortens the feature axis of
each batch dimension.
"""
batch_size, sequence_length = shape
if mask_length < 1:
raise ValueError("`mask_length` has to be bigger than 0.")
if mask_length > sequence_length:
raise ValueError(
f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length}"
f" and `sequence_length`: {sequence_length}`"
)
# epsilon is used for probabilistic rounding
epsilon = np.random.rand(1).item()
def compute_num_masked_span(input_length):
"""Given input length, compute how many spans should be masked"""
num_masked_span = int(mask_prob * input_length / mask_length + epsilon)
num_masked_span = max(num_masked_span, min_masks)
# make sure num masked span <= sequence_length
if num_masked_span * mask_length > sequence_length:
num_masked_span = sequence_length // mask_length
# make sure num_masked span is also <= input_length - (mask_length - 1)
if input_length - (mask_length - 1) < num_masked_span:
num_masked_span = max(input_length - (mask_length - 1), 0)
return num_masked_span
# compute number of masked spans in batch
input_lengths = (
attention_mask.detach().sum(-1).tolist()
if attention_mask is not None
else [sequence_length for _ in range(batch_size)]
)
# SpecAugment mask to fill
spec_aug_mask = np.zeros((batch_size, sequence_length), dtype=bool)
spec_aug_mask_idxs = []
max_num_masked_span = compute_num_masked_span(sequence_length)
if max_num_masked_span == 0:
return spec_aug_mask
for input_length in input_lengths:
# compute num of masked spans for this input
num_masked_span = compute_num_masked_span(input_length)
# get random indices to mask
spec_aug_mask_idx = np.random.choice(
np.arange(input_length - (mask_length - 1)), num_masked_span, replace=False
)
# pick first sampled index that will serve as a dummy index to pad vector
# to ensure same dimension for all batches due to probabilistic rounding
# Picking first sample just pads those vectors twice.
if len(spec_aug_mask_idx) == 0:
# this case can only happen if `input_length` is strictly smaller then
# `sequence_length` in which case the last token has to be a padding
# token which we can use as a dummy mask id
dummy_mask_idx = sequence_length - 1
else:
dummy_mask_idx = spec_aug_mask_idx[0]
spec_aug_mask_idx = np.concatenate(
[spec_aug_mask_idx, np.ones(max_num_masked_span - num_masked_span, dtype=np.int32) * dummy_mask_idx]
)
spec_aug_mask_idxs.append(spec_aug_mask_idx)
spec_aug_mask_idxs = np.array(spec_aug_mask_idxs)
# expand masked indices to masked spans
spec_aug_mask_idxs = np.broadcast_to(
spec_aug_mask_idxs[:, :, None], (batch_size, max_num_masked_span, mask_length)
)
spec_aug_mask_idxs = spec_aug_mask_idxs.reshape(batch_size, max_num_masked_span * mask_length)
# add offset to the starting indexes so that indexes now create a span
offsets = np.arange(mask_length)[None, None, :]
offsets = np.broadcast_to(offsets, (batch_size, max_num_masked_span, mask_length)).reshape(
batch_size, max_num_masked_span * mask_length
)
spec_aug_mask_idxs = spec_aug_mask_idxs + offsets
# ensure that we cannot have indices larger than sequence_length
if spec_aug_mask_idxs.max() > sequence_length - 1:
spec_aug_mask_idxs[spec_aug_mask_idxs > sequence_length - 1] = sequence_length - 1
# scatter indices to mask
np.put_along_axis(spec_aug_mask, spec_aug_mask_idxs, 1, -1)
return spec_aug_mask
HUBERT_START_DOCSTRING = r"""
Hubert was proposed in [HuBERT: Self-Supervised Speech Representation Learning by Masked Prediction of Hidden
Units](https://arxiv.org/abs/2106.07447) by Wei-Ning Hsu, Benjamin Bolte, Yao-Hung Hubert Tsai, Kushal Lakhotia,
@ -1238,6 +1189,7 @@ class HubertModel(HubertPreTrainedModel):
self.feature_extractor = HubertFeatureEncoder(config)
self.feature_projection = HubertFeatureProjection(config)
# model only needs masking vector if mask prob is > 0.0
if config.mask_time_prob > 0.0 or config.mask_feature_prob > 0.0:
self.masked_spec_embed = nn.Parameter(torch.Tensor(config.hidden_size).uniform_())
@ -1249,7 +1201,6 @@ class HubertModel(HubertPreTrainedModel):
# Initialize weights and apply final processing
self.post_init()
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Model._mask_hidden_states
def _mask_hidden_states(
self,
hidden_states: torch.FloatTensor,
@ -1370,11 +1321,17 @@ class HubertModel(HubertPreTrainedModel):
)
_HIDDEN_STATES_START_POSITION = 1
_CTC_EXPECTED_OUTPUT = "'MISTER QUILTER IS THE APOSTLE OF THE MIDDLE CLASSES AND WE ARE GLAD TO WELCOME HIS GOSPEL'"
_CTC_EXPECTED_LOSS = 22.68
@add_start_docstrings(
"""Hubert Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC).""",
HUBERT_START_DOCSTRING,
)
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForCTC with Wav2Vec2->Hubert, wav2vec2->hubert, WAV_2_VEC_2->HUBERT
class HubertForCTC(HubertPreTrainedModel):
def __init__(self, config, target_lang: Optional[str] = None):
super().__init__(config)
@ -1526,6 +1483,11 @@ class HubertForCTC(HubertPreTrainedModel):
)
_SEQ_CLASS_CHECKPOINT = "superb/hubert-base-superb-ks"
_SEQ_CLASS_EXPECTED_OUTPUT = "'_unknown_'"
_SEQ_CLASS_EXPECTED_LOSS = 8.53
@add_start_docstrings(
"""
Hubert Model with a sequence classification head on top (a linear layer over the pooled output) for tasks like
@ -1533,7 +1495,6 @@ class HubertForCTC(HubertPreTrainedModel):
""",
HUBERT_START_DOCSTRING,
)
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification with Wav2Vec2->Hubert, wav2vec2->hubert, WAV_2_VEC_2->HUBERT
class HubertForSequenceClassification(HubertPreTrainedModel):
def __init__(self, config):
super().__init__(config)

View File

@ -0,0 +1,400 @@
from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
from ...activations import ACT2FN
from ...integrations.deepspeed import is_deepspeed_zero3_enabled
from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput
from ...modeling_utils import PreTrainedModel
from ...utils import (
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
replace_return_docstrings,
)
from ..wav2vec2.modeling_wav2vec2 import (
Wav2Vec2Encoder,
Wav2Vec2EncoderStableLayerNorm,
Wav2Vec2FeatureEncoder,
Wav2Vec2ForCTC,
Wav2Vec2ForSequenceClassification,
Wav2Vec2Model,
Wav2Vec2SamePadLayer,
)
from .configuration_hubert import HubertConfig
_HIDDEN_STATES_START_POSITION = 1
# General docstring
_CONFIG_FOR_DOC = "HubertConfig"
# Base docstring
_CHECKPOINT_FOR_DOC = "facebook/hubert-large-ls960-ft"
_EXPECTED_OUTPUT_SHAPE = [1, 292, 768]
_CTC_EXPECTED_OUTPUT = "'MISTER QUILTER IS THE APOSTLE OF THE MIDDLE CLASSES AND WE ARE GLAD TO WELCOME HIS GOSPEL'"
_CTC_EXPECTED_LOSS = 22.68
_SEQ_CLASS_CHECKPOINT = "superb/hubert-base-superb-ks"
_SEQ_CLASS_EXPECTED_OUTPUT = "'_unknown_'"
_SEQ_CLASS_EXPECTED_LOSS = 8.53
class HubertPositionalConvEmbedding(nn.Module):
def __init__(self, config):
super().__init__()
self.conv = nn.Conv1d(
config.hidden_size,
config.hidden_size,
kernel_size=config.num_conv_pos_embeddings,
padding=config.num_conv_pos_embeddings // 2,
groups=config.num_conv_pos_embedding_groups,
)
self.batch_norm = None
if config.conv_pos_batch_norm:
self.batch_norm = nn.BatchNorm1d(config.hidden_size)
else:
weight_norm = nn.utils.weight_norm
if hasattr(nn.utils.parametrizations, "weight_norm"):
weight_norm = nn.utils.parametrizations.weight_norm
if is_deepspeed_zero3_enabled():
import deepspeed
with deepspeed.zero.GatheredParameters(self.conv.weight, modifier_rank=0):
self.conv = weight_norm(self.conv, name="weight", dim=2)
if hasattr(self.conv, "parametrizations"):
weight_g = self.conv.parametrizations.weight.original0
weight_v = self.conv.parametrizations.weight.original1
else:
weight_g = self.conv.weight_g
weight_v = self.conv.weight_v
deepspeed.zero.register_external_parameter(self, weight_v)
deepspeed.zero.register_external_parameter(self, weight_g)
else:
self.conv = weight_norm(self.conv, name="weight", dim=2)
self.padding = HubertSamePadLayer(config.num_conv_pos_embeddings)
self.activation = ACT2FN[config.feat_extract_activation]
def forward(self, hidden_states):
hidden_states = hidden_states.transpose(1, 2)
if self.batch_norm is not None:
hidden_states = self.batch_norm(hidden_states)
hidden_states = self.conv(hidden_states)
hidden_states = self.padding(hidden_states)
hidden_states = self.activation(hidden_states)
hidden_states = hidden_states.transpose(1, 2)
return hidden_states
class HubertSamePadLayer(Wav2Vec2SamePadLayer):
pass
class HubertFeatureEncoder(Wav2Vec2FeatureEncoder):
pass
class HubertFeatureProjection(nn.Module):
def __init__(self, config):
super().__init__()
self.feat_proj_layer_norm = config.feat_proj_layer_norm
if self.feat_proj_layer_norm:
self.layer_norm = nn.LayerNorm(config.conv_dim[-1], eps=config.layer_norm_eps)
self.projection = nn.Linear(config.conv_dim[-1], config.hidden_size)
self.dropout = nn.Dropout(config.feat_proj_dropout)
def forward(self, hidden_states):
# non-projected hidden states are needed for quantization
if self.feat_proj_layer_norm:
hidden_states = self.layer_norm(hidden_states)
hidden_states = self.projection(hidden_states)
hidden_states = self.dropout(hidden_states)
return hidden_states
class HubertEncoder(Wav2Vec2Encoder):
pass
class HubertEncoderStableLayerNorm(Wav2Vec2EncoderStableLayerNorm):
pass
class HubertPreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = HubertConfig
base_model_prefix = "hubert"
main_input_name = "input_values"
supports_gradient_checkpointing = True
_supports_flash_attn_2 = True
_supports_sdpa = True
def _init_weights(self, module):
"""Initialize the weights"""
if isinstance(module, nn.Linear):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
elif isinstance(module, nn.Conv1d):
if is_deepspeed_zero3_enabled():
import deepspeed
if hasattr(module, "weight_v") and hasattr(module, "weight_g"):
with deepspeed.zero.GatheredParameters([module.weight_v, module.weight_g], modifier_rank=0):
nn.init.kaiming_normal_(module.weight.data)
else:
with deepspeed.zero.GatheredParameters(module.weight, modifier_rank=0):
nn.init.kaiming_normal_(module.weight.data)
else:
nn.init.kaiming_normal_(module.weight.data)
if isinstance(module, (nn.Linear, nn.Conv1d)) and module.bias is not None:
module.bias.data.zero_()
def _get_feat_extract_output_lengths(self, input_lengths: Union[torch.LongTensor, int]):
"""
Computes the output length of the convolutional layers
"""
def _conv_out_length(input_length, kernel_size, stride):
# 1D convolutional layer output length formula taken
# from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
return torch.div(input_length - kernel_size, stride, rounding_mode="floor") + 1
for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride):
input_lengths = _conv_out_length(input_lengths, kernel_size, stride)
return input_lengths
def _get_feature_vector_attention_mask(self, feature_vector_length: int, attention_mask: torch.LongTensor):
output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long)
batch_size = attention_mask.shape[0]
attention_mask = torch.zeros(
(batch_size, feature_vector_length), dtype=attention_mask.dtype, device=attention_mask.device
)
# these two operations makes sure that all values before the output lengths idxs are attended to
attention_mask[(torch.arange(attention_mask.shape[0], device=attention_mask.device), output_lengths - 1)] = 1
attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool()
return attention_mask
HUBERT_START_DOCSTRING = r"""
Hubert was proposed in [HuBERT: Self-Supervised Speech Representation Learning by Masked Prediction of Hidden
Units](https://arxiv.org/abs/2106.07447) by Wei-Ning Hsu, Benjamin Bolte, Yao-Hung Hubert Tsai, Kushal Lakhotia,
Ruslan Salakhutdinov, Abdelrahman Mohamed.
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving etc.).
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use
it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
behavior.
Parameters:
config ([`HubertConfig`]): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
HUBERT_INPUTS_DOCSTRING = r"""
Args:
input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
Float values of input raw speech waveform. Values can be obtained by loading a `.flac` or `.wav` audio file
into an array of type `List[float]` or a `numpy.ndarray`, *e.g.* via the soundfile library (`pip install
soundfile`). To prepare the array into `input_values`, the [`AutoProcessor`] should be used for padding and
conversion into a tensor of type `torch.FloatTensor`. See [`Wav2Vec2Processor.__call__`] for details.
attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing convolution and attention on padding token indices. Mask values selected in `[0,
1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
<Tip warning={true}>
`attention_mask` should only be passed if the corresponding processor has `config.return_attention_mask ==
True`. For all models whose processor has `config.return_attention_mask == False`, such as
[hubert-base](https://huggingface.co/facebook/hubert-base-ls960), `attention_mask` should **not** be passed
to avoid degraded performance when doing batched inference. For such models `input_values` should simply be
padded with 0 and passed without `attention_mask`. Be aware that these models also yield slightly different
results depending on whether `input_values` is padded or not.
</Tip>
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
@add_start_docstrings(
"The bare Hubert Model transformer outputting raw hidden-states without any specific head on top.",
HUBERT_START_DOCSTRING,
)
class HubertModel(Wav2Vec2Model, HubertPreTrainedModel):
def __init__(self, config: HubertConfig):
super().__init__(config)
self.config = config
self.feature_extractor = HubertFeatureEncoder(config)
self.feature_projection = HubertFeatureProjection(config)
if config.mask_time_prob > 0.0 or config.mask_feature_prob > 0.0:
self.masked_spec_embed = nn.Parameter(torch.Tensor(config.hidden_size).uniform_())
if config.do_stable_layer_norm:
self.encoder = HubertEncoderStableLayerNorm(config)
else:
self.encoder = HubertEncoder(config)
# Initialize weights and apply final processing
self.post_init()
del self.adapter
def freeze_feature_extractor(self):
raise AttributeError("Not needed for Hubert")
def freeze_feature_encoder(self):
raise AttributeError("Not needed for Hubert")
@add_start_docstrings_to_model_forward(HUBERT_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_values: Optional[torch.Tensor],
attention_mask: Optional[torch.Tensor] = None,
mask_time_indices: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutput]:
"""
Returns:
Example:
```python
>>> from transformers import AutoProcessor, HubertModel
>>> from datasets import load_dataset
>>> import soundfile as sf
>>> processor = AutoProcessor.from_pretrained("facebook/hubert-large-ls960-ft")
>>> model = HubertModel.from_pretrained("facebook/hubert-large-ls960-ft")
>>> def map_to_array(batch):
... speech, _ = sf.read(batch["file"])
... batch["speech"] = speech
... return batch
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
>>> ds = ds.map(map_to_array)
>>> input_values = processor(ds["speech"][0], return_tensors="pt").input_values # Batch size 1
>>> hidden_states = model(input_values).last_hidden_state
```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
extract_features = self.feature_extractor(input_values)
extract_features = extract_features.transpose(1, 2)
if attention_mask is not None:
# compute reduced attention_mask corresponding to feature vectors
attention_mask = self._get_feature_vector_attention_mask(extract_features.shape[1], attention_mask)
hidden_states = self.feature_projection(extract_features)
hidden_states = self._mask_hidden_states(hidden_states, mask_time_indices=mask_time_indices)
encoder_outputs = self.encoder(
hidden_states,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = encoder_outputs[0]
if not return_dict:
return (hidden_states,) + encoder_outputs[1:]
return BaseModelOutput(
last_hidden_state=hidden_states,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)
@add_start_docstrings(
"""Hubert Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC).""",
HUBERT_START_DOCSTRING,
)
class HubertForCTC(Wav2Vec2ForCTC):
pass
@add_start_docstrings_to_model_forward(HUBERT_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=CausalLMOutput,
config_class=_CONFIG_FOR_DOC,
expected_output=_CTC_EXPECTED_OUTPUT,
expected_loss=_CTC_EXPECTED_LOSS,
)
def forward(self, **super_kwargs):
super().forward(**super_kwargs)
@add_start_docstrings(
"""
Hubert Model with a sequence classification head on top (a linear layer over the pooled output) for tasks like
SUPERB Keyword Spotting.
""",
HUBERT_START_DOCSTRING,
)
class HubertForSequenceClassification(Wav2Vec2ForSequenceClassification):
pass
@add_start_docstrings_to_model_forward(HUBERT_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint=_SEQ_CLASS_CHECKPOINT,
output_type=SequenceClassifierOutput,
config_class=_CONFIG_FOR_DOC,
modality="audio",
expected_output=_SEQ_CLASS_EXPECTED_OUTPUT,
expected_loss=_SEQ_CLASS_EXPECTED_LOSS,
)
def forward(self, **super_kwargs):
super().forward(**super_kwargs)
__all__ = ["HubertForCTC", "HubertForSequenceClassification", "HubertModel", "HubertPreTrainedModel"]

View File

@ -303,7 +303,7 @@ class MistralForQuestionAnswering(LlamaForQuestionAnswering):
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
**kwargs,
) -> Union[Tuple, QuestionAnsweringModelOutput]:
) -> QuestionAnsweringModelOutput:
r"""
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for position (index) of the start of the labelled span for computing the token classification loss.

View File

@ -620,7 +620,7 @@ class MixtralModel(MixtralPreTrainedModel):
output_router_logits: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
) -> BaseModelOutputWithPast:
) -> MoeModelOutputWithPast:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_router_logits = (
output_router_logits if output_router_logits is not None else self.config.output_router_logits
@ -1013,7 +1013,7 @@ class MixtralForCausalLM(MixtralPreTrainedModel, GenerationMixin):
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**kwargs: Unpack[KwargsForCausalLM],
) -> CausalLMOutputWithPast:
) -> MoeCausalLMOutputWithPast:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,

View File

@ -853,7 +853,7 @@ class MoonshineDecoder(MoonshinePreTrainedModel):
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
) -> BaseModelOutputWithPast:
) -> Union[Tuple, BaseModelOutputWithPast]:
"""
Args:
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*):
@ -1415,7 +1415,7 @@ class MoonshineModel(MoonshinePreTrainedModel):
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple[torch.Tensor], Seq2SeqModelOutput]:
) -> Seq2SeqModelOutput:
r"""
Returns:

View File

@ -1,5 +1,5 @@
from functools import partial
from typing import Callable, Optional, Tuple, Union
from typing import Callable, Optional, Tuple
import torch
import torch.nn as nn
@ -191,7 +191,7 @@ class PhiModel(LlamaModel):
output_hidden_states: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
) -> Union[Tuple, BaseModelOutputWithPast]:
) -> BaseModelOutputWithPast:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states

View File

@ -14,7 +14,7 @@
# limitations under the License.
"""PyTorch Qwen3 model."""
from typing import Callable, Optional, Tuple, Union
from typing import Callable, Optional, Tuple
import torch
import torch.utils.checkpoint
@ -146,7 +146,7 @@ class Qwen3ForCausalLM(LlamaForCausalLM):
def forward(
self,
**super_kwargs: Unpack[KwargsForCausalLM],
) -> Union[Tuple, CausalLMOutputWithPast]:
) -> CausalLMOutputWithPast:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,

View File

@ -633,7 +633,7 @@ class Qwen3MoeModel(Qwen3MoePreTrainedModel):
output_router_logits: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
) -> BaseModelOutputWithPast:
) -> MoeModelOutputWithPast:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_router_logits = (
output_router_logits if output_router_logits is not None else self.config.output_router_logits
@ -1026,7 +1026,7 @@ class Qwen3MoeForCausalLM(Qwen3MoePreTrainedModel, GenerationMixin):
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**kwargs: Unpack[KwargsForCausalLM],
) -> CausalLMOutputWithPast:
) -> MoeCausalLMOutputWithPast:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,

View File

@ -257,7 +257,7 @@ class Qwen3MoeForCausalLM(MixtralForCausalLM):
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**kwargs: Unpack[KwargsForCausalLM],
) -> Union[Tuple, MoeCausalLMOutputWithPast]:
) -> MoeCausalLMOutputWithPast:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,

View File

@ -1218,7 +1218,7 @@ class SEWModel(SEWPreTrainedModel):
"""SEW Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC).""",
SEW_START_DOCSTRING,
)
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForCTC with Wav2Vec2->SEW, wav2vec2->sew, WAV_2_VEC_2->SEW
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForCTC with Wav2Vec2->SEW, wav2vec2->sew, WAV2VEC2->SEW
class SEWForCTC(SEWPreTrainedModel):
def __init__(self, config, target_lang: Optional[str] = None):
super().__init__(config)
@ -1377,7 +1377,7 @@ class SEWForCTC(SEWPreTrainedModel):
""",
SEW_START_DOCSTRING,
)
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification with Wav2Vec2->SEW, wav2vec2->sew, WAV_2_VEC_2->SEW
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification with Wav2Vec2->SEW, wav2vec2->sew, WAV2VEC2->SEW
class SEWForSequenceClassification(SEWPreTrainedModel):
def __init__(self, config):
super().__init__(config)

View File

@ -1468,7 +1468,7 @@ class SEWDModel(SEWDPreTrainedModel):
"""SEW-D Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC).""",
SEWD_START_DOCSTRING,
)
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForCTC with Wav2Vec2->SEWD, wav2vec2->sew_d, WAV_2_VEC_2->SEWD
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForCTC with Wav2Vec2->SEWD, wav2vec2->sew_d, WAV2VEC2->SEWD
class SEWDForCTC(SEWDPreTrainedModel):
def __init__(self, config, target_lang: Optional[str] = None):
super().__init__(config)
@ -1627,7 +1627,7 @@ class SEWDForCTC(SEWDPreTrainedModel):
""",
SEWD_START_DOCSTRING,
)
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification with Wav2Vec2->SEWD, wav2vec2->sew_d, WAV_2_VEC_2->SEWD
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification with Wav2Vec2->SEWD, wav2vec2->sew_d, WAV2VEC2->SEWD
class SEWDForSequenceClassification(SEWDPreTrainedModel):
def __init__(self, config):
super().__init__(config)

View File

@ -999,7 +999,7 @@ class Siglip2MultiheadAttentionPoolingHead(nn.Module):
self.mlp = Siglip2MLP(config)
self.num_heads = config.num_attention_heads
def forward(self, hidden_state: torch.Tensor, attention_mask: Optional[torch.Tensor] = None):
def forward(self, hidden_state: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
batch_size = hidden_state.shape[0]
probe = self.probe.repeat(batch_size, 1, 1)

View File

@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Tuple, Union
from typing import Optional
import torch
import torch.nn as nn
@ -243,7 +243,7 @@ class Siglip2VisionTransformer(SiglipVisionTransformer):
spatial_shapes: torch.LongTensor,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPooling]:
) -> BaseModelOutputWithPooling:
r"""
Returns:

View File

@ -19,7 +19,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ...configuration_utils import PretrainedConfig
from ...utils import logging
from ..auto import CONFIG_MAPPING, AutoConfig

View File

@ -1,19 +1,9 @@
# coding=utf-8
# Copyright 2021 The Fairseq Authors and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch UniSpeech model."""
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# This file was automatically generated from src/transformers/models/unispeech/modular_unispeech.py.
# Do NOT edit this file manually as any edits will be overwritten by the generation of
# the file from the modular. If any change should be done, please apply the change to the
# modular_unispeech.py file directly. One of our CI enforces this.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
import math
import warnings
from dataclasses import dataclass
@ -21,18 +11,22 @@ from typing import Optional, Tuple, Union
import numpy as np
import torch
import torch.utils.checkpoint
from torch import nn
import torch.nn as nn
from torch.nn import CrossEntropyLoss
from ...activations import ACT2FN
from ...integrations.deepspeed import is_deepspeed_zero3_enabled
from ...integrations.fsdp import is_fsdp_managed_module
from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available
from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput, Wav2Vec2BaseModelOutput
from ...modeling_outputs import (
BaseModelOutput,
CausalLMOutput,
ModelOutput,
SequenceClassifierOutput,
Wav2Vec2BaseModelOutput,
)
from ...modeling_utils import PreTrainedModel
from ...utils import (
ModelOutput,
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
@ -48,20 +42,12 @@ if is_flash_attn_available():
logger = logging.get_logger(__name__)
_HIDDEN_STATES_START_POSITION = 2
# Base docstring
_CHECKPOINT_FOR_DOC = "patrickvonplaten/unispeech-large-1500h-cv-timit"
# General docstring
_CONFIG_FOR_DOC = "UniSpeechConfig"
# Base docstring
_CHECKPOINT_FOR_DOC = "patrickvonplaten/unispeech-large-1500h-cv-timit"
_EXPECTED_OUTPUT_SHAPE = [1, 292, 1024]
# CTC docstring
_CTC_EXPECTED_OUTPUT = "'mister quilter is the apposl of the midle classes and weare glad to welcom his gosepl'"
_CTC_EXPECTED_LOSS = 17.17
@dataclass
class UniSpeechForPreTrainingOutput(ModelOutput):
@ -99,202 +85,17 @@ class UniSpeechForPreTrainingOutput(ModelOutput):
attentions: Optional[Tuple[torch.FloatTensor]] = None
# Copied from transformers.models.wav2vec2.modeling_wav2vec2._compute_mask_indices
def _compute_mask_indices(
shape: Tuple[int, int],
mask_prob: float,
mask_length: int,
attention_mask: Optional[torch.LongTensor] = None,
min_masks: int = 0,
) -> np.ndarray:
"""
Computes random mask spans for a given shape. Used to implement [SpecAugment: A Simple Data Augmentation Method for
ASR](https://arxiv.org/abs/1904.08779). Note that this method is not optimized to run on TPU and should be run on
CPU as part of the preprocessing during training.
Args:
shape: The shape for which to compute masks. This should be of a tuple of size 2 where
the first element is the batch size and the second element is the length of the axis to span.
mask_prob: The percentage of the whole axis (between 0 and 1) which will be masked. The number of
independently generated mask spans of length `mask_length` is computed by
`mask_prob*shape[1]/mask_length`. Note that due to overlaps, `mask_prob` is an upper bound and the
actual percentage will be smaller.
mask_length: size of the mask
min_masks: minimum number of masked spans
attention_mask: A (right-padded) attention mask which independently shortens the feature axis of
each batch dimension.
"""
batch_size, sequence_length = shape
if mask_length < 1:
raise ValueError("`mask_length` has to be bigger than 0.")
if mask_length > sequence_length:
raise ValueError(
f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length}"
f" and `sequence_length`: {sequence_length}`"
)
# epsilon is used for probabilistic rounding
epsilon = np.random.rand(1).item()
def compute_num_masked_span(input_length):
"""Given input length, compute how many spans should be masked"""
num_masked_span = int(mask_prob * input_length / mask_length + epsilon)
num_masked_span = max(num_masked_span, min_masks)
# make sure num masked span <= sequence_length
if num_masked_span * mask_length > sequence_length:
num_masked_span = sequence_length // mask_length
# make sure num_masked span is also <= input_length - (mask_length - 1)
if input_length - (mask_length - 1) < num_masked_span:
num_masked_span = max(input_length - (mask_length - 1), 0)
return num_masked_span
# compute number of masked spans in batch
input_lengths = (
attention_mask.detach().sum(-1).tolist()
if attention_mask is not None
else [sequence_length for _ in range(batch_size)]
)
# SpecAugment mask to fill
spec_aug_mask = np.zeros((batch_size, sequence_length), dtype=bool)
spec_aug_mask_idxs = []
max_num_masked_span = compute_num_masked_span(sequence_length)
if max_num_masked_span == 0:
return spec_aug_mask
for input_length in input_lengths:
# compute num of masked spans for this input
num_masked_span = compute_num_masked_span(input_length)
# get random indices to mask
spec_aug_mask_idx = np.random.choice(
np.arange(input_length - (mask_length - 1)), num_masked_span, replace=False
)
# pick first sampled index that will serve as a dummy index to pad vector
# to ensure same dimension for all batches due to probabilistic rounding
# Picking first sample just pads those vectors twice.
if len(spec_aug_mask_idx) == 0:
# this case can only happen if `input_length` is strictly smaller then
# `sequence_length` in which case the last token has to be a padding
# token which we can use as a dummy mask id
dummy_mask_idx = sequence_length - 1
else:
dummy_mask_idx = spec_aug_mask_idx[0]
spec_aug_mask_idx = np.concatenate(
[spec_aug_mask_idx, np.ones(max_num_masked_span - num_masked_span, dtype=np.int32) * dummy_mask_idx]
)
spec_aug_mask_idxs.append(spec_aug_mask_idx)
spec_aug_mask_idxs = np.array(spec_aug_mask_idxs)
# expand masked indices to masked spans
spec_aug_mask_idxs = np.broadcast_to(
spec_aug_mask_idxs[:, :, None], (batch_size, max_num_masked_span, mask_length)
)
spec_aug_mask_idxs = spec_aug_mask_idxs.reshape(batch_size, max_num_masked_span * mask_length)
# add offset to the starting indexes so that indexes now create a span
offsets = np.arange(mask_length)[None, None, :]
offsets = np.broadcast_to(offsets, (batch_size, max_num_masked_span, mask_length)).reshape(
batch_size, max_num_masked_span * mask_length
)
spec_aug_mask_idxs = spec_aug_mask_idxs + offsets
# ensure that we cannot have indices larger than sequence_length
if spec_aug_mask_idxs.max() > sequence_length - 1:
spec_aug_mask_idxs[spec_aug_mask_idxs > sequence_length - 1] = sequence_length - 1
# scatter indices to mask
np.put_along_axis(spec_aug_mask, spec_aug_mask_idxs, 1, -1)
return spec_aug_mask
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2NoLayerNormConvLayer with Wav2Vec2->UniSpeech
class UniSpeechNoLayerNormConvLayer(nn.Module):
def __init__(self, config, layer_id=0):
class UniSpeechSamePadLayer(nn.Module):
def __init__(self, num_conv_pos_embeddings):
super().__init__()
self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
self.out_conv_dim = config.conv_dim[layer_id]
self.conv = nn.Conv1d(
self.in_conv_dim,
self.out_conv_dim,
kernel_size=config.conv_kernel[layer_id],
stride=config.conv_stride[layer_id],
bias=config.conv_bias,
)
self.activation = ACT2FN[config.feat_extract_activation]
self.num_pad_remove = 1 if num_conv_pos_embeddings % 2 == 0 else 0
def forward(self, hidden_states):
hidden_states = self.conv(hidden_states)
hidden_states = self.activation(hidden_states)
if self.num_pad_remove > 0:
hidden_states = hidden_states[:, :, : -self.num_pad_remove]
return hidden_states
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2LayerNormConvLayer with Wav2Vec2->UniSpeech
class UniSpeechLayerNormConvLayer(nn.Module):
def __init__(self, config, layer_id=0):
super().__init__()
self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
self.out_conv_dim = config.conv_dim[layer_id]
self.conv = nn.Conv1d(
self.in_conv_dim,
self.out_conv_dim,
kernel_size=config.conv_kernel[layer_id],
stride=config.conv_stride[layer_id],
bias=config.conv_bias,
)
self.layer_norm = nn.LayerNorm(self.out_conv_dim, elementwise_affine=True)
self.activation = ACT2FN[config.feat_extract_activation]
def forward(self, hidden_states):
hidden_states = self.conv(hidden_states)
hidden_states = hidden_states.transpose(-2, -1)
hidden_states = self.layer_norm(hidden_states)
hidden_states = hidden_states.transpose(-2, -1)
hidden_states = self.activation(hidden_states)
return hidden_states
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2GroupNormConvLayer with Wav2Vec2->UniSpeech
class UniSpeechGroupNormConvLayer(nn.Module):
def __init__(self, config, layer_id=0):
super().__init__()
self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
self.out_conv_dim = config.conv_dim[layer_id]
self.conv = nn.Conv1d(
self.in_conv_dim,
self.out_conv_dim,
kernel_size=config.conv_kernel[layer_id],
stride=config.conv_stride[layer_id],
bias=config.conv_bias,
)
self.activation = ACT2FN[config.feat_extract_activation]
self.layer_norm = nn.GroupNorm(num_groups=self.out_conv_dim, num_channels=self.out_conv_dim, affine=True)
def forward(self, hidden_states):
hidden_states = self.conv(hidden_states)
hidden_states = self.layer_norm(hidden_states)
hidden_states = self.activation(hidden_states)
return hidden_states
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2PositionalConvEmbedding with Wav2Vec2->UniSpeech
class UniSpeechPositionalConvEmbedding(nn.Module):
def __init__(self, config):
super().__init__()
@ -340,19 +141,78 @@ class UniSpeechPositionalConvEmbedding(nn.Module):
return hidden_states
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2SamePadLayer with Wav2Vec2->UniSpeech
class UniSpeechSamePadLayer(nn.Module):
def __init__(self, num_conv_pos_embeddings):
class UniSpeechNoLayerNormConvLayer(nn.Module):
def __init__(self, config, layer_id=0):
super().__init__()
self.num_pad_remove = 1 if num_conv_pos_embeddings % 2 == 0 else 0
self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
self.out_conv_dim = config.conv_dim[layer_id]
self.conv = nn.Conv1d(
self.in_conv_dim,
self.out_conv_dim,
kernel_size=config.conv_kernel[layer_id],
stride=config.conv_stride[layer_id],
bias=config.conv_bias,
)
self.activation = ACT2FN[config.feat_extract_activation]
def forward(self, hidden_states):
if self.num_pad_remove > 0:
hidden_states = hidden_states[:, :, : -self.num_pad_remove]
hidden_states = self.conv(hidden_states)
hidden_states = self.activation(hidden_states)
return hidden_states
class UniSpeechLayerNormConvLayer(nn.Module):
def __init__(self, config, layer_id=0):
super().__init__()
self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
self.out_conv_dim = config.conv_dim[layer_id]
self.conv = nn.Conv1d(
self.in_conv_dim,
self.out_conv_dim,
kernel_size=config.conv_kernel[layer_id],
stride=config.conv_stride[layer_id],
bias=config.conv_bias,
)
self.layer_norm = nn.LayerNorm(self.out_conv_dim, elementwise_affine=True)
self.activation = ACT2FN[config.feat_extract_activation]
def forward(self, hidden_states):
hidden_states = self.conv(hidden_states)
hidden_states = hidden_states.transpose(-2, -1)
hidden_states = self.layer_norm(hidden_states)
hidden_states = hidden_states.transpose(-2, -1)
hidden_states = self.activation(hidden_states)
return hidden_states
class UniSpeechGroupNormConvLayer(nn.Module):
def __init__(self, config, layer_id=0):
super().__init__()
self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
self.out_conv_dim = config.conv_dim[layer_id]
self.conv = nn.Conv1d(
self.in_conv_dim,
self.out_conv_dim,
kernel_size=config.conv_kernel[layer_id],
stride=config.conv_stride[layer_id],
bias=config.conv_bias,
)
self.activation = ACT2FN[config.feat_extract_activation]
self.layer_norm = nn.GroupNorm(num_groups=self.out_conv_dim, num_channels=self.out_conv_dim, affine=True)
def forward(self, hidden_states):
hidden_states = self.conv(hidden_states)
hidden_states = self.layer_norm(hidden_states)
hidden_states = self.activation(hidden_states)
return hidden_states
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2FeatureEncoder with Wav2Vec2->UniSpeech
class UniSpeechFeatureEncoder(nn.Module):
"""Construct the features from raw audio waveform"""
@ -400,18 +260,6 @@ class UniSpeechFeatureEncoder(nn.Module):
return hidden_states
class UniSpeechFeatureExtractor(UniSpeechFeatureEncoder):
def __init__(self, config):
super().__init__(config)
warnings.warn(
f"The class `{self.__class__.__name__}` has been depreciated "
"and will be removed in Transformers v5. "
f"Use `{self.__class__.__bases__[0].__name__}` instead.",
FutureWarning,
)
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2FeatureProjection with Wav2Vec2->UniSpeech
class UniSpeechFeatureProjection(nn.Module):
def __init__(self, config):
super().__init__()
@ -427,7 +275,6 @@ class UniSpeechFeatureProjection(nn.Module):
return hidden_states, norm_hidden_states
# Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->UniSpeech
class UniSpeechAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
@ -586,7 +433,6 @@ class UniSpeechAttention(nn.Module):
return attn_output, attn_weights_reshaped, past_key_value
# Copied from transformers.models.bart.modeling_bart.BartFlashAttention2 with Bart->UniSpeech
class UniSpeechFlashAttention2(UniSpeechAttention):
"""
UniSpeech flash attention module. This module inherits from `UniSpeechAttention` as the weights of the module stays
@ -714,7 +560,6 @@ class UniSpeechFlashAttention2(UniSpeechAttention):
class UniSpeechSdpaAttention(UniSpeechAttention):
# Copied from transformers.models.bart.modeling_bart.BartSdpaAttention.forward with Bart->UniSpeech
def forward(
self,
hidden_states: torch.Tensor,
@ -820,14 +665,6 @@ class UniSpeechSdpaAttention(UniSpeechAttention):
return attn_output, None, past_key_value
UNISPEECH_ATTENTION_CLASSES = {
"eager": UniSpeechAttention,
"sdpa": UniSpeechSdpaAttention,
"flash_attention_2": UniSpeechFlashAttention2,
}
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2FeedForward with Wav2Vec2->UniSpeech
class UniSpeechFeedForward(nn.Module):
def __init__(self, config):
super().__init__()
@ -852,7 +689,13 @@ class UniSpeechFeedForward(nn.Module):
return hidden_states
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2EncoderLayer with Wav2Vec2->UniSpeech, WAV2VEC2->UNISPEECH
UNISPEECH_ATTENTION_CLASSES = {
"eager": UniSpeechAttention,
"sdpa": UniSpeechSdpaAttention,
"flash_attention_2": UniSpeechFlashAttention2,
}
class UniSpeechEncoderLayer(nn.Module):
def __init__(self, config):
super().__init__()
@ -888,79 +731,6 @@ class UniSpeechEncoderLayer(nn.Module):
return outputs
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2AttnAdapterLayer with Wav2Vec2->UniSpeech
class UniSpeechAttnAdapterLayer(nn.Module):
def __init__(self, config):
"""
Implements adapter modules directly with 3D tensor weight as parameters and without using ModuleList to speed
up training throughput.
"""
super().__init__()
self.input_dim = config.adapter_attn_dim
self.hidden_dim = config.hidden_size
self.norm = nn.LayerNorm(self.hidden_dim)
self.linear_1 = nn.Linear(self.hidden_dim, self.input_dim)
self.act_fn = nn.ReLU()
self.linear_2 = nn.Linear(self.input_dim, self.hidden_dim)
def forward(self, hidden_states: torch.FloatTensor):
hidden_states = self.norm(hidden_states)
hidden_states = self.linear_1(hidden_states)
hidden_states = self.act_fn(hidden_states)
hidden_states = self.linear_2(hidden_states)
return hidden_states
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2EncoderLayerStableLayerNorm with Wav2Vec2->UniSpeech, WAV2VEC2->UNISPEECH
class UniSpeechEncoderLayerStableLayerNorm(nn.Module):
def __init__(self, config):
super().__init__()
self.attention = UNISPEECH_ATTENTION_CLASSES[config._attn_implementation](
embed_dim=config.hidden_size,
num_heads=config.num_attention_heads,
dropout=config.attention_dropout,
is_decoder=False,
)
self.dropout = nn.Dropout(config.hidden_dropout)
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.feed_forward = UniSpeechFeedForward(config)
self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
if getattr(config, "adapter_attn_dim", None) is not None:
self.adapter_layer = UniSpeechAttnAdapterLayer(config)
else:
self.adapter_layer = None
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
):
attn_residual = hidden_states
hidden_states = self.layer_norm(hidden_states)
hidden_states, attn_weights, _ = self.attention(
hidden_states, attention_mask=attention_mask, output_attentions=output_attentions
)
hidden_states = self.dropout(hidden_states)
hidden_states = attn_residual + hidden_states
hidden_states = hidden_states + self.feed_forward(self.final_layer_norm(hidden_states))
if self.adapter_layer is not None:
hidden_states = hidden_states + self.adapter_layer(hidden_states)
outputs = (hidden_states,)
if output_attentions:
outputs += (attn_weights,)
return outputs
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Encoder with Wav2Vec2->UniSpeech
class UniSpeechEncoder(nn.Module):
def __init__(self, config):
super().__init__()
@ -1046,7 +816,76 @@ class UniSpeechEncoder(nn.Module):
)
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2EncoderStableLayerNorm with Wav2Vec2->UniSpeech
class UniSpeechAttnAdapterLayer(nn.Module):
def __init__(self, config):
"""
Implements adapter modules directly with 3D tensor weight as parameters and without using ModuleList to speed
up training throughput.
"""
super().__init__()
self.input_dim = config.adapter_attn_dim
self.hidden_dim = config.hidden_size
self.norm = nn.LayerNorm(self.hidden_dim)
self.linear_1 = nn.Linear(self.hidden_dim, self.input_dim)
self.act_fn = nn.ReLU()
self.linear_2 = nn.Linear(self.input_dim, self.hidden_dim)
def forward(self, hidden_states: torch.FloatTensor):
hidden_states = self.norm(hidden_states)
hidden_states = self.linear_1(hidden_states)
hidden_states = self.act_fn(hidden_states)
hidden_states = self.linear_2(hidden_states)
return hidden_states
class UniSpeechEncoderLayerStableLayerNorm(nn.Module):
def __init__(self, config):
super().__init__()
self.attention = UNISPEECH_ATTENTION_CLASSES[config._attn_implementation](
embed_dim=config.hidden_size,
num_heads=config.num_attention_heads,
dropout=config.attention_dropout,
is_decoder=False,
)
self.dropout = nn.Dropout(config.hidden_dropout)
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.feed_forward = UniSpeechFeedForward(config)
self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
if getattr(config, "adapter_attn_dim", None) is not None:
self.adapter_layer = UniSpeechAttnAdapterLayer(config)
else:
self.adapter_layer = None
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
):
attn_residual = hidden_states
hidden_states = self.layer_norm(hidden_states)
hidden_states, attn_weights, _ = self.attention(
hidden_states, attention_mask=attention_mask, output_attentions=output_attentions
)
hidden_states = self.dropout(hidden_states)
hidden_states = attn_residual + hidden_states
hidden_states = hidden_states + self.feed_forward(self.final_layer_norm(hidden_states))
if self.adapter_layer is not None:
hidden_states = hidden_states + self.adapter_layer(hidden_states)
outputs = (hidden_states,)
if output_attentions:
outputs += (attn_weights,)
return outputs
class UniSpeechEncoderStableLayerNorm(nn.Module):
def __init__(self, config):
super().__init__()
@ -1138,7 +977,7 @@ class UniSpeechEncoderStableLayerNorm(nn.Module):
class UniSpeechGumbelVectorQuantizer(nn.Module):
"""
Vector quantization using gumbel softmax. See [CATEGORICAL REPARAMETERIZATION WITH
Vector quantization using gumbel softmax. See `[CATEGORICAL REPARAMETERIZATION WITH
GUMBEL-SOFTMAX](https://arxiv.org/pdf/1611.01144.pdf) for more information.
"""
@ -1149,8 +988,8 @@ class UniSpeechGumbelVectorQuantizer(nn.Module):
if config.codevector_dim % self.num_groups != 0:
raise ValueError(
f"`config.codevector_dim {config.codevector_dim} must be divisible by `config.num_codevector_groups`"
f" {self.num_groups} for concatenation"
f"`config.codevector_dim {config.codevector_dim} must be divisible "
f"by `config.num_codevector_groups` {self.num_groups} for concatenation"
)
# storage for codebook variables (codewords)
@ -1283,6 +1122,128 @@ class UniSpeechPreTrainedModel(PreTrainedModel):
return attention_mask
def _compute_mask_indices(
shape: Tuple[int, int],
mask_prob: float,
mask_length: int,
attention_mask: Optional[torch.LongTensor] = None,
min_masks: int = 0,
) -> np.ndarray:
"""
Computes random mask spans for a given shape. Used to implement [SpecAugment: A Simple Data Augmentation Method for
ASR](https://arxiv.org/abs/1904.08779). Note that this method is not optimized to run on TPU and should be run on
CPU as part of the preprocessing during training.
Args:
shape: The shape for which to compute masks. This should be of a tuple of size 2 where
the first element is the batch size and the second element is the length of the axis to span.
mask_prob: The percentage of the whole axis (between 0 and 1) which will be masked. The number of
independently generated mask spans of length `mask_length` is computed by
`mask_prob*shape[1]/mask_length`. Note that due to overlaps, `mask_prob` is an upper bound and the
actual percentage will be smaller.
mask_length: size of the mask
min_masks: minimum number of masked spans
attention_mask: A (right-padded) attention mask which independently shortens the feature axis of
each batch dimension.
"""
batch_size, sequence_length = shape
if mask_length < 1:
raise ValueError("`mask_length` has to be bigger than 0.")
if mask_length > sequence_length:
raise ValueError(
f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length}"
f" and `sequence_length`: {sequence_length}`"
)
# epsilon is used for probabilistic rounding
epsilon = np.random.rand(1).item()
def compute_num_masked_span(input_length):
"""Given input length, compute how many spans should be masked"""
num_masked_span = int(mask_prob * input_length / mask_length + epsilon)
num_masked_span = max(num_masked_span, min_masks)
# make sure num masked span <= sequence_length
if num_masked_span * mask_length > sequence_length:
num_masked_span = sequence_length // mask_length
# make sure num_masked span is also <= input_length - (mask_length - 1)
if input_length - (mask_length - 1) < num_masked_span:
num_masked_span = max(input_length - (mask_length - 1), 0)
return num_masked_span
# compute number of masked spans in batch
input_lengths = (
attention_mask.detach().sum(-1).tolist()
if attention_mask is not None
else [sequence_length for _ in range(batch_size)]
)
# SpecAugment mask to fill
spec_aug_mask = np.zeros((batch_size, sequence_length), dtype=bool)
spec_aug_mask_idxs = []
max_num_masked_span = compute_num_masked_span(sequence_length)
if max_num_masked_span == 0:
return spec_aug_mask
for input_length in input_lengths:
# compute num of masked spans for this input
num_masked_span = compute_num_masked_span(input_length)
# get random indices to mask
spec_aug_mask_idx = np.random.choice(
np.arange(input_length - (mask_length - 1)), num_masked_span, replace=False
)
# pick first sampled index that will serve as a dummy index to pad vector
# to ensure same dimension for all batches due to probabilistic rounding
# Picking first sample just pads those vectors twice.
if len(spec_aug_mask_idx) == 0:
# this case can only happen if `input_length` is strictly smaller then
# `sequence_length` in which case the last token has to be a padding
# token which we can use as a dummy mask id
dummy_mask_idx = sequence_length - 1
else:
dummy_mask_idx = spec_aug_mask_idx[0]
spec_aug_mask_idx = np.concatenate(
[spec_aug_mask_idx, np.ones(max_num_masked_span - num_masked_span, dtype=np.int32) * dummy_mask_idx]
)
spec_aug_mask_idxs.append(spec_aug_mask_idx)
spec_aug_mask_idxs = np.array(spec_aug_mask_idxs)
# expand masked indices to masked spans
spec_aug_mask_idxs = np.broadcast_to(
spec_aug_mask_idxs[:, :, None], (batch_size, max_num_masked_span, mask_length)
)
spec_aug_mask_idxs = spec_aug_mask_idxs.reshape(batch_size, max_num_masked_span * mask_length)
# add offset to the starting indexes so that indexes now create a span
offsets = np.arange(mask_length)[None, None, :]
offsets = np.broadcast_to(offsets, (batch_size, max_num_masked_span, mask_length)).reshape(
batch_size, max_num_masked_span * mask_length
)
spec_aug_mask_idxs = spec_aug_mask_idxs + offsets
# ensure that we cannot have indices larger than sequence_length
if spec_aug_mask_idxs.max() > sequence_length - 1:
spec_aug_mask_idxs[spec_aug_mask_idxs > sequence_length - 1] = sequence_length - 1
# scatter indices to mask
np.put_along_axis(spec_aug_mask, spec_aug_mask_idxs, 1, -1)
return spec_aug_mask
_EXPECTED_OUTPUT_SHAPE = [1, 292, 1024]
UNISPEECH_START_DOCSTRING = r"""
UniSpeech was proposed in [UniSpeech: Unified Speech Representation Learning with Labeled and Unlabeled
Data](https://arxiv.org/abs/2101.07597) by Chengyi Wang, Yu Wu, Yao Qian, Kenichi Kumatani, Shujie Liu, Furu Wei,
@ -1301,7 +1262,6 @@ UNISPEECH_START_DOCSTRING = r"""
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
UNISPEECH_INPUTS_DOCSTRING = r"""
Args:
input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
@ -1339,6 +1299,9 @@ UNISPEECH_INPUTS_DOCSTRING = r"""
"""
UniSpeechBaseModelOutput = Wav2Vec2BaseModelOutput
@add_start_docstrings(
"The bare UniSpeech Model transformer outputting raw hidden-states without any specific head on top.",
UNISPEECH_START_DOCSTRING,
@ -1361,7 +1324,6 @@ class UniSpeechModel(UniSpeechPreTrainedModel):
# Initialize weights and apply final processing
self.post_init()
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Model._mask_hidden_states
def _mask_hidden_states(
self,
hidden_states: torch.FloatTensor,
@ -1411,7 +1373,7 @@ class UniSpeechModel(UniSpeechPreTrainedModel):
@add_start_docstrings_to_model_forward(UNISPEECH_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=Wav2Vec2BaseModelOutput,
output_type=UniSpeechBaseModelOutput,
config_class=_CONFIG_FOR_DOC,
modality="audio",
expected_output=_EXPECTED_OUTPUT_SHAPE,
@ -1424,7 +1386,7 @@ class UniSpeechModel(UniSpeechPreTrainedModel):
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, Wav2Vec2BaseModelOutput]:
) -> Union[Tuple, UniSpeechBaseModelOutput]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
@ -1456,7 +1418,7 @@ class UniSpeechModel(UniSpeechPreTrainedModel):
if not return_dict:
return (hidden_states, extract_features) + encoder_outputs[1:]
return Wav2Vec2BaseModelOutput(
return UniSpeechBaseModelOutput(
last_hidden_state=hidden_states,
extract_features=extract_features,
hidden_states=encoder_outputs.hidden_states,
@ -1610,6 +1572,13 @@ class UniSpeechForPreTraining(UniSpeechPreTrainedModel):
)
_HIDDEN_STATES_START_POSITION = 2
# CTC docstring
_CTC_EXPECTED_OUTPUT = "'mister quilter is the apposl of the midle classes and weare glad to welcom his gosepl'"
_CTC_EXPECTED_LOSS = 17.17
@add_start_docstrings(
"""UniSpeech Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC).""",
UNISPEECH_START_DOCSTRING,
@ -1620,7 +1589,6 @@ class UniSpeechForPreTraining(UniSpeechPreTrainedModel):
by default.
""",
)
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForCTC with Wav2Vec2->UniSpeech, wav2vec2->unispeech, WAV_2_VEC_2->UNISPEECH
class UniSpeechForCTC(UniSpeechPreTrainedModel):
def __init__(self, config, target_lang: Optional[str] = None):
super().__init__(config)
@ -1797,7 +1765,6 @@ class UniSpeechForSequenceClassification(UniSpeechPreTrainedModel):
# Initialize weights and apply final processing
self.post_init()
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.freeze_feature_extractor
def freeze_feature_extractor(self):
"""
Calling this function will disable the gradient computation for the feature encoder so that its parameters will
@ -1810,7 +1777,6 @@ class UniSpeechForSequenceClassification(UniSpeechPreTrainedModel):
)
self.freeze_feature_encoder()
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.freeze_feature_encoder with wav2vec2->unispeech
def freeze_feature_encoder(self):
"""
Calling this function will disable the gradient computation for the feature encoder so that its parameter will
@ -1818,7 +1784,6 @@ class UniSpeechForSequenceClassification(UniSpeechPreTrainedModel):
"""
self.unispeech.feature_extractor._freeze_parameters()
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.freeze_base_model with wav2vec2->unispeech
def freeze_base_model(self):
"""
Calling this function will disable the gradient computation for the base model so that its parameters will not
@ -1834,7 +1799,6 @@ class UniSpeechForSequenceClassification(UniSpeechPreTrainedModel):
config_class=_CONFIG_FOR_DOC,
modality="audio",
)
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.forward with Wav2Vec2->UniSpeech, wav2vec2->unispeech
def forward(
self,
input_values: Optional[torch.Tensor],

View File

@ -0,0 +1,563 @@
import math
import warnings
from dataclasses import dataclass
from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
from ...modeling_outputs import CausalLMOutput, ModelOutput, SequenceClassifierOutput, Wav2Vec2BaseModelOutput
from ...modeling_utils import PreTrainedModel
from ...utils import (
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
logging,
replace_return_docstrings,
)
from ..wav2vec2.modeling_wav2vec2 import (
Wav2Vec2Encoder,
Wav2Vec2EncoderStableLayerNorm,
Wav2Vec2FeatureEncoder,
Wav2Vec2FeatureProjection,
Wav2Vec2ForCTC,
Wav2Vec2ForSequenceClassification,
Wav2Vec2GumbelVectorQuantizer,
Wav2Vec2Model,
Wav2Vec2PositionalConvEmbedding,
)
from .configuration_unispeech import UniSpeechConfig
logger = logging.get_logger(__name__)
_HIDDEN_STATES_START_POSITION = 2
# General docstring
_CONFIG_FOR_DOC = "UniSpeechConfig"
# Base docstring
_CHECKPOINT_FOR_DOC = "patrickvonplaten/unispeech-large-1500h-cv-timit"
_EXPECTED_OUTPUT_SHAPE = [1, 292, 1024]
# CTC docstring
_CTC_EXPECTED_OUTPUT = "'mister quilter is the apposl of the midle classes and weare glad to welcom his gosepl'"
_CTC_EXPECTED_LOSS = 17.17
@dataclass
class UniSpeechForPreTrainingOutput(ModelOutput):
"""
Output type of [`UniSpeechForPreTrainingOutput`], with potential hidden states and attentions.
Args:
loss (*optional*, returned when model is in train mode, `torch.FloatTensor` of shape `(1,)`):
Total loss as the sum of the contrastive loss (L_m) and the diversity loss (L_d) as stated in the [official
paper](https://arxiv.org/pdf/2006.11477.pdf) . (classification) loss.
projected_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.proj_codevector_dim)`):
Hidden-states of the model projected to *config.proj_codevector_dim* that can be used to predict the masked
projected quantized states.
projected_quantized_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.proj_codevector_dim)`):
Quantized extracted feature vectors projected to *config.proj_codevector_dim* representing the positive
target vectors for contrastive loss.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
loss: Optional[torch.FloatTensor] = None
projected_states: Optional[torch.FloatTensor] = None
projected_quantized_states: Optional[torch.FloatTensor] = None
codevector_perplexity: Optional[torch.FloatTensor] = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
class UniSpeechPositionalConvEmbedding(Wav2Vec2PositionalConvEmbedding):
pass
class UniSpeechFeatureEncoder(Wav2Vec2FeatureEncoder):
pass
class UniSpeechFeatureProjection(Wav2Vec2FeatureProjection):
pass
class UniSpeechEncoder(Wav2Vec2Encoder):
pass
class UniSpeechEncoderStableLayerNorm(Wav2Vec2EncoderStableLayerNorm):
pass
class UniSpeechGumbelVectorQuantizer(Wav2Vec2GumbelVectorQuantizer):
@staticmethod
def _compute_perplexity(probs):
marginal_probs = probs.mean(dim=0)
perplexity = torch.exp(-torch.sum(marginal_probs * torch.log(marginal_probs + 1e-7), dim=-1)).sum()
return perplexity
def forward(self, hidden_states):
batch_size, sequence_length, hidden_size = hidden_states.shape
# project to codevector dim
hidden_states = self.weight_proj(hidden_states)
hidden_states = hidden_states.view(batch_size * sequence_length * self.num_groups, -1)
if self.training:
# sample code vector probs via gumbel in differentiateable way
codevector_probs = nn.functional.gumbel_softmax(
hidden_states.float(), tau=self.temperature, hard=True
).type_as(hidden_states)
# compute perplexity
codevector_soft_dist = torch.softmax(
hidden_states.view(batch_size * sequence_length, self.num_groups, -1).float(), dim=-1
)
perplexity = self._compute_perplexity(codevector_soft_dist)
else:
# take argmax in non-differentiable way
# comptute hard codevector distribution (one hot)
codevector_idx = hidden_states.argmax(dim=-1)
codevector_probs = hidden_states.new_zeros(*hidden_states.shape).scatter_(
-1, codevector_idx.view(-1, 1), 1.0
)
codevector_probs = codevector_probs.view(batch_size * sequence_length, self.num_groups, -1)
perplexity = self._compute_perplexity(codevector_probs)
codevector_probs = codevector_probs.view(batch_size * sequence_length, -1)
# use probs to retrieve codevectors
codevectors_per_group = codevector_probs.unsqueeze(-1) * self.codevectors
codevectors = codevectors_per_group.view(batch_size * sequence_length, self.num_groups, self.num_vars, -1)
codevectors = codevectors.sum(-2).view(batch_size, sequence_length, -1)
return codevectors, perplexity
class UniSpeechPreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = UniSpeechConfig
base_model_prefix = "unispeech"
main_input_name = "input_values"
supports_gradient_checkpointing = True
_supports_flash_attn_2 = True
_supports_sdpa = True
def _init_weights(self, module):
"""Initialize the weights"""
# gumbel softmax requires special init
if isinstance(module, UniSpeechGumbelVectorQuantizer):
module.weight_proj.weight.data.normal_(mean=0.0, std=1)
module.weight_proj.bias.data.zero_()
nn.init.uniform_(module.codevectors)
elif isinstance(module, UniSpeechPositionalConvEmbedding):
nn.init.normal_(
module.conv.weight,
mean=0,
std=2 * math.sqrt(1 / (module.conv.kernel_size[0] * module.conv.in_channels)),
)
nn.init.constant_(module.conv.bias, 0)
elif isinstance(module, UniSpeechFeatureProjection):
k = math.sqrt(1 / module.projection.in_features)
nn.init.uniform_(module.projection.weight, a=-k, b=k)
nn.init.uniform_(module.projection.bias, a=-k, b=k)
elif isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
elif isinstance(module, nn.Conv1d):
nn.init.kaiming_normal_(module.weight)
if module.bias is not None:
k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0]))
nn.init.uniform_(module.bias, a=-k, b=k)
def _get_feat_extract_output_lengths(self, input_lengths: Union[torch.LongTensor, int]):
"""
Computes the output length of the convolutional layers
"""
def _conv_out_length(input_length, kernel_size, stride):
# 1D convolutional layer output length formula taken
# from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
return torch.div(input_length - kernel_size, stride, rounding_mode="floor") + 1
for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride):
input_lengths = _conv_out_length(input_lengths, kernel_size, stride)
return input_lengths
def _get_feature_vector_attention_mask(self, feature_vector_length: int, attention_mask: torch.LongTensor):
# Effectively attention_mask.sum(-1), but not inplace to be able to run
# on inference mode.
non_padded_lengths = attention_mask.cumsum(dim=-1)[:, -1]
output_lengths = self._get_feat_extract_output_lengths(non_padded_lengths).to(torch.long)
batch_size = attention_mask.shape[0]
attention_mask = torch.zeros(
(batch_size, feature_vector_length), dtype=attention_mask.dtype, device=attention_mask.device
)
# these two operations makes sure that all values before the output lengths idxs are attended to
attention_mask[(torch.arange(attention_mask.shape[0], device=attention_mask.device), output_lengths - 1)] = 1
attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool()
return attention_mask
UNISPEECH_START_DOCSTRING = r"""
UniSpeech was proposed in [UniSpeech: Unified Speech Representation Learning with Labeled and Unlabeled
Data](https://arxiv.org/abs/2101.07597) by Chengyi Wang, Yu Wu, Yao Qian, Kenichi Kumatani, Shujie Liu, Furu Wei,
Michael Zeng, Xuedong Huang.
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving etc.).
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use
it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
behavior.
Parameters:
config ([`UniSpeechConfig`]): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
UNISPEECH_INPUTS_DOCSTRING = r"""
Args:
input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
Float values of input raw speech waveform. Values can be obtained by loading a `.flac` or `.wav` audio file
into an array of type `List[float]` or a `numpy.ndarray`, *e.g.* via the soundfile library (`pip install
soundfile`). To prepare the array into `input_values`, the [`AutoProcessor`] should be used for padding and
conversion into a tensor of type `torch.FloatTensor`. See [`Wav2Vec2Processor.__call__`] for details.
attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing convolution and attention on padding token indices. Mask values selected in `[0,
1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
<Tip warning={true}>
`attention_mask` should only be passed if the corresponding processor has `config.return_attention_mask ==
True`. For all models whose processor has `config.return_attention_mask == False`, `attention_mask` should
**not** be passed to avoid degraded performance when doing batched inference. For such models
`input_values` should simply be padded with 0 and passed without `attention_mask`. Be aware that these
models also yield slightly different results depending on whether `input_values` is padded or not.
</Tip>
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
UniSpeechBaseModelOutput = Wav2Vec2BaseModelOutput
@add_start_docstrings(
"The bare UniSpeech Model transformer outputting raw hidden-states without any specific head on top.",
UNISPEECH_START_DOCSTRING,
)
class UniSpeechModel(UniSpeechPreTrainedModel, Wav2Vec2Model):
def __init__(self, config: UniSpeechConfig):
UniSpeechPreTrainedModel.__init__(config)
self.config = config
self.feature_extractor = UniSpeechFeatureEncoder(config)
self.feature_projection = UniSpeechFeatureProjection(config)
if config.mask_time_prob > 0.0 or config.mask_feature_prob > 0.0:
self.masked_spec_embed = nn.Parameter(torch.Tensor(config.hidden_size).uniform_())
if config.do_stable_layer_norm:
self.encoder = UniSpeechEncoderStableLayerNorm(config)
else:
self.encoder = UniSpeechEncoder(config)
# Initialize weights and apply final processing
self.post_init()
def freeze_feature_extractor(self):
raise AttributeError("Not needed for UniSpeech")
def freeze_feature_encoder(self):
raise AttributeError("Not needed for UniSpeech")
@add_start_docstrings_to_model_forward(UNISPEECH_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=UniSpeechBaseModelOutput,
config_class=_CONFIG_FOR_DOC,
modality="audio",
expected_output=_EXPECTED_OUTPUT_SHAPE,
)
def forward(
self,
input_values: Optional[torch.Tensor],
attention_mask: Optional[torch.Tensor] = None,
mask_time_indices: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, UniSpeechBaseModelOutput]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
extract_features = self.feature_extractor(input_values)
extract_features = extract_features.transpose(1, 2)
if attention_mask is not None:
# compute reduced attention_mask corresponding to feature vectors
attention_mask = self._get_feature_vector_attention_mask(extract_features.shape[1], attention_mask)
hidden_states, extract_features = self.feature_projection(extract_features)
hidden_states = self._mask_hidden_states(
hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask
)
encoder_outputs = self.encoder(
hidden_states,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = encoder_outputs[0]
if not return_dict:
return (hidden_states, extract_features) + encoder_outputs[1:]
return UniSpeechBaseModelOutput(
last_hidden_state=hidden_states,
extract_features=extract_features,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)
@add_start_docstrings(
"""UniSpeech Model with a vector-quantization module and ctc loss for pre-training.""", UNISPEECH_START_DOCSTRING
)
class UniSpeechForPreTraining(UniSpeechPreTrainedModel):
def __init__(self, config: UniSpeechConfig):
super().__init__(config)
self.unispeech = UniSpeechModel(config)
self.dropout_features = nn.Dropout(config.feat_quantizer_dropout)
self.quantizer = UniSpeechGumbelVectorQuantizer(config)
self.project_q = nn.Linear(config.codevector_dim, config.proj_codevector_dim)
self.project_hid = nn.Linear(config.proj_codevector_dim, config.hidden_size)
self.ctc_proj = nn.Linear(config.hidden_size, config.num_ctc_classes)
self.dropout = nn.Dropout(config.final_dropout)
# Initialize weights and apply final processing
self.post_init()
def set_gumbel_temperature(self, temperature: int):
"""
Set the Gumbel softmax temperature to a given value. Only necessary for training
"""
self.quantizer.temperature = temperature
def freeze_feature_extractor(self):
"""
Calling this function will disable the gradient computation for the feature encoder so that its parameters will
not be updated during training.
"""
warnings.warn(
"The method `freeze_feature_extractor` is deprecated and will be removed in Transformers v5. "
"Please use the equivalent `freeze_feature_encoder` method instead.",
FutureWarning,
)
self.freeze_feature_encoder()
def freeze_feature_encoder(self):
"""
Calling this function will disable the gradient computation for the feature encoder so that its parameter will
not be updated during training.
"""
self.unispeech.feature_extractor._freeze_parameters()
@staticmethod
def compute_contrastive_logits(
target_features: torch.FloatTensor,
negative_features: torch.FloatTensor,
predicted_features: torch.FloatTensor,
temperature: int = 1,
):
"""
Compute logits for contrastive loss based using cosine similarity as the distance measure between
`[positive_feature, negative_features]` and `[predicted_features]`. Additionally, temperature can be applied.
"""
target_features = torch.cat([target_features, negative_features], dim=0)
logits = torch.cosine_similarity(predicted_features.float(), target_features.float(), dim=-1)
logits = logits.type_as(target_features)
# apply temperature
logits = logits / temperature
return logits
@add_start_docstrings_to_model_forward(UNISPEECH_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=UniSpeechForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_values: Optional[torch.Tensor],
attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, UniSpeechForPreTrainingOutput]:
r"""
mask_time_indices (`torch.BoolTensor` of shape `(batch_size, sequence_length)`, *optional*):
Indices to mask extracted features for contrastive loss. When in training mode, model learns to predict
masked extracted features in *config.proj_codevector_dim* space.
sampled_negative_indices (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_negatives)`, *optional*):
Indices indicating which quantized target vectors are used as negative sampled vectors in contrastive loss.
Required input for pre-training.
Returns:
Example:
```python
>>> import torch
>>> from transformers import AutoFeatureExtractor, UniSpeechForPreTraining
>>> feature_extractor = AutoFeatureExtractor.from_pretrained("microsoft/unispeech-large-1500h-cv")
>>> model = UniSpeechForPreTraining.from_pretrained("microsoft/unispeech-large-1500h-cv")
>>> # TODO: Add full pretraining example
```"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.unispeech(
input_values,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
transformer_features = outputs[0]
# quantize all (unmasked) extracted features and project to final vq dim
extract_features = self.dropout_features(outputs[1])
quantized_features, codevector_perplexity = self.quantizer(extract_features)
# project quantized features twice
quantized_features = self.project_q(quantized_features.to(self.project_q.weight.dtype))
quantized_features = self.project_hid(quantized_features)
prob_replace_matrix = torch.empty(transformer_features.size(0), transformer_features.size(1)).fill_(
self.config.replace_prob
)
prob_replace_matrix = prob_replace_matrix.transpose(0, 1)
sampled_replace_matrix = torch.bernoulli(prob_replace_matrix).bool().to(transformer_features.device)
sampled_replace_matrix = sampled_replace_matrix.transpose(0, 1)
sampled_replace_matrix = sampled_replace_matrix.unsqueeze(-1)
logits = transformer_features.masked_fill(sampled_replace_matrix, 0.0) + (
quantized_features.masked_fill(~sampled_replace_matrix, 0.0)
)
# project to ctc units
logits = self.dropout(logits)
logits = self.ctc_proj(logits)
# TODO(PVP) - add negative sampling & loss computation
loss = None
if not return_dict:
if loss is not None:
return (loss, transformer_features, quantized_features, codevector_perplexity) + outputs[2:]
return (transformer_features, quantized_features, codevector_perplexity) + outputs[2:]
return UniSpeechForPreTrainingOutput(
loss=loss,
projected_states=transformer_features,
projected_quantized_states=quantized_features,
codevector_perplexity=codevector_perplexity,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
@add_start_docstrings(
"""UniSpeech Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC).""",
UNISPEECH_START_DOCSTRING,
"""
target_lang (`str`, *optional*):
Language id of adapter weights. Adapter weights are stored in the format adapter.<lang>.safetensors or
adapter.<lang>.bin. Only relevant when using an instance of [`UniSpeechForCTC`] with adapters. Uses 'eng'
by default.
""",
)
class UniSpeechForCTC(Wav2Vec2ForCTC):
@add_start_docstrings_to_model_forward(UNISPEECH_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=CausalLMOutput,
config_class=_CONFIG_FOR_DOC,
expected_output=_CTC_EXPECTED_OUTPUT,
expected_loss=_CTC_EXPECTED_LOSS,
)
def forward(self, **super_kwargs):
super().forward(**super_kwargs)
@add_start_docstrings(
"""
UniSpeech Model with a sequence classification head on top (a linear layer over the pooled output) for tasks like
SUPERB Keyword Spotting.
""",
UNISPEECH_START_DOCSTRING,
)
class UniSpeechForSequenceClassification(Wav2Vec2ForSequenceClassification):
@add_start_docstrings_to_model_forward(UNISPEECH_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=SequenceClassifierOutput,
config_class=_CONFIG_FOR_DOC,
modality="audio",
)
def forward(self, **super_kwargs):
super().forward(**super_kwargs)
__all__ = [
"UniSpeechForCTC",
"UniSpeechForPreTraining",
"UniSpeechForSequenceClassification",
"UniSpeechModel",
"UniSpeechPreTrainedModel",
]

View File

@ -1,19 +1,9 @@
# coding=utf-8
# Copyright 2021 The Fairseq Authors and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch UniSpeechSat model."""
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# This file was automatically generated from src/transformers/models/unispeech_sat/modular_unispeech_sat.py.
# Do NOT edit this file manually as any edits will be overwritten by the generation of
# the file from the modular. If any change should be done, please apply the change to the
# modular_unispeech_sat.py file directly. One of our CI enforces this.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
import math
import warnings
from dataclasses import dataclass
@ -21,8 +11,7 @@ from typing import Optional, Tuple, Union
import numpy as np
import torch
import torch.utils.checkpoint
from torch import nn
import torch.nn as nn
from torch.nn import CrossEntropyLoss
from ...activations import ACT2FN
@ -32,6 +21,7 @@ from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask,
from ...modeling_outputs import (
BaseModelOutput,
CausalLMOutput,
ModelOutput,
SequenceClassifierOutput,
TokenClassifierOutput,
Wav2Vec2BaseModelOutput,
@ -39,7 +29,6 @@ from ...modeling_outputs import (
)
from ...modeling_utils import PreTrainedModel
from ...utils import (
ModelOutput,
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
@ -56,28 +45,12 @@ if is_flash_attn_available():
logger = logging.get_logger(__name__)
_HIDDEN_STATES_START_POSITION = 2
# Base docstring
_CHECKPOINT_FOR_DOC = "microsoft/unispeech-sat-base-100h-libri-ft"
# General docstring
_CONFIG_FOR_DOC = "UniSpeechSatConfig"
# Base docstring
_CHECKPOINT_FOR_DOC = "microsoft/unispeech-sat-base-100h-libri-ft"
_EXPECTED_OUTPUT_SHAPE = [1, 292, 768]
# CTC docstring
_CTC_EXPECTED_OUTPUT = "'MISTER QUILDER IS THE APOSTLE OF THE MIDDLE CLASSES AND WE ARE GLAD TO WELCOME HIS GOSPEL'"
_CTC_EXPECTED_LOSS = 39.88
# Frame class docstring
_FRAME_CLASS_CHECKPOINT = "microsoft/unispeech-sat-base-plus-sd"
_FRAME_EXPECTED_OUTPUT = [0, 0]
# Speaker Verification docstring
_XVECTOR_CHECKPOINT = "microsoft/unispeech-sat-base-plus-sv"
_XVECTOR_EXPECTED_OUTPUT = 0.97
@dataclass
class UniSpeechSatForPreTrainingOutput(ModelOutput):
@ -116,202 +89,17 @@ class UniSpeechSatForPreTrainingOutput(ModelOutput):
attentions: Optional[Tuple[torch.FloatTensor]] = None
# Copied from transformers.models.wav2vec2.modeling_wav2vec2._compute_mask_indices
def _compute_mask_indices(
shape: Tuple[int, int],
mask_prob: float,
mask_length: int,
attention_mask: Optional[torch.LongTensor] = None,
min_masks: int = 0,
) -> np.ndarray:
"""
Computes random mask spans for a given shape. Used to implement [SpecAugment: A Simple Data Augmentation Method for
ASR](https://arxiv.org/abs/1904.08779). Note that this method is not optimized to run on TPU and should be run on
CPU as part of the preprocessing during training.
Args:
shape: The shape for which to compute masks. This should be of a tuple of size 2 where
the first element is the batch size and the second element is the length of the axis to span.
mask_prob: The percentage of the whole axis (between 0 and 1) which will be masked. The number of
independently generated mask spans of length `mask_length` is computed by
`mask_prob*shape[1]/mask_length`. Note that due to overlaps, `mask_prob` is an upper bound and the
actual percentage will be smaller.
mask_length: size of the mask
min_masks: minimum number of masked spans
attention_mask: A (right-padded) attention mask which independently shortens the feature axis of
each batch dimension.
"""
batch_size, sequence_length = shape
if mask_length < 1:
raise ValueError("`mask_length` has to be bigger than 0.")
if mask_length > sequence_length:
raise ValueError(
f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length}"
f" and `sequence_length`: {sequence_length}`"
)
# epsilon is used for probabilistic rounding
epsilon = np.random.rand(1).item()
def compute_num_masked_span(input_length):
"""Given input length, compute how many spans should be masked"""
num_masked_span = int(mask_prob * input_length / mask_length + epsilon)
num_masked_span = max(num_masked_span, min_masks)
# make sure num masked span <= sequence_length
if num_masked_span * mask_length > sequence_length:
num_masked_span = sequence_length // mask_length
# make sure num_masked span is also <= input_length - (mask_length - 1)
if input_length - (mask_length - 1) < num_masked_span:
num_masked_span = max(input_length - (mask_length - 1), 0)
return num_masked_span
# compute number of masked spans in batch
input_lengths = (
attention_mask.detach().sum(-1).tolist()
if attention_mask is not None
else [sequence_length for _ in range(batch_size)]
)
# SpecAugment mask to fill
spec_aug_mask = np.zeros((batch_size, sequence_length), dtype=bool)
spec_aug_mask_idxs = []
max_num_masked_span = compute_num_masked_span(sequence_length)
if max_num_masked_span == 0:
return spec_aug_mask
for input_length in input_lengths:
# compute num of masked spans for this input
num_masked_span = compute_num_masked_span(input_length)
# get random indices to mask
spec_aug_mask_idx = np.random.choice(
np.arange(input_length - (mask_length - 1)), num_masked_span, replace=False
)
# pick first sampled index that will serve as a dummy index to pad vector
# to ensure same dimension for all batches due to probabilistic rounding
# Picking first sample just pads those vectors twice.
if len(spec_aug_mask_idx) == 0:
# this case can only happen if `input_length` is strictly smaller then
# `sequence_length` in which case the last token has to be a padding
# token which we can use as a dummy mask id
dummy_mask_idx = sequence_length - 1
else:
dummy_mask_idx = spec_aug_mask_idx[0]
spec_aug_mask_idx = np.concatenate(
[spec_aug_mask_idx, np.ones(max_num_masked_span - num_masked_span, dtype=np.int32) * dummy_mask_idx]
)
spec_aug_mask_idxs.append(spec_aug_mask_idx)
spec_aug_mask_idxs = np.array(spec_aug_mask_idxs)
# expand masked indices to masked spans
spec_aug_mask_idxs = np.broadcast_to(
spec_aug_mask_idxs[:, :, None], (batch_size, max_num_masked_span, mask_length)
)
spec_aug_mask_idxs = spec_aug_mask_idxs.reshape(batch_size, max_num_masked_span * mask_length)
# add offset to the starting indexes so that indexes now create a span
offsets = np.arange(mask_length)[None, None, :]
offsets = np.broadcast_to(offsets, (batch_size, max_num_masked_span, mask_length)).reshape(
batch_size, max_num_masked_span * mask_length
)
spec_aug_mask_idxs = spec_aug_mask_idxs + offsets
# ensure that we cannot have indices larger than sequence_length
if spec_aug_mask_idxs.max() > sequence_length - 1:
spec_aug_mask_idxs[spec_aug_mask_idxs > sequence_length - 1] = sequence_length - 1
# scatter indices to mask
np.put_along_axis(spec_aug_mask, spec_aug_mask_idxs, 1, -1)
return spec_aug_mask
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2NoLayerNormConvLayer with Wav2Vec2->UniSpeechSat
class UniSpeechSatNoLayerNormConvLayer(nn.Module):
def __init__(self, config, layer_id=0):
class UniSpeechSatSamePadLayer(nn.Module):
def __init__(self, num_conv_pos_embeddings):
super().__init__()
self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
self.out_conv_dim = config.conv_dim[layer_id]
self.conv = nn.Conv1d(
self.in_conv_dim,
self.out_conv_dim,
kernel_size=config.conv_kernel[layer_id],
stride=config.conv_stride[layer_id],
bias=config.conv_bias,
)
self.activation = ACT2FN[config.feat_extract_activation]
self.num_pad_remove = 1 if num_conv_pos_embeddings % 2 == 0 else 0
def forward(self, hidden_states):
hidden_states = self.conv(hidden_states)
hidden_states = self.activation(hidden_states)
if self.num_pad_remove > 0:
hidden_states = hidden_states[:, :, : -self.num_pad_remove]
return hidden_states
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2LayerNormConvLayer with Wav2Vec2->UniSpeechSat
class UniSpeechSatLayerNormConvLayer(nn.Module):
def __init__(self, config, layer_id=0):
super().__init__()
self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
self.out_conv_dim = config.conv_dim[layer_id]
self.conv = nn.Conv1d(
self.in_conv_dim,
self.out_conv_dim,
kernel_size=config.conv_kernel[layer_id],
stride=config.conv_stride[layer_id],
bias=config.conv_bias,
)
self.layer_norm = nn.LayerNorm(self.out_conv_dim, elementwise_affine=True)
self.activation = ACT2FN[config.feat_extract_activation]
def forward(self, hidden_states):
hidden_states = self.conv(hidden_states)
hidden_states = hidden_states.transpose(-2, -1)
hidden_states = self.layer_norm(hidden_states)
hidden_states = hidden_states.transpose(-2, -1)
hidden_states = self.activation(hidden_states)
return hidden_states
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2GroupNormConvLayer with Wav2Vec2->UniSpeechSat
class UniSpeechSatGroupNormConvLayer(nn.Module):
def __init__(self, config, layer_id=0):
super().__init__()
self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
self.out_conv_dim = config.conv_dim[layer_id]
self.conv = nn.Conv1d(
self.in_conv_dim,
self.out_conv_dim,
kernel_size=config.conv_kernel[layer_id],
stride=config.conv_stride[layer_id],
bias=config.conv_bias,
)
self.activation = ACT2FN[config.feat_extract_activation]
self.layer_norm = nn.GroupNorm(num_groups=self.out_conv_dim, num_channels=self.out_conv_dim, affine=True)
def forward(self, hidden_states):
hidden_states = self.conv(hidden_states)
hidden_states = self.layer_norm(hidden_states)
hidden_states = self.activation(hidden_states)
return hidden_states
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2PositionalConvEmbedding with Wav2Vec2->UniSpeechSat
class UniSpeechSatPositionalConvEmbedding(nn.Module):
def __init__(self, config):
super().__init__()
@ -357,19 +145,78 @@ class UniSpeechSatPositionalConvEmbedding(nn.Module):
return hidden_states
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2SamePadLayer with Wav2Vec2->UniSpeechSat
class UniSpeechSatSamePadLayer(nn.Module):
def __init__(self, num_conv_pos_embeddings):
class UniSpeechSatNoLayerNormConvLayer(nn.Module):
def __init__(self, config, layer_id=0):
super().__init__()
self.num_pad_remove = 1 if num_conv_pos_embeddings % 2 == 0 else 0
self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
self.out_conv_dim = config.conv_dim[layer_id]
self.conv = nn.Conv1d(
self.in_conv_dim,
self.out_conv_dim,
kernel_size=config.conv_kernel[layer_id],
stride=config.conv_stride[layer_id],
bias=config.conv_bias,
)
self.activation = ACT2FN[config.feat_extract_activation]
def forward(self, hidden_states):
if self.num_pad_remove > 0:
hidden_states = hidden_states[:, :, : -self.num_pad_remove]
hidden_states = self.conv(hidden_states)
hidden_states = self.activation(hidden_states)
return hidden_states
class UniSpeechSatLayerNormConvLayer(nn.Module):
def __init__(self, config, layer_id=0):
super().__init__()
self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
self.out_conv_dim = config.conv_dim[layer_id]
self.conv = nn.Conv1d(
self.in_conv_dim,
self.out_conv_dim,
kernel_size=config.conv_kernel[layer_id],
stride=config.conv_stride[layer_id],
bias=config.conv_bias,
)
self.layer_norm = nn.LayerNorm(self.out_conv_dim, elementwise_affine=True)
self.activation = ACT2FN[config.feat_extract_activation]
def forward(self, hidden_states):
hidden_states = self.conv(hidden_states)
hidden_states = hidden_states.transpose(-2, -1)
hidden_states = self.layer_norm(hidden_states)
hidden_states = hidden_states.transpose(-2, -1)
hidden_states = self.activation(hidden_states)
return hidden_states
class UniSpeechSatGroupNormConvLayer(nn.Module):
def __init__(self, config, layer_id=0):
super().__init__()
self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
self.out_conv_dim = config.conv_dim[layer_id]
self.conv = nn.Conv1d(
self.in_conv_dim,
self.out_conv_dim,
kernel_size=config.conv_kernel[layer_id],
stride=config.conv_stride[layer_id],
bias=config.conv_bias,
)
self.activation = ACT2FN[config.feat_extract_activation]
self.layer_norm = nn.GroupNorm(num_groups=self.out_conv_dim, num_channels=self.out_conv_dim, affine=True)
def forward(self, hidden_states):
hidden_states = self.conv(hidden_states)
hidden_states = self.layer_norm(hidden_states)
hidden_states = self.activation(hidden_states)
return hidden_states
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2FeatureEncoder with Wav2Vec2->UniSpeechSat
class UniSpeechSatFeatureEncoder(nn.Module):
"""Construct the features from raw audio waveform"""
@ -417,18 +264,6 @@ class UniSpeechSatFeatureEncoder(nn.Module):
return hidden_states
class UniSpeechSatFeatureExtractor(UniSpeechSatFeatureEncoder):
def __init__(self, config):
super().__init__(config)
warnings.warn(
f"The class `{self.__class__.__name__}` has been depreciated "
"and will be removed in Transformers v5. "
f"Use `{self.__class__.__bases__[0].__name__}` instead.",
FutureWarning,
)
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2FeatureProjection with Wav2Vec2->UniSpeechSat
class UniSpeechSatFeatureProjection(nn.Module):
def __init__(self, config):
super().__init__()
@ -444,7 +279,6 @@ class UniSpeechSatFeatureProjection(nn.Module):
return hidden_states, norm_hidden_states
# Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->UniSpeechSat
class UniSpeechSatAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
@ -603,7 +437,6 @@ class UniSpeechSatAttention(nn.Module):
return attn_output, attn_weights_reshaped, past_key_value
# Copied from transformers.models.bart.modeling_bart.BartFlashAttention2 with Bart->UniSpeechSat
class UniSpeechSatFlashAttention2(UniSpeechSatAttention):
"""
UniSpeechSat flash attention module. This module inherits from `UniSpeechSatAttention` as the weights of the module stays
@ -731,7 +564,6 @@ class UniSpeechSatFlashAttention2(UniSpeechSatAttention):
class UniSpeechSatSdpaAttention(UniSpeechSatAttention):
# Copied from transformers.models.bart.modeling_bart.BartSdpaAttention.forward with Bart->UniSpeechSat
def forward(
self,
hidden_states: torch.Tensor,
@ -837,14 +669,6 @@ class UniSpeechSatSdpaAttention(UniSpeechSatAttention):
return attn_output, None, past_key_value
UNISPEECHSAT_ATTENTION_CLASSES = {
"eager": UniSpeechSatAttention,
"sdpa": UniSpeechSatSdpaAttention,
"flash_attention_2": UniSpeechSatFlashAttention2,
}
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2FeedForward with Wav2Vec2->UniSpeechSat
class UniSpeechSatFeedForward(nn.Module):
def __init__(self, config):
super().__init__()
@ -869,11 +693,17 @@ class UniSpeechSatFeedForward(nn.Module):
return hidden_states
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2EncoderLayer with Wav2Vec2->UniSpeechSat, WAV2VEC2->UNISPEECHSAT
UNISPEECH_SAT_ATTENTION_CLASSES = {
"eager": UniSpeechSatAttention,
"sdpa": UniSpeechSatSdpaAttention,
"flash_attention_2": UniSpeechSatFlashAttention2,
}
class UniSpeechSatEncoderLayer(nn.Module):
def __init__(self, config):
super().__init__()
self.attention = UNISPEECHSAT_ATTENTION_CLASSES[config._attn_implementation](
self.attention = UNISPEECH_SAT_ATTENTION_CLASSES[config._attn_implementation](
embed_dim=config.hidden_size,
num_heads=config.num_attention_heads,
dropout=config.attention_dropout,
@ -905,79 +735,6 @@ class UniSpeechSatEncoderLayer(nn.Module):
return outputs
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2AttnAdapterLayer with Wav2Vec2->UniSpeechSat
class UniSpeechSatAttnAdapterLayer(nn.Module):
def __init__(self, config):
"""
Implements adapter modules directly with 3D tensor weight as parameters and without using ModuleList to speed
up training throughput.
"""
super().__init__()
self.input_dim = config.adapter_attn_dim
self.hidden_dim = config.hidden_size
self.norm = nn.LayerNorm(self.hidden_dim)
self.linear_1 = nn.Linear(self.hidden_dim, self.input_dim)
self.act_fn = nn.ReLU()
self.linear_2 = nn.Linear(self.input_dim, self.hidden_dim)
def forward(self, hidden_states: torch.FloatTensor):
hidden_states = self.norm(hidden_states)
hidden_states = self.linear_1(hidden_states)
hidden_states = self.act_fn(hidden_states)
hidden_states = self.linear_2(hidden_states)
return hidden_states
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2EncoderLayerStableLayerNorm with Wav2Vec2->UniSpeechSat, WAV2VEC2->UNISPEECHSAT
class UniSpeechSatEncoderLayerStableLayerNorm(nn.Module):
def __init__(self, config):
super().__init__()
self.attention = UNISPEECHSAT_ATTENTION_CLASSES[config._attn_implementation](
embed_dim=config.hidden_size,
num_heads=config.num_attention_heads,
dropout=config.attention_dropout,
is_decoder=False,
)
self.dropout = nn.Dropout(config.hidden_dropout)
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.feed_forward = UniSpeechSatFeedForward(config)
self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
if getattr(config, "adapter_attn_dim", None) is not None:
self.adapter_layer = UniSpeechSatAttnAdapterLayer(config)
else:
self.adapter_layer = None
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
):
attn_residual = hidden_states
hidden_states = self.layer_norm(hidden_states)
hidden_states, attn_weights, _ = self.attention(
hidden_states, attention_mask=attention_mask, output_attentions=output_attentions
)
hidden_states = self.dropout(hidden_states)
hidden_states = attn_residual + hidden_states
hidden_states = hidden_states + self.feed_forward(self.final_layer_norm(hidden_states))
if self.adapter_layer is not None:
hidden_states = hidden_states + self.adapter_layer(hidden_states)
outputs = (hidden_states,)
if output_attentions:
outputs += (attn_weights,)
return outputs
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Encoder with Wav2Vec2->UniSpeechSat
class UniSpeechSatEncoder(nn.Module):
def __init__(self, config):
super().__init__()
@ -1063,7 +820,76 @@ class UniSpeechSatEncoder(nn.Module):
)
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2EncoderStableLayerNorm with Wav2Vec2->UniSpeechSat
class UniSpeechSatAttnAdapterLayer(nn.Module):
def __init__(self, config):
"""
Implements adapter modules directly with 3D tensor weight as parameters and without using ModuleList to speed
up training throughput.
"""
super().__init__()
self.input_dim = config.adapter_attn_dim
self.hidden_dim = config.hidden_size
self.norm = nn.LayerNorm(self.hidden_dim)
self.linear_1 = nn.Linear(self.hidden_dim, self.input_dim)
self.act_fn = nn.ReLU()
self.linear_2 = nn.Linear(self.input_dim, self.hidden_dim)
def forward(self, hidden_states: torch.FloatTensor):
hidden_states = self.norm(hidden_states)
hidden_states = self.linear_1(hidden_states)
hidden_states = self.act_fn(hidden_states)
hidden_states = self.linear_2(hidden_states)
return hidden_states
class UniSpeechSatEncoderLayerStableLayerNorm(nn.Module):
def __init__(self, config):
super().__init__()
self.attention = UNISPEECH_SAT_ATTENTION_CLASSES[config._attn_implementation](
embed_dim=config.hidden_size,
num_heads=config.num_attention_heads,
dropout=config.attention_dropout,
is_decoder=False,
)
self.dropout = nn.Dropout(config.hidden_dropout)
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.feed_forward = UniSpeechSatFeedForward(config)
self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
if getattr(config, "adapter_attn_dim", None) is not None:
self.adapter_layer = UniSpeechSatAttnAdapterLayer(config)
else:
self.adapter_layer = None
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
):
attn_residual = hidden_states
hidden_states = self.layer_norm(hidden_states)
hidden_states, attn_weights, _ = self.attention(
hidden_states, attention_mask=attention_mask, output_attentions=output_attentions
)
hidden_states = self.dropout(hidden_states)
hidden_states = attn_residual + hidden_states
hidden_states = hidden_states + self.feed_forward(self.final_layer_norm(hidden_states))
if self.adapter_layer is not None:
hidden_states = hidden_states + self.adapter_layer(hidden_states)
outputs = (hidden_states,)
if output_attentions:
outputs += (attn_weights,)
return outputs
class UniSpeechSatEncoderStableLayerNorm(nn.Module):
def __init__(self, config):
super().__init__()
@ -1155,7 +981,7 @@ class UniSpeechSatEncoderStableLayerNorm(nn.Module):
class UniSpeechSatGumbelVectorQuantizer(nn.Module):
"""
Vector quantization using gumbel softmax. See [CATEGORICAL REPARAMETERIZATION WITH
Vector quantization using gumbel softmax. See `[CATEGORICAL REPARAMETERIZATION WITH
GUMBEL-SOFTMAX](https://arxiv.org/pdf/1611.01144.pdf) for more information.
"""
@ -1166,8 +992,8 @@ class UniSpeechSatGumbelVectorQuantizer(nn.Module):
if config.codevector_dim % self.num_groups != 0:
raise ValueError(
f"`config.codevector_dim {config.codevector_dim} must be divisible by `config.num_codevector_groups`"
f" {self.num_groups} for concatenation"
f"`config.codevector_dim {config.codevector_dim} must be divisible "
f"by `config.num_codevector_groups` {self.num_groups} for concatenation"
)
# storage for codebook variables (codewords)
@ -1300,6 +1126,128 @@ class UniSpeechSatPreTrainedModel(PreTrainedModel):
return attention_mask
def _compute_mask_indices(
shape: Tuple[int, int],
mask_prob: float,
mask_length: int,
attention_mask: Optional[torch.LongTensor] = None,
min_masks: int = 0,
) -> np.ndarray:
"""
Computes random mask spans for a given shape. Used to implement [SpecAugment: A Simple Data Augmentation Method for
ASR](https://arxiv.org/abs/1904.08779). Note that this method is not optimized to run on TPU and should be run on
CPU as part of the preprocessing during training.
Args:
shape: The shape for which to compute masks. This should be of a tuple of size 2 where
the first element is the batch size and the second element is the length of the axis to span.
mask_prob: The percentage of the whole axis (between 0 and 1) which will be masked. The number of
independently generated mask spans of length `mask_length` is computed by
`mask_prob*shape[1]/mask_length`. Note that due to overlaps, `mask_prob` is an upper bound and the
actual percentage will be smaller.
mask_length: size of the mask
min_masks: minimum number of masked spans
attention_mask: A (right-padded) attention mask which independently shortens the feature axis of
each batch dimension.
"""
batch_size, sequence_length = shape
if mask_length < 1:
raise ValueError("`mask_length` has to be bigger than 0.")
if mask_length > sequence_length:
raise ValueError(
f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length}"
f" and `sequence_length`: {sequence_length}`"
)
# epsilon is used for probabilistic rounding
epsilon = np.random.rand(1).item()
def compute_num_masked_span(input_length):
"""Given input length, compute how many spans should be masked"""
num_masked_span = int(mask_prob * input_length / mask_length + epsilon)
num_masked_span = max(num_masked_span, min_masks)
# make sure num masked span <= sequence_length
if num_masked_span * mask_length > sequence_length:
num_masked_span = sequence_length // mask_length
# make sure num_masked span is also <= input_length - (mask_length - 1)
if input_length - (mask_length - 1) < num_masked_span:
num_masked_span = max(input_length - (mask_length - 1), 0)
return num_masked_span
# compute number of masked spans in batch
input_lengths = (
attention_mask.detach().sum(-1).tolist()
if attention_mask is not None
else [sequence_length for _ in range(batch_size)]
)
# SpecAugment mask to fill
spec_aug_mask = np.zeros((batch_size, sequence_length), dtype=bool)
spec_aug_mask_idxs = []
max_num_masked_span = compute_num_masked_span(sequence_length)
if max_num_masked_span == 0:
return spec_aug_mask
for input_length in input_lengths:
# compute num of masked spans for this input
num_masked_span = compute_num_masked_span(input_length)
# get random indices to mask
spec_aug_mask_idx = np.random.choice(
np.arange(input_length - (mask_length - 1)), num_masked_span, replace=False
)
# pick first sampled index that will serve as a dummy index to pad vector
# to ensure same dimension for all batches due to probabilistic rounding
# Picking first sample just pads those vectors twice.
if len(spec_aug_mask_idx) == 0:
# this case can only happen if `input_length` is strictly smaller then
# `sequence_length` in which case the last token has to be a padding
# token which we can use as a dummy mask id
dummy_mask_idx = sequence_length - 1
else:
dummy_mask_idx = spec_aug_mask_idx[0]
spec_aug_mask_idx = np.concatenate(
[spec_aug_mask_idx, np.ones(max_num_masked_span - num_masked_span, dtype=np.int32) * dummy_mask_idx]
)
spec_aug_mask_idxs.append(spec_aug_mask_idx)
spec_aug_mask_idxs = np.array(spec_aug_mask_idxs)
# expand masked indices to masked spans
spec_aug_mask_idxs = np.broadcast_to(
spec_aug_mask_idxs[:, :, None], (batch_size, max_num_masked_span, mask_length)
)
spec_aug_mask_idxs = spec_aug_mask_idxs.reshape(batch_size, max_num_masked_span * mask_length)
# add offset to the starting indexes so that indexes now create a span
offsets = np.arange(mask_length)[None, None, :]
offsets = np.broadcast_to(offsets, (batch_size, max_num_masked_span, mask_length)).reshape(
batch_size, max_num_masked_span * mask_length
)
spec_aug_mask_idxs = spec_aug_mask_idxs + offsets
# ensure that we cannot have indices larger than sequence_length
if spec_aug_mask_idxs.max() > sequence_length - 1:
spec_aug_mask_idxs[spec_aug_mask_idxs > sequence_length - 1] = sequence_length - 1
# scatter indices to mask
np.put_along_axis(spec_aug_mask, spec_aug_mask_idxs, 1, -1)
return spec_aug_mask
_EXPECTED_OUTPUT_SHAPE = [1, 292, 768]
UNISPEECH_SAT_START_DOCSTRING = r"""
UniSpeechSat was proposed in [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech
Representations](https://arxiv.org/abs/2006.11477) by Alexei Baevski, Henry Zhou, Abdelrahman Mohamed, Michael
@ -1318,7 +1266,6 @@ UNISPEECH_SAT_START_DOCSTRING = r"""
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
UNISPEECH_SAT_INPUTS_DOCSTRING = r"""
Args:
input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
@ -1338,12 +1285,10 @@ UNISPEECH_SAT_INPUTS_DOCSTRING = r"""
<Tip warning={true}>
`attention_mask` should only be passed if the corresponding processor has `config.return_attention_mask ==
True`. For all models whose processor has `config.return_attention_mask == False`, such as
[microsoft/unispeech-sat-base-100h-libri-ft](https://huggingface.co/microsoft/unispeech-sat-base-100h-libri-ft),
`attention_mask` should **not** be passed to avoid degraded performance when doing batched inference. For
such models `input_values` should simply be padded with 0 and passed without `attention_mask`. Be aware
that these models also yield slightly different results depending on whether `input_values` is padded or
not.
True`. For all models whose processor has `config.return_attention_mask == False`, `attention_mask` should
**not** be passed to avoid degraded performance when doing batched inference. For such models
`input_values` should simply be padded with 0 and passed without `attention_mask`. Be aware that these
models also yield slightly different results depending on whether `input_values` is padded or not.
</Tip>
@ -1357,6 +1302,8 @@ UNISPEECH_SAT_INPUTS_DOCSTRING = r"""
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
UniSpeechSatBaseModelOutput = Wav2Vec2BaseModelOutput
@add_start_docstrings(
"The bare UniSpeechSat Model transformer outputting raw hidden-states without any specific head on top.",
@ -1379,7 +1326,6 @@ class UniSpeechSatModel(UniSpeechSatPreTrainedModel):
# Initialize weights and apply final processing
self.post_init()
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Model._mask_hidden_states
def _mask_hidden_states(
self,
hidden_states: torch.FloatTensor,
@ -1429,7 +1375,7 @@ class UniSpeechSatModel(UniSpeechSatPreTrainedModel):
@add_start_docstrings_to_model_forward(UNISPEECH_SAT_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=Wav2Vec2BaseModelOutput,
output_type=UniSpeechSatBaseModelOutput,
config_class=_CONFIG_FOR_DOC,
modality="audio",
expected_output=_EXPECTED_OUTPUT_SHAPE,
@ -1442,7 +1388,7 @@ class UniSpeechSatModel(UniSpeechSatPreTrainedModel):
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, Wav2Vec2BaseModelOutput]:
) -> Union[Tuple, UniSpeechSatBaseModelOutput]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
@ -1474,7 +1420,7 @@ class UniSpeechSatModel(UniSpeechSatPreTrainedModel):
if not return_dict:
return (hidden_states, extract_features) + encoder_outputs[1:]
return Wav2Vec2BaseModelOutput(
return UniSpeechSatBaseModelOutput(
last_hidden_state=hidden_states,
extract_features=extract_features,
hidden_states=encoder_outputs.hidden_states,
@ -1482,7 +1428,10 @@ class UniSpeechSatModel(UniSpeechSatPreTrainedModel):
)
@add_start_docstrings("""UniSpeechSat Model with a quantizer and `VQ` head on top.""", UNISPEECH_SAT_START_DOCSTRING)
@add_start_docstrings(
"""UniSpeechSat Model with a vector-quantization module and ctc loss for pre-training.""",
UNISPEECH_SAT_START_DOCSTRING,
)
class UniSpeechSatForPreTraining(UniSpeechSatPreTrainedModel):
def __init__(self, config: UniSpeechSatConfig):
super().__init__(config)
@ -1529,7 +1478,7 @@ class UniSpeechSatForPreTraining(UniSpeechSatPreTrainedModel):
Calling this function will disable the gradient computation for the feature encoder so that its parameter will
not be updated during training.
"""
self.wav2vec2.feature_extractor._freeze_parameters()
self.unispeech_sat.feature_extractor._freeze_parameters()
@staticmethod
def compute_contrastive_logits(
@ -1594,16 +1543,6 @@ class UniSpeechSatForPreTraining(UniSpeechSatPreTrainedModel):
logits = extract_features
loss = quantized_features = codevector_perplexity = None
# layer normalization (has no effect when `config.do_stable_layer_norm == False`)
# extract_features = self.layer_norm_for_extract(extract_features)
# quantized_features, codevector_perplexity = self.quantizer(extract_features)
#
# project quantized features twice
# quantized_features = self.project_q(quantized_features)
# quantized_features = self.project_hid(quantized_features)
#
# loss = None
# logits = quantized_features
if not return_dict:
if loss is not None:
return (loss, logits, transformer_features, quantized_features, codevector_perplexity) + outputs[2:]
@ -1620,6 +1559,13 @@ class UniSpeechSatForPreTraining(UniSpeechSatPreTrainedModel):
)
_HIDDEN_STATES_START_POSITION = 2
# CTC docstring
_CTC_EXPECTED_OUTPUT = "'MISTER QUILDER IS THE APOSTLE OF THE MIDDLE CLASSES AND WE ARE GLAD TO WELCOME HIS GOSPEL'"
_CTC_EXPECTED_LOSS = 39.88
@add_start_docstrings(
"""UniSpeechSat Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC).""",
UNISPEECH_SAT_START_DOCSTRING,
@ -1630,7 +1576,6 @@ class UniSpeechSatForPreTraining(UniSpeechSatPreTrainedModel):
'eng' by default.
""",
)
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForCTC with Wav2Vec2->UniSpeechSat, wav2vec2->unispeech_sat, WAV_2_VEC_2->UNISPEECH_SAT
class UniSpeechSatForCTC(UniSpeechSatPreTrainedModel):
def __init__(self, config, target_lang: Optional[str] = None):
super().__init__(config)
@ -1784,8 +1729,8 @@ class UniSpeechSatForCTC(UniSpeechSatPreTrainedModel):
@add_start_docstrings(
"""
UniSpeechSat Model with a sequence classification head on top (a linear layer over the pooled output) for tasks
like SUPERB Keyword Spotting.
UniSpeechSat Model with a sequence classification head on top (a linear layer over the pooled output) for tasks like
SUPERB Keyword Spotting.
""",
UNISPEECH_SAT_START_DOCSTRING,
)
@ -1807,7 +1752,6 @@ class UniSpeechSatForSequenceClassification(UniSpeechSatPreTrainedModel):
# Initialize weights and apply final processing
self.post_init()
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.freeze_feature_extractor
def freeze_feature_extractor(self):
"""
Calling this function will disable the gradient computation for the feature encoder so that its parameters will
@ -1820,7 +1764,6 @@ class UniSpeechSatForSequenceClassification(UniSpeechSatPreTrainedModel):
)
self.freeze_feature_encoder()
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.freeze_feature_encoder with wav2vec2->unispeech_sat
def freeze_feature_encoder(self):
"""
Calling this function will disable the gradient computation for the feature encoder so that its parameter will
@ -1828,7 +1771,6 @@ class UniSpeechSatForSequenceClassification(UniSpeechSatPreTrainedModel):
"""
self.unispeech_sat.feature_extractor._freeze_parameters()
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.freeze_base_model with wav2vec2->unispeech_sat
def freeze_base_model(self):
"""
Calling this function will disable the gradient computation for the base model so that its parameters will not
@ -1844,7 +1786,6 @@ class UniSpeechSatForSequenceClassification(UniSpeechSatPreTrainedModel):
config_class=_CONFIG_FOR_DOC,
modality="audio",
)
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.forward with Wav2Vec2->UniSpeechSat, wav2vec2->unispeech_sat
def forward(
self,
input_values: Optional[torch.Tensor],
@ -1908,13 +1849,17 @@ class UniSpeechSatForSequenceClassification(UniSpeechSatPreTrainedModel):
)
# Frame class docstring
_FRAME_CLASS_CHECKPOINT = "microsoft/unispeech-sat-base-plus-sd"
_FRAME_EXPECTED_OUTPUT = [0, 0]
@add_start_docstrings(
"""
UniSpeech-SAT Model with a frame classification head on top for tasks like Speaker Diarization.
UniSpeechSat Model with a frame classification head on top for tasks like Speaker Diarization.
""",
UNISPEECH_SAT_START_DOCSTRING,
)
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForAudioFrameClassification with Wav2Vec2->UniSpeechSat, wav2vec2->unispeech_sat, WAV_2_VEC_2->UNISPEECH_SAT
class UniSpeechSatForAudioFrameClassification(UniSpeechSatPreTrainedModel):
def __init__(self, config):
super().__init__(config)
@ -2021,7 +1966,6 @@ class UniSpeechSatForAudioFrameClassification(UniSpeechSatPreTrainedModel):
)
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.AMSoftmaxLoss
class AMSoftmaxLoss(nn.Module):
def __init__(self, input_dim, num_labels, scale=30.0, margin=0.4):
super(AMSoftmaxLoss, self).__init__()
@ -2045,7 +1989,6 @@ class AMSoftmaxLoss(nn.Module):
return loss
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.TDNNLayer
class TDNNLayer(nn.Module):
def __init__(self, config, layer_id=0):
super().__init__()
@ -2061,6 +2004,7 @@ class TDNNLayer(nn.Module):
if is_peft_available():
from peft.tuners.lora import LoraLayer
if is_peft_available():
if isinstance(self.kernel, LoraLayer):
warnings.warn(
"Detected LoRA on TDNNLayer. LoRA weights won't be applied due to optimization. "
@ -2077,13 +2021,17 @@ class TDNNLayer(nn.Module):
return hidden_states
# Speaker Verification docstring
_XVECTOR_CHECKPOINT = "microsoft/unispeech-sat-base-plus-sv"
_XVECTOR_EXPECTED_OUTPUT = 0.97
@add_start_docstrings(
"""
UniSpeech-SAT Model with an XVector feature extraction head on top for tasks like Speaker Verification.
UniSpeechSat Model with an XVector feature extraction head on top for tasks like Speaker Verification.
""",
UNISPEECH_SAT_START_DOCSTRING,
)
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForXVector with Wav2Vec2->UniSpeechSat, wav2vec2->unispeech_sat, WAV_2_VEC_2->UNISPEECH_SAT
class UniSpeechSatForXVector(UniSpeechSatPreTrainedModel):
def __init__(self, config):
super().__init__(config)

View File

@ -0,0 +1,610 @@
import math
import warnings
from dataclasses import dataclass
from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
from ...modeling_outputs import (
CausalLMOutput,
ModelOutput,
SequenceClassifierOutput,
TokenClassifierOutput,
Wav2Vec2BaseModelOutput,
XVectorOutput,
)
from ...modeling_utils import PreTrainedModel
from ...utils import (
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
logging,
replace_return_docstrings,
)
from ..wav2vec2.modeling_wav2vec2 import (
Wav2Vec2Encoder,
Wav2Vec2EncoderStableLayerNorm,
Wav2Vec2FeatureEncoder,
Wav2Vec2FeatureProjection,
Wav2Vec2ForAudioFrameClassification,
Wav2Vec2ForCTC,
Wav2Vec2ForSequenceClassification,
Wav2Vec2ForXVector,
Wav2Vec2GumbelVectorQuantizer,
Wav2Vec2Model,
Wav2Vec2PositionalConvEmbedding,
)
from .configuration_unispeech_sat import UniSpeechSatConfig
logger = logging.get_logger(__name__)
_HIDDEN_STATES_START_POSITION = 2
# General docstring
_CONFIG_FOR_DOC = "UniSpeechSatConfig"
# Base docstring
_CHECKPOINT_FOR_DOC = "microsoft/unispeech-sat-base-100h-libri-ft"
_EXPECTED_OUTPUT_SHAPE = [1, 292, 768]
# CTC docstring
_CTC_EXPECTED_OUTPUT = "'MISTER QUILDER IS THE APOSTLE OF THE MIDDLE CLASSES AND WE ARE GLAD TO WELCOME HIS GOSPEL'"
_CTC_EXPECTED_LOSS = 39.88
# Frame class docstring
_FRAME_CLASS_CHECKPOINT = "microsoft/unispeech-sat-base-plus-sd"
_FRAME_EXPECTED_OUTPUT = [0, 0]
# Speaker Verification docstring
_XVECTOR_CHECKPOINT = "microsoft/unispeech-sat-base-plus-sv"
_XVECTOR_EXPECTED_OUTPUT = 0.97
@dataclass
class UniSpeechSatForPreTrainingOutput(ModelOutput):
"""
Output type of [`UniSpeechSatForPreTrainingOutput`], with potential hidden states and attentions.
Args:
loss (*optional*, returned when model is in train mode, `torch.FloatTensor` of shape `(1,)`):
Total loss as the sum of the contrastive loss (L_m) and the diversity loss (L_d) as stated in the [official
paper](https://arxiv.org/pdf/2006.11477.pdf) . (classification) loss.
projected_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.proj_codevector_dim)`):
Hidden-states of the model projected to *config.proj_codevector_dim* that can be used to predict the masked
projected quantized states.
projected_quantized_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.proj_codevector_dim)`):
Quantized extracted feature vectors projected to *config.proj_codevector_dim* representing the positive
target vectors for contrastive loss.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
loss: Optional[torch.FloatTensor] = None
logits: Optional[torch.FloatTensor] = None
projected_states: Optional[torch.FloatTensor] = None
projected_quantized_states: Optional[torch.FloatTensor] = None
codevector_perplexity: Optional[torch.FloatTensor] = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
class UniSpeechSatPositionalConvEmbedding(Wav2Vec2PositionalConvEmbedding):
pass
class UniSpeechSatFeatureEncoder(Wav2Vec2FeatureEncoder):
pass
class UniSpeechSatFeatureProjection(Wav2Vec2FeatureProjection):
pass
class UniSpeechSatEncoder(Wav2Vec2Encoder):
pass
class UniSpeechSatEncoderStableLayerNorm(Wav2Vec2EncoderStableLayerNorm):
pass
class UniSpeechSatGumbelVectorQuantizer(Wav2Vec2GumbelVectorQuantizer):
def __init__(self, config):
super().__init__()
self.weight_proj = nn.Linear(config.hidden_size, self.num_groups * self.num_vars)
@staticmethod
def _compute_perplexity(probs, mask=None):
marginal_probs = probs.mean(dim=0)
perplexity = torch.exp(-torch.sum(marginal_probs * torch.log(marginal_probs + 1e-7), dim=-1)).sum()
return perplexity
def forward(self, hidden_states):
batch_size, sequence_length, hidden_size = hidden_states.shape
# project to codevector dim
hidden_states = self.weight_proj(hidden_states)
hidden_states = hidden_states.view(batch_size * sequence_length * self.num_groups, -1)
if self.training:
# sample code vector probs via gumbel in differentiateable way
codevector_probs = nn.functional.gumbel_softmax(
hidden_states.float(), tau=self.temperature, hard=True
).type_as(hidden_states)
# compute perplexity
codevector_soft_dist = torch.softmax(
hidden_states.view(batch_size * sequence_length, self.num_groups, -1).float(), dim=-1
)
perplexity = self._compute_perplexity(codevector_soft_dist)
else:
# take argmax in non-differentiable way
# comptute hard codevector distribution (one hot)
codevector_idx = hidden_states.argmax(dim=-1)
codevector_probs = hidden_states.new_zeros(*hidden_states.shape).scatter_(
-1, codevector_idx.view(-1, 1), 1.0
)
codevector_probs = codevector_probs.view(batch_size * sequence_length, self.num_groups, -1)
perplexity = self._compute_perplexity(codevector_probs)
codevector_probs = codevector_probs.view(batch_size * sequence_length, -1)
# use probs to retrieve codevectors
codevectors_per_group = codevector_probs.unsqueeze(-1) * self.codevectors
codevectors = codevectors_per_group.view(batch_size * sequence_length, self.num_groups, self.num_vars, -1)
codevectors = codevectors.sum(-2).view(batch_size, sequence_length, -1)
return codevectors, perplexity
class UniSpeechSatPreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = UniSpeechSatConfig
base_model_prefix = "unispeech_sat"
main_input_name = "input_values"
supports_gradient_checkpointing = True
_supports_flash_attn_2 = True
_supports_sdpa = True
def _init_weights(self, module):
"""Initialize the weights"""
# gumbel softmax requires special init
if isinstance(module, UniSpeechSatGumbelVectorQuantizer):
module.weight_proj.weight.data.normal_(mean=0.0, std=1)
module.weight_proj.bias.data.zero_()
nn.init.uniform_(module.codevectors)
elif isinstance(module, UniSpeechSatPositionalConvEmbedding):
nn.init.normal_(
module.conv.weight,
mean=0,
std=2 * math.sqrt(1 / (module.conv.kernel_size[0] * module.conv.in_channels)),
)
nn.init.constant_(module.conv.bias, 0)
elif isinstance(module, UniSpeechSatFeatureProjection):
k = math.sqrt(1 / module.projection.in_features)
nn.init.uniform_(module.projection.weight, a=-k, b=k)
nn.init.uniform_(module.projection.bias, a=-k, b=k)
elif isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
elif isinstance(module, nn.Conv1d):
nn.init.kaiming_normal_(module.weight)
if module.bias is not None:
k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0]))
nn.init.uniform_(module.bias, a=-k, b=k)
def _get_feat_extract_output_lengths(self, input_lengths: Union[torch.LongTensor, int]):
"""
Computes the output length of the convolutional layers
"""
def _conv_out_length(input_length, kernel_size, stride):
# 1D convolutional layer output length formula taken
# from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
return torch.div(input_length - kernel_size, stride, rounding_mode="floor") + 1
for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride):
input_lengths = _conv_out_length(input_lengths, kernel_size, stride)
return input_lengths
def _get_feature_vector_attention_mask(self, feature_vector_length: int, attention_mask: torch.LongTensor):
# Effectively attention_mask.sum(-1), but not inplace to be able to run
# on inference mode.
non_padded_lengths = attention_mask.cumsum(dim=-1)[:, -1]
output_lengths = self._get_feat_extract_output_lengths(non_padded_lengths).to(torch.long)
batch_size = attention_mask.shape[0]
attention_mask = torch.zeros(
(batch_size, feature_vector_length), dtype=attention_mask.dtype, device=attention_mask.device
)
# these two operations makes sure that all values before the output lengths idxs are attended to
attention_mask[(torch.arange(attention_mask.shape[0], device=attention_mask.device), output_lengths - 1)] = 1
attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool()
return attention_mask
UNISPEECH_SAT_START_DOCSTRING = r"""
UniSpeechSat was proposed in [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech
Representations](https://arxiv.org/abs/2006.11477) by Alexei Baevski, Henry Zhou, Abdelrahman Mohamed, Michael
Auli.
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving etc.).
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use
it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
behavior.
Parameters:
config ([`UniSpeechSatConfig`]): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
UNISPEECH_SAT_INPUTS_DOCSTRING = r"""
Args:
input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
Float values of input raw speech waveform. Values can be obtained by loading a `.flac` or `.wav` audio file
into an array of type `List[float]` or a `numpy.ndarray`, *e.g.* via the soundfile library (`pip install
soundfile`). To prepare the array into `input_values`, the [`AutoProcessor`] should be used for padding and
conversion into a tensor of type `torch.FloatTensor`. See [`Wav2Vec2Processor.__call__`] for details.
attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing convolution and attention on padding token indices. Mask values selected in `[0,
1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
<Tip warning={true}>
`attention_mask` should only be passed if the corresponding processor has `config.return_attention_mask ==
True`. For all models whose processor has `config.return_attention_mask == False`, `attention_mask` should
**not** be passed to avoid degraded performance when doing batched inference. For such models
`input_values` should simply be padded with 0 and passed without `attention_mask`. Be aware that these
models also yield slightly different results depending on whether `input_values` is padded or not.
</Tip>
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
UniSpeechSatBaseModelOutput = Wav2Vec2BaseModelOutput
@add_start_docstrings(
"The bare UniSpeechSat Model transformer outputting raw hidden-states without any specific head on top.",
UNISPEECH_SAT_START_DOCSTRING,
)
class UniSpeechSatModel(UniSpeechSatPreTrainedModel, Wav2Vec2Model):
def __init__(self, config: UniSpeechSatConfig):
UniSpeechSatPreTrainedModel.__init__(config)
self.config = config
self.feature_extractor = UniSpeechSatFeatureEncoder(config)
self.feature_projection = UniSpeechSatFeatureProjection(config)
self.masked_spec_embed = nn.Parameter(torch.Tensor(config.hidden_size).uniform_())
if config.do_stable_layer_norm:
self.encoder = UniSpeechSatEncoderStableLayerNorm(config)
else:
self.encoder = UniSpeechSatEncoder(config)
# Initialize weights and apply final processing
self.post_init()
def freeze_feature_extractor(self):
raise AttributeError("Not needed for UniSpeechSat")
def freeze_feature_encoder(self):
raise AttributeError("Not needed for UniSpeechSat")
@add_start_docstrings_to_model_forward(UNISPEECH_SAT_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=UniSpeechSatBaseModelOutput,
config_class=_CONFIG_FOR_DOC,
modality="audio",
expected_output=_EXPECTED_OUTPUT_SHAPE,
)
def forward(
self,
input_values: Optional[torch.Tensor],
attention_mask: Optional[torch.Tensor] = None,
mask_time_indices: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, UniSpeechSatBaseModelOutput]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
extract_features = self.feature_extractor(input_values)
extract_features = extract_features.transpose(1, 2)
if attention_mask is not None:
# compute reduced attention_mask corresponding to feature vectors
attention_mask = self._get_feature_vector_attention_mask(extract_features.shape[1], attention_mask)
hidden_states, extract_features = self.feature_projection(extract_features)
hidden_states = self._mask_hidden_states(
hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask
)
encoder_outputs = self.encoder(
hidden_states,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = encoder_outputs[0]
if not return_dict:
return (hidden_states, extract_features) + encoder_outputs[1:]
return UniSpeechSatBaseModelOutput(
last_hidden_state=hidden_states,
extract_features=extract_features,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)
@add_start_docstrings(
"""UniSpeechSat Model with a vector-quantization module and ctc loss for pre-training.""",
UNISPEECH_SAT_START_DOCSTRING,
)
class UniSpeechSatForPreTraining(UniSpeechSatPreTrainedModel):
def __init__(self, config: UniSpeechSatConfig):
super().__init__(config)
self.unispeech_sat = UniSpeechSatModel(config)
self.dropout_features = nn.Dropout(config.feat_quantizer_dropout)
self.quantizer = UniSpeechSatGumbelVectorQuantizer(config)
self.project_q = nn.Linear(config.codevector_dim, config.proj_codevector_dim)
self.project_hid = nn.Linear(config.hidden_size, config.proj_codevector_dim)
self.dropout = nn.Dropout(config.final_dropout)
self.speaker_proj = nn.Linear(config.hidden_size, config.codevector_dim)
self.label_embeddings_concat = nn.Parameter(torch.FloatTensor(config.num_clusters, config.codevector_dim))
self.label_embeddings_concat.data.zero_()
self.layer_norm_for_extract = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
if self.config.do_stable_layer_norm:
self.layer_norm_for_extract.requires_grad = False
# Initialize weights and apply final processing
self.post_init()
def set_gumbel_temperature(self, temperature: int):
"""
Set the Gumbel softmax temperature to a given value. Only necessary for training
"""
self.quantizer.temperature = temperature
def freeze_feature_extractor(self):
"""
Calling this function will disable the gradient computation for the feature encoder so that its parameters will
not be updated during training.
"""
warnings.warn(
"The method `freeze_feature_extractor` is deprecated and will be removed in Transformers v5. "
"Please use the equivalent `freeze_feature_encoder` method instead.",
FutureWarning,
)
self.freeze_feature_encoder()
def freeze_feature_encoder(self):
"""
Calling this function will disable the gradient computation for the feature encoder so that its parameter will
not be updated during training.
"""
self.unispeech_sat.feature_extractor._freeze_parameters()
@staticmethod
def compute_contrastive_logits(
target_features: torch.FloatTensor,
negative_features: torch.FloatTensor,
predicted_features: torch.FloatTensor,
temperature: int = 1,
):
"""
Compute logits for contrastive loss based using cosine similarity as the distance measure between
`[positive_feature, negative_features]` and `[predicted_features]`. Additionally, temperature can be applied.
"""
target_features = torch.cat([target_features, negative_features], dim=0)
logits = torch.cosine_similarity(predicted_features.float(), target_features.float(), dim=-1)
logits = logits.type_as(target_features)
# apply temperature
logits = logits / temperature
return logits
@add_start_docstrings_to_model_forward(UNISPEECH_SAT_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=UniSpeechSatForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_values: Optional[torch.Tensor],
attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, UniSpeechSatForPreTrainingOutput]:
r"""
Returns:
Example:
```python
>>> import torch
>>> from transformers import AutoFeatureExtractor, UniSpeechSatForPreTraining
>>> from transformers.models.unispeech_sat.modeling_unispeech_sat import _compute_mask_indices
>>> feature_extractor = AutoFeatureExtractor.from_pretrained("microsoft/unispeech-sat-base")
>>> model = UniSpeechSatForPreTraining.from_pretrained("microsoft/unispeech-sat-base")
>>> # TODO: Add full pretraining example
```"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.unispeech_sat(
input_values,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
transformer_features = outputs[0]
# quantize all (unmasked) extracted features and project to final vq dim
extract_features = self.dropout_features(outputs[1])
# TODO(PVP) - add pretraining logic and add to tests
logits = extract_features
loss = quantized_features = codevector_perplexity = None
if not return_dict:
if loss is not None:
return (loss, logits, transformer_features, quantized_features, codevector_perplexity) + outputs[2:]
return (logits, transformer_features, quantized_features, codevector_perplexity) + outputs[2:]
return UniSpeechSatForPreTrainingOutput(
loss=loss,
logits=logits,
projected_states=transformer_features,
projected_quantized_states=quantized_features,
codevector_perplexity=codevector_perplexity,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
@add_start_docstrings(
"""UniSpeechSat Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC).""",
UNISPEECH_SAT_START_DOCSTRING,
"""
target_lang (`str`, *optional*):
Language id of adapter weights. Adapter weights are stored in the format adapter.<lang>.safetensors or
adapter.<lang>.bin. Only relevant when using an instance of [`UniSpeechSatForCTC`] with adapters. Uses
'eng' by default.
""",
)
class UniSpeechSatForCTC(Wav2Vec2ForCTC):
@add_start_docstrings_to_model_forward(UNISPEECH_SAT_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=CausalLMOutput,
config_class=_CONFIG_FOR_DOC,
expected_output=_CTC_EXPECTED_OUTPUT,
expected_loss=_CTC_EXPECTED_LOSS,
)
def forward(self, **super_kwargs):
return super().forward(**super_kwargs)
@add_start_docstrings(
"""
UniSpeechSat Model with a sequence classification head on top (a linear layer over the pooled output) for tasks like
SUPERB Keyword Spotting.
""",
UNISPEECH_SAT_START_DOCSTRING,
)
class UniSpeechSatForSequenceClassification(Wav2Vec2ForSequenceClassification):
@add_start_docstrings_to_model_forward(UNISPEECH_SAT_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=SequenceClassifierOutput,
config_class=_CONFIG_FOR_DOC,
modality="audio",
)
def forward(self, **super_kwargs):
super().forward(**super_kwargs)
@add_start_docstrings(
"""
UniSpeechSat Model with a frame classification head on top for tasks like Speaker Diarization.
""",
UNISPEECH_SAT_START_DOCSTRING,
)
class UniSpeechSatForAudioFrameClassification(Wav2Vec2ForAudioFrameClassification):
@add_start_docstrings_to_model_forward(UNISPEECH_SAT_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint=_FRAME_CLASS_CHECKPOINT,
output_type=TokenClassifierOutput,
config_class=_CONFIG_FOR_DOC,
modality="audio",
expected_output=_FRAME_EXPECTED_OUTPUT,
)
def forward(self, **super_kwargs):
super().forward(**super_kwargs)
@add_start_docstrings(
"""
UniSpeechSat Model with an XVector feature extraction head on top for tasks like Speaker Verification.
""",
UNISPEECH_SAT_START_DOCSTRING,
)
class UniSpeechSatForXVector(Wav2Vec2ForXVector):
pass
@add_start_docstrings_to_model_forward(UNISPEECH_SAT_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint=_XVECTOR_CHECKPOINT,
output_type=XVectorOutput,
config_class=_CONFIG_FOR_DOC,
modality="audio",
expected_output=_XVECTOR_EXPECTED_OUTPUT,
)
def forward(self, **super_kwargs):
super().forward(**super_kwargs)
__all__ = [
"UniSpeechSatForAudioFrameClassification",
"UniSpeechSatForCTC",
"UniSpeechSatForPreTraining",
"UniSpeechSatForSequenceClassification",
"UniSpeechSatForXVector",
"UniSpeechSatModel",
"UniSpeechSatPreTrainedModel",
]

View File

@ -212,7 +212,7 @@ def _sample_negative_indices(features_shape: Tuple, num_negatives: int, attentio
return sampled_negative_indices
WAV_2_VEC_2_START_DOCSTRING = r"""
WAV2VEC2_START_DOCSTRING = r"""
Wav2Vec2 was proposed in [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech
Representations](https://arxiv.org/abs/2006.11477) by Alexei Baevski, Henry Zhou, Abdelrahman Mohamed, Michael
Auli.
@ -251,7 +251,7 @@ WAV_2_VEC_2_START_DOCSTRING = r"""
"""
WAV_2_VEC_2_INPUTS_DOCSTRING = r"""
WAV2VEC2_INPUTS_DOCSTRING = r"""
Args:
input_values (`jnp.ndarray` of shape `(batch_size, sequence_length)`):
Float values of input raw speech waveform. Values can be obtained by loading a `.flac` or `.wav` audio file
@ -885,7 +885,7 @@ class FlaxWav2Vec2PreTrainedModel(FlaxPreTrainedModel):
else:
return random_params
@add_start_docstrings_to_model_forward(WAV_2_VEC_2_INPUTS_DOCSTRING)
@add_start_docstrings_to_model_forward(WAV2VEC2_INPUTS_DOCSTRING)
def __call__(
self,
input_values,
@ -1050,7 +1050,7 @@ class FlaxWav2Vec2Module(nn.Module):
@add_start_docstrings(
"The bare Wav2Vec2 Model transformer outputting raw hidden-states without any specific head on top.",
WAV_2_VEC_2_START_DOCSTRING,
WAV2VEC2_START_DOCSTRING,
)
class FlaxWav2Vec2Model(FlaxWav2Vec2PreTrainedModel):
module_class = FlaxWav2Vec2Module
@ -1088,7 +1088,7 @@ FLAX_WAV2VEC2_MODEL_DOCSTRING = """
overwrite_call_docstring(
FlaxWav2Vec2Model,
WAV_2_VEC_2_INPUTS_DOCSTRING + FLAX_WAV2VEC2_MODEL_DOCSTRING,
WAV2VEC2_INPUTS_DOCSTRING + FLAX_WAV2VEC2_MODEL_DOCSTRING,
)
append_replace_return_docstrings(
FlaxWav2Vec2Model, output_type=FlaxWav2Vec2BaseModelOutput, config_class=Wav2Vec2Config
@ -1168,7 +1168,7 @@ class FlaxWav2Vec2ForCTCModule(nn.Module):
@add_start_docstrings(
"Wav2Vec2 Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC).",
WAV_2_VEC_2_START_DOCSTRING,
WAV2VEC2_START_DOCSTRING,
)
class FlaxWav2Vec2ForCTC(FlaxWav2Vec2PreTrainedModel):
module_class = FlaxWav2Vec2ForCTCModule
@ -1211,7 +1211,7 @@ FLAX_WAV2VEC2_FOR_CTC_DOCSTRING = """
overwrite_call_docstring(
FlaxWav2Vec2ForCTC,
WAV_2_VEC_2_INPUTS_DOCSTRING + FLAX_WAV2VEC2_FOR_CTC_DOCSTRING,
WAV2VEC2_INPUTS_DOCSTRING + FLAX_WAV2VEC2_FOR_CTC_DOCSTRING,
)
append_replace_return_docstrings(FlaxWav2Vec2ForCTC, output_type=FlaxCausalLMOutput, config_class=Wav2Vec2Config)
@ -1315,11 +1315,11 @@ class FlaxWav2Vec2ForPreTrainingModule(nn.Module):
return input_lengths
@add_start_docstrings("""Wav2Vec2 Model with a quantizer and `VQ` head on top.""", WAV_2_VEC_2_START_DOCSTRING)
@add_start_docstrings("""Wav2Vec2 Model with a quantizer and `VQ` head on top.""", WAV2VEC2_START_DOCSTRING)
class FlaxWav2Vec2ForPreTraining(FlaxWav2Vec2PreTrainedModel):
module_class = FlaxWav2Vec2ForPreTrainingModule
@add_start_docstrings_to_model_forward(WAV_2_VEC_2_INPUTS_DOCSTRING)
@add_start_docstrings_to_model_forward(WAV2VEC2_INPUTS_DOCSTRING)
# overwrite since has `gumbel_temperature` input
def __call__(
self,
@ -1418,7 +1418,7 @@ FLAX_WAV2VEC2_FOR_PRETRAINING_DOCSTRING = """
overwrite_call_docstring(
FlaxWav2Vec2ForPreTraining,
WAV_2_VEC_2_INPUTS_DOCSTRING + FLAX_WAV2VEC2_FOR_PRETRAINING_DOCSTRING,
WAV2VEC2_INPUTS_DOCSTRING + FLAX_WAV2VEC2_FOR_PRETRAINING_DOCSTRING,
)
append_replace_return_docstrings(
FlaxWav2Vec2ForPreTraining, output_type=FlaxWav2Vec2ForPreTrainingOutput, config_class=Wav2Vec2Config

View File

@ -1397,7 +1397,7 @@ class TFWav2Vec2PreTrainedModel(TFPreTrainedModel):
return attention_mask
WAV_2_VEC_2_START_DOCSTRING = r"""
WAV2VEC2_START_DOCSTRING = r"""
This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
@ -1439,7 +1439,7 @@ WAV_2_VEC_2_START_DOCSTRING = r"""
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
WAV_2_VEC_2_INPUTS_DOCSTRING = r"""
WAV2VEC2_INPUTS_DOCSTRING = r"""
Args:
input_values (`np.ndarray`, `tf.Tensor`, `List[tf.Tensor]` `Dict[str, tf.Tensor]` or `Dict[str, np.ndarray]` and each example must have the shape `({0})`):
Indices of input sequence tokens in the vocabulary.
@ -1497,7 +1497,7 @@ WAV_2_VEC_2_INPUTS_DOCSTRING = r"""
@add_start_docstrings(
"The bare TFWav2Vec2 Model transformer outputing raw hidden-states without any specific head on top.",
WAV_2_VEC_2_START_DOCSTRING,
WAV2VEC2_START_DOCSTRING,
)
class TFWav2Vec2Model(TFWav2Vec2PreTrainedModel):
def __init__(self, config: Wav2Vec2Config, *inputs, **kwargs):
@ -1505,7 +1505,7 @@ class TFWav2Vec2Model(TFWav2Vec2PreTrainedModel):
self.config = config
self.wav2vec2 = TFWav2Vec2MainLayer(config, name="wav2vec2")
@add_start_docstrings_to_model_forward(WAV_2_VEC_2_INPUTS_DOCSTRING)
@add_start_docstrings_to_model_forward(WAV2VEC2_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=TFBaseModelOutput, config_class=_CONFIG_FOR_DOC)
@unpack_inputs
def call(
@ -1579,7 +1579,7 @@ class TFWav2Vec2Model(TFWav2Vec2PreTrainedModel):
@add_start_docstrings(
"""TFWav2Vec2 Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC).""",
WAV_2_VEC_2_START_DOCSTRING,
WAV2VEC2_START_DOCSTRING,
)
class TFWav2Vec2ForCTC(TFWav2Vec2PreTrainedModel):
def __init__(self, config: Wav2Vec2Config, *inputs, **kwargs):
@ -1612,7 +1612,7 @@ class TFWav2Vec2ForCTC(TFWav2Vec2PreTrainedModel):
self.wav2vec2.feature_extractor.trainable = False
@unpack_inputs
@add_start_docstrings_to_model_forward(WAV_2_VEC_2_INPUTS_DOCSTRING)
@add_start_docstrings_to_model_forward(WAV2VEC2_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=TFCausalLMOutput, config_class=_CONFIG_FOR_DOC)
def call(
self,

View File

@ -63,6 +63,7 @@ if is_safetensors_available():
if is_flash_attn_available():
from ...modeling_flash_attention_utils import _flash_attention_forward
logger = logging.get_logger(__name__)
@ -1633,7 +1634,7 @@ class Wav2Vec2PreTrainedModel(PreTrainedModel):
self.target_lang = target_lang
WAV_2_VEC_2_START_DOCSTRING = r"""
WAV2VEC2_START_DOCSTRING = r"""
Wav2Vec2 was proposed in [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech
Representations](https://arxiv.org/abs/2006.11477) by Alexei Baevski, Henry Zhou, Abdelrahman Mohamed, Michael
Auli.
@ -1652,7 +1653,7 @@ WAV_2_VEC_2_START_DOCSTRING = r"""
"""
WAV_2_VEC_2_INPUTS_DOCSTRING = r"""
WAV2VEC2_INPUTS_DOCSTRING = r"""
Args:
input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
Float values of input raw speech waveform. Values can be obtained by loading a `.flac` or `.wav` audio file
@ -1692,7 +1693,7 @@ WAV_2_VEC_2_INPUTS_DOCSTRING = r"""
@add_start_docstrings(
"The bare Wav2Vec2 Model transformer outputting raw hidden-states without any specific head on top.",
WAV_2_VEC_2_START_DOCSTRING,
WAV2VEC2_START_DOCSTRING,
)
class Wav2Vec2Model(Wav2Vec2PreTrainedModel):
def __init__(self, config: Wav2Vec2Config):
@ -1780,7 +1781,7 @@ class Wav2Vec2Model(Wav2Vec2PreTrainedModel):
return hidden_states
@add_start_docstrings_to_model_forward(WAV_2_VEC_2_INPUTS_DOCSTRING)
@add_start_docstrings_to_model_forward(WAV2VEC2_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=Wav2Vec2BaseModelOutput,
@ -1841,7 +1842,7 @@ class Wav2Vec2Model(Wav2Vec2PreTrainedModel):
)
@add_start_docstrings("""Wav2Vec2 Model with a quantizer and `VQ` head on top.""", WAV_2_VEC_2_START_DOCSTRING)
@add_start_docstrings("""Wav2Vec2 Model with a quantizer and `VQ` head on top.""", WAV2VEC2_START_DOCSTRING)
class Wav2Vec2ForPreTraining(Wav2Vec2PreTrainedModel):
def __init__(self, config: Wav2Vec2Config):
super().__init__(config)
@ -1902,7 +1903,7 @@ class Wav2Vec2ForPreTraining(Wav2Vec2PreTrainedModel):
logits = logits / temperature
return logits
@add_start_docstrings_to_model_forward(WAV_2_VEC_2_INPUTS_DOCSTRING)
@add_start_docstrings_to_model_forward(WAV2VEC2_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=Wav2Vec2ForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
@ -2065,7 +2066,7 @@ class Wav2Vec2ForPreTraining(Wav2Vec2PreTrainedModel):
)
@add_start_docstrings("""Wav2Vec2 Model with a `language modeling` head on top.""", WAV_2_VEC_2_START_DOCSTRING)
@add_start_docstrings("""Wav2Vec2 Model with a `language modeling` head on top.""", WAV2VEC2_START_DOCSTRING)
class Wav2Vec2ForMaskedLM(Wav2Vec2PreTrainedModel):
def __init__(self, config):
super().__init__(config)
@ -2081,7 +2082,7 @@ class Wav2Vec2ForMaskedLM(Wav2Vec2PreTrainedModel):
# Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward(WAV_2_VEC_2_INPUTS_DOCSTRING)
@add_start_docstrings_to_model_forward(WAV2VEC2_INPUTS_DOCSTRING)
def forward(
self,
input_values: torch.FloatTensor,
@ -2113,7 +2114,7 @@ class Wav2Vec2ForMaskedLM(Wav2Vec2PreTrainedModel):
@add_start_docstrings(
"""Wav2Vec2 Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC).""",
WAV_2_VEC_2_START_DOCSTRING,
WAV2VEC2_START_DOCSTRING,
"""
target_lang (`str`, *optional*):
Language id of adapter weights. Adapter weights are stored in the format adapter.<lang>.safetensors or
@ -2193,7 +2194,7 @@ class Wav2Vec2ForCTC(Wav2Vec2PreTrainedModel):
for param in self.wav2vec2.parameters():
param.requires_grad = False
@add_start_docstrings_to_model_forward(WAV_2_VEC_2_INPUTS_DOCSTRING)
@add_start_docstrings_to_model_forward(WAV2VEC2_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=CausalLMOutput,
@ -2277,7 +2278,7 @@ class Wav2Vec2ForCTC(Wav2Vec2PreTrainedModel):
Wav2Vec2 Model with a sequence classification head on top (a linear layer over the pooled output) for tasks like
SUPERB Keyword Spotting.
""",
WAV_2_VEC_2_START_DOCSTRING,
WAV2VEC2_START_DOCSTRING,
)
class Wav2Vec2ForSequenceClassification(Wav2Vec2PreTrainedModel):
def __init__(self, config):
@ -2324,7 +2325,7 @@ class Wav2Vec2ForSequenceClassification(Wav2Vec2PreTrainedModel):
for param in self.wav2vec2.parameters():
param.requires_grad = False
@add_start_docstrings_to_model_forward(WAV_2_VEC_2_INPUTS_DOCSTRING)
@add_start_docstrings_to_model_forward(WAV2VEC2_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint=_SEQ_CLASS_CHECKPOINT,
output_type=SequenceClassifierOutput,
@ -2400,7 +2401,7 @@ class Wav2Vec2ForSequenceClassification(Wav2Vec2PreTrainedModel):
"""
Wav2Vec2 Model with a frame classification head on top for tasks like Speaker Diarization.
""",
WAV_2_VEC_2_START_DOCSTRING,
WAV2VEC2_START_DOCSTRING,
)
class Wav2Vec2ForAudioFrameClassification(Wav2Vec2PreTrainedModel):
def __init__(self, config):
@ -2446,7 +2447,7 @@ class Wav2Vec2ForAudioFrameClassification(Wav2Vec2PreTrainedModel):
for param in self.wav2vec2.parameters():
param.requires_grad = False
@add_start_docstrings_to_model_forward(WAV_2_VEC_2_INPUTS_DOCSTRING)
@add_start_docstrings_to_model_forward(WAV2VEC2_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint=_FRAME_CLASS_CHECKPOINT,
output_type=TokenClassifierOutput,
@ -2546,6 +2547,7 @@ class TDNNLayer(nn.Module):
if is_peft_available():
from peft.tuners.lora import LoraLayer
if is_peft_available():
if isinstance(self.kernel, LoraLayer):
warnings.warn(
"Detected LoRA on TDNNLayer. LoRA weights won't be applied due to optimization. "
@ -2566,7 +2568,7 @@ class TDNNLayer(nn.Module):
"""
Wav2Vec2 Model with an XVector feature extraction head on top for tasks like Speaker Verification.
""",
WAV_2_VEC_2_START_DOCSTRING,
WAV2VEC2_START_DOCSTRING,
)
class Wav2Vec2ForXVector(Wav2Vec2PreTrainedModel):
def __init__(self, config):
@ -2630,7 +2632,7 @@ class Wav2Vec2ForXVector(Wav2Vec2PreTrainedModel):
return input_lengths
@add_start_docstrings_to_model_forward(WAV_2_VEC_2_INPUTS_DOCSTRING)
@add_start_docstrings_to_model_forward(WAV2VEC2_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint=_XVECTOR_CHECKPOINT,
output_type=XVectorOutput,

View File

@ -1,26 +1,15 @@
# coding=utf-8
# Copyright 2024 The Seamless Authors and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch Wav2Vec2-BERT model."""
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# This file was automatically generated from src/transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py.
# Do NOT edit this file manually as any edits will be overwritten by the generation of
# the file from the modular. If any change should be done, please apply the change to the
# modular_wav2vec2_bert.py file directly. One of our CI enforces this.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
import math
import warnings
from typing import Optional, Tuple, Union
import numpy as np
import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import CrossEntropyLoss
@ -42,213 +31,14 @@ from ...utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_peft_available,
logging,
)
from .configuration_wav2vec2_bert import Wav2Vec2BertConfig
logger = logging.get_logger(__name__)
_HIDDEN_STATES_START_POSITION = 2
# General docstring
_CONFIG_FOR_DOC = "Wav2Vec2BertConfig"
# Base docstring
_BASE_CHECKPOINT_FOR_DOC = "facebook/w2v-bert-2.0"
_PRETRAINED_CHECKPOINT_FOR_DOC = "hf-audio/wav2vec2-bert-CV16-en"
_EXPECTED_OUTPUT_SHAPE = [1, 146, 1024]
# CTC docstring
_CTC_EXPECTED_OUTPUT = "'mr quilter is the apostle of the middle classes and we are glad to welcome his gospel'"
_CTC_EXPECTED_LOSS = 17.04
# Copied from transformers.models.seamless_m4t_v2.modeling_seamless_m4t_v2._compute_new_attention_mask
def _compute_new_attention_mask(hidden_states: torch.Tensor, seq_lens: torch.Tensor):
"""
Computes an attention mask of the form `(batch, seq_len)` with an attention for each element in the batch that
stops at the corresponding element in `seq_lens`.
Args:
hidden_states (`torch.FloatTensor` of shape `(batch, seq_len, *)`):
The sequences to mask, where `*` is any number of sequence-specific dimensions including none.
seq_lens (`torch.Tensor` of shape `(batch)`:
Each element represents the length of the sequence at the same index in `hidden_states`
Returns:
`torch.FloatTensor`: The float attention mask of shape `(batch, seq_len)`
"""
batch_size, mask_seq_len = hidden_states.shape[:2]
indices = torch.arange(mask_seq_len, device=seq_lens.device).expand(batch_size, -1)
bool_mask = indices >= seq_lens.unsqueeze(1).expand(-1, mask_seq_len)
mask = hidden_states.new_ones((batch_size, mask_seq_len))
mask = mask.masked_fill(bool_mask, 0)
return mask
# Copied from transformers.models.wav2vec2.modeling_wav2vec2._compute_mask_indices
def _compute_mask_indices(
shape: Tuple[int, int],
mask_prob: float,
mask_length: int,
attention_mask: Optional[torch.LongTensor] = None,
min_masks: int = 0,
) -> np.ndarray:
"""
Computes random mask spans for a given shape. Used to implement [SpecAugment: A Simple Data Augmentation Method for
ASR](https://arxiv.org/abs/1904.08779). Note that this method is not optimized to run on TPU and should be run on
CPU as part of the preprocessing during training.
Args:
shape: The shape for which to compute masks. This should be of a tuple of size 2 where
the first element is the batch size and the second element is the length of the axis to span.
mask_prob: The percentage of the whole axis (between 0 and 1) which will be masked. The number of
independently generated mask spans of length `mask_length` is computed by
`mask_prob*shape[1]/mask_length`. Note that due to overlaps, `mask_prob` is an upper bound and the
actual percentage will be smaller.
mask_length: size of the mask
min_masks: minimum number of masked spans
attention_mask: A (right-padded) attention mask which independently shortens the feature axis of
each batch dimension.
"""
batch_size, sequence_length = shape
if mask_length < 1:
raise ValueError("`mask_length` has to be bigger than 0.")
if mask_length > sequence_length:
raise ValueError(
f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length}"
f" and `sequence_length`: {sequence_length}`"
)
# epsilon is used for probabilistic rounding
epsilon = np.random.rand(1).item()
def compute_num_masked_span(input_length):
"""Given input length, compute how many spans should be masked"""
num_masked_span = int(mask_prob * input_length / mask_length + epsilon)
num_masked_span = max(num_masked_span, min_masks)
# make sure num masked span <= sequence_length
if num_masked_span * mask_length > sequence_length:
num_masked_span = sequence_length // mask_length
# make sure num_masked span is also <= input_length - (mask_length - 1)
if input_length - (mask_length - 1) < num_masked_span:
num_masked_span = max(input_length - (mask_length - 1), 0)
return num_masked_span
# compute number of masked spans in batch
input_lengths = (
attention_mask.detach().sum(-1).tolist()
if attention_mask is not None
else [sequence_length for _ in range(batch_size)]
)
# SpecAugment mask to fill
spec_aug_mask = np.zeros((batch_size, sequence_length), dtype=bool)
spec_aug_mask_idxs = []
max_num_masked_span = compute_num_masked_span(sequence_length)
if max_num_masked_span == 0:
return spec_aug_mask
for input_length in input_lengths:
# compute num of masked spans for this input
num_masked_span = compute_num_masked_span(input_length)
# get random indices to mask
spec_aug_mask_idx = np.random.choice(
np.arange(input_length - (mask_length - 1)), num_masked_span, replace=False
)
# pick first sampled index that will serve as a dummy index to pad vector
# to ensure same dimension for all batches due to probabilistic rounding
# Picking first sample just pads those vectors twice.
if len(spec_aug_mask_idx) == 0:
# this case can only happen if `input_length` is strictly smaller then
# `sequence_length` in which case the last token has to be a padding
# token which we can use as a dummy mask id
dummy_mask_idx = sequence_length - 1
else:
dummy_mask_idx = spec_aug_mask_idx[0]
spec_aug_mask_idx = np.concatenate(
[spec_aug_mask_idx, np.ones(max_num_masked_span - num_masked_span, dtype=np.int32) * dummy_mask_idx]
)
spec_aug_mask_idxs.append(spec_aug_mask_idx)
spec_aug_mask_idxs = np.array(spec_aug_mask_idxs)
# expand masked indices to masked spans
spec_aug_mask_idxs = np.broadcast_to(
spec_aug_mask_idxs[:, :, None], (batch_size, max_num_masked_span, mask_length)
)
spec_aug_mask_idxs = spec_aug_mask_idxs.reshape(batch_size, max_num_masked_span * mask_length)
# add offset to the starting indexes so that indexes now create a span
offsets = np.arange(mask_length)[None, None, :]
offsets = np.broadcast_to(offsets, (batch_size, max_num_masked_span, mask_length)).reshape(
batch_size, max_num_masked_span * mask_length
)
spec_aug_mask_idxs = spec_aug_mask_idxs + offsets
# ensure that we cannot have indices larger than sequence_length
if spec_aug_mask_idxs.max() > sequence_length - 1:
spec_aug_mask_idxs[spec_aug_mask_idxs > sequence_length - 1] = sequence_length - 1
# scatter indices to mask
np.put_along_axis(spec_aug_mask, spec_aug_mask_idxs, 1, -1)
return spec_aug_mask
# Copied from transformers.models.wav2vec2.modeling_wav2vec2._sample_negative_indices
def _sample_negative_indices(
features_shape: Tuple, num_negatives: int, mask_time_indices: Optional[np.ndarray] = None
):
"""
Sample `num_negatives` vectors from feature vectors.
"""
batch_size, sequence_length = features_shape
# generate indices of the positive vectors themselves, repeat them `num_negatives` times
sequence_length_range = np.arange(sequence_length)
# get `num_negatives` random vector indices from the same utterance
sampled_negative_indices = np.zeros(shape=(batch_size, sequence_length, num_negatives), dtype=np.int32)
mask_time_indices = (
mask_time_indices.astype(bool) if mask_time_indices is not None else np.ones(features_shape, dtype=bool)
)
for batch_idx in range(batch_size):
high = mask_time_indices[batch_idx].sum() - 1
mapped_masked_indices = sequence_length_range[mask_time_indices[batch_idx]]
feature_indices = np.broadcast_to(np.arange(high + 1)[:, None], (high + 1, num_negatives))
sampled_indices = np.random.randint(0, high, size=(high + 1, num_negatives))
# avoid sampling the same positive vector, but keep the distribution uniform
sampled_indices[sampled_indices >= feature_indices] += 1
# remap to actual indices
sampled_negative_indices[batch_idx][mask_time_indices[batch_idx]] = mapped_masked_indices[sampled_indices]
# correct for batch size
sampled_negative_indices[batch_idx] += batch_idx * sequence_length
return sampled_negative_indices
# Copied from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer.Wav2Vec2ConformerRotaryPositionalEmbedding with Wav2Vec2Conformer->Wav2Vec2Bert
class Wav2Vec2BertRotaryPositionalEmbedding(nn.Module):
"""Rotary positional embedding
Reference : https://blog.eleuther.ai/rotary-embeddings/ Paper: https://arxiv.org/pdf/2104.09864.pdf
@ -284,7 +74,6 @@ class Wav2Vec2BertRotaryPositionalEmbedding(nn.Module):
return self.cached_rotary_positional_embedding
# Copied from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer.Wav2Vec2ConformerRelPositionalEmbedding with Wav2Vec2Conformer->Wav2Vec2Bert
class Wav2Vec2BertRelPositionalEmbedding(nn.Module):
"""Relative positional encoding module."""
@ -363,7 +152,6 @@ class Wav2Vec2BertFeedForward(nn.Module):
self.output_dense = nn.Linear(config.intermediate_size, hidden_size)
self.output_dropout = nn.Dropout(config.hidden_dropout)
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2FeedForward.forward
def forward(self, hidden_states):
hidden_states = self.intermediate_dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states)
@ -556,7 +344,6 @@ class Wav2Vec2BertSelfAttention(nn.Module):
return hidden_states, probs
# Copied from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer.Wav2Vec2ConformerSelfAttention._apply_rotary_embedding
def _apply_rotary_embedding(self, hidden_states, relative_position_embeddings):
batch_size, sequence_length, hidden_size = hidden_states.size()
hidden_states = hidden_states.view(batch_size, sequence_length, self.num_heads, self.head_size)
@ -576,7 +363,6 @@ class Wav2Vec2BertSelfAttention(nn.Module):
return hidden_states
# Copied from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer.Wav2Vec2ConformerSelfAttention._apply_relative_embeddings
def _apply_relative_embeddings(self, query, key, relative_position_embeddings):
# 1. project positional embeddings
# => (batch, head, 2*time1-1, d_k)
@ -823,6 +609,32 @@ class Wav2Vec2BertAdapter(nn.Module):
return hidden_states
# Copied from transformers.models.seamless_m4t_v2.modeling_seamless_m4t_v2._compute_new_attention_mask
def _compute_new_attention_mask(hidden_states: torch.Tensor, seq_lens: torch.Tensor):
"""
Computes an attention mask of the form `(batch, seq_len)` with an attention for each element in the batch that
stops at the corresponding element in `seq_lens`.
Args:
hidden_states (`torch.FloatTensor` of shape `(batch, seq_len, *)`):
The sequences to mask, where `*` is any number of sequence-specific dimensions including none.
seq_lens (`torch.Tensor` of shape `(batch)`:
Each element represents the length of the sequence at the same index in `hidden_states`
Returns:
`torch.FloatTensor`: The float attention mask of shape `(batch, seq_len)`
"""
batch_size, mask_seq_len = hidden_states.shape[:2]
indices = torch.arange(mask_seq_len, device=seq_lens.device).expand(batch_size, -1)
bool_mask = indices >= seq_lens.unsqueeze(1).expand(-1, mask_seq_len)
mask = hidden_states.new_ones((batch_size, mask_seq_len))
mask = mask.masked_fill(bool_mask, 0)
return mask
class Wav2Vec2BertAdapterLayer(nn.Module):
def __init__(self, config):
super().__init__()
@ -911,7 +723,6 @@ class Wav2Vec2BertAdapterLayer(nn.Module):
return hidden_states
# Copied from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer.Wav2Vec2ConformerPreTrainedModel with Wav2Vec2Conformer->Wav2Vec2Bert,wav2vec2_conformer->wav2vec2_bert, input_values->input_features
class Wav2Vec2BertPreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
@ -995,6 +806,129 @@ class Wav2Vec2BertPreTrainedModel(PreTrainedModel):
return attention_mask
def _compute_mask_indices(
shape: Tuple[int, int],
mask_prob: float,
mask_length: int,
attention_mask: Optional[torch.LongTensor] = None,
min_masks: int = 0,
) -> np.ndarray:
"""
Computes random mask spans for a given shape. Used to implement [SpecAugment: A Simple Data Augmentation Method for
ASR](https://arxiv.org/abs/1904.08779). Note that this method is not optimized to run on TPU and should be run on
CPU as part of the preprocessing during training.
Args:
shape: The shape for which to compute masks. This should be of a tuple of size 2 where
the first element is the batch size and the second element is the length of the axis to span.
mask_prob: The percentage of the whole axis (between 0 and 1) which will be masked. The number of
independently generated mask spans of length `mask_length` is computed by
`mask_prob*shape[1]/mask_length`. Note that due to overlaps, `mask_prob` is an upper bound and the
actual percentage will be smaller.
mask_length: size of the mask
min_masks: minimum number of masked spans
attention_mask: A (right-padded) attention mask which independently shortens the feature axis of
each batch dimension.
"""
batch_size, sequence_length = shape
if mask_length < 1:
raise ValueError("`mask_length` has to be bigger than 0.")
if mask_length > sequence_length:
raise ValueError(
f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length}"
f" and `sequence_length`: {sequence_length}`"
)
# epsilon is used for probabilistic rounding
epsilon = np.random.rand(1).item()
def compute_num_masked_span(input_length):
"""Given input length, compute how many spans should be masked"""
num_masked_span = int(mask_prob * input_length / mask_length + epsilon)
num_masked_span = max(num_masked_span, min_masks)
# make sure num masked span <= sequence_length
if num_masked_span * mask_length > sequence_length:
num_masked_span = sequence_length // mask_length
# make sure num_masked span is also <= input_length - (mask_length - 1)
if input_length - (mask_length - 1) < num_masked_span:
num_masked_span = max(input_length - (mask_length - 1), 0)
return num_masked_span
# compute number of masked spans in batch
input_lengths = (
attention_mask.detach().sum(-1).tolist()
if attention_mask is not None
else [sequence_length for _ in range(batch_size)]
)
# SpecAugment mask to fill
spec_aug_mask = np.zeros((batch_size, sequence_length), dtype=bool)
spec_aug_mask_idxs = []
max_num_masked_span = compute_num_masked_span(sequence_length)
if max_num_masked_span == 0:
return spec_aug_mask
for input_length in input_lengths:
# compute num of masked spans for this input
num_masked_span = compute_num_masked_span(input_length)
# get random indices to mask
spec_aug_mask_idx = np.random.choice(
np.arange(input_length - (mask_length - 1)), num_masked_span, replace=False
)
# pick first sampled index that will serve as a dummy index to pad vector
# to ensure same dimension for all batches due to probabilistic rounding
# Picking first sample just pads those vectors twice.
if len(spec_aug_mask_idx) == 0:
# this case can only happen if `input_length` is strictly smaller then
# `sequence_length` in which case the last token has to be a padding
# token which we can use as a dummy mask id
dummy_mask_idx = sequence_length - 1
else:
dummy_mask_idx = spec_aug_mask_idx[0]
spec_aug_mask_idx = np.concatenate(
[spec_aug_mask_idx, np.ones(max_num_masked_span - num_masked_span, dtype=np.int32) * dummy_mask_idx]
)
spec_aug_mask_idxs.append(spec_aug_mask_idx)
spec_aug_mask_idxs = np.array(spec_aug_mask_idxs)
# expand masked indices to masked spans
spec_aug_mask_idxs = np.broadcast_to(
spec_aug_mask_idxs[:, :, None], (batch_size, max_num_masked_span, mask_length)
)
spec_aug_mask_idxs = spec_aug_mask_idxs.reshape(batch_size, max_num_masked_span * mask_length)
# add offset to the starting indexes so that indexes now create a span
offsets = np.arange(mask_length)[None, None, :]
offsets = np.broadcast_to(offsets, (batch_size, max_num_masked_span, mask_length)).reshape(
batch_size, max_num_masked_span * mask_length
)
spec_aug_mask_idxs = spec_aug_mask_idxs + offsets
# ensure that we cannot have indices larger than sequence_length
if spec_aug_mask_idxs.max() > sequence_length - 1:
spec_aug_mask_idxs[spec_aug_mask_idxs > sequence_length - 1] = sequence_length - 1
# scatter indices to mask
np.put_along_axis(spec_aug_mask, spec_aug_mask_idxs, 1, -1)
return spec_aug_mask
_PRETRAINED_CHECKPOINT_FOR_DOC = "hf-audio/wav2vec2-bert-CV16-en"
_EXPECTED_OUTPUT_SHAPE = [1, 146, 1024]
WAV2VEC2_BERT_START_DOCSTRING = r"""
Wav2Vec2Bert was proposed in [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech
Representations](https://arxiv.org/abs/2006.11477) by Alexei Baevski, Henry Zhou, Abdelrahman Mohamed, Michael
@ -1003,8 +937,9 @@ WAV2VEC2_BERT_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving etc.).
This model is a PyTorch [nn.Module](https://pytorch.org/docs/stable/nn.html#nn.Module) sub-class. Use it as a
regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior.
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use
it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
behavior.
Parameters:
config ([`Wav2Vec2BertConfig`]): Model configuration class with all the parameters of the model.
@ -1012,7 +947,6 @@ WAV2VEC2_BERT_START_DOCSTRING = r"""
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
WAV2VEC2_BERT_INPUTS_DOCSTRING = r"""
Args:
input_features (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
@ -1039,6 +973,9 @@ WAV2VEC2_BERT_INPUTS_DOCSTRING = r"""
"""
Wav2Vec2BertBaseModelOutput = Wav2Vec2BaseModelOutput
@add_start_docstrings(
"The bare Wav2Vec2Bert Model transformer outputting raw hidden-states without any specific head on top.",
WAV2VEC2_BERT_START_DOCSTRING,
@ -1064,7 +1001,6 @@ class Wav2Vec2BertModel(Wav2Vec2BertPreTrainedModel):
# Initialize weights and apply final processing
self.post_init()
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Model._mask_hidden_states
def _mask_hidden_states(
self,
hidden_states: torch.FloatTensor,
@ -1114,7 +1050,7 @@ class Wav2Vec2BertModel(Wav2Vec2BertPreTrainedModel):
@add_start_docstrings_to_model_forward(WAV2VEC2_BERT_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint=_PRETRAINED_CHECKPOINT_FOR_DOC,
output_type=Wav2Vec2BaseModelOutput,
output_type=Wav2Vec2BertBaseModelOutput,
config_class=_CONFIG_FOR_DOC,
modality="audio",
expected_output=_EXPECTED_OUTPUT_SHAPE,
@ -1127,7 +1063,7 @@ class Wav2Vec2BertModel(Wav2Vec2BertPreTrainedModel):
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, Wav2Vec2BaseModelOutput]:
) -> Union[Tuple, Wav2Vec2BertBaseModelOutput]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
@ -1159,7 +1095,7 @@ class Wav2Vec2BertModel(Wav2Vec2BertPreTrainedModel):
if not return_dict:
return (hidden_states, extract_features) + encoder_outputs[1:]
return Wav2Vec2BaseModelOutput(
return Wav2Vec2BertBaseModelOutput(
last_hidden_state=hidden_states,
extract_features=extract_features,
hidden_states=encoder_outputs.hidden_states,
@ -1167,12 +1103,18 @@ class Wav2Vec2BertModel(Wav2Vec2BertPreTrainedModel):
)
_HIDDEN_STATES_START_POSITION = 2
# CTC docstring
_CTC_EXPECTED_OUTPUT = "'mr quilter is the apostle of the middle classes and we are glad to welcome his gospel'"
_CTC_EXPECTED_LOSS = 17.04
@add_start_docstrings(
"""Wav2Vec2Bert Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC).""",
WAV2VEC2_BERT_START_DOCSTRING,
)
class Wav2Vec2BertForCTC(Wav2Vec2BertPreTrainedModel):
# Copied from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer.Wav2Vec2ConformerForCTC.__init__ with Wav2Vec2Conformer->Wav2Vec2Bert,WAV2VEC2_CONFORMER->WAV2VEC2_BERT,wav2vec2_conformer->wav2vec2_bert
def __init__(self, config, target_lang: Optional[str] = None):
super().__init__(config)
@ -1277,6 +1219,10 @@ class Wav2Vec2BertForCTC(Wav2Vec2BertPreTrainedModel):
)
# Base docstring
_BASE_CHECKPOINT_FOR_DOC = "facebook/w2v-bert-2.0"
@add_start_docstrings(
"""
Wav2Vec2Bert Model with a sequence classification head on top (a linear layer over the pooled output) for
@ -1285,7 +1231,6 @@ class Wav2Vec2BertForCTC(Wav2Vec2BertPreTrainedModel):
WAV2VEC2_BERT_START_DOCSTRING,
)
class Wav2Vec2BertForSequenceClassification(Wav2Vec2BertPreTrainedModel):
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.__init__ with Wav2Vec2->Wav2Vec2Bert,wav2vec2->wav2vec2_bert
def __init__(self, config):
super().__init__(config)
@ -1318,7 +1263,6 @@ class Wav2Vec2BertForSequenceClassification(Wav2Vec2BertPreTrainedModel):
config_class=_CONFIG_FOR_DOC,
modality="audio",
)
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.forward with Wav2Vec2->Wav2Vec2Bert,wav2vec2->wav2vec2_bert,WAV_2_VEC_2->WAV2VEC2_BERT, input_values->input_features
def forward(
self,
input_features: Optional[torch.Tensor],
@ -1389,7 +1333,6 @@ class Wav2Vec2BertForSequenceClassification(Wav2Vec2BertPreTrainedModel):
WAV2VEC2_BERT_START_DOCSTRING,
)
class Wav2Vec2BertForAudioFrameClassification(Wav2Vec2BertPreTrainedModel):
# Copied from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer.Wav2Vec2ConformerForAudioFrameClassification.__init__ with Wav2Vec2Conformer->Wav2Vec2Bert,WAV2VEC2_CONFORMER->WAV2VEC2_BERT,wav2vec2_conformer->wav2vec2_bert
def __init__(self, config):
super().__init__(config)
@ -1406,7 +1349,6 @@ class Wav2Vec2BertForAudioFrameClassification(Wav2Vec2BertPreTrainedModel):
self.init_weights()
# Copied from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer.Wav2Vec2ConformerForAudioFrameClassification.freeze_base_model with wav2vec2_conformer->wav2vec2_bert
def freeze_base_model(self):
"""
Calling this function will disable the gradient computation for the base model so that its parameters will not
@ -1422,7 +1364,6 @@ class Wav2Vec2BertForAudioFrameClassification(Wav2Vec2BertPreTrainedModel):
config_class=_CONFIG_FOR_DOC,
modality="audio",
)
# Copied from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer.Wav2Vec2ConformerForAudioFrameClassification.forward with wav2vec2_conformer->wav2vec2_bert, input_values->input_features
def forward(
self,
input_features: Optional[torch.Tensor],
@ -1477,7 +1418,6 @@ class Wav2Vec2BertForAudioFrameClassification(Wav2Vec2BertPreTrainedModel):
)
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.AMSoftmaxLoss
class AMSoftmaxLoss(nn.Module):
def __init__(self, input_dim, num_labels, scale=30.0, margin=0.4):
super(AMSoftmaxLoss, self).__init__()
@ -1501,7 +1441,6 @@ class AMSoftmaxLoss(nn.Module):
return loss
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.TDNNLayer
class TDNNLayer(nn.Module):
def __init__(self, config, layer_id=0):
super().__init__()
@ -1517,6 +1456,7 @@ class TDNNLayer(nn.Module):
if is_peft_available():
from peft.tuners.lora import LoraLayer
if is_peft_available():
if isinstance(self.kernel, LoraLayer):
warnings.warn(
"Detected LoRA on TDNNLayer. LoRA weights won't be applied due to optimization. "
@ -1540,7 +1480,6 @@ class TDNNLayer(nn.Module):
WAV2VEC2_BERT_START_DOCSTRING,
)
class Wav2Vec2BertForXVector(Wav2Vec2BertPreTrainedModel):
# Copied from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer.Wav2Vec2ConformerForXVector.__init__ with Wav2Vec2Conformer->Wav2Vec2Bert,WAV2VEC2_CONFORMER->WAV2VEC2_BERT,wav2vec2_conformer->wav2vec2_bert
def __init__(self, config):
super().__init__(config)
@ -1560,7 +1499,6 @@ class Wav2Vec2BertForXVector(Wav2Vec2BertPreTrainedModel):
self.init_weights()
# Copied from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer.Wav2Vec2ConformerForXVector.freeze_base_model with wav2vec2_conformer->wav2vec2_bert
def freeze_base_model(self):
"""
Calling this function will disable the gradient computation for the base model so that its parameters will not
@ -1569,7 +1507,6 @@ class Wav2Vec2BertForXVector(Wav2Vec2BertPreTrainedModel):
for param in self.wav2vec2_bert.parameters():
param.requires_grad = False
# Copied from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer.Wav2Vec2ConformerForXVector._get_tdnn_output_lengths
def _get_tdnn_output_lengths(self, input_lengths: Union[torch.LongTensor, int]):
"""
Computes the output length of the TDNN layers
@ -1592,7 +1529,6 @@ class Wav2Vec2BertForXVector(Wav2Vec2BertPreTrainedModel):
config_class=_CONFIG_FOR_DOC,
modality="audio",
)
# Copied from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer.Wav2Vec2ConformerForXVector.forward with wav2vec2_conformer->wav2vec2_bert, input_values->input_features
def forward(
self,
input_features: Optional[torch.Tensor],

File diff suppressed because it is too large Load Diff

View File

@ -1,19 +1,9 @@
# coding=utf-8
# Copyright 2022 The Fairseq Authors and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch Wav2Vec2-Conformer model."""
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# This file was automatically generated from src/transformers/models/wav2vec2_conformer/modular_wav2vec2_conformer.py.
# Do NOT edit this file manually as any edits will be overwritten by the generation of
# the file from the modular. If any change should be done, please apply the change to the
# modular_wav2vec2_conformer.py file directly. One of our CI enforces this.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
import math
import warnings
from dataclasses import dataclass
@ -21,7 +11,6 @@ from typing import Optional, Tuple, Union
import numpy as np
import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import CrossEntropyLoss
@ -43,31 +32,19 @@ from ...utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_peft_available,
logging,
replace_return_docstrings,
)
from .configuration_wav2vec2_conformer import Wav2Vec2ConformerConfig
logger = logging.get_logger(__name__)
_HIDDEN_STATES_START_POSITION = 2
# Base docstring
_CHECKPOINT_FOR_DOC = "facebook/wav2vec2-conformer-rope-large-960h-ft"
# General docstring
_CONFIG_FOR_DOC = "Wav2Vec2ConformerConfig"
# Base docstring
_CHECKPOINT_FOR_DOC = "facebook/wav2vec2-conformer-rope-large-960h-ft"
_EXPECTED_OUTPUT_SHAPE = [1, 292, 1024]
# CTC docstring
_CTC_EXPECTED_OUTPUT = "'MISTER QUILTER IS THE APOSTLE OF THE MIDDLE CLASSES AND WE ARE GLAD TO WELCOME HIS GOSPEL'"
_CTC_EXPECTED_LOSS = 64.21
@dataclass
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForPreTrainingOutput with Wav2Vec2->Wav2Vec2Conformer
class Wav2Vec2ConformerForPreTrainingOutput(ModelOutput):
"""
Output type of [`Wav2Vec2ConformerForPreTraining`], with potential hidden states and attentions.
@ -109,239 +86,17 @@ class Wav2Vec2ConformerForPreTrainingOutput(ModelOutput):
diversity_loss: Optional[torch.FloatTensor] = None
# Copied from transformers.models.wav2vec2.modeling_wav2vec2._compute_mask_indices
def _compute_mask_indices(
shape: Tuple[int, int],
mask_prob: float,
mask_length: int,
attention_mask: Optional[torch.LongTensor] = None,
min_masks: int = 0,
) -> np.ndarray:
"""
Computes random mask spans for a given shape. Used to implement [SpecAugment: A Simple Data Augmentation Method for
ASR](https://arxiv.org/abs/1904.08779). Note that this method is not optimized to run on TPU and should be run on
CPU as part of the preprocessing during training.
Args:
shape: The shape for which to compute masks. This should be of a tuple of size 2 where
the first element is the batch size and the second element is the length of the axis to span.
mask_prob: The percentage of the whole axis (between 0 and 1) which will be masked. The number of
independently generated mask spans of length `mask_length` is computed by
`mask_prob*shape[1]/mask_length`. Note that due to overlaps, `mask_prob` is an upper bound and the
actual percentage will be smaller.
mask_length: size of the mask
min_masks: minimum number of masked spans
attention_mask: A (right-padded) attention mask which independently shortens the feature axis of
each batch dimension.
"""
batch_size, sequence_length = shape
if mask_length < 1:
raise ValueError("`mask_length` has to be bigger than 0.")
if mask_length > sequence_length:
raise ValueError(
f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length}"
f" and `sequence_length`: {sequence_length}`"
)
# epsilon is used for probabilistic rounding
epsilon = np.random.rand(1).item()
def compute_num_masked_span(input_length):
"""Given input length, compute how many spans should be masked"""
num_masked_span = int(mask_prob * input_length / mask_length + epsilon)
num_masked_span = max(num_masked_span, min_masks)
# make sure num masked span <= sequence_length
if num_masked_span * mask_length > sequence_length:
num_masked_span = sequence_length // mask_length
# make sure num_masked span is also <= input_length - (mask_length - 1)
if input_length - (mask_length - 1) < num_masked_span:
num_masked_span = max(input_length - (mask_length - 1), 0)
return num_masked_span
# compute number of masked spans in batch
input_lengths = (
attention_mask.detach().sum(-1).tolist()
if attention_mask is not None
else [sequence_length for _ in range(batch_size)]
)
# SpecAugment mask to fill
spec_aug_mask = np.zeros((batch_size, sequence_length), dtype=bool)
spec_aug_mask_idxs = []
max_num_masked_span = compute_num_masked_span(sequence_length)
if max_num_masked_span == 0:
return spec_aug_mask
for input_length in input_lengths:
# compute num of masked spans for this input
num_masked_span = compute_num_masked_span(input_length)
# get random indices to mask
spec_aug_mask_idx = np.random.choice(
np.arange(input_length - (mask_length - 1)), num_masked_span, replace=False
)
# pick first sampled index that will serve as a dummy index to pad vector
# to ensure same dimension for all batches due to probabilistic rounding
# Picking first sample just pads those vectors twice.
if len(spec_aug_mask_idx) == 0:
# this case can only happen if `input_length` is strictly smaller then
# `sequence_length` in which case the last token has to be a padding
# token which we can use as a dummy mask id
dummy_mask_idx = sequence_length - 1
else:
dummy_mask_idx = spec_aug_mask_idx[0]
spec_aug_mask_idx = np.concatenate(
[spec_aug_mask_idx, np.ones(max_num_masked_span - num_masked_span, dtype=np.int32) * dummy_mask_idx]
)
spec_aug_mask_idxs.append(spec_aug_mask_idx)
spec_aug_mask_idxs = np.array(spec_aug_mask_idxs)
# expand masked indices to masked spans
spec_aug_mask_idxs = np.broadcast_to(
spec_aug_mask_idxs[:, :, None], (batch_size, max_num_masked_span, mask_length)
)
spec_aug_mask_idxs = spec_aug_mask_idxs.reshape(batch_size, max_num_masked_span * mask_length)
# add offset to the starting indexes so that indexes now create a span
offsets = np.arange(mask_length)[None, None, :]
offsets = np.broadcast_to(offsets, (batch_size, max_num_masked_span, mask_length)).reshape(
batch_size, max_num_masked_span * mask_length
)
spec_aug_mask_idxs = spec_aug_mask_idxs + offsets
# ensure that we cannot have indices larger than sequence_length
if spec_aug_mask_idxs.max() > sequence_length - 1:
spec_aug_mask_idxs[spec_aug_mask_idxs > sequence_length - 1] = sequence_length - 1
# scatter indices to mask
np.put_along_axis(spec_aug_mask, spec_aug_mask_idxs, 1, -1)
return spec_aug_mask
# Copied from transformers.models.wav2vec2.modeling_wav2vec2._sample_negative_indices
def _sample_negative_indices(
features_shape: Tuple, num_negatives: int, mask_time_indices: Optional[np.ndarray] = None
):
"""
Sample `num_negatives` vectors from feature vectors.
"""
batch_size, sequence_length = features_shape
# generate indices of the positive vectors themselves, repeat them `num_negatives` times
sequence_length_range = np.arange(sequence_length)
# get `num_negatives` random vector indices from the same utterance
sampled_negative_indices = np.zeros(shape=(batch_size, sequence_length, num_negatives), dtype=np.int32)
mask_time_indices = (
mask_time_indices.astype(bool) if mask_time_indices is not None else np.ones(features_shape, dtype=bool)
)
for batch_idx in range(batch_size):
high = mask_time_indices[batch_idx].sum() - 1
mapped_masked_indices = sequence_length_range[mask_time_indices[batch_idx]]
feature_indices = np.broadcast_to(np.arange(high + 1)[:, None], (high + 1, num_negatives))
sampled_indices = np.random.randint(0, high, size=(high + 1, num_negatives))
# avoid sampling the same positive vector, but keep the distribution uniform
sampled_indices[sampled_indices >= feature_indices] += 1
# remap to actual indices
sampled_negative_indices[batch_idx][mask_time_indices[batch_idx]] = mapped_masked_indices[sampled_indices]
# correct for batch size
sampled_negative_indices[batch_idx] += batch_idx * sequence_length
return sampled_negative_indices
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2NoLayerNormConvLayer with Wav2Vec2->Wav2Vec2Conformer
class Wav2Vec2ConformerNoLayerNormConvLayer(nn.Module):
def __init__(self, config, layer_id=0):
class Wav2Vec2ConformerSamePadLayer(nn.Module):
def __init__(self, num_conv_pos_embeddings):
super().__init__()
self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
self.out_conv_dim = config.conv_dim[layer_id]
self.conv = nn.Conv1d(
self.in_conv_dim,
self.out_conv_dim,
kernel_size=config.conv_kernel[layer_id],
stride=config.conv_stride[layer_id],
bias=config.conv_bias,
)
self.activation = ACT2FN[config.feat_extract_activation]
self.num_pad_remove = 1 if num_conv_pos_embeddings % 2 == 0 else 0
def forward(self, hidden_states):
hidden_states = self.conv(hidden_states)
hidden_states = self.activation(hidden_states)
if self.num_pad_remove > 0:
hidden_states = hidden_states[:, :, : -self.num_pad_remove]
return hidden_states
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2LayerNormConvLayer with Wav2Vec2->Wav2Vec2Conformer
class Wav2Vec2ConformerLayerNormConvLayer(nn.Module):
def __init__(self, config, layer_id=0):
super().__init__()
self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
self.out_conv_dim = config.conv_dim[layer_id]
self.conv = nn.Conv1d(
self.in_conv_dim,
self.out_conv_dim,
kernel_size=config.conv_kernel[layer_id],
stride=config.conv_stride[layer_id],
bias=config.conv_bias,
)
self.layer_norm = nn.LayerNorm(self.out_conv_dim, elementwise_affine=True)
self.activation = ACT2FN[config.feat_extract_activation]
def forward(self, hidden_states):
hidden_states = self.conv(hidden_states)
hidden_states = hidden_states.transpose(-2, -1)
hidden_states = self.layer_norm(hidden_states)
hidden_states = hidden_states.transpose(-2, -1)
hidden_states = self.activation(hidden_states)
return hidden_states
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2GroupNormConvLayer with Wav2Vec2->Wav2Vec2Conformer
class Wav2Vec2ConformerGroupNormConvLayer(nn.Module):
def __init__(self, config, layer_id=0):
super().__init__()
self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
self.out_conv_dim = config.conv_dim[layer_id]
self.conv = nn.Conv1d(
self.in_conv_dim,
self.out_conv_dim,
kernel_size=config.conv_kernel[layer_id],
stride=config.conv_stride[layer_id],
bias=config.conv_bias,
)
self.activation = ACT2FN[config.feat_extract_activation]
self.layer_norm = nn.GroupNorm(num_groups=self.out_conv_dim, num_channels=self.out_conv_dim, affine=True)
def forward(self, hidden_states):
hidden_states = self.conv(hidden_states)
hidden_states = self.layer_norm(hidden_states)
hidden_states = self.activation(hidden_states)
return hidden_states
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2PositionalConvEmbedding with Wav2Vec2->Wav2Vec2Conformer
class Wav2Vec2ConformerPositionalConvEmbedding(nn.Module):
def __init__(self, config):
super().__init__()
@ -471,19 +226,78 @@ class Wav2Vec2ConformerRelPositionalEmbedding(nn.Module):
return relative_position_embeddings
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2SamePadLayer with Wav2Vec2->Wav2Vec2Conformer
class Wav2Vec2ConformerSamePadLayer(nn.Module):
def __init__(self, num_conv_pos_embeddings):
class Wav2Vec2ConformerNoLayerNormConvLayer(nn.Module):
def __init__(self, config, layer_id=0):
super().__init__()
self.num_pad_remove = 1 if num_conv_pos_embeddings % 2 == 0 else 0
self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
self.out_conv_dim = config.conv_dim[layer_id]
self.conv = nn.Conv1d(
self.in_conv_dim,
self.out_conv_dim,
kernel_size=config.conv_kernel[layer_id],
stride=config.conv_stride[layer_id],
bias=config.conv_bias,
)
self.activation = ACT2FN[config.feat_extract_activation]
def forward(self, hidden_states):
if self.num_pad_remove > 0:
hidden_states = hidden_states[:, :, : -self.num_pad_remove]
hidden_states = self.conv(hidden_states)
hidden_states = self.activation(hidden_states)
return hidden_states
class Wav2Vec2ConformerLayerNormConvLayer(nn.Module):
def __init__(self, config, layer_id=0):
super().__init__()
self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
self.out_conv_dim = config.conv_dim[layer_id]
self.conv = nn.Conv1d(
self.in_conv_dim,
self.out_conv_dim,
kernel_size=config.conv_kernel[layer_id],
stride=config.conv_stride[layer_id],
bias=config.conv_bias,
)
self.layer_norm = nn.LayerNorm(self.out_conv_dim, elementwise_affine=True)
self.activation = ACT2FN[config.feat_extract_activation]
def forward(self, hidden_states):
hidden_states = self.conv(hidden_states)
hidden_states = hidden_states.transpose(-2, -1)
hidden_states = self.layer_norm(hidden_states)
hidden_states = hidden_states.transpose(-2, -1)
hidden_states = self.activation(hidden_states)
return hidden_states
class Wav2Vec2ConformerGroupNormConvLayer(nn.Module):
def __init__(self, config, layer_id=0):
super().__init__()
self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
self.out_conv_dim = config.conv_dim[layer_id]
self.conv = nn.Conv1d(
self.in_conv_dim,
self.out_conv_dim,
kernel_size=config.conv_kernel[layer_id],
stride=config.conv_stride[layer_id],
bias=config.conv_bias,
)
self.activation = ACT2FN[config.feat_extract_activation]
self.layer_norm = nn.GroupNorm(num_groups=self.out_conv_dim, num_channels=self.out_conv_dim, affine=True)
def forward(self, hidden_states):
hidden_states = self.conv(hidden_states)
hidden_states = self.layer_norm(hidden_states)
hidden_states = self.activation(hidden_states)
return hidden_states
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2FeatureEncoder with Wav2Vec2->Wav2Vec2Conformer
class Wav2Vec2ConformerFeatureEncoder(nn.Module):
"""Construct the features from raw audio waveform"""
@ -531,7 +345,6 @@ class Wav2Vec2ConformerFeatureEncoder(nn.Module):
return hidden_states
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2FeatureProjection with Wav2Vec2->Wav2Vec2Conformer
class Wav2Vec2ConformerFeatureProjection(nn.Module):
def __init__(self, config):
super().__init__()
@ -547,7 +360,6 @@ class Wav2Vec2ConformerFeatureProjection(nn.Module):
return hidden_states, norm_hidden_states
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2FeedForward with Wav2Vec2->Wav2Vec2Conformer
class Wav2Vec2ConformerFeedForward(nn.Module):
def __init__(self, config):
super().__init__()
@ -943,7 +755,6 @@ class Wav2Vec2ConformerEncoder(nn.Module):
)
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2GumbelVectorQuantizer with Wav2Vec2->Wav2Vec2Conformer
class Wav2Vec2ConformerGumbelVectorQuantizer(nn.Module):
"""
Vector quantization using gumbel softmax. See `[CATEGORICAL REPARAMETERIZATION WITH
@ -1020,7 +831,6 @@ class Wav2Vec2ConformerGumbelVectorQuantizer(nn.Module):
return codevectors, perplexity
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Adapter with Wav2Vec2->Wav2Vec2Conformer
class Wav2Vec2ConformerAdapter(nn.Module):
def __init__(self, config):
super().__init__()
@ -1052,7 +862,6 @@ class Wav2Vec2ConformerAdapter(nn.Module):
return hidden_states
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2AdapterLayer with Wav2Vec2->Wav2Vec2Conformer
class Wav2Vec2ConformerAdapterLayer(nn.Module):
def __init__(self, config):
super().__init__()
@ -1170,6 +979,128 @@ class Wav2Vec2ConformerPreTrainedModel(PreTrainedModel):
return attention_mask
def _compute_mask_indices(
shape: Tuple[int, int],
mask_prob: float,
mask_length: int,
attention_mask: Optional[torch.LongTensor] = None,
min_masks: int = 0,
) -> np.ndarray:
"""
Computes random mask spans for a given shape. Used to implement [SpecAugment: A Simple Data Augmentation Method for
ASR](https://arxiv.org/abs/1904.08779). Note that this method is not optimized to run on TPU and should be run on
CPU as part of the preprocessing during training.
Args:
shape: The shape for which to compute masks. This should be of a tuple of size 2 where
the first element is the batch size and the second element is the length of the axis to span.
mask_prob: The percentage of the whole axis (between 0 and 1) which will be masked. The number of
independently generated mask spans of length `mask_length` is computed by
`mask_prob*shape[1]/mask_length`. Note that due to overlaps, `mask_prob` is an upper bound and the
actual percentage will be smaller.
mask_length: size of the mask
min_masks: minimum number of masked spans
attention_mask: A (right-padded) attention mask which independently shortens the feature axis of
each batch dimension.
"""
batch_size, sequence_length = shape
if mask_length < 1:
raise ValueError("`mask_length` has to be bigger than 0.")
if mask_length > sequence_length:
raise ValueError(
f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length}"
f" and `sequence_length`: {sequence_length}`"
)
# epsilon is used for probabilistic rounding
epsilon = np.random.rand(1).item()
def compute_num_masked_span(input_length):
"""Given input length, compute how many spans should be masked"""
num_masked_span = int(mask_prob * input_length / mask_length + epsilon)
num_masked_span = max(num_masked_span, min_masks)
# make sure num masked span <= sequence_length
if num_masked_span * mask_length > sequence_length:
num_masked_span = sequence_length // mask_length
# make sure num_masked span is also <= input_length - (mask_length - 1)
if input_length - (mask_length - 1) < num_masked_span:
num_masked_span = max(input_length - (mask_length - 1), 0)
return num_masked_span
# compute number of masked spans in batch
input_lengths = (
attention_mask.detach().sum(-1).tolist()
if attention_mask is not None
else [sequence_length for _ in range(batch_size)]
)
# SpecAugment mask to fill
spec_aug_mask = np.zeros((batch_size, sequence_length), dtype=bool)
spec_aug_mask_idxs = []
max_num_masked_span = compute_num_masked_span(sequence_length)
if max_num_masked_span == 0:
return spec_aug_mask
for input_length in input_lengths:
# compute num of masked spans for this input
num_masked_span = compute_num_masked_span(input_length)
# get random indices to mask
spec_aug_mask_idx = np.random.choice(
np.arange(input_length - (mask_length - 1)), num_masked_span, replace=False
)
# pick first sampled index that will serve as a dummy index to pad vector
# to ensure same dimension for all batches due to probabilistic rounding
# Picking first sample just pads those vectors twice.
if len(spec_aug_mask_idx) == 0:
# this case can only happen if `input_length` is strictly smaller then
# `sequence_length` in which case the last token has to be a padding
# token which we can use as a dummy mask id
dummy_mask_idx = sequence_length - 1
else:
dummy_mask_idx = spec_aug_mask_idx[0]
spec_aug_mask_idx = np.concatenate(
[spec_aug_mask_idx, np.ones(max_num_masked_span - num_masked_span, dtype=np.int32) * dummy_mask_idx]
)
spec_aug_mask_idxs.append(spec_aug_mask_idx)
spec_aug_mask_idxs = np.array(spec_aug_mask_idxs)
# expand masked indices to masked spans
spec_aug_mask_idxs = np.broadcast_to(
spec_aug_mask_idxs[:, :, None], (batch_size, max_num_masked_span, mask_length)
)
spec_aug_mask_idxs = spec_aug_mask_idxs.reshape(batch_size, max_num_masked_span * mask_length)
# add offset to the starting indexes so that indexes now create a span
offsets = np.arange(mask_length)[None, None, :]
offsets = np.broadcast_to(offsets, (batch_size, max_num_masked_span, mask_length)).reshape(
batch_size, max_num_masked_span * mask_length
)
spec_aug_mask_idxs = spec_aug_mask_idxs + offsets
# ensure that we cannot have indices larger than sequence_length
if spec_aug_mask_idxs.max() > sequence_length - 1:
spec_aug_mask_idxs[spec_aug_mask_idxs > sequence_length - 1] = sequence_length - 1
# scatter indices to mask
np.put_along_axis(spec_aug_mask, spec_aug_mask_idxs, 1, -1)
return spec_aug_mask
_EXPECTED_OUTPUT_SHAPE = [1, 292, 1024]
WAV2VEC2_CONFORMER_START_DOCSTRING = r"""
Wav2Vec2Conformer was proposed in [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech
Representations](https://arxiv.org/abs/2006.11477) by Alexei Baevski, Henry Zhou, Abdelrahman Mohamed, Michael
@ -1178,8 +1109,9 @@ WAV2VEC2_CONFORMER_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving etc.).
This model is a PyTorch [nn.Module](https://pytorch.org/docs/stable/nn.html#nn.Module) sub-class. Use it as a
regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior.
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use
it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
behavior.
Parameters:
config ([`Wav2Vec2ConformerConfig`]): Model configuration class with all the parameters of the model.
@ -1187,7 +1119,6 @@ WAV2VEC2_CONFORMER_START_DOCSTRING = r"""
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
WAV2VEC2_CONFORMER_INPUTS_DOCSTRING = r"""
Args:
input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
@ -1226,6 +1157,8 @@ WAV2VEC2_CONFORMER_INPUTS_DOCSTRING = r"""
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
Wav2Vec2ConformerBaseModelOutput = Wav2Vec2BaseModelOutput
@add_start_docstrings(
"The bare Wav2Vec2Conformer Model transformer outputting raw hidden-states without any specific head on top.",
@ -1249,7 +1182,6 @@ class Wav2Vec2ConformerModel(Wav2Vec2ConformerPreTrainedModel):
# Initialize weights and apply final processing
self.post_init()
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Model.freeze_feature_encoder
def freeze_feature_encoder(self):
"""
Calling this function will disable the gradient computation for the feature encoder so that its parameter will
@ -1257,7 +1189,6 @@ class Wav2Vec2ConformerModel(Wav2Vec2ConformerPreTrainedModel):
"""
self.feature_extractor._freeze_parameters()
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Model._mask_hidden_states
def _mask_hidden_states(
self,
hidden_states: torch.FloatTensor,
@ -1307,12 +1238,11 @@ class Wav2Vec2ConformerModel(Wav2Vec2ConformerPreTrainedModel):
@add_start_docstrings_to_model_forward(WAV2VEC2_CONFORMER_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=Wav2Vec2BaseModelOutput,
output_type=Wav2Vec2ConformerBaseModelOutput,
config_class=_CONFIG_FOR_DOC,
modality="audio",
expected_output=_EXPECTED_OUTPUT_SHAPE,
)
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Model.forward with wav2vec2->wav2vec2_conformer
def forward(
self,
input_values: Optional[torch.Tensor],
@ -1321,7 +1251,7 @@ class Wav2Vec2ConformerModel(Wav2Vec2ConformerPreTrainedModel):
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, Wav2Vec2BaseModelOutput]:
) -> Union[Tuple, Wav2Vec2ConformerBaseModelOutput]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
@ -1358,7 +1288,7 @@ class Wav2Vec2ConformerModel(Wav2Vec2ConformerPreTrainedModel):
if not return_dict:
return (hidden_states, extract_features) + encoder_outputs[1:]
return Wav2Vec2BaseModelOutput(
return Wav2Vec2ConformerBaseModelOutput(
last_hidden_state=hidden_states,
extract_features=extract_features,
hidden_states=encoder_outputs.hidden_states,
@ -1370,7 +1300,6 @@ class Wav2Vec2ConformerModel(Wav2Vec2ConformerPreTrainedModel):
"""Wav2Vec2Conformer Model with a quantizer and `VQ` head on top.""", WAV2VEC2_CONFORMER_START_DOCSTRING
)
class Wav2Vec2ConformerForPreTraining(Wav2Vec2ConformerPreTrainedModel):
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForPreTraining.__init__ with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer
def __init__(self, config: Wav2Vec2ConformerConfig):
super().__init__(config)
self.wav2vec2_conformer = Wav2Vec2ConformerModel(config)
@ -1384,14 +1313,12 @@ class Wav2Vec2ConformerForPreTraining(Wav2Vec2ConformerPreTrainedModel):
# Initialize weights and apply final processing
self.post_init()
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForPreTraining.set_gumbel_temperature
def set_gumbel_temperature(self, temperature: int):
"""
Set the Gumbel softmax temperature to a given value. Only necessary for training
"""
self.quantizer.temperature = temperature
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForPreTraining.freeze_feature_encoder with wav2vec2->wav2vec2_conformer
def freeze_feature_encoder(self):
"""
Calling this function will disable the gradient computation for the feature encoder so that its parameter will
@ -1400,7 +1327,6 @@ class Wav2Vec2ConformerForPreTraining(Wav2Vec2ConformerPreTrainedModel):
self.wav2vec2_conformer.feature_extractor._freeze_parameters()
@staticmethod
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForPreTraining.compute_contrastive_logits
def compute_contrastive_logits(
target_features: torch.FloatTensor,
negative_features: torch.FloatTensor,
@ -1423,7 +1349,6 @@ class Wav2Vec2ConformerForPreTraining(Wav2Vec2ConformerPreTrainedModel):
@add_start_docstrings_to_model_forward(WAV2VEC2_CONFORMER_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=Wav2Vec2ConformerForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForPreTraining.forward with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer,wav2vec2_conformer-base->wav2vec2-conformer-rel-pos-large
def forward(
self,
input_values: Optional[torch.Tensor],
@ -1452,8 +1377,8 @@ class Wav2Vec2ConformerForPreTraining(Wav2Vec2ConformerPreTrainedModel):
>>> from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer import _compute_mask_indices, _sample_negative_indices
>>> from datasets import load_dataset
>>> feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-conformer-rel-pos-large")
>>> model = Wav2Vec2ConformerForPreTraining.from_pretrained("facebook/wav2vec2-conformer-rel-pos-large")
>>> feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2_conformer-base")
>>> model = Wav2Vec2ConformerForPreTraining.from_pretrained("facebook/wav2vec2_conformer-base")
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
>>> input_values = feature_extractor(ds[0]["audio"]["array"], return_tensors="pt").input_values # Batch size 1
@ -1585,12 +1510,18 @@ class Wav2Vec2ConformerForPreTraining(Wav2Vec2ConformerPreTrainedModel):
)
_HIDDEN_STATES_START_POSITION = 2
# CTC docstring
_CTC_EXPECTED_OUTPUT = "'MISTER QUILTER IS THE APOSTLE OF THE MIDDLE CLASSES AND WE ARE GLAD TO WELCOME HIS GOSPEL'"
_CTC_EXPECTED_LOSS = 64.21
@add_start_docstrings(
"""Wav2Vec2Conformer Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC).""",
WAV2VEC2_CONFORMER_START_DOCSTRING,
)
class Wav2Vec2ConformerForCTC(Wav2Vec2ConformerPreTrainedModel):
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForCTC.__init__ with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer
def __init__(self, config, target_lang: Optional[str] = None):
super().__init__(config)
@ -1614,7 +1545,6 @@ class Wav2Vec2ConformerForCTC(Wav2Vec2ConformerPreTrainedModel):
# Initialize weights and apply final processing
self.post_init()
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForCTC.freeze_feature_encoder with wav2vec2->wav2vec2_conformer
def freeze_feature_encoder(self):
"""
Calling this function will disable the gradient computation for the feature encoder so that its parameter will
@ -1630,7 +1560,6 @@ class Wav2Vec2ConformerForCTC(Wav2Vec2ConformerPreTrainedModel):
expected_output=_CTC_EXPECTED_OUTPUT,
expected_loss=_CTC_EXPECTED_LOSS,
)
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForCTC.forward with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer
def forward(
self,
input_values: Optional[torch.Tensor],
@ -1710,7 +1639,6 @@ class Wav2Vec2ConformerForCTC(Wav2Vec2ConformerPreTrainedModel):
WAV2VEC2_CONFORMER_START_DOCSTRING,
)
class Wav2Vec2ConformerForSequenceClassification(Wav2Vec2ConformerPreTrainedModel):
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.__init__ with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer
def __init__(self, config):
super().__init__(config)
@ -1728,7 +1656,6 @@ class Wav2Vec2ConformerForSequenceClassification(Wav2Vec2ConformerPreTrainedMode
# Initialize weights and apply final processing
self.post_init()
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.freeze_feature_encoder with wav2vec2->wav2vec2_conformer
def freeze_feature_encoder(self):
"""
Calling this function will disable the gradient computation for the feature encoder so that its parameter will
@ -1751,7 +1678,6 @@ class Wav2Vec2ConformerForSequenceClassification(Wav2Vec2ConformerPreTrainedMode
config_class=_CONFIG_FOR_DOC,
modality="audio",
)
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.forward with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer,WAV_2_VEC_2->WAV2VEC2_CONFORMER
def forward(
self,
input_values: Optional[torch.Tensor],
@ -1822,7 +1748,6 @@ class Wav2Vec2ConformerForSequenceClassification(Wav2Vec2ConformerPreTrainedMode
WAV2VEC2_CONFORMER_START_DOCSTRING,
)
class Wav2Vec2ConformerForAudioFrameClassification(Wav2Vec2ConformerPreTrainedModel):
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForAudioFrameClassification.__init__ with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer,WAV_2_VEC_2->WAV2VEC2_CONFORMER
def __init__(self, config):
super().__init__(config)
@ -1839,7 +1764,6 @@ class Wav2Vec2ConformerForAudioFrameClassification(Wav2Vec2ConformerPreTrainedMo
self.init_weights()
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForAudioFrameClassification.freeze_feature_encoder with wav2vec2->wav2vec2_conformer
def freeze_feature_encoder(self):
"""
Calling this function will disable the gradient computation for the feature encoder so that its parameter will
@ -1847,7 +1771,6 @@ class Wav2Vec2ConformerForAudioFrameClassification(Wav2Vec2ConformerPreTrainedMo
"""
self.wav2vec2_conformer.feature_extractor._freeze_parameters()
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForAudioFrameClassification.freeze_base_model with wav2vec2->wav2vec2_conformer
def freeze_base_model(self):
"""
Calling this function will disable the gradient computation for the base model so that its parameters will not
@ -1863,7 +1786,6 @@ class Wav2Vec2ConformerForAudioFrameClassification(Wav2Vec2ConformerPreTrainedMo
config_class=_CONFIG_FOR_DOC,
modality="audio",
)
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForAudioFrameClassification.forward with wav2vec2->wav2vec2_conformer
def forward(
self,
input_values: Optional[torch.Tensor],
@ -1918,7 +1840,6 @@ class Wav2Vec2ConformerForAudioFrameClassification(Wav2Vec2ConformerPreTrainedMo
)
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.AMSoftmaxLoss
class AMSoftmaxLoss(nn.Module):
def __init__(self, input_dim, num_labels, scale=30.0, margin=0.4):
super(AMSoftmaxLoss, self).__init__()
@ -1942,7 +1863,6 @@ class AMSoftmaxLoss(nn.Module):
return loss
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.TDNNLayer
class TDNNLayer(nn.Module):
def __init__(self, config, layer_id=0):
super().__init__()
@ -1958,6 +1878,7 @@ class TDNNLayer(nn.Module):
if is_peft_available():
from peft.tuners.lora import LoraLayer
if is_peft_available():
if isinstance(self.kernel, LoraLayer):
warnings.warn(
"Detected LoRA on TDNNLayer. LoRA weights won't be applied due to optimization. "
@ -2000,7 +1921,6 @@ class Wav2Vec2ConformerForXVector(Wav2Vec2ConformerPreTrainedModel):
self.init_weights()
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForXVector.freeze_feature_encoder with wav2vec2->wav2vec2_conformer
def freeze_feature_encoder(self):
"""
Calling this function will disable the gradient computation for the feature encoder so that its parameter will
@ -2008,7 +1928,6 @@ class Wav2Vec2ConformerForXVector(Wav2Vec2ConformerPreTrainedModel):
"""
self.wav2vec2_conformer.feature_extractor._freeze_parameters()
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForXVector.freeze_base_model with wav2vec2->wav2vec2_conformer
def freeze_base_model(self):
"""
Calling this function will disable the gradient computation for the base model so that its parameters will not
@ -2017,7 +1936,6 @@ class Wav2Vec2ConformerForXVector(Wav2Vec2ConformerPreTrainedModel):
for param in self.wav2vec2_conformer.parameters():
param.requires_grad = False
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForXVector._get_tdnn_output_lengths with wav2vec2->wav2vec2_conformer
def _get_tdnn_output_lengths(self, input_lengths: Union[torch.LongTensor, int]):
"""
Computes the output length of the TDNN layers
@ -2040,7 +1958,6 @@ class Wav2Vec2ConformerForXVector(Wav2Vec2ConformerPreTrainedModel):
config_class=_CONFIG_FOR_DOC,
modality="audio",
)
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForXVector.forward with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer,WAV_2_VEC_2->WAV2VEC2_CONFORMER
def forward(
self,
input_values: Optional[torch.Tensor],

View File

@ -0,0 +1,892 @@
import math
from dataclasses import dataclass
from typing import Optional, Tuple, Union
import torch
from torch import nn
from ...activations import ACT2FN
from ...integrations.deepspeed import is_deepspeed_zero3_enabled
from ...integrations.fsdp import is_fsdp_managed_module
from ...modeling_outputs import (
BaseModelOutput,
CausalLMOutput,
SequenceClassifierOutput,
TokenClassifierOutput,
Wav2Vec2BaseModelOutput,
XVectorOutput,
)
from ...modeling_utils import PreTrainedModel
from ...utils import (
ModelOutput,
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
logging,
replace_return_docstrings,
)
from ..wav2vec2.modeling_wav2vec2 import (
Wav2Vec2Adapter,
Wav2Vec2AdapterLayer,
Wav2Vec2FeatureEncoder,
Wav2Vec2FeatureProjection,
Wav2Vec2FeedForward,
Wav2Vec2ForAudioFrameClassification,
Wav2Vec2ForCTC,
Wav2Vec2ForPreTraining,
Wav2Vec2ForSequenceClassification,
Wav2Vec2ForXVector,
Wav2Vec2GumbelVectorQuantizer,
Wav2Vec2Model,
Wav2Vec2PositionalConvEmbedding,
)
from .configuration_wav2vec2_conformer import Wav2Vec2ConformerConfig
logger = logging.get_logger(__name__)
_HIDDEN_STATES_START_POSITION = 2
# General docstring
_CONFIG_FOR_DOC = "Wav2Vec2ConformerConfig"
# Base docstring
_CHECKPOINT_FOR_DOC = "facebook/wav2vec2-conformer-rope-large-960h-ft"
_EXPECTED_OUTPUT_SHAPE = [1, 292, 1024]
# CTC docstring
_CTC_EXPECTED_OUTPUT = "'MISTER QUILTER IS THE APOSTLE OF THE MIDDLE CLASSES AND WE ARE GLAD TO WELCOME HIS GOSPEL'"
_CTC_EXPECTED_LOSS = 64.21
@dataclass
class Wav2Vec2ConformerForPreTrainingOutput(ModelOutput):
"""
Output type of [`Wav2Vec2ConformerForPreTraining`], with potential hidden states and attentions.
Args:
loss (*optional*, returned when `sample_negative_indices` are passed, `torch.FloatTensor` of shape `(1,)`):
Total loss as the sum of the contrastive loss (L_m) and the diversity loss (L_d) as stated in the [official
paper](https://arxiv.org/pdf/2006.11477.pdf) . (classification) loss.
projected_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.proj_codevector_dim)`):
Hidden-states of the model projected to *config.proj_codevector_dim* that can be used to predict the masked
projected quantized states.
projected_quantized_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.proj_codevector_dim)`):
Quantized extracted feature vectors projected to *config.proj_codevector_dim* representing the positive
target vectors for contrastive loss.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
contrastive_loss (*optional*, returned when `sample_negative_indices` are passed, `torch.FloatTensor` of shape `(1,)`):
The contrastive loss (L_m) as stated in the [official paper](https://arxiv.org/pdf/2006.11477.pdf) .
diversity_loss (*optional*, returned when `sample_negative_indices` are passed, `torch.FloatTensor` of shape `(1,)`):
The diversity loss (L_d) as stated in the [official paper](https://arxiv.org/pdf/2006.11477.pdf) .
"""
loss: Optional[torch.FloatTensor] = None
projected_states: Optional[torch.FloatTensor] = None
projected_quantized_states: Optional[torch.FloatTensor] = None
codevector_perplexity: Optional[torch.FloatTensor] = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
contrastive_loss: Optional[torch.FloatTensor] = None
diversity_loss: Optional[torch.FloatTensor] = None
class Wav2Vec2ConformerPositionalConvEmbedding(Wav2Vec2PositionalConvEmbedding):
pass
class Wav2Vec2ConformerRotaryPositionalEmbedding(nn.Module):
"""Rotary positional embedding
Reference : https://blog.eleuther.ai/rotary-embeddings/ Paper: https://arxiv.org/pdf/2104.09864.pdf
"""
def __init__(self, config):
super().__init__()
dim = config.hidden_size // config.num_attention_heads
base = config.rotary_embedding_base
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim))
self.register_buffer("inv_freq", inv_freq)
self.cached_sequence_length = None
self.cached_rotary_positional_embedding = None
def forward(self, hidden_states):
sequence_length = hidden_states.shape[1]
if sequence_length == self.cached_sequence_length and self.cached_rotary_positional_embedding is not None:
return self.cached_rotary_positional_embedding
self.cached_sequence_length = sequence_length
# Embeddings are computed in the dtype of the inv_freq constant
time_stamps = torch.arange(sequence_length).type_as(self.inv_freq)
freqs = torch.einsum("i,j->ij", time_stamps, self.inv_freq)
embeddings = torch.cat((freqs, freqs), dim=-1)
cos_embeddings = embeddings.cos()[:, None, None, :]
sin_embeddings = embeddings.sin()[:, None, None, :]
# Computed embeddings are cast to the dtype of the hidden state inputs
self.cached_rotary_positional_embedding = torch.stack([cos_embeddings, sin_embeddings]).type_as(hidden_states)
return self.cached_rotary_positional_embedding
class Wav2Vec2ConformerRelPositionalEmbedding(nn.Module):
"""Relative positional encoding module."""
def __init__(self, config):
super().__init__()
self.max_len = config.max_source_positions
self.d_model = config.hidden_size
self.pe = None
self.extend_pe(torch.tensor(0.0).expand(1, self.max_len))
def extend_pe(self, x):
# Reset the positional encodings
if self.pe is not None:
# self.pe contains both positive and negative parts
# the length of self.pe is 2 * input_len - 1
if self.pe.size(1) >= x.size(1) * 2 - 1:
if self.pe.dtype != x.dtype or self.pe.device != x.device:
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
return
# Suppose `i` is the position of query vector and `j` is the
# position of key vector. We use positive relative positions when keys
# are to the left (i>j) and negative relative positions otherwise (i<j).
pe_positive = torch.zeros(x.size(1), self.d_model)
pe_negative = torch.zeros(x.size(1), self.d_model)
position = torch.arange(0, x.size(1), dtype=torch.int64).float().unsqueeze(1)
div_term = torch.exp(
torch.arange(0, self.d_model, 2, dtype=torch.int64).float() * -(math.log(10000.0) / self.d_model)
)
pe_positive[:, 0::2] = torch.sin(position * div_term)
pe_positive[:, 1::2] = torch.cos(position * div_term)
pe_negative[:, 0::2] = torch.sin(-1 * position * div_term)
pe_negative[:, 1::2] = torch.cos(-1 * position * div_term)
# Reverse the order of positive indices and concat both positive and
# negative indices. This is used to support the shifting trick
# as in https://arxiv.org/abs/1901.02860
pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0)
pe_negative = pe_negative[1:].unsqueeze(0)
pe = torch.cat([pe_positive, pe_negative], dim=1)
self.pe = pe.to(device=x.device, dtype=x.dtype)
def forward(self, hidden_states: torch.Tensor):
self.extend_pe(hidden_states)
start_idx = self.pe.size(1) // 2 - hidden_states.size(1) + 1
end_idx = self.pe.size(1) // 2 + hidden_states.size(1)
relative_position_embeddings = self.pe[:, start_idx:end_idx]
return relative_position_embeddings
class Wav2Vec2ConformerFeatureEncoder(Wav2Vec2FeatureEncoder):
pass
class Wav2Vec2ConformerFeatureProjection(Wav2Vec2FeatureProjection):
pass
class Wav2Vec2ConformerFeedForward(Wav2Vec2FeedForward):
pass
class Wav2Vec2ConformerConvolutionModule(nn.Module):
"""Convolution block used in the conformer block"""
def __init__(self, config):
super().__init__()
if (config.conv_depthwise_kernel_size - 1) % 2 == 1:
raise ValueError("`config.conv_depthwise_kernel_size` should be a odd number for 'SAME' padding")
self.layer_norm = nn.LayerNorm(config.hidden_size)
self.pointwise_conv1 = nn.Conv1d(
config.hidden_size,
2 * config.hidden_size,
kernel_size=1,
stride=1,
padding=0,
bias=False,
)
self.glu = nn.GLU(dim=1)
self.depthwise_conv = nn.Conv1d(
config.hidden_size,
config.hidden_size,
config.conv_depthwise_kernel_size,
stride=1,
padding=(config.conv_depthwise_kernel_size - 1) // 2,
groups=config.hidden_size,
bias=False,
)
self.batch_norm = nn.BatchNorm1d(config.hidden_size)
self.activation = ACT2FN[config.hidden_act]
self.pointwise_conv2 = nn.Conv1d(
config.hidden_size,
config.hidden_size,
kernel_size=1,
stride=1,
padding=0,
bias=False,
)
self.dropout = nn.Dropout(config.conformer_conv_dropout)
def forward(self, hidden_states):
hidden_states = self.layer_norm(hidden_states)
# exchange the temporal dimension and the feature dimension
hidden_states = hidden_states.transpose(1, 2)
# GLU mechanism
# => (batch, 2*channel, dim)
hidden_states = self.pointwise_conv1(hidden_states)
# => (batch, channel, dim)
hidden_states = self.glu(hidden_states)
# 1D Depthwise Conv
hidden_states = self.depthwise_conv(hidden_states)
hidden_states = self.batch_norm(hidden_states)
hidden_states = self.activation(hidden_states)
hidden_states = self.pointwise_conv2(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = hidden_states.transpose(1, 2)
return hidden_states
class Wav2Vec2ConformerSelfAttention(nn.Module):
"""Construct an Wav2Vec2ConformerSelfAttention object.
Can be enhanced with rotary or relative position embeddings.
"""
def __init__(self, config):
super().__init__()
self.head_size = config.hidden_size // config.num_attention_heads
self.num_heads = config.num_attention_heads
self.position_embeddings_type = config.position_embeddings_type
self.linear_q = nn.Linear(config.hidden_size, config.hidden_size)
self.linear_k = nn.Linear(config.hidden_size, config.hidden_size)
self.linear_v = nn.Linear(config.hidden_size, config.hidden_size)
self.linear_out = nn.Linear(config.hidden_size, config.hidden_size)
self.dropout = nn.Dropout(p=config.attention_dropout)
if self.position_embeddings_type == "relative":
# linear transformation for positional encoding
self.linear_pos = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
# these two learnable bias are used in matrix c and matrix d
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
self.pos_bias_u = nn.Parameter(torch.zeros(self.num_heads, self.head_size))
self.pos_bias_v = nn.Parameter(torch.zeros(self.num_heads, self.head_size))
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
relative_position_embeddings: Optional[torch.Tensor] = None,
output_attentions: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
# self-attention mechanism
batch_size, sequence_length, hidden_size = hidden_states.size()
# make sure query/key states can be != value states
query_key_states = hidden_states
value_states = hidden_states
if self.position_embeddings_type == "rotary":
if relative_position_embeddings is None:
raise ValueError(
"`relative_position_embeddings` has to be defined when `self.position_embeddings_type == 'rotary'"
)
query_key_states = self._apply_rotary_embedding(query_key_states, relative_position_embeddings)
# project query_key_states and value_states
query = self.linear_q(query_key_states).view(batch_size, -1, self.num_heads, self.head_size)
key = self.linear_k(query_key_states).view(batch_size, -1, self.num_heads, self.head_size)
value = self.linear_v(value_states).view(batch_size, -1, self.num_heads, self.head_size)
# => (batch, head, time1, d_k)
query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)
if self.position_embeddings_type == "relative":
if relative_position_embeddings is None:
raise ValueError(
"`relative_position_embeddings` has to be defined when `self.position_embeddings_type =="
" 'relative'"
)
# apply relative_position_embeddings to qk scores
# as proposed in Transformer_XL: https://arxiv.org/abs/1901.02860
scores = self._apply_relative_embeddings(
query=query, key=key, relative_position_embeddings=relative_position_embeddings
)
else:
scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.head_size)
# apply attention_mask if necessary
if attention_mask is not None:
scores = scores + attention_mask
# => (batch, head, time1, time2)
probs = torch.softmax(scores, dim=-1)
probs = self.dropout(probs)
# => (batch, head, time1, d_k)
hidden_states = torch.matmul(probs, value)
# => (batch, time1, hidden_size)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_size)
hidden_states = self.linear_out(hidden_states)
return hidden_states, probs
def _apply_rotary_embedding(self, hidden_states, relative_position_embeddings):
batch_size, sequence_length, hidden_size = hidden_states.size()
hidden_states = hidden_states.view(batch_size, sequence_length, self.num_heads, self.head_size)
cos = relative_position_embeddings[0, :sequence_length, ...]
sin = relative_position_embeddings[1, :sequence_length, ...]
# rotate hidden_states with rotary embeddings
hidden_states = hidden_states.transpose(0, 1)
rotated_states_begin = hidden_states[..., : self.head_size // 2]
rotated_states_end = hidden_states[..., self.head_size // 2 :]
rotated_states = torch.cat((-rotated_states_end, rotated_states_begin), dim=rotated_states_begin.ndim - 1)
hidden_states = (hidden_states * cos) + (rotated_states * sin)
hidden_states = hidden_states.transpose(0, 1)
hidden_states = hidden_states.view(batch_size, sequence_length, self.num_heads * self.head_size)
return hidden_states
def _apply_relative_embeddings(self, query, key, relative_position_embeddings):
# 1. project positional embeddings
# => (batch, head, 2*time1-1, d_k)
proj_relative_position_embeddings = self.linear_pos(relative_position_embeddings)
proj_relative_position_embeddings = proj_relative_position_embeddings.view(
relative_position_embeddings.size(0), -1, self.num_heads, self.head_size
)
proj_relative_position_embeddings = proj_relative_position_embeddings.transpose(1, 2)
proj_relative_position_embeddings = proj_relative_position_embeddings.transpose(2, 3)
# 2. Add bias to query
# => (batch, head, time1, d_k)
query = query.transpose(1, 2)
q_with_bias_u = (query + self.pos_bias_u).transpose(1, 2)
q_with_bias_v = (query + self.pos_bias_v).transpose(1, 2)
# 3. attention score: first compute matrix a and matrix c
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
# => (batch, head, time1, time2)
scores_ac = torch.matmul(q_with_bias_u, key.transpose(-2, -1))
# 4. then compute matrix b and matrix d
# => (batch, head, time1, 2*time1-1)
scores_bd = torch.matmul(q_with_bias_v, proj_relative_position_embeddings)
# 5. shift matrix b and matrix d
zero_pad = torch.zeros((*scores_bd.size()[:3], 1), device=scores_bd.device, dtype=scores_bd.dtype)
scores_bd_padded = torch.cat([zero_pad, scores_bd], dim=-1)
scores_bd_padded_shape = scores_bd.size()[:2] + (scores_bd.shape[3] + 1, scores_bd.shape[2])
scores_bd_padded = scores_bd_padded.view(*scores_bd_padded_shape)
scores_bd = scores_bd_padded[:, :, 1:].view_as(scores_bd)
scores_bd = scores_bd[:, :, :, : scores_bd.size(-1) // 2 + 1]
# 6. sum matrices
# => (batch, head, time1, time2)
scores = (scores_ac + scores_bd) / math.sqrt(self.head_size)
return scores
class Wav2Vec2ConformerEncoderLayer(nn.Module):
"""Conformer block based on https://arxiv.org/abs/2005.08100."""
def __init__(self, config):
super().__init__()
embed_dim = config.hidden_size
dropout = config.attention_dropout
# Feed-forward 1
self.ffn1_layer_norm = nn.LayerNorm(embed_dim)
self.ffn1 = Wav2Vec2ConformerFeedForward(config)
# Self-Attention
self.self_attn_layer_norm = nn.LayerNorm(embed_dim)
self.self_attn_dropout = nn.Dropout(dropout)
self.self_attn = Wav2Vec2ConformerSelfAttention(config)
# Conformer Convolution
self.conv_module = Wav2Vec2ConformerConvolutionModule(config)
# Feed-forward 2
self.ffn2_layer_norm = nn.LayerNorm(embed_dim)
self.ffn2 = Wav2Vec2ConformerFeedForward(config)
self.final_layer_norm = nn.LayerNorm(embed_dim)
def forward(
self,
hidden_states,
attention_mask: Optional[torch.Tensor] = None,
relative_position_embeddings: Optional[torch.Tensor] = None,
output_attentions: bool = False,
):
hidden_states = hidden_states
# 1. Feed-Forward 1 layer
residual = hidden_states
hidden_states = self.ffn1_layer_norm(hidden_states)
hidden_states = self.ffn1(hidden_states)
hidden_states = hidden_states * 0.5 + residual
residual = hidden_states
# 2. Self-Attention layer
hidden_states = self.self_attn_layer_norm(hidden_states)
hidden_states, attn_weigts = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
relative_position_embeddings=relative_position_embeddings,
output_attentions=output_attentions,
)
hidden_states = self.self_attn_dropout(hidden_states)
hidden_states = hidden_states + residual
# 3. Convolutional Layer
residual = hidden_states
hidden_states = self.conv_module(hidden_states)
hidden_states = residual + hidden_states
# 4. Feed-Forward 2 Layer
residual = hidden_states
hidden_states = self.ffn2_layer_norm(hidden_states)
hidden_states = self.ffn2(hidden_states)
hidden_states = hidden_states * 0.5 + residual
hidden_states = self.final_layer_norm(hidden_states)
return hidden_states, attn_weigts
class Wav2Vec2ConformerEncoder(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
if config.position_embeddings_type == "relative":
self.embed_positions = Wav2Vec2ConformerRelPositionalEmbedding(config)
elif config.position_embeddings_type == "rotary":
self.embed_positions = Wav2Vec2ConformerRotaryPositionalEmbedding(config)
else:
self.embed_positions = None
self.pos_conv_embed = Wav2Vec2ConformerPositionalConvEmbedding(config)
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout)
self.layers = nn.ModuleList([Wav2Vec2ConformerEncoderLayer(config) for _ in range(config.num_hidden_layers)])
self.gradient_checkpointing = False
def forward(
self,
hidden_states,
attention_mask=None,
output_attentions=False,
output_hidden_states=False,
return_dict=True,
):
all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None
if attention_mask is not None:
# make sure padded tokens output 0
expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
hidden_states[~expand_attention_mask] = 0.0
# extend attention_mask
attention_mask = 1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)
attention_mask = attention_mask * torch.finfo(hidden_states.dtype).min
attention_mask = attention_mask.expand(
attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1]
)
hidden_states = self.dropout(hidden_states)
if self.embed_positions is not None:
relative_position_embeddings = self.embed_positions(hidden_states)
else:
relative_position_embeddings = None
synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self)
for i, layer in enumerate(self.layers):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
dropout_probability = torch.rand([])
skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False
if not skip_the_layer or synced_gpus:
# under fsdp or deepspeed zero3 all gpus must run in sync
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
layer.__call__,
hidden_states,
attention_mask,
relative_position_embeddings,
output_attentions,
)
else:
layer_outputs = layer(
hidden_states,
attention_mask=attention_mask,
relative_position_embeddings=relative_position_embeddings,
output_attentions=output_attentions,
)
hidden_states = layer_outputs[0]
if skip_the_layer:
layer_outputs = (None, None)
if output_attentions:
all_self_attentions = all_self_attentions + (layer_outputs[1],)
hidden_states = self.layer_norm(hidden_states)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict:
return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
return BaseModelOutput(
last_hidden_state=hidden_states,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
)
class Wav2Vec2ConformerGumbelVectorQuantizer(Wav2Vec2GumbelVectorQuantizer):
pass
class Wav2Vec2ConformerAdapter(Wav2Vec2Adapter):
pass
class Wav2Vec2ConformerAdapterLayer(Wav2Vec2AdapterLayer):
pass
class Wav2Vec2ConformerPreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = Wav2Vec2ConformerConfig
base_model_prefix = "wav2vec2_conformer"
main_input_name = "input_values"
supports_gradient_checkpointing = True
def _init_weights(self, module):
"""Initialize the weights"""
# Wav2Vec2ForPreTraining last 2 linear layers need standard Linear init.
if isinstance(module, Wav2Vec2ConformerForPreTraining):
module.project_hid.reset_parameters()
module.project_q.reset_parameters()
module.project_hid._is_hf_initialized = True
module.project_q._is_hf_initialized = True
# gumbel softmax requires special init
elif isinstance(module, Wav2Vec2ConformerGumbelVectorQuantizer):
module.weight_proj.weight.data.normal_(mean=0.0, std=1)
module.weight_proj.bias.data.zero_()
nn.init.uniform_(module.codevectors)
elif isinstance(module, Wav2Vec2ConformerSelfAttention):
if hasattr(module, "pos_bias_u"):
nn.init.xavier_uniform_(module.pos_bias_u)
if hasattr(module, "pos_bias_v"):
nn.init.xavier_uniform_(module.pos_bias_v)
elif isinstance(module, Wav2Vec2ConformerPositionalConvEmbedding):
nn.init.normal_(
module.conv.weight,
mean=0,
std=2 * math.sqrt(1 / (module.conv.kernel_size[0] * module.conv.in_channels)),
)
nn.init.constant_(module.conv.bias, 0)
elif isinstance(module, Wav2Vec2ConformerFeatureProjection):
k = math.sqrt(1 / module.projection.in_features)
nn.init.uniform_(module.projection.weight, a=-k, b=k)
nn.init.uniform_(module.projection.bias, a=-k, b=k)
elif isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
elif isinstance(module, nn.Conv1d):
nn.init.kaiming_normal_(module.weight)
if module.bias is not None:
k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0]))
nn.init.uniform_(module.bias, a=-k, b=k)
def _get_feat_extract_output_lengths(
self, input_lengths: Union[torch.LongTensor, int], add_adapter: Optional[bool] = None
):
"""
Computes the output length of the convolutional layers
"""
add_adapter = self.config.add_adapter if add_adapter is None else add_adapter
def _conv_out_length(input_length, kernel_size, stride):
# 1D convolutional layer output length formula taken
# from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
return torch.div(input_length - kernel_size, stride, rounding_mode="floor") + 1
for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride):
input_lengths = _conv_out_length(input_lengths, kernel_size, stride)
if add_adapter:
for _ in range(self.config.num_adapter_layers):
input_lengths = _conv_out_length(input_lengths, 1, self.config.adapter_stride)
return input_lengths
def _get_feature_vector_attention_mask(
self, feature_vector_length: int, attention_mask: torch.LongTensor, add_adapter=None
):
# Effectively attention_mask.sum(-1), but not inplace to be able to run
# on inference mode.
non_padded_lengths = attention_mask.cumsum(dim=-1)[:, -1]
output_lengths = self._get_feat_extract_output_lengths(non_padded_lengths, add_adapter=add_adapter)
output_lengths = output_lengths.to(torch.long)
batch_size = attention_mask.shape[0]
attention_mask = torch.zeros(
(batch_size, feature_vector_length), dtype=attention_mask.dtype, device=attention_mask.device
)
# these two operations makes sure that all values before the output lengths idxs are attended to
attention_mask[(torch.arange(attention_mask.shape[0], device=attention_mask.device), output_lengths - 1)] = 1
attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool()
return attention_mask
WAV2VEC2_CONFORMER_START_DOCSTRING = None # will be automatically redefined
WAV2VEC2_CONFORMER_INPUTS_DOCSTRING = r"""
Args:
input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
Float values of input raw speech waveform. Values can be obtained by loading a `.flac` or `.wav` audio file
into an array of type `List[float]` or a `numpy.ndarray`, *e.g.* via the soundfile library (`pip install
soundfile`). To prepare the array into `input_values`, the [`AutoProcessor`] should be used for padding and
conversion into a tensor of type `torch.FloatTensor`. See [`Wav2Vec2Processor.__call__`] for details.
attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing convolution and attention on padding token indices. Mask values selected in `[0,
1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
<Tip warning={true}>
`attention_mask` should only be passed if the corresponding processor has `config.return_attention_mask ==
True`. For all models whose processor has `config.return_attention_mask == False`, such as
[wav2vec2-conformer-rel-pos-large](https://huggingface.co/facebook/wav2vec2-conformer-rel-pos-large),
`attention_mask` should **not** be passed to avoid degraded performance when doing batched inference. For
such models `input_values` should simply be padded with 0 and passed without `attention_mask`. Be aware
that these models also yield slightly different results depending on whether `input_values` is padded or
not.
</Tip>
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
Wav2Vec2ConformerBaseModelOutput = Wav2Vec2BaseModelOutput
@add_start_docstrings(
"The bare Wav2Vec2Conformer Model transformer outputting raw hidden-states without any specific head on top.",
WAV2VEC2_CONFORMER_START_DOCSTRING,
)
class Wav2Vec2ConformerModel(Wav2Vec2ConformerPreTrainedModel, Wav2Vec2Model):
def __init__(self, config: Wav2Vec2ConformerConfig):
Wav2Vec2ConformerPreTrainedModel.__init__(config)
self.config = config
self.feature_extractor = Wav2Vec2ConformerFeatureEncoder(config)
self.feature_projection = Wav2Vec2ConformerFeatureProjection(config)
# model only needs masking vector if mask prob is > 0.0
if config.mask_time_prob > 0.0 or config.mask_feature_prob > 0.0:
self.masked_spec_embed = nn.Parameter(torch.Tensor(config.hidden_size).uniform_())
self.encoder = Wav2Vec2ConformerEncoder(config)
self.adapter = Wav2Vec2ConformerAdapter(config) if config.add_adapter else None
# Initialize weights and apply final processing
self.post_init()
def freeze_feature_extractor(self):
raise AttributeError("Not needed for Wav2Vec2Conformer")
@add_start_docstrings_to_model_forward(WAV2VEC2_CONFORMER_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=Wav2Vec2ConformerBaseModelOutput,
config_class=_CONFIG_FOR_DOC,
modality="audio",
expected_output=_EXPECTED_OUTPUT_SHAPE,
)
def forward(self, **super_kwargs):
return super().forward(**super_kwargs)
@add_start_docstrings(
"""Wav2Vec2Conformer Model with a quantizer and `VQ` head on top.""", WAV2VEC2_CONFORMER_START_DOCSTRING
)
class Wav2Vec2ConformerForPreTraining(Wav2Vec2ForPreTraining):
def __init__(self, config: Wav2Vec2ConformerConfig):
super().__init__(config)
def freeze_feature_extractor(self):
raise AttributeError("Not needed for Wav2Vec2Conformer")
@add_start_docstrings_to_model_forward(WAV2VEC2_CONFORMER_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=Wav2Vec2ConformerForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
def forward(self, **super_kwargs) -> Union[Tuple, Wav2Vec2ConformerForPreTrainingOutput]:
return super().forward(**super_kwargs)
@add_start_docstrings(
"""Wav2Vec2Conformer Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC).""",
WAV2VEC2_CONFORMER_START_DOCSTRING,
)
class Wav2Vec2ConformerForCTC(Wav2Vec2ForCTC):
def __init__(self, config, target_lang: Optional[str] = None):
super().__init__(config)
def tie_weights(self):
raise AttributeError("Not needed for Wav2Vec2Conformer")
def freeze_feature_extractor(self):
raise AttributeError("Not needed for Wav2Vec2Conformer")
def freeze_base_model(self):
raise AttributeError("Not needed for Wav2Vec2Conformer")
@add_start_docstrings_to_model_forward(WAV2VEC2_CONFORMER_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=CausalLMOutput,
config_class=_CONFIG_FOR_DOC,
expected_output=_CTC_EXPECTED_OUTPUT,
expected_loss=_CTC_EXPECTED_LOSS,
)
def forward(self, **super_kwargs):
return super().forward(**super_kwargs)
@add_start_docstrings(
"""
Wav2Vec2Conformer Model with a sequence classification head on top (a linear layer over the pooled output) for
tasks like SUPERB Keyword Spotting.
""",
WAV2VEC2_CONFORMER_START_DOCSTRING,
)
class Wav2Vec2ConformerForSequenceClassification(Wav2Vec2ForSequenceClassification):
def __init__(self, config):
super().__init__(config)
def freeze_feature_extractor(self):
raise AttributeError("Not needed for Wav2Vec2Conformer")
@add_start_docstrings_to_model_forward(WAV2VEC2_CONFORMER_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=SequenceClassifierOutput,
config_class=_CONFIG_FOR_DOC,
modality="audio",
)
def forward(self, **super_kwargs):
return super().forward(**super_kwargs)
@add_start_docstrings(
"""
Wav2Vec2Conformer Model with a frame classification head on top for tasks like Speaker Diarization.
""",
WAV2VEC2_CONFORMER_START_DOCSTRING,
)
class Wav2Vec2ConformerForAudioFrameClassification(Wav2Vec2ForAudioFrameClassification):
def __init__(self, config):
super().__init__(config)
def freeze_feature_extractor(self):
raise AttributeError("Not needed for Wav2Vec2Conformer")
@add_start_docstrings_to_model_forward(WAV2VEC2_CONFORMER_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=TokenClassifierOutput,
config_class=_CONFIG_FOR_DOC,
modality="audio",
)
def forward(self, **super_kwargs):
return super().forward(**super_kwargs)
@add_start_docstrings(
"""
Wav2Vec2Conformer Model with an XVector feature extraction head on top for tasks like Speaker Verification.
""",
WAV2VEC2_CONFORMER_START_DOCSTRING,
)
class Wav2Vec2ConformerForXVector(Wav2Vec2ForXVector):
def __init__(self, config):
super().__init__(config)
def freeze_feature_extractor(self):
raise AttributeError("Not needed for Wav2Vec2Conformer")
@add_start_docstrings_to_model_forward(WAV2VEC2_CONFORMER_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=XVectorOutput,
config_class=_CONFIG_FOR_DOC,
modality="audio",
)
def forward(self, **super_kwargs):
return super().forward(**super_kwargs)
__all__ = [
"Wav2Vec2ConformerForAudioFrameClassification",
"Wav2Vec2ConformerForCTC",
"Wav2Vec2ConformerForPreTraining",
"Wav2Vec2ConformerForSequenceClassification",
"Wav2Vec2ConformerForXVector",
"Wav2Vec2ConformerModel",
"Wav2Vec2ConformerPreTrainedModel",
]

View File

@ -1,28 +1,17 @@
# coding=utf-8
# Copyright 2021 The Fairseq Authors, Microsoft Research, and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch WavLM model."""
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# This file was automatically generated from src/transformers/models/wavlm/modular_wavlm.py.
# Do NOT edit this file manually as any edits will be overwritten by the generation of
# the file from the modular. If any change should be done, please apply the change to the
# modular_wavlm.py file directly. One of our CI enforces this.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
import math
import warnings
from typing import Optional, Tuple, Union
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint
from torch import nn
from torch.nn import CrossEntropyLoss
from ...activations import ACT2FN
@ -49,225 +38,22 @@ from .configuration_wavlm import WavLMConfig
logger = logging.get_logger(__name__)
_CHECKPOINT_FOR_DOC = "patrickvonplaten/wavlm-libri-clean-100h-base-plus"
_HIDDEN_STATES_START_POSITION = 2
# General docstring
_CONFIG_FOR_DOC = "WavLMConfig"
# Base docstring
_CHECKPOINT_FOR_DOC = "patrickvonplaten/wavlm-libri-clean-100h-base-plus"
_EXPECTED_OUTPUT_SHAPE = [1, 292, 768]
# CTC docstring
_CTC_EXPECTED_OUTPUT = "'mister quilter is the aposle of the middle classes and we are glad to welcome his gospel'"
_CTC_EXPECTED_LOSS = 12.51
# Frame class docstring
_FRAME_CLASS_CHECKPOINT = "microsoft/wavlm-base-plus-sd"
_FRAME_EXPECTED_OUTPUT = [0, 0]
# Speaker Verification docstring
_XVECTOR_CHECKPOINT = "microsoft/wavlm-base-plus-sv"
_XVECTOR_EXPECTED_OUTPUT = 0.97
# Copied from transformers.models.wav2vec2.modeling_wav2vec2._compute_mask_indices
def _compute_mask_indices(
shape: Tuple[int, int],
mask_prob: float,
mask_length: int,
attention_mask: Optional[torch.LongTensor] = None,
min_masks: int = 0,
) -> np.ndarray:
"""
Computes random mask spans for a given shape. Used to implement [SpecAugment: A Simple Data Augmentation Method for
ASR](https://arxiv.org/abs/1904.08779). Note that this method is not optimized to run on TPU and should be run on
CPU as part of the preprocessing during training.
Args:
shape: The shape for which to compute masks. This should be of a tuple of size 2 where
the first element is the batch size and the second element is the length of the axis to span.
mask_prob: The percentage of the whole axis (between 0 and 1) which will be masked. The number of
independently generated mask spans of length `mask_length` is computed by
`mask_prob*shape[1]/mask_length`. Note that due to overlaps, `mask_prob` is an upper bound and the
actual percentage will be smaller.
mask_length: size of the mask
min_masks: minimum number of masked spans
attention_mask: A (right-padded) attention mask which independently shortens the feature axis of
each batch dimension.
"""
batch_size, sequence_length = shape
if mask_length < 1:
raise ValueError("`mask_length` has to be bigger than 0.")
if mask_length > sequence_length:
raise ValueError(
f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length}"
f" and `sequence_length`: {sequence_length}`"
)
# epsilon is used for probabilistic rounding
epsilon = np.random.rand(1).item()
def compute_num_masked_span(input_length):
"""Given input length, compute how many spans should be masked"""
num_masked_span = int(mask_prob * input_length / mask_length + epsilon)
num_masked_span = max(num_masked_span, min_masks)
# make sure num masked span <= sequence_length
if num_masked_span * mask_length > sequence_length:
num_masked_span = sequence_length // mask_length
# make sure num_masked span is also <= input_length - (mask_length - 1)
if input_length - (mask_length - 1) < num_masked_span:
num_masked_span = max(input_length - (mask_length - 1), 0)
return num_masked_span
# compute number of masked spans in batch
input_lengths = (
attention_mask.detach().sum(-1).tolist()
if attention_mask is not None
else [sequence_length for _ in range(batch_size)]
)
# SpecAugment mask to fill
spec_aug_mask = np.zeros((batch_size, sequence_length), dtype=bool)
spec_aug_mask_idxs = []
max_num_masked_span = compute_num_masked_span(sequence_length)
if max_num_masked_span == 0:
return spec_aug_mask
for input_length in input_lengths:
# compute num of masked spans for this input
num_masked_span = compute_num_masked_span(input_length)
# get random indices to mask
spec_aug_mask_idx = np.random.choice(
np.arange(input_length - (mask_length - 1)), num_masked_span, replace=False
)
# pick first sampled index that will serve as a dummy index to pad vector
# to ensure same dimension for all batches due to probabilistic rounding
# Picking first sample just pads those vectors twice.
if len(spec_aug_mask_idx) == 0:
# this case can only happen if `input_length` is strictly smaller then
# `sequence_length` in which case the last token has to be a padding
# token which we can use as a dummy mask id
dummy_mask_idx = sequence_length - 1
else:
dummy_mask_idx = spec_aug_mask_idx[0]
spec_aug_mask_idx = np.concatenate(
[spec_aug_mask_idx, np.ones(max_num_masked_span - num_masked_span, dtype=np.int32) * dummy_mask_idx]
)
spec_aug_mask_idxs.append(spec_aug_mask_idx)
spec_aug_mask_idxs = np.array(spec_aug_mask_idxs)
# expand masked indices to masked spans
spec_aug_mask_idxs = np.broadcast_to(
spec_aug_mask_idxs[:, :, None], (batch_size, max_num_masked_span, mask_length)
)
spec_aug_mask_idxs = spec_aug_mask_idxs.reshape(batch_size, max_num_masked_span * mask_length)
# add offset to the starting indexes so that indexes now create a span
offsets = np.arange(mask_length)[None, None, :]
offsets = np.broadcast_to(offsets, (batch_size, max_num_masked_span, mask_length)).reshape(
batch_size, max_num_masked_span * mask_length
)
spec_aug_mask_idxs = spec_aug_mask_idxs + offsets
# ensure that we cannot have indices larger than sequence_length
if spec_aug_mask_idxs.max() > sequence_length - 1:
spec_aug_mask_idxs[spec_aug_mask_idxs > sequence_length - 1] = sequence_length - 1
# scatter indices to mask
np.put_along_axis(spec_aug_mask, spec_aug_mask_idxs, 1, -1)
return spec_aug_mask
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2NoLayerNormConvLayer with Wav2Vec2->WavLM
class WavLMNoLayerNormConvLayer(nn.Module):
def __init__(self, config, layer_id=0):
class WavLMSamePadLayer(nn.Module):
def __init__(self, num_conv_pos_embeddings):
super().__init__()
self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
self.out_conv_dim = config.conv_dim[layer_id]
self.conv = nn.Conv1d(
self.in_conv_dim,
self.out_conv_dim,
kernel_size=config.conv_kernel[layer_id],
stride=config.conv_stride[layer_id],
bias=config.conv_bias,
)
self.activation = ACT2FN[config.feat_extract_activation]
self.num_pad_remove = 1 if num_conv_pos_embeddings % 2 == 0 else 0
def forward(self, hidden_states):
hidden_states = self.conv(hidden_states)
hidden_states = self.activation(hidden_states)
if self.num_pad_remove > 0:
hidden_states = hidden_states[:, :, : -self.num_pad_remove]
return hidden_states
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2LayerNormConvLayer with Wav2Vec2->WavLM
class WavLMLayerNormConvLayer(nn.Module):
def __init__(self, config, layer_id=0):
super().__init__()
self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
self.out_conv_dim = config.conv_dim[layer_id]
self.conv = nn.Conv1d(
self.in_conv_dim,
self.out_conv_dim,
kernel_size=config.conv_kernel[layer_id],
stride=config.conv_stride[layer_id],
bias=config.conv_bias,
)
self.layer_norm = nn.LayerNorm(self.out_conv_dim, elementwise_affine=True)
self.activation = ACT2FN[config.feat_extract_activation]
def forward(self, hidden_states):
hidden_states = self.conv(hidden_states)
hidden_states = hidden_states.transpose(-2, -1)
hidden_states = self.layer_norm(hidden_states)
hidden_states = hidden_states.transpose(-2, -1)
hidden_states = self.activation(hidden_states)
return hidden_states
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2GroupNormConvLayer with Wav2Vec2->WavLM
class WavLMGroupNormConvLayer(nn.Module):
def __init__(self, config, layer_id=0):
super().__init__()
self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
self.out_conv_dim = config.conv_dim[layer_id]
self.conv = nn.Conv1d(
self.in_conv_dim,
self.out_conv_dim,
kernel_size=config.conv_kernel[layer_id],
stride=config.conv_stride[layer_id],
bias=config.conv_bias,
)
self.activation = ACT2FN[config.feat_extract_activation]
self.layer_norm = nn.GroupNorm(num_groups=self.out_conv_dim, num_channels=self.out_conv_dim, affine=True)
def forward(self, hidden_states):
hidden_states = self.conv(hidden_states)
hidden_states = self.layer_norm(hidden_states)
hidden_states = self.activation(hidden_states)
return hidden_states
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2PositionalConvEmbedding with Wav2Vec2->WavLM
class WavLMPositionalConvEmbedding(nn.Module):
def __init__(self, config):
super().__init__()
@ -313,75 +99,6 @@ class WavLMPositionalConvEmbedding(nn.Module):
return hidden_states
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2SamePadLayer with Wav2Vec2->WavLM
class WavLMSamePadLayer(nn.Module):
def __init__(self, num_conv_pos_embeddings):
super().__init__()
self.num_pad_remove = 1 if num_conv_pos_embeddings % 2 == 0 else 0
def forward(self, hidden_states):
if self.num_pad_remove > 0:
hidden_states = hidden_states[:, :, : -self.num_pad_remove]
return hidden_states
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2FeatureEncoder with Wav2Vec2->WavLM
class WavLMFeatureEncoder(nn.Module):
"""Construct the features from raw audio waveform"""
def __init__(self, config):
super().__init__()
if config.feat_extract_norm == "group":
conv_layers = [WavLMGroupNormConvLayer(config, layer_id=0)] + [
WavLMNoLayerNormConvLayer(config, layer_id=i + 1) for i in range(config.num_feat_extract_layers - 1)
]
elif config.feat_extract_norm == "layer":
conv_layers = [WavLMLayerNormConvLayer(config, layer_id=i) for i in range(config.num_feat_extract_layers)]
else:
raise ValueError(
f"`config.feat_extract_norm` is {config.feat_extract_norm}, but has to be one of ['group', 'layer']"
)
self.conv_layers = nn.ModuleList(conv_layers)
self.gradient_checkpointing = False
self._requires_grad = True
def _freeze_parameters(self):
for param in self.parameters():
param.requires_grad = False
self._requires_grad = False
def forward(self, input_values):
hidden_states = input_values[:, None]
# make sure hidden_states require grad for gradient_checkpointing
if self._requires_grad and self.training:
hidden_states.requires_grad = True
for conv_layer in self.conv_layers:
if self._requires_grad and self.gradient_checkpointing and self.training:
hidden_states = self._gradient_checkpointing_func(
conv_layer.__call__,
hidden_states,
)
else:
hidden_states = conv_layer(hidden_states)
return hidden_states
class WavLMFeatureExtractor(WavLMFeatureEncoder):
def __init__(self, config):
super().__init__(config)
warnings.warn(
f"The class `{self.__class__.__name__}` has been depreciated "
"and will be removed in Transformers v5. "
f"Use `{self.__class__.__bases__[0].__name__}` instead.",
FutureWarning,
)
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2FeatureProjection with Wav2Vec2->WavLM
class WavLMFeatureProjection(nn.Module):
def __init__(self, config):
super().__init__()
@ -563,7 +280,6 @@ class WavLMAttention(nn.Module):
return relative_buckets
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2FeedForward with Wav2Vec2->WavLM
class WavLMFeedForward(nn.Module):
def __init__(self, config):
super().__init__()
@ -903,57 +619,6 @@ class WavLMGumbelVectorQuantizer(nn.Module):
return codevectors, perplexity
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Adapter with Wav2Vec2->WavLM
class WavLMAdapter(nn.Module):
def __init__(self, config):
super().__init__()
# feature dim might need to be down-projected
if config.output_hidden_size != config.hidden_size:
self.proj = nn.Linear(config.hidden_size, config.output_hidden_size)
self.proj_layer_norm = nn.LayerNorm(config.output_hidden_size)
else:
self.proj = self.proj_layer_norm = None
self.layers = nn.ModuleList(WavLMAdapterLayer(config) for _ in range(config.num_adapter_layers))
self.layerdrop = config.layerdrop
def forward(self, hidden_states):
# down project hidden_states if necessary
if self.proj is not None and self.proj_layer_norm is not None:
hidden_states = self.proj(hidden_states)
hidden_states = self.proj_layer_norm(hidden_states)
hidden_states = hidden_states.transpose(1, 2)
for layer in self.layers:
layerdrop_prob = np.random.random()
if not self.training or (layerdrop_prob > self.layerdrop):
hidden_states = layer(hidden_states)
hidden_states = hidden_states.transpose(1, 2)
return hidden_states
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2AdapterLayer with Wav2Vec2->WavLM
class WavLMAdapterLayer(nn.Module):
def __init__(self, config):
super().__init__()
self.conv = nn.Conv1d(
config.output_hidden_size,
2 * config.output_hidden_size,
config.adapter_kernel_size,
stride=config.adapter_stride,
padding=1,
)
def forward(self, hidden_states):
hidden_states = self.conv(hidden_states)
hidden_states = nn.functional.glu(hidden_states, dim=1)
return hidden_states
class WavLMPreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
@ -964,6 +629,8 @@ class WavLMPreTrainedModel(PreTrainedModel):
base_model_prefix = "wavlm"
main_input_name = "input_values"
supports_gradient_checkpointing = True
_supports_flash_attn_2 = False
_supports_sdpa = False
def _init_weights(self, module):
"""Initialize the weights"""
@ -1042,6 +709,293 @@ class WavLMPreTrainedModel(PreTrainedModel):
return attention_mask
class WavLMNoLayerNormConvLayer(nn.Module):
def __init__(self, config, layer_id=0):
super().__init__()
self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
self.out_conv_dim = config.conv_dim[layer_id]
self.conv = nn.Conv1d(
self.in_conv_dim,
self.out_conv_dim,
kernel_size=config.conv_kernel[layer_id],
stride=config.conv_stride[layer_id],
bias=config.conv_bias,
)
self.activation = ACT2FN[config.feat_extract_activation]
def forward(self, hidden_states):
hidden_states = self.conv(hidden_states)
hidden_states = self.activation(hidden_states)
return hidden_states
class WavLMLayerNormConvLayer(nn.Module):
def __init__(self, config, layer_id=0):
super().__init__()
self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
self.out_conv_dim = config.conv_dim[layer_id]
self.conv = nn.Conv1d(
self.in_conv_dim,
self.out_conv_dim,
kernel_size=config.conv_kernel[layer_id],
stride=config.conv_stride[layer_id],
bias=config.conv_bias,
)
self.layer_norm = nn.LayerNorm(self.out_conv_dim, elementwise_affine=True)
self.activation = ACT2FN[config.feat_extract_activation]
def forward(self, hidden_states):
hidden_states = self.conv(hidden_states)
hidden_states = hidden_states.transpose(-2, -1)
hidden_states = self.layer_norm(hidden_states)
hidden_states = hidden_states.transpose(-2, -1)
hidden_states = self.activation(hidden_states)
return hidden_states
class WavLMGroupNormConvLayer(nn.Module):
def __init__(self, config, layer_id=0):
super().__init__()
self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
self.out_conv_dim = config.conv_dim[layer_id]
self.conv = nn.Conv1d(
self.in_conv_dim,
self.out_conv_dim,
kernel_size=config.conv_kernel[layer_id],
stride=config.conv_stride[layer_id],
bias=config.conv_bias,
)
self.activation = ACT2FN[config.feat_extract_activation]
self.layer_norm = nn.GroupNorm(num_groups=self.out_conv_dim, num_channels=self.out_conv_dim, affine=True)
def forward(self, hidden_states):
hidden_states = self.conv(hidden_states)
hidden_states = self.layer_norm(hidden_states)
hidden_states = self.activation(hidden_states)
return hidden_states
class WavLMFeatureEncoder(nn.Module):
"""Construct the features from raw audio waveform"""
def __init__(self, config):
super().__init__()
if config.feat_extract_norm == "group":
conv_layers = [WavLMGroupNormConvLayer(config, layer_id=0)] + [
WavLMNoLayerNormConvLayer(config, layer_id=i + 1) for i in range(config.num_feat_extract_layers - 1)
]
elif config.feat_extract_norm == "layer":
conv_layers = [WavLMLayerNormConvLayer(config, layer_id=i) for i in range(config.num_feat_extract_layers)]
else:
raise ValueError(
f"`config.feat_extract_norm` is {config.feat_extract_norm}, but has to be one of ['group', 'layer']"
)
self.conv_layers = nn.ModuleList(conv_layers)
self.gradient_checkpointing = False
self._requires_grad = True
def _freeze_parameters(self):
for param in self.parameters():
param.requires_grad = False
self._requires_grad = False
def forward(self, input_values):
hidden_states = input_values[:, None]
# make sure hidden_states require grad for gradient_checkpointing
if self._requires_grad and self.training:
hidden_states.requires_grad = True
for conv_layer in self.conv_layers:
if self._requires_grad and self.gradient_checkpointing and self.training:
hidden_states = self._gradient_checkpointing_func(
conv_layer.__call__,
hidden_states,
)
else:
hidden_states = conv_layer(hidden_states)
return hidden_states
class WavLMAdapterLayer(nn.Module):
def __init__(self, config):
super().__init__()
self.conv = nn.Conv1d(
config.output_hidden_size,
2 * config.output_hidden_size,
config.adapter_kernel_size,
stride=config.adapter_stride,
padding=1,
)
def forward(self, hidden_states):
hidden_states = self.conv(hidden_states)
hidden_states = nn.functional.glu(hidden_states, dim=1)
return hidden_states
class WavLMAdapter(nn.Module):
def __init__(self, config):
super().__init__()
# feature dim might need to be down-projected
if config.output_hidden_size != config.hidden_size:
self.proj = nn.Linear(config.hidden_size, config.output_hidden_size)
self.proj_layer_norm = nn.LayerNorm(config.output_hidden_size)
else:
self.proj = self.proj_layer_norm = None
self.layers = nn.ModuleList(WavLMAdapterLayer(config) for _ in range(config.num_adapter_layers))
self.layerdrop = config.layerdrop
def forward(self, hidden_states):
# down project hidden_states if necessary
if self.proj is not None and self.proj_layer_norm is not None:
hidden_states = self.proj(hidden_states)
hidden_states = self.proj_layer_norm(hidden_states)
hidden_states = hidden_states.transpose(1, 2)
for layer in self.layers:
layerdrop_prob = np.random.random()
if not self.training or (layerdrop_prob > self.layerdrop):
hidden_states = layer(hidden_states)
hidden_states = hidden_states.transpose(1, 2)
return hidden_states
def _compute_mask_indices(
shape: Tuple[int, int],
mask_prob: float,
mask_length: int,
attention_mask: Optional[torch.LongTensor] = None,
min_masks: int = 0,
) -> np.ndarray:
"""
Computes random mask spans for a given shape. Used to implement [SpecAugment: A Simple Data Augmentation Method for
ASR](https://arxiv.org/abs/1904.08779). Note that this method is not optimized to run on TPU and should be run on
CPU as part of the preprocessing during training.
Args:
shape: The shape for which to compute masks. This should be of a tuple of size 2 where
the first element is the batch size and the second element is the length of the axis to span.
mask_prob: The percentage of the whole axis (between 0 and 1) which will be masked. The number of
independently generated mask spans of length `mask_length` is computed by
`mask_prob*shape[1]/mask_length`. Note that due to overlaps, `mask_prob` is an upper bound and the
actual percentage will be smaller.
mask_length: size of the mask
min_masks: minimum number of masked spans
attention_mask: A (right-padded) attention mask which independently shortens the feature axis of
each batch dimension.
"""
batch_size, sequence_length = shape
if mask_length < 1:
raise ValueError("`mask_length` has to be bigger than 0.")
if mask_length > sequence_length:
raise ValueError(
f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length}"
f" and `sequence_length`: {sequence_length}`"
)
# epsilon is used for probabilistic rounding
epsilon = np.random.rand(1).item()
def compute_num_masked_span(input_length):
"""Given input length, compute how many spans should be masked"""
num_masked_span = int(mask_prob * input_length / mask_length + epsilon)
num_masked_span = max(num_masked_span, min_masks)
# make sure num masked span <= sequence_length
if num_masked_span * mask_length > sequence_length:
num_masked_span = sequence_length // mask_length
# make sure num_masked span is also <= input_length - (mask_length - 1)
if input_length - (mask_length - 1) < num_masked_span:
num_masked_span = max(input_length - (mask_length - 1), 0)
return num_masked_span
# compute number of masked spans in batch
input_lengths = (
attention_mask.detach().sum(-1).tolist()
if attention_mask is not None
else [sequence_length for _ in range(batch_size)]
)
# SpecAugment mask to fill
spec_aug_mask = np.zeros((batch_size, sequence_length), dtype=bool)
spec_aug_mask_idxs = []
max_num_masked_span = compute_num_masked_span(sequence_length)
if max_num_masked_span == 0:
return spec_aug_mask
for input_length in input_lengths:
# compute num of masked spans for this input
num_masked_span = compute_num_masked_span(input_length)
# get random indices to mask
spec_aug_mask_idx = np.random.choice(
np.arange(input_length - (mask_length - 1)), num_masked_span, replace=False
)
# pick first sampled index that will serve as a dummy index to pad vector
# to ensure same dimension for all batches due to probabilistic rounding
# Picking first sample just pads those vectors twice.
if len(spec_aug_mask_idx) == 0:
# this case can only happen if `input_length` is strictly smaller then
# `sequence_length` in which case the last token has to be a padding
# token which we can use as a dummy mask id
dummy_mask_idx = sequence_length - 1
else:
dummy_mask_idx = spec_aug_mask_idx[0]
spec_aug_mask_idx = np.concatenate(
[spec_aug_mask_idx, np.ones(max_num_masked_span - num_masked_span, dtype=np.int32) * dummy_mask_idx]
)
spec_aug_mask_idxs.append(spec_aug_mask_idx)
spec_aug_mask_idxs = np.array(spec_aug_mask_idxs)
# expand masked indices to masked spans
spec_aug_mask_idxs = np.broadcast_to(
spec_aug_mask_idxs[:, :, None], (batch_size, max_num_masked_span, mask_length)
)
spec_aug_mask_idxs = spec_aug_mask_idxs.reshape(batch_size, max_num_masked_span * mask_length)
# add offset to the starting indexes so that indexes now create a span
offsets = np.arange(mask_length)[None, None, :]
offsets = np.broadcast_to(offsets, (batch_size, max_num_masked_span, mask_length)).reshape(
batch_size, max_num_masked_span * mask_length
)
spec_aug_mask_idxs = spec_aug_mask_idxs + offsets
# ensure that we cannot have indices larger than sequence_length
if spec_aug_mask_idxs.max() > sequence_length - 1:
spec_aug_mask_idxs[spec_aug_mask_idxs > sequence_length - 1] = sequence_length - 1
# scatter indices to mask
np.put_along_axis(spec_aug_mask, spec_aug_mask_idxs, 1, -1)
return spec_aug_mask
_EXPECTED_OUTPUT_SHAPE = [1, 292, 768]
WAVLM_START_DOCSTRING = r"""
WavLM was proposed in [WavLM: Unified Speech Representation Learning with Labeled and Unlabeled
Data](https://arxiv.org/abs/2110.13900) by Sanyuan Chen, Chengyi Wang, Zhengyang Chen, Yu Wu, Shujie Liu, Zhuo
@ -1061,7 +1015,6 @@ WAVLM_START_DOCSTRING = r"""
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
WAVLM_INPUTS_DOCSTRING = r"""
Args:
input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
@ -1098,12 +1051,13 @@ WAVLM_INPUTS_DOCSTRING = r"""
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
WavLMBaseModelOutput = Wav2Vec2BaseModelOutput
@add_start_docstrings(
"The bare WavLM Model transformer outputting raw hidden-states without any specific head on top.",
WAVLM_START_DOCSTRING,
)
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Model with Wav2Vec2->WavLM, wav2vec2->wavlm, WAV_2_VEC_2->WAVLM, WavLMBaseModelOutput->Wav2Vec2BaseModelOutput
class WavLMModel(WavLMPreTrainedModel):
def __init__(self, config: WavLMConfig):
super().__init__(config)
@ -1193,7 +1147,7 @@ class WavLMModel(WavLMPreTrainedModel):
@add_start_docstrings_to_model_forward(WAVLM_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=Wav2Vec2BaseModelOutput,
output_type=WavLMBaseModelOutput,
config_class=_CONFIG_FOR_DOC,
modality="audio",
expected_output=_EXPECTED_OUTPUT_SHAPE,
@ -1206,7 +1160,7 @@ class WavLMModel(WavLMPreTrainedModel):
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, Wav2Vec2BaseModelOutput]:
) -> Union[Tuple, WavLMBaseModelOutput]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
@ -1243,7 +1197,7 @@ class WavLMModel(WavLMPreTrainedModel):
if not return_dict:
return (hidden_states, extract_features) + encoder_outputs[1:]
return Wav2Vec2BaseModelOutput(
return WavLMBaseModelOutput(
last_hidden_state=hidden_states,
extract_features=extract_features,
hidden_states=encoder_outputs.hidden_states,
@ -1251,11 +1205,16 @@ class WavLMModel(WavLMPreTrainedModel):
)
_HIDDEN_STATES_START_POSITION = 2
_CTC_EXPECTED_OUTPUT = "'mister quilter is the aposle of the middle classes and we are glad to welcome his gospel'"
_CTC_EXPECTED_LOSS = 12.51
@add_start_docstrings(
"""WavLM Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC).""",
WAVLM_START_DOCSTRING,
)
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForCTC with Wav2Vec2->WavLM, wav2vec2->wavlm, WAV_2_VEC_2->WAVLM
class WavLMForCTC(WavLMPreTrainedModel):
def __init__(self, config, target_lang: Optional[str] = None):
super().__init__(config)
@ -1432,7 +1391,6 @@ class WavLMForSequenceClassification(WavLMPreTrainedModel):
# Initialize weights and apply final processing
self.post_init()
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.freeze_feature_extractor
def freeze_feature_extractor(self):
"""
Calling this function will disable the gradient computation for the feature encoder so that its parameters will
@ -1445,7 +1403,6 @@ class WavLMForSequenceClassification(WavLMPreTrainedModel):
)
self.freeze_feature_encoder()
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.freeze_feature_encoder with wav2vec2->wavlm
def freeze_feature_encoder(self):
"""
Calling this function will disable the gradient computation for the feature encoder so that its parameter will
@ -1453,7 +1410,6 @@ class WavLMForSequenceClassification(WavLMPreTrainedModel):
"""
self.wavlm.feature_extractor._freeze_parameters()
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.freeze_base_model with wav2vec2->wavlm
def freeze_base_model(self):
"""
Calling this function will disable the gradient computation for the base model so that its parameters will not
@ -1469,7 +1425,6 @@ class WavLMForSequenceClassification(WavLMPreTrainedModel):
config_class=_CONFIG_FOR_DOC,
modality="audio",
)
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.forward with Wav2Vec2->WavLM, wav2vec2->wavlm
def forward(
self,
input_values: Optional[torch.Tensor],
@ -1533,13 +1488,16 @@ class WavLMForSequenceClassification(WavLMPreTrainedModel):
)
_FRAME_CLASS_CHECKPOINT = "microsoft/wavlm-base-plus-sd"
_FRAME_EXPECTED_OUTPUT = [0, 0]
@add_start_docstrings(
"""
WavLM Model with a frame classification head on top for tasks like Speaker Diarization.
""",
WAVLM_START_DOCSTRING,
)
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForAudioFrameClassification with Wav2Vec2->WavLM, wav2vec2->wavlm, WAV_2_VEC_2->WAVLM
class WavLMForAudioFrameClassification(WavLMPreTrainedModel):
def __init__(self, config):
super().__init__(config)
@ -1646,7 +1604,6 @@ class WavLMForAudioFrameClassification(WavLMPreTrainedModel):
)
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.AMSoftmaxLoss
class AMSoftmaxLoss(nn.Module):
def __init__(self, input_dim, num_labels, scale=30.0, margin=0.4):
super(AMSoftmaxLoss, self).__init__()
@ -1670,7 +1627,6 @@ class AMSoftmaxLoss(nn.Module):
return loss
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.TDNNLayer
class TDNNLayer(nn.Module):
def __init__(self, config, layer_id=0):
super().__init__()
@ -1686,6 +1642,7 @@ class TDNNLayer(nn.Module):
if is_peft_available():
from peft.tuners.lora import LoraLayer
if is_peft_available():
if isinstance(self.kernel, LoraLayer):
warnings.warn(
"Detected LoRA on TDNNLayer. LoRA weights won't be applied due to optimization. "
@ -1702,13 +1659,16 @@ class TDNNLayer(nn.Module):
return hidden_states
_XVECTOR_CHECKPOINT = "microsoft/wavlm-base-plus-sv"
_XVECTOR_EXPECTED_OUTPUT = 0.97
@add_start_docstrings(
"""
WavLM Model with an XVector feature extraction head on top for tasks like Speaker Verification.
""",
WAVLM_START_DOCSTRING,
)
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForXVector with Wav2Vec2->WavLM, wav2vec2->wavlm, WAV_2_VEC_2->WAVLM
class WavLMForXVector(WavLMPreTrainedModel):
def __init__(self, config):
super().__init__(config)

View File

@ -0,0 +1,758 @@
import math
from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from ...integrations.deepspeed import is_deepspeed_zero3_enabled
from ...integrations.fsdp import is_fsdp_managed_module
from ...modeling_outputs import (
BaseModelOutput,
CausalLMOutput,
SequenceClassifierOutput,
TokenClassifierOutput,
Wav2Vec2BaseModelOutput,
XVectorOutput,
)
from ...modeling_utils import PreTrainedModel
from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
from ..wav2vec2.modeling_wav2vec2 import (
Wav2Vec2FeatureProjection,
Wav2Vec2FeedForward,
Wav2Vec2ForAudioFrameClassification,
Wav2Vec2ForCTC,
Wav2Vec2ForSequenceClassification,
Wav2Vec2ForXVector,
Wav2Vec2Model,
Wav2Vec2PositionalConvEmbedding,
Wav2Vec2PreTrainedModel,
)
from .configuration_wavlm import WavLMConfig
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "WavLMConfig"
_CHECKPOINT_FOR_DOC = "patrickvonplaten/wavlm-libri-clean-100h-base-plus"
_EXPECTED_OUTPUT_SHAPE = [1, 292, 768]
_CTC_EXPECTED_OUTPUT = "'mister quilter is the aposle of the middle classes and we are glad to welcome his gospel'"
_CTC_EXPECTED_LOSS = 12.51
_FRAME_CLASS_CHECKPOINT = "microsoft/wavlm-base-plus-sd"
_FRAME_EXPECTED_OUTPUT = [0, 0]
_XVECTOR_CHECKPOINT = "microsoft/wavlm-base-plus-sv"
_XVECTOR_EXPECTED_OUTPUT = 0.97
class WavLMPositionalConvEmbedding(Wav2Vec2PositionalConvEmbedding):
pass
class WavLMFeatureProjection(Wav2Vec2FeatureProjection):
pass
class WavLMAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(
self,
embed_dim: int,
num_heads: int,
dropout: float = 0.0,
num_buckets: int = 320,
max_distance: int = 800,
has_relative_position_bias: bool = True,
):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.dropout = dropout
self.head_dim = embed_dim // num_heads
if (self.head_dim * num_heads) != self.embed_dim:
raise ValueError(
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
f" and `num_heads`: {num_heads})."
)
self.scaling = self.head_dim**-0.5
self.k_proj = nn.Linear(embed_dim, embed_dim)
self.v_proj = nn.Linear(embed_dim, embed_dim)
self.q_proj = nn.Linear(embed_dim, embed_dim)
self.out_proj = nn.Linear(embed_dim, embed_dim)
self.num_buckets = num_buckets
self.max_distance = max_distance
self.gru_rel_pos_const = nn.Parameter(torch.ones(1, self.num_heads, 1, 1))
self.gru_rel_pos_linear = nn.Linear(self.head_dim, 8)
if has_relative_position_bias:
self.rel_attn_embed = nn.Embedding(self.num_buckets, self.num_heads)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_bias: Optional[torch.Tensor] = None,
output_attentions: bool = False,
index=0,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Attention layer with relative attention"""
bsz, tgt_len, _ = hidden_states.size()
# first pass of attention layer creates position bias
if position_bias is None:
position_bias = self.compute_bias(tgt_len, tgt_len)
position_bias = (
position_bias.unsqueeze(0).repeat(bsz, 1, 1, 1).view(bsz * self.num_heads, tgt_len, tgt_len)
)
# Compute relative position bias:
# 1) get reshape hidden_states
gated_hidden_states = hidden_states.view(hidden_states.shape[:-1] + (self.num_heads, -1))
gated_hidden_states = gated_hidden_states.permute(0, 2, 1, 3)
# 2) project hidden states
relative_position_proj = self.gru_rel_pos_linear(gated_hidden_states)
relative_position_proj = relative_position_proj.view(gated_hidden_states.shape[:-1] + (2, 4)).sum(-1)
# 3) compute gate for position bias from projected hidden states
gate_a, gate_b = torch.sigmoid(relative_position_proj).chunk(2, dim=-1)
gate_output = gate_a * (gate_b * self.gru_rel_pos_const - 1.0) + 2.0
# 4) apply gate to position bias to compute gated position_bias
gated_position_bias = gate_output.view(bsz * self.num_heads, -1, 1) * position_bias
gated_position_bias = gated_position_bias.view((-1, tgt_len, tgt_len))
attn_output, attn_weights = self.torch_multi_head_self_attention(
hidden_states, attention_mask, gated_position_bias, output_attentions
)
return attn_output, attn_weights, position_bias
def torch_multi_head_self_attention(
self,
hidden_states: torch.FloatTensor,
attention_mask: Union[torch.LongTensor, torch.BoolTensor],
gated_position_bias: torch.FloatTensor,
output_attentions: bool,
) -> (torch.FloatTensor, torch.FloatTensor):
"""simple wrapper around torch's multi_head_attention_forward function"""
# self-attention assumes q = k = v
query = key = value = hidden_states.transpose(0, 1)
key_padding_mask = attention_mask.ne(1) if attention_mask is not None else None
# disable bias and add_zero_attn
bias_k = bias_v = None
add_zero_attn = False
# PyTorch 1.3.0 has F.multi_head_attention_forward defined
# so no problem with backwards compatibility
attn_output, attn_weights = F.multi_head_attention_forward(
query,
key,
value,
self.embed_dim,
self.num_heads,
torch.empty([0]),
torch.cat((self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)),
bias_k,
bias_v,
add_zero_attn,
self.dropout,
self.out_proj.weight,
self.out_proj.bias,
self.training,
key_padding_mask,
output_attentions,
gated_position_bias,
use_separate_proj_weight=True,
q_proj_weight=self.q_proj.weight,
k_proj_weight=self.k_proj.weight,
v_proj_weight=self.v_proj.weight,
)
# [Seq_Len, Batch Size, ...] -> [Batch Size, Seq_Len, ...]
attn_output = attn_output.transpose(0, 1)
if attn_weights is not None:
# IMPORTANT: Attention weights are averaged weights
# here which should not be the case. This is an open issue
# on PyTorch: https://github.com/pytorch/pytorch/issues/32590
attn_weights = attn_weights[:, None].broadcast_to(
attn_weights.shape[:1] + (self.num_heads,) + attn_weights.shape[1:]
)
return attn_output, attn_weights
def compute_bias(self, query_length: int, key_length: int) -> torch.FloatTensor:
context_position = torch.arange(query_length, dtype=torch.long)[:, None]
memory_position = torch.arange(key_length, dtype=torch.long)[None, :]
relative_position = memory_position - context_position
relative_position_bucket = self._relative_positions_bucket(relative_position)
relative_position_bucket = relative_position_bucket.to(self.rel_attn_embed.weight.device)
values = self.rel_attn_embed(relative_position_bucket)
values = values.permute([2, 0, 1])
return values
def _relative_positions_bucket(self, relative_positions: torch.FloatTensor) -> torch.FloatTensor:
num_buckets = self.num_buckets // 2
relative_buckets = (relative_positions > 0).to(torch.long) * num_buckets
relative_positions = torch.abs(relative_positions)
max_exact = num_buckets // 2
is_small = relative_positions < max_exact
relative_positions_if_large = torch.log(relative_positions.float() / max_exact)
relative_positions_if_large = relative_positions_if_large / math.log(self.max_distance / max_exact)
relative_positions_if_large = relative_positions_if_large * (num_buckets - max_exact)
relative_position_if_large = (max_exact + relative_positions_if_large).to(torch.long)
relative_position_if_large = torch.min(
relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1)
)
relative_buckets += torch.where(is_small, relative_positions, relative_position_if_large)
return relative_buckets
class WavLMFeedForward(Wav2Vec2FeedForward):
pass
class WavLMEncoderLayer(nn.Module):
def __init__(self, config: WavLMConfig, has_relative_position_bias: bool = True):
super().__init__()
self.attention = WavLMAttention(
embed_dim=config.hidden_size,
num_heads=config.num_attention_heads,
dropout=config.attention_dropout,
num_buckets=config.num_buckets,
max_distance=config.max_bucket_distance,
has_relative_position_bias=has_relative_position_bias,
)
self.dropout = nn.Dropout(config.hidden_dropout)
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.feed_forward = WavLMFeedForward(config)
self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
def forward(self, hidden_states, attention_mask=None, position_bias=None, output_attentions=False, index=0):
attn_residual = hidden_states
hidden_states, attn_weights, position_bias = self.attention(
hidden_states,
attention_mask=attention_mask,
position_bias=position_bias,
output_attentions=output_attentions,
index=index,
)
hidden_states = self.dropout(hidden_states)
hidden_states = attn_residual + hidden_states
hidden_states = self.layer_norm(hidden_states)
hidden_states = hidden_states + self.feed_forward(hidden_states)
hidden_states = self.final_layer_norm(hidden_states)
outputs = (hidden_states, position_bias)
if output_attentions:
outputs += (attn_weights,)
return outputs
class WavLMEncoderLayerStableLayerNorm(nn.Module):
def __init__(self, config: WavLMConfig, has_relative_position_bias: bool = True):
super().__init__()
self.attention = WavLMAttention(
embed_dim=config.hidden_size,
num_heads=config.num_attention_heads,
dropout=config.attention_dropout,
num_buckets=config.num_buckets,
max_distance=config.max_bucket_distance,
has_relative_position_bias=has_relative_position_bias,
)
self.dropout = nn.Dropout(config.hidden_dropout)
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.feed_forward = WavLMFeedForward(config)
self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
def forward(self, hidden_states, attention_mask=None, position_bias=None, output_attentions=False):
attn_residual = hidden_states
hidden_states = self.layer_norm(hidden_states)
hidden_states, attn_weights, position_bias = self.attention(
hidden_states,
attention_mask=attention_mask,
position_bias=position_bias,
output_attentions=output_attentions,
)
hidden_states = self.dropout(hidden_states)
hidden_states = attn_residual + hidden_states
hidden_states = hidden_states + self.feed_forward(self.final_layer_norm(hidden_states))
outputs = (hidden_states, position_bias)
if output_attentions:
outputs += (attn_weights,)
return outputs
class WavLMEncoder(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.pos_conv_embed = WavLMPositionalConvEmbedding(config)
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout)
self.layers = nn.ModuleList(
[WavLMEncoderLayer(config, has_relative_position_bias=(i == 0)) for i in range(config.num_hidden_layers)]
)
self.gradient_checkpointing = False
def forward(
self,
hidden_states,
attention_mask=None,
output_attentions=False,
output_hidden_states=False,
return_dict=True,
):
all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None
if attention_mask is not None:
# make sure padded tokens output 0
expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
hidden_states[~expand_attention_mask] = 0
position_embeddings = self.pos_conv_embed(hidden_states)
hidden_states = hidden_states + position_embeddings
hidden_states = self.layer_norm(hidden_states)
hidden_states = self.dropout(hidden_states)
synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self)
position_bias = None
for i, layer in enumerate(self.layers):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
dropout_probability = torch.rand([])
skip_the_layer = self.training and i > 0 and (dropout_probability < self.config.layerdrop)
if not skip_the_layer or synced_gpus:
# under fsdp or deepspeed zero3 all gpus must run in sync
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
layer.__call__,
hidden_states,
attention_mask,
position_bias,
output_attentions,
)
else:
layer_outputs = layer(
hidden_states,
attention_mask=attention_mask,
position_bias=position_bias,
output_attentions=output_attentions,
index=i,
)
hidden_states, position_bias = layer_outputs[:2]
if skip_the_layer:
layer_outputs = (None, None, None)
if output_attentions:
all_self_attentions = all_self_attentions + (layer_outputs[2],)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict:
return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
return BaseModelOutput(
last_hidden_state=hidden_states,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
)
class WavLMEncoderStableLayerNorm(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.pos_conv_embed = WavLMPositionalConvEmbedding(config)
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout)
self.layers = nn.ModuleList(
[
WavLMEncoderLayerStableLayerNorm(config, has_relative_position_bias=(i == 0))
for i in range(config.num_hidden_layers)
]
)
self.gradient_checkpointing = False
def forward(
self,
hidden_states,
attention_mask=None,
output_attentions=False,
output_hidden_states=False,
return_dict=True,
):
all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None
if attention_mask is not None:
# make sure padded tokens are not attended to
expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
hidden_states[~expand_attention_mask] = 0
position_embeddings = self.pos_conv_embed(hidden_states)
hidden_states = hidden_states + position_embeddings
hidden_states = self.dropout(hidden_states)
synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self)
position_bias = None
for i, layer in enumerate(self.layers):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
dropout_probability = torch.rand([])
skip_the_layer = self.training and i > 0 and (dropout_probability < self.config.layerdrop)
if not skip_the_layer or synced_gpus:
# under fsdp or deepspeed zero3 all gpus must run in sync
# XXX: could optimize this like synced_gpus in generate_utils but not sure if it's worth the code complication
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
layer.__call__,
hidden_states,
attention_mask,
position_bias,
output_attentions,
)
else:
layer_outputs = layer(
hidden_states,
attention_mask=attention_mask,
output_attentions=output_attentions,
position_bias=position_bias,
)
hidden_states, position_bias = layer_outputs[:2]
if skip_the_layer:
layer_outputs = (None, None, None)
if output_attentions:
all_self_attentions = all_self_attentions + (layer_outputs[2],)
hidden_states = self.layer_norm(hidden_states)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict:
return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
return BaseModelOutput(
last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_self_attentions
)
class WavLMGumbelVectorQuantizer(nn.Module):
"""
Vector quantization using gumbel softmax. See [CATEGORICAL REPARAMETERIZATION WITH
GUMBEL-SOFTMAX](https://arxiv.org/pdf/1611.01144.pdf) for more information.
"""
def __init__(self, config):
super().__init__()
self.num_groups = config.num_codevector_groups
self.num_vars = config.num_codevectors_per_group
if config.codevector_dim % self.num_groups != 0:
raise ValueError(
f"`config.codevector_dim {config.codevector_dim} must be divisible"
f" by `config.num_codevector_groups` {self.num_groups} "
"for concatenation."
)
# storage for codebook variables (codewords)
self.codevectors = nn.Parameter(
torch.FloatTensor(1, self.num_groups * self.num_vars, config.codevector_dim // self.num_groups)
)
self.weight_proj = nn.Linear(config.conv_dim[-1], self.num_groups * self.num_vars)
# can be decayed for training
self.temperature = 2
@staticmethod
def _compute_perplexity(probs):
marginal_probs = probs.mean(dim=0)
perplexity = torch.exp(-torch.sum(marginal_probs * torch.log(marginal_probs + 1e-7), dim=-1)).sum()
return perplexity
def forward(self, hidden_states):
batch_size, sequence_length, hidden_size = hidden_states.shape
# project to codevector dim
hidden_states = self.weight_proj(hidden_states)
hidden_states = hidden_states.view(batch_size * sequence_length * self.num_groups, -1)
if self.training:
# sample code vector probs via gumbel in differentiateable way
codevector_probs = nn.functional.gumbel_softmax(hidden_states.float(), tau=self.temperature, hard=True)
codevector_probs = codevector_probs.type_as(hidden_states)
# compute perplexity
codevector_soft_dist = torch.softmax(
hidden_states.view(batch_size * sequence_length, self.num_groups, -1).float(), dim=-1
)
perplexity = self._compute_perplexity(codevector_soft_dist)
else:
# take argmax in non-differentiable way
# comptute hard codevector distribution (one hot)
codevector_idx = hidden_states.argmax(dim=-1)
codevector_probs = hidden_states.new_zeros(*hidden_states.shape).scatter_(
-1, codevector_idx.view(-1, 1), 1.0
)
codevector_probs = codevector_probs.view(batch_size * sequence_length, self.num_groups, -1)
perplexity = self._compute_perplexity(codevector_probs)
codevector_probs = codevector_probs.view(batch_size * sequence_length, -1)
# use probs to retrieve codevectors
codevectors_per_group = codevector_probs.unsqueeze(-1) * self.codevectors
codevectors = codevectors_per_group.view(batch_size * sequence_length, self.num_groups, self.num_vars, -1)
codevectors = codevectors.sum(-2).view(batch_size, sequence_length, -1)
return codevectors, perplexity
class WavLMPreTrainedModel(PreTrainedModel, Wav2Vec2PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = WavLMConfig
base_model_prefix = "wavlm"
main_input_name = "input_values"
supports_gradient_checkpointing = True
_supports_flash_attn_2 = False
_supports_sdpa = False
def _init_weights(self, module):
"""Initialize the weights"""
# gumbel softmax requires special init
if isinstance(module, WavLMGumbelVectorQuantizer):
module.weight_proj.weight.data.normal_(mean=0.0, std=1)
module.weight_proj.bias.data.zero_()
nn.init.uniform_(module.codevectors)
elif isinstance(module, WavLMPositionalConvEmbedding):
nn.init.normal_(
module.conv.weight,
mean=0,
std=2 * math.sqrt(1 / (module.conv.kernel_size[0] * module.conv.in_channels)),
)
nn.init.constant_(module.conv.bias, 0)
elif isinstance(module, WavLMFeatureProjection):
k = math.sqrt(1 / module.projection.in_features)
nn.init.uniform_(module.projection.weight, a=-k, b=k)
nn.init.uniform_(module.projection.bias, a=-k, b=k)
elif isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
elif isinstance(module, nn.Conv1d):
nn.init.kaiming_normal_(module.weight)
if module.bias is not None:
k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0]))
nn.init.uniform_(module.bias, a=-k, b=k)
def _get_adapters(self):
raise AttributeError("Not needed for WavLM")
def init_adapter_layers(self):
raise AttributeError("Not needed for WavLM")
def load_adapter(self):
raise AttributeError("Not needed for WavLM")
WAVLM_START_DOCSTRING = r"""
WavLM was proposed in [WavLM: Unified Speech Representation Learning with Labeled and Unlabeled
Data](https://arxiv.org/abs/2110.13900) by Sanyuan Chen, Chengyi Wang, Zhengyang Chen, Yu Wu, Shujie Liu, Zhuo
Chen, Jinyu Li, Naoyuki Kanda, Takuya Yoshioka, Xiong Xiao, Jian Wu, Long Zhou, Shuo Ren, Yanmin Qian, Yao Qian,
Jian Wu, Michael Zeng, Xiangzhan Yu, Furu Wei.
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving etc.).
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use
it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
behavior.
Parameters:
config ([`WavLMConfig`]): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
WAVLM_INPUTS_DOCSTRING = r"""
Args:
input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
Float values of input raw speech waveform. Values can be obtained by loading a `.flac` or `.wav` audio file
into an array of type `List[float]` or a `numpy.ndarray`, *e.g.* via the soundfile library (`pip install
soundfile`). To prepare the array into `input_values`, the [`AutoProcessor`] should be used for padding and
conversion into a tensor of type `torch.FloatTensor`. See [`Wav2Vec2Processor.__call__`] for details.
attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing convolution and attention on padding token indices. Mask values selected in `[0,
1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
<Tip warning={true}>
`attention_mask` should only be passed if the corresponding processor has `config.return_attention_mask ==
True`. For all models whose processor has `config.return_attention_mask == False`, `attention_mask` should
**not** be passed to avoid degraded performance when doing batched inference. For such models
`input_values` should simply be padded with 0 and passed without `attention_mask`. Be aware that these
models also yield slightly different results depending on whether `input_values` is padded or not.
</Tip>
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
WavLMBaseModelOutput = Wav2Vec2BaseModelOutput
@add_start_docstrings(
"The bare WavLM Model transformer outputting raw hidden-states without any specific head on top.",
WAVLM_START_DOCSTRING,
)
class WavLMModel(Wav2Vec2Model):
@add_start_docstrings_to_model_forward(WAVLM_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=WavLMBaseModelOutput,
config_class=_CONFIG_FOR_DOC,
modality="audio",
expected_output=_EXPECTED_OUTPUT_SHAPE,
)
def forward(self, **super_kwargs):
return super().forward(**super_kwargs)
@add_start_docstrings(
"""WavLM Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC).""",
WAVLM_START_DOCSTRING,
)
class WavLMForCTC(Wav2Vec2ForCTC):
@add_start_docstrings_to_model_forward(WAVLM_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=CausalLMOutput,
config_class=_CONFIG_FOR_DOC,
expected_output=_CTC_EXPECTED_OUTPUT,
expected_loss=_CTC_EXPECTED_LOSS,
)
def forward(self, **super_kwargs):
super().forward(**super_kwargs)
@add_start_docstrings(
"""
WavLM Model with a sequence classification head on top (a linear layer over the pooled output) for tasks like
SUPERB Keyword Spotting.
""",
WAVLM_START_DOCSTRING,
)
class WavLMForSequenceClassification(Wav2Vec2ForSequenceClassification):
@add_start_docstrings_to_model_forward(WAVLM_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=SequenceClassifierOutput,
config_class=_CONFIG_FOR_DOC,
modality="audio",
)
def forward(self, **super_kwargs):
super().forward(**super_kwargs)
@add_start_docstrings(
"""
WavLM Model with a frame classification head on top for tasks like Speaker Diarization.
""",
WAVLM_START_DOCSTRING,
)
class WavLMForAudioFrameClassification(Wav2Vec2ForAudioFrameClassification):
@add_start_docstrings_to_model_forward(WAVLM_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint=_FRAME_CLASS_CHECKPOINT,
output_type=TokenClassifierOutput,
config_class=_CONFIG_FOR_DOC,
modality="audio",
expected_output=_FRAME_EXPECTED_OUTPUT,
)
def forward(self, **super_kwargs):
super().forward(**super_kwargs)
@add_start_docstrings(
"""
WavLM Model with an XVector feature extraction head on top for tasks like Speaker Verification.
""",
WAVLM_START_DOCSTRING,
)
class WavLMForXVector(Wav2Vec2ForXVector):
pass
@add_start_docstrings_to_model_forward(WAVLM_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint=_XVECTOR_CHECKPOINT,
output_type=XVectorOutput,
config_class=_CONFIG_FOR_DOC,
modality="audio",
expected_output=_XVECTOR_EXPECTED_OUTPUT,
)
def forward(self, **super_kwargs):
super().forward(**super_kwargs)
__all__ = [
"WavLMForAudioFrameClassification",
"WavLMForCTC",
"WavLMForSequenceClassification",
"WavLMForXVector",
"WavLMModel",
"WavLMPreTrainedModel",
]

View File

@ -50,9 +50,9 @@ if is_torch_available():
Wav2Vec2BertForXVector,
Wav2Vec2BertModel,
)
from transformers.models.wav2vec2.modeling_wav2vec2 import _sample_negative_indices
from transformers.models.wav2vec2_bert.modeling_wav2vec2_bert import (
_compute_mask_indices,
_sample_negative_indices,
)

View File

@ -55,10 +55,10 @@ if is_torch_available():
Wav2Vec2FeatureExtractor,
Wav2Vec2Processor,
)
from transformers.models.wav2vec2.modeling_wav2vec2 import _sample_negative_indices
from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer import (
Wav2Vec2ConformerGumbelVectorQuantizer,
_compute_mask_indices,
_sample_negative_indices,
)

View File

@ -79,8 +79,11 @@ def preserve_case_replace(text, patterns: dict, default_name: str):
def get_cased_name(lowercase_name: str) -> str:
"""From a model name in lowercase in the format `my_model`, return the cased name in the format `MyModel`."""
alt_lowercase_name = lowercase_name.replace("_", "-")
if lowercase_name in CONFIG_MAPPING_NAMES:
return CONFIG_MAPPING_NAMES[lowercase_name].replace("Config", "")
elif alt_lowercase_name in CONFIG_MAPPING_NAMES:
return CONFIG_MAPPING_NAMES[alt_lowercase_name].replace("Config", "")
else:
return "".join(x.title() for x in lowercase_name.split("_"))
@ -106,6 +109,8 @@ class ReplaceNameTransformer(m.MatcherDecoratableTransformer):
def __init__(self, old_name: str, new_name: str, original_new_model_name: str = "", only_doc: bool = False):
super().__init__()
old_name = old_name.replace("-", "_")
new_name = new_name.replace("-", "_")
self.old_name = old_name
self.new_name = new_name
self.cased_new_name = get_cased_name(self.new_name)
@ -535,7 +540,7 @@ def find_all_dependencies(
# Top-level variables that match the following patterns will always use the value in the `modular_xxx.py` file
ASSIGNMENTS_REGEX_TO_KEEP = [r"_CHECKPOINT", r"_EXPECTED", r"_FOR_DOC"]
ASSIGNMENTS_REGEX_TO_KEEP = [r"_CHECKPOINT", r"_EXPECTED", r"_FOR_DOC", r"_HIDDEN_STATES_START_POSITION"]
# Top-level variables that match the following patterns will use the value in the `modular_xxx.py` file only if they are not None
ASSIGNMENTS_REGEX_TO_KEEP_IF_NOT_NONE = [r"_DOCSTRING"]
@ -616,6 +621,7 @@ class ModuleMapper(CSTVisitor, ABC):
self.object_dependency_mapping = defaultdict(set) # immediate function/assignment dependency mapping (i.e. dependencies immediately in the function/assignment definition)
self.assignments: Dict[str, cst.SimpleStatementLine] = {} # mapping of global assignments names to Nodes
self.current_function = None # this keeps track of the current module-scope function
self.current_class = None # this keeps track of the current module-scope class
self.current_assignment = None # this keeps track of the current module-scope assignment
# this keeps track of objects imported from modeling files (`from .configuration import Config`) -> `Config` should not be a dependency
self.objects_imported_from_modeling = set()
@ -672,7 +678,7 @@ class ModuleMapper(CSTVisitor, ABC):
def visit_If(self, node):
# If we are inside a function, do not add the import to the list of imports
if self.current_function is None:
if self.current_function is None and self.current_class is None:
for stmt in node.body.body:
if m.matches(stmt, m.SimpleStatementLine(body=[m.ImportFrom() | m.Import()])):
self.imports.append(node)
@ -680,6 +686,10 @@ class ModuleMapper(CSTVisitor, ABC):
def visit_ClassDef(self, node: ClassDef) -> None:
"""Record class nodes to create their dependencies at the end."""
self.classes[node.name.value] = node
self.current_class = node.name.value
def leave_ClassDef(self, node):
self.current_class = None
def visit_Name(self, node: cst.Call):
"""This is used to create a mapping from module-scope functions and assignments to objects used inside them."""
@ -1024,11 +1034,20 @@ def replace_class_node(
new_decorators = (
updated_methods[name].decorators if len(updated_methods[name].decorators) > 0 else func.decorators
)
# Keep return annotation in `modular_xxx.py` if any, else original return annotation
new_return_annotation = updated_methods[name].returns if updated_methods[name].returns else func.returns
if not re.match(
r"\ndef .*\(.*\):\n raise.*Error\(.*",
mapper.python_module.code_for_node(updated_methods[name]),
):
func = func.with_changes(body=updated_methods[name].body, params=new_params, decorators=new_decorators)
func = func.with_changes(
body=updated_methods[name].body,
params=new_params,
decorators=new_decorators,
returns=new_return_annotation,
)
else:
continue
@ -1136,7 +1155,7 @@ def append_new_import_node(
import_node = node.body[0]
names_to_keep = []
for name in import_node.names:
name_value = name.evaluated_name
name_value = name.evaluated_alias or name.evaluated_name
if name_value not in unused_imports and name_value not in added_names:
names_to_keep.append(name.with_changes(comma=cst.MaybeSentinel.DEFAULT))
added_names.add(name_value)