Add Fast owlvit Processor (#37164)

* Add Fast Owlvit Processor

* Update image_processing_owlvit_fast.py

* Update image_processing_owlvit_fast.py

---------

Co-authored-by: Yoni Gozlan <74535834+yonigozlan@users.noreply.github.com>
This commit is contained in:
Parteek 2025-04-14 21:28:09 +05:30 committed by GitHub
parent cb39f7dd5b
commit 20ceaca228
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 269 additions and 17 deletions

View File

@ -94,6 +94,11 @@ A demo notebook on using OWL-ViT for zero- and one-shot (image-guided) object de
[[autodoc]] OwlViTImageProcessor [[autodoc]] OwlViTImageProcessor
- preprocess - preprocess
## OwlViTImageProcessorFast
[[autodoc]] OwlViTImageProcessorFast
- preprocess
- post_process_object_detection - post_process_object_detection
- post_process_image_guided_detection - post_process_image_guided_detection

View File

@ -123,7 +123,7 @@ else:
("nougat", ("NougatImageProcessor",)), ("nougat", ("NougatImageProcessor",)),
("oneformer", ("OneFormerImageProcessor",)), ("oneformer", ("OneFormerImageProcessor",)),
("owlv2", ("Owlv2ImageProcessor",)), ("owlv2", ("Owlv2ImageProcessor",)),
("owlvit", ("OwlViTImageProcessor",)), ("owlvit", ("OwlViTImageProcessor", "OwlViTImageProcessorFast")),
("paligemma", ("SiglipImageProcessor", "SiglipImageProcessorFast")), ("paligemma", ("SiglipImageProcessor", "SiglipImageProcessorFast")),
("perceiver", ("PerceiverImageProcessor", "PerceiverImageProcessorFast")), ("perceiver", ("PerceiverImageProcessor", "PerceiverImageProcessorFast")),
("phi4_multimodal", "Phi4MultimodalImageProcessorFast"), ("phi4_multimodal", "Phi4MultimodalImageProcessorFast"),

View File

@ -21,6 +21,7 @@ if TYPE_CHECKING:
from .configuration_owlvit import * from .configuration_owlvit import *
from .feature_extraction_owlvit import * from .feature_extraction_owlvit import *
from .image_processing_owlvit import * from .image_processing_owlvit import *
from .image_processing_owlvit_fast import *
from .modeling_owlvit import * from .modeling_owlvit import *
from .processing_owlvit import * from .processing_owlvit import *
else: else:

View File

@ -0,0 +1,240 @@
# coding=utf-8
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Fast Image processor class for OwlViT"""
import warnings
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
from ...image_processing_utils_fast import (
BASE_IMAGE_PROCESSOR_FAST_DOCSTRING,
BaseImageProcessorFast,
)
from ...image_transforms import center_to_corners_format
from ...image_utils import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD, PILImageResampling
from ...utils import TensorType, add_start_docstrings, is_torch_available, logging
if TYPE_CHECKING:
from .modeling_owlvit import OwlViTObjectDetectionOutput
if is_torch_available():
import torch
from .image_processing_owlvit import _scale_boxes, box_iou
logger = logging.get_logger(__name__)
@add_start_docstrings(
"Constructs a fast OwlViT image processor.",
BASE_IMAGE_PROCESSOR_FAST_DOCSTRING,
)
class OwlViTImageProcessorFast(BaseImageProcessorFast):
resample = PILImageResampling.BICUBIC
image_mean = OPENAI_CLIP_MEAN
image_std = OPENAI_CLIP_STD
size = {"height": 768, "width": 768}
default_to_square = True
crop_size = {"height": 768, "width": 768}
do_resize = True
do_center_crop = False
do_rescale = True
do_normalize = None
do_convert_rgb = None
model_input_names = ["pixel_values"]
# Copied from transformers.models.owlvit.image_processing_owlvit.OwlViTImageProcessor.post_process
def post_process(self, outputs, target_sizes):
"""
Converts the raw output of [`OwlViTForObjectDetection`] into final bounding boxes in (top_left_x, top_left_y,
bottom_right_x, bottom_right_y) format.
Args:
outputs ([`OwlViTObjectDetectionOutput`]):
Raw outputs of the model.
target_sizes (`torch.Tensor` of shape `(batch_size, 2)`):
Tensor containing the size (h, w) of each image of the batch. For evaluation, this must be the original
image size (before any data augmentation). For visualization, this should be the image size after data
augment, but before padding.
Returns:
`List[Dict]`: A list of dictionaries, each dictionary containing the scores, labels and boxes for an image
in the batch as predicted by the model.
"""
# TODO: (amy) add support for other frameworks
warnings.warn(
"`post_process` is deprecated and will be removed in v5 of Transformers, please use"
" `post_process_object_detection` instead, with `threshold=0.` for equivalent results.",
FutureWarning,
)
logits, boxes = outputs.logits, outputs.pred_boxes
if len(logits) != len(target_sizes):
raise ValueError("Make sure that you pass in as many target sizes as the batch dimension of the logits")
if target_sizes.shape[1] != 2:
raise ValueError("Each element of target_sizes must contain the size (h, w) of each image of the batch")
probs = torch.max(logits, dim=-1)
scores = torch.sigmoid(probs.values)
labels = probs.indices
# Convert to [x0, y0, x1, y1] format
boxes = center_to_corners_format(boxes)
# Convert from relative [0, 1] to absolute [0, height] coordinates
img_h, img_w = target_sizes.unbind(1)
scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1).to(boxes.device)
boxes = boxes * scale_fct[:, None, :]
results = [{"scores": s, "labels": l, "boxes": b} for s, l, b in zip(scores, labels, boxes)]
return results
# Copied from transformers.models.owlvit.image_processing_owlvit.OwlViTImageProcessor.post_process_object_detection
def post_process_object_detection(
self,
outputs: "OwlViTObjectDetectionOutput",
threshold: float = 0.1,
target_sizes: Optional[Union[TensorType, List[Tuple]]] = None,
):
"""
Converts the raw output of [`OwlViTForObjectDetection`] into final bounding boxes in (top_left_x, top_left_y,
bottom_right_x, bottom_right_y) format.
Args:
outputs ([`OwlViTObjectDetectionOutput`]):
Raw outputs of the model.
threshold (`float`, *optional*, defaults to 0.1):
Score threshold to keep object detection predictions.
target_sizes (`torch.Tensor` or `List[Tuple[int, int]]`, *optional*):
Tensor of shape `(batch_size, 2)` or list of tuples (`Tuple[int, int]`) containing the target size
`(height, width)` of each image in the batch. If unset, predictions will not be resized.
Returns:
`List[Dict]`: A list of dictionaries, each dictionary containing the following keys:
- "scores": The confidence scores for each predicted box on the image.
- "labels": Indexes of the classes predicted by the model on the image.
- "boxes": Image bounding boxes in (top_left_x, top_left_y, bottom_right_x, bottom_right_y) format.
"""
batch_logits, batch_boxes = outputs.logits, outputs.pred_boxes
batch_size = len(batch_logits)
if target_sizes is not None and len(target_sizes) != batch_size:
raise ValueError("Make sure that you pass in as many target sizes as images")
# batch_logits of shape (batch_size, num_queries, num_classes)
batch_class_logits = torch.max(batch_logits, dim=-1)
batch_scores = torch.sigmoid(batch_class_logits.values)
batch_labels = batch_class_logits.indices
# Convert to [x0, y0, x1, y1] format
batch_boxes = center_to_corners_format(batch_boxes)
# Convert from relative [0, 1] to absolute [0, height] coordinates
if target_sizes is not None:
batch_boxes = _scale_boxes(batch_boxes, target_sizes)
results = []
for scores, labels, boxes in zip(batch_scores, batch_labels, batch_boxes):
keep = scores > threshold
scores = scores[keep]
labels = labels[keep]
boxes = boxes[keep]
results.append({"scores": scores, "labels": labels, "boxes": boxes})
return results
# Copied from transformers.models.owlvit.image_processing_owlvit.OwlViTImageProcessor.post_process_image_guided_detection
def post_process_image_guided_detection(self, outputs, threshold=0.0, nms_threshold=0.3, target_sizes=None):
"""
Converts the output of [`OwlViTForObjectDetection.image_guided_detection`] into the format expected by the COCO
api.
Args:
outputs ([`OwlViTImageGuidedObjectDetectionOutput`]):
Raw outputs of the model.
threshold (`float`, *optional*, defaults to 0.0):
Minimum confidence threshold to use to filter out predicted boxes.
nms_threshold (`float`, *optional*, defaults to 0.3):
IoU threshold for non-maximum suppression of overlapping boxes.
target_sizes (`torch.Tensor`, *optional*):
Tensor of shape (batch_size, 2) where each entry is the (height, width) of the corresponding image in
the batch. If set, predicted normalized bounding boxes are rescaled to the target sizes. If left to
None, predictions will not be unnormalized.
Returns:
`List[Dict]`: A list of dictionaries, each dictionary containing the scores, labels and boxes for an image
in the batch as predicted by the model. All labels are set to None as
`OwlViTForObjectDetection.image_guided_detection` perform one-shot object detection.
"""
logits, target_boxes = outputs.logits, outputs.target_pred_boxes
if target_sizes is not None and len(logits) != len(target_sizes):
raise ValueError("Make sure that you pass in as many target sizes as the batch dimension of the logits")
if target_sizes is not None and target_sizes.shape[1] != 2:
raise ValueError("Each element of target_sizes must contain the size (h, w) of each image of the batch")
probs = torch.max(logits, dim=-1)
scores = torch.sigmoid(probs.values)
# Convert to [x0, y0, x1, y1] format
target_boxes = center_to_corners_format(target_boxes)
# Apply non-maximum suppression (NMS)
if nms_threshold < 1.0:
for idx in range(target_boxes.shape[0]):
for i in torch.argsort(-scores[idx]):
if not scores[idx][i]:
continue
ious = box_iou(target_boxes[idx][i, :].unsqueeze(0), target_boxes[idx])[0][0]
ious[i] = -1.0 # Mask self-IoU.
scores[idx][ious > nms_threshold] = 0.0
# Convert from relative [0, 1] to absolute [0, height] coordinates
if target_sizes is not None:
target_boxes = _scale_boxes(target_boxes, target_sizes)
# Compute box display alphas based on prediction scores
results = []
alphas = torch.zeros_like(scores)
for idx in range(target_boxes.shape[0]):
# Select scores for boxes matching the current query:
query_scores = scores[idx]
if not query_scores.nonzero().numel():
continue
# Apply threshold on scores before scaling
query_scores[query_scores < threshold] = 0.0
# Scale box alpha such that the best box for each query has alpha 1.0 and the worst box has alpha 0.1.
# All other boxes will either belong to a different query, or will not be shown.
max_score = torch.max(query_scores) + 1e-6
query_alphas = (query_scores - (max_score * 0.1)) / (max_score * 0.9)
query_alphas = torch.clip(query_alphas, 0.0, 1.0)
alphas[idx] = query_alphas
mask = alphas[idx] > 0
box_scores = alphas[idx][mask]
boxes = target_boxes[idx][mask]
results.append({"scores": box_scores, "labels": None, "boxes": boxes})
return results
__all__ = ["OwlViTImageProcessorFast"]

View File

@ -16,7 +16,7 @@
import unittest import unittest
from transformers.testing_utils import require_torch, require_vision from transformers.testing_utils import require_torch, require_vision
from transformers.utils import is_vision_available from transformers.utils import is_torchvision_available, is_vision_available
from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs
@ -24,6 +24,9 @@ from ...test_image_processing_common import ImageProcessingTestMixin, prepare_im
if is_vision_available(): if is_vision_available():
from transformers import OwlViTImageProcessor from transformers import OwlViTImageProcessor
if is_torchvision_available():
from transformers import OwlViTImageProcessorFast
class OwlViTImageProcessingTester: class OwlViTImageProcessingTester:
def __init__( def __init__(
@ -89,6 +92,7 @@ class OwlViTImageProcessingTester:
@require_vision @require_vision
class OwlViTImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase): class OwlViTImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
image_processing_class = OwlViTImageProcessor if is_vision_available() else None image_processing_class = OwlViTImageProcessor if is_vision_available() else None
fast_image_processing_class = OwlViTImageProcessorFast if is_torchvision_available() else None
def setUp(self): def setUp(self):
super().setUp() super().setUp()
@ -99,21 +103,23 @@ class OwlViTImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
return self.image_processor_tester.prepare_image_processor_dict() return self.image_processor_tester.prepare_image_processor_dict()
def test_image_processor_properties(self): def test_image_processor_properties(self):
image_processing = self.image_processing_class(**self.image_processor_dict) for image_processing_class in self.image_processor_list:
self.assertTrue(hasattr(image_processing, "do_resize")) image_processing = image_processing_class(**self.image_processor_dict)
self.assertTrue(hasattr(image_processing, "size")) self.assertTrue(hasattr(image_processing, "do_resize"))
self.assertTrue(hasattr(image_processing, "do_center_crop")) self.assertTrue(hasattr(image_processing, "size"))
self.assertTrue(hasattr(image_processing, "center_crop")) self.assertTrue(hasattr(image_processing, "do_center_crop"))
self.assertTrue(hasattr(image_processing, "do_normalize")) self.assertTrue(hasattr(image_processing, "center_crop"))
self.assertTrue(hasattr(image_processing, "image_mean")) self.assertTrue(hasattr(image_processing, "do_normalize"))
self.assertTrue(hasattr(image_processing, "image_std")) self.assertTrue(hasattr(image_processing, "image_mean"))
self.assertTrue(hasattr(image_processing, "do_convert_rgb")) self.assertTrue(hasattr(image_processing, "image_std"))
self.assertTrue(hasattr(image_processing, "do_convert_rgb"))
def test_image_processor_from_dict_with_kwargs(self): def test_image_processor_from_dict_with_kwargs(self):
image_processor = self.image_processing_class.from_dict(self.image_processor_dict) for image_processing_class in self.image_processor_list:
self.assertEqual(image_processor.size, {"height": 18, "width": 18}) image_processor = image_processing_class.from_dict(self.image_processor_dict)
self.assertEqual(image_processor.crop_size, {"height": 18, "width": 18}) self.assertEqual(image_processor.size, {"height": 18, "width": 18})
self.assertEqual(image_processor.crop_size, {"height": 18, "width": 18})
image_processor = self.image_processing_class.from_dict(self.image_processor_dict, size=42, crop_size=84) image_processor = image_processing_class.from_dict(self.image_processor_dict, size=42, crop_size=84)
self.assertEqual(image_processor.size, {"height": 42, "width": 42}) self.assertEqual(image_processor.size, {"height": 42, "width": 42})
self.assertEqual(image_processor.crop_size, {"height": 84, "width": 84}) self.assertEqual(image_processor.crop_size, {"height": 84, "width": 84})