# Mistral
[Mistral](https://huggingface.co/papers/2310.06825) is a 7B parameter language model, available as a pretrained and instruction-tuned variant, focused on balancing
the scaling costs of large models with performance and efficient inference. This model uses sliding window attention (SWA) trained with a 8K context length and a fixed cache size to handle longer sequences more effectively. Grouped-query attention (GQA) speeds up inference and reduces memory requirements. Mistral also features a byte-fallback BPE tokenizer to improve token handling and efficiency by ensuring characters are never mapped to out-of-vocabulary tokens.
You can find all the original Mistral checkpoints under the [Mistral AI_](https://huggingface.co/mistralai) organization.
> [!TIP]
> Click on the Mistral models in the right sidebar for more examples of how to apply Mistral to different language tasks.
The example below demonstrates how to chat with [`Pipeline`] or the [`AutoModel`], and from the command line.
```python
>>> import torch
>>> from transformers import pipeline
>>> messages = [
... {"role": "user", "content": "What is your favourite condiment?"},
... {"role": "assistant", "content": "Well, I'm quite partial to a good squeeze of fresh lemon juice. It adds just the right amount of zesty flavour to whatever I'm cooking up in the kitchen!"},
... {"role": "user", "content": "Do you have mayonnaise recipes?"}
... ]
>>> chatbot = pipeline("text-generation", model="mistralai/Mistral-7B-Instruct-v0.3", torch_dtype=torch.bfloat16, device=0)
>>> chatbot(messages)
```
```python
>>> import torch
>>> from transformers import AutoModelForCausalLM, AutoTokenizer
>>> model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.3", torch_dtype=torch.bfloat16, attn_implementation="sdpa", device_map="auto")
>>> tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.3")
>>> messages = [
... {"role": "user", "content": "What is your favourite condiment?"},
... {"role": "assistant", "content": "Well, I'm quite partial to a good squeeze of fresh lemon juice. It adds just the right amount of zesty flavour to whatever I'm cooking up in the kitchen!"},
... {"role": "user", "content": "Do you have mayonnaise recipes?"}
... ]
>>> model_inputs = tokenizer.apply_chat_template(messages, return_tensors="pt").to("cuda")
>>> generated_ids = model.generate(model_inputs, max_new_tokens=100, do_sample=True)
>>> tokenizer.batch_decode(generated_ids)[0]
"Mayonnaise can be made as follows: (...)"
```
```python
echo -e "My favorite condiment is" | transformers chat mistralai/Mistral-7B-v0.3 --torch_dtype auto --device 0 --attn_implementation flash_attention_2
```
Quantization reduces the memory burden of large models by representing the weights in a lower precision. Refer to the [Quantization](../quantization/overview) overview for more available quantization backends.
The example below uses [bitsandbytes](../quantization/bitsandbytes) to only quantize the weights to 4-bits.
```python
>>> import torch
>>> from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
>>> # specify how to quantize the model
>>> quantization_config = BitsAndBytesConfig(
... load_in_4bit=True,
... bnb_4bit_quant_type="nf4",
... bnb_4bit_compute_dtype="torch.float16",
... )
>>> model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.3", quantization_config=True, torch_dtype=torch.bfloat16, device_map="auto")
>>> tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.3")
>>> prompt = "My favourite condiment is"
>>> messages = [
... {"role": "user", "content": "What is your favourite condiment?"},
... {"role": "assistant", "content": "Well, I'm quite partial to a good squeeze of fresh lemon juice. It adds just the right amount of zesty flavour to whatever I'm cooking up in the kitchen!"},
... {"role": "user", "content": "Do you have mayonnaise recipes?"}
... ]
>>> model_inputs = tokenizer.apply_chat_template(messages, return_tensors="pt").to("cuda")
>>> generated_ids = model.generate(model_inputs, max_new_tokens=100, do_sample=True)
>>> tokenizer.batch_decode(generated_ids)[0]
"The expected output"
```
Use the [AttentionMaskVisualizer](https://github.com/huggingface/transformers/blob/beb9b5b02246b9b7ee81ddf938f93f44cfeaad19/src/transformers/utils/attention_visualizer.py#L139) to better understand what tokens the model can and cannot attend to.
```py
>>> from transformers.utils.attention_visualizer import AttentionMaskVisualizer
>>> visualizer = AttentionMaskVisualizer("mistralai/Mistral-7B-Instruct-v0.3")
>>> visualizer("Do you have mayonnaise recipes?")
```
## MistralConfig
[[autodoc]] MistralConfig
## MistralModel
[[autodoc]] MistralModel
- forward
## MistralForCausalLM
[[autodoc]] MistralForCausalLM
- forward
## MistralForSequenceClassification
[[autodoc]] MistralForSequenceClassification
- forward
## MistralForTokenClassification
[[autodoc]] MistralForTokenClassification
- forward
## MistralForQuestionAnswering
[[autodoc]] MistralForQuestionAnswering
- forward
## FlaxMistralModel
[[autodoc]] FlaxMistralModel
- __call__
## FlaxMistralForCausalLM
[[autodoc]] FlaxMistralForCausalLM
- __call__
## TFMistralModel
[[autodoc]] TFMistralModel
- call
## TFMistralForCausalLM
[[autodoc]] TFMistralForCausalLM
- call
## TFMistralForSequenceClassification
[[autodoc]] TFMistralForSequenceClassification
- call