Removes duplicate computations in DETR post processing (#21592)

* Remove redundant computations, comb variable names

* Fix scores to cur_scores
This commit is contained in:
Vitali Petsiuk 2023-02-14 13:00:02 -05:00 committed by GitHub
parent d4ba6e1a0e
commit d3b1adf59f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1311,6 +1311,7 @@ class DetrImageProcessor(BaseImageProcessor):
FutureWarning, FutureWarning,
) )
out_logits, raw_masks = outputs.logits, outputs.pred_masks out_logits, raw_masks = outputs.logits, outputs.pred_masks
empty_label = out_logits.shape[-1] - 1
preds = [] preds = []
def to_tuple(tup): def to_tuple(tup):
@ -1320,16 +1321,15 @@ class DetrImageProcessor(BaseImageProcessor):
for cur_logits, cur_masks, size in zip(out_logits, raw_masks, target_sizes): for cur_logits, cur_masks, size in zip(out_logits, raw_masks, target_sizes):
# we filter empty queries and detection below threshold # we filter empty queries and detection below threshold
scores, labels = cur_logits.softmax(-1).max(-1) cur_scores, cur_labels = cur_logits.softmax(-1).max(-1)
keep = labels.ne(outputs.logits.shape[-1] - 1) & (scores > threshold) keep = cur_labels.ne(empty_label) & (cur_scores > threshold)
cur_scores, cur_classes = cur_logits.softmax(-1).max(-1)
cur_scores = cur_scores[keep] cur_scores = cur_scores[keep]
cur_classes = cur_classes[keep] cur_labels = cur_labels[keep]
cur_masks = cur_masks[keep] cur_masks = cur_masks[keep]
cur_masks = nn.functional.interpolate(cur_masks[:, None], to_tuple(size), mode="bilinear").squeeze(1) cur_masks = nn.functional.interpolate(cur_masks[:, None], to_tuple(size), mode="bilinear").squeeze(1)
cur_masks = (cur_masks.sigmoid() > mask_threshold) * 1 cur_masks = (cur_masks.sigmoid() > mask_threshold) * 1
predictions = {"scores": cur_scores, "labels": cur_classes, "masks": cur_masks} predictions = {"scores": cur_scores, "labels": cur_labels, "masks": cur_masks}
preds.append(predictions) preds.append(predictions)
return preds return preds
@ -1423,6 +1423,7 @@ class DetrImageProcessor(BaseImageProcessor):
raise ValueError( raise ValueError(
"Make sure that you pass in as many target sizes as the batch dimension of the logits and masks" "Make sure that you pass in as many target sizes as the batch dimension of the logits and masks"
) )
empty_label = out_logits.shape[-1] - 1
preds = [] preds = []
def to_tuple(tup): def to_tuple(tup):
@ -1434,24 +1435,23 @@ class DetrImageProcessor(BaseImageProcessor):
out_logits, raw_masks, raw_boxes, processed_sizes, target_sizes out_logits, raw_masks, raw_boxes, processed_sizes, target_sizes
): ):
# we filter empty queries and detection below threshold # we filter empty queries and detection below threshold
scores, labels = cur_logits.softmax(-1).max(-1) cur_scores, cur_labels = cur_logits.softmax(-1).max(-1)
keep = labels.ne(outputs.logits.shape[-1] - 1) & (scores > threshold) keep = cur_labels.ne(empty_label) & (cur_scores > threshold)
cur_scores, cur_classes = cur_logits.softmax(-1).max(-1)
cur_scores = cur_scores[keep] cur_scores = cur_scores[keep]
cur_classes = cur_classes[keep] cur_labels = cur_labels[keep]
cur_masks = cur_masks[keep] cur_masks = cur_masks[keep]
cur_masks = nn.functional.interpolate(cur_masks[:, None], to_tuple(size), mode="bilinear").squeeze(1) cur_masks = nn.functional.interpolate(cur_masks[:, None], to_tuple(size), mode="bilinear").squeeze(1)
cur_boxes = center_to_corners_format(cur_boxes[keep]) cur_boxes = center_to_corners_format(cur_boxes[keep])
h, w = cur_masks.shape[-2:] h, w = cur_masks.shape[-2:]
if len(cur_boxes) != len(cur_classes): if len(cur_boxes) != len(cur_labels):
raise ValueError("Not as many boxes as there are classes") raise ValueError("Not as many boxes as there are classes")
# It may be that we have several predicted masks for the same stuff class. # It may be that we have several predicted masks for the same stuff class.
# In the following, we track the list of masks ids for each stuff class (they are merged later on) # In the following, we track the list of masks ids for each stuff class (they are merged later on)
cur_masks = cur_masks.flatten(1) cur_masks = cur_masks.flatten(1)
stuff_equiv_classes = defaultdict(lambda: []) stuff_equiv_classes = defaultdict(lambda: [])
for k, label in enumerate(cur_classes): for k, label in enumerate(cur_labels):
if not is_thing_map[label.item()]: if not is_thing_map[label.item()]:
stuff_equiv_classes[label.item()].append(k) stuff_equiv_classes[label.item()].append(k)
@ -1491,28 +1491,28 @@ class DetrImageProcessor(BaseImageProcessor):
return area, seg_img return area, seg_img
area, seg_img = get_ids_area(cur_masks, cur_scores, dedup=True) area, seg_img = get_ids_area(cur_masks, cur_scores, dedup=True)
if cur_classes.numel() > 0: if cur_labels.numel() > 0:
# We know filter empty masks as long as we find some # We know filter empty masks as long as we find some
while True: while True:
filtered_small = torch.as_tensor( filtered_small = torch.as_tensor(
[area[i] <= 4 for i, c in enumerate(cur_classes)], dtype=torch.bool, device=keep.device [area[i] <= 4 for i, c in enumerate(cur_labels)], dtype=torch.bool, device=keep.device
) )
if filtered_small.any().item(): if filtered_small.any().item():
cur_scores = cur_scores[~filtered_small] cur_scores = cur_scores[~filtered_small]
cur_classes = cur_classes[~filtered_small] cur_labels = cur_labels[~filtered_small]
cur_masks = cur_masks[~filtered_small] cur_masks = cur_masks[~filtered_small]
area, seg_img = get_ids_area(cur_masks, cur_scores) area, seg_img = get_ids_area(cur_masks, cur_scores)
else: else:
break break
else: else:
cur_classes = torch.ones(1, dtype=torch.long, device=cur_classes.device) cur_labels = torch.ones(1, dtype=torch.long, device=cur_labels.device)
segments_info = [] segments_info = []
for i, a in enumerate(area): for i, a in enumerate(area):
cat = cur_classes[i].item() cat = cur_labels[i].item()
segments_info.append({"id": i, "isthing": is_thing_map[cat], "category_id": cat, "area": a}) segments_info.append({"id": i, "isthing": is_thing_map[cat], "category_id": cat, "area": a})
del cur_classes del cur_labels
with io.BytesIO() as out: with io.BytesIO() as out:
seg_img.save(out, format="PNG") seg_img.save(out, format="PNG")