transformers/src/transformers/models/unispeech/modeling_unispeech.py
Raushan Turganbay e435574721
🚨 Don't use cache in non-generative models (#38751)
* deprecate for 1 version

* style

* fix some tests

* fix esm

* skip for now, GC requires positional args but we have keyword args

* remove transpose for scores in modified models only

* skip fx trace tests
2025-07-01 09:08:21 +00:00

1521 lines
63 KiB
Python
Executable File

# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# This file was automatically generated from src/transformers/models/unispeech/modular_unispeech.py.
# Do NOT edit this file manually as any edits will be overwritten by the generation of
# the file from the modular. If any change should be done, please apply the change to the
# modular_unispeech.py file directly. One of our CI enforces this.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# coding=utf-8
# Copyright 2021 The Fairseq Authors and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import warnings
from dataclasses import dataclass
from typing import Callable, Optional, Union
import numpy as np
import torch
import torch.nn as nn
from torch.nn import CrossEntropyLoss
from ...activations import ACT2FN
from ...integrations.deepspeed import is_deepspeed_zero3_enabled
from ...integrations.fsdp import is_fsdp_managed_module
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_attention_mask_for_sdpa
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import (
BaseModelOutput,
CausalLMOutput,
ModelOutput,
SequenceClassifierOutput,
Wav2Vec2BaseModelOutput,
)
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...utils import auto_docstring, is_torch_flex_attn_available, logging
from ...utils.deprecation import deprecate_kwarg
from .configuration_unispeech import UniSpeechConfig
if is_torch_flex_attn_available():
from ...integrations.flex_attention import make_flex_block_causal_mask
logger = logging.get_logger(__name__)
@dataclass
@auto_docstring(
custom_intro="""
Output type of [`UniSpeechForPreTrainingOutput`], with potential hidden states and attentions.
"""
)
class UniSpeechForPreTrainingOutput(ModelOutput):
r"""
loss (*optional*, returned when model is in train mode, `torch.FloatTensor` of shape `(1,)`):
Total loss as the sum of the contrastive loss (L_m) and the diversity loss (L_d) as stated in the [official
paper](https://arxiv.org/pdf/2006.11477.pdf) . (classification) loss.
projected_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.proj_codevector_dim)`):
Hidden-states of the model projected to *config.proj_codevector_dim* that can be used to predict the masked
projected quantized states.
projected_quantized_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.proj_codevector_dim)`):
Quantized extracted feature vectors projected to *config.proj_codevector_dim* representing the positive
target vectors for contrastive loss.
codevector_perplexity (`torch.FloatTensor` of shape `(1,)`):
The perplexity of the codevector distribution, used to measure the diversity of the codebook.
"""
loss: Optional[torch.FloatTensor] = None
projected_states: Optional[torch.FloatTensor] = None
projected_quantized_states: Optional[torch.FloatTensor] = None
codevector_perplexity: Optional[torch.FloatTensor] = None
hidden_states: Optional[tuple[torch.FloatTensor]] = None
attentions: Optional[tuple[torch.FloatTensor]] = None
class UniSpeechSamePadLayer(nn.Module):
def __init__(self, num_conv_pos_embeddings):
super().__init__()
self.num_pad_remove = 1 if num_conv_pos_embeddings % 2 == 0 else 0
def forward(self, hidden_states):
if self.num_pad_remove > 0:
hidden_states = hidden_states[:, :, : -self.num_pad_remove]
return hidden_states
class UniSpeechPositionalConvEmbedding(nn.Module):
def __init__(self, config):
super().__init__()
self.conv = nn.Conv1d(
config.hidden_size,
config.hidden_size,
kernel_size=config.num_conv_pos_embeddings,
padding=config.num_conv_pos_embeddings // 2,
groups=config.num_conv_pos_embedding_groups,
)
weight_norm = nn.utils.weight_norm
if hasattr(nn.utils.parametrizations, "weight_norm"):
weight_norm = nn.utils.parametrizations.weight_norm
if is_deepspeed_zero3_enabled():
import deepspeed
with deepspeed.zero.GatheredParameters(self.conv.weight, modifier_rank=0):
self.conv = weight_norm(self.conv, name="weight", dim=2)
if hasattr(self.conv, "parametrizations"):
weight_g = self.conv.parametrizations.weight.original0
weight_v = self.conv.parametrizations.weight.original1
else:
weight_g = self.conv.weight_g
weight_v = self.conv.weight_v
deepspeed.zero.register_external_parameter(self, weight_v)
deepspeed.zero.register_external_parameter(self, weight_g)
else:
self.conv = weight_norm(self.conv, name="weight", dim=2)
self.padding = UniSpeechSamePadLayer(config.num_conv_pos_embeddings)
self.activation = ACT2FN[config.feat_extract_activation]
def forward(self, hidden_states):
hidden_states = hidden_states.transpose(1, 2)
hidden_states = self.conv(hidden_states)
hidden_states = self.padding(hidden_states)
hidden_states = self.activation(hidden_states)
hidden_states = hidden_states.transpose(1, 2)
return hidden_states
class UniSpeechNoLayerNormConvLayer(GradientCheckpointingLayer):
def __init__(self, config, layer_id=0):
super().__init__()
self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
self.out_conv_dim = config.conv_dim[layer_id]
self.conv = nn.Conv1d(
self.in_conv_dim,
self.out_conv_dim,
kernel_size=config.conv_kernel[layer_id],
stride=config.conv_stride[layer_id],
bias=config.conv_bias,
)
self.activation = ACT2FN[config.feat_extract_activation]
def forward(self, hidden_states):
hidden_states = self.conv(hidden_states)
hidden_states = self.activation(hidden_states)
return hidden_states
class UniSpeechLayerNormConvLayer(GradientCheckpointingLayer):
def __init__(self, config, layer_id=0):
super().__init__()
self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
self.out_conv_dim = config.conv_dim[layer_id]
self.conv = nn.Conv1d(
self.in_conv_dim,
self.out_conv_dim,
kernel_size=config.conv_kernel[layer_id],
stride=config.conv_stride[layer_id],
bias=config.conv_bias,
)
self.layer_norm = nn.LayerNorm(self.out_conv_dim, elementwise_affine=True)
self.activation = ACT2FN[config.feat_extract_activation]
def forward(self, hidden_states):
hidden_states = self.conv(hidden_states)
hidden_states = hidden_states.transpose(-2, -1)
hidden_states = self.layer_norm(hidden_states)
hidden_states = hidden_states.transpose(-2, -1)
hidden_states = self.activation(hidden_states)
return hidden_states
class UniSpeechGroupNormConvLayer(GradientCheckpointingLayer):
def __init__(self, config, layer_id=0):
super().__init__()
self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
self.out_conv_dim = config.conv_dim[layer_id]
self.conv = nn.Conv1d(
self.in_conv_dim,
self.out_conv_dim,
kernel_size=config.conv_kernel[layer_id],
stride=config.conv_stride[layer_id],
bias=config.conv_bias,
)
self.activation = ACT2FN[config.feat_extract_activation]
self.layer_norm = nn.GroupNorm(num_groups=self.out_conv_dim, num_channels=self.out_conv_dim, affine=True)
def forward(self, hidden_states):
hidden_states = self.conv(hidden_states)
hidden_states = self.layer_norm(hidden_states)
hidden_states = self.activation(hidden_states)
return hidden_states
class UniSpeechFeatureEncoder(nn.Module):
"""Construct the features from raw audio waveform"""
def __init__(self, config):
super().__init__()
if config.feat_extract_norm == "group":
conv_layers = [UniSpeechGroupNormConvLayer(config, layer_id=0)] + [
UniSpeechNoLayerNormConvLayer(config, layer_id=i + 1)
for i in range(config.num_feat_extract_layers - 1)
]
elif config.feat_extract_norm == "layer":
conv_layers = [
UniSpeechLayerNormConvLayer(config, layer_id=i) for i in range(config.num_feat_extract_layers)
]
else:
raise ValueError(
f"`config.feat_extract_norm` is {config.feat_extract_norm}, but has to be one of ['group', 'layer']"
)
self.conv_layers = nn.ModuleList(conv_layers)
self.gradient_checkpointing = False
self._requires_grad = True
def _freeze_parameters(self):
for param in self.parameters():
param.requires_grad = False
self._requires_grad = False
def forward(self, input_values):
hidden_states = input_values[:, None]
# make sure hidden_states require grad for gradient_checkpointing
if self._requires_grad and self.training:
hidden_states.requires_grad = True
for conv_layer in self.conv_layers:
hidden_states = conv_layer(hidden_states)
return hidden_states
class UniSpeechFeatureProjection(nn.Module):
def __init__(self, config):
super().__init__()
self.layer_norm = nn.LayerNorm(config.conv_dim[-1], eps=config.layer_norm_eps)
self.projection = nn.Linear(config.conv_dim[-1], config.hidden_size)
self.dropout = nn.Dropout(config.feat_proj_dropout)
def forward(self, hidden_states):
# non-projected hidden states are needed for quantization
norm_hidden_states = self.layer_norm(hidden_states)
hidden_states = self.projection(norm_hidden_states)
hidden_states = self.dropout(hidden_states)
return hidden_states, norm_hidden_states
def eager_attention_forward(
module: nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
scaling: Optional[float] = None,
dropout: float = 0.0,
head_mask: Optional[torch.Tensor] = None,
**kwargs,
):
if scaling is None:
scaling = query.size(-1) ** -0.5
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
if attention_mask is not None:
attn_weights = attn_weights + attention_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
if head_mask is not None:
attn_weights = attn_weights * head_mask.view(1, -1, 1, 1)
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
attn_output = torch.matmul(attn_weights, value)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, attn_weights
class UniSpeechAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(
self,
embed_dim: int,
num_heads: int,
dropout: float = 0.0,
is_decoder: bool = False,
bias: bool = True,
is_causal: bool = False,
config: Optional[UniSpeechConfig] = None,
):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.dropout = dropout
self.head_dim = embed_dim // num_heads
self.config = config
if (self.head_dim * num_heads) != self.embed_dim:
raise ValueError(
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
f" and `num_heads`: {num_heads})."
)
self.scaling = self.head_dim**-0.5
self.is_decoder = is_decoder
self.is_causal = is_causal
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
@deprecate_kwarg("past_key_value", version="4.54.0")
def forward(
self,
hidden_states: torch.Tensor,
key_value_states: Optional[torch.Tensor] = None,
past_key_value: Optional[tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = False,
# TODO: we need a refactor so that the different attention modules can get their specific kwargs
# ATM, we have mixed things encoder, decoder, and encoder-decoder attn
**kwargs: Unpack[FlashAttentionKwargs],
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""
# if key_value_states are provided this layer is used as a cross-attention layer
# for the decoder
is_cross_attention = key_value_states is not None
# determine input shapes
bsz, tgt_len = hidden_states.shape[:-1]
src_len = key_value_states.shape[1] if is_cross_attention else tgt_len
q_input_shape = (bsz, tgt_len, -1, self.head_dim)
kv_input_shape = (bsz, src_len, -1, self.head_dim)
# get query proj
query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2)
current_states = key_value_states if is_cross_attention else hidden_states
key_states = self.k_proj(current_states).view(*kv_input_shape).transpose(1, 2)
value_states = self.v_proj(current_states).view(*kv_input_shape).transpose(1, 2)
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
attn_output, attn_weights = attention_interface(
self,
query_states,
key_states,
value_states,
attention_mask,
dropout=0.0 if not self.training else self.dropout,
scaling=self.scaling,
output_attentions=output_attentions,
head_mask=layer_head_mask,
**kwargs,
)
attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous()
attn_output = self.out_proj(attn_output)
return attn_output, attn_weights, None
class UniSpeechFeedForward(nn.Module):
def __init__(self, config):
super().__init__()
self.intermediate_dropout = nn.Dropout(config.activation_dropout)
self.intermediate_dense = nn.Linear(config.hidden_size, config.intermediate_size)
if isinstance(config.hidden_act, str):
self.intermediate_act_fn = ACT2FN[config.hidden_act]
else:
self.intermediate_act_fn = config.hidden_act
self.output_dense = nn.Linear(config.intermediate_size, config.hidden_size)
self.output_dropout = nn.Dropout(config.hidden_dropout)
def forward(self, hidden_states):
hidden_states = self.intermediate_dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states)
hidden_states = self.intermediate_dropout(hidden_states)
hidden_states = self.output_dense(hidden_states)
hidden_states = self.output_dropout(hidden_states)
return hidden_states
class UniSpeechEncoderLayer(GradientCheckpointingLayer):
def __init__(self, config):
super().__init__()
self.attention = UniSpeechAttention(
embed_dim=config.hidden_size,
num_heads=config.num_attention_heads,
dropout=config.attention_dropout,
is_decoder=False,
config=config,
)
self.dropout = nn.Dropout(config.hidden_dropout)
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.feed_forward = UniSpeechFeedForward(config)
self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
def forward(self, hidden_states, attention_mask=None, output_attentions=False):
attn_residual = hidden_states
hidden_states, attn_weights, _ = self.attention(
hidden_states, attention_mask=attention_mask, output_attentions=output_attentions
)
hidden_states = self.dropout(hidden_states)
hidden_states = attn_residual + hidden_states
hidden_states = self.layer_norm(hidden_states)
hidden_states = hidden_states + self.feed_forward(hidden_states)
hidden_states = self.final_layer_norm(hidden_states)
outputs = (hidden_states,)
if output_attentions:
outputs += (attn_weights,)
return outputs
class UniSpeechEncoder(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.pos_conv_embed = UniSpeechPositionalConvEmbedding(config)
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout)
self.layers = nn.ModuleList([UniSpeechEncoderLayer(config) for _ in range(config.num_hidden_layers)])
self.gradient_checkpointing = False
def forward(
self,
hidden_states: torch.tensor,
attention_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
):
all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None
if attention_mask is not None:
# make sure padded tokens output 0
expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
hidden_states[~expand_attention_mask] = 0
attention_mask = self._update_full_mask(
attention_mask,
hidden_states,
)
position_embeddings = self.pos_conv_embed(hidden_states)
hidden_states = hidden_states + position_embeddings
hidden_states = self.layer_norm(hidden_states)
hidden_states = self.dropout(hidden_states)
synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self)
for layer in self.layers:
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
# add LayerDrop (see https://huggingface.co/papers/1909.11556 for description)
dropout_probability = torch.rand([])
skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False
if not skip_the_layer or synced_gpus:
# under fsdp or deepspeed zero3 all gpus must run in sync
layer_outputs = layer(
hidden_states, attention_mask=attention_mask, output_attentions=output_attentions
)
hidden_states = layer_outputs[0]
if skip_the_layer:
layer_outputs = (None, None)
if output_attentions:
all_self_attentions = all_self_attentions + (layer_outputs[1],)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict:
return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
return BaseModelOutput(
last_hidden_state=hidden_states,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
)
def _update_full_mask(
self,
attention_mask: Union[torch.Tensor, None],
inputs_embeds: torch.Tensor,
):
if attention_mask is not None:
if self.config._attn_implementation == "flash_attention_2":
attention_mask = attention_mask if 0 in attention_mask else None
elif self.config._attn_implementation == "sdpa":
# output_attentions=True & head_mask can not be supported when using SDPA, fall back to
# the manual implementation that requires a 4D causal mask in all cases.
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype)
elif self.config._attn_implementation == "flex_attention":
if isinstance(attention_mask, torch.Tensor):
attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False)
else:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype)
return attention_mask
class UniSpeechAttnAdapterLayer(nn.Module):
def __init__(self, config):
"""
Implements adapter modules directly with 3D tensor weight as parameters and without using ModuleList to speed
up training throughput.
"""
super().__init__()
self.input_dim = config.adapter_attn_dim
self.hidden_dim = config.hidden_size
self.norm = nn.LayerNorm(self.hidden_dim)
self.linear_1 = nn.Linear(self.hidden_dim, self.input_dim)
self.act_fn = nn.ReLU()
self.linear_2 = nn.Linear(self.input_dim, self.hidden_dim)
def forward(self, hidden_states: torch.FloatTensor):
hidden_states = self.norm(hidden_states)
hidden_states = self.linear_1(hidden_states)
hidden_states = self.act_fn(hidden_states)
hidden_states = self.linear_2(hidden_states)
return hidden_states
class UniSpeechEncoderLayerStableLayerNorm(GradientCheckpointingLayer):
def __init__(self, config):
super().__init__()
self.attention = UniSpeechAttention(
embed_dim=config.hidden_size,
num_heads=config.num_attention_heads,
dropout=config.attention_dropout,
is_decoder=False,
config=config,
)
self.dropout = nn.Dropout(config.hidden_dropout)
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.feed_forward = UniSpeechFeedForward(config)
self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
if getattr(config, "adapter_attn_dim", None) is not None:
self.adapter_layer = UniSpeechAttnAdapterLayer(config)
else:
self.adapter_layer = None
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
):
attn_residual = hidden_states
hidden_states = self.layer_norm(hidden_states)
hidden_states, attn_weights, _ = self.attention(
hidden_states, attention_mask=attention_mask, output_attentions=output_attentions
)
hidden_states = self.dropout(hidden_states)
hidden_states = attn_residual + hidden_states
hidden_states = hidden_states + self.feed_forward(self.final_layer_norm(hidden_states))
if self.adapter_layer is not None:
hidden_states = hidden_states + self.adapter_layer(hidden_states)
outputs = (hidden_states,)
if output_attentions:
outputs += (attn_weights,)
return outputs
class UniSpeechEncoderStableLayerNorm(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.pos_conv_embed = UniSpeechPositionalConvEmbedding(config)
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout)
self.layers = nn.ModuleList(
[UniSpeechEncoderLayerStableLayerNorm(config) for _ in range(config.num_hidden_layers)]
)
self.gradient_checkpointing = False
def forward(
self,
hidden_states,
attention_mask=None,
output_attentions=False,
output_hidden_states=False,
return_dict=True,
):
all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None
if attention_mask is not None:
# make sure padded tokens output 0
expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
hidden_states[~expand_attention_mask] = 0
attention_mask = self._update_full_mask(
attention_mask,
hidden_states,
)
position_embeddings = self.pos_conv_embed(hidden_states)
hidden_states = hidden_states + position_embeddings
hidden_states = self.dropout(hidden_states)
synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self)
for layer in self.layers:
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
# add LayerDrop (see https://huggingface.co/papers/1909.11556 for description)
dropout_probability = torch.rand([])
skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False
if not skip_the_layer or synced_gpus:
# under fsdp or deepspeed zero3 all gpus must run in sync
# XXX: could optimize this like synced_gpus in generate_utils but not sure if it's worth the code complication
layer_outputs = layer(
hidden_states, attention_mask=attention_mask, output_attentions=output_attentions
)
hidden_states = layer_outputs[0]
if skip_the_layer:
layer_outputs = (None, None)
if output_attentions:
all_self_attentions = all_self_attentions + (layer_outputs[1],)
hidden_states = self.layer_norm(hidden_states)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict:
return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
return BaseModelOutput(
last_hidden_state=hidden_states,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
)
def _update_full_mask(
self,
attention_mask: Union[torch.Tensor, None],
inputs_embeds: torch.Tensor,
):
if attention_mask is not None:
if self.config._attn_implementation == "flash_attention_2":
attention_mask = attention_mask if 0 in attention_mask else None
elif self.config._attn_implementation == "sdpa":
# output_attentions=True & head_mask can not be supported when using SDPA, fall back to
# the manual implementation that requires a 4D causal mask in all cases.
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype)
elif self.config._attn_implementation == "flex_attention":
if isinstance(attention_mask, torch.Tensor):
attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False)
else:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype)
return attention_mask
class UniSpeechGumbelVectorQuantizer(nn.Module):
"""
Vector quantization using gumbel softmax. See `[CATEGORICAL REPARAMETERIZATION WITH
GUMBEL-SOFTMAX](https://huggingface.co/papers/1611.01144) for more information.
"""
def __init__(self, config):
super().__init__()
self.num_groups = config.num_codevector_groups
self.num_vars = config.num_codevectors_per_group
if config.codevector_dim % self.num_groups != 0:
raise ValueError(
f"`config.codevector_dim {config.codevector_dim} must be divisible "
f"by `config.num_codevector_groups` {self.num_groups} for concatenation"
)
# storage for codebook variables (codewords)
self.codevectors = nn.Parameter(
torch.FloatTensor(1, self.num_groups * self.num_vars, config.codevector_dim // self.num_groups)
)
self.weight_proj = nn.Linear(config.conv_dim[-1], self.num_groups * self.num_vars)
# can be decayed for training
self.temperature = 2
@staticmethod
def _compute_perplexity(probs):
marginal_probs = probs.mean(dim=0)
perplexity = torch.exp(-torch.sum(marginal_probs * torch.log(marginal_probs + 1e-7), dim=-1)).sum()
return perplexity
def forward(self, hidden_states):
batch_size, sequence_length, hidden_size = hidden_states.shape
# project to codevector dim
hidden_states = self.weight_proj(hidden_states)
hidden_states = hidden_states.view(batch_size * sequence_length * self.num_groups, -1)
if self.training:
# sample code vector probs via gumbel in differentiateable way
codevector_probs = nn.functional.gumbel_softmax(
hidden_states.float(), tau=self.temperature, hard=True
).type_as(hidden_states)
# compute perplexity
codevector_soft_dist = torch.softmax(
hidden_states.view(batch_size * sequence_length, self.num_groups, -1).float(), dim=-1
)
perplexity = self._compute_perplexity(codevector_soft_dist)
else:
# take argmax in non-differentiable way
# comptute hard codevector distribution (one hot)
codevector_idx = hidden_states.argmax(dim=-1)
codevector_probs = hidden_states.new_zeros(*hidden_states.shape).scatter_(
-1, codevector_idx.view(-1, 1), 1.0
)
codevector_probs = codevector_probs.view(batch_size * sequence_length, self.num_groups, -1)
perplexity = self._compute_perplexity(codevector_probs)
codevector_probs = codevector_probs.view(batch_size * sequence_length, -1)
# use probs to retrieve codevectors
codevectors_per_group = codevector_probs.unsqueeze(-1) * self.codevectors
codevectors = codevectors_per_group.view(batch_size * sequence_length, self.num_groups, self.num_vars, -1)
codevectors = codevectors.sum(-2).view(batch_size, sequence_length, -1)
return codevectors, perplexity
@auto_docstring
class UniSpeechPreTrainedModel(PreTrainedModel):
config_class = UniSpeechConfig
base_model_prefix = "unispeech"
main_input_name = "input_values"
supports_gradient_checkpointing = True
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_flex_attn = True
def _init_weights(self, module):
"""Initialize the weights"""
# gumbel softmax requires special init
if isinstance(module, UniSpeechGumbelVectorQuantizer):
module.weight_proj.weight.data.normal_(mean=0.0, std=1)
module.weight_proj.bias.data.zero_()
nn.init.uniform_(module.codevectors)
elif isinstance(module, UniSpeechPositionalConvEmbedding):
nn.init.normal_(
module.conv.weight,
mean=0,
std=2 * math.sqrt(1 / (module.conv.kernel_size[0] * module.conv.in_channels)),
)
nn.init.constant_(module.conv.bias, 0)
elif isinstance(module, UniSpeechFeatureProjection):
k = math.sqrt(1 / module.projection.in_features)
nn.init.uniform_(module.projection.weight, a=-k, b=k)
nn.init.uniform_(module.projection.bias, a=-k, b=k)
elif isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
elif isinstance(module, nn.Conv1d):
nn.init.kaiming_normal_(module.weight)
if module.bias is not None:
k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0]))
nn.init.uniform_(module.bias, a=-k, b=k)
def _get_feat_extract_output_lengths(self, input_lengths: Union[torch.LongTensor, int]):
"""
Computes the output length of the convolutional layers
"""
def _conv_out_length(input_length, kernel_size, stride):
# 1D convolutional layer output length formula taken
# from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
return torch.div(input_length - kernel_size, stride, rounding_mode="floor") + 1
for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride):
input_lengths = _conv_out_length(input_lengths, kernel_size, stride)
return input_lengths
def _get_feature_vector_attention_mask(self, feature_vector_length: int, attention_mask: torch.LongTensor):
# Effectively attention_mask.sum(-1), but not inplace to be able to run
# on inference mode.
non_padded_lengths = attention_mask.cumsum(dim=-1)[:, -1]
output_lengths = self._get_feat_extract_output_lengths(non_padded_lengths).to(torch.long)
batch_size = attention_mask.shape[0]
attention_mask = torch.zeros(
(batch_size, feature_vector_length), dtype=attention_mask.dtype, device=attention_mask.device
)
# these two operations makes sure that all values before the output lengths idxs are attended to
attention_mask[(torch.arange(attention_mask.shape[0], device=attention_mask.device), output_lengths - 1)] = 1
attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool()
return attention_mask
def _compute_mask_indices(
shape: tuple[int, int],
mask_prob: float,
mask_length: int,
attention_mask: Optional[torch.LongTensor] = None,
min_masks: int = 0,
) -> np.ndarray:
"""
Computes random mask spans for a given shape. Used to implement [SpecAugment: A Simple Data Augmentation Method for
ASR](https://huggingface.co/papers/1904.08779). Note that this method is not optimized to run on TPU and should be run on
CPU as part of the preprocessing during training.
Args:
shape: The shape for which to compute masks. This should be of a tuple of size 2 where
the first element is the batch size and the second element is the length of the axis to span.
mask_prob: The percentage of the whole axis (between 0 and 1) which will be masked. The number of
independently generated mask spans of length `mask_length` is computed by
`mask_prob*shape[1]/mask_length`. Note that due to overlaps, `mask_prob` is an upper bound and the
actual percentage will be smaller.
mask_length: size of the mask
min_masks: minimum number of masked spans
attention_mask: A (right-padded) attention mask which independently shortens the feature axis of
each batch dimension.
"""
batch_size, sequence_length = shape
if mask_length < 1:
raise ValueError("`mask_length` has to be bigger than 0.")
if mask_length > sequence_length:
raise ValueError(
f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length}"
f" and `sequence_length`: {sequence_length}`"
)
# epsilon is used for probabilistic rounding
epsilon = np.random.rand(1).item()
def compute_num_masked_span(input_length):
"""Given input length, compute how many spans should be masked"""
num_masked_span = int(mask_prob * input_length / mask_length + epsilon)
num_masked_span = max(num_masked_span, min_masks)
# make sure num masked span <= sequence_length
if num_masked_span * mask_length > sequence_length:
num_masked_span = sequence_length // mask_length
# make sure num_masked span is also <= input_length - (mask_length - 1)
if input_length - (mask_length - 1) < num_masked_span:
num_masked_span = max(input_length - (mask_length - 1), 0)
return num_masked_span
# compute number of masked spans in batch
input_lengths = (
attention_mask.detach().sum(-1).tolist()
if attention_mask is not None
else [sequence_length for _ in range(batch_size)]
)
# SpecAugment mask to fill
spec_aug_mask = np.zeros((batch_size, sequence_length), dtype=bool)
spec_aug_mask_idxs = []
max_num_masked_span = compute_num_masked_span(sequence_length)
if max_num_masked_span == 0:
return spec_aug_mask
for input_length in input_lengths:
# compute num of masked spans for this input
num_masked_span = compute_num_masked_span(input_length)
# get random indices to mask
spec_aug_mask_idx = np.random.choice(
np.arange(input_length - (mask_length - 1)), num_masked_span, replace=False
)
# pick first sampled index that will serve as a dummy index to pad vector
# to ensure same dimension for all batches due to probabilistic rounding
# Picking first sample just pads those vectors twice.
if len(spec_aug_mask_idx) == 0:
# this case can only happen if `input_length` is strictly smaller then
# `sequence_length` in which case the last token has to be a padding
# token which we can use as a dummy mask id
dummy_mask_idx = sequence_length - 1
else:
dummy_mask_idx = spec_aug_mask_idx[0]
spec_aug_mask_idx = np.concatenate(
[spec_aug_mask_idx, np.ones(max_num_masked_span - num_masked_span, dtype=np.int32) * dummy_mask_idx]
)
spec_aug_mask_idxs.append(spec_aug_mask_idx)
spec_aug_mask_idxs = np.array(spec_aug_mask_idxs)
# expand masked indices to masked spans
spec_aug_mask_idxs = np.broadcast_to(
spec_aug_mask_idxs[:, :, None], (batch_size, max_num_masked_span, mask_length)
)
spec_aug_mask_idxs = spec_aug_mask_idxs.reshape(batch_size, max_num_masked_span * mask_length)
# add offset to the starting indexes so that indexes now create a span
offsets = np.arange(mask_length)[None, None, :]
offsets = np.broadcast_to(offsets, (batch_size, max_num_masked_span, mask_length)).reshape(
batch_size, max_num_masked_span * mask_length
)
spec_aug_mask_idxs = spec_aug_mask_idxs + offsets
# ensure that we cannot have indices larger than sequence_length
if spec_aug_mask_idxs.max() > sequence_length - 1:
spec_aug_mask_idxs[spec_aug_mask_idxs > sequence_length - 1] = sequence_length - 1
# scatter indices to mask
np.put_along_axis(spec_aug_mask, spec_aug_mask_idxs, 1, -1)
return spec_aug_mask
UniSpeechBaseModelOutput = Wav2Vec2BaseModelOutput
@auto_docstring
class UniSpeechModel(UniSpeechPreTrainedModel):
def __init__(self, config: UniSpeechConfig):
super().__init__(config)
self.config = config
self.feature_extractor = UniSpeechFeatureEncoder(config)
self.feature_projection = UniSpeechFeatureProjection(config)
if config.mask_time_prob > 0.0 or config.mask_feature_prob > 0.0:
self.masked_spec_embed = nn.Parameter(torch.Tensor(config.hidden_size).uniform_())
if config.do_stable_layer_norm:
self.encoder = UniSpeechEncoderStableLayerNorm(config)
else:
self.encoder = UniSpeechEncoder(config)
# Initialize weights and apply final processing
self.post_init()
def _mask_hidden_states(
self,
hidden_states: torch.FloatTensor,
mask_time_indices: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
):
"""
Masks extracted features along time axis and/or along feature axis according to
[SpecAugment](https://huggingface.co/papers/1904.08779).
"""
# `config.apply_spec_augment` can set masking to False
if not getattr(self.config, "apply_spec_augment", True):
return hidden_states
# generate indices & apply SpecAugment along time axis
batch_size, sequence_length, hidden_size = hidden_states.size()
if mask_time_indices is not None:
# apply SpecAugment along time axis with given mask_time_indices
hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
elif self.config.mask_time_prob > 0 and self.training:
mask_time_indices = _compute_mask_indices(
(batch_size, sequence_length),
mask_prob=self.config.mask_time_prob,
mask_length=self.config.mask_time_length,
attention_mask=attention_mask,
min_masks=self.config.mask_time_min_masks,
)
mask_time_indices = torch.tensor(mask_time_indices, device=hidden_states.device, dtype=torch.bool)
hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
if self.config.mask_feature_prob > 0 and self.training:
# generate indices & apply SpecAugment along feature axis
mask_feature_indices = _compute_mask_indices(
(batch_size, hidden_size),
mask_prob=self.config.mask_feature_prob,
mask_length=self.config.mask_feature_length,
min_masks=self.config.mask_feature_min_masks,
)
mask_feature_indices = torch.tensor(mask_feature_indices, device=hidden_states.device, dtype=torch.bool)
mask_feature_indices = mask_feature_indices[:, None].expand(-1, sequence_length, -1)
hidden_states[mask_feature_indices] = 0
return hidden_states
@auto_docstring
def forward(
self,
input_values: Optional[torch.Tensor],
attention_mask: Optional[torch.Tensor] = None,
mask_time_indices: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[tuple, UniSpeechBaseModelOutput]:
r"""
mask_time_indices (`torch.BoolTensor` of shape `(batch_size, sequence_length)`, *optional*):
Indices to mask extracted features for contrastive loss. When in training mode, model learns to predict
masked extracted features in *config.proj_codevector_dim* space.
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
extract_features = self.feature_extractor(input_values)
extract_features = extract_features.transpose(1, 2)
if attention_mask is not None:
# compute reduced attention_mask corresponding to feature vectors
attention_mask = self._get_feature_vector_attention_mask(extract_features.shape[1], attention_mask)
hidden_states, extract_features = self.feature_projection(extract_features)
hidden_states = self._mask_hidden_states(
hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask
)
encoder_outputs = self.encoder(
hidden_states,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = encoder_outputs[0]
if not return_dict:
return (hidden_states, extract_features) + encoder_outputs[1:]
return UniSpeechBaseModelOutput(
last_hidden_state=hidden_states,
extract_features=extract_features,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)
@auto_docstring(
custom_intro="""
UniSpeech Model with a vector-quantization module and ctc loss for pre-training.
"""
)
class UniSpeechForPreTraining(UniSpeechPreTrainedModel):
def __init__(self, config: UniSpeechConfig):
super().__init__(config)
self.unispeech = UniSpeechModel(config)
self.dropout_features = nn.Dropout(config.feat_quantizer_dropout)
self.quantizer = UniSpeechGumbelVectorQuantizer(config)
self.project_q = nn.Linear(config.codevector_dim, config.proj_codevector_dim)
self.project_hid = nn.Linear(config.proj_codevector_dim, config.hidden_size)
self.ctc_proj = nn.Linear(config.hidden_size, config.num_ctc_classes)
self.dropout = nn.Dropout(config.final_dropout)
# Initialize weights and apply final processing
self.post_init()
def set_gumbel_temperature(self, temperature: int):
"""
Set the Gumbel softmax temperature to a given value. Only necessary for training
"""
self.quantizer.temperature = temperature
def freeze_feature_extractor(self):
"""
Calling this function will disable the gradient computation for the feature encoder so that its parameters will
not be updated during training.
"""
warnings.warn(
"The method `freeze_feature_extractor` is deprecated and will be removed in Transformers v5. "
"Please use the equivalent `freeze_feature_encoder` method instead.",
FutureWarning,
)
self.freeze_feature_encoder()
def freeze_feature_encoder(self):
"""
Calling this function will disable the gradient computation for the feature encoder so that its parameter will
not be updated during training.
"""
self.unispeech.feature_extractor._freeze_parameters()
@staticmethod
def compute_contrastive_logits(
target_features: torch.FloatTensor,
negative_features: torch.FloatTensor,
predicted_features: torch.FloatTensor,
temperature: int = 1,
):
"""
Compute logits for contrastive loss based using cosine similarity as the distance measure between
`[positive_feature, negative_features]` and `[predicted_features]`. Additionally, temperature can be applied.
"""
target_features = torch.cat([target_features, negative_features], dim=0)
logits = torch.cosine_similarity(predicted_features.float(), target_features.float(), dim=-1)
logits = logits.type_as(target_features)
# apply temperature
logits = logits / temperature
return logits
@auto_docstring
def forward(
self,
input_values: Optional[torch.Tensor],
attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[tuple, UniSpeechForPreTrainingOutput]:
r"""
Example:
```python
>>> import torch
>>> from transformers import AutoFeatureExtractor, UniSpeechForPreTraining
>>> feature_extractor = AutoFeatureExtractor.from_pretrained("microsoft/unispeech-large-1500h-cv")
>>> model = UniSpeechForPreTraining.from_pretrained("microsoft/unispeech-large-1500h-cv")
>>> # TODO: Add full pretraining example
```"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.unispeech(
input_values,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
transformer_features = outputs[0]
# quantize all (unmasked) extracted features and project to final vq dim
extract_features = self.dropout_features(outputs[1])
quantized_features, codevector_perplexity = self.quantizer(extract_features)
# project quantized features twice
quantized_features = self.project_q(quantized_features.to(self.project_q.weight.dtype))
quantized_features = self.project_hid(quantized_features)
prob_replace_matrix = torch.empty(transformer_features.size(0), transformer_features.size(1)).fill_(
self.config.replace_prob
)
prob_replace_matrix = prob_replace_matrix.transpose(0, 1)
sampled_replace_matrix = torch.bernoulli(prob_replace_matrix).bool().to(transformer_features.device)
sampled_replace_matrix = sampled_replace_matrix.transpose(0, 1)
sampled_replace_matrix = sampled_replace_matrix.unsqueeze(-1)
logits = transformer_features.masked_fill(sampled_replace_matrix, 0.0) + (
quantized_features.masked_fill(~sampled_replace_matrix, 0.0)
)
# project to ctc units
logits = self.dropout(logits)
logits = self.ctc_proj(logits)
# TODO(PVP) - add negative sampling & loss computation
loss = None
if not return_dict:
if loss is not None:
return (loss, transformer_features, quantized_features, codevector_perplexity) + outputs[2:]
return (transformer_features, quantized_features, codevector_perplexity) + outputs[2:]
return UniSpeechForPreTrainingOutput(
loss=loss,
projected_states=transformer_features,
projected_quantized_states=quantized_features,
codevector_perplexity=codevector_perplexity,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
_HIDDEN_STATES_START_POSITION = 2
@auto_docstring(
custom_intro="""
UniSpeech Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC).
"""
)
class UniSpeechForCTC(UniSpeechPreTrainedModel):
def __init__(self, config, target_lang: Optional[str] = None):
r"""
target_lang (`str`, *optional*):
Language id of adapter weights. Adapter weights are stored in the format adapter.<lang>.safetensors or
adapter.<lang>.bin. Only relevant when using an instance of [`UniSpeechForCTC`] with adapters. Uses 'eng' by
default.
"""
super().__init__(config)
self.unispeech = UniSpeechModel(config)
self.dropout = nn.Dropout(config.final_dropout)
self.target_lang = target_lang
if config.vocab_size is None:
raise ValueError(
f"You are trying to instantiate {self.__class__} with a configuration that "
"does not define the vocabulary size of the language model head. Please "
"instantiate the model as follows: `UniSpeechForCTC.from_pretrained(..., vocab_size=vocab_size)`. "
"or define `vocab_size` of your model's configuration."
)
output_hidden_size = (
config.output_hidden_size if hasattr(config, "add_adapter") and config.add_adapter else config.hidden_size
)
self.lm_head = nn.Linear(output_hidden_size, config.vocab_size)
# Initialize weights and apply final processing
self.post_init()
def tie_weights(self):
"""
This method overwrites [`~PreTrainedModel.tie_weights`] so that adapter weights can be correctly loaded when
passing `target_lang=...` to `from_pretrained(...)`.
This method is **not** supposed to be called by the user and is prone to be changed in the future.
"""
# Note that `tie_weights` is usually used to tie input and output embedding weights. The method is re-purposed to
# correctly load adapter layers for UniSpeech so that we do not have to introduce a new API to
# [`PreTrainedModel`]. While slightly hacky, UniSpeech never has to tie input and output embeddings, so that it is
# ok to repurpose this function here.
target_lang = self.target_lang
if target_lang is not None and getattr(self.config, "adapter_attn_dim", None) is None:
raise ValueError(f"Cannot pass `target_lang`: {target_lang} if `config.adapter_attn_dim` is not defined.")
elif target_lang is None and getattr(self.config, "adapter_attn_dim", None) is not None:
logger.info("By default `target_lang` is set to 'eng'.")
elif target_lang is not None:
self.load_adapter(target_lang, force_load=True)
def freeze_feature_extractor(self):
"""
Calling this function will disable the gradient computation for the feature encoder so that its parameter will
not be updated during training.
"""
warnings.warn(
"The method `freeze_feature_extractor` is deprecated and will be removed in Transformers v5. "
"Please use the equivalent `freeze_feature_encoder` method instead.",
FutureWarning,
)
self.freeze_feature_encoder()
def freeze_feature_encoder(self):
"""
Calling this function will disable the gradient computation for the feature encoder so that its parameter will
not be updated during training.
"""
self.unispeech.feature_extractor._freeze_parameters()
def freeze_base_model(self):
"""
Calling this function will disable the gradient computation for the base model so that its parameters will not
be updated during training. Only the classification head will be updated.
"""
for param in self.unispeech.parameters():
param.requires_grad = False
@auto_docstring
def forward(
self,
input_values: Optional[torch.Tensor],
attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
labels: Optional[torch.Tensor] = None,
) -> Union[tuple, CausalLMOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, target_length)`, *optional*):
Labels for connectionist temporal classification. Note that `target_length` has to be smaller or equal to
the sequence length of the output logits. Indices are selected in `[-100, 0, ..., config.vocab_size - 1]`.
All labels set to `-100` are ignored (masked), the loss is only computed for labels in `[0, ...,
config.vocab_size - 1]`.
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if labels is not None and labels.max() >= self.config.vocab_size:
raise ValueError(f"Label values must be <= vocab_size: {self.config.vocab_size}")
outputs = self.unispeech(
input_values,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = outputs[0]
hidden_states = self.dropout(hidden_states)
logits = self.lm_head(hidden_states)
loss = None
if labels is not None:
# retrieve loss input_lengths from attention_mask
attention_mask = (
attention_mask if attention_mask is not None else torch.ones_like(input_values, dtype=torch.long)
)
input_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long)
# assuming that padded tokens are filled with -100
# when not being attended to
labels_mask = labels >= 0
target_lengths = labels_mask.sum(-1)
flattened_targets = labels.masked_select(labels_mask)
# ctc_loss doesn't support fp16
log_probs = nn.functional.log_softmax(logits, dim=-1, dtype=torch.float32).transpose(0, 1)
with torch.backends.cudnn.flags(enabled=False):
loss = nn.functional.ctc_loss(
log_probs,
flattened_targets,
input_lengths,
target_lengths,
blank=self.config.pad_token_id,
reduction=self.config.ctc_loss_reduction,
zero_infinity=self.config.ctc_zero_infinity,
)
if not return_dict:
output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
return ((loss,) + output) if loss is not None else output
return CausalLMOutput(
loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions
)
@auto_docstring(
custom_intro="""
UniSpeech Model with a sequence classification head on top (a linear layer over the pooled output) for tasks like
SUPERB Keyword Spotting.
"""
)
class UniSpeechForSequenceClassification(UniSpeechPreTrainedModel):
def __init__(self, config):
super().__init__(config)
if hasattr(config, "add_adapter") and config.add_adapter:
raise ValueError(
"Sequence classification does not support the use of UniSpeech adapters (config.add_adapter=True)"
)
self.unispeech = UniSpeechModel(config)
num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings
if config.use_weighted_layer_sum:
self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
self.projector = nn.Linear(config.hidden_size, config.classifier_proj_size)
self.classifier = nn.Linear(config.classifier_proj_size, config.num_labels)
# Initialize weights and apply final processing
self.post_init()
def freeze_feature_extractor(self):
"""
Calling this function will disable the gradient computation for the feature encoder so that its parameters will
not be updated during training.
"""
warnings.warn(
"The method `freeze_feature_extractor` is deprecated and will be removed in Transformers v5. "
"Please use the equivalent `freeze_feature_encoder` method instead.",
FutureWarning,
)
self.freeze_feature_encoder()
def freeze_feature_encoder(self):
"""
Calling this function will disable the gradient computation for the feature encoder so that its parameter will
not be updated during training.
"""
self.unispeech.feature_extractor._freeze_parameters()
def freeze_base_model(self):
"""
Calling this function will disable the gradient computation for the base model so that its parameters will not
be updated during training. Only the classification head will be updated.
"""
for param in self.unispeech.parameters():
param.requires_grad = False
@auto_docstring
def forward(
self,
input_values: Optional[torch.Tensor],
attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
labels: Optional[torch.Tensor] = None,
) -> Union[tuple, SequenceClassifierOutput]:
r"""
input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
Float values of input raw speech waveform. Values can be obtained by loading a `.flac` or `.wav` audio file
into an array of type `list[float]` or a `numpy.ndarray`, *e.g.* via the soundfile library (`pip install
soundfile`). To prepare the array into `input_values`, the [`AutoProcessor`] should be used for padding and
conversion into a tensor of type `torch.FloatTensor`. See [`UniSpeechProcessor.__call__`] for details.
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states
outputs = self.unispeech(
input_values,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
if self.config.use_weighted_layer_sum:
hidden_states = outputs[_HIDDEN_STATES_START_POSITION]
hidden_states = torch.stack(hidden_states, dim=1)
norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
else:
hidden_states = outputs[0]
hidden_states = self.projector(hidden_states)
if attention_mask is None:
pooled_output = hidden_states.mean(dim=1)
else:
padding_mask = self._get_feature_vector_attention_mask(hidden_states.shape[1], attention_mask)
expand_padding_mask = padding_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
hidden_states[~expand_padding_mask] = 0.0
pooled_output = hidden_states.sum(dim=1) / padding_mask.sum(dim=1).view(-1, 1)
logits = self.classifier(pooled_output)
loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
if not return_dict:
output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
__all__ = [
"UniSpeechForCTC",
"UniSpeechForPreTraining",
"UniSpeechForSequenceClassification",
"UniSpeechModel",
"UniSpeechPreTrainedModel",
]