mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 11:11:05 +06:00
Add kyutai stt (#38909)
* first draft * cleaner version * udpate tests + modeling * add tests * init * udpate test_modeling_common * fix tests * csm Processor draft * convertion update * mimi cache padding convolutions draft * mimi streaming udpates * update mimi padding cache test * udpate cache padding mimi test * make style mimi * updates generate moshi asr * moshi asr integration tests (single + batched) * update tests * update conversion script * good default sliding window value * udpdate generate * update test checkpoint * nit * fix mimi * fix codec prefix * revert * revert * update config * update config * unnecessary mimi input restriction * remove delay in tokens * remove _prepare_4d_causal_attention_mask_with_cache_position and _update_causal_mask * test update * modular update * make style * nit * rename * create codec model generation config at init * remove delay * max_new_tokens/length warning * correct conv1 padding cache import for modular * nit * fix on encoder_past_key_values * convert modular * move frame_size to config * move frame_size to config * update test name * handle first token is bos * better handling of max_new_tokens * fix * fix batch size in test input prep * update docstring * convert modular * make style * make style * add feature extractor * correct modular convention name for feature_extraction file * update convertion script * doc processor * update doc * udpate init * update model type * fixes * update tests * fix * make * add doc * nit * fix * doc * auto mappings * doc * nit * convert modular * doc * nit * extend _keep_in_fp32_modules to enforce fp32 * renaming to stt * doc update + test update * doc fixes * doc fix * doc fix * fix musicgen tests * fix musicgen tests * make style * fix musicgen tests * correct frame_rate config param for mimi * update mimi test * revert update mimi test * enforce cpu test * move cache init in cache class * convert modular * docstring update * update model id * feature_extractor -> feature_extraction (SEW) * convert modular * update model id
This commit is contained in:
parent
08bf7f1afe
commit
6bdd4ec952
@ -843,6 +843,8 @@
|
||||
title: GraniteSpeech
|
||||
- local: model_doc/hubert
|
||||
title: Hubert
|
||||
- local: model_doc/stt
|
||||
title: Kyutai Speech-To-Text
|
||||
- local: model_doc/mctct
|
||||
title: MCTCT
|
||||
- local: model_doc/mimi
|
||||
|
122
docs/source/en/model_doc/stt.md
Normal file
122
docs/source/en/model_doc/stt.md
Normal file
@ -0,0 +1,122 @@
|
||||
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
|
||||
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
||||
rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
|
||||
# Kyutai Speech-To-Text
|
||||
## Overview
|
||||
|
||||
Kyutai STT is a speech-to-text model architecture based on the [Mimi codec](https://huggingface.co/docs/transformers/en/model_doc/mimi), which encodes audio into discrete tokens in a streaming fashion, and a [Moshi-like](https://huggingface.co/docs/transformers/en/model_doc/moshi) autoregressive decoder. Kyutai’s lab has released two model checkpoints:
|
||||
- [kyutai/stt-1b-en_fr](https://huggingface.co/kyutai/stt-1b-en_fr): a 1B-parameter model capable of transcribing both English and French
|
||||
- [kyutai/stt-2.6b-en](https://huggingface.co/kyutai/stt-2.6b-en): a 2.6B-parameter model focused solely on English, optimized for maximum transcription accuracy
|
||||
|
||||
<div class="flex justify-center">
|
||||
<img src="https://huggingface.co/datasets/eustlb/documentation-images/resolve/main/kyutai_stt.png"/>
|
||||
</div>
|
||||
|
||||
## Usage Tips
|
||||
|
||||
### Inference
|
||||
|
||||
```python
|
||||
import torch
|
||||
from datasets import load_dataset, Audio
|
||||
from transformers import KyutaiSpeechToTextProcessor, KyutaiSpeechToTextForConditionalGeneration
|
||||
|
||||
# 1. load the model and the processor
|
||||
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
model_id = "kyutai/stt-2.6b-en"
|
||||
|
||||
processor = KyutaiSpeechToTextProcessor.from_pretrained(model_id)
|
||||
model = KyutaiSpeechToTextForConditionalGeneration.from_pretrained(model_id, device_map=torch_device)
|
||||
|
||||
# 2. load audio samples
|
||||
ds = load_dataset(
|
||||
"hf-internal-testing/librispeech_asr_dummy", "clean", split="validation"
|
||||
)
|
||||
ds = ds.cast_column("audio", Audio(sampling_rate=24000))
|
||||
|
||||
# 3. prepare the model inputs
|
||||
inputs = processor(
|
||||
ds[0]["audio"]["array"],
|
||||
)
|
||||
inputs.to(torch_device)
|
||||
|
||||
# 4. infer the model
|
||||
output_tokens = model.generate(**inputs)
|
||||
|
||||
# 5. decode the generated tokens
|
||||
print(processor.batch_decode(output_tokens, skip_special_tokens=True))
|
||||
```
|
||||
|
||||
### Batched Inference
|
||||
|
||||
```python
|
||||
import torch
|
||||
from datasets import load_dataset, Audio
|
||||
from transformers import KyutaiSpeechToTextProcessor, KyutaiSpeechToTextForConditionalGeneration
|
||||
|
||||
# 1. load the model and the processor
|
||||
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
model_id = "kyutai/stt-2.6b-en"
|
||||
|
||||
processor = KyutaiSpeechToTextProcessor.from_pretrained(model_id)
|
||||
model = KyutaiSpeechToTextForConditionalGeneration.from_pretrained(model_id, device_map=torch_device)
|
||||
|
||||
# 2. load audio samples
|
||||
ds = load_dataset(
|
||||
"hf-internal-testing/librispeech_asr_dummy", "clean", split="validation"
|
||||
)
|
||||
ds = ds.cast_column("audio", Audio(sampling_rate=24000))
|
||||
|
||||
# 3. prepare the model inputs
|
||||
audio_arrays = [ds[i]["audio"]["array"] for i in range(4)]
|
||||
inputs = processor(audio_arrays, return_tensors="pt", padding=True)
|
||||
inputs = inputs.to(torch_device)
|
||||
|
||||
# 4. infer the model
|
||||
output_tokens = model.generate(**inputs)
|
||||
|
||||
# 5. decode the generated tokens
|
||||
decoded_outputs = processor.batch_decode(output_tokens, skip_special_tokens=True)
|
||||
for output in decoded_outputs:
|
||||
print(output)
|
||||
```
|
||||
|
||||
This model was contributed by [Eustache Le Bihan](https://huggingface.co/eustlb).
|
||||
The original code can be found [here](https://github.com/kyutai-labs/moshi).
|
||||
|
||||
|
||||
## KyutaiSpeechToTextConfig
|
||||
|
||||
[[autodoc]] KyutaiSpeechToTextConfig
|
||||
|
||||
## KyutaiSpeechToTextProcessor
|
||||
|
||||
[[autodoc]] KyutaiSpeechToTextProcessor
|
||||
- __call__
|
||||
|
||||
## KyutaiSpeechToTextFeatureExtractor
|
||||
|
||||
[[autodoc]] KyutaiSpeechToTextFeatureExtractor
|
||||
|
||||
## KyutaiSpeechToTextForConditionalGeneration
|
||||
|
||||
[[autodoc]] KyutaiSpeechToTextForConditionalGeneration
|
||||
- forward
|
||||
- generate
|
||||
|
||||
## KyutaiSpeechToTextModel
|
||||
|
||||
[[autodoc]] KyutaiSpeechToTextModel
|
@ -4658,8 +4658,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
# The _keep_in_fp32_modules flag is only used to avoid bf16 -> fp16 casting precision issues. It was introduced
|
||||
# in case of force loading a model that should stay bf16 in fp16 (which includes a few quantizers as this is a pre-processing
|
||||
# step for e.g. bitsandbytes). See https://github.com/huggingface/transformers/issues/20287 for details.
|
||||
# Update: to extend _keep_in_fp32_modules flag feature, it can also be used to force modules that should stay in fp32
|
||||
if model._keep_in_fp32_modules is not None and (
|
||||
torch_dtype == torch.float16 or getattr(hf_quantizer, "use_keep_in_fp32_modules", False)
|
||||
torch_dtype == torch.float16
|
||||
or torch_dtype == torch.bfloat16
|
||||
or getattr(hf_quantizer, "use_keep_in_fp32_modules", False)
|
||||
):
|
||||
# We need to match exact layers, so we add either `.` on each side, or start/end of string
|
||||
keep_in_fp32_regex = re.compile(
|
||||
|
@ -285,6 +285,7 @@ if TYPE_CHECKING:
|
||||
from .squeezebert import *
|
||||
from .stablelm import *
|
||||
from .starcoder2 import *
|
||||
from .stt import *
|
||||
from .superglue import *
|
||||
from .superpoint import *
|
||||
from .swiftformer import *
|
||||
|
@ -322,6 +322,7 @@ CONFIG_MAPPING_NAMES = OrderedDict[str, str](
|
||||
("squeezebert", "SqueezeBertConfig"),
|
||||
("stablelm", "StableLmConfig"),
|
||||
("starcoder2", "Starcoder2Config"),
|
||||
("stt", "KyutaiSpeechToTextConfig"),
|
||||
("superglue", "SuperGlueConfig"),
|
||||
("superpoint", "SuperPointConfig"),
|
||||
("swiftformer", "SwiftFormerConfig"),
|
||||
@ -707,6 +708,7 @@ MODEL_NAMES_MAPPING = OrderedDict[str, str](
|
||||
("squeezebert", "SqueezeBERT"),
|
||||
("stablelm", "StableLm"),
|
||||
("starcoder2", "Starcoder2"),
|
||||
("stt", "KyutaiSpeechToText"),
|
||||
("superglue", "SuperGlue"),
|
||||
("superpoint", "SuperPoint"),
|
||||
("swiftformer", "SwiftFormer"),
|
||||
|
@ -91,6 +91,7 @@ FEATURE_EXTRACTOR_MAPPING_NAMES = OrderedDict(
|
||||
("sew-d", "Wav2Vec2FeatureExtractor"),
|
||||
("speech_to_text", "Speech2TextFeatureExtractor"),
|
||||
("speecht5", "SpeechT5FeatureExtractor"),
|
||||
("stt", "KyutaiSpeechToTextFeatureExtractor"),
|
||||
("swiftformer", "ViTFeatureExtractor"),
|
||||
("swin", "ViTFeatureExtractor"),
|
||||
("swinv2", "ViTFeatureExtractor"),
|
||||
|
@ -300,6 +300,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
|
||||
("squeezebert", "SqueezeBertModel"),
|
||||
("stablelm", "StableLmModel"),
|
||||
("starcoder2", "Starcoder2Model"),
|
||||
("stt", "KyutaiSpeechToTextModel"),
|
||||
("superglue", "SuperGlueForKeypointMatching"),
|
||||
("swiftformer", "SwiftFormerModel"),
|
||||
("swin", "SwinModel"),
|
||||
@ -1055,6 +1056,7 @@ MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES = OrderedDict(
|
||||
("speech-encoder-decoder", "SpeechEncoderDecoderModel"),
|
||||
("speech_to_text", "Speech2TextForConditionalGeneration"),
|
||||
("speecht5", "SpeechT5ForSpeechToText"),
|
||||
("stt", "KyutaiSpeechToTextForConditionalGeneration"),
|
||||
("whisper", "WhisperForConditionalGeneration"),
|
||||
]
|
||||
)
|
||||
|
@ -116,6 +116,7 @@ PROCESSOR_MAPPING_NAMES = OrderedDict(
|
||||
("speech_to_text", "Speech2TextProcessor"),
|
||||
("speech_to_text_2", "Speech2Text2Processor"),
|
||||
("speecht5", "SpeechT5Processor"),
|
||||
("stt", "KyutaiSpeechToTextProcessor"),
|
||||
("trocr", "TrOCRProcessor"),
|
||||
("tvlt", "TvltProcessor"),
|
||||
("tvp", "TvpProcessor"),
|
||||
|
@ -38,8 +38,8 @@ class MimiConfig(PretrainedConfig):
|
||||
Args:
|
||||
sampling_rate (`int`, *optional*, defaults to 24000):
|
||||
The sampling rate at which the audio waveform should be digitalized expressed in hertz (Hz).
|
||||
frame_rate (`float`, *optional*, defaults to 12.5):
|
||||
Framerate of the model.
|
||||
frame_rate (`float`, *optional*):
|
||||
Should be computed from the other parameters, yet kept for backward compatibility.
|
||||
audio_channels (`int`, *optional*, defaults to 1):
|
||||
Number of channels in the audio data. Either 1 for mono or 2 for stereo.
|
||||
hidden_size (`int`, *optional*, defaults to 512):
|
||||
@ -111,6 +111,8 @@ class MimiConfig(PretrainedConfig):
|
||||
use_cache (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
||||
relevant if `config.is_decoder=True`.
|
||||
use_streaming (`bool`, *optional*, defaults to `False`):
|
||||
Whether to use streaming mode. If `True`, the model encode method will return the padding cache that can be used in a subsequent call to the encode method.
|
||||
rope_theta (`float`, *optional*, defaults to 10000.0):
|
||||
The base period of the RoPE embeddings.
|
||||
sliding_window (`int`, *optional*, defaults to 250):
|
||||
@ -141,7 +143,7 @@ class MimiConfig(PretrainedConfig):
|
||||
def __init__(
|
||||
self,
|
||||
sampling_rate=24_000,
|
||||
frame_rate=12.5,
|
||||
frame_rate=None,
|
||||
audio_channels=1,
|
||||
hidden_size=512,
|
||||
num_filters=64,
|
||||
@ -172,6 +174,7 @@ class MimiConfig(PretrainedConfig):
|
||||
initializer_range=0.02,
|
||||
norm_eps=1e-5,
|
||||
use_cache=False,
|
||||
use_streaming=False,
|
||||
rope_theta=10000.0,
|
||||
sliding_window=250,
|
||||
attention_dropout=0.0,
|
||||
@ -180,7 +183,6 @@ class MimiConfig(PretrainedConfig):
|
||||
**kwargs,
|
||||
):
|
||||
self.sampling_rate = sampling_rate
|
||||
self.frame_rate = frame_rate
|
||||
self.audio_channels = audio_channels
|
||||
self.hidden_size = hidden_size
|
||||
self.num_filters = num_filters
|
||||
@ -209,6 +211,7 @@ class MimiConfig(PretrainedConfig):
|
||||
self.initializer_range = initializer_range
|
||||
self.norm_eps = norm_eps
|
||||
self.use_cache = use_cache
|
||||
self.use_streaming = use_streaming
|
||||
self.rope_theta = rope_theta
|
||||
self.sliding_window = sliding_window
|
||||
self.attention_dropout = attention_dropout
|
||||
@ -216,6 +219,14 @@ class MimiConfig(PretrainedConfig):
|
||||
self.layer_scale_initial_scale = layer_scale_initial_scale
|
||||
self.attention_bias = attention_bias
|
||||
|
||||
# Handle backward compatibility for frame_rate:
|
||||
# If frame_rate is explicitly provided, use it (backward compatibility)
|
||||
# Otherwise, compute it from other parameters (correctly)
|
||||
if frame_rate is not None:
|
||||
self._frame_rate = frame_rate
|
||||
else:
|
||||
self._frame_rate = None
|
||||
|
||||
if num_semantic_quantizers >= self.num_quantizers:
|
||||
raise ValueError(
|
||||
f"The number of semantic quantizers should be lower than the total number of quantizers {self.num_quantizers}, but is currently {num_semantic_quantizers}."
|
||||
@ -233,5 +244,36 @@ class MimiConfig(PretrainedConfig):
|
||||
# alias to num_quantizers
|
||||
return self.num_quantizers
|
||||
|
||||
@property
|
||||
def frame_size(self) -> int:
|
||||
# 1. we need each encoder conv stride
|
||||
# first conv
|
||||
strides = [1]
|
||||
|
||||
# layer convs
|
||||
for ratio in reversed(self.upsampling_ratios):
|
||||
for j in range(self.num_residual_layers):
|
||||
len_kernel_sizes = len(self.residual_kernel_size) if isinstance(self.residual_kernel_size, list) else 1
|
||||
strides.extend([1] * (len_kernel_sizes + 1))
|
||||
if self.use_conv_shortcut: # skip connection
|
||||
strides.append(1)
|
||||
|
||||
strides.append(ratio)
|
||||
|
||||
# last conv
|
||||
strides.append(1)
|
||||
|
||||
# downsampling layer
|
||||
strides.append(2)
|
||||
|
||||
return math.prod(strides)
|
||||
|
||||
@property
|
||||
def frame_rate(self) -> float:
|
||||
# handle backward compatibility
|
||||
if self._frame_rate is not None:
|
||||
return self._frame_rate
|
||||
return self.sampling_rate / self.frame_size
|
||||
|
||||
|
||||
__all__ = ["MimiConfig"]
|
||||
|
@ -23,25 +23,20 @@ import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache
|
||||
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
||||
from ...cache_utils import Cache, DynamicCache, StaticCache
|
||||
from ...masking_utils import create_causal_mask
|
||||
from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import BaseModelOutputWithPast
|
||||
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...utils import ModelOutput, auto_docstring, is_torch_flex_attn_available, logging
|
||||
from ...utils import ModelOutput, auto_docstring, logging
|
||||
from .configuration_mimi import MimiConfig
|
||||
|
||||
|
||||
if is_flash_attn_available():
|
||||
from ...modeling_flash_attention_utils import _flash_attention_forward
|
||||
|
||||
if is_torch_flex_attn_available():
|
||||
from torch.nn.attention.flex_attention import BlockMask
|
||||
|
||||
from ...integrations.flex_attention import make_flex_block_causal_mask
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
@ -78,6 +73,91 @@ class MimiOutput(ModelOutput):
|
||||
decoder_past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None
|
||||
|
||||
|
||||
class MimiConv1dPaddingCache:
|
||||
"""
|
||||
Padding cache for MimiConv1d causal convolutions in order to support streaming via cache padding.
|
||||
See: https://arxiv.org/pdf/2005.06720 & https://arxiv.org/pdf/2204.07064
|
||||
|
||||
A padding cache is a list of cached partial hidden states for each convolution layer.
|
||||
Hidden states are cached from the previous call to the MimiConv1d forward pass, given the padding size.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_layers: int,
|
||||
per_layer_padding: list[int],
|
||||
per_layer_padding_mode: list[str],
|
||||
per_layer_in_channels: list[int],
|
||||
):
|
||||
# ensure correct number of layers for each arg
|
||||
from_args_num_layers = {len(per_layer_padding), len(per_layer_padding_mode), len(per_layer_in_channels)}
|
||||
|
||||
if len(from_args_num_layers) != 1 or from_args_num_layers.pop() != num_layers:
|
||||
raise ValueError(
|
||||
f"Expected `num_layers` ({num_layers}) values in `per_layer_padding`, `per_layer_padding_mode` and `per_layer_in_channels`"
|
||||
)
|
||||
elif not all(mode in ["constant", "replicate"] for mode in per_layer_padding_mode):
|
||||
raise NotImplementedError(
|
||||
"`padding_cache` is not supported for convolutions using other than `constant` or `replicate` padding mode"
|
||||
)
|
||||
|
||||
self.per_layer_padding = per_layer_padding
|
||||
self.per_layer_padding_mode = per_layer_padding_mode
|
||||
self.per_layer_in_channels = per_layer_in_channels
|
||||
self.per_layer_is_init = [True] * num_layers
|
||||
|
||||
self.padding_cache = [None] * num_layers
|
||||
|
||||
def update(self, hidden_states: torch.Tensor, layer_idx: int):
|
||||
"""
|
||||
Updates the padding cache with the new padding states for the layer `layer_idx` and returns the current cache.
|
||||
|
||||
Parameters:
|
||||
hidden_states (`torch.Tensor`):
|
||||
The hidden states to be partially cached.
|
||||
layer_idx (`int`):
|
||||
The index of the layer to cache the states for.
|
||||
Returns:
|
||||
`torch.Tensor` or `None`, the current padding cache.
|
||||
"""
|
||||
batch_size, dtype, device = hidden_states.shape[0], hidden_states.dtype, hidden_states.device
|
||||
padding = self.per_layer_padding[layer_idx]
|
||||
padding_mode = self.per_layer_padding_mode[layer_idx]
|
||||
in_channels = self.per_layer_in_channels[layer_idx]
|
||||
|
||||
if self.padding_cache[layer_idx] is None:
|
||||
if padding_mode == "constant":
|
||||
current_cache = torch.zeros(
|
||||
batch_size,
|
||||
in_channels,
|
||||
padding,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
elif padding_mode == "replicate":
|
||||
current_cache = (
|
||||
torch.ones(
|
||||
batch_size,
|
||||
in_channels,
|
||||
padding,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
* hidden_states[..., :1]
|
||||
)
|
||||
else:
|
||||
current_cache = self.padding_cache[layer_idx]
|
||||
|
||||
# update the cache
|
||||
if padding > 0:
|
||||
padding_states = hidden_states[:, :, -padding:]
|
||||
else:
|
||||
padding_states = torch.empty(batch_size, in_channels, padding, dtype=dtype, device=device)
|
||||
self.padding_cache[layer_idx] = padding_states
|
||||
|
||||
return current_cache
|
||||
|
||||
|
||||
@dataclass
|
||||
@auto_docstring
|
||||
class MimiEncoderOutput(ModelOutput):
|
||||
@ -96,6 +176,7 @@ class MimiEncoderOutput(ModelOutput):
|
||||
|
||||
audio_codes: Optional[torch.LongTensor] = None
|
||||
encoder_past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None
|
||||
padding_cache: Optional[MimiConv1dPaddingCache] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -130,12 +211,15 @@ class MimiConv1d(nn.Module):
|
||||
stride: int = 1,
|
||||
dilation: int = 1,
|
||||
groups: int = 1,
|
||||
pad_mode=None,
|
||||
pad_mode: Optional[str] = None,
|
||||
bias: bool = True,
|
||||
layer_idx: Optional[int] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.causal = config.use_causal_conv
|
||||
self.pad_mode = config.pad_mode if pad_mode is None else pad_mode
|
||||
self.layer_idx = layer_idx
|
||||
self.in_channels = in_channels
|
||||
|
||||
# warn user on unusual setup between dilation and stride
|
||||
if stride > 1 and dilation > 1:
|
||||
@ -232,12 +316,20 @@ class MimiConv1d(nn.Module):
|
||||
) // self.conv.stride[0] + 1
|
||||
return output_lenght
|
||||
|
||||
def forward(self, hidden_states):
|
||||
def forward(self, hidden_states, padding_cache=None):
|
||||
extra_padding = self._get_extra_padding_for_conv1d(hidden_states)
|
||||
|
||||
if self.causal:
|
||||
if not self.causal and padding_cache is not None:
|
||||
raise ValueError("`padding_cache` is not supported for non-causal convolutions.")
|
||||
|
||||
if self.causal and padding_cache is not None:
|
||||
layer_padding_cache = padding_cache.update(hidden_states, self.layer_idx)
|
||||
hidden_states = torch.cat([layer_padding_cache, hidden_states], dim=2)
|
||||
|
||||
elif self.causal:
|
||||
# Left padding for causal
|
||||
hidden_states = self._pad1d(hidden_states, (self.padding_total, extra_padding), mode=self.pad_mode)
|
||||
|
||||
else:
|
||||
hidden_states = self._pad1d(
|
||||
hidden_states, (self.padding_left, self.padding_right + extra_padding), mode=self.pad_mode
|
||||
@ -305,7 +397,6 @@ class MimiConvTranspose1d(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
# Copied from transformers.models.encodec.modeling_encodec.EncodecResnetBlock with Encodec->Mimi,EnCodec->Mimi
|
||||
class MimiResnetBlock(nn.Module):
|
||||
"""
|
||||
Residual block from SEANet model as used by Mimi.
|
||||
@ -331,12 +422,21 @@ class MimiResnetBlock(nn.Module):
|
||||
else:
|
||||
self.shortcut = nn.Identity()
|
||||
|
||||
def forward(self, hidden_states):
|
||||
def forward(self, hidden_states, padding_cache=None):
|
||||
residual = hidden_states
|
||||
for layer in self.block:
|
||||
hidden_states = layer(hidden_states)
|
||||
|
||||
return self.shortcut(residual) + hidden_states
|
||||
for layer in self.block:
|
||||
if isinstance(layer, MimiConv1d):
|
||||
hidden_states = layer(hidden_states, padding_cache=padding_cache)
|
||||
else:
|
||||
hidden_states = layer(hidden_states)
|
||||
|
||||
if isinstance(self.shortcut, MimiConv1d):
|
||||
residual = self.shortcut(residual, padding_cache=padding_cache)
|
||||
else:
|
||||
residual = self.shortcut(residual)
|
||||
|
||||
return residual + hidden_states
|
||||
|
||||
|
||||
class MimiEncoder(nn.Module):
|
||||
@ -370,10 +470,17 @@ class MimiEncoder(nn.Module):
|
||||
self.layers = nn.ModuleList(model)
|
||||
self._mimiconv1d_layer_names = mimiconv1d_layer_names
|
||||
|
||||
# Copied from transformers.models.encodec.modeling_encodec.EncodecEncoder.forward
|
||||
def forward(self, hidden_states):
|
||||
# initialize layer_idx for MimiConv1d submodules, necessary for padding_cache
|
||||
for layer_idx, layername in enumerate(self._mimiconv1d_layer_names):
|
||||
conv_layer = self.get_submodule(layername)
|
||||
setattr(conv_layer, "layer_idx", layer_idx)
|
||||
|
||||
def forward(self, hidden_states, padding_cache=None):
|
||||
for layer in self.layers:
|
||||
hidden_states = layer(hidden_states)
|
||||
if isinstance(layer, (MimiConv1d, MimiResnetBlock)):
|
||||
hidden_states = layer(hidden_states, padding_cache=padding_cache)
|
||||
else:
|
||||
hidden_states = layer(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
@ -1005,11 +1112,13 @@ class MimiTransformerModel(nn.Module):
|
||||
if position_ids is None:
|
||||
position_ids = cache_position.unsqueeze(0)
|
||||
|
||||
causal_mask = None
|
||||
if attention_mask is not None:
|
||||
causal_mask = self._update_causal_mask(
|
||||
attention_mask, hidden_states, cache_position, past_key_values, output_attentions
|
||||
)
|
||||
causal_mask = create_causal_mask(
|
||||
config=self.config,
|
||||
input_embeds=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
cache_position=cache_position,
|
||||
past_key_values=past_key_values,
|
||||
)
|
||||
|
||||
# decoder layers
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
@ -1054,163 +1163,6 @@ class MimiTransformerModel(nn.Module):
|
||||
attentions=all_self_attns,
|
||||
)
|
||||
|
||||
# Copied from transformers.models.phimoe.modeling_phimoe.PhimoeModel._update_causal_mask with Phimoe->Mimi
|
||||
def _update_causal_mask(
|
||||
self,
|
||||
attention_mask: Union[torch.Tensor, "BlockMask"],
|
||||
input_tensor: torch.Tensor,
|
||||
cache_position: torch.Tensor,
|
||||
past_key_values: Cache,
|
||||
output_attentions: bool = False,
|
||||
):
|
||||
if self.config._attn_implementation == "flash_attention_2":
|
||||
if attention_mask is not None and past_key_values is not None:
|
||||
is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0]
|
||||
if is_padding_right:
|
||||
raise ValueError(
|
||||
"You are attempting to perform batched generation with padding_side='right'"
|
||||
" this may lead to unexpected behaviour for Flash Attention version of Mimi. Make sure to "
|
||||
" call `tokenizer.padding_side = 'left'` before tokenizing the input. "
|
||||
)
|
||||
if attention_mask is not None and 0.0 in attention_mask:
|
||||
return attention_mask
|
||||
return None
|
||||
if self.config._attn_implementation == "flex_attention":
|
||||
if isinstance(attention_mask, torch.Tensor):
|
||||
attention_mask = make_flex_block_causal_mask(attention_mask)
|
||||
return attention_mask
|
||||
|
||||
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
|
||||
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
|
||||
# to infer the attention mask.
|
||||
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||
using_static_cache = isinstance(past_key_values, StaticCache)
|
||||
using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache)
|
||||
|
||||
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
|
||||
if (
|
||||
self.config._attn_implementation == "sdpa"
|
||||
and not (using_static_cache or using_sliding_window_cache)
|
||||
and not output_attentions
|
||||
):
|
||||
if AttentionMaskConverter._ignore_causal_mask_sdpa(
|
||||
attention_mask,
|
||||
inputs_embeds=input_tensor,
|
||||
past_key_values_length=past_seen_tokens,
|
||||
sliding_window=self.config.sliding_window,
|
||||
is_training=self.training,
|
||||
):
|
||||
return None
|
||||
|
||||
dtype = input_tensor.dtype
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
sequence_length = input_tensor.shape[1]
|
||||
# SlidingWindowCache or StaticCache
|
||||
if using_sliding_window_cache or using_static_cache:
|
||||
target_length = past_key_values.get_max_cache_shape()
|
||||
# DynamicCache or no cache
|
||||
else:
|
||||
target_length = (
|
||||
attention_mask.shape[-1]
|
||||
if isinstance(attention_mask, torch.Tensor)
|
||||
else past_seen_tokens + sequence_length + 1
|
||||
)
|
||||
|
||||
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
|
||||
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask,
|
||||
sequence_length=sequence_length,
|
||||
target_length=target_length,
|
||||
dtype=dtype,
|
||||
cache_position=cache_position,
|
||||
batch_size=input_tensor.shape[0],
|
||||
config=self.config,
|
||||
past_key_values=past_key_values,
|
||||
)
|
||||
|
||||
if (
|
||||
self.config._attn_implementation == "sdpa"
|
||||
and attention_mask is not None
|
||||
and attention_mask.device.type in ["cuda", "xpu", "npu"]
|
||||
and not output_attentions
|
||||
):
|
||||
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
|
||||
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
|
||||
# Details: https://github.com/pytorch/pytorch/issues/110213
|
||||
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
|
||||
|
||||
return causal_mask
|
||||
|
||||
@staticmethod
|
||||
# Copied from transformers.models.phimoe.modeling_phimoe.PhimoeModel._prepare_4d_causal_attention_mask_with_cache_position with Phimoe->Mimi
|
||||
def _prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask: torch.Tensor,
|
||||
sequence_length: int,
|
||||
target_length: int,
|
||||
dtype: torch.dtype,
|
||||
cache_position: torch.Tensor,
|
||||
batch_size: int,
|
||||
config: MimiConfig,
|
||||
past_key_values: Cache,
|
||||
):
|
||||
"""
|
||||
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
||||
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
|
||||
|
||||
Args:
|
||||
attention_mask (`torch.Tensor`):
|
||||
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
|
||||
sequence_length (`int`):
|
||||
The sequence length being processed.
|
||||
target_length (`int`):
|
||||
The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
|
||||
dtype (`torch.dtype`):
|
||||
The dtype to use for the 4D attention mask.
|
||||
cache_position (`torch.Tensor`):
|
||||
Indices depicting the position of the input sequence tokens in the sequence.
|
||||
batch_size (`torch.Tensor`):
|
||||
Batch size.
|
||||
config (`MimiConfig`):
|
||||
The model's configuration class
|
||||
past_key_values (`Cache`):
|
||||
The cache class that is being used currently to generate
|
||||
"""
|
||||
if attention_mask is not None and attention_mask.dim() == 4:
|
||||
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
|
||||
causal_mask = attention_mask
|
||||
else:
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
causal_mask = torch.full(
|
||||
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
|
||||
)
|
||||
diagonal_attend_mask = torch.arange(target_length, device=cache_position.device) > cache_position.reshape(
|
||||
-1, 1
|
||||
)
|
||||
text_config = config.get_text_config()
|
||||
if getattr(text_config, "use_sliding_window", True) and text_config.sliding_window is not None:
|
||||
# if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also
|
||||
# the check is needed to verify is current checkpoint was trained with sliding window or not
|
||||
if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length:
|
||||
sliding_attend_mask = torch.arange(target_length, device=cache_position.device) <= (
|
||||
cache_position.reshape(-1, 1) - text_config.sliding_window
|
||||
)
|
||||
diagonal_attend_mask.bitwise_or_(sliding_attend_mask)
|
||||
causal_mask *= diagonal_attend_mask
|
||||
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
||||
if attention_mask is not None:
|
||||
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
||||
if attention_mask.shape[-1] > target_length:
|
||||
attention_mask = attention_mask[:, :target_length]
|
||||
mask_length = attention_mask.shape[-1]
|
||||
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
|
||||
causal_mask.device
|
||||
)
|
||||
padding_mask = padding_mask == 0
|
||||
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
||||
padding_mask, min_dtype
|
||||
)
|
||||
return causal_mask
|
||||
|
||||
|
||||
class MimiDecoder(nn.Module):
|
||||
"""SEANet decoder as used by Mimi."""
|
||||
@ -1269,7 +1221,7 @@ class MimiEuclideanCodebook(nn.Module):
|
||||
def quantize(self, hidden_states):
|
||||
# Projects each vector in `hidden_states` over the nearest centroid and return its index.
|
||||
# `hidden_states` should be `[N, D]` with `N` the number of input vectors and `D` the dimension.
|
||||
dists = torch.cdist(hidden_states[None], self.embed[None], p=2)[0]
|
||||
dists = torch.cdist(hidden_states[None].float(), self.embed[None].float(), p=2)[0]
|
||||
embed_ind = dists.argmin(dim=-1)
|
||||
return embed_ind
|
||||
|
||||
@ -1476,6 +1428,7 @@ class MimiModel(MimiPreTrainedModel):
|
||||
stride=2,
|
||||
bias=False,
|
||||
pad_mode="replicate",
|
||||
layer_idx=len(self.encoder._mimiconv1d_layer_names),
|
||||
)
|
||||
|
||||
self.upsample = MimiConvTranspose1d(
|
||||
@ -1512,12 +1465,17 @@ class MimiModel(MimiPreTrainedModel):
|
||||
num_quantizers: int,
|
||||
padding_mask: int,
|
||||
past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None,
|
||||
padding_cache: Optional[MimiConv1dPaddingCache] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
"""
|
||||
Encodes the given input using the underlying VQVAE. The padding mask is required to compute the correct scale.
|
||||
"""
|
||||
embeddings = self.encoder(input_values)
|
||||
|
||||
# TODO: @eustlb, let's make the encoder support padding_mask so that batched inputs are supported.
|
||||
embeddings = self.encoder(input_values, padding_cache=padding_cache)
|
||||
|
||||
# TODO: @eustlb, convert the padding mask to attention mask.
|
||||
encoder_outputs = self.encoder_transformer(
|
||||
embeddings.transpose(1, 2), past_key_values=past_key_values, return_dict=return_dict
|
||||
)
|
||||
@ -1526,11 +1484,11 @@ class MimiModel(MimiPreTrainedModel):
|
||||
elif len(encoder_outputs) > 1:
|
||||
past_key_values = encoder_outputs[1]
|
||||
embeddings = encoder_outputs[0].transpose(1, 2)
|
||||
embeddings = self.downsample(embeddings)
|
||||
embeddings = self.downsample(embeddings, padding_cache=padding_cache)
|
||||
|
||||
codes = self.quantizer.encode(embeddings, num_quantizers)
|
||||
codes = codes.transpose(0, 1)
|
||||
return codes, past_key_values
|
||||
return codes, past_key_values, padding_cache
|
||||
|
||||
def get_encoded_length(self, input_length: torch.LongTensor) -> torch.LongTensor:
|
||||
"""
|
||||
@ -1570,6 +1528,8 @@ class MimiModel(MimiPreTrainedModel):
|
||||
padding_mask: Optional[torch.Tensor] = None,
|
||||
num_quantizers: Optional[float] = None,
|
||||
encoder_past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None,
|
||||
padding_cache: Optional[MimiConv1dPaddingCache] = None,
|
||||
use_streaming: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[tuple[torch.Tensor, Optional[torch.Tensor]], MimiEncoderOutput]:
|
||||
"""
|
||||
@ -1598,6 +1558,7 @@ class MimiModel(MimiPreTrainedModel):
|
||||
`codebook` of shape `[batch_size, num_codebooks, frames]`, the discrete encoded codes for the input audio waveform.
|
||||
"""
|
||||
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
||||
use_streaming = use_streaming if use_streaming is not None else self.config.use_streaming
|
||||
|
||||
num_quantizers = self.config.num_quantizers if num_quantizers is None else num_quantizers
|
||||
|
||||
@ -1614,11 +1575,31 @@ class MimiModel(MimiPreTrainedModel):
|
||||
if padding_mask is None:
|
||||
padding_mask = torch.ones_like(input_values).bool()
|
||||
|
||||
encoded_frames, encoder_past_key_values = self._encode_frame(
|
||||
if use_streaming and padding_cache is None:
|
||||
per_layer_padding, per_layer_padding_mode, per_layer_in_channels = [], [], []
|
||||
for layer_name in self.encoder._mimiconv1d_layer_names:
|
||||
per_layer_padding.append(self.encoder.get_submodule(layer_name).padding_total)
|
||||
per_layer_padding_mode.append(self.encoder.get_submodule(layer_name).pad_mode)
|
||||
per_layer_in_channels.append(self.encoder.get_submodule(layer_name).in_channels)
|
||||
|
||||
# downsample layer
|
||||
per_layer_padding.append(self.downsample.padding_total)
|
||||
per_layer_padding_mode.append(self.downsample.pad_mode)
|
||||
per_layer_in_channels.append(self.downsample.in_channels)
|
||||
|
||||
padding_cache = MimiConv1dPaddingCache(
|
||||
num_layers=len(self.encoder._mimiconv1d_layer_names) + 1,
|
||||
per_layer_padding=per_layer_padding,
|
||||
per_layer_padding_mode=per_layer_padding_mode,
|
||||
per_layer_in_channels=per_layer_in_channels,
|
||||
)
|
||||
|
||||
encoded_frames, encoder_past_key_values, padding_cache = self._encode_frame(
|
||||
input_values,
|
||||
num_quantizers,
|
||||
padding_mask.bool(),
|
||||
past_key_values=encoder_past_key_values,
|
||||
padding_cache=padding_cache,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
@ -1626,9 +1607,10 @@ class MimiModel(MimiPreTrainedModel):
|
||||
return (
|
||||
encoded_frames,
|
||||
encoder_past_key_values,
|
||||
padding_cache,
|
||||
)
|
||||
|
||||
return MimiEncoderOutput(encoded_frames, encoder_past_key_values)
|
||||
return MimiEncoderOutput(encoded_frames, encoder_past_key_values, padding_cache)
|
||||
|
||||
def _decode_frame(
|
||||
self,
|
||||
|
29
src/transformers/models/stt/__init__.py
Normal file
29
src/transformers/models/stt/__init__.py
Normal file
@ -0,0 +1,29 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...utils import _LazyModule
|
||||
from ...utils.import_utils import define_import_structure
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .configuration_kyutai_speech_to_text import *
|
||||
from .feature_extraction_kyutai_speech_to_text import *
|
||||
from .modeling_kyutai_speech_to_text import *
|
||||
from .processing_kyutai_speech_to_text import *
|
||||
else:
|
||||
import sys
|
||||
|
||||
_file = globals()["__file__"]
|
||||
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
|
@ -0,0 +1,188 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.s
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...utils import logging
|
||||
from ..auto.configuration_auto import AutoConfig
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class KyutaiSpeechToTextConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`KyutaiSpeechToTextForConditionalGeneration`].
|
||||
It is used to instantiate a Kyutai Speech-to-Text model according to the specified arguments, defining the model
|
||||
architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the
|
||||
2.6b-en model.
|
||||
|
||||
e.g. [kyutai/stt-2.6b-en](https://huggingface.co/kyutai/stt-2.6b-en)
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
Args:
|
||||
codebook_vocab_size (`int`, *optional*, defaults to 2049):
|
||||
Vocabulary size of the codebook. Defines the number of different audio tokens that can be represented by each codebook.
|
||||
vocab_size (`int`, *optional*, defaults to 4001):
|
||||
Vocabulary size of the model. Defines the number of different tokens that can be represented by the
|
||||
`input_ids` passed when calling the model.
|
||||
hidden_size (`int`, *optional*, defaults to 2048):
|
||||
Dimensionality of the layers and the pooler layer of the main decoder.
|
||||
num_hidden_layers (`int`, *optional*, defaults to 48):
|
||||
Number of decoder layers.
|
||||
num_attention_heads (`int`, *optional*, defaults to 32):
|
||||
Number of attention heads for each attention layer in the main decoder block.
|
||||
num_key_value_heads (`int`, *optional*):
|
||||
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
|
||||
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
|
||||
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
|
||||
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
|
||||
by meanpooling all the original heads within that group. For more details checkout [this
|
||||
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `num_attention_heads`.
|
||||
max_position_embeddings (`int`, *optional*, defaults to 750):
|
||||
The maximum sequence length that this model might ever be used with. Typically, set this to something large
|
||||
just in case (e.g., 512 or 1024 or 2048).
|
||||
rope_theta (`float`, *optional*, defaults to 100000.0):
|
||||
The base period of the RoPE embeddings.
|
||||
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
|
||||
The non-linear activation function (function or string) in the decoder.
|
||||
head_dim (`int`, *optional*, defaults to `hidden_size // num_attention_heads`):
|
||||
The attention head dimension.
|
||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
use_cache (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
||||
relevant if `config.is_decoder=True`.
|
||||
sliding_window (`int`, *optional*, defaults to 375):
|
||||
Sliding window attention window size. If not specified, will default to `3000`.
|
||||
attention_dropout (`float`, *optional*, defaults to 0.0):
|
||||
The dropout ratio for the attention probabilities.
|
||||
ffn_dim (`int`, *optional*, defaults to 11264):
|
||||
Dimensionality of the "intermediate" (often named feed-forward) layer in the main decoder block. Must be even.
|
||||
rms_norm_eps (`float`, *optional*, defaults to 1e-08):
|
||||
The epsilon used by the rms normalization layers.
|
||||
num_codebooks (`int`, *optional*, defaults to 32):
|
||||
The number of audio codebooks for each audio channels.
|
||||
audio_bos_token_id (`int`, *optional*, defaults to 2048):
|
||||
Beginning of stream token id for codebook tokens.
|
||||
audio_pad_token_id (`int`, *optional*, defaults to 69569):
|
||||
Padding token id for codebook tokens.
|
||||
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
|
||||
Whether to tie weight embeddings.
|
||||
pad_token_id (`int`, *optional*, defaults to 3):
|
||||
Padding token id.
|
||||
bos_token_id (`int`, *optional*, defaults to 48000):
|
||||
Beginning of stream token id for text tokens.
|
||||
codec_config (`PretrainedConfig`, *optional*):
|
||||
Configuration for the codec.
|
||||
kwargs (*optional*):
|
||||
Dictionary of keyword arguments. Notably:
|
||||
- **audio_encoder_config** ([`PretrainedConfig`], *optional*) -- An instance of a configuration object that
|
||||
defines the audio encoder config.
|
||||
- **depth__config** ([`PretrainedConfig`], *optional*) -- An instance of a configuration object that
|
||||
defines the depth decoder config.
|
||||
|
||||
|
||||
Example:
|
||||
```python
|
||||
>>> from transformers import KyutaiSpeechToTextConfig, KyutaiSpeechToTextForConditionalGeneration
|
||||
|
||||
>>> # Initializing a KyutaiSpeechToTextConfig
|
||||
>>> configuration = KyutaiSpeechToTextConfig()
|
||||
|
||||
>>> # Initializing a model
|
||||
>>> model = KyutaiSpeechToTextForConditionalGeneration(configuration)
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
```"""
|
||||
|
||||
# not the best naming here for `model_type`, but original codebase already uses model type:`stt` for in the config so we keep it to simplify
|
||||
model_type = "stt"
|
||||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
sub_configs = {"codec_config": AutoConfig}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
codebook_vocab_size=2049,
|
||||
vocab_size=4001,
|
||||
hidden_size=2048,
|
||||
num_hidden_layers=48,
|
||||
num_attention_heads=32,
|
||||
num_key_value_heads=None,
|
||||
max_position_embeddings=750,
|
||||
rope_theta=100000.0,
|
||||
hidden_act="silu",
|
||||
head_dim=None,
|
||||
initializer_range=0.02,
|
||||
use_cache=True,
|
||||
sliding_window=375,
|
||||
attention_dropout=0.0,
|
||||
ffn_dim=11264,
|
||||
rms_norm_eps=1e-8,
|
||||
num_codebooks=32,
|
||||
audio_bos_token_id=2048,
|
||||
audio_pad_token_id=69569,
|
||||
tie_word_embeddings=False,
|
||||
pad_token_id=3,
|
||||
bos_token_id=48000,
|
||||
codec_config=None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(
|
||||
pad_token_id=pad_token_id, bos_token_id=bos_token_id, tie_word_embeddings=tie_word_embeddings, **kwargs
|
||||
)
|
||||
|
||||
if codec_config is None:
|
||||
self.codec_config = AutoConfig.for_model("mimi")
|
||||
logger.info("codec_config is None, using default audio encoder config.")
|
||||
elif isinstance(codec_config, dict):
|
||||
self.codec_config = AutoConfig.for_model(**codec_config)
|
||||
elif isinstance(codec_config, PretrainedConfig):
|
||||
self.codec_config = codec_config
|
||||
|
||||
self.num_codebooks = num_codebooks
|
||||
self.frame_size = self.codec_config.frame_size
|
||||
|
||||
self.audio_bos_token_id = audio_bos_token_id
|
||||
self.audio_pad_token_id = audio_pad_token_id
|
||||
self.codebook_vocab_size = codebook_vocab_size
|
||||
|
||||
self.vocab_size = vocab_size
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.hidden_size = hidden_size
|
||||
if ffn_dim % 2 == 1:
|
||||
raise ValueError(f"`ffn_dim={ffn_dim}` must be even.")
|
||||
self.ffn_dim = ffn_dim
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
|
||||
# for backward compatibility
|
||||
if num_key_value_heads is None:
|
||||
num_key_value_heads = num_attention_heads
|
||||
|
||||
self.num_key_value_heads = num_key_value_heads
|
||||
self.hidden_act = hidden_act
|
||||
self.initializer_range = initializer_range
|
||||
self.rms_norm_eps = rms_norm_eps
|
||||
self.use_cache = use_cache
|
||||
self.rope_theta = rope_theta
|
||||
self.attention_dropout = attention_dropout
|
||||
self.head_dim = head_dim if head_dim is not None else self.hidden_size // self.num_attention_heads
|
||||
self.sliding_window = sliding_window
|
||||
|
||||
|
||||
__all__ = ["KyutaiSpeechToTextConfig"]
|
@ -0,0 +1,377 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import gc
|
||||
import os
|
||||
import re
|
||||
|
||||
import safetensors.torch
|
||||
import sentencepiece
|
||||
import torch
|
||||
|
||||
from transformers import (
|
||||
KyutaiSpeechToTextConfig,
|
||||
KyutaiSpeechToTextFeatureExtractor,
|
||||
KyutaiSpeechToTextForConditionalGeneration,
|
||||
KyutaiSpeechToTextProcessor,
|
||||
PreTrainedTokenizerFast,
|
||||
)
|
||||
from transformers.convert_slow_tokenizer import MoshiConverter
|
||||
from transformers.utils.hub import cached_file
|
||||
|
||||
|
||||
# fmt: off
|
||||
MOSHI_ORIGINAL_TO_CONVERTED_KEY_MAPPING = {
|
||||
r"out_norm": r"norm",
|
||||
r"gating\.linear_in": r"mlp.fc1",
|
||||
r"gating\.linear_out": r"mlp.fc2",
|
||||
r"self_attn\.out_proj": r"self_attn.o_proj.linear",
|
||||
r"norm1": r"input_layernorm",
|
||||
r"norm2": r"post_attention_layernorm",
|
||||
r"layer_scale_1": r"self_attn_layer_scale",
|
||||
r"layer_scale_2": r"mlp_layer_scale",
|
||||
r"alpha": r"weight",
|
||||
}
|
||||
# fmt: on
|
||||
|
||||
|
||||
# fmt: off
|
||||
MIMI_ORIGINAL_TO_CONVERTED_KEY_MAPPING = {
|
||||
r"conv\.conv\.conv": "conv",
|
||||
r"convtr\.convtr\.convtr": "conv",
|
||||
r"conv\.conv": "conv",
|
||||
r"convtr\.convtr": "conv",
|
||||
r"quantizer\.rvq_first\.vq": "quantizer.semantic_residual_vector_quantizer",
|
||||
r"quantizer\.rvq_first": "quantizer.semantic_residual_vector_quantizer",
|
||||
r"quantizer\.rvq_rest\.vq": "quantizer.acoustic_residual_vector_quantizer",
|
||||
r"quantizer\.rvq_rest": "quantizer.acoustic_residual_vector_quantizer",
|
||||
r"_codebook": "codebook",
|
||||
r"_initialized": "initialized",
|
||||
r"embedding_sum": "embed_sum",
|
||||
r"encoder\.model": "encoder.layers",
|
||||
r"decoder\.model": "decoder.layers",
|
||||
r"encoder_transformer\.transformer": "encoder_transformer",
|
||||
r"decoder_transformer\.transformer": "decoder_transformer",
|
||||
r"linear1": "mlp.fc1",
|
||||
r"linear2": "mlp.fc2",
|
||||
r"self_attn\.out_proj": "self_attn.o_proj",
|
||||
r"norm1": "input_layernorm",
|
||||
r"norm2": "post_attention_layernorm",
|
||||
r"layer_scale_1": "self_attn_layer_scale",
|
||||
r"layer_scale_2": "mlp_layer_scale",
|
||||
}
|
||||
# fmt: on
|
||||
|
||||
|
||||
def permute_for_rope(input_tensor, n_heads, dim1, dim2):
|
||||
"""
|
||||
When you go from the complex ROPE formulation to sin and cos one, you need
|
||||
to permute the query and key weights (to avoid doing it on the fly)
|
||||
"""
|
||||
return input_tensor.view(n_heads, dim1 // n_heads // 2, 2, dim2).transpose(1, 2).reshape(dim1, dim2)
|
||||
|
||||
|
||||
def convert_key(key, mapping):
|
||||
for pattern, replacement in mapping.items():
|
||||
key = re.sub(pattern, replacement, key)
|
||||
return key
|
||||
|
||||
|
||||
def convert_kyutai_speech_to_text_state_dict(state_dict, config, unwanted_prefix="transformer."):
|
||||
hidden_size = config.hidden_size
|
||||
head_dim = config.head_dim
|
||||
num_heads = int(config.hidden_size // config.head_dim)
|
||||
num_key_value_heads = config.num_key_value_heads
|
||||
key_value_head_dim = config.num_key_value_heads * head_dim
|
||||
|
||||
# concat embeddings
|
||||
embed_tokens_weight = []
|
||||
for i in range(32):
|
||||
embed_tokens_weight.append(state_dict.pop(f"emb.{i}.weight"))
|
||||
|
||||
embed_tokens_weight = torch.cat(embed_tokens_weight, dim=0)
|
||||
embed_tokens_weight = torch.cat([state_dict.pop("text_emb.weight"), embed_tokens_weight])
|
||||
embed_tokens_weight = torch.cat([embed_tokens_weight, torch.zeros(1, config.hidden_size)], dim=0)
|
||||
state_dict["embed_tokens.embed_tokens.weight"] = embed_tokens_weight
|
||||
|
||||
for key, value in list(state_dict.items()):
|
||||
if unwanted_prefix is not None and unwanted_prefix in key:
|
||||
new_key = key[len(unwanted_prefix) :]
|
||||
else:
|
||||
new_key = key
|
||||
|
||||
new_key = convert_key(new_key, MOSHI_ORIGINAL_TO_CONVERTED_KEY_MAPPING)
|
||||
|
||||
# Post-process the current_parameter.
|
||||
if "alpha" in key:
|
||||
state_dict[key] = state_dict[key].squeeze()
|
||||
|
||||
if "in_proj_weight" in new_key:
|
||||
# split qkv into query key and value
|
||||
mixed_qkv = state_dict.pop(key)
|
||||
qkv_dim = mixed_qkv.size(0) // 3
|
||||
|
||||
query_layer = mixed_qkv[:qkv_dim]
|
||||
key_layer = mixed_qkv[qkv_dim : qkv_dim * 2]
|
||||
value_layer = mixed_qkv[qkv_dim * 2 :]
|
||||
state_dict[new_key.replace("in_proj_weight", "q_proj.linear.weight")] = permute_for_rope(
|
||||
query_layer, num_heads, hidden_size, hidden_size
|
||||
)
|
||||
state_dict[new_key.replace("in_proj_weight", "k_proj.linear.weight")] = permute_for_rope(
|
||||
key_layer, num_key_value_heads, key_value_head_dim, hidden_size
|
||||
)
|
||||
|
||||
state_dict[new_key.replace("in_proj_weight", "v_proj.linear.weight")] = value_layer
|
||||
else:
|
||||
state_dict[new_key] = state_dict.pop(key)
|
||||
|
||||
return state_dict
|
||||
|
||||
|
||||
def convert_mimi_state_dict(state_dict, config, unwanted_prefix=None):
|
||||
hidden_size = config.hidden_size
|
||||
head_dim = config.head_dim
|
||||
num_heads = int(config.hidden_size // config.head_dim)
|
||||
num_key_value_heads = config.num_key_value_heads
|
||||
key_value_head_dim = config.num_key_value_heads * head_dim
|
||||
|
||||
for key, value in list(state_dict.items()):
|
||||
if unwanted_prefix is not None and unwanted_prefix in key:
|
||||
new_key = key[len(unwanted_prefix) :]
|
||||
else:
|
||||
new_key = key
|
||||
|
||||
new_key = convert_key(new_key, MIMI_ORIGINAL_TO_CONVERTED_KEY_MAPPING)
|
||||
|
||||
if "in_proj_weight" in new_key:
|
||||
# split qkv into query key and value
|
||||
mixed_qkv = state_dict.pop(key)
|
||||
qkv_dim = mixed_qkv.size(0) // 3
|
||||
|
||||
query_layer = mixed_qkv[:qkv_dim]
|
||||
key_layer = mixed_qkv[qkv_dim : qkv_dim * 2]
|
||||
value_layer = mixed_qkv[qkv_dim * 2 :]
|
||||
|
||||
state_dict[new_key.replace("in_proj_weight", "q_proj.weight")] = permute_for_rope(
|
||||
query_layer, num_heads, hidden_size, hidden_size
|
||||
)
|
||||
state_dict[new_key.replace("in_proj_weight", "k_proj.weight")] = permute_for_rope(
|
||||
key_layer, num_key_value_heads, key_value_head_dim, hidden_size
|
||||
)
|
||||
state_dict[new_key.replace("in_proj_weight", "v_proj.weight")] = value_layer
|
||||
else:
|
||||
state_dict[new_key] = state_dict.pop(key)
|
||||
|
||||
return state_dict
|
||||
|
||||
|
||||
def write_model(
|
||||
input_path_or_repo,
|
||||
model_name,
|
||||
codec_model_path_or_repo,
|
||||
codec_model_name,
|
||||
output_dir,
|
||||
safe_serialization=True,
|
||||
unwanted_prefix="transformer.",
|
||||
):
|
||||
print("Converting the model.")
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
config = KyutaiSpeechToTextConfig()
|
||||
config.use_cache = True
|
||||
config.codec_config.sliding_window = 250
|
||||
|
||||
model_path = cached_file(
|
||||
input_path_or_repo,
|
||||
model_name,
|
||||
)
|
||||
|
||||
codec_path = cached_file(
|
||||
codec_model_path_or_repo,
|
||||
codec_model_name,
|
||||
)
|
||||
|
||||
print(f"Fetching all parameters from the checkpoint at {model_path}...")
|
||||
state_dict = safetensors.torch.load_file(model_path)
|
||||
|
||||
print(f"Fetching all parameters from the checkpoint at {codec_path}...")
|
||||
codec_state_dict = safetensors.torch.load_file(codec_path)
|
||||
|
||||
print("Converting model...")
|
||||
# -----------------------
|
||||
# convert parameter names
|
||||
# -----------------------
|
||||
state_dict = convert_kyutai_speech_to_text_state_dict(state_dict, config, unwanted_prefix=unwanted_prefix)
|
||||
codec_state_dict = convert_mimi_state_dict(codec_state_dict, config.codec_config, unwanted_prefix=None)
|
||||
|
||||
# -------------------------
|
||||
# load the weights and save
|
||||
# -------------------------
|
||||
print("Loading the checkpoint in a Moshi ASR model.")
|
||||
with torch.device("meta"):
|
||||
model = KyutaiSpeechToTextForConditionalGeneration(config)
|
||||
|
||||
linear_weight = state_dict.pop("text_linear.weight")
|
||||
model.model.load_state_dict(state_dict, strict=True, assign=True)
|
||||
|
||||
linear_weight = torch.cat([linear_weight, torch.zeros(1, config.hidden_size)])
|
||||
model.lm_head.load_state_dict({"weight": linear_weight}, strict=True, assign=True)
|
||||
|
||||
model.codec_model.load_state_dict(codec_state_dict, strict=True, assign=True)
|
||||
|
||||
print("Checkpoint loaded successfully.")
|
||||
del model.config._name_or_path
|
||||
del model.config.codec_config._name_or_path
|
||||
|
||||
# default generation config
|
||||
model.generation_config._from_model_config = False
|
||||
model.generation_config.audio_window_size = 1
|
||||
model.generation_config.cache_implementation = "sliding_window"
|
||||
|
||||
model.codec_model.generation_config._from_model_config = False
|
||||
model.codec_model.generation_config.cache_implementation = "sliding_window"
|
||||
model.codec_model.generation_config.use_cache = True
|
||||
|
||||
print("Saving the model.")
|
||||
model.save_pretrained(output_dir, safe_serialization=safe_serialization)
|
||||
del state_dict, model
|
||||
|
||||
# Safety check: reload the converted model
|
||||
gc.collect()
|
||||
print("Reloading the model to check if it's saved correctly.")
|
||||
KyutaiSpeechToTextForConditionalGeneration.from_pretrained(
|
||||
output_dir, torch_dtype=torch.bfloat16, device_map="auto"
|
||||
)
|
||||
print("Model reloaded successfully.")
|
||||
|
||||
|
||||
def write_processor(
|
||||
input_path_or_repo,
|
||||
tokenizer_model_name,
|
||||
codec_model_path_or_repo,
|
||||
output_dir,
|
||||
audio_delay_seconds,
|
||||
audio_silence_prefix_seconds,
|
||||
):
|
||||
tokenizer_path = cached_file(
|
||||
input_path_or_repo,
|
||||
tokenizer_model_name,
|
||||
)
|
||||
|
||||
tokenizer = MoshiConverter(tokenizer_path).converted()
|
||||
original_tokenizer = sentencepiece.SentencePieceProcessor(tokenizer_path)
|
||||
|
||||
tokenizer = PreTrainedTokenizerFast(
|
||||
tokenizer_object=tokenizer,
|
||||
chat_template=None,
|
||||
unk_token="<unk>",
|
||||
model_input_names=["input_ids", "attention_mask"],
|
||||
clean_up_tokenization_spaces=False,
|
||||
bos_token_id=original_tokenizer.bos_id(),
|
||||
eos_token_id=original_tokenizer.eos_id(),
|
||||
pad_token_id=original_tokenizer.pad_id(),
|
||||
)
|
||||
|
||||
feature_extractor = KyutaiSpeechToTextFeatureExtractor(
|
||||
audio_delay_seconds=audio_delay_seconds,
|
||||
audio_silence_prefix_seconds=audio_silence_prefix_seconds,
|
||||
)
|
||||
|
||||
processor = KyutaiSpeechToTextProcessor(feature_extractor, tokenizer)
|
||||
processor.save_pretrained(output_dir)
|
||||
print(f"Processor saved successfully to {output_dir}")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Convert Moshi ASR weights to HuggingFace format")
|
||||
parser.add_argument(
|
||||
"--input_path_or_repo",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path or repo containing Moshi ASR weights",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_name",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Name of the model in input_path_or_repo",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tokenizer_model_name",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Name of the tokenizer model in input_path_or_repo",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--codec_model_path_or_repo",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path or repo containing the Mimi weights",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mimi_name",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Name of the Mimi model in codec_model_path_or_repo",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--preprocessor_model_path_or_repo",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path or repo containing the preprocessor config",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
help="Location to write HF model and tokenizer",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--safe_serialization", action="store_true", default=True, help="Whether or not to save using `safetensors`."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--audio_delay_seconds",
|
||||
type=float,
|
||||
required=True,
|
||||
help="Audio delay in seconds to add to the right of the input",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--audio_silence_prefix_seconds",
|
||||
type=float,
|
||||
required=True,
|
||||
help="Audio silence prefix in seconds to add to the left of the input",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
write_model(
|
||||
args.input_path_or_repo,
|
||||
args.model_name,
|
||||
args.codec_model_path_or_repo,
|
||||
args.mimi_name,
|
||||
args.output_dir,
|
||||
safe_serialization=args.safe_serialization,
|
||||
)
|
||||
|
||||
write_processor(
|
||||
args.input_path_or_repo,
|
||||
args.tokenizer_model_name,
|
||||
args.preprocessor_model_path_or_repo,
|
||||
args.output_dir,
|
||||
args.audio_delay_seconds,
|
||||
args.audio_silence_prefix_seconds,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -0,0 +1,237 @@
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
# This file was automatically generated from src/transformers/models/stt/modular_kyutai_speech_to_text.py.
|
||||
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
||||
# the file from the modular. If any change should be done, please apply the change to the
|
||||
# modular_kyutai_speech_to_text.py file directly. One of our CI enforces this.
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
# coding=utf-8
|
||||
# Copyright 2025 Kyutai and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Optional, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ...feature_extraction_sequence_utils import SequenceFeatureExtractor
|
||||
from ...feature_extraction_utils import BatchFeature
|
||||
from ...utils import PaddingStrategy, TensorType, logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class KyutaiSpeechToTextFeatureExtractor(SequenceFeatureExtractor):
|
||||
r"""
|
||||
Constructs an KyutaiSpeechToText feature extractor.
|
||||
|
||||
This feature extractor inherits from [`~feature_extraction_sequence_utils.SequenceFeatureExtractor`] which contains
|
||||
most of the main methods. Users should refer to this superclass for more information regarding those methods.
|
||||
|
||||
Args:
|
||||
feature_size (`int`, *optional*, defaults to 1):
|
||||
The feature dimension of the extracted features. Use 1 for mono, 2 for stereo.
|
||||
sampling_rate (`int`, *optional*, defaults to 24000):
|
||||
The sampling rate at which the audio waveform should be digitalized expressed in hertz (Hz).
|
||||
padding_value (`float`, *optional*, defaults to 0.0):
|
||||
The value that is used to fill the padding values.
|
||||
chunk_length_s (`float`, *optional*):
|
||||
If defined the audio is pre-processed into chunks of lengths `chunk_length_s` and then encoded.
|
||||
overlap (`float`, *optional*):
|
||||
Defines the overlap between each chunk. It is used to compute the `chunk_stride` using the following
|
||||
formulae : `int((1.0 - self.overlap) * self.chunk_length)`.
|
||||
audio_delay_seconds (`float`, *optional*, defaults to 0.0):
|
||||
The delay in seconds to add after the audio (right padding).
|
||||
audio_silence_prefix_seconds (`float`, *optional*, defaults to 0.0):
|
||||
The silence prefix in seconds to add before the audio (left padding).
|
||||
"""
|
||||
|
||||
model_input_names = ["input_values", "padding_mask"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
feature_size: int = 1,
|
||||
sampling_rate: int = 24000,
|
||||
padding_value: float = 0.0,
|
||||
chunk_length_s: Optional[float] = None,
|
||||
overlap: Optional[float] = None,
|
||||
audio_delay_seconds: Optional[float] = 0.0,
|
||||
audio_silence_prefix_seconds: Optional[float] = 0.0,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(feature_size=feature_size, sampling_rate=sampling_rate, padding_value=padding_value, **kwargs)
|
||||
self.chunk_length_s = chunk_length_s
|
||||
self.overlap = overlap
|
||||
self.audio_delay_seconds = audio_delay_seconds
|
||||
self.audio_silence_prefix_seconds = audio_silence_prefix_seconds
|
||||
|
||||
# This is a property because you might want to change the chunk_length_s on the fly
|
||||
@property
|
||||
def chunk_length(self) -> Optional[int]:
|
||||
if self.chunk_length_s is None:
|
||||
return None
|
||||
else:
|
||||
return int(self.chunk_length_s * self.sampling_rate)
|
||||
|
||||
# This is a property because you might want to change the chunk_length_s on the fly
|
||||
@property
|
||||
def chunk_stride(self) -> Optional[int]:
|
||||
if self.chunk_length_s is None or self.overlap is None:
|
||||
return None
|
||||
else:
|
||||
return max(1, int((1.0 - self.overlap) * self.chunk_length))
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
raw_audio: Union[np.ndarray, list[float], list[np.ndarray], list[list[float]]],
|
||||
padding: Optional[Union[bool, str, PaddingStrategy]] = None,
|
||||
truncation: Optional[bool] = False,
|
||||
max_length: Optional[int] = None,
|
||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||
sampling_rate: Optional[int] = None,
|
||||
) -> BatchFeature:
|
||||
"""
|
||||
Main method to featurize and prepare for the model one or several sequence(s).
|
||||
|
||||
Args:
|
||||
raw_audio (`np.ndarray`, `list[float]`, `list[np.ndarray]`, `list[list[float]]`):
|
||||
The sequence or batch of sequences to be processed. Each sequence can be a numpy array, a list of float
|
||||
values, a list of numpy arrays or a list of list of float values. The numpy array must be of shape
|
||||
`(num_samples,)` for mono audio (`feature_size = 1`), or `(2, num_samples)` for stereo audio
|
||||
(`feature_size = 2`).
|
||||
padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`):
|
||||
Select a strategy to pad the returned sequences (according to the model's padding side and padding
|
||||
index) among:
|
||||
|
||||
- `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
|
||||
sequence if provided).
|
||||
- `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
|
||||
acceptable input length for the model if that argument is not provided.
|
||||
- `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
|
||||
lengths).
|
||||
truncation (`bool`, *optional*, defaults to `False`):
|
||||
Activates truncation to cut input sequences longer than `max_length` to `max_length`.
|
||||
max_length (`int`, *optional*):
|
||||
Maximum length of the returned list and optionally padding length (see above).
|
||||
return_tensors (`str` or [`~utils.TensorType`], *optional*):
|
||||
If set, will return tensors instead of list of python integers. Acceptable values are:
|
||||
|
||||
- `'tf'`: Return TensorFlow `tf.constant` objects.
|
||||
- `'pt'`: Return PyTorch `torch.Tensor` objects.
|
||||
- `'np'`: Return Numpy `np.ndarray` objects.
|
||||
sampling_rate (`int`, *optional*):
|
||||
The sampling rate at which the `audio` input was sampled. It is strongly recommended to pass
|
||||
`sampling_rate` at the forward call to prevent silent errors.
|
||||
"""
|
||||
if sampling_rate is not None:
|
||||
if sampling_rate != self.sampling_rate:
|
||||
raise ValueError(
|
||||
f"The model corresponding to this feature extractor: {self} was trained using a sampling rate of"
|
||||
f" {self.sampling_rate}. Please make sure that the provided audio input was sampled with"
|
||||
f" {self.sampling_rate} and not {sampling_rate}."
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"It is strongly recommended to pass the `sampling_rate` argument to `{self.__class__.__name__}()`. "
|
||||
"Failing to do so can result in silent errors that might be hard to debug."
|
||||
)
|
||||
|
||||
if padding and truncation:
|
||||
raise ValueError("Both padding and truncation were set. Make sure you only set one.")
|
||||
elif padding is None:
|
||||
# by default let's pad the inputs
|
||||
padding = True
|
||||
|
||||
is_batched = bool(
|
||||
isinstance(raw_audio, (list, tuple)) and (isinstance(raw_audio[0], (np.ndarray, tuple, list)))
|
||||
)
|
||||
|
||||
if is_batched:
|
||||
raw_audio = [np.asarray(audio, dtype=np.float32).T for audio in raw_audio]
|
||||
elif not is_batched and not isinstance(raw_audio, np.ndarray):
|
||||
raw_audio = np.asarray(raw_audio, dtype=np.float32)
|
||||
elif isinstance(raw_audio, np.ndarray) and raw_audio.dtype is np.dtype(np.float64):
|
||||
raw_audio = raw_audio.astype(np.float32)
|
||||
|
||||
# always return batch
|
||||
if not is_batched:
|
||||
raw_audio = [np.asarray(raw_audio).T]
|
||||
|
||||
# verify inputs are valid
|
||||
for idx, example in enumerate(raw_audio):
|
||||
if example.ndim > 2:
|
||||
raise ValueError(f"Expected input shape (channels, length) but got shape {example.shape}")
|
||||
if self.feature_size == 1 and example.ndim != 1:
|
||||
raise ValueError(f"Expected mono audio but example has {example.shape[-1]} channels")
|
||||
if self.feature_size == 2 and example.shape[-1] != 2:
|
||||
raise ValueError(f"Expected stereo audio but example has {example.shape[-1]} channels")
|
||||
|
||||
padded_inputs = None
|
||||
input_values = BatchFeature({"input_values": raw_audio})
|
||||
if self.chunk_stride is not None and self.chunk_length is not None and max_length is None:
|
||||
if truncation:
|
||||
max_length = min(array.shape[0] for array in raw_audio)
|
||||
nb_step = int(np.floor(max_length / self.chunk_stride))
|
||||
max_length = (nb_step - 1) * self.chunk_stride + self.chunk_length
|
||||
elif padding:
|
||||
max_length = max(array.shape[0] for array in raw_audio)
|
||||
nb_step = int(np.ceil(max_length / self.chunk_stride))
|
||||
max_length = (nb_step - 1) * self.chunk_stride + self.chunk_length
|
||||
padding = "max_length"
|
||||
else:
|
||||
padded_inputs = input_values
|
||||
|
||||
# normal padding on batch
|
||||
if padded_inputs is None:
|
||||
padded_inputs = self.pad(
|
||||
input_values,
|
||||
max_length=max_length,
|
||||
truncation=truncation,
|
||||
padding=padding,
|
||||
return_attention_mask=padding,
|
||||
)
|
||||
|
||||
if padding:
|
||||
padded_inputs["padding_mask"] = padded_inputs.pop("attention_mask")
|
||||
|
||||
# now let's padd left and right
|
||||
pad_left = int(self.audio_silence_prefix_seconds * self.sampling_rate)
|
||||
pad_right = int((self.audio_delay_seconds + 1.0) * self.sampling_rate)
|
||||
padded_inputs["input_values"] = np.pad(
|
||||
padded_inputs["input_values"],
|
||||
((0, 0), (pad_left, pad_right)),
|
||||
mode="constant",
|
||||
constant_values=0.0,
|
||||
)
|
||||
if padding:
|
||||
padded_inputs["padding_mask"] = np.pad(
|
||||
padded_inputs["padding_mask"],
|
||||
((0, 0), (pad_left, pad_right)),
|
||||
mode="constant",
|
||||
constant_values=0,
|
||||
)
|
||||
|
||||
input_values = []
|
||||
for example in padded_inputs.pop("input_values"):
|
||||
if self.feature_size == 1:
|
||||
example = example[..., None]
|
||||
input_values.append(example.T)
|
||||
|
||||
padded_inputs["input_values"] = input_values
|
||||
if return_tensors is not None:
|
||||
padded_inputs = padded_inputs.convert_to_tensors(return_tensors)
|
||||
|
||||
return padded_inputs
|
||||
|
||||
|
||||
__all__ = ["KyutaiSpeechToTextFeatureExtractor"]
|
1434
src/transformers/models/stt/modeling_kyutai_speech_to_text.py
Normal file
1434
src/transformers/models/stt/modeling_kyutai_speech_to_text.py
Normal file
File diff suppressed because it is too large
Load Diff
510
src/transformers/models/stt/modular_kyutai_speech_to_text.py
Normal file
510
src/transformers/models/stt/modular_kyutai_speech_to_text.py
Normal file
@ -0,0 +1,510 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 Kyutai and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import types
|
||||
from typing import Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ...cache_utils import Cache
|
||||
from ...feature_extraction_utils import BatchFeature
|
||||
from ...generation import GenerationConfig, GenerationMixin
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...utils import PaddingStrategy, TensorType, logging
|
||||
from ..auto import AutoModel
|
||||
from ..encodec.feature_extraction_encodec import EncodecFeatureExtractor
|
||||
from ..llama.modeling_llama import LlamaForCausalLM
|
||||
from ..mimi.modeling_mimi import MimiConv1dPaddingCache
|
||||
from ..moshi.modeling_moshi import MoshiModel, MoshiPreTrainedModel
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class KyutaiSpeechToTextFeatureExtractor(EncodecFeatureExtractor):
|
||||
r"""
|
||||
Constructs an KyutaiSpeechToText feature extractor.
|
||||
|
||||
This feature extractor inherits from [`~feature_extraction_sequence_utils.SequenceFeatureExtractor`] which contains
|
||||
most of the main methods. Users should refer to this superclass for more information regarding those methods.
|
||||
|
||||
Args:
|
||||
feature_size (`int`, *optional*, defaults to 1):
|
||||
The feature dimension of the extracted features. Use 1 for mono, 2 for stereo.
|
||||
sampling_rate (`int`, *optional*, defaults to 24000):
|
||||
The sampling rate at which the audio waveform should be digitalized expressed in hertz (Hz).
|
||||
padding_value (`float`, *optional*, defaults to 0.0):
|
||||
The value that is used to fill the padding values.
|
||||
chunk_length_s (`float`, *optional*):
|
||||
If defined the audio is pre-processed into chunks of lengths `chunk_length_s` and then encoded.
|
||||
overlap (`float`, *optional*):
|
||||
Defines the overlap between each chunk. It is used to compute the `chunk_stride` using the following
|
||||
formulae : `int((1.0 - self.overlap) * self.chunk_length)`.
|
||||
audio_delay_seconds (`float`, *optional*, defaults to 0.0):
|
||||
The delay in seconds to add after the audio (right padding).
|
||||
audio_silence_prefix_seconds (`float`, *optional*, defaults to 0.0):
|
||||
The silence prefix in seconds to add before the audio (left padding).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
audio_delay_seconds: Optional[float] = 0.0,
|
||||
audio_silence_prefix_seconds: Optional[float] = 0.0,
|
||||
**super_kwargs,
|
||||
):
|
||||
super().__init__(**super_kwargs)
|
||||
self.audio_delay_seconds = audio_delay_seconds
|
||||
self.audio_silence_prefix_seconds = audio_silence_prefix_seconds
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
raw_audio: Union[np.ndarray, list[float], list[np.ndarray], list[list[float]]],
|
||||
padding: Optional[Union[bool, str, PaddingStrategy]] = None,
|
||||
truncation: Optional[bool] = False,
|
||||
max_length: Optional[int] = None,
|
||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||
sampling_rate: Optional[int] = None,
|
||||
) -> BatchFeature:
|
||||
"""
|
||||
Main method to featurize and prepare for the model one or several sequence(s).
|
||||
|
||||
Args:
|
||||
raw_audio (`np.ndarray`, `list[float]`, `list[np.ndarray]`, `list[list[float]]`):
|
||||
The sequence or batch of sequences to be processed. Each sequence can be a numpy array, a list of float
|
||||
values, a list of numpy arrays or a list of list of float values. The numpy array must be of shape
|
||||
`(num_samples,)` for mono audio (`feature_size = 1`), or `(2, num_samples)` for stereo audio
|
||||
(`feature_size = 2`).
|
||||
padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`):
|
||||
Select a strategy to pad the returned sequences (according to the model's padding side and padding
|
||||
index) among:
|
||||
|
||||
- `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
|
||||
sequence if provided).
|
||||
- `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
|
||||
acceptable input length for the model if that argument is not provided.
|
||||
- `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
|
||||
lengths).
|
||||
truncation (`bool`, *optional*, defaults to `False`):
|
||||
Activates truncation to cut input sequences longer than `max_length` to `max_length`.
|
||||
max_length (`int`, *optional*):
|
||||
Maximum length of the returned list and optionally padding length (see above).
|
||||
return_tensors (`str` or [`~utils.TensorType`], *optional*):
|
||||
If set, will return tensors instead of list of python integers. Acceptable values are:
|
||||
|
||||
- `'tf'`: Return TensorFlow `tf.constant` objects.
|
||||
- `'pt'`: Return PyTorch `torch.Tensor` objects.
|
||||
- `'np'`: Return Numpy `np.ndarray` objects.
|
||||
sampling_rate (`int`, *optional*):
|
||||
The sampling rate at which the `audio` input was sampled. It is strongly recommended to pass
|
||||
`sampling_rate` at the forward call to prevent silent errors.
|
||||
"""
|
||||
if sampling_rate is not None:
|
||||
if sampling_rate != self.sampling_rate:
|
||||
raise ValueError(
|
||||
f"The model corresponding to this feature extractor: {self} was trained using a sampling rate of"
|
||||
f" {self.sampling_rate}. Please make sure that the provided audio input was sampled with"
|
||||
f" {self.sampling_rate} and not {sampling_rate}."
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"It is strongly recommended to pass the `sampling_rate` argument to `{self.__class__.__name__}()`. "
|
||||
"Failing to do so can result in silent errors that might be hard to debug."
|
||||
)
|
||||
|
||||
if padding and truncation:
|
||||
raise ValueError("Both padding and truncation were set. Make sure you only set one.")
|
||||
elif padding is None:
|
||||
# by default let's pad the inputs
|
||||
padding = True
|
||||
|
||||
is_batched = bool(
|
||||
isinstance(raw_audio, (list, tuple)) and (isinstance(raw_audio[0], (np.ndarray, tuple, list)))
|
||||
)
|
||||
|
||||
if is_batched:
|
||||
raw_audio = [np.asarray(audio, dtype=np.float32).T for audio in raw_audio]
|
||||
elif not is_batched and not isinstance(raw_audio, np.ndarray):
|
||||
raw_audio = np.asarray(raw_audio, dtype=np.float32)
|
||||
elif isinstance(raw_audio, np.ndarray) and raw_audio.dtype is np.dtype(np.float64):
|
||||
raw_audio = raw_audio.astype(np.float32)
|
||||
|
||||
# always return batch
|
||||
if not is_batched:
|
||||
raw_audio = [np.asarray(raw_audio).T]
|
||||
|
||||
# verify inputs are valid
|
||||
for idx, example in enumerate(raw_audio):
|
||||
if example.ndim > 2:
|
||||
raise ValueError(f"Expected input shape (channels, length) but got shape {example.shape}")
|
||||
if self.feature_size == 1 and example.ndim != 1:
|
||||
raise ValueError(f"Expected mono audio but example has {example.shape[-1]} channels")
|
||||
if self.feature_size == 2 and example.shape[-1] != 2:
|
||||
raise ValueError(f"Expected stereo audio but example has {example.shape[-1]} channels")
|
||||
|
||||
padded_inputs = None
|
||||
input_values = BatchFeature({"input_values": raw_audio})
|
||||
if self.chunk_stride is not None and self.chunk_length is not None and max_length is None:
|
||||
if truncation:
|
||||
max_length = min(array.shape[0] for array in raw_audio)
|
||||
nb_step = int(np.floor(max_length / self.chunk_stride))
|
||||
max_length = (nb_step - 1) * self.chunk_stride + self.chunk_length
|
||||
elif padding:
|
||||
max_length = max(array.shape[0] for array in raw_audio)
|
||||
nb_step = int(np.ceil(max_length / self.chunk_stride))
|
||||
max_length = (nb_step - 1) * self.chunk_stride + self.chunk_length
|
||||
padding = "max_length"
|
||||
else:
|
||||
padded_inputs = input_values
|
||||
|
||||
# normal padding on batch
|
||||
if padded_inputs is None:
|
||||
padded_inputs = self.pad(
|
||||
input_values,
|
||||
max_length=max_length,
|
||||
truncation=truncation,
|
||||
padding=padding,
|
||||
return_attention_mask=padding,
|
||||
)
|
||||
|
||||
if padding:
|
||||
padded_inputs["padding_mask"] = padded_inputs.pop("attention_mask")
|
||||
|
||||
# now let's padd left and right
|
||||
pad_left = int(self.audio_silence_prefix_seconds * self.sampling_rate)
|
||||
pad_right = int((self.audio_delay_seconds + 1.0) * self.sampling_rate)
|
||||
padded_inputs["input_values"] = np.pad(
|
||||
padded_inputs["input_values"],
|
||||
((0, 0), (pad_left, pad_right)),
|
||||
mode="constant",
|
||||
constant_values=0.0,
|
||||
)
|
||||
if padding:
|
||||
padded_inputs["padding_mask"] = np.pad(
|
||||
padded_inputs["padding_mask"],
|
||||
((0, 0), (pad_left, pad_right)),
|
||||
mode="constant",
|
||||
constant_values=0,
|
||||
)
|
||||
|
||||
input_values = []
|
||||
for example in padded_inputs.pop("input_values"):
|
||||
if self.feature_size == 1:
|
||||
example = example[..., None]
|
||||
input_values.append(example.T)
|
||||
|
||||
padded_inputs["input_values"] = input_values
|
||||
if return_tensors is not None:
|
||||
padded_inputs = padded_inputs.convert_to_tensors(return_tensors)
|
||||
|
||||
return padded_inputs
|
||||
|
||||
|
||||
class KyutaiSpeechToTextPreTrainedModel(MoshiPreTrainedModel):
|
||||
pass
|
||||
|
||||
|
||||
class KyutaiSpeechToTextConv1dPaddingCache(MimiConv1dPaddingCache):
|
||||
pass
|
||||
|
||||
|
||||
class KyutaiSpeechToTextEmbeddings(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.embed_tokens = nn.Embedding(
|
||||
config.vocab_size + (config.num_codebooks * config.codebook_vocab_size) + 1,
|
||||
config.hidden_size,
|
||||
padding_idx=config.audio_pad_token_id,
|
||||
)
|
||||
audio_tokens_offsets = torch.arange(config.num_codebooks) * config.codebook_vocab_size
|
||||
audio_tokens_offsets += config.vocab_size
|
||||
audio_tokens_offsets = nn.functional.pad(
|
||||
audio_tokens_offsets, (1, 0)
|
||||
) # pad one 0 to the left for the text token
|
||||
self.register_buffer("audio_tokens_offsets", audio_tokens_offsets, persistent=False)
|
||||
|
||||
def forward(self, input_ids):
|
||||
input_ids = torch.where(
|
||||
input_ids == self.embed_tokens.padding_idx, input_ids, input_ids + self.audio_tokens_offsets
|
||||
)
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
inputs_embeds = inputs_embeds.sum(dim=2)
|
||||
return inputs_embeds
|
||||
|
||||
|
||||
class KyutaiSpeechToTextModel(MoshiModel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.embed_tokens = KyutaiSpeechToTextEmbeddings(config)
|
||||
|
||||
|
||||
class KyutaiSpeechToTextForConditionalGeneration(LlamaForCausalLM, GenerationMixin, PreTrainedModel):
|
||||
_keep_in_fp32_modules = ["codec_model"]
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.codec_model = AutoModel.from_config(config.codec_config)
|
||||
|
||||
# we are in an edge case where for the codec_model self.can_generate is False, setting self.codec_model.generation_config to None
|
||||
# yet the codec_model needs a generation config to initalize it's cache for streaming inference
|
||||
# we therefore initialize a generation config for the codec model
|
||||
self.codec_model.generation_config = GenerationConfig.from_model_config(config.codec_config)
|
||||
|
||||
def forward(self, **super_kwargs):
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> import torch
|
||||
>>> from datasets import load_dataset, Audio
|
||||
>>> from transformers import KyutaiSpeechToTextProcessor, KyutaiSpeechToTextForConditionalGeneration
|
||||
|
||||
>>> torch_device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
>>> model_id = "kyutai/stt-2.6b-en"
|
||||
|
||||
>>> processor = KyutaiSpeechToTextProcessor.from_pretrained(model_id)
|
||||
>>> model = KyutaiSpeechToTextForConditionalGeneration.from_pretrained(model_id, device_map=torch_device)
|
||||
|
||||
>>> ds = load_dataset(
|
||||
... "hf-internal-testing/librispeech_asr_dummy", "clean", split="validation"
|
||||
... )
|
||||
|
||||
>>> ds = ds.cast_column("audio", Audio(sampling_rate=24000))
|
||||
>>> inputs = processor(
|
||||
... ds[0]["audio"]["array"],
|
||||
... )
|
||||
>>> inputs.to(torch_device)
|
||||
|
||||
>>> output_tokens = model.generate(**inputs)
|
||||
>>> print(processor.batch_decode(output_tokens, skip_special_tokens=True))
|
||||
```"""
|
||||
super().forward(**super_kwargs)
|
||||
|
||||
def _prepare_generation_config(self, *args, **kwargs):
|
||||
generation_config, model_kwargs = GenerationMixin._prepare_generation_config(*args, **kwargs)
|
||||
# this should be passed to the model kwargs for the input preparation
|
||||
model_kwargs["audio_window_size"] = (
|
||||
generation_config.audio_window_size if hasattr(generation_config, "audio_window_size") else None
|
||||
)
|
||||
return generation_config, model_kwargs
|
||||
|
||||
def _prepare_model_inputs(
|
||||
self,
|
||||
inputs: Optional[torch.Tensor] = None,
|
||||
bos_token_id: Optional[torch.Tensor] = None,
|
||||
model_kwargs: Optional[dict[str, torch.Tensor]] = None,
|
||||
) -> tuple[torch.Tensor, Optional[str], dict[str, torch.Tensor]]:
|
||||
inputs, input_name, model_kwargs = GenerationMixin._prepare_model_inputs(
|
||||
inputs=inputs,
|
||||
bos_token_id=bos_token_id,
|
||||
model_kwargs=model_kwargs,
|
||||
)
|
||||
|
||||
audio_window_size = model_kwargs.get("audio_window_size", None)
|
||||
if audio_window_size is None:
|
||||
audio_window_size = self.codec_model.get_encoded_length(model_kwargs["input_values"].shape[-1]).item()
|
||||
model_kwargs["audio_window_size"] = audio_window_size
|
||||
|
||||
batch_size = inputs.shape[0]
|
||||
device = inputs.device
|
||||
|
||||
# initialize audio tokens
|
||||
model_kwargs["audio_tokens"] = torch.zeros(
|
||||
(batch_size, audio_window_size, self.config.num_codebooks),
|
||||
device=device,
|
||||
dtype=torch.long,
|
||||
)
|
||||
model_kwargs["current_window"] = (
|
||||
torch.tensor([0, 0], device=device, dtype=torch.long).expand(batch_size, -1).contiguous()
|
||||
)
|
||||
|
||||
# let's use generate's cache preparation to prepare the cache for the codec model
|
||||
temporary_model_kwargs = {}
|
||||
|
||||
# monkey patching the codec model with cache preparation methods since we don't want it to inherit fully from GenerationMixin
|
||||
# Add cache-related methods from GenerationMixin to codec model
|
||||
cache_methods = [
|
||||
"_prepare_cache_for_generation",
|
||||
"_get_cache",
|
||||
"_supports_default_dynamic_cache",
|
||||
"_get_layer_device_map_for_cache_init",
|
||||
]
|
||||
for method in cache_methods:
|
||||
setattr(self.codec_model, method, types.MethodType(getattr(self, method).__func__, self.codec_model))
|
||||
|
||||
self.codec_model._prepare_cache_for_generation(
|
||||
generation_config=self.codec_model.generation_config,
|
||||
model_kwargs=temporary_model_kwargs,
|
||||
assistant_model=None,
|
||||
batch_size=batch_size,
|
||||
max_cache_length=self.config.codec_config.sliding_window,
|
||||
device=device,
|
||||
)
|
||||
|
||||
if "past_key_values" in temporary_model_kwargs:
|
||||
model_kwargs["encoder_past_key_values"] = temporary_model_kwargs["past_key_values"]
|
||||
|
||||
# initialize the padding cache for the codec model
|
||||
per_layer_padding, per_layer_padding_mode, per_layer_in_channels = [], [], []
|
||||
for layer_name in self.codec_model.encoder._mimiconv1d_layer_names:
|
||||
per_layer_padding.append(self.codec_model.encoder.get_submodule(layer_name).padding_total)
|
||||
per_layer_padding_mode.append(self.codec_model.encoder.get_submodule(layer_name).pad_mode)
|
||||
per_layer_in_channels.append(self.codec_model.encoder.get_submodule(layer_name).in_channels)
|
||||
|
||||
# downsample layer
|
||||
per_layer_padding.append(self.codec_model.downsample.padding_total)
|
||||
per_layer_padding_mode.append(self.codec_model.downsample.pad_mode)
|
||||
per_layer_in_channels.append(self.codec_model.downsample.in_channels)
|
||||
|
||||
model_kwargs["padding_cache"] = KyutaiSpeechToTextConv1dPaddingCache(
|
||||
num_layers=len(self.codec_model.encoder._mimiconv1d_layer_names) + 1,
|
||||
per_layer_padding=per_layer_padding,
|
||||
per_layer_padding_mode=per_layer_padding_mode,
|
||||
per_layer_in_channels=per_layer_in_channels,
|
||||
)
|
||||
|
||||
return inputs, input_name, model_kwargs
|
||||
|
||||
def prepare_inputs_for_generation(
|
||||
self,
|
||||
*args,
|
||||
audio_tokens: Optional[torch.LongTensor] = None,
|
||||
input_values: Optional[torch.FloatTensor] = None,
|
||||
padding_mask: Optional[torch.Tensor] = None,
|
||||
audio_window_size: Optional[int] = None,
|
||||
current_window: Optional[tuple[int, int]] = None,
|
||||
encoder_past_key_values: Optional[Cache] = None,
|
||||
padding_cache: Optional[KyutaiSpeechToTextConv1dPaddingCache] = None,
|
||||
**kwargs,
|
||||
):
|
||||
model_inputs = GenerationMixin.prepare_inputs_for_generation(*args, **kwargs)
|
||||
|
||||
if input_values is not None:
|
||||
cache_position = model_inputs["cache_position"]
|
||||
start, end = current_window[0]
|
||||
|
||||
# first cache position is for bos token, so we need to offset by -1
|
||||
if cache_position[-1] - 1 >= end:
|
||||
# we need to encode the new audio tokens
|
||||
with torch.no_grad():
|
||||
input_values_start_idx = start * self.config.frame_size
|
||||
input_values_end_idx = (start + audio_window_size) * self.config.frame_size
|
||||
current_input_values = input_values[..., input_values_start_idx:input_values_end_idx]
|
||||
codec_model_output = self.codec_model.encode(
|
||||
current_input_values,
|
||||
encoder_past_key_values=encoder_past_key_values,
|
||||
padding_cache=padding_cache,
|
||||
)
|
||||
new_audio_tokens = codec_model_output.audio_codes.transpose(1, 2)
|
||||
|
||||
audio_tokens.copy_(new_audio_tokens)
|
||||
|
||||
start = end.clone()
|
||||
end = end + audio_window_size
|
||||
current_window.copy_(
|
||||
torch.tensor([start, end], device=current_window.device).expand(current_window.shape[0], -1)
|
||||
)
|
||||
|
||||
# first cache position is for bos token, so we need to offset by -1
|
||||
current_audio_tokens_idxs = (cache_position - start - 1).clamp(min=0)
|
||||
current_audio_tokens = audio_tokens[:, current_audio_tokens_idxs, :]
|
||||
|
||||
current_audio_tokens[:, cache_position == 0, :] = self.config.audio_bos_token_id
|
||||
|
||||
input_ids = model_inputs.pop("input_ids")
|
||||
input_ids = torch.cat(
|
||||
[input_ids.unsqueeze(2), current_audio_tokens],
|
||||
dim=2,
|
||||
)
|
||||
model_inputs["input_ids"] = input_ids
|
||||
|
||||
return model_inputs
|
||||
|
||||
# TODO: @eustlb, this should be standardized
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
if kwargs.get("output_loading_info", False):
|
||||
model, loading_info = PreTrainedModel.from_pretrained(*args, **kwargs)
|
||||
else:
|
||||
model = PreTrainedModel.from_pretrained(*args, **kwargs)
|
||||
|
||||
# copy depth decoder generation conf attr to the depth decoder generation config
|
||||
prefix = "codec_"
|
||||
prefix_len = len(prefix)
|
||||
codec_model_attrs = {
|
||||
attr[prefix_len:]: value
|
||||
for attr, value in vars(model.generation_config).items()
|
||||
if attr.startswith(prefix)
|
||||
}
|
||||
|
||||
vars(model.codec_model.generation_config).update({"_from_model_config": False, **codec_model_attrs})
|
||||
|
||||
# remove the depth decoder generation conf attr from the model generation config
|
||||
for attr in codec_model_attrs:
|
||||
delattr(model.generation_config, prefix + attr)
|
||||
|
||||
if "output_loading_info" in kwargs:
|
||||
return model, loading_info
|
||||
else:
|
||||
return model
|
||||
|
||||
# TODO: @eustlb, this should be standardized
|
||||
def save_pretrained(self, *args, **kwargs):
|
||||
prefix = "codec_"
|
||||
codec_model_attrs = self.codec_model.generation_config.to_diff_dict()
|
||||
codec_model_attrs.pop("transformers_version", None)
|
||||
for attr, value in codec_model_attrs.items():
|
||||
setattr(self.generation_config, prefix + attr, value)
|
||||
|
||||
PreTrainedModel.save_pretrained(self, *args, **kwargs)
|
||||
|
||||
def generate(self, *args, **kwargs):
|
||||
r"""
|
||||
This method forwards all its arguments to GenerationMixin's [`~GenerationMixin.generate`]. Please refer to the docstring of this method for more information.
|
||||
"""
|
||||
max_new_tokens = kwargs.pop("max_new_tokens", None)
|
||||
input_values = kwargs.get("input_values")
|
||||
|
||||
# TODO: @eustlb, we should have per-batch-idx values
|
||||
# here we do not use padding_mask to be aligned to what's done in the original codebase
|
||||
max_audio_frames = input_values.shape[-1] // self.config.codec_config.frame_size
|
||||
|
||||
if max_new_tokens is None or max_new_tokens > max_audio_frames:
|
||||
if max_new_tokens is not None:
|
||||
logger.warning(
|
||||
f"`max_new_tokens` ({max_new_tokens}) is greater than the maximum number of audio frames ({max_audio_frames})."
|
||||
f"Setting `max_new_tokens` to {max_audio_frames}."
|
||||
)
|
||||
max_new_tokens = max_audio_frames
|
||||
|
||||
return GenerationMixin.generate(
|
||||
*args,
|
||||
max_new_tokens=max_new_tokens,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"KyutaiSpeechToTextPreTrainedModel",
|
||||
"KyutaiSpeechToTextModel",
|
||||
"KyutaiSpeechToTextForConditionalGeneration",
|
||||
"KyutaiSpeechToTextFeatureExtractor",
|
||||
]
|
104
src/transformers/models/stt/processing_kyutai_speech_to_text.py
Normal file
104
src/transformers/models/stt/processing_kyutai_speech_to_text.py
Normal file
@ -0,0 +1,104 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from ...audio_utils import AudioInput, make_list_of_audio
|
||||
from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack
|
||||
|
||||
|
||||
class KyutaiSpeechToTextProcessorKwargs(ProcessingKwargs, total=False):
|
||||
_defaults = {
|
||||
"audio_kwargs": {
|
||||
"sampling_rate": 24000,
|
||||
},
|
||||
"common_kwargs": {"return_tensors": "pt"},
|
||||
}
|
||||
|
||||
|
||||
class KyutaiSpeechToTextProcessor(ProcessorMixin):
|
||||
r"""
|
||||
Constructs a Moshi ASR processor which wraps [`EncodecFeatureExtractor`] and
|
||||
[`PreTrainedTokenizerFast`] into a single processor that inherits both the audio feature extraction and
|
||||
tokenizer functionalities. See the [`~KyutaiSpeechToTextProcessor.__call__`] for more
|
||||
information.
|
||||
"""
|
||||
|
||||
feature_extractor_class = "KyutaiSpeechToTextFeatureExtractor"
|
||||
tokenizer_class = "PreTrainedTokenizerFast"
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
audio: Optional[AudioInput] = None,
|
||||
**kwargs: Unpack[KyutaiSpeechToTextProcessorKwargs],
|
||||
):
|
||||
r"""
|
||||
Main method to prepare audio to be fed as input to the model. This method forwards the `audio`
|
||||
arguments to KyutaiSpeechToTextFeatureExtractor's [`~KyutaiSpeechToTextFeatureExtractor.__call__`]. Please refer
|
||||
to the docstring of the above method for more information.
|
||||
|
||||
Args:
|
||||
audio (`np.ndarray`, `torch.Tensor`, `list[np.ndarray]`, `list[torch.Tensor]`):
|
||||
The audio or batch of audio to be prepared. Each audio can be a NumPy array or PyTorch
|
||||
tensor.
|
||||
return_tensors (`str` or [`~utils.TensorType`], *optional*):
|
||||
If set, will return tensors of a particular framework. Acceptable values are:
|
||||
- `'tf'`: Return TensorFlow `tf.constant` objects.
|
||||
- `'pt'`: Return PyTorch `torch.Tensor` objects.
|
||||
- `'np'`: Return NumPy `np.ndarray` objects.
|
||||
- `'jax'`: Return JAX `jnp.ndarray` objects.
|
||||
Returns:
|
||||
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
|
||||
|
||||
- **input_values** -- List of audio values to be fed to a model. Returned when `audio` is not `None`.
|
||||
- **padding_mask** -- List of indices specifying which input values should be ignored by the model.
|
||||
"""
|
||||
|
||||
if audio is None:
|
||||
raise ValueError("`audio` is required.")
|
||||
|
||||
output_kwargs = self._merge_kwargs(
|
||||
KyutaiSpeechToTextProcessorKwargs,
|
||||
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
audio_kwargs = output_kwargs["audio_kwargs"]
|
||||
|
||||
# ensure audio in correct format
|
||||
audio = make_list_of_audio(audio)
|
||||
|
||||
inputs = self.feature_extractor(
|
||||
audio,
|
||||
**audio_kwargs,
|
||||
)
|
||||
|
||||
return inputs
|
||||
|
||||
def decode(self, *args, **kwargs):
|
||||
"""
|
||||
This method forwards all its arguments to KyutaiSpeechToTextTokenizer's [`~PreTrainedTokenizer.decode`]. Please refer to
|
||||
the docstring of this method for more information.
|
||||
"""
|
||||
return self.tokenizer.decode(*args, **kwargs)
|
||||
|
||||
def batch_decode(self, *args, **kwargs):
|
||||
"""
|
||||
This method forwards all its arguments to KyutaiSpeechToTextTokenizer's [`~PreTrainedTokenizer.batch_decode`]. Please
|
||||
refer to the docstring of this method for more information.
|
||||
"""
|
||||
return self.tokenizer.batch_decode(*args, **kwargs)
|
||||
|
||||
|
||||
__all__ = ["KyutaiSpeechToTextProcessor"]
|
0
tests/models/kyutai_speech_to_text/__init__.py
Normal file
0
tests/models/kyutai_speech_to_text/__init__.py
Normal file
@ -0,0 +1,704 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Testing suite for the PyTorch Moshi ASR model."""
|
||||
|
||||
import gc
|
||||
import inspect
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import datasets
|
||||
import pytest
|
||||
from parameterized import parameterized
|
||||
|
||||
from transformers import (
|
||||
KyutaiSpeechToTextConfig,
|
||||
KyutaiSpeechToTextForConditionalGeneration,
|
||||
KyutaiSpeechToTextProcessor,
|
||||
is_torch_available,
|
||||
)
|
||||
from transformers.testing_utils import (
|
||||
cleanup,
|
||||
require_torch,
|
||||
require_torch_accelerator,
|
||||
require_torch_sdpa,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
from ...generation.test_utils import GenerationTesterMixin
|
||||
from ...test_configuration_common import ConfigTester
|
||||
from ...test_modeling_common import (
|
||||
TEST_EAGER_MATCHES_SDPA_INFERENCE_PARAMETERIZATION,
|
||||
ModelTesterMixin,
|
||||
_config_zero_init,
|
||||
floats_tensor,
|
||||
ids_tensor,
|
||||
)
|
||||
from ...test_pipeline_mixin import PipelineTesterMixin
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers import (
|
||||
KyutaiSpeechToTextForConditionalGeneration,
|
||||
KyutaiSpeechToTextModel,
|
||||
)
|
||||
|
||||
|
||||
class KyutaiSpeechToTextModelTester:
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
batch_size=13,
|
||||
seq_length=7,
|
||||
text_seq_length=1,
|
||||
input_values_length=192, # gives 3 audio tokens, corresponding to the default in GenerationTesterMixin
|
||||
is_training=False,
|
||||
use_input_mask=True,
|
||||
use_token_type_ids=False,
|
||||
use_labels=True,
|
||||
codebook_vocab_size=2049,
|
||||
vocab_size=99,
|
||||
hidden_size=32,
|
||||
num_hidden_layers=2,
|
||||
num_attention_heads=4,
|
||||
num_key_value_heads=None,
|
||||
max_position_embeddings=512,
|
||||
rope_theta=10000.0,
|
||||
hidden_act="silu",
|
||||
head_dim=None,
|
||||
initializer_range=0.02,
|
||||
use_cache=True,
|
||||
sliding_window=512,
|
||||
attention_dropout=0.1,
|
||||
ffn_dim=38,
|
||||
rms_norm_eps=1e-6,
|
||||
num_codebooks=8,
|
||||
frame_size=64,
|
||||
delay_in_tokens=5,
|
||||
audio_bos_token_id=2048,
|
||||
audio_pad_token_id=2048,
|
||||
tie_word_embeddings=False,
|
||||
pad_token_id=0,
|
||||
bos_token_id=1,
|
||||
codec_config={
|
||||
"model_type": "mimi",
|
||||
"num_quantizers": 8,
|
||||
"audio_channels": 1,
|
||||
"chunk_in_sec": None,
|
||||
"hidden_size": 16,
|
||||
"num_filters": 8,
|
||||
"num_residual_layers": 1,
|
||||
"upsampling_ratios": [8, 4],
|
||||
"codebook_size": 16,
|
||||
"vector_quantization_hidden_dimension": 16,
|
||||
"upsample_groups": 16,
|
||||
"num_hidden_layers": 2,
|
||||
"num_attention_heads": 2,
|
||||
"num_key_value_heads": 2,
|
||||
"sliding_window": 4,
|
||||
"codebook_dim": 16,
|
||||
"use_cache": False,
|
||||
},
|
||||
scope=None,
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
self.seq_length = seq_length
|
||||
self.text_seq_length = text_seq_length
|
||||
self.is_training = is_training
|
||||
self.use_input_mask = use_input_mask
|
||||
self.use_token_type_ids = use_token_type_ids
|
||||
self.use_labels = use_labels
|
||||
self.codebook_vocab_size = codebook_vocab_size
|
||||
self.vocab_size = vocab_size
|
||||
self.hidden_size = hidden_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.num_key_value_heads = num_key_value_heads
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.rope_theta = rope_theta
|
||||
self.hidden_act = hidden_act
|
||||
self.head_dim = head_dim
|
||||
self.initializer_range = initializer_range
|
||||
self.use_cache = use_cache
|
||||
self.sliding_window = sliding_window
|
||||
self.attention_dropout = attention_dropout
|
||||
self.ffn_dim = ffn_dim
|
||||
self.rms_norm_eps = rms_norm_eps
|
||||
self.num_codebooks = num_codebooks
|
||||
self.frame_size = frame_size
|
||||
self.delay_in_tokens = delay_in_tokens
|
||||
self.audio_bos_token_id = audio_bos_token_id
|
||||
self.audio_pad_token_id = audio_pad_token_id
|
||||
self.tie_word_embeddings = tie_word_embeddings
|
||||
self.pad_token_id = pad_token_id
|
||||
self.bos_token_id = bos_token_id
|
||||
self.codec_config = codec_config
|
||||
self.scope = scope
|
||||
self.input_values_length = input_values_length
|
||||
|
||||
def get_config(self):
|
||||
return KyutaiSpeechToTextConfig(
|
||||
codebook_vocab_size=self.codebook_vocab_size,
|
||||
vocab_size=self.vocab_size,
|
||||
hidden_size=self.hidden_size,
|
||||
num_hidden_layers=self.num_hidden_layers,
|
||||
num_attention_heads=self.num_attention_heads,
|
||||
num_key_value_heads=self.num_key_value_heads,
|
||||
max_position_embeddings=self.max_position_embeddings,
|
||||
rope_theta=self.rope_theta,
|
||||
hidden_act=self.hidden_act,
|
||||
head_dim=self.head_dim,
|
||||
initializer_range=self.initializer_range,
|
||||
use_cache=self.use_cache,
|
||||
sliding_window=self.sliding_window,
|
||||
attention_dropout=self.attention_dropout,
|
||||
ffn_dim=self.ffn_dim,
|
||||
rms_norm_eps=self.rms_norm_eps,
|
||||
num_codebooks=self.num_codebooks,
|
||||
frame_size=self.frame_size,
|
||||
delay_in_tokens=self.delay_in_tokens,
|
||||
audio_bos_token_id=self.audio_bos_token_id,
|
||||
audio_pad_token_id=self.audio_pad_token_id,
|
||||
tie_word_embeddings=self.tie_word_embeddings,
|
||||
pad_token_id=self.pad_token_id,
|
||||
bos_token_id=self.bos_token_id,
|
||||
codec_config=self.codec_config,
|
||||
)
|
||||
|
||||
def create_and_check_model(self, config, input_ids, input_mask):
|
||||
model = KyutaiSpeechToTextModel(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(input_ids, attention_mask=input_mask)
|
||||
result = model(input_ids)
|
||||
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
config = self.get_config()
|
||||
|
||||
text_input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size - 1) + 1
|
||||
codebook_input_ids = (
|
||||
ids_tensor([self.batch_size, self.seq_length, self.num_codebooks], self.codebook_vocab_size - 1) + 1
|
||||
)
|
||||
|
||||
input_ids = torch.cat([text_input_ids.unsqueeze(2), codebook_input_ids], dim=2)
|
||||
attention_mask = text_input_ids.ne(1).to(torch_device)
|
||||
|
||||
return config, input_ids, attention_mask
|
||||
|
||||
def prepare_config_and_inputs_generate(self):
|
||||
config = self.get_config()
|
||||
|
||||
input_ids = torch.ones([self.batch_size, 1], dtype=torch.long, device=torch_device)
|
||||
input_values = floats_tensor([self.batch_size, 1, self.input_values_length])
|
||||
padding_mask = torch.ones_like(input_values, dtype=torch.int32, device=torch_device)
|
||||
|
||||
return config, input_ids, input_values, padding_mask
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
(
|
||||
config,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
) = config_and_inputs
|
||||
inputs_dict = {"input_ids": input_ids, "attention_mask": attention_mask}
|
||||
return config, inputs_dict
|
||||
|
||||
def prepare_config_and_inputs_for_common_generate(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs_generate()
|
||||
(
|
||||
config,
|
||||
input_ids,
|
||||
input_values,
|
||||
padding_mask,
|
||||
) = config_and_inputs
|
||||
inputs_dict = {
|
||||
"input_ids": input_ids,
|
||||
"input_values": input_values,
|
||||
"padding_mask": padding_mask,
|
||||
}
|
||||
return config, inputs_dict
|
||||
|
||||
|
||||
@require_torch
|
||||
class KyutaiSpeechToTextModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (
|
||||
(
|
||||
KyutaiSpeechToTextModel,
|
||||
KyutaiSpeechToTextForConditionalGeneration,
|
||||
)
|
||||
if is_torch_available()
|
||||
else ()
|
||||
)
|
||||
pipeline_model_mapping = (
|
||||
{
|
||||
"feature-extraction": KyutaiSpeechToTextModel,
|
||||
"automatic-speech-recognition": KyutaiSpeechToTextForConditionalGeneration,
|
||||
}
|
||||
if is_torch_available()
|
||||
else {}
|
||||
)
|
||||
test_headmasking = False
|
||||
test_pruning = False
|
||||
fx_compatible = False # Broken by attention refactor cc @Cyrilvallez
|
||||
|
||||
# Need to use `0.8` instead of `0.9` for `test_cpu_offload`
|
||||
# This is because we are hitting edge cases with the causal_mask buffer
|
||||
model_split_percents = [0.5, 0.7, 0.8]
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = KyutaiSpeechToTextModelTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=KyutaiSpeechToTextConfig, hidden_size=37)
|
||||
|
||||
def test_config(self):
|
||||
self.config_tester.run_common_tests()
|
||||
|
||||
def test_model(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||
|
||||
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
|
||||
inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels)
|
||||
|
||||
return inputs_dict
|
||||
|
||||
def prepare_config_and_inputs_for_generate(self, batch_size=2):
|
||||
# monkey patch prepare_config_and_inputs_for_common
|
||||
|
||||
prepare_config_and_inputs_for_common = self.model_tester.prepare_config_and_inputs_for_common
|
||||
original_batch_size = self.model_tester.batch_size
|
||||
|
||||
self.model_tester.prepare_config_and_inputs_for_common = (
|
||||
self.model_tester.prepare_config_and_inputs_for_common_generate
|
||||
)
|
||||
self.model_tester.batch_size = batch_size
|
||||
|
||||
config, filtered_inputs_dict = super().prepare_config_and_inputs_for_generate()
|
||||
self.model_tester.prepare_config_and_inputs_for_common = prepare_config_and_inputs_for_common
|
||||
|
||||
self.model_tester.batch_size = original_batch_size
|
||||
return config, filtered_inputs_dict
|
||||
|
||||
@pytest.mark.skip(reason="Moshi ASR has custom embedding approach (text and audio embeddings).")
|
||||
def test_model_get_set_embeddings(self):
|
||||
pass
|
||||
|
||||
@pytest.mark.skip(reason="Moshi ASR has custom embedding approach (text and audio embeddings).")
|
||||
def test_tie_model_weights(self):
|
||||
pass
|
||||
|
||||
@pytest.mark.skip(reason="Moshi ASR has custom embedding approach (text and audio embeddings).")
|
||||
def test_resize_embeddings_untied(self):
|
||||
pass
|
||||
|
||||
@pytest.mark.skip(reason="Moshi ASR has custom embedding approach (text and audio embeddings).")
|
||||
def test_resize_tokens_embeddings(self):
|
||||
pass
|
||||
|
||||
@pytest.mark.skip(reason="Moshi ASR has custom embedding approach (text and audio embeddings).")
|
||||
def test_tied_weights_keys(self):
|
||||
pass
|
||||
|
||||
@pytest.mark.skip(reason="Does not apply to Moshi ASR that requires input_values.")
|
||||
def test_generate_without_input_ids(self):
|
||||
pass
|
||||
|
||||
def test_initialization(self):
|
||||
"""
|
||||
Overrides [ModelTesterMixin.test_initialization] because of specificities of Mimi codec model.
|
||||
See https://github.com/huggingface/transformers/blob/1077603410cd73ba71d64a522033574d66d64b55/tests/models/mimi/test_modeling_mimi.py#L384-L397
|
||||
"""
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
configs_no_init = _config_zero_init(config)
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config=configs_no_init)
|
||||
for name, param in model.named_parameters():
|
||||
uniform_init_parms = ["conv", "input_proj", "output_proj"]
|
||||
if param.requires_grad:
|
||||
if any(x in name for x in uniform_init_parms):
|
||||
self.assertTrue(
|
||||
-1.0 <= ((param.data.mean() * 1e9).round() / 1e9).item() <= 1.0,
|
||||
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
|
||||
)
|
||||
|
||||
@parameterized.expand(TEST_EAGER_MATCHES_SDPA_INFERENCE_PARAMETERIZATION)
|
||||
@require_torch_sdpa
|
||||
def test_eager_matches_sdpa_inference(
|
||||
self, name, torch_dtype, padding_side, use_attention_mask, output_attentions, enable_kernels
|
||||
):
|
||||
if use_attention_mask or (not use_attention_mask and torch_dtype == "fp32" and not output_attentions):
|
||||
self.skipTest("Test is failing, fix me :) ")
|
||||
parent_parameterized_test = getattr(ModelTesterMixin, self._testMethodName)
|
||||
parent_parameterized_test(self)
|
||||
|
||||
@unittest.skip(reason="Some undefined behavior encountered with test versions of this model. Skip for now.")
|
||||
def test_cpu_offload(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Some undefined behavior encountered with test versions of this model. Skip for now.")
|
||||
def test_disk_offload_bin(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Some undefined behavior encountered with test versions of this model. Skip for now.")
|
||||
def test_disk_offload_safetensors(self):
|
||||
pass
|
||||
|
||||
@pytest.mark.generate
|
||||
def test_left_padding_compatibility(self):
|
||||
# NOTE: left-padding results in small numerical differences. This is expected.
|
||||
# See https://github.com/huggingface/transformers/issues/25420#issuecomment-1775317535
|
||||
|
||||
# First, filter out models that don't support left padding
|
||||
# - The model must have generative capabilities
|
||||
if len(self.all_generative_model_classes) == 0:
|
||||
self.skipTest(reason="No generative architecture available for this model.")
|
||||
|
||||
# - The model must support padding
|
||||
if not self.has_attentions:
|
||||
self.skipTest(reason="This model doesn't support padding.")
|
||||
|
||||
# - The model must be a decoder-only architecture (encoder-based architectures use right-padding)
|
||||
decoder_only_classes = []
|
||||
for model_class in self.all_generative_model_classes:
|
||||
config, _ = self.prepare_config_and_inputs_for_generate()
|
||||
if config.is_encoder_decoder:
|
||||
continue
|
||||
else:
|
||||
decoder_only_classes.append(model_class)
|
||||
if len(decoder_only_classes) == 0:
|
||||
self.skipTest(reason="No decoder-only architecture available for this model.")
|
||||
|
||||
# - Decoder-only architectures derived from encoder-decoder models could support it in theory, but we haven't
|
||||
# added support for it yet. We skip these models for now.
|
||||
has_encoder_attributes = any(
|
||||
attr_name
|
||||
for attr_name in config.to_dict().keys()
|
||||
if attr_name.startswith("encoder") and attr_name != "encoder_no_repeat_ngram_size"
|
||||
)
|
||||
if has_encoder_attributes:
|
||||
self.skipTest(
|
||||
reason="The decoder-only derived from encoder-decoder models are not expected to support left-padding."
|
||||
)
|
||||
|
||||
# Then, test left-padding
|
||||
def _prepare_model_kwargs(input_ids, attention_mask, signature):
|
||||
model_kwargs = {"input_ids": input_ids, "attention_mask": attention_mask}
|
||||
if "position_ids" in signature:
|
||||
position_ids = torch.cumsum(attention_mask, dim=-1) - 1
|
||||
position_ids.masked_fill_(attention_mask == 0, 1)
|
||||
model_kwargs["position_ids"] = position_ids
|
||||
if "cache_position" in signature:
|
||||
cache_position = torch.arange(input_ids.shape[1], device=torch_device)
|
||||
model_kwargs["cache_position"] = cache_position
|
||||
return model_kwargs
|
||||
|
||||
for model_class in decoder_only_classes:
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
input_ids = inputs_dict["input_ids"]
|
||||
attention_mask = inputs_dict.get("attention_mask")
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones_like(input_ids)
|
||||
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
signature = inspect.signature(model.forward).parameters.keys()
|
||||
|
||||
# no cache as some models require special cache classes to be init outside forward
|
||||
model.generation_config.use_cache = False
|
||||
|
||||
# Without padding
|
||||
model_kwargs = _prepare_model_kwargs(input_ids, attention_mask, signature)
|
||||
next_logits_wo_padding = model(**model_kwargs).logits[:, -1, :]
|
||||
|
||||
# With left-padding (length 32)
|
||||
# can hardcode pad_token to be 0 as we'll do attn masking anyway
|
||||
pad_token_id = (
|
||||
config.get_text_config().pad_token_id if config.get_text_config().pad_token_id is not None else 0
|
||||
)
|
||||
pad_size = (input_ids.shape[0], 32, *input_ids.shape[2:])
|
||||
padding = torch.ones(pad_size, dtype=input_ids.dtype, device=torch_device) * pad_token_id
|
||||
padded_input_ids = torch.cat((padding, input_ids), dim=1)
|
||||
padded_attention_mask = torch.cat(
|
||||
(torch.zeros(pad_size[:2], dtype=input_ids.dtype, device=torch_device), attention_mask), dim=1
|
||||
)
|
||||
model_kwargs = _prepare_model_kwargs(padded_input_ids, padded_attention_mask, signature)
|
||||
next_logits_with_padding = model(**model_kwargs).logits[:, -1, :]
|
||||
|
||||
# They should result in very similar logits
|
||||
torch.testing.assert_close(next_logits_wo_padding, next_logits_with_padding, rtol=1e-5, atol=1e-5)
|
||||
|
||||
def test_generate_continue_from_past_key_values(self):
|
||||
# Tests that we can continue generating from past key values, returned from a previous `generate` call
|
||||
for model_class in self.all_generative_model_classes:
|
||||
if any(model_name in model_class.__name__.lower() for model_name in ["imagegpt", "mllama"]):
|
||||
self.skipTest(reason="Won't fix: old model with unique inputs/caches/other")
|
||||
if any(model_name in model_class.__name__.lower() for model_name in ["umt5"]):
|
||||
self.skipTest(reason="TODO: needs modeling or test input preparation fixes for compatibility")
|
||||
|
||||
config, inputs = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
if not hasattr(config.get_text_config(), "use_cache"):
|
||||
self.skipTest(reason=f"{model_class.__name__} doesn't support caching")
|
||||
|
||||
# Let's make it always:
|
||||
# 1. use cache (for obvious reasons)
|
||||
# 2. generate to max length (which can be achieved by setting the eos token to an invalid value), which
|
||||
# would make the test flaky (e.g. EOS is generated on iteration 1 on both generations, but the
|
||||
# continuation would force it to generate beyond an EOS token)
|
||||
# 3. ignore `token_type_ids` for simplicity
|
||||
# 4. ignore `forced_eos_token_id`, which requires further manipulation of the continuation inputs and is
|
||||
# active by default on some models
|
||||
# 5. ignore `encoder_no_repeat_ngram_size`, which is set by default in some encoder-decoder models. When
|
||||
# we use their decoder as a stand-alone model, `encoder_no_repeat_ngram_size` actually prevents
|
||||
# repetition exclusively from the prompt. This test relies on comparing one call vs 2 calls
|
||||
# with cache, what is considered a prompt is different in the two cases.
|
||||
|
||||
if "token_type_ids" in inputs:
|
||||
del inputs["token_type_ids"]
|
||||
|
||||
model = model_class(config).to(torch_device)
|
||||
model.eval()
|
||||
|
||||
# If "past_key_values" is not returned, skip the test (e.g. RWKV uses a different cache name and format)
|
||||
outputs = model(**inputs)
|
||||
if "past_key_values" not in outputs:
|
||||
self.skipTest(reason="This model doesn't return `past_key_values`")
|
||||
|
||||
generate_kwargs = {
|
||||
"pad_token_id": -1,
|
||||
"eos_token_id": -1,
|
||||
"forced_eos_token_id": None,
|
||||
"encoder_no_repeat_ngram_size": 0,
|
||||
"use_cache": True,
|
||||
"do_sample": False,
|
||||
"return_dict_in_generate": True,
|
||||
"output_scores": True,
|
||||
}
|
||||
|
||||
# Traditional way of generating text, with `return_dict_in_generate` to return the past key values
|
||||
_, inputs = self.prepare_config_and_inputs_for_generate()
|
||||
outputs = model.generate(**inputs, **generate_kwargs, max_new_tokens=3)
|
||||
|
||||
# Let's generate again, but passing the past key values in between (2 + 1 = 3 tokens). Note that the
|
||||
# inputs may need to be tweaked across `generate` calls (like the attention mask).
|
||||
outputs_cached = model.generate(**inputs, **generate_kwargs, max_new_tokens=2)
|
||||
|
||||
# Continue from the tokens generated above, preparing the inputs accordingly
|
||||
inputs["past_key_values"] = outputs_cached.past_key_values
|
||||
new_attention_len = outputs_cached.sequences.shape[-1]
|
||||
if config.is_encoder_decoder:
|
||||
inputs["decoder_input_ids"] = outputs_cached.sequences
|
||||
if "decoder_attention_mask" in inputs:
|
||||
inputs["decoder_attention_mask"] = torch.nn.functional.pad(
|
||||
inputs["decoder_attention_mask"],
|
||||
(0, new_attention_len - inputs["decoder_attention_mask"].shape[1]),
|
||||
mode="constant",
|
||||
value=1,
|
||||
)
|
||||
else:
|
||||
inputs["input_ids"] = outputs_cached.sequences
|
||||
if "attention_mask" in inputs:
|
||||
inputs["attention_mask"] = torch.nn.functional.pad(
|
||||
inputs["attention_mask"],
|
||||
(0, new_attention_len - inputs["attention_mask"].shape[1]),
|
||||
mode="constant",
|
||||
value=1,
|
||||
)
|
||||
first_caches_scores = outputs_cached.scores
|
||||
outputs_cached = model.generate(**inputs, **generate_kwargs, max_new_tokens=1)
|
||||
full_cached_scores = first_caches_scores + outputs_cached.scores
|
||||
outputs_cached.scores = full_cached_scores
|
||||
|
||||
# The two sets of generated text and past kv should be equal to each other
|
||||
self._check_similar_generate_outputs(outputs, outputs_cached)
|
||||
for layer_idx in range(len(outputs_cached.past_key_values)):
|
||||
for kv_idx in range(len(outputs_cached.past_key_values[layer_idx])):
|
||||
self.assertTrue(
|
||||
torch.allclose(
|
||||
outputs.past_key_values[layer_idx][kv_idx],
|
||||
outputs_cached.past_key_values[layer_idx][kv_idx],
|
||||
)
|
||||
)
|
||||
|
||||
# needs to be overridden to avoid to avoid casting of input_values to float16
|
||||
# indeed, the codec model is kept in fp32, so we need to avoid casting input_values to float16
|
||||
def _test_attention_implementation(self, attn_implementation):
|
||||
"""
|
||||
Compares the output of generate with the eager attention implementation against other implementations.
|
||||
NOTE: despite the test logic being the same, different implementations actually need different decorators, hence
|
||||
this separate function.
|
||||
"""
|
||||
max_new_tokens = 30
|
||||
support_flag = {
|
||||
"sdpa": "_supports_sdpa",
|
||||
"flash_attention_2": "_supports_flash_attn_2",
|
||||
}
|
||||
|
||||
for model_class in self.all_generative_model_classes:
|
||||
if not getattr(model_class, support_flag[attn_implementation]):
|
||||
self.skipTest(f"{model_class.__name__} does not support `attn_implementation={attn_implementation}`")
|
||||
|
||||
config, original_inputs_dict = self.prepare_config_and_inputs_for_generate()
|
||||
inputs_dict = {}
|
||||
for input_name, input_data in original_inputs_dict.items():
|
||||
if (
|
||||
isinstance(input_data, torch.Tensor)
|
||||
and input_data.dtype in [torch.float32, torch.bfloat16]
|
||||
and input_name != "input_values"
|
||||
):
|
||||
inputs_dict[input_name] = input_data.to(torch.float16)
|
||||
else:
|
||||
inputs_dict[input_name] = input_data
|
||||
main_input = inputs_dict[model_class.main_input_name]
|
||||
|
||||
# FA2 doesn't accept masking in the middle of the sequence for now. We usually generate right-padded
|
||||
# attention masks at test time and, with generate, the mask will be appended with 1s on the right,
|
||||
# resulting in a mask with holes (not supported properly by FA2).
|
||||
if attn_implementation == "flash_attention_2":
|
||||
for input_name in ("attention_mask", "decoder_attention_mask", "encoder_attention_mask"):
|
||||
if input_name in inputs_dict:
|
||||
inputs_dict[input_name] = torch.ones_like(inputs_dict[input_name])
|
||||
|
||||
# make sure that all models have enough positions for generation
|
||||
if hasattr(config, "max_position_embeddings"):
|
||||
config.max_position_embeddings = max_new_tokens + main_input.shape[1] + 1
|
||||
|
||||
model = model_class(config)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname)
|
||||
del model
|
||||
gc.collect()
|
||||
|
||||
generate_kwargs = {
|
||||
"max_new_tokens": max_new_tokens,
|
||||
"do_sample": False,
|
||||
"return_dict_in_generate": True,
|
||||
"output_scores": True,
|
||||
"use_cache": True,
|
||||
}
|
||||
|
||||
model_eager = model_class.from_pretrained(
|
||||
tmpdirname,
|
||||
torch_dtype=torch.float16,
|
||||
attn_implementation="eager",
|
||||
).to(torch_device)
|
||||
res_eager = model_eager.generate(**inputs_dict, **generate_kwargs)
|
||||
del model_eager
|
||||
gc.collect()
|
||||
|
||||
model_attn = model_class.from_pretrained(
|
||||
tmpdirname,
|
||||
torch_dtype=torch.float16,
|
||||
attn_implementation=attn_implementation,
|
||||
).to(torch_device)
|
||||
res_attn = model_attn.generate(**inputs_dict, **generate_kwargs)
|
||||
del model_attn
|
||||
gc.collect()
|
||||
|
||||
self._check_similar_generate_outputs(res_eager, res_attn, atol=1e-3, rtol=1e-3)
|
||||
|
||||
|
||||
class KyutaiSpeechToTextForConditionalGenerationIntegrationTests(unittest.TestCase):
|
||||
_dataset = None
|
||||
|
||||
def setUp(self):
|
||||
self.model_checkpoint = "kyutai/stt-2.6b-en"
|
||||
|
||||
def tearDown(self):
|
||||
cleanup(torch_device, gc_collect=True)
|
||||
|
||||
@classmethod
|
||||
def _load_dataset(cls):
|
||||
# Lazy loading of the dataset. Because it is a class method, it will only be loaded once per pytest process.
|
||||
if cls._dataset is None:
|
||||
cls._dataset = datasets.load_dataset(
|
||||
"hf-internal-testing/librispeech_asr_dummy", "clean", split="validation"
|
||||
)
|
||||
# using 24000 here for simplicity, should rather be processor.feature_extractor.sampling_rate
|
||||
cls._dataset = cls._dataset.cast_column("audio", datasets.Audio(sampling_rate=24000))
|
||||
|
||||
def _load_datasamples(self, num_samples):
|
||||
self._load_dataset()
|
||||
ds = self._dataset
|
||||
speech_samples = ds.sort("id").select(range(num_samples))[:num_samples]["audio"]
|
||||
return [x["array"] for x in speech_samples]
|
||||
|
||||
@slow
|
||||
@require_torch_accelerator
|
||||
def test_generation(self):
|
||||
"""
|
||||
reproduce test expected outputs using original codebase: https://gist.github.com/eustlb/7a9aa6139d11e0103c6b65bac103da52
|
||||
|
||||
DISCLAIMER: we are testing for pretty short inputs. Indeed, reproducing correct expected outputs for longer is not possible
|
||||
as implementation choices (qkv matrix in one linear for original code vs three for hf) create growing divergence with context lenght,
|
||||
ultimately giving different outputs.
|
||||
"""
|
||||
processor = KyutaiSpeechToTextProcessor.from_pretrained(self.model_checkpoint)
|
||||
model = KyutaiSpeechToTextForConditionalGeneration.from_pretrained(
|
||||
self.model_checkpoint, device_map=torch_device
|
||||
)
|
||||
|
||||
samples = self._load_datasamples(1)
|
||||
inputs = processor(
|
||||
samples,
|
||||
).to(torch_device)
|
||||
|
||||
out = model.generate(**inputs)
|
||||
|
||||
# fmt: off
|
||||
EXPECTED_TOKENS = torch.tensor([
|
||||
[48000, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 0, 1519, 263, 3, 3, 0, 3635, 428, 641, 0, 277, 3, 0, 265, 0, 267, 1162, 261, 274, 410, 0, 272, 3, 0, 265, 0, 260, 1621, 0, 1174, 371, 262, 3, 3, 3, 0, 269, 0, 281, 0, 304, 0, 2433, 3, 0, 266, 3, 0, 281, 1661, 3, 0, 376, 3, 3, 0, 350, 261, 401, 516, 263, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3]],
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
torch.testing.assert_close(out.cpu(), EXPECTED_TOKENS)
|
||||
|
||||
@slow
|
||||
@require_torch_accelerator
|
||||
def test_generation_batched(self):
|
||||
"""
|
||||
reproduce test expected outputs using original codebase: https://gist.github.com/eustlb/b58c217c75124d405ec1c13877c7ece8
|
||||
|
||||
DISCLAIMER: we are testing for pretty short inputs. Indeed, reproducing correct expected outputs for longer is not possible
|
||||
as implementation choices (qkv matrix in one linear for original code vs three for hf) create growing divergence with context lenght,
|
||||
ultimately giving different outputs.
|
||||
"""
|
||||
processor = KyutaiSpeechToTextProcessor.from_pretrained(self.model_checkpoint)
|
||||
model = KyutaiSpeechToTextForConditionalGeneration.from_pretrained(
|
||||
self.model_checkpoint, device_map=torch_device
|
||||
)
|
||||
|
||||
samples = self._load_datasamples(4)
|
||||
inputs = processor(
|
||||
samples,
|
||||
).to(torch_device)
|
||||
|
||||
out = model.generate(**inputs)
|
||||
|
||||
# fmt: off
|
||||
EXPECTED_TOKENS = torch.tensor([
|
||||
[48000, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 0, 1519, 263, 3, 3, 0, 3635, 428, 641, 0, 277, 3, 0, 265, 0, 267, 1162, 261, 274, 410, 0, 272, 3, 0, 265, 0, 260, 1621, 0, 1174, 371, 262, 3, 3, 3, 0, 269, 0, 281, 0, 304, 0, 2433, 3, 0, 266, 3, 0, 281, 1661, 3, 0, 376, 3, 3, 0, 350, 261, 401, 516, 263, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3],
|
||||
[48000, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 0, 500, 334, 0, 277, 3, 0, 1519, 263, 3, 3, 0, 3635, 428, 641, 264, 261, 0, 511, 1109, 3, 0, 1138, 3, 3, 3, 0, 508, 827, 3, 3, 3, 3, 0, 468, 3, 3, 0, 376, 3, 3, 3, 0, 260, 978, 263, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3],
|
||||
[48000, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 0, 414, 0, 527, 261, 3, 0, 409, 3, 3, 3, 0, 271, 3, 0, 309, 3, 0, 285, 3, 0, 521, 371, 609, 3, 3, 0, 260, 959, 3, 3, 3, 0, 272, 3, 0, 265, 0, 546, 262, 3, 3, 3, 3, 3, 3, 0, 291, 3, 0, 975, 2203, 3, 3, 3, 3, 0, 269, 3, 0, 260, 489, 651, 274, 279, 1870, 3, 0, 1084, 873, 273, 3, 0, 260, 531, 3, 3, 0, 409, 262, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 0, 1502, 1005, 836, 3, 3, 0, 1666, 306, 3, 0, 340, 3, 0, 260, 3232, 3, 0, 269, 3, 3, 0, 275, 261, 0, 260, 1379, 261, 0, 3324, 3, 3, 3, 3, 0, 549, 3, 3, 0, 693, 405, 323, 3, 0, 266, 3, 3, 0, 265, 0, 699, 263, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3],
|
||||
[48000, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 0, 414, 0, 392, 3, 3, 0, 1269, 314, 0, 2607, 261, 3, 3, 3, 0, 1098, 295, 3, 3, 3, 0, 446, 625, 3, 0, 496, 280, 1205, 485, 1071, 1627, 449, 264, 261, 3, 0, 400, 0, 277, 3, 3, 3, 0, 260, 342, 3, 0, 618, 280, 1866, 3, 3, 0, 554, 3, 3, 3, 3, 0, 317, 262, 3, 3, 3, 3, 3, 3, 3, 3, 0, 269, 0, 303, 3, 0, 573, 2615, 3, 3, 0, 276, 3, 0, 275, 0, 305, 3, 0, 260, 415, 3, 3, 0, 272, 3, 3, 3, 3, 0, 1631, 327, 3, 3, 0, 333, 739, 841, 263, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3],
|
||||
])
|
||||
# fmt: on
|
||||
|
||||
torch.testing.assert_close(out.cpu(), EXPECTED_TOKENS)
|
@ -107,14 +107,21 @@ class MimiModelTester:
|
||||
self.sliding_window = sliding_window
|
||||
self.use_cache = use_cache
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
input_values = floats_tensor([self.batch_size, self.num_channels, self.intermediate_size], scale=1.0)
|
||||
def prepare_config_and_inputs(self, input_values_length=None):
|
||||
input_values = floats_tensor(
|
||||
[
|
||||
self.batch_size,
|
||||
self.num_channels,
|
||||
self.intermediate_size if input_values_length is None else input_values_length,
|
||||
],
|
||||
scale=1.0,
|
||||
)
|
||||
config = self.get_config()
|
||||
inputs_dict = {"input_values": input_values}
|
||||
return config, inputs_dict
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config, inputs_dict = self.prepare_config_and_inputs()
|
||||
def prepare_config_and_inputs_for_common(self, input_values_length=None):
|
||||
config, inputs_dict = self.prepare_config_and_inputs(input_values_length=input_values_length)
|
||||
return config, inputs_dict
|
||||
|
||||
def prepare_config_and_inputs_for_model_class(self, model_class):
|
||||
@ -508,6 +515,54 @@ class MimiIntegrationTest(unittest.TestCase):
|
||||
)
|
||||
self.assertTrue(rmse < 1e-3)
|
||||
|
||||
def test_integration_encode_with_padding_cache(self):
|
||||
"""
|
||||
We test here the possibility to run Mimi in a streaming manner, i.e. chunk by chunk.
|
||||
1. we encode a first time the entire audio
|
||||
2. we encode the audio chunk by chunk, each chunk being the smallest size possible for the model (i.e. the frame size)
|
||||
|
||||
This test must be run on CPU since GPU floating point operations accumulate rounding errors that cause test failures.
|
||||
"""
|
||||
librispeech_dummy = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
||||
|
||||
model_id = "kyutai/mimi"
|
||||
|
||||
model = MimiModel.from_pretrained(model_id, use_cache=True).to("cpu")
|
||||
processor = AutoFeatureExtractor.from_pretrained(model_id)
|
||||
|
||||
librispeech_dummy = librispeech_dummy.cast_column("audio", Audio(sampling_rate=processor.sampling_rate))
|
||||
audio_sample = librispeech_dummy[-1]["audio"]["array"]
|
||||
|
||||
inputs = processor(
|
||||
raw_audio=audio_sample,
|
||||
sampling_rate=processor.sampling_rate,
|
||||
return_tensors="pt",
|
||||
).to("cpu")
|
||||
|
||||
frame_size = model.config.frame_size
|
||||
audio_codes = model.encode(inputs["input_values"]).audio_codes
|
||||
|
||||
# streaming chunk by chunk
|
||||
encoder_past_key_values = None
|
||||
padding_cache = None
|
||||
encoded_frames_list = []
|
||||
|
||||
for start in range(0, inputs["input_values"].shape[-1], frame_size):
|
||||
input_values_chunk = inputs["input_values"][:, :, start : start + frame_size]
|
||||
encoder_outputs = model.encode(
|
||||
input_values_chunk,
|
||||
padding_cache=padding_cache,
|
||||
encoder_past_key_values=encoder_past_key_values,
|
||||
use_streaming=True,
|
||||
)
|
||||
encoder_past_key_values = encoder_outputs.encoder_past_key_values
|
||||
padding_cache = encoder_outputs.padding_cache
|
||||
encoded_frames_list.append(encoder_outputs.audio_codes)
|
||||
|
||||
streamed_audio_codes = torch.cat(encoded_frames_list, dim=-1)
|
||||
|
||||
torch.testing.assert_close(streamed_audio_codes, audio_codes)
|
||||
|
||||
def test_integration(self):
|
||||
expected_rmses = {
|
||||
"8": 0.0018785292,
|
||||
|
@ -3566,7 +3566,11 @@ class ModelTesterMixin:
|
||||
# TODO: if we can also check with `batch_size=1` without being flaky?
|
||||
for batch_size in [7]:
|
||||
# musicgen decoder models; TODO: find better abstraction
|
||||
if hasattr(self.model_tester, "num_codebooks") and not hasattr(model_eager, "text_encoder"):
|
||||
if (
|
||||
model.__class__.__name__.startswith("Musicgen")
|
||||
and hasattr(self.model_tester, "num_codebooks")
|
||||
and not hasattr(model_eager, "text_encoder")
|
||||
):
|
||||
input_data_batch_size = batch_size * self.model_tester.num_codebooks
|
||||
else:
|
||||
input_data_batch_size = batch_size
|
||||
@ -3626,7 +3630,7 @@ class ModelTesterMixin:
|
||||
|
||||
if is_encoder_decoder:
|
||||
# musicgen encoder-decoder models; TODO: find better abstraction
|
||||
if hasattr(self.model_tester, "num_codebooks"):
|
||||
if model.__class__.__name__.startswith("Musicgen") and hasattr(self.model_tester, "num_codebooks"):
|
||||
input_data_batch_size = batch_size * self.model_tester.num_codebooks
|
||||
else:
|
||||
input_data_batch_size = batch_size
|
||||
|
@ -619,7 +619,7 @@ ALL_FILE_TYPES = (
|
||||
"processing",
|
||||
"image_processing",
|
||||
"video_processing",
|
||||
"feature_extractor",
|
||||
"feature_extraction",
|
||||
)
|
||||
|
||||
|
||||
@ -1137,7 +1137,7 @@ TYPE_TO_FILE_TYPE = {
|
||||
"VideoProcessor": "video_processing",
|
||||
"VideoProcessorInitKwargs": "video_processing",
|
||||
"FastImageProcessorKwargs": "image_processing*_fast",
|
||||
"FeatureExtractor": "feature_extractor",
|
||||
"FeatureExtractor": "feature_extraction",
|
||||
"ProcessorKwargs": "processing",
|
||||
"VideosKwargs": "processing",
|
||||
"ImagesKwargs": "processing",
|
||||
|
Loading…
Reference in New Issue
Block a user