mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 02:31:11 +06:00
fix owlvit tests, update docstring examples (#18586)
This commit is contained in:
parent
05d3a43c59
commit
f28f240828
@ -57,8 +57,8 @@ OWL-ViT is a zero-shot text-conditioned object detection model. OWL-ViT uses [CL
|
|||||||
... box = [round(i, 2) for i in box.tolist()]
|
... box = [round(i, 2) for i in box.tolist()]
|
||||||
... if score >= score_threshold:
|
... if score >= score_threshold:
|
||||||
... print(f"Detected {text[label]} with confidence {round(score.item(), 3)} at location {box}")
|
... print(f"Detected {text[label]} with confidence {round(score.item(), 3)} at location {box}")
|
||||||
Detected a photo of a cat with confidence 0.243 at location [1.42, 50.69, 308.58, 370.48]
|
Detected a photo of a cat with confidence 0.707 at location [324.97, 20.44, 640.58, 373.29]
|
||||||
Detected a photo of a cat with confidence 0.298 at location [348.06, 20.56, 642.33, 372.61]
|
Detected a photo of a cat with confidence 0.717 at location [1.46, 55.26, 315.55, 472.17]
|
||||||
```
|
```
|
||||||
|
|
||||||
This model was contributed by [adirik](https://huggingface.co/adirik). The original code can be found [here](https://github.com/google-research/scenic/tree/main/scenic/projects/owl_vit).
|
This model was contributed by [adirik](https://huggingface.co/adirik). The original code can be found [here](https://github.com/google-research/scenic/tree/main/scenic/projects/owl_vit).
|
||||||
|
@ -1323,8 +1323,8 @@ class OwlViTForObjectDetection(OwlViTPreTrainedModel):
|
|||||||
... box = [round(i, 2) for i in box.tolist()]
|
... box = [round(i, 2) for i in box.tolist()]
|
||||||
... if score >= score_threshold:
|
... if score >= score_threshold:
|
||||||
... print(f"Detected {text[label]} with confidence {round(score.item(), 3)} at location {box}")
|
... print(f"Detected {text[label]} with confidence {round(score.item(), 3)} at location {box}")
|
||||||
Detected a photo of a cat with confidence 0.243 at location [1.42, 50.69, 308.58, 370.48]
|
Detected a photo of a cat with confidence 0.707 at location [324.97, 20.44, 640.58, 373.29]
|
||||||
Detected a photo of a cat with confidence 0.298 at location [348.06, 20.56, 642.33, 372.61]
|
Detected a photo of a cat with confidence 0.717 at location [1.46, 55.26, 315.55, 472.17]
|
||||||
```"""
|
```"""
|
||||||
output_hidden_states = (
|
output_hidden_states = (
|
||||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||||
|
@ -733,7 +733,6 @@ def prepare_img():
|
|||||||
|
|
||||||
@require_vision
|
@require_vision
|
||||||
@require_torch
|
@require_torch
|
||||||
@unittest.skip("These tests are broken, fix me Alara")
|
|
||||||
class OwlViTModelIntegrationTest(unittest.TestCase):
|
class OwlViTModelIntegrationTest(unittest.TestCase):
|
||||||
@slow
|
@slow
|
||||||
def test_inference(self):
|
def test_inference(self):
|
||||||
@ -763,8 +762,7 @@ class OwlViTModelIntegrationTest(unittest.TestCase):
|
|||||||
outputs.logits_per_text.shape,
|
outputs.logits_per_text.shape,
|
||||||
torch.Size((inputs.input_ids.shape[0], inputs.pixel_values.shape[0])),
|
torch.Size((inputs.input_ids.shape[0], inputs.pixel_values.shape[0])),
|
||||||
)
|
)
|
||||||
expected_logits = torch.tensor([[4.4420, 0.6181]], device=torch_device)
|
expected_logits = torch.tensor([[3.4613, 0.9403]], device=torch_device)
|
||||||
|
|
||||||
self.assertTrue(torch.allclose(outputs.logits_per_image, expected_logits, atol=1e-3))
|
self.assertTrue(torch.allclose(outputs.logits_per_image, expected_logits, atol=1e-3))
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
@ -788,7 +786,8 @@ class OwlViTModelIntegrationTest(unittest.TestCase):
|
|||||||
|
|
||||||
num_queries = int((model.config.vision_config.image_size / model.config.vision_config.patch_size) ** 2)
|
num_queries = int((model.config.vision_config.image_size / model.config.vision_config.patch_size) ** 2)
|
||||||
self.assertEqual(outputs.pred_boxes.shape, torch.Size((1, num_queries, 4)))
|
self.assertEqual(outputs.pred_boxes.shape, torch.Size((1, num_queries, 4)))
|
||||||
|
|
||||||
expected_slice_boxes = torch.tensor(
|
expected_slice_boxes = torch.tensor(
|
||||||
[[0.0948, 0.0471, 0.1915], [0.3194, 0.0583, 0.6498], [0.1441, 0.0452, 0.2197]]
|
[[0.0691, 0.0445, 0.1373], [0.1592, 0.0456, 0.3192], [0.1632, 0.0423, 0.2478]]
|
||||||
).to(torch_device)
|
).to(torch_device)
|
||||||
self.assertTrue(torch.allclose(outputs.pred_boxes[0, :3, :3], expected_slice_boxes, atol=1e-4))
|
self.assertTrue(torch.allclose(outputs.pred_boxes[0, :3, :3], expected_slice_boxes, atol=1e-4))
|
||||||
|
Loading…
Reference in New Issue
Block a user