mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 04:40:06 +06:00
Add llama4 (#37307)
* remove one of the last deps * update fast image processor after refactor * styling * more quality of life improvements * nit * update * cleanups * some cleanups * vllm updates * update fake image token * [convert] Fix typo * [convert] Strip extraneous bytes from shards * [convert] Minor fixes * [convert] Use num_experts * multi-image fixes in modeling + processor * fixup size * 128 experts * Use default rope * Unfuse mlp * simplify a lot inputs embeds merging * remove .item() 👀 * fix from review * Address feedback * Use None "default" for rope_scaling. Add eot. * set seed * return aspect ratios and bug fixes * Moe 128 rebased (#8) * 128 experts * Use default rope * Unfuse mlp * Address feedback * Use None "default" for rope_scaling. Add eot. * Meta/llama quant compat (#7) * add quant compatible model & conversion code for llama4 * fix a few issues * fix a few issues * minor type mapping fix --------- Co-authored-by: Lu Fang <fanglu@fb.com> * use a new config parameter to determine which model definition to use for MoE --------- Co-authored-by: Pedro Cuenca <pedro@huggingface.co> Co-authored-by: Lu Fang <fanglu@fb.com> * un-comment write_tokenizer from converting script * remove un-used imports * [llama4] Pop aspect_ratios from image processor output in Llama4Processor Signed-off-by: Jon Swenson <jmswen@gmail.com> * Fix parameter_count name * Update src/transformers/models/llama4/configuration_llama4.py * nit * Add changes for no_rope, moe_layers, chunked attention. Just need to test all * Update src/transformers/models/llama4/image_processing_llama4_fast.py * nit * fix post merge with main * support flex attention * fixes * fix * add layer * small updates * rebase and delete llm_compressor * nit * [llama4/mm] Add back <|image|> token that delimits global tile * [llama4/mm] Fix Llama 4 image processing unit tests * add explicit dtype Signed-off-by: Jon Swenson <jmswen@gmail.com> * sdpa works * comment todo small * fix model loading Signed-off-by: Zijing Liu <liuzijing2014@gmail.com> * revert * nits * small fix for TP on 1 node * Read new params from config * Add <|eom|> * lol don't know how this got here * adding fp8 * Save processor, fix chat template * style * Add boi/eoi tokens We don't use them. * fixes for now flex seems to work :) * updates * nits * updates * missking keys * add context parallel * update * update * fix * nits * add worldsize and make eager attn work for vision * Ignore new key present in base models * add tp_plan * fix nope Signed-off-by: Zijing Liu <liuzijing2014@gmail.com> * minor fix Signed-off-by: Zijing Liu <liuzijing2014@gmail.com> * Clean up Llama4 vision model * current updates * add support for `attn_temperature_tuning` * add floor scale * add missing attn scales * push what works, dirty trick for the device synch * oups * Fix pad_token_id See https://huggingface.co/ll-re/Llama-4-Scout-17B-16E/discussions/2/files Confirmed in the original codebase. * fix causallml loading * rm * fix tied-weights * fix sdpa * push current version * should work with both short and long * add compressed_tensos & fix fbgemm tp * Fix flex impl * style * chunking * try to revert the potentially breaking change * fix auto factory * fix shapes in general * rm processing * commit cache utils cleanup * Fix context length * fix * allocate * update tp_plan * fix SDPA! * Add support for sparse `Llama4TextMoe` layer from the kernel hub * cleanup * better merge * update * still broken fixing now * nits * revert print * Write max_position_embeddings and max_model_length * Update modeling_llama4.py * Save attention_chunk_size * Sync eos terminators * Read initializer_range * style * remove `dict` * fix * eager should use `chunked_attention_mask` * revert * fixup * fix config * Revert "Merge pull request #36 from huggingface/sparse-llama4-moe" This reverts commitccda19f050
, reversing changes made toa515579aed
. * Fix typo and remove warning with compiled flex and chunked prefill * Fix MoE vs FF (#41) * fix * Use correct no_rope_layers if provided one is empty list * update tests * fix * skipping some tests * fix fp8 loading Signed-off-by: Zijing Liu <liuzijing2014@gmail.com> * fix text geneartion pipeline Signed-off-by: Zijing Liu <liuzijing2014@gmail.com> * eager needs 4D mask * fix * Some cleanup * fix * update * fix * replace correctly module * patch * modulelist * update * update * clean up * Don't move to `cuda:0` in distributed mode * restrict to compressed tensors for now * rm print * Docs! * Fixes * Update docs/source/en/model_doc/llama4.md Co-authored-by: Pedro Cuenca <pedro@huggingface.co> * Fixes * cuda graph fix * revert some stuff * fixup * styling * Update src/transformers/models/llama4/modeling_llama4.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * fixup * commit licence, cleanup here and there and style * more styling changes * fix dummies * fix and clean docstrings * remove comment * remove warning * Only fast image processor is supported * nit * trigger CI * fix issue with flex encoder * fix dynamic cache * Code quality * Code quality * fix more tests for now * Code quality * Code quality * Nuke bunch of failing stuff * Code quality * Code quality * cleanup removal of slow image processor * ruff fix fast image processor * fix * fix styling * Docs * Repo consistency * Repo consistency * fix sliding window issue * separate llama cache * styling * Repo consistency * Repo consistency * push waht works * L4 Repo consistency * Docs * fix last last alst alst alst alstsaltlsltlaslt --------- Signed-off-by: Jon Swenson <jmswen@gmail.com> Signed-off-by: Zijing Liu <liuzijing2014@gmail.com> Co-authored-by: yonigozlan <yoni.gozlan10@gmail.com> Co-authored-by: Pedro Cuenca <pedro@huggingface.co> Co-authored-by: Pablo Montalvo <pablo.montalvo.leroux@gmail.com> Co-authored-by: Pablo Montalvo <39954772+molbap@users.noreply.github.com> Co-authored-by: Keyun Tong <tongkeyun@gmail.com> Co-authored-by: Zijing Liu <liuzijing2014@users.noreply.github.com> Co-authored-by: Lu Fang <fanglu@fb.com> Co-authored-by: Zijing Liu <liuzijing2014@gmail.com> Co-authored-by: Jon Swenson <jmswen@gmail.com> Co-authored-by: jmswen <jmswen@users.noreply.github.com> Co-authored-by: MekkCyber <mekk.cyber@gmail.com> Co-authored-by: Mohamed Mekkouri <93391238+MekkCyber@users.noreply.github.com> Co-authored-by: Mohit Sharma <mohit21sharma.ms@gmail.com> Co-authored-by: Yong Hoon Shin <yhshin@meta.com> Co-authored-by: Marc Sun <marc@huggingface.co> Co-authored-by: drisspg <drisspguessous@gmail.com> Co-authored-by: Cyril Vallez <cyril.vallez@gmail.com> Co-authored-by: Daniël de Kok <me@danieldk.eu> Co-authored-by: Lysandre <hi@lysand.re> Co-authored-by: Ye (Charlotte) Qi <ye.charlotte.qi@gmail.com> Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
parent
aa40fda346
commit
25b7f27234
@ -507,6 +507,8 @@
|
||||
title: Llama2
|
||||
- local: model_doc/llama3
|
||||
title: Llama3
|
||||
- local: model_doc/llama4
|
||||
title: Llama4
|
||||
- local: model_doc/longformer
|
||||
title: Longformer
|
||||
- local: model_doc/longt5
|
||||
|
442
docs/source/en/model_doc/llama4.md
Normal file
442
docs/source/en/model_doc/llama4.md
Normal file
@ -0,0 +1,442 @@
|
||||
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
|
||||
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
||||
rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
|
||||
# Llama4
|
||||
|
||||
|
||||
<div style="float: right;">
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
<img alt="FlashAttention" src="https://img.shields.io/badge/%E2%9A%A1%EF%B8%8E%20FlashAttention-eae0c8?style=flat">
|
||||
</div>
|
||||
</div>
|
||||
|
||||
Llama 4, developed by Meta, introduces a new auto-regressive Mixture-of-Experts (MoE) architecture.
|
||||
This generation includes two models:
|
||||
- The highly capable Llama 4 Maverick with 17B active parameters out of ~400B total, with 128 experts.
|
||||
- The efficient Llama 4 Scout also has 17B active parameters out of ~109B total, using just 16 experts.
|
||||
-
|
||||
Both models leverage early fusion for native multimodality, enabling them to process text and image inputs.
|
||||
Maverick and Scout are both trained on up to 40 trillion tokens on data encompassing 200 languages
|
||||
(with specific fine-tuning support for 12 languages including Arabic, Spanish, German, and Hindi).
|
||||
|
||||
For deployment, Llama 4 Scout is designed for accessibility, fitting on a single server-grade GPU via
|
||||
on-the-fly 4-bit or 8-bitint4 quantization, while Maverick is available in BF16 and FP8 formats.
|
||||
These models are released under the custom Llama 4 Community License Agreement, available on the model repositories.
|
||||
|
||||
You can find all the original Llama checkpoints under the [meta-llama](https://huggingface.co/meta-llama) organization.
|
||||
|
||||
> [!TIP]
|
||||
> The Llama 4 family of models comes in two flavors: 109B, and 402B parameters. Both of these flavors are extremely
|
||||
> large and won't fit on your run-of-the-mill device. See below for some examples to reduce the memory usage of the
|
||||
> model.
|
||||
>
|
||||
> For the download to be faster and more resilient, we recommend installing the `hf_xet` dependency as followed:
|
||||
> `pip install transformers[hf_xet]`
|
||||
|
||||
The examples below demonstrates how to generate with [`Pipeline`] or the [`AutoModel`]. We additionally add an example
|
||||
showcasing how to toggle the right attributes to enable very long-context generations, as some flavors of Llama 4
|
||||
have context lengths going up to 10 million tokens.
|
||||
|
||||
|
||||
<hfoptions id="usage">
|
||||
<hfoption id="Pipeline">
|
||||
|
||||
```py
|
||||
from transformers import pipeline
|
||||
import torch
|
||||
|
||||
model_id = "meta-llama/Llama-4-Scout-17B-16E-Instruct"
|
||||
|
||||
messages = [
|
||||
{"role": "user", "content": "what is the recipe of mayonnaise?"},
|
||||
]
|
||||
|
||||
pipe = pipeline(
|
||||
"text-generation",
|
||||
model=model_id,
|
||||
device_map="auto",
|
||||
torch_dtype=torch.bfloat16
|
||||
)
|
||||
|
||||
output = pipe(messages, do_sample=False, max_new_tokens=200)
|
||||
print(output[0]["generated_text"][-1]["content"])
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="AutoModel - Text only">
|
||||
|
||||
```py
|
||||
from transformers import AutoTokenizer, Llama4ForConditionalGeneration
|
||||
import torch
|
||||
|
||||
model_id = "meta-llama/Llama-4-Scout-17B-16E-Instruct"
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
|
||||
messages = [
|
||||
{"role": "user", "content": "Who are you?"},
|
||||
]
|
||||
inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt", return_dict=True)
|
||||
|
||||
model = Llama4ForConditionalGeneration.from_pretrained(
|
||||
model_id,
|
||||
device_map="auto",
|
||||
torch_dtype=torch.bfloat16
|
||||
)
|
||||
|
||||
outputs = model.generate(**inputs.to(model.device), max_new_tokens=100)
|
||||
outputs = tokenizer.batch_decode(outputs[:, inputs["input_ids"].shape[-1]:])
|
||||
print(outputs[0])
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="AutoModel - Multimodal">
|
||||
|
||||
```py
|
||||
from transformers import AutoProcessor, Llama4ForConditionalGeneration
|
||||
import torch
|
||||
|
||||
model_id = "meta-llama/Llama-4-Scout-17B-16E-Instruct"
|
||||
|
||||
processor = AutoProcessor.from_pretrained(model_id)
|
||||
model = Llama4ForConditionalGeneration.from_pretrained(
|
||||
model_id,
|
||||
device_map="auto",
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
|
||||
img_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/0052a70beed5bf71b92610a43a52df6d286cd5f3/diffusers/rabbit.jpg"
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image", "url": img_url},
|
||||
{"type": "text", "text": "Describe this image in two sentences."},
|
||||
]
|
||||
},
|
||||
]
|
||||
|
||||
inputs = processor.apply_chat_template(
|
||||
messages,
|
||||
add_generation_prompt=True,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
).to(model.device)
|
||||
|
||||
outputs = model.generate(
|
||||
**inputs,
|
||||
max_new_tokens=256,
|
||||
)
|
||||
|
||||
response = processor.batch_decode(outputs[:, inputs["input_ids"].shape[-1]:])[0]
|
||||
print(response)
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="AutoModel - Multimodal with multiple images">
|
||||
|
||||
```py
|
||||
from transformers import AutoProcessor, Llama4ForConditionalGeneration
|
||||
import torch
|
||||
|
||||
model_id = "meta-llama/Llama-4-Scout-17B-16E-Instruct"
|
||||
|
||||
processor = AutoProcessor.from_pretrained(model_id)
|
||||
model = Llama4ForConditionalGeneration.from_pretrained(
|
||||
model_id,
|
||||
device_map="auto",
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
|
||||
url1 = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/0052a70beed5bf71b92610a43a52df6d286cd5f3/diffusers/rabbit.jpg"
|
||||
url2 = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/datasets/cat_style_layout.png"
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image", "url": url1},
|
||||
{"type": "image", "url": url2},
|
||||
{"type": "text", "text": "Can you describe how these two images are similar, and how they differ?"},
|
||||
]
|
||||
},
|
||||
]
|
||||
|
||||
inputs = processor.apply_chat_template(
|
||||
messages,
|
||||
add_generation_prompt=True,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
).to(model.device)
|
||||
|
||||
outputs = model.generate(
|
||||
**inputs,
|
||||
max_new_tokens=256,
|
||||
)
|
||||
|
||||
response = processor.batch_decode(outputs[:, inputs["input_ids"].shape[-1]:])[0]
|
||||
print(response)
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="AutoModel - Long context">
|
||||
|
||||
Beware: the example below uses both `device_map="auto"` and flex-attention.
|
||||
Please use `torchrun` to run this example in tensor-parallel mode.
|
||||
|
||||
We will work to enable running with `device_map="auto"` and flex-attention without
|
||||
tensor-parallel in the future.
|
||||
|
||||
```py
|
||||
from transformers import Llama4ForConditionalGeneration, AutoTokenizer
|
||||
import torch
|
||||
import time
|
||||
|
||||
file = "very_long_context_prompt.txt"
|
||||
model_id = "meta-llama/Llama-4-Scout-17B-16E-Instruct"
|
||||
|
||||
with open(file, "r") as f:
|
||||
very_long_text = "\n".join(f.readlines())
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
model = Llama4ForConditionalGeneration.from_pretrained(
|
||||
model_id,
|
||||
device_map="auto",
|
||||
attn_implementation="flex_attention",
|
||||
torch_dtype=torch.bfloat16
|
||||
)
|
||||
|
||||
messages = [
|
||||
{"role": "user", "content": f"Look at the following texts: [{very_long_text}]\n\n\n\nWhat are the books, and who wrote them? Make me a nice list."},
|
||||
]
|
||||
input_ids = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt")
|
||||
|
||||
torch.cuda.synchronize()
|
||||
start = time.time()
|
||||
out = model.generate(
|
||||
input_ids.to(model.device),
|
||||
prefill_chunk_size=2048*8,
|
||||
max_new_tokens=300,
|
||||
cache_implementation="hybrid",
|
||||
)
|
||||
print(time.time()-start)
|
||||
print(tokenizer.batch_decode(out[:, input_ids.shape[-1]:]))
|
||||
print(f"{torch.cuda.max_memory_allocated(model.device) / 1024**3:.2f} GiB")
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
## Efficiency; how to get the best out of llama 4
|
||||
|
||||
### The Attention methods
|
||||
|
||||
Updating the default attention function can significantly improve compute performance as well as memory usage. Refer to the [Attention Interface](../attention_interface) overview for an in-depth explanation of our interface.
|
||||
|
||||
As of release, the Llama 4 model supports the following attention methods: `eager`, `flex_attention`, `sdpa`. We recommend using `flex_attention` for best results.
|
||||
Switching attention mechanism is done at the model initialization step:
|
||||
|
||||
|
||||
<hfoptions id="Attention">
|
||||
<hfoption id="Flex Attention">
|
||||
|
||||
Setting Flex Attention ensures the best results with the very long context the model can handle.
|
||||
|
||||
> [!TIP] Beware: the example below uses both `device_map="auto"` and flex-attention.
|
||||
> Please use `torchrun` to run this example in tensor-parallel mode.
|
||||
>
|
||||
> We will work to enable running with `device_map="auto"` and flex-attention without
|
||||
> tensor-parallel in the future.
|
||||
|
||||
```py
|
||||
from transformers import Llama4ForConditionalGeneration
|
||||
import torch
|
||||
|
||||
model = Llama4ForConditionalGeneration.from_pretrained(
|
||||
model_id,
|
||||
attn_implementation="flex_attention",
|
||||
device_map="auto",
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
```
|
||||
</hfoption>
|
||||
<hfoption id="SDPA">
|
||||
The `sdpa` attention method is generally more compute-efficient than the `eager` method.
|
||||
|
||||
```py
|
||||
from transformers import Llama4ForConditionalGeneration
|
||||
import torch
|
||||
|
||||
model = Llama4ForConditionalGeneration.from_pretrained(
|
||||
model_id,
|
||||
attn_implementation="sdpa",
|
||||
device_map="auto",
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
```
|
||||
</hfoption>
|
||||
<hfoption id="Eager">
|
||||
The `eager` attention method is set by default, so no need for anything different when loading the model:
|
||||
|
||||
```py
|
||||
from transformers import Llama4ForConditionalGeneration
|
||||
import torch
|
||||
|
||||
model = Llama4ForConditionalGeneration.from_pretrained(
|
||||
model_id,
|
||||
device_map="auto",
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
```
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
|
||||
### Quantization
|
||||
|
||||
Quantization reduces the memory burden of large models by representing the weights in a lower precision. Refer to the [Quantization](../quantization/overview) overview for available quantization backends.
|
||||
At time of release, both FBGEMM and LLM-Compressor are supported; more quantization methods will be supported in the days that follow the release.
|
||||
|
||||
See below for examples using both:
|
||||
|
||||
|
||||
|
||||
Here is an example loading an BF16 model in FP8 using the FBGEMM approach:
|
||||
|
||||
<hfoptions id="Quantization">
|
||||
<hfoption id="FBGEMM">
|
||||
|
||||
```python
|
||||
from transformers import AutoTokenizer, Llama4ForConditionalGeneration, FbgemmFp8Config
|
||||
import torch
|
||||
|
||||
model_id = "meta-llama/Llama-4-Scout-17B-16E-Instruct"
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
|
||||
messages = [
|
||||
{"role": "user", "content": "Who are you?"},
|
||||
]
|
||||
inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt", return_dict=True)
|
||||
|
||||
model = Llama4ForConditionalGeneration.from_pretrained(
|
||||
model_id,
|
||||
device_map="auto",
|
||||
torch_dtype=torch.bfloat16,
|
||||
quantization_config=FbgemmFp8Config()
|
||||
)
|
||||
|
||||
outputs = model.generate(**inputs.to(model.device), max_new_tokens=100)
|
||||
outputs = tokenizer.batch_decode(outputs[:, inputs["input_ids"].shape[-1]:])
|
||||
print(outputs[0])
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="LLM-Compressor">
|
||||
|
||||
To use the LLM-Compressor technique, we recommend leveraging the pre-quantized FP8 checkpoint available with the release:
|
||||
|
||||
```python
|
||||
from transformers import AutoTokenizer, Llama4ForConditionalGeneration
|
||||
import torch
|
||||
|
||||
model_id = "meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8"
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
|
||||
messages = [
|
||||
{"role": "user", "content": "Who are you?"},
|
||||
]
|
||||
inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt", return_dict=True)
|
||||
|
||||
model = Llama4ForConditionalGeneration.from_pretrained(
|
||||
model_id,
|
||||
tp_plan="auto",
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
|
||||
outputs = model.generate(**inputs.to(model.device), max_new_tokens=100)
|
||||
outputs = tokenizer.batch_decode(outputs[:, inputs["input_ids"].shape[-1]:])
|
||||
print(outputs[0])
|
||||
```
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
### Offloading
|
||||
|
||||
Enabling CPU-offloading means that components of the model might be moved to CPU instead of GPU in case the GPU-memory available isn't sufficient to load the entire model.
|
||||
At inference, different components will be loaded/unloaded from/to the GPU on the fly. This ensures that the model can be loaded on smaller machines as long as the CPU-memory is sufficient.
|
||||
However, this also slows down inference as it adds communication overhead.
|
||||
|
||||
In order to enable CPU-offloading, you simply need to specify the `device_map` to `auto` at model load:
|
||||
|
||||
```py
|
||||
from transformers import Llama4ForConditionalGeneration
|
||||
import torch
|
||||
|
||||
model = Llama4ForConditionalGeneration.from_pretrained(
|
||||
model_id,
|
||||
device_map="auto",
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
```
|
||||
|
||||
## Llama4Config
|
||||
|
||||
[[autodoc]] Llama4Config
|
||||
|
||||
## Llama4TextConfig
|
||||
|
||||
[[autodoc]] Llama4TextConfig
|
||||
|
||||
## Llama4VisionConfig
|
||||
|
||||
[[autodoc]] Llama4VisionConfig
|
||||
|
||||
## Llama4Processor
|
||||
|
||||
[[autodoc]] Llama4Processor
|
||||
|
||||
## Llama4ImageProcessorFast
|
||||
|
||||
[[autodoc]] Llama4ImageProcessorFast
|
||||
|
||||
## Llama4ForConditionalGeneration
|
||||
|
||||
[[autodoc]] Llama4ForConditionalGeneration
|
||||
- forward
|
||||
|
||||
## Llama4ForCausalLM
|
||||
|
||||
[[autodoc]] Llama4ForCausalLM
|
||||
- forward
|
||||
|
||||
## Llama4TextModel
|
||||
|
||||
[[autodoc]] Llama4TextModel
|
||||
- forward
|
||||
|
||||
## Llama4ForCausalLM
|
||||
|
||||
[[autodoc]] Llama4ForCausalLM
|
||||
- forward
|
||||
|
||||
## Llama4VisionModel
|
||||
|
||||
[[autodoc]] Llama4VisionModel
|
||||
- forward
|
@ -562,6 +562,12 @@ _import_structure = {
|
||||
"models.levit": ["LevitConfig"],
|
||||
"models.lilt": ["LiltConfig"],
|
||||
"models.llama": ["LlamaConfig"],
|
||||
"models.llama4": [
|
||||
"Llama4Config",
|
||||
"Llama4Processor",
|
||||
"Llama4TextConfig",
|
||||
"Llama4VisionConfig",
|
||||
],
|
||||
"models.llava": [
|
||||
"LlavaConfig",
|
||||
"LlavaProcessor",
|
||||
@ -1354,6 +1360,7 @@ else:
|
||||
_import_structure["models.detr"].append("DetrImageProcessorFast")
|
||||
_import_structure["models.gemma3"].append("Gemma3ImageProcessorFast")
|
||||
_import_structure["models.got_ocr2"].append("GotOcr2ImageProcessorFast")
|
||||
_import_structure["models.llama4"].append("Llama4ImageProcessorFast")
|
||||
_import_structure["models.llava"].append("LlavaImageProcessorFast")
|
||||
_import_structure["models.llava_next"].append("LlavaNextImageProcessorFast")
|
||||
_import_structure["models.llava_onevision"].append("LlavaOnevisionImageProcessorFast")
|
||||
@ -2510,6 +2517,15 @@ else:
|
||||
"GlmPreTrainedModel",
|
||||
]
|
||||
)
|
||||
_import_structure["models.llama4"].extend(
|
||||
[
|
||||
"Llama4ForCausalLM",
|
||||
"Llama4ForConditionalGeneration",
|
||||
"Llama4TextModel",
|
||||
"Llama4VisionModel",
|
||||
"Llama4PreTrainedModel",
|
||||
]
|
||||
)
|
||||
_import_structure["models.glpn"].extend(
|
||||
[
|
||||
"GLPNForDepthEstimation",
|
||||
@ -5807,6 +5823,12 @@ if TYPE_CHECKING:
|
||||
from .models.levit import LevitConfig
|
||||
from .models.lilt import LiltConfig
|
||||
from .models.llama import LlamaConfig
|
||||
from .models.llama4 import (
|
||||
Llama4Config,
|
||||
Llama4Processor,
|
||||
Llama4TextConfig,
|
||||
Llama4VisionConfig,
|
||||
)
|
||||
from .models.llava import (
|
||||
LlavaConfig,
|
||||
LlavaProcessor,
|
||||
@ -6646,6 +6668,7 @@ if TYPE_CHECKING:
|
||||
from .models.detr import DetrImageProcessorFast
|
||||
from .models.gemma3 import Gemma3ImageProcessorFast
|
||||
from .models.got_ocr2 import GotOcr2ImageProcessorFast
|
||||
from .models.llama4 import Llama4ImageProcessorFast
|
||||
from .models.llava import LlavaImageProcessorFast
|
||||
from .models.llava_next import LlavaNextImageProcessorFast
|
||||
from .models.llava_onevision import LlavaOnevisionImageProcessorFast
|
||||
@ -7827,6 +7850,13 @@ if TYPE_CHECKING:
|
||||
LlamaModel,
|
||||
LlamaPreTrainedModel,
|
||||
)
|
||||
from .models.llama4 import (
|
||||
Llama4ForCausalLM,
|
||||
Llama4ForConditionalGeneration,
|
||||
Llama4PreTrainedModel,
|
||||
Llama4TextModel,
|
||||
Llama4VisionModel,
|
||||
)
|
||||
from .models.llava import (
|
||||
LlavaForConditionalGeneration,
|
||||
LlavaPreTrainedModel,
|
||||
|
@ -1811,6 +1811,200 @@ class HybridCache(Cache):
|
||||
self.value_cache[layer_idx].zero_()
|
||||
|
||||
|
||||
class HybridChunkedCache(Cache):
|
||||
"""
|
||||
Hybrid Cache class to be used with `torch.compile` for Gemma2 models that alternate between a local sliding window attention
|
||||
and global attention in every other layer. Under the hood, Hybrid Cache leverages ["SlidingWindowCache"] for sliding window attention
|
||||
and ["StaticCache"] for global attention. For more information, see the documentation of each subcomponeent cache class.
|
||||
|
||||
Parameters:
|
||||
config (`PretrainedConfig):
|
||||
The configuration file defining the shape-related attributes required to initialize the static cache.
|
||||
max_batch_size (`int`):
|
||||
The maximum batch size with which the model will be used. Note that a new instance must be instantiated if a
|
||||
smaller batch size is used.
|
||||
max_cache_len (`int`, *optional*):
|
||||
The maximum sequence length with which the model will be used.
|
||||
device (`torch.device` or `str`, *optional*):
|
||||
The device on which the cache should be initialized. If you're using more than 1 computation device, you
|
||||
should pass the `layer_device_map` argument instead.
|
||||
dtype (torch.dtype, *optional*, defaults to `torch.bfloat16`):
|
||||
The default `dtype` to use when initializing the layer.
|
||||
layer_device_map (`Optional[Dict[int, Union[str, torch.device, int]]]]`, *optional*):
|
||||
Mapping between the layers and its device. This is required when you are manually initializing the cache
|
||||
and the model is split between different gpus. You can know which layers mapped to which device by
|
||||
checking the associated device_map: `model.hf_device_map`.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoTokenizer, AutoModelForCausalLM, HybridCache
|
||||
|
||||
>>> model = AutoModelForCausalLM.from_pretrained("google/gemma-2-2b")
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b")
|
||||
|
||||
>>> inputs = tokenizer(text="My name is Gemma", return_tensors="pt")
|
||||
|
||||
>>> # Prepare a cache class and pass it to model's forward
|
||||
>>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate
|
||||
>>> max_generated_length = inputs.input_ids.shape[1] + 10
|
||||
>>> past_key_values = HybridCache(config=model.config, max_batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype)
|
||||
>>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
|
||||
>>> outputs.past_key_values # access cache filled with key/values from generation
|
||||
HybridCache()
|
||||
```
|
||||
"""
|
||||
|
||||
# TODO (joao): dive deeper into gemma2 and paligemma -- there are reports of speed loss with compilation. Revert
|
||||
# ALL changes from the PR that commented the line below when reactivating it.
|
||||
# is_compileable = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
max_batch_size: int,
|
||||
max_cache_len: Optional[int] = None,
|
||||
device: Union[torch.device, str, None] = None,
|
||||
dtype: torch.dtype = torch.bfloat16,
|
||||
layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
if not hasattr(config, "sliding_window") or config.sliding_window is None:
|
||||
self.sliding_window = getattr(config.get_text_config(), "attention_chunk_size", 8192)
|
||||
else:
|
||||
self.sliding_window = config.sliding_window
|
||||
self.max_cache_len = max_cache_len
|
||||
self.max_batch_size = max_batch_size
|
||||
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
|
||||
self._dtype = dtype
|
||||
|
||||
if hasattr(config.get_text_config(), "no_rope_layers"):
|
||||
self.is_sliding = config.no_rope_layers
|
||||
else:
|
||||
layer_switch = getattr(config, "sliding_window_pattern", 2)
|
||||
self.is_sliding = [bool((i + 1) % layer_switch) for i in range(config.num_hidden_layers)]
|
||||
|
||||
self.key_cache: List[torch.Tensor] = []
|
||||
self.value_cache: List[torch.Tensor] = []
|
||||
self.cumulative_length = [0 for _ in range(config.num_hidden_layers)]
|
||||
|
||||
def initialise_cache_layer(self, layer_idx, key_states):
|
||||
if len(self.key_cache) > layer_idx:
|
||||
return
|
||||
|
||||
num_key_value_heads = key_states.shape[1]
|
||||
device = key_states.device
|
||||
global_cache_shape = (self.max_batch_size, num_key_value_heads, self.max_cache_len, self.head_dim)
|
||||
sliding_cache_shape = (
|
||||
self.max_batch_size,
|
||||
num_key_value_heads,
|
||||
self.sliding_window,
|
||||
self.head_dim,
|
||||
)
|
||||
# Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph
|
||||
# breaks when updating the cache.
|
||||
cache_shape = sliding_cache_shape if self.is_sliding[layer_idx] else global_cache_shape
|
||||
new_layer_key_cache = torch.zeros(cache_shape, dtype=self._dtype, device=device)
|
||||
new_layer_value_cache = torch.zeros(cache_shape, dtype=self._dtype, device=device)
|
||||
torch._dynamo.mark_static_address(new_layer_key_cache)
|
||||
torch._dynamo.mark_static_address(new_layer_value_cache)
|
||||
self.key_cache.append(new_layer_key_cache)
|
||||
self.value_cache.append(new_layer_value_cache)
|
||||
|
||||
def _sliding_update(self, cache_position, layer_idx, key_states, value_states, k_out, v_out, max_cache_len):
|
||||
cumulative_length = self.cumulative_length[layer_idx]
|
||||
is_full = cumulative_length >= max_cache_len
|
||||
if is_full:
|
||||
full_key_states = torch.cat((k_out[:, :, 1:, :], key_states), dim=-2)
|
||||
full_value_states = torch.cat((v_out[:, :, 1:, :], value_states), dim=-2)
|
||||
elif not is_full and cumulative_length + key_states.shape[2] > max_cache_len:
|
||||
full_key_states = torch.cat((k_out[:, :, :cumulative_length, :], key_states), dim=-2)
|
||||
full_value_states = torch.cat((v_out[:, :, :cumulative_length, :], value_states), dim=-2)
|
||||
else:
|
||||
self.key_cache[layer_idx].index_copy_(2, cache_position, key_states)
|
||||
self.value_cache[layer_idx].index_copy_(2, cache_position, value_states)
|
||||
self.cumulative_length[layer_idx] += key_states.shape[-2]
|
||||
return self.key_cache[layer_idx], self.value_cache[layer_idx]
|
||||
|
||||
self.key_cache[layer_idx].copy_(full_key_states[:, :, -max_cache_len:, :])
|
||||
self.value_cache[layer_idx].copy_(full_value_states[:, :, -max_cache_len:, :])
|
||||
self.cumulative_length[layer_idx] += key_states.shape[-2]
|
||||
# we should return the whole states instead of k_out, v_out to take the whole prompt
|
||||
# into consideration when building kv cache instead of just throwing away tokens outside of the window
|
||||
return full_key_states, full_value_states
|
||||
|
||||
def _static_update(self, cache_position, layer_idx, key_states, value_states, k_out, v_out, max_cache_len):
|
||||
k_out[:, :, cache_position] = key_states
|
||||
v_out[:, :, cache_position] = value_states
|
||||
|
||||
self.key_cache[layer_idx] = k_out
|
||||
self.value_cache[layer_idx] = v_out
|
||||
return k_out, v_out
|
||||
|
||||
def update(
|
||||
self,
|
||||
key_states: torch.Tensor,
|
||||
value_states: torch.Tensor,
|
||||
layer_idx: int,
|
||||
cache_kwargs: Optional[Dict[str, Any]] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
if cache_kwargs is None:
|
||||
cache_kwargs = {}
|
||||
cache_position = cache_kwargs.get("cache_position")
|
||||
self.initialise_cache_layer(layer_idx, key_states)
|
||||
|
||||
# These two `if` blocks are only reached in multigpu and if `layer_device_map` is not passed. They are used
|
||||
# when the cache is initialized in the forward pass (e.g. Gemma2)
|
||||
if self.key_cache[layer_idx].device != key_states.device:
|
||||
self.key_cache[layer_idx] = self.key_cache[layer_idx].to(key_states.device)
|
||||
if self.value_cache[layer_idx].device != value_states.device:
|
||||
self.value_cache[layer_idx] = self.value_cache[layer_idx].to(value_states.device)
|
||||
|
||||
k_out = self.key_cache[layer_idx]
|
||||
v_out = self.value_cache[layer_idx]
|
||||
key_states = key_states.to(k_out.dtype)
|
||||
value_states = value_states.to(v_out.dtype)
|
||||
|
||||
if self.is_sliding[layer_idx]:
|
||||
update_fn = self._sliding_update
|
||||
else:
|
||||
update_fn = self._static_update
|
||||
|
||||
return update_fn(
|
||||
cache_position,
|
||||
layer_idx,
|
||||
key_states,
|
||||
value_states,
|
||||
k_out,
|
||||
v_out,
|
||||
k_out.shape[2],
|
||||
)
|
||||
|
||||
def get_max_cache_shape(self) -> Optional[int]:
|
||||
return self.max_cache_len
|
||||
|
||||
def get_seq_length(self, layer_idx: Optional[int] = 0):
|
||||
# Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's
|
||||
# limit the check to the first batch member and head dimension.
|
||||
# TODO: deprecate this function in favor of `cache_position`
|
||||
if layer_idx != 0:
|
||||
raise ValueError(
|
||||
"`get_seq_length` on `HybridCache` may get inconsistent results depending on the layer index. "
|
||||
"Using the `layer_idx` argument is not supported."
|
||||
)
|
||||
if len(self.key_cache) == 0:
|
||||
return 0
|
||||
return (self.key_cache[layer_idx][0, 0].any(dim=-1)).sum()
|
||||
|
||||
def reset(self):
|
||||
"""Resets the cache values while preserving the objects"""
|
||||
for layer_idx in range(len(self.key_cache)):
|
||||
# In-place ops prevent breaking the static address
|
||||
self.key_cache[layer_idx].zero_()
|
||||
self.value_cache[layer_idx].zero_()
|
||||
self.cumulative_length = [0 for _ in range(len(self.cumulative_length))]
|
||||
|
||||
|
||||
class MambaCache:
|
||||
"""
|
||||
Cache for mamba model which does not have attention mechanism and key value states.
|
||||
|
@ -801,18 +801,19 @@ class PretrainedConfig(PushToHubMixin):
|
||||
|
||||
def to_diff_dict(self) -> dict[str, Any]:
|
||||
"""
|
||||
Removes all attributes from config which correspond to the default config attributes for better readability and
|
||||
serializes to a Python dictionary.
|
||||
Removes all attributes from the configuration that correspond to the default config attributes for
|
||||
better readability, while always retaining the `config` attribute from the class. Serializes to a
|
||||
Python dictionary.
|
||||
|
||||
Returns:
|
||||
`Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance,
|
||||
Dict[str, Any]: Dictionary of all the attributes that make up this configuration instance.
|
||||
"""
|
||||
config_dict = self.to_dict()
|
||||
|
||||
# get the default config dict
|
||||
# Get the default config dict (from a fresh PreTrainedConfig instance)
|
||||
default_config_dict = PretrainedConfig().to_dict()
|
||||
|
||||
# get class specific config dict
|
||||
# Get class-specific config dict if not part of a composition
|
||||
class_config_dict = self.__class__().to_dict() if not self.is_composition else {}
|
||||
|
||||
serializable_config_dict = {}
|
||||
@ -847,8 +848,7 @@ class PretrainedConfig(PushToHubMixin):
|
||||
if not isinstance(self.quantization_config, dict)
|
||||
else self.quantization_config
|
||||
)
|
||||
|
||||
# pop the `_pre_quantization_dtype` as torch.dtypes are not serializable.
|
||||
# Pop the `_pre_quantization_dtype` as torch.dtypes are not serializable.
|
||||
_ = serializable_config_dict.pop("_pre_quantization_dtype", None)
|
||||
|
||||
self.dict_torch_dtype_to_str(serializable_config_dict)
|
||||
|
@ -52,6 +52,7 @@ if is_torch_available():
|
||||
from ..cache_utils import (
|
||||
HQQQuantizedCache,
|
||||
HybridCache,
|
||||
HybridChunkedCache,
|
||||
MambaCache,
|
||||
OffloadedStaticCache,
|
||||
QuantizedCacheConfig,
|
||||
@ -69,6 +70,7 @@ if is_torch_available():
|
||||
"offloaded_static": OffloadedStaticCache,
|
||||
"sliding_window": SlidingWindowCache,
|
||||
"hybrid": HybridCache,
|
||||
"hybrid_chunked": HybridChunkedCache,
|
||||
"mamba": MambaCache,
|
||||
}
|
||||
QUANT_BACKEND_CLASSES_MAPPING = {"quanto": QuantoQuantizedCache, "HQQ": HQQQuantizedCache}
|
||||
@ -416,6 +418,7 @@ class GenerationConfig(PushToHubMixin):
|
||||
if isinstance(self.cache_config, dict):
|
||||
self.cache_config = cache_config_class.from_dict(self.cache_config)
|
||||
self.return_legacy_cache = kwargs.pop("return_legacy_cache", None)
|
||||
self.prefill_chunk_size = kwargs.pop("prefill_chunk_size", None)
|
||||
|
||||
# Parameters for manipulation of the model output logits
|
||||
self.temperature = kwargs.pop("temperature", 1.0)
|
||||
|
@ -1830,6 +1830,9 @@ class GenerationMixin:
|
||||
|
||||
Returns the resulting cache object.
|
||||
"""
|
||||
if cache_implementation == "hybrid" and "llama4" in getattr(self.config, "model_type", ""):
|
||||
cache_implementation = "hybrid_chunked"
|
||||
|
||||
cache_cls: Cache = NEED_SETUP_CACHE_CLASSES_MAPPING[cache_implementation]
|
||||
requires_cross_attention_cache = (
|
||||
self.config.is_encoder_decoder or model_kwargs.get("encoder_outputs") is not None
|
||||
@ -3405,7 +3408,12 @@ class GenerationMixin:
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "0"
|
||||
model_forward = self.get_compiled_call(generation_config.compile_config)
|
||||
|
||||
is_prefill = True
|
||||
if generation_config.prefill_chunk_size is not None:
|
||||
model_kwargs = self._prefill_chunking(input_ids, generation_config, **model_kwargs)
|
||||
is_prefill = False
|
||||
else:
|
||||
is_prefill = True
|
||||
|
||||
while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
|
||||
# prepare model inputs
|
||||
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
|
||||
@ -4855,6 +4863,45 @@ class GenerationMixin:
|
||||
else:
|
||||
return input_ids
|
||||
|
||||
def _prefill_chunking(self, input_ids: torch.LongTensor, generation_config: GenerationConfig, **model_kwargs):
|
||||
# Even if we are not compiling the forward, flex is always compiled when used. With chunk prefill, we may
|
||||
# end up needing just a bit more graphs than the default (which is 8). Doing this avoids very cryptic warnings
|
||||
torch._dynamo.config.cache_size_limit = 64
|
||||
|
||||
chunk_size = generation_config.prefill_chunk_size
|
||||
# Only chunk up the token just before last, so that decoding is completely performed outside this function
|
||||
# (here we simply prefill the cache)
|
||||
input_chunks = torch.split(input_ids[:, :-1], chunk_size, dim=-1)
|
||||
|
||||
if "past_key_values" not in model_kwargs:
|
||||
raise ValueError("Cannot use prefill chunkink without a cache")
|
||||
|
||||
model_forward = self.get_compiled_call(generation_config.compile_config)
|
||||
attention_mask = model_kwargs.pop("attention_mask", None)
|
||||
|
||||
past_length = 0
|
||||
for input_chunk in input_chunks:
|
||||
current_length = past_length + input_chunk.shape[-1]
|
||||
# Prepare inputs
|
||||
if attention_mask is not None:
|
||||
model_kwargs["attention_mask"] = attention_mask[:, :current_length]
|
||||
model_kwargs["cache_position"] = torch.arange(
|
||||
past_length, current_length, dtype=torch.long, device=input_chunk.device
|
||||
)
|
||||
model_kwargs["position_ids"] = model_kwargs["cache_position"].unsqueeze(0)
|
||||
model_inputs = self.prepare_inputs_for_generation(input_chunk, **model_kwargs)
|
||||
|
||||
outputs = model_forward(**model_inputs, return_dict=True)
|
||||
|
||||
model_kwargs["past_key_values"] = outputs.past_key_values
|
||||
past_length = current_length
|
||||
|
||||
model_kwargs["attention_mask"] = attention_mask
|
||||
model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + 1
|
||||
_ = model_kwargs.pop("position_ids", None)
|
||||
|
||||
return model_kwargs
|
||||
|
||||
|
||||
def _speculative_sampling(
|
||||
candidate_input_ids,
|
||||
|
@ -53,7 +53,7 @@ _import_structure = {
|
||||
"unset_hf_deepspeed_config",
|
||||
],
|
||||
"eetq": ["replace_with_eetq_linear"],
|
||||
"fbgemm_fp8": ["FbgemmFp8Linear", "replace_with_fbgemm_fp8_linear"],
|
||||
"fbgemm_fp8": ["FbgemmFp8Linear", "FbgemmFp8Llama4TextExperts", "replace_with_fbgemm_fp8_linear"],
|
||||
"finegrained_fp8": ["FP8Linear", "replace_with_fp8_linear"],
|
||||
"fsdp": ["is_fsdp_managed_module"],
|
||||
"ggml": [
|
||||
@ -192,7 +192,7 @@ if TYPE_CHECKING:
|
||||
unset_hf_deepspeed_config,
|
||||
)
|
||||
from .eetq import replace_with_eetq_linear
|
||||
from .fbgemm_fp8 import FbgemmFp8Linear, replace_with_fbgemm_fp8_linear
|
||||
from .fbgemm_fp8 import FbgemmFp8Linear, FbgemmFp8Llama4TextExperts, replace_with_fbgemm_fp8_linear
|
||||
from .finegrained_fp8 import FP8Linear, replace_with_fp8_linear
|
||||
from .fsdp import is_fsdp_managed_module
|
||||
from .ggml import (
|
||||
|
54
src/transformers/integrations/compressed_tensors.py
Normal file
54
src/transformers/integrations/compressed_tensors.py
Normal file
@ -0,0 +1,54 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from transformers.utils import is_torch_available
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from transformers.models.llama4.modeling_llama4 import Llama4TextMLP
|
||||
|
||||
|
||||
def skip(*args, **kwargs):
|
||||
pass
|
||||
|
||||
|
||||
class CompressedExpertsLinear(nn.Module):
|
||||
"""
|
||||
A module that implements a compressed version of a list of expert modules.
|
||||
This is specifically designed to work with Llama4TextExperts in MoE layers.
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
# Skip random weight initialization for experts. Otherwise,
|
||||
# the init of this module would take over minutes. For a model
|
||||
# with tens of layers of experts, it would easily take over 20 minutes.
|
||||
nn.init.kaiming_uniform_ = skip
|
||||
nn.init.uniform_ = skip
|
||||
nn.init.normal_ = skip
|
||||
super().__init__()
|
||||
self.num_experts = config.num_local_experts
|
||||
self.expert_modules = nn.ModuleList([Llama4TextMLP(config) for _ in range(self.num_experts)])
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = hidden_states.reshape(self.num_experts, -1, hidden_states.shape[-1])
|
||||
expert_routed_out_list = []
|
||||
for expert_idx in range(self.num_experts):
|
||||
expert_routed_out_list.append(self.expert_modules[expert_idx](hidden_states[expert_idx]))
|
||||
routed_out = torch.cat(expert_routed_out_list, dim=0)
|
||||
return routed_out
|
@ -12,6 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ..activations import ACT2FN
|
||||
from ..utils import is_accelerate_available, is_fbgemm_gpu_available, is_torch_available, logging
|
||||
|
||||
|
||||
@ -28,36 +29,36 @@ if is_fbgemm_gpu_available():
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class FbgemmFp8Linear(torch.nn.Module):
|
||||
class FbgemmFp8Linear(torch.nn.Linear):
|
||||
def __init__(self, in_features, out_features, bias, weight_dtype=torch.float32):
|
||||
super().__init__()
|
||||
super().__init__(in_features, out_features, bias)
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
|
||||
self.register_buffer("weight", torch.zeros((out_features, in_features), dtype=torch.float8_e4m3fn))
|
||||
self.register_buffer("weight_scale", torch.zeros((out_features, 1), dtype=weight_dtype))
|
||||
self.weight = torch.nn.Parameter(torch.zeros((out_features, in_features), dtype=torch.float8_e4m3fn))
|
||||
self.weight_scale = torch.nn.Parameter(torch.zeros((out_features, 1), dtype=weight_dtype))
|
||||
self.register_buffer("input_scale_ub", torch.zeros([1], dtype=torch.float), persistent=False)
|
||||
|
||||
if bias:
|
||||
self.register_buffer("bias", torch.zeros((self.out_features), dtype=weight_dtype))
|
||||
self.bias = torch.nn.Parameter(torch.zeros((self.out_features), dtype=weight_dtype))
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
def forward(self, x):
|
||||
num_tokens = None
|
||||
# quantize_fp8_per_row will squash the leading dimensions, so save the desired shape here
|
||||
output_shape = (*x.shape[:-1], -1)
|
||||
# x_quantized and x_scale are not necessarily on the same device as x, this is an issue.
|
||||
# https://github.com/pytorch/FBGEMM/blob/e08af8539c391437f447173863df0f3f6f6f1855/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cu#L1237C3-L1237C45
|
||||
x_quantized, x_scale = torch.ops.fbgemm.quantize_fp8_per_row(
|
||||
x.view(-1, x.shape[-1]), num_tokens, self.input_scale_ub
|
||||
x.view(-1, x.shape[-1]), scale_ub=self.input_scale_ub
|
||||
)
|
||||
# moving x_quantized, x_scale here creates glibberish output ... However, if we move the output, it works
|
||||
# x_quantized, x_scale = x_quantized.to(x.device), x_scale.to(x.device)
|
||||
|
||||
# The computation still happens on the device where self.weight is even if x_quantized is not on the same device as self.weight
|
||||
weight_scale_float32 = self.weight_scale.to(torch.float32)
|
||||
output = torch.ops.fbgemm.f8f8bf16_rowwise(
|
||||
x_quantized, self.weight, x_scale, self.weight_scale, use_fast_accum=True
|
||||
x_quantized, self.weight, x_scale, weight_scale_float32, use_fast_accum=True
|
||||
)
|
||||
output = output + self.bias if self.bias is not None else output
|
||||
# Hacky for now, we have the output to the device of x
|
||||
@ -67,6 +68,92 @@ class FbgemmFp8Linear(torch.nn.Module):
|
||||
return output
|
||||
|
||||
|
||||
class FbgemmFp8Llama4TextExperts(nn.Module):
|
||||
def __init__(self, config, dtype=torch.float32):
|
||||
super().__init__()
|
||||
self.num_experts = config.num_local_experts
|
||||
self.intermediate_size = config.intermediate_size
|
||||
self.hidden_size = config.hidden_size
|
||||
self.expert_dim = self.intermediate_size
|
||||
self.act_fn = ACT2FN[config.hidden_act]
|
||||
# Register FP8 buffers for gate_up_proj
|
||||
self.gate_up_proj = torch.nn.Parameter(
|
||||
torch.zeros((self.num_experts, self.hidden_size, 2 * self.expert_dim), dtype=torch.float8_e4m3fn)
|
||||
)
|
||||
self.gate_up_proj_scale = torch.nn.Parameter(
|
||||
torch.zeros((self.num_experts, 1, self.expert_dim * 2), dtype=torch.float32)
|
||||
)
|
||||
# Register FP8 buffers for down_proj
|
||||
self.down_proj = torch.nn.Parameter(
|
||||
torch.zeros((self.num_experts, self.expert_dim, self.hidden_size), dtype=torch.float8_e4m3fn)
|
||||
)
|
||||
self.down_proj_scale = torch.nn.Parameter(
|
||||
torch.zeros((self.num_experts, self.hidden_size, 1), dtype=torch.float32)
|
||||
)
|
||||
# Register input scale upper bound
|
||||
self.register_buffer("input_scale_ub", torch.zeros([1], dtype=torch.float), persistent=False)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
"""
|
||||
Args:
|
||||
hidden_states (torch.Tensor): (batch_size * token_num, hidden_size)
|
||||
Returns:
|
||||
torch.Tensor: (batch_size * token_num, hidden_size)
|
||||
"""
|
||||
# Reshape hidden states for expert computation
|
||||
hidden_states = hidden_states.view(self.num_experts, -1, self.hidden_size)
|
||||
num_tokens = None
|
||||
|
||||
# Pre-allocate tensor for all expert outputs with same shape as hidden_states
|
||||
next_states = torch.empty_like(hidden_states)
|
||||
|
||||
for i in range(self.num_experts):
|
||||
# Extract expert's hidden states
|
||||
expert_hidden = hidden_states[i]
|
||||
expert_hidden_reshaped = expert_hidden.reshape(-1, self.hidden_size)
|
||||
# Quantize for this expert
|
||||
expert_quantized, expert_scale = torch.ops.fbgemm.quantize_fp8_per_row(
|
||||
expert_hidden_reshaped, num_tokens, self.input_scale_ub
|
||||
)
|
||||
sharded_expert_dim = self.gate_up_proj.shape[-1] // 2
|
||||
gate_up_proj_scale_float32 = self.gate_up_proj_scale.to(torch.float32)
|
||||
|
||||
gate = torch.ops.fbgemm.f8f8bf16_rowwise(
|
||||
expert_quantized,
|
||||
self.gate_up_proj[i].transpose(0, 1)[:sharded_expert_dim].contiguous(),
|
||||
expert_scale,
|
||||
gate_up_proj_scale_float32[i][0][:sharded_expert_dim].view(-1, 1).contiguous(),
|
||||
use_fast_accum=True,
|
||||
)
|
||||
|
||||
up = torch.ops.fbgemm.f8f8bf16_rowwise(
|
||||
expert_quantized,
|
||||
self.gate_up_proj[i].transpose(0, 1)[sharded_expert_dim:].contiguous(),
|
||||
expert_scale,
|
||||
gate_up_proj_scale_float32[i][0][sharded_expert_dim:].view(-1, 1).contiguous(),
|
||||
use_fast_accum=True,
|
||||
)
|
||||
|
||||
activated = up * self.act_fn(gate)
|
||||
|
||||
activated_quantized, activated_scale = torch.ops.fbgemm.quantize_fp8_per_row(
|
||||
activated, num_tokens, self.input_scale_ub
|
||||
)
|
||||
|
||||
down_proj_scale_float32 = self.down_proj_scale.to(torch.float32)
|
||||
expert_output = torch.ops.fbgemm.f8f8bf16_rowwise(
|
||||
activated_quantized,
|
||||
self.down_proj[i].transpose(0, 1).contiguous(),
|
||||
activated_scale,
|
||||
down_proj_scale_float32[i].view(-1, 1).contiguous(),
|
||||
use_fast_accum=True,
|
||||
)
|
||||
|
||||
next_states[i] = expert_output
|
||||
next_states = next_states.to(hidden_states.device)
|
||||
return next_states.view(-1, self.hidden_size)
|
||||
|
||||
|
||||
def _replace_with_fbgemm_fp8_linear(
|
||||
model,
|
||||
modules_to_not_convert=None,
|
||||
@ -74,12 +161,17 @@ def _replace_with_fbgemm_fp8_linear(
|
||||
quantization_config=None,
|
||||
has_been_replaced=False,
|
||||
pre_quantized=False,
|
||||
config=None,
|
||||
tp_plan=None,
|
||||
):
|
||||
"""
|
||||
Private method that wraps the recursion for module replacement.
|
||||
|
||||
Returns the converted model and a boolean that indicates if the conversion has been successfull or not.
|
||||
"""
|
||||
|
||||
import re
|
||||
|
||||
if current_key_name is None:
|
||||
current_key_name = []
|
||||
|
||||
@ -105,9 +197,27 @@ def _replace_with_fbgemm_fp8_linear(
|
||||
# Force requires grad to False to avoid unexpected errors
|
||||
model._modules[name].requires_grad_(False)
|
||||
# set non persistant buffer outside of init_empty_weights
|
||||
model._modules[name].input_scale_ub = torch.tensor(
|
||||
[quantization_config.activation_scale_ub],
|
||||
dtype=torch.float,
|
||||
)
|
||||
if module.__class__.__name__ == "Llama4TextExperts" and name not in modules_to_not_convert:
|
||||
current_key_name_str = ".".join(current_key_name)
|
||||
if not any(
|
||||
(key + "." in current_key_name_str) or (key == current_key_name_str) for key in modules_to_not_convert
|
||||
):
|
||||
with init_empty_weights(include_buffers=True):
|
||||
tp_plan[re.sub(r"\d+", "*", current_key_name_str + ".gate_up_proj_scale")] = tp_plan[
|
||||
re.sub(r"\d+", "*", current_key_name_str + ".gate_up_proj")
|
||||
]
|
||||
tp_plan[re.sub(r"\d+", "*", current_key_name_str + ".down_proj_scale")] = None
|
||||
model._modules[name] = FbgemmFp8Llama4TextExperts(
|
||||
config.text_config,
|
||||
)
|
||||
model._modules[name].input_scale_ub = torch.tensor(
|
||||
[quantization_config.activation_scale_ub], dtype=torch.float
|
||||
)
|
||||
|
||||
if len(list(module.children())) > 0:
|
||||
_, has_been_replaced = _replace_with_fbgemm_fp8_linear(
|
||||
module,
|
||||
@ -116,6 +226,8 @@ def _replace_with_fbgemm_fp8_linear(
|
||||
quantization_config,
|
||||
has_been_replaced=has_been_replaced,
|
||||
pre_quantized=pre_quantized,
|
||||
config=config,
|
||||
tp_plan=tp_plan,
|
||||
)
|
||||
# Remove the last key for recursion
|
||||
current_key_name.pop(-1)
|
||||
@ -123,7 +235,13 @@ def _replace_with_fbgemm_fp8_linear(
|
||||
|
||||
|
||||
def replace_with_fbgemm_fp8_linear(
|
||||
model, modules_to_not_convert=None, current_key_name=None, quantization_config=None, pre_quantized=False
|
||||
model,
|
||||
modules_to_not_convert=None,
|
||||
current_key_name=None,
|
||||
quantization_config=None,
|
||||
pre_quantized=False,
|
||||
config=None,
|
||||
tp_plan=None,
|
||||
):
|
||||
"""
|
||||
A helper function to replace all `torch.nn.Linear` modules by `FbgemmFp8Linear` modules.
|
||||
@ -151,9 +269,14 @@ def replace_with_fbgemm_fp8_linear(
|
||||
modules_to_not_convert.extend(quantization_config.modules_to_not_convert)
|
||||
modules_to_not_convert = list(set(modules_to_not_convert))
|
||||
model, has_been_replaced = _replace_with_fbgemm_fp8_linear(
|
||||
model, modules_to_not_convert, current_key_name, quantization_config, pre_quantized=pre_quantized
|
||||
model,
|
||||
modules_to_not_convert,
|
||||
current_key_name,
|
||||
quantization_config,
|
||||
pre_quantized=pre_quantized,
|
||||
config=config,
|
||||
tp_plan=tp_plan,
|
||||
)
|
||||
|
||||
if not has_been_replaced:
|
||||
logger.warning(
|
||||
"You are loading your model using FP8 quantization but no linear modules were found in your model."
|
||||
|
@ -34,10 +34,7 @@ from ..utils import is_torch_flex_attn_available
|
||||
|
||||
|
||||
if is_torch_flex_attn_available():
|
||||
from torch.nn.attention.flex_attention import (
|
||||
BlockMask,
|
||||
flex_attention,
|
||||
)
|
||||
from torch.nn.attention.flex_attention import BlockMask, flex_attention
|
||||
from torch.nn.attention.flex_attention import (
|
||||
create_block_mask as create_block_causal_mask_flex,
|
||||
)
|
||||
@ -64,14 +61,23 @@ class WrappedFlexAttention:
|
||||
Initialize or update the singleton instance.
|
||||
"""
|
||||
if self._is_flex_compiled is False:
|
||||
self._compiled_flex_attention = torch.compile(flex_attention, dynamic=False)
|
||||
self._compiled_flex_attention = torch.compile(flex_attention, backend="inductor")
|
||||
self._is_flex_compiled = True
|
||||
|
||||
def __call__(self):
|
||||
return self._compiled_flex_attention
|
||||
|
||||
|
||||
def make_flex_block_causal_mask(attention_mask_2d: torch.Tensor) -> "BlockMask":
|
||||
Offset = Union[torch.Tensor, int]
|
||||
|
||||
|
||||
def make_flex_block_causal_mask(
|
||||
attention_mask_2d: torch.Tensor,
|
||||
attention_chunk_size: Optional[int] = None,
|
||||
query_length=None,
|
||||
key_length=None,
|
||||
offsets: Optional[Tuple[Offset, Offset]] = None,
|
||||
) -> "BlockMask":
|
||||
"""
|
||||
Create a block causal document mask for a batch of sequences, both packed and unpacked.
|
||||
Create Block causal logic and passing it into :func:`torch.nn.attention.flex_attention.create_block_mask`.
|
||||
@ -94,10 +100,13 @@ def make_flex_block_causal_mask(attention_mask_2d: torch.Tensor) -> "BlockMask":
|
||||
Returns:
|
||||
BlockMask
|
||||
"""
|
||||
attention_mask_2d = torch.nn.functional.pad(attention_mask_2d, value=0, pad=(0, key_length))
|
||||
device = attention_mask_2d.device
|
||||
document_ids = attention_mask_2d.clone()
|
||||
|
||||
document_ids = attention_mask_2d
|
||||
batch_size, total_seq_len = document_ids.shape
|
||||
if attention_chunk_size is not None:
|
||||
# we create an arange, then we just // by chunk size to get [0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3]
|
||||
document_ids = (document_ids.fill_(1).cumsum(-1) - 1) // (attention_chunk_size)
|
||||
|
||||
# Instead of passing a tensor mask, flex attention requires a mask_mod function
|
||||
# that determines which elements of QK^T should be included in the attention
|
||||
@ -112,18 +121,30 @@ def make_flex_block_causal_mask(attention_mask_2d: torch.Tensor) -> "BlockMask":
|
||||
See :func:`~torchtune.modules.attention_utils.create_block_causal_mask`
|
||||
for an illustration.
|
||||
"""
|
||||
causal_mask = q_idx >= kv_idx
|
||||
causal_mask = q_idx >= kv_idx # not valid when decoding
|
||||
document_mask = document_ids[batch_idx, q_idx] == document_ids[batch_idx, kv_idx]
|
||||
padding_mask = document_ids[batch_idx, q_idx] > 0
|
||||
return causal_mask & document_mask & padding_mask
|
||||
padding_mask = attention_mask_2d[batch_idx, q_idx] > 0
|
||||
final_mask = causal_mask & padding_mask & document_mask
|
||||
return final_mask
|
||||
|
||||
if offsets is not None:
|
||||
q_offset = offsets[0]
|
||||
kv_offset = offsets[1]
|
||||
|
||||
def mask_mod(batch_idx, head_idx, q_idx, kv_idx):
|
||||
offset_q = q_idx + q_offset
|
||||
offset_kv = kv_idx + kv_offset
|
||||
return causal_mask_mod(batch_idx, head_idx, offset_q, offset_kv)
|
||||
else:
|
||||
mask_mod = causal_mask_mod
|
||||
return create_block_causal_mask_flex(
|
||||
mask_mod=causal_mask_mod,
|
||||
B=batch_size,
|
||||
mask_mod=mask_mod,
|
||||
B=1,
|
||||
H=None, # attention head
|
||||
Q_LEN=total_seq_len,
|
||||
KV_LEN=total_seq_len,
|
||||
Q_LEN=query_length,
|
||||
KV_LEN=key_length,
|
||||
device=device,
|
||||
_compile=True,
|
||||
)
|
||||
|
||||
|
||||
@ -144,6 +165,18 @@ def compile_friendly_flex_attention(
|
||||
)
|
||||
|
||||
|
||||
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||
"""
|
||||
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
||||
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
||||
"""
|
||||
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
||||
if n_rep == 1:
|
||||
return hidden_states
|
||||
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
|
||||
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
||||
|
||||
|
||||
def flex_attention_forward(
|
||||
module: torch.nn.Module,
|
||||
query: torch.Tensor,
|
||||
@ -174,14 +207,25 @@ def flex_attention_forward(
|
||||
score = score + head_mask[batch_idx][head_idx][0][0]
|
||||
return score
|
||||
|
||||
enable_gqa = True
|
||||
num_local_query_heads = query.shape[1]
|
||||
|
||||
# When running TP this helps:
|
||||
if not ((num_local_query_heads & (num_local_query_heads - 1)) == 0):
|
||||
key = repeat_kv(key, query.shape[1] // key.shape[1])
|
||||
value = repeat_kv(value, query.shape[1] // value.shape[1])
|
||||
enable_gqa = False
|
||||
|
||||
kernel_options = kwargs.get("kernel_options", None)
|
||||
attn_output, attention_weights = compile_friendly_flex_attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
score_mod=score_mod,
|
||||
block_mask=block_mask,
|
||||
enable_gqa=True,
|
||||
enable_gqa=enable_gqa,
|
||||
scale=scaling,
|
||||
kernel_options=kernel_options,
|
||||
# Last time checked on PyTorch == 2.5.1: Flex Attention always computes the lse regardless.
|
||||
# For simplification, we thus always return it as no additional computations are introduced.
|
||||
return_lse=True,
|
||||
|
@ -31,7 +31,7 @@ def sdpa_attention_forward(
|
||||
value = repeat_kv(value, module.num_key_value_groups)
|
||||
|
||||
causal_mask = attention_mask
|
||||
if attention_mask is not None:
|
||||
if attention_mask is not None and causal_mask.ndim == 4:
|
||||
causal_mask = causal_mask[:, :, :, : key.shape[-2]]
|
||||
|
||||
# SDPA with memory-efficient backend is bugged with non-contiguous inputs and custom attn_mask for some torch versions
|
||||
|
@ -61,6 +61,21 @@ def _blocks_to_block_sizes(total_size: int, blocks: Union[int, List[int]]) -> Li
|
||||
return [single_size] * blocks
|
||||
|
||||
|
||||
str_to_torch_dtype = {
|
||||
"BOOL": torch.bool,
|
||||
"U8": torch.uint8,
|
||||
"I8": torch.int8,
|
||||
"I16": torch.int16,
|
||||
"F16": torch.float16,
|
||||
"BF16": torch.bfloat16,
|
||||
"I32": torch.int32,
|
||||
"F32": torch.float32,
|
||||
"F64": torch.float64,
|
||||
"I64": torch.int64,
|
||||
"F8_E4M3": torch.float8_e4m3fn,
|
||||
}
|
||||
|
||||
|
||||
def get_packed_weights(param, empty_param, device_mesh, rank, dim):
|
||||
"""
|
||||
When weights are packed (gate_up_proj), we need to make sure each shard gets its correct share.
|
||||
@ -106,6 +121,12 @@ def get_packed_weights(param, empty_param, device_mesh, rank, dim):
|
||||
tensors_slices += range(block_offset + start, block_offset + stop)
|
||||
block_offset += block_size
|
||||
|
||||
slice_dtype = slice_.get_dtype()
|
||||
# Handle F8_E4M3 dtype by converting to float16 before slicing
|
||||
# Without upcasting, the slicing causes : RuntimeError: "index_cpu" not implemented for 'Float8_e4m3fn'
|
||||
if slice_dtype == "F8_E4M3":
|
||||
slice_ = slice_[...].to(torch.float16)
|
||||
|
||||
if dim == 0:
|
||||
tensor = slice_[tensors_slices, ...]
|
||||
elif dim == 1 or dim == -2:
|
||||
@ -114,7 +135,7 @@ def get_packed_weights(param, empty_param, device_mesh, rank, dim):
|
||||
tensor = slice_[..., tensors_slices]
|
||||
else:
|
||||
raise ValueError(f"Unsupported dim {dim}, only dim 0, 1 or 2 are supported")
|
||||
return tensor
|
||||
return tensor.to(str_to_torch_dtype[slice_dtype])
|
||||
|
||||
|
||||
def get_tensor_shard(param, empty_param, device_mesh, rank, dim):
|
||||
@ -199,11 +220,12 @@ class GatherParallel(TensorParallelLayer):
|
||||
@staticmethod
|
||||
def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_mesh):
|
||||
if isinstance(inputs[0], DTensor):
|
||||
inputs[0] = inputs[0].to_local()
|
||||
inputs = inputs[0].to_local()
|
||||
return inputs
|
||||
|
||||
@staticmethod
|
||||
def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh):
|
||||
# this op cannot be asynch, otherwise it completely breaks the outputs of models
|
||||
torch.distributed.all_reduce(outputs[0], op=torch.distributed.ReduceOp.SUM, async_op=False)
|
||||
return outputs
|
||||
|
||||
@ -266,7 +288,7 @@ class ColwiseParallel(TensorParallelLayer):
|
||||
|
||||
# transform the input layouts to the desired layouts of ColwiseParallel
|
||||
if input_layouts != desired_input_layouts:
|
||||
input_tensor = input_tensor.redistribute(placements=desired_input_layouts, async_op=True)
|
||||
input_tensor = input_tensor.redistribute(placements=desired_input_layouts, async_op=False)
|
||||
return input_tensor
|
||||
|
||||
def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh):
|
||||
@ -291,7 +313,7 @@ class ColwiseParallel(TensorParallelLayer):
|
||||
def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh):
|
||||
# outputs is a shard on last dimension DTensor, i.e. Shard(-1)
|
||||
if outputs.placements != output_layouts:
|
||||
outputs = outputs.redistribute(placements=output_layouts, async_op=True)
|
||||
outputs = outputs.redistribute(placements=output_layouts, async_op=False)
|
||||
# back to local tensor
|
||||
return outputs.to_local() if use_local_output else outputs
|
||||
|
||||
@ -343,16 +365,6 @@ class RowwiseParallel(TensorParallelLayer):
|
||||
self.use_local_output = use_local_output
|
||||
self.use_dtensor = use_dtensor
|
||||
|
||||
@staticmethod
|
||||
def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_mesh):
|
||||
input_tensor = inputs[0]
|
||||
if not isinstance(input_tensor, DTensor):
|
||||
input_tensor = DTensor.from_local(input_tensor, device_mesh, input_layouts, run_check=False)
|
||||
|
||||
if input_layouts != desired_input_layouts:
|
||||
input_tensor = input_tensor.redistribute(placements=desired_input_layouts, async_op=True)
|
||||
return input_tensor
|
||||
|
||||
def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh):
|
||||
# Rowwise shard weight to Shard(1), bias to Replicate(), weight be Shard(1)
|
||||
# means Rowwise as nn.Linear is input * weight^T + bias, where
|
||||
@ -371,6 +383,20 @@ class RowwiseParallel(TensorParallelLayer):
|
||||
parameter = DTensor.from_local(parameter, device_mesh, shard, run_check=False)
|
||||
return nn.Parameter(parameter)
|
||||
|
||||
@staticmethod
|
||||
def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_mesh):
|
||||
if hasattr(mod, "bias") and mod.bias is not None:
|
||||
mod._bias = mod.bias
|
||||
mod.bias = None
|
||||
|
||||
input_tensor = inputs[0]
|
||||
if not isinstance(input_tensor, DTensor):
|
||||
input_tensor = DTensor.from_local(input_tensor, device_mesh, input_layouts, run_check=False)
|
||||
|
||||
if input_layouts != desired_input_layouts:
|
||||
input_tensor = input_tensor.redistribute(placements=desired_input_layouts, async_op=True)
|
||||
return input_tensor
|
||||
|
||||
@staticmethod
|
||||
def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh):
|
||||
# Rowwise sharding produces partial output, depending on output layouts:
|
||||
@ -378,6 +404,8 @@ class RowwiseParallel(TensorParallelLayer):
|
||||
# 2. to shard -> reduce_scatter
|
||||
if outputs.placements != output_layouts:
|
||||
outputs = outputs.redistribute(placements=output_layouts, async_op=True)
|
||||
if hasattr(mod, "_bias"):
|
||||
outputs += mod._bias
|
||||
# back to local tensor if use_local_output is True
|
||||
return outputs.to_local() if use_local_output else outputs
|
||||
|
||||
@ -418,6 +446,90 @@ class PackedRowwiseParallel(RowwiseParallel):
|
||||
return nn.Parameter(parameter)
|
||||
|
||||
|
||||
class SequenceParallel(TensorParallelLayer):
|
||||
"""
|
||||
SequenceParallel replicates a compatible ``nn.Module`` parameters and runs the sharded computation with
|
||||
input sharded on the sequence dimension. This currently supports ``nn.LayerNorm``, ``nn.Dropout``, and the
|
||||
`RMSNorm python implementation <https://github.com/facebookresearch/llama/blob/main/llama/model.py#L34>`__
|
||||
|
||||
This style implements the operation that is described in the paper
|
||||
`Reducing Activation Recomputation in Large Transformer Models <https://arxiv.org/abs/2205.05198>`__
|
||||
|
||||
If the input passed in to this ``nn.Module`` is a :class:`torch.Tensor`, it assumes that the input is already sharded
|
||||
on the sequence dimension and converts the input to a :class:`DTensor` sharded on the sequence dimension. If the input
|
||||
passed in to this ``nn.Module`` is already a :class:`DTensor` but is not sharded on the sequence dimension, it would
|
||||
redistribute the input to be sharded on the sequence dimension.
|
||||
|
||||
The output of the ``nn.Module`` will be sharded on the sequence dimension.
|
||||
|
||||
Keyword Args:
|
||||
sequence_dim (int, optional):
|
||||
The sequence dimension of the input tensor for the ``nn.Module``, this is used to annotate the input tensor to
|
||||
become a DTensor that is sharded on the sequence dimension, default: 1.
|
||||
use_local_output (bool, optional):
|
||||
Whether to use local :class:`torch.Tensor` instead of :class:`DTensor` for the module output, default: False.
|
||||
Returns:
|
||||
A :class:`ParallelStyle` object that represents Sequence Parallel of the ``nn.Module``.
|
||||
|
||||
Example::
|
||||
>>> # xdoctest: +SKIP(failing)
|
||||
>>> from torch.distributed.tensor.parallel import parallelize_module, SequenceParallel
|
||||
>>> from torch.distributed.device_mesh import init_device_mesh
|
||||
>>> ...
|
||||
>>> m = Model(...) # m is a nn.Module that contains a "norm" nn.LayerNorm submodule
|
||||
>>> tp_mesh = init_device_mesh("cuda", (8,))
|
||||
>>>
|
||||
>>> # By default, the input of the "norm" will be converted to DTensor that shards on the sequence dim
|
||||
>>> # and the output of "norm" will return a sharded on sequence dimension :class:`DTensor`.
|
||||
>>>
|
||||
>>> sharded_mod = parallelize_module(m, tp_mesh, {"norm": SequenceParallel()}),
|
||||
>>> ...
|
||||
|
||||
.. note:: SequenceParallel style assumes ones initialization if there are weights in the nn.Module (i.e.
|
||||
``nn.LayerNorm`` or ``RMSNorm``, and they by default have ones initialization). If you have custom
|
||||
inits for the weights on those modules, you need to broadcast the weights before/after parallelizing
|
||||
to ensure that they are replicated.
|
||||
"""
|
||||
|
||||
def __init__(self, *, sequence_dim: int = 1, use_local_output: bool = False, use_dtensor=False):
|
||||
super().__init__()
|
||||
self.input_layouts = (Replicate(),)
|
||||
self.desired_input_layouts = (Shard(1),)
|
||||
self.output_layouts = (Replicate(),)
|
||||
self.use_local_output = use_local_output
|
||||
self.use_dtensor = True
|
||||
self.sequence_sharding = (Shard(sequence_dim),)
|
||||
self.use_local_output = use_local_output
|
||||
|
||||
@staticmethod
|
||||
def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_mesh):
|
||||
input_tensor = inputs[0]
|
||||
if not isinstance(input_tensor, DTensor):
|
||||
input_tensor = DTensor.from_local(input_tensor, device_mesh, input_layouts, run_check=False)
|
||||
if input_layouts != desired_input_layouts:
|
||||
input_tensor = input_tensor.redistribute(placements=desired_input_layouts, async_op=True)
|
||||
return input_tensor
|
||||
|
||||
@staticmethod
|
||||
def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh):
|
||||
outputs = outputs.redistribute(
|
||||
placements=(Replicate(),), async_op=True
|
||||
) # maybe we have to replicate ? because next layer is not sharded
|
||||
return outputs.to_local() # if use_local_output else outputs
|
||||
|
||||
def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh):
|
||||
# colwise shard weight/bias to Shard(0), weight be Shard(-2) (0 if you have 1 dim only)
|
||||
# means Colwise as Linear is input * weight^T + bias, where
|
||||
# weight would become Shard(1)
|
||||
parameter = param[:]
|
||||
parameter = parameter.to(param_casting_dtype)
|
||||
if to_contiguous:
|
||||
parameter = parameter.contiguous()
|
||||
if self.use_dtensor:
|
||||
parameter = DTensor.from_local(parameter, device_mesh, [Replicate()], run_check=False)
|
||||
return nn.Parameter(parameter)
|
||||
|
||||
|
||||
SUPPORTED_TP_STYLES = {
|
||||
"colwise",
|
||||
"rowwise",
|
||||
@ -428,6 +540,7 @@ SUPPORTED_TP_STYLES = {
|
||||
"local",
|
||||
"gather",
|
||||
"local_packed_rowwise",
|
||||
"sequence_parallel",
|
||||
}
|
||||
|
||||
|
||||
@ -459,6 +572,8 @@ def translate_to_torch_parallel_style(style: str):
|
||||
return GatherParallel()
|
||||
elif style == "local_packed_rowwise":
|
||||
return PackedRowwiseParallel(use_dtensor=False)
|
||||
elif style == "sequence_parallel":
|
||||
return SequenceParallel()
|
||||
else:
|
||||
raise ValueError(f"Unsupported parallel style value: {style}")
|
||||
|
||||
@ -518,6 +633,7 @@ def shard_and_distribute_module(
|
||||
tp_plan = model._tp_plan
|
||||
module_to_tp = model.get_submodule(param_name)
|
||||
current_module_plan = None
|
||||
rank = int(rank)
|
||||
generic_param_name = re.sub(r"\d+", "*", parameter_name)
|
||||
if generic_param_name in tp_plan:
|
||||
current_module_plan = tp_plan[generic_param_name]
|
||||
@ -531,12 +647,18 @@ def shard_and_distribute_module(
|
||||
module_to_tp._is_hooked = True
|
||||
|
||||
if current_module_plan is not None:
|
||||
tp_layer = translate_to_torch_parallel_style(current_module_plan)
|
||||
param = tp_layer.partition_tensor(
|
||||
param, empty_param, param_type, param_casting_dtype, is_contiguous, rank, device_mesh
|
||||
)
|
||||
try:
|
||||
tp_layer = translate_to_torch_parallel_style(current_module_plan)
|
||||
param = tp_layer.partition_tensor(
|
||||
param, empty_param, param_type, param_casting_dtype, is_contiguous, rank, device_mesh
|
||||
)
|
||||
except NotImplementedError as e:
|
||||
print(
|
||||
f"Trying to prepare {parameter_name}, but it's not supported. Corresponding module: {module_to_tp} Fix it's TP plan, current layer: {tp_layer} : {e}"
|
||||
)
|
||||
else:
|
||||
# TODO log no plan modules in set
|
||||
# print("No plan for", parameter_name,end ="\n")
|
||||
param = param[...].to(param_casting_dtype)
|
||||
if is_contiguous:
|
||||
param = param.contiguous()
|
||||
|
@ -484,6 +484,7 @@ str_to_torch_dtype = {
|
||||
"F32": torch.float32,
|
||||
"F64": torch.float64,
|
||||
"I64": torch.int64,
|
||||
"F8_E4M3": torch.float8_e4m3fn,
|
||||
}
|
||||
|
||||
if is_torch_greater_or_equal("2.1.0"):
|
||||
@ -1914,16 +1915,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
)
|
||||
|
||||
# If current model is a base model, attach `base_model_tp_plan` and `base_model_pp_plan` from config
|
||||
if self.base_model is self:
|
||||
self._pp_plan = (
|
||||
self.config.base_model_pp_plan.copy() if self.config.base_model_pp_plan is not None else None
|
||||
)
|
||||
self._tp_plan = self.config.base_model_tp_plan.copy() if self.config.base_model_tp_plan is not None else {}
|
||||
else:
|
||||
self._tp_plan = self._tp_plan or {}
|
||||
for name, module in self.named_children():
|
||||
if plan := getattr(module, "_tp_plan", None):
|
||||
self._tp_plan.update({f"{name}.{k}": v for k, v in plan.items()})
|
||||
self._pp_plan = self.config.base_model_pp_plan.copy() if self.config.base_model_pp_plan is not None else None
|
||||
self._tp_plan = self.config.base_model_tp_plan.copy() if self.config.base_model_tp_plan is not None else {}
|
||||
for name, module in self.named_children():
|
||||
if plan := getattr(module, "_tp_plan", None):
|
||||
self._tp_plan.update({f"{name}.{k}": v for k, v in plan.copy().items()})
|
||||
|
||||
if self._tp_plan is not None and is_torch_greater_or_equal("2.3"):
|
||||
for _, v in self._tp_plan.items():
|
||||
@ -4054,6 +4050,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
import sys
|
||||
|
||||
sys.stdout = open(os.devnull, "w")
|
||||
sys.stderr = open(os.devnull, "w")
|
||||
# This is the easiest way to dispatch to the current process device
|
||||
device_map = tp_device
|
||||
# Assuming sharding the model onto the world
|
||||
@ -4238,6 +4235,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
)
|
||||
torch_dtype = hf_quantizer.update_torch_dtype(torch_dtype)
|
||||
device_map = hf_quantizer.update_device_map(device_map)
|
||||
config = hf_quantizer.update_tp_plan(config)
|
||||
|
||||
# In order to ensure popular quantization methods are supported. Can be disable with `disable_telemetry`
|
||||
if hasattr(hf_quantizer.quantization_config.quant_method, "value"):
|
||||
@ -4370,9 +4368,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
|
||||
if hf_quantizer is not None:
|
||||
hf_quantizer.preprocess_model(
|
||||
model=model, device_map=device_map, keep_in_fp32_modules=model._keep_in_fp32_modules
|
||||
model=model, device_map=device_map, keep_in_fp32_modules=model._keep_in_fp32_modules, config=config
|
||||
)
|
||||
|
||||
# We store the original dtype for quantized models as we cannot easily retrieve it
|
||||
# once the weights have been quantized
|
||||
# Note that once you have loaded a quantized model, you can't change its dtype so this will
|
||||
@ -4901,7 +4898,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
name,
|
||||
casting_dtype,
|
||||
to_contiguous,
|
||||
tp_device.index,
|
||||
os.environ["RANK"],
|
||||
device_mesh,
|
||||
)
|
||||
|
||||
@ -5174,6 +5171,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
want to use compiled version to avoid recomputing the graph with new shapes) and iterative decoding
|
||||
(where we want the speed-ups of compiled version with static shapes)."""
|
||||
# Only reset it if not present or different from previous config
|
||||
if "llama4" in self.config.model_type: # TODO try to enable for FULL COMPILE HYBRID CACHE SUPPORT
|
||||
return self.__call__
|
||||
default_config = getattr(self.generation_config, "compile_config", CompileConfig())
|
||||
if (
|
||||
not hasattr(self, "_compiled_call")
|
||||
|
@ -148,6 +148,7 @@ from . import (
|
||||
levit,
|
||||
lilt,
|
||||
llama,
|
||||
llama4,
|
||||
llava,
|
||||
llava_next,
|
||||
llava_next_video,
|
||||
|
@ -544,10 +544,6 @@ class _BaseAutoModelClass:
|
||||
if kwargs_orig.get("quantization_config", None) is not None:
|
||||
kwargs["quantization_config"] = kwargs_orig["quantization_config"]
|
||||
|
||||
# AutoClass-specific config manipulation
|
||||
config = copy.deepcopy(config)
|
||||
config = cls._prepare_config_for_auto_class(config)
|
||||
|
||||
has_remote_code = hasattr(config, "auto_map") and cls.__name__ in config.auto_map
|
||||
has_local_code = type(config) in cls._model_mapping.keys()
|
||||
trust_remote_code = resolve_trust_remote_code(
|
||||
@ -570,6 +566,8 @@ class _BaseAutoModelClass:
|
||||
)
|
||||
elif type(config) in cls._model_mapping.keys():
|
||||
model_class = _get_model_class(config, cls._model_mapping)
|
||||
if model_class.config_class == config.sub_configs.get("text_config", None):
|
||||
config = config.get_text_config()
|
||||
return model_class.from_pretrained(
|
||||
pretrained_model_name_or_path, *model_args, config=config, **hub_kwargs, **kwargs
|
||||
)
|
||||
|
@ -170,6 +170,8 @@ CONFIG_MAPPING_NAMES = OrderedDict(
|
||||
("levit", "LevitConfig"),
|
||||
("lilt", "LiltConfig"),
|
||||
("llama", "LlamaConfig"),
|
||||
("llama4", "Llama4Config"),
|
||||
("llama4_text", "Llama4TextConfig"),
|
||||
("llava", "LlavaConfig"),
|
||||
("llava_next", "LlavaNextConfig"),
|
||||
("llava_next_video", "LlavaNextVideoConfig"),
|
||||
@ -519,6 +521,8 @@ MODEL_NAMES_MAPPING = OrderedDict(
|
||||
("llama", "LLaMA"),
|
||||
("llama2", "Llama2"),
|
||||
("llama3", "Llama3"),
|
||||
("llama4", "Llama4"),
|
||||
("llama4_text", "Llama4ForCausalLM"),
|
||||
("llava", "LLaVa"),
|
||||
("llava_next", "LLaVA-NeXT"),
|
||||
("llava_next_video", "LLaVa-NeXT-Video"),
|
||||
@ -776,6 +780,7 @@ SPECIAL_MODEL_TYPE_TO_MODULE_NAME = OrderedDict(
|
||||
("rt_detr_resnet", "rt_detr"),
|
||||
("granitevision", "llava_next"),
|
||||
("sam_vision_model", "sam"),
|
||||
("llama4_text", "llama4"),
|
||||
]
|
||||
)
|
||||
|
||||
|
@ -104,6 +104,7 @@ else:
|
||||
("layoutlmv2", ("LayoutLMv2ImageProcessor",)),
|
||||
("layoutlmv3", ("LayoutLMv3ImageProcessor",)),
|
||||
("levit", ("LevitImageProcessor",)),
|
||||
("llama4", ("Llama4ImageProcessor", "Llama4ImageProcessorFast")),
|
||||
("llava", ("LlavaImageProcessor", "LlavaImageProcessorFast")),
|
||||
("llava_next", ("LlavaNextImageProcessor", "LlavaNextImageProcessorFast")),
|
||||
("llava_next_video", ("LlavaNextVideoImageProcessor",)),
|
||||
|
@ -17,7 +17,6 @@
|
||||
import warnings
|
||||
from collections import OrderedDict
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...utils import logging
|
||||
from .auto_factory import (
|
||||
_BaseAutoBackboneClass,
|
||||
@ -161,6 +160,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
|
||||
("levit", "LevitModel"),
|
||||
("lilt", "LiltModel"),
|
||||
("llama", "LlamaModel"),
|
||||
("llama4", "Llama4ForConditionalGeneration"),
|
||||
("longformer", "LongformerModel"),
|
||||
("longt5", "LongT5Model"),
|
||||
("luke", "LukeModel"),
|
||||
@ -547,6 +547,8 @@ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
|
||||
("jamba", "JambaForCausalLM"),
|
||||
("jetmoe", "JetMoeForCausalLM"),
|
||||
("llama", "LlamaForCausalLM"),
|
||||
("llama4", "Llama4ForCausalLM"),
|
||||
("llama4_text", "Llama4ForCausalLM"),
|
||||
("mamba", "MambaForCausalLM"),
|
||||
("mamba2", "Mamba2ForCausalLM"),
|
||||
("marian", "MarianForCausalLM"),
|
||||
@ -634,6 +636,7 @@ MODEL_FOR_IMAGE_MAPPING_NAMES = OrderedDict(
|
||||
("ijepa", "IJepaModel"),
|
||||
("imagegpt", "ImageGPTModel"),
|
||||
("levit", "LevitModel"),
|
||||
("llama4", "Llama4VisionModel"),
|
||||
("mllama", "MllamaVisionModel"),
|
||||
("mobilenet_v1", "MobileNetV1Model"),
|
||||
("mobilenet_v2", "MobileNetV2Model"),
|
||||
@ -849,6 +852,7 @@ MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES = OrderedDict(
|
||||
("idefics3", "Idefics3ForConditionalGeneration"),
|
||||
("instructblip", "InstructBlipForConditionalGeneration"),
|
||||
("kosmos-2", "Kosmos2ForConditionalGeneration"),
|
||||
("llama4", "Llama4ForConditionalGeneration"),
|
||||
("llava", "LlavaForConditionalGeneration"),
|
||||
("llava_next", "LlavaNextForConditionalGeneration"),
|
||||
("llava_onevision", "LlavaOnevisionForConditionalGeneration"),
|
||||
@ -1492,6 +1496,7 @@ MODEL_FOR_TEXT_ENCODING_MAPPING_NAMES = OrderedDict(
|
||||
("emu3", "Emu3TextModel"),
|
||||
("flaubert", "FlaubertModel"),
|
||||
("ibert", "IBertModel"),
|
||||
("llama4", "Llama4TextModel"),
|
||||
("longformer", "LongformerModel"),
|
||||
("mllama", "MllamaTextModel"),
|
||||
("mobilebert", "MobileBertModel"),
|
||||
@ -1678,30 +1683,6 @@ _AutoModelWithLMHead = auto_class_update(_AutoModelWithLMHead, head_doc="languag
|
||||
class AutoModelForCausalLM(_BaseAutoModelClass):
|
||||
_model_mapping = MODEL_FOR_CAUSAL_LM_MAPPING
|
||||
|
||||
@classmethod
|
||||
def _prepare_config_for_auto_class(cls, config: PretrainedConfig) -> PretrainedConfig:
|
||||
"""
|
||||
Additional autoclass-specific config post-loading manipulation. In this specific autoclass, if the config has
|
||||
a nested text decoder section, uses that section instead.
|
||||
|
||||
Under the hood, multimodal models mapped by AutoModelForCausalLM assume the text decoder receives its own
|
||||
config, rather than the config for the whole model. This is used e.g. to load the text-only part of a VLM.
|
||||
"""
|
||||
possible_text_config_names = ("decoder", "generator", "text_config")
|
||||
text_config_names = []
|
||||
for text_config_name in possible_text_config_names:
|
||||
if hasattr(config, text_config_name):
|
||||
text_config_names += [text_config_name]
|
||||
|
||||
text_config = config.get_text_config(decoder=True)
|
||||
if text_config_names and type(text_config) in cls._model_mapping.keys():
|
||||
warnings.warn(
|
||||
"Loading a multimodal model with `AutoModelForCausalLM` is deprecated and will be removed in v5. "
|
||||
"`AutoModelForCausalLM` will be used to load only the text-to-text generation module.",
|
||||
FutureWarning,
|
||||
)
|
||||
return config
|
||||
|
||||
|
||||
AutoModelForCausalLM = auto_class_update(AutoModelForCausalLM, head_doc="causal language modeling")
|
||||
|
||||
|
@ -77,6 +77,7 @@ PROCESSOR_MAPPING_NAMES = OrderedDict(
|
||||
("kosmos-2", "Kosmos2Processor"),
|
||||
("layoutlmv2", "LayoutLMv2Processor"),
|
||||
("layoutlmv3", "LayoutLMv3Processor"),
|
||||
("llama4", "Llama4Processor"),
|
||||
("llava", "LlavaProcessor"),
|
||||
("llava_next", "LlavaNextProcessor"),
|
||||
("llava_next_video", "LlavaNextVideoProcessor"),
|
||||
|
@ -292,6 +292,20 @@ else:
|
||||
"LlamaTokenizerFast" if is_tokenizers_available() else None,
|
||||
),
|
||||
),
|
||||
(
|
||||
"llama4",
|
||||
(
|
||||
"LlamaTokenizer" if is_sentencepiece_available() else None,
|
||||
"LlamaTokenizerFast" if is_tokenizers_available() else None,
|
||||
),
|
||||
),
|
||||
(
|
||||
"llama4_text",
|
||||
(
|
||||
"LlamaTokenizer" if is_sentencepiece_available() else None,
|
||||
"LlamaTokenizerFast" if is_tokenizers_available() else None,
|
||||
),
|
||||
),
|
||||
("llava", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("llava_next", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("llava_next_video", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
|
||||
|
29
src/transformers/models/llama4/__init__.py
Normal file
29
src/transformers/models/llama4/__init__.py
Normal file
@ -0,0 +1,29 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...utils import _LazyModule
|
||||
from ...utils.import_utils import define_import_structure
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .configuration_llama4 import *
|
||||
from .image_processing_llama4_fast import *
|
||||
from .modeling_llama4 import *
|
||||
from .processing_llama4 import *
|
||||
else:
|
||||
import sys
|
||||
|
||||
_file = globals()["__file__"]
|
||||
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
|
432
src/transformers/models/llama4/configuration_llama4.py
Normal file
432
src/transformers/models/llama4/configuration_llama4.py
Normal file
@ -0,0 +1,432 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 The LLAMA4 and HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class Llama4VisionConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`Llama4VisionModel`]. It is used to instantiate a
|
||||
Llama4 vision model according to the specified arguments, defining the model architecture. Instantiating a configuration
|
||||
with the defaults will yield a similar configuration to that of the Llama4 109B.
|
||||
|
||||
e.g. [meta-llama/Llama-4-Scout-17B-16E](https://huggingface.co/meta-llama/Llama-4-Scout-17B-16E)
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
Args:
|
||||
hidden_size (`int`, *optional*, defaults to 768):
|
||||
Dimensionality of the encoder layers and the pooler layer.
|
||||
hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
|
||||
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
|
||||
`"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported.
|
||||
num_hidden_layers (`int`, *optional*, defaults to 34):
|
||||
Number of hidden layers in the Transformer encoder.
|
||||
num_attention_heads (`int`, *optional*, defaults to 16):
|
||||
Number of attention heads for each attention layer in the Transformer encoder.
|
||||
num_channels (`int`, *optional*, defaults to 3):
|
||||
Number of channels in the input image.
|
||||
intermediate_size (`int`, *optional*, defaults to 5632):
|
||||
Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder.
|
||||
vision_output_dim (`int`, *optional*, defaults to 7680):
|
||||
Dimensionality of the vision model output. Includes output of transformer
|
||||
encoder with intermediate layers and global transformer encoder.
|
||||
image_size (`int`, *optional*, defaults to 448):
|
||||
The size (resolution) of each image *tile*.
|
||||
patch_size (`int`, *optional*, defaults to 14):
|
||||
The size (resolution) of each patch.
|
||||
norm_eps (`float`, *optional*, defaults to 1e-05):
|
||||
The epsilon used by the layer normalization layers.
|
||||
vision_feature_layer (``, *optional*, defaults to -1): TODO
|
||||
vision_feature_select_strategy (`int`, *optional*, defaults to `"default"`): TODO
|
||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
pixel_shuffle_ratio (`int`, *optional*, defaults to 0.5): TODO
|
||||
projector_input_dim (`int`, *optional*, defaults to 4096): TODO
|
||||
projector_output_dim (`int`, *optional*, defaults to 4096): TODO
|
||||
multi_modal_projector_bias (`int`, *optional*, defaults to `False`): TODO
|
||||
projector_dropout (`int`, *optional*, defaults to 0.0): TODO
|
||||
attention_dropout (`int`, *optional*, defaults to 0.0): TODO
|
||||
rope_theta (`int`, *optional*, defaults to 10000): TODO
|
||||
"""
|
||||
|
||||
base_model_tp_plan = {
|
||||
"model.layers.*.self_attn.q_proj": "colwise",
|
||||
"model.layers.*.self_attn.k_proj": "colwise",
|
||||
"model.layers.*.self_attn.v_proj": "colwise",
|
||||
"model.layers.*.self_attn.o_proj": "rowwise",
|
||||
"vision_adapter.mlp.fc1": "colwise",
|
||||
"vision_adapter.mlp.fc2": "rowwise",
|
||||
"patch_embedding.linear": "colwise_rep",
|
||||
}
|
||||
model_type = "llama4_vision_model"
|
||||
base_config_key = "vision_config"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int = 768,
|
||||
hidden_act: str = "gelu",
|
||||
num_hidden_layers: int = 34,
|
||||
num_attention_heads: int = 16,
|
||||
num_channels: int = 3,
|
||||
intermediate_size: int = 5632,
|
||||
vision_output_dim: int = 7680,
|
||||
image_size: int = 448,
|
||||
patch_size: int = 14,
|
||||
norm_eps: float = 1e-5,
|
||||
vision_feature_layer=-1,
|
||||
vision_feature_select_strategy="default",
|
||||
initializer_range: float = 0.02,
|
||||
pixel_shuffle_ratio=0.5,
|
||||
projector_input_dim=4096,
|
||||
projector_output_dim=4096,
|
||||
multi_modal_projector_bias=False,
|
||||
projector_dropout=0.0,
|
||||
attention_dropout=0.0,
|
||||
rope_theta=10000,
|
||||
**kwargs,
|
||||
):
|
||||
self.hidden_size = hidden_size
|
||||
self.hidden_act = hidden_act
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_channels = num_channels
|
||||
self.intermediate_size = intermediate_size
|
||||
self.image_size = image_size
|
||||
self.vision_output_dim = vision_output_dim
|
||||
self.patch_size = patch_size
|
||||
self.norm_eps = norm_eps
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.initializer_range = initializer_range
|
||||
self.pixel_shuffle_ratio = pixel_shuffle_ratio
|
||||
self.projector_input_dim = projector_input_dim
|
||||
self.projector_output_dim = projector_output_dim
|
||||
self.multi_modal_projector_bias = multi_modal_projector_bias
|
||||
self.projector_dropout = projector_dropout
|
||||
self.attention_dropout = attention_dropout
|
||||
self.vision_feature_layer = vision_feature_layer
|
||||
self.vision_feature_select_strategy = vision_feature_select_strategy
|
||||
self.rope_theta = rope_theta
|
||||
super().__init__(**kwargs)
|
||||
|
||||
|
||||
class Llama4TextConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`Llama4TextModel`]. It is used to instantiate a
|
||||
Llama4 text model according to the specified arguments, defining the model architecture. Instantiating a configuration
|
||||
with the defaults will yield a similar configuration to that of the Llama4 109B.
|
||||
|
||||
e.g. [meta-llama/Llama-4-Scout-17B-16E](https://huggingface.co/meta-llama/Llama-4-Scout-17B-16E)
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
Args:
|
||||
vocab_size (`int`, *optional*, defaults to 202048):
|
||||
Vocabulary size of the Llama4 text model. Defines the maximum number of different tokens that can be represented
|
||||
by the `inputs_ids` passed when calling [`Llama4TextModel`].
|
||||
hidden_size (`int`, *optional*, defaults to 5120):
|
||||
Dimensionality of the embeddings and hidden states.
|
||||
intermediate_size (`int`, *optional*, defaults to 8192):
|
||||
Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder.
|
||||
intermediate_size_mlp (`int`, *optional*, defaults to 16384): TODO
|
||||
num_hidden_layers (`int`, *optional*, defaults to 48):
|
||||
Number of hidden layers in the Transformer encoder.
|
||||
num_attention_heads (`int`, *optional*, defaults to 40):
|
||||
Number of attention heads for each attention layer in the Transformer encoder.
|
||||
num_key_value_heads (`int`, *optional*, defaults to 8):
|
||||
This is the number of key_value heads that should be used to implement Grouped Query Attention. If not
|
||||
specified, will default to `num_attention_heads`.
|
||||
head_dim (`int`, *optional*, defaults to 128): TODO
|
||||
hidden_act (`str` or `Callable`, *optional*, defaults to `"silu"`):
|
||||
The non-linear activation function (function or string) in the encoder and pooler.
|
||||
max_position_embeddings (`int`, *optional*, defaults to 131072):
|
||||
The maximum sequence length that this model might ever be used with.
|
||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
rms_norm_eps (`float`, *optional*, defaults to 1e-05):
|
||||
The epsilon used by the rms normalization layers.
|
||||
use_cache (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not the model should return the last key/values attentions.
|
||||
pad_token_id (`int`, *optional*, defaults to 128004):
|
||||
The id of the padding token.
|
||||
bos_token_id (`int`, *optional*, defaults to 1):
|
||||
The id of the beginning of sentence token.
|
||||
eos_token_id (`int`, *optional*, defaults to 2):
|
||||
The id of the end of sentence token.
|
||||
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
|
||||
Whether to tie weight embeddings
|
||||
rope_theta (`float`, *optional*, defaults to `500000.0`):
|
||||
The base period of the RoPE embeddings.
|
||||
attention_dropout (`int`, *optional*, defaults to 0.0): TODO
|
||||
num_experts_per_tok (`int`, *optional*, defaults to 1): TODO
|
||||
num_local_experts (`int`, *optional*, defaults to 16): TODO
|
||||
moe_layers (`int`, *optional*): TODO
|
||||
interleave_moe_layer_step (`int`, *optional*, defaults to 1): TODO
|
||||
use_qk_norm (`int`, *optional*, defaults to `True`): TODO
|
||||
output_router_logits (`int`, *optional*, defaults to `False`): TODO
|
||||
router_aux_loss_coef (`int`, *optional*, defaults to 0.001): TODO
|
||||
router_jitter_noise (`int`, *optional*, defaults to 0.0): TODO
|
||||
rope_scaling (`Dict`, *optional*):
|
||||
Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
|
||||
and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
|
||||
accordingly.
|
||||
Expected contents:
|
||||
`rope_type` (`str`):
|
||||
The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
|
||||
'llama3'], with 'default' being the original RoPE implementation.
|
||||
`factor` (`float`, *optional*):
|
||||
Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
|
||||
most scaling types, a `factor` of x will enable the model to handle sequences of length x *
|
||||
original maximum pre-trained length.
|
||||
`original_max_position_embeddings` (`int`, *optional*):
|
||||
Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
|
||||
pretraining.
|
||||
`attention_factor` (`float`, *optional*):
|
||||
Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
|
||||
computation. If unspecified, it defaults to value recommended by the implementation, using the
|
||||
`factor` field to infer the suggested value.
|
||||
`beta_fast` (`float`, *optional*):
|
||||
Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
|
||||
ramp function. If unspecified, it defaults to 32.
|
||||
`beta_slow` (`float`, *optional*):
|
||||
Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
|
||||
ramp function. If unspecified, it defaults to 1.
|
||||
`short_factor` (`List[float]`, *optional*):
|
||||
Only used with 'longrope'. The scaling factor to be applied to short contexts (<
|
||||
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
|
||||
size divided by the number of attention heads divided by 2
|
||||
`long_factor` (`List[float]`, *optional*):
|
||||
Only used with 'longrope'. The scaling factor to be applied to long contexts (<
|
||||
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
|
||||
size divided by the number of attention heads divided by 2
|
||||
`low_freq_factor` (`float`, *optional*):
|
||||
Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
|
||||
`high_freq_factor` (`float`, *optional*):
|
||||
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
|
||||
<TODO>
|
||||
<TODO>
|
||||
no_rope_layers (`int`, *optional*): TODO
|
||||
no_rope_layer_interval (`int`, *optional*, defaults to 4): TODO
|
||||
attention_chunk_size (`int`, *optional*, defaults to 8192):
|
||||
<TODO>
|
||||
attn_temperature_tuning (`int`, *optional*, defaults to 4): TODO
|
||||
floor_scale (`int`, *optional*, defaults to 8192): TODO
|
||||
attn_scale (`int`, *optional*, defaults to 0.1): TODO
|
||||
|
||||
Example:
|
||||
"""
|
||||
|
||||
model_type = "llama4_text"
|
||||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
base_model_tp_plan = {
|
||||
"layers.*.self_attn.q_proj": "colwise",
|
||||
"layers.*.self_attn.k_proj": "colwise",
|
||||
"layers.*.self_attn.v_proj": "colwise",
|
||||
"layers.*.self_attn.o_proj": "rowwise",
|
||||
"layers.*.input_layernorm.weight": "sequence_parallel",
|
||||
"layers.*.post_attention_layernorm.weight": "sequence_parallel",
|
||||
"norm.weight": "sequence_parallel",
|
||||
"layers.*.feed_forward.shared_expert.gate_proj": "local_colwise",
|
||||
"layers.*.feed_forward.shared_expert.up_proj": "local_colwise",
|
||||
"layers.*.feed_forward.shared_expert.down_proj": "local_rowwise",
|
||||
"layers.*.feed_forward.experts.gate_up_proj": "local_packed_rowwise", # row because not linear
|
||||
"layers.*.feed_forward.experts.down_proj": "local_colwise", # col because not linear
|
||||
"layers.*.feed_forward.experts": "local",
|
||||
"layers.*.feed_forward.gate_proj": "local_colwise",
|
||||
"layers.*.feed_forward.up_proj": "local_colwise",
|
||||
"layers.*.feed_forward.down_proj": "local_rowwise",
|
||||
"layers.*.feed_forward": "gather",
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=202048,
|
||||
hidden_size=5120,
|
||||
intermediate_size=8192,
|
||||
intermediate_size_mlp=16384,
|
||||
num_hidden_layers=48,
|
||||
num_attention_heads=40,
|
||||
num_key_value_heads=8,
|
||||
head_dim=128,
|
||||
hidden_act="silu",
|
||||
max_position_embeddings=4096 * 32,
|
||||
initializer_range=0.02,
|
||||
rms_norm_eps=1e-5,
|
||||
use_cache=True,
|
||||
pad_token_id=None,
|
||||
bos_token_id=1,
|
||||
eos_token_id=2,
|
||||
tie_word_embeddings=False,
|
||||
rope_theta=500000,
|
||||
attention_dropout=0.0,
|
||||
num_experts_per_tok=1,
|
||||
num_local_experts=16,
|
||||
moe_layers=None,
|
||||
interleave_moe_layer_step=1,
|
||||
use_qk_norm=True,
|
||||
output_router_logits=False,
|
||||
router_aux_loss_coef=0.001,
|
||||
router_jitter_noise=0.0,
|
||||
rope_scaling=None,
|
||||
no_rope_layers=None,
|
||||
no_rope_layer_interval=4,
|
||||
attention_chunk_size=8192,
|
||||
attn_temperature_tuning=4,
|
||||
floor_scale=8192,
|
||||
attn_scale=0.1,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(
|
||||
pad_token_id=pad_token_id,
|
||||
bos_token_id=bos_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
tie_word_embeddings=tie_word_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
self.attn_temperature_tuning = attn_temperature_tuning
|
||||
self.attn_scale = attn_scale
|
||||
self.floor_scale = floor_scale
|
||||
self.vocab_size = vocab_size
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.intermediate_size_mlp = intermediate_size_mlp
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.rope_scaling = rope_scaling
|
||||
self.attention_bias = False
|
||||
|
||||
# for backward compatibility
|
||||
if num_key_value_heads is None:
|
||||
num_key_value_heads = num_attention_heads
|
||||
|
||||
self.num_key_value_heads = num_key_value_heads
|
||||
self.hidden_act = hidden_act
|
||||
self.initializer_range = initializer_range
|
||||
self.rms_norm_eps = rms_norm_eps
|
||||
self.use_cache = use_cache
|
||||
self.rope_theta = rope_theta
|
||||
self.attention_dropout = attention_dropout
|
||||
self.head_dim = head_dim if head_dim is not None else self.hidden_size // self.num_attention_heads
|
||||
self.use_qk_norm = use_qk_norm
|
||||
|
||||
self.num_experts_per_tok = num_experts_per_tok
|
||||
self.num_local_experts = num_local_experts
|
||||
|
||||
self.output_router_logits = output_router_logits
|
||||
self.router_aux_loss_coef = router_aux_loss_coef
|
||||
self.router_jitter_noise = router_jitter_noise
|
||||
default_no_rope_layers = [
|
||||
int((layer_idx + 1) % no_rope_layer_interval != 0) for layer_idx in range(self.num_hidden_layers)
|
||||
]
|
||||
|
||||
# no_rope_layers == [] is invalid as we cannot have 0 layers
|
||||
self.no_rope_layers = no_rope_layers if no_rope_layers else default_no_rope_layers
|
||||
|
||||
self.interleave_moe_layer_step = interleave_moe_layer_step
|
||||
self.moe_layers = (
|
||||
moe_layers
|
||||
if moe_layers is not None
|
||||
else list(range(interleave_moe_layer_step - 1, num_hidden_layers, interleave_moe_layer_step))
|
||||
)
|
||||
self.attention_chunk_size = attention_chunk_size
|
||||
|
||||
|
||||
class Llama4Config(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`Llama4Model`]. It is used to instantiate an
|
||||
Llama4 model according to the specified arguments, defining the model architecture. Instantiating a configuration
|
||||
with the defaults will yield a similar configuration to that of the Llama4 109B.
|
||||
|
||||
e.g. [meta-llama/Llama-4-Scout-17B-16E](https://huggingface.co/meta-llama/Llama-4-Scout-17B-16E)
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
|
||||
Args:
|
||||
vision_config (`Llama4VisionConfig`, *optional*):
|
||||
The Llama4 Vision config.
|
||||
text_config (`Llama4TextConfig`, *optional*):
|
||||
The Llama4 Text config.
|
||||
boi_token_index (`int`, *optional*, defaults to 200080):
|
||||
The begin-of-image token index to wrap the image prompt.
|
||||
eoi_token_index (`int`, *optional*, defaults to 200081):
|
||||
The end-of-image token index to wrap the image prompt.
|
||||
image_token_index (`int`, *optional*, defaults to 200092):
|
||||
The image token index to encode the image prompt.
|
||||
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
|
||||
Whether the model's input and output word embeddings should be tied.
|
||||
|
||||
```python
|
||||
>>> from transformers import Llama4Model, Llama4Config
|
||||
|
||||
>>> # Initializing a Llama4 7B style configuration
|
||||
>>> configuration = Llama4Config()
|
||||
|
||||
>>> # Initializing a model from the Llama4 7B style configuration
|
||||
>>> model = Llama4Model(configuration)
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
```"""
|
||||
|
||||
model_type = "llama4"
|
||||
sub_configs = {"text_config": Llama4TextConfig, "vision_config": Llama4VisionConfig}
|
||||
base_model_tp_plan = {
|
||||
"multi_modal_projector.linear_1": "colwise_rep",
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vision_config=None,
|
||||
text_config=None,
|
||||
boi_token_index=200080,
|
||||
eoi_token_index=200081,
|
||||
image_token_index=200092,
|
||||
tie_word_embeddings=False,
|
||||
**kwargs,
|
||||
):
|
||||
if vision_config is None:
|
||||
self.vision_config = Llama4VisionConfig()
|
||||
logger.info("vision_config is None, using default llama4 vision config")
|
||||
elif isinstance(vision_config, dict):
|
||||
self.vision_config = Llama4VisionConfig(**vision_config)
|
||||
elif isinstance(vision_config, Llama4VisionConfig):
|
||||
self.vision_config = vision_config
|
||||
|
||||
self.boi_token_index = boi_token_index
|
||||
self.eoi_token_index = eoi_token_index
|
||||
self.image_token_index = image_token_index
|
||||
|
||||
if text_config is None:
|
||||
self.text_config = Llama4TextConfig()
|
||||
logger.info("text_config is None, using default llama4 text config")
|
||||
elif isinstance(text_config, dict):
|
||||
self.text_config = Llama4TextConfig(**text_config)
|
||||
elif isinstance(text_config, Llama4TextConfig):
|
||||
self.text_config = text_config
|
||||
|
||||
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
|
||||
|
||||
|
||||
__all__ = ["Llama4Config", "Llama4TextConfig", "Llama4VisionConfig"]
|
736
src/transformers/models/llama4/convert_llama4_weights_to_hf.py
Normal file
736
src/transformers/models/llama4/convert_llama4_weights_to_hf.py
Normal file
File diff suppressed because one or more lines are too long
480
src/transformers/models/llama4/image_processing_llama4_fast.py
Normal file
480
src/transformers/models/llama4/image_processing_llama4_fast.py
Normal file
@ -0,0 +1,480 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Fast Image processor class for Got-OCR-2."""
|
||||
|
||||
import math
|
||||
from collections import defaultdict
|
||||
from functools import lru_cache
|
||||
from typing import List, Optional, Set, Tuple, Union
|
||||
|
||||
from ...image_processing_utils import BatchFeature
|
||||
from ...image_processing_utils_fast import (
|
||||
BASE_IMAGE_PROCESSOR_FAST_DOCSTRING,
|
||||
BASE_IMAGE_PROCESSOR_FAST_DOCSTRING_PREPROCESS,
|
||||
BaseImageProcessorFast,
|
||||
DefaultFastImageProcessorKwargs,
|
||||
group_images_by_shape,
|
||||
reorder_images,
|
||||
)
|
||||
from ...image_utils import (
|
||||
ImageInput,
|
||||
PILImageResampling,
|
||||
SizeDict,
|
||||
)
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import (
|
||||
TensorType,
|
||||
add_start_docstrings,
|
||||
is_torch_available,
|
||||
is_torchvision_available,
|
||||
is_torchvision_v2_available,
|
||||
)
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
if is_torchvision_available():
|
||||
if is_torchvision_v2_available():
|
||||
from torchvision.transforms.v2 import functional as F
|
||||
else:
|
||||
from torchvision.transforms import functional as F
|
||||
|
||||
|
||||
def get_factors(dividend: int) -> Set[int]:
|
||||
"""
|
||||
Calculate all factors of a given number, i.e. a dividor that leaves
|
||||
no remainder. For example, if dividend=12, it will return {1, 2, 3, 4, 6, 12}.
|
||||
|
||||
Args:
|
||||
dividend (int): The number to find factors for.
|
||||
|
||||
Returns:
|
||||
set: A set containing all factors of the number.
|
||||
"""
|
||||
factors_set = set()
|
||||
|
||||
for i in range(1, int(dividend**0.5) + 1):
|
||||
if dividend % i == 0:
|
||||
factors_set.add(i)
|
||||
factors_set.add(dividend // i)
|
||||
return factors_set
|
||||
|
||||
|
||||
def get_max_res_without_distortion(
|
||||
image_size: Tuple[int, int],
|
||||
target_size: Tuple[int, int],
|
||||
) -> Tuple[int, int]:
|
||||
"""
|
||||
Determines the maximum resolution to which an image can be resized to without distorting its
|
||||
aspect ratio, based on the target resolution.
|
||||
|
||||
Args:
|
||||
image_size (Tuple[int, int]): The original resolution of the image (height, width).
|
||||
target_resolution (Tuple[int, int]): The desired resolution to fit the image into (height, width).
|
||||
Returns:
|
||||
Tuple[int, int]: The optimal dimensions (height, width) to which the image should be resized.
|
||||
Example:
|
||||
>>> _get_max_res_without_distortion([200, 300], target_size = [450, 200])
|
||||
(134, 200)
|
||||
>>> _get_max_res_without_distortion([800, 600], target_size = [450, 1300])
|
||||
(450, 338)
|
||||
"""
|
||||
|
||||
original_height, original_width = image_size
|
||||
target_height, target_width = target_size
|
||||
|
||||
scale_w = target_width / original_width
|
||||
scale_h = target_height / original_height
|
||||
|
||||
if scale_w < scale_h:
|
||||
new_width = target_width
|
||||
new_height = min(math.floor(original_height * scale_w), target_height)
|
||||
else:
|
||||
new_height = target_height
|
||||
new_width = min(math.floor(original_width * scale_h), target_width)
|
||||
|
||||
return new_height, new_width
|
||||
|
||||
|
||||
class Llama4ImageProcessorKwargs(DefaultFastImageProcessorKwargs):
|
||||
max_patches: Optional[int]
|
||||
resize_to_max_canvas: Optional[bool]
|
||||
|
||||
|
||||
def split_to_tiles(images: torch.Tensor, num_tiles_height: int, num_tiles_width: int) -> torch.Tensor:
|
||||
# Split image into number of required tiles (width x height)
|
||||
batch_size, num_channels, height, width = images.size()
|
||||
images = images.view(
|
||||
batch_size,
|
||||
num_channels,
|
||||
num_tiles_height,
|
||||
height // num_tiles_height,
|
||||
num_tiles_width,
|
||||
width // num_tiles_width,
|
||||
)
|
||||
# Permute dimensions to reorder the axes
|
||||
image = images.permute(0, 2, 4, 1, 3, 5).contiguous()
|
||||
# Reshape into the desired output shape (batch_size * 4, num_channels, width/2, height/2)
|
||||
image = image.view(
|
||||
batch_size,
|
||||
num_tiles_width * num_tiles_height,
|
||||
num_channels,
|
||||
height // num_tiles_height,
|
||||
width // num_tiles_width,
|
||||
)
|
||||
return image
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def find_supported_resolutions(max_num_chunks: int, patch_size: SizeDict) -> torch.Tensor:
|
||||
"""
|
||||
Computes all of the allowed resolutions for a fixed number of chunks
|
||||
and patch_size. Useful for when dividing an image into chunks.
|
||||
|
||||
Args:
|
||||
max_num_chunks (int): Maximum number of chunks for processing.
|
||||
patch_size (int): Size of the side of the patch.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: List of possible resolutions as tuples (height, width).
|
||||
|
||||
Example:
|
||||
>>> max_num_chunks = 5
|
||||
>>> patch_size = 224
|
||||
>>> find_supported_resolutions(max_num_chunks, patch_size)
|
||||
tensor([(224, 896), (448, 448), (224, 224), (896, 224), (224, 672),
|
||||
(672, 224), (224, 448), (448, 224)])
|
||||
|
||||
Given max_num_chunks=4, patch_size=224, it will create a dictionary:
|
||||
{
|
||||
0.25: [(1, 4)],
|
||||
1.0: [(2, 2), (1, 1)],
|
||||
4.0: [(4, 1)],
|
||||
0.33: [(1, 3)],
|
||||
3.0: [(3, 1)],
|
||||
0.5: [(1, 2)],
|
||||
2.0: [(2, 1)]
|
||||
}
|
||||
|
||||
and return the resolutions multiplied by the patch_size:
|
||||
[(1*224, 4*224), (2*224, 2*224), ..., (2*224, 1*224)]
|
||||
"""
|
||||
height, width = patch_size.height, patch_size.width
|
||||
if height != width:
|
||||
raise ValueError("`size` must be square.")
|
||||
|
||||
patch_size = height
|
||||
|
||||
asp_dict = defaultdict(list)
|
||||
for chunk_size in range(max_num_chunks, 0, -1):
|
||||
_factors = sorted(get_factors(chunk_size))
|
||||
_asp_ratios = [(factor, chunk_size // factor) for factor in _factors]
|
||||
for height, width in _asp_ratios:
|
||||
ratio_float = height / width
|
||||
asp_dict[ratio_float].append((height, width))
|
||||
|
||||
# get the resolutions multiplied by the patch_size
|
||||
possible_resolutions = []
|
||||
for key, value in asp_dict.items():
|
||||
for height, depth in value:
|
||||
possible_resolutions.append((height * patch_size, depth * patch_size))
|
||||
|
||||
return possible_resolutions
|
||||
|
||||
|
||||
def pad_to_best_fit(
|
||||
images: "torch.Tensor",
|
||||
target_size: Tuple[int, int],
|
||||
background_color: Union[int, Tuple[int, int, int]] = 0,
|
||||
) -> "torch.Tensor":
|
||||
"""
|
||||
Pads an image to fit the target size.
|
||||
|
||||
Args:
|
||||
images (`np.ndarray`):
|
||||
The images to pad.
|
||||
background_color (`int` or `Tuple[int, int, int]`, *optional*, defaults to 0):
|
||||
The color to use for the padding. Can be an integer for single channel or a
|
||||
tuple of integers representing for multi-channel images. If passed as integer
|
||||
in mutli-channel mode, it will default to `0` in subsequent channels.
|
||||
Returns:
|
||||
`torch.Tensor`: The padded images.
|
||||
"""
|
||||
|
||||
num_channels = images.shape[1] if len(images.shape) == 4 else images.shape[0]
|
||||
if isinstance(background_color, int):
|
||||
background_color = [background_color] + [0] * (num_channels - 1)
|
||||
elif len(background_color) != num_channels:
|
||||
raise ValueError(
|
||||
f"background_color must have no more than {num_channels} elements to match the number of channels"
|
||||
)
|
||||
|
||||
height, width = images.shape[-2:]
|
||||
target_height, target_width = target_size
|
||||
paste_x_right = target_width - width
|
||||
paste_y_right = target_height - height
|
||||
padded_images = F.pad(images, padding=[0, 0, paste_x_right, paste_y_right], fill=background_color)
|
||||
|
||||
return padded_images
|
||||
|
||||
|
||||
def get_best_fit(
|
||||
image_size: Tuple[int, int],
|
||||
possible_resolutions: torch.Tensor,
|
||||
resize_to_max_canvas: bool = False,
|
||||
) -> Tuple[int, int]:
|
||||
"""
|
||||
Determines the best canvas possible from a list of possible resolutions to, without distortion,
|
||||
resize an image to.
|
||||
|
||||
For each possible resolution, calculates the scaling factors for
|
||||
width and height, and selects the smallest one, which is the limiting side.
|
||||
E.g. to match the canvas you can upscale height by 2x, and width by 1.5x,
|
||||
therefore, the maximum upscaling you can do is min(2, 1.5) = 1.5.
|
||||
|
||||
If upscaling is possible (any of the scaling factors is greater than 1),
|
||||
then picks the smallest upscaling factor > 1, unless resize_to_max_canvas is True.
|
||||
|
||||
If upscaling is not possible, then picks the largest scaling factor <= 1, i.e.
|
||||
reduce downscaling as much as possible.
|
||||
|
||||
If there are multiple resolutions with the same max scale, we pick the one with the lowest area,
|
||||
to minimize padding. E.g., the same image can be upscaled to 224x224 and 224x448, but the latter
|
||||
has more padding.
|
||||
|
||||
Args:
|
||||
image_size (Tuple[int, int]): A tuple containing the height and width of the image.
|
||||
possible_resolutions (torch.Tensor): A tensor of shape (N, 2) where each
|
||||
row represents a possible resolution (height, width).
|
||||
resize_to_max_canvas (bool): If True, will return the largest upscaling resolution.
|
||||
|
||||
Returns:
|
||||
List[int]: The best resolution [height, width] for the given image.
|
||||
|
||||
Example:
|
||||
>>> image_size = (200, 300)
|
||||
>>> possible_resolutions = torch.tensor([[224, 672],
|
||||
... [672, 224],
|
||||
... [224, 448],
|
||||
... [448, 224],
|
||||
... [224, 224]])
|
||||
>>> get_best_fit(image_size, possible_resolutions)
|
||||
[224, 448]
|
||||
|
||||
We have:
|
||||
scale_w = tensor([2.2400, 0.7467, 1.4933, 0.7467, 0.7467])
|
||||
scale_h = tensor([1.1200, 3.3600, 1.1200, 2.2400, 1.1200])
|
||||
scales = tensor([1.1200, 0.7467, 1.1200, 0.7467, 0.7467])
|
||||
Only one of the scales > 1:
|
||||
upscaling_possible = tensor([1.1200, 1.1200])
|
||||
smallest_rescale = tensor(1.1200)
|
||||
So we pick the resolution with the smallest smallest area:
|
||||
areas = tensor([150528, 100352]) # [672, 224], [224, 448]
|
||||
optimal_canvas = tensor([224, 448])
|
||||
"""
|
||||
|
||||
original_height, original_width = image_size
|
||||
|
||||
# get all possible resolutions heights/widths
|
||||
target_heights, target_widths = (
|
||||
possible_resolutions[:, 0],
|
||||
possible_resolutions[:, 1],
|
||||
)
|
||||
|
||||
# get scaling factors to resize the image without distortion
|
||||
scale_w = target_widths / original_width
|
||||
scale_h = target_heights / original_height
|
||||
|
||||
# get the min scale between width and height (limiting side -> no distortion)
|
||||
scales = torch.where(scale_h > scale_w, scale_w, scale_h)
|
||||
|
||||
# filter only scales that allow upscaling
|
||||
upscaling_options = scales[scales >= 1]
|
||||
if len(upscaling_options) > 0:
|
||||
if resize_to_max_canvas:
|
||||
selected_scale = torch.max(upscaling_options)
|
||||
else:
|
||||
selected_scale = torch.min(upscaling_options)
|
||||
else:
|
||||
# no upscaling possible,
|
||||
# get the minimum downscaling (max scale for scales<1)
|
||||
downscaling_options = scales[scales < 1]
|
||||
selected_scale = torch.max(downscaling_options)
|
||||
|
||||
# get all resolutions that support this scaling factor,
|
||||
# e.g. you can upscale to 224x224, 224x448, 224x672 without distortion
|
||||
chosen_canvas = possible_resolutions[scales == selected_scale]
|
||||
|
||||
# if there are multiple resolutions,
|
||||
# get the one with minimum area to reduce padding
|
||||
if len(chosen_canvas) > 1:
|
||||
areas = chosen_canvas[:, 0] * chosen_canvas[:, 1]
|
||||
optimal_idx = torch.argmin(areas)
|
||||
optimal_canvas = chosen_canvas[optimal_idx]
|
||||
else:
|
||||
optimal_canvas = chosen_canvas[0]
|
||||
|
||||
return tuple(optimal_canvas.tolist())
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"Constructs a fast Llama4 image processor.",
|
||||
BASE_IMAGE_PROCESSOR_FAST_DOCSTRING,
|
||||
"""
|
||||
max_patches (`int`, *optional*, defaults to 16):
|
||||
The maximum number of patches to be extracted from the image.
|
||||
Can be overridden by the `max_patches` parameter in the `preprocess` method.
|
||||
resize_to_max_canvas (`bool`, *optional*, defaults to False):
|
||||
Whether to resize the image to the maximum canvas size.
|
||||
If True, picks the canvas the allows the largest resizing without distortion.
|
||||
If False, downsample as little as possible, including no resizing at all,
|
||||
but never upsample, unless the image is smaller than the patch size.
|
||||
""",
|
||||
)
|
||||
class Llama4ImageProcessorFast(BaseImageProcessorFast):
|
||||
resample = PILImageResampling.BILINEAR
|
||||
image_mean = [0.5, 0.5, 0.5]
|
||||
image_std = [0.5, 0.5, 0.5]
|
||||
size = {"height": 336, "width": 336}
|
||||
do_resize = True
|
||||
do_rescale = True
|
||||
do_normalize = True
|
||||
do_convert_rgb = True
|
||||
max_patches = 16
|
||||
resize_to_max_canvas = False
|
||||
valid_kwargs = Llama4ImageProcessorKwargs
|
||||
|
||||
def __init__(self, **kwargs: Unpack[Llama4ImageProcessorKwargs]):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def rescale_and_normalize(
|
||||
self,
|
||||
images: "torch.Tensor",
|
||||
do_rescale: bool,
|
||||
rescale_factor: float,
|
||||
do_normalize: bool,
|
||||
image_mean: Union[float, List[float]],
|
||||
image_std: Union[float, List[float]],
|
||||
) -> "torch.Tensor":
|
||||
"""
|
||||
Rescale and normalize images.
|
||||
Override to rescale and normalize the images in torch.bfloat16 as in the original implementation
|
||||
"""
|
||||
if do_rescale and do_normalize:
|
||||
images = images.to(dtype=torch.bfloat16) * rescale_factor
|
||||
images = self.normalize(images, image_mean, image_std)
|
||||
elif do_rescale:
|
||||
images = images * rescale_factor
|
||||
elif do_normalize:
|
||||
images = self.normalize(images, image_mean, image_std)
|
||||
|
||||
return images
|
||||
|
||||
@add_start_docstrings(
|
||||
BASE_IMAGE_PROCESSOR_FAST_DOCSTRING_PREPROCESS,
|
||||
"""
|
||||
max_patches (`int`, *optional*, defaults to 16):
|
||||
The maximum number of patches to be extracted from the image.
|
||||
Can be overridden by the `max_patches` parameter in the `preprocess` method.
|
||||
resize_to_max_canvas (`bool`, *optional*, defaults to False):
|
||||
Whether to resize the image to the maximum canvas size.
|
||||
If True, picks the canvas the allows the largest resizing without distortion.
|
||||
If False, downsample as little as possible, including no resizing at all,
|
||||
but never upsample, unless the image is smaller than the patch size.
|
||||
""",
|
||||
)
|
||||
def preprocess(self, images: ImageInput, **kwargs: Unpack[Llama4ImageProcessorKwargs]) -> BatchFeature:
|
||||
return super().preprocess(images, **kwargs)
|
||||
|
||||
def _preprocess(
|
||||
self,
|
||||
images: List["torch.Tensor"],
|
||||
size: SizeDict,
|
||||
max_patches: int,
|
||||
resize_to_max_canvas: bool,
|
||||
interpolation: Optional["F.InterpolationMode"],
|
||||
do_rescale: bool,
|
||||
rescale_factor: float,
|
||||
do_normalize: bool,
|
||||
image_mean: Optional[Union[float, List[float]]],
|
||||
image_std: Optional[Union[float, List[float]]],
|
||||
return_tensors: Optional[Union[str, TensorType]],
|
||||
**kwargs,
|
||||
) -> BatchFeature:
|
||||
possible_resolutions = find_supported_resolutions(max_num_chunks=max_patches, patch_size=size)
|
||||
possible_resolutions = torch.tensor(possible_resolutions)
|
||||
# process images by batch, grouped by shape
|
||||
grouped_images, grouped_images_index = group_images_by_shape(images)
|
||||
grouped_processed_images = {}
|
||||
grouped_aspect_ratios = {}
|
||||
for shape, stacked_images in grouped_images.items():
|
||||
image_size = stacked_images.shape[-2:]
|
||||
target_size = get_best_fit(image_size, possible_resolutions, resize_to_max_canvas=resize_to_max_canvas)
|
||||
# If target_size requires upscaling, we might want to limit the upscaling to max_upscaling_size
|
||||
max_upscaling_size = None if resize_to_max_canvas else size.height
|
||||
if max_upscaling_size is not None:
|
||||
new_target_height = min(max(image_size[0], max_upscaling_size), target_size[0])
|
||||
new_target_width = min(max(image_size[1], max_upscaling_size), target_size[1])
|
||||
target_size_without_distortion = (new_target_height, new_target_width)
|
||||
|
||||
# resize to target_size while preserving aspect ratio
|
||||
new_size_without_distortion = get_max_res_without_distortion(image_size, target_size_without_distortion)
|
||||
new_size_without_distortion = SizeDict(
|
||||
height=max(new_size_without_distortion[0], 1), width=max(new_size_without_distortion[1], 1)
|
||||
)
|
||||
processed_images = self.resize(
|
||||
stacked_images,
|
||||
new_size_without_distortion,
|
||||
interpolation=interpolation,
|
||||
)
|
||||
|
||||
# pad to target_size to be able to split into tiles
|
||||
processed_images = pad_to_best_fit(processed_images, target_size)
|
||||
processed_images = self.rescale_and_normalize(
|
||||
processed_images, do_rescale, rescale_factor, do_normalize, image_mean, image_std
|
||||
)
|
||||
|
||||
ratio_h, ratio_w = (
|
||||
target_size[0] // size.height,
|
||||
target_size[1] // size.height,
|
||||
)
|
||||
# split into tiles
|
||||
processed_images = split_to_tiles(processed_images, ratio_h, ratio_w)
|
||||
grouped_processed_images[shape] = processed_images
|
||||
grouped_aspect_ratios[shape] = torch.tensor([[ratio_h, ratio_w]] * stacked_images.shape[0])
|
||||
|
||||
# add a global tile to the processed tile if there are more than one tile
|
||||
if ratio_h * ratio_w > 1:
|
||||
global_tiles = self.resize(
|
||||
stacked_images,
|
||||
size,
|
||||
interpolation=interpolation,
|
||||
)
|
||||
global_tiles = self.rescale_and_normalize(
|
||||
global_tiles, do_rescale, rescale_factor, do_normalize, image_mean, image_std
|
||||
)
|
||||
grouped_processed_images[shape] = torch.cat([processed_images, global_tiles.unsqueeze(1)], dim=1)
|
||||
processed_images = reorder_images(grouped_processed_images, grouped_images_index)
|
||||
aspect_ratios_list = reorder_images(grouped_aspect_ratios, grouped_images_index)
|
||||
|
||||
processed_images = torch.cat(processed_images, dim=0) if return_tensors else processed_images
|
||||
aspect_ratios = torch.stack(aspect_ratios_list, dim=0) if return_tensors else aspect_ratios_list
|
||||
return BatchFeature(
|
||||
data={"pixel_values": processed_images, "aspect_ratios": aspect_ratios}, tensor_type=return_tensors
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["Llama4ImageProcessorFast"]
|
1903
src/transformers/models/llama4/modeling_llama4.py
Normal file
1903
src/transformers/models/llama4/modeling_llama4.py
Normal file
File diff suppressed because it is too large
Load Diff
275
src/transformers/models/llama4/processing_llama4.py
Normal file
275
src/transformers/models/llama4/processing_llama4.py
Normal file
File diff suppressed because one or more lines are too long
@ -981,6 +981,8 @@ class Pipeline(_ScikitCompat, PushToHubMixin):
|
||||
else:
|
||||
self.device = device if device is not None else -1
|
||||
|
||||
if torch.distributed.is_initialized():
|
||||
self.device = self.model.device
|
||||
logger.warning(f"Device set to use {self.device}")
|
||||
|
||||
self.binary_output = binary_output
|
||||
|
@ -1178,10 +1178,6 @@ class ProcessorMixin(PushToHubMixin):
|
||||
unused_kwargs = {}
|
||||
unused_keys = set(kwargs_from_config) - set(valid_kwargs)
|
||||
if unused_keys:
|
||||
unused_key_str = ", ".join(unused_keys)
|
||||
logger.warning(
|
||||
f"Some kwargs in processor config are unused and will not have any effect: {unused_key_str}. "
|
||||
)
|
||||
unused_kwargs = {k: processor_config[k] for k in unused_keys}
|
||||
return unused_kwargs
|
||||
|
||||
|
@ -43,8 +43,7 @@ is_torch_greater_or_equal_than_1_12 = is_torch_greater_or_equal("1.12", accept_d
|
||||
_torch_distributed_available = torch.distributed.is_available()
|
||||
|
||||
if is_torch_greater_or_equal("2.5") and _torch_distributed_available:
|
||||
from torch.distributed.tensor import Replicate
|
||||
from torch.distributed.tensor.parallel import ColwiseParallel, RowwiseParallel
|
||||
pass
|
||||
|
||||
|
||||
def softmax_backward_data(parent, grad_output, output, dim, self):
|
||||
@ -335,29 +334,6 @@ def isin_mps_friendly(elements: torch.Tensor, test_elements: torch.Tensor | int)
|
||||
return torch.isin(elements, test_elements)
|
||||
|
||||
|
||||
# TODO need to add the __repr__ that shows that it is a colwise parallel
|
||||
# See https://github.com/pytorch/pytorch/issues/145726
|
||||
def translate_to_torch_parallel_style(style: str):
|
||||
"""
|
||||
In model configurations, we use a neutral type (string) to specify parallel
|
||||
styles, here we translate them into torch.distributed tensor-parallel
|
||||
types.
|
||||
"""
|
||||
if not isinstance(style, str):
|
||||
raise ValueError(f"Unsupported parallel style type {type(style)}, expected str")
|
||||
|
||||
if style == "colwise":
|
||||
return ColwiseParallel()
|
||||
elif style == "rowwise":
|
||||
return RowwiseParallel()
|
||||
elif style == "colwise_rep":
|
||||
return ColwiseParallel(output_layouts=Replicate())
|
||||
elif style == "rowwise_rep":
|
||||
return RowwiseParallel(input_layouts=Replicate())
|
||||
else:
|
||||
raise ValueError(f"Unsupported parallel style value: {style}")
|
||||
|
||||
|
||||
def compile_compatible_method_lru_cache(*lru_args, **lru_kwargs):
|
||||
"""
|
||||
LRU cache decorator from standard functools library, but with a workaround to disable
|
||||
@ -382,88 +358,3 @@ def compile_compatible_method_lru_cache(*lru_args, **lru_kwargs):
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def distribute_module(
|
||||
module: nn.Module,
|
||||
device_mesh=None,
|
||||
partition_fn=None,
|
||||
input_fn=None,
|
||||
output_fn=None,
|
||||
) -> nn.Module:
|
||||
"""
|
||||
This function expose three functions to control the parameters/inputs/outputs of the module:
|
||||
|
||||
1. To perform sharding on the module before runtime execution by specifying the
|
||||
``partition_fn`` (i.e. allow user to convert Module parameters to :class:`DTensor`
|
||||
parameters according to the `partition_fn` specified).
|
||||
2. To control the inputs or outputs of the module during runtime execution by
|
||||
specifying the ``input_fn`` and ``output_fn``. (i.e. convert the input to
|
||||
:class:`DTensor`, convert the output back to ``torch.Tensor``)
|
||||
|
||||
Args:
|
||||
module (:class:`nn.Module`): user module to be partitioned.
|
||||
device_mesh (:class:`DeviceMesh`): the device mesh to place the module.
|
||||
partition_fn (Callable): the function to partition parameters (i.e. shard certain
|
||||
parameters across the ``device_mesh``). If ``partition_fn`` is not specified,
|
||||
by default we replicate all module parameters of ``module`` across the mesh.
|
||||
input_fn (Callable): specify the input distribution, i.e. could control how the
|
||||
input of the module is sharded. ``input_fn`` will be installed as a module
|
||||
``forward_pre_hook`` (pre forward hook).
|
||||
output_fn (Callable): specify the output distribution, i.e. could control how the
|
||||
output is sharded, or convert it back to torch.Tensor. ``output_fn`` will be
|
||||
installed as a module ``forward_hook`` (post forward hook).
|
||||
|
||||
Returns:
|
||||
A module that contains parameters/buffers that are all ``DTensor`` s.
|
||||
|
||||
.. note::
|
||||
When initialize the DeviceMesh with the ``xla`` device_type, ``distribute_module``
|
||||
return nn.Module with PyTorch/XLA SPMD annotated parameters. See
|
||||
`this issue <https://github.com/pytorch/pytorch/issues/92909>`__
|
||||
for more details. The XLA integration is experimental and subject to change.
|
||||
|
||||
"""
|
||||
|
||||
torch._C._log_api_usage_once("torch.dtensor.distribute_module")
|
||||
|
||||
device_mesh = device_mesh
|
||||
|
||||
# register input_fn as module forward pre hook
|
||||
if input_fn is not None:
|
||||
# check the input_fn signature
|
||||
num_args = len(inspect.signature(input_fn).parameters)
|
||||
if num_args == 2:
|
||||
# input_fn only takes in inputs and device mesh
|
||||
logger.warning(
|
||||
"Deprecating input_fn that takes two arguments (inputs, device_mesh), "
|
||||
"please use input_fn that takes in (module, inputs, device_mesh) instead!",
|
||||
FutureWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
module.register_forward_pre_hook(lambda _, inputs: input_fn(inputs, device_mesh)) # type: ignore[call-arg]
|
||||
elif num_args == 3:
|
||||
# input_fn takes in module, inputs, device mesh
|
||||
module.register_forward_pre_hook(lambda mod, inputs: input_fn(mod, inputs, device_mesh))
|
||||
else:
|
||||
raise ValueError(f"input_fn should take in 3 arguments, but got {num_args} arguments!")
|
||||
# register output_fn as module forward hook
|
||||
if output_fn is not None:
|
||||
num_args = len(inspect.signature(output_fn).parameters)
|
||||
if num_args == 2:
|
||||
# output_fn only takes in outputs and device mesh
|
||||
logger.warning(
|
||||
"Deprecating output_fn that takes two arguments (inputs, device_mesh), "
|
||||
"please use output_fn that takes in (module, inputs, device_mesh) instead!",
|
||||
FutureWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
module.register_forward_hook(
|
||||
lambda mod, inputs, outputs: output_fn(outputs, device_mesh) # type: ignore[call-arg]
|
||||
)
|
||||
elif num_args == 3:
|
||||
module.register_forward_hook(lambda mod, inputs, outputs: output_fn(mod, outputs, device_mesh))
|
||||
else:
|
||||
raise ValueError(f"output_fn should take in 3 arguments, but got {num_args} arguments!")
|
||||
|
||||
return module
|
||||
|
52
src/transformers/quantizers/base.py
Executable file → Normal file
52
src/transformers/quantizers/base.py
Executable file → Normal file
@ -15,7 +15,8 @@ from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
||||
|
||||
from ..utils import is_torch_available
|
||||
from ..utils.quantization_config import QuantizationConfigMixin
|
||||
from ..utils.quantization_config import QuantizationConfigMixin, QuantizationMethod
|
||||
from .quantizers_utils import get_module_from_name
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -23,6 +24,9 @@ if TYPE_CHECKING:
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
from torch.nn import ModuleList
|
||||
else:
|
||||
ModuleList = str
|
||||
|
||||
|
||||
class HfQuantizer(ABC):
|
||||
@ -198,6 +202,10 @@ class HfQuantizer(ABC):
|
||||
"""
|
||||
return
|
||||
|
||||
def update_tp_plan(self, config):
|
||||
"updates the tp plan for the scales"
|
||||
return config
|
||||
|
||||
def preprocess_model(self, model: "PreTrainedModel", **kwargs):
|
||||
"""
|
||||
Setting model attributes and/or converting model before weights loading. At this point
|
||||
@ -212,6 +220,7 @@ class HfQuantizer(ABC):
|
||||
"""
|
||||
model.is_quantized = True
|
||||
model.quantization_method = self.quantization_config.quant_method
|
||||
self._convert_model_for_quantization(model)
|
||||
return self._process_model_before_weight_loading(model, **kwargs)
|
||||
|
||||
def postprocess_model(self, model: "PreTrainedModel", **kwargs):
|
||||
@ -288,3 +297,44 @@ class HfQuantizer(ABC):
|
||||
@property
|
||||
@abstractmethod
|
||||
def is_trainable(self): ...
|
||||
|
||||
def _convert_model_for_quantization(self, model):
|
||||
from accelerate import init_empty_weights
|
||||
|
||||
for name, module in model.named_modules():
|
||||
module_class_name = module.__class__.__name__
|
||||
if (
|
||||
module_class_name in MODULES_TO_PATCH_FOR_QUANTIZATION.keys()
|
||||
and self.quantization_config.quant_method == QuantizationMethod.COMPRESSED_TENSORS
|
||||
):
|
||||
with init_empty_weights():
|
||||
parent_module, name = get_module_from_name(model, name)
|
||||
parent_module._modules[name] = MODULES_TO_PATCH_FOR_QUANTIZATION[module_class_name](
|
||||
model.config.get_text_config()
|
||||
)
|
||||
|
||||
|
||||
class SequentialLlama4TextExperts(ModuleList):
|
||||
"""
|
||||
A module that implements a compressed version of a list of expert modules.
|
||||
This is specifically designed to work with Llama4TextExperts in MoE layers.
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
from transformers.models.llama4.modeling_llama4 import Llama4TextMLP
|
||||
|
||||
super().__init__([Llama4TextMLP(config) for _ in range(config.num_local_experts)])
|
||||
self.num_experts = config.num_local_experts
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: "torch.Tensor",
|
||||
) -> "torch.Tensor":
|
||||
hidden_states = hidden_states.reshape(self.num_experts, -1, hidden_states.shape[-1])
|
||||
routed_out = torch.zeros_like(hidden_states)
|
||||
for expert_idx in range(self.num_experts):
|
||||
routed_out[expert_idx] = self[expert_idx](hidden_states[expert_idx])
|
||||
return routed_out
|
||||
|
||||
|
||||
MODULES_TO_PATCH_FOR_QUANTIZATION = {"Llama4TextExperts": SequentialLlama4TextExperts}
|
||||
|
@ -146,6 +146,19 @@ class CompressedTensorsHfQuantizer(HfQuantizer):
|
||||
self.compressor.quantization_config.quantization_status = QuantizationStatus.FROZEN
|
||||
self.compressor.decompress(model_path=cache_path, model=model)
|
||||
|
||||
def update_tp_plan(self, config):
|
||||
additional_plan = {
|
||||
"layers.*.feed_forward.experts.*.gate_proj.weight": "local_colwise",
|
||||
"layers.*.feed_forward.experts.*.gate_proj.weight_scale": "local_colwise",
|
||||
"layers.*.feed_forward.experts.*.up_proj.weight": "local_colwise",
|
||||
"layers.*.feed_forward.experts.*.up_proj.weight_scale": "local_colwise",
|
||||
"layers.*.feed_forward.experts.*.down_proj.weight": "local_rowwise",
|
||||
}
|
||||
if config.get_text_config() is not None and config.get_text_config().base_model_tp_plan is not None:
|
||||
config.get_text_config().base_model_tp_plan.update(additional_plan)
|
||||
|
||||
return config
|
||||
|
||||
@property
|
||||
def is_trainable(self):
|
||||
return True
|
||||
|
@ -116,7 +116,7 @@ class FbgemmFp8HfQuantizer(HfQuantizer):
|
||||
state_dict: Dict[str, Any],
|
||||
**kwargs,
|
||||
):
|
||||
from ..integrations import FbgemmFp8Linear
|
||||
from ..integrations import FbgemmFp8Linear, FbgemmFp8Llama4TextExperts
|
||||
|
||||
module, tensor_name = get_module_from_name(model, param_name)
|
||||
|
||||
@ -129,6 +129,13 @@ class FbgemmFp8HfQuantizer(HfQuantizer):
|
||||
if tensor_name == "weight_scale":
|
||||
raise ValueError("Expect unquantized weights but got a quantized weight_scale")
|
||||
return True
|
||||
if isinstance(module, FbgemmFp8Llama4TextExperts):
|
||||
if self.pre_quantized or tensor_name == "bias":
|
||||
return False
|
||||
else:
|
||||
if tensor_name == "gate_up_proj_scale" or tensor_name == "down_proj_scale":
|
||||
raise ValueError("Expect unquantized weights but got a quantized weight_scale")
|
||||
return True
|
||||
return False
|
||||
|
||||
def create_quantized_param(
|
||||
@ -143,12 +150,52 @@ class FbgemmFp8HfQuantizer(HfQuantizer):
|
||||
"""
|
||||
Quantizes weights into weight and weight_scale
|
||||
"""
|
||||
new_value, weight_scale = torch.ops.fbgemm.quantize_fp8_per_row(param_value)
|
||||
|
||||
from ..integrations import FbgemmFp8Llama4TextExperts
|
||||
|
||||
module, tensor_name = get_module_from_name(model, param_name)
|
||||
module._buffers[tensor_name] = new_value.to(target_device)
|
||||
# to have the right output shape -> (out_features, 1)
|
||||
module._buffers["weight_scale"] = weight_scale.view(weight_scale.shape[0], 1).to(target_device)
|
||||
if isinstance(module, FbgemmFp8Llama4TextExperts):
|
||||
if tensor_name == "gate_up_proj":
|
||||
# Process each expert separately
|
||||
# Transpose the second and third dimension
|
||||
transposed_param = param_value.transpose(1, 2)
|
||||
|
||||
# Reshape to 2D for quantization
|
||||
original_shape = transposed_param.shape
|
||||
flattened_param = transposed_param.reshape(-1, original_shape[-1])
|
||||
|
||||
# Quantize using per row instead of per column
|
||||
new_value_flat, weight_scale_flat = torch.ops.fbgemm.quantize_fp8_per_row(flattened_param)
|
||||
|
||||
# Reshape back to original dimensions
|
||||
new_value = new_value_flat.reshape(original_shape)
|
||||
new_value = new_value.transpose(1, 2)
|
||||
weight_scale = weight_scale_flat.reshape(original_shape[0], 1, original_shape[1])
|
||||
elif tensor_name == "down_proj":
|
||||
# Process each expert separately
|
||||
# Transpose the weights for proper quantization
|
||||
transposed_param = param_value.transpose(1, 2)
|
||||
|
||||
# Reshape to 2D for quantization
|
||||
original_shape = transposed_param.shape
|
||||
flattened_param = transposed_param.reshape(-1, original_shape[-1])
|
||||
|
||||
# Quantize using per column
|
||||
new_value_flat, weight_scale_flat = torch.ops.fbgemm.quantize_fp8_per_row(flattened_param)
|
||||
|
||||
# Reshape back to original dimensions
|
||||
new_value = new_value_flat.reshape(original_shape)
|
||||
new_value = new_value.transpose(1, 2)
|
||||
weight_scale = weight_scale_flat.reshape(original_shape[0], original_shape[1], 1)
|
||||
|
||||
module._parameters[f"{tensor_name}_scale"] = torch.nn.Parameter(weight_scale.to(target_device))
|
||||
else:
|
||||
new_value, weight_scale = torch.ops.fbgemm.quantize_fp8_per_row(param_value)
|
||||
module._parameters[f"{tensor_name}_scale"] = torch.nn.Parameter(
|
||||
weight_scale.view(weight_scale.shape[0], 1).to(target_device)
|
||||
)
|
||||
|
||||
module._parameters[tensor_name] = torch.nn.Parameter(new_value.to(target_device))
|
||||
|
||||
if unexpected_keys is not None and param_name in unexpected_keys:
|
||||
unexpected_keys.remove(param_name)
|
||||
@ -165,25 +212,29 @@ class FbgemmFp8HfQuantizer(HfQuantizer):
|
||||
):
|
||||
from ..integrations import replace_with_fbgemm_fp8_linear
|
||||
|
||||
tp_plan = model._tp_plan
|
||||
self.modules_to_not_convert = self.get_modules_to_not_convert(
|
||||
model, self.quantization_config.modules_to_not_convert, keep_in_fp32_modules
|
||||
)
|
||||
|
||||
config = model.config
|
||||
model = replace_with_fbgemm_fp8_linear(
|
||||
model,
|
||||
modules_to_not_convert=self.modules_to_not_convert,
|
||||
quantization_config=self.quantization_config,
|
||||
pre_quantized=self.pre_quantized,
|
||||
config=config,
|
||||
tp_plan=tp_plan,
|
||||
)
|
||||
|
||||
model.config.quantization_config = self.quantization_config
|
||||
|
||||
def update_missing_keys(self, model, missing_keys: List[str], prefix: str) -> List[str]:
|
||||
from ..integrations import FbgemmFp8Linear
|
||||
from ..integrations import FbgemmFp8Linear, FbgemmFp8Llama4TextExperts
|
||||
|
||||
not_missing_keys = []
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, FbgemmFp8Linear):
|
||||
if isinstance(module, FbgemmFp8Linear) or isinstance(module, FbgemmFp8Llama4TextExperts):
|
||||
for missing in missing_keys:
|
||||
if (
|
||||
(name in missing or name in f"{prefix}.{missing}")
|
||||
|
@ -3950,7 +3950,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
||||
verbose (`bool`): Whether or not to print more information and warnings.
|
||||
|
||||
"""
|
||||
if max_length is None and len(ids) > self.model_max_length and verbose:
|
||||
if max_length is None and len(ids) > self.model_max_length and verbose and self.model_max_length != 0:
|
||||
if not self.deprecation_warnings.get("sequence-length-is-longer-than-the-specified-maximum", False):
|
||||
logger.warning(
|
||||
"Token indices sequence length is longer than the specified maximum sequence length "
|
||||
|
@ -5823,6 +5823,41 @@ class LlamaPreTrainedModel(metaclass=DummyObject):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class Llama4ForCausalLM(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class Llama4ForConditionalGeneration(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class Llama4PreTrainedModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class Llama4TextModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class Llama4VisionModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class LlavaForConditionalGeneration(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
|
@ -72,6 +72,13 @@ class GotOcr2ImageProcessorFast(metaclass=DummyObject):
|
||||
requires_backends(self, ["torchvision"])
|
||||
|
||||
|
||||
class Llama4ImageProcessorFast(metaclass=DummyObject):
|
||||
_backends = ["torchvision"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torchvision"])
|
||||
|
||||
|
||||
class LlavaImageProcessorFast(metaclass=DummyObject):
|
||||
_backends = ["torchvision"]
|
||||
|
||||
|
@ -408,6 +408,13 @@ class LevitImageProcessor(metaclass=DummyObject):
|
||||
requires_backends(self, ["vision"])
|
||||
|
||||
|
||||
class Llama4ImageProcessor(metaclass=DummyObject):
|
||||
_backends = ["vision"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["vision"])
|
||||
|
||||
|
||||
class LlavaImageProcessor(metaclass=DummyObject):
|
||||
_backends = ["vision"]
|
||||
|
||||
|
0
tests/models/llama4/__init__.py
Normal file
0
tests/models/llama4/__init__.py
Normal file
128
tests/models/llama4/test_image_processing_llama4.py
Normal file
128
tests/models/llama4/test_image_processing_llama4.py
Normal file
@ -0,0 +1,128 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import unittest
|
||||
|
||||
from transformers.testing_utils import require_torch, require_vision
|
||||
from transformers.utils import is_torch_available, is_torchvision_available, is_vision_available
|
||||
|
||||
from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
pass
|
||||
|
||||
if is_vision_available() and is_torchvision_available():
|
||||
from transformers import Llama4ImageProcessorFast
|
||||
|
||||
|
||||
class Llama4ImageProcessingTester(unittest.TestCase):
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
batch_size=7,
|
||||
num_channels=3,
|
||||
image_size=18,
|
||||
min_resolution=30,
|
||||
max_resolution=400,
|
||||
max_patches=1,
|
||||
do_resize=True,
|
||||
size=None,
|
||||
do_normalize=True,
|
||||
do_pad=False,
|
||||
image_mean=[0.5, 0.5, 0.5],
|
||||
image_std=[0.5, 0.5, 0.5],
|
||||
do_convert_rgb=True,
|
||||
):
|
||||
super().__init__()
|
||||
size = size if size is not None else {"height": 20, "width": 20}
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
self.num_channels = num_channels
|
||||
self.image_size = image_size
|
||||
self.min_resolution = min_resolution
|
||||
self.max_resolution = max_resolution
|
||||
self.max_patches = max_patches
|
||||
self.do_resize = do_resize
|
||||
self.size = size
|
||||
self.do_normalize = do_normalize
|
||||
self.image_mean = image_mean
|
||||
self.image_std = image_std
|
||||
self.do_pad = do_pad
|
||||
self.do_convert_rgb = do_convert_rgb
|
||||
|
||||
def prepare_image_processor_dict(self):
|
||||
return {
|
||||
"max_patches": self.max_patches,
|
||||
"do_resize": self.do_resize,
|
||||
"size": self.size,
|
||||
"do_normalize": self.do_normalize,
|
||||
"image_mean": self.image_mean,
|
||||
"image_std": self.image_std,
|
||||
"do_convert_rgb": self.do_convert_rgb,
|
||||
"do_pad": self.do_pad,
|
||||
}
|
||||
|
||||
def expected_output_image_shape(self, images):
|
||||
return self.num_channels, self.size["height"], self.size["width"]
|
||||
|
||||
def prepare_image_inputs(self, equal_resolution=False, numpify=False, torchify=False):
|
||||
return prepare_image_inputs(
|
||||
batch_size=self.batch_size,
|
||||
num_channels=self.num_channels,
|
||||
min_resolution=self.min_resolution,
|
||||
max_resolution=self.max_resolution,
|
||||
equal_resolution=equal_resolution,
|
||||
numpify=numpify,
|
||||
torchify=torchify,
|
||||
)
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_vision
|
||||
class Llama4ImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
|
||||
test_slow_image_processor = False
|
||||
fast_image_processing_class = Llama4ImageProcessorFast if is_torchvision_available() else None
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.image_processor_tester = Llama4ImageProcessingTester(self)
|
||||
|
||||
@property
|
||||
def image_processor_dict(self):
|
||||
return self.image_processor_tester.prepare_image_processor_dict()
|
||||
|
||||
def test_image_processor_properties(self):
|
||||
for image_processing_class in self.image_processor_list:
|
||||
image_processor = image_processing_class(**self.image_processor_dict)
|
||||
self.assertTrue(hasattr(image_processor, "do_resize"))
|
||||
self.assertTrue(hasattr(image_processor, "size"))
|
||||
self.assertTrue(hasattr(image_processor, "do_normalize"))
|
||||
self.assertTrue(hasattr(image_processor, "image_mean"))
|
||||
self.assertTrue(hasattr(image_processor, "image_std"))
|
||||
self.assertTrue(hasattr(image_processor, "do_convert_rgb"))
|
||||
|
||||
def test_split_tiles(self):
|
||||
for image_processing_class in self.image_processor_list:
|
||||
image_processor = image_processing_class(**self.image_processor_dict)
|
||||
image = self.image_processor_tester.prepare_image_inputs(equal_resolution=True)[0]
|
||||
processed_images = image_processor(
|
||||
image,
|
||||
max_patches=16,
|
||||
)
|
||||
self.assertEqual(len(processed_images.pixel_values), 1)
|
||||
self.assertEqual(processed_images.pixel_values[0].shape[0], 17)
|
||||
self.assertEqual(processed_images.pixel_values[0].shape[-2:], (20, 20))
|
121
tests/models/llama4/test_modeling_llama4.py
Normal file
121
tests/models/llama4/test_modeling_llama4.py
Normal file
@ -0,0 +1,121 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Testing suite for the PyTorch Llama4 model."""
|
||||
|
||||
import unittest
|
||||
|
||||
from transformers import is_torch_available
|
||||
from transformers.testing_utils import (
|
||||
require_read_token,
|
||||
require_torch_large_gpu,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers import (
|
||||
Llama4ForConditionalGeneration,
|
||||
Llama4Processor,
|
||||
)
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_large_gpu
|
||||
@require_read_token
|
||||
class Llama4IntegrationTest(unittest.TestCase):
|
||||
model_id = "ll-re/Llama-4-17B-Omni-Instruct"
|
||||
# This variable is used to determine which CUDA device are we using for our runners (A10 or T4)
|
||||
# Depending on the hardware we get different logits / generations
|
||||
cuda_compute_capability_major_version = None
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
if is_torch_available() and torch.cuda.is_available():
|
||||
# 8 is for A100 / A10 and 7 for T4
|
||||
cls.cuda_compute_capability_major_version = torch.cuda.get_device_capability()[0]
|
||||
cls.model = Llama4ForConditionalGeneration.from_pretrained(
|
||||
"ll-re/Llama-4-17B-Omni-Instruct", device_map="auto", torch_dtype=torch.float32
|
||||
)
|
||||
|
||||
def setUp(self):
|
||||
self.processor = Llama4Processor.from_pretrained("ll-re/Llama-4-17B-Omni-Instruct", padding_side="left")
|
||||
|
||||
url = "https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/cow_beach_1.png"
|
||||
self.messages = [
|
||||
{"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image", "url": url},
|
||||
{"type": "text", "text": "What is shown in this image?"},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
def test_model_17b_16e_fp16(self):
|
||||
EXPECTED_TEXT = [
|
||||
"The capital of France is Paris, which is located in the north-central part of the country. Paris is known for its iconic landmarks such as the",
|
||||
"Roses are red, violets are blue, and this poem is about you. Roses are red, violets are blue, and I love",
|
||||
]
|
||||
|
||||
messages = [
|
||||
{"role": "user", "content": "Who are you?"},
|
||||
]
|
||||
inputs = self.processor.apply_chat_template(
|
||||
messages, add_generation_prompt=True, return_tensors="pt", return_dict=True
|
||||
).to(torch_device)
|
||||
|
||||
output = self.model.generate(**inputs, max_new_tokens=100)
|
||||
output_text = self.processor.batch_decode(output, skip_special_tokens=True)
|
||||
|
||||
print(output_text)
|
||||
self.assertEqual(output_text, EXPECTED_TEXT)
|
||||
|
||||
def test_model_17b_16e_batch(self):
|
||||
messages_2 = [
|
||||
{"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image",
|
||||
"url": "https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/cow_beach_1.png",
|
||||
},
|
||||
{"type": "image", "url": "https://www.ilankelman.org/stopsigns/australia.jpg"},
|
||||
{"type": "text", "text": "Are these images identical?"},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
inputs = self.processor.apply_chat_template(
|
||||
[self.messages, messages_2],
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
add_generation_prompt=True,
|
||||
).to(torch_device)
|
||||
|
||||
output = self.model.generate(**inputs, max_new_tokens=30, do_sample=False)
|
||||
output_text = self.processor.batch_decode(output, skip_special_tokens=True)
|
||||
|
||||
EXPECTED_TEXTS = [
|
||||
'user\nYou are a helpful assistant.\n\n\n\n\n\nWhat is shown in this image?\nmodel\nCertainly! \n\nThe image shows a brown cow standing on a sandy beach with clear turquoise water and a blue sky in the background. It looks like',
|
||||
"user\nYou are a helpful assistant.\n\n\n\n\n\n\n\n\n\nAre these images identical?\nmodel\nNo, these images are not identical. \n\nHere's a breakdown of the differences:\n\n* **Image 1:** Shows a cow"
|
||||
] # fmt: skip
|
||||
self.assertEqual(output_text, EXPECTED_TEXTS)
|
65
tests/models/llama4/test_processor_llama4.py
Normal file
65
tests/models/llama4/test_processor_llama4.py
Normal file
@ -0,0 +1,65 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import shutil
|
||||
import tempfile
|
||||
import unittest
|
||||
from typing import Optional
|
||||
|
||||
from transformers import AutoProcessor, Llama4Processor, PreTrainedTokenizerFast
|
||||
from transformers.testing_utils import require_vision
|
||||
from transformers.utils import is_vision_available
|
||||
|
||||
from ...test_processing_common import ProcessorTesterMixin
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
from transformers import Llama4ImageProcessorFast
|
||||
|
||||
|
||||
@require_vision
|
||||
class Llama4ProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
processor_class = Llama4Processor
|
||||
|
||||
def setUp(self):
|
||||
self.tmpdirname = tempfile.mkdtemp()
|
||||
|
||||
image_processor = Llama4ImageProcessorFast(max_patches=1, size={"height": 20, "width": 20})
|
||||
tokenizer = PreTrainedTokenizerFast.from_pretrained("unsloth/Llama-3.2-11B-Vision-Instruct-unsloth-bnb-4bit")
|
||||
processor_kwargs = self.prepare_processor_dict()
|
||||
processor = Llama4Processor(image_processor, tokenizer, **processor_kwargs)
|
||||
processor.save_pretrained(self.tmpdirname)
|
||||
|
||||
def get_tokenizer(self, **kwargs):
|
||||
return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).tokenizer
|
||||
|
||||
def get_image_processor(self, **kwargs):
|
||||
return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).image_processor
|
||||
|
||||
def tearDown(self):
|
||||
shutil.rmtree(self.tmpdirname)
|
||||
|
||||
# Override as Llama4ProcessorProcessor needs image tokens in prompts
|
||||
def prepare_text_inputs(self, batch_size: Optional[int] = None):
|
||||
if batch_size is None:
|
||||
return "lower newer <image>"
|
||||
|
||||
if batch_size < 1:
|
||||
raise ValueError("batch_size must be greater than 0")
|
||||
|
||||
if batch_size == 1:
|
||||
return ["lower newer <image>"]
|
||||
return ["lower newer <image>", "<image> upper older longer string"] + ["<image> lower newer"] * (
|
||||
batch_size - 2
|
||||
)
|
@ -236,6 +236,16 @@ SPECIAL_CASES_TO_ALLOW = {
|
||||
"text_config",
|
||||
"vision_config",
|
||||
],
|
||||
"Llama4Config": ["boi_token_index", "eoi_token_index"],
|
||||
"Llama4TextConfig": [
|
||||
"interleave_moe_layer_step",
|
||||
"no_rope_layer_interval",
|
||||
"no_rope_layers",
|
||||
"output_router_logits",
|
||||
"router_aux_loss_coef",
|
||||
"router_jitter_noise",
|
||||
],
|
||||
"Llama4VisionConfig": ["multi_modal_projector_bias", "norm_eps"],
|
||||
}
|
||||
|
||||
|
||||
@ -358,6 +368,8 @@ def check_attribute_being_used(config_class, attributes, default_value, source_s
|
||||
"rope_theta",
|
||||
"partial_rotary_factor",
|
||||
"pretraining_tp",
|
||||
"boi_token_index",
|
||||
"eoi_token_index",
|
||||
]
|
||||
attributes_used_in_generation = ["encoder_no_repeat_ngram_size"]
|
||||
|
||||
|
@ -67,6 +67,7 @@ _re_parse_description = re.compile(r"\*optional\*, defaults to (.*)$")
|
||||
# docstrings instead. If formatting should be ignored for the docstring, you can put a comment # no-format on the
|
||||
# line before the docstring.
|
||||
OBJECTS_TO_IGNORE = [
|
||||
"Llama4Processor",
|
||||
# Deprecated
|
||||
"InputExample",
|
||||
"InputFeatures",
|
||||
|
@ -223,13 +223,20 @@ def check_dummies(overwrite: bool = False):
|
||||
f.write(dummy_files[backend])
|
||||
else:
|
||||
# Temporary fix to help people identify which objects introduced are not correctly protected.
|
||||
found = False
|
||||
for _actual, _dummy in zip(
|
||||
actual_dummies["torch"].split("class"), dummy_files["torch"].split("class")
|
||||
):
|
||||
if _actual != _dummy:
|
||||
actual_broken = _actual
|
||||
dummy_broken = _dummy
|
||||
found = True
|
||||
break
|
||||
|
||||
if not found:
|
||||
print("A transient error was found with the dummies, please investigate.")
|
||||
continue
|
||||
|
||||
raise ValueError(
|
||||
"The main __init__ has objects that are not present in "
|
||||
f"transformers.utils.dummy_{short_names.get(backend, backend)}_objects.py.\n"
|
||||
|
@ -144,6 +144,8 @@ IGNORE_NON_TESTED = (
|
||||
"Qwen2_5_VLModel", # Building part of bigger (tested) model. Tested implicitly through Qwen2_5_VLForConditionalGeneration.
|
||||
"MllamaTextModel", # Building part of bigger (tested) model. # TODO: add tests
|
||||
"MllamaVisionModel", # Building part of bigger (tested) model. # TODO: add tests
|
||||
"Llama4TextModel", # Building part of bigger (tested) model. # TODO: add tests
|
||||
"Llama4VisionModel", # Building part of bigger (tested) model. # TODO: add tests
|
||||
"Emu3VQVAE", # Building part of bigger (tested) model
|
||||
"Emu3TextModel", # Building part of bigger (tested) model
|
||||
]
|
||||
@ -170,6 +172,7 @@ TEST_FILES_WITH_NO_COMMON_TESTS = [
|
||||
"models/decision_transformer/test_modeling_decision_transformer.py",
|
||||
"models/bark/test_modeling_bark.py",
|
||||
"models/shieldgemma2/test_modeling_shieldgemma2.py",
|
||||
"models/llama4/test_modeling_llama4.py",
|
||||
]
|
||||
|
||||
# Update this list for models that are not in any of the auto MODEL_XXX_MAPPING. Being in this list is an exception and
|
||||
|
Loading…
Reference in New Issue
Block a user