diff --git a/docs/source/en/model_doc/grounding-dino.md b/docs/source/en/model_doc/grounding-dino.md index a6da554f8d5..d024ff6ba73 100644 --- a/docs/source/en/model_doc/grounding-dino.md +++ b/docs/source/en/model_doc/grounding-dino.md @@ -56,25 +56,26 @@ Here's how to use the model for zero-shot object detection: >>> image_url = "http://images.cocodataset.org/val2017/000000039769.jpg" >>> image = Image.open(requests.get(image_url, stream=True).raw) >>> # Check for cats and remote controls ->>> text = "a cat. a remote control." +>>> text_labels = [["a cat", "a remote control"]] ->>> inputs = processor(images=image, text=text, return_tensors="pt").to(device) +>>> inputs = processor(images=image, text=text_labels, return_tensors="pt").to(device) >>> with torch.no_grad(): ... outputs = model(**inputs) >>> results = processor.post_process_grounded_object_detection( ... outputs, -... inputs.input_ids, -... box_threshold=0.4, +... threshold=0.4, ... text_threshold=0.3, -... target_sizes=[image.size[::-1]] +... target_sizes=[(image.height, image.width)] ... ) ->>> print(results) -[{'boxes': tensor([[344.6959, 23.1090, 637.1833, 374.2751], - [ 12.2666, 51.9145, 316.8582, 472.4392], - [ 38.5742, 70.0015, 176.7838, 118.1806]], device='cuda:0'), - 'labels': ['a cat', 'a cat', 'a remote control'], - 'scores': tensor([0.4785, 0.4381, 0.4776], device='cuda:0')}] +>>> # Retrieve the first image result +>>> result = results[0] +>>> for box, score, text_label in zip(result["boxes"], result["scores"], result["text_labels"]): +... box = [round(x, 2) for x in box.tolist()] +... print(f"Detected {text_label} with confidence {round(score.item(), 3)} at location {box}") +Detected a cat with confidence 0.479 at location [344.7, 23.11, 637.18, 374.28] +Detected a cat with confidence 0.438 at location [12.27, 51.91, 316.86, 472.44] +Detected a remote control with confidence 0.478 at location [38.57, 70.0, 176.78, 118.18] ``` ## Grounded SAM diff --git a/src/transformers/models/grounding_dino/modeling_grounding_dino.py b/src/transformers/models/grounding_dino/modeling_grounding_dino.py index 283409327bf..695ef41e2e0 100644 --- a/src/transformers/models/grounding_dino/modeling_grounding_dino.py +++ b/src/transformers/models/grounding_dino/modeling_grounding_dino.py @@ -286,7 +286,7 @@ class GroundingDinoObjectDetectionOutput(ModelOutput): pred_boxes (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`): Normalized boxes coordinates for all queries, represented as (center_x, center_y, width, height). These values are normalized in [0, 1], relative to the size of each individual image in the batch (disregarding - possible padding). You can use [`~GroundingDinoProcessor.post_process_object_detection`] to retrieve the + possible padding). You can use [`~GroundingDinoProcessor.post_process_grounded_object_detection`] to retrieve the unnormalized bounding boxes. auxiliary_outputs (`List[Dict]`, *optional*): Optional, only returned when auxilary losses are activated (i.e. `config.auxiliary_loss` is set to `True`) @@ -331,6 +331,8 @@ class GroundingDinoObjectDetectionOutput(ModelOutput): background). enc_outputs_coord_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, 4)`, *optional*, returned when `config.two_stage=True`): Logits of predicted bounding boxes coordinates in the first stage. + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Encoded candidate labels sequence. Used in processor to post process object detection result. """ loss: Optional[torch.FloatTensor] = None @@ -351,6 +353,7 @@ class GroundingDinoObjectDetectionOutput(ModelOutput): encoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None enc_outputs_class: Optional[torch.FloatTensor] = None enc_outputs_coord_logits: Optional[torch.FloatTensor] = None + input_ids: Optional[torch.LongTensor] = None # Copied from transformers.models.detr.modeling_detr.DetrFrozenBatchNorm2d with Detr->GroundingDino @@ -2546,30 +2549,41 @@ class GroundingDinoForObjectDetection(GroundingDinoPreTrainedModel): Examples: ```python - >>> from transformers import AutoProcessor, GroundingDinoForObjectDetection - >>> from PIL import Image >>> import requests - >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" - >>> image = Image.open(requests.get(url, stream=True).raw) - >>> text = "a cat." + >>> import torch + >>> from PIL import Image + >>> from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection - >>> processor = AutoProcessor.from_pretrained("IDEA-Research/grounding-dino-tiny") - >>> model = GroundingDinoForObjectDetection.from_pretrained("IDEA-Research/grounding-dino-tiny") + >>> model_id = "IDEA-Research/grounding-dino-tiny" + >>> device = "cuda" - >>> inputs = processor(images=image, text=text, return_tensors="pt") - >>> outputs = model(**inputs) + >>> processor = AutoProcessor.from_pretrained(model_id) + >>> model = AutoModelForZeroShotObjectDetection.from_pretrained(model_id).to(device) - >>> # convert outputs (bounding boxes and class logits) to COCO API - >>> target_sizes = torch.tensor([image.size[::-1]]) - >>> results = processor.image_processor.post_process_object_detection( - ... outputs, threshold=0.35, target_sizes=target_sizes - ... )[0] - >>> for score, label, box in zip(results["scores"], results["labels"], results["boxes"]): - ... box = [round(i, 1) for i in box.tolist()] - ... print(f"Detected {label.item()} with confidence " f"{round(score.item(), 2)} at location {box}") - Detected 1 with confidence 0.45 at location [344.8, 23.2, 637.4, 373.8] - Detected 1 with confidence 0.41 at location [11.9, 51.6, 316.6, 472.9] + >>> image_url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(image_url, stream=True).raw) + >>> # Check for cats and remote controls + >>> text_labels = [["a cat", "a remote control"]] + + >>> inputs = processor(images=image, text=text_labels, return_tensors="pt").to(device) + >>> with torch.no_grad(): + ... outputs = model(**inputs) + + >>> results = processor.post_process_grounded_object_detection( + ... outputs, + ... threshold=0.4, + ... text_threshold=0.3, + ... target_sizes=[(image.height, image.width)] + ... ) + >>> # Retrieve the first image result + >>> result = results[0] + >>> for box, score, text_label in zip(result["boxes"], result["scores"], result["text_labels"]): + ... box = [round(x, 2) for x in box.tolist()] + ... print(f"Detected {text_label} with confidence {round(score.item(), 3)} at location {box}") + Detected a cat with confidence 0.479 at location [344.7, 23.11, 637.18, 374.28] + Detected a cat with confidence 0.438 at location [12.27, 51.91, 316.86, 472.44] + Detected a remote control with confidence 0.478 at location [38.57, 70.0, 176.78, 118.18] ```""" return_dict = return_dict if return_dict is not None else self.config.use_return_dict @@ -2639,13 +2653,10 @@ class GroundingDinoForObjectDetection(GroundingDinoPreTrainedModel): ) if not return_dict: - if auxiliary_outputs is not None: - output = (logits, pred_boxes) + auxiliary_outputs + outputs - else: - output = (logits, pred_boxes) + outputs - tuple_outputs = ((loss, loss_dict) + output) if loss is not None else output - - return tuple_outputs + auxiliary_outputs = auxiliary_outputs if auxiliary_outputs is not None else [] + output = [loss, loss_dict, logits, pred_boxes, *auxiliary_outputs, *outputs, input_ids] + output = tuple(out for out in output if out is not None) + return output dict_outputs = GroundingDinoObjectDetectionOutput( loss=loss, @@ -2666,6 +2677,7 @@ class GroundingDinoForObjectDetection(GroundingDinoPreTrainedModel): init_reference_points=outputs.init_reference_points, enc_outputs_class=outputs.enc_outputs_class, enc_outputs_coord_logits=outputs.enc_outputs_coord_logits, + input_ids=input_ids, ) return dict_outputs diff --git a/src/transformers/models/grounding_dino/processing_grounding_dino.py b/src/transformers/models/grounding_dino/processing_grounding_dino.py index 9dbcea64328..f21846ab189 100644 --- a/src/transformers/models/grounding_dino/processing_grounding_dino.py +++ b/src/transformers/models/grounding_dino/processing_grounding_dino.py @@ -17,7 +17,8 @@ Processor class for Grounding DINO. """ import pathlib -from typing import Dict, List, Optional, Tuple, Union +import warnings +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union from ...image_processing_utils import BatchFeature from ...image_transforms import center_to_corners_format @@ -25,11 +26,15 @@ from ...image_utils import AnnotationFormat, ImageInput from ...processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin, Unpack from ...tokenization_utils_base import BatchEncoding, PreTokenizedInput, TextInput from ...utils import TensorType, is_torch_available +from ...utils.deprecation import deprecate_kwarg if is_torch_available(): import torch +if TYPE_CHECKING: + from .modeling_grounding_dino import GroundingDinoObjectDetectionOutput + AnnotationType = Dict[str, Union[int, str, List[Dict]]] @@ -60,6 +65,42 @@ def get_phrases_from_posmap(posmaps, input_ids): return token_ids +def _is_list_of_candidate_labels(text) -> bool: + """Check that text is list/tuple of strings and each string is a candidate label and not merged candidate labels text. + Merged candidate labels text is a string with candidate labels separated by a dot. + """ + if isinstance(text, (list, tuple)): + return all(isinstance(t, str) and "." not in t for t in text) + return False + + +def _merge_candidate_labels_text(text: List[str]) -> str: + """ + Merge candidate labels text into a single string. Ensure all labels are lowercase. + For example, ["A cat", "a dog"] -> "a cat. a dog." + """ + labels = [t.strip().lower() for t in text] # ensure lowercase + merged_labels_str = ". ".join(labels) + "." # join with dot and add a dot at the end + return merged_labels_str + + +class DictWithDeprecationWarning(dict): + message = ( + "The key `labels` is will return integer ids in `GroundingDinoProcessor.post_process_grounded_object_detection` " + "output since v4.51.0. Use `text_labels` instead to retrieve string object names." + ) + + def __getitem__(self, key): + if key == "labels": + warnings.warn(self.message, FutureWarning) + return super().__getitem__(key) + + def get(self, key, *args, **kwargs): + if key == "labels": + warnings.warn(self.message, FutureWarning) + return super().get(key, *args, **kwargs) + + class GroundingDinoImagesKwargs(ImagesKwargs, total=False): annotations: Optional[Union[AnnotationType, List[AnnotationType]]] return_segmentation_masks: Optional[bool] @@ -120,7 +161,15 @@ class GroundingDinoProcessor(ProcessorMixin): This method uses [`GroundingDinoImageProcessor.__call__`] method to prepare image(s) for the model, and [`BertTokenizerFast.__call__`] to prepare text for the model. - Please refer to the docstring of the above two methods for more information. + Args: + images (`ImageInput`, `List[ImageInput]`, *optional*): + The image or batch of images to be processed. The image might be either PIL image, numpy array or a torch tensor. + text (`TextInput`, `PreTokenizedInput`, `List[TextInput]`, `List[PreTokenizedInput]`, *optional*): + Candidate labels to be detected on the image. The text might be one of the following: + - A list of candidate labels (strings) to be detected on the image (e.g. ["a cat", "a dog"]). + - A batch of candidate labels to be detected on the batch of images (e.g. [["a cat", "a dog"], ["a car", "a person"]]). + - A merged candidate labels string to be detected on the image, separated by "." (e.g. "a cat. a dog."). + - A batch of merged candidate labels text to be detected on the batch of images (e.g. ["a cat. a dog.", "a car. a person."]). """ if images is None and text is None: raise ValueError("You must specify either text or images.") @@ -138,6 +187,7 @@ class GroundingDinoProcessor(ProcessorMixin): encoding_image_processor = BatchFeature() if text is not None: + text = self._preprocess_input_text(text) text_encoding = self.tokenizer( text=text, **output_kwargs["text_kwargs"], @@ -149,6 +199,23 @@ class GroundingDinoProcessor(ProcessorMixin): return text_encoding + def _preprocess_input_text(self, text): + """ + Preprocess input text to ensure that labels are in the correct format for the model. + If the text is a list of candidate labels, merge the candidate labels into a single string, + for example, ["a cat", "a dog"] -> "a cat. a dog.". In case candidate labels are already in a form of + "a cat. a dog.", the text is returned as is. + """ + + if _is_list_of_candidate_labels(text): + text = _merge_candidate_labels_text(text) + + # for batched input + elif isinstance(text, (list, tuple)) and all(_is_list_of_candidate_labels(t) for t in text): + text = [_merge_candidate_labels_text(sample) for sample in text] + + return text + # Copied from transformers.models.blip.processing_blip.BlipProcessor.batch_decode with BertTokenizerFast->PreTrainedTokenizer def batch_decode(self, *args, **kwargs): """ @@ -172,13 +239,15 @@ class GroundingDinoProcessor(ProcessorMixin): image_processor_input_names = self.image_processor.model_input_names return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names)) + @deprecate_kwarg("box_threshold", new_name="threshold", version="4.51.0") def post_process_grounded_object_detection( self, - outputs, - input_ids, - box_threshold: float = 0.25, + outputs: "GroundingDinoObjectDetectionOutput", + input_ids: Optional[TensorType] = None, + threshold: float = 0.25, text_threshold: float = 0.25, - target_sizes: Union[TensorType, List[Tuple]] = None, + target_sizes: Optional[Union[TensorType, List[Tuple]]] = None, + text_labels: Optional[List[List[str]]] = None, ): """ Converts the raw output of [`GroundingDinoForObjectDetection`] into final bounding boxes in (top_left_x, top_left_y, @@ -187,32 +256,38 @@ class GroundingDinoProcessor(ProcessorMixin): Args: outputs ([`GroundingDinoObjectDetectionOutput`]): Raw outputs of the model. - input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): - The token ids of the input text. - box_threshold (`float`, *optional*, defaults to 0.25): - Score threshold to keep object detection predictions. + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + The token ids of the input text. If not provided will be taken from the model output. + threshold (`float`, *optional*, defaults to 0.25): + Threshold to keep object detection predictions based on confidence score. text_threshold (`float`, *optional*, defaults to 0.25): Score threshold to keep text detection predictions. target_sizes (`torch.Tensor` or `List[Tuple[int, int]]`, *optional*): Tensor of shape `(batch_size, 2)` or list of tuples (`Tuple[int, int]`) containing the target size `(height, width)` of each image in the batch. If unset, predictions will not be resized. + text_labels (`List[List[str]]`, *optional*): + List of candidate labels to be detected on each image. At the moment it's *NOT used*, but required + to be in signature for the zero-shot object detection pipeline. Text labels are instead extracted + from the `input_ids` tensor provided in `outputs`. + Returns: - `List[Dict]`: A list of dictionaries, each dictionary containing the scores, labels and boxes for an image - in the batch as predicted by the model. + `List[Dict]`: A list of dictionaries, each dictionary containing the + - **scores**: tensor of confidence scores for detected objects + - **boxes**: tensor of bounding boxes in [x0, y0, x1, y1] format + - **labels**: list of text labels for each detected object (will be replaced with integer ids in v4.51.0) + - **text_labels**: list of text labels for detected objects """ - logits, boxes = outputs.logits, outputs.pred_boxes + batch_logits, batch_boxes = outputs.logits, outputs.pred_boxes + input_ids = input_ids if input_ids is not None else outputs.input_ids - if target_sizes is not None: - if len(logits) != len(target_sizes): - raise ValueError( - "Make sure that you pass in as many target sizes as the batch dimension of the logits" - ) + if target_sizes is not None and len(target_sizes) != len(batch_logits): + raise ValueError("Make sure that you pass in as many target sizes as the batch dimension of the logits") - probs = torch.sigmoid(logits) # (batch_size, num_queries, 256) - scores = torch.max(probs, dim=-1)[0] # (batch_size, num_queries) + batch_probs = torch.sigmoid(batch_logits) # (batch_size, num_queries, 256) + batch_scores = torch.max(batch_probs, dim=-1)[0] # (batch_size, num_queries) # Convert to [x0, y0, x1, y1] format - boxes = center_to_corners_format(boxes) + batch_boxes = center_to_corners_format(batch_boxes) # Convert from relative [0, 1] to absolute [0, height] coordinates if target_sizes is not None: @@ -222,17 +297,30 @@ class GroundingDinoProcessor(ProcessorMixin): else: img_h, img_w = target_sizes.unbind(1) - scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1).to(boxes.device) - boxes = boxes * scale_fct[:, None, :] + scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1).to(batch_boxes.device) + batch_boxes = batch_boxes * scale_fct[:, None, :] results = [] - for idx, (s, b, p) in enumerate(zip(scores, boxes, probs)): - score = s[s > box_threshold] - box = b[s > box_threshold] - prob = p[s > box_threshold] + for idx, (scores, boxes, probs) in enumerate(zip(batch_scores, batch_boxes, batch_probs)): + keep = scores > threshold + scores = scores[keep] + boxes = boxes[keep] + + # extract text labels + prob = probs[keep] label_ids = get_phrases_from_posmap(prob > text_threshold, input_ids[idx]) - label = self.batch_decode(label_ids) - results.append({"scores": score, "labels": label, "boxes": box}) + objects_text_labels = self.batch_decode(label_ids) + + result = DictWithDeprecationWarning( + { + "scores": scores, + "boxes": boxes, + "text_labels": objects_text_labels, + # TODO: @pavel, set labels to None since v4.51.0 or find a way to extract ids + "labels": objects_text_labels, + } + ) + results.append(result) return results diff --git a/tests/models/grounding_dino/test_modeling_grounding_dino.py b/tests/models/grounding_dino/test_modeling_grounding_dino.py index c6e9671dd59..30a8d44c8e9 100644 --- a/tests/models/grounding_dino/test_modeling_grounding_dino.py +++ b/tests/models/grounding_dino/test_modeling_grounding_dino.py @@ -322,9 +322,9 @@ class GroundingDinoModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.Tes # loss is at first position if "labels" in inputs_dict: correct_outlen += 1 # loss is added to beginning - # Object Detection model returns pred_logits and pred_boxes + # Object Detection model returns pred_logits and pred_boxes and input_ids if model_class.__name__ == "GroundingDinoForObjectDetection": - correct_outlen += 2 + correct_outlen += 3 self.assertEqual(out_len, correct_outlen) @@ -653,7 +653,7 @@ class GroundingDinoModelIntegrationTests(unittest.TestCase): # verify postprocessing results = processor.image_processor.post_process_object_detection( - outputs, threshold=0.35, target_sizes=[image.size[::-1]] + outputs, threshold=0.35, target_sizes=[(image.height, image.width)] )[0] expected_scores = torch.tensor([0.4526, 0.4082]).to(torch_device) expected_slice_boxes = torch.tensor([344.8143, 23.1796, 637.4004, 373.8295]).to(torch_device) @@ -667,14 +667,14 @@ class GroundingDinoModelIntegrationTests(unittest.TestCase): results = processor.post_process_grounded_object_detection( outputs=outputs, input_ids=encoding.input_ids, - box_threshold=0.35, + threshold=0.35, text_threshold=0.3, - target_sizes=[image.size[::-1]], + target_sizes=[(image.height, image.width)], )[0] self.assertTrue(torch.allclose(results["scores"], expected_scores, atol=1e-3)) self.assertTrue(torch.allclose(results["boxes"][0, :], expected_slice_boxes, atol=1e-2)) - self.assertListEqual(results["labels"], expected_labels) + self.assertListEqual(results["text_labels"], expected_labels) @require_torch_accelerator def test_inference_object_detection_head_equivalence_cpu_gpu(self): @@ -706,11 +706,11 @@ class GroundingDinoModelIntegrationTests(unittest.TestCase): # assert postprocessing results_cpu = processor.image_processor.post_process_object_detection( - cpu_outputs, threshold=0.35, target_sizes=[image.size[::-1]] + cpu_outputs, threshold=0.35, target_sizes=[(image.height, image.width)] )[0] result_gpu = processor.image_processor.post_process_object_detection( - gpu_outputs, threshold=0.35, target_sizes=[image.size[::-1]] + gpu_outputs, threshold=0.35, target_sizes=[(image.height, image.width)] )[0] self.assertTrue(torch.allclose(results_cpu["scores"], result_gpu["scores"].cpu(), atol=1e-3)) diff --git a/tests/models/grounding_dino/test_processor_grounding_dino.py b/tests/models/grounding_dino/test_processor_grounding_dino.py index c2d8aee828d..8f9ced4b0c4 100644 --- a/tests/models/grounding_dino/test_processor_grounding_dino.py +++ b/tests/models/grounding_dino/test_processor_grounding_dino.py @@ -17,6 +17,7 @@ import os import shutil import tempfile import unittest +from typing import Optional import pytest @@ -77,6 +78,20 @@ class GroundingDinoProcessorTest(ProcessorTesterMixin, unittest.TestCase): self.embed_dim = 5 self.seq_length = 5 + def prepare_text_inputs(self, batch_size: Optional[int] = None): + labels = ["a cat", "remote control"] + labels_longer = ["a person", "a car", "a dog", "a cat"] + + if batch_size is None: + return labels + + if batch_size < 1: + raise ValueError("batch_size must be greater than 0") + + if batch_size == 1: + return [labels] + return [labels, labels_longer] + [labels] * (batch_size - 2) + # Copied from tests.models.clip.test_processor_clip.CLIPProcessorTest.get_tokenizer with CLIP->Bert def get_tokenizer(self, **kwargs): return BertTokenizer.from_pretrained(self.tmpdirname, **kwargs) @@ -98,6 +113,7 @@ class GroundingDinoProcessorTest(ProcessorTesterMixin, unittest.TestCase): return GroundingDinoObjectDetectionOutput( pred_boxes=torch.rand(self.batch_size, self.num_queries, 4), logits=torch.rand(self.batch_size, self.num_queries, self.embed_dim), + input_ids=self.get_fake_grounding_dino_input_ids(), ) def get_fake_grounding_dino_input_ids(self): @@ -111,14 +127,11 @@ class GroundingDinoProcessorTest(ProcessorTesterMixin, unittest.TestCase): processor = GroundingDinoProcessor(tokenizer=tokenizer, image_processor=image_processor) grounding_dino_output = self.get_fake_grounding_dino_output() - grounding_dino_input_ids = self.get_fake_grounding_dino_input_ids() - post_processed = processor.post_process_grounded_object_detection( - grounding_dino_output, grounding_dino_input_ids - ) + post_processed = processor.post_process_grounded_object_detection(grounding_dino_output) self.assertEqual(len(post_processed), self.batch_size) - self.assertEqual(list(post_processed[0].keys()), ["scores", "labels", "boxes"]) + self.assertEqual(list(post_processed[0].keys()), ["scores", "boxes", "text_labels", "labels"]) self.assertEqual(post_processed[0]["boxes"].shape, (self.num_queries, 4)) self.assertEqual(post_processed[0]["scores"].shape, (self.num_queries,)) @@ -248,3 +261,26 @@ class GroundingDinoProcessorTest(ProcessorTesterMixin, unittest.TestCase): inputs = processor(text=input_str, images=image_input) self.assertListEqual(list(inputs.keys()), processor.model_input_names) + + def test_text_preprocessing_equivalence(self): + processor = GroundingDinoProcessor.from_pretrained(self.tmpdirname) + + # check for single input + formatted_labels = "a cat. a remote control." + labels = ["a cat", "a remote control"] + inputs1 = processor(text=formatted_labels, return_tensors="pt") + inputs2 = processor(text=labels, return_tensors="pt") + self.assertTrue( + torch.allclose(inputs1["input_ids"], inputs2["input_ids"]), + f"Input ids are not equal for single input: {inputs1['input_ids']} != {inputs2['input_ids']}", + ) + + # check for batched input + formatted_labels = ["a cat. a remote control.", "a car. a person."] + labels = [["a cat", "a remote control"], ["a car", "a person"]] + inputs1 = processor(text=formatted_labels, return_tensors="pt", padding=True) + inputs2 = processor(text=labels, return_tensors="pt", padding=True) + self.assertTrue( + torch.allclose(inputs1["input_ids"], inputs2["input_ids"]), + f"Input ids are not equal for batched input: {inputs1['input_ids']} != {inputs2['input_ids']}", + )