mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-29 09:12:21 +06:00
Fix code examples of DETR and YOLOS (#19669)
Co-authored-by: Niels Rogge <nielsrogge@Nielss-MacBook-Pro.local>
This commit is contained in:
parent
35bd089a24
commit
bf0addc56e
@ -1426,16 +1426,16 @@ class DetrForObjectDetection(DetrPreTrainedModel):
|
|||||||
|
|
||||||
>>> # convert outputs (bounding boxes and class logits) to COCO API
|
>>> # convert outputs (bounding boxes and class logits) to COCO API
|
||||||
>>> target_sizes = torch.tensor([image.size[::-1]])
|
>>> target_sizes = torch.tensor([image.size[::-1]])
|
||||||
>>> results = feature_extractor.post_process_object_detection(outputs, target_sizes=target_sizes)[0]
|
>>> results = feature_extractor.post_process_object_detection(
|
||||||
|
... outputs, threshold=0.9, target_sizes=target_sizes
|
||||||
|
... )[0]
|
||||||
|
|
||||||
>>> for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
|
>>> for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
|
||||||
... box = [round(i, 2) for i in box.tolist()]
|
... box = [round(i, 2) for i in box.tolist()]
|
||||||
... # let's only keep detections with score > 0.9
|
... print(
|
||||||
... if score > 0.9:
|
... f"Detected {model.config.id2label[label.item()]} with confidence "
|
||||||
... print(
|
... f"{round(score.item(), 3)} at location {box}"
|
||||||
... f"Detected {model.config.id2label[label.item()]} with confidence "
|
... )
|
||||||
... f"{round(score.item(), 3)} at location {box}"
|
|
||||||
... )
|
|
||||||
Detected remote with confidence 0.998 at location [40.16, 70.81, 175.55, 117.98]
|
Detected remote with confidence 0.998 at location [40.16, 70.81, 175.55, 117.98]
|
||||||
Detected remote with confidence 0.996 at location [333.24, 72.55, 368.33, 187.66]
|
Detected remote with confidence 0.996 at location [333.24, 72.55, 368.33, 187.66]
|
||||||
Detected couch with confidence 0.995 at location [-0.02, 1.15, 639.73, 473.76]
|
Detected couch with confidence 0.995 at location [-0.02, 1.15, 639.73, 473.76]
|
||||||
|
@ -18,7 +18,7 @@ import io
|
|||||||
import pathlib
|
import pathlib
|
||||||
import warnings
|
import warnings
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from typing import Dict, List, Optional, Union
|
from typing import Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
@ -748,6 +748,61 @@ class YolosFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin)
|
|||||||
preds.append(predictions)
|
preds.append(predictions)
|
||||||
return preds
|
return preds
|
||||||
|
|
||||||
|
# Copied from transformers.models.detr.feature_extraction_detr.DetrFeatureExtractor.post_process_object_detection
|
||||||
|
def post_process_object_detection(
|
||||||
|
self, outputs, threshold: float = 0.5, target_sizes: Union[TensorType, List[Tuple]] = None
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Converts the output of [`DetrForObjectDetection`] into the format expected by the COCO api. Only supports
|
||||||
|
PyTorch.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
outputs ([`DetrObjectDetectionOutput`]):
|
||||||
|
Raw outputs of the model.
|
||||||
|
threshold (`float`, *optional*):
|
||||||
|
Score threshold to keep object detection predictions.
|
||||||
|
target_sizes (`torch.Tensor` or `List[Tuple[int, int]]`, *optional*, defaults to `None`):
|
||||||
|
Tensor of shape `(batch_size, 2)` or list of tuples (`Tuple[int, int]`) containing the target size
|
||||||
|
(height, width) of each image in the batch. If left to None, predictions will not be resized.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`List[Dict]`: A list of dictionaries, each dictionary containing the scores, labels and boxes for an image
|
||||||
|
in the batch as predicted by the model.
|
||||||
|
"""
|
||||||
|
out_logits, out_bbox = outputs.logits, outputs.pred_boxes
|
||||||
|
|
||||||
|
if target_sizes is not None:
|
||||||
|
if len(out_logits) != len(target_sizes):
|
||||||
|
raise ValueError(
|
||||||
|
"Make sure that you pass in as many target sizes as the batch dimension of the logits"
|
||||||
|
)
|
||||||
|
|
||||||
|
prob = nn.functional.softmax(out_logits, -1)
|
||||||
|
scores, labels = prob[..., :-1].max(-1)
|
||||||
|
|
||||||
|
# Convert to [x0, y0, x1, y1] format
|
||||||
|
boxes = center_to_corners_format(out_bbox)
|
||||||
|
|
||||||
|
# Convert from relative [0, 1] to absolute [0, height] coordinates
|
||||||
|
if target_sizes is not None:
|
||||||
|
if isinstance(target_sizes, List):
|
||||||
|
img_h = torch.Tensor([i[0] for i in target_sizes])
|
||||||
|
img_w = torch.Tensor([i[1] for i in target_sizes])
|
||||||
|
else:
|
||||||
|
img_h, img_w = target_sizes.unbind(1)
|
||||||
|
|
||||||
|
scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1)
|
||||||
|
boxes = boxes * scale_fct[:, None, :]
|
||||||
|
|
||||||
|
results = []
|
||||||
|
for s, l, b in zip(scores, labels, boxes):
|
||||||
|
score = s[s > threshold]
|
||||||
|
label = l[s > threshold]
|
||||||
|
box = b[s > threshold]
|
||||||
|
results.append({"scores": score, "labels": label, "boxes": box})
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
# Copied from transformers.models.detr.feature_extraction_detr.DetrFeatureExtractor.post_process_instance
|
# Copied from transformers.models.detr.feature_extraction_detr.DetrFeatureExtractor.post_process_instance
|
||||||
def post_process_instance(self, results, outputs, orig_target_sizes, max_target_sizes, threshold=0.5):
|
def post_process_instance(self, results, outputs, orig_target_sizes, max_target_sizes, threshold=0.5):
|
||||||
"""
|
"""
|
||||||
|
@ -754,24 +754,39 @@ class YolosForObjectDetection(YolosPreTrainedModel):
|
|||||||
Returns:
|
Returns:
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
>>> from transformers import YolosFeatureExtractor, YolosForObjectDetection
|
>>> from transformers import AutoFeatureExtractor, AutoModelForObjectDetection
|
||||||
|
>>> import torch
|
||||||
>>> from PIL import Image
|
>>> from PIL import Image
|
||||||
>>> import requests
|
>>> import requests
|
||||||
|
|
||||||
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||||
>>> image = Image.open(requests.get(url, stream=True).raw)
|
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||||||
|
|
||||||
>>> feature_extractor = YolosFeatureExtractor.from_pretrained("hustvl/yolos-small")
|
>>> feature_extractor = AutoFeatureExtractor.from_pretrained("hustvl/yolos-tiny")
|
||||||
>>> model = YolosForObjectDetection.from_pretrained("hustvl/yolos-small")
|
>>> model = AutoModelForObjectDetection.from_pretrained("hustvl/yolos-tiny")
|
||||||
|
|
||||||
>>> inputs = feature_extractor(images=image, return_tensors="pt")
|
>>> inputs = feature_extractor(images=image, return_tensors="pt")
|
||||||
|
|
||||||
>>> outputs = model(**inputs)
|
>>> outputs = model(**inputs)
|
||||||
|
|
||||||
>>> # model predicts bounding boxes and corresponding COCO classes
|
>>> # convert outputs (bounding boxes and class logits) to COCO API
|
||||||
>>> logits = outputs.logits
|
>>> target_sizes = torch.tensor([image.size[::-1]])
|
||||||
>>> bboxes = outputs.pred_boxes
|
>>> results = feature_extractor.post_process_object_detection(
|
||||||
|
... outputs, threshold=0.9, target_sizes=target_sizes
|
||||||
|
... )[0]
|
||||||
|
|
||||||
|
>>> for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
|
||||||
|
... box = [round(i, 2) for i in box.tolist()]
|
||||||
|
... print(
|
||||||
|
... f"Detected {model.config.id2label[label.item()]} with confidence "
|
||||||
|
... f"{round(score.item(), 3)} at location {box}"
|
||||||
|
... )
|
||||||
|
Detected remote with confidence 0.994 at location [46.96, 72.61, 181.02, 119.73]
|
||||||
|
Detected remote with confidence 0.975 at location [340.66, 79.19, 372.59, 192.65]
|
||||||
|
Detected cat with confidence 0.984 at location [12.27, 54.25, 319.42, 470.99]
|
||||||
|
Detected remote with confidence 0.922 at location [41.66, 71.96, 178.7, 120.33]
|
||||||
|
Detected cat with confidence 0.914 at location [342.34, 21.48, 638.64, 372.46]
|
||||||
```"""
|
```"""
|
||||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user