Fix code examples of DETR and YOLOS (#19669)

Co-authored-by: Niels Rogge <nielsrogge@Nielss-MacBook-Pro.local>
This commit is contained in:
NielsRogge 2022-10-17 17:48:22 +02:00 committed by GitHub
parent 35bd089a24
commit bf0addc56e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 85 additions and 15 deletions

View File

@ -1426,12 +1426,12 @@ class DetrForObjectDetection(DetrPreTrainedModel):
>>> # convert outputs (bounding boxes and class logits) to COCO API
>>> target_sizes = torch.tensor([image.size[::-1]])
>>> results = feature_extractor.post_process_object_detection(outputs, target_sizes=target_sizes)[0]
>>> results = feature_extractor.post_process_object_detection(
... outputs, threshold=0.9, target_sizes=target_sizes
... )[0]
>>> for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
... box = [round(i, 2) for i in box.tolist()]
... # let's only keep detections with score > 0.9
... if score > 0.9:
... print(
... f"Detected {model.config.id2label[label.item()]} with confidence "
... f"{round(score.item(), 3)} at location {box}"

View File

@ -18,7 +18,7 @@ import io
import pathlib
import warnings
from collections import defaultdict
from typing import Dict, List, Optional, Union
from typing import Dict, List, Optional, Tuple, Union
import numpy as np
from PIL import Image
@ -748,6 +748,61 @@ class YolosFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin)
preds.append(predictions)
return preds
# Copied from transformers.models.detr.feature_extraction_detr.DetrFeatureExtractor.post_process_object_detection
def post_process_object_detection(
self, outputs, threshold: float = 0.5, target_sizes: Union[TensorType, List[Tuple]] = None
):
"""
Converts the output of [`DetrForObjectDetection`] into the format expected by the COCO api. Only supports
PyTorch.
Args:
outputs ([`DetrObjectDetectionOutput`]):
Raw outputs of the model.
threshold (`float`, *optional*):
Score threshold to keep object detection predictions.
target_sizes (`torch.Tensor` or `List[Tuple[int, int]]`, *optional*, defaults to `None`):
Tensor of shape `(batch_size, 2)` or list of tuples (`Tuple[int, int]`) containing the target size
(height, width) of each image in the batch. If left to None, predictions will not be resized.
Returns:
`List[Dict]`: A list of dictionaries, each dictionary containing the scores, labels and boxes for an image
in the batch as predicted by the model.
"""
out_logits, out_bbox = outputs.logits, outputs.pred_boxes
if target_sizes is not None:
if len(out_logits) != len(target_sizes):
raise ValueError(
"Make sure that you pass in as many target sizes as the batch dimension of the logits"
)
prob = nn.functional.softmax(out_logits, -1)
scores, labels = prob[..., :-1].max(-1)
# Convert to [x0, y0, x1, y1] format
boxes = center_to_corners_format(out_bbox)
# Convert from relative [0, 1] to absolute [0, height] coordinates
if target_sizes is not None:
if isinstance(target_sizes, List):
img_h = torch.Tensor([i[0] for i in target_sizes])
img_w = torch.Tensor([i[1] for i in target_sizes])
else:
img_h, img_w = target_sizes.unbind(1)
scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1)
boxes = boxes * scale_fct[:, None, :]
results = []
for s, l, b in zip(scores, labels, boxes):
score = s[s > threshold]
label = l[s > threshold]
box = b[s > threshold]
results.append({"scores": score, "labels": label, "boxes": box})
return results
# Copied from transformers.models.detr.feature_extraction_detr.DetrFeatureExtractor.post_process_instance
def post_process_instance(self, results, outputs, orig_target_sizes, max_target_sizes, threshold=0.5):
"""

View File

@ -754,24 +754,39 @@ class YolosForObjectDetection(YolosPreTrainedModel):
Returns:
Examples:
```python
>>> from transformers import YolosFeatureExtractor, YolosForObjectDetection
>>> from transformers import AutoFeatureExtractor, AutoModelForObjectDetection
>>> import torch
>>> from PIL import Image
>>> import requests
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> feature_extractor = YolosFeatureExtractor.from_pretrained("hustvl/yolos-small")
>>> model = YolosForObjectDetection.from_pretrained("hustvl/yolos-small")
>>> feature_extractor = AutoFeatureExtractor.from_pretrained("hustvl/yolos-tiny")
>>> model = AutoModelForObjectDetection.from_pretrained("hustvl/yolos-tiny")
>>> inputs = feature_extractor(images=image, return_tensors="pt")
>>> outputs = model(**inputs)
>>> # model predicts bounding boxes and corresponding COCO classes
>>> logits = outputs.logits
>>> bboxes = outputs.pred_boxes
>>> # convert outputs (bounding boxes and class logits) to COCO API
>>> target_sizes = torch.tensor([image.size[::-1]])
>>> results = feature_extractor.post_process_object_detection(
... outputs, threshold=0.9, target_sizes=target_sizes
... )[0]
>>> for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
... box = [round(i, 2) for i in box.tolist()]
... print(
... f"Detected {model.config.id2label[label.item()]} with confidence "
... f"{round(score.item(), 3)} at location {box}"
... )
Detected remote with confidence 0.994 at location [46.96, 72.61, 181.02, 119.73]
Detected remote with confidence 0.975 at location [340.66, 79.19, 372.59, 192.65]
Detected cat with confidence 0.984 at location [12.27, 54.25, 319.42, 470.99]
Detected remote with confidence 0.922 at location [41.66, 71.96, 178.7, 120.33]
Detected cat with confidence 0.914 at location [342.34, 21.48, 638.64, 372.46]
```"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict