Codec integration (#33565)

* clean mimi commit

* some nits suggestions from Arthur

* make fixup

* rename repo id + change readme

* Update docs/source/en/model_doc/mimi.md

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* add flaky flag to batching equivalence due to audio_codes failing sometimes

---------

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
Yoach Lacombe 2024-09-18 19:23:44 +02:00 committed by GitHub
parent 6019f3ff78
commit 5af7d41e49
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
16 changed files with 3208 additions and 0 deletions

View File

@ -722,6 +722,8 @@
title: Hubert
- local: model_doc/mctct
title: MCTCT
- local: model_doc/mimi
title: Mimi
- local: model_doc/mms
title: MMS
- local: model_doc/musicgen

View File

@ -210,6 +210,7 @@ Flax), PyTorch, and/or TensorFlow.
| [Megatron-BERT](model_doc/megatron-bert) | ✅ | ❌ | ❌ |
| [Megatron-GPT2](model_doc/megatron_gpt2) | ✅ | ✅ | ✅ |
| [MGP-STR](model_doc/mgp-str) | ✅ | ❌ | ❌ |
| [Mimi](model_doc/mimi) | ✅ | ❌ | ❌ |
| [Mistral](model_doc/mistral) | ✅ | ✅ | ✅ |
| [Mixtral](model_doc/mixtral) | ✅ | ❌ | ❌ |
| [mLUKE](model_doc/mluke) | ✅ | ❌ | ❌ |

View File

