diff --git a/src/transformers/models/detr/modeling_detr.py b/src/transformers/models/detr/modeling_detr.py index 8fcbfe09c04..f2fb0539f92 100644 --- a/src/transformers/models/detr/modeling_detr.py +++ b/src/transformers/models/detr/modeling_detr.py @@ -1426,16 +1426,16 @@ class DetrForObjectDetection(DetrPreTrainedModel): >>> # convert outputs (bounding boxes and class logits) to COCO API >>> target_sizes = torch.tensor([image.size[::-1]]) - >>> results = feature_extractor.post_process_object_detection(outputs, target_sizes=target_sizes)[0] + >>> results = feature_extractor.post_process_object_detection( + ... outputs, threshold=0.9, target_sizes=target_sizes + ... )[0] >>> for score, label, box in zip(results["scores"], results["labels"], results["boxes"]): ... box = [round(i, 2) for i in box.tolist()] - ... # let's only keep detections with score > 0.9 - ... if score > 0.9: - ... print( - ... f"Detected {model.config.id2label[label.item()]} with confidence " - ... f"{round(score.item(), 3)} at location {box}" - ... ) + ... print( + ... f"Detected {model.config.id2label[label.item()]} with confidence " + ... f"{round(score.item(), 3)} at location {box}" + ... ) Detected remote with confidence 0.998 at location [40.16, 70.81, 175.55, 117.98] Detected remote with confidence 0.996 at location [333.24, 72.55, 368.33, 187.66] Detected couch with confidence 0.995 at location [-0.02, 1.15, 639.73, 473.76] diff --git a/src/transformers/models/yolos/feature_extraction_yolos.py b/src/transformers/models/yolos/feature_extraction_yolos.py index 7239986d587..81235ec6281 100644 --- a/src/transformers/models/yolos/feature_extraction_yolos.py +++ b/src/transformers/models/yolos/feature_extraction_yolos.py @@ -18,7 +18,7 @@ import io import pathlib import warnings from collections import defaultdict -from typing import Dict, List, Optional, Union +from typing import Dict, List, Optional, Tuple, Union import numpy as np from PIL import Image @@ -748,6 +748,61 @@ class YolosFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin) preds.append(predictions) return preds + # Copied from transformers.models.detr.feature_extraction_detr.DetrFeatureExtractor.post_process_object_detection + def post_process_object_detection( + self, outputs, threshold: float = 0.5, target_sizes: Union[TensorType, List[Tuple]] = None + ): + """ + Converts the output of [`DetrForObjectDetection`] into the format expected by the COCO api. Only supports + PyTorch. + + Args: + outputs ([`DetrObjectDetectionOutput`]): + Raw outputs of the model. + threshold (`float`, *optional*): + Score threshold to keep object detection predictions. + target_sizes (`torch.Tensor` or `List[Tuple[int, int]]`, *optional*, defaults to `None`): + Tensor of shape `(batch_size, 2)` or list of tuples (`Tuple[int, int]`) containing the target size + (height, width) of each image in the batch. If left to None, predictions will not be resized. + + Returns: + `List[Dict]`: A list of dictionaries, each dictionary containing the scores, labels and boxes for an image + in the batch as predicted by the model. + """ + out_logits, out_bbox = outputs.logits, outputs.pred_boxes + + if target_sizes is not None: + if len(out_logits) != len(target_sizes): + raise ValueError( + "Make sure that you pass in as many target sizes as the batch dimension of the logits" + ) + + prob = nn.functional.softmax(out_logits, -1) + scores, labels = prob[..., :-1].max(-1) + + # Convert to [x0, y0, x1, y1] format + boxes = center_to_corners_format(out_bbox) + + # Convert from relative [0, 1] to absolute [0, height] coordinates + if target_sizes is not None: + if isinstance(target_sizes, List): + img_h = torch.Tensor([i[0] for i in target_sizes]) + img_w = torch.Tensor([i[1] for i in target_sizes]) + else: + img_h, img_w = target_sizes.unbind(1) + + scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1) + boxes = boxes * scale_fct[:, None, :] + + results = [] + for s, l, b in zip(scores, labels, boxes): + score = s[s > threshold] + label = l[s > threshold] + box = b[s > threshold] + results.append({"scores": score, "labels": label, "boxes": box}) + + return results + # Copied from transformers.models.detr.feature_extraction_detr.DetrFeatureExtractor.post_process_instance def post_process_instance(self, results, outputs, orig_target_sizes, max_target_sizes, threshold=0.5): """ diff --git a/src/transformers/models/yolos/modeling_yolos.py b/src/transformers/models/yolos/modeling_yolos.py index d2b8132cc81..fe876132b1c 100755 --- a/src/transformers/models/yolos/modeling_yolos.py +++ b/src/transformers/models/yolos/modeling_yolos.py @@ -754,24 +754,39 @@ class YolosForObjectDetection(YolosPreTrainedModel): Returns: Examples: + ```python - >>> from transformers import YolosFeatureExtractor, YolosForObjectDetection + >>> from transformers import AutoFeatureExtractor, AutoModelForObjectDetection + >>> import torch >>> from PIL import Image >>> import requests >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" >>> image = Image.open(requests.get(url, stream=True).raw) - >>> feature_extractor = YolosFeatureExtractor.from_pretrained("hustvl/yolos-small") - >>> model = YolosForObjectDetection.from_pretrained("hustvl/yolos-small") + >>> feature_extractor = AutoFeatureExtractor.from_pretrained("hustvl/yolos-tiny") + >>> model = AutoModelForObjectDetection.from_pretrained("hustvl/yolos-tiny") >>> inputs = feature_extractor(images=image, return_tensors="pt") - >>> outputs = model(**inputs) - >>> # model predicts bounding boxes and corresponding COCO classes - >>> logits = outputs.logits - >>> bboxes = outputs.pred_boxes + >>> # convert outputs (bounding boxes and class logits) to COCO API + >>> target_sizes = torch.tensor([image.size[::-1]]) + >>> results = feature_extractor.post_process_object_detection( + ... outputs, threshold=0.9, target_sizes=target_sizes + ... )[0] + + >>> for score, label, box in zip(results["scores"], results["labels"], results["boxes"]): + ... box = [round(i, 2) for i in box.tolist()] + ... print( + ... f"Detected {model.config.id2label[label.item()]} with confidence " + ... f"{round(score.item(), 3)} at location {box}" + ... ) + Detected remote with confidence 0.994 at location [46.96, 72.61, 181.02, 119.73] + Detected remote with confidence 0.975 at location [340.66, 79.19, 372.59, 192.65] + Detected cat with confidence 0.984 at location [12.27, 54.25, 319.42, 470.99] + Detected remote with confidence 0.922 at location [41.66, 71.96, 178.7, 120.33] + Detected cat with confidence 0.914 at location [342.34, 21.48, 638.64, 372.46] ```""" return_dict = return_dict if return_dict is not None else self.config.use_return_dict