mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 12:50:06 +06:00
Add ColQwen2 to 🤗 transformers (#35778)
* feat: add colqwen2 (wip) * tests: fix test_attention_outputs * tests: reduce hidden size to accelerate tests * tests: fix `test_attention_outputs` 🥳 * fix: fix wrong parent class for `ColQwen2ForRetrievalOutput` * fix: minor typing and style changes * chore: run `make style` * feat: remove redundant `max_num_visual_tokens` attribute in `ColQwen2Processor` * tests: tweak comments * style: apply ruff formatter * feat: move default values for `visual_prompt_prefix` and `query_prefix` * docs: update ColQwen2 model card * docs: tweak model cards * docs: add required example config checkpoint * tests: update expected scores in integration test * docs: tweak quickstart snippets * fix: address PR comments * tests: fix colqwen2 tests + tweak comment in colpali test * tests: unskip useful tests * fix: fix bug when `visual_prompt_prefix` or `query_prefix` is an empty string * fix: fix ColPali outputs when `return_dict == False` * fix: fix issue with PaliGemma output not being a dict * docs: set default dtype to bfloat16 in quickstart snippets * fix: fix error when `return_dict=False` in ColPali and ColQwen2 * tests: fix special tokens not being replaced in input_ids * style: fix lint * fix: `ColQwen2Processor`'s `padding_side` is now set from `processor_config.json` * fix: remove unused `padding_side` in ColQwen2 model * docs: update ColQwen2's model doc * fix: fix harcoded vlm backbone class in ColQwen2Config * fix: remove `padding_side` from ColQwen2Processor as should fed from kwargs * docs: fix typo in model docstring * docs: add illuin mention in model docs * fix: let `padding_size` be handled by `tokenizer_config.json` * docs: add colpali reference url in colqwen2's model doc * docs: add Hf mention in model docs * docs: add late interaction mention in model docs * docs: tweak colqwen2 model doc * docs: update reference checkpoint for ColPali to v1.3 * docs: simplify quickstart snippets * docs: remove redundant `.eval()` * refactor: use `can_return_tuple` decorator for ColPali and ColQwen2 * docs: fix copyright date * docs: add missing copyright in tests * fix: raise error when `initializer_range` is not in config * docs: remove redundant `.eval()` in colpali doc * fix: fix `get_text_config` now that Qwen2VL has a proper `text_config` attribute See https://github.com/huggingface/transformers/pull/37268 for details about changes in Qwen2VL's config. * fix: add missing `initializer_range` attribute in `ColQwen2Config` * fix: use `get_text_config` in `resize_token_embeddings` * update colwen2 with auto_docstring * docs: fix wrong copyright year * chore: remove `raise` as `initializer_range` has a default value in `ColQwen2Config` * refactor: merge `inner_forward` into `forward` * Refactor colqwen2 after refactoring of qwen2VL, use modular for modeling code * protect torch import in modular to protect in processing * protect torch import in modular to protect in processing * tests: fix hf model path in ColQwen2 integration test * docs: clarify `attn_implementation` and add comments * docs: add fallback snippet for using offline PIL dummy images * docs: temporarily revert attn_implementation to `None` while sdpa is not fixed * docs: tweaks in colpali/colqwen2 quick start snippets * fix: add missing flags to enable SDPA/Flex Attention in ColQwen2 model * fix: add missing changes in modular file * fix modeling tests --------- Co-authored-by: yonigozlan <yoni.gozlan@huggingface.co>
This commit is contained in:
parent
beaed8ce01
commit
c72ba69441
@ -937,6 +937,8 @@
|
||||
title: CLVP
|
||||
- local: model_doc/colpali
|
||||
title: ColPali
|
||||
- local: model_doc/colqwen2
|
||||
title: ColQwen2
|
||||
- local: model_doc/data2vec
|
||||
title: Data2Vec
|
||||
- local: model_doc/deplot
|
||||
|
@ -20,9 +20,11 @@ rendered properly in your Markdown viewer.
|
||||
|
||||
# ColPali
|
||||
|
||||
[ColPali](https://huggingface.co/papers/2407.01449) is a model designed to retrieve documents by analyzing their visual features. Unlike traditional systems that rely heavily on text extraction and OCR, ColPali treats each page as an image. It uses [Paligemma-3B](./paligemma) to capture not only text, but also the layout, tables, charts, and other visual elements to create detailed embeddings. This offers a more comprehensive understanding of documents and enables more efficient and accurate retrieval.
|
||||
[ColPali](https://huggingface.co/papers/2407.01449) is a model designed to retrieve documents by analyzing their visual features. Unlike traditional systems that rely heavily on text extraction and OCR, ColPali treats each page as an image. It uses [Paligemma-3B](./paligemma) to capture not only text, but also the layout, tables, charts, and other visual elements to create detailed multi-vector embeddings that can be used for retrieval by computing pairwise late interaction similarity scores. This offers a more comprehensive understanding of documents and enables more efficient and accurate retrieval.
|
||||
|
||||
You can find all the original ColPali checkpoints under the [ColPali](https://huggingface.co/collections/vidore/hf-native-colvision-models-6755d68fc60a8553acaa96f7) collection.
|
||||
This model was contributed by [@tonywu71](https://huggingface.co/tonywu71) (ILLUIN Technology) and [@yonigozlan](https://huggingface.co/yonigozlan) (HuggingFace).
|
||||
|
||||
You can find all the original ColPali checkpoints under Vidore's [Hf-native ColVision Models](https://huggingface.co/collections/vidore/hf-native-colvision-models-6755d68fc60a8553acaa96f7) collection.
|
||||
|
||||
> [!TIP]
|
||||
> Click on the ColPali models in the right sidebar for more examples of how to use ColPali for image retrieval.
|
||||
@ -30,21 +32,25 @@ You can find all the original ColPali checkpoints under the [ColPali](https://hu
|
||||
<hfoptions id="usage">
|
||||
<hfoption id="image retrieval">
|
||||
|
||||
```py
|
||||
```python
|
||||
import requests
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from transformers import ColPaliForRetrieval, ColPaliProcessor
|
||||
|
||||
# Load model (bfloat16 support is limited; fallback to float32 if needed)
|
||||
model = ColPaliForRetrieval.from_pretrained(
|
||||
"vidore/colpali-v1.2-hf",
|
||||
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
|
||||
device_map="auto", # "cpu", "cuda", or "mps" for Apple Silicon
|
||||
).eval()
|
||||
|
||||
# Load the model and the processor
|
||||
model_name = "vidore/colpali-v1.3-hf"
|
||||
|
||||
model = ColPaliForRetrieval.from_pretrained(
|
||||
model_name,
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map="auto", # "cpu", "cuda", or "mps" for Apple Silicon
|
||||
)
|
||||
processor = ColPaliProcessor.from_pretrained(model_name)
|
||||
|
||||
# The document page screenshots from your corpus
|
||||
url1 = "https://upload.wikimedia.org/wikipedia/commons/8/89/US-original-Declaration-1776.jpg"
|
||||
url2 = "https://upload.wikimedia.org/wikipedia/commons/thumb/4/4c/Romeoandjuliet1597.jpg/500px-Romeoandjuliet1597.jpg"
|
||||
|
||||
@ -53,25 +59,37 @@ images = [
|
||||
Image.open(requests.get(url2, stream=True).raw),
|
||||
]
|
||||
|
||||
# The queries you want to retrieve documents for
|
||||
queries = [
|
||||
"Who printed the edition of Romeo and Juliet?",
|
||||
"When was the United States Declaration of Independence proclaimed?",
|
||||
"Who printed the edition of Romeo and Juliet?",
|
||||
]
|
||||
|
||||
# Process the inputs
|
||||
inputs_images = processor(images=images, return_tensors="pt").to(model.device)
|
||||
inputs_text = processor(text=queries, return_tensors="pt").to(model.device)
|
||||
inputs_images = processor(images=images).to(model.device)
|
||||
inputs_text = processor(text=queries).to(model.device)
|
||||
|
||||
# Forward pass
|
||||
with torch.no_grad():
|
||||
image_embeddings = model(**inputs_images).embeddings
|
||||
query_embeddings = model(**inputs_text).embeddings
|
||||
|
||||
# Score the queries against the images
|
||||
scores = processor.score_retrieval(query_embeddings, image_embeddings)
|
||||
|
||||
print("Retrieval scores (query x image):")
|
||||
print(scores)
|
||||
```
|
||||
|
||||
If you have issue with loading the images with PIL, you can use the following code to create dummy images:
|
||||
|
||||
```python
|
||||
images = [
|
||||
Image.new("RGB", (128, 128), color="white"),
|
||||
Image.new("RGB", (64, 32), color="black"),
|
||||
]
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
@ -79,12 +97,15 @@ Quantization reduces the memory burden of large models by representing the weigh
|
||||
|
||||
The example below uses [bitsandbytes](../quantization/bitsandbytes.md) to quantize the weights to int4.
|
||||
|
||||
```py
|
||||
```python
|
||||
import requests
|
||||
import torch
|
||||
from PIL import Image
|
||||
from transformers import ColPaliForRetrieval, ColPaliProcessor
|
||||
from transformers import BitsAndBytesConfig
|
||||
|
||||
from transformers import BitsAndBytesConfig, ColPaliForRetrieval, ColPaliProcessor
|
||||
|
||||
|
||||
model_name = "vidore/colpali-v1.3-hf"
|
||||
|
||||
# 4-bit quantization configuration
|
||||
bnb_config = BitsAndBytesConfig(
|
||||
@ -94,14 +115,11 @@ bnb_config = BitsAndBytesConfig(
|
||||
bnb_4bit_compute_dtype=torch.float16,
|
||||
)
|
||||
|
||||
model_name = "vidore/colpali-v1.2-hf"
|
||||
|
||||
# Load model
|
||||
model = ColPaliForRetrieval.from_pretrained(
|
||||
model_name,
|
||||
quantization_config=bnb_config,
|
||||
device_map="cuda"
|
||||
).eval()
|
||||
device_map="cuda",
|
||||
)
|
||||
|
||||
processor = ColPaliProcessor.from_pretrained(model_name)
|
||||
|
||||
@ -114,8 +132,8 @@ images = [
|
||||
]
|
||||
|
||||
queries = [
|
||||
"Who printed the edition of Romeo and Juliet?",
|
||||
"When was the United States Declaration of Independence proclaimed?",
|
||||
"Who printed the edition of Romeo and Juliet?",
|
||||
]
|
||||
|
||||
# Process the inputs
|
||||
@ -127,6 +145,7 @@ with torch.no_grad():
|
||||
image_embeddings = model(**inputs_images).embeddings
|
||||
query_embeddings = model(**inputs_text).embeddings
|
||||
|
||||
# Score the queries against the images
|
||||
scores = processor.score_retrieval(query_embeddings, image_embeddings)
|
||||
|
||||
print("Retrieval scores (query x image):")
|
||||
|
176
docs/source/en/model_doc/colqwen2.md
Normal file
176
docs/source/en/model_doc/colqwen2.md
Normal file
@ -0,0 +1,176 @@
|
||||
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
|
||||
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
||||
rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
|
||||
<div style="float: right;">
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
</div>
|
||||
</div>
|
||||
|
||||
# ColQwen2
|
||||
|
||||
[ColQwen2](https://doi.org/10.48550/arXiv.2407.01449) is a variant of the [ColPali](./colpali) model designed to retrieve documents by analyzing their visual features. Unlike traditional systems that rely heavily on text extraction and OCR, ColQwen2 treats each page as an image. It uses the [Qwen2-VL](./qwen2_vl) backbone to capture not only text, but also the layout, tables, charts, and other visual elements to create detailed multi-vector embeddings that can be used for retrieval by computing pairwise late interaction similarity scores. This offers a more comprehensive understanding of documents and enables more efficient and accurate retrieval.
|
||||
|
||||
This model was contributed by [@tonywu71](https://huggingface.co/tonywu71) (ILLUIN Technology) and [@yonigozlan](https://huggingface.co/yonigozlan) (HuggingFace).
|
||||
|
||||
You can find all the original ColPali checkpoints under Vidore's [Hf-native ColVision Models](https://huggingface.co/collections/vidore/hf-native-colvision-models-6755d68fc60a8553acaa96f7) collection.
|
||||
|
||||
> [!TIP]
|
||||
> Click on the ColQwen2 models in the right sidebar for more examples of how to use ColQwen2 for image retrieval.
|
||||
|
||||
<hfoptions id="usage">
|
||||
<hfoption id="image retrieval">
|
||||
|
||||
```python
|
||||
import requests
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from transformers import ColQwen2ForRetrieval, ColQwen2Processor
|
||||
from transformers.utils.import_utils import is_flash_attn_2_available
|
||||
|
||||
|
||||
# Load the model and the processor
|
||||
model_name = "vidore/colqwen2-v1.0-hf"
|
||||
|
||||
model = ColQwen2ForRetrieval.from_pretrained(
|
||||
model_name,
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map="auto", # "cpu", "cuda", or "mps" for Apple Silicon
|
||||
attn_implementation="flash_attention_2" if is_flash_attn_2_available() else "sdpa",
|
||||
)
|
||||
processor = ColQwen2Processor.from_pretrained(model_name)
|
||||
|
||||
# The document page screenshots from your corpus
|
||||
url1 = "https://upload.wikimedia.org/wikipedia/commons/8/89/US-original-Declaration-1776.jpg"
|
||||
url2 = "https://upload.wikimedia.org/wikipedia/commons/thumb/4/4c/Romeoandjuliet1597.jpg/500px-Romeoandjuliet1597.jpg"
|
||||
|
||||
images = [
|
||||
Image.open(requests.get(url1, stream=True).raw),
|
||||
Image.open(requests.get(url2, stream=True).raw),
|
||||
]
|
||||
|
||||
# The queries you want to retrieve documents for
|
||||
queries = [
|
||||
"When was the United States Declaration of Independence proclaimed?",
|
||||
"Who printed the edition of Romeo and Juliet?",
|
||||
]
|
||||
|
||||
# Process the inputs
|
||||
inputs_images = processor(images=images).to(model.device)
|
||||
inputs_text = processor(text=queries).to(model.device)
|
||||
|
||||
# Forward pass
|
||||
with torch.no_grad():
|
||||
image_embeddings = model(**inputs_images).embeddings
|
||||
query_embeddings = model(**inputs_text).embeddings
|
||||
|
||||
# Score the queries against the images
|
||||
scores = processor.score_retrieval(query_embeddings, image_embeddings)
|
||||
|
||||
print("Retrieval scores (query x image):")
|
||||
print(scores)
|
||||
```
|
||||
|
||||
If you have issue with loading the images with PIL, you can use the following code to create dummy images:
|
||||
|
||||
```python
|
||||
images = [
|
||||
Image.new("RGB", (128, 128), color="white"),
|
||||
Image.new("RGB", (64, 32), color="black"),
|
||||
]
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
Quantization reduces the memory burden of large models by representing the weights in a lower precision. Refer to the [Quantization](../quantization/overview) overview for more available quantization backends.
|
||||
|
||||
The example below uses [bitsandbytes](../quantization/bitsandbytes.md) to quantize the weights to int4.
|
||||
|
||||
```python
|
||||
import requests
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from transformers import BitsAndBytesConfig, ColQwen2ForRetrieval, ColQwen2Processor
|
||||
|
||||
|
||||
model_name = "vidore/colqwen2-v1.0-hf"
|
||||
|
||||
# 4-bit quantization configuration
|
||||
bnb_config = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_use_double_quant=True,
|
||||
bnb_4bit_quant_type="nf4",
|
||||
bnb_4bit_compute_dtype=torch.float16,
|
||||
)
|
||||
|
||||
model = ColQwen2ForRetrieval.from_pretrained(
|
||||
model_name,
|
||||
quantization_config=bnb_config,
|
||||
device_map="cuda",
|
||||
).eval()
|
||||
|
||||
processor = ColQwen2Processor.from_pretrained(model_name)
|
||||
|
||||
url1 = "https://upload.wikimedia.org/wikipedia/commons/8/89/US-original-Declaration-1776.jpg"
|
||||
url2 = "https://upload.wikimedia.org/wikipedia/commons/thumb/4/4c/Romeoandjuliet1597.jpg/500px-Romeoandjuliet1597.jpg"
|
||||
|
||||
images = [
|
||||
Image.open(requests.get(url1, stream=True).raw),
|
||||
Image.open(requests.get(url2, stream=True).raw),
|
||||
]
|
||||
|
||||
queries = [
|
||||
"When was the United States Declaration of Independence proclaimed?",
|
||||
"Who printed the edition of Romeo and Juliet?",
|
||||
]
|
||||
|
||||
# Process the inputs
|
||||
inputs_images = processor(images=images, return_tensors="pt").to(model.device)
|
||||
inputs_text = processor(text=queries, return_tensors="pt").to(model.device)
|
||||
|
||||
# Forward pass
|
||||
with torch.no_grad():
|
||||
image_embeddings = model(**inputs_images).embeddings
|
||||
query_embeddings = model(**inputs_text).embeddings
|
||||
|
||||
# Score the queries against the images
|
||||
scores = processor.score_retrieval(query_embeddings, image_embeddings)
|
||||
|
||||
print("Retrieval scores (query x image):")
|
||||
print(scores)
|
||||
```
|
||||
|
||||
## Notes
|
||||
|
||||
- [`~ColQwen2Processor.score_retrieval`] returns a 2D tensor where the first dimension is the number of queries and the second dimension is the number of images. A higher score indicates more similarity between the query and image.
|
||||
- Unlike ColPali, ColQwen2 supports arbitrary image resolutions and aspect ratios, which means images are not resized into fixed-size squares. This preserves more of the original input signal.
|
||||
- Larger input images generate longer multi-vector embeddings, allowing users to adjust image resolution to balance performance and memory usage.
|
||||
|
||||
## ColQwen2Config
|
||||
|
||||
[[autodoc]] ColQwen2Config
|
||||
|
||||
## ColQwen2Processor
|
||||
|
||||
[[autodoc]] ColQwen2Processor
|
||||
|
||||
## ColQwen2ForRetrieval
|
||||
|
||||
[[autodoc]] ColQwen2ForRetrieval
|
||||
- forward
|
@ -62,6 +62,7 @@ if TYPE_CHECKING:
|
||||
from .cohere import *
|
||||
from .cohere2 import *
|
||||
from .colpali import *
|
||||
from .colqwen2 import *
|
||||
from .conditional_detr import *
|
||||
from .convbert import *
|
||||
from .convnext import *
|
||||
|
@ -79,6 +79,7 @@ CONFIG_MAPPING_NAMES = OrderedDict[str, str](
|
||||
("cohere", "CohereConfig"),
|
||||
("cohere2", "Cohere2Config"),
|
||||
("colpali", "ColPaliConfig"),
|
||||
("colqwen2", "ColQwen2Config"),
|
||||
("conditional_detr", "ConditionalDetrConfig"),
|
||||
("convbert", "ConvBertConfig"),
|
||||
("convnext", "ConvNextConfig"),
|
||||
@ -437,6 +438,7 @@ MODEL_NAMES_MAPPING = OrderedDict[str, str](
|
||||
("cohere", "Cohere"),
|
||||
("cohere2", "Cohere2"),
|
||||
("colpali", "ColPali"),
|
||||
("colqwen2", "ColQwen2"),
|
||||
("conditional_detr", "Conditional DETR"),
|
||||
("convbert", "ConvBERT"),
|
||||
("convnext", "ConvNeXT"),
|
||||
|
@ -365,6 +365,7 @@ MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict(
|
||||
("bloom", "BloomForCausalLM"),
|
||||
("camembert", "CamembertForMaskedLM"),
|
||||
("colpali", "ColPaliForRetrieval"),
|
||||
("colqwen2", "ColQwen2ForRetrieval"),
|
||||
("ctrl", "CTRLLMHeadModel"),
|
||||
("data2vec-text", "Data2VecTextForMaskedLM"),
|
||||
("deberta", "DebertaForMaskedLM"),
|
||||
|
@ -66,6 +66,7 @@ PROCESSOR_MAPPING_NAMES = OrderedDict(
|
||||
("clipseg", "CLIPSegProcessor"),
|
||||
("clvp", "ClvpProcessor"),
|
||||
("colpali", "ColPaliProcessor"),
|
||||
("colqwen2", "ColQwen2Processor"),
|
||||
("emu3", "Emu3Processor"),
|
||||
("flava", "FlavaProcessor"),
|
||||
("fuyu", "FuyuProcessor"),
|
||||
|
@ -147,6 +147,7 @@ TOKENIZER_MAPPING_NAMES = OrderedDict[str, tuple[Optional[str], Optional[str]]](
|
||||
("cohere", (None, "CohereTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("cohere2", (None, "CohereTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("colpali", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("colqwen2", ("Qwen2Tokenizer", "Qwen2TokenizerFast" if is_tokenizers_available() else None)),
|
||||
("convbert", ("ConvBertTokenizer", "ConvBertTokenizerFast" if is_tokenizers_available() else None)),
|
||||
(
|
||||
"cpm",
|
||||
|
@ -33,8 +33,6 @@ class ColPaliConfig(PretrainedConfig):
|
||||
Creating a configuration with the default settings will result in a configuration where the VLM backbone is set to the
|
||||
default PaliGemma configuration, i.e the one from [vidore/colpali-v1.2](https://huggingface.co/vidore/colpali-v1.2).
|
||||
|
||||
The ColPali config is very similar to [`PaligemmaConfig`], but with an extra attribute defining the embedding dimension.
|
||||
|
||||
Note that contrarily to what the class name suggests (actually the name refers to the ColPali **methodology**), you can
|
||||
use a different VLM backbone model than PaliGemma by passing the corresponding VLM configuration to the class constructor.
|
||||
|
||||
@ -93,7 +91,7 @@ class ColPaliConfig(PretrainedConfig):
|
||||
)
|
||||
|
||||
self.vlm_config = vlm_config
|
||||
self.text_config = text_config = text_config if text_config is not None else vlm_config.text_config
|
||||
self.text_config = text_config if text_config is not None else vlm_config.text_config
|
||||
if isinstance(self.text_config, dict):
|
||||
text_config["model_type"] = text_config["model_type"] if "model_type" in text_config else "gemma"
|
||||
self.text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config)
|
||||
|
@ -24,16 +24,10 @@ from transformers import AutoModelForImageTextToText
|
||||
|
||||
from ...cache_utils import Cache
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...utils import ModelOutput, auto_docstring
|
||||
from ...utils import ModelOutput, auto_docstring, can_return_tuple
|
||||
from .configuration_colpali import ColPaliConfig
|
||||
|
||||
|
||||
@auto_docstring(
|
||||
custom_intro="""
|
||||
The bare ColPali model outputting raw hidden-states without any specific head on top.
|
||||
"""
|
||||
)
|
||||
@auto_docstring
|
||||
class ColPaliPreTrainedModel(PreTrainedModel):
|
||||
config_class = ColPaliConfig
|
||||
base_model_prefix = "model"
|
||||
@ -98,13 +92,16 @@ class ColPaliForRetrievalOutput(ModelOutput):
|
||||
|
||||
@auto_docstring(
|
||||
custom_intro="""
|
||||
In our proposed ColPali approach, we leverage VLMs to construct efficient multi-vector embeddings directly
|
||||
from document images (“screenshots”) for document retrieval. We train the model to maximize the similarity
|
||||
The ColPali architecture leverages VLMs to construct efficient multi-vector embeddings directly
|
||||
from document images (“screenshots”) for document retrieval. The model is trained to maximize the similarity
|
||||
between these document embeddings and the corresponding query embeddings, using the late interaction method
|
||||
introduced in ColBERT.
|
||||
|
||||
Using ColPali removes the need for potentially complex and brittle layout recognition and OCR pipelines with a
|
||||
single model that can take into account both the textual and visual content (layout, charts, etc.) of a document.
|
||||
|
||||
ColPali is part of the ColVision model family, which was first introduced in the following paper:
|
||||
[*ColPali: Efficient Document Retrieval with Vision Language Models*](https://arxiv.org/abs/2407.01449).
|
||||
"""
|
||||
)
|
||||
class ColPaliForRetrieval(ColPaliPreTrainedModel):
|
||||
@ -126,6 +123,7 @@ class ColPaliForRetrieval(ColPaliPreTrainedModel):
|
||||
|
||||
self.post_init()
|
||||
|
||||
@can_return_tuple
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
@ -136,9 +134,9 @@ class ColPaliForRetrieval(ColPaliPreTrainedModel):
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
**kwargs,
|
||||
) -> Union[Tuple, ColPaliForRetrievalOutput]:
|
||||
if "pixel_values" in kwargs:
|
||||
kwargs["pixel_values"] = kwargs["pixel_values"].to(dtype=self.dtype)
|
||||
) -> ColPaliForRetrievalOutput:
|
||||
if pixel_values is not None:
|
||||
pixel_values = pixel_values.to(dtype=self.dtype)
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
|
||||
output_hidden_states = (
|
||||
@ -146,17 +144,19 @@ class ColPaliForRetrieval(ColPaliPreTrainedModel):
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
outputs = self.vlm(
|
||||
vlm_output = self.vlm(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
pixel_values=pixel_values,
|
||||
output_hidden_states=True,
|
||||
return_dict=return_dict,
|
||||
return_dict=True,
|
||||
output_attentions=output_attentions,
|
||||
**kwargs,
|
||||
)
|
||||
vlm_hidden_states = vlm_output.hidden_states if output_hidden_states else None
|
||||
vlm_image_hidden_states = vlm_output.image_hidden_states if pixel_values is not None else None
|
||||
|
||||
last_hidden_states = outputs.hidden_states[-1] # (batch_size, sequence_length, hidden_size)
|
||||
last_hidden_states = vlm_output.hidden_states[-1] # (batch_size, sequence_length, hidden_size)
|
||||
embeddings = self.embedding_proj_layer(last_hidden_states) # (batch_size, sequence_length, dim)
|
||||
|
||||
# L2 normalization
|
||||
@ -164,20 +164,12 @@ class ColPaliForRetrieval(ColPaliPreTrainedModel):
|
||||
|
||||
embeddings = embeddings * attention_mask.unsqueeze(-1) # (batch_size, sequence_length, dim)
|
||||
|
||||
loss = None
|
||||
if not return_dict:
|
||||
output = (embeddings,) + outputs[2:]
|
||||
output[2] = output[2] if output_hidden_states is not None else None
|
||||
output[-1] = (outputs.image_hidden_states if pixel_values is not None else None,)
|
||||
return (loss,) + output if loss is not None else output
|
||||
|
||||
return ColPaliForRetrievalOutput(
|
||||
loss=loss,
|
||||
embeddings=embeddings,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states if output_hidden_states else None,
|
||||
attentions=outputs.attentions,
|
||||
image_hidden_states=outputs.image_hidden_states if pixel_values is not None else None,
|
||||
past_key_values=vlm_output.past_key_values,
|
||||
hidden_states=vlm_hidden_states,
|
||||
attentions=vlm_output.attentions,
|
||||
image_hidden_states=vlm_image_hidden_states,
|
||||
)
|
||||
|
||||
def get_input_embeddings(self):
|
||||
|
@ -14,28 +14,15 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from typing import ClassVar, List, Optional, Union
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from transformers.models.paligemma.processing_paligemma import (
|
||||
IMAGE_TOKEN,
|
||||
PaliGemmaProcessor,
|
||||
build_string_from_input,
|
||||
)
|
||||
from transformers.models.paligemma.processing_paligemma import IMAGE_TOKEN, PaliGemmaProcessor, build_string_from_input
|
||||
|
||||
from ...feature_extraction_utils import BatchFeature
|
||||
from ...image_utils import ImageInput, is_valid_image, make_flat_list_of_images
|
||||
from ...processing_utils import (
|
||||
ProcessingKwargs,
|
||||
Unpack,
|
||||
)
|
||||
from ...tokenization_utils_base import (
|
||||
PreTokenizedInput,
|
||||
TextInput,
|
||||
)
|
||||
from ...utils import (
|
||||
is_torch_available,
|
||||
logging,
|
||||
)
|
||||
from ...processing_utils import ProcessingKwargs, Unpack
|
||||
from ...tokenization_utils_base import PreTokenizedInput, TextInput
|
||||
from ...utils import is_torch_available, logging
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
@ -73,10 +60,23 @@ class ColPaliProcessor(PaliGemmaProcessor):
|
||||
The tokenizer is a required input.
|
||||
chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages
|
||||
in a chat into a tokenizable string.
|
||||
visual_prompt_prefix (`str`, *optional*, defaults to `"Describe the image."`):
|
||||
A string that gets tokenized and prepended to the image tokens.
|
||||
query_prefix (`str`, *optional*, defaults to `"Question: "`):
|
||||
A prefix to be used for the query.
|
||||
"""
|
||||
|
||||
visual_prompt_prefix: ClassVar[str] = "Describe the image."
|
||||
query_prefix: ClassVar[str] = "Question: "
|
||||
def __init__(
|
||||
self,
|
||||
image_processor=None,
|
||||
tokenizer=None,
|
||||
chat_template=None,
|
||||
visual_prompt_prefix: str = "Describe the image.",
|
||||
query_prefix: str = "Question: ",
|
||||
):
|
||||
super().__init__(image_processor=image_processor, tokenizer=tokenizer, chat_template=chat_template)
|
||||
self.visual_prompt_prefix = visual_prompt_prefix
|
||||
self.query_prefix = query_prefix
|
||||
|
||||
@property
|
||||
def query_augmentation_token(self) -> str:
|
||||
@ -96,7 +96,7 @@ class ColPaliProcessor(PaliGemmaProcessor):
|
||||
**kwargs: Unpack[ColPaliProcessorKwargs],
|
||||
) -> BatchFeature:
|
||||
"""
|
||||
Main method to prepare for the model either (1) one or several texts, either (2) one or several image(s). This method is custom
|
||||
Main method to prepare for the model either (1) one or several texts, either (2) one or several image(s). This method is a custom
|
||||
wrapper around the PaliGemmaProcessor's [`~PaliGemmaProcessor.__call__`] method adapted for the ColPali model. It cannot process
|
||||
both text and images at the same time.
|
||||
|
||||
@ -196,12 +196,10 @@ class ColPaliProcessor(PaliGemmaProcessor):
|
||||
|
||||
if suffix is None:
|
||||
suffix = self.query_augmentation_token * 10
|
||||
texts_query: List[str] = []
|
||||
|
||||
texts_query: List[str] = []
|
||||
for query in text:
|
||||
query = self.tokenizer.bos_token + self.query_prefix + query
|
||||
query += suffix # add suffix (pad tokens)
|
||||
query += "\n" # make input ISO to PaliGemma's processor
|
||||
query = self.tokenizer.bos_token + self.query_prefix + query + suffix + "\n"
|
||||
texts_query.append(query)
|
||||
|
||||
output_kwargs["text_kwargs"]["max_length"] = output_kwargs["text_kwargs"].get("max_length", 50)
|
||||
@ -223,7 +221,7 @@ class ColPaliProcessor(PaliGemmaProcessor):
|
||||
Prepare for the model one or several image(s). This method is a wrapper around the `__call__` method of the ColPaliProcessor's
|
||||
[`ColPaliProcessor.__call__`].
|
||||
|
||||
This method forwards the `images` and `kwargs` arguments to SiglipImageProcessor's [`~SiglipImageProcessor.__call__`].
|
||||
This method forwards the `images` and `kwargs` arguments to the image processor.
|
||||
|
||||
Args:
|
||||
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
|
||||
@ -258,7 +256,7 @@ class ColPaliProcessor(PaliGemmaProcessor):
|
||||
Prepare for the model one or several texts. This method is a wrapper around the `__call__` method of the ColPaliProcessor's
|
||||
[`ColPaliProcessor.__call__`].
|
||||
|
||||
This method forwards the `text` and `kwargs` arguments to LlamaTokenizerFast's [`~LlamaTokenizerFast.__call__`].
|
||||
This method forwards the `text` and `kwargs` arguments to the tokenizer.
|
||||
|
||||
Args:
|
||||
text (`str`, `List[str]`, `List[List[str]]`):
|
||||
|
@ -20,7 +20,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from typing import ClassVar, List, Optional, Union
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from ...feature_extraction_utils import BatchFeature
|
||||
from ...image_utils import ImageInput, is_valid_image, make_flat_list_of_images
|
||||
@ -87,22 +87,25 @@ class ColPaliProcessor(ProcessorMixin):
|
||||
The tokenizer is a required input.
|
||||
chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages
|
||||
in a chat into a tokenizable string.
|
||||
visual_prompt_prefix (`str`, *optional*, defaults to `"Describe the image."`):
|
||||
A string that gets tokenized and prepended to the image tokens.
|
||||
query_prefix (`str`, *optional*, defaults to `"Question: "`):
|
||||
A prefix to be used for the query.
|
||||
"""
|
||||
|
||||
attributes = ["image_processor", "tokenizer"]
|
||||
image_processor_class = ("SiglipImageProcessor", "SiglipImageProcessorFast")
|
||||
tokenizer_class = ("GemmaTokenizer", "GemmaTokenizerFast")
|
||||
|
||||
visual_prompt_prefix: ClassVar[str] = "Describe the image."
|
||||
query_prefix: ClassVar[str] = "Question: "
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
image_processor=None,
|
||||
tokenizer=None,
|
||||
chat_template=None,
|
||||
**kwargs,
|
||||
visual_prompt_prefix: str = "Describe the image.",
|
||||
query_prefix: str = "Question: ",
|
||||
):
|
||||
super().__init__(image_processor, tokenizer, chat_template=chat_template)
|
||||
if image_processor is None:
|
||||
raise ValueError("You need to specify an `image_processor`.")
|
||||
if tokenizer is None:
|
||||
@ -125,8 +128,8 @@ class ColPaliProcessor(ProcessorMixin):
|
||||
tokenizer.add_tokens(EXTRA_TOKENS)
|
||||
tokenizer.add_bos_token = False
|
||||
tokenizer.add_eos_token = False
|
||||
|
||||
super().__init__(image_processor, tokenizer, chat_template=chat_template)
|
||||
self.visual_prompt_prefix = visual_prompt_prefix
|
||||
self.query_prefix = query_prefix
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
@ -137,7 +140,7 @@ class ColPaliProcessor(ProcessorMixin):
|
||||
**kwargs: Unpack[ColPaliProcessorKwargs],
|
||||
) -> BatchFeature:
|
||||
"""
|
||||
Main method to prepare for the model either (1) one or several texts, either (2) one or several image(s). This method is custom
|
||||
Main method to prepare for the model either (1) one or several texts, either (2) one or several image(s). This method is a custom
|
||||
wrapper around the PaliGemmaProcessor's [`~PaliGemmaProcessor.__call__`] method adapted for the ColPali model. It cannot process
|
||||
both text and images at the same time.
|
||||
|
||||
@ -237,12 +240,10 @@ class ColPaliProcessor(ProcessorMixin):
|
||||
|
||||
if suffix is None:
|
||||
suffix = self.query_augmentation_token * 10
|
||||
texts_query: List[str] = []
|
||||
|
||||
texts_query: List[str] = []
|
||||
for query in text:
|
||||
query = self.tokenizer.bos_token + self.query_prefix + query
|
||||
query += suffix # add suffix (pad tokens)
|
||||
query += "\n" # make input ISO to PaliGemma's processor
|
||||
query = self.tokenizer.bos_token + self.query_prefix + query + suffix + "\n"
|
||||
texts_query.append(query)
|
||||
|
||||
output_kwargs["text_kwargs"]["max_length"] = output_kwargs["text_kwargs"].get("max_length", 50)
|
||||
@ -312,7 +313,7 @@ class ColPaliProcessor(ProcessorMixin):
|
||||
Prepare for the model one or several image(s). This method is a wrapper around the `__call__` method of the ColPaliProcessor's
|
||||
[`ColPaliProcessor.__call__`].
|
||||
|
||||
This method forwards the `images` and `kwargs` arguments to SiglipImageProcessor's [`~SiglipImageProcessor.__call__`].
|
||||
This method forwards the `images` and `kwargs` arguments to the image processor.
|
||||
|
||||
Args:
|
||||
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
|
||||
@ -347,7 +348,7 @@ class ColPaliProcessor(ProcessorMixin):
|
||||
Prepare for the model one or several texts. This method is a wrapper around the `__call__` method of the ColPaliProcessor's
|
||||
[`ColPaliProcessor.__call__`].
|
||||
|
||||
This method forwards the `text` and `kwargs` arguments to LlamaTokenizerFast's [`~LlamaTokenizerFast.__call__`].
|
||||
This method forwards the `text` and `kwargs` arguments to the tokenizer.
|
||||
|
||||
Args:
|
||||
text (`str`, `List[str]`, `List[List[str]]`):
|
||||
|
28
src/transformers/models/colqwen2/__init__.py
Normal file
28
src/transformers/models/colqwen2/__init__.py
Normal file
@ -0,0 +1,28 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...utils import _LazyModule
|
||||
from ...utils.import_utils import define_import_structure
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .configuration_colqwen2 import *
|
||||
from .modeling_colqwen2 import *
|
||||
from .processing_colqwen2 import *
|
||||
else:
|
||||
import sys
|
||||
|
||||
_file = globals()["__file__"]
|
||||
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
|
94
src/transformers/models/colqwen2/configuration_colqwen2.py
Normal file
94
src/transformers/models/colqwen2/configuration_colqwen2.py
Normal file
@ -0,0 +1,94 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from copy import deepcopy
|
||||
from typing import Any, Dict
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...utils import logging
|
||||
from ..auto import CONFIG_MAPPING
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class ColQwen2Config(PretrainedConfig):
|
||||
r"""
|
||||
Configuration class to store the configuration of a [`ColQ2en2ForRetrieval`]. It is used to instantiate an instance
|
||||
of `ColQwen2ForRetrieval` according to the specified arguments, defining the model architecture following the methodology
|
||||
from the "ColPali: Efficient Document Retrieval with Vision Language Models" paper.
|
||||
|
||||
Instantiating a configuration with the defaults will yield a similar configuration to the vision encoder used by the pre-trained
|
||||
ColQwen2-v1.0 model, e.g. [vidore/colqwen2-v1.0-hf](https://huggingface.co/vidore/colqwen2-v1.0-hf).
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
Args:
|
||||
vlm_config (`PretrainedConfig`, *optional*):
|
||||
Configuration of the VLM backbone model.
|
||||
embedding_dim (`int`, *optional*, defaults to 128):
|
||||
Dimension of the multi-vector embeddings produced by the model.
|
||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
Example:
|
||||
|
||||
```python
|
||||
from transformers.models.colqwen2 import ColQwen2Config, ColQwen2ForRetrieval
|
||||
|
||||
config = ColQwen2Config()
|
||||
model = ColQwen2ForRetrieval(config)
|
||||
```
|
||||
"""
|
||||
|
||||
model_type = "colqwen2"
|
||||
sub_configs: Dict[str, Any] = {"vlm_config": PretrainedConfig}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vlm_config=None,
|
||||
embedding_dim: int = 128,
|
||||
initializer_range: float = 0.02,
|
||||
**kwargs,
|
||||
):
|
||||
if vlm_config is None:
|
||||
vlm_config = CONFIG_MAPPING["qwen2_vl"]()
|
||||
logger.info(
|
||||
"`vlm_config` is `None`. Initializing `vlm_config` with the `Qwen2VLConfig` with default values."
|
||||
)
|
||||
elif isinstance(vlm_config, dict):
|
||||
vlm_config = deepcopy(vlm_config)
|
||||
if "model_type" not in vlm_config:
|
||||
raise KeyError(
|
||||
"The `model_type` key is missing in the `vlm_config` dictionary. Please provide the model type."
|
||||
)
|
||||
vlm_config = CONFIG_MAPPING[vlm_config["model_type"]](**vlm_config)
|
||||
elif isinstance(vlm_config, PretrainedConfig):
|
||||
vlm_config = vlm_config
|
||||
else:
|
||||
raise TypeError(
|
||||
f"Invalid type for `vlm_config`. Expected `PretrainedConfig`, `dict`, or `None`, but got {type(vlm_config)}."
|
||||
)
|
||||
|
||||
self.vlm_config = vlm_config
|
||||
self.embedding_dim = embedding_dim
|
||||
self.initializer_range = initializer_range
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def get_text_config(self, decoder=False) -> PretrainedConfig:
|
||||
return self.vlm_config.get_text_config(decoder=decoder)
|
||||
|
||||
|
||||
__all__ = ["ColQwen2Config"]
|
@ -0,0 +1,212 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Convert ColQwen2 weights from the original repository to the HF model format.
|
||||
|
||||
Don't forget to manually upload the processor-related files to the HF model repository
|
||||
after running this script.
|
||||
|
||||
Original repository: https://github.com/illuin-tech/colqwen2.
|
||||
|
||||
NOTE: This script was originally run using `torch==2.5.1` and with:
|
||||
|
||||
```bash
|
||||
python src/transformers/models/colqwen2/convert_colqwen2_weights_to_hf.py \
|
||||
--model_id vidore/colqwen2-v1.0-merged \
|
||||
--revision eeccbae1d44bdcb0c83b1788127a2b2cad7d718e \
|
||||
--original_vlm_name_or_path Qwen/Qwen2-VL-2B-Instruct \
|
||||
--output_dir vidore/colqwen2-v1.0-hf-internal \
|
||||
--push_to_hub
|
||||
```
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import glob
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import torch
|
||||
from huggingface_hub import snapshot_download
|
||||
from safetensors import safe_open
|
||||
|
||||
from transformers import AutoConfig
|
||||
from transformers.models.colqwen2 import ColQwen2ForRetrieval
|
||||
from transformers.models.colqwen2.configuration_colqwen2 import ColQwen2Config
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
logging.set_verbosity_info()
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
ORIGINAL_DTYPE = torch.bfloat16
|
||||
|
||||
|
||||
def load_original_state_dict(model_id: str, revision: Optional[str] = None) -> Dict[str, torch.Tensor]:
|
||||
directory_path = snapshot_download(
|
||||
repo_id=model_id,
|
||||
revision=revision,
|
||||
allow_patterns=["*.safetensors"],
|
||||
)
|
||||
|
||||
original_state_dict = {}
|
||||
for path in glob.glob(f"{directory_path}/*"):
|
||||
if path.endswith(".safetensors"):
|
||||
with safe_open(path, framework="pt", device="cpu") as f:
|
||||
for key in f.keys():
|
||||
original_state_dict[key] = f.get_tensor(key)
|
||||
|
||||
# Some weights are tied, so `lm.head`` is not saved. Let's clone to load state dict.
|
||||
if "lm_head.weight" not in original_state_dict:
|
||||
original_state_dict["lm_head.weight"] = original_state_dict["model.embed_tokens.weight"].clone()
|
||||
|
||||
return original_state_dict
|
||||
|
||||
|
||||
def rename_state_dict_keys(state_dict: Dict[str, Any]) -> Dict[str, Any]:
|
||||
new_state_dict: Dict[str, Any] = {}
|
||||
for key, value in state_dict.items():
|
||||
if key.startswith("custom_text_proj"):
|
||||
new_key = key.replace("custom_text_proj", "embedding_proj_layer")
|
||||
else:
|
||||
# The original ColQwen2 inherits from Qwen2VL, so we simply need to add the `vlm.` prefix
|
||||
# to all remaining keys.
|
||||
if key.startswith("model."):
|
||||
key = key.replace("model.", "model.language_model.")
|
||||
if key.startswith("visual."):
|
||||
key = key.replace("visual.", "model.visual.")
|
||||
new_key = "vlm." + key
|
||||
new_state_dict[new_key] = value
|
||||
return new_state_dict
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def convert_colqwen2_weights_to_hf(
|
||||
model_id: str,
|
||||
output_dir: str,
|
||||
push_to_hub: bool,
|
||||
revision: Optional[str] = None,
|
||||
original_vlm_name_or_path: Optional[str] = None,
|
||||
):
|
||||
# Load the original model data
|
||||
original_config = AutoConfig.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
)
|
||||
if original_vlm_name_or_path is not None:
|
||||
original_config._name_or_path = original_vlm_name_or_path
|
||||
if hasattr(original_config, "architectures"):
|
||||
delattr(original_config, "architectures")
|
||||
|
||||
original_state_dict = load_original_state_dict(model_id, revision=revision)
|
||||
|
||||
# Format the state_dict keys
|
||||
original_state_dict = rename_state_dict_keys(original_state_dict)
|
||||
|
||||
# Create the new config
|
||||
config = ColQwen2Config(
|
||||
vlm_config=original_config,
|
||||
embedding_dim=128, # hardcoded in the original model
|
||||
)
|
||||
config.model_type = "colqwen2"
|
||||
config.is_composition = False
|
||||
|
||||
# Load the untrained model
|
||||
model = ColQwen2ForRetrieval(config=config).to("cpu").eval()
|
||||
print("Created model with new config and randomly initialized weights")
|
||||
|
||||
# NOTE: The new model was initialized with float32 weights. We need to convert it to the desired precision.
|
||||
# There are two ways to set the model's dtype:
|
||||
# - Using `model.from_pretrained(..., torch_dtype=dtype_precision)` doesn't convert the hyperparameters to the desired precision.
|
||||
# - Using `model.to(dtype_precision)` converts all values - including the hyperparameters - to the desired precision.
|
||||
# The following snippet allows a fine-grained control over the model's dtype, making sure that all
|
||||
# the new weights' dtypes match the original model.
|
||||
for param in model.parameters():
|
||||
param.data = param.data.to(ORIGINAL_DTYPE)
|
||||
print(f"Converted the new model weights to `{ORIGINAL_DTYPE}`")
|
||||
|
||||
# Load the original weights
|
||||
model.load_state_dict(original_state_dict)
|
||||
print("Loaded original model weights")
|
||||
|
||||
# # Sanity check: ensure all keys are the same
|
||||
state_dict_keys_old = set(original_state_dict.keys())
|
||||
state_dict_keys_new = set(model.state_dict().keys())
|
||||
disjoint_keys = state_dict_keys_old.symmetric_difference(state_dict_keys_new)
|
||||
if disjoint_keys:
|
||||
raise ValueError(f"Incompatible keys: {disjoint_keys}")
|
||||
|
||||
# Save the model
|
||||
if push_to_hub:
|
||||
model.push_to_hub(output_dir, private=True)
|
||||
print(f"Model pushed to the hub at `{output_dir}`")
|
||||
else:
|
||||
Path(output_dir).mkdir(exist_ok=True, parents=True)
|
||||
model.save_pretrained(output_dir)
|
||||
print(f"Model saved to `{output_dir}`")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="""
|
||||
This script converts the original ColQwen2 model to the HF model format.
|
||||
|
||||
Don't forget to manually upload the processor-related files to the HF model repository
|
||||
after running this script.
|
||||
|
||||
Example usage:
|
||||
```bash
|
||||
python src/transformers/models/colqwen2/convert_colqwen2_weights_to_hf.py \
|
||||
--model_id vidore/colqwen2-v1.0-merged \
|
||||
--revision eeccbae1d44bdcb0c83b1788127a2b2cad7d718e \
|
||||
--original_vlm_name_or_path Qwen/Qwen2-VL-2B-Instruct \
|
||||
--output_dir vidore/colqwen2-v1.0-hf-internal \
|
||||
--push_to_hub
|
||||
```
|
||||
"""
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_id",
|
||||
help="Model ID of the original model to convert",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
help="Location to write HF model and tokenizer",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--push_to_hub",
|
||||
help="Whether or not to push the model to the hub at `output_dir` instead of saving it locally",
|
||||
action="store_true",
|
||||
default=False,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--revision",
|
||||
help="Revision of the model to download",
|
||||
default=None,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--original_vlm_name_or_path",
|
||||
help="Name or path of the original VLM backbone model",
|
||||
default=None,
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
convert_colqwen2_weights_to_hf(
|
||||
model_id=args.model_id,
|
||||
output_dir=args.output_dir,
|
||||
push_to_hub=args.push_to_hub,
|
||||
revision=args.revision,
|
||||
original_vlm_name_or_path=args.original_vlm_name_or_path,
|
||||
)
|
268
src/transformers/models/colqwen2/modeling_colqwen2.py
Normal file
268
src/transformers/models/colqwen2/modeling_colqwen2.py
Normal file
@ -0,0 +1,268 @@
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
# This file was automatically generated from src/transformers/models/colqwen2/modular_colqwen2.py.
|
||||
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
||||
# the file from the modular. If any change should be done, please apply the change to the
|
||||
# modular_colqwen2.py file directly. One of our CI enforces this.
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
# coding=utf-8
|
||||
# Copyright 2025 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
from torch import nn
|
||||
|
||||
from transformers import AutoModelForImageTextToText
|
||||
|
||||
from ...cache_utils import Cache
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...utils import ModelOutput, auto_docstring, can_return_tuple, is_torch_available
|
||||
from .configuration_colqwen2 import ColQwen2Config
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
|
||||
class ColQwen2PreTrainedModel(PreTrainedModel):
|
||||
config_class = ColQwen2Config
|
||||
base_model_prefix = "model"
|
||||
_no_split_modules = []
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
_supports_cache_class = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
std = (
|
||||
self.config.initializer_range
|
||||
if hasattr(self.config, "initializer_range")
|
||||
else self.config.vlm_config.text_config.initializer_range
|
||||
)
|
||||
|
||||
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
|
||||
|
||||
@dataclass
|
||||
class ColQwen2ForRetrievalOutput(ModelOutput):
|
||||
"""
|
||||
Base class for ColQwen2 embeddings output.
|
||||
|
||||
Args:
|
||||
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
|
||||
Language modeling loss (for next-token prediction).
|
||||
embeddings (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
||||
The embeddings of the model.
|
||||
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||||
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
|
||||
`(batch_size, num_heads, sequence_length, embed_size_per_head)`)
|
||||
|
||||
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
|
||||
`past_key_values` input) to speed up sequential decoding.
|
||||
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
||||
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
|
||||
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
|
||||
|
||||
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
|
||||
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
||||
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
||||
sequence_length)`.
|
||||
|
||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
||||
heads.
|
||||
"""
|
||||
|
||||
loss: Optional[torch.FloatTensor] = None
|
||||
embeddings: Optional[torch.Tensor] = None
|
||||
past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None
|
||||
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
||||
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||
|
||||
|
||||
@auto_docstring(
|
||||
custom_intro="""
|
||||
Following the ColPali approach, ColQwen2 leverages VLMs to construct efficient multi-vector embeddings directly
|
||||
from document images (“screenshots”) for document retrieval. The model is trained to maximize the similarity
|
||||
between these document embeddings and the corresponding query embeddings, using the late interaction method
|
||||
introduced in ColBERT.
|
||||
|
||||
Using ColQwen2 removes the need for potentially complex and brittle layout recognition and OCR pipelines with
|
||||
a single model that can take into account both the textual and visual content (layout, charts, ...) of a document.
|
||||
|
||||
ColQwen2 is part of the ColVision model family, which was introduced with ColPali in the following paper:
|
||||
[*ColPali: Efficient Document Retrieval with Vision Language Models*](https://arxiv.org/abs/2407.01449).
|
||||
"""
|
||||
)
|
||||
class ColQwen2ForRetrieval(ColQwen2PreTrainedModel):
|
||||
def __init__(self, config: ColQwen2Config):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
self.vocab_size = config.vlm_config.text_config.vocab_size
|
||||
|
||||
vlm = AutoModelForImageTextToText.from_config(config.vlm_config)
|
||||
if vlm._tied_weights_keys is not None:
|
||||
self._tied_weights_keys = [f"vlm.{k}" for k in vlm._tied_weights_keys]
|
||||
self.vlm = vlm
|
||||
|
||||
self.embedding_dim = self.config.embedding_dim
|
||||
self.embedding_proj_layer = nn.Linear(
|
||||
self.config.vlm_config.text_config.hidden_size,
|
||||
self.embedding_dim,
|
||||
)
|
||||
|
||||
self.post_init()
|
||||
|
||||
@can_return_tuple
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
pixel_values: Optional[torch.Tensor] = None,
|
||||
image_grid_thw: Optional[torch.LongTensor] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
) -> ColQwen2ForRetrievalOutput:
|
||||
r"""
|
||||
image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
|
||||
The temporal, height and width of feature shape of each image in LLM.
|
||||
"""
|
||||
if pixel_values is not None:
|
||||
pixel_values = pixel_values.to(dtype=self.dtype) # (batch_size, max_num_patches, pixel_values)
|
||||
|
||||
# Handle the custom "pixel_values" input obtained with `ColQwen2Processor` through unpadding
|
||||
if pixel_values is not None and image_grid_thw is not None:
|
||||
# NOTE: image_grid_thw: (batch_size, 3) where image_grid_thw[i] = (num_patches_h, num_patches_w, temporal_patch_size)
|
||||
offsets = image_grid_thw[:, 1] * image_grid_thw[:, 2] # (num_patches_h, num_patches_w)
|
||||
pixel_values = torch.cat(
|
||||
[pixel_sequence[:offset] for pixel_sequence, offset in zip(pixel_values, offsets)],
|
||||
dim=0,
|
||||
) # (num_patches_h * num_patches_w, pixel_values)
|
||||
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
position_ids, rope_deltas = self.vlm.model.get_rope_index(
|
||||
input_ids=input_ids,
|
||||
image_grid_thw=image_grid_thw,
|
||||
video_grid_thw=None,
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
|
||||
# Custom data preparation to fix an issue with the gradient flow when training with multiple GPUs.
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.vlm.model.language_model.embed_tokens(input_ids)
|
||||
|
||||
if pixel_values is not None:
|
||||
pixel_values = pixel_values.type(self.vlm.visual.get_dtype())
|
||||
image_embeds = self.vlm.visual(pixel_values, grid_thw=image_grid_thw)
|
||||
image_mask = (
|
||||
(input_ids == self.config.vlm_config.image_token_id).unsqueeze(-1).expand_as(inputs_embeds)
|
||||
)
|
||||
image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
|
||||
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask.to(inputs_embeds.device)
|
||||
|
||||
vlm_output = self.vlm.model(
|
||||
input_ids=None,
|
||||
position_ids=position_ids,
|
||||
attention_mask=attention_mask,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
)
|
||||
|
||||
vlm_hidden_states = vlm_output.hidden_states if output_hidden_states else None
|
||||
|
||||
last_hidden_states = vlm_output[0] # (batch_size, sequence_length, hidden_size)
|
||||
embeddings = self.embedding_proj_layer(last_hidden_states) # (batch_size, sequence_length, dim)
|
||||
|
||||
# L2 normalization
|
||||
embeddings = embeddings / embeddings.norm(dim=-1, keepdim=True) # (batch_size, sequence_length, dim)
|
||||
if attention_mask is not None:
|
||||
embeddings = embeddings * attention_mask.unsqueeze(-1) # (batch_size, sequence_length, dim)
|
||||
|
||||
return ColQwen2ForRetrievalOutput(
|
||||
embeddings=embeddings,
|
||||
past_key_values=vlm_output.past_key_values,
|
||||
hidden_states=vlm_hidden_states,
|
||||
attentions=vlm_output.attentions,
|
||||
)
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.vlm.get_input_embeddings()
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
self.vlm.set_input_embeddings(value)
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return self.vlm.get_output_embeddings()
|
||||
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
self.vlm.set_output_embeddings(new_embeddings)
|
||||
|
||||
def set_decoder(self, decoder):
|
||||
self.vlm.set_decoder(decoder)
|
||||
|
||||
def get_decoder(self):
|
||||
return self.vlm.get_decoder()
|
||||
|
||||
def tie_weights(self):
|
||||
return self.vlm.tie_weights()
|
||||
|
||||
def resize_token_embeddings(
|
||||
self,
|
||||
new_num_tokens: Optional[int] = None,
|
||||
pad_to_multiple_of: Optional[int] = None,
|
||||
mean_resizing: bool = True,
|
||||
) -> nn.Embedding:
|
||||
model_embeds = self.vlm.resize_token_embeddings(
|
||||
new_num_tokens=new_num_tokens,
|
||||
pad_to_multiple_of=pad_to_multiple_of,
|
||||
mean_resizing=mean_resizing,
|
||||
)
|
||||
|
||||
self.config.vlm_config.text_config.vocab_size = model_embeds.num_embeddings
|
||||
self.config.vlm_config.vocab_size = model_embeds.num_embeddings
|
||||
self.vlm.vocab_size = model_embeds.num_embeddings
|
||||
self.vocab_size = model_embeds.num_embeddings
|
||||
|
||||
return model_embeds
|
||||
|
||||
|
||||
__all__ = ["ColQwen2ForRetrieval", "ColQwen2PreTrainedModel"]
|
383
src/transformers/models/colqwen2/modular_colqwen2.py
Normal file
383
src/transformers/models/colqwen2/modular_colqwen2.py
Normal file
@ -0,0 +1,383 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
from transformers.models.colpali.modeling_colpali import ColPaliForRetrieval, ColPaliPreTrainedModel
|
||||
from transformers.models.colpali.processing_colpali import ColPaliProcessor
|
||||
|
||||
from ...cache_utils import Cache
|
||||
from ...feature_extraction_utils import BatchFeature
|
||||
from ...image_utils import ImageInput, is_valid_image
|
||||
from ...processing_utils import ProcessingKwargs, Unpack
|
||||
from ...tokenization_utils_base import PreTokenizedInput, TextInput
|
||||
from ...utils import ModelOutput, auto_docstring, can_return_tuple, is_torch_available, logging
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class ColQwen2ProcessorKwargs(ProcessingKwargs, total=False):
|
||||
_defaults = {
|
||||
"text_kwargs": {
|
||||
"padding": "longest",
|
||||
},
|
||||
"images_kwargs": {
|
||||
"data_format": "channels_first",
|
||||
"do_convert_rgb": True,
|
||||
},
|
||||
"common_kwargs": {"return_tensors": "pt"},
|
||||
}
|
||||
|
||||
|
||||
class ColQwen2Processor(ColPaliProcessor):
|
||||
r"""
|
||||
Constructs a ColQwen2 processor which wraps a Qwen2VLProcessor and special methods to process images and queries, as
|
||||
well as to compute the late-interaction retrieval score.
|
||||
|
||||
[`ColQwen2Processor`] offers all the functionalities of [`Qwen2VLProcessor`]. See the [`~Qwen2VLProcessor.__call__`]
|
||||
for more information.
|
||||
|
||||
Args:
|
||||
image_processor ([`Qwen2VLImageProcessor`], *optional*):
|
||||
The image processor is a required input.
|
||||
tokenizer ([`Qwen2TokenizerFast`], *optional*):
|
||||
The tokenizer is a required input.
|
||||
chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages
|
||||
in a chat into a tokenizable string.
|
||||
visual_prompt_prefix (`str`, *optional*): A string that gets tokenized and prepended to the image tokens.
|
||||
query_prefix (`str`, *optional*): A prefix to be used for the query.
|
||||
"""
|
||||
|
||||
image_processor_class = "Qwen2VLImageProcessor"
|
||||
tokenizer_class = ("Qwen2Tokenizer", "Qwen2TokenizerFast")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
image_processor=None,
|
||||
tokenizer=None,
|
||||
chat_template=None,
|
||||
visual_prompt_prefix: Optional[str] = None,
|
||||
query_prefix: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
ColPaliProcessor().__init__(image_processor, tokenizer, chat_template=chat_template)
|
||||
self.image_token = "<|image_pad|>" if not hasattr(tokenizer, "image_token") else tokenizer.image_token
|
||||
self.video_token = "<|video_pad|>" if not hasattr(tokenizer, "video_token") else tokenizer.video_token
|
||||
|
||||
if visual_prompt_prefix is None:
|
||||
visual_prompt_prefix = "<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>Describe the image.<|im_end|><|endoftext|>"
|
||||
self.visual_prompt_prefix = visual_prompt_prefix
|
||||
|
||||
if query_prefix is None:
|
||||
query_prefix = "Query: "
|
||||
self.query_prefix = query_prefix
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
images: ImageInput = None,
|
||||
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
|
||||
audio=None,
|
||||
videos=None,
|
||||
**kwargs: Unpack[ColQwen2ProcessorKwargs],
|
||||
) -> BatchFeature:
|
||||
"""
|
||||
Main method to prepare for the model either (1) one or several texts, either (2) one or several image(s). This method is a custom
|
||||
wrapper around the Qwen2VLProcessor's [`~Qwen2VLProcessor.__call__`] method adapted for the ColQwen2 model. It cannot process
|
||||
both text and images at the same time.
|
||||
|
||||
When preparing the the text(s), this method forwards the `text` and `kwargs` arguments to Qwen2TokenizerFast's
|
||||
[`~Qwen2TokenizerFast.__call__`].
|
||||
When preparing the the image(s), this method forwards the `images` and `kwargs` arguments to Qwen2VLImageProcessor's
|
||||
[`~Qwen2VLImageProcessor.__call__`].
|
||||
Please refer to the doctsring of the above two methods for more information.
|
||||
|
||||
Args:
|
||||
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
|
||||
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
|
||||
tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a
|
||||
number of channels, H and W are image height and width.
|
||||
text (`str`, `List[str]`, `List[List[str]]`):
|
||||
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
|
||||
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
|
||||
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
|
||||
return_tensors (`str` or [`~utils.TensorType`], *optional*):
|
||||
If set, will return tensors of a particular framework. Acceptable values are:
|
||||
|
||||
- `'tf'`: Return TensorFlow `tf.constant` objects.
|
||||
- `'pt'`: Return PyTorch `torch.Tensor` objects.
|
||||
- `'np'`: Return NumPy `np.ndarray` objects.
|
||||
- `'jax'`: Return JAX `jnp.ndarray` objects.
|
||||
|
||||
Returns:
|
||||
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
|
||||
|
||||
- **input_ids** -- List of token ids to be fed to a model.
|
||||
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
|
||||
`return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
|
||||
`None`).
|
||||
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
|
||||
"""
|
||||
output_kwargs = self._merge_kwargs(
|
||||
ColQwen2ProcessorKwargs,
|
||||
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
suffix = output_kwargs["text_kwargs"].pop("suffix", None)
|
||||
|
||||
return_token_type_ids = True if suffix is not None else False
|
||||
|
||||
if text is None and images is None:
|
||||
raise ValueError("Either text or images must be provided")
|
||||
if text is not None and images is not None:
|
||||
raise ValueError("Only one of text or images can be processed at a time")
|
||||
|
||||
if images is not None:
|
||||
if is_valid_image(images):
|
||||
images = [images]
|
||||
elif isinstance(images, list) and is_valid_image(images[0]):
|
||||
pass
|
||||
elif not (isinstance(images, list) and isinstance(images[0], list) and is_valid_image(images[0][0])):
|
||||
raise ValueError("images must be an image, list of images or list of list of images")
|
||||
|
||||
texts_doc = [self.visual_prompt_prefix] * len(images)
|
||||
|
||||
image_inputs = self.image_processor(images=images, **output_kwargs["images_kwargs"])
|
||||
image_grid_thw = image_inputs["image_grid_thw"]
|
||||
|
||||
if image_grid_thw is not None:
|
||||
merge_length = self.image_processor.merge_size**2
|
||||
index = 0
|
||||
for i in range(len(texts_doc)):
|
||||
while self.image_token in texts_doc[i]:
|
||||
texts_doc[i] = texts_doc[i].replace(
|
||||
self.image_token, "<|placeholder|>" * (image_grid_thw[index].prod() // merge_length), 1
|
||||
)
|
||||
index += 1
|
||||
texts_doc[i] = texts_doc[i].replace("<|placeholder|>", self.image_token)
|
||||
|
||||
text_inputs = self.tokenizer(
|
||||
texts_doc,
|
||||
return_token_type_ids=False,
|
||||
**output_kwargs["text_kwargs"],
|
||||
)
|
||||
|
||||
return_data = BatchFeature(data={**text_inputs, **image_inputs})
|
||||
|
||||
# NOTE: The following adjustment ensures correct behavior with DDP on multiple GPUs.
|
||||
offsets = return_data["image_grid_thw"][:, 1] * return_data["image_grid_thw"][:, 2] # (batch_size,)
|
||||
|
||||
# Split the pixel_values tensor into a list of tensors, one per image
|
||||
pixel_values = list(
|
||||
torch.split(return_data["pixel_values"], offsets.tolist())
|
||||
) # [(num_patches_image_0, pixel_values), ..., (num_patches_image_n, pixel_values)]
|
||||
|
||||
# Pad the list of pixel_value tensors to the same length along the sequence dimension
|
||||
return_data["pixel_values"] = torch.nn.utils.rnn.pad_sequence(
|
||||
pixel_values, batch_first=True
|
||||
) # (batch_size, max_num_patches, pixel_values)
|
||||
|
||||
if return_token_type_ids:
|
||||
labels = return_data["input_ids"].masked_fill(return_data["token_type_ids"] == 0, -100)
|
||||
return_data.update({"labels": labels})
|
||||
|
||||
return return_data
|
||||
|
||||
elif text is not None:
|
||||
if isinstance(text, str):
|
||||
text = [text]
|
||||
elif not (isinstance(text, list) and isinstance(text[0], str)):
|
||||
raise ValueError("Text must be a string or a list of strings")
|
||||
|
||||
if suffix is None:
|
||||
suffix = self.query_augmentation_token * 10
|
||||
|
||||
texts_query: List[str] = []
|
||||
|
||||
for query in text:
|
||||
augmented_query = self.query_prefix + query + suffix
|
||||
texts_query.append(augmented_query)
|
||||
|
||||
batch_query = self.tokenizer(
|
||||
texts_query,
|
||||
return_token_type_ids=False,
|
||||
**output_kwargs["text_kwargs"],
|
||||
)
|
||||
|
||||
return batch_query
|
||||
|
||||
|
||||
class ColQwen2PreTrainedModel(ColPaliPreTrainedModel):
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
_supports_cache_class = True
|
||||
|
||||
|
||||
@dataclass
|
||||
class ColQwen2ForRetrievalOutput(ModelOutput):
|
||||
"""
|
||||
Base class for ColQwen2 embeddings output.
|
||||
|
||||
Args:
|
||||
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
|
||||
Language modeling loss (for next-token prediction).
|
||||
embeddings (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
||||
The embeddings of the model.
|
||||
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||||
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
|
||||
`(batch_size, num_heads, sequence_length, embed_size_per_head)`)
|
||||
|
||||
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
|
||||
`past_key_values` input) to speed up sequential decoding.
|
||||
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
||||
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
|
||||
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
|
||||
|
||||
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
|
||||
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
||||
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
||||
sequence_length)`.
|
||||
|
||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
||||
heads.
|
||||
"""
|
||||
|
||||
loss: Optional[torch.FloatTensor] = None
|
||||
embeddings: Optional[torch.Tensor] = None
|
||||
past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None
|
||||
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
||||
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||
|
||||
|
||||
@auto_docstring(
|
||||
custom_intro="""
|
||||
Following the ColPali approach, ColQwen2 leverages VLMs to construct efficient multi-vector embeddings directly
|
||||
from document images (“screenshots”) for document retrieval. The model is trained to maximize the similarity
|
||||
between these document embeddings and the corresponding query embeddings, using the late interaction method
|
||||
introduced in ColBERT.
|
||||
|
||||
Using ColQwen2 removes the need for potentially complex and brittle layout recognition and OCR pipelines with
|
||||
a single model that can take into account both the textual and visual content (layout, charts, ...) of a document.
|
||||
|
||||
ColQwen2 is part of the ColVision model family, which was introduced with ColPali in the following paper:
|
||||
[*ColPali: Efficient Document Retrieval with Vision Language Models*](https://arxiv.org/abs/2407.01449).
|
||||
"""
|
||||
)
|
||||
class ColQwen2ForRetrieval(ColPaliForRetrieval):
|
||||
@can_return_tuple
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
pixel_values: Optional[torch.Tensor] = None,
|
||||
image_grid_thw: Optional[torch.LongTensor] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
) -> ColQwen2ForRetrievalOutput:
|
||||
r"""
|
||||
image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
|
||||
The temporal, height and width of feature shape of each image in LLM.
|
||||
"""
|
||||
if pixel_values is not None:
|
||||
pixel_values = pixel_values.to(dtype=self.dtype) # (batch_size, max_num_patches, pixel_values)
|
||||
|
||||
# Handle the custom "pixel_values" input obtained with `ColQwen2Processor` through unpadding
|
||||
if pixel_values is not None and image_grid_thw is not None:
|
||||
# NOTE: image_grid_thw: (batch_size, 3) where image_grid_thw[i] = (num_patches_h, num_patches_w, temporal_patch_size)
|
||||
offsets = image_grid_thw[:, 1] * image_grid_thw[:, 2] # (num_patches_h, num_patches_w)
|
||||
pixel_values = torch.cat(
|
||||
[pixel_sequence[:offset] for pixel_sequence, offset in zip(pixel_values, offsets)],
|
||||
dim=0,
|
||||
) # (num_patches_h * num_patches_w, pixel_values)
|
||||
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
position_ids, rope_deltas = self.vlm.model.get_rope_index(
|
||||
input_ids=input_ids,
|
||||
image_grid_thw=image_grid_thw,
|
||||
video_grid_thw=None,
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
|
||||
# Custom data preparation to fix an issue with the gradient flow when training with multiple GPUs.
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.vlm.model.language_model.embed_tokens(input_ids)
|
||||
|
||||
if pixel_values is not None:
|
||||
pixel_values = pixel_values.type(self.vlm.visual.get_dtype())
|
||||
image_embeds = self.vlm.visual(pixel_values, grid_thw=image_grid_thw)
|
||||
image_mask = (
|
||||
(input_ids == self.config.vlm_config.image_token_id).unsqueeze(-1).expand_as(inputs_embeds)
|
||||
)
|
||||
image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
|
||||
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask.to(inputs_embeds.device)
|
||||
|
||||
vlm_output = self.vlm.model(
|
||||
input_ids=None,
|
||||
position_ids=position_ids,
|
||||
attention_mask=attention_mask,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
)
|
||||
|
||||
vlm_hidden_states = vlm_output.hidden_states if output_hidden_states else None
|
||||
|
||||
last_hidden_states = vlm_output[0] # (batch_size, sequence_length, hidden_size)
|
||||
embeddings = self.embedding_proj_layer(last_hidden_states) # (batch_size, sequence_length, dim)
|
||||
|
||||
# L2 normalization
|
||||
embeddings = embeddings / embeddings.norm(dim=-1, keepdim=True) # (batch_size, sequence_length, dim)
|
||||
if attention_mask is not None:
|
||||
embeddings = embeddings * attention_mask.unsqueeze(-1) # (batch_size, sequence_length, dim)
|
||||
|
||||
return ColQwen2ForRetrievalOutput(
|
||||
embeddings=embeddings,
|
||||
past_key_values=vlm_output.past_key_values,
|
||||
hidden_states=vlm_hidden_states,
|
||||
attentions=vlm_output.attentions,
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"ColQwen2ForRetrieval",
|
||||
"ColQwen2PreTrainedModel",
|
||||
"ColQwen2Processor",
|
||||
]
|
408
src/transformers/models/colqwen2/processing_colqwen2.py
Normal file
408
src/transformers/models/colqwen2/processing_colqwen2.py
Normal file
@ -0,0 +1,408 @@
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
# This file was automatically generated from src/transformers/models/colqwen2/modular_colqwen2.py.
|
||||
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
||||
# the file from the modular. If any change should be done, please apply the change to the
|
||||
# modular_colqwen2.py file directly. One of our CI enforces this.
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
# coding=utf-8
|
||||
# Copyright 2025 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from ...feature_extraction_utils import BatchFeature
|
||||
from ...image_utils import ImageInput, is_valid_image
|
||||
from ...processing_utils import MultiModalData, ProcessingKwargs, ProcessorMixin, Unpack
|
||||
from ...tokenization_utils_base import PreTokenizedInput, TextInput
|
||||
from ...utils import is_torch_available
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
|
||||
class ColQwen2ProcessorKwargs(ProcessingKwargs, total=False):
|
||||
_defaults = {
|
||||
"text_kwargs": {
|
||||
"padding": "longest",
|
||||
},
|
||||
"images_kwargs": {
|
||||
"data_format": "channels_first",
|
||||
"do_convert_rgb": True,
|
||||
},
|
||||
"common_kwargs": {"return_tensors": "pt"},
|
||||
}
|
||||
|
||||
|
||||
class ColQwen2Processor(ProcessorMixin):
|
||||
r"""
|
||||
Constructs a ColQwen2 processor which wraps a Qwen2VLProcessor and special methods to process images and queries, as
|
||||
well as to compute the late-interaction retrieval score.
|
||||
|
||||
[`ColQwen2Processor`] offers all the functionalities of [`Qwen2VLProcessor`]. See the [`~Qwen2VLProcessor.__call__`]
|
||||
for more information.
|
||||
|
||||
Args:
|
||||
image_processor ([`Qwen2VLImageProcessor`], *optional*):
|
||||
The image processor is a required input.
|
||||
tokenizer ([`Qwen2TokenizerFast`], *optional*):
|
||||
The tokenizer is a required input.
|
||||
chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages
|
||||
in a chat into a tokenizable string.
|
||||
visual_prompt_prefix (`str`, *optional*): A string that gets tokenized and prepended to the image tokens.
|
||||
query_prefix (`str`, *optional*): A prefix to be used for the query.
|
||||
"""
|
||||
|
||||
attributes = ["image_processor", "tokenizer"]
|
||||
|
||||
image_processor_class = "Qwen2VLImageProcessor"
|
||||
tokenizer_class = ("Qwen2Tokenizer", "Qwen2TokenizerFast")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
image_processor=None,
|
||||
tokenizer=None,
|
||||
chat_template=None,
|
||||
visual_prompt_prefix: Optional[str] = None,
|
||||
query_prefix: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(image_processor, tokenizer, chat_template=chat_template)
|
||||
self.image_token = "<|image_pad|>" if not hasattr(tokenizer, "image_token") else tokenizer.image_token
|
||||
self.video_token = "<|video_pad|>" if not hasattr(tokenizer, "video_token") else tokenizer.video_token
|
||||
|
||||
if visual_prompt_prefix is None:
|
||||
visual_prompt_prefix = "<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>Describe the image.<|im_end|><|endoftext|>"
|
||||
self.visual_prompt_prefix = visual_prompt_prefix
|
||||
|
||||
if query_prefix is None:
|
||||
query_prefix = "Query: "
|
||||
self.query_prefix = query_prefix
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
images: ImageInput = None,
|
||||
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
|
||||
audio=None,
|
||||
videos=None,
|
||||
**kwargs: Unpack[ColQwen2ProcessorKwargs],
|
||||
) -> BatchFeature:
|
||||
"""
|
||||
Main method to prepare for the model either (1) one or several texts, either (2) one or several image(s). This method is a custom
|
||||
wrapper around the Qwen2VLProcessor's [`~Qwen2VLProcessor.__call__`] method adapted for the ColQwen2 model. It cannot process
|
||||
both text and images at the same time.
|
||||
|
||||
When preparing the the text(s), this method forwards the `text` and `kwargs` arguments to Qwen2TokenizerFast's
|
||||
[`~Qwen2TokenizerFast.__call__`].
|
||||
When preparing the the image(s), this method forwards the `images` and `kwargs` arguments to Qwen2VLImageProcessor's
|
||||
[`~Qwen2VLImageProcessor.__call__`].
|
||||
Please refer to the doctsring of the above two methods for more information.
|
||||
|
||||
Args:
|
||||
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
|
||||
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
|
||||
tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a
|
||||
number of channels, H and W are image height and width.
|
||||
text (`str`, `List[str]`, `List[List[str]]`):
|
||||
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
|
||||
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
|
||||
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
|
||||
return_tensors (`str` or [`~utils.TensorType`], *optional*):
|
||||
If set, will return tensors of a particular framework. Acceptable values are:
|
||||
|
||||
- `'tf'`: Return TensorFlow `tf.constant` objects.
|
||||
- `'pt'`: Return PyTorch `torch.Tensor` objects.
|
||||
- `'np'`: Return NumPy `np.ndarray` objects.
|
||||
- `'jax'`: Return JAX `jnp.ndarray` objects.
|
||||
|
||||
Returns:
|
||||
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
|
||||
|
||||
- **input_ids** -- List of token ids to be fed to a model.
|
||||
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
|
||||
`return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
|
||||
`None`).
|
||||
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
|
||||
"""
|
||||
output_kwargs = self._merge_kwargs(
|
||||
ColQwen2ProcessorKwargs,
|
||||
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
suffix = output_kwargs["text_kwargs"].pop("suffix", None)
|
||||
|
||||
return_token_type_ids = True if suffix is not None else False
|
||||
|
||||
if text is None and images is None:
|
||||
raise ValueError("Either text or images must be provided")
|
||||
if text is not None and images is not None:
|
||||
raise ValueError("Only one of text or images can be processed at a time")
|
||||
|
||||
if images is not None:
|
||||
if is_valid_image(images):
|
||||
images = [images]
|
||||
elif isinstance(images, list) and is_valid_image(images[0]):
|
||||
pass
|
||||
elif not (isinstance(images, list) and isinstance(images[0], list) and is_valid_image(images[0][0])):
|
||||
raise ValueError("images must be an image, list of images or list of list of images")
|
||||
|
||||
texts_doc = [self.visual_prompt_prefix] * len(images)
|
||||
|
||||
image_inputs = self.image_processor(images=images, **output_kwargs["images_kwargs"])
|
||||
image_grid_thw = image_inputs["image_grid_thw"]
|
||||
|
||||
if image_grid_thw is not None:
|
||||
merge_length = self.image_processor.merge_size**2
|
||||
index = 0
|
||||
for i in range(len(texts_doc)):
|
||||
while self.image_token in texts_doc[i]:
|
||||
texts_doc[i] = texts_doc[i].replace(
|
||||
self.image_token, "<|placeholder|>" * (image_grid_thw[index].prod() // merge_length), 1
|
||||
)
|
||||
index += 1
|
||||
texts_doc[i] = texts_doc[i].replace("<|placeholder|>", self.image_token)
|
||||
|
||||
text_inputs = self.tokenizer(
|
||||
texts_doc,
|
||||
return_token_type_ids=False,
|
||||
**output_kwargs["text_kwargs"],
|
||||
)
|
||||
|
||||
return_data = BatchFeature(data={**text_inputs, **image_inputs})
|
||||
|
||||
# NOTE: The following adjustment ensures correct behavior with DDP on multiple GPUs.
|
||||
offsets = return_data["image_grid_thw"][:, 1] * return_data["image_grid_thw"][:, 2] # (batch_size,)
|
||||
|
||||
# Split the pixel_values tensor into a list of tensors, one per image
|
||||
pixel_values = list(
|
||||
torch.split(return_data["pixel_values"], offsets.tolist())
|
||||
) # [(num_patches_image_0, pixel_values), ..., (num_patches_image_n, pixel_values)]
|
||||
|
||||
# Pad the list of pixel_value tensors to the same length along the sequence dimension
|
||||
return_data["pixel_values"] = torch.nn.utils.rnn.pad_sequence(
|
||||
pixel_values, batch_first=True
|
||||
) # (batch_size, max_num_patches, pixel_values)
|
||||
|
||||
if return_token_type_ids:
|
||||
labels = return_data["input_ids"].masked_fill(return_data["token_type_ids"] == 0, -100)
|
||||
return_data.update({"labels": labels})
|
||||
|
||||
return return_data
|
||||
|
||||
elif text is not None:
|
||||
if isinstance(text, str):
|
||||
text = [text]
|
||||
elif not (isinstance(text, list) and isinstance(text[0], str)):
|
||||
raise ValueError("Text must be a string or a list of strings")
|
||||
|
||||
if suffix is None:
|
||||
suffix = self.query_augmentation_token * 10
|
||||
|
||||
texts_query: List[str] = []
|
||||
|
||||
for query in text:
|
||||
augmented_query = self.query_prefix + query + suffix
|
||||
texts_query.append(augmented_query)
|
||||
|
||||
batch_query = self.tokenizer(
|
||||
texts_query,
|
||||
return_token_type_ids=False,
|
||||
**output_kwargs["text_kwargs"],
|
||||
)
|
||||
|
||||
return batch_query
|
||||
|
||||
def _get_num_multimodal_tokens(self, image_sizes=None, **kwargs):
|
||||
"""
|
||||
Computes the number of placeholder tokens needed for multimodal inputs with the given sizes.
|
||||
|
||||
Args:
|
||||
image_sizes (List[List[str]], *optional*):
|
||||
The input sizes formatted as (height, width) per each image.
|
||||
Returns:
|
||||
Dict[str, List[int]]: A dictionary mapping each modality ("image", "video", "audio")
|
||||
to a list containing the number of placeholder tokens required. If the model doesn't accept
|
||||
a certain modality or no input sizes are provided, the dict value is set to an empty list.
|
||||
"""
|
||||
vision_data = {}
|
||||
if image_sizes is not None:
|
||||
num_image_tokens = [self.image_seq_length] * len(image_sizes)
|
||||
num_image_patches = [1] * len(image_sizes)
|
||||
vision_data.update({"num_image_tokens": num_image_tokens, "num_image_patches": num_image_patches})
|
||||
return MultiModalData(**vision_data)
|
||||
|
||||
def batch_decode(self, *args, **kwargs):
|
||||
"""
|
||||
This method forwards all its arguments to GemmaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
|
||||
refer to the docstring of this method for more information.
|
||||
"""
|
||||
return self.tokenizer.batch_decode(*args, **kwargs)
|
||||
|
||||
def decode(self, *args, **kwargs):
|
||||
"""
|
||||
This method forwards all its arguments to GemmaTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
|
||||
the docstring of this method for more information.
|
||||
"""
|
||||
return self.tokenizer.decode(*args, **kwargs)
|
||||
|
||||
@property
|
||||
def model_input_names(self):
|
||||
tokenizer_input_names = self.tokenizer.model_input_names
|
||||
image_processor_input_names = self.image_processor.model_input_names
|
||||
return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
|
||||
|
||||
@property
|
||||
def query_augmentation_token(self) -> str:
|
||||
"""
|
||||
Return the query augmentation token.
|
||||
|
||||
Query augmentation buffers are used as reasoning buffers during inference.
|
||||
"""
|
||||
return self.tokenizer.pad_token
|
||||
|
||||
def process_images(
|
||||
self,
|
||||
images: ImageInput = None,
|
||||
**kwargs: Unpack[ColQwen2ProcessorKwargs],
|
||||
) -> BatchFeature:
|
||||
"""
|
||||
Prepare for the model one or several image(s). This method is a wrapper around the `__call__` method of the ColQwen2Processor's
|
||||
[`ColQwen2Processor.__call__`].
|
||||
|
||||
This method forwards the `images` and `kwargs` arguments to the image processor.
|
||||
|
||||
Args:
|
||||
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
|
||||
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
|
||||
tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a
|
||||
number of channels, H and W are image height and width.
|
||||
return_tensors (`str` or [`~utils.TensorType`], *optional*):
|
||||
If set, will return tensors of a particular framework. Acceptable values are:
|
||||
|
||||
- `'tf'`: Return TensorFlow `tf.constant` objects.
|
||||
- `'pt'`: Return PyTorch `torch.Tensor` objects.
|
||||
- `'np'`: Return NumPy `np.ndarray` objects.
|
||||
- `'jax'`: Return JAX `jnp.ndarray` objects.
|
||||
|
||||
Returns:
|
||||
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
|
||||
|
||||
- **input_ids** -- List of token ids to be fed to a model.
|
||||
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
|
||||
`return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
|
||||
`None`).
|
||||
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
|
||||
"""
|
||||
return self.__call__(images=images, **kwargs)
|
||||
|
||||
def process_queries(
|
||||
self,
|
||||
text: Union[TextInput, List[TextInput]],
|
||||
**kwargs: Unpack[ColQwen2ProcessorKwargs],
|
||||
) -> BatchFeature:
|
||||
"""
|
||||
Prepare for the model one or several texts. This method is a wrapper around the `__call__` method of the ColQwen2Processor's
|
||||
[`ColQwen2Processor.__call__`].
|
||||
|
||||
This method forwards the `text` and `kwargs` arguments to the tokenizer.
|
||||
|
||||
Args:
|
||||
text (`str`, `List[str]`, `List[List[str]]`):
|
||||
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
|
||||
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
|
||||
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
|
||||
return_tensors (`str` or [`~utils.TensorType`], *optional*):
|
||||
If set, will return tensors of a particular framework. Acceptable values are:
|
||||
|
||||
- `'tf'`: Return TensorFlow `tf.constant` objects.
|
||||
- `'pt'`: Return PyTorch `torch.Tensor` objects.
|
||||
- `'np'`: Return NumPy `np.ndarray` objects.
|
||||
- `'jax'`: Return JAX `jnp.ndarray` objects.
|
||||
|
||||
Returns:
|
||||
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
|
||||
|
||||
- **input_ids** -- List of token ids to be fed to a model.
|
||||
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
|
||||
`return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
|
||||
`None`).
|
||||
"""
|
||||
return self.__call__(text=text, **kwargs)
|
||||
|
||||
def score_retrieval(
|
||||
self,
|
||||
query_embeddings: Union["torch.Tensor", List["torch.Tensor"]],
|
||||
passage_embeddings: Union["torch.Tensor", List["torch.Tensor"]],
|
||||
batch_size: int = 128,
|
||||
output_dtype: Optional["torch.dtype"] = None,
|
||||
output_device: Union["torch.device", str] = "cpu",
|
||||
) -> "torch.Tensor":
|
||||
"""
|
||||
Compute the late-interaction/MaxSim score (ColBERT-like) for the given multi-vector
|
||||
query embeddings (`qs`) and passage embeddings (`ps`). For ColQwen2, a passage is the
|
||||
image of a document page.
|
||||
|
||||
Because the embedding tensors are multi-vector and can thus have different shapes, they
|
||||
should be fed as:
|
||||
(1) a list of tensors, where the i-th tensor is of shape (sequence_length_i, embedding_dim)
|
||||
(2) a single tensor of shape (n_passages, max_sequence_length, embedding_dim) -> usually
|
||||
obtained by padding the list of tensors.
|
||||
|
||||
Args:
|
||||
query_embeddings (`Union[torch.Tensor, List[torch.Tensor]`): Query embeddings.
|
||||
passage_embeddings (`Union[torch.Tensor, List[torch.Tensor]`): Passage embeddings.
|
||||
batch_size (`int`, *optional*, defaults to 128): Batch size for computing scores.
|
||||
output_dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): The dtype of the output tensor.
|
||||
If `None`, the dtype of the input embeddings is used.
|
||||
output_device (`torch.device` or `str`, *optional*, defaults to "cpu"): The device of the output tensor.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`: A tensor of shape `(n_queries, n_passages)` containing the scores. The score
|
||||
tensor is saved on the "cpu" device.
|
||||
"""
|
||||
|
||||
if len(query_embeddings) == 0:
|
||||
raise ValueError("No queries provided")
|
||||
if len(passage_embeddings) == 0:
|
||||
raise ValueError("No passages provided")
|
||||
|
||||
if query_embeddings[0].device != passage_embeddings[0].device:
|
||||
raise ValueError("Queries and passages must be on the same device")
|
||||
|
||||
if query_embeddings[0].dtype != passage_embeddings[0].dtype:
|
||||
raise ValueError("Queries and passages must have the same dtype")
|
||||
|
||||
if output_dtype is None:
|
||||
output_dtype = query_embeddings[0].dtype
|
||||
|
||||
scores: List[torch.Tensor] = []
|
||||
|
||||
for i in range(0, len(query_embeddings), batch_size):
|
||||
batch_scores: List[torch.Tensor] = []
|
||||
batch_queries = torch.nn.utils.rnn.pad_sequence(
|
||||
query_embeddings[i : i + batch_size], batch_first=True, padding_value=0
|
||||
)
|
||||
for j in range(0, len(passage_embeddings), batch_size):
|
||||
batch_passages = torch.nn.utils.rnn.pad_sequence(
|
||||
passage_embeddings[j : j + batch_size], batch_first=True, padding_value=0
|
||||
)
|
||||
batch_scores.append(
|
||||
torch.einsum("bnd,csd->bcns", batch_queries, batch_passages).max(dim=3)[0].sum(dim=2)
|
||||
)
|
||||
scores.append(torch.cat(batch_scores, dim=1).to(output_dtype).to(output_device))
|
||||
|
||||
return torch.cat(scores, dim=0)
|
||||
|
||||
|
||||
__all__ = ["ColQwen2Processor"]
|
@ -168,7 +168,6 @@ class ColPaliForRetrievalModelTester:
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"labels": input_ids,
|
||||
"token_type_ids": torch.zeros_like(input_ids),
|
||||
}
|
||||
return config, inputs_dict
|
||||
|
||||
@ -333,7 +332,7 @@ class ColPaliModelIntegrationTest(unittest.TestCase):
|
||||
scores = self.processor.score_retrieval(
|
||||
query_embeddings=query_embeddings,
|
||||
passage_embeddings=image_embeddings,
|
||||
) # (len(qs), len(ps))
|
||||
) # (num_queries, num_passages)
|
||||
|
||||
assert scores.ndim == 2, f"Expected 2D tensor, got {scores.ndim}"
|
||||
assert scores.shape == (len(ds), len(ds)), f"Expected shape {(len(ds), len(ds))}, got {scores.shape}"
|
||||
|
@ -1,3 +1,18 @@
|
||||
# Copyright 2024 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Testing suite for the ColPali processor."""
|
||||
|
||||
import shutil
|
||||
import tempfile
|
||||
import unittest
|
||||
@ -89,7 +104,7 @@ class ColPaliProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
self.assertIsInstance(batch_feature["input_ids"], torch.Tensor)
|
||||
self.assertEqual(batch_feature["input_ids"].shape[0], len(queries))
|
||||
|
||||
# The following tests are overwritten as ColPaliProcessor can only take one of images or text as input at a time
|
||||
# The following tests override the parent tests because ColPaliProcessor can only take one of images or text as input at a time.
|
||||
|
||||
def test_tokenizer_defaults_preserved_by_kwargs(self):
|
||||
if "image_processor" not in self.processor_class.attributes:
|
||||
|
0
tests/models/colqwen2/__init__.py
Normal file
0
tests/models/colqwen2/__init__.py
Normal file
333
tests/models/colqwen2/test_modeling_colqwen2.py
Normal file
333
tests/models/colqwen2/test_modeling_colqwen2.py
Normal file
@ -0,0 +1,333 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Testing suite for the PyTorch ColQwen2 model."""
|
||||
|
||||
import gc
|
||||
import unittest
|
||||
from typing import ClassVar
|
||||
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
|
||||
from tests.test_configuration_common import ConfigTester
|
||||
from tests.test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
|
||||
from transformers import is_torch_available
|
||||
from transformers.models.colqwen2.configuration_colqwen2 import ColQwen2Config
|
||||
from transformers.models.colqwen2.modeling_colqwen2 import ColQwen2ForRetrieval, ColQwen2ForRetrievalOutput
|
||||
from transformers.models.colqwen2.processing_colqwen2 import ColQwen2Processor
|
||||
from transformers.testing_utils import require_torch, require_vision, slow, torch_device
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
|
||||
class ColQwen2ForRetrievalModelTester:
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
ignore_index=-100,
|
||||
pad_token_id=2,
|
||||
projector_hidden_act="gelu",
|
||||
seq_length=11,
|
||||
vision_feature_select_strategy="default",
|
||||
vision_feature_layer=-1,
|
||||
projection_dim=32,
|
||||
is_training=False,
|
||||
use_cache=False,
|
||||
vlm_config={
|
||||
"_name_or_path": "Qwen/Qwen2-VL-2B-Instruct",
|
||||
"bos_token_id": 0,
|
||||
"eos_token_id": 1,
|
||||
"vision_start_token_id": 3,
|
||||
"image_token_id": 4,
|
||||
"video_token_id": 5,
|
||||
"hidden_size": 64,
|
||||
"intermediate_size": 2,
|
||||
"max_window_layers": 2,
|
||||
"model_type": "qwen2_vl",
|
||||
"num_attention_heads": 2,
|
||||
"num_hidden_layers": 2,
|
||||
"num_key_value_heads": 2,
|
||||
"rms_norm_eps": 1e-06,
|
||||
"rope_scaling": {"mrope_section": [4, 6, 6], "rope_type": "default", "type": "default"},
|
||||
"sliding_window": 32768,
|
||||
"tie_word_embeddings": True,
|
||||
"vision_config": {
|
||||
"depth": 2,
|
||||
"embed_dim": 32,
|
||||
"hidden_act": "quick_gelu",
|
||||
"hidden_size": 64,
|
||||
"mlp_ratio": 4,
|
||||
"num_heads": 4,
|
||||
"patch_size": 14,
|
||||
"in_chans": 3,
|
||||
"spatial_merge_size": 1,
|
||||
"temporal_patch_size": 2,
|
||||
},
|
||||
"vision_end_token_id": 151653,
|
||||
"vision_token_id": 151654,
|
||||
"vocab_size": 99,
|
||||
},
|
||||
embedding_dim=32,
|
||||
initializer_range=0.02,
|
||||
):
|
||||
self.parent = parent
|
||||
self.ignore_index = ignore_index
|
||||
self.pad_token_id = pad_token_id
|
||||
|
||||
# `image_token_index` is set to 0 to pass "resize_embeddings" test, do not modify
|
||||
self.image_token_index = 0
|
||||
|
||||
self.image_token_id = vlm_config["image_token_id"]
|
||||
self.video_token_id = vlm_config["video_token_id"]
|
||||
self.pad_token_id = vlm_config["eos_token_id"]
|
||||
self.vision_start_token_id = vlm_config["vision_start_token_id"]
|
||||
self.projector_hidden_act = projector_hidden_act
|
||||
self.vision_feature_select_strategy = vision_feature_select_strategy
|
||||
self.vision_feature_layer = vision_feature_layer
|
||||
|
||||
self.image_size = 56
|
||||
self.num_image_tokens = 4
|
||||
|
||||
self.seq_length = seq_length + self.num_image_tokens
|
||||
self.projection_dim = projection_dim
|
||||
|
||||
self.num_hidden_layers = vlm_config["num_hidden_layers"]
|
||||
self.vocab_size = vlm_config["vocab_size"]
|
||||
self.hidden_size = vlm_config["hidden_size"]
|
||||
self.num_attention_heads = vlm_config["num_attention_heads"]
|
||||
self.is_training = is_training
|
||||
|
||||
self.batch_size = 3
|
||||
self.num_channels = vlm_config["vision_config"]["in_chans"]
|
||||
|
||||
self.encoder_seq_length = self.seq_length
|
||||
self.use_cache = use_cache
|
||||
|
||||
self.vlm_config = vlm_config
|
||||
self.embedding_dim = embedding_dim
|
||||
self.initializer_range = initializer_range
|
||||
|
||||
def get_config(self):
|
||||
return ColQwen2Config(
|
||||
vlm_config=self.vlm_config,
|
||||
embedding_dim=self.embedding_dim,
|
||||
initializer_range=self.initializer_range,
|
||||
)
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
config = self.get_config()
|
||||
patch_size = config.vlm_config.vision_config.patch_size
|
||||
temporal_patch_size = config.vlm_config.vision_config.temporal_patch_size
|
||||
|
||||
# NOTE: Assume all inputs are square images of the same size.
|
||||
num_patches = (self.image_size // patch_size) ** 2
|
||||
pixel_values = floats_tensor(
|
||||
[
|
||||
self.batch_size * num_patches,
|
||||
self.num_channels * (patch_size**2) * temporal_patch_size,
|
||||
]
|
||||
)
|
||||
|
||||
# Hardcoded image grid size: do not change unless you modified image size or patch size!
|
||||
image_grid_thw = torch.tensor([1, 4, 4]).repeat(self.batch_size, 1)
|
||||
|
||||
# NOTE: The following adjustment ensures correct behavior with DDP on multiple GPUs.
|
||||
# Line is copied from `src/transformers/models/colqwen2/processing_colqwen2.py`
|
||||
offsets = image_grid_thw[:, 1] * image_grid_thw[:, 2] # (batch_size,)
|
||||
pixel_values = list(
|
||||
torch.split(pixel_values, offsets.tolist())
|
||||
) # [(num_patches_image_0, pixel_values), ..., (num_patches_image_n, pixel_values)]
|
||||
pixel_values = torch.nn.utils.rnn.pad_sequence(
|
||||
pixel_values, batch_first=True
|
||||
) # (batch_size, max_num_patches, pixel_values)
|
||||
|
||||
return config, pixel_values, image_grid_thw
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
config, pixel_values, image_grid_thw = config_and_inputs
|
||||
input_ids = (
|
||||
ids_tensor(
|
||||
shape=[self.batch_size, self.seq_length],
|
||||
vocab_size=config.vlm_config.vocab_size - 1,
|
||||
)
|
||||
+ 1
|
||||
)
|
||||
attention_mask = torch.ones(input_ids.shape, dtype=torch.long, device=torch_device)
|
||||
|
||||
input_ids[:, -1] = self.pad_token_id
|
||||
input_ids[:, : self.num_image_tokens] = self.image_token_id
|
||||
input_ids[input_ids == self.video_token_id] = self.pad_token_id
|
||||
input_ids[input_ids == self.image_token_id] = self.pad_token_id
|
||||
input_ids[input_ids == self.vision_start_token_id] = self.pad_token_id
|
||||
|
||||
inputs_dict = {
|
||||
"input_ids": input_ids,
|
||||
"pixel_values": pixel_values,
|
||||
"image_grid_thw": image_grid_thw,
|
||||
"attention_mask": attention_mask,
|
||||
"labels": input_ids,
|
||||
}
|
||||
return config, inputs_dict
|
||||
|
||||
|
||||
@require_torch
|
||||
class ColQwen2ForRetrievalModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
"""
|
||||
Model tester for `ColQwen2ForRetrieval`.
|
||||
"""
|
||||
|
||||
all_model_classes = (ColQwen2ForRetrieval,) if is_torch_available() else ()
|
||||
fx_compatible = False
|
||||
test_torchscript = False
|
||||
test_pruning = False
|
||||
test_resize_embeddings = True
|
||||
test_head_masking = False
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = ColQwen2ForRetrievalModelTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=ColQwen2Config, has_text_modality=False)
|
||||
|
||||
def test_inputs_embeds(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
|
||||
input_ids = inputs["input_ids"]
|
||||
del inputs["input_ids"]
|
||||
del inputs["pixel_values"]
|
||||
|
||||
wte = model.get_input_embeddings()
|
||||
inputs["inputs_embeds"] = wte(input_ids)
|
||||
|
||||
with torch.no_grad():
|
||||
model(**inputs)
|
||||
|
||||
# overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs
|
||||
# while some other models require pixel_values to be present
|
||||
def test_inputs_embeds_matches_input_ids(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
input_ids = inputs["input_ids"]
|
||||
del inputs["input_ids"]
|
||||
del inputs["pixel_values"]
|
||||
|
||||
inputs_embeds = model.get_input_embeddings()(input_ids)
|
||||
|
||||
with torch.no_grad():
|
||||
out_ids = model(input_ids=input_ids, **inputs)[0]
|
||||
out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0]
|
||||
self.assertTrue(torch.allclose(out_embeds, out_ids))
|
||||
|
||||
@slow
|
||||
@require_vision
|
||||
def test_colqwen2_forward_inputs(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs, return_dict=True)
|
||||
|
||||
self.assertIsInstance(outputs, ColQwen2ForRetrievalOutput)
|
||||
|
||||
@unittest.skip(reason="Some undefined behavior encountered with test versions of Qwen2-VL. Skip for now.")
|
||||
def test_model_parallelism(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Pass because ColQwen2 requires `attention_mask is not None`")
|
||||
def test_sdpa_can_dispatch_on_flash(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Pass because ColQwen2 requires `attention_mask is not None`")
|
||||
def test_sdpa_can_compile_dynamic(self):
|
||||
pass
|
||||
|
||||
|
||||
@require_torch
|
||||
class ColQwen2ModelIntegrationTest(unittest.TestCase):
|
||||
model_name: ClassVar[str] = "vidore/colqwen2-v1.0-hf"
|
||||
|
||||
def setUp(self):
|
||||
self.processor = ColQwen2Processor.from_pretrained(self.model_name)
|
||||
|
||||
def tearDown(self):
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
@slow
|
||||
def test_model_integration_test(self):
|
||||
"""
|
||||
Test if the model is able to retrieve the correct pages for a small and easy dataset.
|
||||
"""
|
||||
model = ColQwen2ForRetrieval.from_pretrained(
|
||||
self.model_name,
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map=torch_device,
|
||||
).eval()
|
||||
|
||||
# Load the test dataset
|
||||
ds = load_dataset("hf-internal-testing/document-visual-retrieval-test", split="test")
|
||||
|
||||
# Preprocess the examples
|
||||
batch_images = self.processor(images=ds["image"]).to(torch_device)
|
||||
batch_queries = self.processor(text=ds["query"]).to(torch_device)
|
||||
|
||||
# Run inference
|
||||
with torch.inference_mode():
|
||||
image_embeddings = model(**batch_images).embeddings
|
||||
query_embeddings = model(**batch_queries).embeddings
|
||||
|
||||
# Compute retrieval scores
|
||||
scores = self.processor.score_retrieval(
|
||||
query_embeddings=query_embeddings,
|
||||
passage_embeddings=image_embeddings,
|
||||
) # (num_queries, num_passages)
|
||||
|
||||
assert scores.ndim == 2, f"Expected 2D tensor, got {scores.ndim}"
|
||||
assert scores.shape == (len(ds), len(ds)), f"Expected shape {(len(ds), len(ds))}, got {scores.shape}"
|
||||
|
||||
# Check if the maximum scores per row are in the diagonal of the matrix score
|
||||
self.assertTrue((scores.argmax(axis=1) == torch.arange(len(ds), device=scores.device)).all())
|
||||
|
||||
# Further validation: fine-grained check, with a hardcoded score from the original Hf implementation.
|
||||
expected_scores = torch.tensor(
|
||||
[
|
||||
[16.2500, 7.8750, 14.6875],
|
||||
[9.5000, 17.1250, 10.5000],
|
||||
[14.9375, 10.9375, 20.0000],
|
||||
],
|
||||
dtype=scores.dtype,
|
||||
)
|
||||
|
||||
assert torch.allclose(scores, expected_scores, atol=1e-3), f"Expected scores {expected_scores}, got {scores}"
|
262
tests/models/colqwen2/test_processing_colqwen2.py
Normal file
262
tests/models/colqwen2/test_processing_colqwen2.py
Normal file
@ -0,0 +1,262 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Testing suite for the ColQwen2 processor."""
|
||||
|
||||
import shutil
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from transformers import AutoProcessor, Qwen2VLProcessor
|
||||
from transformers.models.colqwen2.processing_colqwen2 import ColQwen2Processor
|
||||
from transformers.testing_utils import get_tests_dir, require_torch, require_vision
|
||||
from transformers.utils import is_vision_available
|
||||
|
||||
from ...test_processing_common import ProcessorTesterMixin
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
from transformers import (
|
||||
ColQwen2Processor,
|
||||
)
|
||||
|
||||
SAMPLE_VOCAB = get_tests_dir("fixtures/test_sentencepiece.model")
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_vision
|
||||
class ColQwen2ProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
processor_class = ColQwen2Processor
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.tmpdirname = tempfile.mkdtemp()
|
||||
processor = Qwen2VLProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")
|
||||
processor.save_pretrained(cls.tmpdirname)
|
||||
|
||||
def get_tokenizer(self, **kwargs):
|
||||
return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).tokenizer
|
||||
|
||||
def get_image_processor(self, **kwargs):
|
||||
return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).image_processor
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
shutil.rmtree(cls.tmpdirname)
|
||||
|
||||
def test_process_images(self):
|
||||
# Processor configuration
|
||||
image_input = self.prepare_image_inputs()
|
||||
image_processor = self.get_component("image_processor")
|
||||
tokenizer = self.get_component("tokenizer", max_length=112, padding="max_length")
|
||||
image_processor.image_seq_length = 14
|
||||
|
||||
# Get the processor
|
||||
processor = self.processor_class(
|
||||
tokenizer=tokenizer,
|
||||
image_processor=image_processor,
|
||||
)
|
||||
|
||||
# Process the image
|
||||
batch_feature = processor.process_images(images=image_input, return_tensors="pt")
|
||||
|
||||
# Assertions
|
||||
self.assertIn("pixel_values", batch_feature)
|
||||
self.assertEqual(batch_feature["pixel_values"].shape, torch.Size([1, 56, 1176]))
|
||||
|
||||
def test_process_queries(self):
|
||||
# Inputs
|
||||
queries = [
|
||||
"Is attention really all you need?",
|
||||
"Are Benjamin, Antoine, Merve, and Jo best friends?",
|
||||
]
|
||||
|
||||
# Processor configuration
|
||||
image_processor = self.get_component("image_processor")
|
||||
tokenizer = self.get_component("tokenizer", max_length=112, padding="max_length")
|
||||
image_processor.image_seq_length = 14
|
||||
|
||||
# Get the processor
|
||||
processor = self.processor_class(
|
||||
tokenizer=tokenizer,
|
||||
image_processor=image_processor,
|
||||
)
|
||||
|
||||
# Process the image
|
||||
batch_feature = processor.process_queries(text=queries, return_tensors="pt")
|
||||
|
||||
# Assertions
|
||||
self.assertIn("input_ids", batch_feature)
|
||||
self.assertIsInstance(batch_feature["input_ids"], torch.Tensor)
|
||||
self.assertEqual(batch_feature["input_ids"].shape[0], len(queries))
|
||||
|
||||
# The following tests override the parent tests because ColQwen2Processor can only take one of images or text as input at a time.
|
||||
|
||||
def test_tokenizer_defaults_preserved_by_kwargs(self):
|
||||
if "image_processor" not in self.processor_class.attributes:
|
||||
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
|
||||
processor_components = self.prepare_components()
|
||||
processor_components["tokenizer"] = self.get_component("tokenizer", max_length=117, padding="max_length")
|
||||
|
||||
processor = self.processor_class(**processor_components)
|
||||
self.skip_processor_without_typed_kwargs(processor)
|
||||
input_str = self.prepare_text_inputs()
|
||||
inputs = processor(text=input_str, return_tensors="pt")
|
||||
self.assertEqual(inputs[self.text_input_name].shape[-1], 117)
|
||||
|
||||
def test_image_processor_defaults_preserved_by_image_kwargs(self):
|
||||
"""
|
||||
We use do_rescale=True, rescale_factor=-1 to ensure that image_processor kwargs are preserved in the processor.
|
||||
We then check that the mean of the pixel_values is less than or equal to 0 after processing.
|
||||
Since the original pixel_values are in [0, 255], this is a good indicator that the rescale_factor is indeed applied.
|
||||
"""
|
||||
if "image_processor" not in self.processor_class.attributes:
|
||||
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
|
||||
processor_components = self.prepare_components()
|
||||
processor_components["image_processor"] = self.get_component(
|
||||
"image_processor", do_rescale=True, rescale_factor=-1
|
||||
)
|
||||
processor_components["tokenizer"] = self.get_component("tokenizer", max_length=117, padding="max_length")
|
||||
|
||||
processor = self.processor_class(**processor_components)
|
||||
self.skip_processor_without_typed_kwargs(processor)
|
||||
|
||||
image_input = self.prepare_image_inputs()
|
||||
|
||||
inputs = processor(images=image_input, return_tensors="pt")
|
||||
self.assertLessEqual(inputs[self.images_input_name][0][0].mean(), 0)
|
||||
|
||||
def test_kwargs_overrides_default_tokenizer_kwargs(self):
|
||||
if "image_processor" not in self.processor_class.attributes:
|
||||
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
|
||||
processor_components = self.prepare_components()
|
||||
processor_components["tokenizer"] = self.get_component("tokenizer", padding="longest")
|
||||
|
||||
processor = self.processor_class(**processor_components)
|
||||
self.skip_processor_without_typed_kwargs(processor)
|
||||
input_str = self.prepare_text_inputs()
|
||||
inputs = processor(text=input_str, return_tensors="pt", max_length=112, padding="max_length")
|
||||
self.assertEqual(inputs[self.text_input_name].shape[-1], 112)
|
||||
|
||||
def test_kwargs_overrides_default_image_processor_kwargs(self):
|
||||
if "image_processor" not in self.processor_class.attributes:
|
||||
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
|
||||
processor_components = self.prepare_components()
|
||||
processor_components["image_processor"] = self.get_component(
|
||||
"image_processor", do_rescale=True, rescale_factor=1
|
||||
)
|
||||
processor_components["tokenizer"] = self.get_component("tokenizer", max_length=117, padding="max_length")
|
||||
|
||||
processor = self.processor_class(**processor_components)
|
||||
self.skip_processor_without_typed_kwargs(processor)
|
||||
|
||||
image_input = self.prepare_image_inputs()
|
||||
|
||||
inputs = processor(images=image_input, do_rescale=True, rescale_factor=-1, return_tensors="pt")
|
||||
self.assertLessEqual(inputs[self.images_input_name][0][0].mean(), 0)
|
||||
|
||||
def test_unstructured_kwargs(self):
|
||||
if "image_processor" not in self.processor_class.attributes:
|
||||
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
|
||||
processor_components = self.prepare_components()
|
||||
processor = self.processor_class(**processor_components)
|
||||
self.skip_processor_without_typed_kwargs(processor)
|
||||
|
||||
input_str = self.prepare_text_inputs()
|
||||
inputs = processor(
|
||||
text=input_str,
|
||||
return_tensors="pt",
|
||||
do_rescale=True,
|
||||
rescale_factor=-1,
|
||||
padding="max_length",
|
||||
max_length=76,
|
||||
)
|
||||
|
||||
self.assertEqual(inputs[self.text_input_name].shape[-1], 76)
|
||||
|
||||
def test_unstructured_kwargs_batched(self):
|
||||
if "image_processor" not in self.processor_class.attributes:
|
||||
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
|
||||
processor_components = self.prepare_components()
|
||||
processor = self.processor_class(**processor_components)
|
||||
self.skip_processor_without_typed_kwargs(processor)
|
||||
|
||||
image_input = self.prepare_image_inputs(batch_size=2)
|
||||
inputs = processor(
|
||||
images=image_input,
|
||||
return_tensors="pt",
|
||||
do_rescale=True,
|
||||
rescale_factor=-1,
|
||||
padding="longest",
|
||||
max_length=76,
|
||||
)
|
||||
|
||||
self.assertLessEqual(inputs[self.images_input_name][0][0].mean(), 0)
|
||||
|
||||
def test_doubly_passed_kwargs(self):
|
||||
if "image_processor" not in self.processor_class.attributes:
|
||||
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
|
||||
processor_components = self.prepare_components()
|
||||
processor = self.processor_class(**processor_components)
|
||||
self.skip_processor_without_typed_kwargs(processor)
|
||||
|
||||
image_input = self.prepare_image_inputs()
|
||||
with self.assertRaises(ValueError):
|
||||
_ = processor(
|
||||
images=image_input,
|
||||
images_kwargs={"do_rescale": True, "rescale_factor": -1},
|
||||
do_rescale=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
def test_structured_kwargs_nested(self):
|
||||
if "image_processor" not in self.processor_class.attributes:
|
||||
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
|
||||
processor_components = self.prepare_components()
|
||||
processor = self.processor_class(**processor_components)
|
||||
self.skip_processor_without_typed_kwargs(processor)
|
||||
|
||||
input_str = self.prepare_text_inputs()
|
||||
|
||||
# Define the kwargs for each modality
|
||||
all_kwargs = {
|
||||
"common_kwargs": {"return_tensors": "pt"},
|
||||
"images_kwargs": {"do_rescale": True, "rescale_factor": -1},
|
||||
"text_kwargs": {"padding": "max_length", "max_length": 76},
|
||||
}
|
||||
|
||||
inputs = processor(text=input_str, **all_kwargs)
|
||||
self.skip_processor_without_typed_kwargs(processor)
|
||||
|
||||
self.assertEqual(inputs[self.text_input_name].shape[-1], 76)
|
||||
|
||||
def test_structured_kwargs_nested_from_dict(self):
|
||||
if "image_processor" not in self.processor_class.attributes:
|
||||
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
|
||||
processor_components = self.prepare_components()
|
||||
processor = self.processor_class(**processor_components)
|
||||
self.skip_processor_without_typed_kwargs(processor)
|
||||
image_input = self.prepare_image_inputs()
|
||||
|
||||
# Define the kwargs for each modality
|
||||
all_kwargs = {
|
||||
"common_kwargs": {"return_tensors": "pt"},
|
||||
"images_kwargs": {"do_rescale": True, "rescale_factor": -1},
|
||||
"text_kwargs": {"padding": "max_length", "max_length": 76},
|
||||
}
|
||||
|
||||
inputs = processor(images=image_input, **all_kwargs)
|
||||
self.assertEqual(inputs[self.text_input_name].shape[-1], 76)
|
Loading…
Reference in New Issue
Block a user