mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 21:00:08 +06:00

* feat: run `add-new-model-like` * feat: add paligemma code with "copied from" * feat: add ColPaliProcessor * feat: add ColPaliModel * feat: add ColPaliConfig * feat: rename `ColPaliForConditionalGeneration` to `ColPaliModel` * fixup modeling colpali * fix: fix root import shortcuts * fix: fix `modeling_auto` dict * feat: comment out ColPali test file * fix: fix typos from `add-new-model-like` * feat: explicit the forward input args * feat: move everything to `modular_colpali.py` * fix: put back ColPaliProcesor * feat: add auto-generated files * fix: run `fix-copies` * fix: remove DOCStRING constants to make modular converter work * fix: fix typo + modular converter * fix: add missing imports * feat: no more errors when loading ColPaliModel * fix: remove unused args in forward + tweak doc * feat: rename `ColPaliModel` to `ColPaliForRetrieval` * fix: apply `fix-copies` * feat: add ColPaliProcessor to `modular_colpali` * fix: run make quality + make style * fix: remove duplicate line in configuration_auto * feat: make ColPaliModel inehrit from PaliGemmaForConditionalGeneration * fix: tweak and use ColPaliConfig * feat: rename `score` to `post_process_retrieval` * build: run modular formatter + make style * feat: convert colpali weights + fixes * feat: remove old weight converter file * feat: add and validate tests * feat: replace harcoded path to "vidore/colpali-v1.2-hf" in tests * fix: add bfloat16 conversion in weight converter * feat: replace pytest with unittest in modeling colpali test * feat: add sanity check for weight conversion (doesn't work yet) * feat: add shape sanity check in weigth converter * feat: make ColPaliProcessor args explicit * doc: add doc for ColPali * fix: trying to fix output mismatch * feat: tweaks * fix: ColPaliModelOutput inherits from ModelOutput instead of PaliGemmaCausalLMOutputWithPast * fix: address comments on PR * fix: adapt tests to the Hf norm * wip: try things * feat: add `__call__` method to `ColPaliProcessor` * feat: remove need for dummy image in `process_queries` * build: run new modular converter * fix: fix incorrect method override * Fix tests, processing, modular, convert * fix tokenization auto * hotfix: manually fix processor -> fixme once convert modular is fixed * fix: convert weights working * feat: rename and improve convert weight script * feat: tweaks * fest: remove `device` input for `post_process_retrieval` * refactor: remove unused `get_torch_device` * Fix all tests * docs: update ColPali model doc * wip: fix convert weights to hf * fix logging modular * docs: add acknowledgements in model doc * docs: add missing docstring to ColPaliProcessor * docs: tweak * docs: add doc for `ColPaliForRetrievalOutput.forward` * feat: add modifications from colpali-engine v0.3.2 in ColPaliProcessor * fix: fix and upload colapli hf weights * refactor: rename `post_process_retrieval` to `score_retrieval` * fix: fix wrong typing for `score_retrieval` * test: add integration test for ColPali * chore: rerun convert modular * build: fix root imports * Update docs/source/en/index.md Co-authored-by: Yoni Gozlan <74535834+yonigozlan@users.noreply.github.com> * fix: address PR comments * wip: reduce the prediction gap in weight conversion * docs: add comment in weight conversion script * docs: add example for `ColPaliForRetrieval.forward` * tests: change dataset path to the new one in hf-internal * fix: colpali weight conversion works * test: add fine-grained check for ColPali integration test * fix: fix typos in convert weight script * docs: move input docstring in a variable * fix: remove hardcoded torch device in test * fix: run the new modular refactor * docs: fix python example for ColPali * feat: add option to choose `score_retrieval`'s output dtype and device * docs: update doc for `score_retrieval` * feat: add `patch_size` property in ColPali model * chore: run `make fix-copies` * docs: update description for ColPali cookbooks * fix: remove `ignore_index` methods * feat: remove non-transformers specific methods * feat: update `__init__.py` to new hf format * fix: fix root imports in transformers * feat: remove ColPali's inheritance from PaliGemma * Fix CI issues * nit remove prints * feat: remove ColPali config and model from `modular_colpali.py` * feat: add `ColPaliPreTrainedModel` and update modeling and configuration code * fix: fix auto-removed imports in root `__init__.py` * fix: various fixes * fix: fix `_init_weight` * temp: comment `AutoModel.from_config` for experiments * fix: add missing `output_attentions` arg in ColPali's forward * fix: fix `resize_token_embeddings` * fix: make `input_ids` optional in forward * feat: rename `projection_layer` to `embedding_proj_layer` * wip: fix convert colpali weight script * fix tests and convert weights from original repo * fix unprotected import * fix unprotected torch import * fix style * change vlm_backbone_config to vlm_config * fix unprotected import in modular this time * fix: load config from Hub + tweaks in convert weight script * docs: move example usage from model docstring to model markdown * docs: fix input docstring for ColPali's forward method * fix: use `sub_configs` for ColPaliConfig * fix: remove non-needed sanity checks in weight conversion script + tweaks * fix: fix issue with `replace_return_docstrings` in ColPali's `forward` * docs: update docstring for `ColPaliConfig` * test: change model path in ColPali test * fix: fix ColPaliConfig * fix: fix weight conversion script * test: fix expected weights for ColPali model * docs: update ColPali markdown * docs: fix minor typo in ColPaliProcessor * Fix tests and add _no_split_modules * add text_config to colpali config * [run slow] colpali * move inputs to torch_device in integration test * skip test_model_parallelism * docs: clarify quickstart snippet in ColPali's model card * docs: update ColPali's model card --------- Co-authored-by: yonigozlan <yoni.gozlan@huggingface.co> Co-authored-by: Yoni Gozlan <74535834+yonigozlan@users.noreply.github.com>
248 lines
10 KiB
Python
248 lines
10 KiB
Python
import shutil
|
|
import tempfile
|
|
import unittest
|
|
|
|
import torch
|
|
|
|
from transformers import GemmaTokenizer
|
|
from transformers.models.colpali.processing_colpali import ColPaliProcessor
|
|
from transformers.testing_utils import get_tests_dir, require_torch, require_vision
|
|
from transformers.utils import is_vision_available
|
|
from transformers.utils.dummy_vision_objects import SiglipImageProcessor
|
|
|
|
from ...test_processing_common import ProcessorTesterMixin
|
|
|
|
|
|
if is_vision_available():
|
|
from transformers import (
|
|
ColPaliProcessor,
|
|
PaliGemmaProcessor,
|
|
SiglipImageProcessor,
|
|
)
|
|
|
|
SAMPLE_VOCAB = get_tests_dir("fixtures/test_sentencepiece.model")
|
|
|
|
|
|
@require_vision
|
|
class ColPaliProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
|
processor_class = ColPaliProcessor
|
|
|
|
def setUp(self):
|
|
self.tmpdirname = tempfile.mkdtemp()
|
|
image_processor = SiglipImageProcessor.from_pretrained("google/siglip-so400m-patch14-384")
|
|
image_processor.image_seq_length = 0
|
|
tokenizer = GemmaTokenizer(SAMPLE_VOCAB, keep_accents=True)
|
|
processor = PaliGemmaProcessor(image_processor=image_processor, tokenizer=tokenizer)
|
|
processor.save_pretrained(self.tmpdirname)
|
|
|
|
def tearDown(self):
|
|
shutil.rmtree(self.tmpdirname)
|
|
|
|
@require_torch
|
|
@require_vision
|
|
def test_process_images(self):
|
|
# Processor configuration
|
|
image_input = self.prepare_image_inputs()
|
|
image_processor = self.get_component("image_processor")
|
|
tokenizer = self.get_component("tokenizer", max_length=112, padding="max_length")
|
|
image_processor.image_seq_length = 14
|
|
|
|
# Get the processor
|
|
processor = self.processor_class(
|
|
tokenizer=tokenizer,
|
|
image_processor=image_processor,
|
|
)
|
|
|
|
# Process the image
|
|
batch_feature = processor.process_images(images=image_input, return_tensors="pt")
|
|
|
|
# Assertions
|
|
self.assertIn("pixel_values", batch_feature)
|
|
self.assertEqual(batch_feature["pixel_values"].shape, torch.Size([1, 3, 384, 384]))
|
|
|
|
@require_torch
|
|
@require_vision
|
|
def test_process_queries(self):
|
|
# Inputs
|
|
queries = [
|
|
"Is attention really all you need?",
|
|
"Are Benjamin, Antoine, Merve, and Jo best friends?",
|
|
]
|
|
|
|
# Processor configuration
|
|
image_processor = self.get_component("image_processor")
|
|
tokenizer = self.get_component("tokenizer", max_length=112, padding="max_length")
|
|
image_processor.image_seq_length = 14
|
|
|
|
# Get the processor
|
|
processor = self.processor_class(
|
|
tokenizer=tokenizer,
|
|
image_processor=image_processor,
|
|
)
|
|
|
|
# Process the image
|
|
batch_feature = processor.process_queries(text=queries, return_tensors="pt")
|
|
|
|
# Assertions
|
|
self.assertIn("input_ids", batch_feature)
|
|
self.assertIsInstance(batch_feature["input_ids"], torch.Tensor)
|
|
self.assertEqual(batch_feature["input_ids"].shape[0], len(queries))
|
|
|
|
# The following tests are overwritten as ColPaliProcessor can only take one of images or text as input at a time
|
|
|
|
def test_tokenizer_defaults_preserved_by_kwargs(self):
|
|
if "image_processor" not in self.processor_class.attributes:
|
|
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
|
|
processor_components = self.prepare_components()
|
|
processor_components["tokenizer"] = self.get_component("tokenizer", max_length=117, padding="max_length")
|
|
|
|
processor = self.processor_class(**processor_components)
|
|
self.skip_processor_without_typed_kwargs(processor)
|
|
input_str = self.prepare_text_inputs()
|
|
inputs = processor(text=input_str, return_tensors="pt")
|
|
self.assertEqual(inputs[self.text_input_name].shape[-1], 117)
|
|
|
|
def test_image_processor_defaults_preserved_by_image_kwargs(self):
|
|
"""
|
|
We use do_rescale=True, rescale_factor=-1 to ensure that image_processor kwargs are preserved in the processor.
|
|
We then check that the mean of the pixel_values is less than or equal to 0 after processing.
|
|
Since the original pixel_values are in [0, 255], this is a good indicator that the rescale_factor is indeed applied.
|
|
"""
|
|
if "image_processor" not in self.processor_class.attributes:
|
|
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
|
|
processor_components = self.prepare_components()
|
|
processor_components["image_processor"] = self.get_component(
|
|
"image_processor", do_rescale=True, rescale_factor=-1
|
|
)
|
|
processor_components["tokenizer"] = self.get_component("tokenizer", max_length=117, padding="max_length")
|
|
|
|
processor = self.processor_class(**processor_components)
|
|
self.skip_processor_without_typed_kwargs(processor)
|
|
|
|
image_input = self.prepare_image_inputs()
|
|
|
|
inputs = processor(images=image_input, return_tensors="pt")
|
|
self.assertLessEqual(inputs[self.images_input_name][0][0].mean(), 0)
|
|
|
|
def test_kwargs_overrides_default_tokenizer_kwargs(self):
|
|
if "image_processor" not in self.processor_class.attributes:
|
|
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
|
|
processor_components = self.prepare_components()
|
|
processor_components["tokenizer"] = self.get_component("tokenizer", padding="longest")
|
|
|
|
processor = self.processor_class(**processor_components)
|
|
self.skip_processor_without_typed_kwargs(processor)
|
|
input_str = self.prepare_text_inputs()
|
|
inputs = processor(text=input_str, return_tensors="pt", max_length=112, padding="max_length")
|
|
self.assertEqual(inputs[self.text_input_name].shape[-1], 112)
|
|
|
|
def test_kwargs_overrides_default_image_processor_kwargs(self):
|
|
if "image_processor" not in self.processor_class.attributes:
|
|
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
|
|
processor_components = self.prepare_components()
|
|
processor_components["image_processor"] = self.get_component(
|
|
"image_processor", do_rescale=True, rescale_factor=1
|
|
)
|
|
processor_components["tokenizer"] = self.get_component("tokenizer", max_length=117, padding="max_length")
|
|
|
|
processor = self.processor_class(**processor_components)
|
|
self.skip_processor_without_typed_kwargs(processor)
|
|
|
|
image_input = self.prepare_image_inputs()
|
|
|
|
inputs = processor(images=image_input, do_rescale=True, rescale_factor=-1, return_tensors="pt")
|
|
self.assertLessEqual(inputs[self.images_input_name][0][0].mean(), 0)
|
|
|
|
def test_unstructured_kwargs(self):
|
|
if "image_processor" not in self.processor_class.attributes:
|
|
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
|
|
processor_components = self.prepare_components()
|
|
processor = self.processor_class(**processor_components)
|
|
self.skip_processor_without_typed_kwargs(processor)
|
|
|
|
input_str = self.prepare_text_inputs()
|
|
inputs = processor(
|
|
text=input_str,
|
|
return_tensors="pt",
|
|
do_rescale=True,
|
|
rescale_factor=-1,
|
|
padding="max_length",
|
|
max_length=76,
|
|
)
|
|
|
|
self.assertEqual(inputs[self.text_input_name].shape[-1], 76)
|
|
|
|
def test_unstructured_kwargs_batched(self):
|
|
if "image_processor" not in self.processor_class.attributes:
|
|
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
|
|
processor_components = self.prepare_components()
|
|
processor = self.processor_class(**processor_components)
|
|
self.skip_processor_without_typed_kwargs(processor)
|
|
|
|
image_input = self.prepare_image_inputs(batch_size=2)
|
|
inputs = processor(
|
|
images=image_input,
|
|
return_tensors="pt",
|
|
do_rescale=True,
|
|
rescale_factor=-1,
|
|
padding="longest",
|
|
max_length=76,
|
|
)
|
|
|
|
self.assertLessEqual(inputs[self.images_input_name][0][0].mean(), 0)
|
|
|
|
def test_doubly_passed_kwargs(self):
|
|
if "image_processor" not in self.processor_class.attributes:
|
|
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
|
|
processor_components = self.prepare_components()
|
|
processor = self.processor_class(**processor_components)
|
|
self.skip_processor_without_typed_kwargs(processor)
|
|
|
|
image_input = self.prepare_image_inputs()
|
|
with self.assertRaises(ValueError):
|
|
_ = processor(
|
|
images=image_input,
|
|
images_kwargs={"do_rescale": True, "rescale_factor": -1},
|
|
do_rescale=True,
|
|
return_tensors="pt",
|
|
)
|
|
|
|
def test_structured_kwargs_nested(self):
|
|
if "image_processor" not in self.processor_class.attributes:
|
|
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
|
|
processor_components = self.prepare_components()
|
|
processor = self.processor_class(**processor_components)
|
|
self.skip_processor_without_typed_kwargs(processor)
|
|
|
|
input_str = self.prepare_text_inputs()
|
|
|
|
# Define the kwargs for each modality
|
|
all_kwargs = {
|
|
"common_kwargs": {"return_tensors": "pt"},
|
|
"images_kwargs": {"do_rescale": True, "rescale_factor": -1},
|
|
"text_kwargs": {"padding": "max_length", "max_length": 76},
|
|
}
|
|
|
|
inputs = processor(text=input_str, **all_kwargs)
|
|
self.skip_processor_without_typed_kwargs(processor)
|
|
|
|
self.assertEqual(inputs[self.text_input_name].shape[-1], 76)
|
|
|
|
def test_structured_kwargs_nested_from_dict(self):
|
|
if "image_processor" not in self.processor_class.attributes:
|
|
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
|
|
processor_components = self.prepare_components()
|
|
processor = self.processor_class(**processor_components)
|
|
self.skip_processor_without_typed_kwargs(processor)
|
|
image_input = self.prepare_image_inputs()
|
|
|
|
# Define the kwargs for each modality
|
|
all_kwargs = {
|
|
"common_kwargs": {"return_tensors": "pt"},
|
|
"images_kwargs": {"do_rescale": True, "rescale_factor": -1},
|
|
"text_kwargs": {"padding": "max_length", "max_length": 76},
|
|
}
|
|
|
|
inputs = processor(images=image_input, **all_kwargs)
|
|
self.assertEqual(inputs[self.text_input_name].shape[-1], 76)
|