
* up * up * test * logits ok * up * up * few fixes * conversion script * up * nits * nits * update * nuke * more updates * nites * fix many issues * nit * scatter * nit * nuke megablocks * nits * fix conversion script * nit * remove * nits * nit * update * oupsssss * change * nits device * nits * fixup * update * merge * add copied from * fix the copy mentions * update tests * more fixes * nits * conversion script * add parts of the readme * Update tests/models/mixtral/test_modeling_mixtral.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * new test + conversion script * Apply suggestions from code review Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Apply suggestions from code review * fix * fix copies * fix copies * ooops * fix config * Apply suggestions from code review * fix nits * nit * add copies * add batched tests * docs * fix flash attention * let's add more verbose * add correct outputs * support router ouptus * ignore copies where needed * fix * cat list if list is given for now * nits * Update docs/source/en/model_doc/mixtral.md * finish router refactoring * fix forward * fix expected values * nits * fixup * fix * fix bug * fix * fix dtype mismatch * fix * grrr grrr I support item assignment * fix CI * docs * fixup * remove some copied form * fix weird diff * skip doctest fast on the config and modeling * mark that is supports flash attention in the doc * update * Update src/transformers/models/mixtral/modeling_mixtral.py Co-authored-by: Lysandre Debut <hi@lysand.re> * Update docs/source/en/model_doc/mixtral.md Co-authored-by: Lysandre Debut <hi@lysand.re> * revert router logits config issue * update doc accordingly * Update src/transformers/models/mixtral/convert_mixtral_weights_to_hf.py * nits * use torch testing asssert close * fixup * doc nits --------- Co-authored-by: younesbelkada <younesbelkada@gmail.com> Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Co-authored-by: Lysandre Debut <hi@lysand.re>
7.4 KiB
Mixtral
Overview
Mixtral-8x7B is Mistral AI's second Large Language Model (LLM).
The Mixtral model was proposed in the by the Mistral AI team.
It was introduced in the Mixtral of Experts blogpost with the following introduction:
Today, the team is proud to release Mixtral 8x7B, a high-quality sparse mixture of experts models (SMoE) with open weights. Licensed under Apache 2.0. Mixtral outperforms Llama 2 70B on most benchmarks with 6x faster inference. It is the strongest open-weight model with a permissive license and the best model overall regarding cost/performance trade-offs. In particular, it matches or outperforms GPT3.5 on most standard benchmarks.
Tips:
- The model needs to be converted using the conversion script.
- If the model is quantized to 4bits, a single A100 is enough to fit the entire 84B model.
This model was contributed by Younes Belkada and Arthur Zucker . The original code can be found here.
Model Details
Mixtral-84B is a decoder-based LM with the following architectural choices:
- Mixtral is a Mixture of Expert (MOE) model with 8 experts per MLP, with a total of 85B paramateres but the compute required is the same as a 14B model. This is because even though each experts have to be loaded in RAM (70B like ram requirement) each token from the hidden states are dipatched twice (top 2 routing) and thus the compute (the operation required at each foward computation) is just 2 X sequence_length.
The following implementation details are shared with Mistral AI's first model mistral:
- Sliding Window Attention - Trained with 8k context length and fixed cache size, with a theoretical attention span of 128K tokens
- GQA (Grouped Query Attention) - allowing faster inference and lower cache size.
- Byte-fallback BPE tokenizer - ensures that characters are never mapped to out of vocabulary tokens.
They also provide an instruction fine-tuned model: mistralai/Mixtral-8x7B-v0.1
which can be used for chat-based inference.
For more details please read our release blog post
License
Mixtral-8x7B
is released under the Apache 2.0 license.
Usage tips
Mixtral-8x7B
can be found on the Huggingface Hub
These ready-to-use checkpoints can be downloaded and used via the HuggingFace Hub:
>>> from transformers import AutoModelForCausalLM, AutoTokenizer
>>> device = "cuda" # the device to load the model onto
>>> model = AutoModelForCausalLM.from_pretrained("mistralai/Mixtral-8x7B-v0.1")
>>> tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-8x7B")
>>> prompt = "My favourite condiment is"
>>> model_inputs = tokenizer([prompt], return_tensors="pt").to(device)
>>> model.to(device)
>>> generated_ids = model.generate(**model_inputs, max_new_tokens=100, do_sample=True)
>>> tokenizer.batch_decode(generated_ids)[0]
"The expected output"
To use the raw checkpoints with HuggingFace you can use the convert_mixtral_weights_to_hf.py
script to convert them to the HuggingFace format:
python src/transformers/models/mixtral/convert_mixtral_weights_to_hf.py \
--input_dir /path/to/downloaded/mistral/weights --output_dir /output/path
You can then load the converted model from the output/path
:
from transformers import MixtralForCausalLM, LlamaTokenizer
tokenizer = LlamaTokenizer.from_pretrained("/output/path")
model = MixtralForCausalLM.from_pretrained("/output/path")
Combining Mixtral and Flash Attention 2
First, make sure to install the latest version of Flash Attention 2 to include the sliding window attention feature.
pip install -U flash-attn --no-build-isolation
Make also sure that you have a hardware that is compatible with Flash-Attention 2. Read more about it in the official documentation of flash-attn
repository. Make also sure to load your model in half-precision (e.g. torch.float16
)
To load and run a model using Flash Attention 2, refer to the snippet below:
>>> import torch
>>> from transformers import AutoModelForCausalLM, AutoTokenizer
>>> device = "cuda" # the device to load the model onto
>>> model = AutoModelForCausalLM.from_pretrained("mistralai/Mixtral-8x7B-v0.1", torch_dtype=torch.float16, attn_implementation="flash_attention_2")
>>> tokenizer = AutoTokenizer.from_pretrained("mistralai/Mixtral-8x7B-v0.1")
>>> prompt = "My favourite condiment is"
>>> model_inputs = tokenizer([prompt], return_tensors="pt").to(device)
>>> model.to(device)
>>> generated_ids = model.generate(**model_inputs, max_new_tokens=100, do_sample=True)
>>> tokenizer.batch_decode(generated_ids)[0]
"The expected output"
Expected speedups
Below is a expected speedup diagram that compares pure inference time between the native implementation in transformers using mistralai/Mixtral-8x7B-v0.1
checkpoint and the Flash Attention 2 version of the model.

Sliding window Attention
The current implementation supports the sliding window attention mechanism and memory efficient cache management.
To enable sliding window attention, just make sure to have a flash-attn
version that is compatible with sliding window attention (>=2.3.0
).
The Flash Attention-2 model uses also a more memory efficient cache slicing mechanism - as recommended per the official implementation of Mistral model that use rolling cache mechanism we keep the cache size fixed (self.config.sliding_window
), support batched generation only for padding_side="left"
and use the absolute position of the current token to compute the positional embedding.
The Mistral Team
Albert Jiang, Alexandre Sablayrolles, Arthur Mensch, Chris Bamford, Devendra Singh Chaplot, Diego de las Casas, Florian Bressand, Gianna Lengyel, Guillaume Lample, Lélio Renard Lavaud, Lucile Saulnier, Marie-Anne Lachaux, Pierre Stock, Teven Le Scao, Thibaut Lavril, Thomas Wang, Timothée Lacroix, William El Sayed.
MixtralConfig
autodoc MixtralConfig
MixtralModel
autodoc MixtralModel - forward
MixtralForCausalLM
autodoc MixtralForCausalLM - forward
MixtralForSequenceClassification
autodoc MixtralForSequenceClassification - forward