
* make work with cache abstraction * correct for static cache * hacks for compile * make fast * fix * fix pos ids * generate * fix sdpa * fix sdpa cache pos * fix fa2 * clean fa2 * integrate cache into generate * make style * copies * more copies * update eager * update sdpa * update fa2 * simplify * use cache pos * always compute cross-cache for debug * avoid recompiles Co-authored-by: Arthur Zucker <arthur@huggingface.co> * fix fix * fix fix fix * more fix * try encoder-decoder cache (too messy) * revert encoder-decoder cache * check cross-attn cache * use enc-dec dataclass * use richer enc-dec dataclass * clean-up * revert static cache changes * small fixes * revert to cpu flag * fix copies * add static slow test * past k/v docstring * more docstrings * cache_position docstrings * add to docs * add enc-dec cache to docs * make style * fix after rebase * fix beam * style * fix generation strategies * fix most decoder-only tests * style * skip test * more clean up * small docstrings * Apply suggestions from code review Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * add todo * only crop self-attn * check cache in mixin * style * fix re-compile after rebase * move `is_updated` logic to enc-dec wrapper * revert back * revert cache back * finalise design * fix * fix fix * style * Update src/transformers/cache_utils.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * deprecate * updates * final updates * style * style --------- Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
8.6 KiB
Whisper
Overview
The Whisper model was proposed in Robust Speech Recognition via Large-Scale Weak Supervision by Alec Radford, Jong Wook Kim, Tao Xu, Greg Brockman, Christine McLeavey, Ilya Sutskever.
The abstract from the paper is the following:
We study the capabilities of speech processing systems trained simply to predict large amounts of transcripts of audio on the internet. When scaled to 680,000 hours of multilingual and multitask supervision, the resulting models generalize well to standard benchmarks and are often competitive with prior fully supervised results but in a zeroshot transfer setting without the need for any finetuning. When compared to humans, the models approach their accuracy and robustness. We are releasing models and inference code to serve as a foundation for further work on robust speech processing.
This model was contributed by Arthur Zucker. The Tensorflow version of this model was contributed by amyeroberts. The original code can be found here.
Usage tips
-
The model usually performs well without requiring any finetuning.
-
The architecture follows a classic encoder-decoder architecture, which means that it relies on the [
~generation.GenerationMixin.generate
] function for inference. -
One can use [
WhisperProcessor
] to prepare audio for the model, and decode the predicted ID's back into text. -
To convert the model and the processor, we recommend using the following:
python src/transformers/models/whisper/convert_openai_to_hf.py --checkpoint_path "" --pytorch_dump_folder_path "Arthur/whisper-3" --convert_preprocessor True
The script will automatically determine all necessary parameters from the OpenAI checkpoint. A tiktoken
library needs to be installed
to perform the conversion of the OpenAI tokenizer to the tokenizers
version.
Inference
Here is a step-by-step guide to transcribing an audio sample using a pre-trained Whisper model:
>>> from datasets import load_dataset
>>> from transformers import WhisperProcessor, WhisperForConditionalGeneration
>>> # Select an audio file and read it:
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
>>> audio_sample = ds[0]["audio"]
>>> # Load the Whisper model in Hugging Face format:
>>> processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
>>> model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")
>>> # Use the model and processor to transcribe the audio:
>>> input_features = processor(
... audio_sample["array"], sampling_rate=audio_sample["sampling_rate"], return_tensors="pt"
... ).input_features
>>> # Generate token ids
>>> predicted_ids = model.generate(input_features)
>>> # Decode token ids to text
>>> transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)
>>> transcription[0]
' Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.'
Whisper is compatible with the following optimisations:
- PyTorch Scaled Dot Product Attention (SDPA): flash attention and memory-efficient attention kernels. Enabled by default for
torch>=2.1.1
. - Flash Attention 2: improved implementation of flash attention through better parallelism and work partitioning.
- torch.compile: JIT-compile the forward pass to dispatch to efficient fused kernels.
As an example, the following codesnippet enables SDPA and torch.compile
for up to 5x faster inference:
>>> from datasets import load_dataset
>>> from transformers import WhisperProcessor, WhisperForConditionalGeneration
>>> # Select an audio file and read it:
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
>>> audio_sample = ds[0]["audio"]
>>> # Load the Whisper model with SDPA attention
>>> processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
>>> model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en", attn_implementation="sdpa")
>>> # Enable static cache and compile the forward pass
>>> model.generation_config.cache_implementation = "static"
>>> model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True)
>>> # Use the model and processor to transcribe the audio:
>>> input_features = processor(
... audio_sample["array"], sampling_rate=audio_sample["sampling_rate"], return_tensors="pt"
... ).input_features
>>> # Compile the forward pass
>>> _ = model.generate(input_features)
>>> # Generate token ids using compiled graph (fast!)
>>> predicted_ids = model.generate(input_features)
>>> # Decode token ids to text
>>> transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)
>>> transcription[0]
' Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.'
For more details on each optimisation, refer to the documentation linked above.
Resources
A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with Whisper. If you're interested in submitting a resource to be included here, please feel free to open a Pull Request and we'll review it! The resource should ideally demonstrate something new instead of duplicating an existing resource.
- Fine-tune Whisper on your own dataset for better downstream performance.
- Distil-Whisper: Upto 6x faster, 2x smaller distilled Whisper models for English. We release the model checkpoints, and distillation code.
- A fork with a script to convert a Whisper model in Hugging Face format to OpenAI format. 🌎 Usage example:
pip install -U openai-whisper
python convert_hf_to_openai.py \
--checkpoint openai/whisper-tiny \
--whisper_dump_path whisper-tiny-openai.pt
WhisperConfig
autodoc WhisperConfig
WhisperTokenizer
autodoc WhisperTokenizer - set_prefix_tokens - build_inputs_with_special_tokens - get_special_tokens_mask - create_token_type_ids_from_sequences - save_vocabulary - batch_decode - decode - basic_normalize - normalize
WhisperTokenizerFast
autodoc WhisperTokenizerFast - set_prefix_tokens - build_inputs_with_special_tokens - get_special_tokens_mask - create_token_type_ids_from_sequences - save_vocabulary - batch_decode - decode - basic_normalize - normalize
WhisperFeatureExtractor
autodoc WhisperFeatureExtractor - call
WhisperProcessor
autodoc WhisperProcessor - call - from_pretrained - save_pretrained - batch_decode - decode
WhisperModel
autodoc WhisperModel - forward - _mask_input_features
WhisperForConditionalGeneration
autodoc WhisperForConditionalGeneration - forward - generate
WhisperForCausalLM
autodoc WhisperForCausalLM - forward
WhisperForAudioClassification
autodoc WhisperForAudioClassification - forward
TFWhisperModel
autodoc TFWhisperModel - call
TFWhisperForConditionalGeneration
autodoc TFWhisperForConditionalGeneration - call
FlaxWhisperModel
autodoc FlaxWhisperModel - call
FlaxWhisperForConditionalGeneration
autodoc FlaxWhisperForConditionalGeneration - call
FlaxWhisperForAudioClassification
autodoc FlaxWhisperForAudioClassification - call