mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-04 05:10:06 +06:00
Add SigLIP 2 (#36323)
* Docs * Inits * Auto classes * Add siglip base * Add base tests * Fix Siglip V1 for fix res version * Add image processor * Update conversion * Experimenting with vectorized embeddings * Fixup * Add modular Siglip2Processor * Add modular configuration * Rename num patches * Correct image and text features merging * Working conversion script * Refactoring conversion script * Remove unused code in conversion script * Shorten dict a bit * Refactoring conversion * Done conversion refactoring * Fixup * Modular siglip2 * Make model exportable and compilable without graph breaks * Remove position_ids from image_processor * REmove position ids from modeling file * Update modular * Type hint * Fixup * Set defaults to processor * Add integration test * Revert spatial shapes back to tensor * Change order * Fix most of the tests * Fix docstring * Remove interpolate_pos_encoding arg (not needed) * Update docs * Standardize processing * Fix attention_mask in vision head * Siglip v1: remove double transpose in FA2 * Update modular file * Update FA2 test * Update expected logits * Fix interpolation for siglip2 image processor * Skip init test * Skip dispatch on flash test * Fix modeling tests * Fixup * Add dummy objects * Fix some docstrings * Add siglip2 in index.md * Fix consistency * Add docs * Remove size and data format * Add image processor tests * Fix * Add fast image processor * Fix style * Fix * Docs * Set lowercase for tokenizer * Adjust head size for Siglip v1 * Update siglip2 for consistency with siglip1 * Update siglip2 conversion * Update pipeline * Update checkpoints in tests * Update checkpoint name * Fix pooling for image classification model * Fix FA2 test * Update processor * Fix check repo * Update docs * Fix typos * Fix docstring for fast image processor * Add siglip2 to FA2 docs * Fix fast ip tests * Fix constitency * Fix tokenizer class for siglip v1 * Fix missing header * Refactor scaling for clip, siglip, siglip2 * Remove unused imports * Make fast IP default for siglip2 * Update docs * Update checkpoints * Update modular * Update paper link * Fixup * Fix name in toctree * Fix test
This commit is contained in:
parent
14552cbd7c
commit
a957b7911a
@ -965,6 +965,8 @@
|
||||
title: Segment Anything
|
||||
- local: model_doc/siglip
|
||||
title: SigLIP
|
||||
- local: model_doc/siglip2
|
||||
title: SigLIP2
|
||||
- local: model_doc/smolvlm
|
||||
title: SmolVLM
|
||||
- local: model_doc/speech-encoder-decoder
|
||||
|
@ -317,6 +317,7 @@ Flax), PyTorch, and/or TensorFlow.
|
||||
| [SEW](model_doc/sew) | ✅ | ❌ | ❌ |
|
||||
| [SEW-D](model_doc/sew-d) | ✅ | ❌ | ❌ |
|
||||
| [SigLIP](model_doc/siglip) | ✅ | ❌ | ❌ |
|
||||
| [SigLIP2](model_doc/siglip2) | ✅ | ❌ | ❌ |
|
||||
| [SmolVLM](model_doc/smolvlm) | ✅ | ❌ | ❌ |
|
||||
| [Speech Encoder decoder](model_doc/speech-encoder-decoder) | ✅ | ❌ | ✅ |
|
||||
| [Speech2Text](model_doc/speech_to_text) | ✅ | ✅ | ❌ |
|
||||
|
276
docs/source/en/model_doc/siglip2.md
Normal file
276
docs/source/en/model_doc/siglip2.md
Normal file
@ -0,0 +1,276 @@
|
||||
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
|
||||
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
||||
rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
|
||||
# SigLIP2
|
||||
|
||||
## Overview
|
||||
|
||||
The SigLIP2 model was proposed in [SigLIP 2: Multilingual Vision-Language Encoders with Improved Semantic Understanding, Localization, and Dense Features](https://huggingface.co/papers/2502.14786) by Michael Tschannen, Alexey Gritsenko, Xiao Wang, Muhammad Ferjad Naeem, Ibrahim Alabdulmohsin,
|
||||
Nikhil Parthasarathy, Talfan Evans, Lucas Beyer, Ye Xia, Basil Mustafa, Olivier Hénaff, Jeremiah Harmsen,
|
||||
Andreas Steiner and Xiaohua Zhai.
|
||||
|
||||
The model comes in two variants
|
||||
|
||||
1) FixRes - model works with fixed resolution images (backward compatible with SigLIP v1)
|
||||
2) NaFlex - model works with variable image aspect ratios and resolutions (SigLIP2 in `transformers`)
|
||||
|
||||
The abstract from the paper is the following:
|
||||
|
||||
*We introduce SigLIP 2, a family of new multilingual vision-language encoders that build on the success
|
||||
of the original SigLIP. In this second iteration, we extend the original image-text training objective with
|
||||
several prior, independently developed techniques into a unified recipe—this includes decoder-based
|
||||
pretraining, self-supervised losses (self-distillation, masked prediction) and online data curation. With
|
||||
these changes, SigLIP 2 models outperform their SigLIP counterparts at all model scales in core capabilities,
|
||||
including zero-shot classification (best SigLIP 2 ViT-g/16 achieves 85.0% ImageNet zero-shot
|
||||
accuracy), image-text retrieval, and transfer performance when extracting visual representations for
|
||||
Vision-Language Models (VLMs). Furthermore, the new training recipe leads to significant improvements
|
||||
on localization and dense prediction tasks. We also train variants which support multiple resolutions
|
||||
and preserve the input’s native aspect ratio. Finally, we train on a more diverse data-mixture that
|
||||
includes de-biasing techniques, leading to much better multilingual understanding and improved fair-
|
||||
ness. To provide users with the ability to trade-off inference cost with performance, we release model
|
||||
checkpoints at four sizes (ViT-B/86M, L/303M, So400m/400M, and g/1B).*
|
||||
|
||||
## Usage tips
|
||||
|
||||
- Usage of SigLIP2 is similar to [SigLIP](siglip) and [CLIP](clip). The main difference from CLIP is the training loss, which does not require a global view of all the pairwise similarities of images and texts within a batch. One needs to apply the sigmoid activation function to the logits, rather than the softmax.
|
||||
- Training is supported but does not use `torch.distributed` utilities which may limit the scalability of batch size. However, DDP and FDSP works on single-node multi-gpu setup.
|
||||
- When using the standalone [`GemmaTokenizerFast`] make sure to pass `padding="max_length"` and `max_length=64` as that's how the model was trained.
|
||||
- Model was trained with *lowercased* text, make sure you make the same preprocessing for your text labels.
|
||||
- To get the same results as the pipeline, a prompt template of "this is a photo of {label}" should be used.
|
||||
- The NaFlex variant supports processing images at higher resolutions by adjusting the `max_num_patches` parameter in the `Processor`. The default value is `max_num_patches=256`. Increasing `max_num_patches` to 1024 (4x) will approximately double processed image height and width, while preserving the aspect ratio.
|
||||
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/siglip2_metrics_table.png"
|
||||
alt="drawing" width="600"/>
|
||||
|
||||
This model was contributed by [qubvel](https://huggingface.co/qubvel-hf).
|
||||
The original code can be found [here](https://github.com/google-research/big_vision/tree/main).
|
||||
|
||||
## Usage example
|
||||
|
||||
There are 2 main ways to use SigLIP2: either using the pipeline API, which abstracts away all the complexity for you, or by using the `Siglip2Model` class yourself.
|
||||
|
||||
### FixRes variant
|
||||
|
||||
**Pipeline API**
|
||||
|
||||
The pipeline allows to use the model in a few lines of code:
|
||||
|
||||
```python
|
||||
>>> from transformers import pipeline
|
||||
>>> from PIL import Image
|
||||
>>> import requests
|
||||
|
||||
>>> # load pipe
|
||||
>>> image_classifier = pipeline(
|
||||
... task="zero-shot-image-classification",
|
||||
... model="google/siglip2-base-patch16-224",
|
||||
... )
|
||||
|
||||
>>> # load image
|
||||
>>> url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
|
||||
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
>>> # inference
|
||||
>>> candidate_labels = ["2 cats", "a plane", "a remote"]
|
||||
>>> outputs = image_classifier(image, candidate_labels=candidate_labels)
|
||||
>>> outputs = [{"score": round(output["score"], 4), "label": output["label"] } for output in outputs]
|
||||
>>> print(outputs)
|
||||
[{'score': 0.1499, 'label': '2 cats'}, {'score': 0.0008, 'label': 'a remote'}, {'score': 0.0, 'label': 'a plane'}]
|
||||
```
|
||||
|
||||
**Using the model yourself**
|
||||
|
||||
If you want to do the pre- and postprocessing yourself, here's how to do that:
|
||||
|
||||
```python
|
||||
>>> from PIL import Image
|
||||
>>> import requests
|
||||
>>> from transformers import AutoProcessor, AutoModel
|
||||
>>> import torch
|
||||
|
||||
>>> model = AutoModel.from_pretrained("google/siglip2-base-patch16-224")
|
||||
>>> processor = AutoProcessor.from_pretrained("google/siglip2-base-patch16-224")
|
||||
|
||||
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
>>> candidate_labels = ["2 cats", "2 dogs"]
|
||||
# follows the pipeline prompt template to get same results
|
||||
>>> texts = [f"This is a photo of {label}." for label in candidate_labels]
|
||||
|
||||
# IMPORTANT: we pass `padding=max_length` and `max_length=64` since the model was trained with this
|
||||
>>> inputs = processor(text=texts, images=image, padding="max_length", max_length=64, return_tensors="pt")
|
||||
|
||||
>>> with torch.no_grad():
|
||||
... outputs = model(**inputs)
|
||||
|
||||
>>> logits_per_image = outputs.logits_per_image
|
||||
>>> probs = torch.sigmoid(logits_per_image) # these are the probabilities
|
||||
>>> print(f"{probs[0][0]:.1%} that image 0 is '{candidate_labels[0]}'")
|
||||
15.0% that image 0 is '2 cats'
|
||||
```
|
||||
|
||||
### NaFlex variant
|
||||
|
||||
NaFlex combines ideas from FlexiViT, i.e. supporting multiple, predefined sequence lengths
|
||||
with a single ViT model, and NaViT, namely processing images at their native aspect ratio.
|
||||
This enables processing different types of images at appropriate resolution, e.g. using a
|
||||
larger resolution to process document images, while at the same time minimizing the impact
|
||||
of aspect ratio distortion on certain inference tasks, e.g. on OCR.
|
||||
|
||||
Given a patch size and target sequence length, NaFlex preprocesses the data by first resizing
|
||||
the input image such that the height and width after resizing are multiples of the patch size,
|
||||
while
|
||||
|
||||
1. keeping the aspect ratio distortion as small as possible
|
||||
2. producing a sequence length of at most the desired target sequence length (`max_num_patches`)
|
||||
|
||||
The resulting distortion in width and height is at most `(patch_size - 1) / width` and
|
||||
`(patch_size - 1) / height`, respectively, which tends to be small for common resolutions and aspect ratios.
|
||||
After resizing, the image is split into a sequence of patches, and a mask with padding information is added.
|
||||
|
||||
```python
|
||||
>>> from PIL import Image
|
||||
>>> import requests
|
||||
>>> from transformers import AutoProcessor, AutoModel
|
||||
>>> import torch
|
||||
|
||||
>>> model = AutoModel.from_pretrained("google/siglip2-base-patch16-naflex")
|
||||
>>> processor = AutoProcessor.from_pretrained("google/siglip2-base-patch16-naflex")
|
||||
|
||||
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
>>> candidate_labels = ["2 cats", "2 dogs"]
|
||||
# follows the pipeline prompt template to get same results
|
||||
>>> texts = [f"This is a photo of {label}." for label in candidate_labels]
|
||||
|
||||
# default value for `max_num_patches` is 256, but you can increase resulted image resolution providing
|
||||
# higher values e.g. `max_num_patches=512`
|
||||
>>> inputs = processor(text=texts, images=image, max_num_patches=256, return_tensors="pt")
|
||||
|
||||
>>> with torch.no_grad():
|
||||
... outputs = model(**inputs)
|
||||
|
||||
>>> logits_per_image = outputs.logits_per_image
|
||||
>>> probs = torch.sigmoid(logits_per_image) # these are the probabilities
|
||||
>>> print(f"{probs[0][0]:.1%} that image 0 is '{candidate_labels[0]}'")
|
||||
21.1% that image 0 is '2 cats'
|
||||
```
|
||||
|
||||
## Resources
|
||||
|
||||
A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with SigLIP2.
|
||||
|
||||
- [Zero-shot image classification task guide](../tasks/zero_shot_image_classification)
|
||||
- Demo notebook for SigLIP2 can be found [here](https://github.com/qubvel/transformers-notebooks/tree/master/notebooks/SigLIP2_inference.ipynb). 🌎
|
||||
|
||||
If you're interested in submitting a resource to be included here, please feel free to open a Pull Request and we'll review it! The resource should ideally demonstrate something new instead of duplicating an existing resource.
|
||||
|
||||
|
||||
## Combining SigLIP2 and Flash Attention 2
|
||||
|
||||
First, make sure to install the latest version of Flash Attention 2.
|
||||
|
||||
```bash
|
||||
pip install -U flash-attn --no-build-isolation
|
||||
```
|
||||
|
||||
Make also sure that you have a hardware that is compatible with Flash-Attention 2. Read more about it in the official documentation of flash-attn repository. Make also sure to load your model in half-precision (e.g. `torch.float16``)
|
||||
|
||||
To load and run a model using Flash Attention 2, refer to the snippet below:
|
||||
|
||||
```python
|
||||
>>> import torch
|
||||
>>> import requests
|
||||
>>> from PIL import Image
|
||||
>>> from transformers import AutoProcessor, AutoModel
|
||||
>>> device = "cuda" # the device to load the model onto
|
||||
|
||||
>>> model = AutoModel.from_pretrained(
|
||||
... "google/siglip2-so400m-patch14-384",
|
||||
... attn_implementation="flash_attention_2",
|
||||
... torch_dtype=torch.float16,
|
||||
... device_map=device,
|
||||
... )
|
||||
>>> processor = AutoProcessor.from_pretrained("google/siglip2-so400m-patch14-384")
|
||||
|
||||
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
>>> candidate_labels = ["2 cats", "2 dogs"]
|
||||
# follows the pipeline prompt template to get same results
|
||||
>>> texts = [f'This is a photo of {label}.' for label in candidate_labels]
|
||||
# important: we pass `padding=max_length` since the model was trained with this
|
||||
>>> inputs = processor(text=texts, images=image, padding="max_length", return_tensors="pt").to(device)
|
||||
|
||||
>>> with torch.no_grad():
|
||||
... with torch.autocast(device):
|
||||
... outputs = model(**inputs)
|
||||
|
||||
>>> logits_per_image = outputs.logits_per_image
|
||||
>>> probs = torch.sigmoid(logits_per_image) # these are the probabilities
|
||||
>>> print(f"{probs[0][0]:.1%} that image 0 is '{candidate_labels[0]}'")
|
||||
19.8% that image 0 is '2 cats'
|
||||
```
|
||||
|
||||
## Siglip2Config
|
||||
|
||||
[[autodoc]] Siglip2Config
|
||||
|
||||
## Siglip2TextConfig
|
||||
|
||||
[[autodoc]] Siglip2TextConfig
|
||||
|
||||
## Siglip2VisionConfig
|
||||
|
||||
[[autodoc]] Siglip2VisionConfig
|
||||
|
||||
## Siglip2ImageProcessor
|
||||
|
||||
[[autodoc]] Siglip2ImageProcessor
|
||||
- preprocess
|
||||
|
||||
## Siglip2ImageProcessorFast
|
||||
|
||||
[[autodoc]] Siglip2ImageProcessorFast
|
||||
- preprocess
|
||||
|
||||
## Siglip2Processor
|
||||
|
||||
[[autodoc]] Siglip2Processor
|
||||
|
||||
## Siglip2Model
|
||||
|
||||
[[autodoc]] Siglip2Model
|
||||
- forward
|
||||
- get_text_features
|
||||
- get_image_features
|
||||
|
||||
## Siglip2TextModel
|
||||
|
||||
[[autodoc]] Siglip2TextModel
|
||||
- forward
|
||||
|
||||
## Siglip2VisionModel
|
||||
|
||||
[[autodoc]] Siglip2VisionModel
|
||||
- forward
|
||||
|
||||
## Siglip2ForImageClassification
|
||||
|
||||
[[autodoc]] Siglip2ForImageClassification
|
||||
- forward
|
@ -111,6 +111,7 @@ FlashAttention-2 is currently supported for the following architectures:
|
||||
* [data2vec_audio](https://huggingface.co/docs/transformers/main/en/model_doc/data2vec#transformers.Data2VecAudioModel)
|
||||
* [Sew](https://huggingface.co/docs/transformers/main/en/model_doc/sew#transformers.SEWModel)
|
||||
* [SigLIP](https://huggingface.co/docs/transformers/model_doc/siglip)
|
||||
* [SigLIP2](https://huggingface.co/docs/transformers/model_doc/siglip2)
|
||||
* [UniSpeech](https://huggingface.co/docs/transformers/v4.39.3/en/model_doc/unispeech#transformers.UniSpeechModel)
|
||||
* [unispeech_sat](https://huggingface.co/docs/transformers/v4.39.3/en/model_doc/unispeech-sat#transformers.UniSpeechSatModel)
|
||||
* [helium](https://huggingface.co/docs/transformers/main/en/model_doc/heliumtransformers.HeliumModel)
|
||||
@ -310,6 +311,7 @@ For now, Transformers supports SDPA inference and training for the following arc
|
||||
* [RoBERTa](https://huggingface.co/docs/transformers/model_doc/roberta#transformers.RobertaModel)
|
||||
* [Sew](https://huggingface.co/docs/transformers/main/en/model_doc/sew#transformers.SEWModel)
|
||||
* [SigLIP](https://huggingface.co/docs/transformers/model_doc/siglip)
|
||||
* [SigLIP2](https://huggingface.co/docs/transformers/model_doc/siglip2)
|
||||
* [StableLm](https://huggingface.co/docs/transformers/model_doc/stablelm#transformers.StableLmModel)
|
||||
* [Starcoder2](https://huggingface.co/docs/transformers/model_doc/starcoder2#transformers.Starcoder2Model)
|
||||
* [UniSpeech](https://huggingface.co/docs/transformers/v4.39.3/en/model_doc/unispeech#transformers.UniSpeechModel)
|
||||
|
@ -776,6 +776,12 @@ _import_structure = {
|
||||
"SiglipTextConfig",
|
||||
"SiglipVisionConfig",
|
||||
],
|
||||
"models.siglip2": [
|
||||
"Siglip2Config",
|
||||
"Siglip2Processor",
|
||||
"Siglip2TextConfig",
|
||||
"Siglip2VisionConfig",
|
||||
],
|
||||
"models.smolvlm": ["SmolVLMConfig"],
|
||||
"models.speech_encoder_decoder": ["SpeechEncoderDecoderConfig"],
|
||||
"models.speech_to_text": [
|
||||
@ -1289,6 +1295,7 @@ else:
|
||||
_import_structure["models.segformer"].extend(["SegformerFeatureExtractor", "SegformerImageProcessor"])
|
||||
_import_structure["models.seggpt"].extend(["SegGptImageProcessor"])
|
||||
_import_structure["models.siglip"].append("SiglipImageProcessor")
|
||||
_import_structure["models.siglip2"].append("Siglip2ImageProcessor")
|
||||
_import_structure["models.smolvlm"].extend(["SmolVLMImageProcessor"])
|
||||
_import_structure["models.superglue"].extend(["SuperGlueImageProcessor"])
|
||||
_import_structure["models.superpoint"].extend(["SuperPointImageProcessor"])
|
||||
@ -1330,6 +1337,7 @@ else:
|
||||
_import_structure["models.qwen2_vl"].append("Qwen2VLImageProcessorFast")
|
||||
_import_structure["models.rt_detr"].append("RTDetrImageProcessorFast")
|
||||
_import_structure["models.siglip"].append("SiglipImageProcessorFast")
|
||||
_import_structure["models.siglip2"].append("Siglip2ImageProcessorFast")
|
||||
_import_structure["models.vit"].append("ViTImageProcessorFast")
|
||||
|
||||
try:
|
||||
@ -3559,6 +3567,15 @@ else:
|
||||
"SiglipVisionModel",
|
||||
]
|
||||
)
|
||||
_import_structure["models.siglip2"].extend(
|
||||
[
|
||||
"Siglip2ForImageClassification",
|
||||
"Siglip2Model",
|
||||
"Siglip2PreTrainedModel",
|
||||
"Siglip2TextModel",
|
||||
"Siglip2VisionModel",
|
||||
]
|
||||
)
|
||||
_import_structure["models.smolvlm"].extend(
|
||||
[
|
||||
"SmolVLMForConditionalGeneration",
|
||||
@ -5942,6 +5959,12 @@ if TYPE_CHECKING:
|
||||
SiglipTextConfig,
|
||||
SiglipVisionConfig,
|
||||
)
|
||||
from .models.siglip2 import (
|
||||
Siglip2Config,
|
||||
Siglip2Processor,
|
||||
Siglip2TextConfig,
|
||||
Siglip2VisionConfig,
|
||||
)
|
||||
from .models.smolvlm import SmolVLMConfig
|
||||
from .models.speech_encoder_decoder import SpeechEncoderDecoderConfig
|
||||
from .models.speech_to_text import (
|
||||
@ -6472,6 +6495,7 @@ if TYPE_CHECKING:
|
||||
from .models.segformer import SegformerFeatureExtractor, SegformerImageProcessor
|
||||
from .models.seggpt import SegGptImageProcessor
|
||||
from .models.siglip import SiglipImageProcessor
|
||||
from .models.siglip2 import Siglip2ImageProcessor
|
||||
from .models.smolvlm import SmolVLMImageProcessor
|
||||
from .models.superglue import SuperGlueImageProcessor
|
||||
from .models.superpoint import SuperPointImageProcessor
|
||||
@ -6509,6 +6533,7 @@ if TYPE_CHECKING:
|
||||
from .models.qwen2_vl import Qwen2VLImageProcessorFast
|
||||
from .models.rt_detr import RTDetrImageProcessorFast
|
||||
from .models.siglip import SiglipImageProcessorFast
|
||||
from .models.siglip2 import Siglip2ImageProcessorFast
|
||||
from .models.vit import ViTImageProcessorFast
|
||||
|
||||
try:
|
||||
@ -8288,6 +8313,13 @@ if TYPE_CHECKING:
|
||||
SiglipTextModel,
|
||||
SiglipVisionModel,
|
||||
)
|
||||
from .models.siglip2 import (
|
||||
Siglip2ForImageClassification,
|
||||
Siglip2Model,
|
||||
Siglip2PreTrainedModel,
|
||||
Siglip2TextModel,
|
||||
Siglip2VisionModel,
|
||||
)
|
||||
from .models.smolvlm import (
|
||||
SmolVLMForConditionalGeneration,
|
||||
SmolVLMModel,
|
||||
|
@ -245,6 +245,7 @@ from . import (
|
||||
sew,
|
||||
sew_d,
|
||||
siglip,
|
||||
siglip2,
|
||||
smolvlm,
|
||||
speech_encoder_decoder,
|
||||
speech_to_text,
|
||||
|
@ -271,6 +271,7 @@ CONFIG_MAPPING_NAMES = OrderedDict(
|
||||
("sew", "SEWConfig"),
|
||||
("sew-d", "SEWDConfig"),
|
||||
("siglip", "SiglipConfig"),
|
||||
("siglip2", "Siglip2Config"),
|
||||
("siglip_vision_model", "SiglipVisionConfig"),
|
||||
("smolvlm", "SmolVLMConfig"),
|
||||
("smolvlm_vision", "SmolVLMVisionConfig"),
|
||||
@ -617,6 +618,8 @@ MODEL_NAMES_MAPPING = OrderedDict(
|
||||
("sew", "SEW"),
|
||||
("sew-d", "SEW-D"),
|
||||
("siglip", "SigLIP"),
|
||||
("siglip2", "SigLIP2"),
|
||||
("siglip2_vision_model", "Siglip2VisionModel"),
|
||||
("siglip_vision_model", "SiglipVisionModel"),
|
||||
("smolvlm", "SmolVLM"),
|
||||
("smolvlm_vision", "SmolVLMVisionTransformer"),
|
||||
|
@ -136,6 +136,7 @@ else:
|
||||
("segformer", ("SegformerImageProcessor",)),
|
||||
("seggpt", ("SegGptImageProcessor",)),
|
||||
("siglip", ("SiglipImageProcessor", "SiglipImageProcessorFast")),
|
||||
("siglip2", ("Siglip2ImageProcessor", "Siglip2ImageProcessorFast")),
|
||||
("superglue", "SuperGlueImageProcessor"),
|
||||
("swiftformer", ("ViTImageProcessor", "ViTImageProcessorFast")),
|
||||
("swin", ("ViTImageProcessor", "ViTImageProcessorFast")),
|
||||
|
@ -250,6 +250,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
|
||||
("sew", "SEWModel"),
|
||||
("sew-d", "SEWDModel"),
|
||||
("siglip", "SiglipModel"),
|
||||
("siglip2", "Siglip2Model"),
|
||||
("siglip_vision_model", "SiglipVisionModel"),
|
||||
("smolvlm", "SmolVLMModel"),
|
||||
("smolvlm_vision", "SmolVLMVisionTransformer"),
|
||||
@ -721,6 +722,7 @@ MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
|
||||
("resnet", "ResNetForImageClassification"),
|
||||
("segformer", "SegformerForImageClassification"),
|
||||
("siglip", "SiglipForImageClassification"),
|
||||
("siglip2", "Siglip2ForImageClassification"),
|
||||
("swiftformer", "SwiftFormerForImageClassification"),
|
||||
("swin", "SwinForImageClassification"),
|
||||
("swinv2", "Swinv2ForImageClassification"),
|
||||
@ -1403,6 +1405,7 @@ MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
|
||||
("clip", "CLIPModel"),
|
||||
("clipseg", "CLIPSegModel"),
|
||||
("siglip", "SiglipModel"),
|
||||
("siglip2", "Siglip2Model"),
|
||||
]
|
||||
)
|
||||
|
||||
|
@ -99,6 +99,7 @@ PROCESSOR_MAPPING_NAMES = OrderedDict(
|
||||
("sew", "Wav2Vec2Processor"),
|
||||
("sew-d", "Wav2Vec2Processor"),
|
||||
("siglip", "SiglipProcessor"),
|
||||
("siglip2", "Siglip2Processor"),
|
||||
("speech_to_text", "Speech2TextProcessor"),
|
||||
("speech_to_text_2", "Speech2Text2Processor"),
|
||||
("speecht5", "SpeechT5Processor"),
|
||||
|
@ -479,6 +479,13 @@ else:
|
||||
),
|
||||
),
|
||||
("siglip", ("SiglipTokenizer" if is_sentencepiece_available() else None, None)),
|
||||
(
|
||||
"siglip2",
|
||||
(
|
||||
"GemmaTokenizer" if is_sentencepiece_available() else None,
|
||||
"GemmaTokenizerFast" if is_tokenizers_available() else None,
|
||||
),
|
||||
),
|
||||
("speech_to_text", ("Speech2TextTokenizer" if is_sentencepiece_available() else None, None)),
|
||||
("speech_to_text_2", ("Speech2Text2Tokenizer", None)),
|
||||
("speecht5", ("SpeechT5Tokenizer" if is_sentencepiece_available() else None, None)),
|
||||
|
@ -1394,10 +1394,9 @@ class CLIPModel(CLIPPreTrainedModel):
|
||||
text_embeds = text_embeds / _get_vector_norm(text_embeds)
|
||||
|
||||
# cosine similarity as logits
|
||||
logit_scale = self.logit_scale.exp()
|
||||
logits_per_text = torch.matmul(text_embeds, image_embeds.t().to(text_embeds.device)) * logit_scale.to(
|
||||
text_embeds.device
|
||||
)
|
||||
logits_per_text = torch.matmul(text_embeds, image_embeds.t().to(text_embeds.device))
|
||||
logits_per_text = logits_per_text * self.logit_scale.exp().to(text_embeds.device)
|
||||
|
||||
logits_per_image = logits_per_text.t()
|
||||
|
||||
loss = None
|
||||
|
@ -59,6 +59,8 @@ class SiglipTextConfig(PretrainedConfig):
|
||||
The id of the beginning-of-sequence token in the vocabulary.
|
||||
eos_token_id (`int`, *optional*, defaults to 49407):
|
||||
The id of the end-of-sequence token in the vocabulary.
|
||||
projection_size (`int`, *optional*, defaults to `hidden_size`):
|
||||
The size of the projection head.
|
||||
|
||||
Example:
|
||||
|
||||
@ -94,6 +96,7 @@ class SiglipTextConfig(PretrainedConfig):
|
||||
pad_token_id=1,
|
||||
bos_token_id=49406,
|
||||
eos_token_id=49407,
|
||||
projection_size=None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
|
||||
@ -107,6 +110,7 @@ class SiglipTextConfig(PretrainedConfig):
|
||||
self.layer_norm_eps = layer_norm_eps
|
||||
self.hidden_act = hidden_act
|
||||
self.attention_dropout = attention_dropout
|
||||
self.projection_size = projection_size if projection_size is not None else hidden_size
|
||||
|
||||
|
||||
class SiglipVisionConfig(PretrainedConfig):
|
||||
|
@ -19,7 +19,8 @@ URL: https://github.com/google-research/big_vision/tree/main
|
||||
|
||||
import argparse
|
||||
import collections
|
||||
from pathlib import Path
|
||||
import os
|
||||
from typing import Tuple
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
@ -28,7 +29,14 @@ from huggingface_hub import hf_hub_download
|
||||
from numpy import load
|
||||
from PIL import Image
|
||||
|
||||
from transformers import SiglipConfig, SiglipImageProcessor, SiglipModel, SiglipProcessor, SiglipTokenizer
|
||||
from transformers import (
|
||||
GemmaTokenizerFast,
|
||||
SiglipConfig,
|
||||
SiglipImageProcessor,
|
||||
SiglipModel,
|
||||
SiglipProcessor,
|
||||
SiglipTokenizer,
|
||||
)
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
@ -36,6 +44,33 @@ logging.set_verbosity_info()
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
MODEL_CONFIGS = {
|
||||
"base": {
|
||||
"hidden_size": 768,
|
||||
"intermediate_size": 3072,
|
||||
"num_hidden_layers": 12,
|
||||
"num_attention_heads": 12,
|
||||
},
|
||||
"large": {
|
||||
"hidden_size": 1024,
|
||||
"intermediate_size": 4096,
|
||||
"num_hidden_layers": 24,
|
||||
"num_attention_heads": 16,
|
||||
},
|
||||
"giant-opt": {
|
||||
"hidden_size": 1536,
|
||||
"intermediate_size": 6144,
|
||||
"num_hidden_layers": 40,
|
||||
"num_attention_heads": 16,
|
||||
},
|
||||
"so400m": {
|
||||
"hidden_size": 1152,
|
||||
"intermediate_size": 4304,
|
||||
"num_hidden_layers": 27,
|
||||
"num_attention_heads": 16,
|
||||
},
|
||||
}
|
||||
|
||||
model_name_to_checkpoint = {
|
||||
# base checkpoints
|
||||
"siglip-base-patch16-224": "/Users/nielsrogge/Documents/SigLIP/webli_en_b16_224_63724782.npz",
|
||||
@ -49,56 +84,146 @@ model_name_to_checkpoint = {
|
||||
"siglip-base-patch16-256-i18n": "/Users/nielsrogge/Documents/SigLIP/webli_i18n_b16_256_66117334.npz",
|
||||
# so400m checkpoints
|
||||
"siglip-so400m-patch14-384": "/Users/nielsrogge/Documents/SigLIP/webli_en_so400m_384_58765454.npz",
|
||||
# ----------------- v2 -----------------
|
||||
# base checkpoints
|
||||
"siglip2-base-patch32-256": "gv-hf/siglip2/siglip2_b32_256.npz",
|
||||
"siglip2-base-patch16-224": "gv-hf/siglip2/siglip2_b16_224.npz",
|
||||
"siglip2-base-patch16-256": "gv-hf/siglip2/siglip2_b16_256.npz",
|
||||
"siglip2-base-patch16-384": "gv-hf/siglip2/siglip2_b16_384.npz",
|
||||
"siglip2-base-patch16-512": "gv-hf/siglip2/siglip2_b16_512.npz",
|
||||
# large checkpoints
|
||||
"siglip2-large-patch16-256": "gv-hf/siglip2/siglip2_l16_256.npz",
|
||||
"siglip2-large-patch16-384": "gv-hf/siglip2/siglip2_l16_384.npz",
|
||||
"siglip2-large-patch16-512": "gv-hf/siglip2/siglip2_l16_512.npz",
|
||||
# giant opt checkpoints
|
||||
"siglip2-giant-opt-patch16-256": "gv-hf/siglip2/siglip2_g-opt16_256.npz",
|
||||
"siglip2-giant-opt-patch16-384": "gv-hf/siglip2/siglip2_g-opt16_384.npz",
|
||||
# so400m checkpoints
|
||||
"siglip2-so400m-patch14-224": "gv-hf/siglip2/siglip2_so400m14_224.npz",
|
||||
"siglip2-so400m-patch14-384": "gv-hf/siglip2/siglip2_so400m14_384.npz",
|
||||
"siglip2-so400m-patch16-256": "gv-hf/siglip2/siglip2_so400m16_256.npz",
|
||||
"siglip2-so400m-patch16-384": "gv-hf/siglip2/siglip2_so400m16_384.npz",
|
||||
"siglip2-so400m-patch16-512": "gv-hf/siglip2/siglip2_so400m16_512.npz",
|
||||
}
|
||||
|
||||
model_name_to_image_size = {
|
||||
"siglip-base-patch16-224": 224,
|
||||
"siglip-base-patch16-256": 256,
|
||||
"siglip-base-patch16-384": 384,
|
||||
"siglip-base-patch16-512": 512,
|
||||
"siglip-large-patch16-256": 256,
|
||||
"siglip-large-patch16-384": 384,
|
||||
"siglip-base-patch16-256-i18n": 256,
|
||||
"siglip-so400m-patch14-384": 384,
|
||||
}
|
||||
# ------------------------------------------------------------------------------------------------------
|
||||
# CONFIG
|
||||
# ------------------------------------------------------------------------------------------------------
|
||||
|
||||
|
||||
def get_image_size_from_model_name(model_name: str) -> int:
|
||||
if "-i18n" not in model_name:
|
||||
size = model_name.split("-")[-1]
|
||||
else:
|
||||
size = model_name.split("-")[-2]
|
||||
return int(size)
|
||||
|
||||
|
||||
def get_patch_size_from_model_name(model_name: str) -> int:
|
||||
patch_str = [x for x in model_name.split("-") if "patch" in x][0]
|
||||
return int(patch_str[-2:])
|
||||
|
||||
|
||||
def get_vocab_size_from_model_name(model_name: str) -> int:
|
||||
if "siglip2" in model_name:
|
||||
vocab_size = 256000
|
||||
elif "-i18n" in model_name:
|
||||
vocab_size = 250000
|
||||
else:
|
||||
vocab_size = 32000
|
||||
return vocab_size
|
||||
|
||||
|
||||
def get_vocab_file_from_model_name(model_name: str) -> str:
|
||||
# get vocab file
|
||||
if "i18n" in model_name:
|
||||
vocab_file = "/Users/nielsrogge/Documents/SigLIP/multilingual_vocab/sentencepiece.model"
|
||||
else:
|
||||
vocab_file = "/Users/nielsrogge/Documents/SigLIP/english_vocab/sentencepiece.model"
|
||||
return vocab_file
|
||||
|
||||
|
||||
def get_text_and_vision_vit_variants(model_name: str) -> Tuple[str, str]:
|
||||
variant = model_name.split("-")[1] if "giant-opt" not in model_name else "giant-opt"
|
||||
return {
|
||||
"base": ("base", "base"),
|
||||
"large": ("large", "large"),
|
||||
"so400m": ("so400m", "so400m"),
|
||||
# g-opt siglip2 is not symmetric
|
||||
"giant-opt": ("so400m", "giant-opt"),
|
||||
}[variant]
|
||||
|
||||
|
||||
def get_siglip_config(model_name):
|
||||
config = SiglipConfig()
|
||||
text_variant, vision_variant = get_text_and_vision_vit_variants(model_name)
|
||||
text_config = MODEL_CONFIGS[text_variant].copy()
|
||||
vision_config = MODEL_CONFIGS[vision_variant].copy()
|
||||
|
||||
vocab_size = 250000 if "i18n" in model_name else 32000
|
||||
image_size = model_name_to_image_size[model_name]
|
||||
patch_size = 16 if "patch16" in model_name else 14
|
||||
text_config["vocab_size"] = get_vocab_size_from_model_name(model_name)
|
||||
vision_config["image_size"] = get_image_size_from_model_name(model_name)
|
||||
vision_config["patch_size"] = get_patch_size_from_model_name(model_name)
|
||||
|
||||
# size of the architecture
|
||||
config.vision_config.image_size = image_size
|
||||
config.vision_config.patch_size = patch_size
|
||||
config.text_config.vocab_size = vocab_size
|
||||
if text_config["hidden_size"] != vision_config["hidden_size"]:
|
||||
text_config["projection_size"] = vision_config["hidden_size"]
|
||||
|
||||
if "base" in model_name:
|
||||
pass
|
||||
elif "large" in model_name:
|
||||
config.text_config.hidden_size = 1024
|
||||
config.text_config.intermediate_size = 4096
|
||||
config.text_config.num_hidden_layers = 24
|
||||
config.text_config.num_attention_heads = 16
|
||||
config.vision_config.hidden_size = 1024
|
||||
config.vision_config.intermediate_size = 4096
|
||||
config.vision_config.num_hidden_layers = 24
|
||||
config.vision_config.num_attention_heads = 16
|
||||
elif "so400m" in model_name:
|
||||
config.text_config.hidden_size = 1152
|
||||
config.text_config.intermediate_size = 4304
|
||||
config.text_config.num_hidden_layers = 27
|
||||
config.text_config.num_attention_heads = 16
|
||||
config.vision_config.hidden_size = 1152
|
||||
config.vision_config.intermediate_size = 4304
|
||||
config.vision_config.num_hidden_layers = 27
|
||||
config.vision_config.num_attention_heads = 16
|
||||
return SiglipConfig(text_config=text_config, vision_config=vision_config)
|
||||
|
||||
|
||||
# ------------------------------------------------------------------------------------------------------
|
||||
# PROCESSING
|
||||
# ------------------------------------------------------------------------------------------------------
|
||||
|
||||
|
||||
def get_tokenizer(model_name: str) -> GemmaTokenizerFast:
|
||||
if "siglip2" in model_name:
|
||||
tokenizer = GemmaTokenizerFast.from_pretrained(
|
||||
"google/gemma-2-9b-it",
|
||||
add_bos_token=False,
|
||||
add_eos_token=True,
|
||||
padding_side="right",
|
||||
do_lower_case=True,
|
||||
# important: make tokenizer NOT return attention_mask since original one doesn't require it
|
||||
model_input_names=["input_ids"],
|
||||
)
|
||||
else:
|
||||
raise ValueError("Model not supported")
|
||||
# for siglip v1
|
||||
vocab_file = get_vocab_file_from_model_name(model_name)
|
||||
# important: make tokenizer not return attention_mask since original one doesn't require it
|
||||
tokenizer = SiglipTokenizer(vocab_file=vocab_file, model_input_names=["input_ids"])
|
||||
return tokenizer
|
||||
|
||||
return config
|
||||
|
||||
def get_image_processor(model_name: str) -> SiglipImageProcessor:
|
||||
image_size = get_image_size_from_model_name(model_name)
|
||||
size = {"height": image_size, "width": image_size}
|
||||
if "siglip2" in model_name:
|
||||
image_processor = SiglipImageProcessor(size=size, resample=2) # bilinear resampling
|
||||
else:
|
||||
image_processor = SiglipImageProcessor(size=size)
|
||||
return image_processor
|
||||
|
||||
|
||||
# ------------------------------------------------------------------------------------------------------
|
||||
# CONVERT FUNCTIONS
|
||||
# ------------------------------------------------------------------------------------------------------
|
||||
|
||||
|
||||
def split_encoderblock_layers(state_dict: dict) -> dict:
|
||||
"""
|
||||
Split the encoderblock weight into layers. In some cases they are concatenated in
|
||||
the original checkpoints.
|
||||
"""
|
||||
# Make shallow copy
|
||||
state_dict = state_dict.copy()
|
||||
# Split encoderblock weight into layers
|
||||
keys = list(state_dict.keys())
|
||||
for key in keys:
|
||||
if "/encoderblock/" in key:
|
||||
weight = state_dict.pop(key)
|
||||
for i, weight_i in enumerate(weight):
|
||||
new_name = key.replace("encoderblock", f"encoderblock_{i}")
|
||||
state_dict[new_name] = weight_i
|
||||
return state_dict
|
||||
|
||||
|
||||
def create_rename_keys(config):
|
||||
@ -258,23 +383,21 @@ def convert_siglip_checkpoint(model_name, pytorch_dump_folder_path, verify_logit
|
||||
Copy/paste/tweak model's weights to our SigLIP structure.
|
||||
"""
|
||||
|
||||
# define default SigLIP configuration
|
||||
# Define default SigLIP configuration
|
||||
config = get_siglip_config(model_name)
|
||||
|
||||
# get checkpoint
|
||||
# Get checkpoint
|
||||
checkpoint = model_name_to_checkpoint[model_name]
|
||||
if not os.path.exists(checkpoint):
|
||||
org, repo_id, *filepath = checkpoint.split("/")
|
||||
checkpoint = hf_hub_download(repo_id=f"{org}/{repo_id}", filename="/".join(filepath))
|
||||
|
||||
# get vocab file
|
||||
if "i18n" in model_name:
|
||||
vocab_file = "/Users/nielsrogge/Documents/SigLIP/multilingual_vocab/sentencepiece.model"
|
||||
else:
|
||||
vocab_file = "/Users/nielsrogge/Documents/SigLIP/english_vocab/sentencepiece.model"
|
||||
|
||||
# load original state dict
|
||||
# Load original state dict
|
||||
data = load(checkpoint)
|
||||
state_dict = flatten_nested_dict(data)
|
||||
state_dict = split_encoderblock_layers(state_dict)
|
||||
|
||||
# remove and rename some keys
|
||||
# Remove and rename some keys
|
||||
rename_keys = create_rename_keys(config)
|
||||
for src, dest in rename_keys:
|
||||
rename_key(state_dict, src, dest, config)
|
||||
@ -282,26 +405,28 @@ def convert_siglip_checkpoint(model_name, pytorch_dump_folder_path, verify_logit
|
||||
# qkv matrices of attention pooling head need special treatment
|
||||
read_in_q_k_v_head(state_dict, config)
|
||||
|
||||
# load HuggingFace model
|
||||
# Load HuggingFace model
|
||||
model = SiglipModel(config).eval()
|
||||
model.load_state_dict(state_dict)
|
||||
|
||||
# create processor
|
||||
# important: make tokenizer not return attention_mask since original one doesn't require it
|
||||
image_size = config.vision_config.image_size
|
||||
size = {"height": image_size, "width": image_size}
|
||||
image_processor = SiglipImageProcessor(size=size)
|
||||
tokenizer = SiglipTokenizer(vocab_file=vocab_file, model_input_names=["input_ids"])
|
||||
# Create processor
|
||||
image_processor = get_image_processor(model_name)
|
||||
tokenizer = get_tokenizer(model_name)
|
||||
processor = SiglipProcessor(image_processor=image_processor, tokenizer=tokenizer)
|
||||
|
||||
# verify on dummy images and texts
|
||||
# Verify forward pass on dummy images and texts
|
||||
url_1 = "https://cdn.openai.com/multimodal-neurons/assets/apple/apple-ipod.jpg"
|
||||
image_1 = Image.open(requests.get(url_1, stream=True).raw).convert("RGB")
|
||||
url_2 = "https://cdn.openai.com/multimodal-neurons/assets/apple/apple-blank.jpg"
|
||||
image_2 = Image.open(requests.get(url_2, stream=True).raw).convert("RGB")
|
||||
texts = ["an apple", "a picture of an apple"]
|
||||
|
||||
inputs = processor(images=[image_1, image_2], text=texts, return_tensors="pt", padding="max_length")
|
||||
inputs = processor(images=[image_1, image_2], text=texts, padding="max_length", max_length=64, return_tensors="pt")
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs)
|
||||
|
||||
if verify_logits:
|
||||
image_size = config.vision_config.image_size
|
||||
|
||||
# verify input_ids against original ones
|
||||
if image_size == 224:
|
||||
@ -328,18 +453,13 @@ def convert_siglip_checkpoint(model_name, pytorch_dump_folder_path, verify_logit
|
||||
|
||||
# note: we're testing with original pixel values here since we don't have exact pixel values
|
||||
with torch.no_grad():
|
||||
outputs = model(input_ids=inputs.input_ids, pixel_values=original_pixel_values)
|
||||
|
||||
# with torch.no_grad():
|
||||
# outputs = model(input_ids=inputs.input_ids, pixel_values=inputs.pixel_values)
|
||||
|
||||
outputs = model(input_ids=original_input_ids, pixel_values=original_pixel_values)
|
||||
print(outputs.logits_per_image[:3, :3])
|
||||
|
||||
probs = torch.sigmoid(outputs.logits_per_image) # these are the probabilities
|
||||
print(f"{probs[0][0]:.1%} that image 0 is '{texts[0]}'")
|
||||
print(f"{probs[0][1]:.1%} that image 0 is '{texts[1]}'")
|
||||
|
||||
if verify_logits:
|
||||
if model_name == "siglip-base-patch16-224":
|
||||
expected_slice = torch.tensor(
|
||||
[[-2.9621, -2.1672], [-0.2713, 0.2910]],
|
||||
@ -375,15 +495,16 @@ def convert_siglip_checkpoint(model_name, pytorch_dump_folder_path, verify_logit
|
||||
print("Looks ok!")
|
||||
|
||||
if pytorch_dump_folder_path is not None:
|
||||
Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
|
||||
pytorch_dump_folder_path = os.path.join(pytorch_dump_folder_path, model_name)
|
||||
os.makedirs(pytorch_dump_folder_path, exist_ok=True)
|
||||
print(f"Saving model {model_name} to {pytorch_dump_folder_path}")
|
||||
model.save_pretrained(pytorch_dump_folder_path)
|
||||
print(f"Saving processor to {pytorch_dump_folder_path}")
|
||||
processor.save_pretrained(pytorch_dump_folder_path)
|
||||
|
||||
if push_to_hub:
|
||||
model.push_to_hub(f"nielsr/{model_name}")
|
||||
processor.push_to_hub(f"nielsr/{model_name}")
|
||||
model.push_to_hub(f"s0225/{model_name}", private=True)
|
||||
processor.push_to_hub(f"s0225/{model_name}", private=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@ -401,7 +522,7 @@ if __name__ == "__main__":
|
||||
)
|
||||
parser.add_argument(
|
||||
"--verify_logits",
|
||||
action="store_false",
|
||||
action="store_true",
|
||||
help="Whether to verify logits against the original implementation.",
|
||||
)
|
||||
parser.add_argument(
|
||||
|
@ -471,15 +471,9 @@ class SiglipFlashAttention2(SiglipAttention):
|
||||
# Flash attention requires the input to have the shape
|
||||
# batch_size x seq_length x head_dim x hidden_dim
|
||||
# therefore we just need to keep the original shape
|
||||
query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
# TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
|
||||
# to be able to avoid many of these transpose/reshape/view.
|
||||
query_states = query_states.transpose(1, 2)
|
||||
key_states = key_states.transpose(1, 2)
|
||||
value_states = value_states.transpose(1, 2)
|
||||
query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim)
|
||||
key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim)
|
||||
value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim)
|
||||
|
||||
dropout_rate = self.dropout if self.training else 0.0
|
||||
|
||||
@ -936,7 +930,7 @@ class SiglipTextTransformer(nn.Module):
|
||||
self.encoder = SiglipEncoder(config)
|
||||
self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
|
||||
|
||||
self.head = nn.Linear(embed_dim, embed_dim)
|
||||
self.head = nn.Linear(embed_dim, config.projection_size)
|
||||
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
|
||||
|
||||
@add_start_docstrings_to_model_forward(SIGLIP_TEXT_INPUTS_DOCSTRING)
|
||||
@ -1415,10 +1409,11 @@ class SiglipModel(SiglipPreTrainedModel):
|
||||
text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)
|
||||
|
||||
# cosine similarity as logits
|
||||
logits_per_text = (
|
||||
torch.matmul(text_embeds, image_embeds.t().to(text_embeds.device)) * self.logit_scale.exp()
|
||||
+ self.logit_bias
|
||||
)
|
||||
logits_per_text = torch.matmul(text_embeds, image_embeds.t().to(text_embeds.device))
|
||||
|
||||
logit_scale, logit_bias = self.logit_scale.to(text_embeds.device), self.logit_bias.to(text_embeds.device)
|
||||
logits_per_text = logits_per_text * logit_scale.exp() + logit_bias
|
||||
|
||||
logits_per_image = logits_per_text.t()
|
||||
|
||||
loss = None
|
||||
|
@ -41,7 +41,7 @@ class SiglipProcessor(ProcessorMixin):
|
||||
|
||||
attributes = ["image_processor", "tokenizer"]
|
||||
image_processor_class = "SiglipImageProcessor"
|
||||
tokenizer_class = "SiglipTokenizer"
|
||||
tokenizer_class = "AutoTokenizer"
|
||||
|
||||
def __init__(self, image_processor, tokenizer):
|
||||
super().__init__(image_processor, tokenizer)
|
||||
@ -113,7 +113,7 @@ class SiglipProcessor(ProcessorMixin):
|
||||
image_features = self.image_processor(images, return_tensors=return_tensors)
|
||||
|
||||
if text is not None and images is not None:
|
||||
encoding["pixel_values"] = image_features.pixel_values
|
||||
encoding.update(image_features)
|
||||
return encoding
|
||||
elif text is not None:
|
||||
return encoding
|
||||
|
30
src/transformers/models/siglip2/__init__.py
Normal file
30
src/transformers/models/siglip2/__init__.py
Normal file
@ -0,0 +1,30 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...utils import _LazyModule
|
||||
from ...utils.import_utils import define_import_structure
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .configuration_siglip2 import *
|
||||
from .image_processing_siglip2 import *
|
||||
from .image_processing_siglip2_fast import *
|
||||
from .modeling_siglip2 import *
|
||||
from .processing_siglip2 import *
|
||||
else:
|
||||
import sys
|
||||
|
||||
_file = globals()["__file__"]
|
||||
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
|
277
src/transformers/models/siglip2/configuration_siglip2.py
Normal file
277
src/transformers/models/siglip2/configuration_siglip2.py
Normal file
@ -0,0 +1,277 @@
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
# This file was automatically generated from src/transformers/models/siglip2/modular_siglip2.py.
|
||||
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
||||
# the file from the modular. If any change should be done, please apply the change to the
|
||||
# modular_siglip2.py file directly. One of our CI enforces this.
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
# coding=utf-8
|
||||
# Copyright 2025 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class Siglip2TextConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`Siglip2TextModel`]. It is used to instantiate a
|
||||
Siglip2 text encoder according to the specified arguments, defining the model architecture. Instantiating a
|
||||
configuration with the defaults will yield a similar configuration to that of the text encoder of the Siglip2
|
||||
[google/siglip2-base-patch16-224](https://huggingface.co/google/siglip2-base-patch16-224) architecture.
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
Args:
|
||||
vocab_size (`int`, *optional*, defaults to 32000):
|
||||
Vocabulary size of the Siglip2 text model. Defines the number of different tokens that can be represented by
|
||||
the `inputs_ids` passed when calling [`Siglip2Model`].
|
||||
hidden_size (`int`, *optional*, defaults to 768):
|
||||
Dimensionality of the encoder layers and the pooler layer.
|
||||
intermediate_size (`int`, *optional*, defaults to 3072):
|
||||
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
|
||||
num_hidden_layers (`int`, *optional*, defaults to 12):
|
||||
Number of hidden layers in the Transformer encoder.
|
||||
num_attention_heads (`int`, *optional*, defaults to 12):
|
||||
Number of attention heads for each attention layer in the Transformer encoder.
|
||||
max_position_embeddings (`int`, *optional*, defaults to 64):
|
||||
The maximum sequence length that this model might ever be used with. Typically set this to something large
|
||||
just in case (e.g., 512 or 1024 or 2048).
|
||||
hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`):
|
||||
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
|
||||
`"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported.
|
||||
layer_norm_eps (`float`, *optional*, defaults to 1e-06):
|
||||
The epsilon used by the layer normalization layers.
|
||||
attention_dropout (`float`, *optional*, defaults to 0.0):
|
||||
The dropout ratio for the attention probabilities.
|
||||
pad_token_id (`int`, *optional*, defaults to 1):
|
||||
The id of the padding token in the vocabulary.
|
||||
bos_token_id (`int`, *optional*, defaults to 49406):
|
||||
The id of the beginning-of-sequence token in the vocabulary.
|
||||
eos_token_id (`int`, *optional*, defaults to 49407):
|
||||
The id of the end-of-sequence token in the vocabulary.
|
||||
projection_size (`int`, *optional*, defaults to `hidden_size`):
|
||||
The size of the projection head.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import Siglip2TextConfig, Siglip2TextModel
|
||||
|
||||
>>> # Initializing a Siglip2TextConfig with google/siglip2-base-patch16-224 style configuration
|
||||
>>> configuration = Siglip2TextConfig()
|
||||
|
||||
>>> # Initializing a Siglip2TextModel (with random weights) from the google/siglip2-base-patch16-224 style configuration
|
||||
>>> model = Siglip2TextModel(configuration)
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
```"""
|
||||
|
||||
model_type = "siglip2_text_model"
|
||||
base_config_key = "text_config"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=32000,
|
||||
hidden_size=768,
|
||||
intermediate_size=3072,
|
||||
num_hidden_layers=12,
|
||||
num_attention_heads=12,
|
||||
max_position_embeddings=64,
|
||||
hidden_act="gelu_pytorch_tanh",
|
||||
layer_norm_eps=1e-6,
|
||||
attention_dropout=0.0,
|
||||
# This differs from `CLIPTokenizer`'s default and from openai/siglip2
|
||||
# See https://github.com/huggingface/transformers/pull/24773#issuecomment-1632287538
|
||||
pad_token_id=1,
|
||||
bos_token_id=49406,
|
||||
eos_token_id=49407,
|
||||
projection_size=None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
|
||||
|
||||
self.vocab_size = vocab_size
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.layer_norm_eps = layer_norm_eps
|
||||
self.hidden_act = hidden_act
|
||||
self.attention_dropout = attention_dropout
|
||||
self.projection_size = projection_size if projection_size is not None else hidden_size
|
||||
|
||||
|
||||
class Siglip2VisionConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`Siglip2VisionModel`]. It is used to instantiate a
|
||||
Siglip2 vision encoder according to the specified arguments, defining the model architecture. Instantiating a
|
||||
configuration with the defaults will yield a similar configuration to that of the vision encoder of the Siglip2
|
||||
[google/siglip2-base-patch16-naflex](https://huggingface.co/google/siglip2-base-patch16-naflex) architecture.
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
Args:
|
||||
hidden_size (`int`, *optional*, defaults to 768):
|
||||
Dimensionality of the encoder layers and the pooler layer.
|
||||
intermediate_size (`int`, *optional*, defaults to 3072):
|
||||
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
|
||||
num_hidden_layers (`int`, *optional*, defaults to 12):
|
||||
Number of hidden layers in the Transformer encoder.
|
||||
num_attention_heads (`int`, *optional*, defaults to 12):
|
||||
Number of attention heads for each attention layer in the Transformer encoder.
|
||||
num_channels (`int`, *optional*, defaults to 3):
|
||||
Number of channels in the input images.
|
||||
num_patches (`int`, *optional*, defaults to 256):
|
||||
The number of patches in the image with the size of (`patch_size`, `patch_size`).
|
||||
The image is resized to fill maximum of this number of patches, and to preserve
|
||||
the aspect ratio. In case the resulted number of patches is lower, the image is
|
||||
padded in "patch" dimension.
|
||||
patch_size (`int`, *optional*, defaults to 16):
|
||||
The size (resolution) of each patch.
|
||||
hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`):
|
||||
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
|
||||
`"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported.
|
||||
layer_norm_eps (`float`, *optional*, defaults to 1e-06):
|
||||
The epsilon used by the layer normalization layers.
|
||||
attention_dropout (`float`, *optional*, defaults to 0.0):
|
||||
The dropout ratio for the attention probabilities.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import Siglip2VisionConfig, Siglip2VisionModel
|
||||
|
||||
>>> # Initializing a Siglip2VisionConfig with google/siglip2-base-patch16-naflex style configuration
|
||||
>>> configuration = Siglip2VisionConfig()
|
||||
|
||||
>>> # Initializing a Siglip2VisionModel (with random weights) from the google/siglip2-base-patch16-naflex style configuration
|
||||
>>> model = Siglip2VisionModel(configuration)
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
```"""
|
||||
|
||||
model_type = "siglip2_vision_model"
|
||||
base_config_key = "vision_config"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size=768,
|
||||
intermediate_size=3072,
|
||||
num_hidden_layers=12,
|
||||
num_attention_heads=12,
|
||||
num_channels=3,
|
||||
num_patches=256,
|
||||
patch_size=16,
|
||||
hidden_act="gelu_pytorch_tanh",
|
||||
layer_norm_eps=1e-6,
|
||||
attention_dropout=0.0,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.num_channels = num_channels
|
||||
self.patch_size = patch_size
|
||||
self.attention_dropout = attention_dropout
|
||||
self.layer_norm_eps = layer_norm_eps
|
||||
self.hidden_act = hidden_act
|
||||
self.num_patches = num_patches
|
||||
|
||||
|
||||
class Siglip2Config(PretrainedConfig):
|
||||
r"""
|
||||
[`Siglip2Config`] is the configuration class to store the configuration of a [`Siglip2Model`]. It is used to
|
||||
instantiate a Siglip2 model according to the specified arguments, defining the text model and vision model configs.
|
||||
Instantiating a configuration with the defaults will yield a similar configuration to that of the Siglip2
|
||||
[google/siglip2-base-patch16-224](https://huggingface.co/google/siglip2-base-patch16-224) architecture.
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
Args:
|
||||
text_config (`dict`, *optional*):
|
||||
Dictionary of configuration options used to initialize [`Siglip2TextConfig`].
|
||||
vision_config (`dict`, *optional*):
|
||||
Dictionary of configuration options used to initialize [`Siglip2VisionConfig`].
|
||||
kwargs (*optional*):
|
||||
Dictionary of keyword arguments.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import Siglip2Config, Siglip2Model
|
||||
|
||||
>>> # Initializing a Siglip2Config with google/siglip2-base-patch16-224 style configuration
|
||||
>>> configuration = Siglip2Config()
|
||||
|
||||
>>> # Initializing a Siglip2Model (with random weights) from the google/siglip2-base-patch16-224 style configuration
|
||||
>>> model = Siglip2Model(configuration)
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
|
||||
>>> # We can also initialize a Siglip2Config from a Siglip2TextConfig and a Siglip2VisionConfig
|
||||
>>> from transformers import Siglip2TextConfig, Siglip2VisionConfig
|
||||
|
||||
>>> # Initializing a Siglip2Text and Siglip2Vision configuration
|
||||
>>> config_text = Siglip2TextConfig()
|
||||
>>> config_vision = Siglip2VisionConfig()
|
||||
|
||||
>>> config = Siglip2Config.from_text_vision_configs(config_text, config_vision)
|
||||
```"""
|
||||
|
||||
model_type = "siglip2"
|
||||
sub_configs = {"text_config": Siglip2TextConfig, "vision_config": Siglip2VisionConfig}
|
||||
|
||||
def __init__(self, text_config=None, vision_config=None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
if text_config is None:
|
||||
text_config = {}
|
||||
logger.info("`text_config` is `None`. Initializing the `Siglip2TextConfig` with default values.")
|
||||
|
||||
if vision_config is None:
|
||||
vision_config = {}
|
||||
logger.info("`vision_config` is `None`. initializing the `Siglip2VisionConfig` with default values.")
|
||||
|
||||
self.text_config = Siglip2TextConfig(**text_config)
|
||||
self.vision_config = Siglip2VisionConfig(**vision_config)
|
||||
|
||||
self.initializer_factor = 1.0
|
||||
|
||||
@classmethod
|
||||
def from_text_vision_configs(cls, text_config: Siglip2TextConfig, vision_config: Siglip2VisionConfig, **kwargs):
|
||||
r"""
|
||||
Instantiate a [`Siglip2Config`] (or a derived class) from siglip2 text model configuration and siglip2 vision
|
||||
model configuration.
|
||||
|
||||
Returns:
|
||||
[`Siglip2Config`]: An instance of a configuration object
|
||||
"""
|
||||
|
||||
return cls(text_config=text_config.to_dict(), vision_config=vision_config.to_dict(), **kwargs)
|
||||
|
||||
|
||||
__all__ = ["Siglip2Config", "Siglip2TextConfig", "Siglip2VisionConfig"]
|
438
src/transformers/models/siglip2/convert_siglip2_to_hf.py
Normal file
438
src/transformers/models/siglip2/convert_siglip2_to_hf.py
Normal file
@ -0,0 +1,438 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Convert Siglip2 checkpoints from the original repository.
|
||||
|
||||
URL: https://github.com/google-research/big_vision/tree/main
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import collections
|
||||
import os
|
||||
import re
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from huggingface_hub import hf_hub_download
|
||||
from PIL import Image, ImageDraw
|
||||
|
||||
from transformers import GemmaTokenizerFast, Siglip2Config, Siglip2ImageProcessorFast, Siglip2Model, Siglip2Processor
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
logging.set_verbosity_info()
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
COMMON_CONFIG_PARAMS = {
|
||||
"base": {
|
||||
"hidden_size": 768,
|
||||
"intermediate_size": 3072,
|
||||
"num_hidden_layers": 12,
|
||||
"num_attention_heads": 12,
|
||||
},
|
||||
"large": {
|
||||
"hidden_size": 1024,
|
||||
"intermediate_size": 4096,
|
||||
"num_hidden_layers": 24,
|
||||
"num_attention_heads": 16,
|
||||
},
|
||||
"so400m": {
|
||||
"hidden_size": 1152,
|
||||
"intermediate_size": 4304,
|
||||
"num_hidden_layers": 27,
|
||||
"num_attention_heads": 16,
|
||||
},
|
||||
}
|
||||
|
||||
MODEL_NAME_TO_CHECKPOINT_PATH = {
|
||||
# base checkpoints
|
||||
"siglip2-base-patch16-naflex": "gv-hf/siglip2/siglip2_b16_naflex.npz",
|
||||
"siglip2-so400m-patch16-naflex": "gv-hf/siglip2/siglip2_so400m16_naflex.npz",
|
||||
}
|
||||
|
||||
# fmt: off
|
||||
EXPECTED_OUTPUTS = {
|
||||
"siglip2-base-patch16-naflex": torch.tensor([
|
||||
[ 1.0195, -0.0280, -1.4468],
|
||||
[ -4.5395, -6.2269, -1.5667],
|
||||
[ 4.1757, 5.0358, 3.5159],
|
||||
[ 9.4264, 10.1879, 6.3353],
|
||||
[ 2.4409, 3.1058, 4.5491],
|
||||
[-12.3230, -13.7355, -13.4632],
|
||||
[ 1.1520, 1.1687, -1.9647],
|
||||
]),
|
||||
"siglip2-so400m-patch16-naflex": torch.tensor([
|
||||
[ 0.9422, 0.5540, -2.4405],
|
||||
[ -7.3522, -9.4931, -6.3499],
|
||||
[ 5.7852, 6.7288, 7.7893],
|
||||
[ 9.9881, 10.8136, 9.2121],
|
||||
[ 5.3660, 5.7746, 8.4130],
|
||||
[-12.7218, -14.2631, -13.6442],
|
||||
[ 0.6384, 0.4278, -0.9022],
|
||||
]),
|
||||
}
|
||||
# fmt: on
|
||||
|
||||
# fmt: off
|
||||
ORIGINAL_TO_CONVERTED_KEY_MAPPING = {
|
||||
# Vision embeddings
|
||||
r"params/img/embedding/kernel": r"vision_model.embeddings.patch_embedding.weight",
|
||||
r"params/img/embedding/bias": r"vision_model.embeddings.patch_embedding.bias",
|
||||
r"params/img/pos_embedding": r"vision_model.embeddings.position_embedding.weight",
|
||||
# Vision encoder
|
||||
r"params/img/Transformer/encoderblock_(\d+)/LayerNorm_0/scale": r"vision_model.encoder.layers.\1.layer_norm1.weight",
|
||||
r"params/img/Transformer/encoderblock_(\d+)/LayerNorm_0/bias": r"vision_model.encoder.layers.\1.layer_norm1.bias",
|
||||
r"params/img/Transformer/encoderblock_(\d+)/LayerNorm_1/scale": r"vision_model.encoder.layers.\1.layer_norm2.weight",
|
||||
r"params/img/Transformer/encoderblock_(\d+)/LayerNorm_1/bias": r"vision_model.encoder.layers.\1.layer_norm2.bias",
|
||||
r"params/img/Transformer/encoderblock_(\d+)/MlpBlock_0/Dense_0/kernel": r"vision_model.encoder.layers.\1.mlp.fc1.weight",
|
||||
r"params/img/Transformer/encoderblock_(\d+)/MlpBlock_0/Dense_0/bias": r"vision_model.encoder.layers.\1.mlp.fc1.bias",
|
||||
r"params/img/Transformer/encoderblock_(\d+)/MlpBlock_0/Dense_1/kernel": r"vision_model.encoder.layers.\1.mlp.fc2.weight",
|
||||
r"params/img/Transformer/encoderblock_(\d+)/MlpBlock_0/Dense_1/bias": r"vision_model.encoder.layers.\1.mlp.fc2.bias",
|
||||
r"params/img/Transformer/encoderblock_(\d+)/MultiHeadDotProductAttention_0/(q|k|v|out)[a-z]*/kernel": r"vision_model.encoder.layers.\1.self_attn.\2_proj.weight",
|
||||
r"params/img/Transformer/encoderblock_(\d+)/MultiHeadDotProductAttention_0/(q|k|v|out)[a-z]*/bias": r"vision_model.encoder.layers.\1.self_attn.\2_proj.bias",
|
||||
# Vision norm
|
||||
r"params/img/Transformer/encoder_norm/scale": r"vision_model.post_layernorm.weight",
|
||||
r"params/img/Transformer/encoder_norm/bias": r"vision_model.post_layernorm.bias",
|
||||
# Vision head
|
||||
r"params/img/MAPHead_0/probe": r"vision_model.head.probe",
|
||||
r"params/img/MAPHead_0/LayerNorm_0/scale": r"vision_model.head.layernorm.weight",
|
||||
r"params/img/MAPHead_0/LayerNorm_0/bias": r"vision_model.head.layernorm.bias",
|
||||
r"params/img/MAPHead_0/MlpBlock_0/Dense_0/kernel": r"vision_model.head.mlp.fc1.weight",
|
||||
r"params/img/MAPHead_0/MlpBlock_0/Dense_0/bias": r"vision_model.head.mlp.fc1.bias",
|
||||
r"params/img/MAPHead_0/MlpBlock_0/Dense_1/kernel": r"vision_model.head.mlp.fc2.weight",
|
||||
r"params/img/MAPHead_0/MlpBlock_0/Dense_1/bias": r"vision_model.head.mlp.fc2.bias",
|
||||
r"params/img/MAPHead_0/MultiHeadDotProductAttention_0/out/kernel": r"vision_model.head.attention.out_proj.weight",
|
||||
r"params/img/MAPHead_0/MultiHeadDotProductAttention_0/out/bias": r"vision_model.head.attention.out_proj.bias",
|
||||
r"params/img/MAPHead_0/MultiHeadDotProductAttention_0/qkv/kernel": r"vision_model.head.attention.in_proj_weight",
|
||||
r"params/img/MAPHead_0/MultiHeadDotProductAttention_0/qkv/bias": r"vision_model.head.attention.in_proj_bias",
|
||||
# Text embeddings
|
||||
r"params/txt/Embed_0/embedding": r"text_model.embeddings.token_embedding.weight",
|
||||
r"params/txt/pos_embedding": r"text_model.embeddings.position_embedding.weight",
|
||||
# Text encoder
|
||||
r"params/txt/Encoder_0/encoderblock_(\d+)/LayerNorm_0/scale": r"text_model.encoder.layers.\1.layer_norm1.weight",
|
||||
r"params/txt/Encoder_0/encoderblock_(\d+)/LayerNorm_0/bias": r"text_model.encoder.layers.\1.layer_norm1.bias",
|
||||
r"params/txt/Encoder_0/encoderblock_(\d+)/LayerNorm_1/scale": r"text_model.encoder.layers.\1.layer_norm2.weight",
|
||||
r"params/txt/Encoder_0/encoderblock_(\d+)/LayerNorm_1/bias": r"text_model.encoder.layers.\1.layer_norm2.bias",
|
||||
r"params/txt/Encoder_0/encoderblock_(\d+)/MlpBlock_0/Dense_0/kernel": r"text_model.encoder.layers.\1.mlp.fc1.weight",
|
||||
r"params/txt/Encoder_0/encoderblock_(\d+)/MlpBlock_0/Dense_0/bias": r"text_model.encoder.layers.\1.mlp.fc1.bias",
|
||||
r"params/txt/Encoder_0/encoderblock_(\d+)/MlpBlock_0/Dense_1/kernel": r"text_model.encoder.layers.\1.mlp.fc2.weight",
|
||||
r"params/txt/Encoder_0/encoderblock_(\d+)/MlpBlock_0/Dense_1/bias": r"text_model.encoder.layers.\1.mlp.fc2.bias",
|
||||
r"params/txt/Encoder_0/encoderblock_(\d+)/MultiHeadDotProductAttention_0/(q|k|v|out)[a-z]*/kernel": r"text_model.encoder.layers.\1.self_attn.\2_proj.weight",
|
||||
r"params/txt/Encoder_0/encoderblock_(\d+)/MultiHeadDotProductAttention_0/(q|k|v|out)[a-z]*/bias": r"text_model.encoder.layers.\1.self_attn.\2_proj.bias",
|
||||
# Text encoder norm and head
|
||||
r"params/txt/Encoder_0/encoder_norm/scale": r"text_model.final_layer_norm.weight",
|
||||
r"params/txt/Encoder_0/encoder_norm/bias": r"text_model.final_layer_norm.bias",
|
||||
r"params/txt/head/kernel": r"text_model.head.weight",
|
||||
r"params/txt/head/bias": r"text_model.head.bias",
|
||||
# learned temperature and bias
|
||||
r"params/t": r"logit_scale",
|
||||
r"params/b": r"logit_bias",
|
||||
}
|
||||
# fmt: on
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------------------------
|
||||
# Model objects: configuration, tokenizer, image processor
|
||||
# --------------------------------------------------------------------------------------------
|
||||
|
||||
|
||||
def get_siglip2_config(model_name: str) -> Siglip2Config:
|
||||
"""
|
||||
Create a configuration for the Siglip2 model based on the model name.
|
||||
"""
|
||||
|
||||
_, variant, patch, _ = model_name.split("-")
|
||||
patch_size = int(patch[-2:])
|
||||
num_patches = 256
|
||||
|
||||
common_options = COMMON_CONFIG_PARAMS[variant]
|
||||
vision_config = {
|
||||
"patch_size": patch_size,
|
||||
"num_patches": num_patches,
|
||||
**common_options,
|
||||
}
|
||||
text_config = {
|
||||
"vocab_size": 256_000,
|
||||
**common_options,
|
||||
}
|
||||
config = Siglip2Config(
|
||||
vision_config=vision_config,
|
||||
text_config=text_config,
|
||||
)
|
||||
return config
|
||||
|
||||
|
||||
def get_siglip2_tokenizer() -> GemmaTokenizerFast:
|
||||
# Load pretrained tokenizer
|
||||
gemma_checkpoint = "google/gemma-7b"
|
||||
tokenizer = GemmaTokenizerFast.from_pretrained(
|
||||
gemma_checkpoint,
|
||||
add_bos_token=False,
|
||||
add_eos_token=True,
|
||||
padding_side="right",
|
||||
do_lower_case=True,
|
||||
# important: make tokenizer NOT return attention_mask since original one doesn't require it
|
||||
model_input_names=["input_ids"],
|
||||
)
|
||||
return tokenizer
|
||||
|
||||
|
||||
def get_siglip2_image_processor(patch_size: int, max_num_patches: int) -> Siglip2ImageProcessorFast:
|
||||
image_processor = Siglip2ImageProcessorFast(
|
||||
patch_size=patch_size,
|
||||
max_num_patches=max_num_patches,
|
||||
do_resize=True,
|
||||
do_normalize=True,
|
||||
image_mean=[0.5, 0.5, 0.5],
|
||||
image_std=[0.5, 0.5, 0.5],
|
||||
do_rescale=True,
|
||||
rescale_factor=1 / 255,
|
||||
resample=Image.Resampling.BILINEAR,
|
||||
)
|
||||
return image_processor
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------------------------
|
||||
# Helper functions for state dict conversion
|
||||
# --------------------------------------------------------------------------------------------
|
||||
|
||||
|
||||
def flatten_nested_dict(params: dict, parent_key: str = "", sep: str = "/") -> dict:
|
||||
"""
|
||||
Flatten a nested original checkpoint dictionary into a flat dictionary.
|
||||
"""
|
||||
items = []
|
||||
for k, v in params.items():
|
||||
new_key = parent_key + sep + k if parent_key else k
|
||||
if isinstance(v, collections.abc.MutableMapping):
|
||||
items.extend(flatten_nested_dict(v, new_key, sep=sep).items())
|
||||
else:
|
||||
items.append((new_key, v))
|
||||
return dict(items)
|
||||
|
||||
|
||||
def split_encoderblock_layers(state_dict: dict) -> dict:
|
||||
"""
|
||||
Split the encoderblock weight into layers. In some cases they are concatenated in
|
||||
the original checkpoints.
|
||||
"""
|
||||
# Make shallow copy
|
||||
state_dict = state_dict.copy()
|
||||
# Split encoderblock weight into layers
|
||||
keys = list(state_dict.keys())
|
||||
for key in keys:
|
||||
if "/encoderblock/" in key:
|
||||
weight = state_dict.pop(key)
|
||||
for i, weight_i in enumerate(weight):
|
||||
new_name = key.replace("encoderblock", f"encoderblock_{i}")
|
||||
state_dict[new_name] = weight_i
|
||||
return state_dict
|
||||
|
||||
|
||||
def merge_qkv_for_head(state_dict: dict, config: Siglip2Config) -> dict:
|
||||
"""
|
||||
Merge the q/k/v weights and biases for the attention head.
|
||||
"""
|
||||
# Make shallow copy
|
||||
state_dict = state_dict.copy()
|
||||
# Read and process q/k/v weights and biases
|
||||
qkv_weights, qkv_biases = [], []
|
||||
for name in ["query", "key", "value"]:
|
||||
prefix = f"params/img/MAPHead_0/MultiHeadDotProductAttention_0/{name}"
|
||||
weight = state_dict.pop(f"{prefix}/kernel").reshape(-1, config.vision_config.hidden_size)
|
||||
bias = state_dict.pop(f"{prefix}/bias").reshape(-1)
|
||||
qkv_weights.append(weight)
|
||||
qkv_biases.append(bias)
|
||||
|
||||
# Combine into single tensors
|
||||
state_dict["params/img/MAPHead_0/MultiHeadDotProductAttention_0/qkv/kernel"] = np.concatenate(qkv_weights, axis=1)
|
||||
state_dict["params/img/MAPHead_0/MultiHeadDotProductAttention_0/qkv/bias"] = np.concatenate(qkv_biases, axis=0)
|
||||
return state_dict
|
||||
|
||||
|
||||
def convert_old_keys_to_new_keys(state_dict_keys: list) -> dict:
|
||||
"""
|
||||
This function should be applied only once, on the concatenated keys to efficiently rename using
|
||||
the key mappings.
|
||||
"""
|
||||
output_dict = {}
|
||||
if state_dict_keys is not None:
|
||||
old_text = "\n".join(state_dict_keys)
|
||||
new_text = old_text
|
||||
for pattern, replacement in ORIGINAL_TO_CONVERTED_KEY_MAPPING.items():
|
||||
if replacement is None:
|
||||
new_text = re.sub(pattern, "", new_text) # an empty line
|
||||
continue
|
||||
new_text = re.sub(pattern, replacement, new_text)
|
||||
output_dict = dict(zip(old_text.split("\n"), new_text.split("\n")))
|
||||
return output_dict
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------------------------
|
||||
# Helper functions for model verification
|
||||
# --------------------------------------------------------------------------------------------
|
||||
|
||||
|
||||
def create_image(width, height):
|
||||
"""
|
||||
Helper function to create an image with a blue circle on a red background.
|
||||
"""
|
||||
image = Image.new("RGB", (width, height), color="red")
|
||||
draw = ImageDraw.Draw(image)
|
||||
center_x = image.width // 2
|
||||
center_y = image.height // 2
|
||||
radius = min(center_x, center_y) // 8 * 7
|
||||
draw.ellipse(
|
||||
(center_x - radius, center_y - radius, center_x + radius, center_y + radius),
|
||||
fill="blue",
|
||||
outline="green",
|
||||
width=image.width // 20,
|
||||
)
|
||||
return image
|
||||
|
||||
|
||||
def prepare_inputs():
|
||||
"""
|
||||
Prepare inputs for the model.
|
||||
"""
|
||||
text = [
|
||||
"circle",
|
||||
"ellipsoid",
|
||||
"blue circle on red background",
|
||||
"blue circle with green border on red background",
|
||||
"green circle on red background",
|
||||
"a dog",
|
||||
"a blue dog with a green border on a red background",
|
||||
]
|
||||
img224 = create_image(224, 224)
|
||||
img1024 = create_image(1024, 1024)
|
||||
img224_1024 = create_image(1024, 224)
|
||||
|
||||
images = [img224, img1024, img224_1024]
|
||||
return text, images
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------------------------
|
||||
# Convert model
|
||||
# --------------------------------------------------------------------------------------------
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def convert_siglip2_checkpoint(model_name, pytorch_dump_folder_path, verify_logits=True, push_to_hub=False):
|
||||
"""
|
||||
Copy/paste/tweak model's weights to our Siglip2 structure.
|
||||
"""
|
||||
|
||||
# Define Siglip2 configuration
|
||||
config = get_siglip2_config(model_name)
|
||||
|
||||
checkpoint = MODEL_NAME_TO_CHECKPOINT_PATH[model_name]
|
||||
if not os.path.exists(checkpoint):
|
||||
org, repo_id, *filepath = checkpoint.split("/")
|
||||
checkpoint = hf_hub_download(repo_id=f"{org}/{repo_id}", filename="/".join(filepath))
|
||||
|
||||
print(f"Loading checkpoint from {checkpoint}...")
|
||||
data = np.load(checkpoint)
|
||||
state_dict = flatten_nested_dict(data)
|
||||
state_dict = split_encoderblock_layers(state_dict)
|
||||
state_dict = merge_qkv_for_head(state_dict, config)
|
||||
|
||||
# Rename and transform weights
|
||||
print("Renaming and transforming weights...")
|
||||
|
||||
original_keys = list(state_dict.keys())
|
||||
hf_keys = convert_old_keys_to_new_keys(original_keys)
|
||||
|
||||
new_state_dict = {}
|
||||
for original_key in original_keys:
|
||||
new_key = hf_keys[original_key]
|
||||
parameter = state_dict.pop(original_key)
|
||||
|
||||
hidden_size = config.vision_config.hidden_size if "vision" in new_key else config.text_config.hidden_size
|
||||
|
||||
if any(k in new_key for k in ("out_proj", "q_proj", "k_proj", "v_proj", "position_embedding")):
|
||||
parameter = parameter.reshape(-1, hidden_size)
|
||||
|
||||
# Transpose every weight except for position_embedding and token_embedding
|
||||
if new_key.endswith("weight") and "position_embedding" not in new_key and "token_embedding" not in new_key:
|
||||
parameter = parameter.T
|
||||
|
||||
# Reshape every bias
|
||||
if new_key.endswith("bias"):
|
||||
parameter = parameter.reshape(-1)
|
||||
|
||||
new_state_dict[new_key] = torch.from_numpy(parameter)
|
||||
|
||||
# load HuggingFace model
|
||||
print("Loading HuggingFace model...")
|
||||
model = Siglip2Model(config).eval()
|
||||
model.load_state_dict(new_state_dict)
|
||||
|
||||
# Create processor
|
||||
print("Creating processor...")
|
||||
# TODO: update with more checkpoints
|
||||
tokenizer = get_siglip2_tokenizer()
|
||||
image_processor = get_siglip2_image_processor(config.vision_config.patch_size, max_num_patches=256)
|
||||
processor = Siglip2Processor(image_processor=image_processor, tokenizer=tokenizer)
|
||||
|
||||
# Verify logits
|
||||
if verify_logits:
|
||||
print(f"Verifying logits for {model_name}...")
|
||||
text, images = prepare_inputs()
|
||||
inputs = processor(text=text, images=images, padding="max_length", max_length=64, return_tensors="pt")
|
||||
outputs = model(**inputs)
|
||||
torch.testing.assert_close(outputs.logits_per_text, EXPECTED_OUTPUTS[model_name], atol=1e-3, rtol=1e-3)
|
||||
|
||||
# Save model
|
||||
if pytorch_dump_folder_path is not None:
|
||||
dst_dir = os.path.join(pytorch_dump_folder_path, model_name)
|
||||
print(f"Saving model {model_name} to {dst_dir}...")
|
||||
model.save_pretrained(dst_dir)
|
||||
print(f"Saving processor to {dst_dir}...")
|
||||
processor.save_pretrained(dst_dir)
|
||||
|
||||
if push_to_hub:
|
||||
print(f"Pushing model and processor for {model_name} to the HuggingFace Hub...")
|
||||
model.push_to_hub(f"qubvel-hf/{model_name}", private=True)
|
||||
processor.push_to_hub(f"qubvel-hf/{model_name}", private=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
# Required parameters
|
||||
parser.add_argument(
|
||||
"--model_name",
|
||||
default="siglip2-base-patch16-naflex",
|
||||
type=str,
|
||||
choices=MODEL_NAME_TO_CHECKPOINT_PATH.keys(),
|
||||
help="Name of the model you'd like to convert.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pytorch_dump_folder_path",
|
||||
default="checkpoints/",
|
||||
type=str,
|
||||
help="Path to the output PyTorch model directory.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--verify_logits",
|
||||
action="store_true",
|
||||
help="Whether to verify logits against the original implementation.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--push_to_hub", action="store_true", help="Whether or not to push the converted model to the 🤗 hub."
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
convert_siglip2_checkpoint(args.model_name, args.pytorch_dump_folder_path, args.verify_logits, args.push_to_hub)
|
343
src/transformers/models/siglip2/image_processing_siglip2.py
Normal file
343
src/transformers/models/siglip2/image_processing_siglip2.py
Normal file
@ -0,0 +1,343 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Image processor class for SigLIP2."""
|
||||
|
||||
import math
|
||||
from functools import lru_cache
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ...image_processing_utils import BaseImageProcessor, BatchFeature
|
||||
from ...image_transforms import (
|
||||
convert_to_rgb,
|
||||
resize,
|
||||
to_channel_dimension_format,
|
||||
)
|
||||
from ...image_utils import (
|
||||
ChannelDimension,
|
||||
ImageInput,
|
||||
PILImageResampling,
|
||||
infer_channel_dimension_format,
|
||||
is_scaled_image,
|
||||
make_flat_list_of_images,
|
||||
to_numpy_array,
|
||||
valid_images,
|
||||
validate_preprocess_arguments,
|
||||
)
|
||||
from ...utils import TensorType, filter_out_non_signature_kwargs, is_vision_available, logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
|
||||
|
||||
@lru_cache(maxsize=256)
|
||||
def get_image_size_for_max_num_patches(
|
||||
image_height: int, image_width: int, patch_size: int, max_num_patches: int, eps: float = 1e-5
|
||||
) -> Tuple[int, int]:
|
||||
"""
|
||||
Determine image size based on max number of patches, ensure dimensions are divisible by patch size and image is at least 1 patch.
|
||||
|
||||
Args:
|
||||
image_height (`int`):
|
||||
Original image height.
|
||||
image_width (`int`):
|
||||
Original image width.
|
||||
patch_size (`int`):
|
||||
Patch size for processing.
|
||||
max_num_patches (`int`):
|
||||
Maximum number of patches.
|
||||
eps (`float`):
|
||||
Small threshold for binary search.
|
||||
|
||||
Returns:
|
||||
Tuple: (target_height, target_width)
|
||||
"""
|
||||
|
||||
def get_scaled_image_size(scale: float, size: int, patch_size: int) -> int:
|
||||
scaled_size = size * scale
|
||||
scaled_size = math.ceil(scaled_size / patch_size) * patch_size # make divisible by patch_size
|
||||
scaled_size = max(patch_size, scaled_size) # ensure at least 1 patch
|
||||
return int(scaled_size)
|
||||
|
||||
# Binary search for optimal scale
|
||||
scale_min, scale_max = eps / 10, 100.0
|
||||
while (scale_max - scale_min) >= eps:
|
||||
scale = (scale_min + scale_max) / 2
|
||||
target_height = get_scaled_image_size(scale, image_height, patch_size)
|
||||
target_width = get_scaled_image_size(scale, image_width, patch_size)
|
||||
num_patches = (target_height / patch_size) * (target_width / patch_size)
|
||||
|
||||
if num_patches <= max_num_patches:
|
||||
scale_min = scale
|
||||
else:
|
||||
scale_max = scale
|
||||
|
||||
scale = scale_min
|
||||
target_height = get_scaled_image_size(scale, image_height, patch_size)
|
||||
target_width = get_scaled_image_size(scale, image_width, patch_size)
|
||||
return target_height, target_width
|
||||
|
||||
|
||||
def convert_image_to_patches(image: np.ndarray, patch_size: int) -> np.ndarray:
|
||||
"""
|
||||
Convert 3D array image of shape (image_height, image_width, num_channels) into 2D array of patches of shape
|
||||
(num_patches_height * num_patches_width, patch_size * patch_size * num_channels).
|
||||
"""
|
||||
image_height, image_width, num_channels = image.shape
|
||||
num_patches_height = image_height // patch_size
|
||||
num_patches_width = image_width // patch_size
|
||||
patched_image = image.reshape(num_patches_height, patch_size, num_patches_width, patch_size, num_channels)
|
||||
patched_image = patched_image.transpose(0, 2, 1, 3, 4)
|
||||
patched_image = patched_image.reshape(num_patches_height * num_patches_width, -1)
|
||||
return patched_image
|
||||
|
||||
|
||||
def pad_along_first_dim(array: np.ndarray, target_length: int, pad_value: int = 0) -> Tuple[np.ndarray, np.ndarray]:
|
||||
"""
|
||||
Pad the array along the first dimension.
|
||||
"""
|
||||
current_length = array.shape[0]
|
||||
padding_length = target_length - current_length
|
||||
mask = np.ones((target_length,), dtype=np.int32)
|
||||
if padding_length > 0:
|
||||
paddings = [(0, padding_length)] + [(0, 0)] * (array.ndim - 1)
|
||||
array = np.pad(array, paddings, mode="constant", constant_values=pad_value)
|
||||
mask[-padding_length:] = 0
|
||||
return array, mask
|
||||
|
||||
|
||||
class Siglip2ImageProcessor(BaseImageProcessor):
|
||||
r"""
|
||||
Constructs a SigLIP2 image processor.
|
||||
|
||||
Args:
|
||||
do_resize (`bool`, *optional*, defaults to `True`):
|
||||
Whether to resize the image's dimensions to fit `max_num_patches` according to given `patch_size`.
|
||||
Can be overridden by `do_resize` in the `preprocess` method.
|
||||
resample (`PILImageResampling`, *optional*, defaults to `Resampling.BILINEAR`):
|
||||
Resampling filter to use if resizing the image. Can be overridden by `resample` in the `preprocess` method.
|
||||
do_rescale (`bool`, *optional*, defaults to `True`):
|
||||
Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by `do_rescale` in
|
||||
the `preprocess` method.
|
||||
rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
|
||||
Scale factor to use if rescaling the image. Can be overridden by `rescale_factor` in the `preprocess`
|
||||
method.
|
||||
do_normalize (`bool`, *optional*, defaults to `True`):
|
||||
Whether to normalize the image by the specified mean and standard deviation. Can be overridden by
|
||||
`do_normalize` in the `preprocess` method.
|
||||
image_mean (`float` or `List[float]`, *optional*, defaults to `[0.5, 0.5, 0.5]`):
|
||||
Mean to use if normalizing the image. This is a float or list of floats the length of the number of
|
||||
channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
|
||||
image_std (`float` or `List[float]`, *optional*, defaults to `[0.5, 0.5, 0.5]`):
|
||||
Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
|
||||
number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
|
||||
Can be overridden by the `image_std` parameter in the `preprocess` method.
|
||||
do_convert_rgb (`bool`, *optional*, defaults to `True`):
|
||||
Whether to convert the image to RGB.
|
||||
patch_size (`int`, *optional*, defaults to 16):
|
||||
The size (resolution) of each patch the image will be split to.
|
||||
max_num_patches (`int`, *optional*, defaults to 256):
|
||||
The image will be resized to have at most this number of patches,
|
||||
and then padded in "patch" dimension to match this number exactly.
|
||||
"""
|
||||
|
||||
model_input_names = ["pixel_values", "pixel_attention_mask", "spatial_shapes"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
do_resize: bool = True,
|
||||
resample: PILImageResampling = PILImageResampling.BILINEAR,
|
||||
do_rescale: bool = True,
|
||||
rescale_factor: float = 1 / 255,
|
||||
do_normalize: bool = True,
|
||||
image_mean: Optional[Union[float, List[float]]] = None,
|
||||
image_std: Optional[Union[float, List[float]]] = None,
|
||||
do_convert_rgb: Optional[bool] = None,
|
||||
patch_size: int = 16,
|
||||
max_num_patches: int = 256,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
image_mean = image_mean if image_mean is not None else [0.5, 0.5, 0.5]
|
||||
image_std = image_std if image_std is not None else [0.5, 0.5, 0.5]
|
||||
|
||||
self.do_resize = do_resize
|
||||
self.resample = resample
|
||||
self.do_rescale = do_rescale
|
||||
self.rescale_factor = rescale_factor
|
||||
self.do_normalize = do_normalize
|
||||
self.image_mean = image_mean
|
||||
self.image_std = image_std
|
||||
self.do_convert_rgb = do_convert_rgb
|
||||
self.patch_size = patch_size
|
||||
self.max_num_patches = max_num_patches
|
||||
|
||||
@filter_out_non_signature_kwargs()
|
||||
def preprocess(
|
||||
self,
|
||||
images: ImageInput,
|
||||
do_resize: Optional[bool] = None,
|
||||
resample: Optional[PILImageResampling] = None,
|
||||
do_rescale: Optional[bool] = None,
|
||||
rescale_factor: Optional[float] = None,
|
||||
do_normalize: Optional[bool] = None,
|
||||
image_mean: Optional[Union[float, List[float]]] = None,
|
||||
image_std: Optional[Union[float, List[float]]] = None,
|
||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
do_convert_rgb: Optional[bool] = None,
|
||||
patch_size: Optional[int] = None,
|
||||
max_num_patches: Optional[int] = None,
|
||||
) -> "Image.Image":
|
||||
"""
|
||||
Preprocess an image or batch of images.
|
||||
|
||||
Args:
|
||||
images (`ImageInput`):
|
||||
Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
|
||||
passing in images with pixel values between 0 and 1, set `do_rescale=False`.
|
||||
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
|
||||
Whether to resize the image.
|
||||
size (`Dict[str, int]`, *optional*, defaults to `self.size`):
|
||||
Size of the image after resizing.
|
||||
resample (`int`, *optional*, defaults to `self.resample`):
|
||||
Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only
|
||||
has an effect if `do_resize` is set to `True`.
|
||||
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
|
||||
Whether to rescale the image.
|
||||
rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
|
||||
Rescale factor to rescale the image by if `do_rescale` is set to `True`.
|
||||
do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
|
||||
Whether to normalize the image.
|
||||
image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
|
||||
Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.
|
||||
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
|
||||
Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to
|
||||
`True`.
|
||||
return_tensors (`str` or `TensorType`, *optional*):
|
||||
The type of tensors to return. Can be one of:
|
||||
- Unset: Return a list of `np.ndarray`.
|
||||
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
|
||||
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
|
||||
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
|
||||
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
|
||||
input_data_format (`ChannelDimension` or `str`, *optional*):
|
||||
The channel dimension format for the input image. If unset, the channel dimension format is inferred
|
||||
from the input image. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
||||
do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
|
||||
Whether to convert the image to RGB.
|
||||
patch_size (`int`, *optional*, defaults to `self.patch_size`):
|
||||
Patch size for processing, same as the patch size used in the model.
|
||||
max_num_patches (`int`, *optional*, defaults to `self.max_num_patches`):
|
||||
Maximum number of patches per image, the image will be resized to have at most this number of patches.
|
||||
"""
|
||||
do_resize = do_resize if do_resize is not None else self.do_resize
|
||||
resample = resample if resample is not None else self.resample
|
||||
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
|
||||
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
|
||||
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
|
||||
image_mean = image_mean if image_mean is not None else self.image_mean
|
||||
image_std = image_std if image_std is not None else self.image_std
|
||||
do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
|
||||
patch_size = patch_size if patch_size is not None else self.patch_size
|
||||
max_num_patches = max_num_patches if max_num_patches is not None else self.max_num_patches
|
||||
|
||||
# Explicitly specify data format to be channels last for image preprocessing.
|
||||
# Image processor does not support different output formats, because it returns patches.
|
||||
data_format = ChannelDimension.LAST
|
||||
|
||||
images = make_flat_list_of_images(images)
|
||||
|
||||
if not valid_images(images):
|
||||
raise ValueError(
|
||||
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
|
||||
"torch.Tensor, tf.Tensor or jax.ndarray."
|
||||
)
|
||||
validate_preprocess_arguments(
|
||||
do_rescale=do_rescale,
|
||||
rescale_factor=rescale_factor,
|
||||
do_normalize=do_normalize,
|
||||
image_mean=image_mean,
|
||||
image_std=image_std,
|
||||
)
|
||||
if do_convert_rgb:
|
||||
images = [convert_to_rgb(image) for image in images]
|
||||
|
||||
# All transformations expect numpy arrays.
|
||||
images = [to_numpy_array(image) for image in images]
|
||||
|
||||
if do_rescale and is_scaled_image(images[0]):
|
||||
logger.warning_once(
|
||||
"It looks like you are trying to rescale already rescaled images. If the input"
|
||||
" images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
|
||||
)
|
||||
|
||||
if input_data_format is None:
|
||||
# We assume that all images have the same channel dimension format.
|
||||
input_data_format = infer_channel_dimension_format(images[0])
|
||||
|
||||
pixel_masks = []
|
||||
pixel_values = []
|
||||
spatial_shapes = []
|
||||
|
||||
for image in images:
|
||||
image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
|
||||
|
||||
if do_resize:
|
||||
height, width = get_image_size_for_max_num_patches(
|
||||
image_height=image.shape[0],
|
||||
image_width=image.shape[1],
|
||||
patch_size=patch_size,
|
||||
max_num_patches=max_num_patches,
|
||||
)
|
||||
image = resize(image=image, size=(height, width), resample=resample, input_data_format=data_format)
|
||||
|
||||
if do_rescale:
|
||||
image = self.rescale(image=image, scale=rescale_factor, input_data_format=data_format)
|
||||
|
||||
if do_normalize:
|
||||
image = self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=data_format)
|
||||
|
||||
patches = convert_image_to_patches(image, patch_size)
|
||||
patches, mask = pad_along_first_dim(patches, max_num_patches)
|
||||
num_patches_height = image.shape[0] // patch_size
|
||||
num_patches_width = image.shape[1] // patch_size
|
||||
|
||||
spatial_shapes.append((num_patches_height, num_patches_width))
|
||||
pixel_values.append(patches)
|
||||
pixel_masks.append(mask)
|
||||
|
||||
batch_feature = BatchFeature(
|
||||
data={
|
||||
"pixel_values": pixel_values,
|
||||
"pixel_attention_mask": pixel_masks,
|
||||
"spatial_shapes": spatial_shapes,
|
||||
},
|
||||
tensor_type=return_tensors,
|
||||
)
|
||||
|
||||
return batch_feature
|
||||
|
||||
|
||||
__all__ = ["Siglip2ImageProcessor"]
|
322
src/transformers/models/siglip2/image_processing_siglip2_fast.py
Normal file
322
src/transformers/models/siglip2/image_processing_siglip2_fast.py
Normal file
@ -0,0 +1,322 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Fast Image processor class for SigLIP2."""
|
||||
|
||||
import math
|
||||
from functools import lru_cache
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from ...image_processing_utils import BatchFeature
|
||||
from ...image_processing_utils_fast import BaseImageProcessorFast
|
||||
from ...image_utils import (
|
||||
ChannelDimension,
|
||||
ImageInput,
|
||||
PILImageResampling,
|
||||
SizeDict,
|
||||
TensorType,
|
||||
)
|
||||
from ...utils import (
|
||||
filter_out_non_signature_kwargs,
|
||||
is_torch_available,
|
||||
)
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
|
||||
@lru_cache(maxsize=256)
|
||||
# Copied from transformers.models.siglip2.image_processing_siglip2.get_image_size_for_max_num_patches
|
||||
def get_image_size_for_max_num_patches(
|
||||
image_height: int, image_width: int, patch_size: int, max_num_patches: int, eps: float = 1e-5
|
||||
) -> Tuple[int, int]:
|
||||
"""
|
||||
Determine image size based on max number of patches, ensure dimensions are divisible by patch size and image is at least 1 patch.
|
||||
|
||||
Args:
|
||||
image_height (`int`):
|
||||
Original image height.
|
||||
image_width (`int`):
|
||||
Original image width.
|
||||
patch_size (`int`):
|
||||
Patch size for processing.
|
||||
max_num_patches (`int`):
|
||||
Maximum number of patches.
|
||||
eps (`float`):
|
||||
Small threshold for binary search.
|
||||
|
||||
Returns:
|
||||
Tuple: (target_height, target_width)
|
||||
"""
|
||||
|
||||
def get_scaled_image_size(scale: float, size: int, patch_size: int) -> int:
|
||||
scaled_size = size * scale
|
||||
scaled_size = math.ceil(scaled_size / patch_size) * patch_size # make divisible by patch_size
|
||||
scaled_size = max(patch_size, scaled_size) # ensure at least 1 patch
|
||||
return int(scaled_size)
|
||||
|
||||
# Binary search for optimal scale
|
||||
scale_min, scale_max = eps / 10, 100.0
|
||||
while (scale_max - scale_min) >= eps:
|
||||
scale = (scale_min + scale_max) / 2
|
||||
target_height = get_scaled_image_size(scale, image_height, patch_size)
|
||||
target_width = get_scaled_image_size(scale, image_width, patch_size)
|
||||
num_patches = (target_height / patch_size) * (target_width / patch_size)
|
||||
|
||||
if num_patches <= max_num_patches:
|
||||
scale_min = scale
|
||||
else:
|
||||
scale_max = scale
|
||||
|
||||
scale = scale_min
|
||||
target_height = get_scaled_image_size(scale, image_height, patch_size)
|
||||
target_width = get_scaled_image_size(scale, image_width, patch_size)
|
||||
return target_height, target_width
|
||||
|
||||
|
||||
def convert_image_to_patches(image: "torch.Tensor", patch_size: int) -> "torch.Tensor":
|
||||
"""
|
||||
Convert 3D tensor image of shape (num_channels, image_height, image_width) into 2D tensor of patches of shape
|
||||
(num_patches_height * num_patches_width, patch_size * patch_size * num_channels).
|
||||
"""
|
||||
num_channels, image_height, image_width = image.shape
|
||||
num_patches_height = image_height // patch_size
|
||||
num_patches_width = image_width // patch_size
|
||||
patched_image = image.reshape(num_channels, num_patches_height, patch_size, num_patches_width, patch_size)
|
||||
patched_image = patched_image.permute(1, 3, 2, 4, 0)
|
||||
patched_image = patched_image.reshape(num_patches_height * num_patches_width, -1)
|
||||
return patched_image
|
||||
|
||||
|
||||
def pad_along_first_dim(
|
||||
tensor: "torch.Tensor", target_length: int, pad_value: int = 0
|
||||
) -> Tuple["torch.Tensor", "torch.Tensor"]:
|
||||
"""
|
||||
Pad the tensor along the first dimension.
|
||||
"""
|
||||
current_length = tensor.shape[0]
|
||||
padding_length = target_length - current_length
|
||||
mask = torch.ones((target_length,), dtype=torch.int32)
|
||||
if padding_length > 0:
|
||||
padding = [0, 0] * (tensor.ndim - 1) + [0, padding_length]
|
||||
tensor = torch.nn.functional.pad(tensor, padding, mode="constant", value=pad_value)
|
||||
mask[-padding_length:] = 0
|
||||
return tensor, mask
|
||||
|
||||
|
||||
class Siglip2ImageProcessorFast(BaseImageProcessorFast):
|
||||
r"""
|
||||
Constructs a fast SigLIP2 image processor.
|
||||
|
||||
Args:
|
||||
do_resize (`bool`, *optional*, defaults to `True`):
|
||||
Whether to resize the image's dimensions to fit `max_num_patches` according to given `patch_size`.
|
||||
Can be overridden by `do_resize` in the `preprocess` method.
|
||||
resample (`PILImageResampling`, *optional*, defaults to `Resampling.BILINEAR`):
|
||||
Resampling filter to use if resizing the image. Can be overridden by `resample` in the `preprocess` method.
|
||||
do_rescale (`bool`, *optional*, defaults to `True`):
|
||||
Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by `do_rescale` in
|
||||
the `preprocess` method.
|
||||
rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
|
||||
Scale factor to use if rescaling the image. Can be overridden by `rescale_factor` in the `preprocess`
|
||||
method.
|
||||
do_normalize (`bool`, *optional*, defaults to `True`):
|
||||
Whether to normalize the image by the specified mean and standard deviation. Can be overridden by
|
||||
`do_normalize` in the `preprocess` method.
|
||||
image_mean (`float` or `List[float]`, *optional*, defaults to `[0.5, 0.5, 0.5]`):
|
||||
Mean to use if normalizing the image. This is a float or list of floats the length of the number of
|
||||
channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
|
||||
image_std (`float` or `List[float]`, *optional*, defaults to `[0.5, 0.5, 0.5]`):
|
||||
Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
|
||||
number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
|
||||
Can be overridden by the `image_std` parameter in the `preprocess` method.
|
||||
do_convert_rgb (`bool`, *optional*, defaults to `True`):
|
||||
Whether to convert the image to RGB.
|
||||
patch_size (`int`, *optional*, defaults to 16):
|
||||
The size (resolution) of each patch the image will be split to.
|
||||
max_num_patches (`int`, *optional*, defaults to 256):
|
||||
The image will be resized to have at most this number of patches,
|
||||
and then padded in "patch" dimension to match this number exactly.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
do_resize: bool = True,
|
||||
resample: PILImageResampling = PILImageResampling.BILINEAR,
|
||||
do_rescale: bool = True,
|
||||
rescale_factor: float = 1 / 255,
|
||||
do_normalize: bool = True,
|
||||
image_mean: Optional[Union[float, List[float]]] = None,
|
||||
image_std: Optional[Union[float, List[float]]] = None,
|
||||
do_convert_rgb: Optional[bool] = None,
|
||||
patch_size: int = 16,
|
||||
max_num_patches: int = 256,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
image_mean = image_mean if image_mean is not None else [0.5, 0.5, 0.5]
|
||||
image_std = image_std if image_std is not None else [0.5, 0.5, 0.5]
|
||||
|
||||
self.do_resize = do_resize
|
||||
self.resample = resample
|
||||
self.do_rescale = do_rescale
|
||||
self.rescale_factor = rescale_factor
|
||||
self.do_normalize = do_normalize
|
||||
self.image_mean = image_mean
|
||||
self.image_std = image_std
|
||||
self.do_convert_rgb = do_convert_rgb
|
||||
self.patch_size = patch_size
|
||||
self.max_num_patches = max_num_patches
|
||||
|
||||
@filter_out_non_signature_kwargs()
|
||||
def preprocess(
|
||||
self,
|
||||
images: ImageInput,
|
||||
do_resize: Optional[bool] = None,
|
||||
resample: Optional[PILImageResampling] = None,
|
||||
do_rescale: Optional[bool] = None,
|
||||
rescale_factor: Optional[float] = None,
|
||||
do_normalize: Optional[bool] = None,
|
||||
image_mean: Optional[Union[float, List[float]]] = None,
|
||||
image_std: Optional[Union[float, List[float]]] = None,
|
||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
do_convert_rgb: Optional[bool] = None,
|
||||
patch_size: Optional[int] = None,
|
||||
max_num_patches: Optional[int] = None,
|
||||
device: Union["torch.device", str] = "cpu",
|
||||
) -> BatchFeature:
|
||||
"""
|
||||
Preprocess an image or batch of images.
|
||||
|
||||
Args:
|
||||
images (`ImageInput`):
|
||||
Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
|
||||
passing in images with pixel values between 0 and 1, set `do_rescale=False`.
|
||||
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
|
||||
Whether to resize the image.
|
||||
size (`Dict[str, int]`, *optional*, defaults to `self.size`):
|
||||
Size of the image after resizing.
|
||||
resample (`int`, *optional*, defaults to `self.resample`):
|
||||
Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only
|
||||
has an effect if `do_resize` is set to `True`.
|
||||
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
|
||||
Whether to rescale the image.
|
||||
rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
|
||||
Rescale factor to rescale the image by if `do_rescale` is set to `True`.
|
||||
do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
|
||||
Whether to normalize the image.
|
||||
image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
|
||||
Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.
|
||||
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
|
||||
Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to
|
||||
`True`.
|
||||
return_tensors (`str` or `TensorType`, *optional*):
|
||||
The type of tensors to return. Can be one of:
|
||||
- Unset: Return a list of `np.ndarray`.
|
||||
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
|
||||
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
|
||||
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
|
||||
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
|
||||
input_data_format (`ChannelDimension` or `str`, *optional*):
|
||||
The channel dimension format for the input image. If unset, the channel dimension format is inferred
|
||||
from the input image. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
||||
do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
|
||||
Whether to convert the image to RGB.
|
||||
patch_size (`int`, *optional*, defaults to `self.patch_size`):
|
||||
Patch size for processing, same as the patch size used in the model.
|
||||
max_num_patches (`int`, *optional*, defaults to `self.max_num_patches`):
|
||||
Maximum number of patches per image, the image will be resized to have at most this number of patches.
|
||||
"""
|
||||
do_resize = do_resize if do_resize is not None else self.do_resize
|
||||
resample = resample if resample is not None else self.resample
|
||||
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
|
||||
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
|
||||
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
|
||||
image_mean = image_mean if image_mean is not None else self.image_mean
|
||||
image_std = image_std if image_std is not None else self.image_std
|
||||
do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
|
||||
patch_size = patch_size if patch_size is not None else self.patch_size
|
||||
max_num_patches = max_num_patches if max_num_patches is not None else self.max_num_patches
|
||||
|
||||
image_mean = tuple(image_mean) if isinstance(image_mean, list) else image_mean
|
||||
image_std = tuple(image_std) if isinstance(image_std, list) else image_std
|
||||
|
||||
image_mean, image_std, interpolation = self._prepare_process_arguments(
|
||||
do_normalize=do_normalize,
|
||||
do_rescale=do_rescale,
|
||||
rescale_factor=rescale_factor,
|
||||
image_mean=image_mean,
|
||||
image_std=image_std,
|
||||
resample=resample,
|
||||
)
|
||||
|
||||
images = self._prepare_input_images(
|
||||
images=images,
|
||||
do_convert_rgb=do_convert_rgb,
|
||||
input_data_format=input_data_format,
|
||||
device=device,
|
||||
)
|
||||
|
||||
pixel_masks = []
|
||||
pixel_values = []
|
||||
spatial_shapes = []
|
||||
|
||||
for image in images:
|
||||
if do_resize:
|
||||
height, width = get_image_size_for_max_num_patches(
|
||||
image_height=image.shape[1],
|
||||
image_width=image.shape[2],
|
||||
patch_size=patch_size,
|
||||
max_num_patches=max_num_patches,
|
||||
)
|
||||
side_dict = SizeDict(height=height, width=width)
|
||||
image = self.resize(image=image, size=side_dict, interpolation=interpolation)
|
||||
|
||||
image = self.rescale_and_normalize(image, do_rescale, rescale_factor, do_normalize, image_mean, image_std)
|
||||
|
||||
# (num_channels, height, width) -> (num_patches, patch_size * patch_size * num_channels)
|
||||
patches = convert_image_to_patches(image, patch_size)
|
||||
patches, mask = pad_along_first_dim(patches, max_num_patches)
|
||||
|
||||
num_patches_height = image.shape[1] // patch_size
|
||||
num_patches_width = image.shape[2] // patch_size
|
||||
|
||||
spatial_shapes.append((num_patches_height, num_patches_width))
|
||||
pixel_values.append(patches)
|
||||
pixel_masks.append(mask)
|
||||
|
||||
pixel_values = torch.stack(pixel_values)
|
||||
pixel_masks = torch.stack(pixel_masks)
|
||||
spatial_shapes = torch.tensor(spatial_shapes)
|
||||
|
||||
batch_feature = BatchFeature(
|
||||
data={
|
||||
"pixel_values": pixel_values,
|
||||
"pixel_attention_mask": pixel_masks,
|
||||
"spatial_shapes": spatial_shapes,
|
||||
},
|
||||
tensor_type=return_tensors,
|
||||
)
|
||||
return batch_feature
|
||||
|
||||
|
||||
__all__ = ["Siglip2ImageProcessorFast"]
|
1634
src/transformers/models/siglip2/modeling_siglip2.py
Normal file
1634
src/transformers/models/siglip2/modeling_siglip2.py
Normal file
File diff suppressed because it is too large
Load Diff
537
src/transformers/models/siglip2/modular_siglip2.py
Normal file
537
src/transformers/models/siglip2/modular_siglip2.py
Normal file
@ -0,0 +1,537 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
|
||||
from transformers.models.siglip.configuration_siglip import SiglipConfig, SiglipTextConfig, SiglipVisionConfig
|
||||
from transformers.models.siglip.modeling_siglip import (
|
||||
BaseModelOutputWithPooling,
|
||||
ImageClassifierOutput,
|
||||
SiglipForImageClassification,
|
||||
SiglipModel,
|
||||
SiglipMultiheadAttentionPoolingHead,
|
||||
SiglipOutput,
|
||||
SiglipPreTrainedModel,
|
||||
SiglipTextModel,
|
||||
SiglipTextModelOutput,
|
||||
SiglipVisionModel,
|
||||
SiglipVisionModelOutput,
|
||||
SiglipVisionTransformer,
|
||||
)
|
||||
|
||||
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
|
||||
|
||||
|
||||
class Siglip2TextConfig(SiglipTextConfig):
|
||||
pass
|
||||
|
||||
|
||||
class Siglip2VisionConfig(SiglipVisionConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`Siglip2VisionModel`]. It is used to instantiate a
|
||||
Siglip2 vision encoder according to the specified arguments, defining the model architecture. Instantiating a
|
||||
configuration with the defaults will yield a similar configuration to that of the vision encoder of the Siglip2
|
||||
[google/siglip2-base-patch16-naflex](https://huggingface.co/google/siglip2-base-patch16-naflex) architecture.
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
Args:
|
||||
hidden_size (`int`, *optional*, defaults to 768):
|
||||
Dimensionality of the encoder layers and the pooler layer.
|
||||
intermediate_size (`int`, *optional*, defaults to 3072):
|
||||
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
|
||||
num_hidden_layers (`int`, *optional*, defaults to 12):
|
||||
Number of hidden layers in the Transformer encoder.
|
||||
num_attention_heads (`int`, *optional*, defaults to 12):
|
||||
Number of attention heads for each attention layer in the Transformer encoder.
|
||||
num_channels (`int`, *optional*, defaults to 3):
|
||||
Number of channels in the input images.
|
||||
num_patches (`int`, *optional*, defaults to 256):
|
||||
The number of patches in the image with the size of (`patch_size`, `patch_size`).
|
||||
The image is resized to fill maximum of this number of patches, and to preserve
|
||||
the aspect ratio. In case the resulted number of patches is lower, the image is
|
||||
padded in "patch" dimension.
|
||||
patch_size (`int`, *optional*, defaults to 16):
|
||||
The size (resolution) of each patch.
|
||||
hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`):
|
||||
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
|
||||
`"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported.
|
||||
layer_norm_eps (`float`, *optional*, defaults to 1e-06):
|
||||
The epsilon used by the layer normalization layers.
|
||||
attention_dropout (`float`, *optional*, defaults to 0.0):
|
||||
The dropout ratio for the attention probabilities.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import Siglip2VisionConfig, Siglip2VisionModel
|
||||
|
||||
>>> # Initializing a Siglip2VisionConfig with google/siglip2-base-patch16-naflex style configuration
|
||||
>>> configuration = Siglip2VisionConfig()
|
||||
|
||||
>>> # Initializing a Siglip2VisionModel (with random weights) from the google/siglip2-base-patch16-naflex style configuration
|
||||
>>> model = Siglip2VisionModel(configuration)
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
```"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size=768,
|
||||
intermediate_size=3072,
|
||||
num_hidden_layers=12,
|
||||
num_attention_heads=12,
|
||||
num_channels=3,
|
||||
num_patches=256,
|
||||
patch_size=16,
|
||||
hidden_act="gelu_pytorch_tanh",
|
||||
layer_norm_eps=1e-6,
|
||||
attention_dropout=0.0,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
self.num_patches = num_patches
|
||||
del self.image_size
|
||||
|
||||
|
||||
class Siglip2Config(SiglipConfig):
|
||||
pass
|
||||
|
||||
|
||||
class Siglip2VisionOutput(SiglipVisionModelOutput):
|
||||
pass
|
||||
|
||||
|
||||
class Siglip2TextOutput(SiglipTextModelOutput):
|
||||
pass
|
||||
|
||||
|
||||
class Siglip2Output(SiglipOutput):
|
||||
pass
|
||||
|
||||
|
||||
class Siglip2VisionEmbeddings(nn.Module):
|
||||
def __init__(self, config: Siglip2VisionConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.embed_dim = config.hidden_size
|
||||
self.patch_size = config.patch_size
|
||||
|
||||
self.patch_embedding = nn.Linear(
|
||||
in_features=config.num_channels * self.patch_size * self.patch_size,
|
||||
out_features=self.embed_dim,
|
||||
)
|
||||
|
||||
self.num_patches = config.num_patches
|
||||
self.position_embedding_size = int(self.num_patches**0.5)
|
||||
self.position_embedding = nn.Embedding(self.num_patches, self.embed_dim)
|
||||
|
||||
@staticmethod
|
||||
def resize_positional_embeddings(
|
||||
positional_embeddings: torch.Tensor,
|
||||
spatial_shapes: torch.LongTensor,
|
||||
max_length: int,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Resize positional embeddings to image-specific size and pad to a fixed size.
|
||||
|
||||
Args:
|
||||
positional_embeddings (`torch.Tensor`):
|
||||
Position embeddings of shape (height, width, embed_dim)
|
||||
spatial_shapes (`torch.LongTensor`):
|
||||
Spatial shapes of shape (batch_size, 2) to resize the positional embeddings to
|
||||
max_length (`int`):
|
||||
Maximum length of the positional embeddings to pad resized positional embeddings to
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`: Embeddings of shape (batch_size, max_length, embed_dim)
|
||||
"""
|
||||
batch_size = spatial_shapes.shape[0]
|
||||
embed_dim = positional_embeddings.shape[-1]
|
||||
source_dtype = positional_embeddings.dtype
|
||||
|
||||
resulted_positional_embeddings = torch.empty(
|
||||
(batch_size, max_length, embed_dim),
|
||||
device=positional_embeddings.device,
|
||||
dtype=source_dtype,
|
||||
)
|
||||
|
||||
# (height, width, embed_dim) -> (1, embed_dim, height, width) for interpolation
|
||||
positional_embeddings = positional_embeddings.permute(2, 0, 1).unsqueeze(0)
|
||||
|
||||
# Upcast to float32 on CPU because antialias is not supported for bfloat16/float16 on CPU
|
||||
if positional_embeddings.device.type == "cpu":
|
||||
positional_embeddings = positional_embeddings.to(torch.float32)
|
||||
|
||||
for i in range(batch_size):
|
||||
# (1, dim, height, width) -> (1, dim, target_height, target_width)
|
||||
height, width = spatial_shapes[i]
|
||||
resized_embeddings = F.interpolate(
|
||||
positional_embeddings,
|
||||
size=(height, width),
|
||||
mode="bilinear",
|
||||
align_corners=False,
|
||||
antialias=True,
|
||||
)
|
||||
|
||||
# (1, dim, target_height, target_width) -> (target_height * target_width, dim)
|
||||
resized_embeddings = resized_embeddings.reshape(embed_dim, height * width).transpose(0, 1)
|
||||
|
||||
# Cast to original dtype
|
||||
resized_embeddings = resized_embeddings.to(source_dtype)
|
||||
|
||||
resulted_positional_embeddings[i, : height * width] = resized_embeddings
|
||||
resulted_positional_embeddings[i, height * width :] = resized_embeddings[0]
|
||||
|
||||
return resulted_positional_embeddings
|
||||
|
||||
def forward(self, pixel_values: torch.FloatTensor, spatial_shapes: torch.LongTensor) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
pixel_values (`torch.FloatTensor`):
|
||||
Pixel values of shape (batch_size, max_num_patches, num_channels * patch_size * patch_size)
|
||||
spatial_shapes (`List[Tuple[int, int]]`):
|
||||
Spatial shapes of shape (batch_size, 2) to resize the positional embeddings to
|
||||
"""
|
||||
|
||||
# Apply patch embeddings to already patchified pixel values
|
||||
target_dtype = self.patch_embedding.weight.dtype
|
||||
patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype))
|
||||
|
||||
# Get positional resized and padded positional embeddings
|
||||
positional_embeddings = self.position_embedding.weight.reshape(
|
||||
self.position_embedding_size, self.position_embedding_size, -1
|
||||
)
|
||||
resized_positional_embeddings = self.resize_positional_embeddings(
|
||||
positional_embeddings, spatial_shapes, max_length=pixel_values.shape[1]
|
||||
)
|
||||
|
||||
# Add positional embeddings to patch embeddings
|
||||
embeddings = patch_embeds + resized_positional_embeddings
|
||||
return embeddings
|
||||
|
||||
|
||||
class Siglip2VisionTransformer(SiglipVisionTransformer):
|
||||
def __init__(self, config: Siglip2VisionConfig):
|
||||
super().__init__()
|
||||
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
|
||||
|
||||
# Update: add `spatial_shapes` and `attention_mask`
|
||||
def forward(
|
||||
self,
|
||||
pixel_values: torch.FloatTensor,
|
||||
attention_mask: torch.Tensor,
|
||||
spatial_shapes: torch.LongTensor,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, BaseModelOutputWithPooling]:
|
||||
r"""
|
||||
Returns:
|
||||
|
||||
"""
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
hidden_states = self.embeddings(pixel_values, spatial_shapes)
|
||||
|
||||
if attention_mask is not None and not self._use_flash_attention_2:
|
||||
# [batch_size, seq_len] -> [batch_size, 1, tgt_seq_len, src_seq_len]
|
||||
encoder_attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype)
|
||||
else:
|
||||
encoder_attention_mask = attention_mask
|
||||
|
||||
encoder_outputs = self.encoder(
|
||||
inputs_embeds=hidden_states,
|
||||
attention_mask=encoder_attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
last_hidden_state = encoder_outputs[0]
|
||||
last_hidden_state = self.post_layernorm(last_hidden_state)
|
||||
|
||||
pooler_output = self.head(last_hidden_state, attention_mask) if self.use_head else None
|
||||
if not return_dict:
|
||||
return (last_hidden_state, pooler_output) + encoder_outputs[1:]
|
||||
|
||||
return BaseModelOutputWithPooling(
|
||||
last_hidden_state=last_hidden_state,
|
||||
pooler_output=pooler_output,
|
||||
hidden_states=encoder_outputs.hidden_states,
|
||||
attentions=encoder_outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
class Siglip2PreTrainedModel(SiglipPreTrainedModel):
|
||||
pass
|
||||
|
||||
|
||||
class Siglip2TextModel(SiglipTextModel):
|
||||
pass
|
||||
|
||||
|
||||
class Siglip2MultiheadAttentionPoolingHead(SiglipMultiheadAttentionPoolingHead):
|
||||
def __init__(self, config: Siglip2VisionConfig):
|
||||
super().__init__(config)
|
||||
self.num_heads = config.num_attention_heads
|
||||
|
||||
def forward(self, hidden_state: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
batch_size = hidden_state.shape[0]
|
||||
probe = self.probe.repeat(batch_size, 1, 1)
|
||||
|
||||
if attention_mask is not None:
|
||||
target_len, source_len = probe.shape[1], hidden_state.shape[1]
|
||||
attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_state.dtype, target_len)
|
||||
attention_mask = attention_mask.repeat(1, self.num_heads, target_len, 1)
|
||||
attention_mask = attention_mask.reshape(-1, target_len, source_len)
|
||||
|
||||
hidden_state = self.attention(probe, hidden_state, hidden_state, attn_mask=attention_mask)[0]
|
||||
|
||||
residual = hidden_state
|
||||
hidden_state = self.layernorm(hidden_state)
|
||||
hidden_state = residual + self.mlp(hidden_state)
|
||||
|
||||
return hidden_state[:, 0]
|
||||
|
||||
|
||||
class Siglip2VisionModel(SiglipVisionModel):
|
||||
# Update: add `spatial_shapes` and `pixel_attention_mask`
|
||||
def forward(
|
||||
self,
|
||||
pixel_values: torch.FloatTensor,
|
||||
pixel_attention_mask: torch.Tensor,
|
||||
spatial_shapes: torch.LongTensor,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, BaseModelOutputWithPooling]:
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
return self.vision_model(
|
||||
pixel_values=pixel_values,
|
||||
attention_mask=pixel_attention_mask,
|
||||
spatial_shapes=spatial_shapes,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
|
||||
class Siglip2Model(SiglipModel):
|
||||
# Update: add `spatial_shapes` and `pixel_attention_mask`
|
||||
def get_image_features(
|
||||
self,
|
||||
pixel_values: Optional[torch.FloatTensor] = None,
|
||||
pixel_attention_mask: Optional[torch.Tensor] = None,
|
||||
spatial_shapes: Optional[torch.LongTensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> torch.FloatTensor:
|
||||
# Use Siglip2Model's config for some fields (if specified) instead of those of vision & text components.
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
vision_outputs = self.vision_model(
|
||||
pixel_values=pixel_values,
|
||||
attention_mask=pixel_attention_mask,
|
||||
spatial_shapes=spatial_shapes,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
pooled_output = vision_outputs[1]
|
||||
|
||||
return pooled_output
|
||||
|
||||
# Update: add `spatial_shapes` and `pixel_attention_mask`
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
pixel_values: Optional[torch.FloatTensor] = None,
|
||||
pixel_attention_mask: Optional[torch.Tensor] = None,
|
||||
spatial_shapes: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
return_loss: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, Siglip2Output]:
|
||||
# Use Siglip2 model's config for some fields (if specified) instead of those of vision & text components.
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
vision_outputs = self.vision_model(
|
||||
pixel_values=pixel_values,
|
||||
attention_mask=pixel_attention_mask,
|
||||
spatial_shapes=spatial_shapes,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
text_outputs = self.text_model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
image_embeds = vision_outputs[1]
|
||||
text_embeds = text_outputs[1]
|
||||
|
||||
# normalized features
|
||||
image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
|
||||
text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)
|
||||
|
||||
# cosine similarity as logits
|
||||
logits_per_text = torch.matmul(text_embeds, image_embeds.t().to(text_embeds.device))
|
||||
|
||||
logit_scale, logit_bias = self.logit_scale.to(text_embeds.device), self.logit_bias.to(text_embeds.device)
|
||||
logits_per_text = logits_per_text * logit_scale.exp() + logit_bias
|
||||
|
||||
logits_per_image = logits_per_text.t()
|
||||
|
||||
loss = None
|
||||
if return_loss:
|
||||
# Adapted from https://github.com/google-research/big_vision/blob/01edb81a4716f93a48be43b3a4af14e29cdb3a7f/big_vision/trainers/proj/image_text/siglip2.py#L287
|
||||
eye = torch.eye(logits_per_text.size(0), device=logits_per_text.device)
|
||||
m1_diag1 = -torch.ones_like(logits_per_text) + 2 * eye
|
||||
loglik = torch.nn.functional.logsigmoid(m1_diag1 * logits_per_text)
|
||||
nll = -torch.sum(loglik, dim=-1)
|
||||
loss = nll.mean()
|
||||
|
||||
if not return_dict:
|
||||
output = (logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs)
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return Siglip2Output(
|
||||
loss=loss,
|
||||
logits_per_image=logits_per_image,
|
||||
logits_per_text=logits_per_text,
|
||||
text_embeds=text_embeds,
|
||||
image_embeds=image_embeds,
|
||||
text_model_output=text_outputs,
|
||||
vision_model_output=vision_outputs,
|
||||
)
|
||||
|
||||
|
||||
class Siglip2ForImageClassification(SiglipForImageClassification):
|
||||
# Update: add `spatial_shapes` and `pixel_attention_mask`
|
||||
def forward(
|
||||
self,
|
||||
pixel_values: Optional[torch.Tensor] = None,
|
||||
pixel_attention_mask: Optional[torch.Tensor] = None,
|
||||
spatial_shapes: Optional[torch.LongTensor] = None,
|
||||
labels: Optional[torch.Tensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[tuple, ImageClassifierOutput]:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
outputs = self.vision_model(
|
||||
pixel_values,
|
||||
attention_mask=pixel_attention_mask,
|
||||
spatial_shapes=spatial_shapes,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
sequence_output = outputs[0]
|
||||
|
||||
# average pool the patch tokens
|
||||
if pixel_attention_mask is not None:
|
||||
pool_mask = pixel_attention_mask[..., None].to(sequence_output.device)
|
||||
sequence_output = torch.sum(sequence_output * pool_mask, dim=1) / torch.sum(pool_mask, dim=1)
|
||||
else:
|
||||
sequence_output = torch.mean(sequence_output, dim=1)
|
||||
|
||||
# apply classifier
|
||||
logits = self.classifier(sequence_output)
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
# move labels to correct device to enable model parallelism
|
||||
labels = labels.to(logits.device)
|
||||
if self.config.problem_type is None:
|
||||
if self.num_labels == 1:
|
||||
self.config.problem_type = "regression"
|
||||
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
||||
self.config.problem_type = "single_label_classification"
|
||||
else:
|
||||
self.config.problem_type = "multi_label_classification"
|
||||
|
||||
if self.config.problem_type == "regression":
|
||||
loss_fct = MSELoss()
|
||||
if self.num_labels == 1:
|
||||
loss = loss_fct(logits.squeeze(), labels.squeeze())
|
||||
else:
|
||||
loss = loss_fct(logits, labels)
|
||||
elif self.config.problem_type == "single_label_classification":
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||
elif self.config.problem_type == "multi_label_classification":
|
||||
loss_fct = BCEWithLogitsLoss()
|
||||
loss = loss_fct(logits, labels)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[2:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return ImageClassifierOutput(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"Siglip2Config",
|
||||
"Siglip2TextConfig",
|
||||
"Siglip2VisionConfig",
|
||||
"Siglip2Model",
|
||||
"Siglip2PreTrainedModel",
|
||||
"Siglip2TextModel",
|
||||
"Siglip2VisionModel",
|
||||
"Siglip2ForImageClassification",
|
||||
]
|
171
src/transformers/models/siglip2/processing_siglip2.py
Normal file
171
src/transformers/models/siglip2/processing_siglip2.py
Normal file
@ -0,0 +1,171 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Image/Text processor class for SigLIP2.
|
||||
"""
|
||||
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from ...feature_extraction_utils import BatchFeature
|
||||
from ...image_utils import ImageInput
|
||||
from ...processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin, Unpack
|
||||
from ...tokenization_utils_base import PreTokenizedInput, TextInput
|
||||
|
||||
|
||||
class Siglip2ImagesKwargs(ImagesKwargs, total=False):
|
||||
max_num_patches: Optional[int]
|
||||
patch_size: Optional[int]
|
||||
|
||||
|
||||
class Siglip2ProcessorKwargs(ProcessingKwargs, total=False):
|
||||
images_kwargs: Siglip2ImagesKwargs
|
||||
|
||||
_defaults = {
|
||||
"text_kwargs": {
|
||||
"padding": "max_length",
|
||||
"truncation": True,
|
||||
"max_length": 64,
|
||||
},
|
||||
"images_kwargs": {
|
||||
"max_num_patches": 256,
|
||||
"patch_size": 16,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class Siglip2Processor(ProcessorMixin):
|
||||
r"""
|
||||
Constructs a Siglip2 processor which wraps a Siglip2 image processor and a Gemma tokenizer into a single processor.
|
||||
|
||||
[`Siglip2Processor`] offers all the functionalities of [`Siglip2ImageProcessor`] and [`GemmaTokenizerFast`]. See the
|
||||
[`~Siglip2Processor.__call__`] and [`~Siglip2Processor.decode`] for more information.
|
||||
|
||||
Args:
|
||||
image_processor ([`Siglip2ImageProcessor`]):
|
||||
The image processor is a required input.
|
||||
tokenizer ([`GemmaTokenizerFast`]):
|
||||
The tokenizer is a required input.
|
||||
"""
|
||||
|
||||
attributes = ["image_processor", "tokenizer"]
|
||||
|
||||
image_processor_class = "AutoImageProcessor"
|
||||
tokenizer_class = "AutoTokenizer"
|
||||
|
||||
def __init__(self, image_processor, tokenizer):
|
||||
super().__init__(image_processor, tokenizer)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
images: Optional[Union[ImageInput, List[ImageInput], List[List[ImageInput]]]] = None,
|
||||
text: Optional[Union[TextInput, "PreTokenizedInput", List[TextInput], List["PreTokenizedInput"]]] = None,
|
||||
audio=None,
|
||||
videos=None,
|
||||
**kwargs: Unpack[Siglip2ProcessorKwargs],
|
||||
) -> BatchFeature:
|
||||
"""
|
||||
Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
|
||||
and `kwargs` arguments to GemmaTokenizerFast's [`~GemmaTokenizerFast.__call__`] if `text` is not `None` to encode
|
||||
the text. To prepare the image(s), this method forwards the `images` argument to
|
||||
Siglip2ImageProcessor's [`~Siglip2ImageProcessor.__call__`] if `images` is not `None`. Please refer to the doctsring
|
||||
of the above two methods for more information.
|
||||
|
||||
Args:
|
||||
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
|
||||
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
|
||||
tensor. Both channels-first and channels-last formats are supported.
|
||||
text (`str`, `List[str]`, `List[List[str]]`):
|
||||
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
|
||||
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
|
||||
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
|
||||
padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `max_length`):
|
||||
Select a strategy to pad the returned sequences (according to the model's padding side and padding
|
||||
index) among:
|
||||
- `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
|
||||
acceptable input length for the model if that argument is not provided.
|
||||
- `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
|
||||
sequence if provided).
|
||||
- `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
|
||||
lengths).
|
||||
max_length (`int`, *optional*, defaults to 64):
|
||||
Maximum length of the returned list and optionally padding length (see above).
|
||||
truncation (`bool`, *optional*, defaults to `True`):
|
||||
Activates truncation to cut input sequences longer than `max_length` to `max_length`.
|
||||
return_tensors (`str` or [`~utils.TensorType`], *optional*, defaults to `'pt'`):
|
||||
If set, will return tensors of a particular framework. Acceptable values are:
|
||||
|
||||
- `'tf'`: Return TensorFlow `tf.constant` objects.
|
||||
- `'pt'`: Return PyTorch `torch.Tensor` objects.
|
||||
- `'np'`: Return NumPy `np.ndarray` objects.
|
||||
- `'jax'`: Return JAX `jnp.ndarray` objects.
|
||||
|
||||
Returns:
|
||||
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
|
||||
|
||||
- **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
|
||||
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
|
||||
`return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
|
||||
`None`).
|
||||
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
|
||||
- **pixel_attention_mask** -- Attention mask for the pixel values. Returned when `images` is not `None`.
|
||||
- **spatial_shapes** -- The number of horizontal and vertical patches per image.
|
||||
Returned when `images` is not `None`.
|
||||
"""
|
||||
output_kwargs = self._merge_kwargs(
|
||||
Siglip2ProcessorKwargs,
|
||||
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if text is None and images is None:
|
||||
raise ValueError("You have to specify either text or images. Both cannot be none.")
|
||||
|
||||
if text is not None:
|
||||
encoding = self.tokenizer(text, **output_kwargs["text_kwargs"])
|
||||
|
||||
if images is not None:
|
||||
image_features = self.image_processor(images, **output_kwargs["images_kwargs"])
|
||||
|
||||
if text is not None and images is not None:
|
||||
encoding.update(image_features)
|
||||
return encoding
|
||||
elif text is not None:
|
||||
return encoding
|
||||
else:
|
||||
return_tensors = output_kwargs["common_kwargs"]["return_tensors"]
|
||||
return BatchFeature(data=dict(**image_features), tensor_type=return_tensors)
|
||||
|
||||
def decode(self, *args, **kwargs):
|
||||
"""
|
||||
This method forwards all its arguments to Siglip2Tokenizer's [`~PreTrainedTokenizer.decode`]. Please refer to
|
||||
the docstring of this method for more information.
|
||||
"""
|
||||
return self.tokenizer.decode(*args, **kwargs)
|
||||
|
||||
def batch_decode(self, *args, **kwargs):
|
||||
"""
|
||||
This method forwards all its arguments to Siglip2Tokenizer's [`~PreTrainedTokenizer.batch_decode`]. Please
|
||||
refer to the docstring of this method for more information.
|
||||
"""
|
||||
return self.tokenizer.batch_decode(*args, **kwargs)
|
||||
|
||||
@property
|
||||
def model_input_names(self):
|
||||
tokenizer_input_names = self.tokenizer.model_input_names
|
||||
image_processor_input_names = self.image_processor.model_input_names
|
||||
return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
|
||||
|
||||
|
||||
__all__ = ["Siglip2Processor"]
|
@ -145,8 +145,11 @@ class ZeroShotImageClassificationPipeline(Pipeline):
|
||||
inputs = inputs.to(self.torch_dtype)
|
||||
inputs["candidate_labels"] = candidate_labels
|
||||
sequences = [hypothesis_template.format(x) for x in candidate_labels]
|
||||
padding = "max_length" if self.model.config.model_type == "siglip" else True
|
||||
text_inputs = self.tokenizer(sequences, return_tensors=self.framework, padding=padding, **tokenizer_kwargs)
|
||||
tokenizer_default_kwargs = {"padding": True}
|
||||
if "siglip" in self.model.config.model_type:
|
||||
tokenizer_default_kwargs.update(padding="max_length", max_length=64, truncation=True)
|
||||
tokenizer_default_kwargs.update(tokenizer_kwargs)
|
||||
text_inputs = self.tokenizer(sequences, return_tensors=self.framework, **tokenizer_default_kwargs)
|
||||
inputs["text_inputs"] = [text_inputs]
|
||||
return inputs
|
||||
|
||||
@ -170,7 +173,7 @@ class ZeroShotImageClassificationPipeline(Pipeline):
|
||||
def postprocess(self, model_outputs):
|
||||
candidate_labels = model_outputs.pop("candidate_labels")
|
||||
logits = model_outputs["logits"][0]
|
||||
if self.framework == "pt" and self.model.config.model_type == "siglip":
|
||||
if self.framework == "pt" and "siglip" in self.model.config.model_type:
|
||||
probs = torch.sigmoid(logits).squeeze(-1)
|
||||
scores = probs.tolist()
|
||||
if not isinstance(scores, list):
|
||||
|
@ -8849,6 +8849,41 @@ class SiglipVisionModel(metaclass=DummyObject):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class Siglip2ForImageClassification(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class Siglip2Model(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class Siglip2PreTrainedModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class Siglip2TextModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class Siglip2VisionModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class SmolVLMForConditionalGeneration(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
|
@ -107,6 +107,13 @@ class SiglipImageProcessorFast(metaclass=DummyObject):
|
||||
requires_backends(self, ["torchvision"])
|
||||
|
||||
|
||||
class Siglip2ImageProcessorFast(metaclass=DummyObject):
|
||||
_backends = ["torchvision"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torchvision"])
|
||||
|
||||
|
||||
class ViTImageProcessorFast(metaclass=DummyObject):
|
||||
_backends = ["torchvision"]
|
||||
|
||||
|
@ -639,6 +639,13 @@ class SiglipImageProcessor(metaclass=DummyObject):
|
||||
requires_backends(self, ["vision"])
|
||||
|
||||
|
||||
class Siglip2ImageProcessor(metaclass=DummyObject):
|
||||
_backends = ["vision"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["vision"])
|
||||
|
||||
|
||||
class SmolVLMImageProcessor(metaclass=DummyObject):
|
||||
_backends = ["vision"]
|
||||
|
||||
|
0
tests/models/siglip2/__init__.py
Normal file
0
tests/models/siglip2/__init__.py
Normal file
200
tests/models/siglip2/test_image_processing_siglip2.py
Normal file
200
tests/models/siglip2/test_image_processing_siglip2.py
Normal file
@ -0,0 +1,200 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import unittest
|
||||
|
||||
import requests
|
||||
|
||||
from transformers.testing_utils import require_torch, require_vision
|
||||
from transformers.utils import is_torch_available, is_torchvision_available, is_vision_available
|
||||
|
||||
from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
|
||||
from transformers import Siglip2ImageProcessor
|
||||
|
||||
|
||||
if is_torchvision_available():
|
||||
from transformers import Siglip2ImageProcessorFast
|
||||
|
||||
|
||||
class Siglip2ImageProcessingTester:
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
batch_size=7,
|
||||
num_channels=3,
|
||||
image_size=18,
|
||||
min_resolution=30,
|
||||
max_resolution=400,
|
||||
do_resize=True,
|
||||
size=None,
|
||||
do_rescale=True,
|
||||
rescale_factor=1 / 255,
|
||||
do_normalize=True,
|
||||
image_mean=[0.5, 0.5, 0.5],
|
||||
image_std=[0.5, 0.5, 0.5],
|
||||
resample=None,
|
||||
patch_size=16,
|
||||
max_num_patches=256,
|
||||
):
|
||||
size = size if size is not None else {"height": 18, "width": 18}
|
||||
resample = resample if resample is not None else Image.Resampling.BILINEAR
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
self.num_channels = num_channels
|
||||
self.image_size = image_size
|
||||
self.min_resolution = min_resolution
|
||||
self.max_resolution = max_resolution
|
||||
self.do_resize = do_resize
|
||||
self.size = size
|
||||
self.do_rescale = do_rescale
|
||||
self.rescale_factor = rescale_factor
|
||||
self.do_normalize = do_normalize
|
||||
self.image_mean = image_mean
|
||||
self.image_std = image_std
|
||||
self.resample = resample
|
||||
self.patch_size = patch_size
|
||||
self.max_num_patches = max_num_patches
|
||||
|
||||
def prepare_image_processor_dict(self):
|
||||
return {
|
||||
"do_resize": self.do_resize,
|
||||
"do_rescale": self.do_rescale,
|
||||
"rescale_factor": self.rescale_factor,
|
||||
"do_normalize": self.do_normalize,
|
||||
"image_mean": self.image_mean,
|
||||
"image_std": self.image_std,
|
||||
"resample": self.resample,
|
||||
"patch_size": self.patch_size,
|
||||
"max_num_patches": self.max_num_patches,
|
||||
}
|
||||
|
||||
def expected_output_image_shape(self, images):
|
||||
return self.max_num_patches, self.patch_size * self.patch_size * self.num_channels
|
||||
|
||||
def prepare_image_inputs(self, equal_resolution=False, numpify=False, torchify=False):
|
||||
return prepare_image_inputs(
|
||||
batch_size=self.batch_size,
|
||||
num_channels=self.num_channels,
|
||||
min_resolution=self.min_resolution,
|
||||
max_resolution=self.max_resolution,
|
||||
equal_resolution=equal_resolution,
|
||||
numpify=numpify,
|
||||
torchify=torchify,
|
||||
)
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_vision
|
||||
# Copied from tests.models.clip.test_image_processing_clip.CLIPImageProcessingTest with CLIP->Siglip2
|
||||
class Siglip2ImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
|
||||
image_processing_class = Siglip2ImageProcessor if is_vision_available() else None
|
||||
fast_image_processing_class = Siglip2ImageProcessorFast if is_torchvision_available() else None
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.image_processor_tester = Siglip2ImageProcessingTester(self)
|
||||
|
||||
@property
|
||||
def image_processor_dict(self):
|
||||
return self.image_processor_tester.prepare_image_processor_dict()
|
||||
|
||||
# Ignore copy
|
||||
def test_image_processor_properties(self):
|
||||
for image_processing_class in self.image_processor_list:
|
||||
image_processing = image_processing_class(**self.image_processor_dict)
|
||||
self.assertTrue(hasattr(image_processing, "do_resize"))
|
||||
self.assertTrue(hasattr(image_processing, "resample"))
|
||||
self.assertTrue(hasattr(image_processing, "do_rescale"))
|
||||
self.assertTrue(hasattr(image_processing, "rescale_factor"))
|
||||
self.assertTrue(hasattr(image_processing, "do_normalize"))
|
||||
self.assertTrue(hasattr(image_processing, "image_mean"))
|
||||
self.assertTrue(hasattr(image_processing, "image_std"))
|
||||
self.assertTrue(hasattr(image_processing, "patch_size"))
|
||||
self.assertTrue(hasattr(image_processing, "max_num_patches"))
|
||||
|
||||
# Ignore copy
|
||||
def test_image_processor_from_dict_with_kwargs(self):
|
||||
for image_processing_class in self.image_processor_list:
|
||||
image_processor = image_processing_class.from_dict(self.image_processor_dict)
|
||||
self.assertEqual(image_processor.max_num_patches, 256)
|
||||
self.assertEqual(image_processor.patch_size, 16)
|
||||
|
||||
image_processor = self.image_processing_class.from_dict(
|
||||
self.image_processor_dict, patch_size=32, max_num_patches=512
|
||||
)
|
||||
self.assertEqual(image_processor.patch_size, 32)
|
||||
self.assertEqual(image_processor.max_num_patches, 512)
|
||||
|
||||
@unittest.skip(reason="not supported")
|
||||
# Ignore copy
|
||||
def test_call_numpy_4_channels(self):
|
||||
pass
|
||||
|
||||
# increase mean tolerance to 1e-3 -> 2e-3
|
||||
# Ignore copy
|
||||
def test_slow_fast_equivalence(self):
|
||||
if not self.test_slow_image_processor or not self.test_fast_image_processor:
|
||||
self.skipTest(reason="Skipping slow/fast equivalence test")
|
||||
|
||||
if self.image_processing_class is None or self.fast_image_processing_class is None:
|
||||
self.skipTest(reason="Skipping slow/fast equivalence test as one of the image processors is not defined")
|
||||
|
||||
dummy_image = Image.open(
|
||||
requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw
|
||||
)
|
||||
image_processor_slow = self.image_processing_class(**self.image_processor_dict)
|
||||
image_processor_fast = self.fast_image_processing_class(**self.image_processor_dict)
|
||||
|
||||
encoding_slow = image_processor_slow(dummy_image, return_tensors="pt")
|
||||
encoding_fast = image_processor_fast(dummy_image, return_tensors="pt")
|
||||
torch.testing.assert_close(encoding_slow.pixel_values, encoding_fast.pixel_values, atol=1e-1, rtol=1e-1)
|
||||
self.assertLessEqual(
|
||||
torch.mean(torch.abs(encoding_slow.pixel_values - encoding_fast.pixel_values)).item(), 2e-3
|
||||
)
|
||||
|
||||
# increase mean tolerance to 1e-3 -> 2e-3
|
||||
# Ignore copy
|
||||
def test_slow_fast_equivalence_batched(self):
|
||||
if not self.test_slow_image_processor or not self.test_fast_image_processor:
|
||||
self.skipTest(reason="Skipping slow/fast equivalence test")
|
||||
|
||||
if self.image_processing_class is None or self.fast_image_processing_class is None:
|
||||
self.skipTest(reason="Skipping slow/fast equivalence test as one of the image processors is not defined")
|
||||
|
||||
if hasattr(self.image_processor_tester, "do_center_crop") and self.image_processor_tester.do_center_crop:
|
||||
self.skipTest(
|
||||
reason="Skipping as do_center_crop is True and center_crop functions are not equivalent for fast and slow processors"
|
||||
)
|
||||
|
||||
dummy_images = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, torchify=True)
|
||||
image_processor_slow = self.image_processing_class(**self.image_processor_dict)
|
||||
image_processor_fast = self.fast_image_processing_class(**self.image_processor_dict)
|
||||
|
||||
encoding_slow = image_processor_slow(dummy_images, return_tensors="pt")
|
||||
encoding_fast = image_processor_fast(dummy_images, return_tensors="pt")
|
||||
|
||||
torch.testing.assert_close(encoding_slow.pixel_values, encoding_fast.pixel_values, atol=1e-1, rtol=1e-1)
|
||||
self.assertLessEqual(
|
||||
torch.mean(torch.abs(encoding_slow.pixel_values - encoding_fast.pixel_values)).item(), 2e-3
|
||||
)
|
989
tests/models/siglip2/test_modeling_siglip2.py
Normal file
989
tests/models/siglip2/test_modeling_siglip2.py
Normal file
@ -0,0 +1,989 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Testing suite for the PyTorch Siglip2 model."""
|
||||
|
||||
import inspect
|
||||
import tempfile
|
||||
import unittest
|
||||
from typing import Tuple
|
||||
|
||||
import numpy as np
|
||||
from parameterized import parameterized
|
||||
from pytest import mark
|
||||
|
||||
from transformers import Siglip2Config, Siglip2TextConfig, Siglip2VisionConfig
|
||||
from transformers.testing_utils import (
|
||||
require_flash_attn,
|
||||
require_torch,
|
||||
require_torch_gpu,
|
||||
require_torch_sdpa,
|
||||
require_vision,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
from transformers.utils import (
|
||||
is_torch_available,
|
||||
is_torch_bf16_available_on_device,
|
||||
is_torch_fp16_available_on_device,
|
||||
is_torch_sdpa_available,
|
||||
is_vision_available,
|
||||
)
|
||||
|
||||
from ...test_configuration_common import ConfigTester
|
||||
from ...test_modeling_common import (
|
||||
ModelTesterMixin,
|
||||
floats_tensor,
|
||||
ids_tensor,
|
||||
is_flaky,
|
||||
random_attention_mask,
|
||||
)
|
||||
from ...test_pipeline_mixin import PipelineTesterMixin
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from transformers import Siglip2ForImageClassification, Siglip2Model, Siglip2TextModel, Siglip2VisionModel
|
||||
|
||||
if is_torch_sdpa_available():
|
||||
from torch.nn.attention import SDPBackend, sdpa_kernel
|
||||
|
||||
if is_vision_available():
|
||||
from PIL import Image, ImageDraw
|
||||
|
||||
from transformers import Siglip2Processor
|
||||
|
||||
|
||||
class Siglip2ModelTesterMixin(ModelTesterMixin):
|
||||
def test_sdpa_can_dispatch_composite_models(self):
|
||||
for model_class in self.all_model_classes:
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
model = model_class(config)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname)
|
||||
|
||||
# Load the model with SDPA
|
||||
model_sdpa = model_class.from_pretrained(tmpdirname)
|
||||
model_sdpa = model_sdpa.eval().to(torch_device)
|
||||
|
||||
# Load model with eager attention
|
||||
model_eager = model_class.from_pretrained(
|
||||
tmpdirname,
|
||||
attn_implementation="eager",
|
||||
)
|
||||
model_eager = model_eager.eval().to(torch_device)
|
||||
|
||||
# SigLip has one shared cls attr for all models, so we assign both submodels heer
|
||||
vision_attn = text_attn = "sdpa" if model._supports_sdpa else "eager"
|
||||
|
||||
if hasattr(model_sdpa, "vision_model") and hasattr(model_sdpa, "text_model"):
|
||||
self.assertTrue(model_sdpa.vision_model.config._attn_implementation == vision_attn)
|
||||
self.assertTrue(model_sdpa.text_model.config._attn_implementation == text_attn)
|
||||
self.assertTrue(model_eager.vision_model.config._attn_implementation == "eager")
|
||||
self.assertTrue(model_eager.text_model.config._attn_implementation == "eager")
|
||||
|
||||
self.assertTrue(model_sdpa.config._attn_implementation == "sdpa")
|
||||
self.assertTrue(model_eager.config._attn_implementation == "eager")
|
||||
|
||||
for name, submodule in model_eager.named_modules():
|
||||
class_name = submodule.__class__.__name__
|
||||
if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
|
||||
raise ValueError("The eager model should not have SDPA attention layers")
|
||||
|
||||
has_sdpa = False
|
||||
for name, submodule in model_sdpa.named_modules():
|
||||
class_name = submodule.__class__.__name__
|
||||
if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
|
||||
has_sdpa = True
|
||||
break
|
||||
if not has_sdpa and model_sdpa.config.model_type != "falcon":
|
||||
raise ValueError("The SDPA model should have SDPA attention layers")
|
||||
|
||||
def test_eager_matches_sdpa_inference(
|
||||
self,
|
||||
torch_dtype: str,
|
||||
use_attention_mask_options: Tuple[bool, ...] = (True, False),
|
||||
logit_keys: Tuple[str, ...] = ("logits_per_image", "logits_per_text", "image_embeds", "text_embeds"),
|
||||
):
|
||||
if not self.all_model_classes[0]._supports_sdpa:
|
||||
self.skipTest(f"{self.all_model_classes[0].__name__} does not support SDPA")
|
||||
|
||||
if torch_dtype == "float16" and not is_torch_fp16_available_on_device(torch_device):
|
||||
self.skipTest(f"float16 not supported on {torch_device} (on the specific device currently used)")
|
||||
|
||||
if torch_dtype == "bfloat16" and not is_torch_bf16_available_on_device(torch_device):
|
||||
self.skipTest(
|
||||
f"bfloat16 not supported on {torch_device} (on the specific device currently used, e.g. Nvidia T4 GPU)"
|
||||
)
|
||||
|
||||
# Convert to torch dtype
|
||||
dtypes = {
|
||||
"float16": torch.float16,
|
||||
"bfloat16": torch.bfloat16,
|
||||
"float32": torch.float32,
|
||||
}
|
||||
torch_dtype = dtypes[torch_dtype]
|
||||
|
||||
atols = {
|
||||
torch.float32: 1e-5,
|
||||
torch.bfloat16: 3e-2,
|
||||
torch.float16: 5e-3,
|
||||
}
|
||||
rtols = {
|
||||
torch.float32: 1e-4,
|
||||
torch.bfloat16: 3e-2,
|
||||
torch.float16: 5e-3,
|
||||
}
|
||||
|
||||
atol = atols[torch_dtype]
|
||||
rtol = rtols[torch_dtype]
|
||||
|
||||
def get_mean_reldiff(msg, current_case, x, ref, atol, rtol):
|
||||
return f"{msg} {current_case}: mean relative difference: {((x - ref).abs() / (ref.abs() + 1e-12)).mean():.3e}, torch atol = {atol}, torch rtol = {rtol}"
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
model = model_class(config)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname)
|
||||
|
||||
# Load the model with SDPA
|
||||
model_sdpa = model_class.from_pretrained(tmpdirname, torch_dtype=torch_dtype)
|
||||
model_sdpa = model_sdpa.eval().to(torch_device)
|
||||
|
||||
# Load model with eager attention
|
||||
model_eager = model_class.from_pretrained(
|
||||
tmpdirname,
|
||||
torch_dtype=torch_dtype,
|
||||
attn_implementation="eager",
|
||||
)
|
||||
model_eager = model_eager.eval().to(torch_device)
|
||||
|
||||
# We use these for loops instead of parameterized.expand just for the interest of avoiding loading/saving the model each time,
|
||||
# but it would be nicer to have an efficient way to use parameterized.expand
|
||||
cases = [
|
||||
(use_mask, output_attentions, sdpa_backend, batch_size)
|
||||
for use_mask in use_attention_mask_options
|
||||
for output_attentions in [True, False]
|
||||
for sdpa_backend in [
|
||||
SDPBackend.MATH,
|
||||
[SDPBackend.FLASH_ATTENTION, SDPBackend.MATH],
|
||||
[SDPBackend.EFFICIENT_ATTENTION, SDPBackend.MATH],
|
||||
[SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION, SDPBackend.MATH],
|
||||
]
|
||||
for batch_size in [1, 5]
|
||||
]
|
||||
fail_cases = []
|
||||
|
||||
for use_mask, output_attentions, sdpa_backend, batch_size in cases:
|
||||
processed_inputs = inputs_dict.copy()
|
||||
|
||||
# convert to torch_dtype
|
||||
if "pixel_values" in processed_inputs:
|
||||
processed_inputs["pixel_values"] = processed_inputs["pixel_values"].to(torch_dtype)
|
||||
|
||||
# slice for different batch sizes
|
||||
for key in processed_inputs.keys():
|
||||
if isinstance(processed_inputs[key], (torch.Tensor, list, tuple)):
|
||||
processed_inputs[key] = processed_inputs[key][:batch_size]
|
||||
|
||||
# set attention mask with left padding
|
||||
if not use_mask:
|
||||
processed_inputs.pop("attention_mask", None)
|
||||
else:
|
||||
dummy_attention_mask = processed_inputs["attention_mask"]
|
||||
dummy_attention_mask[:] = 1
|
||||
dummy_attention_mask[:, :1] = 0
|
||||
processed_inputs["attention_mask"] = dummy_attention_mask
|
||||
|
||||
processed_inputs["output_attentions"] = output_attentions
|
||||
processed_inputs["output_hidden_states"] = True
|
||||
|
||||
current_case = (
|
||||
f"padding_side=left, use_mask={use_mask}, batch_size={batch_size}, sdpa_backend={sdpa_backend}"
|
||||
)
|
||||
|
||||
prepared_inputs = self._prepare_for_class(processed_inputs, model_class)
|
||||
|
||||
with torch.no_grad():
|
||||
try:
|
||||
with sdpa_kernel(sdpa_backend):
|
||||
outputs_eager = model_eager(**prepared_inputs)
|
||||
outputs_sdpa = model_sdpa(**prepared_inputs)
|
||||
except Exception as e:
|
||||
fail_cases.append(f"{current_case}: {e}")
|
||||
continue
|
||||
|
||||
for key in logit_keys:
|
||||
eager_logits = outputs_eager[key]
|
||||
sdpa_logits = outputs_sdpa[key]
|
||||
|
||||
if use_mask:
|
||||
eager_logits = eager_logits[:, 1:]
|
||||
sdpa_logits = sdpa_logits[:, 1:]
|
||||
|
||||
is_close = torch.allclose(eager_logits, sdpa_logits, atol=atol, rtol=rtol)
|
||||
if not is_close:
|
||||
fail_cases.append(get_mean_reldiff(key, current_case, sdpa_logits, eager_logits, atol, rtol))
|
||||
|
||||
self.assertTrue(len(fail_cases) == 0, "\n".join(fail_cases))
|
||||
|
||||
@require_flash_attn
|
||||
@require_torch_gpu
|
||||
@mark.flash_attn_test
|
||||
@slow
|
||||
def test_flash_attn_2_inference_equivalence(self):
|
||||
dtype = torch.float16
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
if not model_class._supports_flash_attn_2:
|
||||
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
|
||||
|
||||
# Prepare inputs
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
if "pixel_values" in inputs_dict:
|
||||
inputs_dict["pixel_values"] = inputs_dict["pixel_values"].to(dtype)
|
||||
|
||||
# Separate masks
|
||||
attention_masks = {}
|
||||
if "attention_mask" in inputs_dict:
|
||||
# attention_masks["attention_mask"] = inputs_dict.pop("attention_mask")
|
||||
inputs_dict["attention_mask"] = None
|
||||
if "pixel_attention_mask" in inputs_dict:
|
||||
attention_masks["pixel_attention_mask"] = inputs_dict.pop("pixel_attention_mask")
|
||||
inputs_dict["pixel_attention_mask"] = None
|
||||
|
||||
# Save and load model with flash attention 2 and eager attentions
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model = model_class(config)
|
||||
model.save_pretrained(tmp_dir)
|
||||
|
||||
model = model_class.from_pretrained(tmp_dir, torch_dtype=dtype)
|
||||
model_fa = model_class.from_pretrained(
|
||||
tmp_dir, torch_dtype=dtype, attn_implementation="flash_attention_2"
|
||||
)
|
||||
|
||||
model_fa.to(torch_device)
|
||||
model.to(torch_device)
|
||||
|
||||
# Run forward pass without attention masks
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs_dict, output_hidden_states=True)
|
||||
outputs_fa = model_fa(**inputs_dict, output_hidden_states=True)
|
||||
|
||||
# Choose which key to compare
|
||||
key = [k for k in ["logits", "logits_per_image", "last_hidden_state"] if k in outputs][0]
|
||||
|
||||
torch.testing.assert_close(outputs[key], outputs_fa[key], atol=4e-2, rtol=4e-2)
|
||||
|
||||
# Run forward pass with attention masks
|
||||
inputs_dict.update(attention_masks)
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs_dict, output_hidden_states=True)
|
||||
outputs_fa = model_fa(**inputs_dict, output_hidden_states=True)
|
||||
|
||||
output_tensor = outputs[key]
|
||||
output_tensor_fa = outputs_fa[key]
|
||||
|
||||
# Mask out padded tokens, they are different for SDPA and Flash Attention 2
|
||||
if key == "last_hidden_state" and "pixel_attention_mask" in inputs_dict:
|
||||
output_tensor = output_tensor * inputs_dict["pixel_attention_mask"][..., None]
|
||||
output_tensor_fa = output_tensor_fa * inputs_dict["pixel_attention_mask"][..., None]
|
||||
elif key == "last_hidden_state" and inputs_dict.get("attention_mask", None) is not None:
|
||||
output_tensor = output_tensor * inputs_dict["attention_mask"][..., None]
|
||||
output_tensor_fa = output_tensor_fa * inputs_dict["attention_mask"][..., None]
|
||||
|
||||
torch.testing.assert_close(output_tensor, output_tensor_fa, atol=4e-2, rtol=4e-2)
|
||||
|
||||
# Check with inference + dropout
|
||||
model.train()
|
||||
_ = model_fa(**inputs_dict, output_hidden_states=True)
|
||||
|
||||
@unittest.skip(reason="Siglip2 has default right padding (tested in test_flash_attn_2_inference_equivalence)")
|
||||
def test_flash_attn_2_inference_equivalence_right_padding(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="SDPA can't dispatch on flash with not None `attention_mask`")
|
||||
def test_sdpa_can_dispatch_on_flash(self):
|
||||
pass
|
||||
|
||||
|
||||
class Siglip2VisionModelTester:
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
batch_size=12,
|
||||
num_patches=16,
|
||||
image_num_patches=24,
|
||||
patch_size=2,
|
||||
num_channels=3,
|
||||
is_training=True,
|
||||
hidden_size=32,
|
||||
num_hidden_layers=2,
|
||||
num_attention_heads=4,
|
||||
intermediate_size=37,
|
||||
dropout=0.1,
|
||||
attention_dropout=0.1,
|
||||
initializer_range=0.02,
|
||||
scope=None,
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
self.num_patches = num_patches
|
||||
self.patch_size = patch_size
|
||||
self.num_channels = num_channels
|
||||
self.is_training = is_training
|
||||
self.hidden_size = hidden_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.intermediate_size = intermediate_size
|
||||
self.dropout = dropout
|
||||
self.attention_dropout = attention_dropout
|
||||
self.initializer_range = initializer_range
|
||||
self.scope = scope
|
||||
self.seq_length = image_num_patches
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
pixel_values = floats_tensor(
|
||||
[self.batch_size, self.seq_length, self.num_channels * self.patch_size * self.patch_size]
|
||||
)
|
||||
pixel_attention_mask = torch.zeros(self.batch_size, self.seq_length, device=torch_device, dtype=torch.long)
|
||||
|
||||
spatial_shapes = [
|
||||
(height, width)
|
||||
for height in range(1, self.seq_length)
|
||||
for width in range(1, self.seq_length)
|
||||
if height * width <= self.seq_length
|
||||
] * self.batch_size
|
||||
spatial_shapes = spatial_shapes[: self.batch_size]
|
||||
spatial_shapes = torch.tensor(spatial_shapes, device=torch_device, dtype=torch.long)
|
||||
|
||||
for i, (height, width) in enumerate(spatial_shapes):
|
||||
pixel_attention_mask[i, : height * width] = 1
|
||||
|
||||
config = self.get_config()
|
||||
|
||||
return config, pixel_values, pixel_attention_mask, spatial_shapes
|
||||
|
||||
def get_config(self):
|
||||
return Siglip2VisionConfig(
|
||||
num_patches=self.num_patches,
|
||||
patch_size=self.patch_size,
|
||||
num_channels=self.num_channels,
|
||||
hidden_size=self.hidden_size,
|
||||
num_hidden_layers=self.num_hidden_layers,
|
||||
num_attention_heads=self.num_attention_heads,
|
||||
intermediate_size=self.intermediate_size,
|
||||
dropout=self.dropout,
|
||||
attention_dropout=self.attention_dropout,
|
||||
initializer_range=self.initializer_range,
|
||||
)
|
||||
|
||||
def create_and_check_model(self, config, pixel_values, pixel_attention_mask, spatial_shapes):
|
||||
model = Siglip2VisionModel(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
result = model(pixel_values, pixel_attention_mask, spatial_shapes)
|
||||
|
||||
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
|
||||
self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size))
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config, pixel_values, pixel_attention_mask, spatial_shapes = self.prepare_config_and_inputs()
|
||||
inputs_dict = {
|
||||
"pixel_values": pixel_values,
|
||||
"pixel_attention_mask": pixel_attention_mask,
|
||||
"spatial_shapes": spatial_shapes,
|
||||
}
|
||||
return config, inputs_dict
|
||||
|
||||
|
||||
@require_torch
|
||||
class Siglip2VisionModelTest(Siglip2ModelTesterMixin, unittest.TestCase):
|
||||
"""
|
||||
Here we also overwrite some of the tests of test_modeling_common.py, as SIGLIP2 does not use input_ids, inputs_embeds,
|
||||
attention_mask and seq_length.
|
||||
"""
|
||||
|
||||
all_model_classes = (Siglip2VisionModel,) if is_torch_available() else ()
|
||||
fx_compatible = False
|
||||
test_pruning = False
|
||||
test_resize_embeddings = False
|
||||
test_head_masking = False
|
||||
# MP works but offload doesn't work when the MultiheadAttention is offloaded
|
||||
# TODO: One potential solution would be to add to set preload_module_classes = ["Siglip2MultiheadAttentionPoolingHead"]
|
||||
# in the dispatch_model function
|
||||
test_cpu_offload = False
|
||||
test_disk_offload_safetensors = False
|
||||
test_disk_offload_bin = False
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = Siglip2VisionModelTester(self)
|
||||
self.config_tester = ConfigTester(
|
||||
self, config_class=Siglip2VisionConfig, has_text_modality=False, hidden_size=37
|
||||
)
|
||||
|
||||
def test_config(self):
|
||||
self.config_tester.run_common_tests()
|
||||
|
||||
@unittest.skip(reason="SIGLIP2 does not use inputs_embeds")
|
||||
def test_inputs_embeds(self):
|
||||
pass
|
||||
|
||||
def test_model_get_set_embeddings(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
self.assertIsInstance(model.get_input_embeddings(), (nn.Module))
|
||||
x = model.get_output_embeddings()
|
||||
self.assertTrue(x is None or isinstance(x, nn.Linear))
|
||||
|
||||
def test_forward_signature(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
signature = inspect.signature(model.forward)
|
||||
# signature.parameters is an OrderedDict => so arg_names order is deterministic
|
||||
arg_names = [*signature.parameters.keys()]
|
||||
|
||||
expected_arg_names = ["pixel_values"]
|
||||
self.assertListEqual(arg_names[:1], expected_arg_names)
|
||||
|
||||
def test_model(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||
|
||||
@unittest.skip(reason="Siglip2VisionModel does not support standalone training")
|
||||
def test_training(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Siglip2VisionModel does not support standalone training")
|
||||
def test_training_gradient_checkpointing(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Siglip2VisionModel does not support standalone training")
|
||||
def test_training_gradient_checkpointing_use_reentrant(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Siglip2VisionModel does not support standalone training")
|
||||
def test_training_gradient_checkpointing_use_reentrant_false(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Siglip2VisionModel has no base class and is not available in MODEL_MAPPING")
|
||||
def test_save_load_fast_init_from_base(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Siglip2VisionModel has no base class and is not available in MODEL_MAPPING")
|
||||
def test_save_load_fast_init_to_base(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Siglip2 uses the same initialization scheme as the Flax original implementation")
|
||||
def test_initialization(self):
|
||||
pass
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
model_name = "google/siglip2-base-patch16-naflex"
|
||||
model = Siglip2VisionModel.from_pretrained(model_name)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
@parameterized.expand([("float16",), ("bfloat16",), ("float32",)])
|
||||
@require_torch_sdpa
|
||||
@slow
|
||||
@is_flaky()
|
||||
def test_eager_matches_sdpa_inference(self, torch_dtype: str):
|
||||
super().test_eager_matches_sdpa_inference(
|
||||
torch_dtype=torch_dtype,
|
||||
logit_keys=("pooler_output", "last_hidden_state"),
|
||||
use_attention_mask_options=(False,),
|
||||
)
|
||||
|
||||
@require_torch_sdpa
|
||||
def test_sdpa_can_dispatch_composite_models(self):
|
||||
super().test_sdpa_can_dispatch_composite_models()
|
||||
|
||||
|
||||
class Siglip2TextModelTester:
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
batch_size=12,
|
||||
seq_length=7,
|
||||
is_training=True,
|
||||
use_input_mask=True,
|
||||
use_labels=True,
|
||||
vocab_size=99,
|
||||
hidden_size=32,
|
||||
num_hidden_layers=2,
|
||||
num_attention_heads=4,
|
||||
intermediate_size=37,
|
||||
dropout=0.1,
|
||||
attention_dropout=0.1,
|
||||
max_position_embeddings=512,
|
||||
initializer_range=0.02,
|
||||
scope=None,
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
self.seq_length = seq_length
|
||||
self.is_training = is_training
|
||||
self.use_input_mask = use_input_mask
|
||||
self.use_labels = use_labels
|
||||
self.vocab_size = vocab_size
|
||||
self.hidden_size = hidden_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.intermediate_size = intermediate_size
|
||||
self.dropout = dropout
|
||||
self.attention_dropout = attention_dropout
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.initializer_range = initializer_range
|
||||
self.scope = scope
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||
|
||||
input_mask = None
|
||||
if self.use_input_mask:
|
||||
input_mask = random_attention_mask([self.batch_size, self.seq_length])
|
||||
|
||||
if input_mask is not None:
|
||||
batch_size, seq_length = input_mask.shape
|
||||
rnd_start_indices = np.random.randint(1, seq_length - 1, size=(batch_size,))
|
||||
for batch_idx, start_index in enumerate(rnd_start_indices):
|
||||
input_mask[batch_idx, :start_index] = 1
|
||||
input_mask[batch_idx, start_index:] = 0
|
||||
|
||||
config = self.get_config()
|
||||
|
||||
return config, input_ids, input_mask
|
||||
|
||||
def get_config(self):
|
||||
return Siglip2TextConfig(
|
||||
vocab_size=self.vocab_size,
|
||||
hidden_size=self.hidden_size,
|
||||
num_hidden_layers=self.num_hidden_layers,
|
||||
num_attention_heads=self.num_attention_heads,
|
||||
intermediate_size=self.intermediate_size,
|
||||
dropout=self.dropout,
|
||||
attention_dropout=self.attention_dropout,
|
||||
max_position_embeddings=self.max_position_embeddings,
|
||||
initializer_range=self.initializer_range,
|
||||
)
|
||||
|
||||
def create_and_check_model(self, config, input_ids, input_mask):
|
||||
model = Siglip2TextModel(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
result = model(input_ids, attention_mask=input_mask)
|
||||
result = model(input_ids)
|
||||
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
|
||||
self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size))
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
config, input_ids, input_mask = config_and_inputs
|
||||
inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask}
|
||||
return config, inputs_dict
|
||||
|
||||
|
||||
@require_torch
|
||||
class Siglip2TextModelTest(Siglip2ModelTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (Siglip2TextModel,) if is_torch_available() else ()
|
||||
fx_compatible = False
|
||||
test_resize_embeddings = False
|
||||
test_pruning = False
|
||||
test_head_masking = False
|
||||
model_split_percents = [0.5, 0.8, 0.9]
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = Siglip2TextModelTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=Siglip2TextConfig, hidden_size=37)
|
||||
|
||||
def test_config(self):
|
||||
self.config_tester.run_common_tests()
|
||||
|
||||
def test_model(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||
|
||||
@unittest.skip(reason="Siglip2TextModel does not support standalone training")
|
||||
def test_training(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Siglip2TextModel does not support standalone training")
|
||||
def test_training_gradient_checkpointing(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Siglip2TextModel does not support standalone training")
|
||||
def test_training_gradient_checkpointing_use_reentrant(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Siglip2TextModel does not support standalone training")
|
||||
def test_training_gradient_checkpointing_use_reentrant_false(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Siglip2 does not use inputs_embeds")
|
||||
def test_inputs_embeds(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Siglip2TextModel has no base class and is not available in MODEL_MAPPING")
|
||||
def test_save_load_fast_init_from_base(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Siglip2TextModel has no base class and is not available in MODEL_MAPPING")
|
||||
def test_save_load_fast_init_to_base(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Siglip2 uses the same initialization scheme as the Flax original implementation")
|
||||
def test_initialization(self):
|
||||
pass
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
model_name = "google/siglip2-base-patch16-naflex"
|
||||
model = Siglip2TextModel.from_pretrained(model_name)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
@parameterized.expand([("float16",), ("bfloat16",), ("float32",)])
|
||||
@require_torch_sdpa
|
||||
@slow
|
||||
@is_flaky()
|
||||
def test_eager_matches_sdpa_inference(self, torch_dtype: str):
|
||||
super().test_eager_matches_sdpa_inference(
|
||||
torch_dtype=torch_dtype,
|
||||
logit_keys=("pooler_output", "last_hidden_state"),
|
||||
use_attention_mask_options=(False, True),
|
||||
)
|
||||
|
||||
@require_torch_sdpa
|
||||
def test_sdpa_can_dispatch_composite_models(self):
|
||||
super().test_sdpa_can_dispatch_composite_models()
|
||||
|
||||
|
||||
class Siglip2ModelTester:
|
||||
def __init__(self, parent, text_kwargs=None, vision_kwargs=None, is_training=True):
|
||||
if text_kwargs is None:
|
||||
text_kwargs = {}
|
||||
if vision_kwargs is None:
|
||||
vision_kwargs = {}
|
||||
|
||||
self.parent = parent
|
||||
self.text_model_tester = Siglip2TextModelTester(parent, **text_kwargs)
|
||||
self.vision_model_tester = Siglip2VisionModelTester(parent, **vision_kwargs)
|
||||
self.batch_size = self.text_model_tester.batch_size # need bs for batching_equivalence test
|
||||
self.is_training = is_training
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
text_config, input_ids, attention_mask = self.text_model_tester.prepare_config_and_inputs()
|
||||
vision_config, pixel_values, pixel_attention_mask, spatial_shapes = (
|
||||
self.vision_model_tester.prepare_config_and_inputs()
|
||||
)
|
||||
|
||||
config = self.get_config()
|
||||
|
||||
return config, input_ids, attention_mask, pixel_values, pixel_attention_mask, spatial_shapes
|
||||
|
||||
def get_config(self):
|
||||
return Siglip2Config.from_text_vision_configs(
|
||||
self.text_model_tester.get_config(),
|
||||
self.vision_model_tester.get_config(),
|
||||
)
|
||||
|
||||
def create_and_check_model(
|
||||
self, config, input_ids, attention_mask, pixel_values, pixel_attention_mask, spatial_shapes
|
||||
):
|
||||
model = Siglip2Model(config).to(torch_device).eval()
|
||||
with torch.no_grad():
|
||||
result = model(input_ids, pixel_values, pixel_attention_mask, spatial_shapes, attention_mask)
|
||||
self.parent.assertEqual(
|
||||
result.logits_per_image.shape, (self.vision_model_tester.batch_size, self.text_model_tester.batch_size)
|
||||
)
|
||||
self.parent.assertEqual(
|
||||
result.logits_per_text.shape, (self.text_model_tester.batch_size, self.vision_model_tester.batch_size)
|
||||
)
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
config, input_ids, attention_mask, pixel_values, pixel_attention_mask, spatial_shapes = config_and_inputs
|
||||
inputs_dict = {
|
||||
"input_ids": input_ids,
|
||||
"pixel_values": pixel_values,
|
||||
"pixel_attention_mask": pixel_attention_mask,
|
||||
"spatial_shapes": spatial_shapes,
|
||||
"attention_mask": attention_mask,
|
||||
"position_ids": None,
|
||||
"return_loss": False,
|
||||
}
|
||||
return config, inputs_dict
|
||||
|
||||
|
||||
@require_torch
|
||||
class Siglip2ModelTest(Siglip2ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (Siglip2Model,) if is_torch_available() else ()
|
||||
pipeline_model_mapping = {"feature-extraction": Siglip2Model} if is_torch_available() else {}
|
||||
fx_compatible = False
|
||||
test_head_masking = False
|
||||
test_pruning = False
|
||||
test_resize_embeddings = False
|
||||
test_attention_outputs = False
|
||||
# MP works but offload doesn't work when the MultiheadAttention is offloaded
|
||||
# TODO: One potential solution would be to add to set preload_module_classes = ["Siglip2MultiheadAttentionPoolingHead"]
|
||||
# in the dispatch_model function
|
||||
test_cpu_offload = False
|
||||
test_disk_offload_safetensors = False
|
||||
test_disk_offload_bin = False
|
||||
_is_composite = True
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = Siglip2ModelTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=Siglip2Config, has_text_modality=False)
|
||||
|
||||
def test_config(self):
|
||||
self.config_tester.run_common_tests()
|
||||
|
||||
def test_model(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||
|
||||
@unittest.skip(reason="Hidden_states is tested in individual model tests")
|
||||
def test_hidden_states_output(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Inputs_embeds is tested in individual model tests")
|
||||
def test_inputs_embeds(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Retain_grad is tested in individual model tests")
|
||||
def test_retain_grad_hidden_states_attentions(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Siglip2Model does not have input/output embeddings")
|
||||
def test_model_get_set_embeddings(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Siglip2 uses the same initialization scheme as the Flax original implementation")
|
||||
def test_initialization(self):
|
||||
pass
|
||||
|
||||
def test_load_vision_text_config(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
# Save Siglip2Config and check if we can load Siglip2VisionConfig from it
|
||||
with tempfile.TemporaryDirectory() as tmp_dir_name:
|
||||
config.save_pretrained(tmp_dir_name)
|
||||
vision_config = Siglip2VisionConfig.from_pretrained(tmp_dir_name)
|
||||
self.assertDictEqual(config.vision_config.to_dict(), vision_config.to_dict())
|
||||
|
||||
# Save Siglip2Config and check if we can load Siglip2TextConfig from it
|
||||
with tempfile.TemporaryDirectory() as tmp_dir_name:
|
||||
config.save_pretrained(tmp_dir_name)
|
||||
text_config = Siglip2TextConfig.from_pretrained(tmp_dir_name)
|
||||
self.assertDictEqual(config.text_config.to_dict(), text_config.to_dict())
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
model_name = "google/siglip2-base-patch16-naflex"
|
||||
model = Siglip2Model.from_pretrained(model_name)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
@require_flash_attn
|
||||
@require_torch_gpu
|
||||
@mark.flash_attn_test
|
||||
def test_flash_attn_2_inference_equivalence_right_padding(self):
|
||||
self.skipTest("Siglip2 does not support right padding")
|
||||
|
||||
@parameterized.expand([("float16",), ("bfloat16",), ("float32",)])
|
||||
@require_torch_sdpa
|
||||
@slow
|
||||
@is_flaky()
|
||||
def test_eager_matches_sdpa_inference(self, torch_dtype: str):
|
||||
super().test_eager_matches_sdpa_inference(
|
||||
torch_dtype=torch_dtype,
|
||||
logit_keys=("logits_per_image", "logits_per_text", "image_embeds", "text_embeds"),
|
||||
use_attention_mask_options=(False, True),
|
||||
)
|
||||
|
||||
@require_torch_sdpa
|
||||
def test_sdpa_can_dispatch_composite_models(self):
|
||||
super().test_sdpa_can_dispatch_composite_models()
|
||||
|
||||
|
||||
class Siglip2ForImageClassificationModelTester(Siglip2ModelTester):
|
||||
def __init__(self, parent):
|
||||
super().__init__(parent)
|
||||
self.batch_size = self.vision_model_tester.batch_size
|
||||
self.num_hidden_layers = self.vision_model_tester.num_hidden_layers
|
||||
self.hidden_size = self.vision_model_tester.hidden_size
|
||||
self.seq_length = self.vision_model_tester.seq_length
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
_, pixel_values, pixel_attention_mask, spatial_shapes = self.vision_model_tester.prepare_config_and_inputs()
|
||||
config = self.get_config()
|
||||
|
||||
return config, pixel_values, pixel_attention_mask, spatial_shapes
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
config, pixel_values, pixel_attention_mask, spatial_shapes = config_and_inputs
|
||||
inputs_dict = {
|
||||
"pixel_values": pixel_values,
|
||||
"pixel_attention_mask": pixel_attention_mask,
|
||||
"spatial_shapes": spatial_shapes,
|
||||
}
|
||||
return config, inputs_dict
|
||||
|
||||
|
||||
@require_torch
|
||||
class Siglip2ForImageClassificationModelTest(Siglip2ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (Siglip2ForImageClassification,) if is_torch_available() else ()
|
||||
pipeline_model_mapping = {"image-classification": Siglip2ForImageClassification} if is_torch_available() else {}
|
||||
fx_compatible = False
|
||||
test_head_masking = False
|
||||
test_pruning = False
|
||||
test_resize_embeddings = False
|
||||
test_attention_outputs = False
|
||||
# MP works but offload doesn't work when the MultiheadAttention is offloaded
|
||||
# TODO: One potential solution would be to add to set preload_module_classes = ["Siglip2MultiheadAttentionPoolingHead"]
|
||||
# in the dispatch_model function
|
||||
test_cpu_offload = False
|
||||
test_disk_offload_safetensors = False
|
||||
test_disk_offload_bin = False
|
||||
_is_composite = True
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = Siglip2ForImageClassificationModelTester(self)
|
||||
|
||||
@unittest.skip(reason="Siglip2ForImageClassification does not support inputs_embeds")
|
||||
def test_inputs_embeds(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Siglip2ForImageClassification does not support inputs_embeds")
|
||||
def test_model_get_set_embeddings(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Siglip2ForImageClassification does not support gradient checkpointing yet")
|
||||
def test_training_gradient_checkpointing(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Siglip2ForImageClassification does not support gradient checkpointing yet")
|
||||
def test_training_gradient_checkpointing_use_reentrant(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Siglip2ForImageClassification does not support gradient checkpointing yet")
|
||||
def test_training_gradient_checkpointing_use_reentrant_false(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Siglip2 uses the same initialization scheme as the Flax original implementation")
|
||||
def test_initialization(self):
|
||||
pass
|
||||
|
||||
@parameterized.expand([("float16",), ("bfloat16",), ("float32",)])
|
||||
@require_torch_sdpa
|
||||
@slow
|
||||
@is_flaky()
|
||||
def test_eager_matches_sdpa_inference(self, torch_dtype: str):
|
||||
super().test_eager_matches_sdpa_inference(
|
||||
torch_dtype=torch_dtype, logit_keys=("logits",), use_attention_mask_options=(False,)
|
||||
)
|
||||
|
||||
@require_torch_sdpa
|
||||
def test_sdpa_can_dispatch_composite_models(self):
|
||||
super().test_sdpa_can_dispatch_composite_models()
|
||||
|
||||
|
||||
# Draw a circle on an images with different aspect ratios
|
||||
def prepare_images():
|
||||
shapes = [(224, 224), (1024, 1024), (224, 1024)]
|
||||
images = []
|
||||
for height, width in shapes:
|
||||
image = Image.new("RGB", (width, height), color="red")
|
||||
draw = ImageDraw.Draw(image)
|
||||
center_x = image.width // 2
|
||||
center_y = image.height // 2
|
||||
radius = min(center_x, center_y) // 8 * 7
|
||||
draw.ellipse(
|
||||
(center_x - radius, center_y - radius, center_x + radius, center_y + radius),
|
||||
fill="blue",
|
||||
outline="green",
|
||||
width=image.width // 20,
|
||||
)
|
||||
images.append(image)
|
||||
return images
|
||||
|
||||
|
||||
@require_vision
|
||||
@require_torch
|
||||
class Siglip2ModelIntegrationTest(unittest.TestCase):
|
||||
@slow
|
||||
def test_inference(self):
|
||||
model_name = "google/siglip2-base-patch16-naflex"
|
||||
model = Siglip2Model.from_pretrained(model_name).to(torch_device)
|
||||
processor = Siglip2Processor.from_pretrained(model_name)
|
||||
|
||||
images = prepare_images()
|
||||
text = [
|
||||
"circle",
|
||||
"ellipsoid",
|
||||
"blue circle on red background",
|
||||
"blue circle with green border on red background",
|
||||
"green circle on red background",
|
||||
"a dog",
|
||||
"a blue dog with a green border on a red background",
|
||||
]
|
||||
|
||||
inputs = processor(text=text, images=images, return_tensors="pt")
|
||||
inputs = inputs.to(torch_device)
|
||||
|
||||
# forward pass
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs)
|
||||
|
||||
logits_per_image = outputs.logits_per_image
|
||||
logits_per_text = outputs.logits_per_text
|
||||
|
||||
# verify the logits shape
|
||||
self.assertEqual(
|
||||
logits_per_image.shape,
|
||||
torch.Size((inputs.pixel_values.shape[0], inputs.input_ids.shape[0])),
|
||||
)
|
||||
self.assertEqual(
|
||||
logits_per_text.shape,
|
||||
torch.Size((inputs.input_ids.shape[0], inputs.pixel_values.shape[0])),
|
||||
)
|
||||
|
||||
# verify the logits values
|
||||
# fmt: off
|
||||
expected_logits_per_text = torch.tensor(
|
||||
[
|
||||
[ 1.0195, -0.0280, -1.4468],
|
||||
[ -4.5395, -6.2269, -1.5667],
|
||||
[ 4.1757, 5.0358, 3.5159],
|
||||
[ 9.4264, 10.1879, 6.3353],
|
||||
[ 2.4409, 3.1058, 4.5491],
|
||||
[-12.3230, -13.7355, -13.4632],
|
||||
[ 1.1520, 1.1687, -1.9647],
|
||||
]
|
||||
).to(torch_device)
|
||||
# fmt: on
|
||||
|
||||
torch.testing.assert_close(outputs.logits_per_text, expected_logits_per_text, rtol=1e-3, atol=1e-3)
|
@ -1175,6 +1175,10 @@ class ModelTesterMixin:
|
||||
traced_model = torch.jit.trace(
|
||||
model, (pixel_values, prompt_pixel_values, prompt_masks), check_trace=False
|
||||
) # when traced model is checked, an error is produced due to name mangling
|
||||
elif "Siglip2" in model_class.__name__:
|
||||
outputs = model(**inputs)
|
||||
example_inputs = [t for t in inputs.values() if isinstance(t, torch.Tensor)]
|
||||
traced_model = torch.jit.trace(model, example_inputs, check_trace=False)
|
||||
else:
|
||||
main_input = inputs[main_input_name]
|
||||
|
||||
@ -3035,6 +3039,7 @@ class ModelTesterMixin:
|
||||
"wav2vec2.masked_spec_embed",
|
||||
"Wav2Vec2ForSequenceClassification",
|
||||
"CLIPForImageClassification",
|
||||
"Siglip2ForImageClassification",
|
||||
"RegNetForImageClassification",
|
||||
"ResNetForImageClassification",
|
||||
"UniSpeechSatForSequenceClassification",
|
||||
|
@ -334,6 +334,8 @@ IGNORE_NON_AUTO_CONFIGURED = PRIVATE_MODELS.copy() + [
|
||||
"SegGptForImageSegmentation",
|
||||
"SiglipVisionModel",
|
||||
"SiglipTextModel",
|
||||
"Siglip2VisionModel",
|
||||
"Siglip2TextModel",
|
||||
"ChameleonVQVAE", # no autoclass for VQ-VAE models
|
||||
"VitPoseForPoseEstimation",
|
||||
"CLIPTextModel",
|
||||
|
Loading…
Reference in New Issue
Block a user