Add Fast Conditional-DETR Processor (#37071)

* Add Fast Conditional-DETR Processor

* Update image_processing_conditional_detr_fast.py

* Add modular_conditional_detr.py

* Update image_processing_conditional_detr_fast.py

* Update tests

* make fix

---------

Co-authored-by: Yoni Gozlan <74535834+yonigozlan@users.noreply.github.com>
This commit is contained in:
Parteek 2025-04-15 22:03:34 +05:30 committed by GitHub
parent 4f1dbe8152
commit 51f544a4d4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 1310 additions and 81 deletions

View File

@ -48,6 +48,11 @@ This model was contributed by [DepuMeng](https://huggingface.co/DepuMeng). The o
[[autodoc]] ConditionalDetrImageProcessor
- preprocess
## ConditionalDetrImageProcessorFast
[[autodoc]] ConditionalDetrImageProcessorFast
- preprocess
- post_process_object_detection
- post_process_instance_segmentation
- post_process_semantic_segmentation

View File

@ -43,6 +43,11 @@ alt="描画" width="600"/>
[[autodoc]] ConditionalDetrImageProcessor
- preprocess
## ConditionalDetrImageProcessorFast
[[autodoc]] ConditionalDetrImageProcessorFast
- preprocess
- post_process_object_detection
- post_process_instance_segmentation
- post_process_semantic_segmentation

View File

@ -67,7 +67,7 @@ else:
("chinese_clip", ("ChineseCLIPImageProcessor", "ChineseCLIPImageProcessorFast")),
("clip", ("CLIPImageProcessor", "CLIPImageProcessorFast")),
("clipseg", ("ViTImageProcessor", "ViTImageProcessorFast")),
("conditional_detr", ("ConditionalDetrImageProcessor",)),
("conditional_detr", ("ConditionalDetrImageProcessor", "ConditionalDetrImageProcessorFast")),
("convnext", ("ConvNextImageProcessor", "ConvNextImageProcessorFast")),
("convnextv2", ("ConvNextImageProcessor", "ConvNextImageProcessorFast")),
("cvt", ("ConvNextImageProcessor", "ConvNextImageProcessorFast")),

View File

@ -21,6 +21,7 @@ if TYPE_CHECKING:
from .configuration_conditional_detr import *
from .feature_extraction_conditional_detr import *
from .image_processing_conditional_detr import *
from .image_processing_conditional_detr_fast import *
from .modeling_conditional_detr import *
else:
import sys

View File

@ -0,0 +1,137 @@
from typing import List, Tuple, Union
from transformers.models.detr.image_processing_detr_fast import DetrImageProcessorFast
from ...image_transforms import (
center_to_corners_format,
)
from ...utils import (
TensorType,
is_torch_available,
logging,
)
if is_torch_available():
import torch
logger = logging.get_logger(__name__)
class ConditionalDetrImageProcessorFast(DetrImageProcessorFast):
def post_process(self, outputs, target_sizes):
"""
Converts the output of [`ConditionalDetrForObjectDetection`] into the format expected by the Pascal VOC format (xmin, ymin, xmax, ymax).
Only supports PyTorch.
Args:
outputs ([`ConditionalDetrObjectDetectionOutput`]):
Raw outputs of the model.
target_sizes (`torch.Tensor` of shape `(batch_size, 2)`):
Tensor containing the size (h, w) of each image of the batch. For evaluation, this must be the original
image size (before any data augmentation). For visualization, this should be the image size after data
augment, but before padding.
Returns:
`List[Dict]`: A list of dictionaries, each dictionary containing the scores, labels and boxes for an image
in the batch as predicted by the model.
"""
logging.warning_once(
"`post_process` is deprecated and will be removed in v5 of Transformers, please use"
" `post_process_object_detection` instead, with `threshold=0.` for equivalent results.",
)
out_logits, out_bbox = outputs.logits, outputs.pred_boxes
if len(out_logits) != len(target_sizes):
raise ValueError("Make sure that you pass in as many target sizes as the batch dimension of the logits")
if target_sizes.shape[1] != 2:
raise ValueError("Each element of target_sizes must contain the size (h, w) of each image of the batch")
prob = out_logits.sigmoid()
topk_values, topk_indexes = torch.topk(prob.view(out_logits.shape[0], -1), 300, dim=1)
scores = topk_values
topk_boxes = torch.div(topk_indexes, out_logits.shape[2], rounding_mode="floor")
labels = topk_indexes % out_logits.shape[2]
boxes = center_to_corners_format(out_bbox)
boxes = torch.gather(boxes, 1, topk_boxes.unsqueeze(-1).repeat(1, 1, 4))
# and from relative [0, 1] to absolute [0, height] coordinates
img_h, img_w = target_sizes.unbind(1)
scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1)
boxes = boxes * scale_fct[:, None, :]
results = [{"scores": s, "labels": l, "boxes": b} for s, l, b in zip(scores, labels, boxes)]
return results
def post_process_object_detection(
self, outputs, threshold: float = 0.5, target_sizes: Union[TensorType, List[Tuple]] = None, top_k: int = 100
):
"""
Converts the raw output of [`ConditionalDetrForObjectDetection`] into final bounding boxes in (top_left_x,
top_left_y, bottom_right_x, bottom_right_y) format. Only supports PyTorch.
Args:
outputs ([`ConditionalDetrObjectDetectionOutput`]):
Raw outputs of the model.
threshold (`float`, *optional*):
Score threshold to keep object detection predictions.
target_sizes (`torch.Tensor` or `List[Tuple[int, int]]`, *optional*):
Tensor of shape `(batch_size, 2)` or list of tuples (`Tuple[int, int]`) containing the target size
(height, width) of each image in the batch. If left to None, predictions will not be resized.
top_k (`int`, *optional*, defaults to 100):
Keep only top k bounding boxes before filtering by thresholding.
Returns:
`List[Dict]`: A list of dictionaries, each dictionary containing the scores, labels and boxes for an image
in the batch as predicted by the model.
"""
out_logits, out_bbox = outputs.logits, outputs.pred_boxes
if target_sizes is not None:
if len(out_logits) != len(target_sizes):
raise ValueError(
"Make sure that you pass in as many target sizes as the batch dimension of the logits"
)
prob = out_logits.sigmoid()
prob = prob.view(out_logits.shape[0], -1)
k_value = min(top_k, prob.size(1))
topk_values, topk_indexes = torch.topk(prob, k_value, dim=1)
scores = topk_values
topk_boxes = torch.div(topk_indexes, out_logits.shape[2], rounding_mode="floor")
labels = topk_indexes % out_logits.shape[2]
boxes = center_to_corners_format(out_bbox)
boxes = torch.gather(boxes, 1, topk_boxes.unsqueeze(-1).repeat(1, 1, 4))
# and from relative [0, 1] to absolute [0, height] coordinates
if target_sizes is not None:
if isinstance(target_sizes, List):
img_h = torch.Tensor([i[0] for i in target_sizes])
img_w = torch.Tensor([i[1] for i in target_sizes])
else:
img_h, img_w = target_sizes.unbind(1)
scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1).to(boxes.device)
boxes = boxes * scale_fct[:, None, :]
results = []
for s, l, b in zip(scores, labels, boxes):
score = s[s > threshold]
label = l[s > threshold]
box = b[s > threshold]
results.append({"scores": score, "labels": label, "boxes": box})
return results
def post_process_segmentation():
raise NotImplementedError("Segmentation post-processing is not implemented for Conditional DETR yet.")
def post_process_instance():
raise NotImplementedError("Instance post-processing is not implemented for Conditional DETR yet.")
def post_process_panoptic():
raise NotImplementedError("Panoptic post-processing is not implemented for Conditional DETR yet.")
__all__ = ["ConditionalDetrImageProcessorFast"]

