mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-15 18:48:24 +06:00

* first draft * cleaner version * udpate tests + modeling * add tests * init * udpate test_modeling_common * fix tests * csm Processor draft * convertion update * mimi cache padding convolutions draft * mimi streaming udpates * update mimi padding cache test * udpate cache padding mimi test * make style mimi * updates generate moshi asr * moshi asr integration tests (single + batched) * update tests * update conversion script * good default sliding window value * udpdate generate * update test checkpoint * nit * fix mimi * fix codec prefix * revert * revert * update config * update config * unnecessary mimi input restriction * remove delay in tokens * remove _prepare_4d_causal_attention_mask_with_cache_position and _update_causal_mask * test update * modular update * make style * nit * rename * create codec model generation config at init * remove delay * max_new_tokens/length warning * correct conv1 padding cache import for modular * nit * fix on encoder_past_key_values * convert modular * move frame_size to config * move frame_size to config * update test name * handle first token is bos * better handling of max_new_tokens * fix * fix batch size in test input prep * update docstring * convert modular * make style * make style * add feature extractor * correct modular convention name for feature_extraction file * update convertion script * doc processor * update doc * udpate init * update model type * fixes * update tests * fix * make * add doc * nit * fix * doc * auto mappings * doc * nit * convert modular * doc * nit * extend _keep_in_fp32_modules to enforce fp32 * renaming to stt * doc update + test update * doc fixes * doc fix * doc fix * fix musicgen tests * fix musicgen tests * make style * fix musicgen tests * correct frame_rate config param for mimi * update mimi test * revert update mimi test * enforce cpu test * move cache init in cache class * convert modular * docstring update * update model id * feature_extractor -> feature_extraction (SEW) * convert modular * update model id
123 lines
4.1 KiB
Markdown
123 lines
4.1 KiB
Markdown
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
|
||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||
the License. You may obtain a copy of the License at
|
||
|
||
http://www.apache.org/licenses/LICENSE-2.0
|
||
|
||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||
specific language governing permissions and limitations under the License.
|
||
|
||
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
||
rendered properly in your Markdown viewer.
|
||
|
||
-->
|
||
|
||
# Kyutai Speech-To-Text
|
||
## Overview
|
||
|
||
Kyutai STT is a speech-to-text model architecture based on the [Mimi codec](https://huggingface.co/docs/transformers/en/model_doc/mimi), which encodes audio into discrete tokens in a streaming fashion, and a [Moshi-like](https://huggingface.co/docs/transformers/en/model_doc/moshi) autoregressive decoder. Kyutai’s lab has released two model checkpoints:
|
||
- [kyutai/stt-1b-en_fr](https://huggingface.co/kyutai/stt-1b-en_fr): a 1B-parameter model capable of transcribing both English and French
|
||
- [kyutai/stt-2.6b-en](https://huggingface.co/kyutai/stt-2.6b-en): a 2.6B-parameter model focused solely on English, optimized for maximum transcription accuracy
|
||
|
||
<div class="flex justify-center">
|
||
<img src="https://huggingface.co/datasets/eustlb/documentation-images/resolve/main/kyutai_stt.png"/>
|
||
</div>
|
||
|
||
## Usage Tips
|
||
|
||
### Inference
|
||
|
||
```python
|
||
import torch
|
||
from datasets import load_dataset, Audio
|
||
from transformers import KyutaiSpeechToTextProcessor, KyutaiSpeechToTextForConditionalGeneration
|
||
|
||
# 1. load the model and the processor
|
||
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
|
||
model_id = "kyutai/stt-2.6b-en"
|
||
|
||
processor = KyutaiSpeechToTextProcessor.from_pretrained(model_id)
|
||
model = KyutaiSpeechToTextForConditionalGeneration.from_pretrained(model_id, device_map=torch_device)
|
||
|
||
# 2. load audio samples
|
||
ds = load_dataset(
|
||
"hf-internal-testing/librispeech_asr_dummy", "clean", split="validation"
|
||
)
|
||
ds = ds.cast_column("audio", Audio(sampling_rate=24000))
|
||
|
||
# 3. prepare the model inputs
|
||
inputs = processor(
|
||
ds[0]["audio"]["array"],
|
||
)
|
||
inputs.to(torch_device)
|
||
|
||
# 4. infer the model
|
||
output_tokens = model.generate(**inputs)
|
||
|
||
# 5. decode the generated tokens
|
||
print(processor.batch_decode(output_tokens, skip_special_tokens=True))
|
||
```
|
||
|
||
### Batched Inference
|
||
|
||
```python
|
||
import torch
|
||
from datasets import load_dataset, Audio
|
||
from transformers import KyutaiSpeechToTextProcessor, KyutaiSpeechToTextForConditionalGeneration
|
||
|
||
# 1. load the model and the processor
|
||
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
|
||
model_id = "kyutai/stt-2.6b-en"
|
||
|
||
processor = KyutaiSpeechToTextProcessor.from_pretrained(model_id)
|
||
model = KyutaiSpeechToTextForConditionalGeneration.from_pretrained(model_id, device_map=torch_device)
|
||
|
||
# 2. load audio samples
|
||
ds = load_dataset(
|
||
"hf-internal-testing/librispeech_asr_dummy", "clean", split="validation"
|
||
)
|
||
ds = ds.cast_column("audio", Audio(sampling_rate=24000))
|
||
|
||
# 3. prepare the model inputs
|
||
audio_arrays = [ds[i]["audio"]["array"] for i in range(4)]
|
||
inputs = processor(audio_arrays, return_tensors="pt", padding=True)
|
||
inputs = inputs.to(torch_device)
|
||
|
||
# 4. infer the model
|
||
output_tokens = model.generate(**inputs)
|
||
|
||
# 5. decode the generated tokens
|
||
decoded_outputs = processor.batch_decode(output_tokens, skip_special_tokens=True)
|
||
for output in decoded_outputs:
|
||
print(output)
|
||
```
|
||
|
||
This model was contributed by [Eustache Le Bihan](https://huggingface.co/eustlb).
|
||
The original code can be found [here](https://github.com/kyutai-labs/moshi).
|
||
|
||
|
||
## KyutaiSpeechToTextConfig
|
||
|
||
[[autodoc]] KyutaiSpeechToTextConfig
|
||
|
||
## KyutaiSpeechToTextProcessor
|
||
|
||
[[autodoc]] KyutaiSpeechToTextProcessor
|
||
- __call__
|
||
|
||
## KyutaiSpeechToTextFeatureExtractor
|
||
|
||
[[autodoc]] KyutaiSpeechToTextFeatureExtractor
|
||
|
||
## KyutaiSpeechToTextForConditionalGeneration
|
||
|
||
[[autodoc]] KyutaiSpeechToTextForConditionalGeneration
|
||
- forward
|
||
- generate
|
||
|
||
## KyutaiSpeechToTextModel
|
||
|
||
[[autodoc]] KyutaiSpeechToTextModel
|