diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 02595f30db2..016d7279353 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -740,6 +740,8 @@ title: Mimi - local: model_doc/mms title: MMS + - local: model_doc/moshi + title: Moshi - local: model_doc/musicgen title: MusicGen - local: model_doc/musicgen_melody diff --git a/docs/source/en/index.md b/docs/source/en/index.md index 32a730e6bcf..bdea11a2456 100644 --- a/docs/source/en/index.md +++ b/docs/source/en/index.md @@ -223,6 +223,7 @@ Flax), PyTorch, and/or TensorFlow. | [MobileNetV2](model_doc/mobilenet_v2) | ✅ | ❌ | ❌ | | [MobileViT](model_doc/mobilevit) | ✅ | ✅ | ❌ | | [MobileViTV2](model_doc/mobilevitv2) | ✅ | ❌ | ❌ | +| [Moshi](model_doc/moshi) | ✅ | ❌ | ❌ | | [MPNet](model_doc/mpnet) | ✅ | ✅ | ❌ | | [MPT](model_doc/mpt) | ✅ | ❌ | ❌ | | [MRA](model_doc/mra) | ✅ | ❌ | ❌ | diff --git a/docs/source/en/model_doc/mimi.md b/docs/source/en/model_doc/mimi.md index 486d1836334..ad15a002da9 100644 --- a/docs/source/en/model_doc/mimi.md +++ b/docs/source/en/model_doc/mimi.md @@ -66,4 +66,4 @@ The original code can be found [here](https://github.com/kyutai-labs/moshi). [[autodoc]] MimiModel - decode - encode - - forward + - forward \ No newline at end of file diff --git a/docs/source/en/model_doc/moshi.md b/docs/source/en/model_doc/moshi.md new file mode 100644 index 00000000000..64216f570e3 --- /dev/null +++ b/docs/source/en/model_doc/moshi.md @@ -0,0 +1,183 @@ + + +# Moshi + +## Overview + +The Moshi model was proposed in [Moshi: a speech-text foundation model for real-time dialogue](https://kyutai.org/Moshi.pdf) by Alexandre Défossez, Laurent Mazaré, Manu Orsini, Amélie Royer, Patrick Pérez, Hervé Jégou, Edouard Grave and Neil Zeghidour. + +Moshi is a speech-text foundation model that casts spoken dialogue as speech-to-speech generation. Starting from a text language model backbone, Moshi generates speech as tokens from the residual quantizer of a neural audio codec, while modeling separately its own speech and that of the user into parallel streams. This allows for the removal of explicit speaker turns, and the modeling of arbitrary conversational dynamics. Moshi also predicts time-aligned text tokens as a prefix to audio tokens. This “Inner Monologue” method significantly improves the linguistic quality of generated speech and provides streaming speech recognition and text-to-speech. As a result, Moshi is the first real-time full-duplex spoken large language model, with a theoretical latency of 160ms, 200ms in practice. + +
+ +
+ +The abstract from the paper is the following: + +*We introduce Moshi, a speech-text foundation model and full-duplex spoken dialogue framework. Current systems for spoken dialogue rely on pipelines of independent components, namely voice activity detection, speech recognition, textual dialogue and text-to-speech. Such frameworks cannot emulate the experience of real conversations. First, their complexity induces a latency of several seconds between interactions. Second, text being the intermediate modality for dialogue, non-linguistic information that modifies meaning— such as emotion or non-speech sounds— is lost in the interaction. Finally, they rely on a segmentation into speaker turns, which does not take into account overlapping speech, interruptions and interjections. Moshi solves these independent issues altogether by casting spoken dialogue as speech-to-speech generation. Starting from a text language model backbone, Moshi generates speech as tokens from the residual quantizer of a neural audio codec, while modeling separately its own speech and that of the user into parallel streams. This allows for the removal of explicit speaker turns, and the modeling of arbitrary conversational dynamics. We moreover extend the hierarchical semantic-to-acoustic token generation of previous work to first predict time-aligned text tokens as a prefix to audio tokens. Not only this “Inner Monologue” method significantly improves the linguistic quality of generated speech, but we also illustrate how it can provide streaming speech recognition and text-to-speech. Our resulting model is the first real-time full-duplex spoken large language model, with a theoretical latency of 160ms, 200ms in practice, and is available at github.com/kyutai-labs/moshi.* + +Moshi deals with 3 streams of information: +1. The user's audio +2. Moshi's audio +3. Moshi's textual output + +Similarly to [`~MusicgenModel`], audio is represented with audio codebooks, which can be interpreted like tokens. The main difference between text tokens and audio codebooks is that audio codebooks introduce an additional dimension of information. +Text tokens are typically of dim `(batch_size, sequence_length)` but audio tokens are of dim `(batch_size, num_codebooks, sequence_length)`. + +Moshi's made of 3 components: + +**1. The main decoder (Helium in the paper)** + +It corresponds to [`MoshiForCausalLM`]. It is strictly a classic text LLM, that uses an architecture similar to [` ~GemmaForCausalLM`]. In other words, it takes text tokens, embeds them, pass them through the decoder and a language head, to get text logits. + +**2. The depth decoder** + +On its own, it's also a classic LLM, but this time, instead of generating over the time dimension, it generates over the codebook dimension. + +It also means that its context length is `num_codebooks`, thus it can't generate more than `num_codebooks`. + +Note that each timestamp - i.e each codebook - gets its own set of Linear Layers and Embeddings. + +**3. [`MimiModel`]** + +It's the audio encoder from Kyutai, that has recently been integrated to transformers, which is used to "tokenize" audio. It has the same use that [`~EncodecModel`] has in [`~MusicgenModel`]. + + +## Tips: + +The original checkpoints can be converted using the conversion script `src/transformers/models/moshi/convert_moshi_transformers.py` + + +### How to use the model: + +This implementation has two main aims: +1. quickly test model generation by simplifying the original API +2. simplify training. A training guide will come soon, but user contributions are welcomed! + + + +It is designed for intermediate use. We strongly recommend using the original [implementation](https://github.com/kyutai-labs/moshi) to infer the model in real-time streaming. + + + +**1. Model generation** + +Moshi is a streaming auto-regressive model with two streams of audio. To put it differently, one audio stream corresponds to what the model said/will say and the other audio stream corresponds to what the user said/will say. + +[`MoshiForConditionalGeneration.generate`] thus needs 3 inputs: +1. `input_ids` - corresponding to the text token history +2. `moshi_input_values` or `moshi_audio_codes`- corresponding to the model audio history +3. `user_input_values` or `user_audio_codes` - corresponding to the user audio history + +These three inputs must be synchronized. Meaning that their lengths must correspond to the same number of tokens. + +You can dynamically use the 3 inputs depending on what you want to test: +1. Simply check the model response to an user prompt - in that case, `input_ids` can be filled with pad tokens and `user_input_values` can be a zero tensor of the same shape than the user prompt. +2. Test more complex behaviour - in that case, you must be careful about how the input tokens are synchronized with the audios. + + + +The original model is synchronized text with audio by padding the text in between each token enunciation. + +To follow the example of the following image, `"Hello, I'm Moshi"` could be transformed to `"Hello,I'm Moshi"`. + + + +
+ +
+ + +[`MoshiForConditionalGeneration.generate`] then auto-regressively feeds to itself its own audio stream, but since it doesn't have access to the user input stream while using `transformers`, it will thus **assume that the user is producing blank audio**. + + + +```python +>>> from datasets import load_dataset, Audio +>>> import torch, math +>>> from transformers import MoshiForConditionalGeneration, AutoFeatureExtractor, AutoTokenizer +>>> librispeech_dummy = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") + + +>>> # prepare user input audio +>>> librispeech_dummy = librispeech_dummy.cast_column("audio", Audio(sampling_rate=feature_extractor.sampling_rate)) +>>> audio_sample = librispeech_dummy[-1]["audio"]["array"] +>>> user_input_values = feature_extractor(raw_audio=audio_sample, sampling_rate=feature_extractor.sampling_rate, return_tensors="pt").to(device=device, dtype=dtype) + +>>> # prepare moshi input values - we suppose moshi didn't say anything while the user spoke +>>> moshi_input_values = torch.zeros_like(user_input_values.input_values) + +>>> # prepare moshi input ids - we suppose moshi didn't say anything while the user spoke +>>> num_tokens = math.ceil(moshi_input_values.shape[-1] * waveform_to_token_ratio) +>>> input_ids = torch.ones((1, num_tokens), device=device, dtype=torch.int64) * tokenizer.encode("")[0] + +>>> # generate 25 new tokens (around 2s of audio) +>>> output = model.generate(input_ids=input_ids, user_input_values=user_input_values.input_values, moshi_input_values=moshi_input_values, max_new_tokens=25) + +>>> text_tokens = output.sequences +>>> audio_waveforms = output.audio_sequences +``` + +**2. Model training** + +Most of the work has to be done during data creation/pre-processing, because of the need to align/synchronize streams. + +Once it's done, you can simply forward `text_labels` and `audio_labels` to [`MoshiForConditionalGeneration.forward`], alongside the usual inputs, to get the model loss. + +A training guide will come soon, but user contributions are welcomed! + +### How does the model forward the inputs / generate: + +1. The input streams are embedded and combined into `inputs_embeds`. + +2. `inputs_embeds` is passed through the main decoder, which processes it like a normal LLM would. + +3. The main decoder outputs `text logits` but also its `last hidden state` which is called `temporal context` in the paper. + +3. The depth decoder switches the dimension on which we forward / generate (codebooks instead of time). It uses the token generated from `text logits` and the `temporal context` to auto-regressively generate audio codebooks. + + +This model was contributed by [Yoach Lacombe (ylacombe)](https://huggingface.co/ylacombe). + +The original code can be found [here](https://github.com/kyutai-labs/moshi). + + + +## MoshiConfig + +[[autodoc]] MoshiConfig + +## MoshiDepthConfig + +[[autodoc]] MoshiDepthConfig + +## MoshiModel + +[[autodoc]] MoshiModel + - forward + +## MoshiForCausalLM + +[[autodoc]] MoshiForCausalLM + - forward + +## MoshiForConditionalGeneration + +[[autodoc]] MoshiForConditionalGeneration + - forward + - generate + - get_unconditional_inputs diff --git a/docs/source/en/perf_infer_gpu_one.md b/docs/source/en/perf_infer_gpu_one.md index 82d7f50f77d..2f0e9deb841 100644 --- a/docs/source/en/perf_infer_gpu_one.md +++ b/docs/source/en/perf_infer_gpu_one.md @@ -70,6 +70,7 @@ FlashAttention-2 is currently supported for the following architectures: * [MBart](https://huggingface.co/docs/transformers/model_doc/mbart#transformers.MBartModel) * [Mistral](https://huggingface.co/docs/transformers/model_doc/mistral#transformers.MistralModel) * [Mixtral](https://huggingface.co/docs/transformers/model_doc/mixtral#transformers.MixtralModel) +* [Moshi](https://huggingface.co/docs/transformers/model_doc/moshi#transformers.MoshiModel) * [Musicgen](https://huggingface.co/docs/transformers/model_doc/musicgen#transformers.MusicgenModel) * [MusicGen Melody](https://huggingface.co/docs/transformers/model_doc/musicgen_melody#transformers.MusicgenMelodyModel) * [Nemotron](https://huggingface.co/docs/transformers/model_doc/nemotron) @@ -241,6 +242,7 @@ For now, Transformers supports SDPA inference and training for the following arc * [Mistral](https://huggingface.co/docs/transformers/model_doc/mistral#transformers.MistralModel) * [Mllama](https://huggingface.co/docs/transformers/model_doc/mllama#transformers.MllamaForConditionalGeneration) * [Mixtral](https://huggingface.co/docs/transformers/model_doc/mixtral#transformers.MixtralModel) +* [Moshi](https://huggingface.co/docs/transformers/model_doc/moshi#transformers.MoshiModel) * [Musicgen](https://huggingface.co/docs/transformers/model_doc/musicgen#transformers.MusicgenModel) * [MusicGen Melody](https://huggingface.co/docs/transformers/model_doc/musicgen_melody#transformers.MusicgenMelodyModel) * [NLLB](https://huggingface.co/docs/transformers/model_doc/nllb) diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index daffe11987e..236333fb1cb 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -590,6 +590,10 @@ _import_structure = { "models.mobilenet_v2": ["MobileNetV2Config"], "models.mobilevit": ["MobileViTConfig"], "models.mobilevitv2": ["MobileViTV2Config"], + "models.moshi": [ + "MoshiConfig", + "MoshiDepthConfig", + ], "models.mpnet": [ "MPNetConfig", "MPNetTokenizer", @@ -2783,6 +2787,14 @@ else: "MobileViTV2PreTrainedModel", ] ) + _import_structure["models.moshi"].extend( + [ + "MoshiForCausalLM", + "MoshiForConditionalGeneration", + "MoshiModel", + "MoshiPreTrainedModel", + ] + ) _import_structure["models.mpnet"].extend( [ "MPNetForMaskedLM", @@ -5448,6 +5460,10 @@ if TYPE_CHECKING: from .models.mobilevitv2 import ( MobileViTV2Config, ) + from .models.moshi import ( + MoshiConfig, + MoshiDepthConfig, + ) from .models.mpnet import ( MPNetConfig, MPNetTokenizer, @@ -7386,6 +7402,12 @@ if TYPE_CHECKING: MobileViTV2Model, MobileViTV2PreTrainedModel, ) + from .models.moshi import ( + MoshiForCausalLM, + MoshiForConditionalGeneration, + MoshiModel, + MoshiPreTrainedModel, + ) from .models.mpnet import ( MPNetForMaskedLM, MPNetForMultipleChoice, diff --git a/src/transformers/convert_slow_tokenizer.py b/src/transformers/convert_slow_tokenizer.py index 92371415918..f37f589d5d5 100644 --- a/src/transformers/convert_slow_tokenizer.py +++ b/src/transformers/convert_slow_tokenizer.py @@ -1405,6 +1405,47 @@ class MarkupLMConverter(Converter): return tokenizer +class MoshiConverter(SpmConverter): + handle_byte_fallback = True + + def __init__(self, vocab_file, model_max_length=None, **kwargs): + requires_backends(self, "protobuf") + + Converter.__init__(self, vocab_file) + + # from .utils import sentencepiece_model_pb2 as model_pb2 + model_pb2 = import_protobuf() + + m = model_pb2.ModelProto() + with open(vocab_file, "rb") as f: + m.ParseFromString(f.read()) + self.proto = m + + def normalizer(self, proto): + precompiled_charsmap = proto.normalizer_spec.precompiled_charsmap + _normalizers = [ + normalizers.Replace(" ", "▁"), + ] + if not precompiled_charsmap: + return normalizers.Sequence(_normalizers) + else: + return normalizers.Sequence([normalizers.Precompiled(precompiled_charsmap)] + _normalizers) + + def decoder(self, replacement, add_prefix_space): + sequence = [ + decoders.Replace("▁", " "), + decoders.ByteFallback(), + decoders.Fuse(), + ] + if add_prefix_space: + sequence += [decoders.Strip(content=" ", left=1)] + return decoders.Sequence(sequence) + + def pre_tokenizer(self, replacement, add_prefix_space): + prepend_scheme = "first" + return pre_tokenizers.Metaspace(replacement=replacement, prepend_scheme=prepend_scheme, split=False) + + # Copied from transformers.models.gpt2.tokenization_gpt2.bytes_to_unicode def bytes_to_unicode(): """ diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index f9ab6fce6cf..6d71b754d6f 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1594,8 +1594,10 @@ class GenerationMixin: cache_dtype = self.get_output_embeddings().weight.dtype def get_layer_device_map(execution_device_map: Optional[dict] = None): - if execution_device_map is None or len(execution_device_map) <= 1: + if execution_device_map is None: return None + elif len(execution_device_map) == 1 and "" in execution_device_map: + return {idx: execution_device_map[""] for idx in range(self.config.num_hidden_layers)} layer_device_map = {} for layer in execution_device_map: for idx in range(self.config.num_hidden_layers): diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index 804957c0a55..069c7f90564 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -161,6 +161,7 @@ from . import ( mobilenet_v2, mobilevit, mobilevitv2, + moshi, mpnet, mpt, mra, diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index 17219570684..05d6e717be2 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -179,6 +179,7 @@ CONFIG_MAPPING_NAMES = OrderedDict( ("mobilenet_v2", "MobileNetV2Config"), ("mobilevit", "MobileViTConfig"), ("mobilevitv2", "MobileViTV2Config"), + ("moshi", "MoshiConfig"), ("mpnet", "MPNetConfig"), ("mpt", "MptConfig"), ("mra", "MraConfig"), @@ -490,6 +491,7 @@ MODEL_NAMES_MAPPING = OrderedDict( ("mobilenet_v2", "MobileNetV2"), ("mobilevit", "MobileViT"), ("mobilevitv2", "MobileViTV2"), + ("moshi", "Moshi"), ("mpnet", "MPNet"), ("mpt", "MPT"), ("mra", "MRA"), diff --git a/src/transformers/models/auto/feature_extraction_auto.py b/src/transformers/models/auto/feature_extraction_auto.py index 98d679ef09c..0ddab5681f2 100644 --- a/src/transformers/models/auto/feature_extraction_auto.py +++ b/src/transformers/models/auto/feature_extraction_auto.py @@ -73,6 +73,7 @@ FEATURE_EXTRACTOR_MAPPING_NAMES = OrderedDict( ("mobilenet_v1", "MobileNetV1FeatureExtractor"), ("mobilenet_v2", "MobileNetV2FeatureExtractor"), ("mobilevit", "MobileViTFeatureExtractor"), + ("moshi", "EncodecFeatureExtractor"), ("nat", "ViTFeatureExtractor"), ("owlvit", "OwlViTFeatureExtractor"), ("perceiver", "PerceiverFeatureExtractor"), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index dbfcccaa468..5a98e761adc 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -169,6 +169,7 @@ MODEL_MAPPING_NAMES = OrderedDict( ("mobilenet_v2", "MobileNetV2Model"), ("mobilevit", "MobileViTModel"), ("mobilevitv2", "MobileViTV2Model"), + ("moshi", "MoshiModel"), ("mpnet", "MPNetModel"), ("mpt", "MptModel"), ("mra", "MraModel"), @@ -506,6 +507,7 @@ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict( ("mistral", "MistralForCausalLM"), ("mixtral", "MixtralForCausalLM"), ("mllama", "MllamaForCausalLM"), + ("moshi", "MoshiForCausalLM"), ("mpt", "MptForCausalLM"), ("musicgen", "MusicgenForCausalLM"), ("musicgen_melody", "MusicgenMelodyForCausalLM"), diff --git a/src/transformers/models/auto/tokenization_auto.py b/src/transformers/models/auto/tokenization_auto.py index 8c3a7a82a60..3a3428e0995 100644 --- a/src/transformers/models/auto/tokenization_auto.py +++ b/src/transformers/models/auto/tokenization_auto.py @@ -309,6 +309,7 @@ else: ("mllama", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)), ("mluke", ("MLukeTokenizer" if is_sentencepiece_available() else None, None)), ("mobilebert", ("MobileBertTokenizer", "MobileBertTokenizerFast" if is_tokenizers_available() else None)), + ("moshi", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)), ("mpnet", ("MPNetTokenizer", "MPNetTokenizerFast" if is_tokenizers_available() else None)), ("mpt", (None, "GPTNeoXTokenizerFast" if is_tokenizers_available() else None)), ("mra", ("RobertaTokenizer", "RobertaTokenizerFast" if is_tokenizers_available() else None)), diff --git a/src/transformers/models/moshi/__init__.py b/src/transformers/models/moshi/__init__.py new file mode 100644 index 00000000000..69da6e940ea --- /dev/null +++ b/src/transformers/models/moshi/__init__.py @@ -0,0 +1,27 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_moshi import * + from .modeling_moshi import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/src/transformers/models/moshi/configuration_moshi.py b/src/transformers/models/moshi/configuration_moshi.py new file mode 100644 index 00000000000..654e4e82a49 --- /dev/null +++ b/src/transformers/models/moshi/configuration_moshi.py @@ -0,0 +1,333 @@ +# coding=utf-8 +# Copyright 2024 Meta AI and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Moshi model configuration""" + +from ...configuration_utils import PretrainedConfig +from ...utils import logging +from ..auto.configuration_auto import AutoConfig + + +logger = logging.get_logger(__name__) + + +class MoshiDepthConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`MoshiDepthDecoder`]. It is used to instantiate a + Moshi depth decoder model according to the specified arguments, defining the Moshi depth decoder config. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + vocab_size (`int`, *optional*, defaults to 32000): + Vocabulary size of the MoshiDepthDecoder model. Defines the number of different tokens that can be + represented by the `inputs_ids` passed when calling [`MoshiDepthDecoder`]. + hidden_size (`int`, *optional*, defaults to 1024): + Dimensionality of the layers and the pooler layer of the depth decoder. + input_size (`int`, *optional*, defaults to 4096): + Dimensionality of the input hidden states. Used to connect the main decoder to the depth decoder. + num_hidden_layers (`int`, *optional*, defaults to 6): + Number of depth decoder layers. + num_attention_heads (`int`, *optional*, defaults to 16): + Number of attention heads for each attention layer in the depth decoder block. + num_key_value_heads (`int`, *optional*): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `num_attention_heads`. + audio_vocab_size (`int`, *optional*, defaults to 2048): + Vocabulary size of the audio part of model. Defines the number of different tokens that can be + represented by the `audio_codes` passed when calling the Moshi models. + max_position_embeddings (`int`, *optional*, defaults to 9): + The maximum sequence length that this model might ever be used with. Typically, set this to something large + just in case (e.g., 512 or 1024 or 2048). + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the depth decoder. + head_dim (`int`, *optional*, defaults to `hidden_size // num_attention_heads`): + The attention head dimension. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + sliding_window (`int`, *optional*, defaults to 8): + Sliding window attention window size. If not specified, will default to `8`. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + ffn_dim (`int`, *optional*, defaults to 5632): + Dimensionality of the "intermediate" (often named feed-forward) layer in the depth decoder block. Must be even. + rms_norm_eps (`float`, *optional*, defaults to 1e-08): + The epsilon used by the rms normalization layers. + num_codebooks (`int`, *optional*, defaults to 8): + The number of audio codebooks for each audio channels. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether to tie weight embeddings + kwargs (*optional*): + Dictionary of keyword arguments. Notably: + - **audio_encoder_config** ([`PretrainedConfig`], *optional*) -- An instance of a configuration object that + defines the audio encoder config. + + Example: + + ```python + >>> from transformers import ( + ... MoshiDepthConfig, + ... MoshiDepthDecoder, + ... ) + + >>> configuration = MoshiDepthConfig() + + >>> # Initializing a MoshiDepthDecoder (with random weights) from the kmhf/hf-moshiko style configuration + >>> model = MoshiDepthDecoder(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "moshi_depth" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=32000, + hidden_size=1024, + input_size=4096, + num_hidden_layers=6, + num_attention_heads=16, + num_key_value_heads=None, + audio_vocab_size=2048, + max_position_embeddings=9, + hidden_act="silu", + head_dim=None, + initializer_range=0.02, + use_cache=True, + sliding_window=8, + attention_dropout=0.0, + ffn_dim=5632, + rms_norm_eps=1e-8, + num_codebooks=8, + tie_word_embeddings=False, + **kwargs, + ): + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.input_size = input_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads if num_key_value_heads is not None else num_attention_heads + self.max_position_embeddings = max_position_embeddings + self.hidden_act = hidden_act + self.head_dim = head_dim or hidden_size // num_attention_heads + self.initializer_range = initializer_range + self.use_cache = use_cache + self.sliding_window = sliding_window + self.attention_dropout = attention_dropout + if ffn_dim % 2 == 1: + raise ValueError(f"`ffn_dim={ffn_dim}` must be even.") + self.ffn_dim = ffn_dim + self.rms_norm_eps = rms_norm_eps + self.num_codebooks = num_codebooks + self.audio_vocab_size = audio_vocab_size + + super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs) + + +class MoshiConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`MoshiModel`]. It is used to instantiate a + Moshi model according to the specified arguments, defining the audio encoder, Moshi depth decoder and Moshi decoder + configs. Instantiating a configuration with the defaults will yield a similar configuration to that of the Moshiko model, + e.g. [kmhf/hf-moshiko](https://huggingface.co/kmhf/hf-moshiko) + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + vocab_size (`int`, *optional*, defaults to 32000): + Vocabulary size of the MoshiDecoder model. Defines the number of different tokens that can be + represented by the `inputs_ids` passed when calling [`MoshiDecoder`]. + hidden_size (`int`, *optional*, defaults to 4096): + Dimensionality of the layers and the pooler layer of the main decoder. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of decoder layers. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the main decoder block. + num_key_value_heads (`int`, *optional*): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `num_attention_heads`. + audio_vocab_size (`int`, *optional*): + Vocabulary size of the audio part of model. Defines the number of different tokens that can be + represented by the `audio_codes` passed when calling the Moshi models. + max_position_embeddings (`int`, *optional*, defaults to 3000): + The maximum sequence length that this model might ever be used with. Typically, set this to something large + just in case (e.g., 512 or 1024 or 2048). + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + head_dim (`int`, *optional*, defaults to `hidden_size // num_attention_heads`): + The attention head dimension. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + sliding_window (`int`, *optional*, defaults to 3000): + Sliding window attention window size. If not specified, will default to `3000`. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + ffn_dim (`int`, *optional*, defaults to 22528): + Dimensionality of the "intermediate" (often named feed-forward) layer in the main decoder block. Must be even. + rms_norm_eps (`float`, *optional*, defaults to 1e-08): + The epsilon used by the rms normalization layers. + num_codebooks (`int`, *optional*, defaults to 8): + The number of audio codebooks for each audio channels. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether to tie weight embeddings + kwargs (*optional*): + Dictionary of keyword arguments. Notably: + - **audio_encoder_config** ([`PretrainedConfig`], *optional*) -- An instance of a configuration object that + defines the audio encoder config. + - **depth__config** ([`PretrainedConfig`], *optional*) -- An instance of a configuration object that + defines the depth decoder config. + + + Example: + + ```python + >>> from transformers import ( + ... MoshiConfig, + ... MoshiForConditionalGeneration, + ... ) + + >>> configuration = MoshiConfig() + + >>> # Initializing a MoshiForConditionalGeneration (with random weights) from the kmhf/hf-moshiko style configuration + >>> model = MoshiForConditionalGeneration(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + + >>> # Saving the model, including its configuration + >>> model.save_pretrained("kmhf/hf-moshiko") + + >>> # loading model and config from pretrained folder + >>> moshi_config = MoshiConfig.from_pretrained("kmhf/hf-moshiko") + >>> model = MoshiForConditionalGeneration.from_pretrained("kmhf/hf-moshiko", config=moshi_config) + ```""" + + model_type = "moshi" + is_composition = True + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=32000, + hidden_size=4096, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=None, + audio_vocab_size=None, + max_position_embeddings=3000, + rope_theta=10000.0, + hidden_act="silu", + head_dim=None, + initializer_range=0.02, + use_cache=True, + sliding_window=3000, + attention_dropout=0.0, + ffn_dim=22528, + rms_norm_eps=1e-8, + num_codebooks=8, + tie_word_embeddings=False, + **kwargs, + ): + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads if num_key_value_heads is not None else num_attention_heads + self.max_position_embeddings = max_position_embeddings + self.rope_theta = rope_theta + self.hidden_act = hidden_act + self.head_dim = head_dim or hidden_size // num_attention_heads + self.initializer_range = initializer_range + self.use_cache = use_cache + self.sliding_window = sliding_window + self.attention_dropout = attention_dropout + if ffn_dim % 2 == 1: + raise ValueError(f"`ffn_dim={ffn_dim}` must be even.") + self.ffn_dim = ffn_dim + self.rms_norm_eps = rms_norm_eps + self.num_codebooks = num_codebooks + + audio_encoder_config = kwargs.pop("audio_encoder_config", {}) + audio_encoder_model_type = audio_encoder_config.pop("model_type", "mimi") + + self.audio_encoder_config = AutoConfig.for_model(audio_encoder_model_type, **audio_encoder_config) + + if self.num_codebooks > self.audio_encoder_config.num_codebooks: + raise ValueError( + f"`num_codebooks={num_codebooks}` is greater than the maximum number of codebooks that the audio encoder can deal with ({self.audio_encoder_config.num_codebooks}). Please lower it." + ) + + self.audio_vocab_size = ( + self.audio_encoder_config.codebook_size if audio_vocab_size is None else audio_vocab_size + ) + + depth_decoder_config = kwargs.pop("depth_decoder_config", {}) + depth_decoder_config.update( + { + "audio_vocab_size": self.audio_vocab_size, + "input_size": hidden_size, + "vocab_size": vocab_size, + "num_codebooks": num_codebooks, + } + ) + + self.depth_decoder_config = MoshiDepthConfig(**depth_decoder_config) + + super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs) + + @property + def sampling_rate(self): + return self.audio_encoder_config.sampling_rate + + @classmethod + def from_audio_encoder_config( + cls, + audio_encoder_config: PretrainedConfig, + **kwargs, + ): + r""" + Instantiate a [`MoshiConfig`] (or a derived class) from an audio encoder configuration. + + Returns: + [`MoshiConfig`]: An instance of a configuration object + """ + + return cls( + audio_encoder_config=audio_encoder_config.to_dict(), + **kwargs, + ) + + +__all__ = ["MoshiConfig", "MoshiDepthConfig"] diff --git a/src/transformers/models/moshi/convert_moshi_transformers.py b/src/transformers/models/moshi/convert_moshi_transformers.py new file mode 100644 index 00000000000..1caaee25ef6 --- /dev/null +++ b/src/transformers/models/moshi/convert_moshi_transformers.py @@ -0,0 +1,311 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert Moshi checkpoints.""" + +import argparse + +import safetensors +import sentencepiece +import torch + +from transformers import ( + AutoFeatureExtractor, + GenerationConfig, + MimiModel, # initial audio encoder + MoshiConfig, + MoshiForConditionalGeneration, + PreTrainedTokenizerFast, + logging, +) +from transformers.convert_slow_tokenizer import MoshiConverter + + +logging.set_verbosity_info() +logger = logging.get_logger("transformers.models.mimi") + + +def assert_param_count(model_1, model_2): + count_1 = sum(p[1].numel() for p in model_1.named_parameters() if "final_proj" not in p[0]) + count_2 = sum(p[1].numel() for p in model_2.named_parameters() if "final_proj" not in p[0]) + assert count_1 == count_2, f"{model_1.__class__}: {count_1} != {model_2.__class__}: {count_2}" + + +def param_count(model): + return sum(p[1].numel() for p in model.named_parameters() if "final_proj" not in p[0]) + + +def _grab_best_device(use_gpu=True): + if torch.cuda.device_count() > 0 and use_gpu: + device = "cuda" + else: + device = "cpu" + return torch.device(device) + + +convert_list = [ + # GENERAL + ("out_norm", "decoder.model.norm"), + ("depformer_emb", "depth_decoder.emb"), + ("depformer_text_emb", "depth_decoder.text_emb"), + ("text_emb", "decoder.model.emb"), + ("emb", "embed_tokens"), + ("text_linear", "decoder.lm_head"), + ("depformer", "depth_decoder"), + ("transformer", "decoder.model"), + # TRANSFORMERS PART + ("gating.linear_in", "mlp.fc1"), + ("gating.linear_out", "mlp.fc2"), + ("self_attn.out_proj", "self_attn.o_proj.linear"), + ("norm1", "input_layernorm"), + ("norm2", "post_attention_layernorm"), + ("layer_scale_1", "self_attn_layer_scale"), + ("layer_scale_2", "mlp_layer_scale"), + ("alpha", "weight"), +] + + +def _preprocess_state_dict(state_dict, config): + # Moshi original weights are using a gating mechanism + + # pattern for depth transformer: + # stack(gating.{i}.linear_in)->mlp.fc1 + # stack(gating.{i}.linear_out)->mlp.fc2 + + for layer_idx in range(config.depth_decoder_config.num_hidden_layers): + linear_layers_in = [ + state_dict.pop(f"depformer.layers.{layer_idx}.gating.{i}.linear_in.weight") + for i in range(config.num_codebooks) + ] + linear_layers_out = [ + state_dict.pop(f"depformer.layers.{layer_idx}.gating.{i}.linear_out.weight") + for i in range(config.num_codebooks) + ] + + state_dict[f"depth_decoder.layers.{layer_idx}.mlp.fc1.weight"] = torch.stack(linear_layers_in) + state_dict[f"depth_decoder.layers.{layer_idx}.mlp.fc2.weight"] = torch.stack(linear_layers_out) + + input_projections = [] + lm_heads = [] + for codebook_idx in range(config.num_codebooks): + input_projections.append(state_dict.pop(f"depformer_in.{codebook_idx}.weight")) + lm_heads.append(state_dict.pop(f"linears.{codebook_idx}.weight")) + + state_dict["depth_decoder.input_projections.weight"] = torch.stack(input_projections, dim=0) + state_dict["depth_decoder.lm_heads.weight"] = torch.stack(lm_heads, dim=0) + + return state_dict + + +def _convert_model( + state_dict, + hf_model, + convert_list, + device, + config, + unwanted_prefix=None, +): + hidden_size = config.hidden_size + head_dim = config.head_dim + num_heads = int(config.hidden_size // config.head_dim) + num_key_value_heads = config.num_key_value_heads + key_value_head_dim = config.num_key_value_heads * head_dim + + state_dict = _preprocess_state_dict(state_dict, config) + + # permute for sliced rotary + def permute(w, n_heads, dim1=hidden_size, dim2=hidden_size): + return w.view(n_heads, dim1 // n_heads // 2, 2, dim2).transpose(1, 2).reshape(dim1, dim2) + + for k, v in list(state_dict.items()): + if "audio_encoder" not in k: + new_k = k if unwanted_prefix is None else k[len(unwanted_prefix) :] + for old_layer_name, new_layer_name in convert_list: + if old_layer_name in new_k: + new_k = new_k.replace(old_layer_name, new_layer_name) + + if "alpha" in k: + state_dict[k] = state_dict[k].squeeze() + + if "in_proj_weight" in new_k: + # split qkv into query key and value + mixed_qkv = state_dict.pop(k) + if "depth_decoder" in new_k: + mixed_qkv = mixed_qkv.view(config.num_codebooks, -1, mixed_qkv.shape[-1]) + + qkv_dim = mixed_qkv.size(1) // 3 + + query_layer = mixed_qkv[:, :qkv_dim] + key_layer = mixed_qkv[:, qkv_dim : qkv_dim * 2] + value_layer = mixed_qkv[:, qkv_dim * 2 :] + state_dict[new_k.replace("in_proj_weight", "q_proj.linear.weight")] = query_layer + state_dict[new_k.replace("in_proj_weight", "k_proj.linear.weight")] = key_layer + + else: + qkv_dim = mixed_qkv.size(0) // 3 + + query_layer = mixed_qkv[:qkv_dim] + key_layer = mixed_qkv[qkv_dim : qkv_dim * 2] + value_layer = mixed_qkv[qkv_dim * 2 :] + state_dict[new_k.replace("in_proj_weight", "q_proj.linear.weight")] = permute( + query_layer, num_heads, hidden_size, hidden_size + ) + state_dict[new_k.replace("in_proj_weight", "k_proj.linear.weight")] = permute( + key_layer, num_key_value_heads, key_value_head_dim, hidden_size + ) + + state_dict[new_k.replace("in_proj_weight", "v_proj.linear.weight")] = value_layer + elif "o_proj" in new_k and "depth_decoder" in new_k: + output_layer = state_dict.pop(k) + state_dict[new_k] = output_layer.view(config.num_codebooks, -1, output_layer.shape[-1]) + else: + state_dict[new_k] = state_dict.pop(k) + + # Do the last one by hand + state_dict["depth_decoder.text_embed_tokens.weight"] = state_dict.pop( + "depth_decoder.decoder.model.embed_tokens.weight" + ) + + extra_keys = set(state_dict.keys()) - set(hf_model.state_dict().keys()) + missing_keys = set(hf_model.state_dict().keys()) - set(state_dict.keys()) + if len(extra_keys) != 0: + raise ValueError(f"extra keys found: {extra_keys}") + if len(missing_keys) != 0: + raise ValueError(f"missing keys: {missing_keys}") + hf_model.load_state_dict(state_dict, strict=True) + n_params = param_count(hf_model) + + logger.info(f"model loaded: {round(n_params/1e6,1)}M params") + + hf_model.eval() + hf_model.to(device) + del state_dict + + return hf_model + + +@torch.no_grad() +def convert_checkpoint( + checkpoint_path, + pytorch_dump_folder_path, + mimi_repo_id, + config_path=None, + repo_id=None, +): + """ + Copy/paste/tweak model's weights to transformers design. + """ + device = _grab_best_device() + + mimi_model = MimiModel.from_pretrained(mimi_repo_id, torch_dtype=torch.bfloat16) + + if config_path is not None: + config = MoshiConfig.from_pretrained(config_path) + else: + audio_encoder_config = mimi_model.config + config = MoshiConfig.from_audio_encoder_config(audio_encoder_config) + + model = MoshiForConditionalGeneration(config).to(torch.bfloat16) + + depth_decoder_generation_config = GenerationConfig( + do_sample=True, + temperature=0.8, + top_k=250, + min_length=config.num_codebooks + 1, + max_length=config.num_codebooks + 1, + cache_implementation="sliding_window", + ) + + generation_config = GenerationConfig( + do_sample=True, + temp=0.7, + top_k=25, + cache_implementation="sliding_window", + pad_token_id=config.vocab_size, + bos_token_id=config.vocab_size, + ) + generation_config.depth_decoder_config = depth_decoder_generation_config.to_diff_dict() + + model.generation_config = generation_config + + original_checkpoint = safetensors.torch.load_file(checkpoint_path) + if "best_state" in original_checkpoint: + # we might have a training state saved, in which case discard the yaml results and just retain the weights + original_checkpoint = original_checkpoint["best_state"] + + audio_checkpoint = mimi_model.state_dict() + original_checkpoint.update({f"audio_encoder.{key}": value for (key, value) in audio_checkpoint.items()}) + + model = _convert_model(original_checkpoint, model, convert_list, device, config) + + model.save_pretrained(pytorch_dump_folder_path) + + if repo_id: + print("Pushing to the hub...") + model.push_to_hub(repo_id) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--checkpoint_path", required=True, default=None, type=str, help="Path to original checkpoint") + parser.add_argument( + "--tokenizer_vocab_path", required=False, default=None, type=str, help="Path to original tokenizer vocab file" + ) + parser.add_argument("--mimi_repo_id", required=True, default=None, type=str, help="Repository id to HF Mimi.") + parser.add_argument("--config_path", default=None, type=str, help="Path to hf config.json of model to convert") + parser.add_argument( + "--pytorch_dump_folder_path", required=True, default=None, type=str, help="Path to the output PyTorch model." + ) + parser.add_argument( + "--push_to_hub", default=None, type=str, help="Where to upload the converted model on the 🤗 hub." + ) + + args = parser.parse_args() + + # convert tokenizer + if args.tokenizer_vocab_path: + original_tokenizer = sentencepiece.SentencePieceProcessor(args.tokenizer_vocab_path) + tokenizer = MoshiConverter(args.tokenizer_vocab_path).converted() + tokenizer = PreTrainedTokenizerFast( + tokenizer_object=tokenizer, + chat_template=None, + unk_token="", + model_input_names=["input_ids", "attention_mask"], + clean_up_tokenization_spaces=False, + bos_token_id=original_tokenizer.bos_id(), + eos_token_id=original_tokenizer.eos_id(), + pad_token_id=original_tokenizer.pad_id(), + ) + + tokenizer.save_pretrained(args.pytorch_dump_folder_path) + + if args.push_to_hub: + print("Pushing the tokenizer to the hub...") + tokenizer.push_to_hub(args.push_to_hub) + + # upload feature extractor + feature_extractor = AutoFeatureExtractor.from_pretrained(args.mimi_repo_id) + feature_extractor.save_pretrained(args.pytorch_dump_folder_path) + + if args.push_to_hub: + print("Pushing the feature extractor to the hub...") + feature_extractor.push_to_hub(args.push_to_hub) + + convert_checkpoint( + args.checkpoint_path, + args.pytorch_dump_folder_path, + args.mimi_repo_id, + args.config_path, + args.push_to_hub, + ) diff --git a/src/transformers/models/moshi/modeling_moshi.py b/src/transformers/models/moshi/modeling_moshi.py new file mode 100644 index 00000000000..5746a5934bd --- /dev/null +++ b/src/transformers/models/moshi/modeling_moshi.py @@ -0,0 +1,2813 @@ +# coding=utf-8 +# Copyright 2024 Kyutai and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Moshi model.""" + +import math +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.utils.checkpoint +from torch.nn import CrossEntropyLoss + +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache +from ...generation import ( + GenerationConfig, + GenerationMixin, +) +from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, + ModelOutput, + Seq2SeqLMOutput, +) +from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import ALL_LAYERNORM_LAYERS +from ...utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_flash_attn_2_available, + is_flash_attn_greater_or_equal_2_10, + is_torchdynamo_compiling, + logging, + replace_return_docstrings, +) +from ..auto.modeling_auto import AutoModel +from .configuration_moshi import MoshiConfig, MoshiDepthConfig + + +if is_flash_attn_2_available(): + from ...modeling_flash_attention_utils import _flash_attention_forward + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "MoshiConfig" +_CHECKPOINT_FOR_DOC = "kmhf/hf-moshiko" + + +@dataclass +class MoshiConditionalGenerationGenerateOutput(ModelOutput): + """ + Outputs of [`MoshiForConditionalConditionalGeneration.generate`]. + + Args: + audio_sequences (`torch.LongTensor` of shape `(batch_size*num_return_sequences, 1, sequence_length)`, *optional*): + The generated audio waveforms. + sequences (`torch.LongTensor` of shape `(batch_size*num_return_sequences, sequence_length)`): + The generated text sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter + if all batches finished early due to the `eos_token_id`. + sequences_scores (`torch.FloatTensor` of shape `(batch_size*num_return_sequences)`, *optional*, returned when `output_scores=True`): + Final beam scores of the generated `sequences`. + scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True`): + Beam transition scores for each vocabulary token at each generation step. Beam transition scores consisting + of log probabilities of tokens conditioned on log softmax of previously generated tokens in this beam. + Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for each generated token), + with each tensor of shape `(batch_size*num_beams, config.vocab_size)`. + logits (`tuple(torch.FloatTensor)` *optional*, returned when `output_logits=True`): + Unprocessed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax) + at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for + each generated token), with each tensor of shape `(batch_size, config.vocab_size)`. + beam_indices (`torch.LongTensor`, *optional*, returned when `output_scores=True`): + Beam indices of generated token id at each generation step. `torch.LongTensor` of shape + `(batch_size*num_return_sequences, sequence_length)`. + attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True`): + Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of + `torch.FloatTensor` of shape `(batch_size*num_beams, num_heads, generated_length, sequence_length)`. + hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True`): + Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of + `torch.FloatTensor` of shape `(batch_size*num_beams*num_return_sequences, generated_length, hidden_size)`. + past_key_values (`tuple(tuple(torch.FloatTensor)))`, *optional*, returned when `use_cache=True`): + Returns the model cache, used to speed up decoding. Different models have a different cache format, check + the model's documentation. Usually, a [`~cache_utils.Cache`] instance. + audio_codes (`torch.LongTensor` of shape `(batch_size*num_return_sequences, num_codeooks, sequence_length)`, *optional*): + The generated audio codes. Returned if `return_audio_codes=True`. Intermediate audio "tokens" which transforms to `audio_sequences` once passed through the audio decoder. + """ + + audio_sequences: Optional[torch.Tensor] = None + sequences: torch.LongTensor = None + sequences_scores: Optional[torch.FloatTensor] = None + scores: Optional[Tuple[torch.FloatTensor]] = None + logits: Optional[Tuple[torch.FloatTensor]] = None + beam_indices: Optional[torch.LongTensor] = None + attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + past_key_values: Optional[Tuple[Tuple[Tuple[torch.FloatTensor]]]] = None + audio_codes: Optional[torch.LongTensor] = None + + +@dataclass +class MoshiCausalLMOutputWithPast(ModelOutput): + """ + `MoshiForCausalLM` outputs. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + last_hidden_state: torch.FloatTensor = None + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + + +@dataclass +class MoshiConditionalGenerationOutputWithPast(ModelOutput): + """ + `MoshiForConditionalGeneration` outputs. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `text_labels` is provided): + Text language modeling loss (for next-token prediction). + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the text language modeling head (scores for each vocabulary token before SoftMax). + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + depth_loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `audio_labels` is provided): + Audio language modeling loss (for next-token prediction). + audio_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the audio language modeling heads. + depth_past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Past key-values of the depth decoder. + depth_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Hidden states of the depth decoder + depth_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Depth decoder's Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + last_hidden_state: torch.FloatTensor = None + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + depth_loss: Optional[torch.FloatTensor] = None + audio_logits: torch.FloatTensor = None + depth_past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + depth_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + depth_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + + +@dataclass +class MoshiUnconditionalInput(ModelOutput): + """ + Args: + input_ids (`torch.Tensor `of shape `(batch_size, sequence_length), *optional*): + The sequence used as a text prompt for the generation. + user_audio_codes (`torch.Tensor `of shape `(batch_size, num_codebooks, sequence_length), *optional*): + The audio codes used as audio user prompt for the generation. Has priority over `user_input_values` and represents the audio "tokens" of `user_input_values` once passed through the audio encoder. + moshi_audio_codes (`torch.Tensor `of shape `(batch_size, num_codebooks, sequence_length), *optional*): + The audio codes used as audio Moshi prompt for the generation. Has priority over `moshi_input_values` and represents the audio "tokens" of `moshi_input_values` once passed through the audio encoder. + attention_mask (`torch.LongTensor`) of shape `(batch_size, sequence_length)`, *optional*): + Attention mask to avoid performing attention on padding token indices. Mask values selected in `[0, + 1]`: 1 for tokens that are **not masked**, 0 for tokens that are **masked**. + """ + + input_ids: torch.LongTensor = None + user_audio_codes: torch.Tensor = None + moshi_audio_codes: torch.Tensor = None + attention_mask: torch.LongTensor = None + + +# Copied from transformers.models.gemma.modeling_gemma.GemmaRMSNorm with Gemma->Moshi +class MoshiRMSNorm(nn.Module): + def __init__(self, dim: int, eps: float = 1e-6): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) # Ignore copy + + def _norm(self, x): + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + # Ignore copy + def forward(self, x): + output = self._norm(x.float()) + output = output * self.weight.float() + return output.type_as(x) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.eps}" + + +ALL_LAYERNORM_LAYERS.append(MoshiRMSNorm) + + +class MoshiFlexibleLinear(nn.Module): + def __init__(self, input_size, output_size, num_layers): + super().__init__() + # Stack the weights for N layers into a single tensor (num_layers, output_size, input_size) + self.weight = nn.Parameter(torch.randn(num_layers, output_size, input_size)) + + def forward(self, x, layer_idx=None): + """ + `MoshiFlexibleLinear` creates one linear layer per codebook. There's multiple ways to use it. + In the default case, `sequence_length=num_layers`, so each element of the sequence will be matmul to the weights corresponding to its index on the sequence. + + For more advanced cases, one can specify which codebook's layer(s) to use with `layer_idx`. + If `layer_idx` indicates a single integer, all of the element of the sequence will be matmul to this single codebook's layer. + But if `layer_idx` is a tensor of shape `(seq_length,)`, it will matmul each i-th element of the input sequence to the corresponding layer `weight[i]`. + + + Args: + x (`torch.FloatTensor): input to the layer of shape `(batch, num_layers, embed_dim)` or of shape `(batch, seq_length, embed_dim)` + layer_idx (`torch.Tensor`, *optional*): + Can be used to specify which codebook's layers(s) to use. + If it's a tensor of shape `(seq_length,)`, will matmul each element of the sequence to the corresponding weights. + But if `layer_idx` is a tensor of shape `(seq_length,)`, it will matmul each i-th element of the input sequence to the corresponding layer `weight[i]`. + """ + + # Use torch.gather to select the corresponding weights for each sample + # (codebooks, output_size, hidden_size) + selected_weights = torch.index_select(self.weight, 0, layer_idx) if layer_idx is not None else self.weight + + # (1, codebooks, hidden_size, output_size) + selected_weights = selected_weights.transpose(1, 2)[None, :, :, :] + + # (batch_size, codebooks, 1, hidden_size) x (1, codebooks, hidden_size, output_size) + # -> (batch_size, codebooks, 1, output_size) + x = torch.matmul(x[:, :, None, :], selected_weights) + + # (batch_size, codebooks, output_size) + return x.squeeze(2) + + +class MoshiLinear(nn.Module): + def __init__(self, input_dim, output_dim, num_codebooks, use_flexible_linear=False): + super().__init__() + + self.use_flexible_linear = use_flexible_linear + + if not use_flexible_linear: + self.linear = nn.Linear(input_dim, output_dim, bias=False) + else: + self.linear = MoshiFlexibleLinear(input_dim, output_dim, num_layers=num_codebooks) + + def forward(self, x, layer_idx=None): + if self.use_flexible_linear: + return self.linear(x, layer_idx) + else: + return self.linear(x) + + +# Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Moshi +class MoshiRotaryEmbedding(nn.Module): + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): + super().__init__() + + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + @torch.no_grad() + # copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.forward + # TODO(joao): add me back asap :) + def forward(self, x, position_ids): + # x: [bs, num_attention_heads, seq_len, head_size] + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 since bfloat16 loses precision on long contexts + # See https://github.com/huggingface/transformers/pull/29285 + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +# Copied from transformers.models.llama.modeling_llama.rotate_half +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class MoshiGatingMLP(nn.Module): + def __init__(self, config, use_flexible_linear=False): + super().__init__() + + self.activation_fn = ACT2FN[config.hidden_act] + ffn_dim = config.ffn_dim + hidden_size = config.hidden_size + num_layers = config.num_codebooks if use_flexible_linear else 1 + if num_layers == 1: + self.fc1 = nn.Linear(hidden_size, ffn_dim, bias=False) + self.fc2 = nn.Linear(ffn_dim // 2, hidden_size, bias=False) + else: + self.fc1 = MoshiFlexibleLinear(hidden_size, ffn_dim, num_layers) + self.fc2 = MoshiFlexibleLinear(ffn_dim // 2, hidden_size, num_layers) + + def forward(self, hidden_states: torch.Tensor, layer_idx: int = None) -> torch.Tensor: + hidden_states = self.fc1(hidden_states) if layer_idx is None else self.fc1(hidden_states, layer_idx) + + batch_size, sequence_length, _ = hidden_states.shape + hidden_states = hidden_states.view(batch_size, sequence_length, 2, -1) + hidden_states = self.activation_fn(hidden_states[..., 0, :]) * hidden_states[..., 1, :] + hidden_states = self.fc2(hidden_states) if layer_idx is None else self.fc2(hidden_states, layer_idx) + return hidden_states + + +# Copied from transformers.models.llama.modeling_llama.repeat_kv +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +class MoshiAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: MoshiConfig, layer_idx: Optional[int] = None, use_flexible_linear=False, use_rope=True): + super().__init__() + self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " + "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.attention_dropout = config.attention_dropout + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = config.head_dim + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.is_causal = True + self.scaling = 1 / math.sqrt(self.head_dim) + + if self.hidden_size % self.num_heads != 0: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." + ) + + self.q_proj = MoshiLinear( + self.hidden_size, self.num_heads * self.head_dim, config.num_codebooks, use_flexible_linear + ) + self.k_proj = MoshiLinear( + self.hidden_size, self.num_key_value_heads * self.head_dim, config.num_codebooks, use_flexible_linear + ) + self.v_proj = MoshiLinear( + self.hidden_size, self.num_key_value_heads * self.head_dim, config.num_codebooks, use_flexible_linear + ) + self.o_proj = MoshiLinear( + self.num_heads * self.head_dim, self.hidden_size, config.num_codebooks, use_flexible_linear + ) + + # rotary embeddings are not used in the depth decoder + self.rotary_emb = None + if use_rope: + self.rope_theta = config.rope_theta + self.rotary_emb = MoshiRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, + ) + + # Copied from transformers.models.gemma.modeling_gemma.GemmaAttention.forward + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states, cache_position) # Ignore copy + key_states = self.k_proj(hidden_states, cache_position) # Ignore copy + value_states = self.v_proj(hidden_states, cache_position) # Ignore copy + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + if self.rotary_emb is not None: # Ignore copy + cos, sin = self.rotary_emb(value_states, position_ids) # Ignore copy + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) # Ignore copy + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = ( + {"sin": sin, "cos": cos, "cache_position": cache_position} + if self.rotary_emb is not None + else {"cache_position": cache_position} + ) # Ignore copy + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scaling + + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + + attn_output = attn_output.view(bsz, q_len, -1) + attn_output = self.o_proj(attn_output, cache_position) # Ignore copy + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +# Copied from transformers.models.gemma.modeling_gemma.GemmaFlashAttention2 with Gemma->Moshi +class MoshiFlashAttention2(MoshiAttention): + """ + Moshi flash attention module. This module inherits from `MoshiAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if isinstance(past_key_value, StaticCache): + raise ValueError( + "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` " + "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers" + ) + + output_attentions = False + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states, cache_position) # Ignore copy + key_states = self.k_proj(hidden_states, cache_position) # Ignore copy + value_states = self.v_proj(hidden_states, cache_position) # Ignore copy + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we just need to keep the original shape + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + if self.rotary_emb is not None: # Ignore copy + cos, sin = self.rotary_emb(value_states, position_ids) # Ignore copy + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) # Ignore copy + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = ( + {"sin": sin, "cos": cos, "cache_position": cache_position} + if self.rotary_emb is not None + else {"cache_position": cache_position} + ) # Ignore copy + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache + # to be able to avoid many of these transpose/reshape/view. + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + dropout_rate = self.attention_dropout if self.training else 0.0 + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (MoshiRMSNorm handles it correctly) + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + position_ids=position_ids, + dropout=dropout_rate, + sliding_window=getattr(self, "sliding_window", None), + is_causal=self.is_causal, + use_top_left_mask=self._flash_attn_uses_top_left_mask, + ) + + attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() + attn_output = self.o_proj(attn_output, cache_position) # Ignore copy + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +# Copied from transformers.models.gemma.modeling_gemma.GemmaSdpaAttention with Gemma->Moshi +class MoshiSdpaAttention(MoshiAttention): + """ + Moshi attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from + `MoshiAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to + SDPA API. + """ + + # Adapted from MoshiAttention.forward + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if output_attentions: + # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. + logger.warning_once( + "MoshiModel is using MoshiSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " + 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states, cache_position) # Ignore copy + key_states = self.k_proj(hidden_states, cache_position) # Ignore copy + value_states = self.v_proj(hidden_states, cache_position) # Ignore copy + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + if self.rotary_emb is not None: # Ignore copy + cos, sin = self.rotary_emb(value_states, position_ids) # Ignore copy + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) # Ignore copy + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = ( + {"sin": sin, "cos": cos, "cache_position": cache_position} + if self.rotary_emb is not None + else {"cache_position": cache_position} + ) # Ignore copy + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + causal_mask = attention_mask + if attention_mask is not None: + causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query_states.device.type == "cuda" and causal_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + is_causal = True if causal_mask is None and q_len > 1 else False + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=causal_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + is_causal=is_causal, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(bsz, q_len, -1) + + attn_output = self.o_proj(attn_output, cache_position) # Ignore copy + + return attn_output, None, past_key_value + + +MOSHI_ATTENTION_CLASSES = { + "eager": MoshiAttention, + "flash_attention_2": MoshiFlashAttention2, + "sdpa": MoshiSdpaAttention, +} + + +class MoshiDecoderLayer(nn.Module): + def __init__(self, config: MoshiConfig, layer_idx: int, use_flexible_linear: bool, use_rope=True): + super().__init__() + self.hidden_size = config.hidden_size + self.use_flexible_linear = use_flexible_linear + + self.self_attn = MOSHI_ATTENTION_CLASSES[config._attn_implementation]( + config=config, layer_idx=layer_idx, use_flexible_linear=use_flexible_linear, use_rope=use_rope + ) + + self.mlp = MoshiGatingMLP(config, use_flexible_linear) + self.input_layernorm = MoshiRMSNorm(self.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = MoshiRMSNorm(self.hidden_size, eps=config.rms_norm_eps) + self.sliding_window = config.sliding_window + + self._attn_implementation = config._attn_implementation + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): + attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, + query_sequence_length, key_sequence_length)` if default attention is used. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model + """ + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = ( + self.mlp(hidden_states) if not self.use_flexible_linear else self.mlp(hidden_states, cache_position) + ) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +class MoshiPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = MoshiConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["MoshiDecoderLayer", "MimiTransformerLayer"] + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_cache_class = True + main_input_name = "input_ids" + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, (nn.Linear, nn.Conv1d)): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + elif isinstance(module, nn.Conv1d): + nn.init.kaiming_normal_(module.weight) + if module.bias is not None: + k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0])) + nn.init.uniform_(module.bias, a=-k, b=k) + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +MOSHI_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`MoshiConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +MOSHI_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence text tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + user_input_values (`torch.Tensor `of shape `(batch_size, 1, audio_sequence_length), *optional*): + The audio waveforms used as audio user prompt for the generation. + user_audio_codes (`torch.Tensor `of shape `(batch_size, num_codebooks, sequence_length), *optional*): + The audio codes used as audio user prompt for the generation. Has priority over `user_input_values` and represents the audio "tokens" of `user_input_values` once passed through the audio encoder. + moshi_input_values (`torch.Tensor `of shape `(batch_size, 1, audio_sequence_length), *optional*): + The audio waveforms used as audio Moshi prompt for the generation. + moshi_audio_codes (`torch.Tensor `of shape `(batch_size, num_codebooks, sequence_length), *optional*): + The audio codes used as audio Moshi prompt for the generation. Has priority over `moshi_input_values` and represents the audio "tokens" of `moshi_input_values` once passed through the audio encoder. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded + representation. If `past_key_values` is used, optionally only the last `inputs_embeds` have to be + input (see `past_key_values`). This is useful if you want more control over how to convert + `input_ids` indices into associated vectors than the model's internal embedding lookup matrix. + + If `input_ids` and `inputs_embeds` are both unset, `inputs_embeds` takes the value + of `inputs_embeds`. + past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` + returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. + + Two formats are allowed: + - a [`~cache_utils.Cache`] instance; + - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy + cache format. + + The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the + legacy cache format will be returned. + text_labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for text language modeling. Note that the labels **are shifted** inside the model, i.e. you can set + `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` + are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` + audio_labels (`torch.LongTensor` of shape `(batch_size, num_codebooks, sequence_length)`, *optional*): + Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set + `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` + are ignored (masked), the loss is only computed for labels in `[0, ..., config.audio_vocab_size]` + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + +MOSHI_DECODER_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` + returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. + + Two formats are allowed: + - a [`~cache_utils.Cache`] instance, see our + [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache); + - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy + cache format. + + The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the + legacy cache format will be returned. + + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't + have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` + of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, + this tensor is not affected by padding. It is used to update the cache in the correct position and to infer + the complete sequence length. +""" + + +class MoshiDepthDecoder(MoshiPreTrainedModel, GenerationMixin): + """ + Transformer depth decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`MoshiTransformerLayer`] + + Args: + config: MoshiConfig + """ + + config_class = MoshiDepthConfig + + def __init__(self, config: MoshiDepthConfig): + super().__init__(config) + + self.text_embed_tokens = nn.Embedding(config.vocab_size + 1, config.hidden_size) + + # the last codebook is never used as input + self.embed_tokens = nn.ModuleList( + [nn.Embedding(config.audio_vocab_size + 1, config.hidden_size) for _ in range(config.num_codebooks - 1)] + ) + + self.input_projections = MoshiFlexibleLinear(config.input_size, config.hidden_size, config.num_codebooks) + + self.layers = nn.ModuleList( + [ + MoshiDecoderLayer(config, layer_idx, use_flexible_linear=True, use_rope=False) + for layer_idx in range(config.num_hidden_layers) + ] + ) + + self.lm_heads = MoshiFlexibleLinear(config.hidden_size, config.audio_vocab_size, config.num_codebooks) + self._attn_implementation = config._attn_implementation + self.gradient_checkpointing = False + self.config = config + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + last_hidden_state: torch.LongTensor = None, + attention_mask: Optional[torch.BoolTensor] = None, + past_key_values: Tuple[Tuple[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + position_ids: Optional[torch.LongTensor] = None, + labels: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + """ + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens. The first element of the sequence must the text token associated to the audio codebooks. + The rest of the elements must be flatten audio codebooks. The `cache_position` argument can be used to indicate to which index is associated each token. + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the main decoder. Used to contextualize `input_ids` + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` + returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. + + Two formats are allowed: + - a [`~cache_utils.Cache`] instance; + - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy + cache format. + + The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the + legacy cache format will be returned. + + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't + have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` + of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert the inputs into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + if use_cache and past_key_values is None and not self.training: + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + + past_seen_tokens = 0 if past_key_values is None else past_key_values.get_seq_length() + if cache_position is None: + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + input_ids.shape[1], device=input_ids.device + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + # If inputs_embeds is provided, it has the priority over input_ids, which won't be used + if inputs_embeds is None: + inputs_embeds = [] + for position_idx in cache_position: + position_idx = position_idx.item() + if position_idx == 0: + inputs_embeds.append(self.text_embed_tokens(input_ids[:, [position_idx]])) + else: + inputs_embeds.append( + self.embed_tokens[(position_idx - 1)](input_ids[:, [position_idx - past_seen_tokens]]) + ) + + inputs_embeds = torch.cat(inputs_embeds, dim=1) + + inputs_embeds += self.input_projections(last_hidden_state, cache_position) + + causal_mask = None + if attention_mask is not None: + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = None + hidden_states = inputs_embeds + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + causal_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + logits = self.lm_heads(hidden_states, cache_position) + + loss = None + if labels is not None: + # Upcast to float if we need to compute the loss to avoid potential precision issues + logits = logits.float() + loss_fct = CrossEntropyLoss() + + labels = labels.masked_fill(labels == self.config.audio_vocab_size, -100).reshape(-1) + # Enable model parallelism + labels = labels.to(logits.device) + loss = loss_fct(logits.reshape(-1, self.config.audio_vocab_size), labels) + + if not return_dict: + return tuple(v for v in [loss, logits, next_cache, all_hidden_states, all_self_attns] if v is not None) + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + # Copied from transformers.models.phi3.modeling_phi3.Phi3Model._update_causal_mask + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if ( + self.config._attn_implementation == "sdpa" + and not (using_static_cache or using_sliding_window_cache) + and not output_attentions + ): + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + sliding_window=self.config.sliding_window, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + min_dtype = torch.finfo(dtype).min + sequence_length = input_tensor.shape[1] + # SlidingWindowCache or StaticCache + if using_sliding_window_cache or using_static_cache: + target_length = past_key_values.get_max_cache_shape() + # DynamicCache or no cache + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + device=device, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + config=self.config, + past_key_values=past_key_values, + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type == "cuda" + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + @staticmethod + # Copied from transformers.models.mistral.modeling_mistral.MistralModel._prepare_4d_causal_attention_mask_with_cache_position with Mistral->MoshiDepth + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + cache_position: torch.Tensor, + batch_size: int, + config: MoshiDepthConfig, + past_key_values: Cache, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to plcae the 4D attention mask on. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + config (`MoshiDepthConfig`): + The model's configuration class + past_key_values (`Cache`): + The cache class that is being used currently to generate + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) + diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + if config.sliding_window is not None: + # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also + # the check is needed to verify is current checkpoint was trained with sliding window or not + if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length: + sliding_attend_mask = torch.arange(target_length, device=device) <= ( + cache_position.reshape(-1, 1) - config.sliding_window + ) + diagonal_attend_mask |= sliding_attend_mask + causal_mask *= diagonal_attend_mask + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + if attention_mask.shape[-1] > target_length: + attention_mask = attention_mask[:, :target_length] + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + return causal_mask + + def prepare_inputs_for_generation( + self, + input_ids: torch.LongTensor, + past_key_values: Optional[Cache] = None, + attention_mask: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + use_cache: bool = True, + num_logits_to_keep: Optional[int] = None, + **kwargs, + ): + """ + Prepare the model inputs for generation. In includes operations like computing the 4D attention mask or + slicing inputs given the existing cache. + See the documentation in the used model for the arguments (different models might have different requirements + for e.g. `past_key_values`). Should work as is for most LLMs. + """ + # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens + # Exception 1: when passing input_embeds, input_ids may be missing entries + # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here + if past_key_values is not None: + if inputs_embeds is not None: # Exception 1 + input_ids = input_ids[:, -cache_position.shape[0] :] + elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) + input_ids = input_ids[:, cache_position] + + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -input_ids.shape[1] :] + + # This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s + # `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the + # decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, + # `position_ids` is already contiguous but with varying stride which retriggers a capture. + position_ids = position_ids.clone(memory_format=torch.contiguous_format) + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and cache_position[0] == 0: + model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None} + else: + # The clone here is for the same reason as for `position_ids`. + model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None} + + if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2: + if model_inputs["inputs_embeds"] is not None: + batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape + device = model_inputs["inputs_embeds"].device + else: + batch_size, sequence_length = model_inputs["input_ids"].shape + device = model_inputs["input_ids"].device + + # Create the causal mask with fixed shape in advance, to reduce recompilations. If the function to create + # the 4D causal mask exists, it should be present in the base model (XXXModel class). + attention_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=past_key_values.max_cache_len, + dtype=self.text_embed_tokens.weight.dtype, + device=device, + cache_position=cache_position, + batch_size=batch_size, + config=self.config, + past_key_values=past_key_values, + ) + + if num_logits_to_keep is not None: + model_inputs["num_logits_to_keep"] = num_logits_to_keep + + model_inputs.update( + { + "position_ids": position_ids, + "cache_position": cache_position, + "past_key_values": past_key_values, + "use_cache": use_cache, + "attention_mask": attention_mask, + "last_hidden_state": kwargs.get("last_hidden_state"), + } + ) + return model_inputs + + +@add_start_docstrings( + "The bare Moshi Model outputting raw hidden-states without any specific head on top.", + MOSHI_START_DOCSTRING, +) +class MoshiModel(MoshiPreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`MoshiDecoderLayer`] + + Args: + config: MoshiConfig + """ + + def __init__(self, config: MoshiConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size + 1, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [ + MoshiDecoderLayer(config, layer_idx, use_flexible_linear=False) + for layer_idx in range(config.num_hidden_layers) + ] + ) + self.norm = MoshiRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @add_start_docstrings_to_model_forward(MOSHI_DECODER_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + return_legacy_cache = False # noqa: F841 + if ( + use_cache and not isinstance(past_key_values, Cache) and not self.training + ): # kept for BC (non `Cache` `past_key_values` inputs) + return_legacy_cache = True # noqa: F841 + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = None + if attention_mask is not None: + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) + + # embed positions + hidden_states = inputs_embeds + + if ( + use_cache and not isinstance(past_key_values, Cache) and not self.training + ): # kept for BC (non `Cache` `past_key_values` inputs) + return_legacy_cache = True + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + logger.warning_once( + "We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. " + "Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/internal/generation_utils#transformers.Cache)" + ) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + causal_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + next_cache = next_cache.to_legacy_cache() + + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + # Copied from transformers.models.phi3.modeling_phi3.Phi3Model._update_causal_mask + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if ( + self.config._attn_implementation == "sdpa" + and not (using_static_cache or using_sliding_window_cache) + and not output_attentions + ): + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + sliding_window=self.config.sliding_window, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + min_dtype = torch.finfo(dtype).min + sequence_length = input_tensor.shape[1] + # SlidingWindowCache or StaticCache + if using_sliding_window_cache or using_static_cache: + target_length = past_key_values.get_max_cache_shape() + # DynamicCache or no cache + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + device=device, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + config=self.config, + past_key_values=past_key_values, + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type == "cuda" + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + @staticmethod + # Copied from transformers.models.mistral.modeling_mistral.MistralModel._prepare_4d_causal_attention_mask_with_cache_position with Mistral->Moshi + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + cache_position: torch.Tensor, + batch_size: int, + config: MoshiConfig, + past_key_values: Cache, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to plcae the 4D attention mask on. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + config (`MoshiConfig`): + The model's configuration class + past_key_values (`Cache`): + The cache class that is being used currently to generate + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) + diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + if config.sliding_window is not None: + # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also + # the check is needed to verify is current checkpoint was trained with sliding window or not + if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length: + sliding_attend_mask = torch.arange(target_length, device=device) <= ( + cache_position.reshape(-1, 1) - config.sliding_window + ) + diagonal_attend_mask |= sliding_attend_mask + causal_mask *= diagonal_attend_mask + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + if attention_mask.shape[-1] > target_length: + attention_mask = attention_mask[:, :target_length] + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + return causal_mask + + +@add_start_docstrings( + "The Moshi decoder model with a text language modelling head on top. Only usable for text.", + MOSHI_START_DOCSTRING, +) +class MoshiForCausalLM(MoshiPreTrainedModel, GenerationMixin): + _tied_weights_keys = ["model.embed_tokens.weight", "lm_head.weight"] + + # Copied from transformers.models.gemma.modeling_gemma.GemmaForCausalLM.__init__ with Gemma->Moshi + def __init__(self, config): + super().__init__(config) + self.model = MoshiModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @add_start_docstrings_to_model_forward(MOSHI_DECODER_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=MoshiCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + labels: Optional[torch.LongTensor] = None, + num_logits_to_keep: int = 0, + ) -> Union[Tuple, MoshiCausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + num_logits_to_keep (`int`, *optional*): + Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, MoshiForCausalLM + + >>> model = MoshiForCausalLM.from_pretrained("kmhf/hf-moshiko") + >>> tokenizer = AutoTokenizer.from_pretrained("kmhf/hf-moshiko") + + >>> prompt = "What is your favorite condiment?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "What is your favorite condiment?" + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + + hidden_states = outputs[0] + if labels is None and not is_torchdynamo_compiling(): + logger.warning_once( + "Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)" + ) + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) + + loss = None + if labels is not None: + # Upcast to float if we need to compute the loss to avoid potential precision issues + logits = logits.float() + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = ( + logits, + hidden_states, + ) + outputs[1:] + return (loss,) + output if loss is not None else output + + return MoshiCausalLMOutputWithPast( + loss=loss, + logits=logits, + last_hidden_state=hidden_states, # Ignore copy + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + "The original Moshi model with an audio encoder, a Moshi depth decoder and a Moshi decoder, " + "for speech-to-speech.", + MOSHI_START_DOCSTRING, +) +class MoshiForConditionalGeneration(MoshiPreTrainedModel, GenerationMixin): + _tied_weights_keys = ["decoder.model.embed_tokens.weight", "decoder.lm_head.weight"] + config_class = MoshiConfig + main_input_name = "input_ids" + supports_gradient_checkpointing = True + _supports_flash_attn_2 = True + _supports_sdpa = True + + def __init__(self, config: MoshiConfig): + super().__init__(config) + # We have 2 * num_codebooks audio embedding layers because we have the user input channel and the model output channel. + self.embed_tokens = nn.ModuleList( + [nn.Embedding(config.audio_vocab_size + 1, config.hidden_size) for _ in range(2 * config.num_codebooks)] + ) + self.audio_encoder = AutoModel.from_config( + config.audio_encoder_config, attn_implementation=config._attn_implementation + ) + self.decoder = MoshiForCausalLM(config) + + config.depth_decoder_config._attn_implementation_internal = config._attn_implementation + self.depth_decoder = MoshiDepthDecoder(config.depth_decoder_config) + + self.num_codebooks = config.num_codebooks + self.post_init() + + def get_audio_encoder(self): + return self.audio_encoder + + def get_depth_decoder(self): + return self.depth_decoder + + def get_decoder(self): + return self.decoder + + @add_start_docstrings_to_model_forward(MOSHI_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.BoolTensor] = None, + user_input_values: Optional[torch.FloatTensor] = None, + user_audio_codes: Optional[torch.Tensor] = None, + moshi_input_values: Optional[torch.FloatTensor] = None, + moshi_audio_codes: Optional[torch.Tensor] = None, + past_key_values: Tuple[Tuple[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + text_labels: Optional[torch.LongTensor] = None, + audio_labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs, + ) -> Union[Tuple, Seq2SeqLMOutput]: + r""" + Returns: + + Examples: + ```python + >>> from transformers import MoshiForConditionalGeneration + >>> import torch + + >>> model = MoshiForConditionalGeneration.from_pretrained("kmhf/hf-moshiko") + >>> inputs = moshi.get_unconditional_inputs() + + >>> logits = model(**inputs, ).logits + >>> logits.shape # (bsz, seq_len, text_vocab_size) + torch.Size([1, 1, 32000]) + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + kwargs_audio_encoder = { + argument[len("audio_encoder_")]: value + for argument, value in kwargs.items() + if argument.startswith("audio_encoder_") + } + + kwargs_decoder = { + argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_") + } + + kwargs_depth_decoder = { + argument[len("depth_decoder_") :]: value + for argument, value in kwargs.items() + if argument.startswith("depth_decoder_") + } + + # If inputs_embeds is provided, it has the priority over input_ids and audio_codes, which won't be used + if inputs_embeds is None: + if user_input_values is not None and user_audio_codes is None: + user_audio_codes = self.audio_encoder.encode( + user_input_values, num_quantizers=self.num_codebooks, **kwargs_audio_encoder + )[0] + + if moshi_input_values is not None and moshi_audio_codes is None: + moshi_audio_codes = self.audio_encoder.encode( + moshi_input_values, num_quantizers=self.num_codebooks, **kwargs_audio_encoder + )[0] + + audio_codes = torch.cat([moshi_audio_codes, user_audio_codes], dim=1) + + if input_ids is None and audio_codes is None: + raise ValueError( + "You must provide at least one of `input_ids`, `inputs_embeds`, `input_values` and `audio_codes`." + ) + + if input_ids is not None: + inputs_embeds = self.decoder.model.embed_tokens(input_ids) + + if audio_codes is not None: + audio_inputs_embeds = sum( + [self.embed_tokens[codebook](audio_codes[:, codebook]) for codebook in range(audio_codes.shape[1])] + ) + inputs_embeds = ( + audio_inputs_embeds + if inputs_embeds is None + else audio_inputs_embeds + inputs_embeds.to(audio_inputs_embeds.device) + ) + + # Decode + decoder_outputs = self.decoder( + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + use_cache=use_cache, + past_key_values=past_key_values, + return_dict=True, + labels=text_labels, + **kwargs_decoder, + ) + + decoder_last_hidden_state = decoder_outputs.last_hidden_state + + depth_decoder_outputs = None + final_loss = decoder_outputs.loss + if text_labels is not None and audio_labels is not None: + # To use depth decoder forward here, we actually need oracle input ids since we're supposed to pass the true input ids + + audio_labels = self.build_delay_pattern_mask( + audio_labels, + bos_token_id=self.config.audio_vocab_size, + pad_token_id=self.config.audio_vocab_size, + max_length=audio_labels.shape[-1] + 1, + )[0] + + # (batch_size, sequence_length) -> (batch_size * sequence_length, 1) + text_labels = text_labels.view(-1, 1) + + # (batch_size, num_codebooks, sequence_length) -> (batch_size * sequence_length, num_codebooks) + audio_labels = audio_labels.transpose(1, 2).reshape(-1, audio_labels.shape[1]) + + depth_input_ids = torch.cat([text_labels, audio_labels], dim=1) + # keep the last codebook out of input_ids + depth_input_ids = depth_input_ids[:, :-1] + + # (batch_size, sequence_length, dim) -> (batch_size * sequence_length, 1, dim) + decoder_last_hidden_state = decoder_last_hidden_state.view(-1, 1, decoder_last_hidden_state.shape[-1]) + + depth_decoder_outputs = self.depth_decoder( + last_hidden_state=decoder_last_hidden_state, + input_ids=depth_input_ids, + attention_mask=attention_mask, + labels=audio_labels, + **kwargs_depth_decoder, + ) + + final_loss += depth_decoder_outputs.loss + + if not return_dict: + outputs = decoder_outputs.to_tuple() + if depth_decoder_outputs is not None: + outputs += depth_decoder_outputs.to_tuple() + return outputs + + return MoshiConditionalGenerationOutputWithPast( + loss=decoder_outputs.loss, + logits=decoder_outputs.logits, + last_hidden_state=decoder_last_hidden_state, + past_key_values=decoder_outputs.past_key_values, + hidden_states=decoder_outputs.hidden_states, + attentions=decoder_outputs.attentions, + depth_loss=None if depth_decoder_outputs is None else depth_decoder_outputs.loss, + audio_logits=None if depth_decoder_outputs is None else depth_decoder_outputs.logits, + depth_past_key_values=None if decoder_outputs is None else decoder_outputs.past_key_values, + depth_hidden_states=None if decoder_outputs is None else decoder_outputs.hidden_states, + depth_attentions=None if decoder_outputs is None else decoder_outputs.attentions, + ) + + def _prepare_inputs_embeds_for_generation( + self, + input_ids: Optional[torch.LongTensor] = None, + user_input_values: Optional[torch.FloatTensor] = None, + user_audio_codes: Optional[torch.Tensor] = None, + moshi_input_values: Optional[torch.FloatTensor] = None, + moshi_audio_codes: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + generation_config: Optional[GenerationConfig] = None, + apply_delay_pattern_mask: bool = False, + concat_unconditional_inputs: bool = False, + ): + user_delay_pattern_mask = None + moshi_delay_pattern_mask = None + + if ( + inputs_embeds is None + and input_ids is None + and user_input_values is None + and user_audio_codes is None + and moshi_input_values is None + and moshi_audio_codes is None + ): + raise ValueError( + "You must provide at least one of `input_ids`, `user_input_values`, `moshi_input_values`, `user_audio_codes`, `moshi_audio_codes` or `inputs_embeds`." + ) + + # in case inputs_embeds is passed, we might still need to create delay pattern masks + if inputs_embeds is None or apply_delay_pattern_mask: + if user_input_values is not None and user_audio_codes is None: + user_audio_codes = self.audio_encoder.encode(user_input_values, num_quantizers=self.num_codebooks)[0] + + if moshi_input_values is not None and moshi_audio_codes is None: + moshi_audio_codes = self.audio_encoder.encode(moshi_input_values, num_quantizers=self.num_codebooks)[0] + + if inputs_embeds is None and concat_unconditional_inputs: + unconditional_inputs = self.get_unconditional_inputs(num_samples=user_audio_codes.shape[0]) + moshi_audio_codes = torch.cat([unconditional_inputs.moshi_audio_codes, moshi_audio_codes], dim=2) + user_audio_codes = torch.cat([unconditional_inputs.user_audio_codes, user_audio_codes], dim=2) + input_ids = torch.cat([unconditional_inputs.input_ids, input_ids], dim=1) + if attention_mask is not None: + attention_mask = torch.cat([unconditional_inputs.attention_mask, attention_mask], dim=1) + + if inputs_embeds is None or apply_delay_pattern_mask: + if apply_delay_pattern_mask and user_audio_codes is not None: + user_audio_codes, user_delay_pattern_mask = self.build_delay_pattern_mask( + user_audio_codes, + bos_token_id=self.config.audio_vocab_size, + pad_token_id=self.config.audio_vocab_size, + max_length=generation_config.max_length, + ) + + if apply_delay_pattern_mask and moshi_audio_codes is not None: + moshi_audio_codes, moshi_delay_pattern_mask = self.build_delay_pattern_mask( + moshi_audio_codes, + bos_token_id=self.config.audio_vocab_size, + pad_token_id=self.config.audio_vocab_size, + max_length=generation_config.max_length, + ) + + # If inputs_embeds is provided, it has the priority over input_ids and audio_codes, which won't be used + if inputs_embeds is None: + audio_inputs_embeds = None + if user_audio_codes is not None and moshi_audio_codes is not None: + audio_codes = torch.cat([moshi_audio_codes, user_audio_codes], dim=1) + audio_inputs_embeds = sum( + [self.embed_tokens[codebook](audio_codes[:, codebook]) for codebook in range(audio_codes.shape[1])] + ) + elif moshi_audio_codes is not None: + audio_codes = moshi_audio_codes + audio_inputs_embeds = sum( + [self.embed_tokens[codebook](audio_codes[:, codebook]) for codebook in range(audio_codes.shape[1])] + ) + elif user_audio_codes is not None: + audio_codes = user_audio_codes + audio_inputs_embeds = sum( + [ + self.embed_tokens[codebook](audio_codes[:, codebook + self.num_codebooks]) + for codebook in range(audio_codes.shape[1]) + ] + ) + + if input_ids is not None: + inputs_embeds = self.decoder.model.embed_tokens(input_ids) + + if audio_inputs_embeds is not None: + inputs_embeds = ( + audio_inputs_embeds + if inputs_embeds is None + else audio_inputs_embeds + inputs_embeds.to(audio_inputs_embeds.device) + ) + + return ( + inputs_embeds, + input_ids, + user_audio_codes, + moshi_audio_codes, + user_delay_pattern_mask, + moshi_delay_pattern_mask, + attention_mask, + ) + + @torch.no_grad() + def generate( + self, + input_ids: Optional[torch.LongTensor] = None, + user_input_values: Optional[torch.FloatTensor] = None, + user_audio_codes: Optional[torch.Tensor] = None, + moshi_input_values: Optional[torch.FloatTensor] = None, + moshi_audio_codes: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + return_audio_waveforms: Optional[bool] = True, + return_audio_codes: Optional[bool] = None, + concat_unconditional_inputs: Optional[bool] = True, + **kwargs, + ) -> torch.LongTensor: + """ + Generates sequences of text token ids and audio tokens ids. + + Parameters: + input_ids (`torch.Tensor `of shape `(batch_size, sequence_length), *optional*): + The sequence used as a text prompt for the generation. + user_input_values (`torch.Tensor `of shape `(batch_size, 1, audio_sequence_length), *optional*): + The audio waveforms used as audio user prompt for the generation. + user_audio_codes (`torch.Tensor `of shape `(batch_size, num_codebooks, sequence_length), *optional*): + The audio codes used as audio user prompt for the generation. Has priority over `user_input_values` and represents the audio "tokens" of `user_input_values` once passed through the audio encoder. + moshi_input_values (`torch.Tensor `of shape `(batch_size, 1, audio_sequence_length), *optional*): + The audio waveforms used as audio Moshi prompt for the generation. + moshi_audio_codes (`torch.Tensor `of shape `(batch_size, num_codebooks, sequence_length), *optional*): + The audio codes used as audio Moshi prompt for the generation. Has priority over `moshi_input_values` and represents the audio "tokens" of `moshi_input_values` once passed through the audio encoder. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` and the audio inputs you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert the inputs into associated vectors than the + model's internal embedding lookup matrix. + return_audio_waveforms (`bool`, *optional*, defaults to `True`): + If `False`, won't generate the audio waveforms. + return_audio_codes (`bool`, *optional*): + If `True`, will also returns the generated audio codes, i.e the intermediate audio "tokens" which transforms to `audio_sequences` once passed through the audio decoder. + concat_unconditional_inputs (`bool`, *optional*, defaults to `True`): + If `False`, won't concatenate initial audio and text tokens. + kwargs (`Dict[str, Any]`, *optional*): + Remaining dictionary of keyword arguments that are passed to the `generate` method. Refers to the + original [`generate` docstrings](https://huggingface.co/docs/transformers/main/en/main_classes/text_generation#transformers.GenerationMixin.generate) + for more information on how to use them. + Note that keywords with a *depth_* prefix will be input for the `generate` method of the + depth decoder. Otherwise, the latter will use its default generation config. + Return: + [`MoshiConditionalGenerationGenerateOutput`] + """ + # multiple generate -> need to create/update device map + if hasattr(self, "hf_device_map") and not hasattr(self.depth_decoder, "hf_device_map"): + self.depth_decoder.hf_device_map = {} + if "" in self.hf_device_map: + self.depth_decoder.hf_device_map = self.hf_device_map + else: + main_device = [d for d in self.hf_device_map.values() if d not in ["cpu", "disk"]][0] + self.depth_decoder.hf_device_map = { + key[len("depth_decoder") :]: main_device if value in ["cpu", "disk"] else value + for key, value in self.hf_device_map.items() + if key.startswith("depth_decoder") + } + # need to remove depth_decoder from the top device_map so that we assign correctly the device for each layer idx in the cache + self.hf_device_map = { + key: value for key, value in self.hf_device_map.items() if not key.startswith("depth_decoder") + } + # retrieve depth decoder kwargs + depth_decoder_kwargs_keys = {argument for argument in kwargs if argument.startswith("depth_decoder_")} + kwargs_depth_decoder = { + argument[len("depth_decoder_") :]: kwargs.pop(argument) for argument in depth_decoder_kwargs_keys + } + + # needs to prepare generation config, even though it'll be done again in `generate` + generation_config, kwargs = self._prepare_generation_config(kwargs.pop("generation_config", None), **kwargs) + + input_ids, user_audio_codes, moshi_audio_codes, concat_unconditional_inputs = ( + self._check_and_maybe_initalize_inputs( + input_ids=input_ids, + user_input_values=user_input_values, + user_audio_codes=user_audio_codes, + moshi_input_values=moshi_input_values, + moshi_audio_codes=moshi_audio_codes, + inputs_embeds=inputs_embeds, + concat_unconditional_inputs=concat_unconditional_inputs, + ) + ) + + inputs = inputs_embeds if input_ids is None else input_ids + + input_ids_length = inputs.shape[-1] + 1 if concat_unconditional_inputs else inputs.shape[-1] + has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None + has_default_min_length = kwargs.get("min_length") is None and generation_config.min_length is not None + generation_config = self._prepare_generated_length( + generation_config=generation_config, + has_default_max_length=has_default_max_length, + has_default_min_length=has_default_min_length, + model_input_name="inputs_embeds" if input_ids is None else "input_ids", + inputs_tensor=inputs, + input_ids_length=input_ids_length, + ) + + # retrieve depth decoder generation config if it exists + if hasattr(generation_config, "depth_decoder_config"): + depth_decoder_generation_config = generation_config.depth_decoder_config + else: + # we need to control the number of tokens generated by the depth decoder + depth_decoder_generation_config = { + "min_length": self.num_codebooks + 1, + "max_length": self.num_codebooks + 1, + "cache_implementation": "sliding_window", + } + # update kwargs_depth_decoder: kwargs_depth_decoder have priority over depth_decoder_generation_config + depth_decoder_generation_config.update(kwargs_depth_decoder) + kwargs_depth_decoder = depth_decoder_generation_config + + attention_mask = kwargs.pop("attention_mask", None) + ( + inputs_embeds, + input_ids, + user_audio_codes, + moshi_audio_codes, + user_delay_pattern_mask, + moshi_delay_pattern_mask, + attention_mask, + ) = self._prepare_inputs_embeds_for_generation( + input_ids=input_ids, + user_input_values=user_input_values, + user_audio_codes=user_audio_codes, + moshi_input_values=moshi_input_values, + moshi_audio_codes=moshi_audio_codes, + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + generation_config=generation_config, + apply_delay_pattern_mask=True, + concat_unconditional_inputs=concat_unconditional_inputs, + ) + + # create blank user inputs - moshi needs a constant stream of user inputs + blank_input_values = torch.zeros( + (inputs_embeds.shape[0], 1, int(self.config.sampling_rate / self.config.audio_encoder_config.frame_rate)), + dtype=self.dtype, + device=self.device, + ) + blank_user_audio_codes = self.audio_encoder.encode(blank_input_values, num_quantizers=self.num_codebooks)[0] + + # set delay pattern mask for the rest of the generation + kwargs["user_delay_pattern_mask"] = ( + user_delay_pattern_mask if user_delay_pattern_mask is not None else kwargs.get("user_delay_pattern_mask") + ) + kwargs["moshi_delay_pattern_mask"] = ( + moshi_delay_pattern_mask + if moshi_delay_pattern_mask is not None + else kwargs.get("moshi_delay_pattern_mask") + ) + + self.generated_audio_codes = torch.repeat_interleave( + moshi_audio_codes, max(generation_config.num_beams, generation_config.num_return_sequences), dim=0 + ) + + return_dict_in_generate = generation_config.num_beams > 1 or generation_config.return_dict_in_generate + output_scores = generation_config.num_beams > 1 or generation_config.output_scores + outputs = super().generate( + inputs_embeds=inputs_embeds, + input_ids=input_ids, + generation_config=generation_config, + blank_user_audio_codes=blank_user_audio_codes, + kwargs_depth_decoder=kwargs_depth_decoder, + return_dict_in_generate=return_dict_in_generate, + output_scores=output_scores, + attention_mask=attention_mask, + **kwargs, + ) + + if not return_audio_waveforms and not return_audio_codes: + if return_dict_in_generate and not generation_config.return_dict_in_generate: + return outputs.sequences + return outputs + + # check if outputs is a dict or tokens + if not return_dict_in_generate: + output_text_ids = outputs + else: + output_text_ids = outputs.sequences + + if generation_config.num_return_sequences > 1: + moshi_delay_pattern_mask = torch.repeat_interleave( + moshi_delay_pattern_mask, generation_config.num_return_sequences, dim=0 + ) + + if generation_config.num_beams > 1: + # we need to reorganize self.last_hidden_states and generated audio codes according to the beam_indices + + # Beam indices are of shape `input_length + number_generated_tokens` but actually starts + # indexing indices at index 0 instead of index `input_length-1`. + # We thus discard the last `input_length` indices that are never used. + beam_indices = outputs.beam_indices[:, : -moshi_audio_codes.shape[-1]] + + generated_audio_codes = self.generated_audio_codes[:, :, moshi_audio_codes.shape[-1] :] + + # we've generated audio tokens `number_generated_tokens-1` times, so we use the corresponding beam indices to + # retrieve the right audio tokens + expanded_beam_indices = beam_indices[:, :-1].unsqueeze(1).expand(-1, self.num_codebooks, -1) + generated_audio_codes = torch.gather(generated_audio_codes, dim=0, index=expanded_beam_indices) + + # now, rebuild generated audio codes, this time with the right beam tracking + moshi_audio_codes = torch.repeat_interleave( + moshi_audio_codes, generation_config.num_return_sequences, dim=0 + ) + self.generated_audio_codes = torch.cat((moshi_audio_codes, generated_audio_codes), dim=2) + + # use the last beam indice to retrieve the right self.last_hidden_state + self.last_hidden_state = torch.index_select(self.last_hidden_state, dim=0, index=beam_indices[:, -1]) + + # we need to make a last generation with the latest generated tokens + last_hidden_state = self.last_hidden_state.view(-1, 1, self.last_hidden_state.shape[-1]) + + last_generated_audio_codes = self.depth_decoder.generate( + last_hidden_state=last_hidden_state, + input_ids=output_text_ids[:, -1:].view(-1, 1), + **kwargs_depth_decoder, + ) + + last_generated_audio_codes = last_generated_audio_codes[:, 1:].unsqueeze(2) + + self.generated_audio_codes = torch.cat([self.generated_audio_codes, last_generated_audio_codes], dim=2) + + # apply the pattern mask to the final audio ids + output_audio_codes = self.apply_delay_pattern_mask(self.generated_audio_codes, moshi_delay_pattern_mask) + + # revert the pattern delay mask by filtering the pad token id and bos token ids + mask = moshi_delay_pattern_mask != self.config.audio_vocab_size + + output_audio_codes = output_audio_codes[mask].reshape(mask.shape[0], self.num_codebooks, -1) + + output_values = None + if return_audio_waveforms: + output_values = self.audio_encoder.decode( + output_audio_codes, + ).audio_values + + output_audio_codes = output_audio_codes if return_audio_codes else None + + if generation_config.return_dict_in_generate: + return MoshiConditionalGenerationGenerateOutput( + audio_sequences=output_values, audio_codes=output_audio_codes, **outputs + ) + + return MoshiConditionalGenerationGenerateOutput( + audio_sequences=output_values, sequences=output_text_ids, audio_codes=output_audio_codes + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + position_ids=None, + use_cache=True, + num_logits_to_keep=None, + user_delay_pattern_mask=None, + moshi_delay_pattern_mask=None, + kwargs_depth_decoder=None, + blank_user_audio_codes: Optional[torch.FloatTensor] = None, + **kwargs, + ): + # 1. Do usual operations done on LLMs like Gemma - because we pre-processed inputs, the first pass always has inputs_embeds + + # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens + # Exception 1: when passing input_embeds, input_ids may be missing entries + # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here + if past_key_values is not None: + if inputs_embeds is not None: # Exception 1 + input_ids = input_ids[:, -cache_position.shape[0] :] + elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) + input_ids = input_ids[:, cache_position] + + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -input_ids.shape[1] :] + # This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture. + position_ids = position_ids.clone(memory_format=torch.contiguous_format) + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and cache_position[0] == 0: + model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None} + else: + # The clone here is for the same reason as for `position_ids`. + model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None} + + if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2: + if model_inputs["inputs_embeds"] is not None: + batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape + device = model_inputs["inputs_embeds"].device + else: + batch_size, sequence_length = model_inputs["input_ids"].shape + device = model_inputs["input_ids"].device + + dtype = self.decoder.dtype + + attention_mask = self.decoder.model._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=past_key_values.max_cache_len, + dtype=dtype, + device=device, + cache_position=cache_position, + batch_size=batch_size, + config=self.config, + past_key_values=past_key_values, + ) + + if num_logits_to_keep is not None: + model_inputs["num_logits_to_keep"] = num_logits_to_keep + + model_inputs.update( + { + "position_ids": position_ids, + "cache_position": cache_position, + "past_key_values": past_key_values, + "use_cache": use_cache, + "attention_mask": attention_mask, + } + ) + + # 2. Now that everything is prepared, generate audio_codes using the depth decoder + + # we want to do it after a first token has been generated + if model_inputs["input_ids"] is not None: + last_hidden_state = kwargs.get("last_hidden_state") + # (batch_size, sequence_length, dim) -> (batch_size * sequence_length, 1, dim) + last_hidden_state = last_hidden_state.view(-1, 1, last_hidden_state.shape[-1]) + + input_ids = model_inputs.pop("input_ids") + + generated_audio_codes = self.depth_decoder.generate( + last_hidden_state=last_hidden_state, + input_ids=input_ids.view(-1, 1), + **kwargs_depth_decoder, + ) + + # the first tokens are text tokens + generated_audio_codes = generated_audio_codes[:, 1:].unsqueeze(2) + + user_audio_codes = self.apply_delay_pattern_mask( + torch.cat( + [self.generated_audio_codes, blank_user_audio_codes.to(self.generated_audio_codes.device)], dim=2 + ), + user_delay_pattern_mask, + )[:, :, -1:] + self.generated_audio_codes = self.apply_delay_pattern_mask( + torch.cat([self.generated_audio_codes, generated_audio_codes], dim=2), moshi_delay_pattern_mask + ) + + inputs_embeds, _, _, _, _, _, _ = self._prepare_inputs_embeds_for_generation( + input_ids, moshi_audio_codes=self.generated_audio_codes[:, :, -1:], user_audio_codes=user_audio_codes + ) + + model_inputs["input_ids"] = None + model_inputs["inputs_embeds"] = inputs_embeds + + return model_inputs + + def _update_model_kwargs_for_generation( + self, + outputs: ModelOutput, + model_kwargs: Dict[str, Any], + is_encoder_decoder: bool = False, + num_new_tokens: int = 1, + ) -> Dict[str, Any]: + model_kwargs = super()._update_model_kwargs_for_generation( + outputs, model_kwargs, is_encoder_decoder, num_new_tokens + ) + + # update last_hidden_state that'll be used in the depth decoder + model_kwargs["last_hidden_state"] = outputs.get("last_hidden_state")[:, -1:] + + # dirty, but we need to make a last depth_decoder.generate + self.last_hidden_state = outputs.get("last_hidden_state")[:, -1:] + return model_kwargs + + def get_input_embeddings(self): + return self.decoder.get_input_embeddings() + + def set_input_embeddings(self, value): + self.decoder.set_input_embeddings(value) + + def get_output_embeddings(self): + return self.decoder.get_output_embeddings() + + def set_output_embeddings(self, new_embeddings): + self.decoder.set_output_embeddings(new_embeddings) + + def freeze_audio_encoder(self): + """ + Freeze the audio encoder weights. + """ + for param in self.audio_encoder.parameters(): + param.requires_grad = False + self.audio_encoder._requires_grad = False + + def freeze_depth_decoder(self): + """ + Freeze the depth encoder weights. + """ + for param in self.depth_decoder.parameters(): + param.requires_grad = False + self.depth_decoder._requires_grad = False + + @staticmethod + # Copied from transformers.models.musicgen.modeling_musicgen.MusicgenForCausalLM.apply_delay_pattern_mask + def apply_delay_pattern_mask(input_ids, decoder_pad_token_mask): + """Apply a delay pattern mask to the decoder input ids, only preserving predictions where + the mask is set to -1, and otherwise setting to the value detailed in the mask.""" + seq_len = input_ids.shape[-1] + decoder_pad_token_mask = decoder_pad_token_mask[..., :seq_len] + input_ids = torch.where(decoder_pad_token_mask == -1, input_ids, decoder_pad_token_mask) + return input_ids + + def build_delay_pattern_mask( + self, input_ids: torch.LongTensor, bos_token_id: int, pad_token_id: int, max_length: int = None + ): + """Build a delayed pattern mask to the input_ids. Each codebook, except the first one, is offset by + one, giving a delayed pattern mask at the start of sequence and end of sequence. Take the example where there + are 4 codebooks and a max sequence length of 6, we have the delayed pattern mask of shape `(codebooks, + seq_len)`: + - [-1, -1, -1, -1, -1, P] + - [ B, -1, -1, -1, -1, -1] + - [ B, -1, -1, -1, -1, -1] + - [ B, -1, -1, -1, -1, -1] + where B is the begining-of-sentence token, P is the special padding token id and -1 indicates that the token is valid for prediction. If we include + a prompt (input ids), the -1 positions indicate where new tokens should be predicted. Otherwise, the + mask is set to the value in the prompt: + - [ a0, a1, -1, -1, -1, P] + - [ B, b0, b1, -1, -1, -1] + - [ B, c0, c1, -1, -1, -1] + - [ B, d0, d1, -1, -1, -1] + where a-d indicate the codebook channel and 0/1 indicates the temporality. Now, we only override the -1 + tokens in our prediction. + """ + bsz, num_codebooks, seq_len = input_ids.shape + + max_length = max_length if max_length is not None else self.generation_config.max_length + input_ids_shifted = ( + torch.ones((bsz, num_codebooks, max_length), dtype=torch.long, device=input_ids.device) * -1 + ) + + # the first codebook channel is not shifted + seq_len_to_keep = min(seq_len, max_length - 1) + input_ids_shifted[:, 0, :seq_len_to_keep] = input_ids[:, 0, :seq_len_to_keep] + + # fill the shifted ids with the prompt entries + input_ids_shifted[:, 1:, 1 : seq_len_to_keep + 1] = input_ids[:, 1:, :seq_len_to_keep] + + # fill with BOS and PAD + input_ids_shifted[:, 1:, 0] = bos_token_id + input_ids_shifted[:, 0, -1] = pad_token_id + + # construct a pattern mask that indicates the positions of BOS and PAD tokens for each codebook + pattern_mask = input_ids_shifted + + input_ids = input_ids_shifted[..., :seq_len_to_keep] + return input_ids, pattern_mask + + def get_unconditional_inputs(self, num_samples=1): + """ + Helper function to get null inputs for unconditional generation, enabling the model to be used without the + feature extractor or tokenizer. + + Args: + num_samples (int, *optional*): + Number of audio samples to unconditionally generate. + max_new_tokens (int, *optional*): + Number of tokens to generate for each sample. More tokens means longer audio samples, at the expense of + longer inference (since more audio tokens need to be generated per sample). + + Example: + ```python + >>> from transformers import MoshiForConditionalGeneration + + >>> model = MoshiForConditionalGeneration.from_pretrained("kmhf/hf-moshiko-pytorch-bf16") + + >>> # get the unconditional (or 'null') inputs for the model + >>> unconditional_inputs = model.get_unconditional_inputs(num_samples=1) + >>> audio_samples = model.generate(**unconditional_inputs, max_new_tokens=256) + ```""" + + input_ids = torch.ones((num_samples, 1), device=self.device, dtype=torch.int64) * self.config.vocab_size + user_audio_codes = ( + torch.ones((num_samples, self.num_codebooks, 1), device=self.device, dtype=torch.int64) + * self.config.audio_vocab_size + ) + moshi_audio_codes = ( + torch.ones((num_samples, self.num_codebooks, 1), device=self.device, dtype=torch.int64) + * self.config.audio_vocab_size + ) + attention_mask = torch.ones((num_samples, 1), device=self.device, dtype=torch.long) + + return MoshiUnconditionalInput( + input_ids=input_ids, + user_audio_codes=user_audio_codes, + moshi_audio_codes=moshi_audio_codes, + attention_mask=attention_mask, + ) + + def _check_and_maybe_initalize_inputs( + self, + input_ids=None, + user_input_values=None, + user_audio_codes=None, + moshi_input_values=None, + moshi_audio_codes=None, + inputs_embeds=None, + concat_unconditional_inputs=None, + ): + inputs = input_ids if inputs_embeds is None else inputs_embeds + user_input = user_audio_codes if user_input_values is None else user_input_values + moshi_input = moshi_audio_codes if moshi_input_values is None else moshi_input_values + + one_input_has_been_passed = (user_input is not None) or (moshi_input is not None) or (inputs is not None) + + # concat_unconditional_inputs will be False if inputs_embeds is used + concat_unconditional_inputs = concat_unconditional_inputs and not ( + inputs_embeds is not None and input_ids is None + ) + + # if one or two of the three required inputs have been passed, throws an error + if one_input_has_been_passed and (user_input is None): + raise ValueError( + "No user audio inputs have been passed alongside the other inputs. Make sure either `user_input_values` or `user_audio_codes` is passed or use `MoshiForConditionalGeneration.get_unconditional_inputs`. Check the `MoshiForConditionalGeneration` docstrings for more information." + ) + elif one_input_has_been_passed and (moshi_input is None): + raise ValueError( + "No Moshi audio inputs have been passed alongside the other inputs. Make sure either `moshi_input_values` or `moshi_audio_codes` is passed or use `MoshiForConditionalGeneration.get_unconditional_inputs`. Check the `MoshiForConditionalGeneration` docstrings for more information." + ) + elif one_input_has_been_passed and (inputs is None): + raise ValueError( + "No `input_ids` or `inputs_embeds` have been passed alongside the other inputs. Make sure `input_ids` is passed or use `MoshiForConditionalGeneration.get_unconditional_inputs`. Check the `MoshiForConditionalGeneration` docstrings for more information." + ) + elif not one_input_has_been_passed: + # if no inputs have been passed, use default values + unconditional_inputs = self.get_unconditional_inputs() + input_ids = unconditional_inputs.input_ids + user_audio_codes = unconditional_inputs.user_audio_codes + moshi_audio_codes = unconditional_inputs.moshi_audio_codes + + # in that case, no need to concat unconditional inputs + concat_unconditional_inputs = False + else: + # check if same sequence length + user_seq_length = user_input.shape[-1] + moshi_seq_length = moshi_input.shape[-1] + tokens_seq_length = inputs.shape[1] + + ratio = self.config.audio_encoder_config.frame_rate / self.config.sampling_rate + moshi_seq_length = math.ceil(moshi_seq_length * ratio) if moshi_audio_codes is None else moshi_seq_length + user_seq_length = math.ceil(user_seq_length * ratio) if user_audio_codes is None else user_seq_length + + if tokens_seq_length != moshi_seq_length or tokens_seq_length != user_seq_length: + raise ValueError( + "At least one of the 3 inputs of `MoshiForConditionalGeneration` doesn't have the same sequence length as the others." + "Make sure that they all have the same sequence length. Check the `MoshiForConditionalGeneration` docstrings for more information." + ) + + return input_ids, user_audio_codes, moshi_audio_codes, concat_unconditional_inputs + + @staticmethod + def _reorder_cache( + past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor + ) -> Tuple[Tuple[torch.Tensor]]: + """ + This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or + [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct + beam_idx at every generation step. + """ + return tuple( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past) + for layer_past in past_key_values + ) + + +__all__ = ["MoshiForCausalLM", "MoshiForConditionalGeneration", "MoshiModel", "MoshiPreTrainedModel"] diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index 4ca25bc7914..d7570c57c62 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -6219,6 +6219,34 @@ class MobileViTV2PreTrainedModel(metaclass=DummyObject): requires_backends(self, ["torch"]) +class MoshiForCausalLM(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class MoshiForConditionalGeneration(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class MoshiModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class MoshiPreTrainedModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class MPNetForMaskedLM(metaclass=DummyObject): _backends = ["torch"] diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index a1bc5265667..5165e43c099 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -1098,6 +1098,7 @@ class GenerationTesterMixin: "bigbirdpegasus", "led", "mega", + "moshi", "speech2text", "git", "prophetnet", @@ -1172,6 +1173,7 @@ class GenerationTesterMixin: "bigbirdpegasus", "led", "mega", + "moshi", "speech2text", "git", "prophetnet", @@ -1285,6 +1287,7 @@ class GenerationTesterMixin: "bigbirdpegasus", "led", "mega", + "moshi", "speech2text", "git", "prophetnet", diff --git a/tests/models/mimi/test_modeling_mimi.py b/tests/models/mimi/test_modeling_mimi.py index ab6184ce2bb..074dceae155 100644 --- a/tests/models/mimi/test_modeling_mimi.py +++ b/tests/models/mimi/test_modeling_mimi.py @@ -790,6 +790,7 @@ class MimiIntegrationTest(unittest.TestCase): } librispeech_dummy = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") + model_id = "kyutai/mimi" model = MimiModel.from_pretrained(model_id, use_cache=True).to(torch_device) @@ -840,6 +841,7 @@ class MimiIntegrationTest(unittest.TestCase): "32": 1803071, } librispeech_dummy = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") + model_id = "kyutai/mimi" processor = AutoFeatureExtractor.from_pretrained(model_id) diff --git a/tests/models/moshi/__init__.py b/tests/models/moshi/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tests/models/moshi/test_modeling_moshi.py b/tests/models/moshi/test_modeling_moshi.py new file mode 100644 index 00000000000..b299b414d60 --- /dev/null +++ b/tests/models/moshi/test_modeling_moshi.py @@ -0,0 +1,1126 @@ +# coding=utf-8 +# Copyright 2024, The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Testing suite for the PyTorch Moshi model.""" + +import copy +import tempfile +import unittest + +import numpy as np +import pytest +from datasets import Audio, load_dataset +from parameterized import parameterized + +from transformers import ( + MoshiConfig, + PretrainedConfig, +) +from transformers.integrations.deepspeed import ( + is_deepspeed_available, + is_deepspeed_zero3_enabled, +) +from transformers.testing_utils import ( + is_flaky, + is_torch_available, + require_torch, + require_torch_fp16, + require_torch_sdpa, + slow, + torch_device, +) +from transformers.utils import cached_property + +from ...generation.test_utils import GenerationTesterMixin +from ...test_configuration_common import ConfigTester +from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor +from ...test_pipeline_mixin import PipelineTesterMixin + + +if is_deepspeed_available(): + import deepspeed + +if is_torch_available(): + import torch + + from transformers import ( + AutoFeatureExtractor, + AutoTokenizer, + MoshiForCausalLM, + MoshiForConditionalGeneration, + MoshiModel, + ) + + +def _config_zero_init(config): + configs_no_init = copy.deepcopy(config) + for key in configs_no_init.__dict__.keys(): + if "_range" in key or "_std" in key or "initializer_factor" in key or "layer_scale" in key: + setattr(configs_no_init, key, 1e-10) + if isinstance(getattr(configs_no_init, key, None), PretrainedConfig): + no_init_subconfig = _config_zero_init(getattr(configs_no_init, key)) + setattr(configs_no_init, key, no_init_subconfig) + return configs_no_init + + +class MoshiDecoderTester: + def __init__( + self, + parent, + batch_size=4, # need batch_size != num_hidden_layers + seq_length=7, + is_training=True, + vocab_size=99, + hidden_size=32, + num_hidden_layers=2, + num_attention_heads=4, + intermediate_size=4, + hidden_act="silu", + rms_norm_eps=0.001, + ffn_dim=32, + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=100, + pad_token_id=25, + num_codebooks=4, + audio_encoder_type="mimi", + attn_implementation="eager", + ): + self.parent = parent + self.batch_size = batch_size + self.seq_length = seq_length + self.is_training = is_training + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.rms_norm_eps = rms_norm_eps + self.ffn_dim = ffn_dim + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.pad_token_id = pad_token_id + self.num_codebooks = num_codebooks + self.audio_encoder_type = audio_encoder_type + self.attn_implementation = attn_implementation + + def prepare_config_and_inputs(self, batch_size=None): + batch_size = self.batch_size if batch_size is None else batch_size + input_ids = ids_tensor([batch_size, self.seq_length], self.vocab_size) + config = self.get_config() + + attention_mask = input_ids.ne(self.pad_token_id) + + inputs_dict = {"input_ids": input_ids, "attention_mask": attention_mask} + return config, inputs_dict + + def get_config(self): + config = MoshiConfig( + vocab_size=self.vocab_size, + hidden_size=self.hidden_size, + num_hidden_layers=self.num_hidden_layers, + num_attention_heads=self.num_attention_heads, + d_ff=self.intermediate_size, + num_codebooks=self.num_codebooks, + rms_norm_eps=self.rms_norm_eps, + tie_word_embeddings=False, + pad_token_id=self.pad_token_id, + ffn_dim=self.ffn_dim, + audio_encoder_config={"model_type": self.audio_encoder_type}, + attn_implementation=self.attn_implementation, + ) + return config + + def prepare_config_and_inputs_for_common(self, batch_size=None): + config, inputs_dict = self.prepare_config_and_inputs(batch_size) + return config, inputs_dict + + +@require_torch +class MoshiDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase): + all_model_classes = (MoshiModel, MoshiForCausalLM) if is_torch_available() else () + all_generative_model_classes = ( + (MoshiForCausalLM,) if is_torch_available() else () + ) # we don't want to run all the generation tests, only a specific subset + test_pruning = False + test_resize_embeddings = True + test_head_masking = False + pipeline_model_mapping = ( + { + "feature-extraction": MoshiModel, + "text-generation": MoshiForCausalLM, + } + if is_torch_available() + else {} + ) + + def setUp(self): + self.model_tester = MoshiDecoderTester(self) + self.config_tester = ConfigTester( + self, + config_class=MoshiConfig, + hidden_size=16, + audio_encoder_config={"model_type": self.model_tester.audio_encoder_type}, + ) + + @unittest.skip(reason="The MoshiModel does not have support dynamic compile yet") + def test_sdpa_can_compile_dynamic(self): + pass + + def _get_input_ids_and_config(self, batch_size=1): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common(batch_size) + input_ids = inputs_dict.pop("input_ids").to(torch_device) + attention_mask = inputs_dict.pop("attention_mask").to(torch_device) + + return config, input_ids, attention_mask, inputs_dict + + def _get_logits_processor_kwargs(self, do_sample=False, config=None): + logits_processor_kwargs = {} + return logits_processor_kwargs + + @require_torch_sdpa + @slow + @parameterized.expand([("float16",), ("bfloat16",), ("float32",)]) + def test_eager_matches_sdpa_inference(self, torch_dtype: str): + self.skipTest(reason="Moshi has no strict equivalence between two modes, skipping this test.") + + # Copied from tests.test_modeling_common.ModelTesterMixin.test_resize_tokens_embeddings + def test_resize_tokens_embeddings(self): + if not self.test_resize_embeddings: + self.skipTest(reason="test_resize_embeddings is set to `False`") + + ( + original_config, + inputs_dict, + ) = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + config = copy.deepcopy(original_config) + if is_deepspeed_zero3_enabled(): + with deepspeed.zero.Init(): + model = model_class(config) + else: + model = model_class(config) + model.to(torch_device) + + model_embed_pre_resize = model.get_input_embeddings() + type_model_embed_pre_resize = type(model_embed_pre_resize) + + if self.model_tester.is_training is False: + model.eval() + + model_vocab_size = config.get_text_config().vocab_size + # Retrieve the embeddings and clone theme + model_embed = model.resize_token_embeddings(model_vocab_size) + cloned_embeddings = model_embed.weight.clone() + + # Check that resizing the token embeddings with a larger vocab size increases the model's vocab size + model_embed = model.resize_token_embeddings(model_vocab_size + 10) + new_model_vocab_size = model.config.get_text_config().vocab_size + self.assertEqual(new_model_vocab_size, model_vocab_size + 10) + # Check that it actually resizes the embeddings matrix + self.assertEqual(model_embed.weight.shape[0], cloned_embeddings.shape[0] + 10) + # Check to make sure the type of embeddings returned post resizing is same as type of input + type_model_embed_post_resize = type(model_embed) + self.assertEqual(type_model_embed_pre_resize, type_model_embed_post_resize) + # Check that added embeddings mean is close to the old embeddings mean + if is_deepspeed_zero3_enabled(): + with deepspeed.zero.GatheredParameters(model_embed.weight, modifier_rank=None): + old_embeddings_mean = torch.mean(model_embed.weight.data[:-10, :], axis=0) + new_embeddings_mean = torch.mean(model_embed.weight.data[-10:, :], axis=0) + else: + old_embeddings_mean = torch.mean(model_embed.weight.data[:-10, :], axis=0) + new_embeddings_mean = torch.mean(model_embed.weight.data[-10:, :], axis=0) + torch.testing.assert_close(old_embeddings_mean, new_embeddings_mean, atol=1e-3, rtol=1e-1) + + # Check that the model can still do a forward pass successfully (every parameter should be resized) + if not is_deepspeed_zero3_enabled(): + # A distriputed launcher is needed for the forward pass when deepspeed is enabled + model(**self._prepare_for_class(inputs_dict, model_class)) + + # Check that resizing the token embeddings with a smaller vocab size decreases the model's vocab size + model_embed = model.resize_token_embeddings(model_vocab_size - 15) + new_model_vocab_size = model.config.get_text_config().vocab_size + self.assertEqual(new_model_vocab_size, model_vocab_size - 15) + # Check that it actually resizes the embeddings matrix + self.assertEqual(model_embed.weight.shape[0], cloned_embeddings.shape[0] - 15) + + # Check that the model can still do a forward pass successfully (every parameter should be resized) + # Input ids should be clamped to the maximum size of the vocabulary + inputs_dict["input_ids"].clamp_(max=model_vocab_size - 15 - 1) + + # make sure that decoder_input_ids are resized as well + if not is_deepspeed_zero3_enabled(): + # A distriputed launcher is needed for the forward pass when deepspeed is enabled + if "decoder_input_ids" in inputs_dict: + inputs_dict["decoder_input_ids"].clamp_(max=model_vocab_size - 15 - 1) + model(**self._prepare_for_class(inputs_dict, model_class)) + + # Check that adding and removing tokens has not modified the first part of the embedding matrix. + models_equal = True + for p1, p2 in zip(cloned_embeddings, model_embed.weight): + if p1.data.ne(p2.data).sum() > 0: + models_equal = False + + self.assertTrue(models_equal) + + del model + if is_deepspeed_zero3_enabled(): + with deepspeed.zero.Init(): + model = model_class(config) + else: + model = model_class(config) + model.to(torch_device) + + model_vocab_size = config.get_text_config().vocab_size + model.resize_token_embeddings(model_vocab_size + 10, pad_to_multiple_of=1) + new_model_vocab_size = model.config.get_text_config().vocab_size + self.assertTrue(new_model_vocab_size + 10, model_vocab_size) + + model_embed = model.resize_token_embeddings(model_vocab_size, pad_to_multiple_of=64) + new_model_vocab_size = model.config.get_text_config().vocab_size + self.assertTrue(model_embed.weight.shape[0] // 64, 0) + + self.assertTrue(model_embed.weight.shape[0], new_model_vocab_size) + self.assertTrue(new_model_vocab_size, model.vocab_size) + + model_embed = model.resize_token_embeddings(model_vocab_size + 13, pad_to_multiple_of=64) + self.assertTrue(model_embed.weight.shape[0] // 64, 0) + + # Check that resizing a model to a multiple of pad_to_multiple leads to a model of exactly that size + target_dimension = 128 + model_embed = model.resize_token_embeddings(target_dimension, pad_to_multiple_of=64) + self.assertTrue(model_embed.weight.shape[0], target_dimension) + + with self.assertRaisesRegex( + ValueError, + "Asking to pad the embedding matrix to a multiple of `1.3`, which is not and integer. Please make sure to pass an integer", + ): + model.resize_token_embeddings(model_vocab_size, pad_to_multiple_of=1.3) + + # Test when `vocab_size` is smaller than `hidden_size`. + del model + config.vocab_size = 4 + config.pad_token_id = 4 # Ignore copy + if is_deepspeed_zero3_enabled(): + with deepspeed.zero.Init(): + model = model_class(config) + else: + model = model_class(config) + model.to(torch_device) + + model_vocab_size = config.get_text_config().vocab_size + # Retrieve the embeddings and clone theme + model_embed = model.resize_token_embeddings(model_vocab_size) + cloned_embeddings = model_embed.weight.clone() + + # Check that resizing the token embeddings with a larger vocab size increases the model's vocab size + model_embed = model.resize_token_embeddings(model_vocab_size + 10) + new_model_vocab_size = model.config.get_text_config().vocab_size + self.assertEqual(new_model_vocab_size, model_vocab_size + 10) + # Check that it actually resizes the embeddings matrix + self.assertEqual(model_embed.weight.shape[0], cloned_embeddings.shape[0] + 10) + # Check to make sure the type of embeddings returned post resizing is same as type of input + type_model_embed_post_resize = type(model_embed) + self.assertEqual(type_model_embed_pre_resize, type_model_embed_post_resize) + # Check that added embeddings mean is close to the old embeddings mean + if is_deepspeed_zero3_enabled(): + with deepspeed.zero.GatheredParameters(model_embed.weight, modifier_rank=None): + old_embeddings_mean = torch.mean(model_embed.weight.data[:-10, :], axis=0) + new_embeddings_mean = torch.mean(model_embed.weight.data[-10:, :], axis=0) + else: + old_embeddings_mean = torch.mean(model_embed.weight.data[:-10, :], axis=0) + new_embeddings_mean = torch.mean(model_embed.weight.data[-10:, :], axis=0) + torch.testing.assert_close(old_embeddings_mean, new_embeddings_mean, atol=1e-3, rtol=1e-1) + + @unittest.skip(reason="Some undefined behavior encountered with test versions of this model. Skip for now.") + def test_cpu_offload(self): + pass + + @unittest.skip(reason="Some undefined behavior encountered with test versions of this model. Skip for now.") + def test_disk_offload_bin(self): + pass + + @unittest.skip(reason="Some undefined behavior encountered with test versions of this model. Skip for now.") + def test_disk_offload_safetensors(self): + pass + + @is_flaky(max_attempts=5, description="flaky on some models.") + def test_save_load(self): + super().test_save_load() + + +class MoshiTester: + def __init__( + self, + parent, + batch_size=4, # need batch_size != num_hidden_layers + seq_length=7, + is_training=True, + vocab_size=99, + hidden_size=32, + num_hidden_layers=2, + num_attention_heads=8, + intermediate_size=4, + hidden_act="silu", + rms_norm_eps=0.001, + ffn_dim=32, + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=100, + pad_token_id=25, + bos_token_id=25, + num_codebooks=4, + audio_encoder_type="mimi", + attn_implementation="eager", + depth_hidden_size=16, + depth_num_hidden_layers=2, + depth_max_position_embeddings=5, + depth_num_attention_heads=8, + depth_ffn_dim=16, + depth_sliding_window=4, + mimi_intermediate_size=40, + mimi_hidden_size=32, + mimi_num_filters=8, + mimi_num_residual_layers=1, + mimi_upsampling_ratios=[8, 4], + mimi_codebook_size=64, + mimi_vector_quantization_hidden_dimension=64, + mimi_codebook_dim=64, + mimi_upsample_groups=32, + mimi_num_hidden_layers=2, + mimi_num_attention_heads=2, + mimi_num_key_value_heads=2, + mimi_sliding_window=3, + sampling_rate=800, + ): + self.parent = parent + self.batch_size = batch_size + self.seq_length = seq_length + self.is_training = is_training + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.rms_norm_eps = rms_norm_eps + self.ffn_dim = ffn_dim + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.pad_token_id = pad_token_id + self.bos_token_id = bos_token_id + self.num_codebooks = num_codebooks + self.attn_implementation = attn_implementation + self.depth_hidden_size = depth_hidden_size + self.depth_num_hidden_layers = depth_num_hidden_layers + self.depth_max_position_embeddings = depth_max_position_embeddings + self.depth_num_attention_heads = depth_num_attention_heads + self.depth_ffn_dim = depth_ffn_dim + self.depth_sliding_window = depth_sliding_window + + self.audio_encoder_type = audio_encoder_type + self.mimi_intermediate_size = mimi_intermediate_size + self.mimi_hidden_size = mimi_hidden_size + self.mimi_num_filters = mimi_num_filters + self.mimi_num_residual_layers = mimi_num_residual_layers + self.mimi_upsampling_ratios = mimi_upsampling_ratios + self.mimi_codebook_size = mimi_codebook_size + self.mimi_vector_quantization_hidden_dimension = mimi_vector_quantization_hidden_dimension + self.mimi_codebook_dim = mimi_codebook_dim + self.mimi_upsample_groups = mimi_upsample_groups + self.mimi_num_hidden_layers = mimi_num_hidden_layers + self.mimi_num_attention_heads = mimi_num_attention_heads + self.mimi_num_key_value_heads = mimi_num_key_value_heads + self.mimi_sliding_window = mimi_sliding_window + self.sampling_rate = sampling_rate + + self.num_hidden_states_types = 2 + + def prepare_config_and_inputs(self, batch_size=None): + batch_size = self.batch_size if batch_size is None else batch_size + + input_ids = ids_tensor([batch_size, self.seq_length], self.vocab_size) + + moshi_audio_codes = ids_tensor([batch_size, self.num_codebooks, self.seq_length], self.mimi_codebook_size) + user_audio_codes = ids_tensor([batch_size, self.num_codebooks, self.seq_length], self.mimi_codebook_size) + attention_mask = input_ids.ne(self.pad_token_id) + + config = self.get_config() + inputs_dict = { + "input_ids": input_ids, + "moshi_audio_codes": moshi_audio_codes, + "user_audio_codes": user_audio_codes, + "attention_mask": attention_mask, + } + return config, inputs_dict + + def get_config(self): + mimi_dict_config = { + "model_type": self.audio_encoder_type, + "audio_channels": 1, + "hidden_size": self.mimi_hidden_size, + "num_filters": self.mimi_num_filters, + "num_residual_layers": self.mimi_num_residual_layers, + "upsampling_ratios": self.mimi_upsampling_ratios, + "codebook_size": self.mimi_codebook_size, + "vector_quantization_hidden_dimension": self.mimi_vector_quantization_hidden_dimension, + "upsample_groups": self.mimi_upsample_groups, + "num_hidden_layers": self.mimi_num_hidden_layers, + "num_attention_heads": self.mimi_num_attention_heads, + "num_key_value_heads": self.mimi_num_key_value_heads, + "sliding_window": self.mimi_sliding_window, + "codebook_dim": self.mimi_codebook_dim, + "use_cache": False, + "sampling_rate": self.sampling_rate, + } + + depth_dict_config = { + "hidden_size": self.depth_hidden_size, + "num_hidden_layers": self.depth_num_hidden_layers, + "max_position_embeddings": self.depth_max_position_embeddings, + "num_attention_heads": self.depth_num_attention_heads, + "ffn_dim": self.depth_ffn_dim, + "sliding_window": self.depth_sliding_window, + } + + config = MoshiConfig( + vocab_size=self.vocab_size, + hidden_size=self.hidden_size, + num_hidden_layers=self.num_hidden_layers, + num_attention_heads=self.num_attention_heads, + d_ff=self.intermediate_size, + num_codebooks=self.num_codebooks, + rms_norm_eps=self.rms_norm_eps, + tie_word_embeddings=False, + pad_token_id=self.pad_token_id, + bos_token_id=self.bos_token_id, + ffn_dim=self.ffn_dim, + audio_encoder_config=mimi_dict_config, + depth_decoder_config=depth_dict_config, + attn_implementation=self.attn_implementation, + ) + return config + + def prepare_config_and_inputs_for_common(self, batch_size=None): + config, inputs_dict = self.prepare_config_and_inputs(batch_size) + return config, inputs_dict + + +@require_torch +class MoshiTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): + all_model_classes = (MoshiForConditionalGeneration,) if is_torch_available() else () + all_generative_model_classes = (MoshiForConditionalGeneration,) if is_torch_available() else () + test_pruning = False # training is not supported yet for Moshi + test_headmasking = False + test_resize_embeddings = False + test_torchscript = False + + def setUp(self): + self.model_tester = MoshiTester(self) + + # special case for labels + def _prepare_for_class(self, inputs_dict, model_class, return_labels=False): + inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels) + + if return_labels: + inputs_dict["text_labels"] = torch.zeros( + (self.model_tester.batch_size, self.model_tester.seq_length), + dtype=torch.long, + device=torch_device, + ) + return inputs_dict + + def _get_input_ids_and_config(self, batch_size=2): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common(batch_size) + input_ids = inputs_dict.pop("input_ids").to(torch_device) + attention_mask = inputs_dict.pop("attention_mask").to(torch_device) + + # Make sure we only return `input_ids`. + # Note that audio_codes will still be generated internally, so the ability to test audio codes is still there. + # There are further tests to test that audio waveforms and codes are well generated. + inputs_dict["return_audio_waveforms"] = False + inputs_dict["return_audio_codes"] = False + inputs_dict["concat_unconditional_inputs"] = False + + return config, input_ids, attention_mask, inputs_dict + + def prepare_config_and_inputs_for_generate(self, batch_size=2): + config, filtered_inputs_dict = super().prepare_config_and_inputs_for_generate() + + # Make sure we only return `input_ids`. + # Note that audio_codes will still be generated internally, so the ability to test audio codes is still there. + # There are further tests to test that audio waveforms and codes are well generated. + filtered_inputs_dict["return_audio_waveforms"] = False + filtered_inputs_dict["return_audio_codes"] = False + filtered_inputs_dict["concat_unconditional_inputs"] = False + + return config, filtered_inputs_dict + + def _check_hidden_states_for_generate( + self, batch_size, hidden_states, min_length, max_length, config, use_cache=False, num_beam_groups=1 + ): + # Overwrite because the generate method actually alway uses `inputs_embeds` so `use_cache` is always `True` + self.assertIsInstance(hidden_states, tuple) + self.assertListEqual( + [isinstance(iter_hidden_states, tuple) for iter_hidden_states in hidden_states], + [True] * len(hidden_states), + ) + self.assertEqual(len(hidden_states), (max_length - min_length) * num_beam_groups) + + for idx, iter_hidden_states in enumerate(hidden_states): + seq_len = min_length if idx == 0 else 1 + expected_shape = (batch_size * num_beam_groups, seq_len, config.hidden_size) + # check hidden size + self.assertListEqual( + [layer_hidden_states.shape for layer_hidden_states in iter_hidden_states], + [expected_shape] * len(iter_hidden_states), + ) + + def _check_outputs(self, output, input_ids, config, use_cache=False, num_return_sequences=1): + # Overwrite because the generate method actually alway uses `inputs_embeds` so `use_cache` is always `True` + super()._check_outputs(output, input_ids, config, use_cache=True, num_return_sequences=num_return_sequences) + + def _check_hidden_states_for_generate( + self, batch_size, hidden_states, min_length, max_length, config, use_cache=False, num_beam_groups=1 + ): + # Overwrite because the generate method actually alway uses `inputs_embeds` so `use_cache` is always `True` + self.assertIsInstance(hidden_states, tuple) + self.assertListEqual( + [isinstance(iter_hidden_states, tuple) for iter_hidden_states in hidden_states], + [True] * len(hidden_states), + ) + self.assertEqual(len(hidden_states), (max_length - min_length) * num_beam_groups) + + for idx, iter_hidden_states in enumerate(hidden_states): + seq_len = 1 + expected_shape = (batch_size * num_beam_groups, seq_len, config.hidden_size) + # check hidden size + self.assertListEqual( + [layer_hidden_states.shape for layer_hidden_states in iter_hidden_states], + [expected_shape] * len(iter_hidden_states), + ) + + def _check_attentions_for_generate( + self, batch_size, attentions, min_length, max_length, config, use_cache=False, num_beam_groups=1 + ): + # Overwrite because the generate method actually alway uses `inputs_embeds` so `use_cache` is always `True` + self.assertIsInstance(attentions, tuple) + self.assertListEqual( + [isinstance(iter_attentions, tuple) for iter_attentions in attentions], [True] * len(attentions) + ) + self.assertEqual(len(attentions), (max_length - min_length) * num_beam_groups) + + for idx, iter_attentions in enumerate(attentions): + tgt_len = 1 + src_len = min_length + idx + + expected_shape = ( + batch_size * num_beam_groups, + config.num_attention_heads, + tgt_len, + src_len, + ) + # check attn size + self.assertListEqual( + [layer_attention.shape for layer_attention in iter_attentions], [expected_shape] * len(iter_attentions) + ) + + def test_initialization(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + configs_no_init = _config_zero_init(config) + for model_class in self.all_model_classes: + model = model_class(config=configs_no_init) + for name, param in model.named_parameters(): + uniform_init_parms = ["conv", "input_proj", "output_proj"] + if param.requires_grad: + if any(x in name for x in uniform_init_parms): + self.assertTrue( + -1.0 <= ((param.data.mean() * 1e9).round() / 1e9).item() <= 1.0, + msg=f"Parameter {name} of model {model_class} seems not properly initialized", + ) + + @pytest.mark.generate + def test_generate_from_inputs_embeds_decoder_only(self): + for model_class in self.all_generative_model_classes: + config, input_ids, _, inputs_dict = self._get_input_ids_and_config() + + model = model_class(config).to(torch_device).eval() + + # Traditional way of generating text + outputs_from_ids = model.generate( + input_ids, max_new_tokens=5, return_dict_in_generate=True, output_scores=True, **inputs_dict + ) + self.assertEqual(outputs_from_ids.sequences.shape, (input_ids.shape[0], input_ids.shape[1] + 5)) + + # Same thing, but from input embeddings (`input_ids` is passed so the prompt is present in the output) + inputs_embeds = model.get_input_embeddings()(input_ids) + outputs_from_embeds = model.generate( + input_ids, + inputs_embeds=inputs_embeds, + max_new_tokens=5, + return_dict_in_generate=True, + output_scores=True, + **inputs_dict, + ) + + # But if we pass different inputs_embeds, we should get different outputs (the output text may be the + # same, but the logits will almost surely be different) + random_embeds = torch.rand_like(inputs_embeds) + outputs_from_rand_embeds = model.generate( + input_ids, + inputs_embeds=random_embeds, + max_new_tokens=5, + return_dict_in_generate=True, + output_scores=True, + **inputs_dict, + ) + for i in range(len(outputs_from_rand_embeds.scores)): + self.assertFalse(torch.allclose(outputs_from_embeds.scores[i], outputs_from_rand_embeds.scores[i])) + + # input_ids is not a required input -- if we don't pass it, the newly generated tokens will be the same + outputs_from_embeds_wo_ids = model.generate( + inputs_embeds=inputs_embeds, + max_new_tokens=5, + return_dict_in_generate=True, + output_scores=True, + **inputs_dict, + ) + self.assertListEqual( + outputs_from_embeds.sequences[:, inputs_embeds.shape[1] :].tolist(), + outputs_from_embeds_wo_ids.sequences.tolist(), + ) + + @unittest.skip(reason="Continuing from past key values is not straightforward as we're dealing with 3 inputs") + def test_generate_continue_from_past_key_values(self): + pass + + @unittest.skip("Moshi doesn't support contrastive generation yet.") + def test_contrastive_generate(self): + pass + + @unittest.skip("Moshi doesn't support contrastive generation yet.") + def test_contrastive_generate_dict_outputs_use_cache(self): + pass + + @unittest.skip("Moshi doesn't support contrastive generation yet.") + def test_contrastive_generate_low_memory(self): + pass + + @unittest.skip("Adapting this test is costly. `test_eager_matches_sdpa_generate` tests this already.") + @parameterized.expand([("float16",), ("bfloat16",), ("float32",)]) + @require_torch_sdpa + @slow + def test_eager_matches_sdpa_inference(self, torch_dtype: str): + pass + + @unittest.skip(reason="The Moshi model does not have support dynamic compile yet") + def test_sdpa_can_compile_dynamic(self): + pass + + @pytest.mark.generate + def test_left_padding_compatibility(self): + # NOTE: left-padding results in small numerical differences. This is expected. + # See https://github.com/huggingface/transformers/issues/25420#issuecomment-1775317535 + + # Then, test left-padding + + for model_class in self.all_generative_model_classes: + config, input_ids, attention_mask, input_dict = self._get_input_ids_and_config() + model = model_class(config).to(torch_device).eval() + + # no cache as some models require special cache classes to be init outside forward + model.generation_config.use_cache = False + + # Without padding + next_logits_wo_padding = model(input_ids=input_ids, attention_mask=attention_mask, **input_dict).logits[ + :, -1, : + ] + + # With left-padding (length 32) + # can hardcode pad_token to be 0 as we'll do attn masking anyway + pad_token_id = ( + config.get_text_config().pad_token_id if config.get_text_config().pad_token_id is not None else 0 + ) + pad_size = (input_ids.shape[0], 32) + padding = torch.ones(pad_size, dtype=input_ids.dtype, device=torch_device) * pad_token_id + padded_input_ids = torch.cat((padding, input_ids), dim=1) + + padded_attention_mask = torch.cat((torch.zeros_like(padding), attention_mask), dim=1) + + padding = ( + torch.ones( + (pad_size[0], self.model_tester.num_codebooks, 32), dtype=input_ids.dtype, device=torch_device + ) + * config.audio_vocab_size + ) + padded_moshi_audio_codes = torch.cat((padding, input_dict["moshi_audio_codes"]), dim=2) + padded_user_audio_codes = torch.cat((padding, input_dict["user_audio_codes"]), dim=2) + + model_kwargs = { + "input_ids": padded_input_ids, + "attention_mask": padded_attention_mask, + "moshi_audio_codes": padded_moshi_audio_codes, + "user_audio_codes": padded_user_audio_codes, + } + + next_logits_with_padding = model(**model_kwargs).logits[:, -1, :] + + # They should result in very similar logits + self.assertTrue(torch.allclose(next_logits_wo_padding, next_logits_with_padding, atol=1e-5)) + + @require_torch_sdpa + @slow + @is_flaky(max_attempts=5, description="flaky on some models.") + def test_eager_matches_sdpa_generate(self): + if not self.has_attentions: + self.skipTest(reason="Model architecture does not support attentions") + + max_new_tokens = 5 + + if len(self.all_generative_model_classes) == 0: + self.skipTest(f"{self.__class__.__name__} tests a model that does support generate: skipping this test") + + for model_class in self.all_generative_model_classes: + if not model_class._supports_sdpa: + self.skipTest(f"{model_class.__name__} does not support SDPA") + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + dummy_input = inputs_dict[model_class.main_input_name] + if dummy_input.dtype in [torch.float32, torch.bfloat16]: + dummy_input = dummy_input.to(torch.float16) + + inputs_dict[model_class.main_input_name] = dummy_input + + # make sure that all models have enough positions for generation + if hasattr(config, "max_position_embeddings"): + config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1 + + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + + model_sdpa = model_class.from_pretrained( + tmpdirname, + torch_dtype=torch.float16, + low_cpu_mem_usage=True, + ).to(torch_device) + + self.assertTrue(model_sdpa.config._attn_implementation == "sdpa") + + model_eager = model_class.from_pretrained( + tmpdirname, + torch_dtype=torch.float16, + low_cpu_mem_usage=True, + attn_implementation="eager", + ).to(torch_device) + + self.assertTrue(model_eager.config._attn_implementation == "eager") + + for name, submodule in model_eager.named_modules(): + class_name = submodule.__class__.__name__ + if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name: + raise ValueError("The eager model should not have SDPA attention layers") + + has_sdpa = False + for name, submodule in model_sdpa.named_modules(): + class_name = submodule.__class__.__name__ + if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name: + has_sdpa = True + break + if not has_sdpa: + raise ValueError("The SDPA model should have SDPA attention layers") + + # Just test that a large cache works as expected + res_eager = model_eager.generate( + **inputs_dict, + max_new_tokens=max_new_tokens, + do_sample=False, + depth_decoder_do_sample=False, + ) + + res_sdpa = model_sdpa.generate( + **inputs_dict, + max_new_tokens=max_new_tokens, + do_sample=False, + depth_decoder_do_sample=False, + ) + + self.assertTrue(torch.allclose(res_eager.sequences, res_sdpa.sequences)) + self.assertTrue(torch.allclose(res_eager.audio_sequences, res_sdpa.audio_sequences)) + + @pytest.mark.generate + def test_generate_without_input_ids(self): + config, _, _, _ = self._get_input_ids_and_config() + + for model_class in self.all_generative_model_classes: + model = model_class(config).to(torch_device) + model.eval() + + output_ids_generate = model.generate( + do_sample=False, max_new_tokens=self.max_new_tokens, remove_invalid_values=True + ) + self.assertIsNotNone(output_ids_generate) + + @unittest.skip(reason="The audio encoder has no gradients.") + def test_training_gradient_checkpointing(self): + pass + + @unittest.skip(reason="The audio encoder has no gradients.") + def test_training_gradient_checkpointing_use_reentrant(self): + pass + + @unittest.skip(reason="The audio encoder has no gradients.") + def test_training_gradient_checkpointing_use_reentrant_false(self): + pass + + def test_generate_from_input_values(self): + for model_class in self.all_generative_model_classes: + config, input_ids, _, _ = self._get_input_ids_and_config() + + model = model_class(config).to(torch_device).eval() + + input_values_length = int( + self.model_tester.seq_length * config.sampling_rate / config.audio_encoder_config.frame_rate + ) + + user_input_values = floats_tensor((input_ids.shape[0], 1, input_values_length)) + moshi_input_values = floats_tensor((input_ids.shape[0], 1, input_values_length)) + + user_audio_codes = model.audio_encoder.encode(user_input_values, num_quantizers=model.num_codebooks)[0] + moshi_audio_codes = model.audio_encoder.encode(moshi_input_values, num_quantizers=model.num_codebooks)[0] + + outputs_from_audio_codes = model.generate( + input_ids, max_new_tokens=5, user_audio_codes=user_audio_codes, moshi_audio_codes=moshi_audio_codes + ) + + outputs_from_audio_values = model.generate( + input_ids, max_new_tokens=5, user_input_values=user_input_values, moshi_input_values=moshi_input_values + ) + + self.assertTrue((outputs_from_audio_values.sequences == outputs_from_audio_codes.sequences).all()) + self.assertTrue( + torch.allclose(outputs_from_audio_codes.audio_sequences, outputs_from_audio_values.audio_sequences) + ) + + def test_generate_depth_decoder_kwargs(self): + # test sampling and beam search + for model_class in self.all_generative_model_classes: + config, input_ids, _, input_dict = self._get_input_ids_and_config() + + model = model_class(config).to(torch_device).eval() + + model.generate(input_ids, max_new_tokens=5, **input_dict, depth_decoder_do_sample=True) + + model.generate( + input_ids, max_new_tokens=5, **input_dict, depth_decoder_do_sample=True, depth_decoder_num_beams=5 + ) + + def test_generate_from_unconditional(self): + # test sampling and beam search + for model_class in self.all_generative_model_classes: + config, input_ids, _, input_dict = self._get_input_ids_and_config() + + model = model_class(config).to(torch_device).eval() + + # check bs>1 + model.generate( + **model.get_unconditional_inputs(num_samples=4), max_new_tokens=5, concat_unconditional_inputs=False + ) + + # check same results from uncondtional or no inputs + outputs_from_unconditional = model.generate( + **model.get_unconditional_inputs(num_samples=1), max_new_tokens=5, concat_unconditional_inputs=False + ) + outputs_from_none = model.generate(max_new_tokens=5) + + self.assertTrue((outputs_from_unconditional.sequences == outputs_from_none.sequences).all()) + self.assertTrue( + torch.allclose(outputs_from_unconditional.audio_sequences, outputs_from_none.audio_sequences) + ) + + @unittest.skip(reason="Compile not yet supported because in Moshi models") + def test_sdpa_can_dispatch_on_flash(self): + pass + + @unittest.skip(reason="Some undefined behavior encountered with test versions of this model. Skip for now.") + def test_cpu_offload(self): + pass + + @unittest.skip(reason="Some undefined behavior encountered with test versions of this model. Skip for now.") + def test_disk_offload_bin(self): + pass + + @unittest.skip(reason="Some undefined behavior encountered with test versions of this model. Skip for now.") + def test_disk_offload_safetensors(self): + pass + + @is_flaky(max_attempts=5, description="flaky on some models.") + def test_save_load(self): + super().test_save_load() + + +def place_dict_on_device(dict_to_place, device): + for key in dict_to_place: + if dict_to_place[key] is not None and isinstance(dict_to_place[key], torch.Tensor): + dict_to_place[key] = dict_to_place[key].to(device) + return dict_to_place + + +@require_torch +class MoshiIntegrationTests(unittest.TestCase): + @cached_property + def feature_extractor(self): + return AutoFeatureExtractor.from_pretrained("kmhf/hf-moshiko") + + @cached_property + def tokenizer(self): + return AutoTokenizer.from_pretrained("kmhf/hf-moshiko") + + def _load_datasample(self): + ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") + dataset = ds.cast_column("audio", Audio(sampling_rate=self.feature_extractor.sampling_rate)) + # automatic decoding with librispeech + speech_sample = dataset.sort("id")[0]["audio"]["array"] + return speech_sample + + @slow + def test_moshika_conditional_greedy(self): + model = MoshiForConditionalGeneration.from_pretrained( + "kmhf/hf-moshika", torch_dtype=torch.float16, device_map="auto" + ) + inputs = self.feature_extractor(self._load_datasample(), return_tensors="pt").to( + device=torch_device, dtype=torch.float16 + ) + + user_audio_codes = model.audio_encoder.encode(**inputs, num_quantizers=8).audio_codes + + input_ids = self.tokenizer.encode(" Hello,", return_tensors="pt").to( + torch_device + ) + + # fmt: off + moshi_audio_codes = [[[1049, 127, 1880, 972, 972, 1156, 1913, 415, 1933], + [1700, 243, 91, 91, 91, 745, 1478, 638, 57], + [1626, 457, 457, 457, 457, 1839, 200, 2011, 1142], + [546, 290, 390, 390, 290, 1408, 1812, 1187, 1911], + [306, 306, 1314, 1314, 1314, 759, 796, 854, 1466], + [1443, 1443, 1030, 317, 347, 1178, 613, 1576, 2023], + [1871, 428, 1433, 1433, 1978, 1405, 1755, 820, 610], + [2008, 1744, 1511, 568, 1533, 550, 237, 1412, 1401]]] + # fmt: on + + moshi_audio_codes = torch.tensor(moshi_audio_codes, device=torch_device) + user_audio_codes = user_audio_codes[:, :, : moshi_audio_codes.shape[-1]] + + model_outputs = model.generate( + user_audio_codes=user_audio_codes, + moshi_audio_codes=moshi_audio_codes, + input_ids=input_ids, + do_sample=False, + depth_decoder_do_sample=False, + return_audio_codes=True, + max_new_tokens=2, + ) + + expected_text_token = 452 + expected_audio_tokens = [916, 1396, 1238, 579, 1105, 914, 1257, 810] # fmt: skip + + self.assertTrue(expected_text_token == model_outputs.sequences[0, -2].cpu().item()) + self.assertTrue(expected_audio_tokens == model_outputs.audio_codes[0, :, -1].cpu().tolist()) + + @slow + def test_moshiko_greedy_unconditional_fp16_eager(self): + model = MoshiForConditionalGeneration.from_pretrained( + "kmhf/hf-moshiko", torch_dtype=torch.float16, device_map="auto" + ) + some_expected_audio_tokens = [[1049, 127], [1700, 243], [1626, 457], [546, 290], [306, 306], [1443, 1443], [1871, 428], [2008, 1744]] # fmt: skip + + model_outputs = model.generate( + do_sample=False, depth_decoder_do_sample=False, return_audio_codes=True, max_new_tokens=10 + ) + + # eager equivalence is not as strict as sdpa. + self.assertTrue(some_expected_audio_tokens == model_outputs.audio_codes[0, :, :2].cpu().tolist()) + + @slow + def test_moshiko_greedy_unconditional_fp32(self): + model = MoshiForConditionalGeneration.from_pretrained( + "kmhf/hf-moshiko", torch_dtype=torch.float32, device_map="auto" + ) + + expected_audio_codesum = 72065 + expected_text_tokens = [3, 3, 3, 0, 11725, 261, 3, 3, 3, 3] # fmt: skip + some_expected_audio_tokens = [[1049, 127], [1700, 243], [1626, 457], [546, 290], [306, 306], [1443, 1443], [1871, 428], [2008, 1744]] # fmt: skip + + model_outputs = model.generate( + do_sample=False, depth_decoder_do_sample=False, return_audio_codes=True, max_new_tokens=10 + ) + + # make sure audio encoded codes are correct + audio_code_sums = model_outputs.audio_codes.sum().item() + self.assertTrue(np.abs(audio_code_sums - expected_audio_codesum) <= (3e-3 * audio_code_sums)) + + self.assertTrue(expected_text_tokens == model_outputs.sequences[0, 1:].cpu().tolist()) + self.assertTrue(some_expected_audio_tokens == model_outputs.audio_codes[0, :, :2].cpu().tolist()) + + @slow + @require_torch_fp16 + def test_moshiko_greedy_unconditional_fp16(self): + model = MoshiForConditionalGeneration.from_pretrained( + "kmhf/hf-moshiko", torch_dtype=torch.float16, device_map="auto" + ) + + expected_audio_codesum = 72065 + expected_text_tokens = [3, 3, 3, 0, 11725, 261, 3, 3, 3, 3] # fmt: skip + some_expected_audio_tokens = [[1049, 127], [1700, 243], [1626, 457], [546, 290], [306, 306], [1443, 1443], [1871, 428], [2008, 1744]] # fmt: skip + + model_outputs = model.generate( + do_sample=False, depth_decoder_do_sample=False, return_audio_codes=True, max_new_tokens=10 + ) + + # make sure audio encoded codes are correct + audio_code_sums = model_outputs.audio_codes.sum().item() + self.assertTrue(np.abs(audio_code_sums - expected_audio_codesum) <= (3e-3 * audio_code_sums)) + + self.assertTrue(expected_text_tokens == model_outputs.sequences[0, 1:].cpu().tolist()) + self.assertTrue(some_expected_audio_tokens == model_outputs.audio_codes[0, :, :2].cpu().tolist()) + + @slow + @require_torch_fp16 + def test_moshika_greedy_unconditional_fp16(self): + model = MoshiForConditionalGeneration.from_pretrained( + "kmhf/hf-moshika", torch_dtype=torch.float16, device_map="auto" + ) + + expected_audio_codesum = 72932 + expected_text_tokens = [3, 3, 3, 0, 667, 263, 3, 3, 0, 705] # fmt: skip + some_expected_audio_tokens = [[1049, 127], [1700, 243], [1626, 457], [546, 290], [306, 306], [1443, 347], [1871, 428], [2008, 2008]] # fmt: skip + + model_outputs = model.generate( + do_sample=False, depth_decoder_do_sample=False, return_audio_codes=True, max_new_tokens=10 + ) + + # make sure audio encoded codes are correct + audio_code_sums = model_outputs.audio_codes.sum().item() + self.assertTrue(np.abs(audio_code_sums - expected_audio_codesum) <= 2048) + + self.assertTrue(expected_text_tokens == model_outputs.sequences[0, 1:].cpu().tolist()) + self.assertTrue(some_expected_audio_tokens == model_outputs.audio_codes[0, :, :2].cpu().tolist()) diff --git a/tests/models/moshi/test_tokenization_moshi.py b/tests/models/moshi/test_tokenization_moshi.py new file mode 100644 index 00000000000..ad3a34a197f --- /dev/null +++ b/tests/models/moshi/test_tokenization_moshi.py @@ -0,0 +1,447 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import pickle +import shutil +import tempfile +import unittest + +from transformers import ( + SPIECE_UNDERLINE, + AddedToken, + AutoTokenizer, + PreTrainedTokenizerFast, + SpecialTokensMixin, +) +from transformers.convert_slow_tokenizer import MoshiConverter +from transformers.testing_utils import ( + get_tests_dir, + nested_simplify, + require_sentencepiece, + require_tokenizers, + require_torch, +) + +from ...test_tokenization_common import SMALL_TRAINING_CORPUS, TokenizerTesterMixin + + +SAMPLE_VOCAB = get_tests_dir("fixtures/test_sentencepiece.model") + + +@require_sentencepiece +@require_tokenizers +class MoshiTokenizationTest(TokenizerTesterMixin, unittest.TestCase): + from_pretrained_id = ["kmhf/hf-moshiko"] + rust_tokenizer_class = PreTrainedTokenizerFast + + test_slow_tokenizer = False + test_rust_tokenizer = True + from_pretrained_kwargs = {} + + def setUp(self): + super().setUp() + + # We have a SentencePiece fixture for testing + tokenizer = PreTrainedTokenizerFast( + tokenizer_object=MoshiConverter(vocab_file=SAMPLE_VOCAB).converted(), + bos_token="", + unk_token="", + eos_token="", + ) + tokenizer.pad_token = tokenizer.eos_token + tokenizer.save_pretrained(self.tmpdirname) + + def get_rust_tokenizer(self, **kwargs) -> PreTrainedTokenizerFast: + return self.rust_tokenizer_class.from_pretrained(self.tmpdirname, **kwargs) + + @unittest.skip(reason="No slow tokenizer") + def test_added_tokens_serialization(self): + pass + + @unittest.skip(reason="PreTrainedTokenizerFast doesn't have tokenizer_file in its signature") + def test_rust_tokenizer_signature(self): + pass + + @unittest.skip(reason="No slow tokenizer") + def test_encode_decode_with_spaces(self): + pass + + def test_full_tokenizer(self): + tokenizer = PreTrainedTokenizerFast( + tokenizer_object=MoshiConverter(vocab_file=SAMPLE_VOCAB).converted(), + bos_token="", + unk_token="", + eos_token="", + ) + + tokens = tokenizer.tokenize("This is a test") + self.assertListEqual(tokens, ["▁This", "▁is", "▁a", "▁t", "est"]) + + self.assertListEqual( + tokenizer.convert_tokens_to_ids(tokens), + [285, 46, 10, 170, 382], + ) + + tokens = tokenizer.tokenize("I was born in 92000, and this is falsé.") + self.assertListEqual( + tokens, + [ + SPIECE_UNDERLINE + "I", + SPIECE_UNDERLINE + "was", + SPIECE_UNDERLINE + "b", + "or", + "n", + SPIECE_UNDERLINE + "in", + SPIECE_UNDERLINE + "", + "9", + "2", + "0", + "0", + "0", + ",", + SPIECE_UNDERLINE + "and", + SPIECE_UNDERLINE + "this", + SPIECE_UNDERLINE + "is", + SPIECE_UNDERLINE + "f", + "al", + "s", + "é", + ".", + ], + ) + ids = tokenizer.convert_tokens_to_ids(tokens) + self.assertListEqual( + ids, + [8, 21, 84, 55, 24, 19, 7, 0, 602, 347, 347, 347, 3, 12, 66, 46, 72, 80, 6, 0, 4], + ) + + back_tokens = tokenizer.convert_ids_to_tokens(ids) + self.assertListEqual( + back_tokens, + [ + SPIECE_UNDERLINE + "I", + SPIECE_UNDERLINE + "was", + SPIECE_UNDERLINE + "b", + "or", + "n", + SPIECE_UNDERLINE + "in", + SPIECE_UNDERLINE + "", + "", + "2", + "0", + "0", + "0", + ",", + SPIECE_UNDERLINE + "and", + SPIECE_UNDERLINE + "this", + SPIECE_UNDERLINE + "is", + SPIECE_UNDERLINE + "f", + "al", + "s", + "", + ".", + ], + ) + + def test_special_tokens_initialization(self): + for tokenizer, pretrained_name, kwargs in self.tokenizers_list: + with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"): + added_tokens = [AddedToken("", lstrip=True)] + + tokenizer_r = self.rust_tokenizer_class.from_pretrained( + pretrained_name, additional_special_tokens=added_tokens, **kwargs + ) + r_output = tokenizer_r.encode("Hey this is a token") + + special_token_id = tokenizer_r.encode("", add_special_tokens=False)[0] + + self.assertTrue(special_token_id in r_output) + + def test_picklable(self): + with tempfile.NamedTemporaryFile() as f: + shutil.copyfile(SAMPLE_VOCAB, f.name) + tokenizer = PreTrainedTokenizerFast( + tokenizer_object=MoshiConverter(vocab_file=f.name).converted(), + bos_token="", + unk_token="", + eos_token="", + ) + pickled_tokenizer = pickle.dumps(tokenizer) + pickle.loads(pickled_tokenizer) + + def test_training_new_tokenizer(self): + # This feature only exists for fast tokenizers + if not self.test_rust_tokenizer: + self.skipTest(reason="test_rust_tokenizer is set to False") + + tokenizer = self.get_rust_tokenizer() + new_tokenizer = tokenizer.train_new_from_iterator(SMALL_TRAINING_CORPUS, 100) + + # Test we can use the new tokenizer with something not seen during training + inputs = new_tokenizer(["This is the first sentence", "This sentence is different 🤗."]) + self.assertEqual(len(inputs["input_ids"]), 2) + decoded_input = new_tokenizer.decode(inputs["input_ids"][0], skip_special_tokens=True) + expected_result = "This is the first sentence" + + self.assertEqual(expected_result, decoded_input) + + # We check that the parameters of the tokenizer remained the same + # Check we have the same number of added_tokens for both pair and non-pair inputs. + self.assertEqual(tokenizer.num_special_tokens_to_add(False), new_tokenizer.num_special_tokens_to_add(False)) + self.assertEqual(tokenizer.num_special_tokens_to_add(True), new_tokenizer.num_special_tokens_to_add(True)) + + # Check we have the correct max_length for both pair and non-pair inputs. + self.assertEqual(tokenizer.max_len_single_sentence, new_tokenizer.max_len_single_sentence) + self.assertEqual(tokenizer.max_len_sentences_pair, new_tokenizer.max_len_sentences_pair) + + # Assert the set of special tokens match as we didn't ask to change them + self.assertSequenceEqual( + tokenizer.all_special_tokens_extended, + new_tokenizer.all_special_tokens_extended, + ) + + self.assertDictEqual(tokenizer.special_tokens_map, new_tokenizer.special_tokens_map) + + def test_training_new_tokenizer_with_special_tokens_change(self): + # This feature only exists for fast tokenizers + if not self.test_rust_tokenizer: + self.skipTest(reason="test_rust_tokenizer is set to False") + + tokenizer = self.get_rust_tokenizer() + # Test with a special tokens map + class_signature = inspect.signature(tokenizer.__class__) + if "cls_token" in class_signature.parameters: + new_tokenizer = tokenizer.train_new_from_iterator( + SMALL_TRAINING_CORPUS, 100, special_tokens_map={tokenizer.cls_token: ""} + ) + cls_id = new_tokenizer.get_vocab()[""] + self.assertEqual(new_tokenizer.cls_token, "") + self.assertEqual(new_tokenizer.cls_token_id, cls_id) + + # Create a new mapping from the special tokens defined in the original tokenizer + special_tokens_list = SpecialTokensMixin.SPECIAL_TOKENS_ATTRIBUTES.copy() + special_tokens_list.remove("additional_special_tokens") + special_tokens_map = {} + for token in special_tokens_list: + # Get the private one to avoid unnecessary warnings. + if getattr(tokenizer, f"_{token}") is not None: + special_token = getattr(tokenizer, token) + special_tokens_map[special_token] = f"{special_token}a" + + # Train new tokenizer + new_tokenizer = tokenizer.train_new_from_iterator( + SMALL_TRAINING_CORPUS, 100, special_tokens_map=special_tokens_map + ) + + # Check the changes + for token in special_tokens_list: + # Get the private one to avoid unnecessary warnings. + if getattr(tokenizer, f"_{token}") is None: + continue + special_token = getattr(tokenizer, token) + if special_token in special_tokens_map: + new_special_token = getattr(new_tokenizer, token) + self.assertEqual(special_tokens_map[special_token], new_special_token) + + new_id = new_tokenizer.get_vocab()[new_special_token] + self.assertEqual(getattr(new_tokenizer, f"{token}_id"), new_id) + + # Check if the AddedToken / string format has been kept + for special_token in tokenizer.all_special_tokens_extended: + if isinstance(special_token, AddedToken) and special_token.content not in special_tokens_map: + # The special token must appear identically in the list of the new tokenizer. + self.assertTrue( + special_token in new_tokenizer.all_special_tokens_extended, + f"'{special_token}' should be in {new_tokenizer.all_special_tokens_extended}", + ) + elif isinstance(special_token, AddedToken): + # The special token must appear in the list of the new tokenizer as an object of type AddedToken with + # the same parameters as the old AddedToken except the content that the user has requested to change. + special_token_str = special_token.content + new_special_token_str = special_tokens_map[special_token_str] + + find = False + for candidate in new_tokenizer.all_special_tokens_extended: + if ( + isinstance(candidate, AddedToken) + and candidate.content == new_special_token_str + and candidate.lstrip == special_token.lstrip + and candidate.rstrip == special_token.rstrip + and candidate.normalized == special_token.normalized + and candidate.single_word == special_token.single_word + ): + find = True + break + special_token.content = new_special_token_str + self.assertTrue( + find, + f"'{special_token.__repr__()}' should appear as an `AddedToken` in the all_special_tokens_extended = " + f"{[k for k in new_tokenizer.all_special_tokens_extended if str(k)==new_special_token_str]} but it is missing" + ", this means that the new tokenizers did not keep the `rstrip`, `lstrip`, `normalized` etc attributes.", + ) + elif special_token not in special_tokens_map: + # The special token must appear identically in the list of the new tokenizer. + self.assertTrue( + special_token in new_tokenizer.all_special_tokens_extended, + f"'{special_token.__repr__()}' should be in {new_tokenizer.all_special_tokens_extended}", + ) + + else: + # The special token must appear in the list of the new tokenizer as an object of type string. + self.assertTrue(special_tokens_map[special_token] in new_tokenizer.all_special_tokens_extended) + + # Test we can use the new tokenizer with something not seen during training + inputs = new_tokenizer(["This is the first sentence", "This sentence is different 🤗."]) + self.assertEqual(len(inputs["input_ids"]), 2) + decoded_input = new_tokenizer.decode(inputs["input_ids"][0], skip_special_tokens=True) + expected_result = "This is the first sentence" + + self.assertEqual(expected_result, decoded_input) + + def test_alignement_methods(self): + # TODO: @ArthurZucker - alignment is broken + pass + + def test_added_tokens_do_lower_case(self): + # TODO: @ArthurZucker + pass + + +@require_torch +@require_sentencepiece +@require_tokenizers +class MoshiIntegrationTest(unittest.TestCase): + @classmethod + def setUpClass(cls): + checkpoint_name = "kmhf/hf-moshiko" + cls.rust_tokenizer = AutoTokenizer.from_pretrained(checkpoint_name) + return cls + + @require_torch + def integration_tests(self): + inputs = self.tokenizer( + ["The following string should be properly encoded: Hello.", "But ird and ปี ird ด"], + return_tensors="pt", + ) + + long_attention_mask = [1] * 21 + + # fmt: off + self.assertEqual( + nested_simplify(inputs), + { + "input_ids": [ + [287, 547, 2359, 457, 297, 3708, 11488, 279, 11725, 263], + [588, 478, 1442, 267, 260, 228, 188, 159, 228, 188, 185, 260, 260, 478, 1442, 260, 260, 260, 228, 188, 152], + ], + "attention_mask": [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1], long_attention_mask], + }, + ) + # fmt: on + + def test_fast_special_tokens(self): + fast_tokenizer = self.rust_tokenizer + + fast_tokenizer.add_eos_token = False + fast = fast_tokenizer.encode("A sample test", add_special_tokens=True) + assert fast == [318, 1145, 694] + + fast_tokenizer.add_eos_token = True + fast = fast_tokenizer.encode("A sample test", add_special_tokens=True) + assert fast == [318, 1145, 694] + + self.rust_tokenizer.add_eos_token = False + + def test_simple_encode_decode(self): + rust_tokenizer = self.rust_tokenizer + + self.assertEqual(rust_tokenizer.encode("This is a test"), [353, 275, 272, 694]) + self.assertEqual(rust_tokenizer.decode([353, 275, 272, 694], skip_special_tokens=True), "This is a test") + + # bytefallback showcase + bytefallback_tokens = [260, 235, 152, 163, 234, 184, 191, 13340, 235, 160, 163, 236, 180, 159, 234, 156, 179] # fmt: skip + self.assertEqual(rust_tokenizer.encode("生活的真谛是"), bytefallback_tokens) + self.assertEqual( + rust_tokenizer.decode(bytefallback_tokens, skip_special_tokens=True), + "生活的真谛是", + ) + + # Inner spaces showcase + self.assertEqual(rust_tokenizer.encode("Hi Hello"), [2769, 260, 11725]) + self.assertEqual(rust_tokenizer.decode([2769, 260, 11725], skip_special_tokens=True), "Hi Hello") + + self.assertEqual(rust_tokenizer.encode("Hi Hello"), [2769, 260, 260, 11725]) + self.assertEqual(rust_tokenizer.decode([2769, 260, 260, 11725], skip_special_tokens=True), "Hi Hello") + + # TODO: @ArthurZucker + # self.assertEqual(rust_tokenizer.encode(""), []) + + # self.assertEqual(rust_tokenizer.encode(" "), [260, 260]) + + # self.assertEqual(rust_tokenizer.encode(" "), [260, 260, 260]) + + # self.assertEqual(rust_tokenizer.encode(" Hello"), [260, 11725]) + + # self.assertEqual(rust_tokenizer.encode(""), [607, 266, 578]) + + def test_no_differences_decode(self): + rust_tokenizer = self.rust_tokenizer + + self.assertEqual(rust_tokenizer.decode([869]), "levels") + + self.assertEqual(rust_tokenizer.decode([30112, 869]), "unanswered levels") + + +@require_sentencepiece +@require_tokenizers +class CommonSpmIntegrationTests(unittest.TestCase): + """ + A class that regroups important test to make sure that we properly handle the special tokens. + """ + + def test_edge_case_tabulation(self): + fast_tokenizer = AutoTokenizer.from_pretrained("kmhf/hf-moshiko") + input_text = "Hey. \t\t \n\nyou é @#😈 🤗! , 1234 15 5,61" + EXPECTED_IDS = [11510, 934, 4451, 266, 578, 263, 260, 13, 13, 260, 14, 14, 5209, 260, 260, 1202, 260, 527, 1322, 244, 163, 156, 140, 260, 260, 244, 163, 168, 155, 430, 1047, 261, 260, 265, 270, 278, 281, 260, 265, 280, 260, 280, 261, 285, 265] # fmt: skip + EXPECTED_TOKENS = ['▁Hey', '<', 'eo', 's', '>', '.', '▁', '<0x09>', '<0x09>', '▁', '<0x0A>', '<0x0A>', 'you', '▁', '▁', 'é', '▁', '▁@', '#', '<0xF0>', '<0x9F>', '<0x98>', '<0x88>', '▁', '▁', '<0xF0>', '<0x9F>', '<0xA4>', '<0x97>', '!', '▁▁▁▁▁▁▁', ',', '▁', '1', '2', '3', '4', '▁', '1', '5', '▁', '5', ',', '6', '1'] # fmt: skip + + tokens = fast_tokenizer.tokenize(input_text) + with self.subTest("test fast edge case fast"): + self.assertEqual(tokens, EXPECTED_TOKENS) + + input_ids = fast_tokenizer.encode(input_text) + with self.subTest("test fast edge case fast"): + self.assertEqual(input_ids, EXPECTED_IDS) + + text = fast_tokenizer.decode(EXPECTED_IDS) + with self.subTest("test fast edge case fast"): + self.assertEqual(text, "Hey. \t\t \n\nyou é @#😈 🤗! , 1234 15 5,61") + + input_text = "\t\t\t\t \n\n61" + EXPECTED_IDS = [260, 13, 13, 13, 13, 260, 14, 14, 285, 265] + EXPECTED_TOKENS = ["▁", "<0x09>", "<0x09>", "<0x09>", "<0x09>", "▁", "<0x0A>", "<0x0A>", "6", "1"] + + tokens = fast_tokenizer.tokenize(input_text) + with self.subTest("test fast edge case fast"): + self.assertEqual(tokens, EXPECTED_TOKENS) + + input_ids = fast_tokenizer.encode(input_text) + with self.subTest("test fast edge case fast"): + self.assertEqual(input_ids, EXPECTED_IDS) + + text = fast_tokenizer.decode(EXPECTED_IDS) + with self.subTest("test fast edge case fast"): + self.assertEqual(text, "\t\t\t\t \n\n61") diff --git a/utils/check_repo.py b/utils/check_repo.py index 98f96bcc78a..6872dada3d9 100644 --- a/utils/check_repo.py +++ b/utils/check_repo.py @@ -327,6 +327,7 @@ IGNORE_NON_AUTO_CONFIGURED = PRIVATE_MODELS.copy() + [ "SiglipVisionModel", "SiglipTextModel", "ChameleonVQVAE", # no autoclass for VQ-VAE models + "MoshiForConditionalGeneration", # no auto class for speech-to-speech ] # DO NOT edit this list!