transformers/docs/source/en/model_doc/paligemma.md
Raushan Turganbay 17742bd9c8
🔴 [VLM] Add base model without head (#37033)
* i guessreverted all CdGen classes

* style

* llava onevision

* fix copies

* fix some tests

* some more tests

* dump

* skip these

* nevermind, i am dumb

* revert fix not needed

* fixup

* fixup

* another fixup

* more fixup to make ci finally happy

* fixup after rebasing

* fix qwen tests

* add internVL + typos here and there

* image token index -> id

* style

* fix init weights

* revert blip-2 not supported

* address comments

* fix copies

* revert blip2 test file as well

* as discussed internally, revert back CdGen models

* fix some tests

* fix more tests for compile

* CI red

* fix copies

* enumerate explicitly allowed models

* address comments

* fix tests

* fixup

* style again

* add tests for new model class

* another fixup ( x _ x )

* [fixup] unused attributes can be removed post-deprecation
2025-05-07 17:47:51 +02:00

7.9 KiB

PyTorch FlashAttention SDPA

PaliGemma

PaliGemma is a family of vision-language models (VLMs), combining SigLIP with the Gemma 2B model. PaliGemma is available in 3B, 10B, and 28B parameters. The main purpose of PaliGemma is to provide an adaptable base VLM that is easy to transfer to other tasks. The SigLIP vision encoder is a "shape optimized" contrastively pretrained ViT that converts an image into a sequence of tokens and prepended to an optional prompt. The Gemma 2B model is used as the decoder. PaliGemma uses full attention on all image and text tokens to maximize its capacity.

PaliGemma 2 improves on the first model by using Gemma 2 (2B, 9B, and 27B parameter variants) as the decoder. These are available as pt or mix variants. The pt checkpoints are intended for further fine-tuning and the mix checkpoints are ready for use out of the box.

You can find all the original PaliGemma checkpoints under the PaliGemma, PaliGemma 2, and PaliGemma 2 Mix collections.

Tip

Click on the PaliGemma models in the right sidebar for more examples of how to apply PaliGemma to different vision and language tasks.

The example below demonstrates how to generate text based on an image with [Pipeline] or the [AutoModel] class.

import torch
from transformers import pipeline

pipeline = pipeline(
    task="image-text-to-text",
    model="google/paligemma2-3b-mix-224",
    device=0,
    torch_dtype=torch.bfloat16
)
pipeline(
    "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg",
    text="What is in this image?"
)
import torch
import requests
from PIL import Image
from transformers import AutoProcessor, PaliGemmaForConditionalGeneration

model = PaliGemmaForConditionalGeneration.from_pretrained(
    "google/paligemma2-3b-mix-224",
    torch_dtype=torch.bfloat16,
    device_map="auto",
    attn_implementation="sdpa"
)
processor = AutoProcessor.from_pretrained(
    "google/paligemma2-3b-mix-224",
)

prompt = "What is in this image?"
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"
image = Image.open(requests.get(url, stream=True).raw)
inputs = processor(image, prompt, return_tensors="pt").to("cuda")

output = model.generate(**inputs, max_new_tokens=50, cache_implementation="static")
print(processor.decode(output[0], skip_special_tokens=True))

Quantization reduces the memory burden of large models by representing the weights in a lower precision. Refer to the Quantization overview for more available quantization backends.

The example below uses torchao to only quantize the weights to int4.

# pip install torchao
import torch
import requests
from PIL import Image
from transformers import TorchAoConfig, AutoProcessor, PaliGemmaForConditionalGeneration

quantization_config = TorchAoConfig("int4_weight_only", group_size=128)
model = PaliGemmaForConditionalGeneration.from_pretrained(
    "google/paligemma2-28b-mix-224",
    torch_dtype=torch.bfloat16,
    device_map="auto",
    quantization_config=quantization_config
)
processor = AutoProcessor.from_pretrained(
    "google/paligemma2-28b-mix-224",
)

prompt = "What is in this image?"
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"
image = Image.open(requests.get(url, stream=True).raw)
inputs = processor(image, prompt, return_tensors="pt").to("cuda")

output = model.generate(**inputs, max_new_tokens=50, cache_implementation="static")
print(processor.decode(output[0], skip_special_tokens=True))

Use the AttentionMaskVisualizer to better understand what tokens the model can and cannot attend to.

from transformers.utils.attention_visualizer import AttentionMaskVisualizer

visualizer = AttentionMaskVisualizer("google/paligemma2-3b-mix-224")
visualizer("<img> What is in this image?")

Notes

  • PaliGemma is not a conversational model and works best when fine-tuned for specific downstream tasks such as image captioning, visual question answering (VQA), object detection, and document understanding.

  • [PaliGemmaProcessor] can prepare images, text, and optional labels for the model. Pass the suffix parameter to the processor to create labels for the model during fine-tuning.

    prompt = "What is in this image?"
    answer = "a pallas cat"
    inputs = processor(images=image, text=prompt, suffix=answer, return_tensors="pt")
    
  • PaliGemma can support multiple input images if it is fine-tuned to accept multiple images. For example, the NLVR2 checkpoint supports multiple images. Pass the images as a list to the processor.

    import torch
    import requests
    from PIL import Image
    from transformers import TorchAoConfig, AutoProcessor, PaliGemmaForConditionalGeneration
    
    model = PaliGemmaForConditionalGeneration.from_pretrained("google/paligemma-3b-ft-nlvr2-448")
    processor = AutoProcessor.from_pretrained("google/paligemma-3b-ft-nlvr2-448")
    
    prompt = "Are these two images the same?"
    cat_image = Image.open(
        requests.get("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg", stream=True).raw
    )
    cow_image = Image.open(
        requests.get(
            "https://media.istockphoto.com/id/1192867753/photo/cow-in-berchida-beach-siniscola.jpg?s=612x612&w=0&k=20&c=v0hjjniwsMNfJSuKWZuIn8pssmD5h5bSN1peBd1CmH4=", stream=True
        ).raw
    )
    
    inputs = processor(images=[[cat_image, cow_image]], text=prompt, return_tensors="pt")
    
    output = model.generate(**inputs, max_new_tokens=20, cache_implementation="static")
    print(processor.decode(output[0], skip_special_tokens=True))
    

PaliGemmaConfig

autodoc PaliGemmaConfig

PaliGemmaProcessor

autodoc PaliGemmaProcessor

PaliGemmaModel

autodoc PaliGemmaModel

PaliGemmaForConditionalGeneration

autodoc PaliGemmaForConditionalGeneration - forward