
* add seq_idx and fa kwargs * update tests * docs and grad ckpt support * fmt * better names * test_raise_missing_padding_free_kwarg_errs * + seq_idx in doc strings * padding free training docs * add link to pr plots * raise err on attn_mask with padding free * rm raising missing padding free err test * BambaFlashAttentionKwargs * run modular util for modular_granitemoehybrid.py
4.6 KiB
Bamba
Overview
Bamba-9B is a decoder-only language model based on the Mamba-2 architecture and is designed to handle a wide range of text generation tasks. It is trained from scratch using a two-stage training approach. In the first stage, the model is trained on 2 trillion tokens from the Dolma v1.7 dataset. In the second stage, it undergoes additional training on 200 billion tokens, leveraging a carefully curated blend of high-quality data to further refine its performance and enhance output quality.
Checkout all Bamba-9B model checkpoints here.
BambaConfig
Model | Params | # Layers | Hidden Dim. | Attention Heads | GQA | KV Heads | Context Length | Tied Embeddings |
---|---|---|---|---|---|---|---|---|
Bamba | 9B (9.78B) | 32 | 4096 | 32 | Yes | 8 | 4096 | True |
autodoc BambaConfig
BambaForCausalLM
from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained("ibm-fms/Bamba-9B")
tokenizer = AutoTokenizer.from_pretrained("ibm-fms/Bamba-9B")
message = ["Mamba is a snake with following properties "]
inputs = tokenizer(message, return_tensors='pt', return_token_type_ids=False)
response = model.generate(**inputs, max_new_tokens=64)
print(tokenizer.batch_decode(response, skip_special_tokens=True)[0])
Padding-Free Training
Bamba supports padding-free training in which distinct training examples can be concatenated together while nevertheless processing the inputs as though they belonged to separate batches. When the examples are of varying lengths, padding-free training can provide significant speed ups and memory savings compared to batching the examples together and using padding, as the unnecessary compute and memory due to padding is avoided entirely. The performance gains depend on factors such as the model and the data distribution, but throughput gains up to ~2x are commonly seen.
Using padding-free training with Bamba requires the flash-attn
, mamba-ssm
, and causal-conv1d
packages, and the following arguments must be passed to the model in addition to input_ids
and
labels
:
position_ids: torch.LongTensor
: the position index of each token in each sequence.seq_idx: torch.IntTensor
: the index of each sequence in the batch.- Each of the [
FlashAttentionKwargs
]cu_seq_lens_q: torch.LongTensor
: The cumulative sequence lengths of all queries.cu_seq_lens_k: torch.LongTensor
: The cumulative sequence lengths of all keys.max_length_q: int
: the longest query length in the batch.max_length_k: int
: the longest key length in the batch.
The attention_mask
inputs should not be provided. The [DataCollatorWithFlattening
] can be used
to programmatically generate the above set of additional arguments using return_seq_idx=True
and
return_flash_attn_kwargs=True
. See this blog post
for additional information.
autodoc BambaForCausalLM - forward
This HF implementation is contributed by ani300 and fabianlim.