mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Grounding DINO Processor standardization (#34853)
* Add input ids to model output * Add text preprocessing for processor * Fix snippet * Add test for equivalence * Add type checking guard * Fixing typehint * Fix test for added `input_ids` in output * Add deprecations and "text_labels" to output * Adjust tests * Fix test * Update code examples * Minor docs and code improvement * Remove one-liner functions and rename class to CamelCase * Update docstring * Fixup
This commit is contained in:
parent
42b2857b01
commit
099d93d2e9
@ -56,25 +56,26 @@ Here's how to use the model for zero-shot object detection:
|
||||
>>> image_url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
>>> image = Image.open(requests.get(image_url, stream=True).raw)
|
||||
>>> # Check for cats and remote controls
|
||||
>>> text = "a cat. a remote control."
|
||||
>>> text_labels = [["a cat", "a remote control"]]
|
||||
|
||||
>>> inputs = processor(images=image, text=text, return_tensors="pt").to(device)
|
||||
>>> inputs = processor(images=image, text=text_labels, return_tensors="pt").to(device)
|
||||
>>> with torch.no_grad():
|
||||
... outputs = model(**inputs)
|
||||
|
||||
>>> results = processor.post_process_grounded_object_detection(
|
||||
... outputs,
|
||||
... inputs.input_ids,
|
||||
... box_threshold=0.4,
|
||||
... threshold=0.4,
|
||||
... text_threshold=0.3,
|
||||
... target_sizes=[image.size[::-1]]
|
||||
... target_sizes=[(image.height, image.width)]
|
||||
... )
|
||||
>>> print(results)
|
||||
[{'boxes': tensor([[344.6959, 23.1090, 637.1833, 374.2751],
|
||||
[ 12.2666, 51.9145, 316.8582, 472.4392],
|
||||
[ 38.5742, 70.0015, 176.7838, 118.1806]], device='cuda:0'),
|
||||
'labels': ['a cat', 'a cat', 'a remote control'],
|
||||
'scores': tensor([0.4785, 0.4381, 0.4776], device='cuda:0')}]
|
||||
>>> # Retrieve the first image result
|
||||
>>> result = results[0]
|
||||
>>> for box, score, text_label in zip(result["boxes"], result["scores"], result["text_labels"]):
|
||||
... box = [round(x, 2) for x in box.tolist()]
|
||||
... print(f"Detected {text_label} with confidence {round(score.item(), 3)} at location {box}")
|
||||
Detected a cat with confidence 0.479 at location [344.7, 23.11, 637.18, 374.28]
|
||||
Detected a cat with confidence 0.438 at location [12.27, 51.91, 316.86, 472.44]
|
||||
Detected a remote control with confidence 0.478 at location [38.57, 70.0, 176.78, 118.18]
|
||||
```
|
||||
|
||||
## Grounded SAM
|
||||
|
@ -286,7 +286,7 @@ class GroundingDinoObjectDetectionOutput(ModelOutput):
|
||||
pred_boxes (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`):
|
||||
Normalized boxes coordinates for all queries, represented as (center_x, center_y, width, height). These
|
||||
values are normalized in [0, 1], relative to the size of each individual image in the batch (disregarding
|
||||
possible padding). You can use [`~GroundingDinoProcessor.post_process_object_detection`] to retrieve the
|
||||
possible padding). You can use [`~GroundingDinoProcessor.post_process_grounded_object_detection`] to retrieve the
|
||||
unnormalized bounding boxes.
|
||||
auxiliary_outputs (`List[Dict]`, *optional*):
|
||||
Optional, only returned when auxilary losses are activated (i.e. `config.auxiliary_loss` is set to `True`)
|
||||
@ -331,6 +331,8 @@ class GroundingDinoObjectDetectionOutput(ModelOutput):
|
||||
background).
|
||||
enc_outputs_coord_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, 4)`, *optional*, returned when `config.two_stage=True`):
|
||||
Logits of predicted bounding boxes coordinates in the first stage.
|
||||
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Encoded candidate labels sequence. Used in processor to post process object detection result.
|
||||
"""
|
||||
|
||||
loss: Optional[torch.FloatTensor] = None
|
||||
@ -351,6 +353,7 @@ class GroundingDinoObjectDetectionOutput(ModelOutput):
|
||||
encoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
||||
enc_outputs_class: Optional[torch.FloatTensor] = None
|
||||
enc_outputs_coord_logits: Optional[torch.FloatTensor] = None
|
||||
input_ids: Optional[torch.LongTensor] = None
|
||||
|
||||
|
||||
# Copied from transformers.models.detr.modeling_detr.DetrFrozenBatchNorm2d with Detr->GroundingDino
|
||||
@ -2546,30 +2549,41 @@ class GroundingDinoForObjectDetection(GroundingDinoPreTrainedModel):
|
||||
Examples:
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoProcessor, GroundingDinoForObjectDetection
|
||||
>>> from PIL import Image
|
||||
>>> import requests
|
||||
|
||||
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||||
>>> text = "a cat."
|
||||
>>> import torch
|
||||
>>> from PIL import Image
|
||||
>>> from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection
|
||||
|
||||
>>> processor = AutoProcessor.from_pretrained("IDEA-Research/grounding-dino-tiny")
|
||||
>>> model = GroundingDinoForObjectDetection.from_pretrained("IDEA-Research/grounding-dino-tiny")
|
||||
>>> model_id = "IDEA-Research/grounding-dino-tiny"
|
||||
>>> device = "cuda"
|
||||
|
||||
>>> inputs = processor(images=image, text=text, return_tensors="pt")
|
||||
>>> outputs = model(**inputs)
|
||||
>>> processor = AutoProcessor.from_pretrained(model_id)
|
||||
>>> model = AutoModelForZeroShotObjectDetection.from_pretrained(model_id).to(device)
|
||||
|
||||
>>> # convert outputs (bounding boxes and class logits) to COCO API
|
||||
>>> target_sizes = torch.tensor([image.size[::-1]])
|
||||
>>> results = processor.image_processor.post_process_object_detection(
|
||||
... outputs, threshold=0.35, target_sizes=target_sizes
|
||||
... )[0]
|
||||
>>> for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
|
||||
... box = [round(i, 1) for i in box.tolist()]
|
||||
... print(f"Detected {label.item()} with confidence " f"{round(score.item(), 2)} at location {box}")
|
||||
Detected 1 with confidence 0.45 at location [344.8, 23.2, 637.4, 373.8]
|
||||
Detected 1 with confidence 0.41 at location [11.9, 51.6, 316.6, 472.9]
|
||||
>>> image_url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
>>> image = Image.open(requests.get(image_url, stream=True).raw)
|
||||
>>> # Check for cats and remote controls
|
||||
>>> text_labels = [["a cat", "a remote control"]]
|
||||
|
||||
>>> inputs = processor(images=image, text=text_labels, return_tensors="pt").to(device)
|
||||
>>> with torch.no_grad():
|
||||
... outputs = model(**inputs)
|
||||
|
||||
>>> results = processor.post_process_grounded_object_detection(
|
||||
... outputs,
|
||||
... threshold=0.4,
|
||||
... text_threshold=0.3,
|
||||
... target_sizes=[(image.height, image.width)]
|
||||
... )
|
||||
>>> # Retrieve the first image result
|
||||
>>> result = results[0]
|
||||
>>> for box, score, text_label in zip(result["boxes"], result["scores"], result["text_labels"]):
|
||||
... box = [round(x, 2) for x in box.tolist()]
|
||||
... print(f"Detected {text_label} with confidence {round(score.item(), 3)} at location {box}")
|
||||
Detected a cat with confidence 0.479 at location [344.7, 23.11, 637.18, 374.28]
|
||||
Detected a cat with confidence 0.438 at location [12.27, 51.91, 316.86, 472.44]
|
||||
Detected a remote control with confidence 0.478 at location [38.57, 70.0, 176.78, 118.18]
|
||||
```"""
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
@ -2639,13 +2653,10 @@ class GroundingDinoForObjectDetection(GroundingDinoPreTrainedModel):
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
if auxiliary_outputs is not None:
|
||||
output = (logits, pred_boxes) + auxiliary_outputs + outputs
|
||||
else:
|
||||
output = (logits, pred_boxes) + outputs
|
||||
tuple_outputs = ((loss, loss_dict) + output) if loss is not None else output
|
||||
|
||||
return tuple_outputs
|
||||
auxiliary_outputs = auxiliary_outputs if auxiliary_outputs is not None else []
|
||||
output = [loss, loss_dict, logits, pred_boxes, *auxiliary_outputs, *outputs, input_ids]
|
||||
output = tuple(out for out in output if out is not None)
|
||||
return output
|
||||
|
||||
dict_outputs = GroundingDinoObjectDetectionOutput(
|
||||
loss=loss,
|
||||
@ -2666,6 +2677,7 @@ class GroundingDinoForObjectDetection(GroundingDinoPreTrainedModel):
|
||||
init_reference_points=outputs.init_reference_points,
|
||||
enc_outputs_class=outputs.enc_outputs_class,
|
||||
enc_outputs_coord_logits=outputs.enc_outputs_coord_logits,
|
||||
input_ids=input_ids,
|
||||
)
|
||||
|
||||
return dict_outputs
|
||||
|
@ -17,7 +17,8 @@ Processor class for Grounding DINO.
|
||||
"""
|
||||
|
||||
import pathlib
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
import warnings
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from ...image_processing_utils import BatchFeature
|
||||
from ...image_transforms import center_to_corners_format
|
||||
@ -25,11 +26,15 @@ from ...image_utils import AnnotationFormat, ImageInput
|
||||
from ...processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin, Unpack
|
||||
from ...tokenization_utils_base import BatchEncoding, PreTokenizedInput, TextInput
|
||||
from ...utils import TensorType, is_torch_available
|
||||
from ...utils.deprecation import deprecate_kwarg
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .modeling_grounding_dino import GroundingDinoObjectDetectionOutput
|
||||
|
||||
|
||||
AnnotationType = Dict[str, Union[int, str, List[Dict]]]
|
||||
|
||||
@ -60,6 +65,42 @@ def get_phrases_from_posmap(posmaps, input_ids):
|
||||
return token_ids
|
||||
|
||||
|
||||
def _is_list_of_candidate_labels(text) -> bool:
|
||||
"""Check that text is list/tuple of strings and each string is a candidate label and not merged candidate labels text.
|
||||
Merged candidate labels text is a string with candidate labels separated by a dot.
|
||||
"""
|
||||
if isinstance(text, (list, tuple)):
|
||||
return all(isinstance(t, str) and "." not in t for t in text)
|
||||
return False
|
||||
|
||||
|
||||
def _merge_candidate_labels_text(text: List[str]) -> str:
|
||||
"""
|
||||
Merge candidate labels text into a single string. Ensure all labels are lowercase.
|
||||
For example, ["A cat", "a dog"] -> "a cat. a dog."
|
||||
"""
|
||||
labels = [t.strip().lower() for t in text] # ensure lowercase
|
||||
merged_labels_str = ". ".join(labels) + "." # join with dot and add a dot at the end
|
||||
return merged_labels_str
|
||||
|
||||
|
||||
class DictWithDeprecationWarning(dict):
|
||||
message = (
|
||||
"The key `labels` is will return integer ids in `GroundingDinoProcessor.post_process_grounded_object_detection` "
|
||||
"output since v4.51.0. Use `text_labels` instead to retrieve string object names."
|
||||
)
|
||||
|
||||
def __getitem__(self, key):
|
||||
if key == "labels":
|
||||
warnings.warn(self.message, FutureWarning)
|
||||
return super().__getitem__(key)
|
||||
|
||||
def get(self, key, *args, **kwargs):
|
||||
if key == "labels":
|
||||
warnings.warn(self.message, FutureWarning)
|
||||
return super().get(key, *args, **kwargs)
|
||||
|
||||
|
||||
class GroundingDinoImagesKwargs(ImagesKwargs, total=False):
|
||||
annotations: Optional[Union[AnnotationType, List[AnnotationType]]]
|
||||
return_segmentation_masks: Optional[bool]
|
||||
@ -120,7 +161,15 @@ class GroundingDinoProcessor(ProcessorMixin):
|
||||
This method uses [`GroundingDinoImageProcessor.__call__`] method to prepare image(s) for the model, and
|
||||
[`BertTokenizerFast.__call__`] to prepare text for the model.
|
||||
|
||||
Please refer to the docstring of the above two methods for more information.
|
||||
Args:
|
||||
images (`ImageInput`, `List[ImageInput]`, *optional*):
|
||||
The image or batch of images to be processed. The image might be either PIL image, numpy array or a torch tensor.
|
||||
text (`TextInput`, `PreTokenizedInput`, `List[TextInput]`, `List[PreTokenizedInput]`, *optional*):
|
||||
Candidate labels to be detected on the image. The text might be one of the following:
|
||||
- A list of candidate labels (strings) to be detected on the image (e.g. ["a cat", "a dog"]).
|
||||
- A batch of candidate labels to be detected on the batch of images (e.g. [["a cat", "a dog"], ["a car", "a person"]]).
|
||||
- A merged candidate labels string to be detected on the image, separated by "." (e.g. "a cat. a dog.").
|
||||
- A batch of merged candidate labels text to be detected on the batch of images (e.g. ["a cat. a dog.", "a car. a person."]).
|
||||
"""
|
||||
if images is None and text is None:
|
||||
raise ValueError("You must specify either text or images.")
|
||||
@ -138,6 +187,7 @@ class GroundingDinoProcessor(ProcessorMixin):
|
||||
encoding_image_processor = BatchFeature()
|
||||
|
||||
if text is not None:
|
||||
text = self._preprocess_input_text(text)
|
||||
text_encoding = self.tokenizer(
|
||||
text=text,
|
||||
**output_kwargs["text_kwargs"],
|
||||
@ -149,6 +199,23 @@ class GroundingDinoProcessor(ProcessorMixin):
|
||||
|
||||
return text_encoding
|
||||
|
||||
def _preprocess_input_text(self, text):
|
||||
"""
|
||||
Preprocess input text to ensure that labels are in the correct format for the model.
|
||||
If the text is a list of candidate labels, merge the candidate labels into a single string,
|
||||
for example, ["a cat", "a dog"] -> "a cat. a dog.". In case candidate labels are already in a form of
|
||||
"a cat. a dog.", the text is returned as is.
|
||||
"""
|
||||
|
||||
if _is_list_of_candidate_labels(text):
|
||||
text = _merge_candidate_labels_text(text)
|
||||
|
||||
# for batched input
|
||||
elif isinstance(text, (list, tuple)) and all(_is_list_of_candidate_labels(t) for t in text):
|
||||
text = [_merge_candidate_labels_text(sample) for sample in text]
|
||||
|
||||
return text
|
||||
|
||||
# Copied from transformers.models.blip.processing_blip.BlipProcessor.batch_decode with BertTokenizerFast->PreTrainedTokenizer
|
||||
def batch_decode(self, *args, **kwargs):
|
||||
"""
|
||||
@ -172,13 +239,15 @@ class GroundingDinoProcessor(ProcessorMixin):
|
||||
image_processor_input_names = self.image_processor.model_input_names
|
||||
return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
|
||||
|
||||
@deprecate_kwarg("box_threshold", new_name="threshold", version="4.51.0")
|
||||
def post_process_grounded_object_detection(
|
||||
self,
|
||||
outputs,
|
||||
input_ids,
|
||||
box_threshold: float = 0.25,
|
||||
outputs: "GroundingDinoObjectDetectionOutput",
|
||||
input_ids: Optional[TensorType] = None,
|
||||
threshold: float = 0.25,
|
||||
text_threshold: float = 0.25,
|
||||
target_sizes: Union[TensorType, List[Tuple]] = None,
|
||||
target_sizes: Optional[Union[TensorType, List[Tuple]]] = None,
|
||||
text_labels: Optional[List[List[str]]] = None,
|
||||
):
|
||||
"""
|
||||
Converts the raw output of [`GroundingDinoForObjectDetection`] into final bounding boxes in (top_left_x, top_left_y,
|
||||
@ -187,32 +256,38 @@ class GroundingDinoProcessor(ProcessorMixin):
|
||||
Args:
|
||||
outputs ([`GroundingDinoObjectDetectionOutput`]):
|
||||
Raw outputs of the model.
|
||||
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
||||
The token ids of the input text.
|
||||
box_threshold (`float`, *optional*, defaults to 0.25):
|
||||
Score threshold to keep object detection predictions.
|
||||
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
The token ids of the input text. If not provided will be taken from the model output.
|
||||
threshold (`float`, *optional*, defaults to 0.25):
|
||||
Threshold to keep object detection predictions based on confidence score.
|
||||
text_threshold (`float`, *optional*, defaults to 0.25):
|
||||
Score threshold to keep text detection predictions.
|
||||
target_sizes (`torch.Tensor` or `List[Tuple[int, int]]`, *optional*):
|
||||
Tensor of shape `(batch_size, 2)` or list of tuples (`Tuple[int, int]`) containing the target size
|
||||
`(height, width)` of each image in the batch. If unset, predictions will not be resized.
|
||||
text_labels (`List[List[str]]`, *optional*):
|
||||
List of candidate labels to be detected on each image. At the moment it's *NOT used*, but required
|
||||
to be in signature for the zero-shot object detection pipeline. Text labels are instead extracted
|
||||
from the `input_ids` tensor provided in `outputs`.
|
||||
|
||||
Returns:
|
||||
`List[Dict]`: A list of dictionaries, each dictionary containing the scores, labels and boxes for an image
|
||||
in the batch as predicted by the model.
|
||||
`List[Dict]`: A list of dictionaries, each dictionary containing the
|
||||
- **scores**: tensor of confidence scores for detected objects
|
||||
- **boxes**: tensor of bounding boxes in [x0, y0, x1, y1] format
|
||||
- **labels**: list of text labels for each detected object (will be replaced with integer ids in v4.51.0)
|
||||
- **text_labels**: list of text labels for detected objects
|
||||
"""
|
||||
logits, boxes = outputs.logits, outputs.pred_boxes
|
||||
batch_logits, batch_boxes = outputs.logits, outputs.pred_boxes
|
||||
input_ids = input_ids if input_ids is not None else outputs.input_ids
|
||||
|
||||
if target_sizes is not None:
|
||||
if len(logits) != len(target_sizes):
|
||||
raise ValueError(
|
||||
"Make sure that you pass in as many target sizes as the batch dimension of the logits"
|
||||
)
|
||||
if target_sizes is not None and len(target_sizes) != len(batch_logits):
|
||||
raise ValueError("Make sure that you pass in as many target sizes as the batch dimension of the logits")
|
||||
|
||||
probs = torch.sigmoid(logits) # (batch_size, num_queries, 256)
|
||||
scores = torch.max(probs, dim=-1)[0] # (batch_size, num_queries)
|
||||
batch_probs = torch.sigmoid(batch_logits) # (batch_size, num_queries, 256)
|
||||
batch_scores = torch.max(batch_probs, dim=-1)[0] # (batch_size, num_queries)
|
||||
|
||||
# Convert to [x0, y0, x1, y1] format
|
||||
boxes = center_to_corners_format(boxes)
|
||||
batch_boxes = center_to_corners_format(batch_boxes)
|
||||
|
||||
# Convert from relative [0, 1] to absolute [0, height] coordinates
|
||||
if target_sizes is not None:
|
||||
@ -222,17 +297,30 @@ class GroundingDinoProcessor(ProcessorMixin):
|
||||
else:
|
||||
img_h, img_w = target_sizes.unbind(1)
|
||||
|
||||
scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1).to(boxes.device)
|
||||
boxes = boxes * scale_fct[:, None, :]
|
||||
scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1).to(batch_boxes.device)
|
||||
batch_boxes = batch_boxes * scale_fct[:, None, :]
|
||||
|
||||
results = []
|
||||
for idx, (s, b, p) in enumerate(zip(scores, boxes, probs)):
|
||||
score = s[s > box_threshold]
|
||||
box = b[s > box_threshold]
|
||||
prob = p[s > box_threshold]
|
||||
for idx, (scores, boxes, probs) in enumerate(zip(batch_scores, batch_boxes, batch_probs)):
|
||||
keep = scores > threshold
|
||||
scores = scores[keep]
|
||||
boxes = boxes[keep]
|
||||
|
||||
# extract text labels
|
||||
prob = probs[keep]
|
||||
label_ids = get_phrases_from_posmap(prob > text_threshold, input_ids[idx])
|
||||
label = self.batch_decode(label_ids)
|
||||
results.append({"scores": score, "labels": label, "boxes": box})
|
||||
objects_text_labels = self.batch_decode(label_ids)
|
||||
|
||||
result = DictWithDeprecationWarning(
|
||||
{
|
||||
"scores": scores,
|
||||
"boxes": boxes,
|
||||
"text_labels": objects_text_labels,
|
||||
# TODO: @pavel, set labels to None since v4.51.0 or find a way to extract ids
|
||||
"labels": objects_text_labels,
|
||||
}
|
||||
)
|
||||
results.append(result)
|
||||
|
||||
return results
|
||||
|
||||
|
@ -322,9 +322,9 @@ class GroundingDinoModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.Tes
|
||||
# loss is at first position
|
||||
if "labels" in inputs_dict:
|
||||
correct_outlen += 1 # loss is added to beginning
|
||||
# Object Detection model returns pred_logits and pred_boxes
|
||||
# Object Detection model returns pred_logits and pred_boxes and input_ids
|
||||
if model_class.__name__ == "GroundingDinoForObjectDetection":
|
||||
correct_outlen += 2
|
||||
correct_outlen += 3
|
||||
|
||||
self.assertEqual(out_len, correct_outlen)
|
||||
|
||||
@ -653,7 +653,7 @@ class GroundingDinoModelIntegrationTests(unittest.TestCase):
|
||||
|
||||
# verify postprocessing
|
||||
results = processor.image_processor.post_process_object_detection(
|
||||
outputs, threshold=0.35, target_sizes=[image.size[::-1]]
|
||||
outputs, threshold=0.35, target_sizes=[(image.height, image.width)]
|
||||
)[0]
|
||||
expected_scores = torch.tensor([0.4526, 0.4082]).to(torch_device)
|
||||
expected_slice_boxes = torch.tensor([344.8143, 23.1796, 637.4004, 373.8295]).to(torch_device)
|
||||
@ -667,14 +667,14 @@ class GroundingDinoModelIntegrationTests(unittest.TestCase):
|
||||
results = processor.post_process_grounded_object_detection(
|
||||
outputs=outputs,
|
||||
input_ids=encoding.input_ids,
|
||||
box_threshold=0.35,
|
||||
threshold=0.35,
|
||||
text_threshold=0.3,
|
||||
target_sizes=[image.size[::-1]],
|
||||
target_sizes=[(image.height, image.width)],
|
||||
)[0]
|
||||
|
||||
self.assertTrue(torch.allclose(results["scores"], expected_scores, atol=1e-3))
|
||||
self.assertTrue(torch.allclose(results["boxes"][0, :], expected_slice_boxes, atol=1e-2))
|
||||
self.assertListEqual(results["labels"], expected_labels)
|
||||
self.assertListEqual(results["text_labels"], expected_labels)
|
||||
|
||||
@require_torch_accelerator
|
||||
def test_inference_object_detection_head_equivalence_cpu_gpu(self):
|
||||
@ -706,11 +706,11 @@ class GroundingDinoModelIntegrationTests(unittest.TestCase):
|
||||
|
||||
# assert postprocessing
|
||||
results_cpu = processor.image_processor.post_process_object_detection(
|
||||
cpu_outputs, threshold=0.35, target_sizes=[image.size[::-1]]
|
||||
cpu_outputs, threshold=0.35, target_sizes=[(image.height, image.width)]
|
||||
)[0]
|
||||
|
||||
result_gpu = processor.image_processor.post_process_object_detection(
|
||||
gpu_outputs, threshold=0.35, target_sizes=[image.size[::-1]]
|
||||
gpu_outputs, threshold=0.35, target_sizes=[(image.height, image.width)]
|
||||
)[0]
|
||||
|
||||
self.assertTrue(torch.allclose(results_cpu["scores"], result_gpu["scores"].cpu(), atol=1e-3))
|
||||
|
@ -17,6 +17,7 @@ import os
|
||||
import shutil
|
||||
import tempfile
|
||||
import unittest
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
|
||||
@ -77,6 +78,20 @@ class GroundingDinoProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
self.embed_dim = 5
|
||||
self.seq_length = 5
|
||||
|
||||
def prepare_text_inputs(self, batch_size: Optional[int] = None):
|
||||
labels = ["a cat", "remote control"]
|
||||
labels_longer = ["a person", "a car", "a dog", "a cat"]
|
||||
|
||||
if batch_size is None:
|
||||
return labels
|
||||
|
||||
if batch_size < 1:
|
||||
raise ValueError("batch_size must be greater than 0")
|
||||
|
||||
if batch_size == 1:
|
||||
return [labels]
|
||||
return [labels, labels_longer] + [labels] * (batch_size - 2)
|
||||
|
||||
# Copied from tests.models.clip.test_processor_clip.CLIPProcessorTest.get_tokenizer with CLIP->Bert
|
||||
def get_tokenizer(self, **kwargs):
|
||||
return BertTokenizer.from_pretrained(self.tmpdirname, **kwargs)
|
||||
@ -98,6 +113,7 @@ class GroundingDinoProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
return GroundingDinoObjectDetectionOutput(
|
||||
pred_boxes=torch.rand(self.batch_size, self.num_queries, 4),
|
||||
logits=torch.rand(self.batch_size, self.num_queries, self.embed_dim),
|
||||
input_ids=self.get_fake_grounding_dino_input_ids(),
|
||||
)
|
||||
|
||||
def get_fake_grounding_dino_input_ids(self):
|
||||
@ -111,14 +127,11 @@ class GroundingDinoProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
processor = GroundingDinoProcessor(tokenizer=tokenizer, image_processor=image_processor)
|
||||
|
||||
grounding_dino_output = self.get_fake_grounding_dino_output()
|
||||
grounding_dino_input_ids = self.get_fake_grounding_dino_input_ids()
|
||||
|
||||
post_processed = processor.post_process_grounded_object_detection(
|
||||
grounding_dino_output, grounding_dino_input_ids
|
||||
)
|
||||
post_processed = processor.post_process_grounded_object_detection(grounding_dino_output)
|
||||
|
||||
self.assertEqual(len(post_processed), self.batch_size)
|
||||
self.assertEqual(list(post_processed[0].keys()), ["scores", "labels", "boxes"])
|
||||
self.assertEqual(list(post_processed[0].keys()), ["scores", "boxes", "text_labels", "labels"])
|
||||
self.assertEqual(post_processed[0]["boxes"].shape, (self.num_queries, 4))
|
||||
self.assertEqual(post_processed[0]["scores"].shape, (self.num_queries,))
|
||||
|
||||
@ -248,3 +261,26 @@ class GroundingDinoProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
inputs = processor(text=input_str, images=image_input)
|
||||
|
||||
self.assertListEqual(list(inputs.keys()), processor.model_input_names)
|
||||
|
||||
def test_text_preprocessing_equivalence(self):
|
||||
processor = GroundingDinoProcessor.from_pretrained(self.tmpdirname)
|
||||
|
||||
# check for single input
|
||||
formatted_labels = "a cat. a remote control."
|
||||
labels = ["a cat", "a remote control"]
|
||||
inputs1 = processor(text=formatted_labels, return_tensors="pt")
|
||||
inputs2 = processor(text=labels, return_tensors="pt")
|
||||
self.assertTrue(
|
||||
torch.allclose(inputs1["input_ids"], inputs2["input_ids"]),
|
||||
f"Input ids are not equal for single input: {inputs1['input_ids']} != {inputs2['input_ids']}",
|
||||
)
|
||||
|
||||
# check for batched input
|
||||
formatted_labels = ["a cat. a remote control.", "a car. a person."]
|
||||
labels = [["a cat", "a remote control"], ["a car", "a person"]]
|
||||
inputs1 = processor(text=formatted_labels, return_tensors="pt", padding=True)
|
||||
inputs2 = processor(text=labels, return_tensors="pt", padding=True)
|
||||
self.assertTrue(
|
||||
torch.allclose(inputs1["input_ids"], inputs2["input_ids"]),
|
||||
f"Input ids are not equal for batched input: {inputs1['input_ids']} != {inputs2['input_ids']}",
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user