transformers/docs/source/en/model_doc/bamba.md
Garrett Goon 390f153469
Add padding-free to bamba (#35861)
* add seq_idx and fa kwargs

* update tests

* docs and grad ckpt support

* fmt

* better names

* test_raise_missing_padding_free_kwarg_errs

* + seq_idx in doc strings

* padding free training docs

* add link to pr plots

* raise err on attn_mask with padding free

* rm raising missing padding free err test

* BambaFlashAttentionKwargs

* run modular util for modular_granitemoehybrid.py
2025-05-20 17:13:59 +02:00

4.6 KiB

Bamba

PyTorch FlashAttention SDPA

Overview

Bamba-9B is a decoder-only language model based on the Mamba-2 architecture and is designed to handle a wide range of text generation tasks. It is trained from scratch using a two-stage training approach. In the first stage, the model is trained on 2 trillion tokens from the Dolma v1.7 dataset. In the second stage, it undergoes additional training on 200 billion tokens, leveraging a carefully curated blend of high-quality data to further refine its performance and enhance output quality.

Checkout all Bamba-9B model checkpoints here.

BambaConfig

Model Params # Layers Hidden Dim. Attention Heads GQA KV Heads Context Length Tied Embeddings
Bamba 9B (9.78B) 32 4096 32 Yes 8 4096 True

autodoc BambaConfig

BambaForCausalLM

from transformers import AutoModelForCausalLM, AutoTokenizer

model = AutoModelForCausalLM.from_pretrained("ibm-fms/Bamba-9B")
tokenizer = AutoTokenizer.from_pretrained("ibm-fms/Bamba-9B")

message = ["Mamba is a snake with following properties  "]
inputs = tokenizer(message, return_tensors='pt', return_token_type_ids=False)
response = model.generate(**inputs, max_new_tokens=64)
print(tokenizer.batch_decode(response, skip_special_tokens=True)[0])

Padding-Free Training

Bamba supports padding-free training in which distinct training examples can be concatenated together while nevertheless processing the inputs as though they belonged to separate batches. When the examples are of varying lengths, padding-free training can provide significant speed ups and memory savings compared to batching the examples together and using padding, as the unnecessary compute and memory due to padding is avoided entirely. The performance gains depend on factors such as the model and the data distribution, but throughput gains up to ~2x are commonly seen.

Using padding-free training with Bamba requires the flash-attn, mamba-ssm, and causal-conv1d packages, and the following arguments must be passed to the model in addition to input_ids and labels:

  • position_ids: torch.LongTensor: the position index of each token in each sequence.
  • seq_idx: torch.IntTensor: the index of each sequence in the batch.
  • Each of the [FlashAttentionKwargs]
    • cu_seq_lens_q: torch.LongTensor: The cumulative sequence lengths of all queries.
    • cu_seq_lens_k: torch.LongTensor: The cumulative sequence lengths of all keys.
    • max_length_q: int: the longest query length in the batch.
    • max_length_k: int: the longest key length in the batch.

The attention_mask inputs should not be provided. The [DataCollatorWithFlattening] can be used to programmatically generate the above set of additional arguments using return_seq_idx=True and return_flash_attn_kwargs=True. See this blog post for additional information.

autodoc BambaForCausalLM - forward

This HF implementation is contributed by ani300 and fabianlim.