diff --git a/docs/source/en/model_doc/owlvit.md b/docs/source/en/model_doc/owlvit.md index 5be8ffc8f58..bbcfe0de9f9 100644 --- a/docs/source/en/model_doc/owlvit.md +++ b/docs/source/en/model_doc/owlvit.md @@ -94,6 +94,11 @@ A demo notebook on using OWL-ViT for zero- and one-shot (image-guided) object de [[autodoc]] OwlViTImageProcessor - preprocess + +## OwlViTImageProcessorFast + +[[autodoc]] OwlViTImageProcessorFast + - preprocess - post_process_object_detection - post_process_image_guided_detection diff --git a/src/transformers/models/auto/image_processing_auto.py b/src/transformers/models/auto/image_processing_auto.py index 3389261575c..caeed802811 100644 --- a/src/transformers/models/auto/image_processing_auto.py +++ b/src/transformers/models/auto/image_processing_auto.py @@ -123,7 +123,7 @@ else: ("nougat", ("NougatImageProcessor",)), ("oneformer", ("OneFormerImageProcessor",)), ("owlv2", ("Owlv2ImageProcessor",)), - ("owlvit", ("OwlViTImageProcessor",)), + ("owlvit", ("OwlViTImageProcessor", "OwlViTImageProcessorFast")), ("paligemma", ("SiglipImageProcessor", "SiglipImageProcessorFast")), ("perceiver", ("PerceiverImageProcessor", "PerceiverImageProcessorFast")), ("phi4_multimodal", "Phi4MultimodalImageProcessorFast"), diff --git a/src/transformers/models/owlvit/__init__.py b/src/transformers/models/owlvit/__init__.py index 3f36ea5e4bd..b3fd80c2bcc 100644 --- a/src/transformers/models/owlvit/__init__.py +++ b/src/transformers/models/owlvit/__init__.py @@ -21,6 +21,7 @@ if TYPE_CHECKING: from .configuration_owlvit import * from .feature_extraction_owlvit import * from .image_processing_owlvit import * + from .image_processing_owlvit_fast import * from .modeling_owlvit import * from .processing_owlvit import * else: diff --git a/src/transformers/models/owlvit/image_processing_owlvit_fast.py b/src/transformers/models/owlvit/image_processing_owlvit_fast.py new file mode 100644 index 00000000000..f048f1a0da5 --- /dev/null +++ b/src/transformers/models/owlvit/image_processing_owlvit_fast.py @@ -0,0 +1,240 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Fast Image processor class for OwlViT""" + +import warnings +from typing import TYPE_CHECKING, List, Optional, Tuple, Union + +from ...image_processing_utils_fast import ( + BASE_IMAGE_PROCESSOR_FAST_DOCSTRING, + BaseImageProcessorFast, +) +from ...image_transforms import center_to_corners_format +from ...image_utils import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD, PILImageResampling +from ...utils import TensorType, add_start_docstrings, is_torch_available, logging + + +if TYPE_CHECKING: + from .modeling_owlvit import OwlViTObjectDetectionOutput + + +if is_torch_available(): + import torch + + from .image_processing_owlvit import _scale_boxes, box_iou + + +logger = logging.get_logger(__name__) + + +@add_start_docstrings( + "Constructs a fast OwlViT image processor.", + BASE_IMAGE_PROCESSOR_FAST_DOCSTRING, +) +class OwlViTImageProcessorFast(BaseImageProcessorFast): + resample = PILImageResampling.BICUBIC + image_mean = OPENAI_CLIP_MEAN + image_std = OPENAI_CLIP_STD + size = {"height": 768, "width": 768} + default_to_square = True + crop_size = {"height": 768, "width": 768} + do_resize = True + do_center_crop = False + do_rescale = True + do_normalize = None + do_convert_rgb = None + model_input_names = ["pixel_values"] + + # Copied from transformers.models.owlvit.image_processing_owlvit.OwlViTImageProcessor.post_process + def post_process(self, outputs, target_sizes): + """ + Converts the raw output of [`OwlViTForObjectDetection`] into final bounding boxes in (top_left_x, top_left_y, + bottom_right_x, bottom_right_y) format. + + Args: + outputs ([`OwlViTObjectDetectionOutput`]): + Raw outputs of the model. + target_sizes (`torch.Tensor` of shape `(batch_size, 2)`): + Tensor containing the size (h, w) of each image of the batch. For evaluation, this must be the original + image size (before any data augmentation). For visualization, this should be the image size after data + augment, but before padding. + Returns: + `List[Dict]`: A list of dictionaries, each dictionary containing the scores, labels and boxes for an image + in the batch as predicted by the model. + """ + # TODO: (amy) add support for other frameworks + warnings.warn( + "`post_process` is deprecated and will be removed in v5 of Transformers, please use" + " `post_process_object_detection` instead, with `threshold=0.` for equivalent results.", + FutureWarning, + ) + + logits, boxes = outputs.logits, outputs.pred_boxes + + if len(logits) != len(target_sizes): + raise ValueError("Make sure that you pass in as many target sizes as the batch dimension of the logits") + if target_sizes.shape[1] != 2: + raise ValueError("Each element of target_sizes must contain the size (h, w) of each image of the batch") + + probs = torch.max(logits, dim=-1) + scores = torch.sigmoid(probs.values) + labels = probs.indices + + # Convert to [x0, y0, x1, y1] format + boxes = center_to_corners_format(boxes) + + # Convert from relative [0, 1] to absolute [0, height] coordinates + img_h, img_w = target_sizes.unbind(1) + scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1).to(boxes.device) + boxes = boxes * scale_fct[:, None, :] + + results = [{"scores": s, "labels": l, "boxes": b} for s, l, b in zip(scores, labels, boxes)] + + return results + + # Copied from transformers.models.owlvit.image_processing_owlvit.OwlViTImageProcessor.post_process_object_detection + def post_process_object_detection( + self, + outputs: "OwlViTObjectDetectionOutput", + threshold: float = 0.1, + target_sizes: Optional[Union[TensorType, List[Tuple]]] = None, + ): + """ + Converts the raw output of [`OwlViTForObjectDetection`] into final bounding boxes in (top_left_x, top_left_y, + bottom_right_x, bottom_right_y) format. + + Args: + outputs ([`OwlViTObjectDetectionOutput`]): + Raw outputs of the model. + threshold (`float`, *optional*, defaults to 0.1): + Score threshold to keep object detection predictions. + target_sizes (`torch.Tensor` or `List[Tuple[int, int]]`, *optional*): + Tensor of shape `(batch_size, 2)` or list of tuples (`Tuple[int, int]`) containing the target size + `(height, width)` of each image in the batch. If unset, predictions will not be resized. + + Returns: + `List[Dict]`: A list of dictionaries, each dictionary containing the following keys: + - "scores": The confidence scores for each predicted box on the image. + - "labels": Indexes of the classes predicted by the model on the image. + - "boxes": Image bounding boxes in (top_left_x, top_left_y, bottom_right_x, bottom_right_y) format. + """ + batch_logits, batch_boxes = outputs.logits, outputs.pred_boxes + batch_size = len(batch_logits) + + if target_sizes is not None and len(target_sizes) != batch_size: + raise ValueError("Make sure that you pass in as many target sizes as images") + + # batch_logits of shape (batch_size, num_queries, num_classes) + batch_class_logits = torch.max(batch_logits, dim=-1) + batch_scores = torch.sigmoid(batch_class_logits.values) + batch_labels = batch_class_logits.indices + + # Convert to [x0, y0, x1, y1] format + batch_boxes = center_to_corners_format(batch_boxes) + + # Convert from relative [0, 1] to absolute [0, height] coordinates + if target_sizes is not None: + batch_boxes = _scale_boxes(batch_boxes, target_sizes) + + results = [] + for scores, labels, boxes in zip(batch_scores, batch_labels, batch_boxes): + keep = scores > threshold + scores = scores[keep] + labels = labels[keep] + boxes = boxes[keep] + results.append({"scores": scores, "labels": labels, "boxes": boxes}) + + return results + + # Copied from transformers.models.owlvit.image_processing_owlvit.OwlViTImageProcessor.post_process_image_guided_detection + def post_process_image_guided_detection(self, outputs, threshold=0.0, nms_threshold=0.3, target_sizes=None): + """ + Converts the output of [`OwlViTForObjectDetection.image_guided_detection`] into the format expected by the COCO + api. + + Args: + outputs ([`OwlViTImageGuidedObjectDetectionOutput`]): + Raw outputs of the model. + threshold (`float`, *optional*, defaults to 0.0): + Minimum confidence threshold to use to filter out predicted boxes. + nms_threshold (`float`, *optional*, defaults to 0.3): + IoU threshold for non-maximum suppression of overlapping boxes. + target_sizes (`torch.Tensor`, *optional*): + Tensor of shape (batch_size, 2) where each entry is the (height, width) of the corresponding image in + the batch. If set, predicted normalized bounding boxes are rescaled to the target sizes. If left to + None, predictions will not be unnormalized. + + Returns: + `List[Dict]`: A list of dictionaries, each dictionary containing the scores, labels and boxes for an image + in the batch as predicted by the model. All labels are set to None as + `OwlViTForObjectDetection.image_guided_detection` perform one-shot object detection. + """ + logits, target_boxes = outputs.logits, outputs.target_pred_boxes + + if target_sizes is not None and len(logits) != len(target_sizes): + raise ValueError("Make sure that you pass in as many target sizes as the batch dimension of the logits") + if target_sizes is not None and target_sizes.shape[1] != 2: + raise ValueError("Each element of target_sizes must contain the size (h, w) of each image of the batch") + + probs = torch.max(logits, dim=-1) + scores = torch.sigmoid(probs.values) + + # Convert to [x0, y0, x1, y1] format + target_boxes = center_to_corners_format(target_boxes) + + # Apply non-maximum suppression (NMS) + if nms_threshold < 1.0: + for idx in range(target_boxes.shape[0]): + for i in torch.argsort(-scores[idx]): + if not scores[idx][i]: + continue + + ious = box_iou(target_boxes[idx][i, :].unsqueeze(0), target_boxes[idx])[0][0] + ious[i] = -1.0 # Mask self-IoU. + scores[idx][ious > nms_threshold] = 0.0 + + # Convert from relative [0, 1] to absolute [0, height] coordinates + if target_sizes is not None: + target_boxes = _scale_boxes(target_boxes, target_sizes) + + # Compute box display alphas based on prediction scores + results = [] + alphas = torch.zeros_like(scores) + + for idx in range(target_boxes.shape[0]): + # Select scores for boxes matching the current query: + query_scores = scores[idx] + if not query_scores.nonzero().numel(): + continue + + # Apply threshold on scores before scaling + query_scores[query_scores < threshold] = 0.0 + + # Scale box alpha such that the best box for each query has alpha 1.0 and the worst box has alpha 0.1. + # All other boxes will either belong to a different query, or will not be shown. + max_score = torch.max(query_scores) + 1e-6 + query_alphas = (query_scores - (max_score * 0.1)) / (max_score * 0.9) + query_alphas = torch.clip(query_alphas, 0.0, 1.0) + alphas[idx] = query_alphas + + mask = alphas[idx] > 0 + box_scores = alphas[idx][mask] + boxes = target_boxes[idx][mask] + results.append({"scores": box_scores, "labels": None, "boxes": boxes}) + + return results + + +__all__ = ["OwlViTImageProcessorFast"] diff --git a/tests/models/owlvit/test_image_processing_owlvit.py b/tests/models/owlvit/test_image_processing_owlvit.py index 7d79fb2fb39..00b65123267 100644 --- a/tests/models/owlvit/test_image_processing_owlvit.py +++ b/tests/models/owlvit/test_image_processing_owlvit.py @@ -16,7 +16,7 @@ import unittest from transformers.testing_utils import require_torch, require_vision -from transformers.utils import is_vision_available +from transformers.utils import is_torchvision_available, is_vision_available from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs @@ -24,6 +24,9 @@ from ...test_image_processing_common import ImageProcessingTestMixin, prepare_im if is_vision_available(): from transformers import OwlViTImageProcessor + if is_torchvision_available(): + from transformers import OwlViTImageProcessorFast + class OwlViTImageProcessingTester: def __init__( @@ -89,6 +92,7 @@ class OwlViTImageProcessingTester: @require_vision class OwlViTImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase): image_processing_class = OwlViTImageProcessor if is_vision_available() else None + fast_image_processing_class = OwlViTImageProcessorFast if is_torchvision_available() else None def setUp(self): super().setUp() @@ -99,21 +103,23 @@ class OwlViTImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase): return self.image_processor_tester.prepare_image_processor_dict() def test_image_processor_properties(self): - image_processing = self.image_processing_class(**self.image_processor_dict) - self.assertTrue(hasattr(image_processing, "do_resize")) - self.assertTrue(hasattr(image_processing, "size")) - self.assertTrue(hasattr(image_processing, "do_center_crop")) - self.assertTrue(hasattr(image_processing, "center_crop")) - self.assertTrue(hasattr(image_processing, "do_normalize")) - self.assertTrue(hasattr(image_processing, "image_mean")) - self.assertTrue(hasattr(image_processing, "image_std")) - self.assertTrue(hasattr(image_processing, "do_convert_rgb")) + for image_processing_class in self.image_processor_list: + image_processing = image_processing_class(**self.image_processor_dict) + self.assertTrue(hasattr(image_processing, "do_resize")) + self.assertTrue(hasattr(image_processing, "size")) + self.assertTrue(hasattr(image_processing, "do_center_crop")) + self.assertTrue(hasattr(image_processing, "center_crop")) + self.assertTrue(hasattr(image_processing, "do_normalize")) + self.assertTrue(hasattr(image_processing, "image_mean")) + self.assertTrue(hasattr(image_processing, "image_std")) + self.assertTrue(hasattr(image_processing, "do_convert_rgb")) def test_image_processor_from_dict_with_kwargs(self): - image_processor = self.image_processing_class.from_dict(self.image_processor_dict) - self.assertEqual(image_processor.size, {"height": 18, "width": 18}) - self.assertEqual(image_processor.crop_size, {"height": 18, "width": 18}) + for image_processing_class in self.image_processor_list: + image_processor = image_processing_class.from_dict(self.image_processor_dict) + self.assertEqual(image_processor.size, {"height": 18, "width": 18}) + self.assertEqual(image_processor.crop_size, {"height": 18, "width": 18}) - image_processor = self.image_processing_class.from_dict(self.image_processor_dict, size=42, crop_size=84) - self.assertEqual(image_processor.size, {"height": 42, "width": 42}) - self.assertEqual(image_processor.crop_size, {"height": 84, "width": 84}) + image_processor = image_processing_class.from_dict(self.image_processor_dict, size=42, crop_size=84) + self.assertEqual(image_processor.size, {"height": 42, "width": 42}) + self.assertEqual(image_processor.crop_size, {"height": 84, "width": 84})