mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-04 21:30:07 +06:00

* deprecate for 1 version * style * fix some tests * fix esm * skip for now, GC requires positional args but we have keyword args * remove transpose for scores in modified models only * skip fx trace tests
1521 lines
63 KiB
Python
Executable File
1521 lines
63 KiB
Python
Executable File
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
|
# This file was automatically generated from src/transformers/models/unispeech/modular_unispeech.py.
|
|
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
|
# the file from the modular. If any change should be done, please apply the change to the
|
|
# modular_unispeech.py file directly. One of our CI enforces this.
|
|
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
|
# coding=utf-8
|
|
# Copyright 2021 The Fairseq Authors and the HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import math
|
|
import warnings
|
|
from dataclasses import dataclass
|
|
from typing import Callable, Optional, Union
|
|
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
from torch.nn import CrossEntropyLoss
|
|
|
|
from ...activations import ACT2FN
|
|
from ...integrations.deepspeed import is_deepspeed_zero3_enabled
|
|
from ...integrations.fsdp import is_fsdp_managed_module
|
|
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_attention_mask_for_sdpa
|
|
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
|
from ...modeling_layers import GradientCheckpointingLayer
|
|
from ...modeling_outputs import (
|
|
BaseModelOutput,
|
|
CausalLMOutput,
|
|
ModelOutput,
|
|
SequenceClassifierOutput,
|
|
Wav2Vec2BaseModelOutput,
|
|
)
|
|
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
|
from ...processing_utils import Unpack
|
|
from ...utils import auto_docstring, is_torch_flex_attn_available, logging
|
|
from ...utils.deprecation import deprecate_kwarg
|
|
from .configuration_unispeech import UniSpeechConfig
|
|
|
|
|
|
if is_torch_flex_attn_available():
|
|
from ...integrations.flex_attention import make_flex_block_causal_mask
|
|
|
|
|
|
logger = logging.get_logger(__name__)
|
|
|
|
|
|
@dataclass
|
|
@auto_docstring(
|
|
custom_intro="""
|
|
Output type of [`UniSpeechForPreTrainingOutput`], with potential hidden states and attentions.
|
|
"""
|
|
)
|
|
class UniSpeechForPreTrainingOutput(ModelOutput):
|
|
r"""
|
|
loss (*optional*, returned when model is in train mode, `torch.FloatTensor` of shape `(1,)`):
|
|
Total loss as the sum of the contrastive loss (L_m) and the diversity loss (L_d) as stated in the [official
|
|
paper](https://arxiv.org/pdf/2006.11477.pdf) . (classification) loss.
|
|
projected_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.proj_codevector_dim)`):
|
|
Hidden-states of the model projected to *config.proj_codevector_dim* that can be used to predict the masked
|
|
projected quantized states.
|
|
projected_quantized_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.proj_codevector_dim)`):
|
|
Quantized extracted feature vectors projected to *config.proj_codevector_dim* representing the positive
|
|
target vectors for contrastive loss.
|
|
codevector_perplexity (`torch.FloatTensor` of shape `(1,)`):
|
|
The perplexity of the codevector distribution, used to measure the diversity of the codebook.
|
|
"""
|
|
|
|
loss: Optional[torch.FloatTensor] = None
|
|
projected_states: Optional[torch.FloatTensor] = None
|
|
projected_quantized_states: Optional[torch.FloatTensor] = None
|
|
codevector_perplexity: Optional[torch.FloatTensor] = None
|
|
hidden_states: Optional[tuple[torch.FloatTensor]] = None
|
|
attentions: Optional[tuple[torch.FloatTensor]] = None
|
|
|
|
|
|
class UniSpeechSamePadLayer(nn.Module):
|
|
def __init__(self, num_conv_pos_embeddings):
|
|
super().__init__()
|
|
self.num_pad_remove = 1 if num_conv_pos_embeddings % 2 == 0 else 0
|
|
|
|
def forward(self, hidden_states):
|
|
if self.num_pad_remove > 0:
|
|
hidden_states = hidden_states[:, :, : -self.num_pad_remove]
|
|
return hidden_states
|
|
|
|
|
|
class UniSpeechPositionalConvEmbedding(nn.Module):
|
|
def __init__(self, config):
|
|
super().__init__()
|
|
self.conv = nn.Conv1d(
|
|
config.hidden_size,
|
|
config.hidden_size,
|
|
kernel_size=config.num_conv_pos_embeddings,
|
|
padding=config.num_conv_pos_embeddings // 2,
|
|
groups=config.num_conv_pos_embedding_groups,
|
|
)
|
|
|
|
weight_norm = nn.utils.weight_norm
|
|
if hasattr(nn.utils.parametrizations, "weight_norm"):
|
|
weight_norm = nn.utils.parametrizations.weight_norm
|
|
|
|
if is_deepspeed_zero3_enabled():
|
|
import deepspeed
|
|
|
|
with deepspeed.zero.GatheredParameters(self.conv.weight, modifier_rank=0):
|
|
self.conv = weight_norm(self.conv, name="weight", dim=2)
|
|
if hasattr(self.conv, "parametrizations"):
|
|
weight_g = self.conv.parametrizations.weight.original0
|
|
weight_v = self.conv.parametrizations.weight.original1
|
|
else:
|
|
weight_g = self.conv.weight_g
|
|
weight_v = self.conv.weight_v
|
|
deepspeed.zero.register_external_parameter(self, weight_v)
|
|
deepspeed.zero.register_external_parameter(self, weight_g)
|
|
else:
|
|
self.conv = weight_norm(self.conv, name="weight", dim=2)
|
|
|
|
self.padding = UniSpeechSamePadLayer(config.num_conv_pos_embeddings)
|
|
self.activation = ACT2FN[config.feat_extract_activation]
|
|
|
|
def forward(self, hidden_states):
|
|
hidden_states = hidden_states.transpose(1, 2)
|
|
|
|
hidden_states = self.conv(hidden_states)
|
|
hidden_states = self.padding(hidden_states)
|
|
hidden_states = self.activation(hidden_states)
|
|
|
|
hidden_states = hidden_states.transpose(1, 2)
|
|
return hidden_states
|
|
|
|
|
|
class UniSpeechNoLayerNormConvLayer(GradientCheckpointingLayer):
|
|
def __init__(self, config, layer_id=0):
|
|
super().__init__()
|
|
self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
|
|
self.out_conv_dim = config.conv_dim[layer_id]
|
|
|
|
self.conv = nn.Conv1d(
|
|
self.in_conv_dim,
|
|
self.out_conv_dim,
|
|
kernel_size=config.conv_kernel[layer_id],
|
|
stride=config.conv_stride[layer_id],
|
|
bias=config.conv_bias,
|
|
)
|
|
self.activation = ACT2FN[config.feat_extract_activation]
|
|
|
|
def forward(self, hidden_states):
|
|
hidden_states = self.conv(hidden_states)
|
|
hidden_states = self.activation(hidden_states)
|
|
return hidden_states
|
|
|
|
|
|
class UniSpeechLayerNormConvLayer(GradientCheckpointingLayer):
|
|
def __init__(self, config, layer_id=0):
|
|
super().__init__()
|
|
self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
|
|
self.out_conv_dim = config.conv_dim[layer_id]
|
|
|
|
self.conv = nn.Conv1d(
|
|
self.in_conv_dim,
|
|
self.out_conv_dim,
|
|
kernel_size=config.conv_kernel[layer_id],
|
|
stride=config.conv_stride[layer_id],
|
|
bias=config.conv_bias,
|
|
)
|
|
self.layer_norm = nn.LayerNorm(self.out_conv_dim, elementwise_affine=True)
|
|
self.activation = ACT2FN[config.feat_extract_activation]
|
|
|
|
def forward(self, hidden_states):
|
|
hidden_states = self.conv(hidden_states)
|
|
|
|
hidden_states = hidden_states.transpose(-2, -1)
|
|
hidden_states = self.layer_norm(hidden_states)
|
|
hidden_states = hidden_states.transpose(-2, -1)
|
|
|
|
hidden_states = self.activation(hidden_states)
|
|
return hidden_states
|
|
|
|
|
|
class UniSpeechGroupNormConvLayer(GradientCheckpointingLayer):
|
|
def __init__(self, config, layer_id=0):
|
|
super().__init__()
|
|
self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
|
|
self.out_conv_dim = config.conv_dim[layer_id]
|
|
|
|
self.conv = nn.Conv1d(
|
|
self.in_conv_dim,
|
|
self.out_conv_dim,
|
|
kernel_size=config.conv_kernel[layer_id],
|
|
stride=config.conv_stride[layer_id],
|
|
bias=config.conv_bias,
|
|
)
|
|
self.activation = ACT2FN[config.feat_extract_activation]
|
|
|
|
self.layer_norm = nn.GroupNorm(num_groups=self.out_conv_dim, num_channels=self.out_conv_dim, affine=True)
|
|
|
|
def forward(self, hidden_states):
|
|
hidden_states = self.conv(hidden_states)
|
|
hidden_states = self.layer_norm(hidden_states)
|
|
hidden_states = self.activation(hidden_states)
|
|
return hidden_states
|
|
|
|
|
|
class UniSpeechFeatureEncoder(nn.Module):
|
|
"""Construct the features from raw audio waveform"""
|
|
|
|
def __init__(self, config):
|
|
super().__init__()
|
|
|
|
if config.feat_extract_norm == "group":
|
|
conv_layers = [UniSpeechGroupNormConvLayer(config, layer_id=0)] + [
|
|
UniSpeechNoLayerNormConvLayer(config, layer_id=i + 1)
|
|
for i in range(config.num_feat_extract_layers - 1)
|
|
]
|
|
elif config.feat_extract_norm == "layer":
|
|
conv_layers = [
|
|
UniSpeechLayerNormConvLayer(config, layer_id=i) for i in range(config.num_feat_extract_layers)
|
|
]
|
|
else:
|
|
raise ValueError(
|
|
f"`config.feat_extract_norm` is {config.feat_extract_norm}, but has to be one of ['group', 'layer']"
|
|
)
|
|
self.conv_layers = nn.ModuleList(conv_layers)
|
|
self.gradient_checkpointing = False
|
|
self._requires_grad = True
|
|
|
|
def _freeze_parameters(self):
|
|
for param in self.parameters():
|
|
param.requires_grad = False
|
|
self._requires_grad = False
|
|
|
|
def forward(self, input_values):
|
|
hidden_states = input_values[:, None]
|
|
|
|
# make sure hidden_states require grad for gradient_checkpointing
|
|
if self._requires_grad and self.training:
|
|
hidden_states.requires_grad = True
|
|
|
|
for conv_layer in self.conv_layers:
|
|
hidden_states = conv_layer(hidden_states)
|
|
|
|
return hidden_states
|
|
|
|
|
|
class UniSpeechFeatureProjection(nn.Module):
|
|
def __init__(self, config):
|
|
super().__init__()
|
|
self.layer_norm = nn.LayerNorm(config.conv_dim[-1], eps=config.layer_norm_eps)
|
|
self.projection = nn.Linear(config.conv_dim[-1], config.hidden_size)
|
|
self.dropout = nn.Dropout(config.feat_proj_dropout)
|
|
|
|
def forward(self, hidden_states):
|
|
# non-projected hidden states are needed for quantization
|
|
norm_hidden_states = self.layer_norm(hidden_states)
|
|
hidden_states = self.projection(norm_hidden_states)
|
|
hidden_states = self.dropout(hidden_states)
|
|
return hidden_states, norm_hidden_states
|
|
|
|
|
|
def eager_attention_forward(
|
|
module: nn.Module,
|
|
query: torch.Tensor,
|
|
key: torch.Tensor,
|
|
value: torch.Tensor,
|
|
attention_mask: Optional[torch.Tensor],
|
|
scaling: Optional[float] = None,
|
|
dropout: float = 0.0,
|
|
head_mask: Optional[torch.Tensor] = None,
|
|
**kwargs,
|
|
):
|
|
if scaling is None:
|
|
scaling = query.size(-1) ** -0.5
|
|
|
|
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
|
|
if attention_mask is not None:
|
|
attn_weights = attn_weights + attention_mask
|
|
|
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
|
|
|
if head_mask is not None:
|
|
attn_weights = attn_weights * head_mask.view(1, -1, 1, 1)
|
|
|
|
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
|
attn_output = torch.matmul(attn_weights, value)
|
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
|
|
|
return attn_output, attn_weights
|
|
|
|
|
|
class UniSpeechAttention(nn.Module):
|
|
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
|
|
|
def __init__(
|
|
self,
|
|
embed_dim: int,
|
|
num_heads: int,
|
|
dropout: float = 0.0,
|
|
is_decoder: bool = False,
|
|
bias: bool = True,
|
|
is_causal: bool = False,
|
|
config: Optional[UniSpeechConfig] = None,
|
|
):
|
|
super().__init__()
|
|
self.embed_dim = embed_dim
|
|
self.num_heads = num_heads
|
|
self.dropout = dropout
|
|
self.head_dim = embed_dim // num_heads
|
|
self.config = config
|
|
|
|
if (self.head_dim * num_heads) != self.embed_dim:
|
|
raise ValueError(
|
|
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
|
|
f" and `num_heads`: {num_heads})."
|
|
)
|
|
self.scaling = self.head_dim**-0.5
|
|
self.is_decoder = is_decoder
|
|
self.is_causal = is_causal
|
|
|
|
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
|
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
|
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
|
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
|
|
|
@deprecate_kwarg("past_key_value", version="4.54.0")
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
key_value_states: Optional[torch.Tensor] = None,
|
|
past_key_value: Optional[tuple[torch.Tensor]] = None,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
layer_head_mask: Optional[torch.Tensor] = None,
|
|
output_attentions: Optional[bool] = False,
|
|
# TODO: we need a refactor so that the different attention modules can get their specific kwargs
|
|
# ATM, we have mixed things encoder, decoder, and encoder-decoder attn
|
|
**kwargs: Unpack[FlashAttentionKwargs],
|
|
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
|
|
"""Input shape: Batch x Time x Channel"""
|
|
|
|
# if key_value_states are provided this layer is used as a cross-attention layer
|
|
# for the decoder
|
|
is_cross_attention = key_value_states is not None
|
|
|
|
# determine input shapes
|
|
bsz, tgt_len = hidden_states.shape[:-1]
|
|
src_len = key_value_states.shape[1] if is_cross_attention else tgt_len
|
|
|
|
q_input_shape = (bsz, tgt_len, -1, self.head_dim)
|
|
kv_input_shape = (bsz, src_len, -1, self.head_dim)
|
|
|
|
# get query proj
|
|
query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2)
|
|
|
|
current_states = key_value_states if is_cross_attention else hidden_states
|
|
key_states = self.k_proj(current_states).view(*kv_input_shape).transpose(1, 2)
|
|
value_states = self.v_proj(current_states).view(*kv_input_shape).transpose(1, 2)
|
|
|
|
attention_interface: Callable = eager_attention_forward
|
|
if self.config._attn_implementation != "eager":
|
|
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
|
|
|
attn_output, attn_weights = attention_interface(
|
|
self,
|
|
query_states,
|
|
key_states,
|
|
value_states,
|
|
attention_mask,
|
|
dropout=0.0 if not self.training else self.dropout,
|
|
scaling=self.scaling,
|
|
output_attentions=output_attentions,
|
|
head_mask=layer_head_mask,
|
|
**kwargs,
|
|
)
|
|
|
|
attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous()
|
|
attn_output = self.out_proj(attn_output)
|
|
|
|
return attn_output, attn_weights, None
|
|
|
|
|
|
class UniSpeechFeedForward(nn.Module):
|
|
def __init__(self, config):
|
|
super().__init__()
|
|
self.intermediate_dropout = nn.Dropout(config.activation_dropout)
|
|
|
|
self.intermediate_dense = nn.Linear(config.hidden_size, config.intermediate_size)
|
|
if isinstance(config.hidden_act, str):
|
|
self.intermediate_act_fn = ACT2FN[config.hidden_act]
|
|
else:
|
|
self.intermediate_act_fn = config.hidden_act
|
|
|
|
self.output_dense = nn.Linear(config.intermediate_size, config.hidden_size)
|
|
self.output_dropout = nn.Dropout(config.hidden_dropout)
|
|
|
|
def forward(self, hidden_states):
|
|
hidden_states = self.intermediate_dense(hidden_states)
|
|
hidden_states = self.intermediate_act_fn(hidden_states)
|
|
hidden_states = self.intermediate_dropout(hidden_states)
|
|
|
|
hidden_states = self.output_dense(hidden_states)
|
|
hidden_states = self.output_dropout(hidden_states)
|
|
return hidden_states
|
|
|
|
|
|
class UniSpeechEncoderLayer(GradientCheckpointingLayer):
|
|
def __init__(self, config):
|
|
super().__init__()
|
|
self.attention = UniSpeechAttention(
|
|
embed_dim=config.hidden_size,
|
|
num_heads=config.num_attention_heads,
|
|
dropout=config.attention_dropout,
|
|
is_decoder=False,
|
|
config=config,
|
|
)
|
|
|
|
self.dropout = nn.Dropout(config.hidden_dropout)
|
|
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
|
self.feed_forward = UniSpeechFeedForward(config)
|
|
self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
|
|
|
def forward(self, hidden_states, attention_mask=None, output_attentions=False):
|
|
attn_residual = hidden_states
|
|
hidden_states, attn_weights, _ = self.attention(
|
|
hidden_states, attention_mask=attention_mask, output_attentions=output_attentions
|
|
)
|
|
hidden_states = self.dropout(hidden_states)
|
|
hidden_states = attn_residual + hidden_states
|
|
|
|
hidden_states = self.layer_norm(hidden_states)
|
|
hidden_states = hidden_states + self.feed_forward(hidden_states)
|
|
hidden_states = self.final_layer_norm(hidden_states)
|
|
|
|
outputs = (hidden_states,)
|
|
|
|
if output_attentions:
|
|
outputs += (attn_weights,)
|
|
|
|
return outputs
|
|
|
|
|
|
class UniSpeechEncoder(nn.Module):
|
|
def __init__(self, config):
|
|
super().__init__()
|
|
self.config = config
|
|
self.pos_conv_embed = UniSpeechPositionalConvEmbedding(config)
|
|
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
|
self.dropout = nn.Dropout(config.hidden_dropout)
|
|
self.layers = nn.ModuleList([UniSpeechEncoderLayer(config) for _ in range(config.num_hidden_layers)])
|
|
self.gradient_checkpointing = False
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.tensor,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
output_attentions: bool = False,
|
|
output_hidden_states: bool = False,
|
|
return_dict: bool = True,
|
|
):
|
|
all_hidden_states = () if output_hidden_states else None
|
|
all_self_attentions = () if output_attentions else None
|
|
|
|
if attention_mask is not None:
|
|
# make sure padded tokens output 0
|
|
expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
|
|
hidden_states[~expand_attention_mask] = 0
|
|
|
|
attention_mask = self._update_full_mask(
|
|
attention_mask,
|
|
hidden_states,
|
|
)
|
|
|
|
position_embeddings = self.pos_conv_embed(hidden_states)
|
|
hidden_states = hidden_states + position_embeddings
|
|
hidden_states = self.layer_norm(hidden_states)
|
|
hidden_states = self.dropout(hidden_states)
|
|
|
|
synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self)
|
|
|
|
for layer in self.layers:
|
|
if output_hidden_states:
|
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
|
|
|
# add LayerDrop (see https://huggingface.co/papers/1909.11556 for description)
|
|
dropout_probability = torch.rand([])
|
|
|
|
skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False
|
|
if not skip_the_layer or synced_gpus:
|
|
# under fsdp or deepspeed zero3 all gpus must run in sync
|
|
layer_outputs = layer(
|
|
hidden_states, attention_mask=attention_mask, output_attentions=output_attentions
|
|
)
|
|
hidden_states = layer_outputs[0]
|
|
|
|
if skip_the_layer:
|
|
layer_outputs = (None, None)
|
|
|
|
if output_attentions:
|
|
all_self_attentions = all_self_attentions + (layer_outputs[1],)
|
|
|
|
if output_hidden_states:
|
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
|
|
|
if not return_dict:
|
|
return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
|
|
return BaseModelOutput(
|
|
last_hidden_state=hidden_states,
|
|
hidden_states=all_hidden_states,
|
|
attentions=all_self_attentions,
|
|
)
|
|
|
|
def _update_full_mask(
|
|
self,
|
|
attention_mask: Union[torch.Tensor, None],
|
|
inputs_embeds: torch.Tensor,
|
|
):
|
|
if attention_mask is not None:
|
|
if self.config._attn_implementation == "flash_attention_2":
|
|
attention_mask = attention_mask if 0 in attention_mask else None
|
|
elif self.config._attn_implementation == "sdpa":
|
|
# output_attentions=True & head_mask can not be supported when using SDPA, fall back to
|
|
# the manual implementation that requires a 4D causal mask in all cases.
|
|
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
|
attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype)
|
|
elif self.config._attn_implementation == "flex_attention":
|
|
if isinstance(attention_mask, torch.Tensor):
|
|
attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False)
|
|
else:
|
|
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
|
attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype)
|
|
|
|
return attention_mask
|
|
|
|
|
|
class UniSpeechAttnAdapterLayer(nn.Module):
|
|
def __init__(self, config):
|
|
"""
|
|
Implements adapter modules directly with 3D tensor weight as parameters and without using ModuleList to speed
|
|
up training throughput.
|
|
"""
|
|
super().__init__()
|
|
self.input_dim = config.adapter_attn_dim
|
|
self.hidden_dim = config.hidden_size
|
|
|
|
self.norm = nn.LayerNorm(self.hidden_dim)
|
|
self.linear_1 = nn.Linear(self.hidden_dim, self.input_dim)
|
|
self.act_fn = nn.ReLU()
|
|
self.linear_2 = nn.Linear(self.input_dim, self.hidden_dim)
|
|
|
|
def forward(self, hidden_states: torch.FloatTensor):
|
|
hidden_states = self.norm(hidden_states)
|
|
|
|
hidden_states = self.linear_1(hidden_states)
|
|
hidden_states = self.act_fn(hidden_states)
|
|
hidden_states = self.linear_2(hidden_states)
|
|
|
|
return hidden_states
|
|
|
|
|
|
class UniSpeechEncoderLayerStableLayerNorm(GradientCheckpointingLayer):
|
|
def __init__(self, config):
|
|
super().__init__()
|
|
self.attention = UniSpeechAttention(
|
|
embed_dim=config.hidden_size,
|
|
num_heads=config.num_attention_heads,
|
|
dropout=config.attention_dropout,
|
|
is_decoder=False,
|
|
config=config,
|
|
)
|
|
self.dropout = nn.Dropout(config.hidden_dropout)
|
|
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
|
self.feed_forward = UniSpeechFeedForward(config)
|
|
self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
|
|
|
if getattr(config, "adapter_attn_dim", None) is not None:
|
|
self.adapter_layer = UniSpeechAttnAdapterLayer(config)
|
|
else:
|
|
self.adapter_layer = None
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
output_attentions: bool = False,
|
|
):
|
|
attn_residual = hidden_states
|
|
hidden_states = self.layer_norm(hidden_states)
|
|
hidden_states, attn_weights, _ = self.attention(
|
|
hidden_states, attention_mask=attention_mask, output_attentions=output_attentions
|
|
)
|
|
hidden_states = self.dropout(hidden_states)
|
|
hidden_states = attn_residual + hidden_states
|
|
hidden_states = hidden_states + self.feed_forward(self.final_layer_norm(hidden_states))
|
|
|
|
if self.adapter_layer is not None:
|
|
hidden_states = hidden_states + self.adapter_layer(hidden_states)
|
|
|
|
outputs = (hidden_states,)
|
|
|
|
if output_attentions:
|
|
outputs += (attn_weights,)
|
|
|
|
return outputs
|
|
|
|
|
|
class UniSpeechEncoderStableLayerNorm(nn.Module):
|
|
def __init__(self, config):
|
|
super().__init__()
|
|
self.config = config
|
|
self.pos_conv_embed = UniSpeechPositionalConvEmbedding(config)
|
|
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
|
self.dropout = nn.Dropout(config.hidden_dropout)
|
|
self.layers = nn.ModuleList(
|
|
[UniSpeechEncoderLayerStableLayerNorm(config) for _ in range(config.num_hidden_layers)]
|
|
)
|
|
self.gradient_checkpointing = False
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states,
|
|
attention_mask=None,
|
|
output_attentions=False,
|
|
output_hidden_states=False,
|
|
return_dict=True,
|
|
):
|
|
all_hidden_states = () if output_hidden_states else None
|
|
all_self_attentions = () if output_attentions else None
|
|
|
|
if attention_mask is not None:
|
|
# make sure padded tokens output 0
|
|
expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
|
|
hidden_states[~expand_attention_mask] = 0
|
|
|
|
attention_mask = self._update_full_mask(
|
|
attention_mask,
|
|
hidden_states,
|
|
)
|
|
|
|
position_embeddings = self.pos_conv_embed(hidden_states)
|
|
hidden_states = hidden_states + position_embeddings
|
|
hidden_states = self.dropout(hidden_states)
|
|
|
|
synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self)
|
|
|
|
for layer in self.layers:
|
|
if output_hidden_states:
|
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
|
|
|
# add LayerDrop (see https://huggingface.co/papers/1909.11556 for description)
|
|
dropout_probability = torch.rand([])
|
|
|
|
skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False
|
|
if not skip_the_layer or synced_gpus:
|
|
# under fsdp or deepspeed zero3 all gpus must run in sync
|
|
# XXX: could optimize this like synced_gpus in generate_utils but not sure if it's worth the code complication
|
|
layer_outputs = layer(
|
|
hidden_states, attention_mask=attention_mask, output_attentions=output_attentions
|
|
)
|
|
hidden_states = layer_outputs[0]
|
|
|
|
if skip_the_layer:
|
|
layer_outputs = (None, None)
|
|
|
|
if output_attentions:
|
|
all_self_attentions = all_self_attentions + (layer_outputs[1],)
|
|
|
|
hidden_states = self.layer_norm(hidden_states)
|
|
|
|
if output_hidden_states:
|
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
|
|
|
if not return_dict:
|
|
return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
|
|
return BaseModelOutput(
|
|
last_hidden_state=hidden_states,
|
|
hidden_states=all_hidden_states,
|
|
attentions=all_self_attentions,
|
|
)
|
|
|
|
def _update_full_mask(
|
|
self,
|
|
attention_mask: Union[torch.Tensor, None],
|
|
inputs_embeds: torch.Tensor,
|
|
):
|
|
if attention_mask is not None:
|
|
if self.config._attn_implementation == "flash_attention_2":
|
|
attention_mask = attention_mask if 0 in attention_mask else None
|
|
elif self.config._attn_implementation == "sdpa":
|
|
# output_attentions=True & head_mask can not be supported when using SDPA, fall back to
|
|
# the manual implementation that requires a 4D causal mask in all cases.
|
|
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
|
attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype)
|
|
elif self.config._attn_implementation == "flex_attention":
|
|
if isinstance(attention_mask, torch.Tensor):
|
|
attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False)
|
|
else:
|
|
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
|
attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype)
|
|
|
|
return attention_mask
|
|
|
|
|
|
class UniSpeechGumbelVectorQuantizer(nn.Module):
|
|
"""
|
|
Vector quantization using gumbel softmax. See `[CATEGORICAL REPARAMETERIZATION WITH
|
|
GUMBEL-SOFTMAX](https://huggingface.co/papers/1611.01144) for more information.
|
|
"""
|
|
|
|
def __init__(self, config):
|
|
super().__init__()
|
|
self.num_groups = config.num_codevector_groups
|
|
self.num_vars = config.num_codevectors_per_group
|
|
|
|
if config.codevector_dim % self.num_groups != 0:
|
|
raise ValueError(
|
|
f"`config.codevector_dim {config.codevector_dim} must be divisible "
|
|
f"by `config.num_codevector_groups` {self.num_groups} for concatenation"
|
|
)
|
|
|
|
# storage for codebook variables (codewords)
|
|
self.codevectors = nn.Parameter(
|
|
torch.FloatTensor(1, self.num_groups * self.num_vars, config.codevector_dim // self.num_groups)
|
|
)
|
|
self.weight_proj = nn.Linear(config.conv_dim[-1], self.num_groups * self.num_vars)
|
|
|
|
# can be decayed for training
|
|
self.temperature = 2
|
|
|
|
@staticmethod
|
|
def _compute_perplexity(probs):
|
|
marginal_probs = probs.mean(dim=0)
|
|
perplexity = torch.exp(-torch.sum(marginal_probs * torch.log(marginal_probs + 1e-7), dim=-1)).sum()
|
|
return perplexity
|
|
|
|
def forward(self, hidden_states):
|
|
batch_size, sequence_length, hidden_size = hidden_states.shape
|
|
|
|
# project to codevector dim
|
|
hidden_states = self.weight_proj(hidden_states)
|
|
hidden_states = hidden_states.view(batch_size * sequence_length * self.num_groups, -1)
|
|
|
|
if self.training:
|
|
# sample code vector probs via gumbel in differentiateable way
|
|
codevector_probs = nn.functional.gumbel_softmax(
|
|
hidden_states.float(), tau=self.temperature, hard=True
|
|
).type_as(hidden_states)
|
|
|
|
# compute perplexity
|
|
codevector_soft_dist = torch.softmax(
|
|
hidden_states.view(batch_size * sequence_length, self.num_groups, -1).float(), dim=-1
|
|
)
|
|
perplexity = self._compute_perplexity(codevector_soft_dist)
|
|
else:
|
|
# take argmax in non-differentiable way
|
|
# comptute hard codevector distribution (one hot)
|
|
codevector_idx = hidden_states.argmax(dim=-1)
|
|
codevector_probs = hidden_states.new_zeros(*hidden_states.shape).scatter_(
|
|
-1, codevector_idx.view(-1, 1), 1.0
|
|
)
|
|
codevector_probs = codevector_probs.view(batch_size * sequence_length, self.num_groups, -1)
|
|
|
|
perplexity = self._compute_perplexity(codevector_probs)
|
|
|
|
codevector_probs = codevector_probs.view(batch_size * sequence_length, -1)
|
|
# use probs to retrieve codevectors
|
|
codevectors_per_group = codevector_probs.unsqueeze(-1) * self.codevectors
|
|
codevectors = codevectors_per_group.view(batch_size * sequence_length, self.num_groups, self.num_vars, -1)
|
|
codevectors = codevectors.sum(-2).view(batch_size, sequence_length, -1)
|
|
|
|
return codevectors, perplexity
|
|
|
|
|
|
@auto_docstring
|
|
class UniSpeechPreTrainedModel(PreTrainedModel):
|
|
config_class = UniSpeechConfig
|
|
base_model_prefix = "unispeech"
|
|
main_input_name = "input_values"
|
|
supports_gradient_checkpointing = True
|
|
_supports_flash_attn_2 = True
|
|
_supports_sdpa = True
|
|
_supports_flex_attn = True
|
|
|
|
def _init_weights(self, module):
|
|
"""Initialize the weights"""
|
|
# gumbel softmax requires special init
|
|
if isinstance(module, UniSpeechGumbelVectorQuantizer):
|
|
module.weight_proj.weight.data.normal_(mean=0.0, std=1)
|
|
module.weight_proj.bias.data.zero_()
|
|
nn.init.uniform_(module.codevectors)
|
|
elif isinstance(module, UniSpeechPositionalConvEmbedding):
|
|
nn.init.normal_(
|
|
module.conv.weight,
|
|
mean=0,
|
|
std=2 * math.sqrt(1 / (module.conv.kernel_size[0] * module.conv.in_channels)),
|
|
)
|
|
nn.init.constant_(module.conv.bias, 0)
|
|
elif isinstance(module, UniSpeechFeatureProjection):
|
|
k = math.sqrt(1 / module.projection.in_features)
|
|
nn.init.uniform_(module.projection.weight, a=-k, b=k)
|
|
nn.init.uniform_(module.projection.bias, a=-k, b=k)
|
|
elif isinstance(module, nn.Linear):
|
|
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
|
|
|
if module.bias is not None:
|
|
module.bias.data.zero_()
|
|
elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)):
|
|
module.bias.data.zero_()
|
|
module.weight.data.fill_(1.0)
|
|
elif isinstance(module, nn.Conv1d):
|
|
nn.init.kaiming_normal_(module.weight)
|
|
|
|
if module.bias is not None:
|
|
k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0]))
|
|
nn.init.uniform_(module.bias, a=-k, b=k)
|
|
|
|
def _get_feat_extract_output_lengths(self, input_lengths: Union[torch.LongTensor, int]):
|
|
"""
|
|
Computes the output length of the convolutional layers
|
|
"""
|
|
|
|
def _conv_out_length(input_length, kernel_size, stride):
|
|
# 1D convolutional layer output length formula taken
|
|
# from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
|
|
return torch.div(input_length - kernel_size, stride, rounding_mode="floor") + 1
|
|
|
|
for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride):
|
|
input_lengths = _conv_out_length(input_lengths, kernel_size, stride)
|
|
|
|
return input_lengths
|
|
|
|
def _get_feature_vector_attention_mask(self, feature_vector_length: int, attention_mask: torch.LongTensor):
|
|
# Effectively attention_mask.sum(-1), but not inplace to be able to run
|
|
# on inference mode.
|
|
non_padded_lengths = attention_mask.cumsum(dim=-1)[:, -1]
|
|
output_lengths = self._get_feat_extract_output_lengths(non_padded_lengths).to(torch.long)
|
|
batch_size = attention_mask.shape[0]
|
|
|
|
attention_mask = torch.zeros(
|
|
(batch_size, feature_vector_length), dtype=attention_mask.dtype, device=attention_mask.device
|
|
)
|
|
# these two operations makes sure that all values before the output lengths idxs are attended to
|
|
attention_mask[(torch.arange(attention_mask.shape[0], device=attention_mask.device), output_lengths - 1)] = 1
|
|
attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool()
|
|
return attention_mask
|
|
|
|
|
|
def _compute_mask_indices(
|
|
shape: tuple[int, int],
|
|
mask_prob: float,
|
|
mask_length: int,
|
|
attention_mask: Optional[torch.LongTensor] = None,
|
|
min_masks: int = 0,
|
|
) -> np.ndarray:
|
|
"""
|
|
Computes random mask spans for a given shape. Used to implement [SpecAugment: A Simple Data Augmentation Method for
|
|
ASR](https://huggingface.co/papers/1904.08779). Note that this method is not optimized to run on TPU and should be run on
|
|
CPU as part of the preprocessing during training.
|
|
|
|
Args:
|
|
shape: The shape for which to compute masks. This should be of a tuple of size 2 where
|
|
the first element is the batch size and the second element is the length of the axis to span.
|
|
mask_prob: The percentage of the whole axis (between 0 and 1) which will be masked. The number of
|
|
independently generated mask spans of length `mask_length` is computed by
|
|
`mask_prob*shape[1]/mask_length`. Note that due to overlaps, `mask_prob` is an upper bound and the
|
|
actual percentage will be smaller.
|
|
mask_length: size of the mask
|
|
min_masks: minimum number of masked spans
|
|
attention_mask: A (right-padded) attention mask which independently shortens the feature axis of
|
|
each batch dimension.
|
|
"""
|
|
batch_size, sequence_length = shape
|
|
|
|
if mask_length < 1:
|
|
raise ValueError("`mask_length` has to be bigger than 0.")
|
|
|
|
if mask_length > sequence_length:
|
|
raise ValueError(
|
|
f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length}"
|
|
f" and `sequence_length`: {sequence_length}`"
|
|
)
|
|
|
|
# epsilon is used for probabilistic rounding
|
|
epsilon = np.random.rand(1).item()
|
|
|
|
def compute_num_masked_span(input_length):
|
|
"""Given input length, compute how many spans should be masked"""
|
|
num_masked_span = int(mask_prob * input_length / mask_length + epsilon)
|
|
num_masked_span = max(num_masked_span, min_masks)
|
|
|
|
# make sure num masked span <= sequence_length
|
|
if num_masked_span * mask_length > sequence_length:
|
|
num_masked_span = sequence_length // mask_length
|
|
|
|
# make sure num_masked span is also <= input_length - (mask_length - 1)
|
|
if input_length - (mask_length - 1) < num_masked_span:
|
|
num_masked_span = max(input_length - (mask_length - 1), 0)
|
|
|
|
return num_masked_span
|
|
|
|
# compute number of masked spans in batch
|
|
input_lengths = (
|
|
attention_mask.detach().sum(-1).tolist()
|
|
if attention_mask is not None
|
|
else [sequence_length for _ in range(batch_size)]
|
|
)
|
|
|
|
# SpecAugment mask to fill
|
|
spec_aug_mask = np.zeros((batch_size, sequence_length), dtype=bool)
|
|
spec_aug_mask_idxs = []
|
|
|
|
max_num_masked_span = compute_num_masked_span(sequence_length)
|
|
|
|
if max_num_masked_span == 0:
|
|
return spec_aug_mask
|
|
|
|
for input_length in input_lengths:
|
|
# compute num of masked spans for this input
|
|
num_masked_span = compute_num_masked_span(input_length)
|
|
|
|
# get random indices to mask
|
|
spec_aug_mask_idx = np.random.choice(
|
|
np.arange(input_length - (mask_length - 1)), num_masked_span, replace=False
|
|
)
|
|
|
|
# pick first sampled index that will serve as a dummy index to pad vector
|
|
# to ensure same dimension for all batches due to probabilistic rounding
|
|
# Picking first sample just pads those vectors twice.
|
|
if len(spec_aug_mask_idx) == 0:
|
|
# this case can only happen if `input_length` is strictly smaller then
|
|
# `sequence_length` in which case the last token has to be a padding
|
|
# token which we can use as a dummy mask id
|
|
dummy_mask_idx = sequence_length - 1
|
|
else:
|
|
dummy_mask_idx = spec_aug_mask_idx[0]
|
|
|
|
spec_aug_mask_idx = np.concatenate(
|
|
[spec_aug_mask_idx, np.ones(max_num_masked_span - num_masked_span, dtype=np.int32) * dummy_mask_idx]
|
|
)
|
|
spec_aug_mask_idxs.append(spec_aug_mask_idx)
|
|
|
|
spec_aug_mask_idxs = np.array(spec_aug_mask_idxs)
|
|
|
|
# expand masked indices to masked spans
|
|
spec_aug_mask_idxs = np.broadcast_to(
|
|
spec_aug_mask_idxs[:, :, None], (batch_size, max_num_masked_span, mask_length)
|
|
)
|
|
spec_aug_mask_idxs = spec_aug_mask_idxs.reshape(batch_size, max_num_masked_span * mask_length)
|
|
|
|
# add offset to the starting indexes so that indexes now create a span
|
|
offsets = np.arange(mask_length)[None, None, :]
|
|
offsets = np.broadcast_to(offsets, (batch_size, max_num_masked_span, mask_length)).reshape(
|
|
batch_size, max_num_masked_span * mask_length
|
|
)
|
|
spec_aug_mask_idxs = spec_aug_mask_idxs + offsets
|
|
|
|
# ensure that we cannot have indices larger than sequence_length
|
|
if spec_aug_mask_idxs.max() > sequence_length - 1:
|
|
spec_aug_mask_idxs[spec_aug_mask_idxs > sequence_length - 1] = sequence_length - 1
|
|
|
|
# scatter indices to mask
|
|
np.put_along_axis(spec_aug_mask, spec_aug_mask_idxs, 1, -1)
|
|
|
|
return spec_aug_mask
|
|
|
|
|
|
UniSpeechBaseModelOutput = Wav2Vec2BaseModelOutput
|
|
|
|
|
|
@auto_docstring
|
|
class UniSpeechModel(UniSpeechPreTrainedModel):
|
|
def __init__(self, config: UniSpeechConfig):
|
|
super().__init__(config)
|
|
self.config = config
|
|
self.feature_extractor = UniSpeechFeatureEncoder(config)
|
|
self.feature_projection = UniSpeechFeatureProjection(config)
|
|
|
|
if config.mask_time_prob > 0.0 or config.mask_feature_prob > 0.0:
|
|
self.masked_spec_embed = nn.Parameter(torch.Tensor(config.hidden_size).uniform_())
|
|
|
|
if config.do_stable_layer_norm:
|
|
self.encoder = UniSpeechEncoderStableLayerNorm(config)
|
|
else:
|
|
self.encoder = UniSpeechEncoder(config)
|
|
|
|
# Initialize weights and apply final processing
|
|
self.post_init()
|
|
|
|
def _mask_hidden_states(
|
|
self,
|
|
hidden_states: torch.FloatTensor,
|
|
mask_time_indices: Optional[torch.FloatTensor] = None,
|
|
attention_mask: Optional[torch.LongTensor] = None,
|
|
):
|
|
"""
|
|
Masks extracted features along time axis and/or along feature axis according to
|
|
[SpecAugment](https://huggingface.co/papers/1904.08779).
|
|
"""
|
|
|
|
# `config.apply_spec_augment` can set masking to False
|
|
if not getattr(self.config, "apply_spec_augment", True):
|
|
return hidden_states
|
|
|
|
# generate indices & apply SpecAugment along time axis
|
|
batch_size, sequence_length, hidden_size = hidden_states.size()
|
|
|
|
if mask_time_indices is not None:
|
|
# apply SpecAugment along time axis with given mask_time_indices
|
|
hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
|
|
elif self.config.mask_time_prob > 0 and self.training:
|
|
mask_time_indices = _compute_mask_indices(
|
|
(batch_size, sequence_length),
|
|
mask_prob=self.config.mask_time_prob,
|
|
mask_length=self.config.mask_time_length,
|
|
attention_mask=attention_mask,
|
|
min_masks=self.config.mask_time_min_masks,
|
|
)
|
|
mask_time_indices = torch.tensor(mask_time_indices, device=hidden_states.device, dtype=torch.bool)
|
|
hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
|
|
|
|
if self.config.mask_feature_prob > 0 and self.training:
|
|
# generate indices & apply SpecAugment along feature axis
|
|
mask_feature_indices = _compute_mask_indices(
|
|
(batch_size, hidden_size),
|
|
mask_prob=self.config.mask_feature_prob,
|
|
mask_length=self.config.mask_feature_length,
|
|
min_masks=self.config.mask_feature_min_masks,
|
|
)
|
|
mask_feature_indices = torch.tensor(mask_feature_indices, device=hidden_states.device, dtype=torch.bool)
|
|
mask_feature_indices = mask_feature_indices[:, None].expand(-1, sequence_length, -1)
|
|
hidden_states[mask_feature_indices] = 0
|
|
|
|
return hidden_states
|
|
|
|
@auto_docstring
|
|
def forward(
|
|
self,
|
|
input_values: Optional[torch.Tensor],
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
mask_time_indices: Optional[torch.FloatTensor] = None,
|
|
output_attentions: Optional[bool] = None,
|
|
output_hidden_states: Optional[bool] = None,
|
|
return_dict: Optional[bool] = None,
|
|
) -> Union[tuple, UniSpeechBaseModelOutput]:
|
|
r"""
|
|
mask_time_indices (`torch.BoolTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
Indices to mask extracted features for contrastive loss. When in training mode, model learns to predict
|
|
masked extracted features in *config.proj_codevector_dim* space.
|
|
"""
|
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
output_hidden_states = (
|
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
)
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
|
|
extract_features = self.feature_extractor(input_values)
|
|
extract_features = extract_features.transpose(1, 2)
|
|
|
|
if attention_mask is not None:
|
|
# compute reduced attention_mask corresponding to feature vectors
|
|
attention_mask = self._get_feature_vector_attention_mask(extract_features.shape[1], attention_mask)
|
|
|
|
hidden_states, extract_features = self.feature_projection(extract_features)
|
|
hidden_states = self._mask_hidden_states(
|
|
hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask
|
|
)
|
|
|
|
encoder_outputs = self.encoder(
|
|
hidden_states,
|
|
attention_mask=attention_mask,
|
|
output_attentions=output_attentions,
|
|
output_hidden_states=output_hidden_states,
|
|
return_dict=return_dict,
|
|
)
|
|
|
|
hidden_states = encoder_outputs[0]
|
|
|
|
if not return_dict:
|
|
return (hidden_states, extract_features) + encoder_outputs[1:]
|
|
|
|
return UniSpeechBaseModelOutput(
|
|
last_hidden_state=hidden_states,
|
|
extract_features=extract_features,
|
|
hidden_states=encoder_outputs.hidden_states,
|
|
attentions=encoder_outputs.attentions,
|
|
)
|
|
|
|
|
|
@auto_docstring(
|
|
custom_intro="""
|
|
UniSpeech Model with a vector-quantization module and ctc loss for pre-training.
|
|
"""
|
|
)
|
|
class UniSpeechForPreTraining(UniSpeechPreTrainedModel):
|
|
def __init__(self, config: UniSpeechConfig):
|
|
super().__init__(config)
|
|
self.unispeech = UniSpeechModel(config)
|
|
self.dropout_features = nn.Dropout(config.feat_quantizer_dropout)
|
|
|
|
self.quantizer = UniSpeechGumbelVectorQuantizer(config)
|
|
self.project_q = nn.Linear(config.codevector_dim, config.proj_codevector_dim)
|
|
self.project_hid = nn.Linear(config.proj_codevector_dim, config.hidden_size)
|
|
|
|
self.ctc_proj = nn.Linear(config.hidden_size, config.num_ctc_classes)
|
|
self.dropout = nn.Dropout(config.final_dropout)
|
|
|
|
# Initialize weights and apply final processing
|
|
self.post_init()
|
|
|
|
def set_gumbel_temperature(self, temperature: int):
|
|
"""
|
|
Set the Gumbel softmax temperature to a given value. Only necessary for training
|
|
"""
|
|
self.quantizer.temperature = temperature
|
|
|
|
def freeze_feature_extractor(self):
|
|
"""
|
|
Calling this function will disable the gradient computation for the feature encoder so that its parameters will
|
|
not be updated during training.
|
|
"""
|
|
warnings.warn(
|
|
"The method `freeze_feature_extractor` is deprecated and will be removed in Transformers v5. "
|
|
"Please use the equivalent `freeze_feature_encoder` method instead.",
|
|
FutureWarning,
|
|
)
|
|
self.freeze_feature_encoder()
|
|
|
|
def freeze_feature_encoder(self):
|
|
"""
|
|
Calling this function will disable the gradient computation for the feature encoder so that its parameter will
|
|
not be updated during training.
|
|
"""
|
|
self.unispeech.feature_extractor._freeze_parameters()
|
|
|
|
@staticmethod
|
|
def compute_contrastive_logits(
|
|
target_features: torch.FloatTensor,
|
|
negative_features: torch.FloatTensor,
|
|
predicted_features: torch.FloatTensor,
|
|
temperature: int = 1,
|
|
):
|
|
"""
|
|
Compute logits for contrastive loss based using cosine similarity as the distance measure between
|
|
`[positive_feature, negative_features]` and `[predicted_features]`. Additionally, temperature can be applied.
|
|
"""
|
|
target_features = torch.cat([target_features, negative_features], dim=0)
|
|
|
|
logits = torch.cosine_similarity(predicted_features.float(), target_features.float(), dim=-1)
|
|
logits = logits.type_as(target_features)
|
|
|
|
# apply temperature
|
|
logits = logits / temperature
|
|
return logits
|
|
|
|
@auto_docstring
|
|
def forward(
|
|
self,
|
|
input_values: Optional[torch.Tensor],
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
output_attentions: Optional[bool] = None,
|
|
output_hidden_states: Optional[bool] = None,
|
|
return_dict: Optional[bool] = None,
|
|
) -> Union[tuple, UniSpeechForPreTrainingOutput]:
|
|
r"""
|
|
Example:
|
|
|
|
```python
|
|
>>> import torch
|
|
>>> from transformers import AutoFeatureExtractor, UniSpeechForPreTraining
|
|
|
|
>>> feature_extractor = AutoFeatureExtractor.from_pretrained("microsoft/unispeech-large-1500h-cv")
|
|
>>> model = UniSpeechForPreTraining.from_pretrained("microsoft/unispeech-large-1500h-cv")
|
|
>>> # TODO: Add full pretraining example
|
|
```"""
|
|
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
|
|
outputs = self.unispeech(
|
|
input_values,
|
|
attention_mask=attention_mask,
|
|
output_attentions=output_attentions,
|
|
output_hidden_states=output_hidden_states,
|
|
return_dict=return_dict,
|
|
)
|
|
transformer_features = outputs[0]
|
|
|
|
# quantize all (unmasked) extracted features and project to final vq dim
|
|
extract_features = self.dropout_features(outputs[1])
|
|
quantized_features, codevector_perplexity = self.quantizer(extract_features)
|
|
|
|
# project quantized features twice
|
|
quantized_features = self.project_q(quantized_features.to(self.project_q.weight.dtype))
|
|
quantized_features = self.project_hid(quantized_features)
|
|
|
|
prob_replace_matrix = torch.empty(transformer_features.size(0), transformer_features.size(1)).fill_(
|
|
self.config.replace_prob
|
|
)
|
|
prob_replace_matrix = prob_replace_matrix.transpose(0, 1)
|
|
sampled_replace_matrix = torch.bernoulli(prob_replace_matrix).bool().to(transformer_features.device)
|
|
sampled_replace_matrix = sampled_replace_matrix.transpose(0, 1)
|
|
sampled_replace_matrix = sampled_replace_matrix.unsqueeze(-1)
|
|
logits = transformer_features.masked_fill(sampled_replace_matrix, 0.0) + (
|
|
quantized_features.masked_fill(~sampled_replace_matrix, 0.0)
|
|
)
|
|
|
|
# project to ctc units
|
|
logits = self.dropout(logits)
|
|
logits = self.ctc_proj(logits)
|
|
|
|
# TODO(PVP) - add negative sampling & loss computation
|
|
loss = None
|
|
if not return_dict:
|
|
if loss is not None:
|
|
return (loss, transformer_features, quantized_features, codevector_perplexity) + outputs[2:]
|
|
return (transformer_features, quantized_features, codevector_perplexity) + outputs[2:]
|
|
|
|
return UniSpeechForPreTrainingOutput(
|
|
loss=loss,
|
|
projected_states=transformer_features,
|
|
projected_quantized_states=quantized_features,
|
|
codevector_perplexity=codevector_perplexity,
|
|
hidden_states=outputs.hidden_states,
|
|
attentions=outputs.attentions,
|
|
)
|
|
|
|
|
|
_HIDDEN_STATES_START_POSITION = 2
|
|
|
|
|
|
@auto_docstring(
|
|
custom_intro="""
|
|
UniSpeech Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC).
|
|
"""
|
|
)
|
|
class UniSpeechForCTC(UniSpeechPreTrainedModel):
|
|
def __init__(self, config, target_lang: Optional[str] = None):
|
|
r"""
|
|
target_lang (`str`, *optional*):
|
|
Language id of adapter weights. Adapter weights are stored in the format adapter.<lang>.safetensors or
|
|
adapter.<lang>.bin. Only relevant when using an instance of [`UniSpeechForCTC`] with adapters. Uses 'eng' by
|
|
default.
|
|
"""
|
|
super().__init__(config)
|
|
|
|
self.unispeech = UniSpeechModel(config)
|
|
self.dropout = nn.Dropout(config.final_dropout)
|
|
|
|
self.target_lang = target_lang
|
|
|
|
if config.vocab_size is None:
|
|
raise ValueError(
|
|
f"You are trying to instantiate {self.__class__} with a configuration that "
|
|
"does not define the vocabulary size of the language model head. Please "
|
|
"instantiate the model as follows: `UniSpeechForCTC.from_pretrained(..., vocab_size=vocab_size)`. "
|
|
"or define `vocab_size` of your model's configuration."
|
|
)
|
|
output_hidden_size = (
|
|
config.output_hidden_size if hasattr(config, "add_adapter") and config.add_adapter else config.hidden_size
|
|
)
|
|
self.lm_head = nn.Linear(output_hidden_size, config.vocab_size)
|
|
|
|
# Initialize weights and apply final processing
|
|
self.post_init()
|
|
|
|
def tie_weights(self):
|
|
"""
|
|
This method overwrites [`~PreTrainedModel.tie_weights`] so that adapter weights can be correctly loaded when
|
|
passing `target_lang=...` to `from_pretrained(...)`.
|
|
|
|
This method is **not** supposed to be called by the user and is prone to be changed in the future.
|
|
"""
|
|
|
|
# Note that `tie_weights` is usually used to tie input and output embedding weights. The method is re-purposed to
|
|
# correctly load adapter layers for UniSpeech so that we do not have to introduce a new API to
|
|
# [`PreTrainedModel`]. While slightly hacky, UniSpeech never has to tie input and output embeddings, so that it is
|
|
# ok to repurpose this function here.
|
|
target_lang = self.target_lang
|
|
|
|
if target_lang is not None and getattr(self.config, "adapter_attn_dim", None) is None:
|
|
raise ValueError(f"Cannot pass `target_lang`: {target_lang} if `config.adapter_attn_dim` is not defined.")
|
|
elif target_lang is None and getattr(self.config, "adapter_attn_dim", None) is not None:
|
|
logger.info("By default `target_lang` is set to 'eng'.")
|
|
elif target_lang is not None:
|
|
self.load_adapter(target_lang, force_load=True)
|
|
|
|
def freeze_feature_extractor(self):
|
|
"""
|
|
Calling this function will disable the gradient computation for the feature encoder so that its parameter will
|
|
not be updated during training.
|
|
"""
|
|
warnings.warn(
|
|
"The method `freeze_feature_extractor` is deprecated and will be removed in Transformers v5. "
|
|
"Please use the equivalent `freeze_feature_encoder` method instead.",
|
|
FutureWarning,
|
|
)
|
|
self.freeze_feature_encoder()
|
|
|
|
def freeze_feature_encoder(self):
|
|
"""
|
|
Calling this function will disable the gradient computation for the feature encoder so that its parameter will
|
|
not be updated during training.
|
|
"""
|
|
self.unispeech.feature_extractor._freeze_parameters()
|
|
|
|
def freeze_base_model(self):
|
|
"""
|
|
Calling this function will disable the gradient computation for the base model so that its parameters will not
|
|
be updated during training. Only the classification head will be updated.
|
|
"""
|
|
for param in self.unispeech.parameters():
|
|
param.requires_grad = False
|
|
|
|
@auto_docstring
|
|
def forward(
|
|
self,
|
|
input_values: Optional[torch.Tensor],
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
output_attentions: Optional[bool] = None,
|
|
output_hidden_states: Optional[bool] = None,
|
|
return_dict: Optional[bool] = None,
|
|
labels: Optional[torch.Tensor] = None,
|
|
) -> Union[tuple, CausalLMOutput]:
|
|
r"""
|
|
labels (`torch.LongTensor` of shape `(batch_size, target_length)`, *optional*):
|
|
Labels for connectionist temporal classification. Note that `target_length` has to be smaller or equal to
|
|
the sequence length of the output logits. Indices are selected in `[-100, 0, ..., config.vocab_size - 1]`.
|
|
All labels set to `-100` are ignored (masked), the loss is only computed for labels in `[0, ...,
|
|
config.vocab_size - 1]`.
|
|
"""
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
|
|
if labels is not None and labels.max() >= self.config.vocab_size:
|
|
raise ValueError(f"Label values must be <= vocab_size: {self.config.vocab_size}")
|
|
|
|
outputs = self.unispeech(
|
|
input_values,
|
|
attention_mask=attention_mask,
|
|
output_attentions=output_attentions,
|
|
output_hidden_states=output_hidden_states,
|
|
return_dict=return_dict,
|
|
)
|
|
|
|
hidden_states = outputs[0]
|
|
hidden_states = self.dropout(hidden_states)
|
|
|
|
logits = self.lm_head(hidden_states)
|
|
|
|
loss = None
|
|
if labels is not None:
|
|
# retrieve loss input_lengths from attention_mask
|
|
attention_mask = (
|
|
attention_mask if attention_mask is not None else torch.ones_like(input_values, dtype=torch.long)
|
|
)
|
|
input_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long)
|
|
|
|
# assuming that padded tokens are filled with -100
|
|
# when not being attended to
|
|
labels_mask = labels >= 0
|
|
target_lengths = labels_mask.sum(-1)
|
|
flattened_targets = labels.masked_select(labels_mask)
|
|
|
|
# ctc_loss doesn't support fp16
|
|
log_probs = nn.functional.log_softmax(logits, dim=-1, dtype=torch.float32).transpose(0, 1)
|
|
|
|
with torch.backends.cudnn.flags(enabled=False):
|
|
loss = nn.functional.ctc_loss(
|
|
log_probs,
|
|
flattened_targets,
|
|
input_lengths,
|
|
target_lengths,
|
|
blank=self.config.pad_token_id,
|
|
reduction=self.config.ctc_loss_reduction,
|
|
zero_infinity=self.config.ctc_zero_infinity,
|
|
)
|
|
|
|
if not return_dict:
|
|
output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
|
|
return ((loss,) + output) if loss is not None else output
|
|
|
|
return CausalLMOutput(
|
|
loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions
|
|
)
|
|
|
|
|
|
@auto_docstring(
|
|
custom_intro="""
|
|
UniSpeech Model with a sequence classification head on top (a linear layer over the pooled output) for tasks like
|
|
SUPERB Keyword Spotting.
|
|
"""
|
|
)
|
|
class UniSpeechForSequenceClassification(UniSpeechPreTrainedModel):
|
|
def __init__(self, config):
|
|
super().__init__(config)
|
|
|
|
if hasattr(config, "add_adapter") and config.add_adapter:
|
|
raise ValueError(
|
|
"Sequence classification does not support the use of UniSpeech adapters (config.add_adapter=True)"
|
|
)
|
|
self.unispeech = UniSpeechModel(config)
|
|
num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings
|
|
if config.use_weighted_layer_sum:
|
|
self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
|
|
self.projector = nn.Linear(config.hidden_size, config.classifier_proj_size)
|
|
self.classifier = nn.Linear(config.classifier_proj_size, config.num_labels)
|
|
|
|
# Initialize weights and apply final processing
|
|
self.post_init()
|
|
|
|
def freeze_feature_extractor(self):
|
|
"""
|
|
Calling this function will disable the gradient computation for the feature encoder so that its parameters will
|
|
not be updated during training.
|
|
"""
|
|
warnings.warn(
|
|
"The method `freeze_feature_extractor` is deprecated and will be removed in Transformers v5. "
|
|
"Please use the equivalent `freeze_feature_encoder` method instead.",
|
|
FutureWarning,
|
|
)
|
|
self.freeze_feature_encoder()
|
|
|
|
def freeze_feature_encoder(self):
|
|
"""
|
|
Calling this function will disable the gradient computation for the feature encoder so that its parameter will
|
|
not be updated during training.
|
|
"""
|
|
self.unispeech.feature_extractor._freeze_parameters()
|
|
|
|
def freeze_base_model(self):
|
|
"""
|
|
Calling this function will disable the gradient computation for the base model so that its parameters will not
|
|
be updated during training. Only the classification head will be updated.
|
|
"""
|
|
for param in self.unispeech.parameters():
|
|
param.requires_grad = False
|
|
|
|
@auto_docstring
|
|
def forward(
|
|
self,
|
|
input_values: Optional[torch.Tensor],
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
output_attentions: Optional[bool] = None,
|
|
output_hidden_states: Optional[bool] = None,
|
|
return_dict: Optional[bool] = None,
|
|
labels: Optional[torch.Tensor] = None,
|
|
) -> Union[tuple, SequenceClassifierOutput]:
|
|
r"""
|
|
input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
|
|
Float values of input raw speech waveform. Values can be obtained by loading a `.flac` or `.wav` audio file
|
|
into an array of type `list[float]` or a `numpy.ndarray`, *e.g.* via the soundfile library (`pip install
|
|
soundfile`). To prepare the array into `input_values`, the [`AutoProcessor`] should be used for padding and
|
|
conversion into a tensor of type `torch.FloatTensor`. See [`UniSpeechProcessor.__call__`] for details.
|
|
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
|
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
|
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
|
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
|
"""
|
|
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states
|
|
|
|
outputs = self.unispeech(
|
|
input_values,
|
|
attention_mask=attention_mask,
|
|
output_attentions=output_attentions,
|
|
output_hidden_states=output_hidden_states,
|
|
return_dict=return_dict,
|
|
)
|
|
|
|
if self.config.use_weighted_layer_sum:
|
|
hidden_states = outputs[_HIDDEN_STATES_START_POSITION]
|
|
hidden_states = torch.stack(hidden_states, dim=1)
|
|
norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
|
|
hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
|
|
else:
|
|
hidden_states = outputs[0]
|
|
|
|
hidden_states = self.projector(hidden_states)
|
|
if attention_mask is None:
|
|
pooled_output = hidden_states.mean(dim=1)
|
|
else:
|
|
padding_mask = self._get_feature_vector_attention_mask(hidden_states.shape[1], attention_mask)
|
|
expand_padding_mask = padding_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
|
|
hidden_states[~expand_padding_mask] = 0.0
|
|
pooled_output = hidden_states.sum(dim=1) / padding_mask.sum(dim=1).view(-1, 1)
|
|
|
|
logits = self.classifier(pooled_output)
|
|
|
|
loss = None
|
|
if labels is not None:
|
|
loss_fct = CrossEntropyLoss()
|
|
loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
|
|
|
|
if not return_dict:
|
|
output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
|
|
return ((loss,) + output) if loss is not None else output
|
|
|
|
return SequenceClassifierOutput(
|
|
loss=loss,
|
|
logits=logits,
|
|
hidden_states=outputs.hidden_states,
|
|
attentions=outputs.attentions,
|
|
)
|
|
|
|
|
|
__all__ = [
|
|
"UniSpeechForCTC",
|
|
"UniSpeechForPreTraining",
|
|
"UniSpeechForSequenceClassification",
|
|
"UniSpeechModel",
|
|
"UniSpeechPreTrainedModel",
|
|
]
|