Add SD and SV heads for WavLM (#14847)

* Add converted heads

* Add dummies
This commit is contained in:
Anton Lozhkov 2021-12-20 16:40:56 +03:00 committed by GitHub
parent cd583bdaa5
commit 3883e3a75e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 573 additions and 7 deletions

View File

@ -81,3 +81,17 @@ WavLMForSequenceClassification
.. autoclass:: transformers.WavLMForSequenceClassification
:members: forward
WavLMForAudioFrameClassification
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.WavLMForAudioFrameClassification
:members: forward
WavLMForXVector
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.WavLMForXVector
:members: forward

View File

@ -1381,8 +1381,10 @@ if is_torch_available():
_import_structure["models.wavlm"].extend(
[
"WAVLM_PRETRAINED_MODEL_ARCHIVE_LIST",
"WavLMForAudioFrameClassification",
"WavLMForCTC",
"WavLMForSequenceClassification",
"WavLMForXVector",
"WavLMModel",
"WavLMPreTrainedModel",
]
@ -3230,8 +3232,10 @@ if TYPE_CHECKING:
)
from .models.wavlm import (
WAVLM_PRETRAINED_MODEL_ARCHIVE_LIST,
WavLMForAudioFrameClassification,
WavLMForCTC,
WavLMForSequenceClassification,
WavLMForXVector,
WavLMModel,
WavLMPreTrainedModel,
)

View File

@ -546,6 +546,7 @@ MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
# Model for Audio Classification mapping
("wav2vec2", "Wav2Vec2ForAudioFrameClassification"),
("unispeech-sat", "UniSpeechSatForAudioFrameClassification"),
("wavlm", "WavLMForAudioFrameClassification"),
]
)
@ -554,6 +555,7 @@ MODEL_FOR_AUDIO_XVECTOR_MAPPING_NAMES = OrderedDict(
# Model for Audio Classification mapping
("wav2vec2", "Wav2Vec2ForXVector"),
("unispeech-sat", "UniSpeechSatForXVector"),
("wavlm", "WavLMForXVector"),
]
)

View File

