transformers/docs/source/en/model_doc/llama4.md
Arthur 25b7f27234
Add llama4 (#37307)
* remove one of the last deps

* update fast image processor after refactor

* styling

* more quality of life improvements

* nit

* update

* cleanups

* some cleanups

* vllm updates

* update fake image token

* [convert] Fix typo

* [convert] Strip extraneous bytes from shards

* [convert] Minor fixes

* [convert] Use num_experts

* multi-image fixes in modeling + processor

* fixup size

* 128 experts

* Use default rope

* Unfuse mlp

* simplify a lot inputs embeds merging

* remove .item() 👀

* fix from review

* Address feedback

* Use None "default" for rope_scaling. Add eot.

* set seed

* return aspect ratios and bug fixes

* Moe 128 rebased (#8)

* 128 experts

* Use default rope

* Unfuse mlp

* Address feedback

* Use None "default" for rope_scaling. Add eot.

* Meta/llama quant compat (#7)

* add quant compatible model & conversion code for llama4

* fix a few issues

* fix a few issues

* minor type mapping fix

---------

Co-authored-by: Lu Fang <fanglu@fb.com>

* use a new config parameter to determine which model definition to use for MoE

---------

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
Co-authored-by: Lu Fang <fanglu@fb.com>

* un-comment write_tokenizer from converting script

* remove un-used imports

* [llama4] Pop aspect_ratios from image processor output in Llama4Processor

Signed-off-by: Jon Swenson <jmswen@gmail.com>

* Fix parameter_count name

* Update src/transformers/models/llama4/configuration_llama4.py

* nit

* Add changes for no_rope, moe_layers, chunked attention. Just need to test all

* Update src/transformers/models/llama4/image_processing_llama4_fast.py

* nit

* fix post merge with main

* support flex attention

* fixes

* fix

* add layer

* small updates

* rebase and delete llm_compressor

* nit

* [llama4/mm] Add back <|image|> token that delimits global tile

* [llama4/mm] Fix Llama 4 image processing unit tests

* add explicit dtype

Signed-off-by: Jon Swenson <jmswen@gmail.com>

* sdpa works

* comment todo small

* fix model loading

Signed-off-by: Zijing Liu <liuzijing2014@gmail.com>

* revert

* nits

* small fix for TP on 1 node

* Read new params from config

* Add <|eom|>

* lol don't know how this got here

* adding fp8

* Save processor, fix chat template

* style

* Add boi/eoi tokens

We don't use them.

* fixes for now flex seems to work :)

* updates

* nits

* updates

* missking keys

* add context parallel

* update

* update

* fix

* nits

* add worldsize and make eager attn work for vision

* Ignore new key present in base models

* add tp_plan

* fix nope

Signed-off-by: Zijing Liu <liuzijing2014@gmail.com>

* minor fix

Signed-off-by: Zijing Liu <liuzijing2014@gmail.com>

* Clean up Llama4 vision model

* current updates

* add support for `attn_temperature_tuning`

* add floor scale

* add missing attn scales

* push what works, dirty trick for the device synch

* oups

* Fix pad_token_id

See
https://huggingface.co/ll-re/Llama-4-Scout-17B-16E/discussions/2/files
Confirmed in the original codebase.

* fix causallml loading

* rm

* fix tied-weights

* fix sdpa

* push current version

* should work with both short and long

* add compressed_tensos & fix fbgemm tp

* Fix flex impl

* style

* chunking

* try to revert the potentially breaking change

* fix auto factory

* fix shapes in general

* rm processing

* commit cache utils cleanup

* Fix context length

* fix

* allocate

* update tp_plan

* fix SDPA!

* Add support for sparse `Llama4TextMoe` layer from the kernel hub

* cleanup

* better merge

* update

* still broken fixing now

* nits

* revert print

* Write max_position_embeddings and max_model_length

* Update modeling_llama4.py

* Save attention_chunk_size

* Sync eos terminators

* Read initializer_range

* style

* remove `dict`

* fix

* eager should use `chunked_attention_mask`

* revert

* fixup

* fix config

* Revert "Merge pull request #36 from huggingface/sparse-llama4-moe"

This reverts commit ccda19f050, reversing
changes made to a515579aed.

* Fix typo and remove warning with compiled flex and chunked prefill

* Fix MoE vs FF (#41)

* fix

* Use correct no_rope_layers if provided one is empty list

* update tests

* fix

* skipping some tests

* fix fp8 loading

Signed-off-by: Zijing Liu <liuzijing2014@gmail.com>

* fix text geneartion pipeline

Signed-off-by: Zijing Liu <liuzijing2014@gmail.com>

* eager needs 4D mask

* fix

* Some cleanup

* fix

* update

* fix

* replace correctly module

* patch

* modulelist

* update

* update

* clean up

* Don't move to `cuda:0` in distributed mode

* restrict to compressed tensors for now

* rm print

* Docs!

* Fixes

* Update docs/source/en/model_doc/llama4.md

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* Fixes

* cuda graph fix

* revert some stuff

* fixup

* styling

* Update src/transformers/models/llama4/modeling_llama4.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* fixup

* commit licence, cleanup here and there and style

* more styling changes

* fix dummies

* fix and clean docstrings

* remove comment

* remove warning

* Only fast image processor is supported

* nit

* trigger CI

* fix issue with flex encoder

* fix dynamic cache

* Code quality

* Code quality

* fix more tests for now

* Code quality

* Code quality

* Nuke bunch of failing stuff

* Code quality

* Code quality

* cleanup removal of slow image processor

* ruff fix fast image processor

* fix

* fix styling

* Docs

* Repo consistency

* Repo consistency

* fix sliding window issue

* separate llama cache

* styling

* Repo consistency

* Repo consistency

* push waht works

* L4 Repo consistency

* Docs

* fix last last alst alst alst alstsaltlsltlaslt

---------

Signed-off-by: Jon Swenson <jmswen@gmail.com>
Signed-off-by: Zijing Liu <liuzijing2014@gmail.com>
Co-authored-by: yonigozlan <yoni.gozlan10@gmail.com>
Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
Co-authored-by: Pablo Montalvo <pablo.montalvo.leroux@gmail.com>
Co-authored-by: Pablo Montalvo <39954772+molbap@users.noreply.github.com>
Co-authored-by: Keyun Tong <tongkeyun@gmail.com>
Co-authored-by: Zijing Liu <liuzijing2014@users.noreply.github.com>
Co-authored-by: Lu Fang <fanglu@fb.com>
Co-authored-by: Zijing Liu <liuzijing2014@gmail.com>
Co-authored-by: Jon Swenson <jmswen@gmail.com>
Co-authored-by: jmswen <jmswen@users.noreply.github.com>
Co-authored-by: MekkCyber <mekk.cyber@gmail.com>
Co-authored-by: Mohamed Mekkouri <93391238+MekkCyber@users.noreply.github.com>
Co-authored-by: Mohit Sharma <mohit21sharma.ms@gmail.com>
Co-authored-by: Yong Hoon Shin <yhshin@meta.com>
Co-authored-by: Marc Sun <marc@huggingface.co>
Co-authored-by: drisspg <drisspguessous@gmail.com>
Co-authored-by: Cyril Vallez <cyril.vallez@gmail.com>
Co-authored-by: Daniël de Kok <me@danieldk.eu>
Co-authored-by: Lysandre <hi@lysand.re>
Co-authored-by: Ye (Charlotte) Qi <ye.charlotte.qi@gmail.com>
Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
2025-04-05 22:02:22 +02:00

13 KiB

Llama4

PyTorch FlashAttention

Llama 4, developed by Meta, introduces a new auto-regressive Mixture-of-Experts (MoE) architecture. This generation includes two models:

  • The highly capable Llama 4 Maverick with 17B active parameters out of ~400B total, with 128 experts.
  • The efficient Llama 4 Scout also has 17B active parameters out of ~109B total, using just 16 experts.

Both models leverage early fusion for native multimodality, enabling them to process text and image inputs. Maverick and Scout are both trained on up to 40 trillion tokens on data encompassing 200 languages (with specific fine-tuning support for 12 languages including Arabic, Spanish, German, and Hindi).

For deployment, Llama 4 Scout is designed for accessibility, fitting on a single server-grade GPU via on-the-fly 4-bit or 8-bitint4 quantization, while Maverick is available in BF16 and FP8 formats. These models are released under the custom Llama 4 Community License Agreement, available on the model repositories.

You can find all the original Llama checkpoints under the meta-llama organization.

Tip

The Llama 4 family of models comes in two flavors: 109B, and 402B parameters. Both of these flavors are extremely large and won't fit on your run-of-the-mill device. See below for some examples to reduce the memory usage of the model.

For the download to be faster and more resilient, we recommend installing the hf_xet dependency as followed: pip install transformers[hf_xet]

The examples below demonstrates how to generate with [Pipeline] or the [AutoModel]. We additionally add an example showcasing how to toggle the right attributes to enable very long-context generations, as some flavors of Llama 4 have context lengths going up to 10 million tokens.

from transformers import pipeline
import torch

model_id = "meta-llama/Llama-4-Scout-17B-16E-Instruct"

messages = [
    {"role": "user", "content": "what is the recipe of mayonnaise?"},
]

pipe = pipeline(
    "text-generation",
    model=model_id,
    device_map="auto",
    torch_dtype=torch.bfloat16
)

output = pipe(messages, do_sample=False, max_new_tokens=200)
print(output[0]["generated_text"][-1]["content"])
from transformers import AutoTokenizer, Llama4ForConditionalGeneration
import torch

model_id = "meta-llama/Llama-4-Scout-17B-16E-Instruct"

tokenizer = AutoTokenizer.from_pretrained(model_id)

messages = [
    {"role": "user", "content": "Who are you?"},
]
inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt", return_dict=True)

model = Llama4ForConditionalGeneration.from_pretrained(
    model_id,
    device_map="auto",
    torch_dtype=torch.bfloat16
)

outputs = model.generate(**inputs.to(model.device), max_new_tokens=100)
outputs = tokenizer.batch_decode(outputs[:, inputs["input_ids"].shape[-1]:])
print(outputs[0])
from transformers import AutoProcessor, Llama4ForConditionalGeneration
import torch

model_id = "meta-llama/Llama-4-Scout-17B-16E-Instruct"

processor = AutoProcessor.from_pretrained(model_id)
model = Llama4ForConditionalGeneration.from_pretrained(
    model_id,
    device_map="auto",
    torch_dtype=torch.bfloat16,
)

img_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/0052a70beed5bf71b92610a43a52df6d286cd5f3/diffusers/rabbit.jpg"
messages = [
    {
        "role": "user",
        "content": [
            {"type": "image", "url": img_url},
            {"type": "text", "text": "Describe this image in two sentences."},
        ]
    },
]

inputs = processor.apply_chat_template(
    messages,
    add_generation_prompt=True,
    tokenize=True,
    return_dict=True,
    return_tensors="pt",
).to(model.device)

outputs = model.generate(
    **inputs,
    max_new_tokens=256,
)

response = processor.batch_decode(outputs[:, inputs["input_ids"].shape[-1]:])[0]
print(response)
from transformers import AutoProcessor, Llama4ForConditionalGeneration
import torch

model_id = "meta-llama/Llama-4-Scout-17B-16E-Instruct"

processor = AutoProcessor.from_pretrained(model_id)
model = Llama4ForConditionalGeneration.from_pretrained(
    model_id,
    device_map="auto",
    torch_dtype=torch.bfloat16,
)

url1 = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/0052a70beed5bf71b92610a43a52df6d286cd5f3/diffusers/rabbit.jpg"
url2 = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/datasets/cat_style_layout.png"
messages = [
    {
        "role": "user",
        "content": [
            {"type": "image", "url": url1},
            {"type": "image", "url": url2},
            {"type": "text", "text": "Can you describe how these two images are similar, and how they differ?"},
        ]
    },
]

inputs = processor.apply_chat_template(
    messages,
    add_generation_prompt=True,
    tokenize=True,
    return_dict=True,
    return_tensors="pt",
).to(model.device)

outputs = model.generate(
    **inputs,
    max_new_tokens=256,
)

response = processor.batch_decode(outputs[:, inputs["input_ids"].shape[-1]:])[0]
print(response)

Beware: the example below uses both device_map="auto" and flex-attention. Please use torchrun to run this example in tensor-parallel mode.

We will work to enable running with device_map="auto" and flex-attention without tensor-parallel in the future.

from transformers import Llama4ForConditionalGeneration, AutoTokenizer
import torch
import time

file = "very_long_context_prompt.txt"
model_id = "meta-llama/Llama-4-Scout-17B-16E-Instruct"

with open(file, "r") as f:
    very_long_text = "\n".join(f.readlines())

tokenizer = AutoTokenizer.from_pretrained(model_id)
model = Llama4ForConditionalGeneration.from_pretrained(
    model_id,
    device_map="auto",
    attn_implementation="flex_attention",
    torch_dtype=torch.bfloat16
)

messages = [
    {"role": "user", "content": f"Look at the following texts: [{very_long_text}]\n\n\n\nWhat are the books, and who wrote them? Make me a nice list."},
]
input_ids = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt")

torch.cuda.synchronize()
start = time.time()
out = model.generate(
    input_ids.to(model.device),
    prefill_chunk_size=2048*8,
    max_new_tokens=300,
    cache_implementation="hybrid",
)
print(time.time()-start)
print(tokenizer.batch_decode(out[:, input_ids.shape[-1]:]))
print(f"{torch.cuda.max_memory_allocated(model.device) / 1024**3:.2f} GiB")

Efficiency; how to get the best out of llama 4

The Attention methods

Updating the default attention function can significantly improve compute performance as well as memory usage. Refer to the Attention Interface overview for an in-depth explanation of our interface.

As of release, the Llama 4 model supports the following attention methods: eager, flex_attention, sdpa. We recommend using flex_attention for best results. Switching attention mechanism is done at the model initialization step:

Setting Flex Attention ensures the best results with the very long context the model can handle.

[!TIP] Beware: the example below uses both device_map="auto" and flex-attention. Please use torchrun to run this example in tensor-parallel mode.

We will work to enable running with device_map="auto" and flex-attention without tensor-parallel in the future.

from transformers import Llama4ForConditionalGeneration
import torch

model = Llama4ForConditionalGeneration.from_pretrained(
    model_id,
    attn_implementation="flex_attention",
    device_map="auto",
    torch_dtype=torch.bfloat16,
)
The `sdpa` attention method is generally more compute-efficient than the `eager` method.
from transformers import Llama4ForConditionalGeneration
import torch

model = Llama4ForConditionalGeneration.from_pretrained(
    model_id,
    attn_implementation="sdpa",
    device_map="auto",
    torch_dtype=torch.bfloat16,
)
The `eager` attention method is set by default, so no need for anything different when loading the model:
from transformers import Llama4ForConditionalGeneration
import torch

model = Llama4ForConditionalGeneration.from_pretrained(
    model_id,
    device_map="auto",
    torch_dtype=torch.bfloat16,
)

Quantization

Quantization reduces the memory burden of large models by representing the weights in a lower precision. Refer to the Quantization overview for available quantization backends. At time of release, both FBGEMM and LLM-Compressor are supported; more quantization methods will be supported in the days that follow the release.

See below for examples using both:

Here is an example loading an BF16 model in FP8 using the FBGEMM approach:

from transformers import AutoTokenizer, Llama4ForConditionalGeneration, FbgemmFp8Config
import torch

model_id = "meta-llama/Llama-4-Scout-17B-16E-Instruct"

tokenizer = AutoTokenizer.from_pretrained(model_id)

messages = [
    {"role": "user", "content": "Who are you?"},
]
inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt", return_dict=True)

model = Llama4ForConditionalGeneration.from_pretrained(
    model_id,
    device_map="auto",
    torch_dtype=torch.bfloat16,
    quantization_config=FbgemmFp8Config()
)

outputs = model.generate(**inputs.to(model.device), max_new_tokens=100)
outputs = tokenizer.batch_decode(outputs[:, inputs["input_ids"].shape[-1]:])
print(outputs[0])

To use the LLM-Compressor technique, we recommend leveraging the pre-quantized FP8 checkpoint available with the release:

from transformers import AutoTokenizer, Llama4ForConditionalGeneration
import torch

model_id = "meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8"

tokenizer = AutoTokenizer.from_pretrained(model_id)

messages = [
    {"role": "user", "content": "Who are you?"},
]
inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt", return_dict=True)

model = Llama4ForConditionalGeneration.from_pretrained(
    model_id,
    tp_plan="auto",
    torch_dtype=torch.bfloat16,
)

outputs = model.generate(**inputs.to(model.device), max_new_tokens=100)
outputs = tokenizer.batch_decode(outputs[:, inputs["input_ids"].shape[-1]:])
print(outputs[0])

Offloading

Enabling CPU-offloading means that components of the model might be moved to CPU instead of GPU in case the GPU-memory available isn't sufficient to load the entire model. At inference, different components will be loaded/unloaded from/to the GPU on the fly. This ensures that the model can be loaded on smaller machines as long as the CPU-memory is sufficient. However, this also slows down inference as it adds communication overhead.

In order to enable CPU-offloading, you simply need to specify the device_map to auto at model load:

from transformers import Llama4ForConditionalGeneration
import torch

model = Llama4ForConditionalGeneration.from_pretrained(
    model_id,
    device_map="auto",
    torch_dtype=torch.bfloat16,
)

Llama4Config

autodoc Llama4Config

Llama4TextConfig

autodoc Llama4TextConfig

Llama4VisionConfig

autodoc Llama4VisionConfig

Llama4Processor

autodoc Llama4Processor

Llama4ImageProcessorFast

autodoc Llama4ImageProcessorFast

Llama4ForConditionalGeneration

autodoc Llama4ForConditionalGeneration

  • forward

Llama4ForCausalLM

autodoc Llama4ForCausalLM

  • forward

Llama4TextModel

autodoc Llama4TextModel

  • forward

Llama4ForCausalLM

autodoc Llama4ForCausalLM

  • forward

Llama4VisionModel

autodoc Llama4VisionModel

  • forward