
* update siglip2 model card * Update docs/source/en/model_doc/siglip2.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Update docs/source/en/model_doc/siglip2.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Update docs/source/en/model_doc/siglip2.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Update docs/source/en/model_doc/siglip2.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Update docs/source/en/model_doc/siglip2.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Update docs/source/en/model_doc/siglip2.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * address comments * separate naflex and fixres variant * Update docs/source/en/model_doc/siglip2.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Update docs/source/en/model_doc/siglip2.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Update docs/source/en/model_doc/siglip2.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> --------- Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
8.7 KiB
SigLIP2
Overview
SigLIP2 is a family of multilingual vision-language encoders that builds on the SigLIP training recipe. It includes decoder-based pretraining, self-distillation, and masked prediction to improve dense prediction tasks (segmentation, depth estimation, etc.). This model is available in two variants:
- NaFlex supports different resolutions and maintains the native image aspect ratio
- FixRes supports fixed resolutions and is backwards compatible with SigLIP
You can find all the original SigLIP2 checkpoints under the SigLIP2 collection.
Tip
Click on the SigLIP2 models in the right sidebar for more examples of how to apply SigLIP2 to different image and text tasks.
The example below demonstrates zero-shot classification with [Pipeline
] or the [AutoModel
] class.
import torch
from transformers import pipeline
image = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"
candidate_labels = ["a Pallas cat", "a lion", "a Siberian tiger"]
pipeline = pipeline(task="zero-shot-image-classification", model="google/siglip2-base-patch16-224", device=0, torch_dtype=torch.bfloat16)
pipeline(image, candidate_labels=candidate_labels)
import torch
import requests
from PIL import Image
from transformers import AutoProcessor, AutoModel
model = AutoModel.from_pretrained("google/siglip2-base-patch16-224", torch_dtype=torch.float16, device_map="auto", attn_implementation="sdpa")
processor = AutoProcessor.from_pretrained("google/siglip2-base-patch16-224")
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"
image = Image.open(requests.get(url, stream=True).raw)
candidate_labels = ["a Pallas cat", "a lion", "a Siberian tiger"]
# follows the pipeline prompt template to get same results
texts = [f'This is a photo of {label}.' for label in candidate_labels]
# IMPORTANT: we pass `padding=max_length` and `max_length=64` since the model was trained with this
inputs = processor(text=texts, images=image, padding="max_length", max_length=64, return_tensors="pt").to("cuda")
with torch.no_grad():
outputs = model(**inputs)
logits_per_image = outputs.logits_per_image
probs = torch.sigmoid(logits_per_image)
print(f"{probs[0][0]:.1%} that image 0 is '{candidate_labels[0]}'")
import torch
import requests
from PIL import Image
from transformers import AutoProcessor, AutoModel
model = AutoModel.from_pretrained("google/siglip2-base-patch16-naflex", torch_dtype=torch.float16, device_map="auto", attn_implementation="sdpa")
processor = AutoProcessor.from_pretrained("google/siglip2-base-patch16-naflex")
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"
image = Image.open(requests.get(url, stream=True).raw)
candidate_labels = ["a Pallas cat", "a lion", "a Siberian tiger"]
texts = [f'This is a photo of {label}.' for label in candidate_labels]
# default value for `max_num_patches` is 256, but you can increase resulted image resolution providing higher values e.g. `max_num_patches=512`
inputs = processor(text=texts, images=image, padding="max_length", max_num_patches=256, return_tensors="pt").to("cuda")
with torch.no_grad():
outputs = model(**inputs)
logits_per_image = outputs.logits_per_image
probs = torch.sigmoid(logits_per_image)
print(f"{probs[0][0]:.1%} that image 0 is '{candidate_labels[0]}'")
Quantization reduces the memory burden of large models by representing the weights in a lower precision. Refer to the Quantization overview for more available quantization backends.
The example below uses bitsandbytes to only quantize the weights to int4.
import torch
import requests
from PIL import Image
from transformers import AutoProcessor, AutoModel, BitsAndBytesConfig
bnb_config = BitsAndBytesConfig(load_in_4bit=True)
model = AutoModel.from_pretrained("google/siglip2-large-patch16-512", quantization_config=bnb_config, device_map="auto", attn_implementation="sdpa")
processor = AutoProcessor.from_pretrained("google/siglip2-base-patch16-224")
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"
image = Image.open(requests.get(url, stream=True).raw)
candidate_labels = ["a Pallas cat", "a lion", "a Siberian tiger"]
# follows the pipeline prompt template to get same results
texts = [f'This is a photo of {label}.' for label in candidate_labels]
# IMPORTANT: we pass `padding=max_length` and `max_length=64` since the model was trained with this
inputs = processor(text=texts, images=image, padding="max_length", max_length=64, return_tensors="pt").to("cuda")
with torch.no_grad():
outputs = model(**inputs)
logits_per_image = outputs.logits_per_image
probs = torch.sigmoid(logits_per_image)
print(f"{probs[0][0]:.1%} that image 0 is '{candidate_labels[0]}'")
Notes
-
Training is supported for DDP and FSDP on single-node multi-GPU setups. However, it does not use torch.distributed utilities which may limit the scalability of batch size.
-
When using the standalone [
GemmaTokenizerFast
] make sure to passpadding="max_length"
andmax_length=64
as that's how the model was trained. -
Model was trained with lowercased text, so make sure your text labels are preprocessed the same way.
-
To get the same results as the [
Pipeline
], a prompt template of"This is a photo of {label}."
should be passed to the processor. -
The NaFlex variant processes different types of images at the appropriate resolution (using a larger resolution to process document images for example), while also minimizing the impact of aspect ratio distortion for certain inference tasks like OCR.
NaFlex resizes the input image so the height and width are multiples of the patch size after resizing. It keeps the aspect ratio distortion as low as possible and produces a sequence length of at most the desired target sequence length (
max_num_patches
). After resizing, the image is split into a sequence of patches and a mask with padding information is added. -
Toggle the
attn_implementation
parameter to either"sdpa"
or"flash_attention_2"
to use a more memory-efficient attention.# pip install -U flash-attn --no-build-isolation from transformers import SiglipModel model = SiglipModel.from_pretrained( "google/siglip2-so400m-patch14-384", attn_implementation="flash_attention_2", torch_dtype=torch.float16, device_map=device, )
Siglip2Config
autodoc Siglip2Config
Siglip2TextConfig
autodoc Siglip2TextConfig
Siglip2VisionConfig
autodoc Siglip2VisionConfig
Siglip2ImageProcessor
autodoc Siglip2ImageProcessor - preprocess
Siglip2ImageProcessorFast
autodoc Siglip2ImageProcessorFast - preprocess
Siglip2Processor
autodoc Siglip2Processor
Siglip2Model
autodoc Siglip2Model - forward - get_text_features - get_image_features
Siglip2TextModel
autodoc Siglip2TextModel - forward
Siglip2VisionModel
autodoc Siglip2VisionModel - forward
Siglip2ForImageClassification
autodoc Siglip2ForImageClassification - forward