mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-05 13:50:13 +06:00

* Add jamba arch * apply "make fix-copies" changes * fix link to model in JambaConfig docstring * Add n_ctx in modeling file because repo-consistency wants that * Add jamba to flash attention and sdpa documentation * mamba dt_proj quant fix now works for LoRA as well * override test_left_padding_compatibility and use a more permissive tolerance. left padding numerical difference are accentuated by mamba layers * add jamba to tokenization auto * fix comments of shape (PR #24 in the model page: https://huggingface.co/ai21labs/Jamba-v0.1/discussions/24) * simple PR fixes * remove unnecessary kwargs from JambaAttentionDecoderLayer and JambaMambaDecoderLayer * remove the LoRA hack for the mamba dt_proj bias. It was solved in huggingface/peft#1530 (https://github.com/huggingface/peft/pull/1530) * Add copied comment on JambaMLP (it's the same as MixtralMLP) * remove padding_mask warnings. It's not supported anymore * fix docstring. Float instead of int * A few more minor PR fixes * (1) lowercase names for mamba layernorms (2) remove _apply_inner_layernorms and do it directly in the forward pass * Return None attention weights from mamba layers. Append to all attentions only if not None. * remove some leftover jamba archive lists * Better separation between expert vs non-expert layers. non-expert layers return None as router_logits, and it is not concatenated to all_router_logits returned from JambaModel * no need to take router_logits at config.expert_layer_offset anymore. result.router_logits now holds results only for expert layers * Add Jamba paper on READMEs * (1) rename n_ctx -> max_position_embeddings (2) don't use it in the modeling file since it's not needed (set it as an exception to check_config_attributes) * Add copied from comment * remove the code path for apply_inner_layernorms=False. Jamba always has the inner mamba layernorms * clearer docstring for _convert_to_standard_cache * style fixes * Change calc_logits_for_entire_prompt (bool) to num_logits_to_keep (int). Adapt assisted decoding code tp use it. Also small change in low memory beam search decoding path to support this new int value in model_inputs * rename test so it still overrides what its meant to override * draft * oups * nit * remove more complexe logic * fix names used in config * fix fix fix * style * fix some more failing tests * generate did not init the cache 🙃 * more small nits * typo * config.mamba_expand * config.hidden_size for the intermediate size of the mamba shapes * fix init of pkv with torch.tensor() * empty tensor * fix some init issues * stupid changes required by generate because it does not even support it's own DynamicCache class * more fixes * fix general assisted gen cache_position bug * tests passing * Add offsets and periods as SPECIAL_CASES_TO_ALLOW in check_config_attributes.py * fix reorder_cache to reorder mamba states and override some more functions in HybridMambaAttentionDynamicCache * no need to override test_past_key_values_format() and _check_past_key_values_for_generate() in tests anymore * fix docstrings and typehints for past_key_values * style fixes * fix docs * change typehint due to copy from Mixtral * forgot import * import order * Add configuration_jamba and modeling_jamba to not_doctested because the model is too big to download (in docstring of JambaForCausalLM.forward) * Add integration test with tiny tandom Jamba model on hub * fix flash attention cache shapes * bring back forgotten hidden states * rename HybridMambaAttentionDynamicCache.seqlen_offset to has_previous_state (and make bool) and bugfix - it should be set to True after a finished forward pass of the entire model * align integration test after modeling fixes * bugfix - mamba can use precomputed states only of forward pass is on a single token * bugfix - mamba can use precomputed states only if they match the batch size * typo * remove making _prepare_4d_causal_attention_mask a leaf function * stop using past_seq_len.get_seq_length(). Use cache positions instead. Adjust test (test_decoder_model_past_with_large_inputs) accordingly --------- Co-authored-by: Arthur Zucker <arthur.zucker@gmail.com> Co-authored-by: Joao Gante <joao@huggingface.co>
123 lines
6.2 KiB
Markdown
123 lines
6.2 KiB
Markdown
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
|
the License. You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
|
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
|
specific language governing permissions and limitations under the License.
|
|
|
|
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
|
rendered properly in your Markdown viewer.
|
|
|
|
-->
|
|
|
|
# Jamba
|
|
|
|
## Overview
|
|
|
|
Jamba is a state-of-the-art, hybrid SSM-Transformer LLM. It is the first production-scale Mamba implementation, which opens up interesting research and application opportunities. While this initial experimentation shows encouraging gains, we expect these to be further enhanced with future optimizations and explorations.
|
|
|
|
For full details of this model please read the [release blog post](https://www.ai21.com/blog/announcing-jamba).
|
|
|
|
### Model Details
|
|
|
|
Jamba is a pretrained, mixture-of-experts (MoE) generative text model, with 12B active parameters and an overall of 52B parameters across all experts. It supports a 256K context length, and can fit up to 140K tokens on a single 80GB GPU.
|
|
|
|
As depicted in the diagram below, Jamba's architecture features a blocks-and-layers approach that allows Jamba to successfully integrate Transformer and Mamba architectures altogether. Each Jamba block contains either an attention or a Mamba layer, followed by a multi-layer perceptron (MLP), producing an overall ratio of one Transformer layer out of every eight total layers.
|
|
|
|
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/jamba_architecture.png"
|
|
alt="drawing" width="600"/>
|
|
|
|
## Usage
|
|
|
|
### Presequities
|
|
|
|
Jamba requires you use `transformers` version 4.39.0 or higher:
|
|
```bash
|
|
pip install transformers>=4.39.0
|
|
```
|
|
|
|
In order to run optimized Mamba implementations, you first need to install `mamba-ssm` and `causal-conv1d`:
|
|
```bash
|
|
pip install mamba-ssm causal-conv1d>=1.2.0
|
|
```
|
|
You also have to have the model on a CUDA device.
|
|
|
|
You can run the model not using the optimized Mamba kernels, but it is **not** recommended as it will result in significantly lower latencies. In order to do that, you'll need to specify `use_mamba_kernels=False` when loading the model.
|
|
|
|
### Run the model
|
|
```python
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
|
|
model = AutoModelForCausalLM.from_pretrained("ai21labs/Jamba-v0.1")
|
|
tokenizer = AutoTokenizer.from_pretrained("ai21labs/Jamba-v0.1")
|
|
|
|
input_ids = tokenizer("In the recent Super Bowl LVIII,", return_tensors='pt').to(model.device)["input_ids"]
|
|
|
|
outputs = model.generate(input_ids, max_new_tokens=216)
|
|
|
|
print(tokenizer.batch_decode(outputs))
|
|
# ["<|startoftext|>In the recent Super Bowl LVIII, the Kansas City Chiefs emerged victorious, defeating the San Francisco 49ers in a thrilling overtime showdown. The game was a nail-biter, with both teams showcasing their skills and determination.\n\nThe Chiefs, led by their star quarterback Patrick Mahomes, displayed their offensive prowess, while the 49ers, led by their strong defense, put up a tough fight. The game went into overtime, with the Chiefs ultimately securing the win with a touchdown.\n\nThe victory marked the Chiefs' second Super Bowl win in four years, solidifying their status as one of the top teams in the NFL. The game was a testament to the skill and talent of both teams, and a thrilling end to the NFL season.\n\nThe Super Bowl is not just about the game itself, but also about the halftime show and the commercials. This year's halftime show featured a star-studded lineup, including Usher, Alicia Keys, and Lil Jon. The show was a spectacle of music and dance, with the performers delivering an energetic and entertaining performance.\n"]
|
|
```
|
|
|
|
<details>
|
|
<summary><strong>Loading the model in half precision</strong></summary>
|
|
|
|
The published checkpoint is saved in BF16. In order to load it into RAM in BF16/FP16, you need to specify `torch_dtype`:
|
|
|
|
```python
|
|
from transformers import AutoModelForCausalLM
|
|
import torch
|
|
model = AutoModelForCausalLM.from_pretrained("ai21labs/Jamba-v0.1", torch_dtype=torch.bfloat16)
|
|
# you can also use torch_dtype=torch.float16
|
|
```
|
|
|
|
When using half precision, you can enable the [FlashAttention2](https://github.com/Dao-AILab/flash-attention) implementation of the Attention blocks. In order to use it, you also need the model on a CUDA device. Since in this precision the model is to big to fit on a single 80GB GPU, you'll also need to parallelize it using [accelerate](https://huggingface.co/docs/accelerate/index):
|
|
```python
|
|
from transformers import AutoModelForCausalLM
|
|
import torch
|
|
model = AutoModelForCausalLM.from_pretrained("ai21labs/Jamba-v0.1",
|
|
torch_dtype=torch.bfloat16,
|
|
attn_implementation="flash_attention_2",
|
|
device_map="auto")
|
|
```
|
|
|
|
</details>
|
|
<details><summary><strong>Load the model in 8-bit</strong></summary>
|
|
|
|
**Using 8-bit precision, it is possible to fit up to 140K sequence lengths on a single 80GB GPU.** You can easily quantize the model to 8-bit using [bitsandbytes](https://huggingface.co/docs/bitsandbytes/index). In order to not degrade model quality, we recommend to exclude the Mamba blocks from the quantization:
|
|
|
|
```python
|
|
from transformers import AutoModelForCausalLM, BitsAndBytesConfig
|
|
quantization_config = BitsAndBytesConfig(load_in_8bit=True, llm_int8_skip_modules=["mamba"])
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
"ai21labs/Jamba-v0.1", torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2", quantization_config=quantization_config
|
|
)
|
|
```
|
|
</details>
|
|
|
|
## JambaConfig
|
|
|
|
[[autodoc]] JambaConfig
|
|
|
|
|
|
## JambaModel
|
|
|
|
[[autodoc]] JambaModel
|
|
- forward
|
|
|
|
|
|
## JambaForCausalLM
|
|
|
|
[[autodoc]] JambaForCausalLM
|
|
- forward
|
|
|
|
|
|
## JambaForSequenceClassification
|
|
|
|
[[autodoc]] transformers.JambaForSequenceClassification
|
|
- forward
|