transformers/docs/source/en/model_doc/llava_next.md
2025-06-29 22:48:08 -04:00

8.0 KiB
Raw Blame History

PyTorch FlashAttention Multimodal

LLaVA-NeXT

LLaVANeXT improves on Llava by increasing the input image resolution by 4x more pixels and supporting 3 aspect ratios (up to 672x672, 336x1344, 1344x336) to better grasp visual details. It is also trained on an improved visual instruction tuning dataset covering more scenarios and applications to improve OCR and common sense reasoning.

You can find all the original LLaVANeXT checkpoints under the LLaVA-NeXT collection.

Tip

Click on the LLaVANeXT models in the right sidebar for more examples of how to apply Llava-NeXT to different multimodal tasks.

The example below demonstrates how to generate text based on an image with [Pipeline] or the [AutoModel] class.

from transformers import pipeline
from PIL import Image
import requests

pipe = pipeline("image-to-text", model="llava-hf/llava-v1.6-mistral-7b-hf", device="cuda")
image = Image.open(requests.get("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/llava_next_ocr.png", stream=True).raw)

result = pipe(image, prompt="What does this chart show?")
print(result[0]["generated_text"])
from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration
from PIL import Image
import requests, torch

processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf")
model = LlavaNextForConditionalGeneration.from_pretrained(
    "llava-hf/llava-v1.6-mistral-7b-hf", torch_dtype=torch.float16
).to("cuda")

image = Image.open(requests.get(
    "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/llava_next_ocr.png", stream=True).raw)

conversation = [
    {"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "What does this chart show?"}]}
]
prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
inputs = processor(image, prompt, return_tensors="pt").to("cuda")

output = model.generate(**inputs, max_new_tokens=100)
print(processor.decode(output[0], skip_special_tokens=True))

Quantization reduces the memory burden of large models by representing the weights in a lower precision. Refer to the Quantization overview for more available quantization backends.

The example below uses bitsandbytes to only quantize the weights to int4.

from transformers import AutoModelForImageTextToText, AutoProcessor, BitsAndBytesConfig

quant_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_quant_type="nf4"
)

model = AutoModelForImageTextToText.from_pretrained(
    "llava-hf/llava-v1.6-mistral-7b-hf",
    quantization_config=quant_config,
    device_map="auto"
)
processor = AutoProcessor.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf")

Use the AttentionMaskVisualizer to explore which tokens the model attends to:

from transformers.utils.attention_visualizer import AttentionMaskVisualizer

viz = AttentionMaskVisualizer("llava-hf/llava-v1.6-mistral-7b-hf")
viz("<image> What is shown in this image?")

Notes

  • Different checkpoints (Mistral, Vicuna, etc.) require a specific prompt format depending on the underlying LLM. Always use [~ProcessorMixin.apply_chat_template] to ensure correct formatting. Refer to the Templates guide for more details.

  • Set padding_side="left" during batched generation for more accurate results.

processor.tokenizer.padding_side = "left"
  • LLaVA-NeXT uses different numbers of patches for images and pads the inputs inside the modeling code except when padding is done during processing. The default setting is left-padding if the model is in eval() mode, otherwise it is right-padding.

  • LLaVA models after v4.46 raises warnings about adding processor.patch_size = {{patch_size}}, processor.num_additional_image_tokens = {{num_additional_image_tokens}}, and processor.vision_feature_select_strategy = {{vision_feature_select_strategy}}. It is strongly recommended to add these attributes to the processor if you own the model checkpoint or open a PR if it isn't.

    Adding these attributes means LLaVA will try to infer the number of image tokens required per image and expand the text with the same number of <image> token placeholders. There are usually ~500 tokens per image, so make sure the text is not truncated because it will cause a failure when merging the embeddings. The attributes can be found in model.config.vision_config.patch_size or model.config.vision_feature_select_strategy.

    The num_additional_image_tokens should be 1 if the vision backbone adds a CLS token or 0 if nothing extra is added.

  • The example below demonstrates inference with multiple input images.

from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration
from PIL import Image
import requests, torch

processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf")
model = LlavaNextForConditionalGeneration.from_pretrained(
    "llava-hf/llava-v1.6-mistral-7b-hf", torch_dtype=torch.float16
).to("cuda")

# Load multiple images
url1 = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/llava_next_ocr.png"
url2 = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/llava_next_comparison.png"

image1 = Image.open(requests.get(url1, stream=True).raw)
image2 = Image.open(requests.get(url2, stream=True).raw)

conversation = [
    {"role": "user", "content": [{"type": "image"}, {"type": "image"}, {"type": "text", "text": "Compare these two images and describe the differences."}]}
]
prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
inputs = processor([image1, image2], prompt, return_tensors="pt").to("cuda")

output = model.generate(**inputs, max_new_tokens=100)
print(processor.decode(output[0], skip_special_tokens=True))

LlavaNextConfig

autodoc LlavaNextConfig

LlavaNextImageProcessor

autodoc LlavaNextImageProcessor - preprocess

LlavaNextImageProcessorFast

autodoc LlavaNextImageProcessorFast - preprocess

LlavaNextProcessor

autodoc LlavaNextProcessor

LlavaNextModel

autodoc LlavaNextModel

LlavaNextForConditionalGeneration

autodoc LlavaNextForConditionalGeneration - forward