Grounding DINO Processor standardization (#34853)

* Add input ids to model output

* Add text preprocessing for processor

* Fix snippet

* Add test for equivalence

* Add type checking guard

* Fixing typehint

* Fix test for added `input_ids` in output

* Add deprecations and "text_labels" to output

* Adjust tests

* Fix test

* Update code examples

* Minor docs and code improvement

* Remove one-liner functions and rename class to CamelCase

* Update docstring

* Fixup
This commit is contained in:
Pavel Iakubovskii 2025-01-17 14:18:16 +00:00 committed by GitHub
parent 42b2857b01
commit 099d93d2e9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 217 additions and 80 deletions

View File

@ -56,25 +56,26 @@ Here's how to use the model for zero-shot object detection:
>>> image_url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(image_url, stream=True).raw)
>>> # Check for cats and remote controls
>>> text = "a cat. a remote control."
>>> text_labels = [["a cat", "a remote control"]]
>>> inputs = processor(images=image, text=text, return_tensors="pt").to(device)
>>> inputs = processor(images=image, text=text_labels, return_tensors="pt").to(device)
>>> with torch.no_grad():
... outputs = model(**inputs)
>>> results = processor.post_process_grounded_object_detection(
... outputs,
... inputs.input_ids,
... box_threshold=0.4,
... threshold=0.4,
... text_threshold=0.3,
... target_sizes=[image.size[::-1]]
... target_sizes=[(image.height, image.width)]
... )
>>> print(results)
[{'boxes': tensor([[344.6959, 23.1090, 637.1833, 374.2751],
[ 12.2666, 51.9145, 316.8582, 472.4392],
[ 38.5742, 70.0015, 176.7838, 118.1806]], device='cuda:0'),
'labels': ['a cat', 'a cat', 'a remote control'],
'scores': tensor([0.4785, 0.4381, 0.4776], device='cuda:0')}]
>>> # Retrieve the first image result
>>> result = results[0]
>>> for box, score, text_label in zip(result["boxes"], result["scores"], result["text_labels"]):
... box = [round(x, 2) for x in box.tolist()]
... print(f"Detected {text_label} with confidence {round(score.item(), 3)} at location {box}")
Detected a cat with confidence 0.479 at location [344.7, 23.11, 637.18, 374.28]
Detected a cat with confidence 0.438 at location [12.27, 51.91, 316.86, 472.44]
Detected a remote control with confidence 0.478 at location [38.57, 70.0, 176.78, 118.18]
```
## Grounded SAM

View File

@ -286,7 +286,7 @@ class GroundingDinoObjectDetectionOutput(ModelOutput):
pred_boxes (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`):
Normalized boxes coordinates for all queries, represented as (center_x, center_y, width, height). These
values are normalized in [0, 1], relative to the size of each individual image in the batch (disregarding
possible padding). You can use [`~GroundingDinoProcessor.post_process_object_detection`] to retrieve the
possible padding). You can use [`~GroundingDinoProcessor.post_process_grounded_object_detection`] to retrieve the
unnormalized bounding boxes.
auxiliary_outputs (`List[Dict]`, *optional*):
Optional, only returned when auxilary losses are activated (i.e. `config.auxiliary_loss` is set to `True`)
@ -331,6 +331,8 @@ class GroundingDinoObjectDetectionOutput(ModelOutput):
background).
enc_outputs_coord_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, 4)`, *optional*, returned when `config.two_stage=True`):
Logits of predicted bounding boxes coordinates in the first stage.
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Encoded candidate labels sequence. Used in processor to post process object detection result.
"""
loss: Optional[torch.FloatTensor] = None
@ -351,6 +353,7 @@ class GroundingDinoObjectDetectionOutput(ModelOutput):
encoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
enc_outputs_class: Optional[torch.FloatTensor] = None
enc_outputs_coord_logits: Optional[torch.FloatTensor] = None
input_ids: Optional[torch.LongTensor] = None
# Copied from transformers.models.detr.modeling_detr.DetrFrozenBatchNorm2d with Detr->GroundingDino
@ -2546,30 +2549,41 @@ class GroundingDinoForObjectDetection(GroundingDinoPreTrainedModel):
Examples:
```python
>>> from transformers import AutoProcessor, GroundingDinoForObjectDetection
>>> from PIL import Image
>>> import requests
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> text = "a cat."
>>> import torch
>>> from PIL import Image
>>> from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection
>>> processor = AutoProcessor.from_pretrained("IDEA-Research/grounding-dino-tiny")
>>> model = GroundingDinoForObjectDetection.from_pretrained("IDEA-Research/grounding-dino-tiny")
>>> model_id = "IDEA-Research/grounding-dino-tiny"
>>> device = "cuda"
>>> inputs = processor(images=image, text=text, return_tensors="pt")
>>> outputs = model(**inputs)
>>> processor = AutoProcessor.from_pretrained(model_id)
>>> model = AutoModelForZeroShotObjectDetection.from_pretrained(model_id).to(device)
>>> # convert outputs (bounding boxes and class logits) to COCO API
>>> target_sizes = torch.tensor([image.size[::-1]])
>>> results = processor.image_processor.post_process_object_detection(
... outputs, threshold=0.35, target_sizes=target_sizes
... )[0]
>>> for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
... box = [round(i, 1) for i in box.tolist()]
... print(f"Detected {label.item()} with confidence " f"{round(score.item(), 2)} at location {box}")
Detected 1 with confidence 0.45 at location [344.8, 23.2, 637.4, 373.8]
Detected 1 with confidence 0.41 at location [11.9, 51.6, 316.6, 472.9]
>>> image_url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(image_url, stream=True).raw)
>>> # Check for cats and remote controls
>>> text_labels = [["a cat", "a remote control"]]
>>> inputs = processor(images=image, text=text_labels, return_tensors="pt").to(device)
>>> with torch.no_grad():
... outputs = model(**inputs)
>>> results = processor.post_process_grounded_object_detection(
... outputs,
... threshold=0.4,
... text_threshold=0.3,
... target_sizes=[(image.height, image.width)]
... )
>>> # Retrieve the first image result
>>> result = results[0]
>>> for box, score, text_label in zip(result["boxes"], result["scores"], result["text_labels"]):
... box = [round(x, 2) for x in box.tolist()]
... print(f"Detected {text_label} with confidence {round(score.item(), 3)} at location {box}")
Detected a cat with confidence 0.479 at location [344.7, 23.11, 637.18, 374.28]
Detected a cat with confidence 0.438 at location [12.27, 51.91, 316.86, 472.44]
Detected a remote control with confidence 0.478 at location [38.57, 70.0, 176.78, 118.18]
```"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
@ -2639,13 +2653,10 @@ class GroundingDinoForObjectDetection(GroundingDinoPreTrainedModel):
)
if not return_dict:
if auxiliary_outputs is not None:
output = (logits, pred_boxes) + auxiliary_outputs + outputs
else:
output = (logits, pred_boxes) + outputs
tuple_outputs = ((loss, loss_dict) + output) if loss is not None else output
return tuple_outputs
auxiliary_outputs = auxiliary_outputs if auxiliary_outputs is not None else []
output = [loss, loss_dict, logits, pred_boxes, *auxiliary_outputs, *outputs, input_ids]
output = tuple(out for out in output if out is not None)
return output
dict_outputs = GroundingDinoObjectDetectionOutput(
loss=loss,
@ -2666,6 +2677,7 @@ class GroundingDinoForObjectDetection(GroundingDinoPreTrainedModel):
init_reference_points=outputs.init_reference_points,
enc_outputs_class=outputs.enc_outputs_class,
enc_outputs_coord_logits=outputs.enc_outputs_coord_logits,
input_ids=input_ids,
)
return dict_outputs

View File

@ -17,7 +17,8 @@ Processor class for Grounding DINO.
"""
import pathlib
from typing import Dict, List, Optional, Tuple, Union
import warnings
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
from ...image_processing_utils import BatchFeature
from ...image_transforms import center_to_corners_format
@ -25,11 +26,15 @@ from ...image_utils import AnnotationFormat, ImageInput
from ...processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin, Unpack
from ...tokenization_utils_base import BatchEncoding, PreTokenizedInput, TextInput
from ...utils import TensorType, is_torch_available
from ...utils.deprecation import deprecate_kwarg
if is_torch_available():
import torch
if TYPE_CHECKING:
from .modeling_grounding_dino import GroundingDinoObjectDetectionOutput
AnnotationType = Dict[str, Union[int, str, List[Dict]]]
@ -60,6 +65,42 @@ def get_phrases_from_posmap(posmaps, input_ids):
return token_ids
def _is_list_of_candidate_labels(text) -> bool:
"""Check that text is list/tuple of strings and each string is a candidate label and not merged candidate labels text.
Merged candidate labels text is a string with candidate labels separated by a dot.
"""
if isinstance(text, (list, tuple)):
return all(isinstance(t, str) and "." not in t for t in text)
return False
def _merge_candidate_labels_text(text: List[str]) -> str:
"""
Merge candidate labels text into a single string. Ensure all labels are lowercase.
For example, ["A cat", "a dog"] -> "a cat. a dog."
"""
labels = [t.strip().lower() for t in text] # ensure lowercase
merged_labels_str = ". ".join(labels) + "." # join with dot and add a dot at the end
return merged_labels_str
class DictWithDeprecationWarning(dict):
message = (
"The key `labels` is will return integer ids in `GroundingDinoProcessor.post_process_grounded_object_detection` "
"output since v4.51.0. Use `text_labels` instead to retrieve string object names."
)
def __getitem__(self, key):
if key == "labels":
warnings.warn(self.message, FutureWarning)
return super().__getitem__(key)
def get(self, key, *args, **kwargs):
if key == "labels":
warnings.warn(self.message, FutureWarning)
return super().get(key, *args, **kwargs)
class GroundingDinoImagesKwargs(ImagesKwargs, total=False):
annotations: Optional[Union[AnnotationType, List[AnnotationType]]]
return_segmentation_masks: Optional[bool]
@ -120,7 +161,15 @@ class GroundingDinoProcessor(ProcessorMixin):
This method uses [`GroundingDinoImageProcessor.__call__`] method to prepare image(s) for the model, and
[`BertTokenizerFast.__call__`] to prepare text for the model.
Please refer to the docstring of the above two methods for more information.
Args:
images (`ImageInput`, `List[ImageInput]`, *optional*):
The image or batch of images to be processed. The image might be either PIL image, numpy array or a torch tensor.
text (`TextInput`, `PreTokenizedInput`, `List[TextInput]`, `List[PreTokenizedInput]`, *optional*):
Candidate labels to be detected on the image. The text might be one of the following:
- A list of candidate labels (strings) to be detected on the image (e.g. ["a cat", "a dog"]).
- A batch of candidate labels to be detected on the batch of images (e.g. [["a cat", "a dog"], ["a car", "a person"]]).
- A merged candidate labels string to be detected on the image, separated by "." (e.g. "a cat. a dog.").
- A batch of merged candidate labels text to be detected on the batch of images (e.g. ["a cat. a dog.", "a car. a person."]).
"""
if images is None and text is None:
raise ValueError("You must specify either text or images.")
@ -138,6 +187,7 @@ class GroundingDinoProcessor(ProcessorMixin):
encoding_image_processor = BatchFeature()
if text is not None:
text = self._preprocess_input_text(text)
text_encoding = self.tokenizer(
text=text,
**output_kwargs["text_kwargs"],
@ -149,6 +199,23 @@ class GroundingDinoProcessor(ProcessorMixin):
return text_encoding
def _preprocess_input_text(self, text):
"""
Preprocess input text to ensure that labels are in the correct format for the model.
If the text is a list of candidate labels, merge the candidate labels into a single string,
for example, ["a cat", "a dog"] -> "a cat. a dog.". In case candidate labels are already in a form of
"a cat. a dog.", the text is returned as is.
"""
if _is_list_of_candidate_labels(text):
text = _merge_candidate_labels_text(text)
# for batched input
elif isinstance(text, (list, tuple)) and all(_is_list_of_candidate_labels(t) for t in text):
text = [_merge_candidate_labels_text(sample) for sample in text]
return text
# Copied from transformers.models.blip.processing_blip.BlipProcessor.batch_decode with BertTokenizerFast->PreTrainedTokenizer
def batch_decode(self, *args, **kwargs):
"""
@ -172,13 +239,15 @@ class GroundingDinoProcessor(ProcessorMixin):
image_processor_input_names = self.image_processor.model_input_names
return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
@deprecate_kwarg("box_threshold", new_name="threshold", version="4.51.0")
def post_process_grounded_object_detection(
self,
outputs,
input_ids,
box_threshold: float = 0.25,
outputs: "GroundingDinoObjectDetectionOutput",
input_ids: Optional[TensorType] = None,
threshold: float = 0.25,
text_threshold: float = 0.25,
target_sizes: Union[TensorType, List[Tuple]] = None,
target_sizes: Optional[Union[TensorType, List[Tuple]]] = None,
text_labels: Optional[List[List[str]]] = None,
):
"""
Converts the raw output of [`GroundingDinoForObjectDetection`] into final bounding boxes in (top_left_x, top_left_y,
@ -187,32 +256,38 @@ class GroundingDinoProcessor(ProcessorMixin):
Args:
outputs ([`GroundingDinoObjectDetectionOutput`]):
Raw outputs of the model.
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
The token ids of the input text.
box_threshold (`float`, *optional*, defaults to 0.25):
Score threshold to keep object detection predictions.
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
The token ids of the input text. If not provided will be taken from the model output.
threshold (`float`, *optional*, defaults to 0.25):
Threshold to keep object detection predictions based on confidence score.
text_threshold (`float`, *optional*, defaults to 0.25):
Score threshold to keep text detection predictions.
target_sizes (`torch.Tensor` or `List[Tuple[int, int]]`, *optional*):
Tensor of shape `(batch_size, 2)` or list of tuples (`Tuple[int, int]`) containing the target size
`(height, width)` of each image in the batch. If unset, predictions will not be resized.
text_labels (`List[List[str]]`, *optional*):
List of candidate labels to be detected on each image. At the moment it's *NOT used*, but required
to be in signature for the zero-shot object detection pipeline. Text labels are instead extracted
from the `input_ids` tensor provided in `outputs`.
Returns:
`List[Dict]`: A list of dictionaries, each dictionary containing the scores, labels and boxes for an image
in the batch as predicted by the model.
`List[Dict]`: A list of dictionaries, each dictionary containing the
- **scores**: tensor of confidence scores for detected objects
- **boxes**: tensor of bounding boxes in [x0, y0, x1, y1] format
- **labels**: list of text labels for each detected object (will be replaced with integer ids in v4.51.0)
- **text_labels**: list of text labels for detected objects
"""
logits, boxes = outputs.logits, outputs.pred_boxes
batch_logits, batch_boxes = outputs.logits, outputs.pred_boxes
input_ids = input_ids if input_ids is not None else outputs.input_ids
if target_sizes is not None:
if len(logits) != len(target_sizes):
raise ValueError(
"Make sure that you pass in as many target sizes as the batch dimension of the logits"
)
if target_sizes is not None and len(target_sizes) != len(batch_logits):
raise ValueError("Make sure that you pass in as many target sizes as the batch dimension of the logits")
probs = torch.sigmoid(logits) # (batch_size, num_queries, 256)
scores = torch.max(probs, dim=-1)[0] # (batch_size, num_queries)
batch_probs = torch.sigmoid(batch_logits) # (batch_size, num_queries, 256)
batch_scores = torch.max(batch_probs, dim=-1)[0] # (batch_size, num_queries)
# Convert to [x0, y0, x1, y1] format
boxes = center_to_corners_format(boxes)
batch_boxes = center_to_corners_format(batch_boxes)
# Convert from relative [0, 1] to absolute [0, height] coordinates
if target_sizes is not None:
@ -222,17 +297,30 @@ class GroundingDinoProcessor(ProcessorMixin):
else:
img_h, img_w = target_sizes.unbind(1)
scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1).to(boxes.device)
boxes = boxes * scale_fct[:, None, :]
scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1).to(batch_boxes.device)
batch_boxes = batch_boxes * scale_fct[:, None, :]
results = []
for idx, (s, b, p) in enumerate(zip(scores, boxes, probs)):
score = s[s > box_threshold]
box = b[s > box_threshold]
prob = p[s > box_threshold]
for idx, (scores, boxes, probs) in enumerate(zip(batch_scores, batch_boxes, batch_probs)):
keep = scores > threshold
scores = scores[keep]
boxes = boxes[keep]
# extract text labels
prob = probs[keep]
label_ids = get_phrases_from_posmap(prob > text_threshold, input_ids[idx])
label = self.batch_decode(label_ids)
results.append({"scores": score, "labels": label, "boxes": box})
objects_text_labels = self.batch_decode(label_ids)
result = DictWithDeprecationWarning(
{
"scores": scores,
"boxes": boxes,
"text_labels": objects_text_labels,
# TODO: @pavel, set labels to None since v4.51.0 or find a way to extract ids
"labels": objects_text_labels,
}
)
results.append(result)
return results

View File

@ -322,9 +322,9 @@ class GroundingDinoModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.Tes
# loss is at first position
if "labels" in inputs_dict:
correct_outlen += 1 # loss is added to beginning
# Object Detection model returns pred_logits and pred_boxes
# Object Detection model returns pred_logits and pred_boxes and input_ids
if model_class.__name__ == "GroundingDinoForObjectDetection":
correct_outlen += 2
correct_outlen += 3
self.assertEqual(out_len, correct_outlen)
@ -653,7 +653,7 @@ class GroundingDinoModelIntegrationTests(unittest.TestCase):
# verify postprocessing
results = processor.image_processor.post_process_object_detection(
outputs, threshold=0.35, target_sizes=[image.size[::-1]]
outputs, threshold=0.35, target_sizes=[(image.height, image.width)]
)[0]
expected_scores = torch.tensor([0.4526, 0.4082]).to(torch_device)
expected_slice_boxes = torch.tensor([344.8143, 23.1796, 637.4004, 373.8295]).to(torch_device)
@ -667,14 +667,14 @@ class GroundingDinoModelIntegrationTests(unittest.TestCase):
results = processor.post_process_grounded_object_detection(
outputs=outputs,
input_ids=encoding.input_ids,
box_threshold=0.35,
threshold=0.35,
text_threshold=0.3,
target_sizes=[image.size[::-1]],
target_sizes=[(image.height, image.width)],
)[0]
self.assertTrue(torch.allclose(results["scores"], expected_scores, atol=1e-3))
self.assertTrue(torch.allclose(results["boxes"][0, :], expected_slice_boxes, atol=1e-2))
self.assertListEqual(results["labels"], expected_labels)
self.assertListEqual(results["text_labels"], expected_labels)
@require_torch_accelerator
def test_inference_object_detection_head_equivalence_cpu_gpu(self):
@ -706,11 +706,11 @@ class GroundingDinoModelIntegrationTests(unittest.TestCase):
# assert postprocessing
results_cpu = processor.image_processor.post_process_object_detection(
cpu_outputs, threshold=0.35, target_sizes=[image.size[::-1]]
cpu_outputs, threshold=0.35, target_sizes=[(image.height, image.width)]
)[0]
result_gpu = processor.image_processor.post_process_object_detection(
gpu_outputs, threshold=0.35, target_sizes=[image.size[::-1]]
gpu_outputs, threshold=0.35, target_sizes=[(image.height, image.width)]
)[0]
self.assertTrue(torch.allclose(results_cpu["scores"], result_gpu["scores"].cpu(), atol=1e-3))

View File

@ -17,6 +17,7 @@ import os
import shutil
import tempfile
import unittest
from typing import Optional
import pytest
@ -77,6 +78,20 @@ class GroundingDinoProcessorTest(ProcessorTesterMixin, unittest.TestCase):
self.embed_dim = 5
self.seq_length = 5
def prepare_text_inputs(self, batch_size: Optional[int] = None):
labels = ["a cat", "remote control"]
labels_longer = ["a person", "a car", "a dog", "a cat"]
if batch_size is None:
return labels
if batch_size < 1:
raise ValueError("batch_size must be greater than 0")
if batch_size == 1:
return [labels]
return [labels, labels_longer] + [labels] * (batch_size - 2)
# Copied from tests.models.clip.test_processor_clip.CLIPProcessorTest.get_tokenizer with CLIP->Bert
def get_tokenizer(self, **kwargs):
return BertTokenizer.from_pretrained(self.tmpdirname, **kwargs)
@ -98,6 +113,7 @@ class GroundingDinoProcessorTest(ProcessorTesterMixin, unittest.TestCase):
return GroundingDinoObjectDetectionOutput(
pred_boxes=torch.rand(self.batch_size, self.num_queries, 4),
logits=torch.rand(self.batch_size, self.num_queries, self.embed_dim),
input_ids=self.get_fake_grounding_dino_input_ids(),
)
def get_fake_grounding_dino_input_ids(self):
@ -111,14 +127,11 @@ class GroundingDinoProcessorTest(ProcessorTesterMixin, unittest.TestCase):
processor = GroundingDinoProcessor(tokenizer=tokenizer, image_processor=image_processor)
grounding_dino_output = self.get_fake_grounding_dino_output()
grounding_dino_input_ids = self.get_fake_grounding_dino_input_ids()
post_processed = processor.post_process_grounded_object_detection(
grounding_dino_output, grounding_dino_input_ids
)
post_processed = processor.post_process_grounded_object_detection(grounding_dino_output)
self.assertEqual(len(post_processed), self.batch_size)
self.assertEqual(list(post_processed[0].keys()), ["scores", "labels", "boxes"])
self.assertEqual(list(post_processed[0].keys()), ["scores", "boxes", "text_labels", "labels"])
self.assertEqual(post_processed[0]["boxes"].shape, (self.num_queries, 4))
self.assertEqual(post_processed[0]["scores"].shape, (self.num_queries,))
@ -248,3 +261,26 @@ class GroundingDinoProcessorTest(ProcessorTesterMixin, unittest.TestCase):
inputs = processor(text=input_str, images=image_input)
self.assertListEqual(list(inputs.keys()), processor.model_input_names)
def test_text_preprocessing_equivalence(self):
processor = GroundingDinoProcessor.from_pretrained(self.tmpdirname)
# check for single input
formatted_labels = "a cat. a remote control."
labels = ["a cat", "a remote control"]
inputs1 = processor(text=formatted_labels, return_tensors="pt")
inputs2 = processor(text=labels, return_tensors="pt")
self.assertTrue(
torch.allclose(inputs1["input_ids"], inputs2["input_ids"]),
f"Input ids are not equal for single input: {inputs1['input_ids']} != {inputs2['input_ids']}",
)
# check for batched input
formatted_labels = ["a cat. a remote control.", "a car. a person."]
labels = [["a cat", "a remote control"], ["a car", "a person"]]
inputs1 = processor(text=formatted_labels, return_tensors="pt", padding=True)
inputs2 = processor(text=labels, return_tensors="pt", padding=True)
self.assertTrue(
torch.allclose(inputs1["input_ids"], inputs2["input_ids"]),
f"Input ids are not equal for batched input: {inputs1['input_ids']} != {inputs2['input_ids']}",
)