mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 12:50:06 +06:00
Add Fast owlvit Processor (#37164)
* Add Fast Owlvit Processor * Update image_processing_owlvit_fast.py * Update image_processing_owlvit_fast.py --------- Co-authored-by: Yoni Gozlan <74535834+yonigozlan@users.noreply.github.com>
This commit is contained in:
parent
cb39f7dd5b
commit
20ceaca228
@ -94,6 +94,11 @@ A demo notebook on using OWL-ViT for zero- and one-shot (image-guided) object de
|
||||
|
||||
[[autodoc]] OwlViTImageProcessor
|
||||
- preprocess
|
||||
|
||||
## OwlViTImageProcessorFast
|
||||
|
||||
[[autodoc]] OwlViTImageProcessorFast
|
||||
- preprocess
|
||||
- post_process_object_detection
|
||||
- post_process_image_guided_detection
|
||||
|
||||
|
@ -123,7 +123,7 @@ else:
|
||||
("nougat", ("NougatImageProcessor",)),
|
||||
("oneformer", ("OneFormerImageProcessor",)),
|
||||
("owlv2", ("Owlv2ImageProcessor",)),
|
||||
("owlvit", ("OwlViTImageProcessor",)),
|
||||
("owlvit", ("OwlViTImageProcessor", "OwlViTImageProcessorFast")),
|
||||
("paligemma", ("SiglipImageProcessor", "SiglipImageProcessorFast")),
|
||||
("perceiver", ("PerceiverImageProcessor", "PerceiverImageProcessorFast")),
|
||||
("phi4_multimodal", "Phi4MultimodalImageProcessorFast"),
|
||||
|
@ -21,6 +21,7 @@ if TYPE_CHECKING:
|
||||
from .configuration_owlvit import *
|
||||
from .feature_extraction_owlvit import *
|
||||
from .image_processing_owlvit import *
|
||||
from .image_processing_owlvit_fast import *
|
||||
from .modeling_owlvit import *
|
||||
from .processing_owlvit import *
|
||||
else:
|
||||
|
240
src/transformers/models/owlvit/image_processing_owlvit_fast.py
Normal file
240
src/transformers/models/owlvit/image_processing_owlvit_fast.py
Normal file
@ -0,0 +1,240 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Fast Image processor class for OwlViT"""
|
||||
|
||||
import warnings
|
||||
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
|
||||
|
||||
from ...image_processing_utils_fast import (
|
||||
BASE_IMAGE_PROCESSOR_FAST_DOCSTRING,
|
||||
BaseImageProcessorFast,
|
||||
)
|
||||
from ...image_transforms import center_to_corners_format
|
||||
from ...image_utils import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD, PILImageResampling
|
||||
from ...utils import TensorType, add_start_docstrings, is_torch_available, logging
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .modeling_owlvit import OwlViTObjectDetectionOutput
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from .image_processing_owlvit import _scale_boxes, box_iou
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"Constructs a fast OwlViT image processor.",
|
||||
BASE_IMAGE_PROCESSOR_FAST_DOCSTRING,
|
||||
)
|
||||
class OwlViTImageProcessorFast(BaseImageProcessorFast):
|
||||
resample = PILImageResampling.BICUBIC
|
||||
image_mean = OPENAI_CLIP_MEAN
|
||||
image_std = OPENAI_CLIP_STD
|
||||
size = {"height": 768, "width": 768}
|
||||
default_to_square = True
|
||||
crop_size = {"height": 768, "width": 768}
|
||||
do_resize = True
|
||||
do_center_crop = False
|
||||
do_rescale = True
|
||||
do_normalize = None
|
||||
do_convert_rgb = None
|
||||
model_input_names = ["pixel_values"]
|
||||
|
||||
# Copied from transformers.models.owlvit.image_processing_owlvit.OwlViTImageProcessor.post_process
|
||||
def post_process(self, outputs, target_sizes):
|
||||
"""
|
||||
Converts the raw output of [`OwlViTForObjectDetection`] into final bounding boxes in (top_left_x, top_left_y,
|
||||
bottom_right_x, bottom_right_y) format.
|
||||
|
||||
Args:
|
||||
outputs ([`OwlViTObjectDetectionOutput`]):
|
||||
Raw outputs of the model.
|
||||
target_sizes (`torch.Tensor` of shape `(batch_size, 2)`):
|
||||
Tensor containing the size (h, w) of each image of the batch. For evaluation, this must be the original
|
||||
image size (before any data augmentation). For visualization, this should be the image size after data
|
||||
augment, but before padding.
|
||||
Returns:
|
||||
`List[Dict]`: A list of dictionaries, each dictionary containing the scores, labels and boxes for an image
|
||||
in the batch as predicted by the model.
|
||||
"""
|
||||
# TODO: (amy) add support for other frameworks
|
||||
warnings.warn(
|
||||
"`post_process` is deprecated and will be removed in v5 of Transformers, please use"
|
||||
" `post_process_object_detection` instead, with `threshold=0.` for equivalent results.",
|
||||
FutureWarning,
|
||||
)
|
||||
|
||||
logits, boxes = outputs.logits, outputs.pred_boxes
|
||||
|
||||
if len(logits) != len(target_sizes):
|
||||
raise ValueError("Make sure that you pass in as many target sizes as the batch dimension of the logits")
|
||||
if target_sizes.shape[1] != 2:
|
||||
raise ValueError("Each element of target_sizes must contain the size (h, w) of each image of the batch")
|
||||
|
||||
probs = torch.max(logits, dim=-1)
|
||||
scores = torch.sigmoid(probs.values)
|
||||
labels = probs.indices
|
||||
|
||||
# Convert to [x0, y0, x1, y1] format
|
||||
boxes = center_to_corners_format(boxes)
|
||||
|
||||
# Convert from relative [0, 1] to absolute [0, height] coordinates
|
||||
img_h, img_w = target_sizes.unbind(1)
|
||||
scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1).to(boxes.device)
|
||||
boxes = boxes * scale_fct[:, None, :]
|
||||
|
||||
results = [{"scores": s, "labels": l, "boxes": b} for s, l, b in zip(scores, labels, boxes)]
|
||||
|
||||
return results
|
||||
|
||||
# Copied from transformers.models.owlvit.image_processing_owlvit.OwlViTImageProcessor.post_process_object_detection
|
||||
def post_process_object_detection(
|
||||
self,
|
||||
outputs: "OwlViTObjectDetectionOutput",
|
||||
threshold: float = 0.1,
|
||||
target_sizes: Optional[Union[TensorType, List[Tuple]]] = None,
|
||||
):
|
||||
"""
|
||||
Converts the raw output of [`OwlViTForObjectDetection`] into final bounding boxes in (top_left_x, top_left_y,
|
||||
bottom_right_x, bottom_right_y) format.
|
||||
|
||||
Args:
|
||||
outputs ([`OwlViTObjectDetectionOutput`]):
|
||||
Raw outputs of the model.
|
||||
threshold (`float`, *optional*, defaults to 0.1):
|
||||
Score threshold to keep object detection predictions.
|
||||
target_sizes (`torch.Tensor` or `List[Tuple[int, int]]`, *optional*):
|
||||
Tensor of shape `(batch_size, 2)` or list of tuples (`Tuple[int, int]`) containing the target size
|
||||
`(height, width)` of each image in the batch. If unset, predictions will not be resized.
|
||||
|
||||
Returns:
|
||||
`List[Dict]`: A list of dictionaries, each dictionary containing the following keys:
|
||||
- "scores": The confidence scores for each predicted box on the image.
|
||||
- "labels": Indexes of the classes predicted by the model on the image.
|
||||
- "boxes": Image bounding boxes in (top_left_x, top_left_y, bottom_right_x, bottom_right_y) format.
|
||||
"""
|
||||
batch_logits, batch_boxes = outputs.logits, outputs.pred_boxes
|
||||
batch_size = len(batch_logits)
|
||||
|
||||
if target_sizes is not None and len(target_sizes) != batch_size:
|
||||
raise ValueError("Make sure that you pass in as many target sizes as images")
|
||||
|
||||
# batch_logits of shape (batch_size, num_queries, num_classes)
|
||||
batch_class_logits = torch.max(batch_logits, dim=-1)
|
||||
batch_scores = torch.sigmoid(batch_class_logits.values)
|
||||
batch_labels = batch_class_logits.indices
|
||||
|
||||
# Convert to [x0, y0, x1, y1] format
|
||||
batch_boxes = center_to_corners_format(batch_boxes)
|
||||
|
||||
# Convert from relative [0, 1] to absolute [0, height] coordinates
|
||||
if target_sizes is not None:
|
||||
batch_boxes = _scale_boxes(batch_boxes, target_sizes)
|
||||
|
||||
results = []
|
||||
for scores, labels, boxes in zip(batch_scores, batch_labels, batch_boxes):
|
||||
keep = scores > threshold
|
||||
scores = scores[keep]
|
||||
labels = labels[keep]
|
||||
boxes = boxes[keep]
|
||||
results.append({"scores": scores, "labels": labels, "boxes": boxes})
|
||||
|
||||
return results
|
||||
|
||||
# Copied from transformers.models.owlvit.image_processing_owlvit.OwlViTImageProcessor.post_process_image_guided_detection
|
||||
def post_process_image_guided_detection(self, outputs, threshold=0.0, nms_threshold=0.3, target_sizes=None):
|
||||
"""
|
||||
Converts the output of [`OwlViTForObjectDetection.image_guided_detection`] into the format expected by the COCO
|
||||
api.
|
||||
|
||||
Args:
|
||||
outputs ([`OwlViTImageGuidedObjectDetectionOutput`]):
|
||||
Raw outputs of the model.
|
||||
threshold (`float`, *optional*, defaults to 0.0):
|
||||
Minimum confidence threshold to use to filter out predicted boxes.
|
||||
nms_threshold (`float`, *optional*, defaults to 0.3):
|
||||
IoU threshold for non-maximum suppression of overlapping boxes.
|
||||
target_sizes (`torch.Tensor`, *optional*):
|
||||
Tensor of shape (batch_size, 2) where each entry is the (height, width) of the corresponding image in
|
||||
the batch. If set, predicted normalized bounding boxes are rescaled to the target sizes. If left to
|
||||
None, predictions will not be unnormalized.
|
||||
|
||||
Returns:
|
||||
`List[Dict]`: A list of dictionaries, each dictionary containing the scores, labels and boxes for an image
|
||||
in the batch as predicted by the model. All labels are set to None as
|
||||
`OwlViTForObjectDetection.image_guided_detection` perform one-shot object detection.
|
||||
"""
|
||||
logits, target_boxes = outputs.logits, outputs.target_pred_boxes
|
||||
|
||||
if target_sizes is not None and len(logits) != len(target_sizes):
|
||||
raise ValueError("Make sure that you pass in as many target sizes as the batch dimension of the logits")
|
||||
if target_sizes is not None and target_sizes.shape[1] != 2:
|
||||
raise ValueError("Each element of target_sizes must contain the size (h, w) of each image of the batch")
|
||||
|
||||
probs = torch.max(logits, dim=-1)
|
||||
scores = torch.sigmoid(probs.values)
|
||||
|
||||
# Convert to [x0, y0, x1, y1] format
|
||||
target_boxes = center_to_corners_format(target_boxes)
|
||||
|
||||
# Apply non-maximum suppression (NMS)
|
||||
if nms_threshold < 1.0:
|
||||
for idx in range(target_boxes.shape[0]):
|
||||
for i in torch.argsort(-scores[idx]):
|
||||
if not scores[idx][i]:
|
||||
continue
|
||||
|
||||
ious = box_iou(target_boxes[idx][i, :].unsqueeze(0), target_boxes[idx])[0][0]
|
||||
ious[i] = -1.0 # Mask self-IoU.
|
||||
scores[idx][ious > nms_threshold] = 0.0
|
||||
|
||||
# Convert from relative [0, 1] to absolute [0, height] coordinates
|
||||
if target_sizes is not None:
|
||||
target_boxes = _scale_boxes(target_boxes, target_sizes)
|
||||
|
||||
# Compute box display alphas based on prediction scores
|
||||
results = []
|
||||
alphas = torch.zeros_like(scores)
|
||||
|
||||
for idx in range(target_boxes.shape[0]):
|
||||
# Select scores for boxes matching the current query:
|
||||
query_scores = scores[idx]
|
||||
if not query_scores.nonzero().numel():
|
||||
continue
|
||||
|
||||
# Apply threshold on scores before scaling
|
||||
query_scores[query_scores < threshold] = 0.0
|
||||
|
||||
# Scale box alpha such that the best box for each query has alpha 1.0 and the worst box has alpha 0.1.
|
||||
# All other boxes will either belong to a different query, or will not be shown.
|
||||
max_score = torch.max(query_scores) + 1e-6
|
||||
query_alphas = (query_scores - (max_score * 0.1)) / (max_score * 0.9)
|
||||
query_alphas = torch.clip(query_alphas, 0.0, 1.0)
|
||||
alphas[idx] = query_alphas
|
||||
|
||||
mask = alphas[idx] > 0
|
||||
box_scores = alphas[idx][mask]
|
||||
boxes = target_boxes[idx][mask]
|
||||
results.append({"scores": box_scores, "labels": None, "boxes": boxes})
|
||||
|
||||
return results
|
||||
|
||||
|
||||
__all__ = ["OwlViTImageProcessorFast"]
|
@ -16,7 +16,7 @@
|
||||
import unittest
|
||||
|
||||
from transformers.testing_utils import require_torch, require_vision
|
||||
from transformers.utils import is_vision_available
|
||||
from transformers.utils import is_torchvision_available, is_vision_available
|
||||
|
||||
from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs
|
||||
|
||||
@ -24,6 +24,9 @@ from ...test_image_processing_common import ImageProcessingTestMixin, prepare_im
|
||||
if is_vision_available():
|
||||
from transformers import OwlViTImageProcessor
|
||||
|
||||
if is_torchvision_available():
|
||||
from transformers import OwlViTImageProcessorFast
|
||||
|
||||
|
||||
class OwlViTImageProcessingTester:
|
||||
def __init__(
|
||||
@ -89,6 +92,7 @@ class OwlViTImageProcessingTester:
|
||||
@require_vision
|
||||
class OwlViTImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
|
||||
image_processing_class = OwlViTImageProcessor if is_vision_available() else None
|
||||
fast_image_processing_class = OwlViTImageProcessorFast if is_torchvision_available() else None
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
@ -99,21 +103,23 @@ class OwlViTImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
|
||||
return self.image_processor_tester.prepare_image_processor_dict()
|
||||
|
||||
def test_image_processor_properties(self):
|
||||
image_processing = self.image_processing_class(**self.image_processor_dict)
|
||||
self.assertTrue(hasattr(image_processing, "do_resize"))
|
||||
self.assertTrue(hasattr(image_processing, "size"))
|
||||
self.assertTrue(hasattr(image_processing, "do_center_crop"))
|
||||
self.assertTrue(hasattr(image_processing, "center_crop"))
|
||||
self.assertTrue(hasattr(image_processing, "do_normalize"))
|
||||
self.assertTrue(hasattr(image_processing, "image_mean"))
|
||||
self.assertTrue(hasattr(image_processing, "image_std"))
|
||||
self.assertTrue(hasattr(image_processing, "do_convert_rgb"))
|
||||
for image_processing_class in self.image_processor_list:
|
||||
image_processing = image_processing_class(**self.image_processor_dict)
|
||||
self.assertTrue(hasattr(image_processing, "do_resize"))
|
||||
self.assertTrue(hasattr(image_processing, "size"))
|
||||
self.assertTrue(hasattr(image_processing, "do_center_crop"))
|
||||
self.assertTrue(hasattr(image_processing, "center_crop"))
|
||||
self.assertTrue(hasattr(image_processing, "do_normalize"))
|
||||
self.assertTrue(hasattr(image_processing, "image_mean"))
|
||||
self.assertTrue(hasattr(image_processing, "image_std"))
|
||||
self.assertTrue(hasattr(image_processing, "do_convert_rgb"))
|
||||
|
||||
def test_image_processor_from_dict_with_kwargs(self):
|
||||
image_processor = self.image_processing_class.from_dict(self.image_processor_dict)
|
||||
self.assertEqual(image_processor.size, {"height": 18, "width": 18})
|
||||
self.assertEqual(image_processor.crop_size, {"height": 18, "width": 18})
|
||||
for image_processing_class in self.image_processor_list:
|
||||
image_processor = image_processing_class.from_dict(self.image_processor_dict)
|
||||
self.assertEqual(image_processor.size, {"height": 18, "width": 18})
|
||||
self.assertEqual(image_processor.crop_size, {"height": 18, "width": 18})
|
||||
|
||||
image_processor = self.image_processing_class.from_dict(self.image_processor_dict, size=42, crop_size=84)
|
||||
self.assertEqual(image_processor.size, {"height": 42, "width": 42})
|
||||
self.assertEqual(image_processor.crop_size, {"height": 84, "width": 84})
|
||||
image_processor = image_processing_class.from_dict(self.image_processor_dict, size=42, crop_size=84)
|
||||
self.assertEqual(image_processor.size, {"height": 42, "width": 42})
|
||||
self.assertEqual(image_processor.crop_size, {"height": 84, "width": 84})
|
||||
|
Loading…
Reference in New Issue
Block a user