@ -27,8 +27,10 @@ _import_structure = {
if is_torch_available():
_import_structure["modeling_wavlm"] = [
"WAVLM_PRETRAINED_MODEL_ARCHIVE_LIST",
"WavLMForAudioFrameClassification",
"WavLMForCTC",
"WavLMForSequenceClassification",
"WavLMForXVector",
"WavLMModel",
"WavLMPreTrainedModel",
]
@ -39,8 +41,10 @@ if TYPE_CHECKING:
if is_torch_available():
from .modeling_wavlm import (
WAVLM_PRETRAINED_MODEL_ARCHIVE_LIST,
WavLMForAudioFrameClassification,
WavLMForCTC,
WavLMForSequenceClassification,
WavLMForXVector,
WavLMModel,
WavLMPreTrainedModel,
)

View File

@ -144,6 +144,17 @@ class WavLMConfig(PretrainedConfig):
instance of :class:`~transformers.WavLMForSequenceClassification`.
classifier_proj_size (:obj:`int`, `optional`, defaults to 256):
Dimensionality of the projection before token mean-pooling for classification.
tdnn_dim (:obj:`Tuple[int]`, `optional`, defaults to :obj:`(512, 512, 512, 512, 1500)`):
A tuple of integers defining the number of output channels of each 1D convolutional layer in the `TDNN`
module of the `XVector` model. The length of `tdnn_dim` defines the number of `TDNN` layers.
tdnn_kernel (:obj:`Tuple[int]`, `optional`, defaults to :obj:`(5, 3, 3, 1, 1)`):
A tuple of integers defining the kernel size of each 1D convolutional layer in the `TDNN` module of the
`XVector` model. The length of `tdnn_kernel` has to match the length of `tdnn_dim`.
tdnn_dilation (:obj:`Tuple[int]`, `optional`, defaults to :obj:`(1, 2, 3, 1, 1)`):
A tuple of integers defining the dilation factor of each 1D convolutional layer in `TDNN` module of the
`XVector` model. The length of `tdnn_dilation` has to match the length of `tdnn_dim`.
xvector_output_dim (:obj:`int`, `optional`, defaults to 512):
Dimensionality of the `XVector` embedding vectors.
add_adapter (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether a convolutional network should be stacked on top of the Wav2Vec2 Encoder. Can be very useful for
warm-starting Wav2Vec2 for SpeechEncoderDecoder models.
@ -220,6 +231,10 @@ class WavLMConfig(PretrainedConfig):
ctc_zero_infinity=False,
use_weighted_layer_sum=False,
classifier_proj_size=256,
tdnn_dim=(512, 512, 512, 512, 1500),
tdnn_kernel=(5, 3, 3, 1, 1),
tdnn_dilation=(1, 2, 3, 1, 1),
xvector_output_dim=512,
num_ctc_classes=80,
pad_token_id=0,
bos_token_id=1,
@ -302,3 +317,12 @@ class WavLMConfig(PretrainedConfig):
self.adapter_stride = adapter_stride
self.num_adapter_layers = num_adapter_layers
self.output_hidden_size = output_hidden_size or hidden_size
# SequenceClassification-specific parameter. Feel free to ignore for other classes.
self.classifier_proj_size = classifier_proj_size
# XVector-specific parameters. Feel free to ignore for other classes.
self.tdnn_dim = list(tdnn_dim)
self.tdnn_kernel = list(tdnn_kernel)
self.tdnn_dilation = list(tdnn_dilation)
self.xvector_output_dim = xvector_output_dim

View File

@ -0,0 +1,110 @@
# coding=utf-8
# Copyright 2021 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert Hubert checkpoint."""
import argparse
import torch
from transformers import (
Wav2Vec2FeatureExtractor,
WavLMConfig,
WavLMForAudioFrameClassification,
WavLMForSequenceClassification,
WavLMForXVector,
logging,
)
logging.set_verbosity_info()
logger = logging.get_logger(__name__)
def convert_classification(base_model_name, hf_config, downstream_dict):
model = WavLMForSequenceClassification.from_pretrained(base_model_name, config=hf_config)
model.projector.weight.data = downstream_dict["projector.weight"]
model.projector.bias.data = downstream_dict["projector.bias"]
model.classifier.weight.data = downstream_dict["model.post_net.linear.weight"]
model.classifier.bias.data = downstream_dict["model.post_net.linear.bias"]
return model
def convert_diarization(base_model_name, hf_config, downstream_dict):
model = WavLMForAudioFrameClassification.from_pretrained(base_model_name, config=hf_config)
model.classifier.weight.data = downstream_dict["model.linear.weight"]
model.classifier.bias.data = downstream_dict["model.linear.bias"]
return model
def convert_xvector(base_model_name, hf_config, downstream_dict):
model = WavLMForXVector.from_pretrained(base_model_name, config=hf_config)
model.projector.weight.data = downstream_dict["connector.weight"]
model.projector.bias.data = downstream_dict["connector.bias"]
for i, kernel_size in enumerate(hf_config.tdnn_kernel):
model.tdnn[i].kernel.weight.data = downstream_dict[
f"model.framelevel_feature_extractor.module.{i}.kernel.weight"
]
model.tdnn[i].kernel.bias.data = downstream_dict[f"model.framelevel_feature_extractor.module.{i}.kernel.bias"]
model.feature_extractor.weight.data = downstream_dict["model.utterancelevel_feature_extractor.linear1.weight"]
model.feature_extractor.bias.data = downstream_dict["model.utterancelevel_feature_extractor.linear1.bias"]
model.classifier.weight.data = downstream_dict["model.utterancelevel_feature_extractor.linear2.weight"]
model.classifier.bias.data = downstream_dict["model.utterancelevel_feature_extractor.linear2.bias"]
model.objective.weight.data = downstream_dict["objective.W"]
return model
@torch.no_grad()
def convert_s3prl_checkpoint(base_model_name, config_path, checkpoint_path, model_dump_path):
"""
Copy/paste/tweak model's weights to transformers design.
"""
checkpoint = torch.load(checkpoint_path, map_location="cpu")
downstream_dict = checkpoint["Downstream"]
hf_config = WavLMConfig.from_pretrained(config_path)
hf_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
base_model_name, return_attention_mask=True, do_normalize=False
)
arch = hf_config.architectures[0]
if arch.endswith("ForSequenceClassification"):
hf_model = convert_classification(base_model_name, hf_config, downstream_dict)
elif arch.endswith("ForAudioFrameClassification"):
hf_model = convert_diarization(base_model_name, hf_config, downstream_dict)
elif arch.endswith("ForXVector"):
hf_model = convert_xvector(base_model_name, hf_config, downstream_dict)
else:
raise NotImplementedError(f"S3PRL weights conversion is not supported for {arch}")
if hf_config.use_weighted_layer_sum:
hf_model.layer_weights.data = checkpoint["Featurizer"]["weights"]
hf_feature_extractor.save_pretrained(model_dump_path)
hf_model.save_pretrained(model_dump_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--base_model_name", default=None, type=str, help="Name of the huggingface pretrained base model."
)
parser.add_argument("--config_path", default=None, type=str, help="Path to the huggingface classifier config.")
parser.add_argument("--checkpoint_path", default=None, type=str, help="Path to the s3prl checkpoint.")
parser.add_argument("--model_dump_path", default=None, type=str, help="Path to the final converted model.")
args = parser.parse_args()
convert_s3prl_checkpoint(args.base_model_name, args.config_path, args.checkpoint_path, args.model_dump_path)

View File

@ -33,7 +33,7 @@ from ...file_utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
)
from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput
from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput, TokenClassifierOutput
from ...modeling_utils import PreTrainedModel
from ...utils import logging
from .configuration_wavlm import WavLMConfig
@ -48,6 +48,10 @@ _CHECKPOINT_FOR_DOC = "microsoft/wavlm-base"
_SEQ_CLASS_CHECKPOINT = "microsoft/wavlm-base"
_FEAT_EXTRACTOR_FOR_DOC = "Wav2Vec2FeatureExtractor"
_SEQ_CLASS_CHECKPOINT = "microsoft/wavlm-base-plus"
_FRAME_CLASS_CHECKPOINT = "microsoft/wavlm-base-plus-sd"
_XVECTOR_CHECKPOINT = "microsoft/wavlm-base-plus-sv"
_HIDDEN_STATES_START_POSITION = 2
WAVLM_PRETRAINED_MODEL_ARCHIVE_LIST = [
@ -87,6 +91,38 @@ class WavLMBaseModelOutput(ModelOutput):
attentions: Optional[Tuple[torch.FloatTensor]] = None
@dataclass
class XVectorOutput(ModelOutput):
"""
Output type of :class:`~transformers.Wav2Vec2ForXVector`.
Args:
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`labels` is provided):
Classification loss.
logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.xvector_output_dim)`):
Classification hidden states before AMSoftmax.
embeddings (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.xvector_output_dim)`):
Utterance embeddings used for vector similarity-based retrieval.
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None
embeddings: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
# Copied from transformers.models.wav2vec2.modeling_wav2vec2._compute_mask_indices
def _compute_mask_indices(
shape: Tuple[int, int],
@ -1447,3 +1483,285 @@ class WavLMForSequenceClassification(WavLMPreTrainedModel):
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
@add_start_docstrings(
"""
WavLM Model with a frame classification head on top for tasks like Speaker Diarization.
""",
WAVLM_START_DOCSTRING,
)
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForAudioFrameClassification with Wav2Vec2->WavLM, wav2vec2->wavlm, WAV_2_VEC_2->WAVLM
class WavLMForAudioFrameClassification(WavLMPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.wavlm = WavLMModel(config)
num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings
if config.use_weighted_layer_sum:
self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
self.init_weights()
def freeze_feature_extractor(self):
"""
Calling this function will disable the gradient computation for the feature extractor so that its parameters
will not be updated during training.
"""
self.wavlm.feature_extractor._freeze_parameters()
def freeze_base_model(self):
"""
Calling this function will disable the gradient computation for the base model so that its parameters will not
be updated during training. Only the classification head will be updated.
"""
for param in self.wavlm.parameters():
param.requires_grad = False
@add_start_docstrings_to_model_forward(WAVLM_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
processor_class=_FEAT_EXTRACTOR_FOR_DOC,
checkpoint=_FRAME_CLASS_CHECKPOINT,
output_type=TokenClassifierOutput,
config_class=_CONFIG_FOR_DOC,
modality="audio",
)
def forward(
self,
input_values,
attention_mask=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ...,
config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states
outputs = self.wavlm(
input_values,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
if self.config.use_weighted_layer_sum:
hidden_states = outputs[_HIDDEN_STATES_START_POSITION]
hidden_states = torch.stack(hidden_states, dim=1)
norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
else:
hidden_states = outputs[0]
logits = self.classifier(hidden_states)
if not return_dict:
output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
return output
return TokenClassifierOutput(
loss=None,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.AMSoftmaxLoss
class AMSoftmaxLoss(nn.Module):
def __init__(self, input_dim, num_labels, scale=30.0, margin=0.4):
super(AMSoftmaxLoss, self).__init__()
self.scale = scale
self.margin = margin
self.num_labels = num_labels
self.weight = nn.Parameter(torch.randn(input_dim, num_labels), requires_grad=True)
self.loss = nn.CrossEntropyLoss()
def forward(self, hidden_states, labels):
labels = labels.flatten()
weight = nn.functional.normalize(self.weight, dim=0)
hidden_states = nn.functional.normalize(hidden_states, dim=1)
cos_theta = torch.mm(hidden_states, weight)
psi = cos_theta - self.margin
onehot = nn.functional.one_hot(labels, self.num_labels)
logits = self.scale * torch.where(onehot.bool(), psi, cos_theta)
loss = self.loss(logits, labels)
return loss
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.TDNNLayer
class TDNNLayer(nn.Module):
def __init__(self, config, layer_id=0):
super().__init__()
self.in_conv_dim = config.tdnn_dim[layer_id - 1] if layer_id > 0 else config.tdnn_dim[layer_id]
self.out_conv_dim = config.tdnn_dim[layer_id]
self.kernel_size = config.tdnn_kernel[layer_id]
self.dilation = config.tdnn_dilation[layer_id]
self.kernel = nn.Linear(self.in_conv_dim * self.kernel_size, self.out_conv_dim)
self.activation = nn.ReLU()
def forward(self, hidden_states):
hidden_states = hidden_states.unsqueeze(1)
hidden_states = nn.functional.unfold(
hidden_states,
(self.kernel_size, self.in_conv_dim),
stride=(1, self.in_conv_dim),
dilation=(self.dilation, 1),
)
hidden_states = hidden_states.transpose(1, 2)
hidden_states = self.kernel(hidden_states)
hidden_states = self.activation(hidden_states)
return hidden_states
@add_start_docstrings(
"""
WavLM Model with an XVector feature extraction head on top for tasks like Speaker Verification.
""",
WAVLM_START_DOCSTRING,
)
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForXVector with Wav2Vec2->WavLM, wav2vec2->wavlm, WAV_2_VEC_2->WAVLM
class WavLMForXVector(WavLMPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.wavlm = WavLMModel(config)
num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings
if config.use_weighted_layer_sum:
self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
self.projector = nn.Linear(config.hidden_size, config.tdnn_dim[0])
tdnn_layers = [TDNNLayer(config, i) for i in range(len(config.tdnn_dim))]
self.tdnn = nn.ModuleList(tdnn_layers)
self.feature_extractor = nn.Linear(config.tdnn_dim[-1] * 2, config.xvector_output_dim)
self.classifier = nn.Linear(config.xvector_output_dim, config.xvector_output_dim)
self.objective = AMSoftmaxLoss(config.xvector_output_dim, config.num_labels)
self.init_weights()
def freeze_feature_extractor(self):
"""
Calling this function will disable the gradient computation for the feature extractor so that its parameters
will not be updated during training.
"""
self.wavlm.feature_extractor._freeze_parameters()
def freeze_base_model(self):
"""
Calling this function will disable the gradient computation for the base model so that its parameters will not
be updated during training. Only the classification head will be updated.
"""
for param in self.wavlm.parameters():
param.requires_grad = False
def _get_tdnn_output_lengths(self, input_lengths: Union[torch.LongTensor, int]):
"""
Computes the output length of the TDNN layers
"""
def _conv_out_length(input_length, kernel_size, stride):
# 1D convolutional layer output length formula taken
# from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
return (input_length - kernel_size) // stride + 1
for kernel_size in self.config.tdnn_kernel:
input_lengths = _conv_out_length(input_lengths, kernel_size, 1)
return input_lengths
@add_start_docstrings_to_model_forward(WAVLM_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
processor_class=_FEAT_EXTRACTOR_FOR_DOC,
checkpoint=_XVECTOR_CHECKPOINT,
output_type=XVectorOutput,
config_class=_CONFIG_FOR_DOC,
modality="audio",
)
def forward(
self,
input_values,
attention_mask=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
labels=None,
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ...,
config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states
outputs = self.wavlm(
input_values,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
if self.config.use_weighted_layer_sum:
hidden_states = outputs[_HIDDEN_STATES_START_POSITION]
hidden_states = torch.stack(hidden_states, dim=1)
norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
else:
hidden_states = outputs[0]
hidden_states = self.projector(hidden_states)
for tdnn_layer in self.tdnn:
hidden_states = tdnn_layer(hidden_states)
# Statistic Pooling
if attention_mask is None:
mean_features = hidden_states.mean(dim=1)
std_features = hidden_states.std(dim=1)
else:
feat_extract_output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(dim=1))
tdnn_output_lengths = self._get_tdnn_output_lengths(feat_extract_output_lengths)
mean_features = []
std_features = []
for i, length in enumerate(tdnn_output_lengths):
mean_features.append(hidden_states[i, :length].mean(dim=0))
std_features.append(hidden_states[i, :length].std(dim=0))
mean_features = torch.stack(mean_features)
std_features = torch.stack(std_features)
statistic_pooling = torch.cat([mean_features, std_features], dim=-1)
output_embeddings = self.feature_extractor(statistic_pooling)
logits = self.classifier(output_embeddings)
loss = None
if labels is not None:
loss = self.objective(logits, labels)
if not return_dict:
output = (logits, output_embeddings) + outputs[_HIDDEN_STATES_START_POSITION:]
return ((loss,) + output) if loss is not None else output
return XVectorOutput(
loss=loss,
logits=logits,
embeddings=output_embeddings,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)

View File

@ -5177,6 +5177,11 @@ class Wav2Vec2PreTrainedModel:
WAVLM_PRETRAINED_MODEL_ARCHIVE_LIST = None
class WavLMForAudioFrameClassification:
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class WavLMForCTC:
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@ -5194,6 +5199,11 @@ class WavLMForSequenceClassification:
requires_backends(self, ["torch"])
class WavLMForXVector:
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class WavLMModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])

View File

@ -863,10 +863,10 @@ class UniSpeechSatModelIntegrationTest(unittest.TestCase):
self.assertTrue(torch.allclose(outputs.last_hidden_state[:, :2, -2:], expected_hidden_states_slice, atol=1e-3))
def test_inference_diarization(self):
model = UniSpeechSatForAudioFrameClassification.from_pretrained("anton-l/unispeech-sat-base-plus-sd").to(
model = UniSpeechSatForAudioFrameClassification.from_pretrained("microsoft/unispeech-sat-base-plus-sd").to(
torch_device
)
processor = Wav2Vec2FeatureExtractor.from_pretrained("anton-l/unispeech-sat-base-plus-sd")
processor = Wav2Vec2FeatureExtractor.from_pretrained("microsoft/unispeech-sat-base-plus-sd")
input_data = self._load_superb("sd", 4)
inputs = processor(input_data["speech"], return_tensors="pt", padding=True, sampling_rate=16_000)
@ -892,8 +892,8 @@ class UniSpeechSatModelIntegrationTest(unittest.TestCase):
self.assertTrue(torch.allclose(outputs.logits[:, :4], expected_logits, atol=1e-3))
def test_inference_speaker_verification(self):
model = UniSpeechSatForXVector.from_pretrained("anton-l/unispeech-sat-base-plus-sv").to(torch_device)
processor = Wav2Vec2FeatureExtractor.from_pretrained("anton-l/unispeech-sat-base-plus-sv")
model = UniSpeechSatForXVector.from_pretrained("microsoft/unispeech-sat-base-plus-sv").to(torch_device)
processor = Wav2Vec2FeatureExtractor.from_pretrained("microsoft/unispeech-sat-base-plus-sv")
input_data = self._load_superb("si", 4)
inputs = processor(input_data["speech"], return_tensors="pt", padding=True)

View File

@ -31,7 +31,14 @@ from .test_modeling_common import ModelTesterMixin, _config_zero_init
if is_torch_available():
import torch
from transformers import Wav2Vec2FeatureExtractor, WavLMForCTC, WavLMForSequenceClassification, WavLMModel
from transformers import (
Wav2Vec2FeatureExtractor,
WavLMForAudioFrameClassification,
WavLMForCTC,
WavLMForSequenceClassification,
WavLMForXVector,
WavLMModel,
)
class WavLMModelTester:
@ -60,6 +67,10 @@ class WavLMModelTester:
initializer_range=0.02,
vocab_size=32,
do_stable_layer_norm=False,
tdnn_dim=(32, 32),
tdnn_kernel=(3, 3),
tdnn_dilation=(1, 1),
xvector_output_dim=32,
scope=None,
):
self.parent = parent
@ -85,6 +96,10 @@ class WavLMModelTester:
self.initializer_range = initializer_range
self.vocab_size = vocab_size
self.do_stable_layer_norm = do_stable_layer_norm
self.tdnn_dim = tdnn_dim
self.tdnn_kernel = tdnn_kernel
self.tdnn_dilation = tdnn_dilation
self.xvector_output_dim = xvector_output_dim
self.scope = scope
output_seq_length = self.seq_length
@ -121,6 +136,10 @@ class WavLMModelTester:
hidden_act=self.hidden_act,
initializer_range=self.initializer_range,
vocab_size=self.vocab_size,
tdnn_dim=self.tdnn_dim,
tdnn_kernel=self.tdnn_kernel,
tdnn_dilation=self.tdnn_dilation,
xvector_output_dim=self.xvector_output_dim,
)
def create_and_check_model(self, config, input_values, attention_mask):
@ -285,7 +304,11 @@ class WavLMModelTester:
@require_torch
class WavLMModelTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = (WavLMForCTC, WavLMModel, WavLMForSequenceClassification) if is_torch_available() else ()
all_model_classes = (
(WavLMForCTC, WavLMModel, WavLMForAudioFrameClassification, WavLMForSequenceClassification, WavLMForXVector)
if is_torch_available()
else ()
)
test_pruning = False
test_headmasking = False
test_torchscript = False
@ -398,6 +421,7 @@ class WavLMModelTest(ModelTesterMixin, unittest.TestCase):
"feature_projection.projection.bias",
"label_embeddings_concat",
"rel_attn_embed",
"objective.weight",
]
if param.requires_grad:
if any([x in name for x in uniform_init_parms]):
@ -446,6 +470,11 @@ class WavLMModelIntegrationTest(unittest.TestCase):
return [x["array"] for x in speech_samples]
def _load_superb(self, task, num_samples):
ds = load_dataset("anton-l/superb_dummy", task, split="test")
return ds[:num_samples]
def test_inference_base(self):
model = WavLMModel.from_pretrained("microsoft/wavlm-base-plus").to(torch_device)
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
@ -491,3 +520,54 @@ class WavLMModelIntegrationTest(unittest.TestCase):
[[[0.1612, 0.4314], [0.1690, 0.4344]], [[0.2086, 0.1396], [0.3014, 0.0903]]]
)
self.assertTrue(torch.allclose(hidden_states_slice, EXPECTED_HIDDEN_STATES_SLICE, rtol=1e-2))
def test_inference_diarization(self):
model = WavLMForAudioFrameClassification.from_pretrained("microsoft/wavlm-base-plus-sd").to(torch_device)
processor = Wav2Vec2FeatureExtractor.from_pretrained("microsoft/wavlm-base-plus-sd")
input_data = self._load_superb("sd", 4)
inputs = processor(input_data["speech"], return_tensors="pt", padding=True, sampling_rate=16_000)
input_values = inputs.input_values.to(torch_device)
attention_mask = inputs.attention_mask.to(torch_device)
with torch.no_grad():
outputs = model(input_values, attention_mask=attention_mask)
# labels is a one-hot array of shape (num_frames, num_speakers)
labels = (outputs.logits > 0).long()
# s3prl logits for the same batch
expected_logits = torch.tensor(
[
[[-5.9566, -8.6554], [-5.7137, -8.9386], [-5.7906, -7.0973], [-5.7829, -5.9999]],
[[-5.2086, -7.7878], [-4.8890, -7.9312], [-4.2004, -3.9101], [-5.4480, -4.6932]],
[[-4.6105, -6.7178], [-5.1930, -6.1635], [-2.6228, -4.1123], [-2.7646, -3.1576]],
[[-4.4477, -7.9206], [-3.9339, -7.3707], [-4.9528, -4.8242], [-3.6921, -2.9687]],
],
device=torch_device,
)
self.assertEqual(labels[0, :, 0].sum(), 258)
self.assertEqual(labels[0, :, 1].sum(), 647)
self.assertTrue(torch.allclose(outputs.logits[:, :4], expected_logits, atol=1e-3))
def test_inference_speaker_verification(self):
model = WavLMForXVector.from_pretrained("microsoft/wavlm-base-plus-sv").to(torch_device)
processor = Wav2Vec2FeatureExtractor.from_pretrained("microsoft/wavlm-base-plus-sv")
input_data = self._load_superb("si", 4)
inputs = processor(input_data["speech"], return_tensors="pt", padding=True)
labels = torch.tensor([5, 1, 1, 3], device=torch_device).T
with torch.no_grad():
input_values = inputs.input_values.to(torch_device)
attention_mask = inputs.attention_mask.to(torch_device)
outputs = model(input_values, attention_mask=attention_mask, labels=labels)
embeddings = torch.nn.functional.normalize(outputs.embeddings, dim=-1)
cosine_sim = torch.nn.CosineSimilarity(dim=-1)
# id10002 vs id10002
self.assertAlmostEqual(cosine_sim(embeddings[1], embeddings[2]).item(), 0.9787, 3)
# id10006 vs id10002
self.assertAlmostEqual(cosine_sim(embeddings[0], embeddings[1]).item(), 0.5064, 3)
# id10002 vs id10004
self.assertAlmostEqual(cosine_sim(embeddings[2], embeddings[3]).item(), 0.4780, 3)
self.assertAlmostEqual(outputs.loss.item(), 18.4154, 3)