# Bamba
PyTorch FlashAttention SDPA
## Overview Bamba-9B is a decoder-only language model based on the [Mamba-2](https://github.com/state-spaces/mamba) architecture and is designed to handle a wide range of text generation tasks. It is trained from scratch using a two-stage training approach. In the first stage, the model is trained on 2 trillion tokens from the Dolma v1.7 dataset. In the second stage, it undergoes additional training on 200 billion tokens, leveraging a carefully curated blend of high-quality data to further refine its performance and enhance output quality. Checkout all Bamba-9B model checkpoints [here](https://github.com/foundation-model-stack/bamba). ## BambaConfig | Model | Params | # Layers | Hidden Dim. | Attention Heads | GQA | KV Heads | Context Length | Tied Embeddings | |-------------------|--------------|----------|-------------|-----------------|-----|----------|----------------|------------------| | Bamba | 9B (9.78B) | 32 | 4096 | 32 | Yes | 8 | 4096 | True | [[autodoc]] BambaConfig ## BambaForCausalLM ```python from transformers import AutoModelForCausalLM, AutoTokenizer model = AutoModelForCausalLM.from_pretrained("ibm-fms/Bamba-9B") tokenizer = AutoTokenizer.from_pretrained("ibm-fms/Bamba-9B") message = ["Mamba is a snake with following properties "] inputs = tokenizer(message, return_tensors='pt', return_token_type_ids=False) response = model.generate(**inputs, max_new_tokens=64) print(tokenizer.batch_decode(response, skip_special_tokens=True)[0]) ``` [[autodoc]] BambaForCausalLM - forward This HF implementation is contributed by [ani300](https://github.com/ani300) and [fabianlim](https://github.com/fabianlim).