Add llama4 (#37307)

* remove one of the last deps

* update fast image processor after refactor

* styling

* more quality of life improvements

* nit

* update

* cleanups

* some cleanups

* vllm updates

* update fake image token

* [convert] Fix typo

* [convert] Strip extraneous bytes from shards

* [convert] Minor fixes

* [convert] Use num_experts

* multi-image fixes in modeling + processor

* fixup size

* 128 experts

* Use default rope

* Unfuse mlp

* simplify a lot inputs embeds merging

* remove .item() 👀

* fix from review

* Address feedback

* Use None "default" for rope_scaling. Add eot.

* set seed

* return aspect ratios and bug fixes

* Moe 128 rebased (#8)

* 128 experts

* Use default rope

* Unfuse mlp

* Address feedback

* Use None "default" for rope_scaling. Add eot.

* Meta/llama quant compat (#7)

* add quant compatible model & conversion code for llama4

* fix a few issues

* fix a few issues

* minor type mapping fix

---------

Co-authored-by: Lu Fang <fanglu@fb.com>

* use a new config parameter to determine which model definition to use for MoE

---------

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
Co-authored-by: Lu Fang <fanglu@fb.com>

* un-comment write_tokenizer from converting script

* remove un-used imports

* [llama4] Pop aspect_ratios from image processor output in Llama4Processor

Signed-off-by: Jon Swenson <jmswen@gmail.com>

* Fix parameter_count name

* Update src/transformers/models/llama4/configuration_llama4.py

* nit

* Add changes for no_rope, moe_layers, chunked attention. Just need to test all

* Update src/transformers/models/llama4/image_processing_llama4_fast.py

* nit

* fix post merge with main

* support flex attention

* fixes

* fix

* add layer

* small updates

* rebase and delete llm_compressor

* nit

* [llama4/mm] Add back <|image|> token that delimits global tile

* [llama4/mm] Fix Llama 4 image processing unit tests

* add explicit dtype

Signed-off-by: Jon Swenson <jmswen@gmail.com>

* sdpa works

* comment todo small

* fix model loading

Signed-off-by: Zijing Liu <liuzijing2014@gmail.com>

* revert

* nits

* small fix for TP on 1 node

* Read new params from config

* Add <|eom|>

* lol don't know how this got here

* adding fp8

* Save processor, fix chat template

* style

* Add boi/eoi tokens

We don't use them.

* fixes for now flex seems to work :)

* updates

* nits

* updates

* missking keys

* add context parallel

* update

* update

* fix

* nits

* add worldsize and make eager attn work for vision

* Ignore new key present in base models

* add tp_plan

* fix nope

Signed-off-by: Zijing Liu <liuzijing2014@gmail.com>

* minor fix

Signed-off-by: Zijing Liu <liuzijing2014@gmail.com>

* Clean up Llama4 vision model

* current updates

* add support for `attn_temperature_tuning`

* add floor scale

* add missing attn scales

* push what works, dirty trick for the device synch

* oups

* Fix pad_token_id

See
https://huggingface.co/ll-re/Llama-4-Scout-17B-16E/discussions/2/files
Confirmed in the original codebase.

* fix causallml loading

* rm

* fix tied-weights

* fix sdpa

* push current version

* should work with both short and long

* add compressed_tensos & fix fbgemm tp

* Fix flex impl

* style

* chunking

* try to revert the potentially breaking change

* fix auto factory

* fix shapes in general

* rm processing

* commit cache utils cleanup

* Fix context length

* fix

* allocate

* update tp_plan

* fix SDPA!

* Add support for sparse `Llama4TextMoe` layer from the kernel hub

* cleanup

* better merge

* update

* still broken fixing now

* nits

* revert print

* Write max_position_embeddings and max_model_length

* Update modeling_llama4.py

* Save attention_chunk_size

* Sync eos terminators

* Read initializer_range

* style

* remove `dict`

* fix

* eager should use `chunked_attention_mask`

* revert

* fixup

* fix config

* Revert "Merge pull request #36 from huggingface/sparse-llama4-moe"

This reverts commit ccda19f050, reversing
changes made to a515579aed.

* Fix typo and remove warning with compiled flex and chunked prefill

* Fix MoE vs FF (#41)

* fix

* Use correct no_rope_layers if provided one is empty list

* update tests

* fix

* skipping some tests

* fix fp8 loading

Signed-off-by: Zijing Liu <liuzijing2014@gmail.com>

* fix text geneartion pipeline

Signed-off-by: Zijing Liu <liuzijing2014@gmail.com>

* eager needs 4D mask

* fix

* Some cleanup

* fix

* update

* fix

* replace correctly module

* patch

* modulelist

* update

* update

* clean up

* Don't move to `cuda:0` in distributed mode

* restrict to compressed tensors for now

* rm print

* Docs!

* Fixes

* Update docs/source/en/model_doc/llama4.md

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* Fixes

* cuda graph fix

* revert some stuff

* fixup

* styling

* Update src/transformers/models/llama4/modeling_llama4.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* fixup

* commit licence, cleanup here and there and style

* more styling changes

* fix dummies

* fix and clean docstrings

* remove comment

* remove warning

* Only fast image processor is supported

* nit

* trigger CI

* fix issue with flex encoder

* fix dynamic cache

* Code quality

* Code quality

* fix more tests for now

* Code quality

* Code quality

* Nuke bunch of failing stuff

* Code quality

* Code quality

* cleanup removal of slow image processor

* ruff fix fast image processor

* fix

* fix styling

* Docs

* Repo consistency

* Repo consistency

* fix sliding window issue

* separate llama cache

* styling

* Repo consistency

* Repo consistency

* push waht works

* L4 Repo consistency

* Docs

* fix last last alst alst alst alstsaltlsltlaslt

---------

Signed-off-by: Jon Swenson <jmswen@gmail.com>
Signed-off-by: Zijing Liu <liuzijing2014@gmail.com>
Co-authored-by: yonigozlan <yoni.gozlan10@gmail.com>
Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
Co-authored-by: Pablo Montalvo <pablo.montalvo.leroux@gmail.com>
Co-authored-by: Pablo Montalvo <39954772+molbap@users.noreply.github.com>
Co-authored-by: Keyun Tong <tongkeyun@gmail.com>
Co-authored-by: Zijing Liu <liuzijing2014@users.noreply.github.com>
Co-authored-by: Lu Fang <fanglu@fb.com>
Co-authored-by: Zijing Liu <liuzijing2014@gmail.com>
Co-authored-by: Jon Swenson <jmswen@gmail.com>
Co-authored-by: jmswen <jmswen@users.noreply.github.com>
Co-authored-by: MekkCyber <mekk.cyber@gmail.com>
Co-authored-by: Mohamed Mekkouri <93391238+MekkCyber@users.noreply.github.com>
Co-authored-by: Mohit Sharma <mohit21sharma.ms@gmail.com>
Co-authored-by: Yong Hoon Shin <yhshin@meta.com>
Co-authored-by: Marc Sun <marc@huggingface.co>
Co-authored-by: drisspg <drisspguessous@gmail.com>
Co-authored-by: Cyril Vallez <cyril.vallez@gmail.com>
Co-authored-by: Daniël de Kok <me@danieldk.eu>
Co-authored-by: Lysandre <hi@lysand.re>
Co-authored-by: Ye (Charlotte) Qi <ye.charlotte.qi@gmail.com>
Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Arthur 2025-04-05 22:02:22 +02:00 committed by GitHub
parent aa40fda346
commit 25b7f27234
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
45 changed files with 5526 additions and 221 deletions

View File

@ -507,6 +507,8 @@
title: Llama2
- local: model_doc/llama3
title: Llama3
- local: model_doc/llama4
title: Llama4
- local: model_doc/longformer
title: Longformer
- local: model_doc/longt5

View File

@ -0,0 +1,442 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.
-->
# Llama4
<div style="float: right;">
<div class="flex flex-wrap space-x-1">
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
<img alt="FlashAttention" src="https://img.shields.io/badge/%E2%9A%A1%EF%B8%8E%20FlashAttention-eae0c8?style=flat">
</div>
</div>
Llama 4, developed by Meta, introduces a new auto-regressive Mixture-of-Experts (MoE) architecture.
This generation includes two models:
- The highly capable Llama 4 Maverick with 17B active parameters out of ~400B total, with 128 experts.
- The efficient Llama 4 Scout also has 17B active parameters out of ~109B total, using just 16 experts.
-
Both models leverage early fusion for native multimodality, enabling them to process text and image inputs.
Maverick and Scout are both trained on up to 40 trillion tokens on data encompassing 200 languages
(with specific fine-tuning support for 12 languages including Arabic, Spanish, German, and Hindi).
For deployment, Llama 4 Scout is designed for accessibility, fitting on a single server-grade GPU via
on-the-fly 4-bit or 8-bitint4 quantization, while Maverick is available in BF16 and FP8 formats.
These models are released under the custom Llama 4 Community License Agreement, available on the model repositories.
You can find all the original Llama checkpoints under the [meta-llama](https://huggingface.co/meta-llama) organization.
> [!TIP]
> The Llama 4 family of models comes in two flavors: 109B, and 402B parameters. Both of these flavors are extremely
> large and won't fit on your run-of-the-mill device. See below for some examples to reduce the memory usage of the
> model.
>
> For the download to be faster and more resilient, we recommend installing the `hf_xet` dependency as followed:
> `pip install transformers[hf_xet]`
The examples below demonstrates how to generate with [`Pipeline`] or the [`AutoModel`]. We additionally add an example
showcasing how to toggle the right attributes to enable very long-context generations, as some flavors of Llama 4
have context lengths going up to 10 million tokens.
<hfoptions id="usage">
<hfoption id="Pipeline">
```py
from transformers import pipeline
import torch
model_id = "meta-llama/Llama-4-Scout-17B-16E-Instruct"
messages = [
{"role": "user", "content": "what is the recipe of mayonnaise?"},
]
pipe = pipeline(
"text-generation",
model=model_id,
device_map="auto",
torch_dtype=torch.bfloat16
)
output = pipe(messages, do_sample=False, max_new_tokens=200)
print(output[0]["generated_text"][-1]["content"])
```
</hfoption>
<hfoption id="AutoModel - Text only">
```py
from transformers import AutoTokenizer, Llama4ForConditionalGeneration
import torch
model_id = "meta-llama/Llama-4-Scout-17B-16E-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_id)
messages = [
{"role": "user", "content": "Who are you?"},
]
inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt", return_dict=True)
model = Llama4ForConditionalGeneration.from_pretrained(
model_id,
device_map="auto",
torch_dtype=torch.bfloat16
)
outputs = model.generate(**inputs.to(model.device), max_new_tokens=100)
outputs = tokenizer.batch_decode(outputs[:, inputs["input_ids"].shape[-1]:])
print(outputs[0])
```
</hfoption>
<hfoption id="AutoModel - Multimodal">
```py
from transformers import AutoProcessor, Llama4ForConditionalGeneration
import torch
model_id = "meta-llama/Llama-4-Scout-17B-16E-Instruct"
processor = AutoProcessor.from_pretrained(model_id)
model = Llama4ForConditionalGeneration.from_pretrained(
model_id,
device_map="auto",
torch_dtype=torch.bfloat16,
)
img_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/0052a70beed5bf71b92610a43a52df6d286cd5f3/diffusers/rabbit.jpg"
messages = [
{
"role": "user",
"content": [
{"type": "image", "url": img_url},
{"type": "text", "text": "Describe this image in two sentences."},
]
},
]
inputs = processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt",
).to(model.device)
outputs = model.generate(
**inputs,
max_new_tokens=256,
)
response = processor.batch_decode(outputs[:, inputs["input_ids"].shape[-1]:])[0]
print(response)
```
</hfoption>
<hfoption id="AutoModel - Multimodal with multiple images">
```py
from transformers import AutoProcessor, Llama4ForConditionalGeneration
import torch
model_id = "meta-llama/Llama-4-Scout-17B-16E-Instruct"
processor = AutoProcessor.from_pretrained(model_id)
model = Llama4ForConditionalGeneration.from_pretrained(
model_id,
device_map="auto",
torch_dtype=torch.bfloat16,
)
url1 = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/0052a70beed5bf71b92610a43a52df6d286cd5f3/diffusers/rabbit.jpg"
url2 = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/datasets/cat_style_layout.png"
messages = [
{
"role": "user",
"content": [
{"type": "image", "url": url1},
{"type": "image", "url": url2},
{"type": "text", "text": "Can you describe how these two images are similar, and how they differ?"},
]
},
]
inputs = processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt",
).to(model.device)
outputs = model.generate(
**inputs,
max_new_tokens=256,
)
response = processor.batch_decode(outputs[:, inputs["input_ids"].shape[-1]:])[0]
print(response)
```
</hfoption>
<hfoption id="AutoModel - Long context">
Beware: the example below uses both `device_map="auto"` and flex-attention.
Please use `torchrun` to run this example in tensor-parallel mode.
We will work to enable running with `device_map="auto"` and flex-attention without
tensor-parallel in the future.
```py
from transformers import Llama4ForConditionalGeneration, AutoTokenizer
import torch
import time
file = "very_long_context_prompt.txt"
model_id = "meta-llama/Llama-4-Scout-17B-16E-Instruct"
with open(file, "r") as f:
very_long_text = "\n".join(f.readlines())
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = Llama4ForConditionalGeneration.from_pretrained(
model_id,
device_map="auto",
attn_implementation="flex_attention",
torch_dtype=torch.bfloat16
)
messages = [
{"role": "user", "content": f"Look at the following texts: [{very_long_text}]\n\n\n\nWhat are the books, and who wrote them? Make me a nice list."},
]
input_ids = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt")
torch.cuda.synchronize()
start = time.time()
out = model.generate(
input_ids.to(model.device),
prefill_chunk_size=2048*8,
max_new_tokens=300,
cache_implementation="hybrid",
)
print(time.time()-start)
print(tokenizer.batch_decode(out[:, input_ids.shape[-1]:]))
print(f"{torch.cuda.max_memory_allocated(model.device) / 1024**3:.2f} GiB")
```
</hfoption>
</hfoptions>
## Efficiency; how to get the best out of llama 4
### The Attention methods
Updating the default attention function can significantly improve compute performance as well as memory usage. Refer to the [Attention Interface](../attention_interface) overview for an in-depth explanation of our interface.
As of release, the Llama 4 model supports the following attention methods: `eager`, `flex_attention`, `sdpa`. We recommend using `flex_attention` for best results.
Switching attention mechanism is done at the model initialization step:
<hfoptions id="Attention">
<hfoption id="Flex Attention">
Setting Flex Attention ensures the best results with the very long context the model can handle.
> [!TIP] Beware: the example below uses both `device_map="auto"` and flex-attention.
> Please use `torchrun` to run this example in tensor-parallel mode.
>
> We will work to enable running with `device_map="auto"` and flex-attention without
> tensor-parallel in the future.
```py
from transformers import Llama4ForConditionalGeneration
import torch
model = Llama4ForConditionalGeneration.from_pretrained(
model_id,
attn_implementation="flex_attention",
device_map="auto",
torch_dtype=torch.bfloat16,
)
```
</hfoption>
<hfoption id="SDPA">
The `sdpa` attention method is generally more compute-efficient than the `eager` method.
```py
from transformers import Llama4ForConditionalGeneration
import torch
model = Llama4ForConditionalGeneration.from_pretrained(
model_id,
attn_implementation="sdpa",
device_map="auto",
torch_dtype=torch.bfloat16,
)
```
</hfoption>
<hfoption id="Eager">
The `eager` attention method is set by default, so no need for anything different when loading the model:
```py
from transformers import Llama4ForConditionalGeneration
import torch
model = Llama4ForConditionalGeneration.from_pretrained(
model_id,
device_map="auto",
torch_dtype=torch.bfloat16,
)
```
</hfoption>
</hfoptions>
### Quantization
Quantization reduces the memory burden of large models by representing the weights in a lower precision. Refer to the [Quantization](../quantization/overview) overview for available quantization backends.
At time of release, both FBGEMM and LLM-Compressor are supported; more quantization methods will be supported in the days that follow the release.
See below for examples using both:
Here is an example loading an BF16 model in FP8 using the FBGEMM approach:
<hfoptions id="Quantization">
<hfoption id="FBGEMM">
```python
from transformers import AutoTokenizer, Llama4ForConditionalGeneration, FbgemmFp8Config
import torch
model_id = "meta-llama/Llama-4-Scout-17B-16E-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_id)
messages = [
{"role": "user", "content": "Who are you?"},
]
inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt", return_dict=True)
model = Llama4ForConditionalGeneration.from_pretrained(
model_id,
device_map="auto",
torch_dtype=torch.bfloat16,
quantization_config=FbgemmFp8Config()
)
outputs = model.generate(**inputs.to(model.device), max_new_tokens=100)
outputs = tokenizer.batch_decode(outputs[:, inputs["input_ids"].shape[-1]:])
print(outputs[0])
```
</hfoption>
<hfoption id="LLM-Compressor">
To use the LLM-Compressor technique, we recommend leveraging the pre-quantized FP8 checkpoint available with the release:
```python
from transformers import AutoTokenizer, Llama4ForConditionalGeneration
import torch
model_id = "meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8"
tokenizer = AutoTokenizer.from_pretrained(model_id)
messages = [
{"role": "user", "content": "Who are you?"},
]
inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt", return_dict=True)
model = Llama4ForConditionalGeneration.from_pretrained(
model_id,
tp_plan="auto",
torch_dtype=torch.bfloat16,
)
outputs = model.generate(**inputs.to(model.device), max_new_tokens=100)
outputs = tokenizer.batch_decode(outputs[:, inputs["input_ids"].shape[-1]:])
print(outputs[0])
```
</hfoption>
</hfoptions>
### Offloading
Enabling CPU-offloading means that components of the model might be moved to CPU instead of GPU in case the GPU-memory available isn't sufficient to load the entire model.
At inference, different components will be loaded/unloaded from/to the GPU on the fly. This ensures that the model can be loaded on smaller machines as long as the CPU-memory is sufficient.
However, this also slows down inference as it adds communication overhead.
In order to enable CPU-offloading, you simply need to specify the `device_map` to `auto` at model load:
```py
from transformers import Llama4ForConditionalGeneration
import torch
model = Llama4ForConditionalGeneration.from_pretrained(
model_id,
device_map="auto",
torch_dtype=torch.bfloat16,
)
```
## Llama4Config
[[autodoc]] Llama4Config
## Llama4TextConfig
[[autodoc]] Llama4TextConfig
## Llama4VisionConfig
[[autodoc]] Llama4VisionConfig
## Llama4Processor
[[autodoc]] Llama4Processor
## Llama4ImageProcessorFast
[[autodoc]] Llama4ImageProcessorFast
## Llama4ForConditionalGeneration
[[autodoc]] Llama4ForConditionalGeneration
- forward
## Llama4ForCausalLM
[[autodoc]] Llama4ForCausalLM
- forward
## Llama4TextModel
[[autodoc]] Llama4TextModel
- forward
## Llama4ForCausalLM
[[autodoc]] Llama4ForCausalLM
- forward
## Llama4VisionModel
[[autodoc]] Llama4VisionModel
- forward

View File

@ -562,6 +562,12 @@ _import_structure = {
"models.levit": ["LevitConfig"],
"models.lilt": ["LiltConfig"],
"models.llama": ["LlamaConfig"],
"models.llama4": [
"Llama4Config",
"Llama4Processor",
"Llama4TextConfig",
"Llama4VisionConfig",
],
"models.llava": [
"LlavaConfig",
"LlavaProcessor",
@ -1354,6 +1360,7 @@ else:
_import_structure["models.detr"].append("DetrImageProcessorFast")
_import_structure["models.gemma3"].append("Gemma3ImageProcessorFast")
_import_structure["models.got_ocr2"].append("GotOcr2ImageProcessorFast")
_import_structure["models.llama4"].append("Llama4ImageProcessorFast")
_import_structure["models.llava"].append("LlavaImageProcessorFast")
_import_structure["models.llava_next"].append("LlavaNextImageProcessorFast")
_import_structure["models.llava_onevision"].append("LlavaOnevisionImageProcessorFast")
@ -2510,6 +2517,15 @@ else:
"GlmPreTrainedModel",
]
)
_import_structure["models.llama4"].extend(
[
"Llama4ForCausalLM",
"Llama4ForConditionalGeneration",
"Llama4TextModel",
"Llama4VisionModel",
"Llama4PreTrainedModel",
]
)
_import_structure["models.glpn"].extend(
[
"GLPNForDepthEstimation",
@ -5807,6 +5823,12 @@ if TYPE_CHECKING:
from .models.levit import LevitConfig
from .models.lilt import LiltConfig
from .models.llama import LlamaConfig
from .models.llama4 import (
Llama4Config,
Llama4Processor,
Llama4TextConfig,
Llama4VisionConfig,
)
from .models.llava import (
LlavaConfig,
LlavaProcessor,
@ -6646,6 +6668,7 @@ if TYPE_CHECKING:
from .models.detr import DetrImageProcessorFast
from .models.gemma3 import Gemma3ImageProcessorFast
from .models.got_ocr2 import GotOcr2ImageProcessorFast
from .models.llama4 import Llama4ImageProcessorFast
from .models.llava import LlavaImageProcessorFast
from .models.llava_next import LlavaNextImageProcessorFast
from .models.llava_onevision import LlavaOnevisionImageProcessorFast
@ -7827,6 +7850,13 @@ if TYPE_CHECKING:
LlamaModel,
LlamaPreTrainedModel,
)
from .models.llama4 import (
Llama4ForCausalLM,
Llama4ForConditionalGeneration,
Llama4PreTrainedModel,
Llama4TextModel,
Llama4VisionModel,
)
from .models.llava import (
LlavaForConditionalGeneration,
LlavaPreTrainedModel,

View File

@ -1811,6 +1811,200 @@ class HybridCache(Cache):
self.value_cache[layer_idx].zero_()
class HybridChunkedCache(Cache):
"""
Hybrid Cache class to be used with `torch.compile` for Gemma2 models that alternate between a local sliding window attention
and global attention in every other layer. Under the hood, Hybrid Cache leverages ["SlidingWindowCache"] for sliding window attention
and ["StaticCache"] for global attention. For more information, see the documentation of each subcomponeent cache class.
Parameters:
config (`PretrainedConfig):
The configuration file defining the shape-related attributes required to initialize the static cache.
max_batch_size (`int`):
The maximum batch size with which the model will be used. Note that a new instance must be instantiated if a
smaller batch size is used.
max_cache_len (`int`, *optional*):
The maximum sequence length with which the model will be used.
device (`torch.device` or `str`, *optional*):
The device on which the cache should be initialized. If you're using more than 1 computation device, you
should pass the `layer_device_map` argument instead.
dtype (torch.dtype, *optional*, defaults to `torch.bfloat16`):
The default `dtype` to use when initializing the layer.
layer_device_map (`Optional[Dict[int, Union[str, torch.device, int]]]]`, *optional*):
Mapping between the layers and its device. This is required when you are manually initializing the cache
and the model is split between different gpus. You can know which layers mapped to which device by
checking the associated device_map: `model.hf_device_map`.
Example:
```python
>>> from transformers import AutoTokenizer, AutoModelForCausalLM, HybridCache
>>> model = AutoModelForCausalLM.from_pretrained("google/gemma-2-2b")
>>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b")
>>> inputs = tokenizer(text="My name is Gemma", return_tensors="pt")
>>> # Prepare a cache class and pass it to model's forward
>>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate
>>> max_generated_length = inputs.input_ids.shape[1] + 10
>>> past_key_values = HybridCache(config=model.config, max_batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype)
>>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
>>> outputs.past_key_values # access cache filled with key/values from generation
HybridCache()
```
"""
# TODO (joao): dive deeper into gemma2 and paligemma -- there are reports of speed loss with compilation. Revert
# ALL changes from the PR that commented the line below when reactivating it.
# is_compileable = True
def __init__(
self,
config: PretrainedConfig,
max_batch_size: int,
max_cache_len: Optional[int] = None,
device: Union[torch.device, str, None] = None,
dtype: torch.dtype = torch.bfloat16,
layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None,
) -> None:
super().__init__()
if not hasattr(config, "sliding_window") or config.sliding_window is None:
self.sliding_window = getattr(config.get_text_config(), "attention_chunk_size", 8192)
else:
self.sliding_window = config.sliding_window
self.max_cache_len = max_cache_len
self.max_batch_size = max_batch_size
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
self._dtype = dtype
if hasattr(config.get_text_config(), "no_rope_layers"):
self.is_sliding = config.no_rope_layers
else:
layer_switch = getattr(config, "sliding_window_pattern", 2)
self.is_sliding = [bool((i + 1) % layer_switch) for i in range(config.num_hidden_layers)]
self.key_cache: List[torch.Tensor] = []
self.value_cache: List[torch.Tensor] = []
self.cumulative_length = [0 for _ in range(config.num_hidden_layers)]
def initialise_cache_layer(self, layer_idx, key_states):
if len(self.key_cache) > layer_idx:
return
num_key_value_heads = key_states.shape[1]
device = key_states.device
global_cache_shape = (self.max_batch_size, num_key_value_heads, self.max_cache_len, self.head_dim)
sliding_cache_shape = (
self.max_batch_size,
num_key_value_heads,
self.sliding_window,
self.head_dim,
)
# Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph
# breaks when updating the cache.
cache_shape = sliding_cache_shape if self.is_sliding[layer_idx] else global_cache_shape
new_layer_key_cache = torch.zeros(cache_shape, dtype=self._dtype, device=device)
new_layer_value_cache = torch.zeros(cache_shape, dtype=self._dtype, device=device)
torch._dynamo.mark_static_address(new_layer_key_cache)
torch._dynamo.mark_static_address(new_layer_value_cache)
self.key_cache.append(new_layer_key_cache)
self.value_cache.append(new_layer_value_cache)
def _sliding_update(self, cache_position, layer_idx, key_states, value_states, k_out, v_out, max_cache_len):
cumulative_length = self.cumulative_length[layer_idx]
is_full = cumulative_length >= max_cache_len
if is_full:
full_key_states = torch.cat((k_out[:, :, 1:, :], key_states), dim=-2)
full_value_states = torch.cat((v_out[:, :, 1:, :], value_states), dim=-2)
elif not is_full and cumulative_length + key_states.shape[2] > max_cache_len:
full_key_states = torch.cat((k_out[:, :, :cumulative_length, :], key_states), dim=-2)
full_value_states = torch.cat((v_out[:, :, :cumulative_length, :], value_states), dim=-2)
else:
self.key_cache[layer_idx].index_copy_(2, cache_position, key_states)
self.value_cache[layer_idx].index_copy_(2, cache_position, value_states)
self.cumulative_length[layer_idx] += key_states.shape[-2]
return self.key_cache[layer_idx], self.value_cache[layer_idx]
self.key_cache[layer_idx].copy_(full_key_states[:, :, -max_cache_len:, :])
self.value_cache[layer_idx].copy_(full_value_states[:, :, -max_cache_len:, :])
self.cumulative_length[layer_idx] += key_states.shape[-2]
# we should return the whole states instead of k_out, v_out to take the whole prompt
# into consideration when building kv cache instead of just throwing away tokens outside of the window
return full_key_states, full_value_states
def _static_update(self, cache_position, layer_idx, key_states, value_states, k_out, v_out, max_cache_len):
k_out[:, :, cache_position] = key_states
v_out[:, :, cache_position] = value_states
self.key_cache[layer_idx] = k_out
self.value_cache[layer_idx] = v_out
return k_out, v_out
def update(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
layer_idx: int,
cache_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
if cache_kwargs is None:
cache_kwargs = {}
cache_position = cache_kwargs.get("cache_position")
self.initialise_cache_layer(layer_idx, key_states)
# These two `if` blocks are only reached in multigpu and if `layer_device_map` is not passed. They are used
# when the cache is initialized in the forward pass (e.g. Gemma2)
if self.key_cache[layer_idx].device != key_states.device:
self.key_cache[layer_idx] = self.key_cache[layer_idx].to(key_states.device)
if self.value_cache[layer_idx].device != value_states.device:
self.value_cache[layer_idx] = self.value_cache[layer_idx].to(value_states.device)
k_out = self.key_cache[layer_idx]
v_out = self.value_cache[layer_idx]
key_states = key_states.to(k_out.dtype)
value_states = value_states.to(v_out.dtype)
if self.is_sliding[layer_idx]:
update_fn = self._sliding_update
else:
update_fn = self._static_update
return update_fn(
cache_position,
layer_idx,
key_states,
value_states,
k_out,
v_out,
k_out.shape[2],
)
def get_max_cache_shape(self) -> Optional[int]:
return self.max_cache_len
def get_seq_length(self, layer_idx: Optional[int] = 0):
# Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's
# limit the check to the first batch member and head dimension.
# TODO: deprecate this function in favor of `cache_position`
if layer_idx != 0:
raise ValueError(
"`get_seq_length` on `HybridCache` may get inconsistent results depending on the layer index. "
"Using the `layer_idx` argument is not supported."
)
if len(self.key_cache) == 0:
return 0
return (self.key_cache[layer_idx][0, 0].any(dim=-1)).sum()
def reset(self):
"""Resets the cache values while preserving the objects"""
for layer_idx in range(len(self.key_cache)):
# In-place ops prevent breaking the static address
self.key_cache[layer_idx].zero_()
self.value_cache[layer_idx].zero_()
self.cumulative_length = [0 for _ in range(len(self.cumulative_length))]
class MambaCache:
"""
Cache for mamba model which does not have attention mechanism and key value states.

View File

@ -801,18 +801,19 @@ class PretrainedConfig(PushToHubMixin):
def to_diff_dict(self) -> dict[str, Any]:
"""
Removes all attributes from config which correspond to the default config attributes for better readability and
serializes to a Python dictionary.
Removes all attributes from the configuration that correspond to the default config attributes for
better readability, while always retaining the `config` attribute from the class. Serializes to a
Python dictionary.
Returns:
`Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance,
Dict[str, Any]: Dictionary of all the attributes that make up this configuration instance.
"""
config_dict = self.to_dict()
# get the default config dict
# Get the default config dict (from a fresh PreTrainedConfig instance)
default_config_dict = PretrainedConfig().to_dict()
# get class specific config dict
# Get class-specific config dict if not part of a composition
class_config_dict = self.__class__().to_dict() if not self.is_composition else {}
serializable_config_dict = {}
@ -847,8 +848,7 @@ class PretrainedConfig(PushToHubMixin):
if not isinstance(self.quantization_config, dict)
else self.quantization_config
)
# pop the `_pre_quantization_dtype` as torch.dtypes are not serializable.
# Pop the `_pre_quantization_dtype` as torch.dtypes are not serializable.
_ = serializable_config_dict.pop("_pre_quantization_dtype", None)
self.dict_torch_dtype_to_str(serializable_config_dict)

View File

@ -52,6 +52,7 @@ if is_torch_available():
from ..cache_utils import (
HQQQuantizedCache,
HybridCache,
HybridChunkedCache,
MambaCache,
OffloadedStaticCache,
QuantizedCacheConfig,
@ -69,6 +70,7 @@ if is_torch_available():
"offloaded_static": OffloadedStaticCache,
"sliding_window": SlidingWindowCache,
"hybrid": HybridCache,
"hybrid_chunked": HybridChunkedCache,
"mamba": MambaCache,
}
QUANT_BACKEND_CLASSES_MAPPING = {"quanto": QuantoQuantizedCache, "HQQ": HQQQuantizedCache}
@ -416,6 +418,7 @@ class GenerationConfig(PushToHubMixin):
if isinstance(self.cache_config, dict):
self.cache_config = cache_config_class.from_dict(self.cache_config)
self.return_legacy_cache = kwargs.pop("return_legacy_cache", None)
self.prefill_chunk_size = kwargs.pop("prefill_chunk_size", None)
# Parameters for manipulation of the model output logits
self.temperature = kwargs.pop("temperature", 1.0)

View File

@ -1830,6 +1830,9 @@ class GenerationMixin:
Returns the resulting cache object.
"""
if cache_implementation == "hybrid" and "llama4" in getattr(self.config, "model_type", ""):
cache_implementation = "hybrid_chunked"
cache_cls: Cache = NEED_SETUP_CACHE_CLASSES_MAPPING[cache_implementation]
requires_cross_attention_cache = (
self.config.is_encoder_decoder or model_kwargs.get("encoder_outputs") is not None
@ -3405,7 +3408,12 @@ class GenerationMixin:
os.environ["TOKENIZERS_PARALLELISM"] = "0"
model_forward = self.get_compiled_call(generation_config.compile_config)
is_prefill = True
if generation_config.prefill_chunk_size is not None:
model_kwargs = self._prefill_chunking(input_ids, generation_config, **model_kwargs)
is_prefill = False
else:
is_prefill = True
while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
# prepare model inputs
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
@ -4855,6 +4863,45 @@ class GenerationMixin:
else:
return input_ids
def _prefill_chunking(self, input_ids: torch.LongTensor, generation_config: GenerationConfig, **model_kwargs):
# Even if we are not compiling the forward, flex is always compiled when used. With chunk prefill, we may
# end up needing just a bit more graphs than the default (which is 8). Doing this avoids very cryptic warnings
torch._dynamo.config.cache_size_limit = 64
chunk_size = generation_config.prefill_chunk_size
# Only chunk up the token just before last, so that decoding is completely performed outside this function
# (here we simply prefill the cache)
input_chunks = torch.split(input_ids[:, :-1], chunk_size, dim=-1)
if "past_key_values" not in model_kwargs:
raise ValueError("Cannot use prefill chunkink without a cache")
model_forward = self.get_compiled_call(generation_config.compile_config)
attention_mask = model_kwargs.pop("attention_mask", None)
past_length = 0
for input_chunk in input_chunks:
current_length = past_length + input_chunk.shape[-1]
# Prepare inputs
if attention_mask is not None:
model_kwargs["attention_mask"] = attention_mask[:, :current_length]
model_kwargs["cache_position"] = torch.arange(
past_length, current_length, dtype=torch.long, device=input_chunk.device
)
model_kwargs["position_ids"] = model_kwargs["cache_position"].unsqueeze(0)
model_inputs = self.prepare_inputs_for_generation(input_chunk, **model_kwargs)
outputs = model_forward(**model_inputs, return_dict=True)
model_kwargs["past_key_values"] = outputs.past_key_values
past_length = current_length
model_kwargs["attention_mask"] = attention_mask
model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + 1
_ = model_kwargs.pop("position_ids", None)
return model_kwargs
def _speculative_sampling(
candidate_input_ids,

View File

@ -53,7 +53,7 @@ _import_structure = {
"unset_hf_deepspeed_config",
],
"eetq": ["replace_with_eetq_linear"],
"fbgemm_fp8": ["FbgemmFp8Linear", "replace_with_fbgemm_fp8_linear"],
"fbgemm_fp8": ["FbgemmFp8Linear", "FbgemmFp8Llama4TextExperts", "replace_with_fbgemm_fp8_linear"],
"finegrained_fp8": ["FP8Linear", "replace_with_fp8_linear"],
"fsdp": ["is_fsdp_managed_module"],
"ggml": [
@ -192,7 +192,7 @@ if TYPE_CHECKING:
unset_hf_deepspeed_config,
)
from .eetq import replace_with_eetq_linear
from .fbgemm_fp8 import FbgemmFp8Linear, replace_with_fbgemm_fp8_linear
from .fbgemm_fp8 import FbgemmFp8Linear, FbgemmFp8Llama4TextExperts, replace_with_fbgemm_fp8_linear
from .finegrained_fp8 import FP8Linear, replace_with_fp8_linear
from .fsdp import is_fsdp_managed_module
from .ggml import (

View File

@ -0,0 +1,54 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from transformers.utils import is_torch_available
if is_torch_available():
import torch
import torch.nn as nn
from transformers.models.llama4.modeling_llama4 import Llama4TextMLP
def skip(*args, **kwargs):
pass
class CompressedExpertsLinear(nn.Module):
"""
A module that implements a compressed version of a list of expert modules.
This is specifically designed to work with Llama4TextExperts in MoE layers.
"""
def __init__(self, config):
# Skip random weight initialization for experts. Otherwise,
# the init of this module would take over minutes. For a model
# with tens of layers of experts, it would easily take over 20 minutes.
nn.init.kaiming_uniform_ = skip
nn.init.uniform_ = skip
nn.init.normal_ = skip
super().__init__()
self.num_experts = config.num_local_experts
self.expert_modules = nn.ModuleList([Llama4TextMLP(config) for _ in range(self.num_experts)])
def forward(
self,
hidden_states: torch.Tensor,
) -> torch.Tensor:
hidden_states = hidden_states.reshape(self.num_experts, -1, hidden_states.shape[-1])
expert_routed_out_list = []
for expert_idx in range(self.num_experts):
expert_routed_out_list.append(self.expert_modules[expert_idx](hidden_states[expert_idx]))
routed_out = torch.cat(expert_routed_out_list, dim=0)
return routed_out

View File

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from ..activations import ACT2FN
from ..utils import is_accelerate_available, is_fbgemm_gpu_available, is_torch_available, logging
@ -28,36 +29,36 @@ if is_fbgemm_gpu_available():
logger = logging.get_logger(__name__)
class FbgemmFp8Linear(torch.nn.Module):
class FbgemmFp8Linear(torch.nn.Linear):
def __init__(self, in_features, out_features, bias, weight_dtype=torch.float32):
super().__init__()
super().__init__(in_features, out_features, bias)
self.in_features = in_features
self.out_features = out_features
self.register_buffer("weight", torch.zeros((out_features, in_features), dtype=torch.float8_e4m3fn))
self.register_buffer("weight_scale", torch.zeros((out_features, 1), dtype=weight_dtype))
self.weight = torch.nn.Parameter(torch.zeros((out_features, in_features), dtype=torch.float8_e4m3fn))
self.weight_scale = torch.nn.Parameter(torch.zeros((out_features, 1), dtype=weight_dtype))
self.register_buffer("input_scale_ub", torch.zeros([1], dtype=torch.float), persistent=False)
if bias:
self.register_buffer("bias", torch.zeros((self.out_features), dtype=weight_dtype))
self.bias = torch.nn.Parameter(torch.zeros((self.out_features), dtype=weight_dtype))
else:
self.bias = None
def forward(self, x):
num_tokens = None
# quantize_fp8_per_row will squash the leading dimensions, so save the desired shape here
output_shape = (*x.shape[:-1], -1)
# x_quantized and x_scale are not necessarily on the same device as x, this is an issue.
# https://github.com/pytorch/FBGEMM/blob/e08af8539c391437f447173863df0f3f6f6f1855/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cu#L1237C3-L1237C45
x_quantized, x_scale = torch.ops.fbgemm.quantize_fp8_per_row(
x.view(-1, x.shape[-1]), num_tokens, self.input_scale_ub
x.view(-1, x.shape[-1]), scale_ub=self.input_scale_ub
)
# moving x_quantized, x_scale here creates glibberish output ... However, if we move the output, it works
# x_quantized, x_scale = x_quantized.to(x.device), x_scale.to(x.device)
# The computation still happens on the device where self.weight is even if x_quantized is not on the same device as self.weight
weight_scale_float32 = self.weight_scale.to(torch.float32)
output = torch.ops.fbgemm.f8f8bf16_rowwise(
x_quantized, self.weight, x_scale, self.weight_scale, use_fast_accum=True
x_quantized, self.weight, x_scale, weight_scale_float32, use_fast_accum=True
)
output = output + self.bias if self.bias is not None else output
# Hacky for now, we have the output to the device of x
@ -67,6 +68,92 @@ class FbgemmFp8Linear(torch.nn.Module):
return output
class FbgemmFp8Llama4TextExperts(nn.Module):
def __init__(self, config, dtype=torch.float32):
super().__init__()
self.num_experts = config.num_local_experts
self.intermediate_size = config.intermediate_size
self.hidden_size = config.hidden_size
self.expert_dim = self.intermediate_size
self.act_fn = ACT2FN[config.hidden_act]
# Register FP8 buffers for gate_up_proj
self.gate_up_proj = torch.nn.Parameter(
torch.zeros((self.num_experts, self.hidden_size, 2 * self.expert_dim), dtype=torch.float8_e4m3fn)
)
self.gate_up_proj_scale = torch.nn.Parameter(
torch.zeros((self.num_experts, 1, self.expert_dim * 2), dtype=torch.float32)
)
# Register FP8 buffers for down_proj
self.down_proj = torch.nn.Parameter(
torch.zeros((self.num_experts, self.expert_dim, self.hidden_size), dtype=torch.float8_e4m3fn)
)
self.down_proj_scale = torch.nn.Parameter(
torch.zeros((self.num_experts, self.hidden_size, 1), dtype=torch.float32)
)
# Register input scale upper bound
self.register_buffer("input_scale_ub", torch.zeros([1], dtype=torch.float), persistent=False)
def forward(self, hidden_states):
"""
Args:
hidden_states (torch.Tensor): (batch_size * token_num, hidden_size)
Returns:
torch.Tensor: (batch_size * token_num, hidden_size)
"""
# Reshape hidden states for expert computation
hidden_states = hidden_states.view(self.num_experts, -1, self.hidden_size)
num_tokens = None
# Pre-allocate tensor for all expert outputs with same shape as hidden_states
next_states = torch.empty_like(hidden_states)
for i in range(self.num_experts):
# Extract expert's hidden states
expert_hidden = hidden_states[i]
expert_hidden_reshaped = expert_hidden.reshape(-1, self.hidden_size)
# Quantize for this expert
expert_quantized, expert_scale = torch.ops.fbgemm.quantize_fp8_per_row(
expert_hidden_reshaped, num_tokens, self.input_scale_ub
)
sharded_expert_dim = self.gate_up_proj.shape[-1] // 2
gate_up_proj_scale_float32 = self.gate_up_proj_scale.to(torch.float32)
gate = torch.ops.fbgemm.f8f8bf16_rowwise(
expert_quantized,
self.gate_up_proj[i].transpose(0, 1)[:sharded_expert_dim].contiguous(),
expert_scale,
gate_up_proj_scale_float32[i][0][:sharded_expert_dim].view(-1, 1).contiguous(),
use_fast_accum=True,
)
up = torch.ops.fbgemm.f8f8bf16_rowwise(
expert_quantized,
self.gate_up_proj[i].transpose(0, 1)[sharded_expert_dim:].contiguous(),
expert_scale,
gate_up_proj_scale_float32[i][0][sharded_expert_dim:].view(-1, 1).contiguous(),
use_fast_accum=True,
)
activated = up * self.act_fn(gate)
activated_quantized, activated_scale = torch.ops.fbgemm.quantize_fp8_per_row(
activated, num_tokens, self.input_scale_ub
)
down_proj_scale_float32 = self.down_proj_scale.to(torch.float32)
expert_output = torch.ops.fbgemm.f8f8bf16_rowwise(
activated_quantized,
self.down_proj[i].transpose(0, 1).contiguous(),
activated_scale,
down_proj_scale_float32[i].view(-1, 1).contiguous(),
use_fast_accum=True,
)
next_states[i] = expert_output
next_states = next_states.to(hidden_states.device)
return next_states.view(-1, self.hidden_size)
def _replace_with_fbgemm_fp8_linear(
model,
modules_to_not_convert=None,
@ -74,12 +161,17 @@ def _replace_with_fbgemm_fp8_linear(
quantization_config=None,
has_been_replaced=False,
pre_quantized=False,
config=None,
tp_plan=None,
):
"""
Private method that wraps the recursion for module replacement.
Returns the converted model and a boolean that indicates if the conversion has been successfull or not.
"""
import re
if current_key_name is None:
current_key_name = []
@ -105,9 +197,27 @@ def _replace_with_fbgemm_fp8_linear(
# Force requires grad to False to avoid unexpected errors
model._modules[name].requires_grad_(False)
# set non persistant buffer outside of init_empty_weights
model._modules[name].input_scale_ub = torch.tensor(
[quantization_config.activation_scale_ub],
dtype=torch.float,
)
if module.__class__.__name__ == "Llama4TextExperts" and name not in modules_to_not_convert:
current_key_name_str = ".".join(current_key_name)
if not any(
(key + "." in current_key_name_str) or (key == current_key_name_str) for key in modules_to_not_convert
):
with init_empty_weights(include_buffers=True):
tp_plan[re.sub(r"\d+", "*", current_key_name_str + ".gate_up_proj_scale")] = tp_plan[
re.sub(r"\d+", "*", current_key_name_str + ".gate_up_proj")
]
tp_plan[re.sub(r"\d+", "*", current_key_name_str + ".down_proj_scale")] = None
model._modules[name] = FbgemmFp8Llama4TextExperts(
config.text_config,
)
model._modules[name].input_scale_ub = torch.tensor(
[quantization_config.activation_scale_ub], dtype=torch.float
)
if len(list(module.children())) > 0:
_, has_been_replaced = _replace_with_fbgemm_fp8_linear(
module,
@ -116,6 +226,8 @@ def _replace_with_fbgemm_fp8_linear(
quantization_config,
has_been_replaced=has_been_replaced,
pre_quantized=pre_quantized,
config=config,
tp_plan=tp_plan,
)
# Remove the last key for recursion
current_key_name.pop(-1)
@ -123,7 +235,13 @@ def _replace_with_fbgemm_fp8_linear(
def replace_with_fbgemm_fp8_linear(
model, modules_to_not_convert=None, current_key_name=None, quantization_config=None, pre_quantized=False
model,
modules_to_not_convert=None,
current_key_name=None,
quantization_config=None,
pre_quantized=False,
config=None,
tp_plan=None,
):
"""
A helper function to replace all `torch.nn.Linear` modules by `FbgemmFp8Linear` modules.
@ -151,9 +269,14 @@ def replace_with_fbgemm_fp8_linear(
modules_to_not_convert.extend(quantization_config.modules_to_not_convert)
modules_to_not_convert = list(set(modules_to_not_convert))
model, has_been_replaced = _replace_with_fbgemm_fp8_linear(
model, modules_to_not_convert, current_key_name, quantization_config, pre_quantized=pre_quantized
model,
modules_to_not_convert,
current_key_name,
quantization_config,
pre_quantized=pre_quantized,
config=config,
tp_plan=tp_plan,
)
if not has_been_replaced:
logger.warning(
"You are loading your model using FP8 quantization but no linear modules were found in your model."

View File

@ -34,10 +34,7 @@ from ..utils import is_torch_flex_attn_available
if is_torch_flex_attn_available():
from torch.nn.attention.flex_attention import (
BlockMask,
flex_attention,
)
from torch.nn.attention.flex_attention import BlockMask, flex_attention
from torch.nn.attention.flex_attention import (
create_block_mask as create_block_causal_mask_flex,
)
@ -64,14 +61,23 @@ class WrappedFlexAttention:
Initialize or update the singleton instance.
"""
if self._is_flex_compiled is False:
self._compiled_flex_attention = torch.compile(flex_attention, dynamic=False)
self._compiled_flex_attention = torch.compile(flex_attention, backend="inductor")
self._is_flex_compiled = True
def __call__(self):
return self._compiled_flex_attention
def make_flex_block_causal_mask(attention_mask_2d: torch.Tensor) -> "BlockMask":
Offset = Union[torch.Tensor, int]
def make_flex_block_causal_mask(
attention_mask_2d: torch.Tensor,
attention_chunk_size: Optional[int] = None,
query_length=None,
key_length=None,
offsets: Optional[Tuple[Offset, Offset]] = None,
) -> "BlockMask":
"""
Create a block causal document mask for a batch of sequences, both packed and unpacked.
Create Block causal logic and passing it into :func:`torch.nn.attention.flex_attention.create_block_mask`.
@ -94,10 +100,13 @@ def make_flex_block_causal_mask(attention_mask_2d: torch.Tensor) -> "BlockMask":
Returns:
BlockMask
"""
attention_mask_2d = torch.nn.functional.pad(attention_mask_2d, value=0, pad=(0, key_length))
device = attention_mask_2d.device
document_ids = attention_mask_2d.clone()
document_ids = attention_mask_2d
batch_size, total_seq_len = document_ids.shape
if attention_chunk_size is not None:
# we create an arange, then we just // by chunk size to get [0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3]
document_ids = (document_ids.fill_(1).cumsum(-1) - 1) // (attention_chunk_size)
# Instead of passing a tensor mask, flex attention requires a mask_mod function
# that determines which elements of QK^T should be included in the attention
@ -112,18 +121,30 @@ def make_flex_block_causal_mask(attention_mask_2d: torch.Tensor) -> "BlockMask":
See :func:`~torchtune.modules.attention_utils.create_block_causal_mask`
for an illustration.
"""
causal_mask = q_idx >= kv_idx
causal_mask = q_idx >= kv_idx # not valid when decoding
document_mask = document_ids[batch_idx, q_idx] == document_ids[batch_idx, kv_idx]
padding_mask = document_ids[batch_idx, q_idx] > 0
return causal_mask & document_mask & padding_mask
padding_mask = attention_mask_2d[batch_idx, q_idx] > 0
final_mask = causal_mask & padding_mask & document_mask
return final_mask
if offsets is not None:
q_offset = offsets[0]
kv_offset = offsets[1]
def mask_mod(batch_idx, head_idx, q_idx, kv_idx):
offset_q = q_idx + q_offset
offset_kv = kv_idx + kv_offset
return causal_mask_mod(batch_idx, head_idx, offset_q, offset_kv)
else:
mask_mod = causal_mask_mod
return create_block_causal_mask_flex(
mask_mod=causal_mask_mod,
B=batch_size,
mask_mod=mask_mod,
B=1,
H=None, # attention head
Q_LEN=total_seq_len,
KV_LEN=total_seq_len,
Q_LEN=query_length,
KV_LEN=key_length,
device=device,
_compile=True,
)
@ -144,6 +165,18 @@ def compile_friendly_flex_attention(
)
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
def flex_attention_forward(
module: torch.nn.Module,
query: torch.Tensor,
@ -174,14 +207,25 @@ def flex_attention_forward(
score = score + head_mask[batch_idx][head_idx][0][0]
return score
enable_gqa = True
num_local_query_heads = query.shape[1]
# When running TP this helps:
if not ((num_local_query_heads & (num_local_query_heads - 1)) == 0):
key = repeat_kv(key, query.shape[1] // key.shape[1])
value = repeat_kv(value, query.shape[1] // value.shape[1])
enable_gqa = False
kernel_options = kwargs.get("kernel_options", None)
attn_output, attention_weights = compile_friendly_flex_attention(
query,
key,
value,
score_mod=score_mod,
block_mask=block_mask,
enable_gqa=True,
enable_gqa=enable_gqa,
scale=scaling,
kernel_options=kernel_options,
# Last time checked on PyTorch == 2.5.1: Flex Attention always computes the lse regardless.
# For simplification, we thus always return it as no additional computations are introduced.
return_lse=True,

View File

@ -31,7 +31,7 @@ def sdpa_attention_forward(
value = repeat_kv(value, module.num_key_value_groups)
causal_mask = attention_mask
if attention_mask is not None:
if attention_mask is not None and causal_mask.ndim == 4:
causal_mask = causal_mask[:, :, :, : key.shape[-2]]
# SDPA with memory-efficient backend is bugged with non-contiguous inputs and custom attn_mask for some torch versions

View File

@ -61,6 +61,21 @@ def _blocks_to_block_sizes(total_size: int, blocks: Union[int, List[int]]) -> Li
return [single_size] * blocks
str_to_torch_dtype = {
"BOOL": torch.bool,
"U8": torch.uint8,
"I8": torch.int8,
"I16": torch.int16,
"F16": torch.float16,
"BF16": torch.bfloat16,
"I32": torch.int32,
"F32": torch.float32,
"F64": torch.float64,
"I64": torch.int64,
"F8_E4M3": torch.float8_e4m3fn,
}
def get_packed_weights(param, empty_param, device_mesh, rank, dim):
"""
When weights are packed (gate_up_proj), we need to make sure each shard gets its correct share.
@ -106,6 +121,12 @@ def get_packed_weights(param, empty_param, device_mesh, rank, dim):
tensors_slices += range(block_offset + start, block_offset + stop)
block_offset += block_size
slice_dtype = slice_.get_dtype()
# Handle F8_E4M3 dtype by converting to float16 before slicing
# Without upcasting, the slicing causes : RuntimeError: "index_cpu" not implemented for 'Float8_e4m3fn'
if slice_dtype == "F8_E4M3":
slice_ = slice_[...].to(torch.float16)
if dim == 0:
tensor = slice_[tensors_slices, ...]
elif dim == 1 or dim == -2:
@ -114,7 +135,7 @@ def get_packed_weights(param, empty_param, device_mesh, rank, dim):
tensor = slice_[..., tensors_slices]
else:
raise ValueError(f"Unsupported dim {dim}, only dim 0, 1 or 2 are supported")
return tensor
return tensor.to(str_to_torch_dtype[slice_dtype])
def get_tensor_shard(param, empty_param, device_mesh, rank, dim):
@ -199,11 +220,12 @@ class GatherParallel(TensorParallelLayer):
@staticmethod
def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_mesh):
if isinstance(inputs[0], DTensor):
inputs[0] = inputs[0].to_local()
inputs = inputs[0].to_local()
return inputs
@staticmethod
def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh):
# this op cannot be asynch, otherwise it completely breaks the outputs of models
torch.distributed.all_reduce(outputs[0], op=torch.distributed.ReduceOp.SUM, async_op=False)
return outputs
@ -266,7 +288,7 @@ class ColwiseParallel(TensorParallelLayer):
# transform the input layouts to the desired layouts of ColwiseParallel
if input_layouts != desired_input_layouts:
input_tensor = input_tensor.redistribute(placements=desired_input_layouts, async_op=True)
input_tensor = input_tensor.redistribute(placements=desired_input_layouts, async_op=False)
return input_tensor
def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh):
@ -291,7 +313,7 @@ class ColwiseParallel(TensorParallelLayer):
def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh):
# outputs is a shard on last dimension DTensor, i.e. Shard(-1)
if outputs.placements != output_layouts:
outputs = outputs.redistribute(placements=output_layouts, async_op=True)
outputs = outputs.redistribute(placements=output_layouts, async_op=False)
# back to local tensor
return outputs.to_local() if use_local_output else outputs
@ -343,16 +365,6 @@ class RowwiseParallel(TensorParallelLayer):
self.use_local_output = use_local_output
self.use_dtensor = use_dtensor
@staticmethod
def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_mesh):
input_tensor = inputs[0]
if not isinstance(input_tensor, DTensor):
input_tensor = DTensor.from_local(input_tensor, device_mesh, input_layouts, run_check=False)
if input_layouts != desired_input_layouts:
input_tensor = input_tensor.redistribute(placements=desired_input_layouts, async_op=True)
return input_tensor
def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh):
# Rowwise shard weight to Shard(1), bias to Replicate(), weight be Shard(1)
# means Rowwise as nn.Linear is input * weight^T + bias, where
@ -371,6 +383,20 @@ class RowwiseParallel(TensorParallelLayer):
parameter = DTensor.from_local(parameter, device_mesh, shard, run_check=False)
return nn.Parameter(parameter)
@staticmethod
def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_mesh):
if hasattr(mod, "bias") and mod.bias is not None:
mod._bias = mod.bias
mod.bias = None
input_tensor = inputs[0]
if not isinstance(input_tensor, DTensor):
input_tensor = DTensor.from_local(input_tensor, device_mesh, input_layouts, run_check=False)
if input_layouts != desired_input_layouts:
input_tensor = input_tensor.redistribute(placements=desired_input_layouts, async_op=True)
return input_tensor
@staticmethod
def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh):
# Rowwise sharding produces partial output, depending on output layouts:
@ -378,6 +404,8 @@ class RowwiseParallel(TensorParallelLayer):
# 2. to shard -> reduce_scatter
if outputs.placements != output_layouts:
outputs = outputs.redistribute(placements=output_layouts, async_op=True)
if hasattr(mod, "_bias"):
outputs += mod._bias
# back to local tensor if use_local_output is True
return outputs.to_local() if use_local_output else outputs
@ -418,6 +446,90 @@ class PackedRowwiseParallel(RowwiseParallel):
return nn.Parameter(parameter)
class SequenceParallel(TensorParallelLayer):
"""
SequenceParallel replicates a compatible ``nn.Module`` parameters and runs the sharded computation with
input sharded on the sequence dimension. This currently supports ``nn.LayerNorm``, ``nn.Dropout``, and the
`RMSNorm python implementation <https://github.com/facebookresearch/llama/blob/main/llama/model.py#L34>`__
This style implements the operation that is described in the paper
`Reducing Activation Recomputation in Large Transformer Models <https://arxiv.org/abs/2205.05198>`__
If the input passed in to this ``nn.Module`` is a :class:`torch.Tensor`, it assumes that the input is already sharded
on the sequence dimension and converts the input to a :class:`DTensor` sharded on the sequence dimension. If the input
passed in to this ``nn.Module`` is already a :class:`DTensor` but is not sharded on the sequence dimension, it would
redistribute the input to be sharded on the sequence dimension.
The output of the ``nn.Module`` will be sharded on the sequence dimension.
Keyword Args:
sequence_dim (int, optional):
The sequence dimension of the input tensor for the ``nn.Module``, this is used to annotate the input tensor to
become a DTensor that is sharded on the sequence dimension, default: 1.
use_local_output (bool, optional):
Whether to use local :class:`torch.Tensor` instead of :class:`DTensor` for the module output, default: False.
Returns:
A :class:`ParallelStyle` object that represents Sequence Parallel of the ``nn.Module``.
Example::
>>> # xdoctest: +SKIP(failing)
>>> from torch.distributed.tensor.parallel import parallelize_module, SequenceParallel
>>> from torch.distributed.device_mesh import init_device_mesh
>>> ...
>>> m = Model(...) # m is a nn.Module that contains a "norm" nn.LayerNorm submodule
>>> tp_mesh = init_device_mesh("cuda", (8,))
>>>
>>> # By default, the input of the "norm" will be converted to DTensor that shards on the sequence dim
>>> # and the output of "norm" will return a sharded on sequence dimension :class:`DTensor`.
>>>
>>> sharded_mod = parallelize_module(m, tp_mesh, {"norm": SequenceParallel()}),
>>> ...
.. note:: SequenceParallel style assumes ones initialization if there are weights in the nn.Module (i.e.
``nn.LayerNorm`` or ``RMSNorm``, and they by default have ones initialization). If you have custom
inits for the weights on those modules, you need to broadcast the weights before/after parallelizing
to ensure that they are replicated.
"""
def __init__(self, *, sequence_dim: int = 1, use_local_output: bool = False, use_dtensor=False):
super().__init__()
self.input_layouts = (Replicate(),)
self.desired_input_layouts = (Shard(1),)
self.output_layouts = (Replicate(),)
self.use_local_output = use_local_output
self.use_dtensor = True
self.sequence_sharding = (Shard(sequence_dim),)
self.use_local_output = use_local_output
@staticmethod
def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_mesh):
input_tensor = inputs[0]
if not isinstance(input_tensor, DTensor):
input_tensor = DTensor.from_local(input_tensor, device_mesh, input_layouts, run_check=False)
if input_layouts != desired_input_layouts:
input_tensor = input_tensor.redistribute(placements=desired_input_layouts, async_op=True)
return input_tensor
@staticmethod
def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh):
outputs = outputs.redistribute(
placements=(Replicate(),), async_op=True
) # maybe we have to replicate ? because next layer is not sharded
return outputs.to_local() # if use_local_output else outputs
def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh):
# colwise shard weight/bias to Shard(0), weight be Shard(-2) (0 if you have 1 dim only)
# means Colwise as Linear is input * weight^T + bias, where
# weight would become Shard(1)
parameter = param[:]
parameter = parameter.to(param_casting_dtype)
if to_contiguous:
parameter = parameter.contiguous()
if self.use_dtensor:
parameter = DTensor.from_local(parameter, device_mesh, [Replicate()], run_check=False)
return nn.Parameter(parameter)
SUPPORTED_TP_STYLES = {
"colwise",
"rowwise",
@ -428,6 +540,7 @@ SUPPORTED_TP_STYLES = {
"local",
"gather",
"local_packed_rowwise",
"sequence_parallel",
}
@ -459,6 +572,8 @@ def translate_to_torch_parallel_style(style: str):
return GatherParallel()
elif style == "local_packed_rowwise":
return PackedRowwiseParallel(use_dtensor=False)
elif style == "sequence_parallel":
return SequenceParallel()
else:
raise ValueError(f"Unsupported parallel style value: {style}")
@ -518,6 +633,7 @@ def shard_and_distribute_module(
tp_plan = model._tp_plan
module_to_tp = model.get_submodule(param_name)
current_module_plan = None
rank = int(rank)
generic_param_name = re.sub(r"\d+", "*", parameter_name)
if generic_param_name in tp_plan:
current_module_plan = tp_plan[generic_param_name]
@ -531,12 +647,18 @@ def shard_and_distribute_module(
module_to_tp._is_hooked = True
if current_module_plan is not None:
tp_layer = translate_to_torch_parallel_style(current_module_plan)
param = tp_layer.partition_tensor(
param, empty_param, param_type, param_casting_dtype, is_contiguous, rank, device_mesh
)
try:
tp_layer = translate_to_torch_parallel_style(current_module_plan)
param = tp_layer.partition_tensor(
param, empty_param, param_type, param_casting_dtype, is_contiguous, rank, device_mesh
)
except NotImplementedError as e:
print(
f"Trying to prepare {parameter_name}, but it's not supported. Corresponding module: {module_to_tp} Fix it's TP plan, current layer: {tp_layer} : {e}"
)
else:
# TODO log no plan modules in set
# print("No plan for", parameter_name,end ="\n")
param = param[...].to(param_casting_dtype)
if is_contiguous:
param = param.contiguous()

View File

@ -484,6 +484,7 @@ str_to_torch_dtype = {
"F32": torch.float32,
"F64": torch.float64,
"I64": torch.int64,
"F8_E4M3": torch.float8_e4m3fn,
}
if is_torch_greater_or_equal("2.1.0"):
@ -1914,16 +1915,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
)
# If current model is a base model, attach `base_model_tp_plan` and `base_model_pp_plan` from config
if self.base_model is self:
self._pp_plan = (
self.config.base_model_pp_plan.copy() if self.config.base_model_pp_plan is not None else None
)
self._tp_plan = self.config.base_model_tp_plan.copy() if self.config.base_model_tp_plan is not None else {}
else:
self._tp_plan = self._tp_plan or {}
for name, module in self.named_children():
if plan := getattr(module, "_tp_plan", None):
self._tp_plan.update({f"{name}.{k}": v for k, v in plan.items()})
self._pp_plan = self.config.base_model_pp_plan.copy() if self.config.base_model_pp_plan is not None else None
self._tp_plan = self.config.base_model_tp_plan.copy() if self.config.base_model_tp_plan is not None else {}
for name, module in self.named_children():
if plan := getattr(module, "_tp_plan", None):
self._tp_plan.update({f"{name}.{k}": v for k, v in plan.copy().items()})
if self._tp_plan is not None and is_torch_greater_or_equal("2.3"):
for _, v in self._tp_plan.items():
@ -4054,6 +4050,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
import sys
sys.stdout = open(os.devnull, "w")
sys.stderr = open(os.devnull, "w")
# This is the easiest way to dispatch to the current process device
device_map = tp_device
# Assuming sharding the model onto the world
@ -4238,6 +4235,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
)
torch_dtype = hf_quantizer.update_torch_dtype(torch_dtype)
device_map = hf_quantizer.update_device_map(device_map)
config = hf_quantizer.update_tp_plan(config)
# In order to ensure popular quantization methods are supported. Can be disable with `disable_telemetry`
if hasattr(hf_quantizer.quantization_config.quant_method, "value"):
@ -4370,9 +4368,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
if hf_quantizer is not None:
hf_quantizer.preprocess_model(
model=model, device_map=device_map, keep_in_fp32_modules=model._keep_in_fp32_modules
model=model, device_map=device_map, keep_in_fp32_modules=model._keep_in_fp32_modules, config=config
)
# We store the original dtype for quantized models as we cannot easily retrieve it
# once the weights have been quantized
# Note that once you have loaded a quantized model, you can't change its dtype so this will
@ -4901,7 +4898,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
name,
casting_dtype,
to_contiguous,
tp_device.index,
os.environ["RANK"],
device_mesh,
)
@ -5174,6 +5171,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
want to use compiled version to avoid recomputing the graph with new shapes) and iterative decoding
(where we want the speed-ups of compiled version with static shapes)."""
# Only reset it if not present or different from previous config
if "llama4" in self.config.model_type: # TODO try to enable for FULL COMPILE HYBRID CACHE SUPPORT
return self.__call__
default_config = getattr(self.generation_config, "compile_config", CompileConfig())
if (
not hasattr(self, "_compiled_call")

View File

@ -148,6 +148,7 @@ from . import (
levit,
lilt,
llama,
llama4,
llava,
llava_next,
llava_next_video,

View File

@ -544,10 +544,6 @@ class _BaseAutoModelClass:
if kwargs_orig.get("quantization_config", None) is not None:
kwargs["quantization_config"] = kwargs_orig["quantization_config"]
# AutoClass-specific config manipulation
config = copy.deepcopy(config)
config = cls._prepare_config_for_auto_class(config)
has_remote_code = hasattr(config, "auto_map") and cls.__name__ in config.auto_map
has_local_code = type(config) in cls._model_mapping.keys()
trust_remote_code = resolve_trust_remote_code(
@ -570,6 +566,8 @@ class _BaseAutoModelClass:
)
elif type(config) in cls._model_mapping.keys():
model_class = _get_model_class(config, cls._model_mapping)
if model_class.config_class == config.sub_configs.get("text_config", None):
config = config.get_text_config()
return model_class.from_pretrained(
pretrained_model_name_or_path, *model_args, config=config, **hub_kwargs, **kwargs
)

View File

@ -170,6 +170,8 @@ CONFIG_MAPPING_NAMES = OrderedDict(
("levit", "LevitConfig"),
("lilt", "LiltConfig"),
("llama", "LlamaConfig"),
("llama4", "Llama4Config"),
("llama4_text", "Llama4TextConfig"),
("llava", "LlavaConfig"),
("llava_next", "LlavaNextConfig"),
("llava_next_video", "LlavaNextVideoConfig"),
@ -519,6 +521,8 @@ MODEL_NAMES_MAPPING = OrderedDict(
("llama", "LLaMA"),
("llama2", "Llama2"),
("llama3", "Llama3"),
("llama4", "Llama4"),
("llama4_text", "Llama4ForCausalLM"),
("llava", "LLaVa"),
("llava_next", "LLaVA-NeXT"),
("llava_next_video", "LLaVa-NeXT-Video"),
@ -776,6 +780,7 @@ SPECIAL_MODEL_TYPE_TO_MODULE_NAME = OrderedDict(
("rt_detr_resnet", "rt_detr"),
("granitevision", "llava_next"),
("sam_vision_model", "sam"),
("llama4_text", "llama4"),
]
)

View File

@ -104,6 +104,7 @@ else:
("layoutlmv2", ("LayoutLMv2ImageProcessor",)),
("layoutlmv3", ("LayoutLMv3ImageProcessor",)),
("levit", ("LevitImageProcessor",)),
("llama4", ("Llama4ImageProcessor", "Llama4ImageProcessorFast")),
("llava", ("LlavaImageProcessor", "LlavaImageProcessorFast")),
("llava_next", ("LlavaNextImageProcessor", "LlavaNextImageProcessorFast")),
("llava_next_video", ("LlavaNextVideoImageProcessor",)),

View File

@ -17,7 +17,6 @@
import warnings
from collections import OrderedDict
from ...configuration_utils import PretrainedConfig
from ...utils import logging
from .auto_factory import (
_BaseAutoBackboneClass,
@ -161,6 +160,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
("levit", "LevitModel"),
("lilt", "LiltModel"),
("llama", "LlamaModel"),
("llama4", "Llama4ForConditionalGeneration"),
("longformer", "LongformerModel"),
("longt5", "LongT5Model"),
("luke", "LukeModel"),
@ -547,6 +547,8 @@ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
("jamba", "JambaForCausalLM"),
("jetmoe", "JetMoeForCausalLM"),
("llama", "LlamaForCausalLM"),
("llama4", "Llama4ForCausalLM"),
("llama4_text", "Llama4ForCausalLM"),
("mamba", "MambaForCausalLM"),
("mamba2", "Mamba2ForCausalLM"),
("marian", "MarianForCausalLM"),
@ -634,6 +636,7 @@ MODEL_FOR_IMAGE_MAPPING_NAMES = OrderedDict(
("ijepa", "IJepaModel"),
("imagegpt", "ImageGPTModel"),
("levit", "LevitModel"),
("llama4", "Llama4VisionModel"),
("mllama", "MllamaVisionModel"),
("mobilenet_v1", "MobileNetV1Model"),
("mobilenet_v2", "MobileNetV2Model"),
@ -849,6 +852,7 @@ MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES = OrderedDict(
("idefics3", "Idefics3ForConditionalGeneration"),
("instructblip", "InstructBlipForConditionalGeneration"),
("kosmos-2", "Kosmos2ForConditionalGeneration"),
("llama4", "Llama4ForConditionalGeneration"),
("llava", "LlavaForConditionalGeneration"),
("llava_next", "LlavaNextForConditionalGeneration"),
("llava_onevision", "LlavaOnevisionForConditionalGeneration"),
@ -1492,6 +1496,7 @@ MODEL_FOR_TEXT_ENCODING_MAPPING_NAMES = OrderedDict(
("emu3", "Emu3TextModel"),
("flaubert", "FlaubertModel"),
("ibert", "IBertModel"),
("llama4", "Llama4TextModel"),
("longformer", "LongformerModel"),
("mllama", "MllamaTextModel"),
("mobilebert", "MobileBertModel"),
@ -1678,30 +1683,6 @@ _AutoModelWithLMHead = auto_class_update(_AutoModelWithLMHead, head_doc="languag
class AutoModelForCausalLM(_BaseAutoModelClass):
_model_mapping = MODEL_FOR_CAUSAL_LM_MAPPING
@classmethod
def _prepare_config_for_auto_class(cls, config: PretrainedConfig) -> PretrainedConfig:
"""
Additional autoclass-specific config post-loading manipulation. In this specific autoclass, if the config has
a nested text decoder section, uses that section instead.
Under the hood, multimodal models mapped by AutoModelForCausalLM assume the text decoder receives its own
config, rather than the config for the whole model. This is used e.g. to load the text-only part of a VLM.
"""
possible_text_config_names = ("decoder", "generator", "text_config")
text_config_names = []
for text_config_name in possible_text_config_names:
if hasattr(config, text_config_name):
text_config_names += [text_config_name]
text_config = config.get_text_config(decoder=True)
if text_config_names and type(text_config) in cls._model_mapping.keys():
warnings.warn(
"Loading a multimodal model with `AutoModelForCausalLM` is deprecated and will be removed in v5. "
"`AutoModelForCausalLM` will be used to load only the text-to-text generation module.",
FutureWarning,
)
return config
AutoModelForCausalLM = auto_class_update(AutoModelForCausalLM, head_doc="causal language modeling")

View File

@ -77,6 +77,7 @@ PROCESSOR_MAPPING_NAMES = OrderedDict(
("kosmos-2", "Kosmos2Processor"),
("layoutlmv2", "LayoutLMv2Processor"),
("layoutlmv3", "LayoutLMv3Processor"),
("llama4", "Llama4Processor"),
("llava", "LlavaProcessor"),
("llava_next", "LlavaNextProcessor"),
("llava_next_video", "LlavaNextVideoProcessor"),

View File

@ -292,6 +292,20 @@ else:
"LlamaTokenizerFast" if is_tokenizers_available() else None,
),
),
(
"llama4",
(
"LlamaTokenizer" if is_sentencepiece_available() else None,
"LlamaTokenizerFast" if is_tokenizers_available() else None,
),
),
(
"llama4_text",
(
"LlamaTokenizer" if is_sentencepiece_available() else None,
"LlamaTokenizerFast" if is_tokenizers_available() else None,
),
),
("llava", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
("llava_next", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
("llava_next_video", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),

View File

@ -0,0 +1,29 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
from ...utils import _LazyModule
from ...utils.import_utils import define_import_structure
if TYPE_CHECKING:
from .configuration_llama4 import *
from .image_processing_llama4_fast import *
from .modeling_llama4 import *
from .processing_llama4 import *
else:
import sys
_file = globals()["__file__"]
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)

View File

@ -0,0 +1,432 @@
# coding=utf-8
# Copyright 2025 The LLAMA4 and HuggingFace Inc. team. All rights reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ...configuration_utils import PretrainedConfig
from ...utils import logging
logger = logging.get_logger(__name__)
class Llama4VisionConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`Llama4VisionModel`]. It is used to instantiate a
Llama4 vision model according to the specified arguments, defining the model architecture. Instantiating a configuration
with the defaults will yield a similar configuration to that of the Llama4 109B.
e.g. [meta-llama/Llama-4-Scout-17B-16E](https://huggingface.co/meta-llama/Llama-4-Scout-17B-16E)
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
hidden_size (`int`, *optional*, defaults to 768):
Dimensionality of the encoder layers and the pooler layer.
hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
`"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported.
num_hidden_layers (`int`, *optional*, defaults to 34):
Number of hidden layers in the Transformer encoder.
num_attention_heads (`int`, *optional*, defaults to 16):
Number of attention heads for each attention layer in the Transformer encoder.
num_channels (`int`, *optional*, defaults to 3):
Number of channels in the input image.
intermediate_size (`int`, *optional*, defaults to 5632):
Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder.
vision_output_dim (`int`, *optional*, defaults to 7680):
Dimensionality of the vision model output. Includes output of transformer
encoder with intermediate layers and global transformer encoder.
image_size (`int`, *optional*, defaults to 448):
The size (resolution) of each image *tile*.
patch_size (`int`, *optional*, defaults to 14):
The size (resolution) of each patch.
norm_eps (`float`, *optional*, defaults to 1e-05):
The epsilon used by the layer normalization layers.
vision_feature_layer (``, *optional*, defaults to -1): TODO
vision_feature_select_strategy (`int`, *optional*, defaults to `"default"`): TODO
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
pixel_shuffle_ratio (`int`, *optional*, defaults to 0.5): TODO
projector_input_dim (`int`, *optional*, defaults to 4096): TODO
projector_output_dim (`int`, *optional*, defaults to 4096): TODO
multi_modal_projector_bias (`int`, *optional*, defaults to `False`): TODO
projector_dropout (`int`, *optional*, defaults to 0.0): TODO
attention_dropout (`int`, *optional*, defaults to 0.0): TODO
rope_theta (`int`, *optional*, defaults to 10000): TODO
"""
base_model_tp_plan = {
"model.layers.*.self_attn.q_proj": "colwise",
"model.layers.*.self_attn.k_proj": "colwise",
"model.layers.*.self_attn.v_proj": "colwise",
"model.layers.*.self_attn.o_proj": "rowwise",
"vision_adapter.mlp.fc1": "colwise",
"vision_adapter.mlp.fc2": "rowwise",
"patch_embedding.linear": "colwise_rep",
}
model_type = "llama4_vision_model"
base_config_key = "vision_config"
def __init__(
self,
hidden_size: int = 768,
hidden_act: str = "gelu",
num_hidden_layers: int = 34,
num_attention_heads: int = 16,
num_channels: int = 3,
intermediate_size: int = 5632,
vision_output_dim: int = 7680,
image_size: int = 448,
patch_size: int = 14,
norm_eps: float = 1e-5,
vision_feature_layer=-1,
vision_feature_select_strategy="default",
initializer_range: float = 0.02,
pixel_shuffle_ratio=0.5,
projector_input_dim=4096,
projector_output_dim=4096,
multi_modal_projector_bias=False,
projector_dropout=0.0,
attention_dropout=0.0,
rope_theta=10000,
**kwargs,
):
self.hidden_size = hidden_size
self.hidden_act = hidden_act
self.num_hidden_layers = num_hidden_layers
self.num_channels = num_channels
self.intermediate_size = intermediate_size
self.image_size = image_size
self.vision_output_dim = vision_output_dim
self.patch_size = patch_size
self.norm_eps = norm_eps
self.num_attention_heads = num_attention_heads
self.initializer_range = initializer_range
self.pixel_shuffle_ratio = pixel_shuffle_ratio
self.projector_input_dim = projector_input_dim
self.projector_output_dim = projector_output_dim
self.multi_modal_projector_bias = multi_modal_projector_bias
self.projector_dropout = projector_dropout
self.attention_dropout = attention_dropout
self.vision_feature_layer = vision_feature_layer
self.vision_feature_select_strategy = vision_feature_select_strategy
self.rope_theta = rope_theta
super().__init__(**kwargs)
class Llama4TextConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`Llama4TextModel`]. It is used to instantiate a
Llama4 text model according to the specified arguments, defining the model architecture. Instantiating a configuration
with the defaults will yield a similar configuration to that of the Llama4 109B.
e.g. [meta-llama/Llama-4-Scout-17B-16E](https://huggingface.co/meta-llama/Llama-4-Scout-17B-16E)
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 202048):
Vocabulary size of the Llama4 text model. Defines the maximum number of different tokens that can be represented
by the `inputs_ids` passed when calling [`Llama4TextModel`].
hidden_size (`int`, *optional*, defaults to 5120):
Dimensionality of the embeddings and hidden states.
intermediate_size (`int`, *optional*, defaults to 8192):
Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder.
intermediate_size_mlp (`int`, *optional*, defaults to 16384): TODO
num_hidden_layers (`int`, *optional*, defaults to 48):
Number of hidden layers in the Transformer encoder.
num_attention_heads (`int`, *optional*, defaults to 40):
Number of attention heads for each attention layer in the Transformer encoder.
num_key_value_heads (`int`, *optional*, defaults to 8):
This is the number of key_value heads that should be used to implement Grouped Query Attention. If not
specified, will default to `num_attention_heads`.
head_dim (`int`, *optional*, defaults to 128): TODO
hidden_act (`str` or `Callable`, *optional*, defaults to `"silu"`):
The non-linear activation function (function or string) in the encoder and pooler.
max_position_embeddings (`int`, *optional*, defaults to 131072):
The maximum sequence length that this model might ever be used with.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
rms_norm_eps (`float`, *optional*, defaults to 1e-05):
The epsilon used by the rms normalization layers.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions.
pad_token_id (`int`, *optional*, defaults to 128004):
The id of the padding token.
bos_token_id (`int`, *optional*, defaults to 1):
The id of the beginning of sentence token.
eos_token_id (`int`, *optional*, defaults to 2):
The id of the end of sentence token.
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
Whether to tie weight embeddings
rope_theta (`float`, *optional*, defaults to `500000.0`):
The base period of the RoPE embeddings.
attention_dropout (`int`, *optional*, defaults to 0.0): TODO
num_experts_per_tok (`int`, *optional*, defaults to 1): TODO
num_local_experts (`int`, *optional*, defaults to 16): TODO
moe_layers (`int`, *optional*): TODO
interleave_moe_layer_step (`int`, *optional*, defaults to 1): TODO
use_qk_norm (`int`, *optional*, defaults to `True`): TODO
output_router_logits (`int`, *optional*, defaults to `False`): TODO
router_aux_loss_coef (`int`, *optional*, defaults to 0.001): TODO
router_jitter_noise (`int`, *optional*, defaults to 0.0): TODO
rope_scaling (`Dict`, *optional*):
Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
accordingly.
Expected contents:
`rope_type` (`str`):
The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
'llama3'], with 'default' being the original RoPE implementation.
`factor` (`float`, *optional*):
Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
most scaling types, a `factor` of x will enable the model to handle sequences of length x *
original maximum pre-trained length.
`original_max_position_embeddings` (`int`, *optional*):
Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
pretraining.
`attention_factor` (`float`, *optional*):
Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
computation. If unspecified, it defaults to value recommended by the implementation, using the
`factor` field to infer the suggested value.
`beta_fast` (`float`, *optional*):
Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
ramp function. If unspecified, it defaults to 32.
`beta_slow` (`float`, *optional*):
Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
ramp function. If unspecified, it defaults to 1.
`short_factor` (`List[float]`, *optional*):
Only used with 'longrope'. The scaling factor to be applied to short contexts (<
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
size divided by the number of attention heads divided by 2
`long_factor` (`List[float]`, *optional*):
Only used with 'longrope'. The scaling factor to be applied to long contexts (<
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
size divided by the number of attention heads divided by 2
`low_freq_factor` (`float`, *optional*):
Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
`high_freq_factor` (`float`, *optional*):
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
<TODO>
<TODO>
no_rope_layers (`int`, *optional*): TODO
no_rope_layer_interval (`int`, *optional*, defaults to 4): TODO
attention_chunk_size (`int`, *optional*, defaults to 8192):
<TODO>
attn_temperature_tuning (`int`, *optional*, defaults to 4): TODO
floor_scale (`int`, *optional*, defaults to 8192): TODO
attn_scale (`int`, *optional*, defaults to 0.1): TODO
Example:
"""
model_type = "llama4_text"
keys_to_ignore_at_inference = ["past_key_values"]
base_model_tp_plan = {
"layers.*.self_attn.q_proj": "colwise",
"layers.*.self_attn.k_proj": "colwise",
"layers.*.self_attn.v_proj": "colwise",
"layers.*.self_attn.o_proj": "rowwise",
"layers.*.input_layernorm.weight": "sequence_parallel",
"layers.*.post_attention_layernorm.weight": "sequence_parallel",
"norm.weight": "sequence_parallel",
"layers.*.feed_forward.shared_expert.gate_proj": "local_colwise",
"layers.*.feed_forward.shared_expert.up_proj": "local_colwise",
"layers.*.feed_forward.shared_expert.down_proj": "local_rowwise",
"layers.*.feed_forward.experts.gate_up_proj": "local_packed_rowwise", # row because not linear
"layers.*.feed_forward.experts.down_proj": "local_colwise", # col because not linear
"layers.*.feed_forward.experts": "local",
"layers.*.feed_forward.gate_proj": "local_colwise",
"layers.*.feed_forward.up_proj": "local_colwise",
"layers.*.feed_forward.down_proj": "local_rowwise",
"layers.*.feed_forward": "gather",
}
def __init__(
self,
vocab_size=202048,
hidden_size=5120,
intermediate_size=8192,
intermediate_size_mlp=16384,
num_hidden_layers=48,
num_attention_heads=40,
num_key_value_heads=8,
head_dim=128,
hidden_act="silu",
max_position_embeddings=4096 * 32,
initializer_range=0.02,
rms_norm_eps=1e-5,
use_cache=True,
pad_token_id=None,
bos_token_id=1,
eos_token_id=2,
tie_word_embeddings=False,
rope_theta=500000,
attention_dropout=0.0,
num_experts_per_tok=1,
num_local_experts=16,
moe_layers=None,
interleave_moe_layer_step=1,
use_qk_norm=True,
output_router_logits=False,
router_aux_loss_coef=0.001,
router_jitter_noise=0.0,
rope_scaling=None,
no_rope_layers=None,
no_rope_layer_interval=4,
attention_chunk_size=8192,
attn_temperature_tuning=4,
floor_scale=8192,
attn_scale=0.1,
**kwargs,
):
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
self.attn_temperature_tuning = attn_temperature_tuning
self.attn_scale = attn_scale
self.floor_scale = floor_scale
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.intermediate_size_mlp = intermediate_size_mlp
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.rope_scaling = rope_scaling
self.attention_bias = False
# for backward compatibility
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.rope_theta = rope_theta
self.attention_dropout = attention_dropout
self.head_dim = head_dim if head_dim is not None else self.hidden_size // self.num_attention_heads
self.use_qk_norm = use_qk_norm
self.num_experts_per_tok = num_experts_per_tok
self.num_local_experts = num_local_experts
self.output_router_logits = output_router_logits
self.router_aux_loss_coef = router_aux_loss_coef
self.router_jitter_noise = router_jitter_noise
default_no_rope_layers = [
int((layer_idx + 1) % no_rope_layer_interval != 0) for layer_idx in range(self.num_hidden_layers)
]
# no_rope_layers == [] is invalid as we cannot have 0 layers
self.no_rope_layers = no_rope_layers if no_rope_layers else default_no_rope_layers
self.interleave_moe_layer_step = interleave_moe_layer_step
self.moe_layers = (
moe_layers
if moe_layers is not None
else list(range(interleave_moe_layer_step - 1, num_hidden_layers, interleave_moe_layer_step))
)
self.attention_chunk_size = attention_chunk_size
class Llama4Config(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`Llama4Model`]. It is used to instantiate an
Llama4 model according to the specified arguments, defining the model architecture. Instantiating a configuration
with the defaults will yield a similar configuration to that of the Llama4 109B.
e.g. [meta-llama/Llama-4-Scout-17B-16E](https://huggingface.co/meta-llama/Llama-4-Scout-17B-16E)
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vision_config (`Llama4VisionConfig`, *optional*):
The Llama4 Vision config.
text_config (`Llama4TextConfig`, *optional*):
The Llama4 Text config.
boi_token_index (`int`, *optional*, defaults to 200080):
The begin-of-image token index to wrap the image prompt.
eoi_token_index (`int`, *optional*, defaults to 200081):
The end-of-image token index to wrap the image prompt.
image_token_index (`int`, *optional*, defaults to 200092):
The image token index to encode the image prompt.
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
Whether the model's input and output word embeddings should be tied.
```python
>>> from transformers import Llama4Model, Llama4Config
>>> # Initializing a Llama4 7B style configuration
>>> configuration = Llama4Config()
>>> # Initializing a model from the Llama4 7B style configuration
>>> model = Llama4Model(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "llama4"
sub_configs = {"text_config": Llama4TextConfig, "vision_config": Llama4VisionConfig}
base_model_tp_plan = {
"multi_modal_projector.linear_1": "colwise_rep",
}
def __init__(
self,
vision_config=None,
text_config=None,
boi_token_index=200080,
eoi_token_index=200081,
image_token_index=200092,
tie_word_embeddings=False,
**kwargs,
):
if vision_config is None:
self.vision_config = Llama4VisionConfig()
logger.info("vision_config is None, using default llama4 vision config")
elif isinstance(vision_config, dict):
self.vision_config = Llama4VisionConfig(**vision_config)
elif isinstance(vision_config, Llama4VisionConfig):
self.vision_config = vision_config
self.boi_token_index = boi_token_index
self.eoi_token_index = eoi_token_index
self.image_token_index = image_token_index
if text_config is None:
self.text_config = Llama4TextConfig()
logger.info("text_config is None, using default llama4 text config")
elif isinstance(text_config, dict):
self.text_config = Llama4TextConfig(**text_config)
elif isinstance(text_config, Llama4TextConfig):
self.text_config = text_config
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
__all__ = ["Llama4Config", "Llama4TextConfig", "Llama4VisionConfig"]

File diff suppressed because one or more lines are too long

View File

@ -0,0 +1,480 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Fast Image processor class for Got-OCR-2."""
import math
from collections import defaultdict
from functools import lru_cache
from typing import List, Optional, Set, Tuple, Union
from ...image_processing_utils import BatchFeature
from ...image_processing_utils_fast import (
BASE_IMAGE_PROCESSOR_FAST_DOCSTRING,
BASE_IMAGE_PROCESSOR_FAST_DOCSTRING_PREPROCESS,
BaseImageProcessorFast,
DefaultFastImageProcessorKwargs,
group_images_by_shape,
reorder_images,
)
from ...image_utils import (
ImageInput,
PILImageResampling,
SizeDict,
)
from ...processing_utils import Unpack
from ...utils import (
TensorType,
add_start_docstrings,
is_torch_available,
is_torchvision_available,
is_torchvision_v2_available,
)
if is_torch_available():
import torch
if is_torchvision_available():
if is_torchvision_v2_available():
from torchvision.transforms.v2 import functional as F
else:
from torchvision.transforms import functional as F
def get_factors(dividend: int) -> Set[int]:
"""
Calculate all factors of a given number, i.e. a dividor that leaves
no remainder. For example, if dividend=12, it will return {1, 2, 3, 4, 6, 12}.
Args:
dividend (int): The number to find factors for.
Returns:
set: A set containing all factors of the number.
"""
factors_set = set()
for i in range(1, int(dividend**0.5) + 1):
if dividend % i == 0:
factors_set.add(i)
factors_set.add(dividend // i)
return factors_set
def get_max_res_without_distortion(
image_size: Tuple[int, int],
target_size: Tuple[int, int],
) -> Tuple[int, int]:
"""
Determines the maximum resolution to which an image can be resized to without distorting its
aspect ratio, based on the target resolution.
Args:
image_size (Tuple[int, int]): The original resolution of the image (height, width).
target_resolution (Tuple[int, int]): The desired resolution to fit the image into (height, width).
Returns:
Tuple[int, int]: The optimal dimensions (height, width) to which the image should be resized.
Example:
>>> _get_max_res_without_distortion([200, 300], target_size = [450, 200])
(134, 200)
>>> _get_max_res_without_distortion([800, 600], target_size = [450, 1300])
(450, 338)
"""
original_height, original_width = image_size
target_height, target_width = target_size
scale_w = target_width / original_width
scale_h = target_height / original_height
if scale_w < scale_h:
new_width = target_width
new_height = min(math.floor(original_height * scale_w), target_height)
else:
new_height = target_height
new_width = min(math.floor(original_width * scale_h), target_width)
return new_height, new_width
class Llama4ImageProcessorKwargs(DefaultFastImageProcessorKwargs):
max_patches: Optional[int]
resize_to_max_canvas: Optional[bool]
def split_to_tiles(images: torch.Tensor, num_tiles_height: int, num_tiles_width: int) -> torch.Tensor:
# Split image into number of required tiles (width x height)
batch_size, num_channels, height, width = images.size()
images = images.view(
batch_size,
num_channels,
num_tiles_height,
height // num_tiles_height,
num_tiles_width,
width // num_tiles_width,
)
# Permute dimensions to reorder the axes
image = images.permute(0, 2, 4, 1, 3, 5).contiguous()
# Reshape into the desired output shape (batch_size * 4, num_channels, width/2, height/2)
image = image.view(
batch_size,
num_tiles_width * num_tiles_height,
num_channels,
height // num_tiles_height,
width // num_tiles_width,
)
return image
@lru_cache(maxsize=1)
def find_supported_resolutions(max_num_chunks: int, patch_size: SizeDict) -> torch.Tensor:
"""
Computes all of the allowed resolutions for a fixed number of chunks
and patch_size. Useful for when dividing an image into chunks.
Args:
max_num_chunks (int): Maximum number of chunks for processing.
patch_size (int): Size of the side of the patch.
Returns:
torch.Tensor: List of possible resolutions as tuples (height, width).
Example:
>>> max_num_chunks = 5
>>> patch_size = 224
>>> find_supported_resolutions(max_num_chunks, patch_size)
tensor([(224, 896), (448, 448), (224, 224), (896, 224), (224, 672),
(672, 224), (224, 448), (448, 224)])
Given max_num_chunks=4, patch_size=224, it will create a dictionary:
{
0.25: [(1, 4)],
1.0: [(2, 2), (1, 1)],
4.0: [(4, 1)],
0.33: [(1, 3)],
3.0: [(3, 1)],
0.5: [(1, 2)],
2.0: [(2, 1)]
}
and return the resolutions multiplied by the patch_size:
[(1*224, 4*224), (2*224, 2*224), ..., (2*224, 1*224)]
"""
height, width = patch_size.height, patch_size.width
if height != width:
raise ValueError("`size` must be square.")
patch_size = height
asp_dict = defaultdict(list)
for chunk_size in range(max_num_chunks, 0, -1):
_factors = sorted(get_factors(chunk_size))
_asp_ratios = [(factor, chunk_size // factor) for factor in _factors]
for height, width in _asp_ratios:
ratio_float = height / width
asp_dict[ratio_float].append((height, width))
# get the resolutions multiplied by the patch_size
possible_resolutions = []
for key, value in asp_dict.items():
for height, depth in value:
possible_resolutions.append((height * patch_size, depth * patch_size))
return possible_resolutions
def pad_to_best_fit(
images: "torch.Tensor",
target_size: Tuple[int, int],
background_color: Union[int, Tuple[int, int, int]] = 0,
) -> "torch.Tensor":
"""
Pads an image to fit the target size.
Args:
images (`np.ndarray`):
The images to pad.
background_color (`int` or `Tuple[int, int, int]`, *optional*, defaults to 0):
The color to use for the padding. Can be an integer for single channel or a
tuple of integers representing for multi-channel images. If passed as integer
in mutli-channel mode, it will default to `0` in subsequent channels.
Returns:
`torch.Tensor`: The padded images.
"""
num_channels = images.shape[1] if len(images.shape) == 4 else images.shape[0]
if isinstance(background_color, int):
background_color = [background_color] + [0] * (num_channels - 1)
elif len(background_color) != num_channels:
raise ValueError(
f"background_color must have no more than {num_channels} elements to match the number of channels"
)
height, width = images.shape[-2:]
target_height, target_width = target_size
paste_x_right = target_width - width
paste_y_right = target_height - height
padded_images = F.pad(images, padding=[0, 0, paste_x_right, paste_y_right], fill=background_color)
return padded_images
def get_best_fit(
image_size: Tuple[int, int],
possible_resolutions: torch.Tensor,
resize_to_max_canvas: bool = False,
) -> Tuple[int, int]:
"""
Determines the best canvas possible from a list of possible resolutions to, without distortion,
resize an image to.
For each possible resolution, calculates the scaling factors for
width and height, and selects the smallest one, which is the limiting side.
E.g. to match the canvas you can upscale height by 2x, and width by 1.5x,
therefore, the maximum upscaling you can do is min(2, 1.5) = 1.5.
If upscaling is possible (any of the scaling factors is greater than 1),
then picks the smallest upscaling factor > 1, unless resize_to_max_canvas is True.
If upscaling is not possible, then picks the largest scaling factor <= 1, i.e.
reduce downscaling as much as possible.
If there are multiple resolutions with the same max scale, we pick the one with the lowest area,
to minimize padding. E.g., the same image can be upscaled to 224x224 and 224x448, but the latter
has more padding.
Args:
image_size (Tuple[int, int]): A tuple containing the height and width of the image.
possible_resolutions (torch.Tensor): A tensor of shape (N, 2) where each
row represents a possible resolution (height, width).
resize_to_max_canvas (bool): If True, will return the largest upscaling resolution.
Returns:
List[int]: The best resolution [height, width] for the given image.
Example:
>>> image_size = (200, 300)
>>> possible_resolutions = torch.tensor([[224, 672],
... [672, 224],
... [224, 448],
... [448, 224],
... [224, 224]])
>>> get_best_fit(image_size, possible_resolutions)
[224, 448]
We have:
scale_w = tensor([2.2400, 0.7467, 1.4933, 0.7467, 0.7467])
scale_h = tensor([1.1200, 3.3600, 1.1200, 2.2400, 1.1200])
scales = tensor([1.1200, 0.7467, 1.1200, 0.7467, 0.7467])
Only one of the scales > 1:
upscaling_possible = tensor([1.1200, 1.1200])
smallest_rescale = tensor(1.1200)
So we pick the resolution with the smallest smallest area:
areas = tensor([150528, 100352]) # [672, 224], [224, 448]
optimal_canvas = tensor([224, 448])
"""
original_height, original_width = image_size
# get all possible resolutions heights/widths
target_heights, target_widths = (
possible_resolutions[:, 0],
possible_resolutions[:, 1],
)
# get scaling factors to resize the image without distortion
scale_w = target_widths / original_width
scale_h = target_heights / original_height
# get the min scale between width and height (limiting side -> no distortion)
scales = torch.where(scale_h > scale_w, scale_w, scale_h)
# filter only scales that allow upscaling
upscaling_options = scales[scales >= 1]
if len(upscaling_options) > 0:
if resize_to_max_canvas:
selected_scale = torch.max(upscaling_options)
else:
selected_scale = torch.min(upscaling_options)
else:
# no upscaling possible,
# get the minimum downscaling (max scale for scales<1)
downscaling_options = scales[scales < 1]
selected_scale = torch.max(downscaling_options)
# get all resolutions that support this scaling factor,
# e.g. you can upscale to 224x224, 224x448, 224x672 without distortion
chosen_canvas = possible_resolutions[scales == selected_scale]
# if there are multiple resolutions,
# get the one with minimum area to reduce padding
if len(chosen_canvas) > 1:
areas = chosen_canvas[:, 0] * chosen_canvas[:, 1]
optimal_idx = torch.argmin(areas)
optimal_canvas = chosen_canvas[optimal_idx]
else:
optimal_canvas = chosen_canvas[0]
return tuple(optimal_canvas.tolist())
@add_start_docstrings(
"Constructs a fast Llama4 image processor.",
BASE_IMAGE_PROCESSOR_FAST_DOCSTRING,
"""
max_patches (`int`, *optional*, defaults to 16):
The maximum number of patches to be extracted from the image.
Can be overridden by the `max_patches` parameter in the `preprocess` method.
resize_to_max_canvas (`bool`, *optional*, defaults to False):
Whether to resize the image to the maximum canvas size.
If True, picks the canvas the allows the largest resizing without distortion.
If False, downsample as little as possible, including no resizing at all,
but never upsample, unless the image is smaller than the patch size.
""",
)
class Llama4ImageProcessorFast(BaseImageProcessorFast):
resample = PILImageResampling.BILINEAR
image_mean = [0.5, 0.5, 0.5]
image_std = [0.5, 0.5, 0.5]
size = {"height": 336, "width": 336}
do_resize = True
do_rescale = True
do_normalize = True
do_convert_rgb = True
max_patches = 16
resize_to_max_canvas = False
valid_kwargs = Llama4ImageProcessorKwargs
def __init__(self, **kwargs: Unpack[Llama4ImageProcessorKwargs]):
super().__init__(**kwargs)
def rescale_and_normalize(
self,
images: "torch.Tensor",
do_rescale: bool,
rescale_factor: float,
do_normalize: bool,
image_mean: Union[float, List[float]],
image_std: Union[float, List[float]],
) -> "torch.Tensor":
"""
Rescale and normalize images.
Override to rescale and normalize the images in torch.bfloat16 as in the original implementation
"""
if do_rescale and do_normalize:
images = images.to(dtype=torch.bfloat16) * rescale_factor
images = self.normalize(images, image_mean, image_std)
elif do_rescale:
images = images * rescale_factor
elif do_normalize:
images = self.normalize(images, image_mean, image_std)
return images
@add_start_docstrings(
BASE_IMAGE_PROCESSOR_FAST_DOCSTRING_PREPROCESS,
"""
max_patches (`int`, *optional*, defaults to 16):
The maximum number of patches to be extracted from the image.
Can be overridden by the `max_patches` parameter in the `preprocess` method.
resize_to_max_canvas (`bool`, *optional*, defaults to False):
Whether to resize the image to the maximum canvas size.
If True, picks the canvas the allows the largest resizing without distortion.
If False, downsample as little as possible, including no resizing at all,
but never upsample, unless the image is smaller than the patch size.
""",
)
def preprocess(self, images: ImageInput, **kwargs: Unpack[Llama4ImageProcessorKwargs]) -> BatchFeature:
return super().preprocess(images, **kwargs)
def _preprocess(
self,
images: List["torch.Tensor"],
size: SizeDict,
max_patches: int,
resize_to_max_canvas: bool,
interpolation: Optional["F.InterpolationMode"],
do_rescale: bool,
rescale_factor: float,
do_normalize: bool,
image_mean: Optional[Union[float, List[float]]],
image_std: Optional[Union[float, List[float]]],
return_tensors: Optional[Union[str, TensorType]],
**kwargs,
) -> BatchFeature:
possible_resolutions = find_supported_resolutions(max_num_chunks=max_patches, patch_size=size)
possible_resolutions = torch.tensor(possible_resolutions)
# process images by batch, grouped by shape
grouped_images, grouped_images_index = group_images_by_shape(images)
grouped_processed_images = {}
grouped_aspect_ratios = {}
for shape, stacked_images in grouped_images.items():
image_size = stacked_images.shape[-2:]
target_size = get_best_fit(image_size, possible_resolutions, resize_to_max_canvas=resize_to_max_canvas)
# If target_size requires upscaling, we might want to limit the upscaling to max_upscaling_size
max_upscaling_size = None if resize_to_max_canvas else size.height
if max_upscaling_size is not None:
new_target_height = min(max(image_size[0], max_upscaling_size), target_size[0])
new_target_width = min(max(image_size[1], max_upscaling_size), target_size[1])
target_size_without_distortion = (new_target_height, new_target_width)
# resize to target_size while preserving aspect ratio
new_size_without_distortion = get_max_res_without_distortion(image_size, target_size_without_distortion)
new_size_without_distortion = SizeDict(
height=max(new_size_without_distortion[0], 1), width=max(new_size_without_distortion[1], 1)
)
processed_images = self.resize(
stacked_images,
new_size_without_distortion,
interpolation=interpolation,
)
# pad to target_size to be able to split into tiles
processed_images = pad_to_best_fit(processed_images, target_size)
processed_images = self.rescale_and_normalize(
processed_images, do_rescale, rescale_factor, do_normalize, image_mean, image_std
)
ratio_h, ratio_w = (
target_size[0] // size.height,
target_size[1] // size.height,
)
# split into tiles
processed_images = split_to_tiles(processed_images, ratio_h, ratio_w)
grouped_processed_images[shape] = processed_images
grouped_aspect_ratios[shape] = torch.tensor([[ratio_h, ratio_w]] * stacked_images.shape[0])
# add a global tile to the processed tile if there are more than one tile
if ratio_h * ratio_w > 1:
global_tiles = self.resize(
stacked_images,
size,
interpolation=interpolation,
)
global_tiles = self.rescale_and_normalize(
global_tiles, do_rescale, rescale_factor, do_normalize, image_mean, image_std
)
grouped_processed_images[shape] = torch.cat([processed_images, global_tiles.unsqueeze(1)], dim=1)
processed_images = reorder_images(grouped_processed_images, grouped_images_index)
aspect_ratios_list = reorder_images(grouped_aspect_ratios, grouped_images_index)
processed_images = torch.cat(processed_images, dim=0) if return_tensors else processed_images
aspect_ratios = torch.stack(aspect_ratios_list, dim=0) if return_tensors else aspect_ratios_list
return BatchFeature(
data={"pixel_values": processed_images, "aspect_ratios": aspect_ratios}, tensor_type=return_tensors
)
__all__ = ["Llama4ImageProcessorFast"]

File diff suppressed because it is too large Load Diff

File diff suppressed because one or more lines are too long

View File

@ -981,6 +981,8 @@ class Pipeline(_ScikitCompat, PushToHubMixin):
else:
self.device = device if device is not None else -1
if torch.distributed.is_initialized():
self.device = self.model.device
logger.warning(f"Device set to use {self.device}")
self.binary_output = binary_output

View File

@ -1178,10 +1178,6 @@ class ProcessorMixin(PushToHubMixin):
unused_kwargs = {}
unused_keys = set(kwargs_from_config) - set(valid_kwargs)
if unused_keys:
unused_key_str = ", ".join(unused_keys)
logger.warning(
f"Some kwargs in processor config are unused and will not have any effect: {unused_key_str}. "
)
unused_kwargs = {k: processor_config[k] for k in unused_keys}
return unused_kwargs

View File

@ -43,8 +43,7 @@ is_torch_greater_or_equal_than_1_12 = is_torch_greater_or_equal("1.12", accept_d
_torch_distributed_available = torch.distributed.is_available()
if is_torch_greater_or_equal("2.5") and _torch_distributed_available:
from torch.distributed.tensor import Replicate
from torch.distributed.tensor.parallel import ColwiseParallel, RowwiseParallel
pass
def softmax_backward_data(parent, grad_output, output, dim, self):
@ -335,29 +334,6 @@ def isin_mps_friendly(elements: torch.Tensor, test_elements: torch.Tensor | int)
return torch.isin(elements, test_elements)
# TODO need to add the __repr__ that shows that it is a colwise parallel
# See https://github.com/pytorch/pytorch/issues/145726
def translate_to_torch_parallel_style(style: str):
"""
In model configurations, we use a neutral type (string) to specify parallel
styles, here we translate them into torch.distributed tensor-parallel
types.
"""
if not isinstance(style, str):
raise ValueError(f"Unsupported parallel style type {type(style)}, expected str")
if style == "colwise":
return ColwiseParallel()
elif style == "rowwise":
return RowwiseParallel()
elif style == "colwise_rep":
return ColwiseParallel(output_layouts=Replicate())
elif style == "rowwise_rep":
return RowwiseParallel(input_layouts=Replicate())
else:
raise ValueError(f"Unsupported parallel style value: {style}")
def compile_compatible_method_lru_cache(*lru_args, **lru_kwargs):
"""
LRU cache decorator from standard functools library, but with a workaround to disable
@ -382,88 +358,3 @@ def compile_compatible_method_lru_cache(*lru_args, **lru_kwargs):
return wrapper
return decorator
def distribute_module(
module: nn.Module,
device_mesh=None,
partition_fn=None,
input_fn=None,
output_fn=None,
) -> nn.Module:
"""
This function expose three functions to control the parameters/inputs/outputs of the module:
1. To perform sharding on the module before runtime execution by specifying the
``partition_fn`` (i.e. allow user to convert Module parameters to :class:`DTensor`
parameters according to the `partition_fn` specified).
2. To control the inputs or outputs of the module during runtime execution by
specifying the ``input_fn`` and ``output_fn``. (i.e. convert the input to
:class:`DTensor`, convert the output back to ``torch.Tensor``)
Args:
module (:class:`nn.Module`): user module to be partitioned.
device_mesh (:class:`DeviceMesh`): the device mesh to place the module.
partition_fn (Callable): the function to partition parameters (i.e. shard certain
parameters across the ``device_mesh``). If ``partition_fn`` is not specified,
by default we replicate all module parameters of ``module`` across the mesh.
input_fn (Callable): specify the input distribution, i.e. could control how the
input of the module is sharded. ``input_fn`` will be installed as a module
``forward_pre_hook`` (pre forward hook).
output_fn (Callable): specify the output distribution, i.e. could control how the
output is sharded, or convert it back to torch.Tensor. ``output_fn`` will be
installed as a module ``forward_hook`` (post forward hook).
Returns:
A module that contains parameters/buffers that are all ``DTensor`` s.
.. note::
When initialize the DeviceMesh with the ``xla`` device_type, ``distribute_module``
return nn.Module with PyTorch/XLA SPMD annotated parameters. See
`this issue <https://github.com/pytorch/pytorch/issues/92909>`__
for more details. The XLA integration is experimental and subject to change.
"""
torch._C._log_api_usage_once("torch.dtensor.distribute_module")
device_mesh = device_mesh
# register input_fn as module forward pre hook
if input_fn is not None:
# check the input_fn signature
num_args = len(inspect.signature(input_fn).parameters)
if num_args == 2:
# input_fn only takes in inputs and device mesh
logger.warning(
"Deprecating input_fn that takes two arguments (inputs, device_mesh), "
"please use input_fn that takes in (module, inputs, device_mesh) instead!",
FutureWarning,
stacklevel=2,
)
module.register_forward_pre_hook(lambda _, inputs: input_fn(inputs, device_mesh)) # type: ignore[call-arg]
elif num_args == 3:
# input_fn takes in module, inputs, device mesh
module.register_forward_pre_hook(lambda mod, inputs: input_fn(mod, inputs, device_mesh))
else:
raise ValueError(f"input_fn should take in 3 arguments, but got {num_args} arguments!")
# register output_fn as module forward hook
if output_fn is not None:
num_args = len(inspect.signature(output_fn).parameters)
if num_args == 2:
# output_fn only takes in outputs and device mesh
logger.warning(
"Deprecating output_fn that takes two arguments (inputs, device_mesh), "
"please use output_fn that takes in (module, inputs, device_mesh) instead!",
FutureWarning,
stacklevel=2,
)
module.register_forward_hook(
lambda mod, inputs, outputs: output_fn(outputs, device_mesh) # type: ignore[call-arg]
)
elif num_args == 3:
module.register_forward_hook(lambda mod, inputs, outputs: output_fn(mod, outputs, device_mesh))
else:
raise ValueError(f"output_fn should take in 3 arguments, but got {num_args} arguments!")
return module

52
src/transformers/quantizers/base.py Executable file → Normal file
View File

@ -15,7 +15,8 @@ from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
from ..utils import is_torch_available
from ..utils.quantization_config import QuantizationConfigMixin
from ..utils.quantization_config import QuantizationConfigMixin, QuantizationMethod
from .quantizers_utils import get_module_from_name
if TYPE_CHECKING:
@ -23,6 +24,9 @@ if TYPE_CHECKING:
if is_torch_available():
import torch
from torch.nn import ModuleList
else:
ModuleList = str
class HfQuantizer(ABC):
@ -198,6 +202,10 @@ class HfQuantizer(ABC):
"""
return
def update_tp_plan(self, config):
"updates the tp plan for the scales"
return config
def preprocess_model(self, model: "PreTrainedModel", **kwargs):
"""
Setting model attributes and/or converting model before weights loading. At this point
@ -212,6 +220,7 @@ class HfQuantizer(ABC):
"""
model.is_quantized = True
model.quantization_method = self.quantization_config.quant_method
self._convert_model_for_quantization(model)
return self._process_model_before_weight_loading(model, **kwargs)
def postprocess_model(self, model: "PreTrainedModel", **kwargs):
@ -288,3 +297,44 @@ class HfQuantizer(ABC):
@property
@abstractmethod
def is_trainable(self): ...
def _convert_model_for_quantization(self, model):
from accelerate import init_empty_weights
for name, module in model.named_modules():
module_class_name = module.__class__.__name__
if (
module_class_name in MODULES_TO_PATCH_FOR_QUANTIZATION.keys()
and self.quantization_config.quant_method == QuantizationMethod.COMPRESSED_TENSORS
):
with init_empty_weights():
parent_module, name = get_module_from_name(model, name)
parent_module._modules[name] = MODULES_TO_PATCH_FOR_QUANTIZATION[module_class_name](
model.config.get_text_config()
)
class SequentialLlama4TextExperts(ModuleList):
"""
A module that implements a compressed version of a list of expert modules.
This is specifically designed to work with Llama4TextExperts in MoE layers.
"""
def __init__(self, config):
from transformers.models.llama4.modeling_llama4 import Llama4TextMLP
super().__init__([Llama4TextMLP(config) for _ in range(config.num_local_experts)])
self.num_experts = config.num_local_experts
def forward(
self,
hidden_states: "torch.Tensor",
) -> "torch.Tensor":
hidden_states = hidden_states.reshape(self.num_experts, -1, hidden_states.shape[-1])
routed_out = torch.zeros_like(hidden_states)
for expert_idx in range(self.num_experts):
routed_out[expert_idx] = self[expert_idx](hidden_states[expert_idx])
return routed_out
MODULES_TO_PATCH_FOR_QUANTIZATION = {"Llama4TextExperts": SequentialLlama4TextExperts}

View File

@ -146,6 +146,19 @@ class CompressedTensorsHfQuantizer(HfQuantizer):
self.compressor.quantization_config.quantization_status = QuantizationStatus.FROZEN
self.compressor.decompress(model_path=cache_path, model=model)
def update_tp_plan(self, config):
additional_plan = {
"layers.*.feed_forward.experts.*.gate_proj.weight": "local_colwise",
"layers.*.feed_forward.experts.*.gate_proj.weight_scale": "local_colwise",
"layers.*.feed_forward.experts.*.up_proj.weight": "local_colwise",
"layers.*.feed_forward.experts.*.up_proj.weight_scale": "local_colwise",
"layers.*.feed_forward.experts.*.down_proj.weight": "local_rowwise",
}
if config.get_text_config() is not None and config.get_text_config().base_model_tp_plan is not None:
config.get_text_config().base_model_tp_plan.update(additional_plan)
return config
@property
def is_trainable(self):
return True

View File

@ -116,7 +116,7 @@ class FbgemmFp8HfQuantizer(HfQuantizer):
state_dict: Dict[str, Any],
**kwargs,
):
from ..integrations import FbgemmFp8Linear
from ..integrations import FbgemmFp8Linear, FbgemmFp8Llama4TextExperts
module, tensor_name = get_module_from_name(model, param_name)
@ -129,6 +129,13 @@ class FbgemmFp8HfQuantizer(HfQuantizer):
if tensor_name == "weight_scale":
raise ValueError("Expect unquantized weights but got a quantized weight_scale")
return True
if isinstance(module, FbgemmFp8Llama4TextExperts):
if self.pre_quantized or tensor_name == "bias":
return False
else:
if tensor_name == "gate_up_proj_scale" or tensor_name == "down_proj_scale":
raise ValueError("Expect unquantized weights but got a quantized weight_scale")
return True
return False
def create_quantized_param(
@ -143,12 +150,52 @@ class FbgemmFp8HfQuantizer(HfQuantizer):
"""
Quantizes weights into weight and weight_scale
"""
new_value, weight_scale = torch.ops.fbgemm.quantize_fp8_per_row(param_value)
from ..integrations import FbgemmFp8Llama4TextExperts
module, tensor_name = get_module_from_name(model, param_name)
module._buffers[tensor_name] = new_value.to(target_device)
# to have the right output shape -> (out_features, 1)
module._buffers["weight_scale"] = weight_scale.view(weight_scale.shape[0], 1).to(target_device)
if isinstance(module, FbgemmFp8Llama4TextExperts):
if tensor_name == "gate_up_proj":
# Process each expert separately
# Transpose the second and third dimension
transposed_param = param_value.transpose(1, 2)
# Reshape to 2D for quantization
original_shape = transposed_param.shape
flattened_param = transposed_param.reshape(-1, original_shape[-1])
# Quantize using per row instead of per column
new_value_flat, weight_scale_flat = torch.ops.fbgemm.quantize_fp8_per_row(flattened_param)
# Reshape back to original dimensions
new_value = new_value_flat.reshape(original_shape)
new_value = new_value.transpose(1, 2)
weight_scale = weight_scale_flat.reshape(original_shape[0], 1, original_shape[1])
elif tensor_name == "down_proj":
# Process each expert separately
# Transpose the weights for proper quantization
transposed_param = param_value.transpose(1, 2)
# Reshape to 2D for quantization
original_shape = transposed_param.shape
flattened_param = transposed_param.reshape(-1, original_shape[-1])
# Quantize using per column
new_value_flat, weight_scale_flat = torch.ops.fbgemm.quantize_fp8_per_row(flattened_param)
# Reshape back to original dimensions
new_value = new_value_flat.reshape(original_shape)
new_value = new_value.transpose(1, 2)
weight_scale = weight_scale_flat.reshape(original_shape[0], original_shape[1], 1)
module._parameters[f"{tensor_name}_scale"] = torch.nn.Parameter(weight_scale.to(target_device))
else:
new_value, weight_scale = torch.ops.fbgemm.quantize_fp8_per_row(param_value)
module._parameters[f"{tensor_name}_scale"] = torch.nn.Parameter(
weight_scale.view(weight_scale.shape[0], 1).to(target_device)
)
module._parameters[tensor_name] = torch.nn.Parameter(new_value.to(target_device))
if unexpected_keys is not None and param_name in unexpected_keys:
unexpected_keys.remove(param_name)
@ -165,25 +212,29 @@ class FbgemmFp8HfQuantizer(HfQuantizer):
):
from ..integrations import replace_with_fbgemm_fp8_linear
tp_plan = model._tp_plan
self.modules_to_not_convert = self.get_modules_to_not_convert(
model, self.quantization_config.modules_to_not_convert, keep_in_fp32_modules
)
config = model.config
model = replace_with_fbgemm_fp8_linear(
model,
modules_to_not_convert=self.modules_to_not_convert,
quantization_config=self.quantization_config,
pre_quantized=self.pre_quantized,
config=config,
tp_plan=tp_plan,
)
model.config.quantization_config = self.quantization_config
def update_missing_keys(self, model, missing_keys: List[str], prefix: str) -> List[str]:
from ..integrations import FbgemmFp8Linear
from ..integrations import FbgemmFp8Linear, FbgemmFp8Llama4TextExperts
not_missing_keys = []
for name, module in model.named_modules():
if isinstance(module, FbgemmFp8Linear):
if isinstance(module, FbgemmFp8Linear) or isinstance(module, FbgemmFp8Llama4TextExperts):
for missing in missing_keys:
if (
(name in missing or name in f"{prefix}.{missing}")

View File

@ -3950,7 +3950,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
verbose (`bool`): Whether or not to print more information and warnings.
"""
if max_length is None and len(ids) > self.model_max_length and verbose:
if max_length is None and len(ids) > self.model_max_length and verbose and self.model_max_length != 0:
if not self.deprecation_warnings.get("sequence-length-is-longer-than-the-specified-maximum", False):
logger.warning(
"Token indices sequence length is longer than the specified maximum sequence length "

View File

@ -5823,6 +5823,41 @@ class LlamaPreTrainedModel(metaclass=DummyObject):
requires_backends(self, ["torch"])
class Llama4ForCausalLM(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class Llama4ForConditionalGeneration(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class Llama4PreTrainedModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class Llama4TextModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class Llama4VisionModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class LlavaForConditionalGeneration(metaclass=DummyObject):
_backends = ["torch"]

View File

@ -72,6 +72,13 @@ class GotOcr2ImageProcessorFast(metaclass=DummyObject):
requires_backends(self, ["torchvision"])
class Llama4ImageProcessorFast(metaclass=DummyObject):
_backends = ["torchvision"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torchvision"])
class LlavaImageProcessorFast(metaclass=DummyObject):
_backends = ["torchvision"]

View File

@ -408,6 +408,13 @@ class LevitImageProcessor(metaclass=DummyObject):
requires_backends(self, ["vision"])
class Llama4ImageProcessor(metaclass=DummyObject):
_backends = ["vision"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["vision"])
class LlavaImageProcessor(metaclass=DummyObject):
_backends = ["vision"]

View File

View File

@ -0,0 +1,128 @@
# coding=utf-8
# Copyright 2022 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from transformers.testing_utils import require_torch, require_vision
from transformers.utils import is_torch_available, is_torchvision_available, is_vision_available
from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs
if is_torch_available():
pass
if is_vision_available() and is_torchvision_available():
from transformers import Llama4ImageProcessorFast
class Llama4ImageProcessingTester(unittest.TestCase):
def __init__(
self,
parent,
batch_size=7,
num_channels=3,
image_size=18,
min_resolution=30,
max_resolution=400,
max_patches=1,
do_resize=True,
size=None,
do_normalize=True,
do_pad=False,
image_mean=[0.5, 0.5, 0.5],
image_std=[0.5, 0.5, 0.5],
do_convert_rgb=True,
):
super().__init__()
size = size if size is not None else {"height": 20, "width": 20}
self.parent = parent
self.batch_size = batch_size
self.num_channels = num_channels
self.image_size = image_size
self.min_resolution = min_resolution
self.max_resolution = max_resolution
self.max_patches = max_patches
self.do_resize = do_resize
self.size = size
self.do_normalize = do_normalize
self.image_mean = image_mean
self.image_std = image_std
self.do_pad = do_pad
self.do_convert_rgb = do_convert_rgb
def prepare_image_processor_dict(self):
return {
"max_patches": self.max_patches,
"do_resize": self.do_resize,
"size": self.size,
"do_normalize": self.do_normalize,
"image_mean": self.image_mean,
"image_std": self.image_std,
"do_convert_rgb": self.do_convert_rgb,
"do_pad": self.do_pad,
}
def expected_output_image_shape(self, images):
return self.num_channels, self.size["height"], self.size["width"]
def prepare_image_inputs(self, equal_resolution=False, numpify=False, torchify=False):
return prepare_image_inputs(
batch_size=self.batch_size,
num_channels=self.num_channels,
min_resolution=self.min_resolution,
max_resolution=self.max_resolution,
equal_resolution=equal_resolution,
numpify=numpify,
torchify=torchify,
)
@require_torch
@require_vision
class Llama4ImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
test_slow_image_processor = False
fast_image_processing_class = Llama4ImageProcessorFast if is_torchvision_available() else None
def setUp(self):
super().setUp()
self.image_processor_tester = Llama4ImageProcessingTester(self)
@property
def image_processor_dict(self):
return self.image_processor_tester.prepare_image_processor_dict()
def test_image_processor_properties(self):
for image_processing_class in self.image_processor_list:
image_processor = image_processing_class(**self.image_processor_dict)
self.assertTrue(hasattr(image_processor, "do_resize"))
self.assertTrue(hasattr(image_processor, "size"))
self.assertTrue(hasattr(image_processor, "do_normalize"))
self.assertTrue(hasattr(image_processor, "image_mean"))
self.assertTrue(hasattr(image_processor, "image_std"))
self.assertTrue(hasattr(image_processor, "do_convert_rgb"))
def test_split_tiles(self):
for image_processing_class in self.image_processor_list:
image_processor = image_processing_class(**self.image_processor_dict)
image = self.image_processor_tester.prepare_image_inputs(equal_resolution=True)[0]
processed_images = image_processor(
image,
max_patches=16,
)
self.assertEqual(len(processed_images.pixel_values), 1)
self.assertEqual(processed_images.pixel_values[0].shape[0], 17)
self.assertEqual(processed_images.pixel_values[0].shape[-2:], (20, 20))

View File

@ -0,0 +1,121 @@
# coding=utf-8
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Testing suite for the PyTorch Llama4 model."""
import unittest
from transformers import is_torch_available
from transformers.testing_utils import (
require_read_token,
require_torch_large_gpu,
slow,
torch_device,
)
if is_torch_available():
import torch
from transformers import (
Llama4ForConditionalGeneration,
Llama4Processor,
)
@slow
@require_torch_large_gpu
@require_read_token
class Llama4IntegrationTest(unittest.TestCase):
model_id = "ll-re/Llama-4-17B-Omni-Instruct"
# This variable is used to determine which CUDA device are we using for our runners (A10 or T4)
# Depending on the hardware we get different logits / generations
cuda_compute_capability_major_version = None
@classmethod
def setUpClass(cls):
if is_torch_available() and torch.cuda.is_available():
# 8 is for A100 / A10 and 7 for T4
cls.cuda_compute_capability_major_version = torch.cuda.get_device_capability()[0]
cls.model = Llama4ForConditionalGeneration.from_pretrained(
"ll-re/Llama-4-17B-Omni-Instruct", device_map="auto", torch_dtype=torch.float32
)
def setUp(self):
self.processor = Llama4Processor.from_pretrained("ll-re/Llama-4-17B-Omni-Instruct", padding_side="left")
url = "https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/cow_beach_1.png"
self.messages = [
{"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]},
{
"role": "user",
"content": [
{"type": "image", "url": url},
{"type": "text", "text": "What is shown in this image?"},
],
},
]
def test_model_17b_16e_fp16(self):
EXPECTED_TEXT = [
"The capital of France is Paris, which is located in the north-central part of the country. Paris is known for its iconic landmarks such as the",
"Roses are red, violets are blue, and this poem is about you. Roses are red, violets are blue, and I love",
]
messages = [
{"role": "user", "content": "Who are you?"},
]
inputs = self.processor.apply_chat_template(
messages, add_generation_prompt=True, return_tensors="pt", return_dict=True
).to(torch_device)
output = self.model.generate(**inputs, max_new_tokens=100)
output_text = self.processor.batch_decode(output, skip_special_tokens=True)
print(output_text)
self.assertEqual(output_text, EXPECTED_TEXT)
def test_model_17b_16e_batch(self):
messages_2 = [
{"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]},
{
"role": "user",
"content": [
{
"type": "image",
"url": "https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/cow_beach_1.png",
},
{"type": "image", "url": "https://www.ilankelman.org/stopsigns/australia.jpg"},
{"type": "text", "text": "Are these images identical?"},
],
},
]
inputs = self.processor.apply_chat_template(
[self.messages, messages_2],
tokenize=True,
return_dict=True,
return_tensors="pt",
padding=True,
add_generation_prompt=True,
).to(torch_device)
output = self.model.generate(**inputs, max_new_tokens=30, do_sample=False)
output_text = self.processor.batch_decode(output, skip_special_tokens=True)
EXPECTED_TEXTS = [
'user\nYou are a helpful assistant.\n\n\n\n\n\nWhat is shown in this image?\nmodel\nCertainly! \n\nThe image shows a brown cow standing on a sandy beach with clear turquoise water and a blue sky in the background. It looks like',
"user\nYou are a helpful assistant.\n\n\n\n\n\n\n\n\n\nAre these images identical?\nmodel\nNo, these images are not identical. \n\nHere's a breakdown of the differences:\n\n* **Image 1:** Shows a cow"
] # fmt: skip
self.assertEqual(output_text, EXPECTED_TEXTS)

View File

@ -0,0 +1,65 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import shutil
import tempfile
import unittest
from typing import Optional
from transformers import AutoProcessor, Llama4Processor, PreTrainedTokenizerFast
from transformers.testing_utils import require_vision
from transformers.utils import is_vision_available
from ...test_processing_common import ProcessorTesterMixin
if is_vision_available():
from transformers import Llama4ImageProcessorFast
@require_vision
class Llama4ProcessorTest(ProcessorTesterMixin, unittest.TestCase):
processor_class = Llama4Processor
def setUp(self):
self.tmpdirname = tempfile.mkdtemp()
image_processor = Llama4ImageProcessorFast(max_patches=1, size={"height": 20, "width": 20})
tokenizer = PreTrainedTokenizerFast.from_pretrained("unsloth/Llama-3.2-11B-Vision-Instruct-unsloth-bnb-4bit")
processor_kwargs = self.prepare_processor_dict()
processor = Llama4Processor(image_processor, tokenizer, **processor_kwargs)
processor.save_pretrained(self.tmpdirname)
def get_tokenizer(self, **kwargs):
return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).tokenizer
def get_image_processor(self, **kwargs):
return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).image_processor
def tearDown(self):
shutil.rmtree(self.tmpdirname)
# Override as Llama4ProcessorProcessor needs image tokens in prompts
def prepare_text_inputs(self, batch_size: Optional[int] = None):
if batch_size is None:
return "lower newer <image>"
if batch_size < 1:
raise ValueError("batch_size must be greater than 0")
if batch_size == 1:
return ["lower newer <image>"]
return ["lower newer <image>", "<image> upper older longer string"] + ["<image> lower newer"] * (
batch_size - 2
)

View File

@ -236,6 +236,16 @@ SPECIAL_CASES_TO_ALLOW = {
"text_config",
"vision_config",
],
"Llama4Config": ["boi_token_index", "eoi_token_index"],
"Llama4TextConfig": [
"interleave_moe_layer_step",
"no_rope_layer_interval",
"no_rope_layers",
"output_router_logits",
"router_aux_loss_coef",
"router_jitter_noise",
],
"Llama4VisionConfig": ["multi_modal_projector_bias", "norm_eps"],
}
@ -358,6 +368,8 @@ def check_attribute_being_used(config_class, attributes, default_value, source_s
"rope_theta",
"partial_rotary_factor",
"pretraining_tp",
"boi_token_index",
"eoi_token_index",
]
attributes_used_in_generation = ["encoder_no_repeat_ngram_size"]

View File

@ -67,6 +67,7 @@ _re_parse_description = re.compile(r"\*optional\*, defaults to (.*)$")
# docstrings instead. If formatting should be ignored for the docstring, you can put a comment # no-format on the
# line before the docstring.
OBJECTS_TO_IGNORE = [
"Llama4Processor",
# Deprecated
"InputExample",
"InputFeatures",

View File

@ -223,13 +223,20 @@ def check_dummies(overwrite: bool = False):
f.write(dummy_files[backend])
else:
# Temporary fix to help people identify which objects introduced are not correctly protected.
found = False
for _actual, _dummy in zip(
actual_dummies["torch"].split("class"), dummy_files["torch"].split("class")
):
if _actual != _dummy:
actual_broken = _actual
dummy_broken = _dummy
found = True
break
if not found:
print("A transient error was found with the dummies, please investigate.")
continue
raise ValueError(
"The main __init__ has objects that are not present in "
f"transformers.utils.dummy_{short_names.get(backend, backend)}_objects.py.\n"

View File

@ -144,6 +144,8 @@ IGNORE_NON_TESTED = (
"Qwen2_5_VLModel", # Building part of bigger (tested) model. Tested implicitly through Qwen2_5_VLForConditionalGeneration.
"MllamaTextModel", # Building part of bigger (tested) model. # TODO: add tests
"MllamaVisionModel", # Building part of bigger (tested) model. # TODO: add tests
"Llama4TextModel", # Building part of bigger (tested) model. # TODO: add tests
"Llama4VisionModel", # Building part of bigger (tested) model. # TODO: add tests
"Emu3VQVAE", # Building part of bigger (tested) model
"Emu3TextModel", # Building part of bigger (tested) model
]
@ -170,6 +172,7 @@ TEST_FILES_WITH_NO_COMMON_TESTS = [
"models/decision_transformer/test_modeling_decision_transformer.py",
"models/bark/test_modeling_bark.py",
"models/shieldgemma2/test_modeling_shieldgemma2.py",
"models/llama4/test_modeling_llama4.py",
]
# Update this list for models that are not in any of the auto MODEL_XXX_MAPPING. Being in this list is an exception and