mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 11:11:05 +06:00
Add SD and SV heads for WavLM (#14847)
* Add converted heads * Add dummies
This commit is contained in:
parent
cd583bdaa5
commit
3883e3a75e
@ -81,3 +81,17 @@ WavLMForSequenceClassification
|
||||
|
||||
.. autoclass:: transformers.WavLMForSequenceClassification
|
||||
:members: forward
|
||||
|
||||
|
||||
WavLMForAudioFrameClassification
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.WavLMForAudioFrameClassification
|
||||
:members: forward
|
||||
|
||||
|
||||
WavLMForXVector
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.WavLMForXVector
|
||||
:members: forward
|
||||
|
@ -1381,8 +1381,10 @@ if is_torch_available():
|
||||
_import_structure["models.wavlm"].extend(
|
||||
[
|
||||
"WAVLM_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
"WavLMForAudioFrameClassification",
|
||||
"WavLMForCTC",
|
||||
"WavLMForSequenceClassification",
|
||||
"WavLMForXVector",
|
||||
"WavLMModel",
|
||||
"WavLMPreTrainedModel",
|
||||
]
|
||||
@ -3230,8 +3232,10 @@ if TYPE_CHECKING:
|
||||
)
|
||||
from .models.wavlm import (
|
||||
WAVLM_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
WavLMForAudioFrameClassification,
|
||||
WavLMForCTC,
|
||||
WavLMForSequenceClassification,
|
||||
WavLMForXVector,
|
||||
WavLMModel,
|
||||
WavLMPreTrainedModel,
|
||||
)
|
||||
|
@ -546,6 +546,7 @@ MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
|
||||
# Model for Audio Classification mapping
|
||||
("wav2vec2", "Wav2Vec2ForAudioFrameClassification"),
|
||||
("unispeech-sat", "UniSpeechSatForAudioFrameClassification"),
|
||||
("wavlm", "WavLMForAudioFrameClassification"),
|
||||
]
|
||||
)
|
||||
|
||||
@ -554,6 +555,7 @@ MODEL_FOR_AUDIO_XVECTOR_MAPPING_NAMES = OrderedDict(
|
||||
# Model for Audio Classification mapping
|
||||
("wav2vec2", "Wav2Vec2ForXVector"),
|
||||
("unispeech-sat", "UniSpeechSatForXVector"),
|
||||
("wavlm", "WavLMForXVector"),
|
||||
]
|
||||
)
|
||||
|
||||
|
@ -27,8 +27,10 @@ _import_structure = {
|
||||
if is_torch_available():
|
||||
_import_structure["modeling_wavlm"] = [
|
||||
"WAVLM_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
"WavLMForAudioFrameClassification",
|
||||
"WavLMForCTC",
|
||||
"WavLMForSequenceClassification",
|
||||
"WavLMForXVector",
|
||||
"WavLMModel",
|
||||
"WavLMPreTrainedModel",
|
||||
]
|
||||
@ -39,8 +41,10 @@ if TYPE_CHECKING:
|
||||
if is_torch_available():
|
||||
from .modeling_wavlm import (
|
||||
WAVLM_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
WavLMForAudioFrameClassification,
|
||||
WavLMForCTC,
|
||||
WavLMForSequenceClassification,
|
||||
WavLMForXVector,
|
||||
WavLMModel,
|
||||
WavLMPreTrainedModel,
|
||||
)
|
||||
|
@ -144,6 +144,17 @@ class WavLMConfig(PretrainedConfig):
|
||||
instance of :class:`~transformers.WavLMForSequenceClassification`.
|
||||
classifier_proj_size (:obj:`int`, `optional`, defaults to 256):
|
||||
Dimensionality of the projection before token mean-pooling for classification.
|
||||
tdnn_dim (:obj:`Tuple[int]`, `optional`, defaults to :obj:`(512, 512, 512, 512, 1500)`):
|
||||
A tuple of integers defining the number of output channels of each 1D convolutional layer in the `TDNN`
|
||||
module of the `XVector` model. The length of `tdnn_dim` defines the number of `TDNN` layers.
|
||||
tdnn_kernel (:obj:`Tuple[int]`, `optional`, defaults to :obj:`(5, 3, 3, 1, 1)`):
|
||||
A tuple of integers defining the kernel size of each 1D convolutional layer in the `TDNN` module of the
|
||||
`XVector` model. The length of `tdnn_kernel` has to match the length of `tdnn_dim`.
|
||||
tdnn_dilation (:obj:`Tuple[int]`, `optional`, defaults to :obj:`(1, 2, 3, 1, 1)`):
|
||||
A tuple of integers defining the dilation factor of each 1D convolutional layer in `TDNN` module of the
|
||||
`XVector` model. The length of `tdnn_dilation` has to match the length of `tdnn_dim`.
|
||||
xvector_output_dim (:obj:`int`, `optional`, defaults to 512):
|
||||
Dimensionality of the `XVector` embedding vectors.
|
||||
add_adapter (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether a convolutional network should be stacked on top of the Wav2Vec2 Encoder. Can be very useful for
|
||||
warm-starting Wav2Vec2 for SpeechEncoderDecoder models.
|
||||
@ -220,6 +231,10 @@ class WavLMConfig(PretrainedConfig):
|
||||
ctc_zero_infinity=False,
|
||||
use_weighted_layer_sum=False,
|
||||
classifier_proj_size=256,
|
||||
tdnn_dim=(512, 512, 512, 512, 1500),
|
||||
tdnn_kernel=(5, 3, 3, 1, 1),
|
||||
tdnn_dilation=(1, 2, 3, 1, 1),
|
||||
xvector_output_dim=512,
|
||||
num_ctc_classes=80,
|
||||
pad_token_id=0,
|
||||
bos_token_id=1,
|
||||
@ -302,3 +317,12 @@ class WavLMConfig(PretrainedConfig):
|
||||
self.adapter_stride = adapter_stride
|
||||
self.num_adapter_layers = num_adapter_layers
|
||||
self.output_hidden_size = output_hidden_size or hidden_size
|
||||
|
||||
# SequenceClassification-specific parameter. Feel free to ignore for other classes.
|
||||
self.classifier_proj_size = classifier_proj_size
|
||||
|
||||
# XVector-specific parameters. Feel free to ignore for other classes.
|
||||
self.tdnn_dim = list(tdnn_dim)
|
||||
self.tdnn_kernel = list(tdnn_kernel)
|
||||
self.tdnn_dilation = list(tdnn_dilation)
|
||||
self.xvector_output_dim = xvector_output_dim
|
||||
|
@ -0,0 +1,110 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2021 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Convert Hubert checkpoint."""
|
||||
|
||||
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
|
||||
from transformers import (
|
||||
Wav2Vec2FeatureExtractor,
|
||||
WavLMConfig,
|
||||
WavLMForAudioFrameClassification,
|
||||
WavLMForSequenceClassification,
|
||||
WavLMForXVector,
|
||||
logging,
|
||||
)
|
||||
|
||||
|
||||
logging.set_verbosity_info()
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def convert_classification(base_model_name, hf_config, downstream_dict):
|
||||
model = WavLMForSequenceClassification.from_pretrained(base_model_name, config=hf_config)
|
||||
model.projector.weight.data = downstream_dict["projector.weight"]
|
||||
model.projector.bias.data = downstream_dict["projector.bias"]
|
||||
model.classifier.weight.data = downstream_dict["model.post_net.linear.weight"]
|
||||
model.classifier.bias.data = downstream_dict["model.post_net.linear.bias"]
|
||||
return model
|
||||
|
||||
|
||||
def convert_diarization(base_model_name, hf_config, downstream_dict):
|
||||
model = WavLMForAudioFrameClassification.from_pretrained(base_model_name, config=hf_config)
|
||||
model.classifier.weight.data = downstream_dict["model.linear.weight"]
|
||||
model.classifier.bias.data = downstream_dict["model.linear.bias"]
|
||||
return model
|
||||
|
||||
|
||||
def convert_xvector(base_model_name, hf_config, downstream_dict):
|
||||
model = WavLMForXVector.from_pretrained(base_model_name, config=hf_config)
|
||||
model.projector.weight.data = downstream_dict["connector.weight"]
|
||||
model.projector.bias.data = downstream_dict["connector.bias"]
|
||||
for i, kernel_size in enumerate(hf_config.tdnn_kernel):
|
||||
model.tdnn[i].kernel.weight.data = downstream_dict[
|
||||
f"model.framelevel_feature_extractor.module.{i}.kernel.weight"
|
||||
]
|
||||
model.tdnn[i].kernel.bias.data = downstream_dict[f"model.framelevel_feature_extractor.module.{i}.kernel.bias"]
|
||||
|
||||
model.feature_extractor.weight.data = downstream_dict["model.utterancelevel_feature_extractor.linear1.weight"]
|
||||
model.feature_extractor.bias.data = downstream_dict["model.utterancelevel_feature_extractor.linear1.bias"]
|
||||
model.classifier.weight.data = downstream_dict["model.utterancelevel_feature_extractor.linear2.weight"]
|
||||
model.classifier.bias.data = downstream_dict["model.utterancelevel_feature_extractor.linear2.bias"]
|
||||
model.objective.weight.data = downstream_dict["objective.W"]
|
||||
return model
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def convert_s3prl_checkpoint(base_model_name, config_path, checkpoint_path, model_dump_path):
|
||||
"""
|
||||
Copy/paste/tweak model's weights to transformers design.
|
||||
"""
|
||||
checkpoint = torch.load(checkpoint_path, map_location="cpu")
|
||||
|
||||
downstream_dict = checkpoint["Downstream"]
|
||||
|
||||
hf_config = WavLMConfig.from_pretrained(config_path)
|
||||
hf_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
|
||||
base_model_name, return_attention_mask=True, do_normalize=False
|
||||
)
|
||||
|
||||
arch = hf_config.architectures[0]
|
||||
if arch.endswith("ForSequenceClassification"):
|
||||
hf_model = convert_classification(base_model_name, hf_config, downstream_dict)
|
||||
elif arch.endswith("ForAudioFrameClassification"):
|
||||
hf_model = convert_diarization(base_model_name, hf_config, downstream_dict)
|
||||
elif arch.endswith("ForXVector"):
|
||||
hf_model = convert_xvector(base_model_name, hf_config, downstream_dict)
|
||||
else:
|
||||
raise NotImplementedError(f"S3PRL weights conversion is not supported for {arch}")
|
||||
|
||||
if hf_config.use_weighted_layer_sum:
|
||||
hf_model.layer_weights.data = checkpoint["Featurizer"]["weights"]
|
||||
|
||||
hf_feature_extractor.save_pretrained(model_dump_path)
|
||||
hf_model.save_pretrained(model_dump_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--base_model_name", default=None, type=str, help="Name of the huggingface pretrained base model."
|
||||
)
|
||||
parser.add_argument("--config_path", default=None, type=str, help="Path to the huggingface classifier config.")
|
||||
parser.add_argument("--checkpoint_path", default=None, type=str, help="Path to the s3prl checkpoint.")
|
||||
parser.add_argument("--model_dump_path", default=None, type=str, help="Path to the final converted model.")
|
||||
args = parser.parse_args()
|
||||
convert_s3prl_checkpoint(args.base_model_name, args.config_path, args.checkpoint_path, args.model_dump_path)
|
@ -33,7 +33,7 @@ from ...file_utils import (
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
)
|
||||
from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput
|
||||
from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput, TokenClassifierOutput
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...utils import logging
|
||||
from .configuration_wavlm import WavLMConfig
|
||||
@ -48,6 +48,10 @@ _CHECKPOINT_FOR_DOC = "microsoft/wavlm-base"
|
||||
_SEQ_CLASS_CHECKPOINT = "microsoft/wavlm-base"
|
||||
_FEAT_EXTRACTOR_FOR_DOC = "Wav2Vec2FeatureExtractor"
|
||||
|
||||
_SEQ_CLASS_CHECKPOINT = "microsoft/wavlm-base-plus"
|
||||
_FRAME_CLASS_CHECKPOINT = "microsoft/wavlm-base-plus-sd"
|
||||
_XVECTOR_CHECKPOINT = "microsoft/wavlm-base-plus-sv"
|
||||
|
||||
_HIDDEN_STATES_START_POSITION = 2
|
||||
|
||||
WAVLM_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||
@ -87,6 +91,38 @@ class WavLMBaseModelOutput(ModelOutput):
|
||||
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class XVectorOutput(ModelOutput):
|
||||
"""
|
||||
Output type of :class:`~transformers.Wav2Vec2ForXVector`.
|
||||
|
||||
Args:
|
||||
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`labels` is provided):
|
||||
Classification loss.
|
||||
logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.xvector_output_dim)`):
|
||||
Classification hidden states before AMSoftmax.
|
||||
embeddings (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.xvector_output_dim)`):
|
||||
Utterance embeddings used for vector similarity-based retrieval.
|
||||
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
||||
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
|
||||
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
||||
|
||||
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
||||
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
|
||||
sequence_length, sequence_length)`.
|
||||
|
||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
||||
heads.
|
||||
"""
|
||||
|
||||
loss: Optional[torch.FloatTensor] = None
|
||||
logits: torch.FloatTensor = None
|
||||
embeddings: torch.FloatTensor = None
|
||||
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
||||
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||
|
||||
|
||||
# Copied from transformers.models.wav2vec2.modeling_wav2vec2._compute_mask_indices
|
||||
def _compute_mask_indices(
|
||||
shape: Tuple[int, int],
|
||||
@ -1447,3 +1483,285 @@ class WavLMForSequenceClassification(WavLMPreTrainedModel):
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""
|
||||
WavLM Model with a frame classification head on top for tasks like Speaker Diarization.
|
||||
""",
|
||||
WAVLM_START_DOCSTRING,
|
||||
)
|
||||
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForAudioFrameClassification with Wav2Vec2->WavLM, wav2vec2->wavlm, WAV_2_VEC_2->WAVLM
|
||||
class WavLMForAudioFrameClassification(WavLMPreTrainedModel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
self.wavlm = WavLMModel(config)
|
||||
num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings
|
||||
if config.use_weighted_layer_sum:
|
||||
self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
|
||||
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
||||
|
||||
self.init_weights()
|
||||
|
||||
def freeze_feature_extractor(self):
|
||||
"""
|
||||
Calling this function will disable the gradient computation for the feature extractor so that its parameters
|
||||
will not be updated during training.
|
||||
"""
|
||||
self.wavlm.feature_extractor._freeze_parameters()
|
||||
|
||||
def freeze_base_model(self):
|
||||
"""
|
||||
Calling this function will disable the gradient computation for the base model so that its parameters will not
|
||||
be updated during training. Only the classification head will be updated.
|
||||
"""
|
||||
for param in self.wavlm.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
@add_start_docstrings_to_model_forward(WAVLM_INPUTS_DOCSTRING)
|
||||
@add_code_sample_docstrings(
|
||||
processor_class=_FEAT_EXTRACTOR_FOR_DOC,
|
||||
checkpoint=_FRAME_CLASS_CHECKPOINT,
|
||||
output_type=TokenClassifierOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
modality="audio",
|
||||
)
|
||||
def forward(
|
||||
self,
|
||||
input_values,
|
||||
attention_mask=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
):
|
||||
r"""
|
||||
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
|
||||
Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ...,
|
||||
config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
|
||||
If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
||||
"""
|
||||
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states
|
||||
|
||||
outputs = self.wavlm(
|
||||
input_values,
|
||||
attention_mask=attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
if self.config.use_weighted_layer_sum:
|
||||
hidden_states = outputs[_HIDDEN_STATES_START_POSITION]
|
||||
hidden_states = torch.stack(hidden_states, dim=1)
|
||||
norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
|
||||
hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
|
||||
else:
|
||||
hidden_states = outputs[0]
|
||||
|
||||
logits = self.classifier(hidden_states)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
|
||||
return output
|
||||
|
||||
return TokenClassifierOutput(
|
||||
loss=None,
|
||||
logits=logits,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.AMSoftmaxLoss
|
||||
class AMSoftmaxLoss(nn.Module):
|
||||
def __init__(self, input_dim, num_labels, scale=30.0, margin=0.4):
|
||||
super(AMSoftmaxLoss, self).__init__()
|
||||
self.scale = scale
|
||||
self.margin = margin
|
||||
self.num_labels = num_labels
|
||||
self.weight = nn.Parameter(torch.randn(input_dim, num_labels), requires_grad=True)
|
||||
self.loss = nn.CrossEntropyLoss()
|
||||
|
||||
def forward(self, hidden_states, labels):
|
||||
labels = labels.flatten()
|
||||
weight = nn.functional.normalize(self.weight, dim=0)
|
||||
hidden_states = nn.functional.normalize(hidden_states, dim=1)
|
||||
cos_theta = torch.mm(hidden_states, weight)
|
||||
psi = cos_theta - self.margin
|
||||
|
||||
onehot = nn.functional.one_hot(labels, self.num_labels)
|
||||
logits = self.scale * torch.where(onehot.bool(), psi, cos_theta)
|
||||
loss = self.loss(logits, labels)
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.TDNNLayer
|
||||
class TDNNLayer(nn.Module):
|
||||
def __init__(self, config, layer_id=0):
|
||||
super().__init__()
|
||||
self.in_conv_dim = config.tdnn_dim[layer_id - 1] if layer_id > 0 else config.tdnn_dim[layer_id]
|
||||
self.out_conv_dim = config.tdnn_dim[layer_id]
|
||||
self.kernel_size = config.tdnn_kernel[layer_id]
|
||||
self.dilation = config.tdnn_dilation[layer_id]
|
||||
|
||||
self.kernel = nn.Linear(self.in_conv_dim * self.kernel_size, self.out_conv_dim)
|
||||
self.activation = nn.ReLU()
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = hidden_states.unsqueeze(1)
|
||||
hidden_states = nn.functional.unfold(
|
||||
hidden_states,
|
||||
(self.kernel_size, self.in_conv_dim),
|
||||
stride=(1, self.in_conv_dim),
|
||||
dilation=(self.dilation, 1),
|
||||
)
|
||||
hidden_states = hidden_states.transpose(1, 2)
|
||||
hidden_states = self.kernel(hidden_states)
|
||||
|
||||
hidden_states = self.activation(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""
|
||||
WavLM Model with an XVector feature extraction head on top for tasks like Speaker Verification.
|
||||
""",
|
||||
WAVLM_START_DOCSTRING,
|
||||
)
|
||||
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForXVector with Wav2Vec2->WavLM, wav2vec2->wavlm, WAV_2_VEC_2->WAVLM
|
||||
class WavLMForXVector(WavLMPreTrainedModel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
self.wavlm = WavLMModel(config)
|
||||
num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings
|
||||
if config.use_weighted_layer_sum:
|
||||
self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
|
||||
self.projector = nn.Linear(config.hidden_size, config.tdnn_dim[0])
|
||||
|
||||
tdnn_layers = [TDNNLayer(config, i) for i in range(len(config.tdnn_dim))]
|
||||
self.tdnn = nn.ModuleList(tdnn_layers)
|
||||
|
||||
self.feature_extractor = nn.Linear(config.tdnn_dim[-1] * 2, config.xvector_output_dim)
|
||||
self.classifier = nn.Linear(config.xvector_output_dim, config.xvector_output_dim)
|
||||
|
||||
self.objective = AMSoftmaxLoss(config.xvector_output_dim, config.num_labels)
|
||||
|
||||
self.init_weights()
|
||||
|
||||
def freeze_feature_extractor(self):
|
||||
"""
|
||||
Calling this function will disable the gradient computation for the feature extractor so that its parameters
|
||||
will not be updated during training.
|
||||
"""
|
||||
self.wavlm.feature_extractor._freeze_parameters()
|
||||
|
||||
def freeze_base_model(self):
|
||||
"""
|
||||
Calling this function will disable the gradient computation for the base model so that its parameters will not
|
||||
be updated during training. Only the classification head will be updated.
|
||||
"""
|
||||
for param in self.wavlm.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def _get_tdnn_output_lengths(self, input_lengths: Union[torch.LongTensor, int]):
|
||||
"""
|
||||
Computes the output length of the TDNN layers
|
||||
"""
|
||||
|
||||
def _conv_out_length(input_length, kernel_size, stride):
|
||||
# 1D convolutional layer output length formula taken
|
||||
# from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
|
||||
return (input_length - kernel_size) // stride + 1
|
||||
|
||||
for kernel_size in self.config.tdnn_kernel:
|
||||
input_lengths = _conv_out_length(input_lengths, kernel_size, 1)
|
||||
|
||||
return input_lengths
|
||||
|
||||
@add_start_docstrings_to_model_forward(WAVLM_INPUTS_DOCSTRING)
|
||||
@add_code_sample_docstrings(
|
||||
processor_class=_FEAT_EXTRACTOR_FOR_DOC,
|
||||
checkpoint=_XVECTOR_CHECKPOINT,
|
||||
output_type=XVectorOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
modality="audio",
|
||||
)
|
||||
def forward(
|
||||
self,
|
||||
input_values,
|
||||
attention_mask=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
labels=None,
|
||||
):
|
||||
r"""
|
||||
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
|
||||
Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ...,
|
||||
config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
|
||||
If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
||||
"""
|
||||
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states
|
||||
|
||||
outputs = self.wavlm(
|
||||
input_values,
|
||||
attention_mask=attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
if self.config.use_weighted_layer_sum:
|
||||
hidden_states = outputs[_HIDDEN_STATES_START_POSITION]
|
||||
hidden_states = torch.stack(hidden_states, dim=1)
|
||||
norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
|
||||
hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
|
||||
else:
|
||||
hidden_states = outputs[0]
|
||||
|
||||
hidden_states = self.projector(hidden_states)
|
||||
|
||||
for tdnn_layer in self.tdnn:
|
||||
hidden_states = tdnn_layer(hidden_states)
|
||||
|
||||
# Statistic Pooling
|
||||
if attention_mask is None:
|
||||
mean_features = hidden_states.mean(dim=1)
|
||||
std_features = hidden_states.std(dim=1)
|
||||
else:
|
||||
feat_extract_output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(dim=1))
|
||||
tdnn_output_lengths = self._get_tdnn_output_lengths(feat_extract_output_lengths)
|
||||
mean_features = []
|
||||
std_features = []
|
||||
for i, length in enumerate(tdnn_output_lengths):
|
||||
mean_features.append(hidden_states[i, :length].mean(dim=0))
|
||||
std_features.append(hidden_states[i, :length].std(dim=0))
|
||||
mean_features = torch.stack(mean_features)
|
||||
std_features = torch.stack(std_features)
|
||||
statistic_pooling = torch.cat([mean_features, std_features], dim=-1)
|
||||
|
||||
output_embeddings = self.feature_extractor(statistic_pooling)
|
||||
logits = self.classifier(output_embeddings)
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss = self.objective(logits, labels)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits, output_embeddings) + outputs[_HIDDEN_STATES_START_POSITION:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return XVectorOutput(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
embeddings=output_embeddings,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
@ -5177,6 +5177,11 @@ class Wav2Vec2PreTrainedModel:
|
||||
WAVLM_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
||||
|
||||
|
||||
class WavLMForAudioFrameClassification:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class WavLMForCTC:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
@ -5194,6 +5199,11 @@ class WavLMForSequenceClassification:
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class WavLMForXVector:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class WavLMModel:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
@ -863,10 +863,10 @@ class UniSpeechSatModelIntegrationTest(unittest.TestCase):
|
||||
self.assertTrue(torch.allclose(outputs.last_hidden_state[:, :2, -2:], expected_hidden_states_slice, atol=1e-3))
|
||||
|
||||
def test_inference_diarization(self):
|
||||
model = UniSpeechSatForAudioFrameClassification.from_pretrained("anton-l/unispeech-sat-base-plus-sd").to(
|
||||
model = UniSpeechSatForAudioFrameClassification.from_pretrained("microsoft/unispeech-sat-base-plus-sd").to(
|
||||
torch_device
|
||||
)
|
||||
processor = Wav2Vec2FeatureExtractor.from_pretrained("anton-l/unispeech-sat-base-plus-sd")
|
||||
processor = Wav2Vec2FeatureExtractor.from_pretrained("microsoft/unispeech-sat-base-plus-sd")
|
||||
input_data = self._load_superb("sd", 4)
|
||||
inputs = processor(input_data["speech"], return_tensors="pt", padding=True, sampling_rate=16_000)
|
||||
|
||||
@ -892,8 +892,8 @@ class UniSpeechSatModelIntegrationTest(unittest.TestCase):
|
||||
self.assertTrue(torch.allclose(outputs.logits[:, :4], expected_logits, atol=1e-3))
|
||||
|
||||
def test_inference_speaker_verification(self):
|
||||
model = UniSpeechSatForXVector.from_pretrained("anton-l/unispeech-sat-base-plus-sv").to(torch_device)
|
||||
processor = Wav2Vec2FeatureExtractor.from_pretrained("anton-l/unispeech-sat-base-plus-sv")
|
||||
model = UniSpeechSatForXVector.from_pretrained("microsoft/unispeech-sat-base-plus-sv").to(torch_device)
|
||||
processor = Wav2Vec2FeatureExtractor.from_pretrained("microsoft/unispeech-sat-base-plus-sv")
|
||||
input_data = self._load_superb("si", 4)
|
||||
|
||||
inputs = processor(input_data["speech"], return_tensors="pt", padding=True)
|
||||
|
@ -31,7 +31,14 @@ from .test_modeling_common import ModelTesterMixin, _config_zero_init
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers import Wav2Vec2FeatureExtractor, WavLMForCTC, WavLMForSequenceClassification, WavLMModel
|
||||
from transformers import (
|
||||
Wav2Vec2FeatureExtractor,
|
||||
WavLMForAudioFrameClassification,
|
||||
WavLMForCTC,
|
||||
WavLMForSequenceClassification,
|
||||
WavLMForXVector,
|
||||
WavLMModel,
|
||||
)
|
||||
|
||||
|
||||
class WavLMModelTester:
|
||||
@ -60,6 +67,10 @@ class WavLMModelTester:
|
||||
initializer_range=0.02,
|
||||
vocab_size=32,
|
||||
do_stable_layer_norm=False,
|
||||
tdnn_dim=(32, 32),
|
||||
tdnn_kernel=(3, 3),
|
||||
tdnn_dilation=(1, 1),
|
||||
xvector_output_dim=32,
|
||||
scope=None,
|
||||
):
|
||||
self.parent = parent
|
||||
@ -85,6 +96,10 @@ class WavLMModelTester:
|
||||
self.initializer_range = initializer_range
|
||||
self.vocab_size = vocab_size
|
||||
self.do_stable_layer_norm = do_stable_layer_norm
|
||||
self.tdnn_dim = tdnn_dim
|
||||
self.tdnn_kernel = tdnn_kernel
|
||||
self.tdnn_dilation = tdnn_dilation
|
||||
self.xvector_output_dim = xvector_output_dim
|
||||
self.scope = scope
|
||||
|
||||
output_seq_length = self.seq_length
|
||||
@ -121,6 +136,10 @@ class WavLMModelTester:
|
||||
hidden_act=self.hidden_act,
|
||||
initializer_range=self.initializer_range,
|
||||
vocab_size=self.vocab_size,
|
||||
tdnn_dim=self.tdnn_dim,
|
||||
tdnn_kernel=self.tdnn_kernel,
|
||||
tdnn_dilation=self.tdnn_dilation,
|
||||
xvector_output_dim=self.xvector_output_dim,
|
||||
)
|
||||
|
||||
def create_and_check_model(self, config, input_values, attention_mask):
|
||||
@ -285,7 +304,11 @@ class WavLMModelTester:
|
||||
|
||||
@require_torch
|
||||
class WavLMModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (WavLMForCTC, WavLMModel, WavLMForSequenceClassification) if is_torch_available() else ()
|
||||
all_model_classes = (
|
||||
(WavLMForCTC, WavLMModel, WavLMForAudioFrameClassification, WavLMForSequenceClassification, WavLMForXVector)
|
||||
if is_torch_available()
|
||||
else ()
|
||||
)
|
||||
test_pruning = False
|
||||
test_headmasking = False
|
||||
test_torchscript = False
|
||||
@ -398,6 +421,7 @@ class WavLMModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
"feature_projection.projection.bias",
|
||||
"label_embeddings_concat",
|
||||
"rel_attn_embed",
|
||||
"objective.weight",
|
||||
]
|
||||
if param.requires_grad:
|
||||
if any([x in name for x in uniform_init_parms]):
|
||||
@ -446,6 +470,11 @@ class WavLMModelIntegrationTest(unittest.TestCase):
|
||||
|
||||
return [x["array"] for x in speech_samples]
|
||||
|
||||
def _load_superb(self, task, num_samples):
|
||||
ds = load_dataset("anton-l/superb_dummy", task, split="test")
|
||||
|
||||
return ds[:num_samples]
|
||||
|
||||
def test_inference_base(self):
|
||||
model = WavLMModel.from_pretrained("microsoft/wavlm-base-plus").to(torch_device)
|
||||
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
|
||||
@ -491,3 +520,54 @@ class WavLMModelIntegrationTest(unittest.TestCase):
|
||||
[[[0.1612, 0.4314], [0.1690, 0.4344]], [[0.2086, 0.1396], [0.3014, 0.0903]]]
|
||||
)
|
||||
self.assertTrue(torch.allclose(hidden_states_slice, EXPECTED_HIDDEN_STATES_SLICE, rtol=1e-2))
|
||||
|
||||
def test_inference_diarization(self):
|
||||
model = WavLMForAudioFrameClassification.from_pretrained("microsoft/wavlm-base-plus-sd").to(torch_device)
|
||||
processor = Wav2Vec2FeatureExtractor.from_pretrained("microsoft/wavlm-base-plus-sd")
|
||||
input_data = self._load_superb("sd", 4)
|
||||
inputs = processor(input_data["speech"], return_tensors="pt", padding=True, sampling_rate=16_000)
|
||||
|
||||
input_values = inputs.input_values.to(torch_device)
|
||||
attention_mask = inputs.attention_mask.to(torch_device)
|
||||
with torch.no_grad():
|
||||
outputs = model(input_values, attention_mask=attention_mask)
|
||||
# labels is a one-hot array of shape (num_frames, num_speakers)
|
||||
labels = (outputs.logits > 0).long()
|
||||
|
||||
# s3prl logits for the same batch
|
||||
expected_logits = torch.tensor(
|
||||
[
|
||||
[[-5.9566, -8.6554], [-5.7137, -8.9386], [-5.7906, -7.0973], [-5.7829, -5.9999]],
|
||||
[[-5.2086, -7.7878], [-4.8890, -7.9312], [-4.2004, -3.9101], [-5.4480, -4.6932]],
|
||||
[[-4.6105, -6.7178], [-5.1930, -6.1635], [-2.6228, -4.1123], [-2.7646, -3.1576]],
|
||||
[[-4.4477, -7.9206], [-3.9339, -7.3707], [-4.9528, -4.8242], [-3.6921, -2.9687]],
|
||||
],
|
||||
device=torch_device,
|
||||
)
|
||||
self.assertEqual(labels[0, :, 0].sum(), 258)
|
||||
self.assertEqual(labels[0, :, 1].sum(), 647)
|
||||
self.assertTrue(torch.allclose(outputs.logits[:, :4], expected_logits, atol=1e-3))
|
||||
|
||||
def test_inference_speaker_verification(self):
|
||||
model = WavLMForXVector.from_pretrained("microsoft/wavlm-base-plus-sv").to(torch_device)
|
||||
processor = Wav2Vec2FeatureExtractor.from_pretrained("microsoft/wavlm-base-plus-sv")
|
||||
input_data = self._load_superb("si", 4)
|
||||
|
||||
inputs = processor(input_data["speech"], return_tensors="pt", padding=True)
|
||||
labels = torch.tensor([5, 1, 1, 3], device=torch_device).T
|
||||
|
||||
with torch.no_grad():
|
||||
input_values = inputs.input_values.to(torch_device)
|
||||
attention_mask = inputs.attention_mask.to(torch_device)
|
||||
outputs = model(input_values, attention_mask=attention_mask, labels=labels)
|
||||
embeddings = torch.nn.functional.normalize(outputs.embeddings, dim=-1)
|
||||
|
||||
cosine_sim = torch.nn.CosineSimilarity(dim=-1)
|
||||
# id10002 vs id10002
|
||||
self.assertAlmostEqual(cosine_sim(embeddings[1], embeddings[2]).item(), 0.9787, 3)
|
||||
# id10006 vs id10002
|
||||
self.assertAlmostEqual(cosine_sim(embeddings[0], embeddings[1]).item(), 0.5064, 3)
|
||||
# id10002 vs id10004
|
||||
self.assertAlmostEqual(cosine_sim(embeddings[2], embeddings[3]).item(), 0.4780, 3)
|
||||
|
||||
self.assertAlmostEqual(outputs.loss.item(), 18.4154, 3)
|
||||
|
Loading…
Reference in New Issue
Block a user