Removes duplicate computations in DETR post processing (#21592)

* Remove redundant computations, comb variable names

* Fix scores to cur_scores
This commit is contained in:
Vitali Petsiuk 2023-02-14 13:00:02 -05:00 committed by GitHub
parent d4ba6e1a0e
commit d3b1adf59f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1311,6 +1311,7 @@ class DetrImageProcessor(BaseImageProcessor):
FutureWarning,
)
out_logits, raw_masks = outputs.logits, outputs.pred_masks
empty_label = out_logits.shape[-1] - 1
preds = []
def to_tuple(tup):
@ -1320,16 +1321,15 @@ class DetrImageProcessor(BaseImageProcessor):
for cur_logits, cur_masks, size in zip(out_logits, raw_masks, target_sizes):
# we filter empty queries and detection below threshold
scores, labels = cur_logits.softmax(-1).max(-1)
keep = labels.ne(outputs.logits.shape[-1] - 1) & (scores > threshold)
cur_scores, cur_classes = cur_logits.softmax(-1).max(-1)
cur_scores, cur_labels = cur_logits.softmax(-1).max(-1)
keep = cur_labels.ne(empty_label) & (cur_scores > threshold)
cur_scores = cur_scores[keep]
cur_classes = cur_classes[keep]
cur_labels = cur_labels[keep]
cur_masks = cur_masks[keep]
cur_masks = nn.functional.interpolate(cur_masks[:, None], to_tuple(size), mode="bilinear").squeeze(1)
cur_masks = (cur_masks.sigmoid() > mask_threshold) * 1
predictions = {"scores": cur_scores, "labels": cur_classes, "masks": cur_masks}
predictions = {"scores": cur_scores, "labels": cur_labels, "masks": cur_masks}
preds.append(predictions)
return preds
@ -1423,6 +1423,7 @@ class DetrImageProcessor(BaseImageProcessor):
raise ValueError(
"Make sure that you pass in as many target sizes as the batch dimension of the logits and masks"
)
empty_label = out_logits.shape[-1] - 1
preds = []
def to_tuple(tup):
@ -1434,24 +1435,23 @@ class DetrImageProcessor(BaseImageProcessor):
out_logits, raw_masks, raw_boxes, processed_sizes, target_sizes
):
# we filter empty queries and detection below threshold
scores, labels = cur_logits.softmax(-1).max(-1)
keep = labels.ne(outputs.logits.shape[-1] - 1) & (scores > threshold)
cur_scores, cur_classes = cur_logits.softmax(-1).max(-1)
cur_scores, cur_labels = cur_logits.softmax(-1).max(-1)
keep = cur_labels.ne(empty_label) & (cur_scores > threshold)
cur_scores = cur_scores[keep]
cur_classes = cur_classes[keep]
cur_labels = cur_labels[keep]
cur_masks = cur_masks[keep]
cur_masks = nn.functional.interpolate(cur_masks[:, None], to_tuple(size), mode="bilinear").squeeze(1)
cur_boxes = center_to_corners_format(cur_boxes[keep])
h, w = cur_masks.shape[-2:]
if len(cur_boxes) != len(cur_classes):
if len(cur_boxes) != len(cur_labels):
raise ValueError("Not as many boxes as there are classes")
# It may be that we have several predicted masks for the same stuff class.
# In the following, we track the list of masks ids for each stuff class (they are merged later on)
cur_masks = cur_masks.flatten(1)
stuff_equiv_classes = defaultdict(lambda: [])
for k, label in enumerate(cur_classes):
for k, label in enumerate(cur_labels):
if not is_thing_map[label.item()]:
stuff_equiv_classes[label.item()].append(k)
@ -1491,28 +1491,28 @@ class DetrImageProcessor(BaseImageProcessor):
return area, seg_img
area, seg_img = get_ids_area(cur_masks, cur_scores, dedup=True)
if cur_classes.numel() > 0:
if cur_labels.numel() > 0:
# We know filter empty masks as long as we find some
while True:
filtered_small = torch.as_tensor(
[area[i] <= 4 for i, c in enumerate(cur_classes)], dtype=torch.bool, device=keep.device
[area[i] <= 4 for i, c in enumerate(cur_labels)], dtype=torch.bool, device=keep.device
)
if filtered_small.any().item():
cur_scores = cur_scores[~filtered_small]
cur_classes = cur_classes[~filtered_small]
cur_labels = cur_labels[~filtered_small]
cur_masks = cur_masks[~filtered_small]
area, seg_img = get_ids_area(cur_masks, cur_scores)
else:
break
else:
cur_classes = torch.ones(1, dtype=torch.long, device=cur_classes.device)
cur_labels = torch.ones(1, dtype=torch.long, device=cur_labels.device)
segments_info = []
for i, a in enumerate(area):
cat = cur_classes[i].item()
cat = cur_labels[i].item()
segments_info.append({"id": i, "isthing": is_thing_map[cat], "category_id": cat, "area": a})
del cur_classes
del cur_labels
with io.BytesIO() as out:
seg_img.save(out, format="PNG")