diff --git a/docs/source/en/model_doc/yolos.mdx b/docs/source/en/model_doc/yolos.mdx index bda65bec913..51c75252ac9 100644 --- a/docs/source/en/model_doc/yolos.mdx +++ b/docs/source/en/model_doc/yolos.mdx @@ -43,9 +43,7 @@ This model was contributed by [nielsr](https://huggingface.co/nielsr). The origi [[autodoc]] YolosFeatureExtractor - __call__ - pad - - post_process - - post_process_segmentation - - post_process_panoptic + - post_process_object_detection ## YolosModel diff --git a/src/transformers/models/yolos/feature_extraction_yolos.py b/src/transformers/models/yolos/feature_extraction_yolos.py index 81235ec6281..1358666d7e9 100644 --- a/src/transformers/models/yolos/feature_extraction_yolos.py +++ b/src/transformers/models/yolos/feature_extraction_yolos.py @@ -14,10 +14,8 @@ # limitations under the License. """Feature extractor class for YOLOS.""" -import io import pathlib import warnings -from collections import defaultdict from typing import Dict, List, Optional, Tuple, Union import numpy as np @@ -748,16 +746,16 @@ class YolosFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin) preds.append(predictions) return preds - # Copied from transformers.models.detr.feature_extraction_detr.DetrFeatureExtractor.post_process_object_detection + # Copied from transformers.models.detr.feature_extraction_detr.DetrFeatureExtractor.post_process_object_detection with Detr->Yolos def post_process_object_detection( self, outputs, threshold: float = 0.5, target_sizes: Union[TensorType, List[Tuple]] = None ): """ - Converts the output of [`DetrForObjectDetection`] into the format expected by the COCO api. Only supports + Converts the output of [`YolosForObjectDetection`] into the format expected by the COCO api. Only supports PyTorch. Args: - outputs ([`DetrObjectDetectionOutput`]): + outputs ([`YolosObjectDetectionOutput`]): Raw outputs of the model. threshold (`float`, *optional*): Score threshold to keep object detection predictions. @@ -802,192 +800,3 @@ class YolosFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin) results.append({"scores": score, "labels": label, "boxes": box}) return results - - # Copied from transformers.models.detr.feature_extraction_detr.DetrFeatureExtractor.post_process_instance - def post_process_instance(self, results, outputs, orig_target_sizes, max_target_sizes, threshold=0.5): - """ - Converts the output of [`DetrForSegmentation`] into actual instance segmentation predictions. Only supports - PyTorch. - - Args: - results (`List[Dict]`): - Results list obtained by [`~DetrFeatureExtractor.post_process`], to which "masks" results will be - added. - outputs ([`DetrSegmentationOutput`]): - Raw outputs of the model. - orig_target_sizes (`torch.Tensor` of shape `(batch_size, 2)`): - Tensor containing the size (h, w) of each image of the batch. For evaluation, this must be the original - image size (before any data augmentation). - max_target_sizes (`torch.Tensor` of shape `(batch_size, 2)`): - Tensor containing the maximum size (h, w) of each image of the batch. For evaluation, this must be the - original image size (before any data augmentation). - threshold (`float`, *optional*, defaults to 0.5): - Threshold to use when turning the predicted masks into binary values. - - Returns: - `List[Dict]`: A list of dictionaries, each dictionary containing the scores, labels, boxes and masks for an - image in the batch as predicted by the model. - """ - warnings.warn( - "`post_process_instance` is deprecated and will be removed in v5 of Transformers, please use" - " `post_process_instance_segmentation`.", - FutureWarning, - ) - - if len(orig_target_sizes) != len(max_target_sizes): - raise ValueError("Make sure to pass in as many orig_target_sizes as max_target_sizes") - max_h, max_w = max_target_sizes.max(0)[0].tolist() - outputs_masks = outputs.pred_masks.squeeze(2) - outputs_masks = nn.functional.interpolate( - outputs_masks, size=(max_h, max_w), mode="bilinear", align_corners=False - ) - outputs_masks = (outputs_masks.sigmoid() > threshold).cpu() - - for i, (cur_mask, t, tt) in enumerate(zip(outputs_masks, max_target_sizes, orig_target_sizes)): - img_h, img_w = t[0], t[1] - results[i]["masks"] = cur_mask[:, :img_h, :img_w].unsqueeze(1) - results[i]["masks"] = nn.functional.interpolate( - results[i]["masks"].float(), size=tuple(tt.tolist()), mode="nearest" - ).byte() - - return results - - # Copied from transformers.models.detr.feature_extraction_detr.DetrFeatureExtractor.post_process_panoptic - def post_process_panoptic(self, outputs, processed_sizes, target_sizes=None, is_thing_map=None, threshold=0.85): - """ - Converts the output of [`DetrForSegmentation`] into actual panoptic predictions. Only supports PyTorch. - - Parameters: - outputs ([`DetrSegmentationOutput`]): - Raw outputs of the model. - processed_sizes (`torch.Tensor` of shape `(batch_size, 2)` or `List[Tuple]` of length `batch_size`): - Torch Tensor (or list) containing the size (h, w) of each image of the batch, i.e. the size after data - augmentation but before batching. - target_sizes (`torch.Tensor` of shape `(batch_size, 2)` or `List[Tuple]` of length `batch_size`, *optional*): - Torch Tensor (or list) corresponding to the requested final size (h, w) of each prediction. If left to - None, it will default to the `processed_sizes`. - is_thing_map (`torch.Tensor` of shape `(batch_size, 2)`, *optional*): - Dictionary mapping class indices to either True or False, depending on whether or not they are a thing. - If not set, defaults to the `is_thing_map` of COCO panoptic. - threshold (`float`, *optional*, defaults to 0.85): - Threshold to use to filter out queries. - - Returns: - `List[Dict]`: A list of dictionaries, each dictionary containing a PNG string and segments_info values for - an image in the batch as predicted by the model. - """ - warnings.warn( - "`post_process_panoptic is deprecated and will be removed in v5 of Transformers, please use" - " `post_process_panoptic_segmentation`.", - FutureWarning, - ) - if target_sizes is None: - target_sizes = processed_sizes - if len(processed_sizes) != len(target_sizes): - raise ValueError("Make sure to pass in as many processed_sizes as target_sizes") - - if is_thing_map is None: - # default to is_thing_map of COCO panoptic - is_thing_map = {i: i <= 90 for i in range(201)} - - out_logits, raw_masks, raw_boxes = outputs.logits, outputs.pred_masks, outputs.pred_boxes - if not len(out_logits) == len(raw_masks) == len(target_sizes): - raise ValueError( - "Make sure that you pass in as many target sizes as the batch dimension of the logits and masks" - ) - preds = [] - - def to_tuple(tup): - if isinstance(tup, tuple): - return tup - return tuple(tup.cpu().tolist()) - - for cur_logits, cur_masks, cur_boxes, size, target_size in zip( - out_logits, raw_masks, raw_boxes, processed_sizes, target_sizes - ): - # we filter empty queries and detection below threshold - scores, labels = cur_logits.softmax(-1).max(-1) - keep = labels.ne(outputs.logits.shape[-1] - 1) & (scores > threshold) - cur_scores, cur_classes = cur_logits.softmax(-1).max(-1) - cur_scores = cur_scores[keep] - cur_classes = cur_classes[keep] - cur_masks = cur_masks[keep] - cur_masks = nn.functional.interpolate(cur_masks[:, None], to_tuple(size), mode="bilinear").squeeze(1) - cur_boxes = center_to_corners_format(cur_boxes[keep]) - - h, w = cur_masks.shape[-2:] - if len(cur_boxes) != len(cur_classes): - raise ValueError("Not as many boxes as there are classes") - - # It may be that we have several predicted masks for the same stuff class. - # In the following, we track the list of masks ids for each stuff class (they are merged later on) - cur_masks = cur_masks.flatten(1) - stuff_equiv_classes = defaultdict(lambda: []) - for k, label in enumerate(cur_classes): - if not is_thing_map[label.item()]: - stuff_equiv_classes[label.item()].append(k) - - def get_ids_area(masks, scores, dedup=False): - # This helper function creates the final panoptic segmentation image - # It also returns the area of the masks that appears on the image - - m_id = masks.transpose(0, 1).softmax(-1) - - if m_id.shape[-1] == 0: - # We didn't detect any mask :( - m_id = torch.zeros((h, w), dtype=torch.long, device=m_id.device) - else: - m_id = m_id.argmax(-1).view(h, w) - - if dedup: - # Merge the masks corresponding to the same stuff class - for equiv in stuff_equiv_classes.values(): - if len(equiv) > 1: - for eq_id in equiv: - m_id.masked_fill_(m_id.eq(eq_id), equiv[0]) - - final_h, final_w = to_tuple(target_size) - - seg_img = Image.fromarray(id_to_rgb(m_id.view(h, w).cpu().numpy())) - seg_img = seg_img.resize(size=(final_w, final_h), resample=Image.NEAREST) - - np_seg_img = torch.ByteTensor(torch.ByteStorage.from_buffer(seg_img.tobytes())) - np_seg_img = np_seg_img.view(final_h, final_w, 3) - np_seg_img = np_seg_img.numpy() - - m_id = torch.from_numpy(rgb_to_id(np_seg_img)) - - area = [] - for i in range(len(scores)): - area.append(m_id.eq(i).sum().item()) - return area, seg_img - - area, seg_img = get_ids_area(cur_masks, cur_scores, dedup=True) - if cur_classes.numel() > 0: - # We know filter empty masks as long as we find some - while True: - filtered_small = torch.as_tensor( - [area[i] <= 4 for i, c in enumerate(cur_classes)], dtype=torch.bool, device=keep.device - ) - if filtered_small.any().item(): - cur_scores = cur_scores[~filtered_small] - cur_classes = cur_classes[~filtered_small] - cur_masks = cur_masks[~filtered_small] - area, seg_img = get_ids_area(cur_masks, cur_scores) - else: - break - - else: - cur_classes = torch.ones(1, dtype=torch.long, device=cur_classes.device) - - segments_info = [] - for i, a in enumerate(area): - cat = cur_classes[i].item() - segments_info.append({"id": i, "isthing": is_thing_map[cat], "category_id": cat, "area": a}) - del cur_classes - - with io.BytesIO() as out: - seg_img.save(out, format="PNG") - predictions = {"png_string": out.getvalue(), "segments_info": segments_info} - preds.append(predictions) - return preds