@ -0,0 +1,69 @@
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.
-->
# Mimi
## Overview
The Mimi model was proposed in [Moshi: a speech-text foundation model for real-time dialogue](https://kyutai.org/Moshi.pdf) by Alexandre Défossez, Laurent Mazaré, Manu Orsini, Amélie Royer, Patrick Pérez, Hervé Jégou, Edouard Grave and Neil Zeghidour. Mimi is a high-fidelity audio codec model developed by the Kyutai team, that combines semantic and acoustic information into audio tokens running at 12Hz and a bitrate of 1.1kbps. In other words, it can be used to map audio waveforms into “audio tokens”, known as “codebooks”.
The abstract from the paper is the following:
*We introduce Moshi, a speech-text foundation model and full-duplex spoken dialogue framework. Current systems for spoken dialogue rely on pipelines of independent components, namely voice activity detection, speech recognition, textual dialogue and text-to-speech. Such frameworks cannot emulate the experience of real conversations. First, their complexity induces a latency of several seconds between interactions. Second, text being the intermediate modality for dialogue, non-linguistic information that modifies meaning— such as emotion or non-speech sounds— is lost in the interaction. Finally, they rely on a segmentation into speaker turns, which does not take into account overlapping speech, interruptions and interjections. Moshi solves these independent issues altogether by casting spoken dialogue as speech-to-speech generation. Starting from a text language model backbone, Moshi generates speech as tokens from the residual quantizer of a neural audio codec, while modeling separately its own speech and that of the user into parallel streams. This allows for the removal of explicit speaker turns, and the modeling of arbitrary conversational dynamics. We moreover extend the hierarchical semantic-to-acoustic token generation of previous work to first predict time-aligned text tokens as a prefix to audio tokens. Not only this “Inner Monologue” method significantly improves the linguistic quality of generated speech, but we also illustrate how it can provide streaming speech recognition and text-to-speech. Our resulting model is the first real-time full-duplex spoken large language model, with a theoretical latency of 160ms, 200ms in practice, and is available at github.com/kyutai-labs/moshi.*
Its architecture is based on [Encodec](model_doc/encodec) with several major differences:
* it uses a much lower frame-rate.
* it uses additional transformers for encoding and decoding for better latent contextualization
* it uses a different quantization scheme: one codebook is dedicated to semantic projection.
## Usage example
Here is a quick example of how to encode and decode an audio using this model:
```python
>>> from datasets import load_dataset, Audio
>>> from transformers import MimiModel, AutoFeatureExtractor
>>> librispeech_dummy = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
>>> # load model and feature extractor
>>> model = MimiModel.from_pretrained("kyutai/mimi")
>>> feature_extractor = AutoFeatureExtractor.from_pretrained("kyutai/mimi")
>>> # load audio sample
>>> librispeech_dummy = librispeech_dummy.cast_column("audio", Audio(sampling_rate=feature_extractor.sampling_rate))
>>> audio_sample = librispeech_dummy[-1]["audio"]["array"]
>>> inputs = feature_extractor(raw_audio=audio_sample, sampling_rate=feature_extractor.sampling_rate, return_tensors="pt")
>>> encoder_outputs = model.encode(inputs["input_values"], inputs["padding_mask"])
>>> audio_values = model.decode(encoder_outputs.audio_codes, inputs["padding_mask"])[0]
>>> # or the equivalent with a forward pass
>>> audio_values = model(inputs["input_values"], inputs["padding_mask"]).audio_values
```
This model was contributed by [Yoach Lacombe (ylacombe)](https://huggingface.co/ylacombe).
The original code can be found [here](https://github.com/kyutai-labs/moshi).
## MimiConfig
[[autodoc]] MimiConfig
## MimiModel
[[autodoc]] MimiModel
- decode
- encode
- forward

View File

@ -61,6 +61,7 @@ FlashAttention-2 is currently supported for the following architectures:
* [Llava-NeXT](https://huggingface.co/docs/transformers/model_doc/llava_next)
* [Llava-NeXT-Video](https://huggingface.co/docs/transformers/model_doc/llava_next_video)
* [LLaVA-Onevision](https://huggingface.co/docs/transformers/model_doc/llava_onevision)
* [Mimi](https://huggingface.co/docs/transformers/model_doc/mimi)
* [VipLlava](https://huggingface.co/docs/transformers/model_doc/vipllava)
* [VideoLlava](https://huggingface.co/docs/transformers/model_doc/video_llava)
* [M2M100](https://huggingface.co/docs/transformers/model_doc/m2m_100)
@ -228,6 +229,7 @@ For now, Transformers supports SDPA inference and training for the following arc
* [Jamba](https://huggingface.co/docs/transformers/model_doc/jamba#transformers.JambaModel)
* [Llama](https://huggingface.co/docs/transformers/model_doc/llama#transformers.LlamaModel)
* [LLaVA-Onevision](https://huggingface.co/docs/transformers/model_doc/llava_onevision)
* [Mimi](https://huggingface.co/docs/transformers/model_doc/mimi)
* [Mistral](https://huggingface.co/docs/transformers/model_doc/mistral#transformers.MistralModel)
* [Mixtral](https://huggingface.co/docs/transformers/model_doc/mixtral#transformers.MixtralModel)
* [Musicgen](https://huggingface.co/docs/transformers/model_doc/musicgen#transformers.MusicgenModel)

View File

@ -573,6 +573,7 @@ _import_structure = {
"MgpstrProcessor",
"MgpstrTokenizer",
],
"models.mimi": ["MimiConfig"],
"models.mistral": ["MistralConfig"],
"models.mixtral": ["MixtralConfig"],
"models.mluke": [],
@ -2666,6 +2667,12 @@ else:
"MgpstrPreTrainedModel",
]
)
_import_structure["models.mimi"].extend(
[
"MimiModel",
"MimiPreTrainedModel",
]
)
_import_structure["models.mistral"].extend(
[
"MistralForCausalLM",
@ -5345,6 +5352,9 @@ if TYPE_CHECKING:
MgpstrProcessor,
MgpstrTokenizer,
)
from .models.mimi import (
MimiConfig,
)
from .models.mistral import MistralConfig
from .models.mixtral import MixtralConfig
from .models.mobilebert import (
@ -7212,6 +7222,10 @@ if TYPE_CHECKING:
MgpstrModel,
MgpstrPreTrainedModel,
)
from .models.mimi import (
MimiModel,
MimiPreTrainedModel,
)
from .models.mistral import (
MistralForCausalLM,
MistralForSequenceClassification,

View File

@ -149,6 +149,7 @@ from . import (
megatron_bert,
megatron_gpt2,
mgp_str,
mimi,
mistral,
mixtral,
mluke,

View File

@ -167,6 +167,7 @@ CONFIG_MAPPING_NAMES = OrderedDict(
("mega", "MegaConfig"),
("megatron-bert", "MegatronBertConfig"),
("mgp-str", "MgpstrConfig"),
("mimi", "MimiConfig"),
("mistral", "MistralConfig"),
("mixtral", "MixtralConfig"),
("mobilebert", "MobileBertConfig"),
@ -468,6 +469,7 @@ MODEL_NAMES_MAPPING = OrderedDict(
("megatron-bert", "Megatron-BERT"),
("megatron_gpt2", "Megatron-GPT2"),
("mgp-str", "MGP-STR"),
("mimi", "Mimi"),
("mistral", "Mistral"),
("mixtral", "Mixtral"),
("mluke", "mLUKE"),

View File

@ -69,6 +69,7 @@ FEATURE_EXTRACTOR_MAPPING_NAMES = OrderedDict(
("levit", "LevitFeatureExtractor"),
("maskformer", "MaskFormerFeatureExtractor"),
("mctct", "MCTCTFeatureExtractor"),
("mimi", "EncodecFeatureExtractor"),
("mobilenet_v1", "MobileNetV1FeatureExtractor"),
("mobilenet_v2", "MobileNetV2FeatureExtractor"),
("mobilevit", "MobileViTFeatureExtractor"),

View File

@ -158,6 +158,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
("mega", "MegaModel"),
("megatron-bert", "MegatronBertModel"),
("mgp-str", "MgpstrForSceneTextRecognition"),
("mimi", "MimiModel"),
("mistral", "MistralModel"),
("mixtral", "MixtralModel"),
("mobilebert", "MobileBertModel"),

View File

@ -0,0 +1,57 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
from ...utils import (
OptionalDependencyNotAvailable,
_LazyModule,
is_torch_available,
)
_import_structure = {
"configuration_mimi": ["MimiConfig"],
}
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_mimi"] = [
"MimiModel",
"MimiPreTrainedModel",
]
if TYPE_CHECKING:
from .configuration_mimi import (
MimiConfig,
)
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_mimi import (
MimiModel,
MimiPreTrainedModel,
)
else:
import sys
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)

View File

@ -0,0 +1,234 @@
# coding=utf-8
# Copyright 2024 Meta Platforms, Inc. and affiliates, and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Mimi model configuration"""
import math
import numpy as np
from ...configuration_utils import PretrainedConfig
from ...utils import logging
logger = logging.get_logger(__name__)
class MimiConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of an [`MimiModel`]. It is used to instantiate a
Mimi model according to the specified arguments, defining the model architecture. Instantiating a configuration
with the defaults will yield a similar configuration to that of the
[kyutai/mimi](https://huggingface.co/kyutai/mimi) architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
sampling_rate (`int`, *optional*, defaults to 24000):
The sampling rate at which the audio waveform should be digitalized expressed in hertz (Hz).
frame_rate (`float`, *optional*, defaults to 12.5):
Framerate of the model.
audio_channels (`int`, *optional*, defaults to 1):
Number of channels in the audio data. Either 1 for mono or 2 for stereo.
hidden_size (`int`, *optional*, defaults to 512):
Intermediate representation dimension.
num_filters (`int`, *optional*, defaults to 64):
Number of convolution kernels of first `MimiConv1d` down sampling layer.
num_residual_layers (`int`, *optional*, defaults to 1):
Number of residual layers.
upsampling_ratios (`Sequence[int]`, *optional*):
Kernel size and stride ratios. The encoder uses downsampling ratios instead of upsampling ratios, hence it
will use the ratios in the reverse order to the ones specified here that must match the decoder order.
If not specified, will defaults to `[8, 6, 5, 4]`
kernel_size (`int`, *optional*, defaults to 7):
Kernel size for the initial convolution.
last_kernel_size (`int`, *optional*, defaults to 3):
Kernel size for the last convolution layer.
residual_kernel_size (`int`, *optional*, defaults to 3):
Kernel size for the residual layers.
dilation_growth_rate (`int`, *optional*, defaults to 2):
How much to increase the dilation with each layer.
use_causal_conv (`bool`, *optional*, defaults to `True`):
Whether to use fully causal convolution.
pad_mode (`str`, *optional*, defaults to `"constant"`):
Padding mode for the convolutions.
compress (`int`, *optional*, defaults to 2):
Reduced dimensionality in residual branches.
trim_right_ratio (`float`, *optional*, defaults to 1.0):
Ratio for trimming at the right of the transposed convolution under the `use_causal_conv = True` setup. If
equal to 1.0, it means that all the trimming is done at the right.
codebook_size (`int`, *optional*, defaults to 2048):
Number of discret codes in each codebooks.
codebook_dim (`int`, *optional*, defaults to 256):
Dimension of the unquantized codebook vectors. If not defined, uses `hidden_size`.
num_quantizers (`int`, *optional*, defaults to 32):
Number of quantizer channels, or codebooks, in the quantizer.
use_conv_shortcut (`bool`, *optional*, defaults to `False`):
Whether to use a convolutional layer as the 'skip' connection in the `MimiResnetBlock` block. If False,
an identity function will be used, giving a generic residual connection.
vector_quantization_hidden_dimension (`int`, *optional*, defaults to 256):
Intermediate representation dimension in the residual vector quantization space.
num_semantic_quantizers (`int`, *optional*, defaults to 1):
Number of semantic quantizer channels, or codebooks, in the semantic quantizer. Must be lower than `num_quantizers`.
upsample_groups (`int`, *optional*, defaults to 512):
If `frame_rate!=encodec_frame_rate`, indicates the number of groups used in the upsampling operation to go from one rate to another.
num_hidden_layers (`int`, *optional*, defaults to 8):
Number of hidden layers in the Transformer models.
intermediate_size (`int`, *optional*, defaults to 2048):
Dimension of the MLP representations.
num_attention_heads (`int`, *optional*, defaults to 8):
Number of attention heads for each attention layer in the Transformer encoder.
num_key_value_heads (`int`, *optional*, defaults to 8):
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
by meanpooling all the original heads within that group. For more details checkout [this
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `8`.
head_dim (`int`, *optional*, defaults to `hidden_size // num_attention_heads`):
The attention head dimension.
hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
The non-linear activation function (function or string) in the decoder.
max_position_embeddings (`int`, *optional*, defaults to 8000):
The maximum sequence length that this model might ever be used with. Mimi's sliding window attention
allows sequence of up to 8000 tokens.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
norm_eps (`float`, *optional*, defaults to 1e-05):
The epsilon used by the LayerNorm normalization layers.
use_cache (`bool`, *optional*, defaults to `False`):
Whether or not the model should return the last key/values attentions (not used by all models). Only
relevant if `config.is_decoder=True`.
rope_theta (`float`, *optional*, defaults to 10000.0):
The base period of the RoPE embeddings.
sliding_window (`int`, *optional*, defaults to 250):
Sliding window attention window size. If not specified, will default to `250`.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
layer_scale_initial_scale (`float`, *optional*, defaults to 0.01):
Initiale scale of the residual rescaling operation done in the Transformer models.
attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
Whether to use a bias in the query, key, value and output projection layers during self-attention.
Example:
```python
>>> from transformers import MimiModel, MimiConfig
>>> # Initializing a "kyutai/mimi" style configuration
>>> configuration = MimiConfig()
>>> # Initializing a model (with random weights) from the "kyutai/mimi" style configuration
>>> model = MimiModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "mimi"
def __init__(
self,
sampling_rate=24_000,
frame_rate=12.5,
audio_channels=1,
hidden_size=512,
num_filters=64,
num_residual_layers=1,
upsampling_ratios=None,
kernel_size=7,
last_kernel_size=3,
residual_kernel_size=3,
dilation_growth_rate=2,
use_causal_conv=True,
pad_mode="constant",
compress=2,
trim_right_ratio=1.0,
codebook_size=2048,
codebook_dim=256,
num_quantizers=32,
use_conv_shortcut=False,
vector_quantization_hidden_dimension=256,
num_semantic_quantizers=1,
upsample_groups=512,
num_hidden_layers=8,
intermediate_size=2048,
num_attention_heads=8,
num_key_value_heads=8,
head_dim=None,
hidden_act="gelu",
max_position_embeddings=8000,
initializer_range=0.02,
norm_eps=1e-5,
use_cache=False,
rope_theta=10000.0,
sliding_window=250,
attention_dropout=0.0,
layer_scale_initial_scale=0.01,
attention_bias=False,
**kwargs,
):
self.sampling_rate = sampling_rate
self.frame_rate = frame_rate
self.audio_channels = audio_channels
self.hidden_size = hidden_size
self.num_filters = num_filters
self.num_residual_layers = num_residual_layers
self.upsampling_ratios = upsampling_ratios if upsampling_ratios else [8, 6, 5, 4]
self.kernel_size = kernel_size
self.last_kernel_size = last_kernel_size
self.residual_kernel_size = residual_kernel_size
self.dilation_growth_rate = dilation_growth_rate
self.use_causal_conv = use_causal_conv
self.pad_mode = pad_mode
self.compress = compress
self.trim_right_ratio = trim_right_ratio
self.codebook_size = codebook_size
self.codebook_dim = codebook_dim if codebook_dim is not None else hidden_size
self.num_quantizers = num_quantizers
self.use_conv_shortcut = use_conv_shortcut
self.vector_quantization_hidden_dimension = vector_quantization_hidden_dimension
self.upsample_groups = upsample_groups
self.num_hidden_layers = num_hidden_layers
self.intermediate_size = intermediate_size
self.num_attention_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.max_position_embeddings = max_position_embeddings
self.initializer_range = initializer_range
self.norm_eps = norm_eps
self.use_cache = use_cache
self.rope_theta = rope_theta
self.sliding_window = sliding_window
self.attention_dropout = attention_dropout
self.head_dim = head_dim or hidden_size // num_attention_heads
self.layer_scale_initial_scale = layer_scale_initial_scale
self.attention_bias = attention_bias
if num_semantic_quantizers >= self.num_quantizers:
raise ValueError(
f"The number of semantic quantizers should be lower than the total number of quantizers {self.num_quantizers}, but is currently {num_semantic_quantizers}."
)
self.num_semantic_quantizers = num_semantic_quantizers
super().__init__(**kwargs)
@property
def encodec_frame_rate(self) -> int:
hop_length = np.prod(self.upsampling_ratios)
return math.ceil(self.sampling_rate / hop_length)
@property
def num_codebooks(self) -> int:
# alias to num_quantizers
return self.num_quantizers

View File

@ -0,0 +1,198 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert Mimi checkpoints."""
import argparse
import safetensors
import torch
from transformers import (
EncodecFeatureExtractor,
MimiConfig,
MimiModel,
logging,
)
logging.set_verbosity_info()
logger = logging.get_logger("transformers.models.mimi")
def assert_param_count(model_1, model_2):
count_1 = sum(p[1].numel() for p in model_1.named_parameters() if "final_proj" not in p[0])
count_2 = sum(p[1].numel() for p in model_2.named_parameters() if "final_proj" not in p[0])
assert count_1 == count_2, f"{model_1.__class__}: {count_1} != {model_2.__class__}: {count_2}"
def param_count(model):
return sum(p[1].numel() for p in model.named_parameters() if "final_proj" not in p[0])
def _grab_best_device(use_gpu=True):
if torch.cuda.device_count() > 0 and use_gpu:
device = "cuda"
else:
device = "cpu"
return torch.device(device)
convert_list = [
# GENERAL
("conv.conv.conv", "conv"),
("convtr.convtr.convtr", "conv"),
("conv.conv", "conv"),
("convtr.convtr", "conv"),
# QUANTIZER
("quantizer.rvq_first.vq", "quantizer.semantic_residual_vector_quantizer"),
("quantizer.rvq_first", "quantizer.semantic_residual_vector_quantizer"),
("quantizer.rvq_rest.vq", "quantizer.acoustic_residual_vector_quantizer"),
("quantizer.rvq_rest", "quantizer.acoustic_residual_vector_quantizer"),
("_codebook", "codebook"),
("_initialized", "initialized"),
("embedding_sum", "embed_sum"),
# ENCODER PART
("encoder.model", "encoder.layers"),
("decoder.model", "decoder.layers"),
# TRANSFORMERS PART
("encoder_transformer.transformer", "encoder_transformer"),
("decoder_transformer.transformer", "decoder_transformer"),
("linear1", "mlp.fc1"),
("linear2", "mlp.fc2"),
("self_attn.out_proj", "self_attn.o_proj"),
("norm1", "input_layernorm"),
("norm2", "post_attention_layernorm"),
("layer_scale_1", "self_attn_layer_scale"),
("layer_scale_2", "mlp_layer_scale"),
]
def _convert_model(
state_dict,
hf_model,
convert_list,
device,
config,
unwanted_prefix=None,
):
hidden_size = config.hidden_size
head_dim = config.head_dim
num_heads = int(config.hidden_size // config.head_dim)
num_key_value_heads = config.num_key_value_heads
key_value_head_dim = config.num_key_value_heads * head_dim
# permute for sliced rotary
def permute(w, n_heads, dim1=hidden_size, dim2=hidden_size):
return w.view(n_heads, dim1 // n_heads // 2, 2, dim2).transpose(1, 2).reshape(dim1, dim2)
for k, v in list(state_dict.items()):
new_k = k if unwanted_prefix is None else k[len(unwanted_prefix) :]
for old_layer_name, new_layer_name in convert_list:
if old_layer_name in new_k:
new_k = new_k.replace(old_layer_name, new_layer_name)
if "in_proj_weight" in new_k:
# split qkv into query key and value
mixed_qkv = state_dict.pop(k)
qkv_dim = mixed_qkv.size(0) // 3
query_layer = mixed_qkv[:qkv_dim]
key_layer = mixed_qkv[qkv_dim : qkv_dim * 2]
value_layer = mixed_qkv[qkv_dim * 2 :]
state_dict[new_k.replace("in_proj_weight", "q_proj.weight")] = permute(query_layer, num_heads)
state_dict[new_k.replace("in_proj_weight", "k_proj.weight")] = permute(
key_layer, num_key_value_heads, dim1=key_value_head_dim
)
state_dict[new_k.replace("in_proj_weight", "v_proj.weight")] = value_layer
else:
state_dict[new_k] = state_dict.pop(k)
extra_keys = set(state_dict.keys()) - set(hf_model.state_dict().keys())
missing_keys = set(hf_model.state_dict().keys()) - set(state_dict.keys())
if len(extra_keys) != 0:
raise ValueError(f"extra keys found: {extra_keys}")
if len(missing_keys) != 0:
raise ValueError(f"missing keys: {missing_keys}")
hf_model.load_state_dict(state_dict, strict=True)
n_params = param_count(hf_model)
logger.info(f"model loaded: {round(n_params/1e6,1)}M params")
hf_model.eval()
hf_model.to(device)
del state_dict
return hf_model
@torch.no_grad()
def convert_checkpoint(
checkpoint_path,
pytorch_dump_folder_path,
config_path=None,
repo_id=None,
):
"""
Copy/paste/tweak model's weights to transformers design.
"""
device = _grab_best_device()
if config_path is not None:
config = MimiConfig.from_pretrained(config_path)
else:
config = MimiConfig()
model = MimiModel(config)
feature_extractor = EncodecFeatureExtractor(
feature_size=config.audio_channels,
sampling_rate=config.sampling_rate,
)
feature_extractor.save_pretrained(pytorch_dump_folder_path)
original_checkpoint = safetensors.torch.load_file(checkpoint_path)
if "best_state" in original_checkpoint:
# we might have a training state saved, in which case discard the yaml results and just retain the weights
original_checkpoint = original_checkpoint["best_state"]
model = _convert_model(original_checkpoint, model, convert_list, device, config)
model.save_pretrained(pytorch_dump_folder_path)
if repo_id:
print("Pushing to the hub...")
feature_extractor.push_to_hub(repo_id)
model.push_to_hub(repo_id)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--checkpoint_path", required=True, default=None, type=str, help="Path to original checkpoint")
parser.add_argument("--config_path", default=None, type=str, help="Path to hf config.json of model to convert")
parser.add_argument(
"--pytorch_dump_folder_path", required=True, default=None, type=str, help="Path to the output PyTorch model."
)
parser.add_argument(
"--push_to_hub", default=None, type=str, help="Where to upload the converted model on the 🤗 hub."
)
args = parser.parse_args()
convert_checkpoint(
args.checkpoint_path,
args.pytorch_dump_folder_path,
args.config_path,
args.push_to_hub,
)

File diff suppressed because it is too large Load Diff

View File

@ -5840,6 +5840,20 @@ class MgpstrPreTrainedModel(metaclass=DummyObject):
requires_backends(self, ["torch"])
class MimiModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class MimiPreTrainedModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class MistralForCausalLM(metaclass=DummyObject):
_backends = ["torch"]

View File

View File

@ -0,0 +1,890 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Testing suite for the PyTorch Mimi model."""
import inspect
import os
import tempfile
import unittest
import numpy as np
from datasets import Audio, load_dataset
from packaging import version
from parameterized import parameterized
from pytest import mark
from transformers import AutoFeatureExtractor, MimiConfig
from transformers.testing_utils import (
is_flaky,
is_torch_available,
require_flash_attn,
require_torch,
require_torch_gpu,
require_torch_sdpa,
slow,
torch_device,
)
from transformers.utils import (
is_torch_bf16_available_on_device,
is_torch_fp16_available_on_device,
)
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor, ids_tensor
if is_torch_available():
import torch
from transformers import MimiModel
# Copied from transformers.tests.encodec.test_modeling_encodec.prepare_inputs_dict
def prepare_inputs_dict(
config,
input_ids=None,
input_values=None,
decoder_input_ids=None,
attention_mask=None,
decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
cross_attn_head_mask=None,
):
if input_ids is not None:
encoder_dict = {"input_ids": input_ids}
else:
encoder_dict = {"input_values": input_values}
decoder_dict = {"decoder_input_ids": decoder_input_ids} if decoder_input_ids is not None else {}
return {**encoder_dict, **decoder_dict}
@require_torch
class MimiModelTester:
def __init__(
self,
parent,
batch_size=5,
num_channels=1,
is_training=False,
intermediate_size=40,
hidden_size=32,
num_filters=8,
num_residual_layers=1,
upsampling_ratios=[8, 4],
codebook_size=64,
vector_quantization_hidden_dimension=64,
codebook_dim=64,
upsample_groups=32,
num_hidden_layers=2,
num_attention_heads=2,
num_key_value_heads=2,
sliding_window=4,
use_cache=False,
):
self.parent = parent
self.batch_size = batch_size
self.num_channels = num_channels
self.is_training = is_training
self.intermediate_size = intermediate_size
self.hidden_size = hidden_size
self.num_filters = num_filters
self.num_residual_layers = num_residual_layers
self.upsampling_ratios = upsampling_ratios
self.codebook_size = codebook_size
self.vector_quantization_hidden_dimension = vector_quantization_hidden_dimension
self.codebook_dim = codebook_dim
self.upsample_groups = upsample_groups
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.sliding_window = sliding_window
self.use_cache = use_cache
def prepare_config_and_inputs(self):
input_values = floats_tensor([self.batch_size, self.num_channels, self.intermediate_size], scale=1.0)
config = self.get_config()
inputs_dict = {"input_values": input_values}
return config, inputs_dict
def prepare_config_and_inputs_for_common(self):
config, inputs_dict = self.prepare_config_and_inputs()
return config, inputs_dict
def prepare_config_and_inputs_for_model_class(self, model_class):
config, inputs_dict = self.prepare_config_and_inputs()
inputs_dict["audio_codes"] = ids_tensor([self.batch_size, 1, self.num_channels], self.codebook_size).type(
torch.int32
)
return config, inputs_dict
def get_config(self):
return MimiConfig(
audio_channels=self.num_channels,
chunk_in_sec=None,
hidden_size=self.hidden_size,
num_filters=self.num_filters,
num_residual_layers=self.num_residual_layers,
upsampling_ratios=self.upsampling_ratios,
codebook_size=self.codebook_size,
vector_quantization_hidden_dimension=self.vector_quantization_hidden_dimension,
upsample_groups=self.upsample_groups,
num_hidden_layers=self.num_hidden_layers,
num_attention_heads=self.num_attention_heads,
num_key_value_heads=self.num_key_value_heads,
sliding_window=self.sliding_window,
codebook_dim=self.codebook_dim,
use_cache=self.use_cache,
)
def create_and_check_model_forward(self, config, inputs_dict):
model = MimiModel(config=config).to(torch_device).eval()
input_values = inputs_dict["input_values"]
result = model(input_values)
self.parent.assertEqual(
result.audio_values.shape, (self.batch_size, self.num_channels, self.intermediate_size)
)
@require_torch
class MimiModelTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = (MimiModel,) if is_torch_available() else ()
is_encoder_decoder = True
test_pruning = False
test_headmasking = False
test_resize_embeddings = False
test_torchscript = False
input_name = "input_values"
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
# model does support returning hidden states
inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)
if "output_attentions" in inputs_dict:
inputs_dict.pop("output_attentions")
if "output_hidden_states" in inputs_dict:
inputs_dict.pop("output_hidden_states")
return inputs_dict
def setUp(self):
self.model_tester = MimiModelTester(self)
self.config_tester = ConfigTester(
self, config_class=MimiConfig, hidden_size=37, common_properties=[], has_text_modality=False
)
def test_config(self):
self.config_tester.run_common_tests()
def test_model_forward(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model_forward(*config_and_inputs)
def test_forward_signature(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
signature = inspect.signature(model.forward)
# signature.parameters is an OrderedDict => so arg_names order is deterministic
arg_names = [*signature.parameters.keys()]
expected_arg_names = ["input_values", "padding_mask", "num_quantizers"]
self.assertListEqual(arg_names[: len(expected_arg_names)], expected_arg_names)
@unittest.skip(reason="The MimiModel does not have `inputs_embeds` logics")
def test_inputs_embeds(self):
pass
@unittest.skip(reason="The MimiModel does not have `inputs_embeds` logics")
def test_model_get_set_embeddings(self):
pass
@unittest.skip(reason="The MimiModel does not have the usual `attention` logic")
def test_retain_grad_hidden_states_attentions(self):
pass
@unittest.skip(reason="The MimiModel does not have the usual `attention` logic")
def test_torchscript_output_attentions(self):
pass
@unittest.skip(reason="The MimiModel does not have the usual `hidden_states` logic")
def test_torchscript_output_hidden_state(self):
pass
# Copied from transformers.tests.encodec.test_modeling_encodec.MimiModelTest._create_and_check_torchscript
def _create_and_check_torchscript(self, config, inputs_dict):
if not self.test_torchscript:
self.skipTest(reason="test_torchscript is set to False")
configs_no_init = _config_zero_init(config) # To be sure we have no Nan
configs_no_init.torchscript = True
configs_no_init.return_dict = False
for model_class in self.all_model_classes:
model = model_class(config=configs_no_init)
model.to(torch_device)
model.eval()
inputs = self._prepare_for_class(inputs_dict, model_class)
main_input_name = model_class.main_input_name
try:
main_input = inputs[main_input_name]
model(main_input)
traced_model = torch.jit.trace(model, main_input)
except RuntimeError:
self.fail("Couldn't trace module.")
with tempfile.TemporaryDirectory() as tmp_dir_name:
pt_file_name = os.path.join(tmp_dir_name, "traced_model.pt")
try:
torch.jit.save(traced_model, pt_file_name)
except Exception:
self.fail("Couldn't save module.")
try:
loaded_model = torch.jit.load(pt_file_name)
except Exception:
self.fail("Couldn't load module.")
model.to(torch_device)
model.eval()
loaded_model.to(torch_device)
loaded_model.eval()
model_state_dict = model.state_dict()
loaded_model_state_dict = loaded_model.state_dict()
non_persistent_buffers = {}
for key in loaded_model_state_dict.keys():
if key not in model_state_dict.keys():
non_persistent_buffers[key] = loaded_model_state_dict[key]
loaded_model_state_dict = {
key: value for key, value in loaded_model_state_dict.items() if key not in non_persistent_buffers
}
self.assertEqual(set(model_state_dict.keys()), set(loaded_model_state_dict.keys()))
model_buffers = list(model.buffers())
for non_persistent_buffer in non_persistent_buffers.values():
found_buffer = False
for i, model_buffer in enumerate(model_buffers):
if torch.equal(non_persistent_buffer, model_buffer):
found_buffer = True
break
self.assertTrue(found_buffer)
model_buffers.pop(i)
model_buffers = list(model.buffers())
for non_persistent_buffer in non_persistent_buffers.values():
found_buffer = False
for i, model_buffer in enumerate(model_buffers):
if torch.equal(non_persistent_buffer, model_buffer):
found_buffer = True
break
self.assertTrue(found_buffer)
model_buffers.pop(i)
models_equal = True
for layer_name, p1 in model_state_dict.items():
if layer_name in loaded_model_state_dict:
p2 = loaded_model_state_dict[layer_name]
if p1.data.ne(p2.data).sum() > 0:
models_equal = False
self.assertTrue(models_equal)
# Avoid memory leak. Without this, each call increase RAM usage by ~20MB.
# (Even with this call, there are still memory leak by ~0.04MB)
self.clear_torch_jit_class_registry()
@unittest.skip(reason="The MimiModel does not have the usual `attention` logic")
def test_attention_outputs(self):
pass
@unittest.skip(reason="The MimiModel does not have the usual `hidden_states` logic")
def test_hidden_states_output(self):
pass
# Copied from transformers.tests.encodec.test_modeling_encodec.MimiModelTest.test_determinism
def test_determinism(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
def check_determinism(first, second):
# outputs are not tensors but list (since each sequence don't have the same frame_length)
out_1 = first.cpu().numpy()
out_2 = second.cpu().numpy()
out_1 = out_1[~np.isnan(out_1)]
out_2 = out_2[~np.isnan(out_2)]
max_diff = np.amax(np.abs(out_1 - out_2))
self.assertLessEqual(max_diff, 1e-5)
for model_class in self.all_model_classes:
model = model_class(config)
model.to(torch_device)
model.eval()
with torch.no_grad():
first = model(**self._prepare_for_class(inputs_dict, model_class))[0]
second = model(**self._prepare_for_class(inputs_dict, model_class))[0]
if isinstance(first, tuple) and isinstance(second, tuple):
for tensor1, tensor2 in zip(first, second):
check_determinism(tensor1, tensor2)
else:
check_determinism(first, second)
# Copied from transformers.tests.encodec.test_modeling_encodec.MimiModelTest.test_model_outputs_equivalence
def test_model_outputs_equivalence(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
def set_nan_tensor_to_zero(t):
t[t != t] = 0
return t
def check_equivalence(model, tuple_inputs, dict_inputs, additional_kwargs={}):
with torch.no_grad():
tuple_output = model(**tuple_inputs, return_dict=False, **additional_kwargs)
dict_output = model(**dict_inputs, return_dict=True, **additional_kwargs)
self.assertTrue(isinstance(tuple_output, tuple))
self.assertTrue(isinstance(dict_output, dict))
for tuple_value, dict_value in zip(tuple_output, dict_output.values()):
self.assertTrue(
torch.allclose(
set_nan_tensor_to_zero(tuple_value), set_nan_tensor_to_zero(dict_value), atol=1e-5
),
msg=(
"Tuple and dict output are not equal. Difference:"
f" {torch.max(torch.abs(tuple_value - dict_value))}. Tuple has `nan`:"
f" {torch.isnan(tuple_value).any()} and `inf`: {torch.isinf(tuple_value)}. Dict has"
f" `nan`: {torch.isnan(dict_value).any()} and `inf`: {torch.isinf(dict_value)}."
),
)
for model_class in self.all_model_classes:
model = model_class(config)
model.to(torch_device)
model.eval()
tuple_inputs = self._prepare_for_class(inputs_dict, model_class)
dict_inputs = self._prepare_for_class(inputs_dict, model_class)
check_equivalence(model, tuple_inputs, dict_inputs)
def test_initialization(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
configs_no_init = _config_zero_init(config)
for model_class in self.all_model_classes:
model = model_class(config=configs_no_init)
for name, param in model.named_parameters():
uniform_init_parms = ["conv", "input_proj", "output_proj"]
if param.requires_grad:
if any(x in name for x in uniform_init_parms):
self.assertTrue(
-1.0 <= ((param.data.mean() * 1e9).round() / 1e9).item() <= 1.0,
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
)
# Copied from transformers.tests.encodec.test_modeling_encodec.MimiModelTest.test_identity_shortcut
def test_identity_shortcut(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs()
config.use_conv_shortcut = False
self.model_tester.create_and_check_model_forward(config, inputs_dict)
@parameterized.expand([("float16",), ("bfloat16",), ("float32",)])
@require_torch_sdpa
@slow
def test_eager_matches_sdpa_inference(self, torch_dtype: str):
if not self.has_attentions:
self.skipTest(reason="Model architecture does not support attentions")
if not self.all_model_classes[0]._supports_sdpa:
self.skipTest(f"{self.all_model_classes[0].__name__} does not support SDPA")
if torch_dtype == "float16" and not is_torch_fp16_available_on_device(torch_device):
self.skipTest(f"float16 not supported on {torch_device} (on the specific device currently used)")
if torch_dtype == "bfloat16" and not is_torch_bf16_available_on_device(torch_device):
self.skipTest(
f"bfloat16 not supported on {torch_device} (on the specific device currently used, e.g. Nvidia T4 GPU)"
)
# Not sure whether it's fine to put torch.XXX in a decorator if torch is not available so hacking it here instead.
if torch_dtype == "float16":
torch_dtype = torch.float16
elif torch_dtype == "bfloat16":
torch_dtype = torch.bfloat16
elif torch_dtype == "float32":
torch_dtype = torch.float32
atols = {
("cpu", False, torch.float32): 1e-6,
("cpu", False, torch.bfloat16): 1e-2,
("cpu", True, torch.float32): 1e-6,
("cpu", True, torch.bfloat16): 1e-2,
("cuda", False, torch.float32): 1e-6,
("cuda", False, torch.bfloat16): 1e-2,
("cuda", False, torch.float16): 5e-3,
("cuda", True, torch.float32): 1e-6,
("cuda", True, torch.bfloat16): 1e-2,
("cuda", True, torch.float16): 5e-3,
}
rtols = {
("cpu", False, torch.float32): 1e-4,
("cpu", False, torch.bfloat16): 1e-2,
("cpu", True, torch.float32): 1e-4,
("cpu", True, torch.bfloat16): 1e-2,
("cuda", False, torch.float32): 1e-4,
("cuda", False, torch.bfloat16): 1e-2,
("cuda", False, torch.float16): 5e-3,
("cuda", True, torch.float32): 1e-4,
("cuda", True, torch.bfloat16): 3e-2,
("cuda", True, torch.float16): 5e-3,
}
def get_mean_reldiff(failcase, x, ref, atol, rtol):
return f"{failcase}: mean relative difference: {((x - ref).abs() / (ref.abs() + 1e-12)).mean():.3e}, torch atol = {atol}, torch rtol = {rtol}"
for model_class in self.all_model_classes:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config)
# FIXME: we deactivate boolean mask for models using "use_mask_token" in their constructors.
# These models support masking only in the case `use_mask_token=True`. Otherwise they cannot consume an input mask.
# This means that the class needs to be instantiated much later, after `use_mask` is set, which means a significant refactor of the code.
# However masking there is not done at any layers that matters (i.e self-attention), therefore we can safely deactivate it.
deactivate_mask = "use_mask_token" in inspect.signature(model_class).parameters
is_encoder_decoder = model.config.is_encoder_decoder
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model_sdpa = model_class.from_pretrained(tmpdirname, torch_dtype=torch_dtype)
model_sdpa = model_sdpa.eval().to(torch_device)
self.assertTrue(model_sdpa.config._attn_implementation == "sdpa")
model_eager = model_class.from_pretrained(
tmpdirname,
torch_dtype=torch_dtype,
attn_implementation="eager",
)
model_eager = model_eager.eval().to(torch_device)
self.assertTrue(model_eager.config._attn_implementation == "eager")
for name, submodule in model_eager.named_modules():
class_name = submodule.__class__.__name__
if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
raise ValueError("The eager model should not have SDPA attention layers")
has_sdpa = False
for name, submodule in model_sdpa.named_modules():
class_name = submodule.__class__.__name__
if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
has_sdpa = True
break
if not has_sdpa and model_sdpa.config.model_type != "falcon":
raise ValueError("The SDPA model should have SDPA attention layers")
# We use these for loops instead of parameterized.expand just for the interest of avoiding loading/saving 16 times the model,
# but it would be nicer to have an efficient way to use parameterized.expand
fail_cases = []
for padding_side in ["left", "right"]:
for use_mask in [False, True]:
for output_attentions in [True, False]:
can_output_attn = "output_attentions" in inspect.signature(model_sdpa.forward).parameters
if not (self.has_attentions and can_output_attn) and output_attentions:
continue
for batch_size in [1, 5]:
dummy_input = inputs_dict[model.main_input_name]
if dummy_input.dtype in [torch.float32, torch.bfloat16, torch.float16]:
dummy_input = dummy_input.to(torch_dtype)
dummy_input = dummy_input[:batch_size]
if dummy_input.shape[0] != batch_size:
if dummy_input.dtype in [torch.float32, torch.bfloat16, torch.float16]:
extension = torch.rand(
batch_size - dummy_input.shape[0],
*dummy_input.shape[1:],
dtype=torch_dtype,
device=torch_device,
)
dummy_input = torch.cat((dummy_input, extension), dim=0).to(torch_device)
else:
extension = torch.randint(
high=5,
size=(batch_size - dummy_input.shape[0], *dummy_input.shape[1:]),
dtype=dummy_input.dtype,
device=torch_device,
)
dummy_input = torch.cat((dummy_input, extension), dim=0).to(torch_device)
if not use_mask:
dummy_attention_mask = None
else:
dummy_attention_mask = inputs_dict.get("attention_mask", None)
if dummy_attention_mask is None:
if is_encoder_decoder:
seqlen = inputs_dict.get("decoder_input_ids", dummy_input).shape[-1]
else:
seqlen = dummy_input.shape[-1]
dummy_attention_mask = (
torch.ones(batch_size, seqlen).to(torch.int64).to(torch_device)
)
dummy_attention_mask = dummy_attention_mask[:batch_size]
if dummy_attention_mask.shape[0] != batch_size:
extension = torch.ones(
batch_size - dummy_attention_mask.shape[0],
*dummy_attention_mask.shape[1:],
dtype=dummy_attention_mask.dtype,
device=torch_device,
)
dummy_attention_mask = torch.cat((dummy_attention_mask, extension), dim=0)
dummy_attention_mask = dummy_attention_mask.to(torch_device)
dummy_attention_mask[:] = 1
if padding_side == "left":
dummy_attention_mask[-1, :-1] = 1
dummy_attention_mask[-1, -4:] = 0
elif padding_side == "right":
dummy_attention_mask[-1, 1:] = 1
dummy_attention_mask[-1, :3] = 0
for enable_kernels in [False, True]:
failcase = f"padding_side={padding_side}, use_mask={use_mask}, batch_size={batch_size}, enable_kernels={enable_kernels}"
if is_encoder_decoder:
decoder_input_ids = inputs_dict.get("decoder_input_ids", dummy_input)[
:batch_size
]
if decoder_input_ids.shape[0] != batch_size:
extension = torch.ones(
batch_size - decoder_input_ids.shape[0],
*decoder_input_ids.shape[1:],
dtype=decoder_input_ids.dtype,
device=torch_device,
)
decoder_input_ids = torch.cat((decoder_input_ids, extension), dim=0)
decoder_input_ids = decoder_input_ids.to(torch_device)
# TODO: never an `attention_mask` arg here?
processed_inputs = {
model.main_input_name: dummy_input,
"decoder_input_ids": decoder_input_ids,
"decoder_attention_mask": dummy_attention_mask,
"output_hidden_states": True,
}
else:
processed_inputs = {
model.main_input_name: dummy_input,
"output_hidden_states": True,
}
# Otherwise fails for e.g. WhisperEncoderModel
if "attention_mask" in inspect.signature(model_eager.forward).parameters:
processed_inputs["attention_mask"] = dummy_attention_mask
if (
self.has_attentions
and "output_attentions" in inspect.signature(model_sdpa.forward).parameters
):
processed_inputs["output_attentions"] = output_attentions
if not deactivate_mask and (
"bool_masked_pos" in inspect.signature(model_eager.forward).parameters
):
dummy_mask = torch.ones((self.model_tester.num_masks,))
# In case of additional token (like class) we define a custom `mask_length`
if hasattr(self.model_tester, "mask_length"):
mask_length = self.model_tester.mask_length - dummy_mask.size(0)
else:
mask_length = self.model_tester.seq_length - dummy_mask.size(0)
dummy_mask = torch.cat([dummy_mask, torch.zeros(mask_length)])
dummy_bool_masked_pos = dummy_mask.expand(batch_size, -1).bool()
processed_inputs["bool_masked_pos"] = dummy_bool_masked_pos.to(torch_device)
if "noise" in inspect.signature(model_eager.forward).parameters:
np.random.seed(2)
num_patches = int(
(self.model_tester.image_size // self.model_tester.patch_size) ** 2
)
noise = np.random.uniform(size=(batch_size, num_patches))
processed_inputs["noise"] = torch.from_numpy(noise)
# TODO: test gradients as well (& for FA2 as well!)
with torch.no_grad():
with torch.backends.cuda.sdp_kernel(
enable_flash=enable_kernels,
enable_math=True,
enable_mem_efficient=enable_kernels,
):
prepared_inputs = self._prepare_for_class(processed_inputs, model_class)
outputs_eager = model_eager(**prepared_inputs)
outputs_sdpa = model_sdpa(**prepared_inputs)
# Ignore copy
logits_eager = outputs_eager.audio_values
# Ignore copy
logits_sdpa = outputs_sdpa.audio_values
if torch_device in ["cpu", "cuda"]:
atol = atols[torch_device, enable_kernels, torch_dtype]
rtol = rtols[torch_device, enable_kernels, torch_dtype]
else:
atol = 1e-7
rtol = 1e-4
# Masked tokens output slightly deviates - we don't mind that.
if use_mask:
if padding_side == "left":
sub_sdpa = logits_sdpa[:-1]
sub_eager = logits_eager[:-1]
if not torch.allclose(sub_sdpa, sub_eager, atol=atol, rtol=rtol):
fail_cases.append(
get_mean_reldiff(failcase, sub_sdpa, sub_eager, atol, rtol)
)
sub_sdpa = logits_sdpa[-1, :-4]
sub_eager = logits_eager[-1, :-4]
if not torch.allclose(sub_sdpa, sub_eager, atol=atol, rtol=rtol):
fail_cases.append(
get_mean_reldiff(failcase, sub_sdpa, sub_eager, atol, rtol)
)
# Testing the padding tokens is not really meaningful but anyway
# sub_sdpa = logits_sdpa[-1, -4:]
# sub_eager = logits_eager[-1, -4:]
# if not torch.allclose(sub_sdpa, sub_eager, atol=atol, rtol=rtol):
# fail_cases.append(get_mean_reldiff(failcase, sub_sdpa, sub_eager, 4e-2, 4e-2))
elif padding_side == "right":
sub_sdpa = logits_sdpa[:-1]
sub_eager = logits_eager[:-1]
if not torch.allclose(sub_sdpa, sub_eager, atol=atol, rtol=rtol):
fail_cases.append(
get_mean_reldiff(failcase, sub_sdpa, sub_eager, atol, rtol)
)
sub_sdpa = logits_sdpa[-1, 3:]
sub_eager = logits_eager[-1, 3:]
if not torch.allclose(sub_sdpa, sub_eager, atol=atol, rtol=rtol):
fail_cases.append(
get_mean_reldiff(failcase, sub_sdpa, sub_eager, atol, rtol)
)
# Testing the padding tokens is not really meaningful but anyway
# sub_sdpa = logits_sdpa[-1, :3]
# sub_eager = logits_eager[-1, :3]
# if not torch.allclose(sub_sdpa, sub_eager, atol=atol, rtol=rtol):
# fail_cases.append(get_mean_reldiff(failcase, sub_sdpa, sub_eager, 4e-2, 4e-2))
else:
if not torch.allclose(logits_sdpa, logits_eager, atol=atol, rtol=rtol):
fail_cases.append(
get_mean_reldiff(failcase, logits_sdpa, logits_eager, atol, rtol)
)
self.assertTrue(len(fail_cases) == 0, "\n".join(fail_cases))
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
@slow
@is_flaky()
def test_flash_attn_2_inference_equivalence(self):
for model_class in self.all_model_classes:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model_fa = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2"
)
model_fa.to(torch_device)
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.bfloat16)
model.to(torch_device)
dummy_input = inputs_dict[model.main_input_name][:1]
if dummy_input.dtype in [torch.float32, torch.float16]:
dummy_input = dummy_input.to(torch.bfloat16)
outputs = model(dummy_input)
outputs_fa = model_fa(dummy_input)
logits = outputs[1]
logits_fa = outputs_fa[1]
assert torch.allclose(logits_fa, logits, atol=4e-2, rtol=4e-2)
@unittest.skip(reason="The MimiModel does not support right padding")
def test_flash_attn_2_inference_equivalence_right_padding(self):
pass
@unittest.skip(reason="The MimiModel does not have support dynamic compile yet")
def test_sdpa_can_compile_dynamic(self):
pass
# For now, Let's focus only on GPU for `torch.compile`
@slow
@require_torch_gpu
def test_torch_compile(self):
if version.parse(torch.__version__) < version.parse("2.3"):
self.skipTest(reason="This test requires torch >= 2.3 to run.")
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
n_iter = 3
for model_class in self.all_model_classes:
model = model_class(config).to(torch_device)
model.forward = torch.compile(model.forward)
for i in range(n_iter):
_ = model(inputs_dict["input_values"].to(torch_device))
@is_flaky()
def test_batching_equivalence(self):
super().test_batching_equivalence()
# Copied from transformers.tests.encodec.test_modeling_encodec.normalize
def normalize(arr):
norm = np.linalg.norm(arr)
normalized_arr = arr / norm
return normalized_arr
# Copied from transformers.tests.encodec.test_modeling_encodec.compute_rmse
def compute_rmse(arr1, arr2):
arr1_normalized = normalize(arr1)
arr2_normalized = normalize(arr2)
return np.sqrt(((arr1_normalized - arr2_normalized) ** 2).mean())
@slow
@require_torch
class MimiIntegrationTest(unittest.TestCase):
def test_integration_using_cache_decode(self):
expected_rmse = {
"8": 0.0018785292,
"32": 0.0012330565,
}
librispeech_dummy = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
model_id = "kyutai/mimi"
model = MimiModel.from_pretrained(model_id, use_cache=True).to(torch_device)
processor = AutoFeatureExtractor.from_pretrained(model_id)
librispeech_dummy = librispeech_dummy.cast_column("audio", Audio(sampling_rate=processor.sampling_rate))
audio_sample = librispeech_dummy[-1]["audio"]["array"]
inputs = processor(
raw_audio=audio_sample,
sampling_rate=processor.sampling_rate,
return_tensors="pt",
).to(torch_device)
for num_codebooks, expected_rmse in expected_rmse.items():
with torch.no_grad():
# use max bandwith for best possible reconstruction
encoder_outputs = model.encode(inputs["input_values"], num_quantizers=int(num_codebooks))
audio_codes = encoder_outputs[0]
decoder_outputs_first_part = model.decode(audio_codes[:, :, : audio_codes.shape[2] // 2])
decoder_outputs_second_part = model.decode(
audio_codes[:, :, audio_codes.shape[2] // 2 :],
decoder_past_key_values=decoder_outputs_first_part.decoder_past_key_values,
)
audio_output_entire_context = model.decode(audio_codes)[0]
audio_output_concat_context = torch.cat(
[decoder_outputs_first_part[0], decoder_outputs_second_part[0]], dim=2
)
# make sure audios are more or less equal
# the RMSE of two random gaussian noise vectors with ~N(0, 1) is around 1.0
rmse = compute_rmse(
audio_output_concat_context.squeeze().cpu().numpy(),
audio_output_entire_context.squeeze().cpu().numpy(),
)
self.assertTrue(rmse < 1e-3)
def test_integration(self):
expected_rmses = {
"8": 0.0018785292,
"32": 0.0012330565,
}
expected_codesums = {
"8": 430423,
"32": 1803071,
}
librispeech_dummy = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
model_id = "kyutai/mimi"
processor = AutoFeatureExtractor.from_pretrained(model_id)
librispeech_dummy = librispeech_dummy.cast_column("audio", Audio(sampling_rate=processor.sampling_rate))
audio_sample = librispeech_dummy[-1]["audio"]["array"]
inputs = processor(
raw_audio=audio_sample,
sampling_rate=processor.sampling_rate,
return_tensors="pt",
).to(torch_device)
for use_cache in [False, True]:
model = MimiModel.from_pretrained(model_id, use_cache=use_cache).to(torch_device)
for num_codebooks, expected_rmse in expected_rmses.items():
with torch.no_grad():
# use max bandwith for best possible reconstruction
encoder_outputs = model.encode(inputs["input_values"], num_quantizers=int(num_codebooks))
audio_code_sums = encoder_outputs[0].sum().cpu().item()
# make sure audio encoded codes are correct
# assert relative difference less than a threshold, because `audio_code_sums` varies a bit
# depending on torch version
self.assertTrue(
np.abs(audio_code_sums - expected_codesums[num_codebooks]) <= (3e-3 * audio_code_sums)
)
input_values_dec = model.decode(encoder_outputs[0], padding_mask=inputs["padding_mask"])[0]
input_values_enc_dec = model(
inputs["input_values"], inputs["padding_mask"], num_quantizers=int(num_codebooks)
)[1]
# make sure forward and decode gives same result
self.assertTrue(torch.allclose(input_values_dec, input_values_enc_dec))
# make sure shape matches
self.assertTrue(inputs["input_values"].shape == input_values_enc_dec.shape)
arr = inputs["input_values"][0].cpu().numpy()
arr_enc_dec = input_values_enc_dec[0].cpu().numpy()
# make sure audios are more or less equal
# the RMSE of two random gaussian noise vectors with ~N(0, 1) is around 1.0
rmse = compute_rmse(arr, arr_enc_dec)
self.assertTrue(np.abs(rmse - expected_rmse) < 1e-5)