Add ColQwen2 to 🤗 transformers (#35778)

* feat: add colqwen2 (wip)

* tests: fix test_attention_outputs

* tests: reduce hidden size to accelerate tests

* tests: fix `test_attention_outputs` 🥳

* fix: fix wrong parent class for `ColQwen2ForRetrievalOutput`

* fix: minor typing and style changes

* chore: run `make style`

* feat: remove redundant `max_num_visual_tokens` attribute in `ColQwen2Processor`

* tests: tweak comments

* style: apply ruff formatter

* feat: move default values for `visual_prompt_prefix` and `query_prefix`

* docs: update ColQwen2 model card

* docs: tweak model cards

* docs: add required example config checkpoint

* tests: update expected scores in integration test

* docs: tweak quickstart snippets

* fix: address PR comments

* tests: fix colqwen2 tests + tweak comment in colpali test

* tests: unskip useful tests

* fix: fix bug when `visual_prompt_prefix` or `query_prefix` is an empty string

* fix: fix ColPali outputs when `return_dict == False`

* fix: fix issue with PaliGemma output not being a dict

* docs: set default dtype to bfloat16 in quickstart snippets

* fix: fix error when `return_dict=False` in ColPali and ColQwen2

* tests: fix special tokens not being replaced in input_ids

* style: fix lint

* fix: `ColQwen2Processor`'s `padding_side` is now set from `processor_config.json`

* fix: remove unused `padding_side` in ColQwen2 model

* docs: update ColQwen2's model doc

* fix: fix harcoded vlm backbone class in ColQwen2Config

* fix: remove `padding_side` from ColQwen2Processor as should fed from kwargs

* docs: fix typo in model docstring

* docs: add illuin mention in model docs

* fix: let `padding_size` be handled by `tokenizer_config.json`

* docs: add colpali reference url in colqwen2's model doc

* docs: add Hf mention in model docs

* docs: add late interaction mention in model docs

* docs: tweak colqwen2 model doc

* docs: update reference checkpoint for ColPali to v1.3

* docs: simplify quickstart snippets

* docs: remove redundant `.eval()`

* refactor:  use `can_return_tuple` decorator for ColPali and ColQwen2

* docs: fix copyright date

* docs: add missing copyright in tests

* fix: raise error when `initializer_range` is not in config

* docs: remove redundant `.eval()` in colpali doc

* fix: fix `get_text_config` now that Qwen2VL has a proper `text_config` attribute

See https://github.com/huggingface/transformers/pull/37268 for details about changes in Qwen2VL's config.

* fix: add missing `initializer_range` attribute in `ColQwen2Config`

* fix: use `get_text_config` in `resize_token_embeddings`

* update colwen2 with auto_docstring

* docs: fix wrong copyright year

* chore: remove `raise` as `initializer_range` has a default value in `ColQwen2Config`

* refactor: merge `inner_forward` into `forward`

* Refactor colqwen2 after refactoring of qwen2VL, use modular for modeling code

* protect torch import in modular to protect in processing

* protect torch import in modular to protect in processing

* tests: fix hf model path in ColQwen2 integration test

* docs: clarify `attn_implementation` and add comments

* docs: add fallback snippet for using offline PIL dummy images

* docs: temporarily revert attn_implementation to `None` while sdpa is not fixed

* docs: tweaks in colpali/colqwen2 quick start snippets

* fix: add missing flags to enable SDPA/Flex Attention in ColQwen2 model

* fix: add missing changes in modular file

* fix modeling tests

---------

Co-authored-by: yonigozlan <yoni.gozlan@huggingface.co>
This commit is contained in:
Tony Wu 2025-06-02 14:58:01 +02:00 committed by GitHub
parent beaed8ce01
commit c72ba69441
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
23 changed files with 2289 additions and 95 deletions

View File

@ -937,6 +937,8 @@
title: CLVP
- local: model_doc/colpali
title: ColPali
- local: model_doc/colqwen2
title: ColQwen2
- local: model_doc/data2vec
title: Data2Vec
- local: model_doc/deplot

View File

@ -20,9 +20,11 @@ rendered properly in your Markdown viewer.
# ColPali
[ColPali](https://huggingface.co/papers/2407.01449) is a model designed to retrieve documents by analyzing their visual features. Unlike traditional systems that rely heavily on text extraction and OCR, ColPali treats each page as an image. It uses [Paligemma-3B](./paligemma) to capture not only text, but also the layout, tables, charts, and other visual elements to create detailed embeddings. This offers a more comprehensive understanding of documents and enables more efficient and accurate retrieval.
[ColPali](https://huggingface.co/papers/2407.01449) is a model designed to retrieve documents by analyzing their visual features. Unlike traditional systems that rely heavily on text extraction and OCR, ColPali treats each page as an image. It uses [Paligemma-3B](./paligemma) to capture not only text, but also the layout, tables, charts, and other visual elements to create detailed multi-vector embeddings that can be used for retrieval by computing pairwise late interaction similarity scores. This offers a more comprehensive understanding of documents and enables more efficient and accurate retrieval.
You can find all the original ColPali checkpoints under the [ColPali](https://huggingface.co/collections/vidore/hf-native-colvision-models-6755d68fc60a8553acaa96f7) collection.
This model was contributed by [@tonywu71](https://huggingface.co/tonywu71) (ILLUIN Technology) and [@yonigozlan](https://huggingface.co/yonigozlan) (HuggingFace).
You can find all the original ColPali checkpoints under Vidore's [Hf-native ColVision Models](https://huggingface.co/collections/vidore/hf-native-colvision-models-6755d68fc60a8553acaa96f7) collection.
> [!TIP]
> Click on the ColPali models in the right sidebar for more examples of how to use ColPali for image retrieval.
@ -30,21 +32,25 @@ You can find all the original ColPali checkpoints under the [ColPali](https://hu
<hfoptions id="usage">
<hfoption id="image retrieval">
```py
```python
import requests
import torch
from PIL import Image
from transformers import ColPaliForRetrieval, ColPaliProcessor
# Load model (bfloat16 support is limited; fallback to float32 if needed)
model = ColPaliForRetrieval.from_pretrained(
"vidore/colpali-v1.2-hf",
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
device_map="auto", # "cpu", "cuda", or "mps" for Apple Silicon
).eval()
# Load the model and the processor
model_name = "vidore/colpali-v1.3-hf"
model = ColPaliForRetrieval.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
device_map="auto", # "cpu", "cuda", or "mps" for Apple Silicon
)
processor = ColPaliProcessor.from_pretrained(model_name)
# The document page screenshots from your corpus
url1 = "https://upload.wikimedia.org/wikipedia/commons/8/89/US-original-Declaration-1776.jpg"
url2 = "https://upload.wikimedia.org/wikipedia/commons/thumb/4/4c/Romeoandjuliet1597.jpg/500px-Romeoandjuliet1597.jpg"
@ -53,25 +59,37 @@ images = [
Image.open(requests.get(url2, stream=True).raw),
]
# The queries you want to retrieve documents for
queries = [
"Who printed the edition of Romeo and Juliet?",
"When was the United States Declaration of Independence proclaimed?",
"Who printed the edition of Romeo and Juliet?",
]
# Process the inputs
inputs_images = processor(images=images, return_tensors="pt").to(model.device)
inputs_text = processor(text=queries, return_tensors="pt").to(model.device)
inputs_images = processor(images=images).to(model.device)
inputs_text = processor(text=queries).to(model.device)
# Forward pass
with torch.no_grad():
image_embeddings = model(**inputs_images).embeddings
query_embeddings = model(**inputs_text).embeddings
# Score the queries against the images
scores = processor.score_retrieval(query_embeddings, image_embeddings)
print("Retrieval scores (query x image):")
print(scores)
```
If you have issue with loading the images with PIL, you can use the following code to create dummy images:
```python
images = [
Image.new("RGB", (128, 128), color="white"),
Image.new("RGB", (64, 32), color="black"),
]
```
</hfoption>
</hfoptions>
@ -79,12 +97,15 @@ Quantization reduces the memory burden of large models by representing the weigh
The example below uses [bitsandbytes](../quantization/bitsandbytes.md) to quantize the weights to int4.
```py
```python
import requests
import torch
from PIL import Image
from transformers import ColPaliForRetrieval, ColPaliProcessor
from transformers import BitsAndBytesConfig
from transformers import BitsAndBytesConfig, ColPaliForRetrieval, ColPaliProcessor
model_name = "vidore/colpali-v1.3-hf"
# 4-bit quantization configuration
bnb_config = BitsAndBytesConfig(
@ -94,14 +115,11 @@ bnb_config = BitsAndBytesConfig(
bnb_4bit_compute_dtype=torch.float16,
)
model_name = "vidore/colpali-v1.2-hf"
# Load model
model = ColPaliForRetrieval.from_pretrained(
model_name,
quantization_config=bnb_config,
device_map="cuda"
).eval()
device_map="cuda",
)
processor = ColPaliProcessor.from_pretrained(model_name)
@ -114,8 +132,8 @@ images = [
]
queries = [
"Who printed the edition of Romeo and Juliet?",
"When was the United States Declaration of Independence proclaimed?",
"Who printed the edition of Romeo and Juliet?",
]
# Process the inputs
@ -127,6 +145,7 @@ with torch.no_grad():
image_embeddings = model(**inputs_images).embeddings
query_embeddings = model(**inputs_text).embeddings
# Score the queries against the images
scores = processor.score_retrieval(query_embeddings, image_embeddings)
print("Retrieval scores (query x image):")

View File

@ -0,0 +1,176 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.
-->
<div style="float: right;">
<div class="flex flex-wrap space-x-1">
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
</div>
</div>
# ColQwen2
[ColQwen2](https://doi.org/10.48550/arXiv.2407.01449) is a variant of the [ColPali](./colpali) model designed to retrieve documents by analyzing their visual features. Unlike traditional systems that rely heavily on text extraction and OCR, ColQwen2 treats each page as an image. It uses the [Qwen2-VL](./qwen2_vl) backbone to capture not only text, but also the layout, tables, charts, and other visual elements to create detailed multi-vector embeddings that can be used for retrieval by computing pairwise late interaction similarity scores. This offers a more comprehensive understanding of documents and enables more efficient and accurate retrieval.
This model was contributed by [@tonywu71](https://huggingface.co/tonywu71) (ILLUIN Technology) and [@yonigozlan](https://huggingface.co/yonigozlan) (HuggingFace).
You can find all the original ColPali checkpoints under Vidore's [Hf-native ColVision Models](https://huggingface.co/collections/vidore/hf-native-colvision-models-6755d68fc60a8553acaa96f7) collection.
> [!TIP]
> Click on the ColQwen2 models in the right sidebar for more examples of how to use ColQwen2 for image retrieval.
<hfoptions id="usage">
<hfoption id="image retrieval">
```python
import requests
import torch
from PIL import Image
from transformers import ColQwen2ForRetrieval, ColQwen2Processor
from transformers.utils.import_utils import is_flash_attn_2_available
# Load the model and the processor
model_name = "vidore/colqwen2-v1.0-hf"
model = ColQwen2ForRetrieval.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
device_map="auto", # "cpu", "cuda", or "mps" for Apple Silicon
attn_implementation="flash_attention_2" if is_flash_attn_2_available() else "sdpa",
)
processor = ColQwen2Processor.from_pretrained(model_name)
# The document page screenshots from your corpus
url1 = "https://upload.wikimedia.org/wikipedia/commons/8/89/US-original-Declaration-1776.jpg"
url2 = "https://upload.wikimedia.org/wikipedia/commons/thumb/4/4c/Romeoandjuliet1597.jpg/500px-Romeoandjuliet1597.jpg"
images = [
Image.open(requests.get(url1, stream=True).raw),
Image.open(requests.get(url2, stream=True).raw),
]
# The queries you want to retrieve documents for
queries = [
"When was the United States Declaration of Independence proclaimed?",
"Who printed the edition of Romeo and Juliet?",
]
# Process the inputs
inputs_images = processor(images=images).to(model.device)
inputs_text = processor(text=queries).to(model.device)
# Forward pass
with torch.no_grad():
image_embeddings = model(**inputs_images).embeddings
query_embeddings = model(**inputs_text).embeddings
# Score the queries against the images
scores = processor.score_retrieval(query_embeddings, image_embeddings)
print("Retrieval scores (query x image):")
print(scores)
```
If you have issue with loading the images with PIL, you can use the following code to create dummy images:
```python
images = [
Image.new("RGB", (128, 128), color="white"),
Image.new("RGB", (64, 32), color="black"),
]
```
</hfoption>
</hfoptions>
Quantization reduces the memory burden of large models by representing the weights in a lower precision. Refer to the [Quantization](../quantization/overview) overview for more available quantization backends.
The example below uses [bitsandbytes](../quantization/bitsandbytes.md) to quantize the weights to int4.
```python
import requests
import torch
from PIL import Image
from transformers import BitsAndBytesConfig, ColQwen2ForRetrieval, ColQwen2Processor
model_name = "vidore/colqwen2-v1.0-hf"
# 4-bit quantization configuration
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16,
)
model = ColQwen2ForRetrieval.from_pretrained(
model_name,
quantization_config=bnb_config,
device_map="cuda",
).eval()
processor = ColQwen2Processor.from_pretrained(model_name)
url1 = "https://upload.wikimedia.org/wikipedia/commons/8/89/US-original-Declaration-1776.jpg"
url2 = "https://upload.wikimedia.org/wikipedia/commons/thumb/4/4c/Romeoandjuliet1597.jpg/500px-Romeoandjuliet1597.jpg"
images = [
Image.open(requests.get(url1, stream=True).raw),
Image.open(requests.get(url2, stream=True).raw),
]
queries = [
"When was the United States Declaration of Independence proclaimed?",
"Who printed the edition of Romeo and Juliet?",
]
# Process the inputs
inputs_images = processor(images=images, return_tensors="pt").to(model.device)
inputs_text = processor(text=queries, return_tensors="pt").to(model.device)
# Forward pass
with torch.no_grad():
image_embeddings = model(**inputs_images).embeddings
query_embeddings = model(**inputs_text).embeddings
# Score the queries against the images
scores = processor.score_retrieval(query_embeddings, image_embeddings)
print("Retrieval scores (query x image):")
print(scores)
```
## Notes
- [`~ColQwen2Processor.score_retrieval`] returns a 2D tensor where the first dimension is the number of queries and the second dimension is the number of images. A higher score indicates more similarity between the query and image.
- Unlike ColPali, ColQwen2 supports arbitrary image resolutions and aspect ratios, which means images are not resized into fixed-size squares. This preserves more of the original input signal.
- Larger input images generate longer multi-vector embeddings, allowing users to adjust image resolution to balance performance and memory usage.
## ColQwen2Config
[[autodoc]] ColQwen2Config
## ColQwen2Processor
[[autodoc]] ColQwen2Processor
## ColQwen2ForRetrieval
[[autodoc]] ColQwen2ForRetrieval
- forward

View File

@ -62,6 +62,7 @@ if TYPE_CHECKING:
from .cohere import *
from .cohere2 import *
from .colpali import *
from .colqwen2 import *
from .conditional_detr import *
from .convbert import *
from .convnext import *

View File

@ -79,6 +79,7 @@ CONFIG_MAPPING_NAMES = OrderedDict[str, str](
("cohere", "CohereConfig"),
("cohere2", "Cohere2Config"),
("colpali", "ColPaliConfig"),
("colqwen2", "ColQwen2Config"),
("conditional_detr", "ConditionalDetrConfig"),
("convbert", "ConvBertConfig"),
("convnext", "ConvNextConfig"),
@ -437,6 +438,7 @@ MODEL_NAMES_MAPPING = OrderedDict[str, str](
("cohere", "Cohere"),
("cohere2", "Cohere2"),
("colpali", "ColPali"),
("colqwen2", "ColQwen2"),
("conditional_detr", "Conditional DETR"),
("convbert", "ConvBERT"),
("convnext", "ConvNeXT"),

View File

@ -365,6 +365,7 @@ MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict(
("bloom", "BloomForCausalLM"),
("camembert", "CamembertForMaskedLM"),
("colpali", "ColPaliForRetrieval"),
("colqwen2", "ColQwen2ForRetrieval"),
("ctrl", "CTRLLMHeadModel"),
("data2vec-text", "Data2VecTextForMaskedLM"),
("deberta", "DebertaForMaskedLM"),

View File

@ -66,6 +66,7 @@ PROCESSOR_MAPPING_NAMES = OrderedDict(
("clipseg", "CLIPSegProcessor"),
("clvp", "ClvpProcessor"),
("colpali", "ColPaliProcessor"),
("colqwen2", "ColQwen2Processor"),
("emu3", "Emu3Processor"),
("flava", "FlavaProcessor"),
("fuyu", "FuyuProcessor"),

View File

@ -147,6 +147,7 @@ TOKENIZER_MAPPING_NAMES = OrderedDict[str, tuple[Optional[str], Optional[str]]](
("cohere", (None, "CohereTokenizerFast" if is_tokenizers_available() else None)),
("cohere2", (None, "CohereTokenizerFast" if is_tokenizers_available() else None)),
("colpali", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
("colqwen2", ("Qwen2Tokenizer", "Qwen2TokenizerFast" if is_tokenizers_available() else None)),
("convbert", ("ConvBertTokenizer", "ConvBertTokenizerFast" if is_tokenizers_available() else None)),
(
"cpm",

View File

@ -33,8 +33,6 @@ class ColPaliConfig(PretrainedConfig):
Creating a configuration with the default settings will result in a configuration where the VLM backbone is set to the
default PaliGemma configuration, i.e the one from [vidore/colpali-v1.2](https://huggingface.co/vidore/colpali-v1.2).
The ColPali config is very similar to [`PaligemmaConfig`], but with an extra attribute defining the embedding dimension.
Note that contrarily to what the class name suggests (actually the name refers to the ColPali **methodology**), you can
use a different VLM backbone model than PaliGemma by passing the corresponding VLM configuration to the class constructor.
@ -93,7 +91,7 @@ class ColPaliConfig(PretrainedConfig):
)
self.vlm_config = vlm_config
self.text_config = text_config = text_config if text_config is not None else vlm_config.text_config
self.text_config = text_config if text_config is not None else vlm_config.text_config
if isinstance(self.text_config, dict):
text_config["model_type"] = text_config["model_type"] if "model_type" in text_config else "gemma"
self.text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config)

View File

@ -24,16 +24,10 @@ from transformers import AutoModelForImageTextToText
from ...cache_utils import Cache
from ...modeling_utils import PreTrainedModel
from ...utils import ModelOutput, auto_docstring
from ...utils import ModelOutput, auto_docstring, can_return_tuple
from .configuration_colpali import ColPaliConfig
@auto_docstring(
custom_intro="""
The bare ColPali model outputting raw hidden-states without any specific head on top.
"""
)
@auto_docstring
class ColPaliPreTrainedModel(PreTrainedModel):
config_class = ColPaliConfig
base_model_prefix = "model"
@ -98,13 +92,16 @@ class ColPaliForRetrievalOutput(ModelOutput):
@auto_docstring(
custom_intro="""
In our proposed ColPali approach, we leverage VLMs to construct efficient multi-vector embeddings directly
from document images (screenshots) for document retrieval. We train the model to maximize the similarity
The ColPali architecture leverages VLMs to construct efficient multi-vector embeddings directly
from document images (screenshots) for document retrieval. The model is trained to maximize the similarity
between these document embeddings and the corresponding query embeddings, using the late interaction method
introduced in ColBERT.
Using ColPali removes the need for potentially complex and brittle layout recognition and OCR pipelines with a
single model that can take into account both the textual and visual content (layout, charts, etc.) of a document.
ColPali is part of the ColVision model family, which was first introduced in the following paper:
[*ColPali: Efficient Document Retrieval with Vision Language Models*](https://arxiv.org/abs/2407.01449).
"""
)
class ColPaliForRetrieval(ColPaliPreTrainedModel):
@ -126,6 +123,7 @@ class ColPaliForRetrieval(ColPaliPreTrainedModel):
self.post_init()
@can_return_tuple
@auto_docstring
def forward(
self,
@ -136,9 +134,9 @@ class ColPaliForRetrieval(ColPaliPreTrainedModel):
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs,
) -> Union[Tuple, ColPaliForRetrievalOutput]:
if "pixel_values" in kwargs:
kwargs["pixel_values"] = kwargs["pixel_values"].to(dtype=self.dtype)
) -> ColPaliForRetrievalOutput:
if pixel_values is not None:
pixel_values = pixel_values.to(dtype=self.dtype)
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
@ -146,17 +144,19 @@ class ColPaliForRetrieval(ColPaliPreTrainedModel):
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.vlm(
vlm_output = self.vlm(
input_ids=input_ids,
attention_mask=attention_mask,
pixel_values=pixel_values,
output_hidden_states=True,
return_dict=return_dict,
return_dict=True,
output_attentions=output_attentions,
**kwargs,
)
vlm_hidden_states = vlm_output.hidden_states if output_hidden_states else None
vlm_image_hidden_states = vlm_output.image_hidden_states if pixel_values is not None else None
last_hidden_states = outputs.hidden_states[-1] # (batch_size, sequence_length, hidden_size)
last_hidden_states = vlm_output.hidden_states[-1] # (batch_size, sequence_length, hidden_size)
embeddings = self.embedding_proj_layer(last_hidden_states) # (batch_size, sequence_length, dim)
# L2 normalization
@ -164,20 +164,12 @@ class ColPaliForRetrieval(ColPaliPreTrainedModel):
embeddings = embeddings * attention_mask.unsqueeze(-1) # (batch_size, sequence_length, dim)
loss = None
if not return_dict:
output = (embeddings,) + outputs[2:]
output[2] = output[2] if output_hidden_states is not None else None
output[-1] = (outputs.image_hidden_states if pixel_values is not None else None,)
return (loss,) + output if loss is not None else output
return ColPaliForRetrievalOutput(
loss=loss,
embeddings=embeddings,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states if output_hidden_states else None,
attentions=outputs.attentions,
image_hidden_states=outputs.image_hidden_states if pixel_values is not None else None,
past_key_values=vlm_output.past_key_values,
hidden_states=vlm_hidden_states,
attentions=vlm_output.attentions,
image_hidden_states=vlm_image_hidden_states,
)
def get_input_embeddings(self):

View File

@ -14,28 +14,15 @@
# limitations under the License.
from typing import ClassVar, List, Optional, Union
from typing import List, Optional, Union
from transformers.models.paligemma.processing_paligemma import (
IMAGE_TOKEN,
PaliGemmaProcessor,
build_string_from_input,
)
from transformers.models.paligemma.processing_paligemma import IMAGE_TOKEN, PaliGemmaProcessor, build_string_from_input
from ...feature_extraction_utils import BatchFeature
from ...image_utils import ImageInput, is_valid_image, make_flat_list_of_images
from ...processing_utils import (
ProcessingKwargs,
Unpack,
)
from ...tokenization_utils_base import (
PreTokenizedInput,
TextInput,
)
from ...utils import (
is_torch_available,
logging,
)
from ...processing_utils import ProcessingKwargs, Unpack
from ...tokenization_utils_base import PreTokenizedInput, TextInput
from ...utils import is_torch_available, logging
if is_torch_available():
@ -73,10 +60,23 @@ class ColPaliProcessor(PaliGemmaProcessor):
The tokenizer is a required input.
chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages
in a chat into a tokenizable string.
visual_prompt_prefix (`str`, *optional*, defaults to `"Describe the image."`):
A string that gets tokenized and prepended to the image tokens.
query_prefix (`str`, *optional*, defaults to `"Question: "`):
A prefix to be used for the query.
"""
visual_prompt_prefix: ClassVar[str] = "Describe the image."
query_prefix: ClassVar[str] = "Question: "
def __init__(
self,
image_processor=None,
tokenizer=None,
chat_template=None,
visual_prompt_prefix: str = "Describe the image.",
query_prefix: str = "Question: ",
):
super().__init__(image_processor=image_processor, tokenizer=tokenizer, chat_template=chat_template)
self.visual_prompt_prefix = visual_prompt_prefix
self.query_prefix = query_prefix
@property
def query_augmentation_token(self) -> str:
@ -96,7 +96,7 @@ class ColPaliProcessor(PaliGemmaProcessor):
**kwargs: Unpack[ColPaliProcessorKwargs],
) -> BatchFeature:
"""
Main method to prepare for the model either (1) one or several texts, either (2) one or several image(s). This method is custom
Main method to prepare for the model either (1) one or several texts, either (2) one or several image(s). This method is a custom
wrapper around the PaliGemmaProcessor's [`~PaliGemmaProcessor.__call__`] method adapted for the ColPali model. It cannot process
both text and images at the same time.
@ -196,12 +196,10 @@ class ColPaliProcessor(PaliGemmaProcessor):
if suffix is None:
suffix = self.query_augmentation_token * 10
texts_query: List[str] = []
texts_query: List[str] = []
for query in text:
query = self.tokenizer.bos_token + self.query_prefix + query
query += suffix # add suffix (pad tokens)
query += "\n" # make input ISO to PaliGemma's processor
query = self.tokenizer.bos_token + self.query_prefix + query + suffix + "\n"
texts_query.append(query)
output_kwargs["text_kwargs"]["max_length"] = output_kwargs["text_kwargs"].get("max_length", 50)
@ -223,7 +221,7 @@ class ColPaliProcessor(PaliGemmaProcessor):
Prepare for the model one or several image(s). This method is a wrapper around the `__call__` method of the ColPaliProcessor's
[`ColPaliProcessor.__call__`].
This method forwards the `images` and `kwargs` arguments to SiglipImageProcessor's [`~SiglipImageProcessor.__call__`].
This method forwards the `images` and `kwargs` arguments to the image processor.
Args:
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
@ -258,7 +256,7 @@ class ColPaliProcessor(PaliGemmaProcessor):
Prepare for the model one or several texts. This method is a wrapper around the `__call__` method of the ColPaliProcessor's
[`ColPaliProcessor.__call__`].
This method forwards the `text` and `kwargs` arguments to LlamaTokenizerFast's [`~LlamaTokenizerFast.__call__`].
This method forwards the `text` and `kwargs` arguments to the tokenizer.
Args:
text (`str`, `List[str]`, `List[List[str]]`):

View File

@ -20,7 +20,7 @@
# limitations under the License.
from typing import ClassVar, List, Optional, Union
from typing import List, Optional, Union
from ...feature_extraction_utils import BatchFeature
from ...image_utils import ImageInput, is_valid_image, make_flat_list_of_images
@ -87,22 +87,25 @@ class ColPaliProcessor(ProcessorMixin):
The tokenizer is a required input.
chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages
in a chat into a tokenizable string.
visual_prompt_prefix (`str`, *optional*, defaults to `"Describe the image."`):
A string that gets tokenized and prepended to the image tokens.
query_prefix (`str`, *optional*, defaults to `"Question: "`):
A prefix to be used for the query.
"""
attributes = ["image_processor", "tokenizer"]
image_processor_class = ("SiglipImageProcessor", "SiglipImageProcessorFast")
tokenizer_class = ("GemmaTokenizer", "GemmaTokenizerFast")
visual_prompt_prefix: ClassVar[str] = "Describe the image."
query_prefix: ClassVar[str] = "Question: "
def __init__(
self,
image_processor=None,
tokenizer=None,
chat_template=None,
**kwargs,
visual_prompt_prefix: str = "Describe the image.",
query_prefix: str = "Question: ",
):
super().__init__(image_processor, tokenizer, chat_template=chat_template)
if image_processor is None:
raise ValueError("You need to specify an `image_processor`.")
if tokenizer is None:
@ -125,8 +128,8 @@ class ColPaliProcessor(ProcessorMixin):
tokenizer.add_tokens(EXTRA_TOKENS)
tokenizer.add_bos_token = False
tokenizer.add_eos_token = False
super().__init__(image_processor, tokenizer, chat_template=chat_template)
self.visual_prompt_prefix = visual_prompt_prefix
self.query_prefix = query_prefix
def __call__(
self,
@ -137,7 +140,7 @@ class ColPaliProcessor(ProcessorMixin):
**kwargs: Unpack[ColPaliProcessorKwargs],
) -> BatchFeature:
"""
Main method to prepare for the model either (1) one or several texts, either (2) one or several image(s). This method is custom
Main method to prepare for the model either (1) one or several texts, either (2) one or several image(s). This method is a custom
wrapper around the PaliGemmaProcessor's [`~PaliGemmaProcessor.__call__`] method adapted for the ColPali model. It cannot process
both text and images at the same time.
@ -237,12 +240,10 @@ class ColPaliProcessor(ProcessorMixin):
if suffix is None:
suffix = self.query_augmentation_token * 10
texts_query: List[str] = []
texts_query: List[str] = []
for query in text:
query = self.tokenizer.bos_token + self.query_prefix + query
query += suffix # add suffix (pad tokens)
query += "\n" # make input ISO to PaliGemma's processor
query = self.tokenizer.bos_token + self.query_prefix + query + suffix + "\n"
texts_query.append(query)
output_kwargs["text_kwargs"]["max_length"] = output_kwargs["text_kwargs"].get("max_length", 50)
@ -312,7 +313,7 @@ class ColPaliProcessor(ProcessorMixin):
Prepare for the model one or several image(s). This method is a wrapper around the `__call__` method of the ColPaliProcessor's
[`ColPaliProcessor.__call__`].
This method forwards the `images` and `kwargs` arguments to SiglipImageProcessor's [`~SiglipImageProcessor.__call__`].
This method forwards the `images` and `kwargs` arguments to the image processor.
Args:
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
@ -347,7 +348,7 @@ class ColPaliProcessor(ProcessorMixin):
Prepare for the model one or several texts. This method is a wrapper around the `__call__` method of the ColPaliProcessor's
[`ColPaliProcessor.__call__`].
This method forwards the `text` and `kwargs` arguments to LlamaTokenizerFast's [`~LlamaTokenizerFast.__call__`].
This method forwards the `text` and `kwargs` arguments to the tokenizer.
Args:
text (`str`, `List[str]`, `List[List[str]]`):

View File

@ -0,0 +1,28 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
from ...utils import _LazyModule
from ...utils.import_utils import define_import_structure
if TYPE_CHECKING:
from .configuration_colqwen2 import *
from .modeling_colqwen2 import *
from .processing_colqwen2 import *
else:
import sys
_file = globals()["__file__"]
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)

View File

@ -0,0 +1,94 @@
# Copyright 2025 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from copy import deepcopy
from typing import Any, Dict
from ...configuration_utils import PretrainedConfig
from ...utils import logging
from ..auto import CONFIG_MAPPING
logger = logging.get_logger(__name__)
class ColQwen2Config(PretrainedConfig):
r"""
Configuration class to store the configuration of a [`ColQ2en2ForRetrieval`]. It is used to instantiate an instance
of `ColQwen2ForRetrieval` according to the specified arguments, defining the model architecture following the methodology
from the "ColPali: Efficient Document Retrieval with Vision Language Models" paper.
Instantiating a configuration with the defaults will yield a similar configuration to the vision encoder used by the pre-trained
ColQwen2-v1.0 model, e.g. [vidore/colqwen2-v1.0-hf](https://huggingface.co/vidore/colqwen2-v1.0-hf).
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vlm_config (`PretrainedConfig`, *optional*):
Configuration of the VLM backbone model.
embedding_dim (`int`, *optional*, defaults to 128):
Dimension of the multi-vector embeddings produced by the model.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
Example:
```python
from transformers.models.colqwen2 import ColQwen2Config, ColQwen2ForRetrieval
config = ColQwen2Config()
model = ColQwen2ForRetrieval(config)
```
"""
model_type = "colqwen2"
sub_configs: Dict[str, Any] = {"vlm_config": PretrainedConfig}
def __init__(
self,
vlm_config=None,
embedding_dim: int = 128,
initializer_range: float = 0.02,
**kwargs,
):
if vlm_config is None:
vlm_config = CONFIG_MAPPING["qwen2_vl"]()
logger.info(
"`vlm_config` is `None`. Initializing `vlm_config` with the `Qwen2VLConfig` with default values."
)
elif isinstance(vlm_config, dict):
vlm_config = deepcopy(vlm_config)
if "model_type" not in vlm_config:
raise KeyError(
"The `model_type` key is missing in the `vlm_config` dictionary. Please provide the model type."
)
vlm_config = CONFIG_MAPPING[vlm_config["model_type"]](**vlm_config)
elif isinstance(vlm_config, PretrainedConfig):
vlm_config = vlm_config
else:
raise TypeError(
f"Invalid type for `vlm_config`. Expected `PretrainedConfig`, `dict`, or `None`, but got {type(vlm_config)}."
)
self.vlm_config = vlm_config
self.embedding_dim = embedding_dim
self.initializer_range = initializer_range
super().__init__(**kwargs)
def get_text_config(self, decoder=False) -> PretrainedConfig:
return self.vlm_config.get_text_config(decoder=decoder)
__all__ = ["ColQwen2Config"]

View File

@ -0,0 +1,212 @@
# coding=utf-8
# Copyright 2025 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Convert ColQwen2 weights from the original repository to the HF model format.
Don't forget to manually upload the processor-related files to the HF model repository
after running this script.
Original repository: https://github.com/illuin-tech/colqwen2.
NOTE: This script was originally run using `torch==2.5.1` and with:
```bash
python src/transformers/models/colqwen2/convert_colqwen2_weights_to_hf.py \
--model_id vidore/colqwen2-v1.0-merged \
--revision eeccbae1d44bdcb0c83b1788127a2b2cad7d718e \
--original_vlm_name_or_path Qwen/Qwen2-VL-2B-Instruct \
--output_dir vidore/colqwen2-v1.0-hf-internal \
--push_to_hub
```
"""
import argparse
import glob
from pathlib import Path
from typing import Any, Dict, Optional
import torch
from huggingface_hub import snapshot_download
from safetensors import safe_open
from transformers import AutoConfig
from transformers.models.colqwen2 import ColQwen2ForRetrieval
from transformers.models.colqwen2.configuration_colqwen2 import ColQwen2Config
from transformers.utils import logging
logging.set_verbosity_info()
logger = logging.get_logger(__name__)
ORIGINAL_DTYPE = torch.bfloat16
def load_original_state_dict(model_id: str, revision: Optional[str] = None) -> Dict[str, torch.Tensor]:
directory_path = snapshot_download(
repo_id=model_id,
revision=revision,
allow_patterns=["*.safetensors"],
)
original_state_dict = {}
for path in glob.glob(f"{directory_path}/*"):
if path.endswith(".safetensors"):
with safe_open(path, framework="pt", device="cpu") as f:
for key in f.keys():
original_state_dict[key] = f.get_tensor(key)
# Some weights are tied, so `lm.head`` is not saved. Let's clone to load state dict.
if "lm_head.weight" not in original_state_dict:
original_state_dict["lm_head.weight"] = original_state_dict["model.embed_tokens.weight"].clone()
return original_state_dict
def rename_state_dict_keys(state_dict: Dict[str, Any]) -> Dict[str, Any]:
new_state_dict: Dict[str, Any] = {}
for key, value in state_dict.items():
if key.startswith("custom_text_proj"):
new_key = key.replace("custom_text_proj", "embedding_proj_layer")
else:
# The original ColQwen2 inherits from Qwen2VL, so we simply need to add the `vlm.` prefix
# to all remaining keys.
if key.startswith("model."):
key = key.replace("model.", "model.language_model.")
if key.startswith("visual."):
key = key.replace("visual.", "model.visual.")
new_key = "vlm." + key
new_state_dict[new_key] = value
return new_state_dict
@torch.no_grad()
def convert_colqwen2_weights_to_hf(
model_id: str,
output_dir: str,
push_to_hub: bool,
revision: Optional[str] = None,
original_vlm_name_or_path: Optional[str] = None,
):
# Load the original model data
original_config = AutoConfig.from_pretrained(
model_id,
revision=revision,
)
if original_vlm_name_or_path is not None:
original_config._name_or_path = original_vlm_name_or_path
if hasattr(original_config, "architectures"):
delattr(original_config, "architectures")
original_state_dict = load_original_state_dict(model_id, revision=revision)
# Format the state_dict keys
original_state_dict = rename_state_dict_keys(original_state_dict)
# Create the new config
config = ColQwen2Config(
vlm_config=original_config,
embedding_dim=128, # hardcoded in the original model
)
config.model_type = "colqwen2"
config.is_composition = False
# Load the untrained model
model = ColQwen2ForRetrieval(config=config).to("cpu").eval()
print("Created model with new config and randomly initialized weights")
# NOTE: The new model was initialized with float32 weights. We need to convert it to the desired precision.
# There are two ways to set the model's dtype:
# - Using `model.from_pretrained(..., torch_dtype=dtype_precision)` doesn't convert the hyperparameters to the desired precision.
# - Using `model.to(dtype_precision)` converts all values - including the hyperparameters - to the desired precision.
# The following snippet allows a fine-grained control over the model's dtype, making sure that all
# the new weights' dtypes match the original model.
for param in model.parameters():
param.data = param.data.to(ORIGINAL_DTYPE)
print(f"Converted the new model weights to `{ORIGINAL_DTYPE}`")
# Load the original weights
model.load_state_dict(original_state_dict)
print("Loaded original model weights")
# # Sanity check: ensure all keys are the same
state_dict_keys_old = set(original_state_dict.keys())
state_dict_keys_new = set(model.state_dict().keys())
disjoint_keys = state_dict_keys_old.symmetric_difference(state_dict_keys_new)
if disjoint_keys:
raise ValueError(f"Incompatible keys: {disjoint_keys}")
# Save the model
if push_to_hub:
model.push_to_hub(output_dir, private=True)
print(f"Model pushed to the hub at `{output_dir}`")
else:
Path(output_dir).mkdir(exist_ok=True, parents=True)
model.save_pretrained(output_dir)
print(f"Model saved to `{output_dir}`")
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="""
This script converts the original ColQwen2 model to the HF model format.
Don't forget to manually upload the processor-related files to the HF model repository
after running this script.
Example usage:
```bash
python src/transformers/models/colqwen2/convert_colqwen2_weights_to_hf.py \
--model_id vidore/colqwen2-v1.0-merged \
--revision eeccbae1d44bdcb0c83b1788127a2b2cad7d718e \
--original_vlm_name_or_path Qwen/Qwen2-VL-2B-Instruct \
--output_dir vidore/colqwen2-v1.0-hf-internal \
--push_to_hub
```
"""
)
parser.add_argument(
"--model_id",
help="Model ID of the original model to convert",
)
parser.add_argument(
"--output_dir",
help="Location to write HF model and tokenizer",
)
parser.add_argument(
"--push_to_hub",
help="Whether or not to push the model to the hub at `output_dir` instead of saving it locally",
action="store_true",
default=False,
)
parser.add_argument(
"--revision",
help="Revision of the model to download",
default=None,
)
parser.add_argument(
"--original_vlm_name_or_path",
help="Name or path of the original VLM backbone model",
default=None,
)
args = parser.parse_args()
convert_colqwen2_weights_to_hf(
model_id=args.model_id,
output_dir=args.output_dir,
push_to_hub=args.push_to_hub,
revision=args.revision,
original_vlm_name_or_path=args.original_vlm_name_or_path,
)

View File

@ -0,0 +1,268 @@
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# This file was automatically generated from src/transformers/models/colqwen2/modular_colqwen2.py.
# Do NOT edit this file manually as any edits will be overwritten by the generation of
# the file from the modular. If any change should be done, please apply the change to the
# modular_colqwen2.py file directly. One of our CI enforces this.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# coding=utf-8
# Copyright 2025 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
from torch import nn
from transformers import AutoModelForImageTextToText
from ...cache_utils import Cache
from ...modeling_utils import PreTrainedModel
from ...utils import ModelOutput, auto_docstring, can_return_tuple, is_torch_available
from .configuration_colqwen2 import ColQwen2Config
if is_torch_available():
import torch
class ColQwen2PreTrainedModel(PreTrainedModel):
config_class = ColQwen2Config
base_model_prefix = "model"
_no_split_modules = []
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_cache_class = True
def _init_weights(self, module):
std = (
self.config.initializer_range
if hasattr(self.config, "initializer_range")
else self.config.vlm_config.text_config.initializer_range
)
if isinstance(module, (nn.Linear, nn.Conv2d)):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
@dataclass
class ColQwen2ForRetrievalOutput(ModelOutput):
"""
Base class for ColQwen2 embeddings output.
Args:
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
Language modeling loss (for next-token prediction).
embeddings (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
The embeddings of the model.
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`)
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
`past_key_values` input) to speed up sequential decoding.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
loss: Optional[torch.FloatTensor] = None
embeddings: Optional[torch.Tensor] = None
past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
@auto_docstring(
custom_intro="""
Following the ColPali approach, ColQwen2 leverages VLMs to construct efficient multi-vector embeddings directly
from document images (screenshots) for document retrieval. The model is trained to maximize the similarity
between these document embeddings and the corresponding query embeddings, using the late interaction method
introduced in ColBERT.
Using ColQwen2 removes the need for potentially complex and brittle layout recognition and OCR pipelines with
a single model that can take into account both the textual and visual content (layout, charts, ...) of a document.
ColQwen2 is part of the ColVision model family, which was introduced with ColPali in the following paper:
[*ColPali: Efficient Document Retrieval with Vision Language Models*](https://arxiv.org/abs/2407.01449).
"""
)
class ColQwen2ForRetrieval(ColQwen2PreTrainedModel):
def __init__(self, config: ColQwen2Config):
super().__init__(config)
self.config = config
self.vocab_size = config.vlm_config.text_config.vocab_size
vlm = AutoModelForImageTextToText.from_config(config.vlm_config)
if vlm._tied_weights_keys is not None:
self._tied_weights_keys = [f"vlm.{k}" for k in vlm._tied_weights_keys]
self.vlm = vlm
self.embedding_dim = self.config.embedding_dim
self.embedding_proj_layer = nn.Linear(
self.config.vlm_config.text_config.hidden_size,
self.embedding_dim,
)
self.post_init()
@can_return_tuple
@auto_docstring
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
labels: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
pixel_values: Optional[torch.Tensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> ColQwen2ForRetrievalOutput:
r"""
image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
The temporal, height and width of feature shape of each image in LLM.
"""
if pixel_values is not None:
pixel_values = pixel_values.to(dtype=self.dtype) # (batch_size, max_num_patches, pixel_values)
# Handle the custom "pixel_values" input obtained with `ColQwen2Processor` through unpadding
if pixel_values is not None and image_grid_thw is not None:
# NOTE: image_grid_thw: (batch_size, 3) where image_grid_thw[i] = (num_patches_h, num_patches_w, temporal_patch_size)
offsets = image_grid_thw[:, 1] * image_grid_thw[:, 2] # (num_patches_h, num_patches_w)
pixel_values = torch.cat(
[pixel_sequence[:offset] for pixel_sequence, offset in zip(pixel_values, offsets)],
dim=0,
) # (num_patches_h * num_patches_w, pixel_values)
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
position_ids, rope_deltas = self.vlm.model.get_rope_index(
input_ids=input_ids,
image_grid_thw=image_grid_thw,
video_grid_thw=None,
attention_mask=attention_mask,
)
# Custom data preparation to fix an issue with the gradient flow when training with multiple GPUs.
if inputs_embeds is None:
inputs_embeds = self.vlm.model.language_model.embed_tokens(input_ids)
if pixel_values is not None:
pixel_values = pixel_values.type(self.vlm.visual.get_dtype())
image_embeds = self.vlm.visual(pixel_values, grid_thw=image_grid_thw)
image_mask = (
(input_ids == self.config.vlm_config.image_token_id).unsqueeze(-1).expand_as(inputs_embeds)
)
image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
if attention_mask is not None:
attention_mask = attention_mask.to(inputs_embeds.device)
vlm_output = self.vlm.model(
input_ids=None,
position_ids=position_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
)
vlm_hidden_states = vlm_output.hidden_states if output_hidden_states else None
last_hidden_states = vlm_output[0] # (batch_size, sequence_length, hidden_size)
embeddings = self.embedding_proj_layer(last_hidden_states) # (batch_size, sequence_length, dim)
# L2 normalization
embeddings = embeddings / embeddings.norm(dim=-1, keepdim=True) # (batch_size, sequence_length, dim)
if attention_mask is not None:
embeddings = embeddings * attention_mask.unsqueeze(-1) # (batch_size, sequence_length, dim)
return ColQwen2ForRetrievalOutput(
embeddings=embeddings,
past_key_values=vlm_output.past_key_values,
hidden_states=vlm_hidden_states,
attentions=vlm_output.attentions,
)
def get_input_embeddings(self):
return self.vlm.get_input_embeddings()
def set_input_embeddings(self, value):
self.vlm.set_input_embeddings(value)
def get_output_embeddings(self):
return self.vlm.get_output_embeddings()
def set_output_embeddings(self, new_embeddings):
self.vlm.set_output_embeddings(new_embeddings)
def set_decoder(self, decoder):
self.vlm.set_decoder(decoder)
def get_decoder(self):
return self.vlm.get_decoder()
def tie_weights(self):
return self.vlm.tie_weights()
def resize_token_embeddings(
self,
new_num_tokens: Optional[int] = None,
pad_to_multiple_of: Optional[int] = None,
mean_resizing: bool = True,
) -> nn.Embedding:
model_embeds = self.vlm.resize_token_embeddings(
new_num_tokens=new_num_tokens,
pad_to_multiple_of=pad_to_multiple_of,
mean_resizing=mean_resizing,
)
self.config.vlm_config.text_config.vocab_size = model_embeds.num_embeddings
self.config.vlm_config.vocab_size = model_embeds.num_embeddings
self.vlm.vocab_size = model_embeds.num_embeddings
self.vocab_size = model_embeds.num_embeddings
return model_embeds
__all__ = ["ColQwen2ForRetrieval", "ColQwen2PreTrainedModel"]

View File

@ -0,0 +1,383 @@
# coding=utf-8
# Copyright 2025 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
from transformers.models.colpali.modeling_colpali import ColPaliForRetrieval, ColPaliPreTrainedModel
from transformers.models.colpali.processing_colpali import ColPaliProcessor
from ...cache_utils import Cache
from ...feature_extraction_utils import BatchFeature
from ...image_utils import ImageInput, is_valid_image
from ...processing_utils import ProcessingKwargs, Unpack
from ...tokenization_utils_base import PreTokenizedInput, TextInput
from ...utils import ModelOutput, auto_docstring, can_return_tuple, is_torch_available, logging
if is_torch_available():
import torch
logger = logging.get_logger(__name__)
class ColQwen2ProcessorKwargs(ProcessingKwargs, total=False):
_defaults = {
"text_kwargs": {
"padding": "longest",
},
"images_kwargs": {
"data_format": "channels_first",
"do_convert_rgb": True,
},
"common_kwargs": {"return_tensors": "pt"},
}
class ColQwen2Processor(ColPaliProcessor):
r"""
Constructs a ColQwen2 processor which wraps a Qwen2VLProcessor and special methods to process images and queries, as
well as to compute the late-interaction retrieval score.
[`ColQwen2Processor`] offers all the functionalities of [`Qwen2VLProcessor`]. See the [`~Qwen2VLProcessor.__call__`]
for more information.
Args:
image_processor ([`Qwen2VLImageProcessor`], *optional*):
The image processor is a required input.
tokenizer ([`Qwen2TokenizerFast`], *optional*):
The tokenizer is a required input.
chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages
in a chat into a tokenizable string.
visual_prompt_prefix (`str`, *optional*): A string that gets tokenized and prepended to the image tokens.
query_prefix (`str`, *optional*): A prefix to be used for the query.
"""
image_processor_class = "Qwen2VLImageProcessor"
tokenizer_class = ("Qwen2Tokenizer", "Qwen2TokenizerFast")
def __init__(
self,
image_processor=None,
tokenizer=None,
chat_template=None,
visual_prompt_prefix: Optional[str] = None,
query_prefix: Optional[str] = None,
**kwargs,
):
ColPaliProcessor().__init__(image_processor, tokenizer, chat_template=chat_template)
self.image_token = "<|image_pad|>" if not hasattr(tokenizer, "image_token") else tokenizer.image_token
self.video_token = "<|video_pad|>" if not hasattr(tokenizer, "video_token") else tokenizer.video_token
if visual_prompt_prefix is None:
visual_prompt_prefix = "<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>Describe the image.<|im_end|><|endoftext|>"
self.visual_prompt_prefix = visual_prompt_prefix
if query_prefix is None:
query_prefix = "Query: "
self.query_prefix = query_prefix
def __call__(
self,
images: ImageInput = None,
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
audio=None,
videos=None,
**kwargs: Unpack[ColQwen2ProcessorKwargs],
) -> BatchFeature:
"""
Main method to prepare for the model either (1) one or several texts, either (2) one or several image(s). This method is a custom
wrapper around the Qwen2VLProcessor's [`~Qwen2VLProcessor.__call__`] method adapted for the ColQwen2 model. It cannot process
both text and images at the same time.
When preparing the the text(s), this method forwards the `text` and `kwargs` arguments to Qwen2TokenizerFast's
[`~Qwen2TokenizerFast.__call__`].
When preparing the the image(s), this method forwards the `images` and `kwargs` arguments to Qwen2VLImageProcessor's
[`~Qwen2VLImageProcessor.__call__`].
Please refer to the doctsring of the above two methods for more information.
Args:
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a
number of channels, H and W are image height and width.
text (`str`, `List[str]`, `List[List[str]]`):
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
return_tensors (`str` or [`~utils.TensorType`], *optional*):
If set, will return tensors of a particular framework. Acceptable values are:
- `'tf'`: Return TensorFlow `tf.constant` objects.
- `'pt'`: Return PyTorch `torch.Tensor` objects.
- `'np'`: Return NumPy `np.ndarray` objects.
- `'jax'`: Return JAX `jnp.ndarray` objects.
Returns:
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
- **input_ids** -- List of token ids to be fed to a model.
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
`return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
`None`).
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
"""
output_kwargs = self._merge_kwargs(
ColQwen2ProcessorKwargs,
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
**kwargs,
)
suffix = output_kwargs["text_kwargs"].pop("suffix", None)
return_token_type_ids = True if suffix is not None else False
if text is None and images is None:
raise ValueError("Either text or images must be provided")
if text is not None and images is not None:
raise ValueError("Only one of text or images can be processed at a time")
if images is not None:
if is_valid_image(images):
images = [images]
elif isinstance(images, list) and is_valid_image(images[0]):
pass
elif not (isinstance(images, list) and isinstance(images[0], list) and is_valid_image(images[0][0])):
raise ValueError("images must be an image, list of images or list of list of images")
texts_doc = [self.visual_prompt_prefix] * len(images)
image_inputs = self.image_processor(images=images, **output_kwargs["images_kwargs"])
image_grid_thw = image_inputs["image_grid_thw"]
if image_grid_thw is not None:
merge_length = self.image_processor.merge_size**2
index = 0
for i in range(len(texts_doc)):
while self.image_token in texts_doc[i]:
texts_doc[i] = texts_doc[i].replace(
self.image_token, "<|placeholder|>" * (image_grid_thw[index].prod() // merge_length), 1
)
index += 1
texts_doc[i] = texts_doc[i].replace("<|placeholder|>", self.image_token)
text_inputs = self.tokenizer(
texts_doc,
return_token_type_ids=False,
**output_kwargs["text_kwargs"],
)
return_data = BatchFeature(data={**text_inputs, **image_inputs})
# NOTE: The following adjustment ensures correct behavior with DDP on multiple GPUs.
offsets = return_data["image_grid_thw"][:, 1] * return_data["image_grid_thw"][:, 2] # (batch_size,)
# Split the pixel_values tensor into a list of tensors, one per image
pixel_values = list(
torch.split(return_data["pixel_values"], offsets.tolist())
) # [(num_patches_image_0, pixel_values), ..., (num_patches_image_n, pixel_values)]
# Pad the list of pixel_value tensors to the same length along the sequence dimension
return_data["pixel_values"] = torch.nn.utils.rnn.pad_sequence(
pixel_values, batch_first=True
) # (batch_size, max_num_patches, pixel_values)
if return_token_type_ids:
labels = return_data["input_ids"].masked_fill(return_data["token_type_ids"] == 0, -100)
return_data.update({"labels": labels})
return return_data
elif text is not None:
if isinstance(text, str):
text = [text]
elif not (isinstance(text, list) and isinstance(text[0], str)):
raise ValueError("Text must be a string or a list of strings")
if suffix is None:
suffix = self.query_augmentation_token * 10
texts_query: List[str] = []
for query in text:
augmented_query = self.query_prefix + query + suffix
texts_query.append(augmented_query)
batch_query = self.tokenizer(
texts_query,
return_token_type_ids=False,
**output_kwargs["text_kwargs"],
)
return batch_query
class ColQwen2PreTrainedModel(ColPaliPreTrainedModel):
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_cache_class = True
@dataclass
class ColQwen2ForRetrievalOutput(ModelOutput):
"""
Base class for ColQwen2 embeddings output.
Args:
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
Language modeling loss (for next-token prediction).
embeddings (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
The embeddings of the model.
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`)
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
`past_key_values` input) to speed up sequential decoding.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
loss: Optional[torch.FloatTensor] = None
embeddings: Optional[torch.Tensor] = None
past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
@auto_docstring(
custom_intro="""
Following the ColPali approach, ColQwen2 leverages VLMs to construct efficient multi-vector embeddings directly
from document images (screenshots) for document retrieval. The model is trained to maximize the similarity
between these document embeddings and the corresponding query embeddings, using the late interaction method
introduced in ColBERT.
Using ColQwen2 removes the need for potentially complex and brittle layout recognition and OCR pipelines with
a single model that can take into account both the textual and visual content (layout, charts, ...) of a document.
ColQwen2 is part of the ColVision model family, which was introduced with ColPali in the following paper:
[*ColPali: Efficient Document Retrieval with Vision Language Models*](https://arxiv.org/abs/2407.01449).
"""
)
class ColQwen2ForRetrieval(ColPaliForRetrieval):
@can_return_tuple
@auto_docstring
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
labels: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
pixel_values: Optional[torch.Tensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> ColQwen2ForRetrievalOutput:
r"""
image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
The temporal, height and width of feature shape of each image in LLM.
"""
if pixel_values is not None:
pixel_values = pixel_values.to(dtype=self.dtype) # (batch_size, max_num_patches, pixel_values)
# Handle the custom "pixel_values" input obtained with `ColQwen2Processor` through unpadding
if pixel_values is not None and image_grid_thw is not None:
# NOTE: image_grid_thw: (batch_size, 3) where image_grid_thw[i] = (num_patches_h, num_patches_w, temporal_patch_size)
offsets = image_grid_thw[:, 1] * image_grid_thw[:, 2] # (num_patches_h, num_patches_w)
pixel_values = torch.cat(
[pixel_sequence[:offset] for pixel_sequence, offset in zip(pixel_values, offsets)],
dim=0,
) # (num_patches_h * num_patches_w, pixel_values)
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
position_ids, rope_deltas = self.vlm.model.get_rope_index(
input_ids=input_ids,
image_grid_thw=image_grid_thw,
video_grid_thw=None,
attention_mask=attention_mask,
)
# Custom data preparation to fix an issue with the gradient flow when training with multiple GPUs.
if inputs_embeds is None:
inputs_embeds = self.vlm.model.language_model.embed_tokens(input_ids)
if pixel_values is not None:
pixel_values = pixel_values.type(self.vlm.visual.get_dtype())
image_embeds = self.vlm.visual(pixel_values, grid_thw=image_grid_thw)
image_mask = (
(input_ids == self.config.vlm_config.image_token_id).unsqueeze(-1).expand_as(inputs_embeds)
)
image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
if attention_mask is not None:
attention_mask = attention_mask.to(inputs_embeds.device)
vlm_output = self.vlm.model(
input_ids=None,
position_ids=position_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
)
vlm_hidden_states = vlm_output.hidden_states if output_hidden_states else None
last_hidden_states = vlm_output[0] # (batch_size, sequence_length, hidden_size)
embeddings = self.embedding_proj_layer(last_hidden_states) # (batch_size, sequence_length, dim)
# L2 normalization
embeddings = embeddings / embeddings.norm(dim=-1, keepdim=True) # (batch_size, sequence_length, dim)
if attention_mask is not None:
embeddings = embeddings * attention_mask.unsqueeze(-1) # (batch_size, sequence_length, dim)
return ColQwen2ForRetrievalOutput(
embeddings=embeddings,
past_key_values=vlm_output.past_key_values,
hidden_states=vlm_hidden_states,
attentions=vlm_output.attentions,
)
__all__ = [
"ColQwen2ForRetrieval",
"ColQwen2PreTrainedModel",
"ColQwen2Processor",
]

View File

@ -0,0 +1,408 @@
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# This file was automatically generated from src/transformers/models/colqwen2/modular_colqwen2.py.
# Do NOT edit this file manually as any edits will be overwritten by the generation of
# the file from the modular. If any change should be done, please apply the change to the
# modular_colqwen2.py file directly. One of our CI enforces this.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# coding=utf-8
# Copyright 2025 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional, Union
from ...feature_extraction_utils import BatchFeature
from ...image_utils import ImageInput, is_valid_image
from ...processing_utils import MultiModalData, ProcessingKwargs, ProcessorMixin, Unpack
from ...tokenization_utils_base import PreTokenizedInput, TextInput
from ...utils import is_torch_available
if is_torch_available():
import torch
class ColQwen2ProcessorKwargs(ProcessingKwargs, total=False):
_defaults = {
"text_kwargs": {
"padding": "longest",
},
"images_kwargs": {
"data_format": "channels_first",
"do_convert_rgb": True,
},
"common_kwargs": {"return_tensors": "pt"},
}
class ColQwen2Processor(ProcessorMixin):
r"""
Constructs a ColQwen2 processor which wraps a Qwen2VLProcessor and special methods to process images and queries, as
well as to compute the late-interaction retrieval score.
[`ColQwen2Processor`] offers all the functionalities of [`Qwen2VLProcessor`]. See the [`~Qwen2VLProcessor.__call__`]
for more information.
Args:
image_processor ([`Qwen2VLImageProcessor`], *optional*):
The image processor is a required input.
tokenizer ([`Qwen2TokenizerFast`], *optional*):
The tokenizer is a required input.
chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages
in a chat into a tokenizable string.
visual_prompt_prefix (`str`, *optional*): A string that gets tokenized and prepended to the image tokens.
query_prefix (`str`, *optional*): A prefix to be used for the query.
"""
attributes = ["image_processor", "tokenizer"]
image_processor_class = "Qwen2VLImageProcessor"
tokenizer_class = ("Qwen2Tokenizer", "Qwen2TokenizerFast")
def __init__(
self,
image_processor=None,
tokenizer=None,
chat_template=None,
visual_prompt_prefix: Optional[str] = None,
query_prefix: Optional[str] = None,
**kwargs,
):
super().__init__(image_processor, tokenizer, chat_template=chat_template)
self.image_token = "<|image_pad|>" if not hasattr(tokenizer, "image_token") else tokenizer.image_token
self.video_token = "<|video_pad|>" if not hasattr(tokenizer, "video_token") else tokenizer.video_token
if visual_prompt_prefix is None:
visual_prompt_prefix = "<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>Describe the image.<|im_end|><|endoftext|>"
self.visual_prompt_prefix = visual_prompt_prefix
if query_prefix is None:
query_prefix = "Query: "
self.query_prefix = query_prefix
def __call__(
self,
images: ImageInput = None,
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
audio=None,
videos=None,
**kwargs: Unpack[ColQwen2ProcessorKwargs],
) -> BatchFeature:
"""
Main method to prepare for the model either (1) one or several texts, either (2) one or several image(s). This method is a custom
wrapper around the Qwen2VLProcessor's [`~Qwen2VLProcessor.__call__`] method adapted for the ColQwen2 model. It cannot process
both text and images at the same time.
When preparing the the text(s), this method forwards the `text` and `kwargs` arguments to Qwen2TokenizerFast's
[`~Qwen2TokenizerFast.__call__`].
When preparing the the image(s), this method forwards the `images` and `kwargs` arguments to Qwen2VLImageProcessor's
[`~Qwen2VLImageProcessor.__call__`].
Please refer to the doctsring of the above two methods for more information.
Args:
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a
number of channels, H and W are image height and width.
text (`str`, `List[str]`, `List[List[str]]`):
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
return_tensors (`str` or [`~utils.TensorType`], *optional*):
If set, will return tensors of a particular framework. Acceptable values are:
- `'tf'`: Return TensorFlow `tf.constant` objects.
- `'pt'`: Return PyTorch `torch.Tensor` objects.
- `'np'`: Return NumPy `np.ndarray` objects.
- `'jax'`: Return JAX `jnp.ndarray` objects.
Returns:
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
- **input_ids** -- List of token ids to be fed to a model.
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
`return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
`None`).
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
"""
output_kwargs = self._merge_kwargs(
ColQwen2ProcessorKwargs,
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
**kwargs,
)
suffix = output_kwargs["text_kwargs"].pop("suffix", None)
return_token_type_ids = True if suffix is not None else False
if text is None and images is None:
raise ValueError("Either text or images must be provided")
if text is not None and images is not None:
raise ValueError("Only one of text or images can be processed at a time")
if images is not None:
if is_valid_image(images):
images = [images]
elif isinstance(images, list) and is_valid_image(images[0]):
pass
elif not (isinstance(images, list) and isinstance(images[0], list) and is_valid_image(images[0][0])):
raise ValueError("images must be an image, list of images or list of list of images")
texts_doc = [self.visual_prompt_prefix] * len(images)
image_inputs = self.image_processor(images=images, **output_kwargs["images_kwargs"])
image_grid_thw = image_inputs["image_grid_thw"]
if image_grid_thw is not None:
merge_length = self.image_processor.merge_size**2
index = 0
for i in range(len(texts_doc)):
while self.image_token in texts_doc[i]:
texts_doc[i] = texts_doc[i].replace(
self.image_token, "<|placeholder|>" * (image_grid_thw[index].prod() // merge_length), 1
)
index += 1
texts_doc[i] = texts_doc[i].replace("<|placeholder|>", self.image_token)
text_inputs = self.tokenizer(
texts_doc,
return_token_type_ids=False,
**output_kwargs["text_kwargs"],
)
return_data = BatchFeature(data={**text_inputs, **image_inputs})
# NOTE: The following adjustment ensures correct behavior with DDP on multiple GPUs.
offsets = return_data["image_grid_thw"][:, 1] * return_data["image_grid_thw"][:, 2] # (batch_size,)
# Split the pixel_values tensor into a list of tensors, one per image
pixel_values = list(
torch.split(return_data["pixel_values"], offsets.tolist())
) # [(num_patches_image_0, pixel_values), ..., (num_patches_image_n, pixel_values)]
# Pad the list of pixel_value tensors to the same length along the sequence dimension
return_data["pixel_values"] = torch.nn.utils.rnn.pad_sequence(
pixel_values, batch_first=True
) # (batch_size, max_num_patches, pixel_values)
if return_token_type_ids:
labels = return_data["input_ids"].masked_fill(return_data["token_type_ids"] == 0, -100)
return_data.update({"labels": labels})
return return_data
elif text is not None:
if isinstance(text, str):
text = [text]
elif not (isinstance(text, list) and isinstance(text[0], str)):
raise ValueError("Text must be a string or a list of strings")
if suffix is None:
suffix = self.query_augmentation_token * 10
texts_query: List[str] = []
for query in text:
augmented_query = self.query_prefix + query + suffix
texts_query.append(augmented_query)
batch_query = self.tokenizer(
texts_query,
return_token_type_ids=False,
**output_kwargs["text_kwargs"],
)
return batch_query
def _get_num_multimodal_tokens(self, image_sizes=None, **kwargs):
"""
Computes the number of placeholder tokens needed for multimodal inputs with the given sizes.
Args:
image_sizes (List[List[str]], *optional*):
The input sizes formatted as (height, width) per each image.
Returns:
Dict[str, List[int]]: A dictionary mapping each modality ("image", "video", "audio")
to a list containing the number of placeholder tokens required. If the model doesn't accept
a certain modality or no input sizes are provided, the dict value is set to an empty list.
"""
vision_data = {}
if image_sizes is not None:
num_image_tokens = [self.image_seq_length] * len(image_sizes)
num_image_patches = [1] * len(image_sizes)
vision_data.update({"num_image_tokens": num_image_tokens, "num_image_patches": num_image_patches})
return MultiModalData(**vision_data)
def batch_decode(self, *args, **kwargs):
"""
This method forwards all its arguments to GemmaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
refer to the docstring of this method for more information.
"""
return self.tokenizer.batch_decode(*args, **kwargs)
def decode(self, *args, **kwargs):
"""
This method forwards all its arguments to GemmaTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
the docstring of this method for more information.
"""
return self.tokenizer.decode(*args, **kwargs)
@property
def model_input_names(self):
tokenizer_input_names = self.tokenizer.model_input_names
image_processor_input_names = self.image_processor.model_input_names
return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
@property
def query_augmentation_token(self) -> str:
"""
Return the query augmentation token.
Query augmentation buffers are used as reasoning buffers during inference.
"""
return self.tokenizer.pad_token
def process_images(
self,
images: ImageInput = None,
**kwargs: Unpack[ColQwen2ProcessorKwargs],
) -> BatchFeature:
"""
Prepare for the model one or several image(s). This method is a wrapper around the `__call__` method of the ColQwen2Processor's
[`ColQwen2Processor.__call__`].
This method forwards the `images` and `kwargs` arguments to the image processor.
Args:
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a
number of channels, H and W are image height and width.
return_tensors (`str` or [`~utils.TensorType`], *optional*):
If set, will return tensors of a particular framework. Acceptable values are:
- `'tf'`: Return TensorFlow `tf.constant` objects.
- `'pt'`: Return PyTorch `torch.Tensor` objects.
- `'np'`: Return NumPy `np.ndarray` objects.
- `'jax'`: Return JAX `jnp.ndarray` objects.
Returns:
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
- **input_ids** -- List of token ids to be fed to a model.
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
`return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
`None`).
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
"""
return self.__call__(images=images, **kwargs)
def process_queries(
self,
text: Union[TextInput, List[TextInput]],
**kwargs: Unpack[ColQwen2ProcessorKwargs],
) -> BatchFeature:
"""
Prepare for the model one or several texts. This method is a wrapper around the `__call__` method of the ColQwen2Processor's
[`ColQwen2Processor.__call__`].
This method forwards the `text` and `kwargs` arguments to the tokenizer.
Args:
text (`str`, `List[str]`, `List[List[str]]`):
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
return_tensors (`str` or [`~utils.TensorType`], *optional*):
If set, will return tensors of a particular framework. Acceptable values are:
- `'tf'`: Return TensorFlow `tf.constant` objects.
- `'pt'`: Return PyTorch `torch.Tensor` objects.
- `'np'`: Return NumPy `np.ndarray` objects.
- `'jax'`: Return JAX `jnp.ndarray` objects.
Returns:
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
- **input_ids** -- List of token ids to be fed to a model.
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
`return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
`None`).
"""
return self.__call__(text=text, **kwargs)
def score_retrieval(
self,
query_embeddings: Union["torch.Tensor", List["torch.Tensor"]],
passage_embeddings: Union["torch.Tensor", List["torch.Tensor"]],
batch_size: int = 128,
output_dtype: Optional["torch.dtype"] = None,
output_device: Union["torch.device", str] = "cpu",
) -> "torch.Tensor":
"""
Compute the late-interaction/MaxSim score (ColBERT-like) for the given multi-vector
query embeddings (`qs`) and passage embeddings (`ps`). For ColQwen2, a passage is the
image of a document page.
Because the embedding tensors are multi-vector and can thus have different shapes, they
should be fed as:
(1) a list of tensors, where the i-th tensor is of shape (sequence_length_i, embedding_dim)
(2) a single tensor of shape (n_passages, max_sequence_length, embedding_dim) -> usually
obtained by padding the list of tensors.
Args:
query_embeddings (`Union[torch.Tensor, List[torch.Tensor]`): Query embeddings.
passage_embeddings (`Union[torch.Tensor, List[torch.Tensor]`): Passage embeddings.
batch_size (`int`, *optional*, defaults to 128): Batch size for computing scores.
output_dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): The dtype of the output tensor.
If `None`, the dtype of the input embeddings is used.
output_device (`torch.device` or `str`, *optional*, defaults to "cpu"): The device of the output tensor.
Returns:
`torch.Tensor`: A tensor of shape `(n_queries, n_passages)` containing the scores. The score
tensor is saved on the "cpu" device.
"""
if len(query_embeddings) == 0:
raise ValueError("No queries provided")
if len(passage_embeddings) == 0:
raise ValueError("No passages provided")
if query_embeddings[0].device != passage_embeddings[0].device:
raise ValueError("Queries and passages must be on the same device")
if query_embeddings[0].dtype != passage_embeddings[0].dtype:
raise ValueError("Queries and passages must have the same dtype")
if output_dtype is None:
output_dtype = query_embeddings[0].dtype
scores: List[torch.Tensor] = []
for i in range(0, len(query_embeddings), batch_size):
batch_scores: List[torch.Tensor] = []
batch_queries = torch.nn.utils.rnn.pad_sequence(
query_embeddings[i : i + batch_size], batch_first=True, padding_value=0
)
for j in range(0, len(passage_embeddings), batch_size):
batch_passages = torch.nn.utils.rnn.pad_sequence(
passage_embeddings[j : j + batch_size], batch_first=True, padding_value=0
)
batch_scores.append(
torch.einsum("bnd,csd->bcns", batch_queries, batch_passages).max(dim=3)[0].sum(dim=2)
)
scores.append(torch.cat(batch_scores, dim=1).to(output_dtype).to(output_device))
return torch.cat(scores, dim=0)
__all__ = ["ColQwen2Processor"]

View File

@ -168,7 +168,6 @@ class ColPaliForRetrievalModelTester:
"input_ids": input_ids,
"attention_mask": attention_mask,
"labels": input_ids,
"token_type_ids": torch.zeros_like(input_ids),
}
return config, inputs_dict
@ -333,7 +332,7 @@ class ColPaliModelIntegrationTest(unittest.TestCase):
scores = self.processor.score_retrieval(
query_embeddings=query_embeddings,
passage_embeddings=image_embeddings,
) # (len(qs), len(ps))
) # (num_queries, num_passages)
assert scores.ndim == 2, f"Expected 2D tensor, got {scores.ndim}"
assert scores.shape == (len(ds), len(ds)), f"Expected shape {(len(ds), len(ds))}, got {scores.shape}"

View File

@ -1,3 +1,18 @@
# Copyright 2024 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Testing suite for the ColPali processor."""
import shutil
import tempfile
import unittest
@ -89,7 +104,7 @@ class ColPaliProcessorTest(ProcessorTesterMixin, unittest.TestCase):
self.assertIsInstance(batch_feature["input_ids"], torch.Tensor)
self.assertEqual(batch_feature["input_ids"].shape[0], len(queries))
# The following tests are overwritten as ColPaliProcessor can only take one of images or text as input at a time
# The following tests override the parent tests because ColPaliProcessor can only take one of images or text as input at a time.
def test_tokenizer_defaults_preserved_by_kwargs(self):
if "image_processor" not in self.processor_class.attributes:

View File

View File

@ -0,0 +1,333 @@
# coding=utf-8
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Testing suite for the PyTorch ColQwen2 model."""
import gc
import unittest
from typing import ClassVar
import torch
from datasets import load_dataset
from tests.test_configuration_common import ConfigTester
from tests.test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
from transformers import is_torch_available
from transformers.models.colqwen2.configuration_colqwen2 import ColQwen2Config
from transformers.models.colqwen2.modeling_colqwen2 import ColQwen2ForRetrieval, ColQwen2ForRetrievalOutput
from transformers.models.colqwen2.processing_colqwen2 import ColQwen2Processor
from transformers.testing_utils import require_torch, require_vision, slow, torch_device
if is_torch_available():
import torch
class ColQwen2ForRetrievalModelTester:
def __init__(
self,
parent,
ignore_index=-100,
pad_token_id=2,
projector_hidden_act="gelu",
seq_length=11,
vision_feature_select_strategy="default",
vision_feature_layer=-1,
projection_dim=32,
is_training=False,
use_cache=False,
vlm_config={
"_name_or_path": "Qwen/Qwen2-VL-2B-Instruct",
"bos_token_id": 0,
"eos_token_id": 1,
"vision_start_token_id": 3,
"image_token_id": 4,
"video_token_id": 5,
"hidden_size": 64,
"intermediate_size": 2,
"max_window_layers": 2,
"model_type": "qwen2_vl",
"num_attention_heads": 2,
"num_hidden_layers": 2,
"num_key_value_heads": 2,
"rms_norm_eps": 1e-06,
"rope_scaling": {"mrope_section": [4, 6, 6], "rope_type": "default", "type": "default"},
"sliding_window": 32768,
"tie_word_embeddings": True,
"vision_config": {
"depth": 2,
"embed_dim": 32,
"hidden_act": "quick_gelu",
"hidden_size": 64,
"mlp_ratio": 4,
"num_heads": 4,
"patch_size": 14,
"in_chans": 3,
"spatial_merge_size": 1,
"temporal_patch_size": 2,
},
"vision_end_token_id": 151653,
"vision_token_id": 151654,
"vocab_size": 99,
},
embedding_dim=32,
initializer_range=0.02,
):
self.parent = parent
self.ignore_index = ignore_index
self.pad_token_id = pad_token_id
# `image_token_index` is set to 0 to pass "resize_embeddings" test, do not modify
self.image_token_index = 0
self.image_token_id = vlm_config["image_token_id"]
self.video_token_id = vlm_config["video_token_id"]
self.pad_token_id = vlm_config["eos_token_id"]
self.vision_start_token_id = vlm_config["vision_start_token_id"]
self.projector_hidden_act = projector_hidden_act
self.vision_feature_select_strategy = vision_feature_select_strategy
self.vision_feature_layer = vision_feature_layer
self.image_size = 56
self.num_image_tokens = 4
self.seq_length = seq_length + self.num_image_tokens
self.projection_dim = projection_dim
self.num_hidden_layers = vlm_config["num_hidden_layers"]
self.vocab_size = vlm_config["vocab_size"]
self.hidden_size = vlm_config["hidden_size"]
self.num_attention_heads = vlm_config["num_attention_heads"]
self.is_training = is_training
self.batch_size = 3
self.num_channels = vlm_config["vision_config"]["in_chans"]
self.encoder_seq_length = self.seq_length
self.use_cache = use_cache
self.vlm_config = vlm_config
self.embedding_dim = embedding_dim
self.initializer_range = initializer_range
def get_config(self):
return ColQwen2Config(
vlm_config=self.vlm_config,
embedding_dim=self.embedding_dim,
initializer_range=self.initializer_range,
)
def prepare_config_and_inputs(self):
config = self.get_config()
patch_size = config.vlm_config.vision_config.patch_size
temporal_patch_size = config.vlm_config.vision_config.temporal_patch_size
# NOTE: Assume all inputs are square images of the same size.
num_patches = (self.image_size // patch_size) ** 2
pixel_values = floats_tensor(
[
self.batch_size * num_patches,
self.num_channels * (patch_size**2) * temporal_patch_size,
]
)
# Hardcoded image grid size: do not change unless you modified image size or patch size!
image_grid_thw = torch.tensor([1, 4, 4]).repeat(self.batch_size, 1)
# NOTE: The following adjustment ensures correct behavior with DDP on multiple GPUs.
# Line is copied from `src/transformers/models/colqwen2/processing_colqwen2.py`
offsets = image_grid_thw[:, 1] * image_grid_thw[:, 2] # (batch_size,)
pixel_values = list(
torch.split(pixel_values, offsets.tolist())
) # [(num_patches_image_0, pixel_values), ..., (num_patches_image_n, pixel_values)]
pixel_values = torch.nn.utils.rnn.pad_sequence(
pixel_values, batch_first=True
) # (batch_size, max_num_patches, pixel_values)
return config, pixel_values, image_grid_thw
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
config, pixel_values, image_grid_thw = config_and_inputs
input_ids = (
ids_tensor(
shape=[self.batch_size, self.seq_length],
vocab_size=config.vlm_config.vocab_size - 1,
)
+ 1
)
attention_mask = torch.ones(input_ids.shape, dtype=torch.long, device=torch_device)
input_ids[:, -1] = self.pad_token_id
input_ids[:, : self.num_image_tokens] = self.image_token_id
input_ids[input_ids == self.video_token_id] = self.pad_token_id
input_ids[input_ids == self.image_token_id] = self.pad_token_id
input_ids[input_ids == self.vision_start_token_id] = self.pad_token_id
inputs_dict = {
"input_ids": input_ids,
"pixel_values": pixel_values,
"image_grid_thw": image_grid_thw,
"attention_mask": attention_mask,
"labels": input_ids,
}
return config, inputs_dict
@require_torch
class ColQwen2ForRetrievalModelTest(ModelTesterMixin, unittest.TestCase):
"""
Model tester for `ColQwen2ForRetrieval`.
"""
all_model_classes = (ColQwen2ForRetrieval,) if is_torch_available() else ()
fx_compatible = False
test_torchscript = False
test_pruning = False
test_resize_embeddings = True
test_head_masking = False
def setUp(self):
self.model_tester = ColQwen2ForRetrievalModelTester(self)
self.config_tester = ConfigTester(self, config_class=ColQwen2Config, has_text_modality=False)
def test_inputs_embeds(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
model.to(torch_device)
model.eval()
inputs = self._prepare_for_class(inputs_dict, model_class)
input_ids = inputs["input_ids"]
del inputs["input_ids"]
del inputs["pixel_values"]
wte = model.get_input_embeddings()
inputs["inputs_embeds"] = wte(input_ids)
with torch.no_grad():
model(**inputs)
# overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs
# while some other models require pixel_values to be present
def test_inputs_embeds_matches_input_ids(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
model.to(torch_device)
model.eval()
inputs = self._prepare_for_class(inputs_dict, model_class)
input_ids = inputs["input_ids"]
del inputs["input_ids"]
del inputs["pixel_values"]
inputs_embeds = model.get_input_embeddings()(input_ids)
with torch.no_grad():
out_ids = model(input_ids=input_ids, **inputs)[0]
out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0]
self.assertTrue(torch.allclose(out_embeds, out_ids))
@slow
@require_vision
def test_colqwen2_forward_inputs(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
model.to(torch_device)
model.eval()
inputs = self._prepare_for_class(inputs_dict, model_class)
with torch.no_grad():
outputs = model(**inputs, return_dict=True)
self.assertIsInstance(outputs, ColQwen2ForRetrievalOutput)
@unittest.skip(reason="Some undefined behavior encountered with test versions of Qwen2-VL. Skip for now.")
def test_model_parallelism(self):
pass
@unittest.skip(reason="Pass because ColQwen2 requires `attention_mask is not None`")
def test_sdpa_can_dispatch_on_flash(self):
pass
@unittest.skip(reason="Pass because ColQwen2 requires `attention_mask is not None`")
def test_sdpa_can_compile_dynamic(self):
pass
@require_torch
class ColQwen2ModelIntegrationTest(unittest.TestCase):
model_name: ClassVar[str] = "vidore/colqwen2-v1.0-hf"
def setUp(self):
self.processor = ColQwen2Processor.from_pretrained(self.model_name)
def tearDown(self):
gc.collect()
torch.cuda.empty_cache()
@slow
def test_model_integration_test(self):
"""
Test if the model is able to retrieve the correct pages for a small and easy dataset.
"""
model = ColQwen2ForRetrieval.from_pretrained(
self.model_name,
torch_dtype=torch.bfloat16,
device_map=torch_device,
).eval()
# Load the test dataset
ds = load_dataset("hf-internal-testing/document-visual-retrieval-test", split="test")
# Preprocess the examples
batch_images = self.processor(images=ds["image"]).to(torch_device)
batch_queries = self.processor(text=ds["query"]).to(torch_device)
# Run inference
with torch.inference_mode():
image_embeddings = model(**batch_images).embeddings
query_embeddings = model(**batch_queries).embeddings
# Compute retrieval scores
scores = self.processor.score_retrieval(
query_embeddings=query_embeddings,
passage_embeddings=image_embeddings,
) # (num_queries, num_passages)
assert scores.ndim == 2, f"Expected 2D tensor, got {scores.ndim}"
assert scores.shape == (len(ds), len(ds)), f"Expected shape {(len(ds), len(ds))}, got {scores.shape}"
# Check if the maximum scores per row are in the diagonal of the matrix score
self.assertTrue((scores.argmax(axis=1) == torch.arange(len(ds), device=scores.device)).all())
# Further validation: fine-grained check, with a hardcoded score from the original Hf implementation.
expected_scores = torch.tensor(
[
[16.2500, 7.8750, 14.6875],
[9.5000, 17.1250, 10.5000],
[14.9375, 10.9375, 20.0000],
],
dtype=scores.dtype,
)
assert torch.allclose(scores, expected_scores, atol=1e-3), f"Expected scores {expected_scores}, got {scores}"

View File

@ -0,0 +1,262 @@
# coding=utf-8
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Testing suite for the ColQwen2 processor."""
import shutil
import tempfile
import unittest
import torch
from transformers import AutoProcessor, Qwen2VLProcessor
from transformers.models.colqwen2.processing_colqwen2 import ColQwen2Processor
from transformers.testing_utils import get_tests_dir, require_torch, require_vision
from transformers.utils import is_vision_available
from ...test_processing_common import ProcessorTesterMixin
if is_vision_available():
from transformers import (
ColQwen2Processor,
)
SAMPLE_VOCAB = get_tests_dir("fixtures/test_sentencepiece.model")
@require_torch
@require_vision
class ColQwen2ProcessorTest(ProcessorTesterMixin, unittest.TestCase):
processor_class = ColQwen2Processor
@classmethod
def setUpClass(cls):
cls.tmpdirname = tempfile.mkdtemp()
processor = Qwen2VLProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")
processor.save_pretrained(cls.tmpdirname)
def get_tokenizer(self, **kwargs):
return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).tokenizer
def get_image_processor(self, **kwargs):
return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).image_processor
@classmethod
def tearDownClass(cls):
shutil.rmtree(cls.tmpdirname)
def test_process_images(self):
# Processor configuration
image_input = self.prepare_image_inputs()
image_processor = self.get_component("image_processor")
tokenizer = self.get_component("tokenizer", max_length=112, padding="max_length")
image_processor.image_seq_length = 14
# Get the processor
processor = self.processor_class(
tokenizer=tokenizer,
image_processor=image_processor,
)
# Process the image
batch_feature = processor.process_images(images=image_input, return_tensors="pt")
# Assertions
self.assertIn("pixel_values", batch_feature)
self.assertEqual(batch_feature["pixel_values"].shape, torch.Size([1, 56, 1176]))
def test_process_queries(self):
# Inputs
queries = [
"Is attention really all you need?",
"Are Benjamin, Antoine, Merve, and Jo best friends?",
]
# Processor configuration
image_processor = self.get_component("image_processor")
tokenizer = self.get_component("tokenizer", max_length=112, padding="max_length")
image_processor.image_seq_length = 14
# Get the processor
processor = self.processor_class(
tokenizer=tokenizer,
image_processor=image_processor,
)
# Process the image
batch_feature = processor.process_queries(text=queries, return_tensors="pt")
# Assertions
self.assertIn("input_ids", batch_feature)
self.assertIsInstance(batch_feature["input_ids"], torch.Tensor)
self.assertEqual(batch_feature["input_ids"].shape[0], len(queries))
# The following tests override the parent tests because ColQwen2Processor can only take one of images or text as input at a time.
def test_tokenizer_defaults_preserved_by_kwargs(self):
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
processor_components = self.prepare_components()
processor_components["tokenizer"] = self.get_component("tokenizer", max_length=117, padding="max_length")
processor = self.processor_class(**processor_components)
self.skip_processor_without_typed_kwargs(processor)
input_str = self.prepare_text_inputs()
inputs = processor(text=input_str, return_tensors="pt")
self.assertEqual(inputs[self.text_input_name].shape[-1], 117)
def test_image_processor_defaults_preserved_by_image_kwargs(self):
"""
We use do_rescale=True, rescale_factor=-1 to ensure that image_processor kwargs are preserved in the processor.
We then check that the mean of the pixel_values is less than or equal to 0 after processing.
Since the original pixel_values are in [0, 255], this is a good indicator that the rescale_factor is indeed applied.
"""
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
processor_components = self.prepare_components()
processor_components["image_processor"] = self.get_component(
"image_processor", do_rescale=True, rescale_factor=-1
)
processor_components["tokenizer"] = self.get_component("tokenizer", max_length=117, padding="max_length")
processor = self.processor_class(**processor_components)
self.skip_processor_without_typed_kwargs(processor)
image_input = self.prepare_image_inputs()
inputs = processor(images=image_input, return_tensors="pt")
self.assertLessEqual(inputs[self.images_input_name][0][0].mean(), 0)
def test_kwargs_overrides_default_tokenizer_kwargs(self):
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
processor_components = self.prepare_components()
processor_components["tokenizer"] = self.get_component("tokenizer", padding="longest")
processor = self.processor_class(**processor_components)
self.skip_processor_without_typed_kwargs(processor)
input_str = self.prepare_text_inputs()
inputs = processor(text=input_str, return_tensors="pt", max_length=112, padding="max_length")
self.assertEqual(inputs[self.text_input_name].shape[-1], 112)
def test_kwargs_overrides_default_image_processor_kwargs(self):
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
processor_components = self.prepare_components()
processor_components["image_processor"] = self.get_component(
"image_processor", do_rescale=True, rescale_factor=1
)
processor_components["tokenizer"] = self.get_component("tokenizer", max_length=117, padding="max_length")
processor = self.processor_class(**processor_components)
self.skip_processor_without_typed_kwargs(processor)
image_input = self.prepare_image_inputs()
inputs = processor(images=image_input, do_rescale=True, rescale_factor=-1, return_tensors="pt")
self.assertLessEqual(inputs[self.images_input_name][0][0].mean(), 0)
def test_unstructured_kwargs(self):
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
processor_components = self.prepare_components()
processor = self.processor_class(**processor_components)
self.skip_processor_without_typed_kwargs(processor)
input_str = self.prepare_text_inputs()
inputs = processor(
text=input_str,
return_tensors="pt",
do_rescale=True,
rescale_factor=-1,
padding="max_length",
max_length=76,
)
self.assertEqual(inputs[self.text_input_name].shape[-1], 76)
def test_unstructured_kwargs_batched(self):
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
processor_components = self.prepare_components()
processor = self.processor_class(**processor_components)
self.skip_processor_without_typed_kwargs(processor)
image_input = self.prepare_image_inputs(batch_size=2)
inputs = processor(
images=image_input,
return_tensors="pt",
do_rescale=True,
rescale_factor=-1,
padding="longest",
max_length=76,
)
self.assertLessEqual(inputs[self.images_input_name][0][0].mean(), 0)
def test_doubly_passed_kwargs(self):
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
processor_components = self.prepare_components()
processor = self.processor_class(**processor_components)
self.skip_processor_without_typed_kwargs(processor)
image_input = self.prepare_image_inputs()
with self.assertRaises(ValueError):
_ = processor(
images=image_input,
images_kwargs={"do_rescale": True, "rescale_factor": -1},
do_rescale=True,
return_tensors="pt",
)
def test_structured_kwargs_nested(self):
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
processor_components = self.prepare_components()
processor = self.processor_class(**processor_components)
self.skip_processor_without_typed_kwargs(processor)
input_str = self.prepare_text_inputs()
# Define the kwargs for each modality
all_kwargs = {
"common_kwargs": {"return_tensors": "pt"},
"images_kwargs": {"do_rescale": True, "rescale_factor": -1},
"text_kwargs": {"padding": "max_length", "max_length": 76},
}
inputs = processor(text=input_str, **all_kwargs)
self.skip_processor_without_typed_kwargs(processor)
self.assertEqual(inputs[self.text_input_name].shape[-1], 76)
def test_structured_kwargs_nested_from_dict(self):
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
processor_components = self.prepare_components()
processor = self.processor_class(**processor_components)
self.skip_processor_without_typed_kwargs(processor)
image_input = self.prepare_image_inputs()
# Define the kwargs for each modality
all_kwargs = {
"common_kwargs": {"return_tensors": "pt"},
"images_kwargs": {"do_rescale": True, "rescale_factor": -1},
"text_kwargs": {"padding": "max_length", "max_length": 76},
}
inputs = processor(images=image_input, **all_kwargs)
self.assertEqual(inputs[self.text_input_name].shape[-1], 76)