fix owlvit tests, update docstring examples (#18586)

This commit is contained in:
Alara Dirik 2022-08-11 19:10:25 +03:00 committed by GitHub
parent 05d3a43c59
commit f28f240828
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 7 additions and 8 deletions

View File

@ -57,8 +57,8 @@ OWL-ViT is a zero-shot text-conditioned object detection model. OWL-ViT uses [CL
... box = [round(i, 2) for i in box.tolist()] ... box = [round(i, 2) for i in box.tolist()]
... if score >= score_threshold: ... if score >= score_threshold:
... print(f"Detected {text[label]} with confidence {round(score.item(), 3)} at location {box}") ... print(f"Detected {text[label]} with confidence {round(score.item(), 3)} at location {box}")
Detected a photo of a cat with confidence 0.243 at location [1.42, 50.69, 308.58, 370.48] Detected a photo of a cat with confidence 0.707 at location [324.97, 20.44, 640.58, 373.29]
Detected a photo of a cat with confidence 0.298 at location [348.06, 20.56, 642.33, 372.61] Detected a photo of a cat with confidence 0.717 at location [1.46, 55.26, 315.55, 472.17]
``` ```
This model was contributed by [adirik](https://huggingface.co/adirik). The original code can be found [here](https://github.com/google-research/scenic/tree/main/scenic/projects/owl_vit). This model was contributed by [adirik](https://huggingface.co/adirik). The original code can be found [here](https://github.com/google-research/scenic/tree/main/scenic/projects/owl_vit).

View File

@ -1323,8 +1323,8 @@ class OwlViTForObjectDetection(OwlViTPreTrainedModel):
... box = [round(i, 2) for i in box.tolist()] ... box = [round(i, 2) for i in box.tolist()]
... if score >= score_threshold: ... if score >= score_threshold:
... print(f"Detected {text[label]} with confidence {round(score.item(), 3)} at location {box}") ... print(f"Detected {text[label]} with confidence {round(score.item(), 3)} at location {box}")
Detected a photo of a cat with confidence 0.243 at location [1.42, 50.69, 308.58, 370.48] Detected a photo of a cat with confidence 0.707 at location [324.97, 20.44, 640.58, 373.29]
Detected a photo of a cat with confidence 0.298 at location [348.06, 20.56, 642.33, 372.61] Detected a photo of a cat with confidence 0.717 at location [1.46, 55.26, 315.55, 472.17]
```""" ```"""
output_hidden_states = ( output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states

View File

@ -733,7 +733,6 @@ def prepare_img():
@require_vision @require_vision
@require_torch @require_torch
@unittest.skip("These tests are broken, fix me Alara")
class OwlViTModelIntegrationTest(unittest.TestCase): class OwlViTModelIntegrationTest(unittest.TestCase):
@slow @slow
def test_inference(self): def test_inference(self):
@ -763,8 +762,7 @@ class OwlViTModelIntegrationTest(unittest.TestCase):
outputs.logits_per_text.shape, outputs.logits_per_text.shape,
torch.Size((inputs.input_ids.shape[0], inputs.pixel_values.shape[0])), torch.Size((inputs.input_ids.shape[0], inputs.pixel_values.shape[0])),
) )
expected_logits = torch.tensor([[4.4420, 0.6181]], device=torch_device) expected_logits = torch.tensor([[3.4613, 0.9403]], device=torch_device)
self.assertTrue(torch.allclose(outputs.logits_per_image, expected_logits, atol=1e-3)) self.assertTrue(torch.allclose(outputs.logits_per_image, expected_logits, atol=1e-3))
@slow @slow
@ -788,7 +786,8 @@ class OwlViTModelIntegrationTest(unittest.TestCase):
num_queries = int((model.config.vision_config.image_size / model.config.vision_config.patch_size) ** 2) num_queries = int((model.config.vision_config.image_size / model.config.vision_config.patch_size) ** 2)
self.assertEqual(outputs.pred_boxes.shape, torch.Size((1, num_queries, 4))) self.assertEqual(outputs.pred_boxes.shape, torch.Size((1, num_queries, 4)))
expected_slice_boxes = torch.tensor( expected_slice_boxes = torch.tensor(
[[0.0948, 0.0471, 0.1915], [0.3194, 0.0583, 0.6498], [0.1441, 0.0452, 0.2197]] [[0.0691, 0.0445, 0.1373], [0.1592, 0.0456, 0.3192], [0.1632, 0.0423, 0.2478]]
).to(torch_device) ).to(torch_device)
self.assertTrue(torch.allclose(outputs.pred_boxes[0, :3, :3], expected_slice_boxes, atol=1e-4)) self.assertTrue(torch.allclose(outputs.pred_boxes[0, :3, :3], expected_slice_boxes, atol=1e-4))