mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 12:50:06 +06:00
Moshi integration (#33624)
* clean mimi commit * some nits suggestions from Arthur * make fixup * first moshi WIP * converting weights working + configuration + generation configuration * finalize converting script - still missing tokenizer and FE and processor * fix saving model w/o default config * working generation * use GenerationMixin instead of inheriting * add delay pattern mask * fix right order: moshi codes then user codes * unconditional inputs + generation config * get rid of MoshiGenerationConfig * blank user inputs * update convert script:fix conversion, add tokenizer, feature extractor and bf16 * add and correct Auto classes * update modeling code, configuration and tests * make fixup * fix some copies * WIP: add integration tests * add dummy objects * propose better readiblity and code organisation * update tokenization tests * update docstrigns, eval and modeling * add .md * make fixup * add MoshiForConditionalGeneration to ignore Auto * revert mimi changes * re * further fix * Update moshi.md * correct md formating * move prepare causal mask to class * fix copies * fix depth decoder causal * fix and correct some tests * make style and update .md * correct config checkpoitn * Update tests/models/moshi/test_tokenization_moshi.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update tests/models/moshi/test_tokenization_moshi.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * make style * Update src/transformers/models/moshi/__init__.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * fixup * change firm in copyrights * udpate config with nested dict * replace einsum * make style * change split to True * add back splt=False * remove tests in convert * Update tests/models/moshi/test_modeling_moshi.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * add default config repo + add model to FA2 docstrings * remove logits float * fix some tokenization tests and ignore some others * make style tokenization tests * update modeling with sliding window + update modeling tests * [run-slow] moshi * remove prepare for generation frol CausalLM * isort * remove copied from * ignore offload tests * update causal mask and prepare 4D mask aligned with recent changes * further test refine + add back prepare_inputs_for_generation for depth decoder * correct conditional use of prepare mask * update slow integration tests * fix multi-device forward * remove previous solution to device_map * save_load is flaky * fix generate multi-devices * fix device * move tensor to int --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> Co-authored-by: Marc Sun <marc@huggingface.co>
This commit is contained in:
parent
d087165db0
commit
9ba021ea75
@ -740,6 +740,8 @@
|
||||
title: Mimi
|
||||
- local: model_doc/mms
|
||||
title: MMS
|
||||
- local: model_doc/moshi
|
||||
title: Moshi
|
||||
- local: model_doc/musicgen
|
||||
title: MusicGen
|
||||
- local: model_doc/musicgen_melody
|
||||
|
@ -223,6 +223,7 @@ Flax), PyTorch, and/or TensorFlow.
|
||||
| [MobileNetV2](model_doc/mobilenet_v2) | ✅ | ❌ | ❌ |
|
||||
| [MobileViT](model_doc/mobilevit) | ✅ | ✅ | ❌ |
|
||||
| [MobileViTV2](model_doc/mobilevitv2) | ✅ | ❌ | ❌ |
|
||||
| [Moshi](model_doc/moshi) | ✅ | ❌ | ❌ |
|
||||
| [MPNet](model_doc/mpnet) | ✅ | ✅ | ❌ |
|
||||
| [MPT](model_doc/mpt) | ✅ | ❌ | ❌ |
|
||||
| [MRA](model_doc/mra) | ✅ | ❌ | ❌ |
|
||||
|
@ -66,4 +66,4 @@ The original code can be found [here](https://github.com/kyutai-labs/moshi).
|
||||
[[autodoc]] MimiModel
|
||||
- decode
|
||||
- encode
|
||||
- forward
|
||||
- forward
|
183
docs/source/en/model_doc/moshi.md
Normal file
183
docs/source/en/model_doc/moshi.md
Normal file
@ -0,0 +1,183 @@
|
||||
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
|
||||
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
||||
rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
|
||||
# Moshi
|
||||
|
||||
## Overview
|
||||
|
||||
The Moshi model was proposed in [Moshi: a speech-text foundation model for real-time dialogue](https://kyutai.org/Moshi.pdf) by Alexandre Défossez, Laurent Mazaré, Manu Orsini, Amélie Royer, Patrick Pérez, Hervé Jégou, Edouard Grave and Neil Zeghidour.
|
||||
|
||||
Moshi is a speech-text foundation model that casts spoken dialogue as speech-to-speech generation. Starting from a text language model backbone, Moshi generates speech as tokens from the residual quantizer of a neural audio codec, while modeling separately its own speech and that of the user into parallel streams. This allows for the removal of explicit speaker turns, and the modeling of arbitrary conversational dynamics. Moshi also predicts time-aligned text tokens as a prefix to audio tokens. This “Inner Monologue” method significantly improves the linguistic quality of generated speech and provides streaming speech recognition and text-to-speech. As a result, Moshi is the first real-time full-duplex spoken large language model, with a theoretical latency of 160ms, 200ms in practice.
|
||||
|
||||
<div style="text-align: center">
|
||||
<img src="https://huggingface.co/datasets/ylacombe/benchmark-comparison/resolve/main/moshi_architecture.png">
|
||||
</div>
|
||||
|
||||
The abstract from the paper is the following:
|
||||
|
||||
*We introduce Moshi, a speech-text foundation model and full-duplex spoken dialogue framework. Current systems for spoken dialogue rely on pipelines of independent components, namely voice activity detection, speech recognition, textual dialogue and text-to-speech. Such frameworks cannot emulate the experience of real conversations. First, their complexity induces a latency of several seconds between interactions. Second, text being the intermediate modality for dialogue, non-linguistic information that modifies meaning— such as emotion or non-speech sounds— is lost in the interaction. Finally, they rely on a segmentation into speaker turns, which does not take into account overlapping speech, interruptions and interjections. Moshi solves these independent issues altogether by casting spoken dialogue as speech-to-speech generation. Starting from a text language model backbone, Moshi generates speech as tokens from the residual quantizer of a neural audio codec, while modeling separately its own speech and that of the user into parallel streams. This allows for the removal of explicit speaker turns, and the modeling of arbitrary conversational dynamics. We moreover extend the hierarchical semantic-to-acoustic token generation of previous work to first predict time-aligned text tokens as a prefix to audio tokens. Not only this “Inner Monologue” method significantly improves the linguistic quality of generated speech, but we also illustrate how it can provide streaming speech recognition and text-to-speech. Our resulting model is the first real-time full-duplex spoken large language model, with a theoretical latency of 160ms, 200ms in practice, and is available at github.com/kyutai-labs/moshi.*
|
||||
|
||||
Moshi deals with 3 streams of information:
|
||||
1. The user's audio
|
||||
2. Moshi's audio
|
||||
3. Moshi's textual output
|
||||
|
||||
Similarly to [`~MusicgenModel`], audio is represented with audio codebooks, which can be interpreted like tokens. The main difference between text tokens and audio codebooks is that audio codebooks introduce an additional dimension of information.
|
||||
Text tokens are typically of dim `(batch_size, sequence_length)` but audio tokens are of dim `(batch_size, num_codebooks, sequence_length)`.
|
||||
|
||||
Moshi's made of 3 components:
|
||||
|
||||
**1. The main decoder (Helium in the paper)**
|
||||
|
||||
It corresponds to [`MoshiForCausalLM`]. It is strictly a classic text LLM, that uses an architecture similar to [` ~GemmaForCausalLM`]. In other words, it takes text tokens, embeds them, pass them through the decoder and a language head, to get text logits.
|
||||
|
||||
**2. The depth decoder**
|
||||
|
||||
On its own, it's also a classic LLM, but this time, instead of generating over the time dimension, it generates over the codebook dimension.
|
||||
|
||||
It also means that its context length is `num_codebooks`, thus it can't generate more than `num_codebooks`.
|
||||
|
||||
Note that each timestamp - i.e each codebook - gets its own set of Linear Layers and Embeddings.
|
||||
|
||||
**3. [`MimiModel`]**
|
||||
|
||||
It's the audio encoder from Kyutai, that has recently been integrated to transformers, which is used to "tokenize" audio. It has the same use that [`~EncodecModel`] has in [`~MusicgenModel`].
|
||||
|
||||
|
||||
## Tips:
|
||||
|
||||
The original checkpoints can be converted using the conversion script `src/transformers/models/moshi/convert_moshi_transformers.py`
|
||||
|
||||
|
||||
### How to use the model:
|
||||
|
||||
This implementation has two main aims:
|
||||
1. quickly test model generation by simplifying the original API
|
||||
2. simplify training. A training guide will come soon, but user contributions are welcomed!
|
||||
|
||||
<Tip>
|
||||
|
||||
It is designed for intermediate use. We strongly recommend using the original [implementation](https://github.com/kyutai-labs/moshi) to infer the model in real-time streaming.
|
||||
|
||||
</Tip>
|
||||
|
||||
**1. Model generation**
|
||||
|
||||
Moshi is a streaming auto-regressive model with two streams of audio. To put it differently, one audio stream corresponds to what the model said/will say and the other audio stream corresponds to what the user said/will say.
|
||||
|
||||
[`MoshiForConditionalGeneration.generate`] thus needs 3 inputs:
|
||||
1. `input_ids` - corresponding to the text token history
|
||||
2. `moshi_input_values` or `moshi_audio_codes`- corresponding to the model audio history
|
||||
3. `user_input_values` or `user_audio_codes` - corresponding to the user audio history
|
||||
|
||||
These three inputs must be synchronized. Meaning that their lengths must correspond to the same number of tokens.
|
||||
|
||||
You can dynamically use the 3 inputs depending on what you want to test:
|
||||
1. Simply check the model response to an user prompt - in that case, `input_ids` can be filled with pad tokens and `user_input_values` can be a zero tensor of the same shape than the user prompt.
|
||||
2. Test more complex behaviour - in that case, you must be careful about how the input tokens are synchronized with the audios.
|
||||
|
||||
<Tip>
|
||||
|
||||
The original model is synchronized text with audio by padding the text in between each token enunciation.
|
||||
|
||||
To follow the example of the following image, `"Hello, I'm Moshi"` could be transformed to `"Hello,<pad><unk>I'm Moshi"`.
|
||||
|
||||
</Tip>
|
||||
|
||||
<div style="text-align: center">
|
||||
<img src="https://huggingface.co/datasets/ylacombe/benchmark-comparison/resolve/main/moshi_text_sync.png">
|
||||
</div>
|
||||
|
||||
|
||||
[`MoshiForConditionalGeneration.generate`] then auto-regressively feeds to itself its own audio stream, but since it doesn't have access to the user input stream while using `transformers`, it will thus **assume that the user is producing blank audio**.
|
||||
|
||||
|
||||
|
||||
```python
|
||||
>>> from datasets import load_dataset, Audio
|
||||
>>> import torch, math
|
||||
>>> from transformers import MoshiForConditionalGeneration, AutoFeatureExtractor, AutoTokenizer
|
||||
>>> librispeech_dummy = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
||||
|
||||
|
||||
>>> # prepare user input audio
|
||||
>>> librispeech_dummy = librispeech_dummy.cast_column("audio", Audio(sampling_rate=feature_extractor.sampling_rate))
|
||||
>>> audio_sample = librispeech_dummy[-1]["audio"]["array"]
|
||||
>>> user_input_values = feature_extractor(raw_audio=audio_sample, sampling_rate=feature_extractor.sampling_rate, return_tensors="pt").to(device=device, dtype=dtype)
|
||||
|
||||
>>> # prepare moshi input values - we suppose moshi didn't say anything while the user spoke
|
||||
>>> moshi_input_values = torch.zeros_like(user_input_values.input_values)
|
||||
|
||||
>>> # prepare moshi input ids - we suppose moshi didn't say anything while the user spoke
|
||||
>>> num_tokens = math.ceil(moshi_input_values.shape[-1] * waveform_to_token_ratio)
|
||||
>>> input_ids = torch.ones((1, num_tokens), device=device, dtype=torch.int64) * tokenizer.encode("<pad>")[0]
|
||||
|
||||
>>> # generate 25 new tokens (around 2s of audio)
|
||||
>>> output = model.generate(input_ids=input_ids, user_input_values=user_input_values.input_values, moshi_input_values=moshi_input_values, max_new_tokens=25)
|
||||
|
||||
>>> text_tokens = output.sequences
|
||||
>>> audio_waveforms = output.audio_sequences
|
||||
```
|
||||
|
||||
**2. Model training**
|
||||
|
||||
Most of the work has to be done during data creation/pre-processing, because of the need to align/synchronize streams.
|
||||
|
||||
Once it's done, you can simply forward `text_labels` and `audio_labels` to [`MoshiForConditionalGeneration.forward`], alongside the usual inputs, to get the model loss.
|
||||
|
||||
A training guide will come soon, but user contributions are welcomed!
|
||||
|
||||
### How does the model forward the inputs / generate:
|
||||
|
||||
1. The input streams are embedded and combined into `inputs_embeds`.
|
||||
|
||||
2. `inputs_embeds` is passed through the main decoder, which processes it like a normal LLM would.
|
||||
|
||||
3. The main decoder outputs `text logits` but also its `last hidden state` which is called `temporal context` in the paper.
|
||||
|
||||
3. The depth decoder switches the dimension on which we forward / generate (codebooks instead of time). It uses the token generated from `text logits` and the `temporal context` to auto-regressively generate audio codebooks.
|
||||
|
||||
|
||||
This model was contributed by [Yoach Lacombe (ylacombe)](https://huggingface.co/ylacombe).
|
||||
|
||||
The original code can be found [here](https://github.com/kyutai-labs/moshi).
|
||||
|
||||
|
||||
|
||||
## MoshiConfig
|
||||
|
||||
[[autodoc]] MoshiConfig
|
||||
|
||||
## MoshiDepthConfig
|
||||
|
||||
[[autodoc]] MoshiDepthConfig
|
||||
|
||||
## MoshiModel
|
||||
|
||||
[[autodoc]] MoshiModel
|
||||
- forward
|
||||
|
||||
## MoshiForCausalLM
|
||||
|
||||
[[autodoc]] MoshiForCausalLM
|
||||
- forward
|
||||
|
||||
## MoshiForConditionalGeneration
|
||||
|
||||
[[autodoc]] MoshiForConditionalGeneration
|
||||
- forward
|
||||
- generate
|
||||
- get_unconditional_inputs
|
@ -70,6 +70,7 @@ FlashAttention-2 is currently supported for the following architectures:
|
||||
* [MBart](https://huggingface.co/docs/transformers/model_doc/mbart#transformers.MBartModel)
|
||||
* [Mistral](https://huggingface.co/docs/transformers/model_doc/mistral#transformers.MistralModel)
|
||||
* [Mixtral](https://huggingface.co/docs/transformers/model_doc/mixtral#transformers.MixtralModel)
|
||||
* [Moshi](https://huggingface.co/docs/transformers/model_doc/moshi#transformers.MoshiModel)
|
||||
* [Musicgen](https://huggingface.co/docs/transformers/model_doc/musicgen#transformers.MusicgenModel)
|
||||
* [MusicGen Melody](https://huggingface.co/docs/transformers/model_doc/musicgen_melody#transformers.MusicgenMelodyModel)
|
||||
* [Nemotron](https://huggingface.co/docs/transformers/model_doc/nemotron)
|
||||
@ -241,6 +242,7 @@ For now, Transformers supports SDPA inference and training for the following arc
|
||||
* [Mistral](https://huggingface.co/docs/transformers/model_doc/mistral#transformers.MistralModel)
|
||||
* [Mllama](https://huggingface.co/docs/transformers/model_doc/mllama#transformers.MllamaForConditionalGeneration)
|
||||
* [Mixtral](https://huggingface.co/docs/transformers/model_doc/mixtral#transformers.MixtralModel)
|
||||
* [Moshi](https://huggingface.co/docs/transformers/model_doc/moshi#transformers.MoshiModel)
|
||||
* [Musicgen](https://huggingface.co/docs/transformers/model_doc/musicgen#transformers.MusicgenModel)
|
||||
* [MusicGen Melody](https://huggingface.co/docs/transformers/model_doc/musicgen_melody#transformers.MusicgenMelodyModel)
|
||||
* [NLLB](https://huggingface.co/docs/transformers/model_doc/nllb)
|
||||
|
@ -590,6 +590,10 @@ _import_structure = {
|
||||
"models.mobilenet_v2": ["MobileNetV2Config"],
|
||||
"models.mobilevit": ["MobileViTConfig"],
|
||||
"models.mobilevitv2": ["MobileViTV2Config"],
|
||||
"models.moshi": [
|
||||
"MoshiConfig",
|
||||
"MoshiDepthConfig",
|
||||
],
|
||||
"models.mpnet": [
|
||||
"MPNetConfig",
|
||||
"MPNetTokenizer",
|
||||
@ -2783,6 +2787,14 @@ else:
|
||||
"MobileViTV2PreTrainedModel",
|
||||
]
|
||||
)
|
||||
_import_structure["models.moshi"].extend(
|
||||
[
|
||||
"MoshiForCausalLM",
|
||||
"MoshiForConditionalGeneration",
|
||||
"MoshiModel",
|
||||
"MoshiPreTrainedModel",
|
||||
]
|
||||
)
|
||||
_import_structure["models.mpnet"].extend(
|
||||
[
|
||||
"MPNetForMaskedLM",
|
||||
@ -5448,6 +5460,10 @@ if TYPE_CHECKING:
|
||||
from .models.mobilevitv2 import (
|
||||
MobileViTV2Config,
|
||||
)
|
||||
from .models.moshi import (
|
||||
MoshiConfig,
|
||||
MoshiDepthConfig,
|
||||
)
|
||||
from .models.mpnet import (
|
||||
MPNetConfig,
|
||||
MPNetTokenizer,
|
||||
@ -7386,6 +7402,12 @@ if TYPE_CHECKING:
|
||||
MobileViTV2Model,
|
||||
MobileViTV2PreTrainedModel,
|
||||
)
|
||||
from .models.moshi import (
|
||||
MoshiForCausalLM,
|
||||
MoshiForConditionalGeneration,
|
||||
MoshiModel,
|
||||
MoshiPreTrainedModel,
|
||||
)
|
||||
from .models.mpnet import (
|
||||
MPNetForMaskedLM,
|
||||
MPNetForMultipleChoice,
|
||||
|
@ -1405,6 +1405,47 @@ class MarkupLMConverter(Converter):
|
||||
return tokenizer
|
||||
|
||||
|
||||
class MoshiConverter(SpmConverter):
|
||||
handle_byte_fallback = True
|
||||
|
||||
def __init__(self, vocab_file, model_max_length=None, **kwargs):
|
||||
requires_backends(self, "protobuf")
|
||||
|
||||
Converter.__init__(self, vocab_file)
|
||||
|
||||
# from .utils import sentencepiece_model_pb2 as model_pb2
|
||||
model_pb2 = import_protobuf()
|
||||
|
||||
m = model_pb2.ModelProto()
|
||||
with open(vocab_file, "rb") as f:
|
||||
m.ParseFromString(f.read())
|
||||
self.proto = m
|
||||
|
||||
def normalizer(self, proto):
|
||||
precompiled_charsmap = proto.normalizer_spec.precompiled_charsmap
|
||||
_normalizers = [
|
||||
normalizers.Replace(" ", "▁"),
|
||||
]
|
||||
if not precompiled_charsmap:
|
||||
return normalizers.Sequence(_normalizers)
|
||||
else:
|
||||
return normalizers.Sequence([normalizers.Precompiled(precompiled_charsmap)] + _normalizers)
|
||||
|
||||
def decoder(self, replacement, add_prefix_space):
|
||||
sequence = [
|
||||
decoders.Replace("▁", " "),
|
||||
decoders.ByteFallback(),
|
||||
decoders.Fuse(),
|
||||
]
|
||||
if add_prefix_space:
|
||||
sequence += [decoders.Strip(content=" ", left=1)]
|
||||
return decoders.Sequence(sequence)
|
||||
|
||||
def pre_tokenizer(self, replacement, add_prefix_space):
|
||||
prepend_scheme = "first"
|
||||
return pre_tokenizers.Metaspace(replacement=replacement, prepend_scheme=prepend_scheme, split=False)
|
||||
|
||||
|
||||
# Copied from transformers.models.gpt2.tokenization_gpt2.bytes_to_unicode
|
||||
def bytes_to_unicode():
|
||||
"""
|
||||
|
@ -1594,8 +1594,10 @@ class GenerationMixin:
|
||||
cache_dtype = self.get_output_embeddings().weight.dtype
|
||||
|
||||
def get_layer_device_map(execution_device_map: Optional[dict] = None):
|
||||
if execution_device_map is None or len(execution_device_map) <= 1:
|
||||
if execution_device_map is None:
|
||||
return None
|
||||
elif len(execution_device_map) == 1 and "" in execution_device_map:
|
||||
return {idx: execution_device_map[""] for idx in range(self.config.num_hidden_layers)}
|
||||
layer_device_map = {}
|
||||
for layer in execution_device_map:
|
||||
for idx in range(self.config.num_hidden_layers):
|
||||
|
@ -161,6 +161,7 @@ from . import (
|
||||
mobilenet_v2,
|
||||
mobilevit,
|
||||
mobilevitv2,
|
||||
moshi,
|
||||
mpnet,
|
||||
mpt,
|
||||
mra,
|
||||
|
@ -179,6 +179,7 @@ CONFIG_MAPPING_NAMES = OrderedDict(
|
||||
("mobilenet_v2", "MobileNetV2Config"),
|
||||
("mobilevit", "MobileViTConfig"),
|
||||
("mobilevitv2", "MobileViTV2Config"),
|
||||
("moshi", "MoshiConfig"),
|
||||
("mpnet", "MPNetConfig"),
|
||||
("mpt", "MptConfig"),
|
||||
("mra", "MraConfig"),
|
||||
@ -490,6 +491,7 @@ MODEL_NAMES_MAPPING = OrderedDict(
|
||||
("mobilenet_v2", "MobileNetV2"),
|
||||
("mobilevit", "MobileViT"),
|
||||
("mobilevitv2", "MobileViTV2"),
|
||||
("moshi", "Moshi"),
|
||||
("mpnet", "MPNet"),
|
||||
("mpt", "MPT"),
|
||||
("mra", "MRA"),
|
||||
|
@ -73,6 +73,7 @@ FEATURE_EXTRACTOR_MAPPING_NAMES = OrderedDict(
|
||||
("mobilenet_v1", "MobileNetV1FeatureExtractor"),
|
||||
("mobilenet_v2", "MobileNetV2FeatureExtractor"),
|
||||
("mobilevit", "MobileViTFeatureExtractor"),
|
||||
("moshi", "EncodecFeatureExtractor"),
|
||||
("nat", "ViTFeatureExtractor"),
|
||||
("owlvit", "OwlViTFeatureExtractor"),
|
||||
("perceiver", "PerceiverFeatureExtractor"),
|
||||
|
@ -169,6 +169,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
|
||||
("mobilenet_v2", "MobileNetV2Model"),
|
||||
("mobilevit", "MobileViTModel"),
|
||||
("mobilevitv2", "MobileViTV2Model"),
|
||||
("moshi", "MoshiModel"),
|
||||
("mpnet", "MPNetModel"),
|
||||
("mpt", "MptModel"),
|
||||
("mra", "MraModel"),
|
||||
@ -506,6 +507,7 @@ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
|
||||
("mistral", "MistralForCausalLM"),
|
||||
("mixtral", "MixtralForCausalLM"),
|
||||
("mllama", "MllamaForCausalLM"),
|
||||
("moshi", "MoshiForCausalLM"),
|
||||
("mpt", "MptForCausalLM"),
|
||||
("musicgen", "MusicgenForCausalLM"),
|
||||
("musicgen_melody", "MusicgenMelodyForCausalLM"),
|
||||
|
@ -309,6 +309,7 @@ else:
|
||||
("mllama", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("mluke", ("MLukeTokenizer" if is_sentencepiece_available() else None, None)),
|
||||
("mobilebert", ("MobileBertTokenizer", "MobileBertTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("moshi", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("mpnet", ("MPNetTokenizer", "MPNetTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("mpt", (None, "GPTNeoXTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("mra", ("RobertaTokenizer", "RobertaTokenizerFast" if is_tokenizers_available() else None)),
|
||||
|
27
src/transformers/models/moshi/__init__.py
Normal file
27
src/transformers/models/moshi/__init__.py
Normal file
@ -0,0 +1,27 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...utils import _LazyModule
|
||||
from ...utils.import_utils import define_import_structure
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .configuration_moshi import *
|
||||
from .modeling_moshi import *
|
||||
else:
|
||||
import sys
|
||||
|
||||
_file = globals()["__file__"]
|
||||
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
|
333
src/transformers/models/moshi/configuration_moshi.py
Normal file
333
src/transformers/models/moshi/configuration_moshi.py
Normal file
@ -0,0 +1,333 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 Meta AI and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Moshi model configuration"""
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...utils import logging
|
||||
from ..auto.configuration_auto import AutoConfig
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class MoshiDepthConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`MoshiDepthDecoder`]. It is used to instantiate a
|
||||
Moshi depth decoder model according to the specified arguments, defining the Moshi depth decoder config.
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
Args:
|
||||
vocab_size (`int`, *optional*, defaults to 32000):
|
||||
Vocabulary size of the MoshiDepthDecoder model. Defines the number of different tokens that can be
|
||||
represented by the `inputs_ids` passed when calling [`MoshiDepthDecoder`].
|
||||
hidden_size (`int`, *optional*, defaults to 1024):
|
||||
Dimensionality of the layers and the pooler layer of the depth decoder.
|
||||
input_size (`int`, *optional*, defaults to 4096):
|
||||
Dimensionality of the input hidden states. Used to connect the main decoder to the depth decoder.
|
||||
num_hidden_layers (`int`, *optional*, defaults to 6):
|
||||
Number of depth decoder layers.
|
||||
num_attention_heads (`int`, *optional*, defaults to 16):
|
||||
Number of attention heads for each attention layer in the depth decoder block.
|
||||
num_key_value_heads (`int`, *optional*):
|
||||
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
|
||||
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
|
||||
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
|
||||
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
|
||||
by meanpooling all the original heads within that group. For more details checkout [this
|
||||
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `num_attention_heads`.
|
||||
audio_vocab_size (`int`, *optional*, defaults to 2048):
|
||||
Vocabulary size of the audio part of model. Defines the number of different tokens that can be
|
||||
represented by the `audio_codes` passed when calling the Moshi models.
|
||||
max_position_embeddings (`int`, *optional*, defaults to 9):
|
||||
The maximum sequence length that this model might ever be used with. Typically, set this to something large
|
||||
just in case (e.g., 512 or 1024 or 2048).
|
||||
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
|
||||
The non-linear activation function (function or string) in the depth decoder.
|
||||
head_dim (`int`, *optional*, defaults to `hidden_size // num_attention_heads`):
|
||||
The attention head dimension.
|
||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
use_cache (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
||||
relevant if `config.is_decoder=True`.
|
||||
sliding_window (`int`, *optional*, defaults to 8):
|
||||
Sliding window attention window size. If not specified, will default to `8`.
|
||||
attention_dropout (`float`, *optional*, defaults to 0.0):
|
||||
The dropout ratio for the attention probabilities.
|
||||
ffn_dim (`int`, *optional*, defaults to 5632):
|
||||
Dimensionality of the "intermediate" (often named feed-forward) layer in the depth decoder block. Must be even.
|
||||
rms_norm_eps (`float`, *optional*, defaults to 1e-08):
|
||||
The epsilon used by the rms normalization layers.
|
||||
num_codebooks (`int`, *optional*, defaults to 8):
|
||||
The number of audio codebooks for each audio channels.
|
||||
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
|
||||
Whether to tie weight embeddings
|
||||
kwargs (*optional*):
|
||||
Dictionary of keyword arguments. Notably:
|
||||
- **audio_encoder_config** ([`PretrainedConfig`], *optional*) -- An instance of a configuration object that
|
||||
defines the audio encoder config.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import (
|
||||
... MoshiDepthConfig,
|
||||
... MoshiDepthDecoder,
|
||||
... )
|
||||
|
||||
>>> configuration = MoshiDepthConfig()
|
||||
|
||||
>>> # Initializing a MoshiDepthDecoder (with random weights) from the kmhf/hf-moshiko style configuration
|
||||
>>> model = MoshiDepthDecoder(configuration)
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
```"""
|
||||
|
||||
model_type = "moshi_depth"
|
||||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=32000,
|
||||
hidden_size=1024,
|
||||
input_size=4096,
|
||||
num_hidden_layers=6,
|
||||
num_attention_heads=16,
|
||||
num_key_value_heads=None,
|
||||
audio_vocab_size=2048,
|
||||
max_position_embeddings=9,
|
||||
hidden_act="silu",
|
||||
head_dim=None,
|
||||
initializer_range=0.02,
|
||||
use_cache=True,
|
||||
sliding_window=8,
|
||||
attention_dropout=0.0,
|
||||
ffn_dim=5632,
|
||||
rms_norm_eps=1e-8,
|
||||
num_codebooks=8,
|
||||
tie_word_embeddings=False,
|
||||
**kwargs,
|
||||
):
|
||||
self.vocab_size = vocab_size
|
||||
self.hidden_size = hidden_size
|
||||
self.input_size = input_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.num_key_value_heads = num_key_value_heads if num_key_value_heads is not None else num_attention_heads
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.hidden_act = hidden_act
|
||||
self.head_dim = head_dim or hidden_size // num_attention_heads
|
||||
self.initializer_range = initializer_range
|
||||
self.use_cache = use_cache
|
||||
self.sliding_window = sliding_window
|
||||
self.attention_dropout = attention_dropout
|
||||
if ffn_dim % 2 == 1:
|
||||
raise ValueError(f"`ffn_dim={ffn_dim}` must be even.")
|
||||
self.ffn_dim = ffn_dim
|
||||
self.rms_norm_eps = rms_norm_eps
|
||||
self.num_codebooks = num_codebooks
|
||||
self.audio_vocab_size = audio_vocab_size
|
||||
|
||||
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
|
||||
|
||||
|
||||
class MoshiConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`MoshiModel`]. It is used to instantiate a
|
||||
Moshi model according to the specified arguments, defining the audio encoder, Moshi depth decoder and Moshi decoder
|
||||
configs. Instantiating a configuration with the defaults will yield a similar configuration to that of the Moshiko model,
|
||||
e.g. [kmhf/hf-moshiko](https://huggingface.co/kmhf/hf-moshiko)
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
Args:
|
||||
vocab_size (`int`, *optional*, defaults to 32000):
|
||||
Vocabulary size of the MoshiDecoder model. Defines the number of different tokens that can be
|
||||
represented by the `inputs_ids` passed when calling [`MoshiDecoder`].
|
||||
hidden_size (`int`, *optional*, defaults to 4096):
|
||||
Dimensionality of the layers and the pooler layer of the main decoder.
|
||||
num_hidden_layers (`int`, *optional*, defaults to 32):
|
||||
Number of decoder layers.
|
||||
num_attention_heads (`int`, *optional*, defaults to 32):
|
||||
Number of attention heads for each attention layer in the main decoder block.
|
||||
num_key_value_heads (`int`, *optional*):
|
||||
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
|
||||
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
|
||||
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
|
||||
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
|
||||
by meanpooling all the original heads within that group. For more details checkout [this
|
||||
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `num_attention_heads`.
|
||||
audio_vocab_size (`int`, *optional*):
|
||||
Vocabulary size of the audio part of model. Defines the number of different tokens that can be
|
||||
represented by the `audio_codes` passed when calling the Moshi models.
|
||||
max_position_embeddings (`int`, *optional*, defaults to 3000):
|
||||
The maximum sequence length that this model might ever be used with. Typically, set this to something large
|
||||
just in case (e.g., 512 or 1024 or 2048).
|
||||
rope_theta (`float`, *optional*, defaults to 10000.0):
|
||||
The base period of the RoPE embeddings.
|
||||
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
|
||||
The non-linear activation function (function or string) in the decoder.
|
||||
head_dim (`int`, *optional*, defaults to `hidden_size // num_attention_heads`):
|
||||
The attention head dimension.
|
||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
use_cache (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
||||
relevant if `config.is_decoder=True`.
|
||||
sliding_window (`int`, *optional*, defaults to 3000):
|
||||
Sliding window attention window size. If not specified, will default to `3000`.
|
||||
attention_dropout (`float`, *optional*, defaults to 0.0):
|
||||
The dropout ratio for the attention probabilities.
|
||||
ffn_dim (`int`, *optional*, defaults to 22528):
|
||||
Dimensionality of the "intermediate" (often named feed-forward) layer in the main decoder block. Must be even.
|
||||
rms_norm_eps (`float`, *optional*, defaults to 1e-08):
|
||||
The epsilon used by the rms normalization layers.
|
||||
num_codebooks (`int`, *optional*, defaults to 8):
|
||||
The number of audio codebooks for each audio channels.
|
||||
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
|
||||
Whether to tie weight embeddings
|
||||
kwargs (*optional*):
|
||||
Dictionary of keyword arguments. Notably:
|
||||
- **audio_encoder_config** ([`PretrainedConfig`], *optional*) -- An instance of a configuration object that
|
||||
defines the audio encoder config.
|
||||
- **depth__config** ([`PretrainedConfig`], *optional*) -- An instance of a configuration object that
|
||||
defines the depth decoder config.
|
||||
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import (
|
||||
... MoshiConfig,
|
||||
... MoshiForConditionalGeneration,
|
||||
... )
|
||||
|
||||
>>> configuration = MoshiConfig()
|
||||
|
||||
>>> # Initializing a MoshiForConditionalGeneration (with random weights) from the kmhf/hf-moshiko style configuration
|
||||
>>> model = MoshiForConditionalGeneration(configuration)
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
|
||||
>>> # Saving the model, including its configuration
|
||||
>>> model.save_pretrained("kmhf/hf-moshiko")
|
||||
|
||||
>>> # loading model and config from pretrained folder
|
||||
>>> moshi_config = MoshiConfig.from_pretrained("kmhf/hf-moshiko")
|
||||
>>> model = MoshiForConditionalGeneration.from_pretrained("kmhf/hf-moshiko", config=moshi_config)
|
||||
```"""
|
||||
|
||||
model_type = "moshi"
|
||||
is_composition = True
|
||||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=32000,
|
||||
hidden_size=4096,
|
||||
num_hidden_layers=32,
|
||||
num_attention_heads=32,
|
||||
num_key_value_heads=None,
|
||||
audio_vocab_size=None,
|
||||
max_position_embeddings=3000,
|
||||
rope_theta=10000.0,
|
||||
hidden_act="silu",
|
||||
head_dim=None,
|
||||
initializer_range=0.02,
|
||||
use_cache=True,
|
||||
sliding_window=3000,
|
||||
attention_dropout=0.0,
|
||||
ffn_dim=22528,
|
||||
rms_norm_eps=1e-8,
|
||||
num_codebooks=8,
|
||||
tie_word_embeddings=False,
|
||||
**kwargs,
|
||||
):
|
||||
self.vocab_size = vocab_size
|
||||
self.hidden_size = hidden_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.num_key_value_heads = num_key_value_heads if num_key_value_heads is not None else num_attention_heads
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.rope_theta = rope_theta
|
||||
self.hidden_act = hidden_act
|
||||
self.head_dim = head_dim or hidden_size // num_attention_heads
|
||||
self.initializer_range = initializer_range
|
||||
self.use_cache = use_cache
|
||||
self.sliding_window = sliding_window
|
||||
self.attention_dropout = attention_dropout
|
||||
if ffn_dim % 2 == 1:
|
||||
raise ValueError(f"`ffn_dim={ffn_dim}` must be even.")
|
||||
self.ffn_dim = ffn_dim
|
||||
self.rms_norm_eps = rms_norm_eps
|
||||
self.num_codebooks = num_codebooks
|
||||
|
||||
audio_encoder_config = kwargs.pop("audio_encoder_config", {})
|
||||
audio_encoder_model_type = audio_encoder_config.pop("model_type", "mimi")
|
||||
|
||||
self.audio_encoder_config = AutoConfig.for_model(audio_encoder_model_type, **audio_encoder_config)
|
||||
|
||||
if self.num_codebooks > self.audio_encoder_config.num_codebooks:
|
||||
raise ValueError(
|
||||
f"`num_codebooks={num_codebooks}` is greater than the maximum number of codebooks that the audio encoder can deal with ({self.audio_encoder_config.num_codebooks}). Please lower it."
|
||||
)
|
||||
|
||||
self.audio_vocab_size = (
|
||||
self.audio_encoder_config.codebook_size if audio_vocab_size is None else audio_vocab_size
|
||||
)
|
||||
|
||||
depth_decoder_config = kwargs.pop("depth_decoder_config", {})
|
||||
depth_decoder_config.update(
|
||||
{
|
||||
"audio_vocab_size": self.audio_vocab_size,
|
||||
"input_size": hidden_size,
|
||||
"vocab_size": vocab_size,
|
||||
"num_codebooks": num_codebooks,
|
||||
}
|
||||
)
|
||||
|
||||
self.depth_decoder_config = MoshiDepthConfig(**depth_decoder_config)
|
||||
|
||||
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
|
||||
|
||||
@property
|
||||
def sampling_rate(self):
|
||||
return self.audio_encoder_config.sampling_rate
|
||||
|
||||
@classmethod
|
||||
def from_audio_encoder_config(
|
||||
cls,
|
||||
audio_encoder_config: PretrainedConfig,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Instantiate a [`MoshiConfig`] (or a derived class) from an audio encoder configuration.
|
||||
|
||||
Returns:
|
||||
[`MoshiConfig`]: An instance of a configuration object
|
||||
"""
|
||||
|
||||
return cls(
|
||||
audio_encoder_config=audio_encoder_config.to_dict(),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["MoshiConfig", "MoshiDepthConfig"]
|
311
src/transformers/models/moshi/convert_moshi_transformers.py
Normal file
311
src/transformers/models/moshi/convert_moshi_transformers.py
Normal file
@ -0,0 +1,311 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Convert Moshi checkpoints."""
|
||||
|
||||
import argparse
|
||||
|
||||
import safetensors
|
||||
import sentencepiece
|
||||
import torch
|
||||
|
||||
from transformers import (
|
||||
AutoFeatureExtractor,
|
||||
GenerationConfig,
|
||||
MimiModel, # initial audio encoder
|
||||
MoshiConfig,
|
||||
MoshiForConditionalGeneration,
|
||||
PreTrainedTokenizerFast,
|
||||
logging,
|
||||
)
|
||||
from transformers.convert_slow_tokenizer import MoshiConverter
|
||||
|
||||
|
||||
logging.set_verbosity_info()
|
||||
logger = logging.get_logger("transformers.models.mimi")
|
||||
|
||||
|
||||
def assert_param_count(model_1, model_2):
|
||||
count_1 = sum(p[1].numel() for p in model_1.named_parameters() if "final_proj" not in p[0])
|
||||
count_2 = sum(p[1].numel() for p in model_2.named_parameters() if "final_proj" not in p[0])
|
||||
assert count_1 == count_2, f"{model_1.__class__}: {count_1} != {model_2.__class__}: {count_2}"
|
||||
|
||||
|
||||
def param_count(model):
|
||||
return sum(p[1].numel() for p in model.named_parameters() if "final_proj" not in p[0])
|
||||
|
||||
|
||||
def _grab_best_device(use_gpu=True):
|
||||
if torch.cuda.device_count() > 0 and use_gpu:
|
||||
device = "cuda"
|
||||
else:
|
||||
device = "cpu"
|
||||
return torch.device(device)
|
||||
|
||||
|
||||
convert_list = [
|
||||
# GENERAL
|
||||
("out_norm", "decoder.model.norm"),
|
||||
("depformer_emb", "depth_decoder.emb"),
|
||||
("depformer_text_emb", "depth_decoder.text_emb"),
|
||||
("text_emb", "decoder.model.emb"),
|
||||
("emb", "embed_tokens"),
|
||||
("text_linear", "decoder.lm_head"),
|
||||
("depformer", "depth_decoder"),
|
||||
("transformer", "decoder.model"),
|
||||
# TRANSFORMERS PART
|
||||
("gating.linear_in", "mlp.fc1"),
|
||||
("gating.linear_out", "mlp.fc2"),
|
||||
("self_attn.out_proj", "self_attn.o_proj.linear"),
|
||||
("norm1", "input_layernorm"),
|
||||
("norm2", "post_attention_layernorm"),
|
||||
("layer_scale_1", "self_attn_layer_scale"),
|
||||
("layer_scale_2", "mlp_layer_scale"),
|
||||
("alpha", "weight"),
|
||||
]
|
||||
|
||||
|
||||
def _preprocess_state_dict(state_dict, config):
|
||||
# Moshi original weights are using a gating mechanism
|
||||
|
||||
# pattern for depth transformer:
|
||||
# stack(gating.{i}.linear_in)->mlp.fc1
|
||||
# stack(gating.{i}.linear_out)->mlp.fc2
|
||||
|
||||
for layer_idx in range(config.depth_decoder_config.num_hidden_layers):
|
||||
linear_layers_in = [
|
||||
state_dict.pop(f"depformer.layers.{layer_idx}.gating.{i}.linear_in.weight")
|
||||
for i in range(config.num_codebooks)
|
||||
]
|
||||
linear_layers_out = [
|
||||
state_dict.pop(f"depformer.layers.{layer_idx}.gating.{i}.linear_out.weight")
|
||||
for i in range(config.num_codebooks)
|
||||
]
|
||||
|
||||
state_dict[f"depth_decoder.layers.{layer_idx}.mlp.fc1.weight"] = torch.stack(linear_layers_in)
|
||||
state_dict[f"depth_decoder.layers.{layer_idx}.mlp.fc2.weight"] = torch.stack(linear_layers_out)
|
||||
|
||||
input_projections = []
|
||||
lm_heads = []
|
||||
for codebook_idx in range(config.num_codebooks):
|
||||
input_projections.append(state_dict.pop(f"depformer_in.{codebook_idx}.weight"))
|
||||
lm_heads.append(state_dict.pop(f"linears.{codebook_idx}.weight"))
|
||||
|
||||
state_dict["depth_decoder.input_projections.weight"] = torch.stack(input_projections, dim=0)
|
||||
state_dict["depth_decoder.lm_heads.weight"] = torch.stack(lm_heads, dim=0)
|
||||
|
||||
return state_dict
|
||||
|
||||
|
||||
def _convert_model(
|
||||
state_dict,
|
||||
hf_model,
|
||||
convert_list,
|
||||
device,
|
||||
config,
|
||||
unwanted_prefix=None,
|
||||
):
|
||||
hidden_size = config.hidden_size
|
||||
head_dim = config.head_dim
|
||||
num_heads = int(config.hidden_size // config.head_dim)
|
||||
num_key_value_heads = config.num_key_value_heads
|
||||
key_value_head_dim = config.num_key_value_heads * head_dim
|
||||
|
||||
state_dict = _preprocess_state_dict(state_dict, config)
|
||||
|
||||
# permute for sliced rotary
|
||||
def permute(w, n_heads, dim1=hidden_size, dim2=hidden_size):
|
||||
return w.view(n_heads, dim1 // n_heads // 2, 2, dim2).transpose(1, 2).reshape(dim1, dim2)
|
||||
|
||||
for k, v in list(state_dict.items()):
|
||||
if "audio_encoder" not in k:
|
||||
new_k = k if unwanted_prefix is None else k[len(unwanted_prefix) :]
|
||||
for old_layer_name, new_layer_name in convert_list:
|
||||
if old_layer_name in new_k:
|
||||
new_k = new_k.replace(old_layer_name, new_layer_name)
|
||||
|
||||
if "alpha" in k:
|
||||
state_dict[k] = state_dict[k].squeeze()
|
||||
|
||||
if "in_proj_weight" in new_k:
|
||||
# split qkv into query key and value
|
||||
mixed_qkv = state_dict.pop(k)
|
||||
if "depth_decoder" in new_k:
|
||||
mixed_qkv = mixed_qkv.view(config.num_codebooks, -1, mixed_qkv.shape[-1])
|
||||
|
||||
qkv_dim = mixed_qkv.size(1) // 3
|
||||
|
||||
query_layer = mixed_qkv[:, :qkv_dim]
|
||||
key_layer = mixed_qkv[:, qkv_dim : qkv_dim * 2]
|
||||
value_layer = mixed_qkv[:, qkv_dim * 2 :]
|
||||
state_dict[new_k.replace("in_proj_weight", "q_proj.linear.weight")] = query_layer
|
||||
state_dict[new_k.replace("in_proj_weight", "k_proj.linear.weight")] = key_layer
|
||||
|
||||
else:
|
||||
qkv_dim = mixed_qkv.size(0) // 3
|
||||
|
||||
query_layer = mixed_qkv[:qkv_dim]
|
||||
key_layer = mixed_qkv[qkv_dim : qkv_dim * 2]
|
||||
value_layer = mixed_qkv[qkv_dim * 2 :]
|
||||
state_dict[new_k.replace("in_proj_weight", "q_proj.linear.weight")] = permute(
|
||||
query_layer, num_heads, hidden_size, hidden_size
|
||||
)
|
||||
state_dict[new_k.replace("in_proj_weight", "k_proj.linear.weight")] = permute(
|
||||
key_layer, num_key_value_heads, key_value_head_dim, hidden_size
|
||||
)
|
||||
|
||||
state_dict[new_k.replace("in_proj_weight", "v_proj.linear.weight")] = value_layer
|
||||
elif "o_proj" in new_k and "depth_decoder" in new_k:
|
||||
output_layer = state_dict.pop(k)
|
||||
state_dict[new_k] = output_layer.view(config.num_codebooks, -1, output_layer.shape[-1])
|
||||
else:
|
||||
state_dict[new_k] = state_dict.pop(k)
|
||||
|
||||
# Do the last one by hand
|
||||
state_dict["depth_decoder.text_embed_tokens.weight"] = state_dict.pop(
|
||||
"depth_decoder.decoder.model.embed_tokens.weight"
|
||||
)
|
||||
|
||||
extra_keys = set(state_dict.keys()) - set(hf_model.state_dict().keys())
|
||||
missing_keys = set(hf_model.state_dict().keys()) - set(state_dict.keys())
|
||||
if len(extra_keys) != 0:
|
||||
raise ValueError(f"extra keys found: {extra_keys}")
|
||||
if len(missing_keys) != 0:
|
||||
raise ValueError(f"missing keys: {missing_keys}")
|
||||
hf_model.load_state_dict(state_dict, strict=True)
|
||||
n_params = param_count(hf_model)
|
||||
|
||||
logger.info(f"model loaded: {round(n_params/1e6,1)}M params")
|
||||
|
||||
hf_model.eval()
|
||||
hf_model.to(device)
|
||||
del state_dict
|
||||
|
||||
return hf_model
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def convert_checkpoint(
|
||||
checkpoint_path,
|
||||
pytorch_dump_folder_path,
|
||||
mimi_repo_id,
|
||||
config_path=None,
|
||||
repo_id=None,
|
||||
):
|
||||
"""
|
||||
Copy/paste/tweak model's weights to transformers design.
|
||||
"""
|
||||
device = _grab_best_device()
|
||||
|
||||
mimi_model = MimiModel.from_pretrained(mimi_repo_id, torch_dtype=torch.bfloat16)
|
||||
|
||||
if config_path is not None:
|
||||
config = MoshiConfig.from_pretrained(config_path)
|
||||
else:
|
||||
audio_encoder_config = mimi_model.config
|
||||
config = MoshiConfig.from_audio_encoder_config(audio_encoder_config)
|
||||
|
||||
model = MoshiForConditionalGeneration(config).to(torch.bfloat16)
|
||||
|
||||
depth_decoder_generation_config = GenerationConfig(
|
||||
do_sample=True,
|
||||
temperature=0.8,
|
||||
top_k=250,
|
||||
min_length=config.num_codebooks + 1,
|
||||
max_length=config.num_codebooks + 1,
|
||||
cache_implementation="sliding_window",
|
||||
)
|
||||
|
||||
generation_config = GenerationConfig(
|
||||
do_sample=True,
|
||||
temp=0.7,
|
||||
top_k=25,
|
||||
cache_implementation="sliding_window",
|
||||
pad_token_id=config.vocab_size,
|
||||
bos_token_id=config.vocab_size,
|
||||
)
|
||||
generation_config.depth_decoder_config = depth_decoder_generation_config.to_diff_dict()
|
||||
|
||||
model.generation_config = generation_config
|
||||
|
||||
original_checkpoint = safetensors.torch.load_file(checkpoint_path)
|
||||
if "best_state" in original_checkpoint:
|
||||
# we might have a training state saved, in which case discard the yaml results and just retain the weights
|
||||
original_checkpoint = original_checkpoint["best_state"]
|
||||
|
||||
audio_checkpoint = mimi_model.state_dict()
|
||||
original_checkpoint.update({f"audio_encoder.{key}": value for (key, value) in audio_checkpoint.items()})
|
||||
|
||||
model = _convert_model(original_checkpoint, model, convert_list, device, config)
|
||||
|
||||
model.save_pretrained(pytorch_dump_folder_path)
|
||||
|
||||
if repo_id:
|
||||
print("Pushing to the hub...")
|
||||
model.push_to_hub(repo_id)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--checkpoint_path", required=True, default=None, type=str, help="Path to original checkpoint")
|
||||
parser.add_argument(
|
||||
"--tokenizer_vocab_path", required=False, default=None, type=str, help="Path to original tokenizer vocab file"
|
||||
)
|
||||
parser.add_argument("--mimi_repo_id", required=True, default=None, type=str, help="Repository id to HF Mimi.")
|
||||
parser.add_argument("--config_path", default=None, type=str, help="Path to hf config.json of model to convert")
|
||||
parser.add_argument(
|
||||
"--pytorch_dump_folder_path", required=True, default=None, type=str, help="Path to the output PyTorch model."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--push_to_hub", default=None, type=str, help="Where to upload the converted model on the 🤗 hub."
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# convert tokenizer
|
||||
if args.tokenizer_vocab_path:
|
||||
original_tokenizer = sentencepiece.SentencePieceProcessor(args.tokenizer_vocab_path)
|
||||
tokenizer = MoshiConverter(args.tokenizer_vocab_path).converted()
|
||||
tokenizer = PreTrainedTokenizerFast(
|
||||
tokenizer_object=tokenizer,
|
||||
chat_template=None,
|
||||
unk_token="<unk>",
|
||||
model_input_names=["input_ids", "attention_mask"],
|
||||
clean_up_tokenization_spaces=False,
|
||||
bos_token_id=original_tokenizer.bos_id(),
|
||||
eos_token_id=original_tokenizer.eos_id(),
|
||||
pad_token_id=original_tokenizer.pad_id(),
|
||||
)
|
||||
|
||||
tokenizer.save_pretrained(args.pytorch_dump_folder_path)
|
||||
|
||||
if args.push_to_hub:
|
||||
print("Pushing the tokenizer to the hub...")
|
||||
tokenizer.push_to_hub(args.push_to_hub)
|
||||
|
||||
# upload feature extractor
|
||||
feature_extractor = AutoFeatureExtractor.from_pretrained(args.mimi_repo_id)
|
||||
feature_extractor.save_pretrained(args.pytorch_dump_folder_path)
|
||||
|
||||
if args.push_to_hub:
|
||||
print("Pushing the feature extractor to the hub...")
|
||||
feature_extractor.push_to_hub(args.push_to_hub)
|
||||
|
||||
convert_checkpoint(
|
||||
args.checkpoint_path,
|
||||
args.pytorch_dump_folder_path,
|
||||
args.mimi_repo_id,
|
||||
args.config_path,
|
||||
args.push_to_hub,
|
||||
)
|
2813
src/transformers/models/moshi/modeling_moshi.py
Normal file
2813
src/transformers/models/moshi/modeling_moshi.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -6219,6 +6219,34 @@ class MobileViTV2PreTrainedModel(metaclass=DummyObject):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class MoshiForCausalLM(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class MoshiForConditionalGeneration(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class MoshiModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class MoshiPreTrainedModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class MPNetForMaskedLM(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
|
@ -1098,6 +1098,7 @@ class GenerationTesterMixin:
|
||||
"bigbirdpegasus",
|
||||
"led",
|
||||
"mega",
|
||||
"moshi",
|
||||
"speech2text",
|
||||
"git",
|
||||
"prophetnet",
|
||||
@ -1172,6 +1173,7 @@ class GenerationTesterMixin:
|
||||
"bigbirdpegasus",
|
||||
"led",
|
||||
"mega",
|
||||
"moshi",
|
||||
"speech2text",
|
||||
"git",
|
||||
"prophetnet",
|
||||
@ -1285,6 +1287,7 @@ class GenerationTesterMixin:
|
||||
"bigbirdpegasus",
|
||||
"led",
|
||||
"mega",
|
||||
"moshi",
|
||||
"speech2text",
|
||||
"git",
|
||||
"prophetnet",
|
||||
|
@ -790,6 +790,7 @@ class MimiIntegrationTest(unittest.TestCase):
|
||||
}
|
||||
|
||||
librispeech_dummy = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
||||
|
||||
model_id = "kyutai/mimi"
|
||||
|
||||
model = MimiModel.from_pretrained(model_id, use_cache=True).to(torch_device)
|
||||
@ -840,6 +841,7 @@ class MimiIntegrationTest(unittest.TestCase):
|
||||
"32": 1803071,
|
||||
}
|
||||
librispeech_dummy = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
||||
|
||||
model_id = "kyutai/mimi"
|
||||
|
||||
processor = AutoFeatureExtractor.from_pretrained(model_id)
|
||||
|
0
tests/models/moshi/__init__.py
Normal file
0
tests/models/moshi/__init__.py
Normal file
1126
tests/models/moshi/test_modeling_moshi.py
Normal file
1126
tests/models/moshi/test_modeling_moshi.py
Normal file
File diff suppressed because it is too large
Load Diff
447
tests/models/moshi/test_tokenization_moshi.py
Normal file
447
tests/models/moshi/test_tokenization_moshi.py
Normal file
@ -0,0 +1,447 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
import pickle
|
||||
import shutil
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
from transformers import (
|
||||
SPIECE_UNDERLINE,
|
||||
AddedToken,
|
||||
AutoTokenizer,
|
||||
PreTrainedTokenizerFast,
|
||||
SpecialTokensMixin,
|
||||
)
|
||||
from transformers.convert_slow_tokenizer import MoshiConverter
|
||||
from transformers.testing_utils import (
|
||||
get_tests_dir,
|
||||
nested_simplify,
|
||||
require_sentencepiece,
|
||||
require_tokenizers,
|
||||
require_torch,
|
||||
)
|
||||
|
||||
from ...test_tokenization_common import SMALL_TRAINING_CORPUS, TokenizerTesterMixin
|
||||
|
||||
|
||||
SAMPLE_VOCAB = get_tests_dir("fixtures/test_sentencepiece.model")
|
||||
|
||||
|
||||
@require_sentencepiece
|
||||
@require_tokenizers
|
||||
class MoshiTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
from_pretrained_id = ["kmhf/hf-moshiko"]
|
||||
rust_tokenizer_class = PreTrainedTokenizerFast
|
||||
|
||||
test_slow_tokenizer = False
|
||||
test_rust_tokenizer = True
|
||||
from_pretrained_kwargs = {}
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
|
||||
# We have a SentencePiece fixture for testing
|
||||
tokenizer = PreTrainedTokenizerFast(
|
||||
tokenizer_object=MoshiConverter(vocab_file=SAMPLE_VOCAB).converted(),
|
||||
bos_token="<s>",
|
||||
unk_token="<unk>",
|
||||
eos_token="</s>",
|
||||
)
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
tokenizer.save_pretrained(self.tmpdirname)
|
||||
|
||||
def get_rust_tokenizer(self, **kwargs) -> PreTrainedTokenizerFast:
|
||||
return self.rust_tokenizer_class.from_pretrained(self.tmpdirname, **kwargs)
|
||||
|
||||
@unittest.skip(reason="No slow tokenizer")
|
||||
def test_added_tokens_serialization(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="PreTrainedTokenizerFast doesn't have tokenizer_file in its signature")
|
||||
def test_rust_tokenizer_signature(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="No slow tokenizer")
|
||||
def test_encode_decode_with_spaces(self):
|
||||
pass
|
||||
|
||||
def test_full_tokenizer(self):
|
||||
tokenizer = PreTrainedTokenizerFast(
|
||||
tokenizer_object=MoshiConverter(vocab_file=SAMPLE_VOCAB).converted(),
|
||||
bos_token="<s>",
|
||||
unk_token="<unk>",
|
||||
eos_token="</s>",
|
||||
)
|
||||
|
||||
tokens = tokenizer.tokenize("This is a test")
|
||||
self.assertListEqual(tokens, ["▁This", "▁is", "▁a", "▁t", "est"])
|
||||
|
||||
self.assertListEqual(
|
||||
tokenizer.convert_tokens_to_ids(tokens),
|
||||
[285, 46, 10, 170, 382],
|
||||
)
|
||||
|
||||
tokens = tokenizer.tokenize("I was born in 92000, and this is falsé.")
|
||||
self.assertListEqual(
|
||||
tokens,
|
||||
[
|
||||
SPIECE_UNDERLINE + "I",
|
||||
SPIECE_UNDERLINE + "was",
|
||||
SPIECE_UNDERLINE + "b",
|
||||
"or",
|
||||
"n",
|
||||
SPIECE_UNDERLINE + "in",
|
||||
SPIECE_UNDERLINE + "",
|
||||
"9",
|
||||
"2",
|
||||
"0",
|
||||
"0",
|
||||
"0",
|
||||
",",
|
||||
SPIECE_UNDERLINE + "and",
|
||||
SPIECE_UNDERLINE + "this",
|
||||
SPIECE_UNDERLINE + "is",
|
||||
SPIECE_UNDERLINE + "f",
|
||||
"al",
|
||||
"s",
|
||||
"é",
|
||||
".",
|
||||
],
|
||||
)
|
||||
ids = tokenizer.convert_tokens_to_ids(tokens)
|
||||
self.assertListEqual(
|
||||
ids,
|
||||
[8, 21, 84, 55, 24, 19, 7, 0, 602, 347, 347, 347, 3, 12, 66, 46, 72, 80, 6, 0, 4],
|
||||
)
|
||||
|
||||
back_tokens = tokenizer.convert_ids_to_tokens(ids)
|
||||
self.assertListEqual(
|
||||
back_tokens,
|
||||
[
|
||||
SPIECE_UNDERLINE + "I",
|
||||
SPIECE_UNDERLINE + "was",
|
||||
SPIECE_UNDERLINE + "b",
|
||||
"or",
|
||||
"n",
|
||||
SPIECE_UNDERLINE + "in",
|
||||
SPIECE_UNDERLINE + "",
|
||||
"<unk>",
|
||||
"2",
|
||||
"0",
|
||||
"0",
|
||||
"0",
|
||||
",",
|
||||
SPIECE_UNDERLINE + "and",
|
||||
SPIECE_UNDERLINE + "this",
|
||||
SPIECE_UNDERLINE + "is",
|
||||
SPIECE_UNDERLINE + "f",
|
||||
"al",
|
||||
"s",
|
||||
"<unk>",
|
||||
".",
|
||||
],
|
||||
)
|
||||
|
||||
def test_special_tokens_initialization(self):
|
||||
for tokenizer, pretrained_name, kwargs in self.tokenizers_list:
|
||||
with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"):
|
||||
added_tokens = [AddedToken("<special>", lstrip=True)]
|
||||
|
||||
tokenizer_r = self.rust_tokenizer_class.from_pretrained(
|
||||
pretrained_name, additional_special_tokens=added_tokens, **kwargs
|
||||
)
|
||||
r_output = tokenizer_r.encode("Hey this is a <special> token")
|
||||
|
||||
special_token_id = tokenizer_r.encode("<special>", add_special_tokens=False)[0]
|
||||
|
||||
self.assertTrue(special_token_id in r_output)
|
||||
|
||||
def test_picklable(self):
|
||||
with tempfile.NamedTemporaryFile() as f:
|
||||
shutil.copyfile(SAMPLE_VOCAB, f.name)
|
||||
tokenizer = PreTrainedTokenizerFast(
|
||||
tokenizer_object=MoshiConverter(vocab_file=f.name).converted(),
|
||||
bos_token="<s>",
|
||||
unk_token="<unk>",
|
||||
eos_token="</s>",
|
||||
)
|
||||
pickled_tokenizer = pickle.dumps(tokenizer)
|
||||
pickle.loads(pickled_tokenizer)
|
||||
|
||||
def test_training_new_tokenizer(self):
|
||||
# This feature only exists for fast tokenizers
|
||||
if not self.test_rust_tokenizer:
|
||||
self.skipTest(reason="test_rust_tokenizer is set to False")
|
||||
|
||||
tokenizer = self.get_rust_tokenizer()
|
||||
new_tokenizer = tokenizer.train_new_from_iterator(SMALL_TRAINING_CORPUS, 100)
|
||||
|
||||
# Test we can use the new tokenizer with something not seen during training
|
||||
inputs = new_tokenizer(["This is the first sentence", "This sentence is different 🤗."])
|
||||
self.assertEqual(len(inputs["input_ids"]), 2)
|
||||
decoded_input = new_tokenizer.decode(inputs["input_ids"][0], skip_special_tokens=True)
|
||||
expected_result = "This is the first sentence"
|
||||
|
||||
self.assertEqual(expected_result, decoded_input)
|
||||
|
||||
# We check that the parameters of the tokenizer remained the same
|
||||
# Check we have the same number of added_tokens for both pair and non-pair inputs.
|
||||
self.assertEqual(tokenizer.num_special_tokens_to_add(False), new_tokenizer.num_special_tokens_to_add(False))
|
||||
self.assertEqual(tokenizer.num_special_tokens_to_add(True), new_tokenizer.num_special_tokens_to_add(True))
|
||||
|
||||
# Check we have the correct max_length for both pair and non-pair inputs.
|
||||
self.assertEqual(tokenizer.max_len_single_sentence, new_tokenizer.max_len_single_sentence)
|
||||
self.assertEqual(tokenizer.max_len_sentences_pair, new_tokenizer.max_len_sentences_pair)
|
||||
|
||||
# Assert the set of special tokens match as we didn't ask to change them
|
||||
self.assertSequenceEqual(
|
||||
tokenizer.all_special_tokens_extended,
|
||||
new_tokenizer.all_special_tokens_extended,
|
||||
)
|
||||
|
||||
self.assertDictEqual(tokenizer.special_tokens_map, new_tokenizer.special_tokens_map)
|
||||
|
||||
def test_training_new_tokenizer_with_special_tokens_change(self):
|
||||
# This feature only exists for fast tokenizers
|
||||
if not self.test_rust_tokenizer:
|
||||
self.skipTest(reason="test_rust_tokenizer is set to False")
|
||||
|
||||
tokenizer = self.get_rust_tokenizer()
|
||||
# Test with a special tokens map
|
||||
class_signature = inspect.signature(tokenizer.__class__)
|
||||
if "cls_token" in class_signature.parameters:
|
||||
new_tokenizer = tokenizer.train_new_from_iterator(
|
||||
SMALL_TRAINING_CORPUS, 100, special_tokens_map={tokenizer.cls_token: "<cls>"}
|
||||
)
|
||||
cls_id = new_tokenizer.get_vocab()["<cls>"]
|
||||
self.assertEqual(new_tokenizer.cls_token, "<cls>")
|
||||
self.assertEqual(new_tokenizer.cls_token_id, cls_id)
|
||||
|
||||
# Create a new mapping from the special tokens defined in the original tokenizer
|
||||
special_tokens_list = SpecialTokensMixin.SPECIAL_TOKENS_ATTRIBUTES.copy()
|
||||
special_tokens_list.remove("additional_special_tokens")
|
||||
special_tokens_map = {}
|
||||
for token in special_tokens_list:
|
||||
# Get the private one to avoid unnecessary warnings.
|
||||
if getattr(tokenizer, f"_{token}") is not None:
|
||||
special_token = getattr(tokenizer, token)
|
||||
special_tokens_map[special_token] = f"{special_token}a"
|
||||
|
||||
# Train new tokenizer
|
||||
new_tokenizer = tokenizer.train_new_from_iterator(
|
||||
SMALL_TRAINING_CORPUS, 100, special_tokens_map=special_tokens_map
|
||||
)
|
||||
|
||||
# Check the changes
|
||||
for token in special_tokens_list:
|
||||
# Get the private one to avoid unnecessary warnings.
|
||||
if getattr(tokenizer, f"_{token}") is None:
|
||||
continue
|
||||
special_token = getattr(tokenizer, token)
|
||||
if special_token in special_tokens_map:
|
||||
new_special_token = getattr(new_tokenizer, token)
|
||||
self.assertEqual(special_tokens_map[special_token], new_special_token)
|
||||
|
||||
new_id = new_tokenizer.get_vocab()[new_special_token]
|
||||
self.assertEqual(getattr(new_tokenizer, f"{token}_id"), new_id)
|
||||
|
||||
# Check if the AddedToken / string format has been kept
|
||||
for special_token in tokenizer.all_special_tokens_extended:
|
||||
if isinstance(special_token, AddedToken) and special_token.content not in special_tokens_map:
|
||||
# The special token must appear identically in the list of the new tokenizer.
|
||||
self.assertTrue(
|
||||
special_token in new_tokenizer.all_special_tokens_extended,
|
||||
f"'{special_token}' should be in {new_tokenizer.all_special_tokens_extended}",
|
||||
)
|
||||
elif isinstance(special_token, AddedToken):
|
||||
# The special token must appear in the list of the new tokenizer as an object of type AddedToken with
|
||||
# the same parameters as the old AddedToken except the content that the user has requested to change.
|
||||
special_token_str = special_token.content
|
||||
new_special_token_str = special_tokens_map[special_token_str]
|
||||
|
||||
find = False
|
||||
for candidate in new_tokenizer.all_special_tokens_extended:
|
||||
if (
|
||||
isinstance(candidate, AddedToken)
|
||||
and candidate.content == new_special_token_str
|
||||
and candidate.lstrip == special_token.lstrip
|
||||
and candidate.rstrip == special_token.rstrip
|
||||
and candidate.normalized == special_token.normalized
|
||||
and candidate.single_word == special_token.single_word
|
||||
):
|
||||
find = True
|
||||
break
|
||||
special_token.content = new_special_token_str
|
||||
self.assertTrue(
|
||||
find,
|
||||
f"'{special_token.__repr__()}' should appear as an `AddedToken` in the all_special_tokens_extended = "
|
||||
f"{[k for k in new_tokenizer.all_special_tokens_extended if str(k)==new_special_token_str]} but it is missing"
|
||||
", this means that the new tokenizers did not keep the `rstrip`, `lstrip`, `normalized` etc attributes.",
|
||||
)
|
||||
elif special_token not in special_tokens_map:
|
||||
# The special token must appear identically in the list of the new tokenizer.
|
||||
self.assertTrue(
|
||||
special_token in new_tokenizer.all_special_tokens_extended,
|
||||
f"'{special_token.__repr__()}' should be in {new_tokenizer.all_special_tokens_extended}",
|
||||
)
|
||||
|
||||
else:
|
||||
# The special token must appear in the list of the new tokenizer as an object of type string.
|
||||
self.assertTrue(special_tokens_map[special_token] in new_tokenizer.all_special_tokens_extended)
|
||||
|
||||
# Test we can use the new tokenizer with something not seen during training
|
||||
inputs = new_tokenizer(["This is the first sentence", "This sentence is different 🤗."])
|
||||
self.assertEqual(len(inputs["input_ids"]), 2)
|
||||
decoded_input = new_tokenizer.decode(inputs["input_ids"][0], skip_special_tokens=True)
|
||||
expected_result = "This is the first sentence"
|
||||
|
||||
self.assertEqual(expected_result, decoded_input)
|
||||
|
||||
def test_alignement_methods(self):
|
||||
# TODO: @ArthurZucker - alignment is broken
|
||||
pass
|
||||
|
||||
def test_added_tokens_do_lower_case(self):
|
||||
# TODO: @ArthurZucker
|
||||
pass
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_sentencepiece
|
||||
@require_tokenizers
|
||||
class MoshiIntegrationTest(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
checkpoint_name = "kmhf/hf-moshiko"
|
||||
cls.rust_tokenizer = AutoTokenizer.from_pretrained(checkpoint_name)
|
||||
return cls
|
||||
|
||||
@require_torch
|
||||
def integration_tests(self):
|
||||
inputs = self.tokenizer(
|
||||
["The following string should be properly encoded: Hello.", "But ird and ปี ird ด"],
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
long_attention_mask = [1] * 21
|
||||
|
||||
# fmt: off
|
||||
self.assertEqual(
|
||||
nested_simplify(inputs),
|
||||
{
|
||||
"input_ids": [
|
||||
[287, 547, 2359, 457, 297, 3708, 11488, 279, 11725, 263],
|
||||
[588, 478, 1442, 267, 260, 228, 188, 159, 228, 188, 185, 260, 260, 478, 1442, 260, 260, 260, 228, 188, 152],
|
||||
],
|
||||
"attention_mask": [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1], long_attention_mask],
|
||||
},
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
def test_fast_special_tokens(self):
|
||||
fast_tokenizer = self.rust_tokenizer
|
||||
|
||||
fast_tokenizer.add_eos_token = False
|
||||
fast = fast_tokenizer.encode("A sample test", add_special_tokens=True)
|
||||
assert fast == [318, 1145, 694]
|
||||
|
||||
fast_tokenizer.add_eos_token = True
|
||||
fast = fast_tokenizer.encode("A sample test", add_special_tokens=True)
|
||||
assert fast == [318, 1145, 694]
|
||||
|
||||
self.rust_tokenizer.add_eos_token = False
|
||||
|
||||
def test_simple_encode_decode(self):
|
||||
rust_tokenizer = self.rust_tokenizer
|
||||
|
||||
self.assertEqual(rust_tokenizer.encode("This is a test"), [353, 275, 272, 694])
|
||||
self.assertEqual(rust_tokenizer.decode([353, 275, 272, 694], skip_special_tokens=True), "This is a test")
|
||||
|
||||
# bytefallback showcase
|
||||
bytefallback_tokens = [260, 235, 152, 163, 234, 184, 191, 13340, 235, 160, 163, 236, 180, 159, 234, 156, 179] # fmt: skip
|
||||
self.assertEqual(rust_tokenizer.encode("生活的真谛是"), bytefallback_tokens)
|
||||
self.assertEqual(
|
||||
rust_tokenizer.decode(bytefallback_tokens, skip_special_tokens=True),
|
||||
"生活的真谛是",
|
||||
)
|
||||
|
||||
# Inner spaces showcase
|
||||
self.assertEqual(rust_tokenizer.encode("Hi Hello"), [2769, 260, 11725])
|
||||
self.assertEqual(rust_tokenizer.decode([2769, 260, 11725], skip_special_tokens=True), "Hi Hello")
|
||||
|
||||
self.assertEqual(rust_tokenizer.encode("Hi Hello"), [2769, 260, 260, 11725])
|
||||
self.assertEqual(rust_tokenizer.decode([2769, 260, 260, 11725], skip_special_tokens=True), "Hi Hello")
|
||||
|
||||
# TODO: @ArthurZucker
|
||||
# self.assertEqual(rust_tokenizer.encode(""), [])
|
||||
|
||||
# self.assertEqual(rust_tokenizer.encode(" "), [260, 260])
|
||||
|
||||
# self.assertEqual(rust_tokenizer.encode(" "), [260, 260, 260])
|
||||
|
||||
# self.assertEqual(rust_tokenizer.encode(" Hello"), [260, 11725])
|
||||
|
||||
# self.assertEqual(rust_tokenizer.encode("<s>"), [607, 266, 578])
|
||||
|
||||
def test_no_differences_decode(self):
|
||||
rust_tokenizer = self.rust_tokenizer
|
||||
|
||||
self.assertEqual(rust_tokenizer.decode([869]), "levels")
|
||||
|
||||
self.assertEqual(rust_tokenizer.decode([30112, 869]), "unanswered levels")
|
||||
|
||||
|
||||
@require_sentencepiece
|
||||
@require_tokenizers
|
||||
class CommonSpmIntegrationTests(unittest.TestCase):
|
||||
"""
|
||||
A class that regroups important test to make sure that we properly handle the special tokens.
|
||||
"""
|
||||
|
||||
def test_edge_case_tabulation(self):
|
||||
fast_tokenizer = AutoTokenizer.from_pretrained("kmhf/hf-moshiko")
|
||||
input_text = "Hey<eos>. \t\t \n\nyou é @#😈 🤗! , 1234 15 5,61"
|
||||
EXPECTED_IDS = [11510, 934, 4451, 266, 578, 263, 260, 13, 13, 260, 14, 14, 5209, 260, 260, 1202, 260, 527, 1322, 244, 163, 156, 140, 260, 260, 244, 163, 168, 155, 430, 1047, 261, 260, 265, 270, 278, 281, 260, 265, 280, 260, 280, 261, 285, 265] # fmt: skip
|
||||
EXPECTED_TOKENS = ['▁Hey', '<', 'eo', 's', '>', '.', '▁', '<0x09>', '<0x09>', '▁', '<0x0A>', '<0x0A>', 'you', '▁', '▁', 'é', '▁', '▁@', '#', '<0xF0>', '<0x9F>', '<0x98>', '<0x88>', '▁', '▁', '<0xF0>', '<0x9F>', '<0xA4>', '<0x97>', '!', '▁▁▁▁▁▁▁', ',', '▁', '1', '2', '3', '4', '▁', '1', '5', '▁', '5', ',', '6', '1'] # fmt: skip
|
||||
|
||||
tokens = fast_tokenizer.tokenize(input_text)
|
||||
with self.subTest("test fast edge case fast"):
|
||||
self.assertEqual(tokens, EXPECTED_TOKENS)
|
||||
|
||||
input_ids = fast_tokenizer.encode(input_text)
|
||||
with self.subTest("test fast edge case fast"):
|
||||
self.assertEqual(input_ids, EXPECTED_IDS)
|
||||
|
||||
text = fast_tokenizer.decode(EXPECTED_IDS)
|
||||
with self.subTest("test fast edge case fast"):
|
||||
self.assertEqual(text, "Hey<eos>. \t\t \n\nyou é @#😈 🤗! , 1234 15 5,61")
|
||||
|
||||
input_text = "\t\t\t\t \n\n61"
|
||||
EXPECTED_IDS = [260, 13, 13, 13, 13, 260, 14, 14, 285, 265]
|
||||
EXPECTED_TOKENS = ["▁", "<0x09>", "<0x09>", "<0x09>", "<0x09>", "▁", "<0x0A>", "<0x0A>", "6", "1"]
|
||||
|
||||
tokens = fast_tokenizer.tokenize(input_text)
|
||||
with self.subTest("test fast edge case fast"):
|
||||
self.assertEqual(tokens, EXPECTED_TOKENS)
|
||||
|
||||
input_ids = fast_tokenizer.encode(input_text)
|
||||
with self.subTest("test fast edge case fast"):
|
||||
self.assertEqual(input_ids, EXPECTED_IDS)
|
||||
|
||||
text = fast_tokenizer.decode(EXPECTED_IDS)
|
||||
with self.subTest("test fast edge case fast"):
|
||||
self.assertEqual(text, "\t\t\t\t \n\n61")
|
@ -327,6 +327,7 @@ IGNORE_NON_AUTO_CONFIGURED = PRIVATE_MODELS.copy() + [
|
||||
"SiglipVisionModel",
|
||||
"SiglipTextModel",
|
||||
"ChameleonVQVAE", # no autoclass for VQ-VAE models
|
||||
"MoshiForConditionalGeneration", # no auto class for speech-to-speech
|
||||
]
|
||||
|
||||
# DO NOT edit this list!
|
||||
|
Loading…
Reference in New Issue
Block a user