mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Codec integration (#33565)
* clean mimi commit * some nits suggestions from Arthur * make fixup * rename repo id + change readme * Update docs/source/en/model_doc/mimi.md Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * add flaky flag to batching equivalence due to audio_codes failing sometimes --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
parent
6019f3ff78
commit
5af7d41e49
@ -722,6 +722,8 @@
|
||||
title: Hubert
|
||||
- local: model_doc/mctct
|
||||
title: MCTCT
|
||||
- local: model_doc/mimi
|
||||
title: Mimi
|
||||
- local: model_doc/mms
|
||||
title: MMS
|
||||
- local: model_doc/musicgen
|
||||
|
@ -210,6 +210,7 @@ Flax), PyTorch, and/or TensorFlow.
|
||||
| [Megatron-BERT](model_doc/megatron-bert) | ✅ | ❌ | ❌ |
|
||||
| [Megatron-GPT2](model_doc/megatron_gpt2) | ✅ | ✅ | ✅ |
|
||||
| [MGP-STR](model_doc/mgp-str) | ✅ | ❌ | ❌ |
|
||||
| [Mimi](model_doc/mimi) | ✅ | ❌ | ❌ |
|
||||
| [Mistral](model_doc/mistral) | ✅ | ✅ | ✅ |
|
||||
| [Mixtral](model_doc/mixtral) | ✅ | ❌ | ❌ |
|
||||
| [mLUKE](model_doc/mluke) | ✅ | ❌ | ❌ |
|
||||
|
69
docs/source/en/model_doc/mimi.md
Normal file
69
docs/source/en/model_doc/mimi.md
Normal file
@ -0,0 +1,69 @@
|
||||
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
|
||||
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
||||
rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
|
||||
# Mimi
|
||||
|
||||
## Overview
|
||||
|
||||
The Mimi model was proposed in [Moshi: a speech-text foundation model for real-time dialogue](https://kyutai.org/Moshi.pdf) by Alexandre Défossez, Laurent Mazaré, Manu Orsini, Amélie Royer, Patrick Pérez, Hervé Jégou, Edouard Grave and Neil Zeghidour. Mimi is a high-fidelity audio codec model developed by the Kyutai team, that combines semantic and acoustic information into audio tokens running at 12Hz and a bitrate of 1.1kbps. In other words, it can be used to map audio waveforms into “audio tokens”, known as “codebooks”.
|
||||
|
||||
The abstract from the paper is the following:
|
||||
|
||||
*We introduce Moshi, a speech-text foundation model and full-duplex spoken dialogue framework. Current systems for spoken dialogue rely on pipelines of independent components, namely voice activity detection, speech recognition, textual dialogue and text-to-speech. Such frameworks cannot emulate the experience of real conversations. First, their complexity induces a latency of several seconds between interactions. Second, text being the intermediate modality for dialogue, non-linguistic information that modifies meaning— such as emotion or non-speech sounds— is lost in the interaction. Finally, they rely on a segmentation into speaker turns, which does not take into account overlapping speech, interruptions and interjections. Moshi solves these independent issues altogether by casting spoken dialogue as speech-to-speech generation. Starting from a text language model backbone, Moshi generates speech as tokens from the residual quantizer of a neural audio codec, while modeling separately its own speech and that of the user into parallel streams. This allows for the removal of explicit speaker turns, and the modeling of arbitrary conversational dynamics. We moreover extend the hierarchical semantic-to-acoustic token generation of previous work to first predict time-aligned text tokens as a prefix to audio tokens. Not only this “Inner Monologue” method significantly improves the linguistic quality of generated speech, but we also illustrate how it can provide streaming speech recognition and text-to-speech. Our resulting model is the first real-time full-duplex spoken large language model, with a theoretical latency of 160ms, 200ms in practice, and is available at github.com/kyutai-labs/moshi.*
|
||||
|
||||
Its architecture is based on [Encodec](model_doc/encodec) with several major differences:
|
||||
* it uses a much lower frame-rate.
|
||||
* it uses additional transformers for encoding and decoding for better latent contextualization
|
||||
* it uses a different quantization scheme: one codebook is dedicated to semantic projection.
|
||||
|
||||
## Usage example
|
||||
|
||||
Here is a quick example of how to encode and decode an audio using this model:
|
||||
|
||||
```python
|
||||
>>> from datasets import load_dataset, Audio
|
||||
>>> from transformers import MimiModel, AutoFeatureExtractor
|
||||
>>> librispeech_dummy = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
||||
|
||||
>>> # load model and feature extractor
|
||||
>>> model = MimiModel.from_pretrained("kyutai/mimi")
|
||||
>>> feature_extractor = AutoFeatureExtractor.from_pretrained("kyutai/mimi")
|
||||
|
||||
>>> # load audio sample
|
||||
>>> librispeech_dummy = librispeech_dummy.cast_column("audio", Audio(sampling_rate=feature_extractor.sampling_rate))
|
||||
>>> audio_sample = librispeech_dummy[-1]["audio"]["array"]
|
||||
>>> inputs = feature_extractor(raw_audio=audio_sample, sampling_rate=feature_extractor.sampling_rate, return_tensors="pt")
|
||||
|
||||
>>> encoder_outputs = model.encode(inputs["input_values"], inputs["padding_mask"])
|
||||
>>> audio_values = model.decode(encoder_outputs.audio_codes, inputs["padding_mask"])[0]
|
||||
>>> # or the equivalent with a forward pass
|
||||
>>> audio_values = model(inputs["input_values"], inputs["padding_mask"]).audio_values
|
||||
```
|
||||
|
||||
This model was contributed by [Yoach Lacombe (ylacombe)](https://huggingface.co/ylacombe).
|
||||
The original code can be found [here](https://github.com/kyutai-labs/moshi).
|
||||
|
||||
|
||||
## MimiConfig
|
||||
|
||||
[[autodoc]] MimiConfig
|
||||
|
||||
## MimiModel
|
||||
|
||||
[[autodoc]] MimiModel
|
||||
- decode
|
||||
- encode
|
||||
- forward
|
@ -61,6 +61,7 @@ FlashAttention-2 is currently supported for the following architectures:
|
||||
* [Llava-NeXT](https://huggingface.co/docs/transformers/model_doc/llava_next)
|
||||
* [Llava-NeXT-Video](https://huggingface.co/docs/transformers/model_doc/llava_next_video)
|
||||
* [LLaVA-Onevision](https://huggingface.co/docs/transformers/model_doc/llava_onevision)
|
||||
* [Mimi](https://huggingface.co/docs/transformers/model_doc/mimi)
|
||||
* [VipLlava](https://huggingface.co/docs/transformers/model_doc/vipllava)
|
||||
* [VideoLlava](https://huggingface.co/docs/transformers/model_doc/video_llava)
|
||||
* [M2M100](https://huggingface.co/docs/transformers/model_doc/m2m_100)
|
||||
@ -228,6 +229,7 @@ For now, Transformers supports SDPA inference and training for the following arc
|
||||
* [Jamba](https://huggingface.co/docs/transformers/model_doc/jamba#transformers.JambaModel)
|
||||
* [Llama](https://huggingface.co/docs/transformers/model_doc/llama#transformers.LlamaModel)
|
||||
* [LLaVA-Onevision](https://huggingface.co/docs/transformers/model_doc/llava_onevision)
|
||||
* [Mimi](https://huggingface.co/docs/transformers/model_doc/mimi)
|
||||
* [Mistral](https://huggingface.co/docs/transformers/model_doc/mistral#transformers.MistralModel)
|
||||
* [Mixtral](https://huggingface.co/docs/transformers/model_doc/mixtral#transformers.MixtralModel)
|
||||
* [Musicgen](https://huggingface.co/docs/transformers/model_doc/musicgen#transformers.MusicgenModel)
|
||||
|
@ -573,6 +573,7 @@ _import_structure = {
|
||||
"MgpstrProcessor",
|
||||
"MgpstrTokenizer",
|
||||
],
|
||||
"models.mimi": ["MimiConfig"],
|
||||
"models.mistral": ["MistralConfig"],
|
||||
"models.mixtral": ["MixtralConfig"],
|
||||
"models.mluke": [],
|
||||
@ -2666,6 +2667,12 @@ else:
|
||||
"MgpstrPreTrainedModel",
|
||||
]
|
||||
)
|
||||
_import_structure["models.mimi"].extend(
|
||||
[
|
||||
"MimiModel",
|
||||
"MimiPreTrainedModel",
|
||||
]
|
||||
)
|
||||
_import_structure["models.mistral"].extend(
|
||||
[
|
||||
"MistralForCausalLM",
|
||||
@ -5345,6 +5352,9 @@ if TYPE_CHECKING:
|
||||
MgpstrProcessor,
|
||||
MgpstrTokenizer,
|
||||
)
|
||||
from .models.mimi import (
|
||||
MimiConfig,
|
||||
)
|
||||
from .models.mistral import MistralConfig
|
||||
from .models.mixtral import MixtralConfig
|
||||
from .models.mobilebert import (
|
||||
@ -7212,6 +7222,10 @@ if TYPE_CHECKING:
|
||||
MgpstrModel,
|
||||
MgpstrPreTrainedModel,
|
||||
)
|
||||
from .models.mimi import (
|
||||
MimiModel,
|
||||
MimiPreTrainedModel,
|
||||
)
|
||||
from .models.mistral import (
|
||||
MistralForCausalLM,
|
||||
MistralForSequenceClassification,
|
||||
|
@ -149,6 +149,7 @@ from . import (
|
||||
megatron_bert,
|
||||
megatron_gpt2,
|
||||
mgp_str,
|
||||
mimi,
|
||||
mistral,
|
||||
mixtral,
|
||||
mluke,
|
||||
|
@ -167,6 +167,7 @@ CONFIG_MAPPING_NAMES = OrderedDict(
|
||||
("mega", "MegaConfig"),
|
||||
("megatron-bert", "MegatronBertConfig"),
|
||||
("mgp-str", "MgpstrConfig"),
|
||||
("mimi", "MimiConfig"),
|
||||
("mistral", "MistralConfig"),
|
||||
("mixtral", "MixtralConfig"),
|
||||
("mobilebert", "MobileBertConfig"),
|
||||
@ -468,6 +469,7 @@ MODEL_NAMES_MAPPING = OrderedDict(
|
||||
("megatron-bert", "Megatron-BERT"),
|
||||
("megatron_gpt2", "Megatron-GPT2"),
|
||||
("mgp-str", "MGP-STR"),
|
||||
("mimi", "Mimi"),
|
||||
("mistral", "Mistral"),
|
||||
("mixtral", "Mixtral"),
|
||||
("mluke", "mLUKE"),
|
||||
|
@ -69,6 +69,7 @@ FEATURE_EXTRACTOR_MAPPING_NAMES = OrderedDict(
|
||||
("levit", "LevitFeatureExtractor"),
|
||||
("maskformer", "MaskFormerFeatureExtractor"),
|
||||
("mctct", "MCTCTFeatureExtractor"),
|
||||
("mimi", "EncodecFeatureExtractor"),
|
||||
("mobilenet_v1", "MobileNetV1FeatureExtractor"),
|
||||
("mobilenet_v2", "MobileNetV2FeatureExtractor"),
|
||||
("mobilevit", "MobileViTFeatureExtractor"),
|
||||
|
@ -158,6 +158,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
|
||||
("mega", "MegaModel"),
|
||||
("megatron-bert", "MegatronBertModel"),
|
||||
("mgp-str", "MgpstrForSceneTextRecognition"),
|
||||
("mimi", "MimiModel"),
|
||||
("mistral", "MistralModel"),
|
||||
("mixtral", "MixtralModel"),
|
||||
("mobilebert", "MobileBertModel"),
|
||||
|
57
src/transformers/models/mimi/__init__.py
Normal file
57
src/transformers/models/mimi/__init__.py
Normal file
@ -0,0 +1,57 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...utils import (
|
||||
OptionalDependencyNotAvailable,
|
||||
_LazyModule,
|
||||
is_torch_available,
|
||||
)
|
||||
|
||||
|
||||
_import_structure = {
|
||||
"configuration_mimi": ["MimiConfig"],
|
||||
}
|
||||
|
||||
try:
|
||||
if not is_torch_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
_import_structure["modeling_mimi"] = [
|
||||
"MimiModel",
|
||||
"MimiPreTrainedModel",
|
||||
]
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .configuration_mimi import (
|
||||
MimiConfig,
|
||||
)
|
||||
|
||||
try:
|
||||
if not is_torch_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
from .modeling_mimi import (
|
||||
MimiModel,
|
||||
MimiPreTrainedModel,
|
||||
)
|
||||
|
||||
else:
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
|
234
src/transformers/models/mimi/configuration_mimi.py
Normal file
234
src/transformers/models/mimi/configuration_mimi.py
Normal file
@ -0,0 +1,234 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 Meta Platforms, Inc. and affiliates, and the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Mimi model configuration"""
|
||||
|
||||
import math
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class MimiConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of an [`MimiModel`]. It is used to instantiate a
|
||||
Mimi model according to the specified arguments, defining the model architecture. Instantiating a configuration
|
||||
with the defaults will yield a similar configuration to that of the
|
||||
[kyutai/mimi](https://huggingface.co/kyutai/mimi) architecture.
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
Args:
|
||||
sampling_rate (`int`, *optional*, defaults to 24000):
|
||||
The sampling rate at which the audio waveform should be digitalized expressed in hertz (Hz).
|
||||
frame_rate (`float`, *optional*, defaults to 12.5):
|
||||
Framerate of the model.
|
||||
audio_channels (`int`, *optional*, defaults to 1):
|
||||
Number of channels in the audio data. Either 1 for mono or 2 for stereo.
|
||||
hidden_size (`int`, *optional*, defaults to 512):
|
||||
Intermediate representation dimension.
|
||||
num_filters (`int`, *optional*, defaults to 64):
|
||||
Number of convolution kernels of first `MimiConv1d` down sampling layer.
|
||||
num_residual_layers (`int`, *optional*, defaults to 1):
|
||||
Number of residual layers.
|
||||
upsampling_ratios (`Sequence[int]`, *optional*):
|
||||
Kernel size and stride ratios. The encoder uses downsampling ratios instead of upsampling ratios, hence it
|
||||
will use the ratios in the reverse order to the ones specified here that must match the decoder order.
|
||||
If not specified, will defaults to `[8, 6, 5, 4]`
|
||||
kernel_size (`int`, *optional*, defaults to 7):
|
||||
Kernel size for the initial convolution.
|
||||
last_kernel_size (`int`, *optional*, defaults to 3):
|
||||
Kernel size for the last convolution layer.
|
||||
residual_kernel_size (`int`, *optional*, defaults to 3):
|
||||
Kernel size for the residual layers.
|
||||
dilation_growth_rate (`int`, *optional*, defaults to 2):
|
||||
How much to increase the dilation with each layer.
|
||||
use_causal_conv (`bool`, *optional*, defaults to `True`):
|
||||
Whether to use fully causal convolution.
|
||||
pad_mode (`str`, *optional*, defaults to `"constant"`):
|
||||
Padding mode for the convolutions.
|
||||
compress (`int`, *optional*, defaults to 2):
|
||||
Reduced dimensionality in residual branches.
|
||||
trim_right_ratio (`float`, *optional*, defaults to 1.0):
|
||||
Ratio for trimming at the right of the transposed convolution under the `use_causal_conv = True` setup. If
|
||||
equal to 1.0, it means that all the trimming is done at the right.
|
||||
codebook_size (`int`, *optional*, defaults to 2048):
|
||||
Number of discret codes in each codebooks.
|
||||
codebook_dim (`int`, *optional*, defaults to 256):
|
||||
Dimension of the unquantized codebook vectors. If not defined, uses `hidden_size`.
|
||||
num_quantizers (`int`, *optional*, defaults to 32):
|
||||
Number of quantizer channels, or codebooks, in the quantizer.
|
||||
use_conv_shortcut (`bool`, *optional*, defaults to `False`):
|
||||
Whether to use a convolutional layer as the 'skip' connection in the `MimiResnetBlock` block. If False,
|
||||
an identity function will be used, giving a generic residual connection.
|
||||
vector_quantization_hidden_dimension (`int`, *optional*, defaults to 256):
|
||||
Intermediate representation dimension in the residual vector quantization space.
|
||||
num_semantic_quantizers (`int`, *optional*, defaults to 1):
|
||||
Number of semantic quantizer channels, or codebooks, in the semantic quantizer. Must be lower than `num_quantizers`.
|
||||
upsample_groups (`int`, *optional*, defaults to 512):
|
||||
If `frame_rate!=encodec_frame_rate`, indicates the number of groups used in the upsampling operation to go from one rate to another.
|
||||
num_hidden_layers (`int`, *optional*, defaults to 8):
|
||||
Number of hidden layers in the Transformer models.
|
||||
intermediate_size (`int`, *optional*, defaults to 2048):
|
||||
Dimension of the MLP representations.
|
||||
num_attention_heads (`int`, *optional*, defaults to 8):
|
||||
Number of attention heads for each attention layer in the Transformer encoder.
|
||||
num_key_value_heads (`int`, *optional*, defaults to 8):
|
||||
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
|
||||
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
|
||||
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
|
||||
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
|
||||
by meanpooling all the original heads within that group. For more details checkout [this
|
||||
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `8`.
|
||||
head_dim (`int`, *optional*, defaults to `hidden_size // num_attention_heads`):
|
||||
The attention head dimension.
|
||||
hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
|
||||
The non-linear activation function (function or string) in the decoder.
|
||||
max_position_embeddings (`int`, *optional*, defaults to 8000):
|
||||
The maximum sequence length that this model might ever be used with. Mimi's sliding window attention
|
||||
allows sequence of up to 8000 tokens.
|
||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
norm_eps (`float`, *optional*, defaults to 1e-05):
|
||||
The epsilon used by the LayerNorm normalization layers.
|
||||
use_cache (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
||||
relevant if `config.is_decoder=True`.
|
||||
rope_theta (`float`, *optional*, defaults to 10000.0):
|
||||
The base period of the RoPE embeddings.
|
||||
sliding_window (`int`, *optional*, defaults to 250):
|
||||
Sliding window attention window size. If not specified, will default to `250`.
|
||||
attention_dropout (`float`, *optional*, defaults to 0.0):
|
||||
The dropout ratio for the attention probabilities.
|
||||
layer_scale_initial_scale (`float`, *optional*, defaults to 0.01):
|
||||
Initiale scale of the residual rescaling operation done in the Transformer models.
|
||||
attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
|
||||
Whether to use a bias in the query, key, value and output projection layers during self-attention.
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import MimiModel, MimiConfig
|
||||
|
||||
>>> # Initializing a "kyutai/mimi" style configuration
|
||||
>>> configuration = MimiConfig()
|
||||
|
||||
>>> # Initializing a model (with random weights) from the "kyutai/mimi" style configuration
|
||||
>>> model = MimiModel(configuration)
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
```"""
|
||||
|
||||
model_type = "mimi"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
sampling_rate=24_000,
|
||||
frame_rate=12.5,
|
||||
audio_channels=1,
|
||||
hidden_size=512,
|
||||
num_filters=64,
|
||||
num_residual_layers=1,
|
||||
upsampling_ratios=None,
|
||||
kernel_size=7,
|
||||
last_kernel_size=3,
|
||||
residual_kernel_size=3,
|
||||
dilation_growth_rate=2,
|
||||
use_causal_conv=True,
|
||||
pad_mode="constant",
|
||||
compress=2,
|
||||
trim_right_ratio=1.0,
|
||||
codebook_size=2048,
|
||||
codebook_dim=256,
|
||||
num_quantizers=32,
|
||||
use_conv_shortcut=False,
|
||||
vector_quantization_hidden_dimension=256,
|
||||
num_semantic_quantizers=1,
|
||||
upsample_groups=512,
|
||||
num_hidden_layers=8,
|
||||
intermediate_size=2048,
|
||||
num_attention_heads=8,
|
||||
num_key_value_heads=8,
|
||||
head_dim=None,
|
||||
hidden_act="gelu",
|
||||
max_position_embeddings=8000,
|
||||
initializer_range=0.02,
|
||||
norm_eps=1e-5,
|
||||
use_cache=False,
|
||||
rope_theta=10000.0,
|
||||
sliding_window=250,
|
||||
attention_dropout=0.0,
|
||||
layer_scale_initial_scale=0.01,
|
||||
attention_bias=False,
|
||||
**kwargs,
|
||||
):
|
||||
self.sampling_rate = sampling_rate
|
||||
self.frame_rate = frame_rate
|
||||
self.audio_channels = audio_channels
|
||||
self.hidden_size = hidden_size
|
||||
self.num_filters = num_filters
|
||||
self.num_residual_layers = num_residual_layers
|
||||
self.upsampling_ratios = upsampling_ratios if upsampling_ratios else [8, 6, 5, 4]
|
||||
self.kernel_size = kernel_size
|
||||
self.last_kernel_size = last_kernel_size
|
||||
self.residual_kernel_size = residual_kernel_size
|
||||
self.dilation_growth_rate = dilation_growth_rate
|
||||
self.use_causal_conv = use_causal_conv
|
||||
self.pad_mode = pad_mode
|
||||
self.compress = compress
|
||||
self.trim_right_ratio = trim_right_ratio
|
||||
self.codebook_size = codebook_size
|
||||
self.codebook_dim = codebook_dim if codebook_dim is not None else hidden_size
|
||||
self.num_quantizers = num_quantizers
|
||||
self.use_conv_shortcut = use_conv_shortcut
|
||||
self.vector_quantization_hidden_dimension = vector_quantization_hidden_dimension
|
||||
self.upsample_groups = upsample_groups
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.intermediate_size = intermediate_size
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.num_key_value_heads = num_key_value_heads
|
||||
self.hidden_act = hidden_act
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.initializer_range = initializer_range
|
||||
self.norm_eps = norm_eps
|
||||
self.use_cache = use_cache
|
||||
self.rope_theta = rope_theta
|
||||
self.sliding_window = sliding_window
|
||||
self.attention_dropout = attention_dropout
|
||||
self.head_dim = head_dim or hidden_size // num_attention_heads
|
||||
self.layer_scale_initial_scale = layer_scale_initial_scale
|
||||
self.attention_bias = attention_bias
|
||||
|
||||
if num_semantic_quantizers >= self.num_quantizers:
|
||||
raise ValueError(
|
||||
f"The number of semantic quantizers should be lower than the total number of quantizers {self.num_quantizers}, but is currently {num_semantic_quantizers}."
|
||||
)
|
||||
self.num_semantic_quantizers = num_semantic_quantizers
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@property
|
||||
def encodec_frame_rate(self) -> int:
|
||||
hop_length = np.prod(self.upsampling_ratios)
|
||||
return math.ceil(self.sampling_rate / hop_length)
|
||||
|
||||
@property
|
||||
def num_codebooks(self) -> int:
|
||||
# alias to num_quantizers
|
||||
return self.num_quantizers
|
@ -0,0 +1,198 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Convert Mimi checkpoints."""
|
||||
|
||||
import argparse
|
||||
|
||||
import safetensors
|
||||
import torch
|
||||
|
||||
from transformers import (
|
||||
EncodecFeatureExtractor,
|
||||
MimiConfig,
|
||||
MimiModel,
|
||||
logging,
|
||||
)
|
||||
|
||||
|
||||
logging.set_verbosity_info()
|
||||
logger = logging.get_logger("transformers.models.mimi")
|
||||
|
||||
|
||||
def assert_param_count(model_1, model_2):
|
||||
count_1 = sum(p[1].numel() for p in model_1.named_parameters() if "final_proj" not in p[0])
|
||||
count_2 = sum(p[1].numel() for p in model_2.named_parameters() if "final_proj" not in p[0])
|
||||
assert count_1 == count_2, f"{model_1.__class__}: {count_1} != {model_2.__class__}: {count_2}"
|
||||
|
||||
|
||||
def param_count(model):
|
||||
return sum(p[1].numel() for p in model.named_parameters() if "final_proj" not in p[0])
|
||||
|
||||
|
||||
def _grab_best_device(use_gpu=True):
|
||||
if torch.cuda.device_count() > 0 and use_gpu:
|
||||
device = "cuda"
|
||||
else:
|
||||
device = "cpu"
|
||||
return torch.device(device)
|
||||
|
||||
|
||||
convert_list = [
|
||||
# GENERAL
|
||||
("conv.conv.conv", "conv"),
|
||||
("convtr.convtr.convtr", "conv"),
|
||||
("conv.conv", "conv"),
|
||||
("convtr.convtr", "conv"),
|
||||
# QUANTIZER
|
||||
("quantizer.rvq_first.vq", "quantizer.semantic_residual_vector_quantizer"),
|
||||
("quantizer.rvq_first", "quantizer.semantic_residual_vector_quantizer"),
|
||||
("quantizer.rvq_rest.vq", "quantizer.acoustic_residual_vector_quantizer"),
|
||||
("quantizer.rvq_rest", "quantizer.acoustic_residual_vector_quantizer"),
|
||||
("_codebook", "codebook"),
|
||||
("_initialized", "initialized"),
|
||||
("embedding_sum", "embed_sum"),
|
||||
# ENCODER PART
|
||||
("encoder.model", "encoder.layers"),
|
||||
("decoder.model", "decoder.layers"),
|
||||
# TRANSFORMERS PART
|
||||
("encoder_transformer.transformer", "encoder_transformer"),
|
||||
("decoder_transformer.transformer", "decoder_transformer"),
|
||||
("linear1", "mlp.fc1"),
|
||||
("linear2", "mlp.fc2"),
|
||||
("self_attn.out_proj", "self_attn.o_proj"),
|
||||
("norm1", "input_layernorm"),
|
||||
("norm2", "post_attention_layernorm"),
|
||||
("layer_scale_1", "self_attn_layer_scale"),
|
||||
("layer_scale_2", "mlp_layer_scale"),
|
||||
]
|
||||
|
||||
|
||||
def _convert_model(
|
||||
state_dict,
|
||||
hf_model,
|
||||
convert_list,
|
||||
device,
|
||||
config,
|
||||
unwanted_prefix=None,
|
||||
):
|
||||
hidden_size = config.hidden_size
|
||||
head_dim = config.head_dim
|
||||
num_heads = int(config.hidden_size // config.head_dim)
|
||||
num_key_value_heads = config.num_key_value_heads
|
||||
key_value_head_dim = config.num_key_value_heads * head_dim
|
||||
|
||||
# permute for sliced rotary
|
||||
def permute(w, n_heads, dim1=hidden_size, dim2=hidden_size):
|
||||
return w.view(n_heads, dim1 // n_heads // 2, 2, dim2).transpose(1, 2).reshape(dim1, dim2)
|
||||
|
||||
for k, v in list(state_dict.items()):
|
||||
new_k = k if unwanted_prefix is None else k[len(unwanted_prefix) :]
|
||||
for old_layer_name, new_layer_name in convert_list:
|
||||
if old_layer_name in new_k:
|
||||
new_k = new_k.replace(old_layer_name, new_layer_name)
|
||||
|
||||
if "in_proj_weight" in new_k:
|
||||
# split qkv into query key and value
|
||||
mixed_qkv = state_dict.pop(k)
|
||||
qkv_dim = mixed_qkv.size(0) // 3
|
||||
|
||||
query_layer = mixed_qkv[:qkv_dim]
|
||||
key_layer = mixed_qkv[qkv_dim : qkv_dim * 2]
|
||||
value_layer = mixed_qkv[qkv_dim * 2 :]
|
||||
|
||||
state_dict[new_k.replace("in_proj_weight", "q_proj.weight")] = permute(query_layer, num_heads)
|
||||
state_dict[new_k.replace("in_proj_weight", "k_proj.weight")] = permute(
|
||||
key_layer, num_key_value_heads, dim1=key_value_head_dim
|
||||
)
|
||||
state_dict[new_k.replace("in_proj_weight", "v_proj.weight")] = value_layer
|
||||
else:
|
||||
state_dict[new_k] = state_dict.pop(k)
|
||||
|
||||
extra_keys = set(state_dict.keys()) - set(hf_model.state_dict().keys())
|
||||
missing_keys = set(hf_model.state_dict().keys()) - set(state_dict.keys())
|
||||
if len(extra_keys) != 0:
|
||||
raise ValueError(f"extra keys found: {extra_keys}")
|
||||
if len(missing_keys) != 0:
|
||||
raise ValueError(f"missing keys: {missing_keys}")
|
||||
hf_model.load_state_dict(state_dict, strict=True)
|
||||
n_params = param_count(hf_model)
|
||||
|
||||
logger.info(f"model loaded: {round(n_params/1e6,1)}M params")
|
||||
|
||||
hf_model.eval()
|
||||
hf_model.to(device)
|
||||
del state_dict
|
||||
|
||||
return hf_model
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def convert_checkpoint(
|
||||
checkpoint_path,
|
||||
pytorch_dump_folder_path,
|
||||
config_path=None,
|
||||
repo_id=None,
|
||||
):
|
||||
"""
|
||||
Copy/paste/tweak model's weights to transformers design.
|
||||
"""
|
||||
device = _grab_best_device()
|
||||
|
||||
if config_path is not None:
|
||||
config = MimiConfig.from_pretrained(config_path)
|
||||
else:
|
||||
config = MimiConfig()
|
||||
|
||||
model = MimiModel(config)
|
||||
|
||||
feature_extractor = EncodecFeatureExtractor(
|
||||
feature_size=config.audio_channels,
|
||||
sampling_rate=config.sampling_rate,
|
||||
)
|
||||
feature_extractor.save_pretrained(pytorch_dump_folder_path)
|
||||
|
||||
original_checkpoint = safetensors.torch.load_file(checkpoint_path)
|
||||
if "best_state" in original_checkpoint:
|
||||
# we might have a training state saved, in which case discard the yaml results and just retain the weights
|
||||
original_checkpoint = original_checkpoint["best_state"]
|
||||
|
||||
model = _convert_model(original_checkpoint, model, convert_list, device, config)
|
||||
|
||||
model.save_pretrained(pytorch_dump_folder_path)
|
||||
|
||||
if repo_id:
|
||||
print("Pushing to the hub...")
|
||||
feature_extractor.push_to_hub(repo_id)
|
||||
model.push_to_hub(repo_id)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--checkpoint_path", required=True, default=None, type=str, help="Path to original checkpoint")
|
||||
parser.add_argument("--config_path", default=None, type=str, help="Path to hf config.json of model to convert")
|
||||
parser.add_argument(
|
||||
"--pytorch_dump_folder_path", required=True, default=None, type=str, help="Path to the output PyTorch model."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--push_to_hub", default=None, type=str, help="Where to upload the converted model on the 🤗 hub."
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
convert_checkpoint(
|
||||
args.checkpoint_path,
|
||||
args.pytorch_dump_folder_path,
|
||||
args.config_path,
|
||||
args.push_to_hub,
|
||||
)
|
1722
src/transformers/models/mimi/modeling_mimi.py
Normal file
1722
src/transformers/models/mimi/modeling_mimi.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -5840,6 +5840,20 @@ class MgpstrPreTrainedModel(metaclass=DummyObject):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class MimiModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class MimiPreTrainedModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class MistralForCausalLM(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
|
0
tests/models/mimi/__init__.py
Normal file
0
tests/models/mimi/__init__.py
Normal file
890
tests/models/mimi/test_modeling_mimi.py
Normal file
890
tests/models/mimi/test_modeling_mimi.py
Normal file
@ -0,0 +1,890 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Testing suite for the PyTorch Mimi model."""
|
||||
|
||||
import inspect
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
from datasets import Audio, load_dataset
|
||||
from packaging import version
|
||||
from parameterized import parameterized
|
||||
from pytest import mark
|
||||
|
||||
from transformers import AutoFeatureExtractor, MimiConfig
|
||||
from transformers.testing_utils import (
|
||||
is_flaky,
|
||||
is_torch_available,
|
||||
require_flash_attn,
|
||||
require_torch,
|
||||
require_torch_gpu,
|
||||
require_torch_sdpa,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
from transformers.utils import (
|
||||
is_torch_bf16_available_on_device,
|
||||
is_torch_fp16_available_on_device,
|
||||
)
|
||||
|
||||
from ...test_configuration_common import ConfigTester
|
||||
from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor, ids_tensor
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers import MimiModel
|
||||
|
||||
|
||||
# Copied from transformers.tests.encodec.test_modeling_encodec.prepare_inputs_dict
|
||||
def prepare_inputs_dict(
|
||||
config,
|
||||
input_ids=None,
|
||||
input_values=None,
|
||||
decoder_input_ids=None,
|
||||
attention_mask=None,
|
||||
decoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
cross_attn_head_mask=None,
|
||||
):
|
||||
if input_ids is not None:
|
||||
encoder_dict = {"input_ids": input_ids}
|
||||
else:
|
||||
encoder_dict = {"input_values": input_values}
|
||||
|
||||
decoder_dict = {"decoder_input_ids": decoder_input_ids} if decoder_input_ids is not None else {}
|
||||
|
||||
return {**encoder_dict, **decoder_dict}
|
||||
|
||||
|
||||
@require_torch
|
||||
class MimiModelTester:
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
batch_size=5,
|
||||
num_channels=1,
|
||||
is_training=False,
|
||||
intermediate_size=40,
|
||||
hidden_size=32,
|
||||
num_filters=8,
|
||||
num_residual_layers=1,
|
||||
upsampling_ratios=[8, 4],
|
||||
codebook_size=64,
|
||||
vector_quantization_hidden_dimension=64,
|
||||
codebook_dim=64,
|
||||
upsample_groups=32,
|
||||
num_hidden_layers=2,
|
||||
num_attention_heads=2,
|
||||
num_key_value_heads=2,
|
||||
sliding_window=4,
|
||||
use_cache=False,
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
self.num_channels = num_channels
|
||||
self.is_training = is_training
|
||||
self.intermediate_size = intermediate_size
|
||||
self.hidden_size = hidden_size
|
||||
self.num_filters = num_filters
|
||||
self.num_residual_layers = num_residual_layers
|
||||
self.upsampling_ratios = upsampling_ratios
|
||||
self.codebook_size = codebook_size
|
||||
self.vector_quantization_hidden_dimension = vector_quantization_hidden_dimension
|
||||
self.codebook_dim = codebook_dim
|
||||
self.upsample_groups = upsample_groups
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.num_key_value_heads = num_key_value_heads
|
||||
self.sliding_window = sliding_window
|
||||
self.use_cache = use_cache
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
input_values = floats_tensor([self.batch_size, self.num_channels, self.intermediate_size], scale=1.0)
|
||||
config = self.get_config()
|
||||
inputs_dict = {"input_values": input_values}
|
||||
return config, inputs_dict
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config, inputs_dict = self.prepare_config_and_inputs()
|
||||
return config, inputs_dict
|
||||
|
||||
def prepare_config_and_inputs_for_model_class(self, model_class):
|
||||
config, inputs_dict = self.prepare_config_and_inputs()
|
||||
inputs_dict["audio_codes"] = ids_tensor([self.batch_size, 1, self.num_channels], self.codebook_size).type(
|
||||
torch.int32
|
||||
)
|
||||
|
||||
return config, inputs_dict
|
||||
|
||||
def get_config(self):
|
||||
return MimiConfig(
|
||||
audio_channels=self.num_channels,
|
||||
chunk_in_sec=None,
|
||||
hidden_size=self.hidden_size,
|
||||
num_filters=self.num_filters,
|
||||
num_residual_layers=self.num_residual_layers,
|
||||
upsampling_ratios=self.upsampling_ratios,
|
||||
codebook_size=self.codebook_size,
|
||||
vector_quantization_hidden_dimension=self.vector_quantization_hidden_dimension,
|
||||
upsample_groups=self.upsample_groups,
|
||||
num_hidden_layers=self.num_hidden_layers,
|
||||
num_attention_heads=self.num_attention_heads,
|
||||
num_key_value_heads=self.num_key_value_heads,
|
||||
sliding_window=self.sliding_window,
|
||||
codebook_dim=self.codebook_dim,
|
||||
use_cache=self.use_cache,
|
||||
)
|
||||
|
||||
def create_and_check_model_forward(self, config, inputs_dict):
|
||||
model = MimiModel(config=config).to(torch_device).eval()
|
||||
|
||||
input_values = inputs_dict["input_values"]
|
||||
result = model(input_values)
|
||||
self.parent.assertEqual(
|
||||
result.audio_values.shape, (self.batch_size, self.num_channels, self.intermediate_size)
|
||||
)
|
||||
|
||||
|
||||
@require_torch
|
||||
class MimiModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (MimiModel,) if is_torch_available() else ()
|
||||
is_encoder_decoder = True
|
||||
test_pruning = False
|
||||
test_headmasking = False
|
||||
test_resize_embeddings = False
|
||||
test_torchscript = False
|
||||
input_name = "input_values"
|
||||
|
||||
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
|
||||
# model does support returning hidden states
|
||||
inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)
|
||||
if "output_attentions" in inputs_dict:
|
||||
inputs_dict.pop("output_attentions")
|
||||
if "output_hidden_states" in inputs_dict:
|
||||
inputs_dict.pop("output_hidden_states")
|
||||
return inputs_dict
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = MimiModelTester(self)
|
||||
self.config_tester = ConfigTester(
|
||||
self, config_class=MimiConfig, hidden_size=37, common_properties=[], has_text_modality=False
|
||||
)
|
||||
|
||||
def test_config(self):
|
||||
self.config_tester.run_common_tests()
|
||||
|
||||
def test_model_forward(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_model_forward(*config_and_inputs)
|
||||
|
||||
def test_forward_signature(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
signature = inspect.signature(model.forward)
|
||||
# signature.parameters is an OrderedDict => so arg_names order is deterministic
|
||||
arg_names = [*signature.parameters.keys()]
|
||||
|
||||
expected_arg_names = ["input_values", "padding_mask", "num_quantizers"]
|
||||
self.assertListEqual(arg_names[: len(expected_arg_names)], expected_arg_names)
|
||||
|
||||
@unittest.skip(reason="The MimiModel does not have `inputs_embeds` logics")
|
||||
def test_inputs_embeds(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="The MimiModel does not have `inputs_embeds` logics")
|
||||
def test_model_get_set_embeddings(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="The MimiModel does not have the usual `attention` logic")
|
||||
def test_retain_grad_hidden_states_attentions(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="The MimiModel does not have the usual `attention` logic")
|
||||
def test_torchscript_output_attentions(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="The MimiModel does not have the usual `hidden_states` logic")
|
||||
def test_torchscript_output_hidden_state(self):
|
||||
pass
|
||||
|
||||
# Copied from transformers.tests.encodec.test_modeling_encodec.MimiModelTest._create_and_check_torchscript
|
||||
def _create_and_check_torchscript(self, config, inputs_dict):
|
||||
if not self.test_torchscript:
|
||||
self.skipTest(reason="test_torchscript is set to False")
|
||||
|
||||
configs_no_init = _config_zero_init(config) # To be sure we have no Nan
|
||||
configs_no_init.torchscript = True
|
||||
configs_no_init.return_dict = False
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config=configs_no_init)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
|
||||
main_input_name = model_class.main_input_name
|
||||
|
||||
try:
|
||||
main_input = inputs[main_input_name]
|
||||
model(main_input)
|
||||
traced_model = torch.jit.trace(model, main_input)
|
||||
except RuntimeError:
|
||||
self.fail("Couldn't trace module.")
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir_name:
|
||||
pt_file_name = os.path.join(tmp_dir_name, "traced_model.pt")
|
||||
|
||||
try:
|
||||
torch.jit.save(traced_model, pt_file_name)
|
||||
except Exception:
|
||||
self.fail("Couldn't save module.")
|
||||
|
||||
try:
|
||||
loaded_model = torch.jit.load(pt_file_name)
|
||||
except Exception:
|
||||
self.fail("Couldn't load module.")
|
||||
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
loaded_model.to(torch_device)
|
||||
loaded_model.eval()
|
||||
|
||||
model_state_dict = model.state_dict()
|
||||
loaded_model_state_dict = loaded_model.state_dict()
|
||||
|
||||
non_persistent_buffers = {}
|
||||
for key in loaded_model_state_dict.keys():
|
||||
if key not in model_state_dict.keys():
|
||||
non_persistent_buffers[key] = loaded_model_state_dict[key]
|
||||
|
||||
loaded_model_state_dict = {
|
||||
key: value for key, value in loaded_model_state_dict.items() if key not in non_persistent_buffers
|
||||
}
|
||||
|
||||
self.assertEqual(set(model_state_dict.keys()), set(loaded_model_state_dict.keys()))
|
||||
|
||||
model_buffers = list(model.buffers())
|
||||
for non_persistent_buffer in non_persistent_buffers.values():
|
||||
found_buffer = False
|
||||
for i, model_buffer in enumerate(model_buffers):
|
||||
if torch.equal(non_persistent_buffer, model_buffer):
|
||||
found_buffer = True
|
||||
break
|
||||
|
||||
self.assertTrue(found_buffer)
|
||||
model_buffers.pop(i)
|
||||
|
||||
model_buffers = list(model.buffers())
|
||||
for non_persistent_buffer in non_persistent_buffers.values():
|
||||
found_buffer = False
|
||||
for i, model_buffer in enumerate(model_buffers):
|
||||
if torch.equal(non_persistent_buffer, model_buffer):
|
||||
found_buffer = True
|
||||
break
|
||||
|
||||
self.assertTrue(found_buffer)
|
||||
model_buffers.pop(i)
|
||||
|
||||
models_equal = True
|
||||
for layer_name, p1 in model_state_dict.items():
|
||||
if layer_name in loaded_model_state_dict:
|
||||
p2 = loaded_model_state_dict[layer_name]
|
||||
if p1.data.ne(p2.data).sum() > 0:
|
||||
models_equal = False
|
||||
|
||||
self.assertTrue(models_equal)
|
||||
|
||||
# Avoid memory leak. Without this, each call increase RAM usage by ~20MB.
|
||||
# (Even with this call, there are still memory leak by ~0.04MB)
|
||||
self.clear_torch_jit_class_registry()
|
||||
|
||||
@unittest.skip(reason="The MimiModel does not have the usual `attention` logic")
|
||||
def test_attention_outputs(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="The MimiModel does not have the usual `hidden_states` logic")
|
||||
def test_hidden_states_output(self):
|
||||
pass
|
||||
|
||||
# Copied from transformers.tests.encodec.test_modeling_encodec.MimiModelTest.test_determinism
|
||||
def test_determinism(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
def check_determinism(first, second):
|
||||
# outputs are not tensors but list (since each sequence don't have the same frame_length)
|
||||
out_1 = first.cpu().numpy()
|
||||
out_2 = second.cpu().numpy()
|
||||
out_1 = out_1[~np.isnan(out_1)]
|
||||
out_2 = out_2[~np.isnan(out_2)]
|
||||
max_diff = np.amax(np.abs(out_1 - out_2))
|
||||
self.assertLessEqual(max_diff, 1e-5)
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
first = model(**self._prepare_for_class(inputs_dict, model_class))[0]
|
||||
second = model(**self._prepare_for_class(inputs_dict, model_class))[0]
|
||||
|
||||
if isinstance(first, tuple) and isinstance(second, tuple):
|
||||
for tensor1, tensor2 in zip(first, second):
|
||||
check_determinism(tensor1, tensor2)
|
||||
else:
|
||||
check_determinism(first, second)
|
||||
|
||||
# Copied from transformers.tests.encodec.test_modeling_encodec.MimiModelTest.test_model_outputs_equivalence
|
||||
def test_model_outputs_equivalence(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
def set_nan_tensor_to_zero(t):
|
||||
t[t != t] = 0
|
||||
return t
|
||||
|
||||
def check_equivalence(model, tuple_inputs, dict_inputs, additional_kwargs={}):
|
||||
with torch.no_grad():
|
||||
tuple_output = model(**tuple_inputs, return_dict=False, **additional_kwargs)
|
||||
dict_output = model(**dict_inputs, return_dict=True, **additional_kwargs)
|
||||
|
||||
self.assertTrue(isinstance(tuple_output, tuple))
|
||||
self.assertTrue(isinstance(dict_output, dict))
|
||||
|
||||
for tuple_value, dict_value in zip(tuple_output, dict_output.values()):
|
||||
self.assertTrue(
|
||||
torch.allclose(
|
||||
set_nan_tensor_to_zero(tuple_value), set_nan_tensor_to_zero(dict_value), atol=1e-5
|
||||
),
|
||||
msg=(
|
||||
"Tuple and dict output are not equal. Difference:"
|
||||
f" {torch.max(torch.abs(tuple_value - dict_value))}. Tuple has `nan`:"
|
||||
f" {torch.isnan(tuple_value).any()} and `inf`: {torch.isinf(tuple_value)}. Dict has"
|
||||
f" `nan`: {torch.isnan(dict_value).any()} and `inf`: {torch.isinf(dict_value)}."
|
||||
),
|
||||
)
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
tuple_inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
dict_inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
check_equivalence(model, tuple_inputs, dict_inputs)
|
||||
|
||||
def test_initialization(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
configs_no_init = _config_zero_init(config)
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config=configs_no_init)
|
||||
for name, param in model.named_parameters():
|
||||
uniform_init_parms = ["conv", "input_proj", "output_proj"]
|
||||
if param.requires_grad:
|
||||
if any(x in name for x in uniform_init_parms):
|
||||
self.assertTrue(
|
||||
-1.0 <= ((param.data.mean() * 1e9).round() / 1e9).item() <= 1.0,
|
||||
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
|
||||
)
|
||||
|
||||
# Copied from transformers.tests.encodec.test_modeling_encodec.MimiModelTest.test_identity_shortcut
|
||||
def test_identity_shortcut(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs()
|
||||
config.use_conv_shortcut = False
|
||||
self.model_tester.create_and_check_model_forward(config, inputs_dict)
|
||||
|
||||
@parameterized.expand([("float16",), ("bfloat16",), ("float32",)])
|
||||
@require_torch_sdpa
|
||||
@slow
|
||||
def test_eager_matches_sdpa_inference(self, torch_dtype: str):
|
||||
if not self.has_attentions:
|
||||
self.skipTest(reason="Model architecture does not support attentions")
|
||||
|
||||
if not self.all_model_classes[0]._supports_sdpa:
|
||||
self.skipTest(f"{self.all_model_classes[0].__name__} does not support SDPA")
|
||||
|
||||
if torch_dtype == "float16" and not is_torch_fp16_available_on_device(torch_device):
|
||||
self.skipTest(f"float16 not supported on {torch_device} (on the specific device currently used)")
|
||||
|
||||
if torch_dtype == "bfloat16" and not is_torch_bf16_available_on_device(torch_device):
|
||||
self.skipTest(
|
||||
f"bfloat16 not supported on {torch_device} (on the specific device currently used, e.g. Nvidia T4 GPU)"
|
||||
)
|
||||
|
||||
# Not sure whether it's fine to put torch.XXX in a decorator if torch is not available so hacking it here instead.
|
||||
if torch_dtype == "float16":
|
||||
torch_dtype = torch.float16
|
||||
elif torch_dtype == "bfloat16":
|
||||
torch_dtype = torch.bfloat16
|
||||
elif torch_dtype == "float32":
|
||||
torch_dtype = torch.float32
|
||||
|
||||
atols = {
|
||||
("cpu", False, torch.float32): 1e-6,
|
||||
("cpu", False, torch.bfloat16): 1e-2,
|
||||
("cpu", True, torch.float32): 1e-6,
|
||||
("cpu", True, torch.bfloat16): 1e-2,
|
||||
("cuda", False, torch.float32): 1e-6,
|
||||
("cuda", False, torch.bfloat16): 1e-2,
|
||||
("cuda", False, torch.float16): 5e-3,
|
||||
("cuda", True, torch.float32): 1e-6,
|
||||
("cuda", True, torch.bfloat16): 1e-2,
|
||||
("cuda", True, torch.float16): 5e-3,
|
||||
}
|
||||
rtols = {
|
||||
("cpu", False, torch.float32): 1e-4,
|
||||
("cpu", False, torch.bfloat16): 1e-2,
|
||||
("cpu", True, torch.float32): 1e-4,
|
||||
("cpu", True, torch.bfloat16): 1e-2,
|
||||
("cuda", False, torch.float32): 1e-4,
|
||||
("cuda", False, torch.bfloat16): 1e-2,
|
||||
("cuda", False, torch.float16): 5e-3,
|
||||
("cuda", True, torch.float32): 1e-4,
|
||||
("cuda", True, torch.bfloat16): 3e-2,
|
||||
("cuda", True, torch.float16): 5e-3,
|
||||
}
|
||||
|
||||
def get_mean_reldiff(failcase, x, ref, atol, rtol):
|
||||
return f"{failcase}: mean relative difference: {((x - ref).abs() / (ref.abs() + 1e-12)).mean():.3e}, torch atol = {atol}, torch rtol = {rtol}"
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
model = model_class(config)
|
||||
# FIXME: we deactivate boolean mask for models using "use_mask_token" in their constructors.
|
||||
# These models support masking only in the case `use_mask_token=True`. Otherwise they cannot consume an input mask.
|
||||
# This means that the class needs to be instantiated much later, after `use_mask` is set, which means a significant refactor of the code.
|
||||
# However masking there is not done at any layers that matters (i.e self-attention), therefore we can safely deactivate it.
|
||||
deactivate_mask = "use_mask_token" in inspect.signature(model_class).parameters
|
||||
|
||||
is_encoder_decoder = model.config.is_encoder_decoder
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname)
|
||||
model_sdpa = model_class.from_pretrained(tmpdirname, torch_dtype=torch_dtype)
|
||||
model_sdpa = model_sdpa.eval().to(torch_device)
|
||||
|
||||
self.assertTrue(model_sdpa.config._attn_implementation == "sdpa")
|
||||
|
||||
model_eager = model_class.from_pretrained(
|
||||
tmpdirname,
|
||||
torch_dtype=torch_dtype,
|
||||
attn_implementation="eager",
|
||||
)
|
||||
model_eager = model_eager.eval().to(torch_device)
|
||||
|
||||
self.assertTrue(model_eager.config._attn_implementation == "eager")
|
||||
|
||||
for name, submodule in model_eager.named_modules():
|
||||
class_name = submodule.__class__.__name__
|
||||
if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
|
||||
raise ValueError("The eager model should not have SDPA attention layers")
|
||||
|
||||
has_sdpa = False
|
||||
for name, submodule in model_sdpa.named_modules():
|
||||
class_name = submodule.__class__.__name__
|
||||
if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
|
||||
has_sdpa = True
|
||||
break
|
||||
if not has_sdpa and model_sdpa.config.model_type != "falcon":
|
||||
raise ValueError("The SDPA model should have SDPA attention layers")
|
||||
|
||||
# We use these for loops instead of parameterized.expand just for the interest of avoiding loading/saving 16 times the model,
|
||||
# but it would be nicer to have an efficient way to use parameterized.expand
|
||||
fail_cases = []
|
||||
for padding_side in ["left", "right"]:
|
||||
for use_mask in [False, True]:
|
||||
for output_attentions in [True, False]:
|
||||
can_output_attn = "output_attentions" in inspect.signature(model_sdpa.forward).parameters
|
||||
if not (self.has_attentions and can_output_attn) and output_attentions:
|
||||
continue
|
||||
for batch_size in [1, 5]:
|
||||
dummy_input = inputs_dict[model.main_input_name]
|
||||
|
||||
if dummy_input.dtype in [torch.float32, torch.bfloat16, torch.float16]:
|
||||
dummy_input = dummy_input.to(torch_dtype)
|
||||
|
||||
dummy_input = dummy_input[:batch_size]
|
||||
if dummy_input.shape[0] != batch_size:
|
||||
if dummy_input.dtype in [torch.float32, torch.bfloat16, torch.float16]:
|
||||
extension = torch.rand(
|
||||
batch_size - dummy_input.shape[0],
|
||||
*dummy_input.shape[1:],
|
||||
dtype=torch_dtype,
|
||||
device=torch_device,
|
||||
)
|
||||
dummy_input = torch.cat((dummy_input, extension), dim=0).to(torch_device)
|
||||
else:
|
||||
extension = torch.randint(
|
||||
high=5,
|
||||
size=(batch_size - dummy_input.shape[0], *dummy_input.shape[1:]),
|
||||
dtype=dummy_input.dtype,
|
||||
device=torch_device,
|
||||
)
|
||||
dummy_input = torch.cat((dummy_input, extension), dim=0).to(torch_device)
|
||||
|
||||
if not use_mask:
|
||||
dummy_attention_mask = None
|
||||
else:
|
||||
dummy_attention_mask = inputs_dict.get("attention_mask", None)
|
||||
if dummy_attention_mask is None:
|
||||
if is_encoder_decoder:
|
||||
seqlen = inputs_dict.get("decoder_input_ids", dummy_input).shape[-1]
|
||||
else:
|
||||
seqlen = dummy_input.shape[-1]
|
||||
dummy_attention_mask = (
|
||||
torch.ones(batch_size, seqlen).to(torch.int64).to(torch_device)
|
||||
)
|
||||
|
||||
dummy_attention_mask = dummy_attention_mask[:batch_size]
|
||||
if dummy_attention_mask.shape[0] != batch_size:
|
||||
extension = torch.ones(
|
||||
batch_size - dummy_attention_mask.shape[0],
|
||||
*dummy_attention_mask.shape[1:],
|
||||
dtype=dummy_attention_mask.dtype,
|
||||
device=torch_device,
|
||||
)
|
||||
dummy_attention_mask = torch.cat((dummy_attention_mask, extension), dim=0)
|
||||
dummy_attention_mask = dummy_attention_mask.to(torch_device)
|
||||
|
||||
dummy_attention_mask[:] = 1
|
||||
if padding_side == "left":
|
||||
dummy_attention_mask[-1, :-1] = 1
|
||||
dummy_attention_mask[-1, -4:] = 0
|
||||
elif padding_side == "right":
|
||||
dummy_attention_mask[-1, 1:] = 1
|
||||
dummy_attention_mask[-1, :3] = 0
|
||||
|
||||
for enable_kernels in [False, True]:
|
||||
failcase = f"padding_side={padding_side}, use_mask={use_mask}, batch_size={batch_size}, enable_kernels={enable_kernels}"
|
||||
if is_encoder_decoder:
|
||||
decoder_input_ids = inputs_dict.get("decoder_input_ids", dummy_input)[
|
||||
:batch_size
|
||||
]
|
||||
if decoder_input_ids.shape[0] != batch_size:
|
||||
extension = torch.ones(
|
||||
batch_size - decoder_input_ids.shape[0],
|
||||
*decoder_input_ids.shape[1:],
|
||||
dtype=decoder_input_ids.dtype,
|
||||
device=torch_device,
|
||||
)
|
||||
decoder_input_ids = torch.cat((decoder_input_ids, extension), dim=0)
|
||||
decoder_input_ids = decoder_input_ids.to(torch_device)
|
||||
|
||||
# TODO: never an `attention_mask` arg here?
|
||||
processed_inputs = {
|
||||
model.main_input_name: dummy_input,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"decoder_attention_mask": dummy_attention_mask,
|
||||
"output_hidden_states": True,
|
||||
}
|
||||
else:
|
||||
processed_inputs = {
|
||||
model.main_input_name: dummy_input,
|
||||
"output_hidden_states": True,
|
||||
}
|
||||
|
||||
# Otherwise fails for e.g. WhisperEncoderModel
|
||||
if "attention_mask" in inspect.signature(model_eager.forward).parameters:
|
||||
processed_inputs["attention_mask"] = dummy_attention_mask
|
||||
|
||||
if (
|
||||
self.has_attentions
|
||||
and "output_attentions" in inspect.signature(model_sdpa.forward).parameters
|
||||
):
|
||||
processed_inputs["output_attentions"] = output_attentions
|
||||
if not deactivate_mask and (
|
||||
"bool_masked_pos" in inspect.signature(model_eager.forward).parameters
|
||||
):
|
||||
dummy_mask = torch.ones((self.model_tester.num_masks,))
|
||||
|
||||
# In case of additional token (like class) we define a custom `mask_length`
|
||||
if hasattr(self.model_tester, "mask_length"):
|
||||
mask_length = self.model_tester.mask_length - dummy_mask.size(0)
|
||||
else:
|
||||
mask_length = self.model_tester.seq_length - dummy_mask.size(0)
|
||||
dummy_mask = torch.cat([dummy_mask, torch.zeros(mask_length)])
|
||||
dummy_bool_masked_pos = dummy_mask.expand(batch_size, -1).bool()
|
||||
processed_inputs["bool_masked_pos"] = dummy_bool_masked_pos.to(torch_device)
|
||||
|
||||
if "noise" in inspect.signature(model_eager.forward).parameters:
|
||||
np.random.seed(2)
|
||||
num_patches = int(
|
||||
(self.model_tester.image_size // self.model_tester.patch_size) ** 2
|
||||
)
|
||||
noise = np.random.uniform(size=(batch_size, num_patches))
|
||||
processed_inputs["noise"] = torch.from_numpy(noise)
|
||||
|
||||
# TODO: test gradients as well (& for FA2 as well!)
|
||||
with torch.no_grad():
|
||||
with torch.backends.cuda.sdp_kernel(
|
||||
enable_flash=enable_kernels,
|
||||
enable_math=True,
|
||||
enable_mem_efficient=enable_kernels,
|
||||
):
|
||||
prepared_inputs = self._prepare_for_class(processed_inputs, model_class)
|
||||
outputs_eager = model_eager(**prepared_inputs)
|
||||
outputs_sdpa = model_sdpa(**prepared_inputs)
|
||||
|
||||
# Ignore copy
|
||||
logits_eager = outputs_eager.audio_values
|
||||
# Ignore copy
|
||||
logits_sdpa = outputs_sdpa.audio_values
|
||||
|
||||
if torch_device in ["cpu", "cuda"]:
|
||||
atol = atols[torch_device, enable_kernels, torch_dtype]
|
||||
rtol = rtols[torch_device, enable_kernels, torch_dtype]
|
||||
else:
|
||||
atol = 1e-7
|
||||
rtol = 1e-4
|
||||
|
||||
# Masked tokens output slightly deviates - we don't mind that.
|
||||
if use_mask:
|
||||
if padding_side == "left":
|
||||
sub_sdpa = logits_sdpa[:-1]
|
||||
sub_eager = logits_eager[:-1]
|
||||
if not torch.allclose(sub_sdpa, sub_eager, atol=atol, rtol=rtol):
|
||||
fail_cases.append(
|
||||
get_mean_reldiff(failcase, sub_sdpa, sub_eager, atol, rtol)
|
||||
)
|
||||
|
||||
sub_sdpa = logits_sdpa[-1, :-4]
|
||||
sub_eager = logits_eager[-1, :-4]
|
||||
if not torch.allclose(sub_sdpa, sub_eager, atol=atol, rtol=rtol):
|
||||
fail_cases.append(
|
||||
get_mean_reldiff(failcase, sub_sdpa, sub_eager, atol, rtol)
|
||||
)
|
||||
|
||||
# Testing the padding tokens is not really meaningful but anyway
|
||||
# sub_sdpa = logits_sdpa[-1, -4:]
|
||||
# sub_eager = logits_eager[-1, -4:]
|
||||
# if not torch.allclose(sub_sdpa, sub_eager, atol=atol, rtol=rtol):
|
||||
# fail_cases.append(get_mean_reldiff(failcase, sub_sdpa, sub_eager, 4e-2, 4e-2))
|
||||
elif padding_side == "right":
|
||||
sub_sdpa = logits_sdpa[:-1]
|
||||
sub_eager = logits_eager[:-1]
|
||||
if not torch.allclose(sub_sdpa, sub_eager, atol=atol, rtol=rtol):
|
||||
fail_cases.append(
|
||||
get_mean_reldiff(failcase, sub_sdpa, sub_eager, atol, rtol)
|
||||
)
|
||||
|
||||
sub_sdpa = logits_sdpa[-1, 3:]
|
||||
sub_eager = logits_eager[-1, 3:]
|
||||
if not torch.allclose(sub_sdpa, sub_eager, atol=atol, rtol=rtol):
|
||||
fail_cases.append(
|
||||
get_mean_reldiff(failcase, sub_sdpa, sub_eager, atol, rtol)
|
||||
)
|
||||
|
||||
# Testing the padding tokens is not really meaningful but anyway
|
||||
# sub_sdpa = logits_sdpa[-1, :3]
|
||||
# sub_eager = logits_eager[-1, :3]
|
||||
# if not torch.allclose(sub_sdpa, sub_eager, atol=atol, rtol=rtol):
|
||||
# fail_cases.append(get_mean_reldiff(failcase, sub_sdpa, sub_eager, 4e-2, 4e-2))
|
||||
|
||||
else:
|
||||
if not torch.allclose(logits_sdpa, logits_eager, atol=atol, rtol=rtol):
|
||||
fail_cases.append(
|
||||
get_mean_reldiff(failcase, logits_sdpa, logits_eager, atol, rtol)
|
||||
)
|
||||
|
||||
self.assertTrue(len(fail_cases) == 0, "\n".join(fail_cases))
|
||||
|
||||
@require_flash_attn
|
||||
@require_torch_gpu
|
||||
@mark.flash_attn_test
|
||||
@slow
|
||||
@is_flaky()
|
||||
def test_flash_attn_2_inference_equivalence(self):
|
||||
for model_class in self.all_model_classes:
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
model = model_class(config)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname)
|
||||
model_fa = model_class.from_pretrained(
|
||||
tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2"
|
||||
)
|
||||
model_fa.to(torch_device)
|
||||
|
||||
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.bfloat16)
|
||||
model.to(torch_device)
|
||||
|
||||
dummy_input = inputs_dict[model.main_input_name][:1]
|
||||
if dummy_input.dtype in [torch.float32, torch.float16]:
|
||||
dummy_input = dummy_input.to(torch.bfloat16)
|
||||
|
||||
outputs = model(dummy_input)
|
||||
outputs_fa = model_fa(dummy_input)
|
||||
|
||||
logits = outputs[1]
|
||||
logits_fa = outputs_fa[1]
|
||||
|
||||
assert torch.allclose(logits_fa, logits, atol=4e-2, rtol=4e-2)
|
||||
|
||||
@unittest.skip(reason="The MimiModel does not support right padding")
|
||||
def test_flash_attn_2_inference_equivalence_right_padding(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="The MimiModel does not have support dynamic compile yet")
|
||||
def test_sdpa_can_compile_dynamic(self):
|
||||
pass
|
||||
|
||||
# For now, Let's focus only on GPU for `torch.compile`
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
def test_torch_compile(self):
|
||||
if version.parse(torch.__version__) < version.parse("2.3"):
|
||||
self.skipTest(reason="This test requires torch >= 2.3 to run.")
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
n_iter = 3
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config).to(torch_device)
|
||||
model.forward = torch.compile(model.forward)
|
||||
for i in range(n_iter):
|
||||
_ = model(inputs_dict["input_values"].to(torch_device))
|
||||
|
||||
@is_flaky()
|
||||
def test_batching_equivalence(self):
|
||||
super().test_batching_equivalence()
|
||||
|
||||
|
||||
# Copied from transformers.tests.encodec.test_modeling_encodec.normalize
|
||||
def normalize(arr):
|
||||
norm = np.linalg.norm(arr)
|
||||
normalized_arr = arr / norm
|
||||
return normalized_arr
|
||||
|
||||
|
||||
# Copied from transformers.tests.encodec.test_modeling_encodec.compute_rmse
|
||||
def compute_rmse(arr1, arr2):
|
||||
arr1_normalized = normalize(arr1)
|
||||
arr2_normalized = normalize(arr2)
|
||||
return np.sqrt(((arr1_normalized - arr2_normalized) ** 2).mean())
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
class MimiIntegrationTest(unittest.TestCase):
|
||||
def test_integration_using_cache_decode(self):
|
||||
expected_rmse = {
|
||||
"8": 0.0018785292,
|
||||
"32": 0.0012330565,
|
||||
}
|
||||
|
||||
librispeech_dummy = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
||||
model_id = "kyutai/mimi"
|
||||
|
||||
model = MimiModel.from_pretrained(model_id, use_cache=True).to(torch_device)
|
||||
processor = AutoFeatureExtractor.from_pretrained(model_id)
|
||||
|
||||
librispeech_dummy = librispeech_dummy.cast_column("audio", Audio(sampling_rate=processor.sampling_rate))
|
||||
audio_sample = librispeech_dummy[-1]["audio"]["array"]
|
||||
|
||||
inputs = processor(
|
||||
raw_audio=audio_sample,
|
||||
sampling_rate=processor.sampling_rate,
|
||||
return_tensors="pt",
|
||||
).to(torch_device)
|
||||
|
||||
for num_codebooks, expected_rmse in expected_rmse.items():
|
||||
with torch.no_grad():
|
||||
# use max bandwith for best possible reconstruction
|
||||
encoder_outputs = model.encode(inputs["input_values"], num_quantizers=int(num_codebooks))
|
||||
|
||||
audio_codes = encoder_outputs[0]
|
||||
|
||||
decoder_outputs_first_part = model.decode(audio_codes[:, :, : audio_codes.shape[2] // 2])
|
||||
decoder_outputs_second_part = model.decode(
|
||||
audio_codes[:, :, audio_codes.shape[2] // 2 :],
|
||||
decoder_past_key_values=decoder_outputs_first_part.decoder_past_key_values,
|
||||
)
|
||||
|
||||
audio_output_entire_context = model.decode(audio_codes)[0]
|
||||
audio_output_concat_context = torch.cat(
|
||||
[decoder_outputs_first_part[0], decoder_outputs_second_part[0]], dim=2
|
||||
)
|
||||
|
||||
# make sure audios are more or less equal
|
||||
# the RMSE of two random gaussian noise vectors with ~N(0, 1) is around 1.0
|
||||
rmse = compute_rmse(
|
||||
audio_output_concat_context.squeeze().cpu().numpy(),
|
||||
audio_output_entire_context.squeeze().cpu().numpy(),
|
||||
)
|
||||
self.assertTrue(rmse < 1e-3)
|
||||
|
||||
def test_integration(self):
|
||||
expected_rmses = {
|
||||
"8": 0.0018785292,
|
||||
"32": 0.0012330565,
|
||||
}
|
||||
expected_codesums = {
|
||||
"8": 430423,
|
||||
"32": 1803071,
|
||||
}
|
||||
librispeech_dummy = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
||||
model_id = "kyutai/mimi"
|
||||
|
||||
processor = AutoFeatureExtractor.from_pretrained(model_id)
|
||||
|
||||
librispeech_dummy = librispeech_dummy.cast_column("audio", Audio(sampling_rate=processor.sampling_rate))
|
||||
audio_sample = librispeech_dummy[-1]["audio"]["array"]
|
||||
|
||||
inputs = processor(
|
||||
raw_audio=audio_sample,
|
||||
sampling_rate=processor.sampling_rate,
|
||||
return_tensors="pt",
|
||||
).to(torch_device)
|
||||
|
||||
for use_cache in [False, True]:
|
||||
model = MimiModel.from_pretrained(model_id, use_cache=use_cache).to(torch_device)
|
||||
for num_codebooks, expected_rmse in expected_rmses.items():
|
||||
with torch.no_grad():
|
||||
# use max bandwith for best possible reconstruction
|
||||
encoder_outputs = model.encode(inputs["input_values"], num_quantizers=int(num_codebooks))
|
||||
|
||||
audio_code_sums = encoder_outputs[0].sum().cpu().item()
|
||||
|
||||
# make sure audio encoded codes are correct
|
||||
# assert relative difference less than a threshold, because `audio_code_sums` varies a bit
|
||||
# depending on torch version
|
||||
self.assertTrue(
|
||||
np.abs(audio_code_sums - expected_codesums[num_codebooks]) <= (3e-3 * audio_code_sums)
|
||||
)
|
||||
|
||||
input_values_dec = model.decode(encoder_outputs[0], padding_mask=inputs["padding_mask"])[0]
|
||||
input_values_enc_dec = model(
|
||||
inputs["input_values"], inputs["padding_mask"], num_quantizers=int(num_codebooks)
|
||||
)[1]
|
||||
|
||||
# make sure forward and decode gives same result
|
||||
self.assertTrue(torch.allclose(input_values_dec, input_values_enc_dec))
|
||||
|
||||
# make sure shape matches
|
||||
self.assertTrue(inputs["input_values"].shape == input_values_enc_dec.shape)
|
||||
|
||||
arr = inputs["input_values"][0].cpu().numpy()
|
||||
arr_enc_dec = input_values_enc_dec[0].cpu().numpy()
|
||||
|
||||
# make sure audios are more or less equal
|
||||
# the RMSE of two random gaussian noise vectors with ~N(0, 1) is around 1.0
|
||||
rmse = compute_rmse(arr, arr_enc_dec)
|
||||
self.assertTrue(np.abs(rmse - expected_rmse) < 1e-5)
|
Loading…
Reference in New Issue
Block a user