View File

@ -20,7 +20,7 @@ import unittest
import numpy as np
from transformers.testing_utils import require_torch, require_vision, slow
from transformers.utils import is_torch_available, is_vision_available
from transformers.utils import is_torch_available, is_torchvision_available, is_vision_available
from ...test_image_processing_common import AnnotationFormatTestMixin, ImageProcessingTestMixin, prepare_image_inputs
@ -33,6 +33,9 @@ if is_vision_available():
from transformers import ConditionalDetrImageProcessor
if is_torchvision_available():
from transformers import ConditionalDetrImageProcessorFast
class ConditionalDetrImageProcessingTester:
def __init__(
@ -132,6 +135,7 @@ class ConditionalDetrImageProcessingTester:
@require_vision
class ConditionalDetrImageProcessingTest(AnnotationFormatTestMixin, ImageProcessingTestMixin, unittest.TestCase):
image_processing_class = ConditionalDetrImageProcessor if is_vision_available() else None
fast_image_processing_class = ConditionalDetrImageProcessorFast if is_torchvision_available() else None
def setUp(self):
super().setUp()
@ -142,23 +146,25 @@ class ConditionalDetrImageProcessingTest(AnnotationFormatTestMixin, ImageProcess
return self.image_processor_tester.prepare_image_processor_dict()
def test_image_processor_properties(self):
image_processing = self.image_processing_class(**self.image_processor_dict)
self.assertTrue(hasattr(image_processing, "image_mean"))
self.assertTrue(hasattr(image_processing, "image_std"))
self.assertTrue(hasattr(image_processing, "do_normalize"))
self.assertTrue(hasattr(image_processing, "do_resize"))
self.assertTrue(hasattr(image_processing, "size"))
for image_processing_class in self.image_processor_list:
image_processing = image_processing_class(**self.image_processor_dict)
self.assertTrue(hasattr(image_processing, "image_mean"))
self.assertTrue(hasattr(image_processing, "image_std"))
self.assertTrue(hasattr(image_processing, "do_normalize"))
self.assertTrue(hasattr(image_processing, "do_resize"))
self.assertTrue(hasattr(image_processing, "size"))
def test_image_processor_from_dict_with_kwargs(self):
image_processor = self.image_processing_class.from_dict(self.image_processor_dict)
self.assertEqual(image_processor.size, {"shortest_edge": 18, "longest_edge": 1333})
self.assertEqual(image_processor.do_pad, True)
for image_processing_class in self.image_processor_list:
image_processor = image_processing_class.from_dict(self.image_processor_dict)
self.assertEqual(image_processor.size, {"shortest_edge": 18, "longest_edge": 1333})
self.assertEqual(image_processor.do_pad, True)
image_processor = self.image_processing_class.from_dict(
self.image_processor_dict, size=42, max_size=84, pad_and_return_pixel_mask=False
)
self.assertEqual(image_processor.size, {"shortest_edge": 42, "longest_edge": 84})
self.assertEqual(image_processor.do_pad, False)
image_processor = image_processing_class.from_dict(
self.image_processor_dict, size=42, max_size=84, pad_and_return_pixel_mask=False
)
self.assertEqual(image_processor.size, {"shortest_edge": 42, "longest_edge": 84})
self.assertEqual(image_processor.do_pad, False)
@slow
def test_call_pytorch_with_coco_detection_annotations(self):
@ -169,40 +175,41 @@ class ConditionalDetrImageProcessingTest(AnnotationFormatTestMixin, ImageProcess
target = {"image_id": 39769, "annotations": target}
# encode them
image_processing = ConditionalDetrImageProcessor.from_pretrained("microsoft/conditional-detr-resnet-50")
encoding = image_processing(images=image, annotations=target, return_tensors="pt")
for image_processing_class in self.image_processor_list:
# encode them
image_processing = image_processing_class.from_pretrained("microsoft/conditional-detr-resnet-50")
encoding = image_processing(images=image, annotations=target, return_tensors="pt")
# verify pixel values
expected_shape = torch.Size([1, 3, 800, 1066])
self.assertEqual(encoding["pixel_values"].shape, expected_shape)
# verify pixel values
expected_shape = torch.Size([1, 3, 800, 1066])
self.assertEqual(encoding["pixel_values"].shape, expected_shape)
expected_slice = torch.tensor([0.2796, 0.3138, 0.3481])
torch.testing.assert_close(encoding["pixel_values"][0, 0, 0, :3], expected_slice, rtol=1e-4, atol=1e-4)
expected_slice = torch.tensor([0.2796, 0.3138, 0.3481])
torch.testing.assert_close(encoding["pixel_values"][0, 0, 0, :3], expected_slice, rtol=1e-4, atol=1e-4)
# verify area
expected_area = torch.tensor([5887.9600, 11250.2061, 489353.8438, 837122.7500, 147967.5156, 165732.3438])
torch.testing.assert_close(encoding["labels"][0]["area"], expected_area)
# verify boxes
expected_boxes_shape = torch.Size([6, 4])
self.assertEqual(encoding["labels"][0]["boxes"].shape, expected_boxes_shape)
expected_boxes_slice = torch.tensor([0.5503, 0.2765, 0.0604, 0.2215])
torch.testing.assert_close(encoding["labels"][0]["boxes"][0], expected_boxes_slice, rtol=1e-3, atol=1e-3)
# verify image_id
expected_image_id = torch.tensor([39769])
torch.testing.assert_close(encoding["labels"][0]["image_id"], expected_image_id)
# verify is_crowd
expected_is_crowd = torch.tensor([0, 0, 0, 0, 0, 0])
torch.testing.assert_close(encoding["labels"][0]["iscrowd"], expected_is_crowd)
# verify class_labels
expected_class_labels = torch.tensor([75, 75, 63, 65, 17, 17])
torch.testing.assert_close(encoding["labels"][0]["class_labels"], expected_class_labels)
# verify orig_size
expected_orig_size = torch.tensor([480, 640])
torch.testing.assert_close(encoding["labels"][0]["orig_size"], expected_orig_size)
# verify size
expected_size = torch.tensor([800, 1066])
torch.testing.assert_close(encoding["labels"][0]["size"], expected_size)
# verify area
expected_area = torch.tensor([5887.9600, 11250.2061, 489353.8438, 837122.7500, 147967.5156, 165732.3438])
torch.testing.assert_close(encoding["labels"][0]["area"], expected_area)
# verify boxes
expected_boxes_shape = torch.Size([6, 4])
self.assertEqual(encoding["labels"][0]["boxes"].shape, expected_boxes_shape)
expected_boxes_slice = torch.tensor([0.5503, 0.2765, 0.0604, 0.2215])
torch.testing.assert_close(encoding["labels"][0]["boxes"][0], expected_boxes_slice, rtol=1e-3, atol=1e-3)
# verify image_id
expected_image_id = torch.tensor([39769])
torch.testing.assert_close(encoding["labels"][0]["image_id"], expected_image_id)
# verify is_crowd
expected_is_crowd = torch.tensor([0, 0, 0, 0, 0, 0])
torch.testing.assert_close(encoding["labels"][0]["iscrowd"], expected_is_crowd)
# verify class_labels
expected_class_labels = torch.tensor([75, 75, 63, 65, 17, 17])
torch.testing.assert_close(encoding["labels"][0]["class_labels"], expected_class_labels)
# verify orig_size
expected_orig_size = torch.tensor([480, 640])
torch.testing.assert_close(encoding["labels"][0]["orig_size"], expected_orig_size)
# verify size
expected_size = torch.tensor([800, 1066])
torch.testing.assert_close(encoding["labels"][0]["size"], expected_size)
@slow
def test_call_pytorch_with_coco_panoptic_annotations(self):
@ -215,43 +222,45 @@ class ConditionalDetrImageProcessingTest(AnnotationFormatTestMixin, ImageProcess
masks_path = pathlib.Path("./tests/fixtures/tests_samples/COCO/coco_panoptic")
# encode them
image_processing = ConditionalDetrImageProcessor(format="coco_panoptic")
encoding = image_processing(images=image, annotations=target, masks_path=masks_path, return_tensors="pt")
for image_processing_class in self.image_processor_list:
# encode them
image_processing = image_processing_class(format="coco_panoptic")
encoding = image_processing(images=image, annotations=target, masks_path=masks_path, return_tensors="pt")
# verify pixel values
expected_shape = torch.Size([1, 3, 800, 1066])
self.assertEqual(encoding["pixel_values"].shape, expected_shape)
# verify pixel values
expected_shape = torch.Size([1, 3, 800, 1066])
self.assertEqual(encoding["pixel_values"].shape, expected_shape)
expected_slice = torch.tensor([0.2796, 0.3138, 0.3481])
torch.testing.assert_close(encoding["pixel_values"][0, 0, 0, :3], expected_slice, rtol=1e-4, atol=1e-4)
expected_slice = torch.tensor([0.2796, 0.3138, 0.3481])
torch.testing.assert_close(encoding["pixel_values"][0, 0, 0, :3], expected_slice, rtol=1e-4, atol=1e-4)
# verify area
expected_area = torch.tensor([147979.6875, 165527.0469, 484638.5938, 11292.9375, 5879.6562, 7634.1147])
torch.testing.assert_close(encoding["labels"][0]["area"], expected_area)
# verify boxes
expected_boxes_shape = torch.Size([6, 4])
self.assertEqual(encoding["labels"][0]["boxes"].shape, expected_boxes_shape)
expected_boxes_slice = torch.tensor([0.2625, 0.5437, 0.4688, 0.8625])
torch.testing.assert_close(encoding["labels"][0]["boxes"][0], expected_boxes_slice, rtol=1e-3, atol=1e-3)
# verify image_id
expected_image_id = torch.tensor([39769])
torch.testing.assert_close(encoding["labels"][0]["image_id"], expected_image_id)
# verify is_crowd
expected_is_crowd = torch.tensor([0, 0, 0, 0, 0, 0])
torch.testing.assert_close(encoding["labels"][0]["iscrowd"], expected_is_crowd)
# verify class_labels
expected_class_labels = torch.tensor([17, 17, 63, 75, 75, 93])
torch.testing.assert_close(encoding["labels"][0]["class_labels"], expected_class_labels)
# verify masks
expected_masks_sum = 822873
self.assertEqual(encoding["labels"][0]["masks"].sum().item(), expected_masks_sum)
# verify orig_size
expected_orig_size = torch.tensor([480, 640])
torch.testing.assert_close(encoding["labels"][0]["orig_size"], expected_orig_size)
# verify size
expected_size = torch.tensor([800, 1066])
torch.testing.assert_close(encoding["labels"][0]["size"], expected_size)
# verify area
expected_area = torch.tensor([147979.6875, 165527.0469, 484638.5938, 11292.9375, 5879.6562, 7634.1147])
torch.testing.assert_close(encoding["labels"][0]["area"], expected_area)
# verify boxes
expected_boxes_shape = torch.Size([6, 4])
self.assertEqual(encoding["labels"][0]["boxes"].shape, expected_boxes_shape)
expected_boxes_slice = torch.tensor([0.2625, 0.5437, 0.4688, 0.8625])
torch.testing.assert_close(encoding["labels"][0]["boxes"][0], expected_boxes_slice, rtol=1e-3, atol=1e-3)
# verify image_id
expected_image_id = torch.tensor([39769])
torch.testing.assert_close(encoding["labels"][0]["image_id"], expected_image_id)
# verify is_crowd
expected_is_crowd = torch.tensor([0, 0, 0, 0, 0, 0])
torch.testing.assert_close(encoding["labels"][0]["iscrowd"], expected_is_crowd)
# verify class_labels
expected_class_labels = torch.tensor([17, 17, 63, 75, 75, 93])
torch.testing.assert_close(encoding["labels"][0]["class_labels"], expected_class_labels)
# verify masks
expected_masks_sum = 822873
relative_error = torch.abs(encoding["labels"][0]["masks"].sum() - expected_masks_sum) / expected_masks_sum
self.assertTrue(relative_error < 1e-3)
# verify orig_size
expected_orig_size = torch.tensor([480, 640])
torch.testing.assert_close(encoding["labels"][0]["orig_size"], expected_orig_size)
# verify size
expected_size = torch.tensor([800, 1066])
torch.testing.assert_close(encoding["labels"][0]["size"], expected_size)
@slow
# Copied from tests.models.detr.test_image_processing_detr.DetrImageProcessingTest.test_batched_coco_detection_annotations with Detr->ConditionalDetr, facebook/detr-resnet-50 ->microsoft/conditional-detr-resnet-50