transformers/tests/models/colqwen2/test_modeling_colqwen2.py
Tony Wu c72ba69441
Add ColQwen2 to 🤗 transformers (#35778)
* feat: add colqwen2 (wip)

* tests: fix test_attention_outputs

* tests: reduce hidden size to accelerate tests

* tests: fix `test_attention_outputs` 🥳

* fix: fix wrong parent class for `ColQwen2ForRetrievalOutput`

* fix: minor typing and style changes

* chore: run `make style`

* feat: remove redundant `max_num_visual_tokens` attribute in `ColQwen2Processor`

* tests: tweak comments

* style: apply ruff formatter

* feat: move default values for `visual_prompt_prefix` and `query_prefix`

* docs: update ColQwen2 model card

* docs: tweak model cards

* docs: add required example config checkpoint

* tests: update expected scores in integration test

* docs: tweak quickstart snippets

* fix: address PR comments

* tests: fix colqwen2 tests + tweak comment in colpali test

* tests: unskip useful tests

* fix: fix bug when `visual_prompt_prefix` or `query_prefix` is an empty string

* fix: fix ColPali outputs when `return_dict == False`

* fix: fix issue with PaliGemma output not being a dict

* docs: set default dtype to bfloat16 in quickstart snippets

* fix: fix error when `return_dict=False` in ColPali and ColQwen2

* tests: fix special tokens not being replaced in input_ids

* style: fix lint

* fix: `ColQwen2Processor`'s `padding_side` is now set from `processor_config.json`

* fix: remove unused `padding_side` in ColQwen2 model

* docs: update ColQwen2's model doc

* fix: fix harcoded vlm backbone class in ColQwen2Config

* fix: remove `padding_side` from ColQwen2Processor as should fed from kwargs

* docs: fix typo in model docstring

* docs: add illuin mention in model docs

* fix: let `padding_size` be handled by `tokenizer_config.json`

* docs: add colpali reference url in colqwen2's model doc

* docs: add Hf mention in model docs

* docs: add late interaction mention in model docs

* docs: tweak colqwen2 model doc

* docs: update reference checkpoint for ColPali to v1.3

* docs: simplify quickstart snippets

* docs: remove redundant `.eval()`

* refactor:  use `can_return_tuple` decorator for ColPali and ColQwen2

* docs: fix copyright date

* docs: add missing copyright in tests

* fix: raise error when `initializer_range` is not in config

* docs: remove redundant `.eval()` in colpali doc

* fix: fix `get_text_config` now that Qwen2VL has a proper `text_config` attribute

See https://github.com/huggingface/transformers/pull/37268 for details about changes in Qwen2VL's config.

* fix: add missing `initializer_range` attribute in `ColQwen2Config`

* fix: use `get_text_config` in `resize_token_embeddings`

* update colwen2 with auto_docstring

* docs: fix wrong copyright year

* chore: remove `raise` as `initializer_range` has a default value in `ColQwen2Config`

* refactor: merge `inner_forward` into `forward`

* Refactor colqwen2 after refactoring of qwen2VL, use modular for modeling code

* protect torch import in modular to protect in processing

* protect torch import in modular to protect in processing

* tests: fix hf model path in ColQwen2 integration test

* docs: clarify `attn_implementation` and add comments

* docs: add fallback snippet for using offline PIL dummy images

* docs: temporarily revert attn_implementation to `None` while sdpa is not fixed

* docs: tweaks in colpali/colqwen2 quick start snippets

* fix: add missing flags to enable SDPA/Flex Attention in ColQwen2 model

* fix: add missing changes in modular file

* fix modeling tests

---------

Co-authored-by: yonigozlan <yoni.gozlan@huggingface.co>
2025-06-02 12:58:01 +00:00

334 lines
12 KiB
Python

# coding=utf-8
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Testing suite for the PyTorch ColQwen2 model."""
import gc
import unittest
from typing import ClassVar
import torch
from datasets import load_dataset
from tests.test_configuration_common import ConfigTester
from tests.test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
from transformers import is_torch_available
from transformers.models.colqwen2.configuration_colqwen2 import ColQwen2Config
from transformers.models.colqwen2.modeling_colqwen2 import ColQwen2ForRetrieval, ColQwen2ForRetrievalOutput
from transformers.models.colqwen2.processing_colqwen2 import ColQwen2Processor
from transformers.testing_utils import require_torch, require_vision, slow, torch_device
if is_torch_available():
import torch
class ColQwen2ForRetrievalModelTester:
def __init__(
self,
parent,
ignore_index=-100,
pad_token_id=2,
projector_hidden_act="gelu",
seq_length=11,
vision_feature_select_strategy="default",
vision_feature_layer=-1,
projection_dim=32,
is_training=False,
use_cache=False,
vlm_config={
"_name_or_path": "Qwen/Qwen2-VL-2B-Instruct",
"bos_token_id": 0,
"eos_token_id": 1,
"vision_start_token_id": 3,
"image_token_id": 4,
"video_token_id": 5,
"hidden_size": 64,
"intermediate_size": 2,
"max_window_layers": 2,
"model_type": "qwen2_vl",
"num_attention_heads": 2,
"num_hidden_layers": 2,
"num_key_value_heads": 2,
"rms_norm_eps": 1e-06,
"rope_scaling": {"mrope_section": [4, 6, 6], "rope_type": "default", "type": "default"},
"sliding_window": 32768,
"tie_word_embeddings": True,
"vision_config": {
"depth": 2,
"embed_dim": 32,
"hidden_act": "quick_gelu",
"hidden_size": 64,
"mlp_ratio": 4,
"num_heads": 4,
"patch_size": 14,
"in_chans": 3,
"spatial_merge_size": 1,
"temporal_patch_size": 2,
},
"vision_end_token_id": 151653,
"vision_token_id": 151654,
"vocab_size": 99,
},
embedding_dim=32,
initializer_range=0.02,
):
self.parent = parent
self.ignore_index = ignore_index
self.pad_token_id = pad_token_id
# `image_token_index` is set to 0 to pass "resize_embeddings" test, do not modify
self.image_token_index = 0
self.image_token_id = vlm_config["image_token_id"]
self.video_token_id = vlm_config["video_token_id"]
self.pad_token_id = vlm_config["eos_token_id"]
self.vision_start_token_id = vlm_config["vision_start_token_id"]
self.projector_hidden_act = projector_hidden_act
self.vision_feature_select_strategy = vision_feature_select_strategy
self.vision_feature_layer = vision_feature_layer
self.image_size = 56
self.num_image_tokens = 4
self.seq_length = seq_length + self.num_image_tokens
self.projection_dim = projection_dim
self.num_hidden_layers = vlm_config["num_hidden_layers"]
self.vocab_size = vlm_config["vocab_size"]
self.hidden_size = vlm_config["hidden_size"]
self.num_attention_heads = vlm_config["num_attention_heads"]
self.is_training = is_training
self.batch_size = 3
self.num_channels = vlm_config["vision_config"]["in_chans"]
self.encoder_seq_length = self.seq_length
self.use_cache = use_cache
self.vlm_config = vlm_config
self.embedding_dim = embedding_dim
self.initializer_range = initializer_range
def get_config(self):
return ColQwen2Config(
vlm_config=self.vlm_config,
embedding_dim=self.embedding_dim,
initializer_range=self.initializer_range,
)
def prepare_config_and_inputs(self):
config = self.get_config()
patch_size = config.vlm_config.vision_config.patch_size
temporal_patch_size = config.vlm_config.vision_config.temporal_patch_size
# NOTE: Assume all inputs are square images of the same size.
num_patches = (self.image_size // patch_size) ** 2
pixel_values = floats_tensor(
[
self.batch_size * num_patches,
self.num_channels * (patch_size**2) * temporal_patch_size,
]
)
# Hardcoded image grid size: do not change unless you modified image size or patch size!
image_grid_thw = torch.tensor([1, 4, 4]).repeat(self.batch_size, 1)
# NOTE: The following adjustment ensures correct behavior with DDP on multiple GPUs.
# Line is copied from `src/transformers/models/colqwen2/processing_colqwen2.py`
offsets = image_grid_thw[:, 1] * image_grid_thw[:, 2] # (batch_size,)
pixel_values = list(
torch.split(pixel_values, offsets.tolist())
) # [(num_patches_image_0, pixel_values), ..., (num_patches_image_n, pixel_values)]
pixel_values = torch.nn.utils.rnn.pad_sequence(
pixel_values, batch_first=True
) # (batch_size, max_num_patches, pixel_values)
return config, pixel_values, image_grid_thw
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
config, pixel_values, image_grid_thw = config_and_inputs
input_ids = (
ids_tensor(
shape=[self.batch_size, self.seq_length],
vocab_size=config.vlm_config.vocab_size - 1,
)
+ 1
)
attention_mask = torch.ones(input_ids.shape, dtype=torch.long, device=torch_device)
input_ids[:, -1] = self.pad_token_id
input_ids[:, : self.num_image_tokens] = self.image_token_id
input_ids[input_ids == self.video_token_id] = self.pad_token_id
input_ids[input_ids == self.image_token_id] = self.pad_token_id
input_ids[input_ids == self.vision_start_token_id] = self.pad_token_id
inputs_dict = {
"input_ids": input_ids,
"pixel_values": pixel_values,
"image_grid_thw": image_grid_thw,
"attention_mask": attention_mask,
"labels": input_ids,
}
return config, inputs_dict
@require_torch
class ColQwen2ForRetrievalModelTest(ModelTesterMixin, unittest.TestCase):
"""
Model tester for `ColQwen2ForRetrieval`.
"""
all_model_classes = (ColQwen2ForRetrieval,) if is_torch_available() else ()
fx_compatible = False
test_torchscript = False
test_pruning = False
test_resize_embeddings = True
test_head_masking = False
def setUp(self):
self.model_tester = ColQwen2ForRetrievalModelTester(self)
self.config_tester = ConfigTester(self, config_class=ColQwen2Config, has_text_modality=False)
def test_inputs_embeds(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
model.to(torch_device)
model.eval()
inputs = self._prepare_for_class(inputs_dict, model_class)
input_ids = inputs["input_ids"]
del inputs["input_ids"]
del inputs["pixel_values"]
wte = model.get_input_embeddings()
inputs["inputs_embeds"] = wte(input_ids)
with torch.no_grad():
model(**inputs)
# overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs
# while some other models require pixel_values to be present
def test_inputs_embeds_matches_input_ids(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
model.to(torch_device)
model.eval()
inputs = self._prepare_for_class(inputs_dict, model_class)
input_ids = inputs["input_ids"]
del inputs["input_ids"]
del inputs["pixel_values"]
inputs_embeds = model.get_input_embeddings()(input_ids)
with torch.no_grad():
out_ids = model(input_ids=input_ids, **inputs)[0]
out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0]
self.assertTrue(torch.allclose(out_embeds, out_ids))
@slow
@require_vision
def test_colqwen2_forward_inputs(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
model.to(torch_device)
model.eval()
inputs = self._prepare_for_class(inputs_dict, model_class)
with torch.no_grad():
outputs = model(**inputs, return_dict=True)
self.assertIsInstance(outputs, ColQwen2ForRetrievalOutput)
@unittest.skip(reason="Some undefined behavior encountered with test versions of Qwen2-VL. Skip for now.")
def test_model_parallelism(self):
pass
@unittest.skip(reason="Pass because ColQwen2 requires `attention_mask is not None`")
def test_sdpa_can_dispatch_on_flash(self):
pass
@unittest.skip(reason="Pass because ColQwen2 requires `attention_mask is not None`")
def test_sdpa_can_compile_dynamic(self):
pass
@require_torch
class ColQwen2ModelIntegrationTest(unittest.TestCase):
model_name: ClassVar[str] = "vidore/colqwen2-v1.0-hf"
def setUp(self):
self.processor = ColQwen2Processor.from_pretrained(self.model_name)
def tearDown(self):
gc.collect()
torch.cuda.empty_cache()
@slow
def test_model_integration_test(self):
"""
Test if the model is able to retrieve the correct pages for a small and easy dataset.
"""
model = ColQwen2ForRetrieval.from_pretrained(
self.model_name,
torch_dtype=torch.bfloat16,
device_map=torch_device,
).eval()
# Load the test dataset
ds = load_dataset("hf-internal-testing/document-visual-retrieval-test", split="test")
# Preprocess the examples
batch_images = self.processor(images=ds["image"]).to(torch_device)
batch_queries = self.processor(text=ds["query"]).to(torch_device)
# Run inference
with torch.inference_mode():
image_embeddings = model(**batch_images).embeddings
query_embeddings = model(**batch_queries).embeddings
# Compute retrieval scores
scores = self.processor.score_retrieval(
query_embeddings=query_embeddings,
passage_embeddings=image_embeddings,
) # (num_queries, num_passages)
assert scores.ndim == 2, f"Expected 2D tensor, got {scores.ndim}"
assert scores.shape == (len(ds), len(ds)), f"Expected shape {(len(ds), len(ds))}, got {scores.shape}"
# Check if the maximum scores per row are in the diagonal of the matrix score
self.assertTrue((scores.argmax(axis=1) == torch.arange(len(ds), device=scores.device)).all())
# Further validation: fine-grained check, with a hardcoded score from the original Hf implementation.
expected_scores = torch.tensor(
[
[16.2500, 7.8750, 14.6875],
[9.5000, 17.1250, 10.5000],
[14.9375, 10.9375, 20.0000],
],
dtype=scores.dtype,
)
assert torch.allclose(scores, expected_scores, atol=1e-3), f"Expected scores {expected_scores}, got {scores}"