diff --git a/src/transformers/models/owlv2/image_processing_owlv2.py b/src/transformers/models/owlv2/image_processing_owlv2.py index 2ba3772d000..1e9a5163a1a 100644 --- a/src/transformers/models/owlv2/image_processing_owlv2.py +++ b/src/transformers/models/owlv2/image_processing_owlv2.py @@ -524,19 +524,11 @@ class Owlv2ImageProcessor(BaseImageProcessor): else: img_h, img_w = target_sizes.unbind(1) - # rescale coordinates - width_ratio = 1 - height_ratio = 1 + # Rescale coordinates, image is padded to square for inference, + # that is why we need to scale boxes to the max size + size = torch.max(img_h, img_w) + scale_fct = torch.stack([size, size, size, size], dim=1).to(boxes.device) - if img_w < img_h: - width_ratio = img_w / img_h - elif img_h < img_w: - height_ratio = img_h / img_w - - img_w = img_w / width_ratio - img_h = img_h / height_ratio - - scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1).to(boxes.device) boxes = boxes * scale_fct[:, None, :] results = [] diff --git a/tests/models/owlv2/test_image_processor_owlv2.py b/tests/models/owlv2/test_image_processor_owlv2.py index 16b6b24df3b..87b96d06547 100644 --- a/tests/models/owlv2/test_image_processor_owlv2.py +++ b/tests/models/owlv2/test_image_processor_owlv2.py @@ -130,17 +130,42 @@ class Owlv2ImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase): model = Owlv2ForObjectDetection.from_pretrained(checkpoint) image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png") - inputs = processor(text=["cat"], images=image, return_tensors="pt") + text = ["cat"] + target_size = image.size[::-1] + expected_boxes = torch.tensor( + [ + [341.66656494140625, 23.38756561279297, 642.321044921875, 371.3482971191406], + [6.753320693969727, 51.96149826049805, 326.61810302734375, 473.12982177734375], + ] + ) + # single image + inputs = processor(text=[text], images=[image], return_tensors="pt") with torch.no_grad(): outputs = model(**inputs) - target_sizes = torch.tensor([image.size[::-1]]) - results = processor.post_process_object_detection(outputs, threshold=0.2, target_sizes=target_sizes)[0] + results = processor.post_process_object_detection(outputs, threshold=0.2, target_sizes=[target_size])[0] - boxes = results["boxes"].tolist() - self.assertEqual(boxes[0], [341.66656494140625, 23.38756561279297, 642.321044921875, 371.3482971191406]) - self.assertEqual(boxes[1], [6.753320693969727, 51.96149826049805, 326.61810302734375, 473.12982177734375]) + boxes = results["boxes"] + self.assertTrue( + torch.allclose(boxes, expected_boxes, atol=1e-2), + f"Single image bounding boxes fail. Expected {expected_boxes}, got {boxes}", + ) + + # batch of images + inputs = processor(text=[text, text], images=[image, image], return_tensors="pt") + with torch.no_grad(): + outputs = model(**inputs) + results = processor.post_process_object_detection( + outputs, threshold=0.2, target_sizes=[target_size, target_size] + ) + + for result in results: + boxes = result["boxes"] + self.assertTrue( + torch.allclose(boxes, expected_boxes, atol=1e-2), + f"Batch image bounding boxes fail. Expected {expected_boxes}, got {boxes}", + ) @unittest.skip("OWLv2 doesn't treat 4 channel PIL and numpy consistently yet") # FIXME Amy def test_call_numpy_4_channels(self):