transformers/docs/source/en/model_doc/mamba.md
Parag Ekbote 28d3148b07
Some checks are pending
Self-hosted runner (benchmark) / Benchmark (aws-g5-4xlarge-cache) (push) Waiting to run
Build documentation / build (push) Waiting to run
New model PR merged notification / Notify new model (push) Waiting to run
Slow tests on important models (on Push - A10) / Get all modified files (push) Waiting to run
Slow tests on important models (on Push - A10) / Slow & FA2 tests (push) Blocked by required conditions
Self-hosted runner (push-caller) / Check if setup was changed (push) Waiting to run
Self-hosted runner (push-caller) / build-docker-containers (push) Blocked by required conditions
Self-hosted runner (push-caller) / Trigger Push CI (push) Blocked by required conditions
Secret Leaks / trufflehog (push) Waiting to run
Update Transformers metadata / build_and_package (push) Waiting to run
Update Model Card for Mamba (#37863)
* update model card.

* Apply suggestions from code review

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* update quantization example.

* update example.

* update

---------

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
2025-05-21 10:58:23 -07:00

5.7 KiB

PyTorch

Mamba

Mamba is a selective structured state space model (SSMs) designed to work around Transformers computational inefficiency when dealing with long sequences. It is a completely attention-free architecture, and comprised of a combination of H3 and gated MLP blocks (Mamba block). Mamba's "content-based reasoning" allows it to focus on specific parts of an input depending on the current token. Mamba also uses a new hardware-aware parallel algorithm to compensate for the lack of convolutional operations. As a result, Mamba has fast inference and can scale to very long sequences.

You can find all the original Mamba checkpoints under the State Space Models organization.

Tip

Click on the Mamba models in the right sidebar for more examples of how to apply Mamba to different language tasks.

The example below demonstrates how to generate text with [Pipeline], [AutoModel], and from the command line.

import torch
from transformers import pipeline

pipeline = pipeline(
    task="text-generation",
    model="state-spaces/mamba-130m-hf",
    torch_dtype=torch.float16,
    device=0
)
pipeline("Plants create energy through a process known as")
import torch  
from transformers import AutoModelForCausalLM, AutoTokenizer  

tokenizer = AutoTokenizer.from_pretrained("state-spaces/mamba-130m-hf")
model = AutoModelForCausalLM.from_pretrained("state-spaces/mamba-130m-hf", torch_dtype=torch.float16, device_map="auto",)  
input_ids = tokenizer("Plants create energy through a process known as", return_tensors="pt").to("cuda")  

output = model.generate(**input_ids)  
print(tokenizer.decode(output[0], skip_special_tokens=True)
echo -e "Plants create energy through a process known as" | transformers run --task text-generation --model state-spaces/mamba-130m-hf --device 0

Quantization reduces the memory burden of large models by representing the weights in a lower precision. Refer to the Quantization overview for more available quantization backends.

The example below uses torchao to only quantize the weights to 4-bit integers.

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TorchAoConfig
from torchao.quantization import Int4WeightOnlyConfig

quantization_config = Int4WeightOnlyConfig(group_size=128)
quantization_config = TorchAoConfig(quant_type=quant_config)
tokenizer = AutoTokenizer.from_pretrained("state-spaces/mamba-2.8b-hf")
model = AutoModelForCausalLM.from_pretrained("state-spaces/mamba-2.8b-hf", torch_dtype=torch.bfloat16, quantization_config=quantization_config, device_map="auto",)
input_ids = tokenizer("Plants create energy through a process known as", return_tensors="pt").to("cuda")

output = model.generate(**input_ids)
print(tokenizer.decode(output[0], skip_special_tokens=True))

Notes

  • The current implementation uses the original CUDA kernels. The FlashAttention equivalent implementation is hosted in the mamba-ssm and causal_conv1d repositories. Make sure to install them if your hardware supports it!

  • Mamba stacks mixer layers which are equivalent to Attention layers. You can find the main logic of Mamba in the MambaMixer class.

  • The example below demonstrates how to fine-tune Mamba with PEFT.

    from datasets import load_dataset
    from trl import SFTTrainer
    from peft import LoraConfig
    from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments
    
    model_id = "state-spaces/mamba-130m-hf"
    tokenizer = AutoTokenizer.from_pretrained(model_id)
    model = AutoModelForCausalLM.from_pretrained(model_id)
    dataset = load_dataset("Abirate/english_quotes", split="train")
    training_args = TrainingArguments(
        output_dir="./results",
        num_train_epochs=3,
        per_device_train_batch_size=4,
        logging_dir='./logs',
        logging_steps=10,
        learning_rate=2e-3
    )
    lora_config =  LoraConfig(
            r=8,
            target_modules=["x_proj", "embeddings", "in_proj", "out_proj"],
            task_type="CAUSAL_LM",
            bias="none"
    )
    trainer = SFTTrainer(
        model=model,
        processing_class=tokenizer,
       args=training_args,
        peft_config=lora_config,
        train_dataset=dataset,
        dataset_text_field="quote",
    )
    trainer.train()
    

MambaConfig

autodoc MambaConfig

MambaModel

autodoc MambaModel - forward

MambaLMHeadModel

autodoc MambaForCausalLM - forward