mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-05 22:00:09 +06:00

* remove one of the last deps * update fast image processor after refactor * styling * more quality of life improvements * nit * update * cleanups * some cleanups * vllm updates * update fake image token * [convert] Fix typo * [convert] Strip extraneous bytes from shards * [convert] Minor fixes * [convert] Use num_experts * multi-image fixes in modeling + processor * fixup size * 128 experts * Use default rope * Unfuse mlp * simplify a lot inputs embeds merging * remove .item() 👀 * fix from review * Address feedback * Use None "default" for rope_scaling. Add eot. * set seed * return aspect ratios and bug fixes * Moe 128 rebased (#8) * 128 experts * Use default rope * Unfuse mlp * Address feedback * Use None "default" for rope_scaling. Add eot. * Meta/llama quant compat (#7) * add quant compatible model & conversion code for llama4 * fix a few issues * fix a few issues * minor type mapping fix --------- Co-authored-by: Lu Fang <fanglu@fb.com> * use a new config parameter to determine which model definition to use for MoE --------- Co-authored-by: Pedro Cuenca <pedro@huggingface.co> Co-authored-by: Lu Fang <fanglu@fb.com> * un-comment write_tokenizer from converting script * remove un-used imports * [llama4] Pop aspect_ratios from image processor output in Llama4Processor Signed-off-by: Jon Swenson <jmswen@gmail.com> * Fix parameter_count name * Update src/transformers/models/llama4/configuration_llama4.py * nit * Add changes for no_rope, moe_layers, chunked attention. Just need to test all * Update src/transformers/models/llama4/image_processing_llama4_fast.py * nit * fix post merge with main * support flex attention * fixes * fix * add layer * small updates * rebase and delete llm_compressor * nit * [llama4/mm] Add back <|image|> token that delimits global tile * [llama4/mm] Fix Llama 4 image processing unit tests * add explicit dtype Signed-off-by: Jon Swenson <jmswen@gmail.com> * sdpa works * comment todo small * fix model loading Signed-off-by: Zijing Liu <liuzijing2014@gmail.com> * revert * nits * small fix for TP on 1 node * Read new params from config * Add <|eom|> * lol don't know how this got here * adding fp8 * Save processor, fix chat template * style * Add boi/eoi tokens We don't use them. * fixes for now flex seems to work :) * updates * nits * updates * missking keys * add context parallel * update * update * fix * nits * add worldsize and make eager attn work for vision * Ignore new key present in base models * add tp_plan * fix nope Signed-off-by: Zijing Liu <liuzijing2014@gmail.com> * minor fix Signed-off-by: Zijing Liu <liuzijing2014@gmail.com> * Clean up Llama4 vision model * current updates * add support for `attn_temperature_tuning` * add floor scale * add missing attn scales * push what works, dirty trick for the device synch * oups * Fix pad_token_id See https://huggingface.co/ll-re/Llama-4-Scout-17B-16E/discussions/2/files Confirmed in the original codebase. * fix causallml loading * rm * fix tied-weights * fix sdpa * push current version * should work with both short and long * add compressed_tensos & fix fbgemm tp * Fix flex impl * style * chunking * try to revert the potentially breaking change * fix auto factory * fix shapes in general * rm processing * commit cache utils cleanup * Fix context length * fix * allocate * update tp_plan * fix SDPA! * Add support for sparse `Llama4TextMoe` layer from the kernel hub * cleanup * better merge * update * still broken fixing now * nits * revert print * Write max_position_embeddings and max_model_length * Update modeling_llama4.py * Save attention_chunk_size * Sync eos terminators * Read initializer_range * style * remove `dict` * fix * eager should use `chunked_attention_mask` * revert * fixup * fix config * Revert "Merge pull request #36 from huggingface/sparse-llama4-moe" This reverts commitccda19f050
, reversing changes made toa515579aed
. * Fix typo and remove warning with compiled flex and chunked prefill * Fix MoE vs FF (#41) * fix * Use correct no_rope_layers if provided one is empty list * update tests * fix * skipping some tests * fix fp8 loading Signed-off-by: Zijing Liu <liuzijing2014@gmail.com> * fix text geneartion pipeline Signed-off-by: Zijing Liu <liuzijing2014@gmail.com> * eager needs 4D mask * fix * Some cleanup * fix * update * fix * replace correctly module * patch * modulelist * update * update * clean up * Don't move to `cuda:0` in distributed mode * restrict to compressed tensors for now * rm print * Docs! * Fixes * Update docs/source/en/model_doc/llama4.md Co-authored-by: Pedro Cuenca <pedro@huggingface.co> * Fixes * cuda graph fix * revert some stuff * fixup * styling * Update src/transformers/models/llama4/modeling_llama4.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * fixup * commit licence, cleanup here and there and style * more styling changes * fix dummies * fix and clean docstrings * remove comment * remove warning * Only fast image processor is supported * nit * trigger CI * fix issue with flex encoder * fix dynamic cache * Code quality * Code quality * fix more tests for now * Code quality * Code quality * Nuke bunch of failing stuff * Code quality * Code quality * cleanup removal of slow image processor * ruff fix fast image processor * fix * fix styling * Docs * Repo consistency * Repo consistency * fix sliding window issue * separate llama cache * styling * Repo consistency * Repo consistency * push waht works * L4 Repo consistency * Docs * fix last last alst alst alst alstsaltlsltlaslt --------- Signed-off-by: Jon Swenson <jmswen@gmail.com> Signed-off-by: Zijing Liu <liuzijing2014@gmail.com> Co-authored-by: yonigozlan <yoni.gozlan10@gmail.com> Co-authored-by: Pedro Cuenca <pedro@huggingface.co> Co-authored-by: Pablo Montalvo <pablo.montalvo.leroux@gmail.com> Co-authored-by: Pablo Montalvo <39954772+molbap@users.noreply.github.com> Co-authored-by: Keyun Tong <tongkeyun@gmail.com> Co-authored-by: Zijing Liu <liuzijing2014@users.noreply.github.com> Co-authored-by: Lu Fang <fanglu@fb.com> Co-authored-by: Zijing Liu <liuzijing2014@gmail.com> Co-authored-by: Jon Swenson <jmswen@gmail.com> Co-authored-by: jmswen <jmswen@users.noreply.github.com> Co-authored-by: MekkCyber <mekk.cyber@gmail.com> Co-authored-by: Mohamed Mekkouri <93391238+MekkCyber@users.noreply.github.com> Co-authored-by: Mohit Sharma <mohit21sharma.ms@gmail.com> Co-authored-by: Yong Hoon Shin <yhshin@meta.com> Co-authored-by: Marc Sun <marc@huggingface.co> Co-authored-by: drisspg <drisspguessous@gmail.com> Co-authored-by: Cyril Vallez <cyril.vallez@gmail.com> Co-authored-by: Daniël de Kok <me@danieldk.eu> Co-authored-by: Lysandre <hi@lysand.re> Co-authored-by: Ye (Charlotte) Qi <ye.charlotte.qi@gmail.com> Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
443 lines
13 KiB
Markdown
443 lines
13 KiB
Markdown
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
|
the License. You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
|
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
|
specific language governing permissions and limitations under the License.
|
|
|
|
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
|
rendered properly in your Markdown viewer.
|
|
|
|
-->
|
|
|
|
# Llama4
|
|
|
|
|
|
<div style="float: right;">
|
|
<div class="flex flex-wrap space-x-1">
|
|
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
|
<img alt="FlashAttention" src="https://img.shields.io/badge/%E2%9A%A1%EF%B8%8E%20FlashAttention-eae0c8?style=flat">
|
|
</div>
|
|
</div>
|
|
|
|
Llama 4, developed by Meta, introduces a new auto-regressive Mixture-of-Experts (MoE) architecture.
|
|
This generation includes two models:
|
|
- The highly capable Llama 4 Maverick with 17B active parameters out of ~400B total, with 128 experts.
|
|
- The efficient Llama 4 Scout also has 17B active parameters out of ~109B total, using just 16 experts.
|
|
-
|
|
Both models leverage early fusion for native multimodality, enabling them to process text and image inputs.
|
|
Maverick and Scout are both trained on up to 40 trillion tokens on data encompassing 200 languages
|
|
(with specific fine-tuning support for 12 languages including Arabic, Spanish, German, and Hindi).
|
|
|
|
For deployment, Llama 4 Scout is designed for accessibility, fitting on a single server-grade GPU via
|
|
on-the-fly 4-bit or 8-bitint4 quantization, while Maverick is available in BF16 and FP8 formats.
|
|
These models are released under the custom Llama 4 Community License Agreement, available on the model repositories.
|
|
|
|
You can find all the original Llama checkpoints under the [meta-llama](https://huggingface.co/meta-llama) organization.
|
|
|
|
> [!TIP]
|
|
> The Llama 4 family of models comes in two flavors: 109B, and 402B parameters. Both of these flavors are extremely
|
|
> large and won't fit on your run-of-the-mill device. See below for some examples to reduce the memory usage of the
|
|
> model.
|
|
>
|
|
> For the download to be faster and more resilient, we recommend installing the `hf_xet` dependency as followed:
|
|
> `pip install transformers[hf_xet]`
|
|
|
|
The examples below demonstrates how to generate with [`Pipeline`] or the [`AutoModel`]. We additionally add an example
|
|
showcasing how to toggle the right attributes to enable very long-context generations, as some flavors of Llama 4
|
|
have context lengths going up to 10 million tokens.
|
|
|
|
|
|
<hfoptions id="usage">
|
|
<hfoption id="Pipeline">
|
|
|
|
```py
|
|
from transformers import pipeline
|
|
import torch
|
|
|
|
model_id = "meta-llama/Llama-4-Scout-17B-16E-Instruct"
|
|
|
|
messages = [
|
|
{"role": "user", "content": "what is the recipe of mayonnaise?"},
|
|
]
|
|
|
|
pipe = pipeline(
|
|
"text-generation",
|
|
model=model_id,
|
|
device_map="auto",
|
|
torch_dtype=torch.bfloat16
|
|
)
|
|
|
|
output = pipe(messages, do_sample=False, max_new_tokens=200)
|
|
print(output[0]["generated_text"][-1]["content"])
|
|
```
|
|
|
|
</hfoption>
|
|
<hfoption id="AutoModel - Text only">
|
|
|
|
```py
|
|
from transformers import AutoTokenizer, Llama4ForConditionalGeneration
|
|
import torch
|
|
|
|
model_id = "meta-llama/Llama-4-Scout-17B-16E-Instruct"
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
|
|
|
messages = [
|
|
{"role": "user", "content": "Who are you?"},
|
|
]
|
|
inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt", return_dict=True)
|
|
|
|
model = Llama4ForConditionalGeneration.from_pretrained(
|
|
model_id,
|
|
device_map="auto",
|
|
torch_dtype=torch.bfloat16
|
|
)
|
|
|
|
outputs = model.generate(**inputs.to(model.device), max_new_tokens=100)
|
|
outputs = tokenizer.batch_decode(outputs[:, inputs["input_ids"].shape[-1]:])
|
|
print(outputs[0])
|
|
```
|
|
|
|
</hfoption>
|
|
<hfoption id="AutoModel - Multimodal">
|
|
|
|
```py
|
|
from transformers import AutoProcessor, Llama4ForConditionalGeneration
|
|
import torch
|
|
|
|
model_id = "meta-llama/Llama-4-Scout-17B-16E-Instruct"
|
|
|
|
processor = AutoProcessor.from_pretrained(model_id)
|
|
model = Llama4ForConditionalGeneration.from_pretrained(
|
|
model_id,
|
|
device_map="auto",
|
|
torch_dtype=torch.bfloat16,
|
|
)
|
|
|
|
img_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/0052a70beed5bf71b92610a43a52df6d286cd5f3/diffusers/rabbit.jpg"
|
|
messages = [
|
|
{
|
|
"role": "user",
|
|
"content": [
|
|
{"type": "image", "url": img_url},
|
|
{"type": "text", "text": "Describe this image in two sentences."},
|
|
]
|
|
},
|
|
]
|
|
|
|
inputs = processor.apply_chat_template(
|
|
messages,
|
|
add_generation_prompt=True,
|
|
tokenize=True,
|
|
return_dict=True,
|
|
return_tensors="pt",
|
|
).to(model.device)
|
|
|
|
outputs = model.generate(
|
|
**inputs,
|
|
max_new_tokens=256,
|
|
)
|
|
|
|
response = processor.batch_decode(outputs[:, inputs["input_ids"].shape[-1]:])[0]
|
|
print(response)
|
|
```
|
|
|
|
</hfoption>
|
|
<hfoption id="AutoModel - Multimodal with multiple images">
|
|
|
|
```py
|
|
from transformers import AutoProcessor, Llama4ForConditionalGeneration
|
|
import torch
|
|
|
|
model_id = "meta-llama/Llama-4-Scout-17B-16E-Instruct"
|
|
|
|
processor = AutoProcessor.from_pretrained(model_id)
|
|
model = Llama4ForConditionalGeneration.from_pretrained(
|
|
model_id,
|
|
device_map="auto",
|
|
torch_dtype=torch.bfloat16,
|
|
)
|
|
|
|
url1 = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/0052a70beed5bf71b92610a43a52df6d286cd5f3/diffusers/rabbit.jpg"
|
|
url2 = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/datasets/cat_style_layout.png"
|
|
messages = [
|
|
{
|
|
"role": "user",
|
|
"content": [
|
|
{"type": "image", "url": url1},
|
|
{"type": "image", "url": url2},
|
|
{"type": "text", "text": "Can you describe how these two images are similar, and how they differ?"},
|
|
]
|
|
},
|
|
]
|
|
|
|
inputs = processor.apply_chat_template(
|
|
messages,
|
|
add_generation_prompt=True,
|
|
tokenize=True,
|
|
return_dict=True,
|
|
return_tensors="pt",
|
|
).to(model.device)
|
|
|
|
outputs = model.generate(
|
|
**inputs,
|
|
max_new_tokens=256,
|
|
)
|
|
|
|
response = processor.batch_decode(outputs[:, inputs["input_ids"].shape[-1]:])[0]
|
|
print(response)
|
|
```
|
|
|
|
</hfoption>
|
|
<hfoption id="AutoModel - Long context">
|
|
|
|
Beware: the example below uses both `device_map="auto"` and flex-attention.
|
|
Please use `torchrun` to run this example in tensor-parallel mode.
|
|
|
|
We will work to enable running with `device_map="auto"` and flex-attention without
|
|
tensor-parallel in the future.
|
|
|
|
```py
|
|
from transformers import Llama4ForConditionalGeneration, AutoTokenizer
|
|
import torch
|
|
import time
|
|
|
|
file = "very_long_context_prompt.txt"
|
|
model_id = "meta-llama/Llama-4-Scout-17B-16E-Instruct"
|
|
|
|
with open(file, "r") as f:
|
|
very_long_text = "\n".join(f.readlines())
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
|
model = Llama4ForConditionalGeneration.from_pretrained(
|
|
model_id,
|
|
device_map="auto",
|
|
attn_implementation="flex_attention",
|
|
torch_dtype=torch.bfloat16
|
|
)
|
|
|
|
messages = [
|
|
{"role": "user", "content": f"Look at the following texts: [{very_long_text}]\n\n\n\nWhat are the books, and who wrote them? Make me a nice list."},
|
|
]
|
|
input_ids = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt")
|
|
|
|
torch.cuda.synchronize()
|
|
start = time.time()
|
|
out = model.generate(
|
|
input_ids.to(model.device),
|
|
prefill_chunk_size=2048*8,
|
|
max_new_tokens=300,
|
|
cache_implementation="hybrid",
|
|
)
|
|
print(time.time()-start)
|
|
print(tokenizer.batch_decode(out[:, input_ids.shape[-1]:]))
|
|
print(f"{torch.cuda.max_memory_allocated(model.device) / 1024**3:.2f} GiB")
|
|
```
|
|
|
|
</hfoption>
|
|
</hfoptions>
|
|
|
|
## Efficiency; how to get the best out of llama 4
|
|
|
|
### The Attention methods
|
|
|
|
Updating the default attention function can significantly improve compute performance as well as memory usage. Refer to the [Attention Interface](../attention_interface) overview for an in-depth explanation of our interface.
|
|
|
|
As of release, the Llama 4 model supports the following attention methods: `eager`, `flex_attention`, `sdpa`. We recommend using `flex_attention` for best results.
|
|
Switching attention mechanism is done at the model initialization step:
|
|
|
|
|
|
<hfoptions id="Attention">
|
|
<hfoption id="Flex Attention">
|
|
|
|
Setting Flex Attention ensures the best results with the very long context the model can handle.
|
|
|
|
> [!TIP] Beware: the example below uses both `device_map="auto"` and flex-attention.
|
|
> Please use `torchrun` to run this example in tensor-parallel mode.
|
|
>
|
|
> We will work to enable running with `device_map="auto"` and flex-attention without
|
|
> tensor-parallel in the future.
|
|
|
|
```py
|
|
from transformers import Llama4ForConditionalGeneration
|
|
import torch
|
|
|
|
model = Llama4ForConditionalGeneration.from_pretrained(
|
|
model_id,
|
|
attn_implementation="flex_attention",
|
|
device_map="auto",
|
|
torch_dtype=torch.bfloat16,
|
|
)
|
|
```
|
|
</hfoption>
|
|
<hfoption id="SDPA">
|
|
The `sdpa` attention method is generally more compute-efficient than the `eager` method.
|
|
|
|
```py
|
|
from transformers import Llama4ForConditionalGeneration
|
|
import torch
|
|
|
|
model = Llama4ForConditionalGeneration.from_pretrained(
|
|
model_id,
|
|
attn_implementation="sdpa",
|
|
device_map="auto",
|
|
torch_dtype=torch.bfloat16,
|
|
)
|
|
```
|
|
</hfoption>
|
|
<hfoption id="Eager">
|
|
The `eager` attention method is set by default, so no need for anything different when loading the model:
|
|
|
|
```py
|
|
from transformers import Llama4ForConditionalGeneration
|
|
import torch
|
|
|
|
model = Llama4ForConditionalGeneration.from_pretrained(
|
|
model_id,
|
|
device_map="auto",
|
|
torch_dtype=torch.bfloat16,
|
|
)
|
|
```
|
|
</hfoption>
|
|
</hfoptions>
|
|
|
|
|
|
### Quantization
|
|
|
|
Quantization reduces the memory burden of large models by representing the weights in a lower precision. Refer to the [Quantization](../quantization/overview) overview for available quantization backends.
|
|
At time of release, both FBGEMM and LLM-Compressor are supported; more quantization methods will be supported in the days that follow the release.
|
|
|
|
See below for examples using both:
|
|
|
|
|
|
|
|
Here is an example loading an BF16 model in FP8 using the FBGEMM approach:
|
|
|
|
<hfoptions id="Quantization">
|
|
<hfoption id="FBGEMM">
|
|
|
|
```python
|
|
from transformers import AutoTokenizer, Llama4ForConditionalGeneration, FbgemmFp8Config
|
|
import torch
|
|
|
|
model_id = "meta-llama/Llama-4-Scout-17B-16E-Instruct"
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
|
|
|
messages = [
|
|
{"role": "user", "content": "Who are you?"},
|
|
]
|
|
inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt", return_dict=True)
|
|
|
|
model = Llama4ForConditionalGeneration.from_pretrained(
|
|
model_id,
|
|
device_map="auto",
|
|
torch_dtype=torch.bfloat16,
|
|
quantization_config=FbgemmFp8Config()
|
|
)
|
|
|
|
outputs = model.generate(**inputs.to(model.device), max_new_tokens=100)
|
|
outputs = tokenizer.batch_decode(outputs[:, inputs["input_ids"].shape[-1]:])
|
|
print(outputs[0])
|
|
```
|
|
|
|
</hfoption>
|
|
<hfoption id="LLM-Compressor">
|
|
|
|
To use the LLM-Compressor technique, we recommend leveraging the pre-quantized FP8 checkpoint available with the release:
|
|
|
|
```python
|
|
from transformers import AutoTokenizer, Llama4ForConditionalGeneration
|
|
import torch
|
|
|
|
model_id = "meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8"
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
|
|
|
messages = [
|
|
{"role": "user", "content": "Who are you?"},
|
|
]
|
|
inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt", return_dict=True)
|
|
|
|
model = Llama4ForConditionalGeneration.from_pretrained(
|
|
model_id,
|
|
tp_plan="auto",
|
|
torch_dtype=torch.bfloat16,
|
|
)
|
|
|
|
outputs = model.generate(**inputs.to(model.device), max_new_tokens=100)
|
|
outputs = tokenizer.batch_decode(outputs[:, inputs["input_ids"].shape[-1]:])
|
|
print(outputs[0])
|
|
```
|
|
</hfoption>
|
|
</hfoptions>
|
|
|
|
### Offloading
|
|
|
|
Enabling CPU-offloading means that components of the model might be moved to CPU instead of GPU in case the GPU-memory available isn't sufficient to load the entire model.
|
|
At inference, different components will be loaded/unloaded from/to the GPU on the fly. This ensures that the model can be loaded on smaller machines as long as the CPU-memory is sufficient.
|
|
However, this also slows down inference as it adds communication overhead.
|
|
|
|
In order to enable CPU-offloading, you simply need to specify the `device_map` to `auto` at model load:
|
|
|
|
```py
|
|
from transformers import Llama4ForConditionalGeneration
|
|
import torch
|
|
|
|
model = Llama4ForConditionalGeneration.from_pretrained(
|
|
model_id,
|
|
device_map="auto",
|
|
torch_dtype=torch.bfloat16,
|
|
)
|
|
```
|
|
|
|
## Llama4Config
|
|
|
|
[[autodoc]] Llama4Config
|
|
|
|
## Llama4TextConfig
|
|
|
|
[[autodoc]] Llama4TextConfig
|
|
|
|
## Llama4VisionConfig
|
|
|
|
[[autodoc]] Llama4VisionConfig
|
|
|
|
## Llama4Processor
|
|
|
|
[[autodoc]] Llama4Processor
|
|
|
|
## Llama4ImageProcessorFast
|
|
|
|
[[autodoc]] Llama4ImageProcessorFast
|
|
|
|
## Llama4ForConditionalGeneration
|
|
|
|
[[autodoc]] Llama4ForConditionalGeneration
|
|
- forward
|
|
|
|
## Llama4ForCausalLM
|
|
|
|
[[autodoc]] Llama4ForCausalLM
|
|
- forward
|
|
|
|
## Llama4TextModel
|
|
|
|
[[autodoc]] Llama4TextModel
|
|
- forward
|
|
|
|
## Llama4ForCausalLM
|
|
|
|
[[autodoc]] Llama4ForCausalLM
|
|
- forward
|
|
|
|
## Llama4VisionModel
|
|
|
|
[[autodoc]] Llama4VisionModel
|
|
- forward
|