mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 18:51:14 +06:00
1312 lines
58 KiB
Python
1312 lines
58 KiB
Python
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
|
# This file was automatically generated from src/transformers/models/moonshine/modular_moonshine.py.
|
|
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
|
# the file from the modular. If any change should be done, please apply the change to the
|
|
# modular_moonshine.py file directly. One of our CI enforces this.
|
|
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
|
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
from typing import Callable, Optional, Union
|
|
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
from ...activations import ACT2FN
|
|
from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
|
|
from ...generation import GenerationMixin
|
|
from ...masking_utils import create_causal_mask
|
|
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_attention_mask_for_sdpa
|
|
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
|
from ...modeling_layers import GradientCheckpointingLayer
|
|
from ...modeling_outputs import (
|
|
BaseModelOutput,
|
|
BaseModelOutputWithPast,
|
|
BaseModelOutputWithPastAndCrossAttentions,
|
|
Seq2SeqLMOutput,
|
|
Seq2SeqModelOutput,
|
|
)
|
|
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
|
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
|
from ...processing_utils import Unpack
|
|
from ...utils import auto_docstring, can_return_tuple, logging
|
|
from ...utils.generic import check_model_inputs
|
|
from .configuration_moonshine import MoonshineConfig
|
|
|
|
|
|
logger = logging.get_logger(__name__)
|
|
|
|
|
|
class MoonshineEncoderMLP(nn.Module):
|
|
def __init__(self, config, hidden_act):
|
|
super().__init__()
|
|
self.config = config
|
|
self.activation_fn = ACT2FN[hidden_act]
|
|
self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
|
|
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
|
|
|
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
|
hidden_states = self.fc1(hidden_states)
|
|
hidden_states = self.activation_fn(hidden_states)
|
|
hidden_states = self.fc2(hidden_states)
|
|
return hidden_states
|
|
|
|
|
|
class MoonshineDecoderMLP(nn.Module):
|
|
def __init__(self, config, hidden_act):
|
|
super().__init__()
|
|
self.config = config
|
|
self.activation_fn = ACT2FN[hidden_act]
|
|
self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size * 2)
|
|
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
|
|
|
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
|
hidden_states = self.fc1(hidden_states)
|
|
hidden_states, gate = hidden_states.chunk(2, dim=-1)
|
|
hidden_states = self.activation_fn(gate) * hidden_states
|
|
hidden_states = self.fc2(hidden_states)
|
|
return hidden_states
|
|
|
|
|
|
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
|
"""
|
|
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
|
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
|
"""
|
|
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
|
if n_rep == 1:
|
|
return hidden_states
|
|
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
|
|
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
|
|
|
|
|
def eager_attention_forward(
|
|
module: nn.Module,
|
|
query: torch.Tensor,
|
|
key: torch.Tensor,
|
|
value: torch.Tensor,
|
|
attention_mask: Optional[torch.Tensor],
|
|
scaling: float,
|
|
dropout: float = 0.0,
|
|
**kwargs,
|
|
):
|
|
key_states = repeat_kv(key, module.num_key_value_groups)
|
|
value_states = repeat_kv(value, module.num_key_value_groups)
|
|
|
|
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
|
|
if attention_mask is not None:
|
|
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
|
attn_weights = attn_weights + causal_mask
|
|
|
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
|
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
|
attn_output = torch.matmul(attn_weights, value_states)
|
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
|
|
|
return attn_output, attn_weights
|
|
|
|
|
|
def rotate_half(x):
|
|
"""Rotates half the hidden dims of the input."""
|
|
x1 = x[..., 0::2]
|
|
x2 = x[..., 1::2]
|
|
return torch.stack((-x2, x1), dim=-1).flatten(-2)
|
|
|
|
|
|
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
|
"""Applies Rotary Position Embedding to the query and key tensors.
|
|
|
|
Args:
|
|
q (`torch.Tensor`): The query tensor.
|
|
k (`torch.Tensor`): The key tensor.
|
|
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
|
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
|
position_ids (`torch.Tensor`, *optional*):
|
|
Deprecated and unused.
|
|
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
|
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
|
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
|
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
|
|
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
|
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
|
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
|
Returns:
|
|
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
|
"""
|
|
cos = cos.unsqueeze(unsqueeze_dim)
|
|
sin = sin.unsqueeze(unsqueeze_dim)
|
|
|
|
# Interleave them instead of usual shape
|
|
cos = cos[..., : cos.shape[-1] // 2].repeat_interleave(2, dim=-1)
|
|
sin = sin[..., : sin.shape[-1] // 2].repeat_interleave(2, dim=-1)
|
|
|
|
# Keep half or full tensor for later concatenation
|
|
rotary_dim = cos.shape[-1]
|
|
q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:]
|
|
k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:]
|
|
|
|
# Apply rotary embeddings on the first half or full tensor
|
|
q_embed = (q_rot * cos) + (rotate_half(q_rot) * sin)
|
|
k_embed = (k_rot * cos) + (rotate_half(k_rot) * sin)
|
|
|
|
# Concatenate back to full shape
|
|
q_embed = torch.cat([q_embed, q_pass], dim=-1)
|
|
k_embed = torch.cat([k_embed, k_pass], dim=-1)
|
|
return q_embed, k_embed
|
|
|
|
|
|
class MoonshineAttention(nn.Module):
|
|
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
|
|
|
return_hooks = {"attentions", 1}
|
|
|
|
def __init__(
|
|
self,
|
|
config: MoonshineConfig,
|
|
layer_idx: int,
|
|
is_causal: bool,
|
|
num_attention_heads: int,
|
|
num_key_value_heads: int,
|
|
):
|
|
super().__init__()
|
|
config.update({"num_attention_heads": num_attention_heads, "num_key_value_heads": num_key_value_heads})
|
|
self.config = config
|
|
self.layer_idx = layer_idx
|
|
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
|
|
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
|
|
self.scaling = self.head_dim**-0.5
|
|
self.attention_dropout = config.attention_dropout
|
|
self.is_causal = is_causal
|
|
|
|
self.q_proj = nn.Linear(
|
|
config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
|
|
)
|
|
self.k_proj = nn.Linear(
|
|
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
|
|
)
|
|
self.v_proj = nn.Linear(
|
|
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
|
|
)
|
|
self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False)
|
|
|
|
# Pad head dimension to the next specified multiple.
|
|
if self.config.pad_head_dim_to_multiple_of is not None:
|
|
target_multiple = self.config.pad_head_dim_to_multiple_of
|
|
target_head_dim = target_multiple * ((self.head_dim + target_multiple - 1) // target_multiple)
|
|
self.head_dim_padding = target_head_dim - self.head_dim
|
|
else:
|
|
self.head_dim_padding = 0
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
past_key_value: Optional[Cache] = None,
|
|
cache_position: Optional[torch.LongTensor] = None,
|
|
key_value_states: Optional[torch.Tensor] = None,
|
|
**kwargs: Unpack[FlashAttentionKwargs],
|
|
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
|
|
bsz, q_len = hidden_states.shape[:-1]
|
|
|
|
query_states = (
|
|
self.q_proj(hidden_states).view(bsz, q_len, self.config.num_key_value_heads, self.head_dim).transpose(1, 2)
|
|
)
|
|
|
|
is_cross_attention = key_value_states is not None
|
|
if past_key_value is not None:
|
|
is_updated = past_key_value.is_updated.get(self.layer_idx)
|
|
if is_cross_attention:
|
|
# after the first generated id, we can subsequently re-use all key/value_states from cache
|
|
past_key_value.is_updated[self.layer_idx] = True
|
|
past_key_value = past_key_value.cross_attention_cache
|
|
else:
|
|
past_key_value = past_key_value.self_attention_cache
|
|
|
|
# use key_value_states if cross attention
|
|
current_states = key_value_states if key_value_states is not None else hidden_states
|
|
if is_cross_attention and past_key_value and is_updated:
|
|
key_states = past_key_value.key_cache[self.layer_idx]
|
|
value_states = past_key_value.value_cache[self.layer_idx]
|
|
else:
|
|
key_states = (
|
|
self.k_proj(current_states)
|
|
.view(bsz, -1, self.config.num_key_value_heads, self.head_dim)
|
|
.transpose(1, 2)
|
|
)
|
|
value_states = (
|
|
self.v_proj(current_states)
|
|
.view(bsz, -1, self.config.num_key_value_heads, self.head_dim)
|
|
.transpose(1, 2)
|
|
)
|
|
if is_cross_attention and past_key_value is not None:
|
|
key_states, value_states = past_key_value.update(
|
|
key_states, value_states, self.layer_idx, {"cache_position": cache_position}
|
|
)
|
|
|
|
if not is_cross_attention:
|
|
cos, sin = position_embeddings
|
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
|
|
|
if past_key_value is not None:
|
|
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
|
key_states, value_states = past_key_value.update(
|
|
key_states, value_states, self.layer_idx, cache_kwargs
|
|
)
|
|
|
|
attention_interface: Callable = eager_attention_forward
|
|
if self.config._attn_implementation != "eager":
|
|
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
|
|
|
is_causal = True if self.is_causal and attention_mask is None and q_len > 1 else False
|
|
|
|
if self.head_dim_padding > 0:
|
|
query_states = torch.nn.functional.pad(query_states, (0, self.head_dim_padding))
|
|
key_states = torch.nn.functional.pad(key_states, (0, self.head_dim_padding))
|
|
value_states = torch.nn.functional.pad(value_states, (0, self.head_dim_padding))
|
|
|
|
attn_output, attn_weights = attention_interface(
|
|
self,
|
|
query_states,
|
|
key_states,
|
|
value_states,
|
|
attention_mask,
|
|
dropout=0.0 if not self.training else self.attention_dropout,
|
|
scaling=self.scaling,
|
|
is_causal=is_causal,
|
|
**kwargs,
|
|
)
|
|
|
|
if self.head_dim_padding > 0:
|
|
attn_output = attn_output[..., : -self.head_dim_padding]
|
|
|
|
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
|
|
attn_output = self.o_proj(attn_output)
|
|
return attn_output, attn_weights
|
|
|
|
|
|
class MoonshineRotaryEmbedding(nn.Module):
|
|
def __init__(self, config: MoonshineConfig, device=None):
|
|
super().__init__()
|
|
# BC: "rope_type" was originally "type"
|
|
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
|
|
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
|
else:
|
|
self.rope_type = "default"
|
|
self.max_seq_len_cached = config.max_position_embeddings
|
|
self.original_max_seq_len = config.max_position_embeddings
|
|
|
|
self.config = config
|
|
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
|
|
|
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
|
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
|
self.original_inv_freq = self.inv_freq
|
|
|
|
@torch.no_grad()
|
|
@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
|
|
def forward(self, x, position_ids):
|
|
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
|
|
position_ids_expanded = position_ids[:, None, :].float()
|
|
|
|
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
|
with torch.autocast(device_type=device_type, enabled=False): # Force float32
|
|
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
|
emb = torch.cat((freqs, freqs), dim=-1)
|
|
cos = emb.cos() * self.attention_scaling
|
|
sin = emb.sin() * self.attention_scaling
|
|
|
|
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
|
|
|
|
|
class MoonshineEncoderLayer(GradientCheckpointingLayer):
|
|
return_hooks = {"hidden_states", 0}
|
|
|
|
def __init__(self, config: MoonshineConfig, layer_idx: int):
|
|
super().__init__()
|
|
self.hidden_size = config.hidden_size
|
|
|
|
self.self_attn = MoonshineAttention(
|
|
config=config,
|
|
layer_idx=layer_idx,
|
|
is_causal=False,
|
|
num_attention_heads=config.encoder_num_attention_heads,
|
|
num_key_value_heads=config.encoder_num_key_value_heads,
|
|
)
|
|
|
|
self.mlp = MoonshineEncoderMLP(config, config.encoder_hidden_act)
|
|
self.input_layernorm = nn.LayerNorm(config.hidden_size, bias=False)
|
|
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, bias=False)
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
position_ids: Optional[torch.LongTensor] = None,
|
|
past_key_value: Optional[Cache] = None,
|
|
use_cache: Optional[bool] = False,
|
|
cache_position: Optional[torch.LongTensor] = None,
|
|
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
|
**kwargs: Unpack[FlashAttentionKwargs],
|
|
) -> tuple[torch.Tensor]:
|
|
residual = hidden_states
|
|
hidden_states = self.input_layernorm(hidden_states)
|
|
|
|
# Self Attention
|
|
hidden_states = self.self_attn(
|
|
hidden_states=hidden_states,
|
|
attention_mask=attention_mask,
|
|
position_ids=position_ids,
|
|
past_key_value=past_key_value,
|
|
use_cache=use_cache,
|
|
cache_position=cache_position,
|
|
position_embeddings=position_embeddings,
|
|
**kwargs,
|
|
)[0]
|
|
hidden_states = residual + hidden_states
|
|
|
|
# Fully Connected
|
|
residual = hidden_states
|
|
hidden_states = self.post_attention_layernorm(hidden_states)
|
|
hidden_states = self.mlp(hidden_states)
|
|
hidden_states = residual + hidden_states
|
|
return hidden_states
|
|
|
|
|
|
class MoonshineDecoderLayer(GradientCheckpointingLayer):
|
|
def __init__(self, config: MoonshineConfig, layer_idx: Optional[int] = None):
|
|
super().__init__()
|
|
self.hidden_size = config.hidden_size
|
|
|
|
self.self_attn = MoonshineAttention(
|
|
config=config,
|
|
layer_idx=layer_idx,
|
|
is_causal=True,
|
|
num_attention_heads=config.decoder_num_attention_heads,
|
|
num_key_value_heads=config.decoder_num_key_value_heads,
|
|
)
|
|
self.encoder_attn = MoonshineAttention(
|
|
config=config,
|
|
layer_idx=layer_idx,
|
|
is_causal=False,
|
|
num_attention_heads=config.decoder_num_attention_heads,
|
|
num_key_value_heads=config.decoder_num_key_value_heads,
|
|
)
|
|
|
|
self.mlp = MoonshineDecoderMLP(config, config.decoder_hidden_act)
|
|
self.input_layernorm = nn.LayerNorm(config.hidden_size, bias=False)
|
|
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, bias=False)
|
|
self.final_layernorm = nn.LayerNorm(config.hidden_size, bias=False)
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
encoder_hidden_states: Optional[torch.Tensor] = None,
|
|
encoder_attention_mask: Optional[torch.Tensor] = None,
|
|
position_ids: Optional[torch.LongTensor] = None,
|
|
encoder_position_ids: Optional[torch.LongTensor] = None,
|
|
past_key_value: Optional[Cache] = None,
|
|
output_attentions: Optional[bool] = False,
|
|
use_cache: Optional[bool] = False,
|
|
cache_position: Optional[torch.LongTensor] = None,
|
|
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
|
encoder_position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
|
|
**kwargs,
|
|
) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
|
residual = hidden_states
|
|
|
|
hidden_states = self.input_layernorm(hidden_states)
|
|
|
|
# Self Attention
|
|
hidden_states, self_attn_weights = self.self_attn(
|
|
hidden_states=hidden_states,
|
|
attention_mask=attention_mask,
|
|
position_ids=position_ids,
|
|
past_key_value=past_key_value,
|
|
output_attentions=output_attentions,
|
|
use_cache=use_cache,
|
|
cache_position=cache_position,
|
|
position_embeddings=position_embeddings,
|
|
**kwargs,
|
|
)
|
|
hidden_states = residual + hidden_states
|
|
|
|
# Cross-Attention Block
|
|
cross_attn_weights = None
|
|
if encoder_hidden_states is not None:
|
|
residual = hidden_states
|
|
hidden_states = self.post_attention_layernorm(hidden_states)
|
|
hidden_states, cross_attn_weights = self.encoder_attn(
|
|
hidden_states=hidden_states,
|
|
key_value_states=encoder_hidden_states,
|
|
attention_mask=encoder_attention_mask,
|
|
past_key_value=past_key_value,
|
|
output_attentions=output_attentions,
|
|
use_cache=use_cache,
|
|
)
|
|
hidden_states = residual + hidden_states
|
|
|
|
# Fully Connected
|
|
residual = hidden_states
|
|
hidden_states = self.final_layernorm(hidden_states)
|
|
hidden_states = self.mlp(hidden_states)
|
|
hidden_states = residual + hidden_states
|
|
|
|
outputs = (hidden_states,)
|
|
|
|
if output_attentions:
|
|
outputs += (self_attn_weights, cross_attn_weights)
|
|
|
|
return outputs
|
|
|
|
|
|
@auto_docstring
|
|
class MoonshinePreTrainedModel(PreTrainedModel):
|
|
config_class = MoonshineConfig
|
|
base_model_prefix = "model"
|
|
main_input_name = "input_values"
|
|
supports_gradient_checkpointing = True
|
|
_no_split_modules = ["MoonshineEncoderLayer", "MoonshineDecoderLayer"]
|
|
_supports_flash_attn_2 = True
|
|
_supports_sdpa = True
|
|
_supports_cache_class = True
|
|
_supports_static_cache = True
|
|
|
|
def _init_weights(self, module):
|
|
std = self.config.initializer_range
|
|
if isinstance(module, (nn.Linear, nn.Conv1d)):
|
|
module.weight.data.normal_(mean=0.0, std=std)
|
|
if module.bias is not None:
|
|
module.bias.data.zero_()
|
|
elif isinstance(module, (nn.GroupNorm, nn.LayerNorm)):
|
|
module.weight.data.fill_(1.0)
|
|
if module.bias is not None:
|
|
module.bias.data.zero_()
|
|
elif isinstance(module, nn.Embedding):
|
|
module.weight.data.normal_(mean=0.0, std=std)
|
|
if module.padding_idx is not None:
|
|
module.weight.data[module.padding_idx].zero_()
|
|
|
|
def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor):
|
|
"""
|
|
Computes the output length of the convolutional layers
|
|
"""
|
|
output_conv1_length = int((input_lengths - 127) / 64 + 1)
|
|
output_conv2_length = int((output_conv1_length - 7) / 3 + 1)
|
|
output_conv3_length = int((output_conv2_length - 3) / 2 + 1)
|
|
|
|
return output_conv3_length
|
|
|
|
|
|
class MoonshineEncoder(MoonshinePreTrainedModel):
|
|
"""
|
|
Transformer encoder consisting of *config.num_hidden_layers* layers. Each layer is a [`MoonshineEncoderLayer`]
|
|
|
|
Args:
|
|
config: MoonshineConfig
|
|
"""
|
|
|
|
main_input_name = "input_values"
|
|
|
|
def __init__(self, config: MoonshineConfig):
|
|
super().__init__(config)
|
|
self.config = config
|
|
embed_dim = config.hidden_size
|
|
|
|
self.conv1 = nn.Conv1d(1, embed_dim, kernel_size=127, stride=64, bias=False)
|
|
self.conv2 = nn.Conv1d(embed_dim, 2 * embed_dim, kernel_size=7, stride=3)
|
|
self.conv3 = nn.Conv1d(2 * embed_dim, embed_dim, kernel_size=3, stride=2)
|
|
self.groupnorm = nn.GroupNorm(num_groups=1, num_channels=embed_dim, eps=1e-5)
|
|
|
|
self.rotary_emb = MoonshineRotaryEmbedding(config=config)
|
|
|
|
self.layers = nn.ModuleList(
|
|
[MoonshineEncoderLayer(config, idx) for idx in range(config.encoder_num_hidden_layers)]
|
|
)
|
|
self.layer_norm = nn.LayerNorm(embed_dim, bias=False)
|
|
|
|
self.gradient_checkpointing = False
|
|
self.post_init()
|
|
|
|
def get_input_embeddings(self) -> nn.Module:
|
|
return self.conv1
|
|
|
|
def set_input_embeddings(self, value: nn.Module):
|
|
self.conv1 = value
|
|
|
|
@can_return_tuple
|
|
def forward(
|
|
self,
|
|
input_values: Optional[torch.FloatTensor] = None,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
output_attentions: Optional[bool] = None,
|
|
output_hidden_states: Optional[bool] = None,
|
|
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
|
|
) -> BaseModelOutputWithPast:
|
|
r"""
|
|
Args:
|
|
input_values (`torch.FloatTensor` of shape `(batch_size, audio_length)`):
|
|
Float values of the raw speech waveform. Raw speech waveform can be
|
|
obtained by loading a `.flac` or `.wav` audio file into an array of type `list[float]` or a
|
|
`numpy.ndarray`, *e.g.* via the soundfile library (`pip install soundfile`). To prepare the array into
|
|
`input_values`, the [`AutoFeatureExtractor`] should be used for padding
|
|
and conversion into a tensor of type `torch.FloatTensor`.
|
|
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
Mask to avoid performing attention on padding indices in `input_values`. Mask values selected in `[0, 1]`:
|
|
- 1 for tokens that are **not masked**,
|
|
- 0 for tokens that are **masked**.
|
|
[What are attention masks?](../glossary#attention-mask)
|
|
output_attentions (`bool`, *optional*):
|
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
|
tensors for more detail.
|
|
output_hidden_states (`bool`, *optional*):
|
|
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
|
more detail.
|
|
return_dict (`bool`, *optional*):
|
|
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
|
"""
|
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
output_hidden_states = (
|
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
)
|
|
|
|
if input_values is None:
|
|
raise ValueError("You must specify input_values.")
|
|
|
|
# conv downsampling
|
|
input_values = input_values.unsqueeze(1)
|
|
hidden_states = nn.functional.tanh(self.conv1(input_values))
|
|
hidden_states = self.groupnorm(hidden_states)
|
|
hidden_states = nn.functional.gelu(self.conv2(hidden_states))
|
|
hidden_states = nn.functional.gelu(self.conv3(hidden_states))
|
|
hidden_states = hidden_states.permute(0, 2, 1)
|
|
|
|
# attention mask downsampling
|
|
if attention_mask is not None:
|
|
mask_len = self._get_feat_extract_output_lengths(attention_mask.shape[-1])
|
|
downsample_stride = 64 * 3 * 2 # conv strides
|
|
attention_mask = attention_mask[..., ::downsample_stride][..., :mask_len]
|
|
if self.config._attn_implementation == "flash_attention_2":
|
|
attention_mask = attention_mask if (attention_mask == 0.0).any() else None
|
|
|
|
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
|
|
elif self.config._attn_implementation == "sdpa" and not output_attentions:
|
|
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
|
attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, hidden_states.dtype)
|
|
else:
|
|
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
|
attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype)
|
|
|
|
position_ids = torch.arange(0, hidden_states.shape[1], device=hidden_states.device).unsqueeze(0)
|
|
|
|
# create position embeddings to be shared across the decoder layers
|
|
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
|
|
|
# encoder layers
|
|
all_hidden_states = () if output_hidden_states else None
|
|
all_self_attns = () if output_attentions else None
|
|
|
|
for encoder_layer in self.layers:
|
|
if output_hidden_states:
|
|
all_hidden_states += (hidden_states,)
|
|
|
|
layer_outputs = encoder_layer(
|
|
hidden_states,
|
|
attention_mask=attention_mask,
|
|
position_ids=position_ids,
|
|
output_attentions=output_attentions,
|
|
position_embeddings=position_embeddings,
|
|
**flash_attn_kwargs,
|
|
)
|
|
|
|
hidden_states = layer_outputs[0]
|
|
|
|
if output_attentions:
|
|
all_self_attns += (layer_outputs[1],)
|
|
|
|
hidden_states = self.layer_norm(hidden_states)
|
|
|
|
# add hidden states from the last encoder layer
|
|
if output_hidden_states:
|
|
all_hidden_states += (hidden_states,)
|
|
|
|
return BaseModelOutputWithPast(
|
|
last_hidden_state=hidden_states,
|
|
hidden_states=all_hidden_states,
|
|
attentions=all_self_attns,
|
|
)
|
|
|
|
|
|
@auto_docstring
|
|
class MoonshineDecoder(MoonshinePreTrainedModel):
|
|
main_input_name = "input_ids"
|
|
|
|
def __init__(self, config: MoonshineConfig):
|
|
super().__init__(config)
|
|
self.padding_idx = config.pad_token_id
|
|
self.vocab_size = config.vocab_size
|
|
|
|
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
|
self.layers = nn.ModuleList(
|
|
[MoonshineDecoderLayer(config, idx) for idx in range(config.decoder_num_hidden_layers)]
|
|
)
|
|
self.norm = nn.LayerNorm(config.hidden_size, bias=False)
|
|
self.rotary_emb = MoonshineRotaryEmbedding(config=config)
|
|
self.gradient_checkpointing = False
|
|
|
|
# Initialize weights and apply final processing
|
|
self.post_init()
|
|
|
|
def get_input_embeddings(self):
|
|
return self.embed_tokens
|
|
|
|
def set_input_embeddings(self, value):
|
|
self.embed_tokens = value
|
|
|
|
@check_model_inputs
|
|
@auto_docstring
|
|
def forward(
|
|
self,
|
|
input_ids: Optional[torch.LongTensor] = None,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
position_ids: Optional[torch.LongTensor] = None,
|
|
past_key_values: Optional[Cache] = None,
|
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
use_cache: Optional[bool] = None,
|
|
output_attentions: Optional[bool] = None,
|
|
output_hidden_states: Optional[bool] = None,
|
|
cache_position: Optional[torch.LongTensor] = None,
|
|
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
|
encoder_attention_mask: Optional[torch.Tensor] = None,
|
|
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
|
|
) -> Union[tuple, BaseModelOutputWithPast]:
|
|
r"""
|
|
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*):
|
|
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
|
|
of the decoder.
|
|
encoder_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
Mask to avoid performing attention on padding indices in `encoder_hidden_states`. Mask values selected in `[0, 1]`:
|
|
- 1 for tokens that are **not masked**,
|
|
- 0 for tokens that are **masked**.
|
|
[What are attention masks?](../glossary#attention-mask)
|
|
"""
|
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
output_hidden_states = (
|
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
)
|
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
|
|
|
if (input_ids is None) ^ (inputs_embeds is not None):
|
|
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
|
|
|
if self.gradient_checkpointing and self.training and use_cache:
|
|
logger.warning_once(
|
|
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
|
|
)
|
|
use_cache = False
|
|
|
|
if inputs_embeds is None:
|
|
inputs_embeds = self.embed_tokens(input_ids)
|
|
|
|
if use_cache and past_key_values is None:
|
|
self_attention_cache = DynamicCache()
|
|
cross_attention_cache = DynamicCache()
|
|
past_key_values = EncoderDecoderCache(self_attention_cache, cross_attention_cache)
|
|
|
|
if cache_position is None:
|
|
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
|
cache_position = torch.arange(
|
|
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
|
)
|
|
|
|
if position_ids is None:
|
|
position_ids = cache_position.unsqueeze(0)
|
|
|
|
causal_mask = create_causal_mask(
|
|
config=self.config,
|
|
input_embeds=inputs_embeds,
|
|
attention_mask=attention_mask,
|
|
cache_position=cache_position,
|
|
past_key_values=past_key_values,
|
|
)
|
|
|
|
hidden_states = inputs_embeds
|
|
|
|
# create position embeddings to be shared across the decoder layers
|
|
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
|
|
|
# decoder layers
|
|
all_hidden_states = () if output_hidden_states else None
|
|
all_self_attns = () if output_attentions else None
|
|
all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
|
|
|
|
# attention mask downsampling
|
|
if encoder_attention_mask is not None:
|
|
mask_len = encoder_hidden_states.shape[-2]
|
|
downsample_stride = 64 * 3 * 2 # conv strides
|
|
encoder_attention_mask = encoder_attention_mask[..., ::downsample_stride][..., :mask_len]
|
|
if self.config._attn_implementation == "flash_attention_2":
|
|
encoder_attention_mask = encoder_attention_mask if (encoder_attention_mask == 0.0).any() else None
|
|
|
|
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
|
|
elif self.config._attn_implementation == "sdpa" and not output_attentions:
|
|
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
|
encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa(
|
|
encoder_attention_mask, hidden_states.dtype, hidden_states.shape[-2]
|
|
)
|
|
else:
|
|
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
|
encoder_attention_mask = _prepare_4d_attention_mask(
|
|
encoder_attention_mask, hidden_states.dtype, hidden_states.shape[-2]
|
|
)
|
|
|
|
for decoder_layer in self.layers:
|
|
if output_hidden_states:
|
|
all_hidden_states += (hidden_states,)
|
|
|
|
layer_outputs = decoder_layer(
|
|
hidden_states,
|
|
causal_mask,
|
|
encoder_hidden_states, # as a positional argument for gradient checkpointing
|
|
encoder_attention_mask=encoder_attention_mask,
|
|
position_ids=position_ids,
|
|
past_key_value=past_key_values,
|
|
output_attentions=output_attentions,
|
|
use_cache=use_cache,
|
|
cache_position=cache_position,
|
|
position_embeddings=position_embeddings,
|
|
**flash_attn_kwargs,
|
|
)
|
|
|
|
hidden_states = layer_outputs[0]
|
|
|
|
if output_attentions:
|
|
all_self_attns += (layer_outputs[1],)
|
|
|
|
if encoder_hidden_states is not None:
|
|
all_cross_attentions += (layer_outputs[2],)
|
|
|
|
hidden_states = self.norm(hidden_states)
|
|
|
|
# add hidden states from the last decoder layer
|
|
if output_hidden_states:
|
|
all_hidden_states += (hidden_states,)
|
|
|
|
return BaseModelOutputWithPastAndCrossAttentions(
|
|
last_hidden_state=hidden_states,
|
|
past_key_values=past_key_values if use_cache else None,
|
|
hidden_states=all_hidden_states,
|
|
attentions=all_self_attns,
|
|
cross_attentions=all_cross_attentions,
|
|
)
|
|
|
|
|
|
def _compute_mask_indices(
|
|
shape: tuple[int, int],
|
|
mask_prob: float,
|
|
mask_length: int,
|
|
attention_mask: Optional[torch.LongTensor] = None,
|
|
min_masks: int = 0,
|
|
) -> np.ndarray:
|
|
"""
|
|
Computes random mask spans for a given shape. Used to implement [SpecAugment: A Simple Data Augmentation Method for
|
|
ASR](https://huggingface.co/papers/1904.08779). Note that this method is not optimized to run on TPU and should be run on
|
|
CPU as part of the preprocessing during training.
|
|
|
|
Args:
|
|
shape: The shape for which to compute masks. This should be of a tuple of size 2 where
|
|
the first element is the batch size and the second element is the length of the axis to span.
|
|
mask_prob: The percentage of the whole axis (between 0 and 1) which will be masked. The number of
|
|
independently generated mask spans of length `mask_length` is computed by
|
|
`mask_prob*shape[1]/mask_length`. Note that due to overlaps, `mask_prob` is an upper bound and the
|
|
actual percentage will be smaller.
|
|
mask_length: size of the mask
|
|
min_masks: minimum number of masked spans
|
|
attention_mask: A (right-padded) attention mask which independently shortens the feature axis of
|
|
each batch dimension.
|
|
"""
|
|
batch_size, sequence_length = shape
|
|
|
|
if mask_length < 1:
|
|
raise ValueError("`mask_length` has to be bigger than 0.")
|
|
|
|
if mask_length > sequence_length:
|
|
raise ValueError(
|
|
f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length}"
|
|
f" and `sequence_length`: {sequence_length}`"
|
|
)
|
|
|
|
# epsilon is used for probabilistic rounding
|
|
epsilon = np.random.rand(1).item()
|
|
|
|
def compute_num_masked_span(input_length):
|
|
"""Given input length, compute how many spans should be masked"""
|
|
num_masked_span = int(mask_prob * input_length / mask_length + epsilon)
|
|
num_masked_span = max(num_masked_span, min_masks)
|
|
|
|
# make sure num masked span <= sequence_length
|
|
if num_masked_span * mask_length > sequence_length:
|
|
num_masked_span = sequence_length // mask_length
|
|
|
|
# make sure num_masked span is also <= input_length - (mask_length - 1)
|
|
if input_length - (mask_length - 1) < num_masked_span:
|
|
num_masked_span = max(input_length - (mask_length - 1), 0)
|
|
|
|
return num_masked_span
|
|
|
|
# compute number of masked spans in batch
|
|
input_lengths = (
|
|
attention_mask.detach().sum(-1).tolist()
|
|
if attention_mask is not None
|
|
else [sequence_length for _ in range(batch_size)]
|
|
)
|
|
|
|
# SpecAugment mask to fill
|
|
spec_aug_mask = np.zeros((batch_size, sequence_length), dtype=bool)
|
|
spec_aug_mask_idxs = []
|
|
|
|
max_num_masked_span = compute_num_masked_span(sequence_length)
|
|
|
|
if max_num_masked_span == 0:
|
|
return spec_aug_mask
|
|
|
|
for input_length in input_lengths:
|
|
# compute num of masked spans for this input
|
|
num_masked_span = compute_num_masked_span(input_length)
|
|
|
|
# get random indices to mask
|
|
spec_aug_mask_idx = np.random.choice(
|
|
np.arange(input_length - (mask_length - 1)), num_masked_span, replace=False
|
|
)
|
|
|
|
# pick first sampled index that will serve as a dummy index to pad vector
|
|
# to ensure same dimension for all batches due to probabilistic rounding
|
|
# Picking first sample just pads those vectors twice.
|
|
if len(spec_aug_mask_idx) == 0:
|
|
# this case can only happen if `input_length` is strictly smaller then
|
|
# `sequence_length` in which case the last token has to be a padding
|
|
# token which we can use as a dummy mask id
|
|
dummy_mask_idx = sequence_length - 1
|
|
else:
|
|
dummy_mask_idx = spec_aug_mask_idx[0]
|
|
|
|
spec_aug_mask_idx = np.concatenate(
|
|
[spec_aug_mask_idx, np.ones(max_num_masked_span - num_masked_span, dtype=np.int32) * dummy_mask_idx]
|
|
)
|
|
spec_aug_mask_idxs.append(spec_aug_mask_idx)
|
|
|
|
spec_aug_mask_idxs = np.array(spec_aug_mask_idxs)
|
|
|
|
# expand masked indices to masked spans
|
|
spec_aug_mask_idxs = np.broadcast_to(
|
|
spec_aug_mask_idxs[:, :, None], (batch_size, max_num_masked_span, mask_length)
|
|
)
|
|
spec_aug_mask_idxs = spec_aug_mask_idxs.reshape(batch_size, max_num_masked_span * mask_length)
|
|
|
|
# add offset to the starting indexes so that indexes now create a span
|
|
offsets = np.arange(mask_length)[None, None, :]
|
|
offsets = np.broadcast_to(offsets, (batch_size, max_num_masked_span, mask_length)).reshape(
|
|
batch_size, max_num_masked_span * mask_length
|
|
)
|
|
spec_aug_mask_idxs = spec_aug_mask_idxs + offsets
|
|
|
|
# ensure that we cannot have indices larger than sequence_length
|
|
if spec_aug_mask_idxs.max() > sequence_length - 1:
|
|
spec_aug_mask_idxs[spec_aug_mask_idxs > sequence_length - 1] = sequence_length - 1
|
|
|
|
# scatter indices to mask
|
|
np.put_along_axis(spec_aug_mask, spec_aug_mask_idxs, 1, -1)
|
|
|
|
return spec_aug_mask
|
|
|
|
|
|
@auto_docstring
|
|
class MoonshineModel(MoonshinePreTrainedModel):
|
|
def __init__(self, config: MoonshineConfig):
|
|
super().__init__(config)
|
|
|
|
self.encoder = MoonshineEncoder(config)
|
|
self.decoder = MoonshineDecoder(config)
|
|
# Initialize weights and apply final processing
|
|
self.post_init()
|
|
|
|
def get_input_embeddings(self):
|
|
return self.decoder.embed_tokens
|
|
|
|
def set_input_embeddings(self, value):
|
|
self.decoder.embed_tokens = value
|
|
|
|
def get_encoder(self):
|
|
return self.encoder
|
|
|
|
def get_decoder(self):
|
|
return self.decoder
|
|
|
|
def freeze_encoder(self):
|
|
"""
|
|
Calling this function will disable the gradient computation for the Moonshine encoder so that its parameters will
|
|
not be updated during training.
|
|
"""
|
|
self.encoder._freeze_parameters()
|
|
|
|
def _mask_input_features(
|
|
self,
|
|
input_features: torch.FloatTensor,
|
|
attention_mask: Optional[torch.LongTensor] = None,
|
|
):
|
|
"""
|
|
Masks extracted features along time axis and/or along feature axis according to
|
|
[SpecAugment](https://huggingface.co/papers/1904.08779).
|
|
"""
|
|
|
|
# `config.apply_spec_augment` can set masking to False
|
|
if not getattr(self.config, "apply_spec_augment", True):
|
|
return input_features
|
|
|
|
# generate indices & apply SpecAugment along time axis
|
|
batch_size, hidden_size, sequence_length = input_features.size()
|
|
|
|
if self.config.mask_time_prob > 0 and self.training:
|
|
# generate indices & apply SpecAugment along time axis
|
|
mask_time_indices = _compute_mask_indices(
|
|
(batch_size, sequence_length),
|
|
mask_prob=self.config.mask_time_prob,
|
|
mask_length=self.config.mask_time_length,
|
|
attention_mask=attention_mask,
|
|
min_masks=self.config.mask_time_min_masks,
|
|
)
|
|
mask_time_indices = torch.tensor(mask_time_indices, device=input_features.device, dtype=torch.bool)
|
|
mask_time_indices = mask_time_indices[:, None].expand(-1, hidden_size, -1)
|
|
input_features[mask_time_indices] = 0
|
|
|
|
if self.config.mask_feature_prob > 0 and self.training:
|
|
# generate indices & apply SpecAugment along feature axis
|
|
mask_feature_indices = _compute_mask_indices(
|
|
(batch_size, hidden_size),
|
|
mask_prob=self.config.mask_feature_prob,
|
|
mask_length=self.config.mask_feature_length,
|
|
min_masks=self.config.mask_feature_min_masks,
|
|
)
|
|
mask_feature_indices = torch.tensor(mask_feature_indices, device=input_features.device, dtype=torch.bool)
|
|
input_features[mask_feature_indices] = 0
|
|
|
|
return input_features
|
|
|
|
@can_return_tuple
|
|
@auto_docstring
|
|
def forward(
|
|
self,
|
|
input_values: Optional[torch.FloatTensor] = None,
|
|
attention_mask: Optional[torch.LongTensor] = None,
|
|
decoder_input_ids: Optional[torch.LongTensor] = None,
|
|
decoder_attention_mask: Optional[torch.LongTensor] = None,
|
|
encoder_outputs: Optional[tuple[tuple[torch.FloatTensor]]] = None,
|
|
past_key_values: Optional[Union[EncoderDecoderCache, tuple[torch.FloatTensor]]] = None,
|
|
decoder_inputs_embeds: Optional[tuple[torch.FloatTensor]] = None,
|
|
decoder_position_ids: Optional[tuple[torch.LongTensor]] = None,
|
|
use_cache: Optional[bool] = None,
|
|
output_attentions: Optional[bool] = None,
|
|
output_hidden_states: Optional[bool] = None,
|
|
cache_position: Optional[torch.LongTensor] = None,
|
|
) -> Seq2SeqModelOutput:
|
|
r"""
|
|
input_values (`torch.FloatTensor` of shape `(batch_size, audio_length)`):
|
|
Float values of the raw speech waveform. Raw speech waveform can be
|
|
obtained by loading a `.flac` or `.wav` audio file into an array of type `list[float]` or a
|
|
`numpy.ndarray`, *e.g.* via the soundfile library (`pip install soundfile`). To prepare the array into
|
|
`input_values`, the [`AutoFeatureExtractor`] should be used for padding
|
|
and conversion into a tensor of type `torch.FloatTensor`.
|
|
decoder_input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
|
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
|
it.
|
|
|
|
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
|
[`PreTrainedTokenizer.__call__`] for details.
|
|
|
|
[What are input IDs?](../glossary#input-ids)
|
|
decoder_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
|
|
|
- 1 for tokens that are **not masked**,
|
|
- 0 for tokens that are **masked**.
|
|
|
|
[What are attention masks?](../glossary#attention-mask)
|
|
|
|
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
|
[`PreTrainedTokenizer.__call__`] for details.
|
|
|
|
If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
|
|
`past_key_values`).
|
|
|
|
If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
|
|
and modify to your needs. See diagram 1 in [the paper](https://huggingface.co/papers/1910.13461) for more
|
|
information on the default strategy.
|
|
|
|
- 1 indicates the head is **not masked**,
|
|
- 0 indicates the head is **masked**.
|
|
decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
|
Optionally, instead of passing `decoder_input_ids` you can choose to directly pass an embedded representation. This
|
|
is useful if you want more control over how to convert `decoder_input_ids` indices into associated vectors than the
|
|
model's internal embedding lookup matrix.
|
|
decoder_position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
|
config.n_positions - 1]`.
|
|
|
|
[What are position IDs?](../glossary#position-ids)
|
|
|
|
Example:
|
|
|
|
```python
|
|
>>> import torch
|
|
>>> from transformers import AutoFeatureExtractor, MoonshineModel
|
|
>>> from datasets import load_dataset
|
|
|
|
>>> model = MoonshineModel.from_pretrained("UsefulSensors/moonshine-tiny")
|
|
>>> feature_extractor = AutoFeatureExtractor.from_pretrained("UsefulSensors/moonshine-tiny")
|
|
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
|
>>> inputs = feature_extractor(ds[0]["audio"]["array"], return_tensors="pt")
|
|
>>> input_values = inputs.input_values
|
|
>>> decoder_input_ids = torch.tensor([[1, 1]]) * model.config.decoder_start_token_id
|
|
>>> last_hidden_state = model(input_values, decoder_input_ids=decoder_input_ids).last_hidden_state
|
|
>>> list(last_hidden_state.shape)
|
|
[1, 2, 288]
|
|
```
|
|
"""
|
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
output_hidden_states = (
|
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
)
|
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
|
|
|
if encoder_outputs is None:
|
|
encoder_outputs: BaseModelOutput = self.encoder(
|
|
input_values,
|
|
attention_mask=attention_mask,
|
|
output_attentions=output_attentions,
|
|
output_hidden_states=output_hidden_states,
|
|
)
|
|
# If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True
|
|
elif not isinstance(encoder_outputs, BaseModelOutput):
|
|
encoder_outputs = BaseModelOutput(
|
|
last_hidden_state=encoder_outputs[0],
|
|
hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
|
|
attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
|
|
)
|
|
|
|
# decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn)
|
|
decoder_outputs: BaseModelOutputWithPastAndCrossAttentions = self.decoder(
|
|
input_ids=decoder_input_ids,
|
|
attention_mask=decoder_attention_mask,
|
|
encoder_attention_mask=attention_mask,
|
|
encoder_hidden_states=encoder_outputs.last_hidden_state,
|
|
past_key_values=past_key_values,
|
|
inputs_embeds=decoder_inputs_embeds,
|
|
position_ids=decoder_position_ids,
|
|
use_cache=use_cache,
|
|
output_attentions=output_attentions,
|
|
output_hidden_states=output_hidden_states,
|
|
cache_position=cache_position,
|
|
)
|
|
|
|
return Seq2SeqModelOutput(
|
|
last_hidden_state=decoder_outputs.last_hidden_state,
|
|
past_key_values=decoder_outputs.past_key_values,
|
|
decoder_hidden_states=decoder_outputs.hidden_states,
|
|
decoder_attentions=decoder_outputs.attentions,
|
|
cross_attentions=decoder_outputs.cross_attentions,
|
|
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
|
|
encoder_hidden_states=encoder_outputs.hidden_states,
|
|
encoder_attentions=encoder_outputs.attentions,
|
|
)
|
|
|
|
|
|
def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int):
|
|
"""
|
|
Shift input ids one token to the right.
|
|
"""
|
|
shifted_input_ids = input_ids.new_zeros(input_ids.shape)
|
|
shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
|
|
shifted_input_ids[:, 0] = decoder_start_token_id
|
|
|
|
if pad_token_id is None:
|
|
raise ValueError("self.model.config.pad_token_id has to be defined.")
|
|
# replace possible -100 values in labels by `pad_token_id`
|
|
shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
|
|
|
|
return shifted_input_ids
|
|
|
|
|
|
@auto_docstring(
|
|
custom_intro="""
|
|
The Moonshine Model with a language modeling head. Can be used for automatic speech recognition.
|
|
"""
|
|
)
|
|
class MoonshineForConditionalGeneration(MoonshinePreTrainedModel, GenerationMixin):
|
|
_tied_weights_keys = ["proj_out.weight"]
|
|
|
|
def __init__(self, config: MoonshineConfig):
|
|
super().__init__(config)
|
|
self.model = MoonshineModel(config)
|
|
self.proj_out = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
|
|
|
# Initialize weights and apply final processing
|
|
self.post_init()
|
|
|
|
def get_encoder(self):
|
|
return self.model.get_encoder()
|
|
|
|
def get_decoder(self):
|
|
return self.model.get_decoder()
|
|
|
|
def get_output_embeddings(self):
|
|
return self.proj_out
|
|
|
|
def set_output_embeddings(self, new_embeddings):
|
|
self.proj_out = new_embeddings
|
|
|
|
def get_input_embeddings(self) -> nn.Module:
|
|
return self.model.get_input_embeddings()
|
|
|
|
@can_return_tuple
|
|
@auto_docstring
|
|
def forward(
|
|
self,
|
|
input_values: Optional[torch.FloatTensor] = None,
|
|
attention_mask: Optional[torch.LongTensor] = None,
|
|
decoder_input_ids: Optional[torch.LongTensor] = None,
|
|
decoder_attention_mask: Optional[torch.LongTensor] = None,
|
|
encoder_outputs: Optional[tuple[tuple[torch.FloatTensor]]] = None,
|
|
past_key_values: Optional[Union[EncoderDecoderCache, tuple[torch.FloatTensor]]] = None,
|
|
decoder_inputs_embeds: Optional[tuple[torch.FloatTensor]] = None,
|
|
decoder_position_ids: Optional[tuple[torch.LongTensor]] = None,
|
|
use_cache: Optional[bool] = None,
|
|
output_attentions: Optional[bool] = None,
|
|
output_hidden_states: Optional[bool] = None,
|
|
cache_position: Optional[torch.LongTensor] = None,
|
|
labels: Optional[torch.LongTensor] = None,
|
|
) -> Seq2SeqLMOutput:
|
|
r"""
|
|
input_values (`torch.FloatTensor` of shape `(batch_size, audio_length)`):
|
|
Float values of the raw speech waveform. Raw speech waveform can be
|
|
obtained by loading a `.flac` or `.wav` audio file into an array of type `list[float]` or a
|
|
`numpy.ndarray`, *e.g.* via the soundfile library (`pip install soundfile`). To prepare the array into
|
|
`input_values`, the [`AutoFeatureExtractor`] should be used for padding
|
|
and conversion into a tensor of type `torch.FloatTensor`.
|
|
decoder_input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
|
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
|
it.
|
|
|
|
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
|
[`PreTrainedTokenizer.__call__`] for details.
|
|
|
|
[What are input IDs?](../glossary#input-ids)
|
|
decoder_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
|
|
|
- 1 for tokens that are **not masked**,
|
|
- 0 for tokens that are **masked**.
|
|
|
|
[What are attention masks?](../glossary#attention-mask)
|
|
|
|
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
|
[`PreTrainedTokenizer.__call__`] for details.
|
|
|
|
If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
|
|
`past_key_values`).
|
|
|
|
If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
|
|
and modify to your needs. See diagram 1 in [the paper](https://huggingface.co/papers/1910.13461) for more
|
|
information on the default strategy.
|
|
|
|
- 1 indicates the head is **not masked**,
|
|
- 0 indicates the head is **masked**.
|
|
decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
|
Optionally, instead of passing `decoder_input_ids` you can choose to directly pass an embedded representation. This
|
|
is useful if you want more control over how to convert `decoder_input_ids` indices into associated vectors than the
|
|
model's internal embedding lookup matrix.
|
|
decoder_position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
|
config.n_positions - 1]`.
|
|
|
|
[What are position IDs?](../glossary#position-ids)
|
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
Labels for computing the language modeling loss. Indices should either be in `[0, ..., config.vocab_size]`
|
|
or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is
|
|
only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
|
|
|
Example:
|
|
|
|
```python
|
|
>>> import torch
|
|
>>> from transformers import AutoProcessor, MoonshineForConditionalGeneration
|
|
>>> from datasets import load_dataset
|
|
|
|
>>> processor = AutoProcessor.from_pretrained("UsefulSensors/moonshine-tiny")
|
|
>>> model = MoonshineForConditionalGeneration.from_pretrained("UsefulSensors/moonshine-tiny")
|
|
|
|
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
|
|
|
>>> inputs = processor(ds[0]["audio"]["array"], return_tensors="pt")
|
|
>>> input_values = inputs.input_values
|
|
|
|
>>> generated_ids = model.generate(input_values, max_new_tokens=100)
|
|
|
|
>>> transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
|
>>> transcription
|
|
'Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.'
|
|
```"""
|
|
|
|
if labels is not None:
|
|
if decoder_input_ids is None and decoder_inputs_embeds is None:
|
|
decoder_input_ids = shift_tokens_right(
|
|
labels, self.config.pad_token_id, self.config.decoder_start_token_id
|
|
)
|
|
|
|
outputs: Seq2SeqModelOutput = self.model(
|
|
input_values,
|
|
attention_mask=attention_mask,
|
|
decoder_input_ids=decoder_input_ids,
|
|
encoder_outputs=encoder_outputs,
|
|
decoder_attention_mask=decoder_attention_mask,
|
|
past_key_values=past_key_values,
|
|
decoder_inputs_embeds=decoder_inputs_embeds,
|
|
decoder_position_ids=decoder_position_ids,
|
|
use_cache=use_cache,
|
|
output_attentions=output_attentions,
|
|
output_hidden_states=output_hidden_states,
|
|
cache_position=cache_position,
|
|
)
|
|
logits = self.proj_out(outputs.last_hidden_state)
|
|
|
|
loss = None
|
|
if labels is not None:
|
|
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size)
|
|
|
|
return Seq2SeqLMOutput(
|
|
loss=loss,
|
|
logits=logits,
|
|
past_key_values=outputs.past_key_values,
|
|
decoder_hidden_states=outputs.decoder_hidden_states,
|
|
decoder_attentions=outputs.decoder_attentions,
|
|
cross_attentions=outputs.cross_attentions,
|
|
encoder_last_hidden_state=outputs.encoder_last_hidden_state,
|
|
encoder_hidden_states=outputs.encoder_hidden_states,
|
|
encoder_attentions=outputs.encoder_attentions,
|
|
)
|
|
|
|
|
|
__all__ = ["MoonshineModel", "MoonshinePreTrainedModel", "MoonshineForConditionalGeneration"]
|