Add kyutai stt (#38909)

* first draft

* cleaner version

* udpate tests + modeling

* add tests

* init

* udpate test_modeling_common

* fix tests

* csm Processor draft

* convertion update

* mimi cache padding convolutions draft

* mimi streaming udpates

* update mimi padding cache test

* udpate cache padding mimi test

* make style mimi

* updates generate moshi asr

* moshi asr integration tests (single + batched)

* update tests

* update conversion script

* good default sliding window value

* udpdate generate

* update test checkpoint

* nit

* fix mimi

* fix codec prefix

* revert

* revert

* update config

* update config

* unnecessary mimi input restriction

* remove delay in tokens

* remove _prepare_4d_causal_attention_mask_with_cache_position and _update_causal_mask

* test update

* modular update

* make style

* nit

* rename

* create codec model generation config at init

* remove delay

* max_new_tokens/length warning

* correct conv1 padding cache import for modular

* nit

* fix on encoder_past_key_values

* convert modular

* move frame_size to config

* move frame_size to config

* update test name

* handle first token is bos

* better handling of max_new_tokens

* fix

* fix batch size in test input prep

* update docstring

* convert modular

* make style

* make style

* add feature extractor

* correct modular convention name for feature_extraction file

* update convertion script

* doc processor

* update doc

* udpate init

* update model type

* fixes

* update tests

* fix

* make

* add doc

* nit

* fix

* doc

* auto mappings

* doc

* nit

* convert modular

* doc

* nit

* extend _keep_in_fp32_modules to enforce fp32

* renaming to stt

* doc update + test update

* doc fixes

* doc fix

* doc fix

* fix musicgen tests

* fix musicgen tests

* make style

* fix musicgen tests

* correct frame_rate config param for mimi

* update mimi test

* revert update mimi test

* enforce cpu test

* move cache init in cache class

* convert modular

* docstring update

* update model id

* feature_extractor -> feature_extraction (SEW)

* convert modular

* update model id
This commit is contained in:
eustlb 2025-06-24 18:01:15 +02:00 committed by GitHub
parent 08bf7f1afe
commit 6bdd4ec952
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
23 changed files with 4000 additions and 200 deletions

View File

@ -843,6 +843,8 @@
title: GraniteSpeech
- local: model_doc/hubert
title: Hubert
- local: model_doc/stt
title: Kyutai Speech-To-Text
- local: model_doc/mctct
title: MCTCT
- local: model_doc/mimi

View File

@ -0,0 +1,122 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.
-->
# Kyutai Speech-To-Text
## Overview
Kyutai STT is a speech-to-text model architecture based on the [Mimi codec](https://huggingface.co/docs/transformers/en/model_doc/mimi), which encodes audio into discrete tokens in a streaming fashion, and a [Moshi-like](https://huggingface.co/docs/transformers/en/model_doc/moshi) autoregressive decoder. Kyutais lab has released two model checkpoints:
- [kyutai/stt-1b-en_fr](https://huggingface.co/kyutai/stt-1b-en_fr): a 1B-parameter model capable of transcribing both English and French
- [kyutai/stt-2.6b-en](https://huggingface.co/kyutai/stt-2.6b-en): a 2.6B-parameter model focused solely on English, optimized for maximum transcription accuracy
<div class="flex justify-center">
<img src="https://huggingface.co/datasets/eustlb/documentation-images/resolve/main/kyutai_stt.png"/>
</div>
## Usage Tips
### Inference
```python
import torch
from datasets import load_dataset, Audio
from transformers import KyutaiSpeechToTextProcessor, KyutaiSpeechToTextForConditionalGeneration
# 1. load the model and the processor
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
model_id = "kyutai/stt-2.6b-en"
processor = KyutaiSpeechToTextProcessor.from_pretrained(model_id)
model = KyutaiSpeechToTextForConditionalGeneration.from_pretrained(model_id, device_map=torch_device)
# 2. load audio samples
ds = load_dataset(
"hf-internal-testing/librispeech_asr_dummy", "clean", split="validation"
)
ds = ds.cast_column("audio", Audio(sampling_rate=24000))
# 3. prepare the model inputs
inputs = processor(
ds[0]["audio"]["array"],
)
inputs.to(torch_device)
# 4. infer the model
output_tokens = model.generate(**inputs)
# 5. decode the generated tokens
print(processor.batch_decode(output_tokens, skip_special_tokens=True))
```
### Batched Inference
```python
import torch
from datasets import load_dataset, Audio
from transformers import KyutaiSpeechToTextProcessor, KyutaiSpeechToTextForConditionalGeneration
# 1. load the model and the processor
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
model_id = "kyutai/stt-2.6b-en"
processor = KyutaiSpeechToTextProcessor.from_pretrained(model_id)
model = KyutaiSpeechToTextForConditionalGeneration.from_pretrained(model_id, device_map=torch_device)
# 2. load audio samples
ds = load_dataset(
"hf-internal-testing/librispeech_asr_dummy", "clean", split="validation"
)
ds = ds.cast_column("audio", Audio(sampling_rate=24000))
# 3. prepare the model inputs
audio_arrays = [ds[i]["audio"]["array"] for i in range(4)]
inputs = processor(audio_arrays, return_tensors="pt", padding=True)
inputs = inputs.to(torch_device)
# 4. infer the model
output_tokens = model.generate(**inputs)
# 5. decode the generated tokens
decoded_outputs = processor.batch_decode(output_tokens, skip_special_tokens=True)
for output in decoded_outputs:
print(output)
```
This model was contributed by [Eustache Le Bihan](https://huggingface.co/eustlb).
The original code can be found [here](https://github.com/kyutai-labs/moshi).
## KyutaiSpeechToTextConfig
[[autodoc]] KyutaiSpeechToTextConfig
## KyutaiSpeechToTextProcessor
[[autodoc]] KyutaiSpeechToTextProcessor
- __call__
## KyutaiSpeechToTextFeatureExtractor
[[autodoc]] KyutaiSpeechToTextFeatureExtractor
## KyutaiSpeechToTextForConditionalGeneration
[[autodoc]] KyutaiSpeechToTextForConditionalGeneration
- forward
- generate
## KyutaiSpeechToTextModel
[[autodoc]] KyutaiSpeechToTextModel

View File

@ -4658,8 +4658,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
# The _keep_in_fp32_modules flag is only used to avoid bf16 -> fp16 casting precision issues. It was introduced
# in case of force loading a model that should stay bf16 in fp16 (which includes a few quantizers as this is a pre-processing
# step for e.g. bitsandbytes). See https://github.com/huggingface/transformers/issues/20287 for details.
# Update: to extend _keep_in_fp32_modules flag feature, it can also be used to force modules that should stay in fp32
if model._keep_in_fp32_modules is not None and (
torch_dtype == torch.float16 or getattr(hf_quantizer, "use_keep_in_fp32_modules", False)
torch_dtype == torch.float16
or torch_dtype == torch.bfloat16
or getattr(hf_quantizer, "use_keep_in_fp32_modules", False)
):
# We need to match exact layers, so we add either `.` on each side, or start/end of string
keep_in_fp32_regex = re.compile(

View File

@ -285,6 +285,7 @@ if TYPE_CHECKING:
from .squeezebert import *
from .stablelm import *
from .starcoder2 import *
from .stt import *
from .superglue import *
from .superpoint import *
from .swiftformer import *

View File

@ -322,6 +322,7 @@ CONFIG_MAPPING_NAMES = OrderedDict[str, str](
("squeezebert", "SqueezeBertConfig"),
("stablelm", "StableLmConfig"),
("starcoder2", "Starcoder2Config"),
("stt", "KyutaiSpeechToTextConfig"),
("superglue", "SuperGlueConfig"),
("superpoint", "SuperPointConfig"),
("swiftformer", "SwiftFormerConfig"),
@ -707,6 +708,7 @@ MODEL_NAMES_MAPPING = OrderedDict[str, str](
("squeezebert", "SqueezeBERT"),
("stablelm", "StableLm"),
("starcoder2", "Starcoder2"),
("stt", "KyutaiSpeechToText"),
("superglue", "SuperGlue"),
("superpoint", "SuperPoint"),
("swiftformer", "SwiftFormer"),

View File

@ -91,6 +91,7 @@ FEATURE_EXTRACTOR_MAPPING_NAMES = OrderedDict(
("sew-d", "Wav2Vec2FeatureExtractor"),
("speech_to_text", "Speech2TextFeatureExtractor"),
("speecht5", "SpeechT5FeatureExtractor"),
("stt", "KyutaiSpeechToTextFeatureExtractor"),
("swiftformer", "ViTFeatureExtractor"),
("swin", "ViTFeatureExtractor"),
("swinv2", "ViTFeatureExtractor"),

View File

@ -300,6 +300,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
("squeezebert", "SqueezeBertModel"),
("stablelm", "StableLmModel"),
("starcoder2", "Starcoder2Model"),
("stt", "KyutaiSpeechToTextModel"),
("superglue", "SuperGlueForKeypointMatching"),
("swiftformer", "SwiftFormerModel"),
("swin", "SwinModel"),
@ -1055,6 +1056,7 @@ MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES = OrderedDict(
("speech-encoder-decoder", "SpeechEncoderDecoderModel"),
("speech_to_text", "Speech2TextForConditionalGeneration"),
("speecht5", "SpeechT5ForSpeechToText"),
("stt", "KyutaiSpeechToTextForConditionalGeneration"),
("whisper", "WhisperForConditionalGeneration"),
]
)

View File

@ -116,6 +116,7 @@ PROCESSOR_MAPPING_NAMES = OrderedDict(
("speech_to_text", "Speech2TextProcessor"),
("speech_to_text_2", "Speech2Text2Processor"),
("speecht5", "SpeechT5Processor"),
("stt", "KyutaiSpeechToTextProcessor"),
("trocr", "TrOCRProcessor"),
("tvlt", "TvltProcessor"),
("tvp", "TvpProcessor"),

View File

@ -38,8 +38,8 @@ class MimiConfig(PretrainedConfig):
Args:
sampling_rate (`int`, *optional*, defaults to 24000):
The sampling rate at which the audio waveform should be digitalized expressed in hertz (Hz).
frame_rate (`float`, *optional*, defaults to 12.5):
Framerate of the model.
frame_rate (`float`, *optional*):
Should be computed from the other parameters, yet kept for backward compatibility.
audio_channels (`int`, *optional*, defaults to 1):
Number of channels in the audio data. Either 1 for mono or 2 for stereo.
hidden_size (`int`, *optional*, defaults to 512):
@ -111,6 +111,8 @@ class MimiConfig(PretrainedConfig):
use_cache (`bool`, *optional*, defaults to `False`):
Whether or not the model should return the last key/values attentions (not used by all models). Only
relevant if `config.is_decoder=True`.
use_streaming (`bool`, *optional*, defaults to `False`):
Whether to use streaming mode. If `True`, the model encode method will return the padding cache that can be used in a subsequent call to the encode method.
rope_theta (`float`, *optional*, defaults to 10000.0):
The base period of the RoPE embeddings.
sliding_window (`int`, *optional*, defaults to 250):
@ -141,7 +143,7 @@ class MimiConfig(PretrainedConfig):
def __init__(
self,
sampling_rate=24_000,
frame_rate=12.5,
frame_rate=None,
audio_channels=1,
hidden_size=512,
num_filters=64,
@ -172,6 +174,7 @@ class MimiConfig(PretrainedConfig):
initializer_range=0.02,
norm_eps=1e-5,
use_cache=False,
use_streaming=False,
rope_theta=10000.0,
sliding_window=250,
attention_dropout=0.0,
@ -180,7 +183,6 @@ class MimiConfig(PretrainedConfig):
**kwargs,
):
self.sampling_rate = sampling_rate
self.frame_rate = frame_rate
self.audio_channels = audio_channels
self.hidden_size = hidden_size
self.num_filters = num_filters
@ -209,6 +211,7 @@ class MimiConfig(PretrainedConfig):
self.initializer_range = initializer_range
self.norm_eps = norm_eps
self.use_cache = use_cache
self.use_streaming = use_streaming
self.rope_theta = rope_theta
self.sliding_window = sliding_window
self.attention_dropout = attention_dropout
@ -216,6 +219,14 @@ class MimiConfig(PretrainedConfig):
self.layer_scale_initial_scale = layer_scale_initial_scale
self.attention_bias = attention_bias
# Handle backward compatibility for frame_rate:
# If frame_rate is explicitly provided, use it (backward compatibility)
# Otherwise, compute it from other parameters (correctly)
if frame_rate is not None:
self._frame_rate = frame_rate
else:
self._frame_rate = None
if num_semantic_quantizers >= self.num_quantizers:
raise ValueError(
f"The number of semantic quantizers should be lower than the total number of quantizers {self.num_quantizers}, but is currently {num_semantic_quantizers}."
@ -233,5 +244,36 @@ class MimiConfig(PretrainedConfig):
# alias to num_quantizers
return self.num_quantizers
@property
def frame_size(self) -> int:
# 1. we need each encoder conv stride
# first conv
strides = [1]
# layer convs
for ratio in reversed(self.upsampling_ratios):
for j in range(self.num_residual_layers):
len_kernel_sizes = len(self.residual_kernel_size) if isinstance(self.residual_kernel_size, list) else 1
strides.extend([1] * (len_kernel_sizes + 1))
if self.use_conv_shortcut: # skip connection
strides.append(1)
strides.append(ratio)
# last conv
strides.append(1)
# downsampling layer
strides.append(2)
return math.prod(strides)
@property
def frame_rate(self) -> float:
# handle backward compatibility
if self._frame_rate is not None:
return self._frame_rate
return self.sampling_rate / self.frame_size
__all__ = ["MimiConfig"]

View File

@ -23,25 +23,20 @@ import torch.utils.checkpoint
from torch import nn
from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache
from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...cache_utils import Cache, DynamicCache, StaticCache
from ...masking_utils import create_causal_mask
from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutputWithPast
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
from ...modeling_utils import PreTrainedModel
from ...utils import ModelOutput, auto_docstring, is_torch_flex_attn_available, logging
from ...utils import ModelOutput, auto_docstring, logging
from .configuration_mimi import MimiConfig
if is_flash_attn_available():
from ...modeling_flash_attention_utils import _flash_attention_forward
if is_torch_flex_attn_available():
from torch.nn.attention.flex_attention import BlockMask
from ...integrations.flex_attention import make_flex_block_causal_mask
logger = logging.get_logger(__name__)
@ -78,6 +73,91 @@ class MimiOutput(ModelOutput):
decoder_past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None
class MimiConv1dPaddingCache:
"""
Padding cache for MimiConv1d causal convolutions in order to support streaming via cache padding.
See: https://arxiv.org/pdf/2005.06720 & https://arxiv.org/pdf/2204.07064
A padding cache is a list of cached partial hidden states for each convolution layer.
Hidden states are cached from the previous call to the MimiConv1d forward pass, given the padding size.
"""
def __init__(
self,
num_layers: int,
per_layer_padding: list[int],
per_layer_padding_mode: list[str],
per_layer_in_channels: list[int],
):
# ensure correct number of layers for each arg
from_args_num_layers = {len(per_layer_padding), len(per_layer_padding_mode), len(per_layer_in_channels)}
if len(from_args_num_layers) != 1 or from_args_num_layers.pop() != num_layers:
raise ValueError(
f"Expected `num_layers` ({num_layers}) values in `per_layer_padding`, `per_layer_padding_mode` and `per_layer_in_channels`"
)
elif not all(mode in ["constant", "replicate"] for mode in per_layer_padding_mode):
raise NotImplementedError(
"`padding_cache` is not supported for convolutions using other than `constant` or `replicate` padding mode"
)
self.per_layer_padding = per_layer_padding
self.per_layer_padding_mode = per_layer_padding_mode
self.per_layer_in_channels = per_layer_in_channels
self.per_layer_is_init = [True] * num_layers
self.padding_cache = [None] * num_layers
def update(self, hidden_states: torch.Tensor, layer_idx: int):
"""
Updates the padding cache with the new padding states for the layer `layer_idx` and returns the current cache.
Parameters:
hidden_states (`torch.Tensor`):
The hidden states to be partially cached.
layer_idx (`int`):
The index of the layer to cache the states for.
Returns:
`torch.Tensor` or `None`, the current padding cache.
"""
batch_size, dtype, device = hidden_states.shape[0], hidden_states.dtype, hidden_states.device
padding = self.per_layer_padding[layer_idx]
padding_mode = self.per_layer_padding_mode[layer_idx]
in_channels = self.per_layer_in_channels[layer_idx]
if self.padding_cache[layer_idx] is None:
if padding_mode == "constant":
current_cache = torch.zeros(
batch_size,
in_channels,
padding,
device=device,
dtype=dtype,
)
elif padding_mode == "replicate":
current_cache = (
torch.ones(
batch_size,
in_channels,
padding,
device=device,
dtype=dtype,
)
* hidden_states[..., :1]
)
else:
current_cache = self.padding_cache[layer_idx]
# update the cache
if padding > 0:
padding_states = hidden_states[:, :, -padding:]
else:
padding_states = torch.empty(batch_size, in_channels, padding, dtype=dtype, device=device)
self.padding_cache[layer_idx] = padding_states
return current_cache
@dataclass
@auto_docstring
class MimiEncoderOutput(ModelOutput):
@ -96,6 +176,7 @@ class MimiEncoderOutput(ModelOutput):
audio_codes: Optional[torch.LongTensor] = None
encoder_past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None
padding_cache: Optional[MimiConv1dPaddingCache] = None
@dataclass
@ -130,12 +211,15 @@ class MimiConv1d(nn.Module):
stride: int = 1,
dilation: int = 1,
groups: int = 1,
pad_mode=None,
pad_mode: Optional[str] = None,
bias: bool = True,
layer_idx: Optional[int] = None,
):
super().__init__()
self.causal = config.use_causal_conv
self.pad_mode = config.pad_mode if pad_mode is None else pad_mode
self.layer_idx = layer_idx
self.in_channels = in_channels
# warn user on unusual setup between dilation and stride
if stride > 1 and dilation > 1:
@ -232,12 +316,20 @@ class MimiConv1d(nn.Module):
) // self.conv.stride[0] + 1
return output_lenght
def forward(self, hidden_states):
def forward(self, hidden_states, padding_cache=None):
extra_padding = self._get_extra_padding_for_conv1d(hidden_states)
if self.causal:
if not self.causal and padding_cache is not None:
raise ValueError("`padding_cache` is not supported for non-causal convolutions.")
if self.causal and padding_cache is not None:
layer_padding_cache = padding_cache.update(hidden_states, self.layer_idx)
hidden_states = torch.cat([layer_padding_cache, hidden_states], dim=2)
elif self.causal:
# Left padding for causal
hidden_states = self._pad1d(hidden_states, (self.padding_total, extra_padding), mode=self.pad_mode)
else:
hidden_states = self._pad1d(
hidden_states, (self.padding_left, self.padding_right + extra_padding), mode=self.pad_mode
@ -305,7 +397,6 @@ class MimiConvTranspose1d(nn.Module):
return hidden_states
# Copied from transformers.models.encodec.modeling_encodec.EncodecResnetBlock with Encodec->Mimi,EnCodec->Mimi
class MimiResnetBlock(nn.Module):
"""
Residual block from SEANet model as used by Mimi.
@ -331,12 +422,21 @@ class MimiResnetBlock(nn.Module):
else:
self.shortcut = nn.Identity()
def forward(self, hidden_states):
def forward(self, hidden_states, padding_cache=None):
residual = hidden_states
for layer in self.block:
hidden_states = layer(hidden_states)
return self.shortcut(residual) + hidden_states
for layer in self.block:
if isinstance(layer, MimiConv1d):
hidden_states = layer(hidden_states, padding_cache=padding_cache)
else:
hidden_states = layer(hidden_states)
if isinstance(self.shortcut, MimiConv1d):
residual = self.shortcut(residual, padding_cache=padding_cache)
else:
residual = self.shortcut(residual)
return residual + hidden_states
class MimiEncoder(nn.Module):
@ -370,10 +470,17 @@ class MimiEncoder(nn.Module):
self.layers = nn.ModuleList(model)
self._mimiconv1d_layer_names = mimiconv1d_layer_names
# Copied from transformers.models.encodec.modeling_encodec.EncodecEncoder.forward
def forward(self, hidden_states):
# initialize layer_idx for MimiConv1d submodules, necessary for padding_cache
for layer_idx, layername in enumerate(self._mimiconv1d_layer_names):
conv_layer = self.get_submodule(layername)
setattr(conv_layer, "layer_idx", layer_idx)
def forward(self, hidden_states, padding_cache=None):
for layer in self.layers:
hidden_states = layer(hidden_states)
if isinstance(layer, (MimiConv1d, MimiResnetBlock)):
hidden_states = layer(hidden_states, padding_cache=padding_cache)
else:
hidden_states = layer(hidden_states)
return hidden_states
@ -1005,11 +1112,13 @@ class MimiTransformerModel(nn.Module):
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
causal_mask = None
if attention_mask is not None:
causal_mask = self._update_causal_mask(
attention_mask, hidden_states, cache_position, past_key_values, output_attentions
)
causal_mask = create_causal_mask(
config=self.config,
input_embeds=hidden_states,
attention_mask=attention_mask,
cache_position=cache_position,
past_key_values=past_key_values,
)
# decoder layers
all_hidden_states = () if output_hidden_states else None
@ -1054,163 +1163,6 @@ class MimiTransformerModel(nn.Module):
attentions=all_self_attns,
)
# Copied from transformers.models.phimoe.modeling_phimoe.PhimoeModel._update_causal_mask with Phimoe->Mimi
def _update_causal_mask(
self,
attention_mask: Union[torch.Tensor, "BlockMask"],
input_tensor: torch.Tensor,
cache_position: torch.Tensor,
past_key_values: Cache,
output_attentions: bool = False,
):
if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and past_key_values is not None:
is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0]
if is_padding_right:
raise ValueError(
"You are attempting to perform batched generation with padding_side='right'"
" this may lead to unexpected behaviour for Flash Attention version of Mimi. Make sure to "
" call `tokenizer.padding_side = 'left'` before tokenizing the input. "
)
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
return None
if self.config._attn_implementation == "flex_attention":
if isinstance(attention_mask, torch.Tensor):
attention_mask = make_flex_block_causal_mask(attention_mask)
return attention_mask
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
# to infer the attention mask.
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
using_static_cache = isinstance(past_key_values, StaticCache)
using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache)
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
if (
self.config._attn_implementation == "sdpa"
and not (using_static_cache or using_sliding_window_cache)
and not output_attentions
):
if AttentionMaskConverter._ignore_causal_mask_sdpa(
attention_mask,
inputs_embeds=input_tensor,
past_key_values_length=past_seen_tokens,
sliding_window=self.config.sliding_window,
is_training=self.training,
):
return None
dtype = input_tensor.dtype
min_dtype = torch.finfo(dtype).min
sequence_length = input_tensor.shape[1]
# SlidingWindowCache or StaticCache
if using_sliding_window_cache or using_static_cache:
target_length = past_key_values.get_max_cache_shape()
# DynamicCache or no cache
else:
target_length = (
attention_mask.shape[-1]
if isinstance(attention_mask, torch.Tensor)
else past_seen_tokens + sequence_length + 1
)
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
attention_mask,
sequence_length=sequence_length,
target_length=target_length,
dtype=dtype,
cache_position=cache_position,
batch_size=input_tensor.shape[0],
config=self.config,
past_key_values=past_key_values,
)
if (
self.config._attn_implementation == "sdpa"
and attention_mask is not None
and attention_mask.device.type in ["cuda", "xpu", "npu"]
and not output_attentions
):
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
# Details: https://github.com/pytorch/pytorch/issues/110213
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
return causal_mask
@staticmethod
# Copied from transformers.models.phimoe.modeling_phimoe.PhimoeModel._prepare_4d_causal_attention_mask_with_cache_position with Phimoe->Mimi
def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor,
sequence_length: int,
target_length: int,
dtype: torch.dtype,
cache_position: torch.Tensor,
batch_size: int,
config: MimiConfig,
past_key_values: Cache,
):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
Args:
attention_mask (`torch.Tensor`):
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
sequence_length (`int`):
The sequence length being processed.
target_length (`int`):
The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
Batch size.
config (`MimiConfig`):
The model's configuration class
past_key_values (`Cache`):
The cache class that is being used currently to generate
"""
if attention_mask is not None and attention_mask.dim() == 4:
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
causal_mask = attention_mask
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
)
diagonal_attend_mask = torch.arange(target_length, device=cache_position.device) > cache_position.reshape(
-1, 1
)
text_config = config.get_text_config()
if getattr(text_config, "use_sliding_window", True) and text_config.sliding_window is not None:
# if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also
# the check is needed to verify is current checkpoint was trained with sliding window or not
if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length:
sliding_attend_mask = torch.arange(target_length, device=cache_position.device) <= (
cache_position.reshape(-1, 1) - text_config.sliding_window
)
diagonal_attend_mask.bitwise_or_(sliding_attend_mask)
causal_mask *= diagonal_attend_mask
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
if attention_mask.shape[-1] > target_length:
attention_mask = attention_mask[:, :target_length]
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
causal_mask.device
)
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
return causal_mask
class MimiDecoder(nn.Module):
"""SEANet decoder as used by Mimi."""
@ -1269,7 +1221,7 @@ class MimiEuclideanCodebook(nn.Module):
def quantize(self, hidden_states):
# Projects each vector in `hidden_states` over the nearest centroid and return its index.
# `hidden_states` should be `[N, D]` with `N` the number of input vectors and `D` the dimension.
dists = torch.cdist(hidden_states[None], self.embed[None], p=2)[0]
dists = torch.cdist(hidden_states[None].float(), self.embed[None].float(), p=2)[0]
embed_ind = dists.argmin(dim=-1)
return embed_ind
@ -1476,6 +1428,7 @@ class MimiModel(MimiPreTrainedModel):
stride=2,
bias=False,
pad_mode="replicate",
layer_idx=len(self.encoder._mimiconv1d_layer_names),
)
self.upsample = MimiConvTranspose1d(
@ -1512,12 +1465,17 @@ class MimiModel(MimiPreTrainedModel):
num_quantizers: int,
padding_mask: int,
past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None,
padding_cache: Optional[MimiConv1dPaddingCache] = None,
return_dict: Optional[bool] = None,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
"""
Encodes the given input using the underlying VQVAE. The padding mask is required to compute the correct scale.
"""
embeddings = self.encoder(input_values)
# TODO: @eustlb, let's make the encoder support padding_mask so that batched inputs are supported.
embeddings = self.encoder(input_values, padding_cache=padding_cache)
# TODO: @eustlb, convert the padding mask to attention mask.
encoder_outputs = self.encoder_transformer(
embeddings.transpose(1, 2), past_key_values=past_key_values, return_dict=return_dict
)
@ -1526,11 +1484,11 @@ class MimiModel(MimiPreTrainedModel):
elif len(encoder_outputs) > 1:
past_key_values = encoder_outputs[1]
embeddings = encoder_outputs[0].transpose(1, 2)
embeddings = self.downsample(embeddings)
embeddings = self.downsample(embeddings, padding_cache=padding_cache)
codes = self.quantizer.encode(embeddings, num_quantizers)
codes = codes.transpose(0, 1)
return codes, past_key_values
return codes, past_key_values, padding_cache
def get_encoded_length(self, input_length: torch.LongTensor) -> torch.LongTensor:
"""
@ -1570,6 +1528,8 @@ class MimiModel(MimiPreTrainedModel):
padding_mask: Optional[torch.Tensor] = None,
num_quantizers: Optional[float] = None,
encoder_past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None,
padding_cache: Optional[MimiConv1dPaddingCache] = None,
use_streaming: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[tuple[torch.Tensor, Optional[torch.Tensor]], MimiEncoderOutput]:
"""
@ -1598,6 +1558,7 @@ class MimiModel(MimiPreTrainedModel):
`codebook` of shape `[batch_size, num_codebooks, frames]`, the discrete encoded codes for the input audio waveform.
"""
return_dict = return_dict if return_dict is not None else self.config.return_dict
use_streaming = use_streaming if use_streaming is not None else self.config.use_streaming
num_quantizers = self.config.num_quantizers if num_quantizers is None else num_quantizers
@ -1614,11 +1575,31 @@ class MimiModel(MimiPreTrainedModel):
if padding_mask is None:
padding_mask = torch.ones_like(input_values).bool()
encoded_frames, encoder_past_key_values = self._encode_frame(
if use_streaming and padding_cache is None:
per_layer_padding, per_layer_padding_mode, per_layer_in_channels = [], [], []
for layer_name in self.encoder._mimiconv1d_layer_names:
per_layer_padding.append(self.encoder.get_submodule(layer_name).padding_total)
per_layer_padding_mode.append(self.encoder.get_submodule(layer_name).pad_mode)
per_layer_in_channels.append(self.encoder.get_submodule(layer_name).in_channels)
# downsample layer
per_layer_padding.append(self.downsample.padding_total)
per_layer_padding_mode.append(self.downsample.pad_mode)
per_layer_in_channels.append(self.downsample.in_channels)
padding_cache = MimiConv1dPaddingCache(
num_layers=len(self.encoder._mimiconv1d_layer_names) + 1,
per_layer_padding=per_layer_padding,
per_layer_padding_mode=per_layer_padding_mode,
per_layer_in_channels=per_layer_in_channels,
)
encoded_frames, encoder_past_key_values, padding_cache = self._encode_frame(
input_values,
num_quantizers,
padding_mask.bool(),
past_key_values=encoder_past_key_values,
padding_cache=padding_cache,
return_dict=return_dict,
)
@ -1626,9 +1607,10 @@ class MimiModel(MimiPreTrainedModel):
return (
encoded_frames,
encoder_past_key_values,
padding_cache,
)
return MimiEncoderOutput(encoded_frames, encoder_past_key_values)
return MimiEncoderOutput(encoded_frames, encoder_past_key_values, padding_cache)
def _decode_frame(
self,

View File

@ -0,0 +1,29 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
from ...utils import _LazyModule
from ...utils.import_utils import define_import_structure
if TYPE_CHECKING:
from .configuration_kyutai_speech_to_text import *
from .feature_extraction_kyutai_speech_to_text import *
from .modeling_kyutai_speech_to_text import *
from .processing_kyutai_speech_to_text import *
else:
import sys
_file = globals()["__file__"]
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)

View File

@ -0,0 +1,188 @@
# coding=utf-8
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.s
from ...configuration_utils import PretrainedConfig
from ...utils import logging
from ..auto.configuration_auto import AutoConfig
logger = logging.get_logger(__name__)
class KyutaiSpeechToTextConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`KyutaiSpeechToTextForConditionalGeneration`].
It is used to instantiate a Kyutai Speech-to-Text model according to the specified arguments, defining the model
architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the
2.6b-en model.
e.g. [kyutai/stt-2.6b-en](https://huggingface.co/kyutai/stt-2.6b-en)
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
codebook_vocab_size (`int`, *optional*, defaults to 2049):
Vocabulary size of the codebook. Defines the number of different audio tokens that can be represented by each codebook.
vocab_size (`int`, *optional*, defaults to 4001):
Vocabulary size of the model. Defines the number of different tokens that can be represented by the
`input_ids` passed when calling the model.
hidden_size (`int`, *optional*, defaults to 2048):
Dimensionality of the layers and the pooler layer of the main decoder.
num_hidden_layers (`int`, *optional*, defaults to 48):
Number of decoder layers.
num_attention_heads (`int`, *optional*, defaults to 32):
Number of attention heads for each attention layer in the main decoder block.
num_key_value_heads (`int`, *optional*):
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
by meanpooling all the original heads within that group. For more details checkout [this
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `num_attention_heads`.
max_position_embeddings (`int`, *optional*, defaults to 750):
The maximum sequence length that this model might ever be used with. Typically, set this to something large
just in case (e.g., 512 or 1024 or 2048).
rope_theta (`float`, *optional*, defaults to 100000.0):
The base period of the RoPE embeddings.
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
The non-linear activation function (function or string) in the decoder.
head_dim (`int`, *optional*, defaults to `hidden_size // num_attention_heads`):
The attention head dimension.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models). Only
relevant if `config.is_decoder=True`.
sliding_window (`int`, *optional*, defaults to 375):
Sliding window attention window size. If not specified, will default to `3000`.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
ffn_dim (`int`, *optional*, defaults to 11264):
Dimensionality of the "intermediate" (often named feed-forward) layer in the main decoder block. Must be even.
rms_norm_eps (`float`, *optional*, defaults to 1e-08):
The epsilon used by the rms normalization layers.
num_codebooks (`int`, *optional*, defaults to 32):
The number of audio codebooks for each audio channels.
audio_bos_token_id (`int`, *optional*, defaults to 2048):
Beginning of stream token id for codebook tokens.
audio_pad_token_id (`int`, *optional*, defaults to 69569):
Padding token id for codebook tokens.
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
Whether to tie weight embeddings.
pad_token_id (`int`, *optional*, defaults to 3):
Padding token id.
bos_token_id (`int`, *optional*, defaults to 48000):
Beginning of stream token id for text tokens.
codec_config (`PretrainedConfig`, *optional*):
Configuration for the codec.
kwargs (*optional*):
Dictionary of keyword arguments. Notably:
- **audio_encoder_config** ([`PretrainedConfig`], *optional*) -- An instance of a configuration object that
defines the audio encoder config.
- **depth__config** ([`PretrainedConfig`], *optional*) -- An instance of a configuration object that
defines the depth decoder config.
Example:
```python
>>> from transformers import KyutaiSpeechToTextConfig, KyutaiSpeechToTextForConditionalGeneration
>>> # Initializing a KyutaiSpeechToTextConfig
>>> configuration = KyutaiSpeechToTextConfig()
>>> # Initializing a model
>>> model = KyutaiSpeechToTextForConditionalGeneration(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
# not the best naming here for `model_type`, but original codebase already uses model type:`stt` for in the config so we keep it to simplify
model_type = "stt"
keys_to_ignore_at_inference = ["past_key_values"]
sub_configs = {"codec_config": AutoConfig}
def __init__(
self,
codebook_vocab_size=2049,
vocab_size=4001,
hidden_size=2048,
num_hidden_layers=48,
num_attention_heads=32,
num_key_value_heads=None,
max_position_embeddings=750,
rope_theta=100000.0,
hidden_act="silu",
head_dim=None,
initializer_range=0.02,
use_cache=True,
sliding_window=375,
attention_dropout=0.0,
ffn_dim=11264,
rms_norm_eps=1e-8,
num_codebooks=32,
audio_bos_token_id=2048,
audio_pad_token_id=69569,
tie_word_embeddings=False,
pad_token_id=3,
bos_token_id=48000,
codec_config=None,
**kwargs,
):
super().__init__(
pad_token_id=pad_token_id, bos_token_id=bos_token_id, tie_word_embeddings=tie_word_embeddings, **kwargs
)
if codec_config is None:
self.codec_config = AutoConfig.for_model("mimi")
logger.info("codec_config is None, using default audio encoder config.")
elif isinstance(codec_config, dict):
self.codec_config = AutoConfig.for_model(**codec_config)
elif isinstance(codec_config, PretrainedConfig):
self.codec_config = codec_config
self.num_codebooks = num_codebooks
self.frame_size = self.codec_config.frame_size
self.audio_bos_token_id = audio_bos_token_id
self.audio_pad_token_id = audio_pad_token_id
self.codebook_vocab_size = codebook_vocab_size
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
if ffn_dim % 2 == 1:
raise ValueError(f"`ffn_dim={ffn_dim}` must be even.")
self.ffn_dim = ffn_dim
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
# for backward compatibility
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.rope_theta = rope_theta
self.attention_dropout = attention_dropout
self.head_dim = head_dim if head_dim is not None else self.hidden_size // self.num_attention_heads
self.sliding_window = sliding_window
__all__ = ["KyutaiSpeechToTextConfig"]

View File

@ -0,0 +1,377 @@
# coding=utf-8
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import gc
import os
import re
import safetensors.torch
import sentencepiece
import torch
from transformers import (
KyutaiSpeechToTextConfig,
KyutaiSpeechToTextFeatureExtractor,
KyutaiSpeechToTextForConditionalGeneration,
KyutaiSpeechToTextProcessor,
PreTrainedTokenizerFast,
)
from transformers.convert_slow_tokenizer import MoshiConverter
from transformers.utils.hub import cached_file
# fmt: off
MOSHI_ORIGINAL_TO_CONVERTED_KEY_MAPPING = {
r"out_norm": r"norm",
r"gating\.linear_in": r"mlp.fc1",
r"gating\.linear_out": r"mlp.fc2",
r"self_attn\.out_proj": r"self_attn.o_proj.linear",
r"norm1": r"input_layernorm",
r"norm2": r"post_attention_layernorm",
r"layer_scale_1": r"self_attn_layer_scale",
r"layer_scale_2": r"mlp_layer_scale",
r"alpha": r"weight",
}
# fmt: on
# fmt: off
MIMI_ORIGINAL_TO_CONVERTED_KEY_MAPPING = {
r"conv\.conv\.conv": "conv",
r"convtr\.convtr\.convtr": "conv",
r"conv\.conv": "conv",
r"convtr\.convtr": "conv",
r"quantizer\.rvq_first\.vq": "quantizer.semantic_residual_vector_quantizer",
r"quantizer\.rvq_first": "quantizer.semantic_residual_vector_quantizer",
r"quantizer\.rvq_rest\.vq": "quantizer.acoustic_residual_vector_quantizer",
r"quantizer\.rvq_rest": "quantizer.acoustic_residual_vector_quantizer",
r"_codebook": "codebook",
r"_initialized": "initialized",
r"embedding_sum": "embed_sum",
r"encoder\.model": "encoder.layers",
r"decoder\.model": "decoder.layers",
r"encoder_transformer\.transformer": "encoder_transformer",
r"decoder_transformer\.transformer": "decoder_transformer",
r"linear1": "mlp.fc1",
r"linear2": "mlp.fc2",
r"self_attn\.out_proj": "self_attn.o_proj",
r"norm1": "input_layernorm",
r"norm2": "post_attention_layernorm",
r"layer_scale_1": "self_attn_layer_scale",
r"layer_scale_2": "mlp_layer_scale",
}
# fmt: on
def permute_for_rope(input_tensor, n_heads, dim1, dim2):
"""
When you go from the complex ROPE formulation to sin and cos one, you need
to permute the query and key weights (to avoid doing it on the fly)
"""
return input_tensor.view(n_heads, dim1 // n_heads // 2, 2, dim2).transpose(1, 2).reshape(dim1, dim2)
def convert_key(key, mapping):
for pattern, replacement in mapping.items():
key = re.sub(pattern, replacement, key)
return key
def convert_kyutai_speech_to_text_state_dict(state_dict, config, unwanted_prefix="transformer."):
hidden_size = config.hidden_size
head_dim = config.head_dim
num_heads = int(config.hidden_size // config.head_dim)
num_key_value_heads = config.num_key_value_heads
key_value_head_dim = config.num_key_value_heads * head_dim
# concat embeddings
embed_tokens_weight = []
for i in range(32):
embed_tokens_weight.append(state_dict.pop(f"emb.{i}.weight"))
embed_tokens_weight = torch.cat(embed_tokens_weight, dim=0)
embed_tokens_weight = torch.cat([state_dict.pop("text_emb.weight"), embed_tokens_weight])
embed_tokens_weight = torch.cat([embed_tokens_weight, torch.zeros(1, config.hidden_size)], dim=0)
state_dict["embed_tokens.embed_tokens.weight"] = embed_tokens_weight
for key, value in list(state_dict.items()):
if unwanted_prefix is not None and unwanted_prefix in key:
new_key = key[len(unwanted_prefix) :]
else:
new_key = key
new_key = convert_key(new_key, MOSHI_ORIGINAL_TO_CONVERTED_KEY_MAPPING)
# Post-process the current_parameter.
if "alpha" in key:
state_dict[key] = state_dict[key].squeeze()
if "in_proj_weight" in new_key:
# split qkv into query key and value
mixed_qkv = state_dict.pop(key)
qkv_dim = mixed_qkv.size(0) // 3
query_layer = mixed_qkv[:qkv_dim]
key_layer = mixed_qkv[qkv_dim : qkv_dim * 2]
value_layer = mixed_qkv[qkv_dim * 2 :]
state_dict[new_key.replace("in_proj_weight", "q_proj.linear.weight")] = permute_for_rope(
query_layer, num_heads, hidden_size, hidden_size
)
state_dict[new_key.replace("in_proj_weight", "k_proj.linear.weight")] = permute_for_rope(
key_layer, num_key_value_heads, key_value_head_dim, hidden_size
)
state_dict[new_key.replace("in_proj_weight", "v_proj.linear.weight")] = value_layer
else:
state_dict[new_key] = state_dict.pop(key)
return state_dict
def convert_mimi_state_dict(state_dict, config, unwanted_prefix=None):
hidden_size = config.hidden_size
head_dim = config.head_dim
num_heads = int(config.hidden_size // config.head_dim)
num_key_value_heads = config.num_key_value_heads
key_value_head_dim = config.num_key_value_heads * head_dim
for key, value in list(state_dict.items()):
if unwanted_prefix is not None and unwanted_prefix in key:
new_key = key[len(unwanted_prefix) :]
else:
new_key = key
new_key = convert_key(new_key, MIMI_ORIGINAL_TO_CONVERTED_KEY_MAPPING)
if "in_proj_weight" in new_key:
# split qkv into query key and value
mixed_qkv = state_dict.pop(key)
qkv_dim = mixed_qkv.size(0) // 3
query_layer = mixed_qkv[:qkv_dim]
key_layer = mixed_qkv[qkv_dim : qkv_dim * 2]
value_layer = mixed_qkv[qkv_dim * 2 :]
state_dict[new_key.replace("in_proj_weight", "q_proj.weight")] = permute_for_rope(
query_layer, num_heads, hidden_size, hidden_size
)
state_dict[new_key.replace("in_proj_weight", "k_proj.weight")] = permute_for_rope(
key_layer, num_key_value_heads, key_value_head_dim, hidden_size
)
state_dict[new_key.replace("in_proj_weight", "v_proj.weight")] = value_layer
else:
state_dict[new_key] = state_dict.pop(key)
return state_dict
def write_model(
input_path_or_repo,
model_name,
codec_model_path_or_repo,
codec_model_name,
output_dir,
safe_serialization=True,
unwanted_prefix="transformer.",
):
print("Converting the model.")
os.makedirs(output_dir, exist_ok=True)
config = KyutaiSpeechToTextConfig()
config.use_cache = True
config.codec_config.sliding_window = 250
model_path = cached_file(
input_path_or_repo,
model_name,
)
codec_path = cached_file(
codec_model_path_or_repo,
codec_model_name,
)
print(f"Fetching all parameters from the checkpoint at {model_path}...")
state_dict = safetensors.torch.load_file(model_path)
print(f"Fetching all parameters from the checkpoint at {codec_path}...")
codec_state_dict = safetensors.torch.load_file(codec_path)
print("Converting model...")
# -----------------------
# convert parameter names
# -----------------------
state_dict = convert_kyutai_speech_to_text_state_dict(state_dict, config, unwanted_prefix=unwanted_prefix)
codec_state_dict = convert_mimi_state_dict(codec_state_dict, config.codec_config, unwanted_prefix=None)
# -------------------------
# load the weights and save
# -------------------------
print("Loading the checkpoint in a Moshi ASR model.")
with torch.device("meta"):
model = KyutaiSpeechToTextForConditionalGeneration(config)
linear_weight = state_dict.pop("text_linear.weight")
model.model.load_state_dict(state_dict, strict=True, assign=True)
linear_weight = torch.cat([linear_weight, torch.zeros(1, config.hidden_size)])
model.lm_head.load_state_dict({"weight": linear_weight}, strict=True, assign=True)
model.codec_model.load_state_dict(codec_state_dict, strict=True, assign=True)
print("Checkpoint loaded successfully.")
del model.config._name_or_path
del model.config.codec_config._name_or_path
# default generation config
model.generation_config._from_model_config = False
model.generation_config.audio_window_size = 1
model.generation_config.cache_implementation = "sliding_window"
model.codec_model.generation_config._from_model_config = False
model.codec_model.generation_config.cache_implementation = "sliding_window"
model.codec_model.generation_config.use_cache = True
print("Saving the model.")
model.save_pretrained(output_dir, safe_serialization=safe_serialization)
del state_dict, model
# Safety check: reload the converted model
gc.collect()
print("Reloading the model to check if it's saved correctly.")
KyutaiSpeechToTextForConditionalGeneration.from_pretrained(
output_dir, torch_dtype=torch.bfloat16, device_map="auto"
)
print("Model reloaded successfully.")
def write_processor(
input_path_or_repo,
tokenizer_model_name,
codec_model_path_or_repo,
output_dir,
audio_delay_seconds,
audio_silence_prefix_seconds,
):
tokenizer_path = cached_file(
input_path_or_repo,
tokenizer_model_name,
)
tokenizer = MoshiConverter(tokenizer_path).converted()
original_tokenizer = sentencepiece.SentencePieceProcessor(tokenizer_path)
tokenizer = PreTrainedTokenizerFast(
tokenizer_object=tokenizer,
chat_template=None,
unk_token="<unk>",
model_input_names=["input_ids", "attention_mask"],
clean_up_tokenization_spaces=False,
bos_token_id=original_tokenizer.bos_id(),
eos_token_id=original_tokenizer.eos_id(),
pad_token_id=original_tokenizer.pad_id(),
)
feature_extractor = KyutaiSpeechToTextFeatureExtractor(
audio_delay_seconds=audio_delay_seconds,
audio_silence_prefix_seconds=audio_silence_prefix_seconds,
)
processor = KyutaiSpeechToTextProcessor(feature_extractor, tokenizer)
processor.save_pretrained(output_dir)
print(f"Processor saved successfully to {output_dir}")
def main():
parser = argparse.ArgumentParser(description="Convert Moshi ASR weights to HuggingFace format")
parser.add_argument(
"--input_path_or_repo",
type=str,
required=True,
help="Path or repo containing Moshi ASR weights",
)
parser.add_argument(
"--model_name",
type=str,
required=True,
help="Name of the model in input_path_or_repo",
)
parser.add_argument(
"--tokenizer_model_name",
type=str,
required=True,
help="Name of the tokenizer model in input_path_or_repo",
)
parser.add_argument(
"--codec_model_path_or_repo",
type=str,
required=True,
help="Path or repo containing the Mimi weights",
)
parser.add_argument(
"--mimi_name",
type=str,
required=True,
help="Name of the Mimi model in codec_model_path_or_repo",
)
parser.add_argument(
"--preprocessor_model_path_or_repo",
type=str,
required=True,
help="Path or repo containing the preprocessor config",
)
parser.add_argument(
"--output_dir",
help="Location to write HF model and tokenizer",
)
parser.add_argument(
"--safe_serialization", action="store_true", default=True, help="Whether or not to save using `safetensors`."
)
parser.add_argument(
"--audio_delay_seconds",
type=float,
required=True,
help="Audio delay in seconds to add to the right of the input",
)
parser.add_argument(
"--audio_silence_prefix_seconds",
type=float,
required=True,
help="Audio silence prefix in seconds to add to the left of the input",
)
args = parser.parse_args()
write_model(
args.input_path_or_repo,
args.model_name,
args.codec_model_path_or_repo,
args.mimi_name,
args.output_dir,
safe_serialization=args.safe_serialization,
)
write_processor(
args.input_path_or_repo,
args.tokenizer_model_name,
args.preprocessor_model_path_or_repo,
args.output_dir,
args.audio_delay_seconds,
args.audio_silence_prefix_seconds,
)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,237 @@
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# This file was automatically generated from src/transformers/models/stt/modular_kyutai_speech_to_text.py.
# Do NOT edit this file manually as any edits will be overwritten by the generation of
# the file from the modular. If any change should be done, please apply the change to the
# modular_kyutai_speech_to_text.py file directly. One of our CI enforces this.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# coding=utf-8
# Copyright 2025 Kyutai and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Union
import numpy as np
from ...feature_extraction_sequence_utils import SequenceFeatureExtractor
from ...feature_extraction_utils import BatchFeature
from ...utils import PaddingStrategy, TensorType, logging
logger = logging.get_logger(__name__)
class KyutaiSpeechToTextFeatureExtractor(SequenceFeatureExtractor):
r"""
Constructs an KyutaiSpeechToText feature extractor.
This feature extractor inherits from [`~feature_extraction_sequence_utils.SequenceFeatureExtractor`] which contains
most of the main methods. Users should refer to this superclass for more information regarding those methods.
Args:
feature_size (`int`, *optional*, defaults to 1):
The feature dimension of the extracted features. Use 1 for mono, 2 for stereo.
sampling_rate (`int`, *optional*, defaults to 24000):
The sampling rate at which the audio waveform should be digitalized expressed in hertz (Hz).
padding_value (`float`, *optional*, defaults to 0.0):
The value that is used to fill the padding values.
chunk_length_s (`float`, *optional*):
If defined the audio is pre-processed into chunks of lengths `chunk_length_s` and then encoded.
overlap (`float`, *optional*):
Defines the overlap between each chunk. It is used to compute the `chunk_stride` using the following
formulae : `int((1.0 - self.overlap) * self.chunk_length)`.
audio_delay_seconds (`float`, *optional*, defaults to 0.0):
The delay in seconds to add after the audio (right padding).
audio_silence_prefix_seconds (`float`, *optional*, defaults to 0.0):
The silence prefix in seconds to add before the audio (left padding).
"""
model_input_names = ["input_values", "padding_mask"]
def __init__(
self,
feature_size: int = 1,
sampling_rate: int = 24000,
padding_value: float = 0.0,
chunk_length_s: Optional[float] = None,
overlap: Optional[float] = None,
audio_delay_seconds: Optional[float] = 0.0,
audio_silence_prefix_seconds: Optional[float] = 0.0,
**kwargs,
):
super().__init__(feature_size=feature_size, sampling_rate=sampling_rate, padding_value=padding_value, **kwargs)
self.chunk_length_s = chunk_length_s
self.overlap = overlap
self.audio_delay_seconds = audio_delay_seconds
self.audio_silence_prefix_seconds = audio_silence_prefix_seconds
# This is a property because you might want to change the chunk_length_s on the fly
@property
def chunk_length(self) -> Optional[int]:
if self.chunk_length_s is None:
return None
else:
return int(self.chunk_length_s * self.sampling_rate)
# This is a property because you might want to change the chunk_length_s on the fly
@property
def chunk_stride(self) -> Optional[int]:
if self.chunk_length_s is None or self.overlap is None:
return None
else:
return max(1, int((1.0 - self.overlap) * self.chunk_length))
def __call__(
self,
raw_audio: Union[np.ndarray, list[float], list[np.ndarray], list[list[float]]],
padding: Optional[Union[bool, str, PaddingStrategy]] = None,
truncation: Optional[bool] = False,
max_length: Optional[int] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
sampling_rate: Optional[int] = None,
) -> BatchFeature:
"""
Main method to featurize and prepare for the model one or several sequence(s).
Args:
raw_audio (`np.ndarray`, `list[float]`, `list[np.ndarray]`, `list[list[float]]`):
The sequence or batch of sequences to be processed. Each sequence can be a numpy array, a list of float
values, a list of numpy arrays or a list of list of float values. The numpy array must be of shape
`(num_samples,)` for mono audio (`feature_size = 1`), or `(2, num_samples)` for stereo audio
(`feature_size = 2`).
padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`):
Select a strategy to pad the returned sequences (according to the model's padding side and padding
index) among:
- `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
sequence if provided).
- `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
acceptable input length for the model if that argument is not provided.
- `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
lengths).
truncation (`bool`, *optional*, defaults to `False`):
Activates truncation to cut input sequences longer than `max_length` to `max_length`.
max_length (`int`, *optional*):
Maximum length of the returned list and optionally padding length (see above).
return_tensors (`str` or [`~utils.TensorType`], *optional*):
If set, will return tensors instead of list of python integers. Acceptable values are:
- `'tf'`: Return TensorFlow `tf.constant` objects.
- `'pt'`: Return PyTorch `torch.Tensor` objects.
- `'np'`: Return Numpy `np.ndarray` objects.
sampling_rate (`int`, *optional*):
The sampling rate at which the `audio` input was sampled. It is strongly recommended to pass
`sampling_rate` at the forward call to prevent silent errors.
"""
if sampling_rate is not None:
if sampling_rate != self.sampling_rate:
raise ValueError(
f"The model corresponding to this feature extractor: {self} was trained using a sampling rate of"
f" {self.sampling_rate}. Please make sure that the provided audio input was sampled with"
f" {self.sampling_rate} and not {sampling_rate}."
)
else:
logger.warning(
f"It is strongly recommended to pass the `sampling_rate` argument to `{self.__class__.__name__}()`. "
"Failing to do so can result in silent errors that might be hard to debug."
)
if padding and truncation:
raise ValueError("Both padding and truncation were set. Make sure you only set one.")
elif padding is None:
# by default let's pad the inputs
padding = True
is_batched = bool(
isinstance(raw_audio, (list, tuple)) and (isinstance(raw_audio[0], (np.ndarray, tuple, list)))
)
if is_batched:
raw_audio = [np.asarray(audio, dtype=np.float32).T for audio in raw_audio]
elif not is_batched and not isinstance(raw_audio, np.ndarray):
raw_audio = np.asarray(raw_audio, dtype=np.float32)
elif isinstance(raw_audio, np.ndarray) and raw_audio.dtype is np.dtype(np.float64):
raw_audio = raw_audio.astype(np.float32)
# always return batch
if not is_batched:
raw_audio = [np.asarray(raw_audio).T]
# verify inputs are valid
for idx, example in enumerate(raw_audio):
if example.ndim > 2:
raise ValueError(f"Expected input shape (channels, length) but got shape {example.shape}")
if self.feature_size == 1 and example.ndim != 1:
raise ValueError(f"Expected mono audio but example has {example.shape[-1]} channels")
if self.feature_size == 2 and example.shape[-1] != 2:
raise ValueError(f"Expected stereo audio but example has {example.shape[-1]} channels")
padded_inputs = None
input_values = BatchFeature({"input_values": raw_audio})
if self.chunk_stride is not None and self.chunk_length is not None and max_length is None:
if truncation:
max_length = min(array.shape[0] for array in raw_audio)
nb_step = int(np.floor(max_length / self.chunk_stride))
max_length = (nb_step - 1) * self.chunk_stride + self.chunk_length
elif padding:
max_length = max(array.shape[0] for array in raw_audio)
nb_step = int(np.ceil(max_length / self.chunk_stride))
max_length = (nb_step - 1) * self.chunk_stride + self.chunk_length
padding = "max_length"
else:
padded_inputs = input_values
# normal padding on batch
if padded_inputs is None:
padded_inputs = self.pad(
input_values,
max_length=max_length,
truncation=truncation,
padding=padding,
return_attention_mask=padding,
)
if padding:
padded_inputs["padding_mask"] = padded_inputs.pop("attention_mask")
# now let's padd left and right
pad_left = int(self.audio_silence_prefix_seconds * self.sampling_rate)
pad_right = int((self.audio_delay_seconds + 1.0) * self.sampling_rate)
padded_inputs["input_values"] = np.pad(
padded_inputs["input_values"],
((0, 0), (pad_left, pad_right)),
mode="constant",
constant_values=0.0,
)
if padding:
padded_inputs["padding_mask"] = np.pad(
padded_inputs["padding_mask"],
((0, 0), (pad_left, pad_right)),
mode="constant",
constant_values=0,
)
input_values = []
for example in padded_inputs.pop("input_values"):
if self.feature_size == 1:
example = example[..., None]
input_values.append(example.T)
padded_inputs["input_values"] = input_values
if return_tensors is not None:
padded_inputs = padded_inputs.convert_to_tensors(return_tensors)
return padded_inputs
__all__ = ["KyutaiSpeechToTextFeatureExtractor"]

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,510 @@
# coding=utf-8
# Copyright 2025 Kyutai and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import types
from typing import Optional, Union
import numpy as np
import torch
import torch.nn as nn
from ...cache_utils import Cache
from ...feature_extraction_utils import BatchFeature
from ...generation import GenerationConfig, GenerationMixin
from ...modeling_utils import PreTrainedModel
from ...utils import PaddingStrategy, TensorType, logging
from ..auto import AutoModel
from ..encodec.feature_extraction_encodec import EncodecFeatureExtractor
from ..llama.modeling_llama import LlamaForCausalLM
from ..mimi.modeling_mimi import MimiConv1dPaddingCache
from ..moshi.modeling_moshi import MoshiModel, MoshiPreTrainedModel
logger = logging.get_logger(__name__)
class KyutaiSpeechToTextFeatureExtractor(EncodecFeatureExtractor):
r"""
Constructs an KyutaiSpeechToText feature extractor.
This feature extractor inherits from [`~feature_extraction_sequence_utils.SequenceFeatureExtractor`] which contains
most of the main methods. Users should refer to this superclass for more information regarding those methods.
Args:
feature_size (`int`, *optional*, defaults to 1):
The feature dimension of the extracted features. Use 1 for mono, 2 for stereo.
sampling_rate (`int`, *optional*, defaults to 24000):
The sampling rate at which the audio waveform should be digitalized expressed in hertz (Hz).
padding_value (`float`, *optional*, defaults to 0.0):
The value that is used to fill the padding values.
chunk_length_s (`float`, *optional*):
If defined the audio is pre-processed into chunks of lengths `chunk_length_s` and then encoded.
overlap (`float`, *optional*):
Defines the overlap between each chunk. It is used to compute the `chunk_stride` using the following
formulae : `int((1.0 - self.overlap) * self.chunk_length)`.
audio_delay_seconds (`float`, *optional*, defaults to 0.0):
The delay in seconds to add after the audio (right padding).
audio_silence_prefix_seconds (`float`, *optional*, defaults to 0.0):
The silence prefix in seconds to add before the audio (left padding).
"""
def __init__(
self,
audio_delay_seconds: Optional[float] = 0.0,
audio_silence_prefix_seconds: Optional[float] = 0.0,
**super_kwargs,
):
super().__init__(**super_kwargs)
self.audio_delay_seconds = audio_delay_seconds
self.audio_silence_prefix_seconds = audio_silence_prefix_seconds
def __call__(
self,
raw_audio: Union[np.ndarray, list[float], list[np.ndarray], list[list[float]]],
padding: Optional[Union[bool, str, PaddingStrategy]] = None,
truncation: Optional[bool] = False,
max_length: Optional[int] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
sampling_rate: Optional[int] = None,
) -> BatchFeature:
"""
Main method to featurize and prepare for the model one or several sequence(s).
Args:
raw_audio (`np.ndarray`, `list[float]`, `list[np.ndarray]`, `list[list[float]]`):
The sequence or batch of sequences to be processed. Each sequence can be a numpy array, a list of float
values, a list of numpy arrays or a list of list of float values. The numpy array must be of shape
`(num_samples,)` for mono audio (`feature_size = 1`), or `(2, num_samples)` for stereo audio
(`feature_size = 2`).
padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`):
Select a strategy to pad the returned sequences (according to the model's padding side and padding
index) among:
- `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
sequence if provided).
- `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
acceptable input length for the model if that argument is not provided.
- `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
lengths).
truncation (`bool`, *optional*, defaults to `False`):
Activates truncation to cut input sequences longer than `max_length` to `max_length`.
max_length (`int`, *optional*):
Maximum length of the returned list and optionally padding length (see above).
return_tensors (`str` or [`~utils.TensorType`], *optional*):
If set, will return tensors instead of list of python integers. Acceptable values are:
- `'tf'`: Return TensorFlow `tf.constant` objects.
- `'pt'`: Return PyTorch `torch.Tensor` objects.
- `'np'`: Return Numpy `np.ndarray` objects.
sampling_rate (`int`, *optional*):
The sampling rate at which the `audio` input was sampled. It is strongly recommended to pass
`sampling_rate` at the forward call to prevent silent errors.
"""
if sampling_rate is not None:
if sampling_rate != self.sampling_rate:
raise ValueError(
f"The model corresponding to this feature extractor: {self} was trained using a sampling rate of"
f" {self.sampling_rate}. Please make sure that the provided audio input was sampled with"
f" {self.sampling_rate} and not {sampling_rate}."
)
else:
logger.warning(
f"It is strongly recommended to pass the `sampling_rate` argument to `{self.__class__.__name__}()`. "
"Failing to do so can result in silent errors that might be hard to debug."
)
if padding and truncation:
raise ValueError("Both padding and truncation were set. Make sure you only set one.")
elif padding is None:
# by default let's pad the inputs
padding = True
is_batched = bool(
isinstance(raw_audio, (list, tuple)) and (isinstance(raw_audio[0], (np.ndarray, tuple, list)))
)
if is_batched:
raw_audio = [np.asarray(audio, dtype=np.float32).T for audio in raw_audio]
elif not is_batched and not isinstance(raw_audio, np.ndarray):
raw_audio = np.asarray(raw_audio, dtype=np.float32)
elif isinstance(raw_audio, np.ndarray) and raw_audio.dtype is np.dtype(np.float64):
raw_audio = raw_audio.astype(np.float32)
# always return batch
if not is_batched:
raw_audio = [np.asarray(raw_audio).T]
# verify inputs are valid
for idx, example in enumerate(raw_audio):
if example.ndim > 2:
raise ValueError(f"Expected input shape (channels, length) but got shape {example.shape}")
if self.feature_size == 1 and example.ndim != 1:
raise ValueError(f"Expected mono audio but example has {example.shape[-1]} channels")
if self.feature_size == 2 and example.shape[-1] != 2:
raise ValueError(f"Expected stereo audio but example has {example.shape[-1]} channels")
padded_inputs = None
input_values = BatchFeature({"input_values": raw_audio})
if self.chunk_stride is not None and self.chunk_length is not None and max_length is None:
if truncation:
max_length = min(array.shape[0] for array in raw_audio)
nb_step = int(np.floor(max_length / self.chunk_stride))
max_length = (nb_step - 1) * self.chunk_stride + self.chunk_length
elif padding:
max_length = max(array.shape[0] for array in raw_audio)
nb_step = int(np.ceil(max_length / self.chunk_stride))
max_length = (nb_step - 1) * self.chunk_stride + self.chunk_length
padding = "max_length"
else:
padded_inputs = input_values
# normal padding on batch
if padded_inputs is None:
padded_inputs = self.pad(
input_values,
max_length=max_length,
truncation=truncation,
padding=padding,
return_attention_mask=padding,
)
if padding:
padded_inputs["padding_mask"] = padded_inputs.pop("attention_mask")
# now let's padd left and right
pad_left = int(self.audio_silence_prefix_seconds * self.sampling_rate)
pad_right = int((self.audio_delay_seconds + 1.0) * self.sampling_rate)
padded_inputs["input_values"] = np.pad(
padded_inputs["input_values"],
((0, 0), (pad_left, pad_right)),
mode="constant",
constant_values=0.0,
)
if padding:
padded_inputs["padding_mask"] = np.pad(
padded_inputs["padding_mask"],
((0, 0), (pad_left, pad_right)),
mode="constant",
constant_values=0,
)
input_values = []
for example in padded_inputs.pop("input_values"):
if self.feature_size == 1:
example = example[..., None]
input_values.append(example.T)
padded_inputs["input_values"] = input_values
if return_tensors is not None:
padded_inputs = padded_inputs.convert_to_tensors(return_tensors)
return padded_inputs
class KyutaiSpeechToTextPreTrainedModel(MoshiPreTrainedModel):
pass
class KyutaiSpeechToTextConv1dPaddingCache(MimiConv1dPaddingCache):
pass
class KyutaiSpeechToTextEmbeddings(nn.Module):
def __init__(self, config):
super().__init__()
self.embed_tokens = nn.Embedding(
config.vocab_size + (config.num_codebooks * config.codebook_vocab_size) + 1,
config.hidden_size,
padding_idx=config.audio_pad_token_id,
)
audio_tokens_offsets = torch.arange(config.num_codebooks) * config.codebook_vocab_size
audio_tokens_offsets += config.vocab_size
audio_tokens_offsets = nn.functional.pad(
audio_tokens_offsets, (1, 0)
) # pad one 0 to the left for the text token
self.register_buffer("audio_tokens_offsets", audio_tokens_offsets, persistent=False)
def forward(self, input_ids):
input_ids = torch.where(
input_ids == self.embed_tokens.padding_idx, input_ids, input_ids + self.audio_tokens_offsets
)
inputs_embeds = self.embed_tokens(input_ids)
inputs_embeds = inputs_embeds.sum(dim=2)
return inputs_embeds
class KyutaiSpeechToTextModel(MoshiModel):
def __init__(self, config):
super().__init__(config)
self.embed_tokens = KyutaiSpeechToTextEmbeddings(config)
class KyutaiSpeechToTextForConditionalGeneration(LlamaForCausalLM, GenerationMixin, PreTrainedModel):
_keep_in_fp32_modules = ["codec_model"]
def __init__(self, config):
super().__init__(config)
self.codec_model = AutoModel.from_config(config.codec_config)
# we are in an edge case where for the codec_model self.can_generate is False, setting self.codec_model.generation_config to None
# yet the codec_model needs a generation config to initalize it's cache for streaming inference
# we therefore initialize a generation config for the codec model
self.codec_model.generation_config = GenerationConfig.from_model_config(config.codec_config)
def forward(self, **super_kwargs):
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
Example:
```python
>>> import torch
>>> from datasets import load_dataset, Audio
>>> from transformers import KyutaiSpeechToTextProcessor, KyutaiSpeechToTextForConditionalGeneration
>>> torch_device = "cuda" if torch.cuda.is_available() else "cpu"
>>> model_id = "kyutai/stt-2.6b-en"
>>> processor = KyutaiSpeechToTextProcessor.from_pretrained(model_id)
>>> model = KyutaiSpeechToTextForConditionalGeneration.from_pretrained(model_id, device_map=torch_device)
>>> ds = load_dataset(
... "hf-internal-testing/librispeech_asr_dummy", "clean", split="validation"
... )
>>> ds = ds.cast_column("audio", Audio(sampling_rate=24000))
>>> inputs = processor(
... ds[0]["audio"]["array"],
... )
>>> inputs.to(torch_device)
>>> output_tokens = model.generate(**inputs)
>>> print(processor.batch_decode(output_tokens, skip_special_tokens=True))
```"""
super().forward(**super_kwargs)
def _prepare_generation_config(self, *args, **kwargs):
generation_config, model_kwargs = GenerationMixin._prepare_generation_config(*args, **kwargs)
# this should be passed to the model kwargs for the input preparation
model_kwargs["audio_window_size"] = (
generation_config.audio_window_size if hasattr(generation_config, "audio_window_size") else None
)
return generation_config, model_kwargs
def _prepare_model_inputs(
self,
inputs: Optional[torch.Tensor] = None,
bos_token_id: Optional[torch.Tensor] = None,
model_kwargs: Optional[dict[str, torch.Tensor]] = None,
) -> tuple[torch.Tensor, Optional[str], dict[str, torch.Tensor]]:
inputs, input_name, model_kwargs = GenerationMixin._prepare_model_inputs(
inputs=inputs,
bos_token_id=bos_token_id,
model_kwargs=model_kwargs,
)
audio_window_size = model_kwargs.get("audio_window_size", None)
if audio_window_size is None:
audio_window_size = self.codec_model.get_encoded_length(model_kwargs["input_values"].shape[-1]).item()
model_kwargs["audio_window_size"] = audio_window_size
batch_size = inputs.shape[0]
device = inputs.device
# initialize audio tokens
model_kwargs["audio_tokens"] = torch.zeros(
(batch_size, audio_window_size, self.config.num_codebooks),
device=device,
dtype=torch.long,
)
model_kwargs["current_window"] = (
torch.tensor([0, 0], device=device, dtype=torch.long).expand(batch_size, -1).contiguous()
)
# let's use generate's cache preparation to prepare the cache for the codec model
temporary_model_kwargs = {}
# monkey patching the codec model with cache preparation methods since we don't want it to inherit fully from GenerationMixin
# Add cache-related methods from GenerationMixin to codec model
cache_methods = [
"_prepare_cache_for_generation",
"_get_cache",
"_supports_default_dynamic_cache",
"_get_layer_device_map_for_cache_init",
]
for method in cache_methods:
setattr(self.codec_model, method, types.MethodType(getattr(self, method).__func__, self.codec_model))
self.codec_model._prepare_cache_for_generation(
generation_config=self.codec_model.generation_config,
model_kwargs=temporary_model_kwargs,
assistant_model=None,
batch_size=batch_size,
max_cache_length=self.config.codec_config.sliding_window,
device=device,
)
if "past_key_values" in temporary_model_kwargs:
model_kwargs["encoder_past_key_values"] = temporary_model_kwargs["past_key_values"]
# initialize the padding cache for the codec model
per_layer_padding, per_layer_padding_mode, per_layer_in_channels = [], [], []
for layer_name in self.codec_model.encoder._mimiconv1d_layer_names:
per_layer_padding.append(self.codec_model.encoder.get_submodule(layer_name).padding_total)
per_layer_padding_mode.append(self.codec_model.encoder.get_submodule(layer_name).pad_mode)
per_layer_in_channels.append(self.codec_model.encoder.get_submodule(layer_name).in_channels)
# downsample layer
per_layer_padding.append(self.codec_model.downsample.padding_total)
per_layer_padding_mode.append(self.codec_model.downsample.pad_mode)
per_layer_in_channels.append(self.codec_model.downsample.in_channels)
model_kwargs["padding_cache"] = KyutaiSpeechToTextConv1dPaddingCache(
num_layers=len(self.codec_model.encoder._mimiconv1d_layer_names) + 1,
per_layer_padding=per_layer_padding,
per_layer_padding_mode=per_layer_padding_mode,
per_layer_in_channels=per_layer_in_channels,
)
return inputs, input_name, model_kwargs
def prepare_inputs_for_generation(
self,
*args,
audio_tokens: Optional[torch.LongTensor] = None,
input_values: Optional[torch.FloatTensor] = None,
padding_mask: Optional[torch.Tensor] = None,
audio_window_size: Optional[int] = None,
current_window: Optional[tuple[int, int]] = None,
encoder_past_key_values: Optional[Cache] = None,
padding_cache: Optional[KyutaiSpeechToTextConv1dPaddingCache] = None,
**kwargs,
):
model_inputs = GenerationMixin.prepare_inputs_for_generation(*args, **kwargs)
if input_values is not None:
cache_position = model_inputs["cache_position"]
start, end = current_window[0]
# first cache position is for bos token, so we need to offset by -1
if cache_position[-1] - 1 >= end:
# we need to encode the new audio tokens
with torch.no_grad():
input_values_start_idx = start * self.config.frame_size
input_values_end_idx = (start + audio_window_size) * self.config.frame_size
current_input_values = input_values[..., input_values_start_idx:input_values_end_idx]
codec_model_output = self.codec_model.encode(
current_input_values,
encoder_past_key_values=encoder_past_key_values,
padding_cache=padding_cache,
)
new_audio_tokens = codec_model_output.audio_codes.transpose(1, 2)
audio_tokens.copy_(new_audio_tokens)
start = end.clone()
end = end + audio_window_size
current_window.copy_(
torch.tensor([start, end], device=current_window.device).expand(current_window.shape[0], -1)
)
# first cache position is for bos token, so we need to offset by -1
current_audio_tokens_idxs = (cache_position - start - 1).clamp(min=0)
current_audio_tokens = audio_tokens[:, current_audio_tokens_idxs, :]
current_audio_tokens[:, cache_position == 0, :] = self.config.audio_bos_token_id
input_ids = model_inputs.pop("input_ids")
input_ids = torch.cat(
[input_ids.unsqueeze(2), current_audio_tokens],
dim=2,
)
model_inputs["input_ids"] = input_ids
return model_inputs
# TODO: @eustlb, this should be standardized
@classmethod
def from_pretrained(cls, *args, **kwargs):
if kwargs.get("output_loading_info", False):
model, loading_info = PreTrainedModel.from_pretrained(*args, **kwargs)
else:
model = PreTrainedModel.from_pretrained(*args, **kwargs)
# copy depth decoder generation conf attr to the depth decoder generation config
prefix = "codec_"
prefix_len = len(prefix)
codec_model_attrs = {
attr[prefix_len:]: value
for attr, value in vars(model.generation_config).items()
if attr.startswith(prefix)
}
vars(model.codec_model.generation_config).update({"_from_model_config": False, **codec_model_attrs})
# remove the depth decoder generation conf attr from the model generation config
for attr in codec_model_attrs:
delattr(model.generation_config, prefix + attr)
if "output_loading_info" in kwargs:
return model, loading_info
else:
return model
# TODO: @eustlb, this should be standardized
def save_pretrained(self, *args, **kwargs):
prefix = "codec_"
codec_model_attrs = self.codec_model.generation_config.to_diff_dict()
codec_model_attrs.pop("transformers_version", None)
for attr, value in codec_model_attrs.items():
setattr(self.generation_config, prefix + attr, value)
PreTrainedModel.save_pretrained(self, *args, **kwargs)
def generate(self, *args, **kwargs):
r"""
This method forwards all its arguments to GenerationMixin's [`~GenerationMixin.generate`]. Please refer to the docstring of this method for more information.
"""
max_new_tokens = kwargs.pop("max_new_tokens", None)
input_values = kwargs.get("input_values")
# TODO: @eustlb, we should have per-batch-idx values
# here we do not use padding_mask to be aligned to what's done in the original codebase
max_audio_frames = input_values.shape[-1] // self.config.codec_config.frame_size
if max_new_tokens is None or max_new_tokens > max_audio_frames:
if max_new_tokens is not None:
logger.warning(
f"`max_new_tokens` ({max_new_tokens}) is greater than the maximum number of audio frames ({max_audio_frames})."
f"Setting `max_new_tokens` to {max_audio_frames}."
)
max_new_tokens = max_audio_frames
return GenerationMixin.generate(
*args,
max_new_tokens=max_new_tokens,
**kwargs,
)
__all__ = [
"KyutaiSpeechToTextPreTrainedModel",
"KyutaiSpeechToTextModel",
"KyutaiSpeechToTextForConditionalGeneration",
"KyutaiSpeechToTextFeatureExtractor",
]

View File

@ -0,0 +1,104 @@
# coding=utf-8
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
from ...audio_utils import AudioInput, make_list_of_audio
from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack
class KyutaiSpeechToTextProcessorKwargs(ProcessingKwargs, total=False):
_defaults = {
"audio_kwargs": {
"sampling_rate": 24000,
},
"common_kwargs": {"return_tensors": "pt"},
}
class KyutaiSpeechToTextProcessor(ProcessorMixin):
r"""
Constructs a Moshi ASR processor which wraps [`EncodecFeatureExtractor`] and
[`PreTrainedTokenizerFast`] into a single processor that inherits both the audio feature extraction and
tokenizer functionalities. See the [`~KyutaiSpeechToTextProcessor.__call__`] for more
information.
"""
feature_extractor_class = "KyutaiSpeechToTextFeatureExtractor"
tokenizer_class = "PreTrainedTokenizerFast"
def __call__(
self,
audio: Optional[AudioInput] = None,
**kwargs: Unpack[KyutaiSpeechToTextProcessorKwargs],
):
r"""
Main method to prepare audio to be fed as input to the model. This method forwards the `audio`
arguments to KyutaiSpeechToTextFeatureExtractor's [`~KyutaiSpeechToTextFeatureExtractor.__call__`]. Please refer
to the docstring of the above method for more information.
Args:
audio (`np.ndarray`, `torch.Tensor`, `list[np.ndarray]`, `list[torch.Tensor]`):
The audio or batch of audio to be prepared. Each audio can be a NumPy array or PyTorch
tensor.
return_tensors (`str` or [`~utils.TensorType`], *optional*):
If set, will return tensors of a particular framework. Acceptable values are:
- `'tf'`: Return TensorFlow `tf.constant` objects.
- `'pt'`: Return PyTorch `torch.Tensor` objects.
- `'np'`: Return NumPy `np.ndarray` objects.
- `'jax'`: Return JAX `jnp.ndarray` objects.
Returns:
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
- **input_values** -- List of audio values to be fed to a model. Returned when `audio` is not `None`.
- **padding_mask** -- List of indices specifying which input values should be ignored by the model.
"""
if audio is None:
raise ValueError("`audio` is required.")
output_kwargs = self._merge_kwargs(
KyutaiSpeechToTextProcessorKwargs,
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
**kwargs,
)
audio_kwargs = output_kwargs["audio_kwargs"]
# ensure audio in correct format
audio = make_list_of_audio(audio)
inputs = self.feature_extractor(
audio,
**audio_kwargs,
)
return inputs
def decode(self, *args, **kwargs):
"""
This method forwards all its arguments to KyutaiSpeechToTextTokenizer's [`~PreTrainedTokenizer.decode`]. Please refer to
the docstring of this method for more information.
"""
return self.tokenizer.decode(*args, **kwargs)
def batch_decode(self, *args, **kwargs):
"""
This method forwards all its arguments to KyutaiSpeechToTextTokenizer's [`~PreTrainedTokenizer.batch_decode`]. Please
refer to the docstring of this method for more information.
"""
return self.tokenizer.batch_decode(*args, **kwargs)
__all__ = ["KyutaiSpeechToTextProcessor"]

View File

@ -0,0 +1,704 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Testing suite for the PyTorch Moshi ASR model."""
import gc
import inspect
import tempfile
import unittest
import datasets
import pytest
from parameterized import parameterized
from transformers import (
KyutaiSpeechToTextConfig,
KyutaiSpeechToTextForConditionalGeneration,
KyutaiSpeechToTextProcessor,
is_torch_available,
)
from transformers.testing_utils import (
cleanup,
require_torch,
require_torch_accelerator,
require_torch_sdpa,
slow,
torch_device,
)
from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import (
TEST_EAGER_MATCHES_SDPA_INFERENCE_PARAMETERIZATION,
ModelTesterMixin,
_config_zero_init,
floats_tensor,
ids_tensor,
)
from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available():
import torch
from transformers import (
KyutaiSpeechToTextForConditionalGeneration,
KyutaiSpeechToTextModel,
)
class KyutaiSpeechToTextModelTester:
def __init__(
self,
parent,
batch_size=13,
seq_length=7,
text_seq_length=1,
input_values_length=192, # gives 3 audio tokens, corresponding to the default in GenerationTesterMixin
is_training=False,
use_input_mask=True,
use_token_type_ids=False,
use_labels=True,
codebook_vocab_size=2049,
vocab_size=99,
hidden_size=32,
num_hidden_layers=2,
num_attention_heads=4,
num_key_value_heads=None,
max_position_embeddings=512,
rope_theta=10000.0,
hidden_act="silu",
head_dim=None,
initializer_range=0.02,
use_cache=True,
sliding_window=512,
attention_dropout=0.1,
ffn_dim=38,
rms_norm_eps=1e-6,
num_codebooks=8,
frame_size=64,
delay_in_tokens=5,
audio_bos_token_id=2048,
audio_pad_token_id=2048,
tie_word_embeddings=False,
pad_token_id=0,
bos_token_id=1,
codec_config={
"model_type": "mimi",
"num_quantizers": 8,
"audio_channels": 1,
"chunk_in_sec": None,
"hidden_size": 16,
"num_filters": 8,
"num_residual_layers": 1,
"upsampling_ratios": [8, 4],
"codebook_size": 16,
"vector_quantization_hidden_dimension": 16,
"upsample_groups": 16,
"num_hidden_layers": 2,
"num_attention_heads": 2,
"num_key_value_heads": 2,
"sliding_window": 4,
"codebook_dim": 16,
"use_cache": False,
},
scope=None,
):
self.parent = parent
self.batch_size = batch_size
self.seq_length = seq_length
self.text_seq_length = text_seq_length
self.is_training = is_training
self.use_input_mask = use_input_mask
self.use_token_type_ids = use_token_type_ids
self.use_labels = use_labels
self.codebook_vocab_size = codebook_vocab_size
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.max_position_embeddings = max_position_embeddings
self.rope_theta = rope_theta
self.hidden_act = hidden_act
self.head_dim = head_dim
self.initializer_range = initializer_range
self.use_cache = use_cache
self.sliding_window = sliding_window
self.attention_dropout = attention_dropout
self.ffn_dim = ffn_dim
self.rms_norm_eps = rms_norm_eps
self.num_codebooks = num_codebooks
self.frame_size = frame_size
self.delay_in_tokens = delay_in_tokens
self.audio_bos_token_id = audio_bos_token_id
self.audio_pad_token_id = audio_pad_token_id
self.tie_word_embeddings = tie_word_embeddings
self.pad_token_id = pad_token_id
self.bos_token_id = bos_token_id
self.codec_config = codec_config
self.scope = scope
self.input_values_length = input_values_length
def get_config(self):
return KyutaiSpeechToTextConfig(
codebook_vocab_size=self.codebook_vocab_size,
vocab_size=self.vocab_size,
hidden_size=self.hidden_size,
num_hidden_layers=self.num_hidden_layers,
num_attention_heads=self.num_attention_heads,
num_key_value_heads=self.num_key_value_heads,
max_position_embeddings=self.max_position_embeddings,
rope_theta=self.rope_theta,
hidden_act=self.hidden_act,
head_dim=self.head_dim,
initializer_range=self.initializer_range,
use_cache=self.use_cache,
sliding_window=self.sliding_window,
attention_dropout=self.attention_dropout,
ffn_dim=self.ffn_dim,
rms_norm_eps=self.rms_norm_eps,
num_codebooks=self.num_codebooks,
frame_size=self.frame_size,
delay_in_tokens=self.delay_in_tokens,
audio_bos_token_id=self.audio_bos_token_id,
audio_pad_token_id=self.audio_pad_token_id,
tie_word_embeddings=self.tie_word_embeddings,
pad_token_id=self.pad_token_id,
bos_token_id=self.bos_token_id,
codec_config=self.codec_config,
)
def create_and_check_model(self, config, input_ids, input_mask):
model = KyutaiSpeechToTextModel(config=config)
model.to(torch_device)
model.eval()
result = model(input_ids, attention_mask=input_mask)
result = model(input_ids)
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
def prepare_config_and_inputs(self):
config = self.get_config()
text_input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size - 1) + 1
codebook_input_ids = (
ids_tensor([self.batch_size, self.seq_length, self.num_codebooks], self.codebook_vocab_size - 1) + 1
)
input_ids = torch.cat([text_input_ids.unsqueeze(2), codebook_input_ids], dim=2)
attention_mask = text_input_ids.ne(1).to(torch_device)
return config, input_ids, attention_mask
def prepare_config_and_inputs_generate(self):
config = self.get_config()
input_ids = torch.ones([self.batch_size, 1], dtype=torch.long, device=torch_device)
input_values = floats_tensor([self.batch_size, 1, self.input_values_length])
padding_mask = torch.ones_like(input_values, dtype=torch.int32, device=torch_device)
return config, input_ids, input_values, padding_mask
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
(
config,
input_ids,
attention_mask,
) = config_and_inputs
inputs_dict = {"input_ids": input_ids, "attention_mask": attention_mask}
return config, inputs_dict
def prepare_config_and_inputs_for_common_generate(self):
config_and_inputs = self.prepare_config_and_inputs_generate()
(
config,
input_ids,
input_values,
padding_mask,
) = config_and_inputs
inputs_dict = {
"input_ids": input_ids,
"input_values": input_values,
"padding_mask": padding_mask,
}
return config, inputs_dict
@require_torch
class KyutaiSpeechToTextModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (
(
KyutaiSpeechToTextModel,
KyutaiSpeechToTextForConditionalGeneration,
)
if is_torch_available()
else ()
)
pipeline_model_mapping = (
{
"feature-extraction": KyutaiSpeechToTextModel,
"automatic-speech-recognition": KyutaiSpeechToTextForConditionalGeneration,
}
if is_torch_available()
else {}
)
test_headmasking = False
test_pruning = False
fx_compatible = False # Broken by attention refactor cc @Cyrilvallez
# Need to use `0.8` instead of `0.9` for `test_cpu_offload`
# This is because we are hitting edge cases with the causal_mask buffer
model_split_percents = [0.5, 0.7, 0.8]
def setUp(self):
self.model_tester = KyutaiSpeechToTextModelTester(self)
self.config_tester = ConfigTester(self, config_class=KyutaiSpeechToTextConfig, hidden_size=37)
def test_config(self):
self.config_tester.run_common_tests()
def test_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels)
return inputs_dict
def prepare_config_and_inputs_for_generate(self, batch_size=2):
# monkey patch prepare_config_and_inputs_for_common
prepare_config_and_inputs_for_common = self.model_tester.prepare_config_and_inputs_for_common
original_batch_size = self.model_tester.batch_size
self.model_tester.prepare_config_and_inputs_for_common = (
self.model_tester.prepare_config_and_inputs_for_common_generate
)
self.model_tester.batch_size = batch_size
config, filtered_inputs_dict = super().prepare_config_and_inputs_for_generate()
self.model_tester.prepare_config_and_inputs_for_common = prepare_config_and_inputs_for_common
self.model_tester.batch_size = original_batch_size
return config, filtered_inputs_dict
@pytest.mark.skip(reason="Moshi ASR has custom embedding approach (text and audio embeddings).")
def test_model_get_set_embeddings(self):
pass
@pytest.mark.skip(reason="Moshi ASR has custom embedding approach (text and audio embeddings).")
def test_tie_model_weights(self):
pass
@pytest.mark.skip(reason="Moshi ASR has custom embedding approach (text and audio embeddings).")
def test_resize_embeddings_untied(self):
pass
@pytest.mark.skip(reason="Moshi ASR has custom embedding approach (text and audio embeddings).")
def test_resize_tokens_embeddings(self):
pass
@pytest.mark.skip(reason="Moshi ASR has custom embedding approach (text and audio embeddings).")
def test_tied_weights_keys(self):
pass
@pytest.mark.skip(reason="Does not apply to Moshi ASR that requires input_values.")
def test_generate_without_input_ids(self):
pass
def test_initialization(self):
"""
Overrides [ModelTesterMixin.test_initialization] because of specificities of Mimi codec model.
See https://github.com/huggingface/transformers/blob/1077603410cd73ba71d64a522033574d66d64b55/tests/models/mimi/test_modeling_mimi.py#L384-L397
"""
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
configs_no_init = _config_zero_init(config)
for model_class in self.all_model_classes:
model = model_class(config=configs_no_init)
for name, param in model.named_parameters():
uniform_init_parms = ["conv", "input_proj", "output_proj"]
if param.requires_grad:
if any(x in name for x in uniform_init_parms):
self.assertTrue(
-1.0 <= ((param.data.mean() * 1e9).round() / 1e9).item() <= 1.0,
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
)
@parameterized.expand(TEST_EAGER_MATCHES_SDPA_INFERENCE_PARAMETERIZATION)
@require_torch_sdpa
def test_eager_matches_sdpa_inference(
self, name, torch_dtype, padding_side, use_attention_mask, output_attentions, enable_kernels
):
if use_attention_mask or (not use_attention_mask and torch_dtype == "fp32" and not output_attentions):
self.skipTest("Test is failing, fix me :) ")
parent_parameterized_test = getattr(ModelTesterMixin, self._testMethodName)
parent_parameterized_test(self)
@unittest.skip(reason="Some undefined behavior encountered with test versions of this model. Skip for now.")
def test_cpu_offload(self):
pass
@unittest.skip(reason="Some undefined behavior encountered with test versions of this model. Skip for now.")
def test_disk_offload_bin(self):
pass
@unittest.skip(reason="Some undefined behavior encountered with test versions of this model. Skip for now.")
def test_disk_offload_safetensors(self):
pass
@pytest.mark.generate
def test_left_padding_compatibility(self):
# NOTE: left-padding results in small numerical differences. This is expected.
# See https://github.com/huggingface/transformers/issues/25420#issuecomment-1775317535
# First, filter out models that don't support left padding
# - The model must have generative capabilities
if len(self.all_generative_model_classes) == 0:
self.skipTest(reason="No generative architecture available for this model.")
# - The model must support padding
if not self.has_attentions:
self.skipTest(reason="This model doesn't support padding.")
# - The model must be a decoder-only architecture (encoder-based architectures use right-padding)
decoder_only_classes = []
for model_class in self.all_generative_model_classes:
config, _ = self.prepare_config_and_inputs_for_generate()
if config.is_encoder_decoder:
continue
else:
decoder_only_classes.append(model_class)
if len(decoder_only_classes) == 0:
self.skipTest(reason="No decoder-only architecture available for this model.")
# - Decoder-only architectures derived from encoder-decoder models could support it in theory, but we haven't
# added support for it yet. We skip these models for now.
has_encoder_attributes = any(
attr_name
for attr_name in config.to_dict().keys()
if attr_name.startswith("encoder") and attr_name != "encoder_no_repeat_ngram_size"
)
if has_encoder_attributes:
self.skipTest(
reason="The decoder-only derived from encoder-decoder models are not expected to support left-padding."
)
# Then, test left-padding
def _prepare_model_kwargs(input_ids, attention_mask, signature):
model_kwargs = {"input_ids": input_ids, "attention_mask": attention_mask}
if "position_ids" in signature:
position_ids = torch.cumsum(attention_mask, dim=-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
model_kwargs["position_ids"] = position_ids
if "cache_position" in signature:
cache_position = torch.arange(input_ids.shape[1], device=torch_device)
model_kwargs["cache_position"] = cache_position
return model_kwargs
for model_class in decoder_only_classes:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
input_ids = inputs_dict["input_ids"]
attention_mask = inputs_dict.get("attention_mask")
if attention_mask is None:
attention_mask = torch.ones_like(input_ids)
model = model_class(config).to(torch_device).eval()
signature = inspect.signature(model.forward).parameters.keys()
# no cache as some models require special cache classes to be init outside forward
model.generation_config.use_cache = False
# Without padding
model_kwargs = _prepare_model_kwargs(input_ids, attention_mask, signature)
next_logits_wo_padding = model(**model_kwargs).logits[:, -1, :]
# With left-padding (length 32)
# can hardcode pad_token to be 0 as we'll do attn masking anyway
pad_token_id = (
config.get_text_config().pad_token_id if config.get_text_config().pad_token_id is not None else 0
)
pad_size = (input_ids.shape[0], 32, *input_ids.shape[2:])
padding = torch.ones(pad_size, dtype=input_ids.dtype, device=torch_device) * pad_token_id
padded_input_ids = torch.cat((padding, input_ids), dim=1)
padded_attention_mask = torch.cat(
(torch.zeros(pad_size[:2], dtype=input_ids.dtype, device=torch_device), attention_mask), dim=1
)
model_kwargs = _prepare_model_kwargs(padded_input_ids, padded_attention_mask, signature)
next_logits_with_padding = model(**model_kwargs).logits[:, -1, :]
# They should result in very similar logits
torch.testing.assert_close(next_logits_wo_padding, next_logits_with_padding, rtol=1e-5, atol=1e-5)
def test_generate_continue_from_past_key_values(self):
# Tests that we can continue generating from past key values, returned from a previous `generate` call
for model_class in self.all_generative_model_classes:
if any(model_name in model_class.__name__.lower() for model_name in ["imagegpt", "mllama"]):
self.skipTest(reason="Won't fix: old model with unique inputs/caches/other")
if any(model_name in model_class.__name__.lower() for model_name in ["umt5"]):
self.skipTest(reason="TODO: needs modeling or test input preparation fixes for compatibility")
config, inputs = self.model_tester.prepare_config_and_inputs_for_common()
if not hasattr(config.get_text_config(), "use_cache"):
self.skipTest(reason=f"{model_class.__name__} doesn't support caching")
# Let's make it always:
# 1. use cache (for obvious reasons)
# 2. generate to max length (which can be achieved by setting the eos token to an invalid value), which
# would make the test flaky (e.g. EOS is generated on iteration 1 on both generations, but the
# continuation would force it to generate beyond an EOS token)
# 3. ignore `token_type_ids` for simplicity
# 4. ignore `forced_eos_token_id`, which requires further manipulation of the continuation inputs and is
# active by default on some models
# 5. ignore `encoder_no_repeat_ngram_size`, which is set by default in some encoder-decoder models. When
# we use their decoder as a stand-alone model, `encoder_no_repeat_ngram_size` actually prevents
# repetition exclusively from the prompt. This test relies on comparing one call vs 2 calls
# with cache, what is considered a prompt is different in the two cases.
if "token_type_ids" in inputs:
del inputs["token_type_ids"]
model = model_class(config).to(torch_device)
model.eval()
# If "past_key_values" is not returned, skip the test (e.g. RWKV uses a different cache name and format)
outputs = model(**inputs)
if "past_key_values" not in outputs:
self.skipTest(reason="This model doesn't return `past_key_values`")
generate_kwargs = {
"pad_token_id": -1,
"eos_token_id": -1,
"forced_eos_token_id": None,
"encoder_no_repeat_ngram_size": 0,
"use_cache": True,
"do_sample": False,
"return_dict_in_generate": True,
"output_scores": True,
}
# Traditional way of generating text, with `return_dict_in_generate` to return the past key values
_, inputs = self.prepare_config_and_inputs_for_generate()
outputs = model.generate(**inputs, **generate_kwargs, max_new_tokens=3)
# Let's generate again, but passing the past key values in between (2 + 1 = 3 tokens). Note that the
# inputs may need to be tweaked across `generate` calls (like the attention mask).
outputs_cached = model.generate(**inputs, **generate_kwargs, max_new_tokens=2)
# Continue from the tokens generated above, preparing the inputs accordingly
inputs["past_key_values"] = outputs_cached.past_key_values
new_attention_len = outputs_cached.sequences.shape[-1]
if config.is_encoder_decoder:
inputs["decoder_input_ids"] = outputs_cached.sequences
if "decoder_attention_mask" in inputs:
inputs["decoder_attention_mask"] = torch.nn.functional.pad(
inputs["decoder_attention_mask"],
(0, new_attention_len - inputs["decoder_attention_mask"].shape[1]),
mode="constant",
value=1,
)
else:
inputs["input_ids"] = outputs_cached.sequences
if "attention_mask" in inputs:
inputs["attention_mask"] = torch.nn.functional.pad(
inputs["attention_mask"],
(0, new_attention_len - inputs["attention_mask"].shape[1]),
mode="constant",
value=1,
)
first_caches_scores = outputs_cached.scores
outputs_cached = model.generate(**inputs, **generate_kwargs, max_new_tokens=1)
full_cached_scores = first_caches_scores + outputs_cached.scores
outputs_cached.scores = full_cached_scores
# The two sets of generated text and past kv should be equal to each other
self._check_similar_generate_outputs(outputs, outputs_cached)
for layer_idx in range(len(outputs_cached.past_key_values)):
for kv_idx in range(len(outputs_cached.past_key_values[layer_idx])):
self.assertTrue(
torch.allclose(
outputs.past_key_values[layer_idx][kv_idx],
outputs_cached.past_key_values[layer_idx][kv_idx],
)
)
# needs to be overridden to avoid to avoid casting of input_values to float16
# indeed, the codec model is kept in fp32, so we need to avoid casting input_values to float16
def _test_attention_implementation(self, attn_implementation):
"""
Compares the output of generate with the eager attention implementation against other implementations.
NOTE: despite the test logic being the same, different implementations actually need different decorators, hence
this separate function.
"""
max_new_tokens = 30
support_flag = {
"sdpa": "_supports_sdpa",
"flash_attention_2": "_supports_flash_attn_2",
}
for model_class in self.all_generative_model_classes:
if not getattr(model_class, support_flag[attn_implementation]):
self.skipTest(f"{model_class.__name__} does not support `attn_implementation={attn_implementation}`")
config, original_inputs_dict = self.prepare_config_and_inputs_for_generate()
inputs_dict = {}
for input_name, input_data in original_inputs_dict.items():
if (
isinstance(input_data, torch.Tensor)
and input_data.dtype in [torch.float32, torch.bfloat16]
and input_name != "input_values"
):
inputs_dict[input_name] = input_data.to(torch.float16)
else:
inputs_dict[input_name] = input_data
main_input = inputs_dict[model_class.main_input_name]
# FA2 doesn't accept masking in the middle of the sequence for now. We usually generate right-padded
# attention masks at test time and, with generate, the mask will be appended with 1s on the right,
# resulting in a mask with holes (not supported properly by FA2).
if attn_implementation == "flash_attention_2":
for input_name in ("attention_mask", "decoder_attention_mask", "encoder_attention_mask"):
if input_name in inputs_dict:
inputs_dict[input_name] = torch.ones_like(inputs_dict[input_name])
# make sure that all models have enough positions for generation
if hasattr(config, "max_position_embeddings"):
config.max_position_embeddings = max_new_tokens + main_input.shape[1] + 1
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
del model
gc.collect()
generate_kwargs = {
"max_new_tokens": max_new_tokens,
"do_sample": False,
"return_dict_in_generate": True,
"output_scores": True,
"use_cache": True,
}
model_eager = model_class.from_pretrained(
tmpdirname,
torch_dtype=torch.float16,
attn_implementation="eager",
).to(torch_device)
res_eager = model_eager.generate(**inputs_dict, **generate_kwargs)
del model_eager
gc.collect()
model_attn = model_class.from_pretrained(
tmpdirname,
torch_dtype=torch.float16,
attn_implementation=attn_implementation,
).to(torch_device)
res_attn = model_attn.generate(**inputs_dict, **generate_kwargs)
del model_attn
gc.collect()
self._check_similar_generate_outputs(res_eager, res_attn, atol=1e-3, rtol=1e-3)
class KyutaiSpeechToTextForConditionalGenerationIntegrationTests(unittest.TestCase):
_dataset = None
def setUp(self):
self.model_checkpoint = "kyutai/stt-2.6b-en"
def tearDown(self):
cleanup(torch_device, gc_collect=True)
@classmethod
def _load_dataset(cls):
# Lazy loading of the dataset. Because it is a class method, it will only be loaded once per pytest process.
if cls._dataset is None:
cls._dataset = datasets.load_dataset(
"hf-internal-testing/librispeech_asr_dummy", "clean", split="validation"
)
# using 24000 here for simplicity, should rather be processor.feature_extractor.sampling_rate
cls._dataset = cls._dataset.cast_column("audio", datasets.Audio(sampling_rate=24000))
def _load_datasamples(self, num_samples):
self._load_dataset()
ds = self._dataset
speech_samples = ds.sort("id").select(range(num_samples))[:num_samples]["audio"]
return [x["array"] for x in speech_samples]
@slow
@require_torch_accelerator
def test_generation(self):
"""
reproduce test expected outputs using original codebase: https://gist.github.com/eustlb/7a9aa6139d11e0103c6b65bac103da52
DISCLAIMER: we are testing for pretty short inputs. Indeed, reproducing correct expected outputs for longer is not possible
as implementation choices (qkv matrix in one linear for original code vs three for hf) create growing divergence with context lenght,
ultimately giving different outputs.
"""
processor = KyutaiSpeechToTextProcessor.from_pretrained(self.model_checkpoint)
model = KyutaiSpeechToTextForConditionalGeneration.from_pretrained(
self.model_checkpoint, device_map=torch_device
)
samples = self._load_datasamples(1)
inputs = processor(
samples,
).to(torch_device)
out = model.generate(**inputs)
# fmt: off
EXPECTED_TOKENS = torch.tensor([
[48000, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 0, 1519, 263, 3, 3, 0, 3635, 428, 641, 0, 277, 3, 0, 265, 0, 267, 1162, 261, 274, 410, 0, 272, 3, 0, 265, 0, 260, 1621, 0, 1174, 371, 262, 3, 3, 3, 0, 269, 0, 281, 0, 304, 0, 2433, 3, 0, 266, 3, 0, 281, 1661, 3, 0, 376, 3, 3, 0, 350, 261, 401, 516, 263, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3]],
)
# fmt: on
torch.testing.assert_close(out.cpu(), EXPECTED_TOKENS)
@slow
@require_torch_accelerator
def test_generation_batched(self):
"""
reproduce test expected outputs using original codebase: https://gist.github.com/eustlb/b58c217c75124d405ec1c13877c7ece8
DISCLAIMER: we are testing for pretty short inputs. Indeed, reproducing correct expected outputs for longer is not possible
as implementation choices (qkv matrix in one linear for original code vs three for hf) create growing divergence with context lenght,
ultimately giving different outputs.
"""
processor = KyutaiSpeechToTextProcessor.from_pretrained(self.model_checkpoint)
model = KyutaiSpeechToTextForConditionalGeneration.from_pretrained(
self.model_checkpoint, device_map=torch_device
)
samples = self._load_datasamples(4)
inputs = processor(
samples,
).to(torch_device)
out = model.generate(**inputs)
# fmt: off
EXPECTED_TOKENS = torch.tensor([
[48000, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 0, 1519, 263, 3, 3, 0, 3635, 428, 641, 0, 277, 3, 0, 265, 0, 267, 1162, 261, 274, 410, 0, 272, 3, 0, 265, 0, 260, 1621, 0, 1174, 371, 262, 3, 3, 3, 0, 269, 0, 281, 0, 304, 0, 2433, 3, 0, 266, 3, 0, 281, 1661, 3, 0, 376, 3, 3, 0, 350, 261, 401, 516, 263, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3],
[48000, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 0, 500, 334, 0, 277, 3, 0, 1519, 263, 3, 3, 0, 3635, 428, 641, 264, 261, 0, 511, 1109, 3, 0, 1138, 3, 3, 3, 0, 508, 827, 3, 3, 3, 3, 0, 468, 3, 3, 0, 376, 3, 3, 3, 0, 260, 978, 263, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3],
[48000, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 0, 414, 0, 527, 261, 3, 0, 409, 3, 3, 3, 0, 271, 3, 0, 309, 3, 0, 285, 3, 0, 521, 371, 609, 3, 3, 0, 260, 959, 3, 3, 3, 0, 272, 3, 0, 265, 0, 546, 262, 3, 3, 3, 3, 3, 3, 0, 291, 3, 0, 975, 2203, 3, 3, 3, 3, 0, 269, 3, 0, 260, 489, 651, 274, 279, 1870, 3, 0, 1084, 873, 273, 3, 0, 260, 531, 3, 3, 0, 409, 262, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 0, 1502, 1005, 836, 3, 3, 0, 1666, 306, 3, 0, 340, 3, 0, 260, 3232, 3, 0, 269, 3, 3, 0, 275, 261, 0, 260, 1379, 261, 0, 3324, 3, 3, 3, 3, 0, 549, 3, 3, 0, 693, 405, 323, 3, 0, 266, 3, 3, 0, 265, 0, 699, 263, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3],
[48000, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 0, 414, 0, 392, 3, 3, 0, 1269, 314, 0, 2607, 261, 3, 3, 3, 0, 1098, 295, 3, 3, 3, 0, 446, 625, 3, 0, 496, 280, 1205, 485, 1071, 1627, 449, 264, 261, 3, 0, 400, 0, 277, 3, 3, 3, 0, 260, 342, 3, 0, 618, 280, 1866, 3, 3, 0, 554, 3, 3, 3, 3, 0, 317, 262, 3, 3, 3, 3, 3, 3, 3, 3, 0, 269, 0, 303, 3, 0, 573, 2615, 3, 3, 0, 276, 3, 0, 275, 0, 305, 3, 0, 260, 415, 3, 3, 0, 272, 3, 3, 3, 3, 0, 1631, 327, 3, 3, 0, 333, 739, 841, 263, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3],
])
# fmt: on
torch.testing.assert_close(out.cpu(), EXPECTED_TOKENS)

View File

@ -107,14 +107,21 @@ class MimiModelTester:
self.sliding_window = sliding_window
self.use_cache = use_cache
def prepare_config_and_inputs(self):
input_values = floats_tensor([self.batch_size, self.num_channels, self.intermediate_size], scale=1.0)
def prepare_config_and_inputs(self, input_values_length=None):
input_values = floats_tensor(
[
self.batch_size,
self.num_channels,
self.intermediate_size if input_values_length is None else input_values_length,
],
scale=1.0,
)
config = self.get_config()
inputs_dict = {"input_values": input_values}
return config, inputs_dict
def prepare_config_and_inputs_for_common(self):
config, inputs_dict = self.prepare_config_and_inputs()
def prepare_config_and_inputs_for_common(self, input_values_length=None):
config, inputs_dict = self.prepare_config_and_inputs(input_values_length=input_values_length)
return config, inputs_dict
def prepare_config_and_inputs_for_model_class(self, model_class):
@ -508,6 +515,54 @@ class MimiIntegrationTest(unittest.TestCase):
)
self.assertTrue(rmse < 1e-3)
def test_integration_encode_with_padding_cache(self):
"""
We test here the possibility to run Mimi in a streaming manner, i.e. chunk by chunk.
1. we encode a first time the entire audio
2. we encode the audio chunk by chunk, each chunk being the smallest size possible for the model (i.e. the frame size)
This test must be run on CPU since GPU floating point operations accumulate rounding errors that cause test failures.
"""
librispeech_dummy = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
model_id = "kyutai/mimi"
model = MimiModel.from_pretrained(model_id, use_cache=True).to("cpu")
processor = AutoFeatureExtractor.from_pretrained(model_id)
librispeech_dummy = librispeech_dummy.cast_column("audio", Audio(sampling_rate=processor.sampling_rate))
audio_sample = librispeech_dummy[-1]["audio"]["array"]
inputs = processor(
raw_audio=audio_sample,
sampling_rate=processor.sampling_rate,
return_tensors="pt",
).to("cpu")
frame_size = model.config.frame_size
audio_codes = model.encode(inputs["input_values"]).audio_codes
# streaming chunk by chunk
encoder_past_key_values = None
padding_cache = None
encoded_frames_list = []
for start in range(0, inputs["input_values"].shape[-1], frame_size):
input_values_chunk = inputs["input_values"][:, :, start : start + frame_size]
encoder_outputs = model.encode(
input_values_chunk,
padding_cache=padding_cache,
encoder_past_key_values=encoder_past_key_values,
use_streaming=True,
)
encoder_past_key_values = encoder_outputs.encoder_past_key_values
padding_cache = encoder_outputs.padding_cache
encoded_frames_list.append(encoder_outputs.audio_codes)
streamed_audio_codes = torch.cat(encoded_frames_list, dim=-1)
torch.testing.assert_close(streamed_audio_codes, audio_codes)
def test_integration(self):
expected_rmses = {
"8": 0.0018785292,

View File

@ -3566,7 +3566,11 @@ class ModelTesterMixin:
# TODO: if we can also check with `batch_size=1` without being flaky?
for batch_size in [7]:
# musicgen decoder models; TODO: find better abstraction
if hasattr(self.model_tester, "num_codebooks") and not hasattr(model_eager, "text_encoder"):
if (
model.__class__.__name__.startswith("Musicgen")
and hasattr(self.model_tester, "num_codebooks")
and not hasattr(model_eager, "text_encoder")
):
input_data_batch_size = batch_size * self.model_tester.num_codebooks
else:
input_data_batch_size = batch_size
@ -3626,7 +3630,7 @@ class ModelTesterMixin:
if is_encoder_decoder:
# musicgen encoder-decoder models; TODO: find better abstraction
if hasattr(self.model_tester, "num_codebooks"):
if model.__class__.__name__.startswith("Musicgen") and hasattr(self.model_tester, "num_codebooks"):
input_data_batch_size = batch_size * self.model_tester.num_codebooks
else:
input_data_batch_size = batch_size

View File

@ -619,7 +619,7 @@ ALL_FILE_TYPES = (
"processing",
"image_processing",
"video_processing",
"feature_extractor",
"feature_extraction",
)
@ -1137,7 +1137,7 @@ TYPE_TO_FILE_TYPE = {
"VideoProcessor": "video_processing",
"VideoProcessorInitKwargs": "video_processing",
"FastImageProcessorKwargs": "image_processing*_fast",
"FeatureExtractor": "feature_extractor",
"FeatureExtractor": "feature_extraction",
"ProcessorKwargs": "processing",
"VideosKwargs": "processing",
"ImagesKwargs": "processing",