mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-30 01:32:23 +06:00

* initial-commit
* start cleaning
* small nits
* small nits
* current updates
* add kernels
* small refactoring little step
* add comments
* styling
* nit
* nits
* Style
* Small changes
* Push dummy mambda simple slow
* nit
* Use original names
* Use original names and remove norm
* Updates for inference params
* Style nd updates
* nits
* Match logits
* Add a test
* Add expected generated text
* nits doc, imports and styling
* style
* oups
* dont install kernels, invite users to install the required kernels
* let use use the original packages
* styling
* nits
* fix some copieds
* update doc
* fix-copies
* styling done
* nits
* fix import check
* run but wrong cuda ress
* mamba CUDA works :)
* fix the fast path
* config naming nits
* conversion script is not required at this stage
* finish fixing the fast path: generation make sense now!
* nit
* Let's start working on the CIs
* style
* better style
* more nits
* test nit
* quick fix for now
* nits
* nit
* nit
* nit
* nits
* update test rest
* fixup
* update test
* nit
* some fixes
* nits
* update test values
* fix styling
* nit
* support peft
* integrations tests require torchg
* also add slow markers
* styling
* chose forward wisely
* nits
* update tests
* fix gradient checkpointing
* fixup
* nit
* fix doc
* check copies
* fix the docstring
* fix some more tests
* style
* fix beam search
* add init schene
* update
* nit
* fix
* fixup the doc
* fix the doc
* fixup
* tentative update but slow is no longer good
* nit
* should we always use float32?
* nits
* revert wrong changes
* res in float32
* cleanup
* skip fmt for now
* update generation values
* update test values running original model
* fixup
* update tests + rename inference_params to cache_params + make sure training does not use cache_params
* small nits
* more nits
* fix final CIs
* style
* nit doc
* I hope final doc nits
* nit
* 🫠
* final touch!
* fix torch import
* Apply suggestions from code review
Co-authored-by: Lysandre Debut <hi@lysand.re>
* Apply suggestions from code review
* fix fix and fix
* fix base model prefix!
* nit
* Update src/transformers/models/mamba/__init__.py
* Update docs/source/en/model_doc/mamba.md
Co-authored-by: Lysandre Debut <hi@lysand.re>
* nit
---------
Co-authored-by: Lysandre Debut <hi@lysand.re>
108 lines
5.6 KiB
Markdown
108 lines
5.6 KiB
Markdown
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
|
||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||
the License. You may obtain a copy of the License at
|
||
|
||
http://www.apache.org/licenses/LICENSE-2.0
|
||
|
||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||
specific language governing permissions and limitations under the License.
|
||
|
||
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
||
rendered properly in your Markdown viewer.
|
||
|
||
-->
|
||
|
||
# Mamba
|
||
|
||
## Overview
|
||
|
||
The Mamba model was proposed in [Mamba: Linear-Time Sequence Modeling with Selective State Spaces](https://arxiv.org/abs/2312.00752) by Albert Gu and Tri Dao.
|
||
|
||
This model is a new paradigm architecture based on `state-space-models`. You can read more about the intuition behind these [here](https://srush.github.io/annotated-s4/).
|
||
|
||
The abstract from the paper is the following:
|
||
|
||
*Foundation models, now powering most of the exciting applications in deep learning, are almost universally based on the Transformer architecture and its core attention module. Many subquadratic-time architectures such as linear attention, gated convolution and recurrent models, and structured state space models (SSMs) have been developed to address Transformers' computational inefficiency on long sequences, but they have not performed as well as attention on important modalities such as language. We identify that a key weakness of such models is their inability to perform content-based reasoning, and make several improvements. First, simply letting the SSM parameters be functions of the input addresses their weakness with discrete modalities, allowing the model to selectively propagate or forget information along the sequence length dimension depending on the current token. Second, even though this change prevents the use of efficient convolutions, we design a hardware-aware parallel algorithm in recurrent mode. We integrate these selective SSMs into a simplified end-to-end neural network architecture without attention or even MLP blocks (Mamba). Mamba enjoys fast inference (5× higher throughput than Transformers) and linear scaling in sequence length, and its performance improves on real data up to million-length sequences. As a general sequence model backbone, Mamba achieves state-of-the-art performance across several modalities such as language, audio, and genomics. On language modeling, our Mamba-3B model outperforms Transformers of the same size and matches Transformers twice its size, both in pretraining and downstream evaluation.*
|
||
|
||
Tips:
|
||
|
||
- Mamba is a new `state space model` architecture that rivals the classic Transformers. It is based on the line of progress on structured state space models, with an efficient hardware-aware design and implementation in the spirit of [FlashAttention](https://github.com/Dao-AILab/flash-attention).
|
||
- Mamba stacks `mixer` layers, which are the equivalent of `Attention` layers. The core logic of `mamba` is held in the `MambaMixer` class.
|
||
- Two implementations cohabit: one is optimized and uses fast cuda kernels, while the other one is naive but can run on any device!
|
||
- The current implementation leverages the original cuda kernels: the equivalent of flash attention for Mamba are hosted in the [`mamba-ssm`](https://github.com/state-spaces/mamba) and the [`causal_conv1d`](https://github.com/Dao-AILab/causal-conv1d) repositories. Make sure to install them if your hardware supports them!
|
||
- Contributions to make the naive path faster are welcome 🤗
|
||
|
||
This model was contributed by [ArthurZ](https://huggingface.co/ArthurZ).
|
||
The original code can be found [here](https://github.com/state-spaces/mamba).
|
||
|
||
# Usage
|
||
|
||
### A simple generation example:
|
||
```python
|
||
from transformers import MambaConfig, MambaForCausalLM, AutoTokenizer
|
||
import torch
|
||
|
||
tokenizer = AutoTokenizer.from_pretrained("ArthurZ/mamba-130m")
|
||
tokenizer.pad_token = tokenizer.eos_token
|
||
|
||
model = MambaForCausalLM.from_pretrained("ArthurZ/mamba-130m", vocab_size=50280, num_hidden_layers=24, torch_dtype=torch.float32)
|
||
model.config.use_cache = True
|
||
input_ids = tokenizer("Hey how are you doing?", return_tensors= "pt")["input_ids"]
|
||
|
||
out = model.generate(input_ids, max_new_tokens=10)
|
||
print(tokenizer.batch_decode(out))
|
||
```
|
||
|
||
### Peft finetuning
|
||
The slow version is not very stable for training, and the fast one needs `float32`!
|
||
|
||
```python
|
||
from datasets import load_dataset
|
||
from trl import SFTTrainer
|
||
from peft import LoraConfig
|
||
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments
|
||
model_id = "ArthurZ/mamba-2.8b"
|
||
tokenizer = AutoTokenizer.from_pretrained(model_id, pad_token ="<s>")
|
||
model = AutoModelForCausalLM.from_pretrained(model_id)
|
||
dataset = load_dataset("Abirate/english_quotes", split="train")
|
||
training_args = TrainingArguments(
|
||
output_dir="./results",
|
||
num_train_epochs=3,
|
||
per_device_train_batch_size=4,
|
||
logging_dir='./logs',
|
||
logging_steps=10,
|
||
learning_rate=2e-3
|
||
)
|
||
lora_config = LoraConfig(
|
||
r=8,
|
||
target_modules="all-linear",
|
||
task_type="CAUSAL_LM",
|
||
bias="none"
|
||
)
|
||
trainer = SFTTrainer(
|
||
model=model,
|
||
tokenizer=tokenizer,
|
||
args=training_args,
|
||
peft_config=lora_config,
|
||
train_dataset=dataset,
|
||
dataset_text_field="quote",
|
||
)
|
||
trainer.train()
|
||
```
|
||
|
||
## MambaConfig
|
||
|
||
[[autodoc]] MambaConfig
|
||
|
||
## MambaModel
|
||
|
||
[[autodoc]] MambaModel
|
||
- forward
|
||
|
||
## MambaLMHeadModel
|
||
|
||
[[autodoc]] MambaForCausalLM
|
||
- forward
|