# Bamba
PyTorch FlashAttention SDPA
## Overview Bamba-9B is a decoder-only language model based on the [Mamba-2](https://github.com/state-spaces/mamba) architecture and is designed to handle a wide range of text generation tasks. It is trained from scratch using a two-stage training approach. In the first stage, the model is trained on 2 trillion tokens from the Dolma v1.7 dataset. In the second stage, it undergoes additional training on 200 billion tokens, leveraging a carefully curated blend of high-quality data to further refine its performance and enhance output quality. Checkout all Bamba-9B model checkpoints [here](https://github.com/foundation-model-stack/bamba). ## BambaConfig | Model | Params | # Layers | Hidden Dim. | Attention Heads | GQA | KV Heads | Context Length | Tied Embeddings | |-------------------|--------------|----------|-------------|-----------------|-----|----------|----------------|------------------| | Bamba | 9B (9.78B) | 32 | 4096 | 32 | Yes | 8 | 4096 | True | [[autodoc]] BambaConfig ## BambaForCausalLM ```python from transformers import AutoModelForCausalLM, AutoTokenizer model = AutoModelForCausalLM.from_pretrained("ibm-fms/Bamba-9B") tokenizer = AutoTokenizer.from_pretrained("ibm-fms/Bamba-9B") message = ["Mamba is a snake with following properties "] inputs = tokenizer(message, return_tensors='pt', return_token_type_ids=False) response = model.generate(**inputs, max_new_tokens=64) print(tokenizer.batch_decode(response, skip_special_tokens=True)[0]) ``` ## Padding-Free Training Bamba supports padding-free training in which distinct training examples can be concatenated together while nevertheless processing the inputs as though they belonged to separate batches. When the examples are of varying lengths, padding-free training can provide significant speed ups and memory savings compared to batching the examples together and using padding, as the unnecessary compute and memory due to padding is avoided entirely. The performance gains depend on factors such as the model and the data distribution, but throughput gains up to [~2x are commonly seen](https://github.com/huggingface/transformers/pull/35861#issue-2807873129). Using padding-free training with Bamba requires the `flash-attn`, `mamba-ssm`, and `causal-conv1d` packages, and the following arguments must be passed to the model in addition to `input_ids` and `labels`: * `position_ids: torch.LongTensor`: the position index of each token in each sequence. * `seq_idx: torch.IntTensor`: the index of each sequence in the batch. * Each of the [`FlashAttentionKwargs`] * `cu_seq_lens_q: torch.LongTensor`: The cumulative sequence lengths of all queries. * `cu_seq_lens_k: torch.LongTensor`: The cumulative sequence lengths of all keys. * `max_length_q: int`: the longest query length in the batch. * `max_length_k: int`: the longest key length in the batch. The `attention_mask` inputs should not be provided. The [`DataCollatorWithFlattening`] can be used to programmatically generate the above set of additional arguments using `return_seq_idx=True` and `return_flash_attn_kwargs=True`. See [this blog post](https://huggingface.co/blog/packing-with-FA2) for additional information. [[autodoc]] BambaForCausalLM - forward This HF implementation is contributed by [ani300](https://github.com/ani300) and [fabianlim](https://github.com/fabianlim).