# ColPali
[ColPali](https://huggingface.co/papers/2407.01449) is a model designed to retrieve documents by analyzing their visual features. Unlike traditional systems that rely heavily on text extraction and OCR, ColPali treats each page as an image. It uses [Paligemma-3B](./paligemma) to capture not only text, but also the layout, tables, charts, and other visual elements to create detailed multi-vector embeddings that can be used for retrieval by computing pairwise late interaction similarity scores. This offers a more comprehensive understanding of documents and enables more efficient and accurate retrieval.
This model was contributed by [@tonywu71](https://huggingface.co/tonywu71) (ILLUIN Technology) and [@yonigozlan](https://huggingface.co/yonigozlan) (HuggingFace).
You can find all the original ColPali checkpoints under Vidore's [Hf-native ColVision Models](https://huggingface.co/collections/vidore/hf-native-colvision-models-6755d68fc60a8553acaa96f7) collection.
> [!TIP]
> Click on the ColPali models in the right sidebar for more examples of how to use ColPali for image retrieval.
```python
import requests
import torch
from PIL import Image
from transformers import ColPaliForRetrieval, ColPaliProcessor
# Load the model and the processor
model_name = "vidore/colpali-v1.3-hf"
model = ColPaliForRetrieval.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
device_map="auto", # "cpu", "cuda", or "mps" for Apple Silicon
)
processor = ColPaliProcessor.from_pretrained(model_name)
# The document page screenshots from your corpus
url1 = "https://upload.wikimedia.org/wikipedia/commons/8/89/US-original-Declaration-1776.jpg"
url2 = "https://upload.wikimedia.org/wikipedia/commons/thumb/4/4c/Romeoandjuliet1597.jpg/500px-Romeoandjuliet1597.jpg"
images = [
Image.open(requests.get(url1, stream=True).raw),
Image.open(requests.get(url2, stream=True).raw),
]
# The queries you want to retrieve documents for
queries = [
"When was the United States Declaration of Independence proclaimed?",
"Who printed the edition of Romeo and Juliet?",
]
# Process the inputs
inputs_images = processor(images=images).to(model.device)
inputs_text = processor(text=queries).to(model.device)
# Forward pass
with torch.no_grad():
image_embeddings = model(**inputs_images).embeddings
query_embeddings = model(**inputs_text).embeddings
# Score the queries against the images
scores = processor.score_retrieval(query_embeddings, image_embeddings)
print("Retrieval scores (query x image):")
print(scores)
```
If you have issue with loading the images with PIL, you can use the following code to create dummy images:
```python
images = [
Image.new("RGB", (128, 128), color="white"),
Image.new("RGB", (64, 32), color="black"),
]
```
Quantization reduces the memory burden of large models by representing the weights in a lower precision. Refer to the [Quantization](../quantization/overview) overview for more available quantization backends.
The example below uses [bitsandbytes](../quantization/bitsandbytes.md) to quantize the weights to int4.
```python
import requests
import torch
from PIL import Image
from transformers import BitsAndBytesConfig, ColPaliForRetrieval, ColPaliProcessor
model_name = "vidore/colpali-v1.3-hf"
# 4-bit quantization configuration
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16,
)
model = ColPaliForRetrieval.from_pretrained(
model_name,
quantization_config=bnb_config,
device_map="cuda",
)
processor = ColPaliProcessor.from_pretrained(model_name)
url1 = "https://upload.wikimedia.org/wikipedia/commons/8/89/US-original-Declaration-1776.jpg"
url2 = "https://upload.wikimedia.org/wikipedia/commons/thumb/4/4c/Romeoandjuliet1597.jpg/500px-Romeoandjuliet1597.jpg"
images = [
Image.open(requests.get(url1, stream=True).raw),
Image.open(requests.get(url2, stream=True).raw),
]
queries = [
"When was the United States Declaration of Independence proclaimed?",
"Who printed the edition of Romeo and Juliet?",
]
# Process the inputs
inputs_images = processor(images=images, return_tensors="pt").to(model.device)
inputs_text = processor(text=queries, return_tensors="pt").to(model.device)
# Forward pass
with torch.no_grad():
image_embeddings = model(**inputs_images).embeddings
query_embeddings = model(**inputs_text).embeddings
# Score the queries against the images
scores = processor.score_retrieval(query_embeddings, image_embeddings)
print("Retrieval scores (query x image):")
print(scores)
```
## Notes
- [`~ColPaliProcessor.score_retrieval`] returns a 2D tensor where the first dimension is the number of queries and the second dimension is the number of images. A higher score indicates more similarity between the query and image.
## ColPaliConfig
[[autodoc]] ColPaliConfig
## ColPaliProcessor
[[autodoc]] ColPaliProcessor
## ColPaliForRetrieval
[[autodoc]] ColPaliForRetrieval
- forward