# Mamba 2
[Mamba 2](https://huggingface.co/papers/2405.21060) is based on the state space duality (SSD) framework which connects structured state space models (SSMs) and attention variants. It uses a more efficient SSD algorithm that is 2-8x faster than Mamba and modifies the architecture to enable tensor parallelism and a grouped-value attention (GVA) head structure.
You can find all the original Mamba 2 checkpoints under the [State Space Models](https://huggingface.co/state-spaces) organization, but the examples shown below use [mistralai/Mamba-Codestral-7B-v0.1](https://huggingface.co/mistralai/Mamba-Codestral-7B-v0.1) because a Hugging Face implementation isn't supported yet for the original checkpoints.
> [!TIP]
> Click on the Mamba models in the right sidebar for more examples of how to apply Mamba to different language tasks.
The example below demonstrates how to generate text with [`Pipeline`], [`AutoModel`], and from the command line.
hfoptions id="usage">
```python
import torch
from transformers import pipeline
pipeline = pipeline(
task="text-generation",
model="mistralai/Mamba-Codestral-7B-v0.1",
torch_dtype=torch.bfloat16,
device=0
)
pipeline("Plants create energy through a process known as")
```
```python
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mamba-Codestral-7B-v0.1")
model = AutoModelForCausalLM.from_pretrained("mistralai/Mamba-Codestral-7B-v0.1", torch_dtype=torch.bfloat16, device_map="auto")
input_ids = tokenizer("Plants create energy through a process known as", return_tensors="pt").to("cuda")
output = model.generate(**input_ids)
print(tokenizer.decode(output[0], skip_special_tokens=True))
```
```bash
echo -e "Plants create energy through a process known as" | transformers-cli run --task text-generation --model mistralai/Mamba-Codestral-7B-v0.1 --device 0
```
Quantization reduces the memory burden of large models by representing the weights in a lower precision. Refer to the [Quantization](../quantization/overview) overview for more available quantization backends.
The example below uses [torchao](../quantization/torchao) to only quantize the weights to 4-bit integers.
```py
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TorchAoConfig
quantization_config = TorchAoConfig("int4_weight_only", group_size=128)
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mamba-Codestral-7B-v0.1")
model = AutoModelForCausalLM.from_pretrained("mistralai/Mamba-Codestral-7B-v0.1", torch_dtype=torch.bfloat16, quantization_config=quantization_config, device_map="auto")
input_ids = tokenizer("Plants create energy through a process known as", return_tensors="pt").to("cuda")
output = model.generate(**input_ids)
print(tokenizer.decode(output[0], skip_special_tokens=True))
```
## Notes
- Codestral Mamba has `groups=8` which are similar to the number of kv heads in an attention-based model.
- Codestral Mamba has two different forward passes, `torch_forward` or `cuda_kernels_forward`, and their results are expected to be slightly different.
- `torch_forward` without compilation is 3-4x faster than `cuda_kernels_forward`.
- `cuda_kernels_forward` uses the original CUDA kernels if they're available in your environment. It is slower during prefill because it requires a "warmup run" due to the higher CPU overhead (see [these](https://github.com/state-spaces/mamba/issues/389#issuecomment-2171755306) [comments](https://github.com/state-spaces/mamba/issues/355#issuecomment-2147597457) for more details).
- There are no positional embeddings in this model, but there is an `attention_mask` and a specific logic to mask out hidden states in two places in the case of batched generation (see this [comment](https://github.com/state-spaces/mamba/issues/66#issuecomment-1863563829) for more details). This (and the addition of the reimplemented Mamba 2 kernels) results in a slight discrepancy between batched and cached generation.
- The SSM algorithm heavily relies on tensor contractions, which have matmul equivalents but the order of operations is slightly different. This makes the difference greater at smaller precisions.
- Hidden states that correspond to padding tokens is shutdown in 2 places and is mostly tested with left-padding. Right-padding propagates noise down the line and is not guaranteed to yield satisfactory results. `tokenizer.padding_side = "left"` ensures you are using the correct padding side.
- The example below demonstrates how to fine-tune Mamba 2 with [PEFT](https://huggingface.co/docs/peft).
```python
from trl import SFTTrainer
from peft import LoraConfig
from transformers import AutoTokenizer, Mamba2ForCausalLM, TrainingArguments
model_id = 'mistralai/Mamba-Codestral-7B-v0.1'
tokenizer = AutoTokenizer.from_pretrained(model_id, revision='refs/pr/9', from_slow=True, legacy=False)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left" #enforce padding side left
model = Mamba2ForCausalLM.from_pretrained(model_id, revision='refs/pr/9')
dataset = load_dataset("Abirate/english_quotes", split="train")
# Without CUDA kernels, batch size of 2 occupies one 80GB device
# but precision can be reduced.
# Experiments and trials welcome!
training_args = TrainingArguments(
output_dir="./results",
num_train_epochs=3,
per_device_train_batch_size=2,
logging_dir='./logs',
logging_steps=10,
learning_rate=2e-3
)
lora_config = LoraConfig(
r=8,
target_modules=["embeddings", "in_proj", "out_proj"],
task_type="CAUSAL_LM",
bias="none"
)
trainer = SFTTrainer(
model=model,
tokenizer=tokenizer,
args=training_args,
peft_config=lora_config,
train_dataset=dataset,
dataset_text_field="quote",
)
trainer.train()
```
## Mamba2Config
[[autodoc]] Mamba2Config
## Mamba2Model
[[autodoc]] Mamba2Model
- forward
## Mamba2LMHeadModel
[[autodoc]] Mamba2ForCausalLM
- forward