From d3b1adf59fff726f1c8f324728e562237f080ce6 Mon Sep 17 00:00:00 2001 From: Vitali Petsiuk Date: Tue, 14 Feb 2023 13:00:02 -0500 Subject: [PATCH] Removes duplicate computations in DETR post processing (#21592) * Remove redundant computations, comb variable names * Fix scores to cur_scores --- .../models/detr/image_processing_detr.py | 34 +++++++++---------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/src/transformers/models/detr/image_processing_detr.py b/src/transformers/models/detr/image_processing_detr.py index f03465954e8..433853efefa 100644 --- a/src/transformers/models/detr/image_processing_detr.py +++ b/src/transformers/models/detr/image_processing_detr.py @@ -1311,6 +1311,7 @@ class DetrImageProcessor(BaseImageProcessor): FutureWarning, ) out_logits, raw_masks = outputs.logits, outputs.pred_masks + empty_label = out_logits.shape[-1] - 1 preds = [] def to_tuple(tup): @@ -1320,16 +1321,15 @@ class DetrImageProcessor(BaseImageProcessor): for cur_logits, cur_masks, size in zip(out_logits, raw_masks, target_sizes): # we filter empty queries and detection below threshold - scores, labels = cur_logits.softmax(-1).max(-1) - keep = labels.ne(outputs.logits.shape[-1] - 1) & (scores > threshold) - cur_scores, cur_classes = cur_logits.softmax(-1).max(-1) + cur_scores, cur_labels = cur_logits.softmax(-1).max(-1) + keep = cur_labels.ne(empty_label) & (cur_scores > threshold) cur_scores = cur_scores[keep] - cur_classes = cur_classes[keep] + cur_labels = cur_labels[keep] cur_masks = cur_masks[keep] cur_masks = nn.functional.interpolate(cur_masks[:, None], to_tuple(size), mode="bilinear").squeeze(1) cur_masks = (cur_masks.sigmoid() > mask_threshold) * 1 - predictions = {"scores": cur_scores, "labels": cur_classes, "masks": cur_masks} + predictions = {"scores": cur_scores, "labels": cur_labels, "masks": cur_masks} preds.append(predictions) return preds @@ -1423,6 +1423,7 @@ class DetrImageProcessor(BaseImageProcessor): raise ValueError( "Make sure that you pass in as many target sizes as the batch dimension of the logits and masks" ) + empty_label = out_logits.shape[-1] - 1 preds = [] def to_tuple(tup): @@ -1434,24 +1435,23 @@ class DetrImageProcessor(BaseImageProcessor): out_logits, raw_masks, raw_boxes, processed_sizes, target_sizes ): # we filter empty queries and detection below threshold - scores, labels = cur_logits.softmax(-1).max(-1) - keep = labels.ne(outputs.logits.shape[-1] - 1) & (scores > threshold) - cur_scores, cur_classes = cur_logits.softmax(-1).max(-1) + cur_scores, cur_labels = cur_logits.softmax(-1).max(-1) + keep = cur_labels.ne(empty_label) & (cur_scores > threshold) cur_scores = cur_scores[keep] - cur_classes = cur_classes[keep] + cur_labels = cur_labels[keep] cur_masks = cur_masks[keep] cur_masks = nn.functional.interpolate(cur_masks[:, None], to_tuple(size), mode="bilinear").squeeze(1) cur_boxes = center_to_corners_format(cur_boxes[keep]) h, w = cur_masks.shape[-2:] - if len(cur_boxes) != len(cur_classes): + if len(cur_boxes) != len(cur_labels): raise ValueError("Not as many boxes as there are classes") # It may be that we have several predicted masks for the same stuff class. # In the following, we track the list of masks ids for each stuff class (they are merged later on) cur_masks = cur_masks.flatten(1) stuff_equiv_classes = defaultdict(lambda: []) - for k, label in enumerate(cur_classes): + for k, label in enumerate(cur_labels): if not is_thing_map[label.item()]: stuff_equiv_classes[label.item()].append(k) @@ -1491,28 +1491,28 @@ class DetrImageProcessor(BaseImageProcessor): return area, seg_img area, seg_img = get_ids_area(cur_masks, cur_scores, dedup=True) - if cur_classes.numel() > 0: + if cur_labels.numel() > 0: # We know filter empty masks as long as we find some while True: filtered_small = torch.as_tensor( - [area[i] <= 4 for i, c in enumerate(cur_classes)], dtype=torch.bool, device=keep.device + [area[i] <= 4 for i, c in enumerate(cur_labels)], dtype=torch.bool, device=keep.device ) if filtered_small.any().item(): cur_scores = cur_scores[~filtered_small] - cur_classes = cur_classes[~filtered_small] + cur_labels = cur_labels[~filtered_small] cur_masks = cur_masks[~filtered_small] area, seg_img = get_ids_area(cur_masks, cur_scores) else: break else: - cur_classes = torch.ones(1, dtype=torch.long, device=cur_classes.device) + cur_labels = torch.ones(1, dtype=torch.long, device=cur_labels.device) segments_info = [] for i, a in enumerate(area): - cat = cur_classes[i].item() + cat = cur_labels[i].item() segments_info.append({"id": i, "isthing": is_thing_map[cat], "category_id": cat, "area": a}) - del cur_classes + del cur_labels with io.BytesIO() as out: seg_img.save(out, format="PNG")