# Mixtral ## Overview Mixtral-8x7B was introduced in the [Mixtral of Experts blogpost](https://mistral.ai/news/mixtral-of-experts/) by Albert Jiang, Alexandre Sablayrolles, Arthur Mensch, Chris Bamford, Devendra Singh Chaplot, Diego de las Casas, Florian Bressand, Gianna Lengyel, Guillaume Lample, Lélio Renard Lavaud, Lucile Saulnier, Marie-Anne Lachaux, Pierre Stock, Teven Le Scao, Thibaut Lavril, Thomas Wang, Timothée Lacroix, William El Sayed. The introduction of the blog post says: *Today, the team is proud to release Mixtral 8x7B, a high-quality sparse mixture of experts models (SMoE) with open weights. Licensed under Apache 2.0. Mixtral outperforms Llama 2 70B on most benchmarks with 6x faster inference. It is the strongest open-weight model with a permissive license and the best model overall regarding cost/performance trade-offs. In particular, it matches or outperforms GPT3.5 on most standard benchmarks.* Mixtral-8x7B is the second large language model (LLM) released by [mistral.ai](https://mistral.ai/), after [Mistral-7B](mistral). ### Architectural details Mixtral-8x7B is a decoder-only Transformer with the following architectural choices: - Mixtral is a Mixture of Experts (MoE) model with 8 experts per MLP, with a total of 45 billion parameters. To learn more about mixture-of-experts, refer to the [blog post](https://huggingface.co/blog/moe). - Despite the model having 45 billion parameters,, the compute required for a single forward pass is the same as that of a 14 billion parameter model. This is because even though each of the experts have to be loaded in RAM (70B like ram requirement) each token from the hidden states are dispatched twice (top 2 routing) and thus the compute (the operation required at each forward computation) is just 2 X sequence_length. The following implementation details are shared with Mistral AI's first model [Mistral-7B](mistral): - Sliding Window Attention - Trained with 8k context length and fixed cache size, with a theoretical attention span of 128K tokens - GQA (Grouped Query Attention) - allowing faster inference and lower cache size. - Byte-fallback BPE tokenizer - ensures that characters are never mapped to out of vocabulary tokens. For more details refer to the [release blog post](https://mistral.ai/news/mixtral-of-experts/). ### License `Mixtral-8x7B` is released under the Apache 2.0 license. ## Usage tips The Mistral team has released 2 checkpoints: - a base model, [Mixtral-8x7B-v0.1](https://huggingface.co/mistralai/Mixtral-8x7B-v0.1), which has been pre-trained to predict the next token on internet-scale data. - an instruction tuned model, [Mixtral-8x7B-Instruct-v0.1](https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1), which is the base model optimized for chat purposes using supervised fine-tuning (SFT) and direct preference optimization (DPO). The base model can be used as follows: ```python >>> from transformers import AutoModelForCausalLM, AutoTokenizer >>> model = AutoModelForCausalLM.from_pretrained("mistralai/Mixtral-8x7B-v0.1", device_map="auto") >>> tokenizer = AutoTokenizer.from_pretrained("mistralai/Mixtral-8x7B-v0.1") >>> prompt = "My favourite condiment is" >>> model_inputs = tokenizer([prompt], return_tensors="pt").to("cuda") >>> model.to(device) >>> generated_ids = model.generate(**model_inputs, max_new_tokens=100, do_sample=True) >>> tokenizer.batch_decode(generated_ids)[0] "My favourite condiment is to ..." ``` The instruction tuned model can be used as follows: ```python >>> from transformers import AutoModelForCausalLM, AutoTokenizer >>> model = AutoModelForCausalLM.from_pretrained("mistralai/Mixtral-8x7B-Instruct-v0.1", device_map="auto") >>> tokenizer = AutoTokenizer.from_pretrained("mistralai/Mixtral-8x7B-Instruct-v0.1") >>> messages = [ ... {"role": "user", "content": "What is your favourite condiment?"}, ... {"role": "assistant", "content": "Well, I'm quite partial to a good squeeze of fresh lemon juice. It adds just the right amount of zesty flavour to whatever I'm cooking up in the kitchen!"}, ... {"role": "user", "content": "Do you have mayonnaise recipes?"} ... ] >>> model_inputs = tokenizer.apply_chat_template(messages, return_tensors="pt").to("cuda") >>> generated_ids = model.generate(model_inputs, max_new_tokens=100, do_sample=True) >>> tokenizer.batch_decode(generated_ids)[0] "Mayonnaise can be made as follows: (...)" ``` As can be seen, the instruction-tuned model requires a [chat template](../chat_templating) to be applied to make sure the inputs are prepared in the right format. ## Speeding up Mixtral by using Flash Attention The code snippets above showcase inference without any optimization tricks. However, one can drastically speed up the model by leveraging [Flash Attention](../perf_train_gpu_one.md#flash-attention-2), which is a faster implementation of the attention mechanism used inside the model. First, make sure to install the latest version of Flash Attention 2 to include the sliding window attention feature. ```bash pip install -U flash-attn --no-build-isolation ``` Make also sure that you have a hardware that is compatible with Flash-Attention 2. Read more about it in the official documentation of the [flash attention repository](https://github.com/Dao-AILab/flash-attention). Make also sure to load your model in half-precision (e.g. `torch.float16`) To load and run a model using Flash Attention-2, refer to the snippet below: ```python >>> import torch >>> from transformers import AutoModelForCausalLM, AutoTokenizer >>> model = AutoModelForCausalLM.from_pretrained("mistralai/Mixtral-8x7B-v0.1", torch_dtype=torch.float16, attn_implementation="flash_attention_2", device_map="auto") >>> tokenizer = AutoTokenizer.from_pretrained("mistralai/Mixtral-8x7B-v0.1") >>> prompt = "My favourite condiment is" >>> model_inputs = tokenizer([prompt], return_tensors="pt").to("cuda") >>> model.to(device) >>> generated_ids = model.generate(**model_inputs, max_new_tokens=100, do_sample=True) >>> tokenizer.batch_decode(generated_ids)[0] "The expected output" ``` ### Expected speedups Below is a expected speedup diagram that compares pure inference time between the native implementation in transformers using `mistralai/Mixtral-8x7B-v0.1` checkpoint and the Flash Attention 2 version of the model.