mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-25 23:38:59 +06:00
Removes duplicate computations in DETR post processing (#21592)
* Remove redundant computations, comb variable names * Fix scores to cur_scores
This commit is contained in:
parent
d4ba6e1a0e
commit
d3b1adf59f
@ -1311,6 +1311,7 @@ class DetrImageProcessor(BaseImageProcessor):
|
|||||||
FutureWarning,
|
FutureWarning,
|
||||||
)
|
)
|
||||||
out_logits, raw_masks = outputs.logits, outputs.pred_masks
|
out_logits, raw_masks = outputs.logits, outputs.pred_masks
|
||||||
|
empty_label = out_logits.shape[-1] - 1
|
||||||
preds = []
|
preds = []
|
||||||
|
|
||||||
def to_tuple(tup):
|
def to_tuple(tup):
|
||||||
@ -1320,16 +1321,15 @@ class DetrImageProcessor(BaseImageProcessor):
|
|||||||
|
|
||||||
for cur_logits, cur_masks, size in zip(out_logits, raw_masks, target_sizes):
|
for cur_logits, cur_masks, size in zip(out_logits, raw_masks, target_sizes):
|
||||||
# we filter empty queries and detection below threshold
|
# we filter empty queries and detection below threshold
|
||||||
scores, labels = cur_logits.softmax(-1).max(-1)
|
cur_scores, cur_labels = cur_logits.softmax(-1).max(-1)
|
||||||
keep = labels.ne(outputs.logits.shape[-1] - 1) & (scores > threshold)
|
keep = cur_labels.ne(empty_label) & (cur_scores > threshold)
|
||||||
cur_scores, cur_classes = cur_logits.softmax(-1).max(-1)
|
|
||||||
cur_scores = cur_scores[keep]
|
cur_scores = cur_scores[keep]
|
||||||
cur_classes = cur_classes[keep]
|
cur_labels = cur_labels[keep]
|
||||||
cur_masks = cur_masks[keep]
|
cur_masks = cur_masks[keep]
|
||||||
cur_masks = nn.functional.interpolate(cur_masks[:, None], to_tuple(size), mode="bilinear").squeeze(1)
|
cur_masks = nn.functional.interpolate(cur_masks[:, None], to_tuple(size), mode="bilinear").squeeze(1)
|
||||||
cur_masks = (cur_masks.sigmoid() > mask_threshold) * 1
|
cur_masks = (cur_masks.sigmoid() > mask_threshold) * 1
|
||||||
|
|
||||||
predictions = {"scores": cur_scores, "labels": cur_classes, "masks": cur_masks}
|
predictions = {"scores": cur_scores, "labels": cur_labels, "masks": cur_masks}
|
||||||
preds.append(predictions)
|
preds.append(predictions)
|
||||||
return preds
|
return preds
|
||||||
|
|
||||||
@ -1423,6 +1423,7 @@ class DetrImageProcessor(BaseImageProcessor):
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Make sure that you pass in as many target sizes as the batch dimension of the logits and masks"
|
"Make sure that you pass in as many target sizes as the batch dimension of the logits and masks"
|
||||||
)
|
)
|
||||||
|
empty_label = out_logits.shape[-1] - 1
|
||||||
preds = []
|
preds = []
|
||||||
|
|
||||||
def to_tuple(tup):
|
def to_tuple(tup):
|
||||||
@ -1434,24 +1435,23 @@ class DetrImageProcessor(BaseImageProcessor):
|
|||||||
out_logits, raw_masks, raw_boxes, processed_sizes, target_sizes
|
out_logits, raw_masks, raw_boxes, processed_sizes, target_sizes
|
||||||
):
|
):
|
||||||
# we filter empty queries and detection below threshold
|
# we filter empty queries and detection below threshold
|
||||||
scores, labels = cur_logits.softmax(-1).max(-1)
|
cur_scores, cur_labels = cur_logits.softmax(-1).max(-1)
|
||||||
keep = labels.ne(outputs.logits.shape[-1] - 1) & (scores > threshold)
|
keep = cur_labels.ne(empty_label) & (cur_scores > threshold)
|
||||||
cur_scores, cur_classes = cur_logits.softmax(-1).max(-1)
|
|
||||||
cur_scores = cur_scores[keep]
|
cur_scores = cur_scores[keep]
|
||||||
cur_classes = cur_classes[keep]
|
cur_labels = cur_labels[keep]
|
||||||
cur_masks = cur_masks[keep]
|
cur_masks = cur_masks[keep]
|
||||||
cur_masks = nn.functional.interpolate(cur_masks[:, None], to_tuple(size), mode="bilinear").squeeze(1)
|
cur_masks = nn.functional.interpolate(cur_masks[:, None], to_tuple(size), mode="bilinear").squeeze(1)
|
||||||
cur_boxes = center_to_corners_format(cur_boxes[keep])
|
cur_boxes = center_to_corners_format(cur_boxes[keep])
|
||||||
|
|
||||||
h, w = cur_masks.shape[-2:]
|
h, w = cur_masks.shape[-2:]
|
||||||
if len(cur_boxes) != len(cur_classes):
|
if len(cur_boxes) != len(cur_labels):
|
||||||
raise ValueError("Not as many boxes as there are classes")
|
raise ValueError("Not as many boxes as there are classes")
|
||||||
|
|
||||||
# It may be that we have several predicted masks for the same stuff class.
|
# It may be that we have several predicted masks for the same stuff class.
|
||||||
# In the following, we track the list of masks ids for each stuff class (they are merged later on)
|
# In the following, we track the list of masks ids for each stuff class (they are merged later on)
|
||||||
cur_masks = cur_masks.flatten(1)
|
cur_masks = cur_masks.flatten(1)
|
||||||
stuff_equiv_classes = defaultdict(lambda: [])
|
stuff_equiv_classes = defaultdict(lambda: [])
|
||||||
for k, label in enumerate(cur_classes):
|
for k, label in enumerate(cur_labels):
|
||||||
if not is_thing_map[label.item()]:
|
if not is_thing_map[label.item()]:
|
||||||
stuff_equiv_classes[label.item()].append(k)
|
stuff_equiv_classes[label.item()].append(k)
|
||||||
|
|
||||||
@ -1491,28 +1491,28 @@ class DetrImageProcessor(BaseImageProcessor):
|
|||||||
return area, seg_img
|
return area, seg_img
|
||||||
|
|
||||||
area, seg_img = get_ids_area(cur_masks, cur_scores, dedup=True)
|
area, seg_img = get_ids_area(cur_masks, cur_scores, dedup=True)
|
||||||
if cur_classes.numel() > 0:
|
if cur_labels.numel() > 0:
|
||||||
# We know filter empty masks as long as we find some
|
# We know filter empty masks as long as we find some
|
||||||
while True:
|
while True:
|
||||||
filtered_small = torch.as_tensor(
|
filtered_small = torch.as_tensor(
|
||||||
[area[i] <= 4 for i, c in enumerate(cur_classes)], dtype=torch.bool, device=keep.device
|
[area[i] <= 4 for i, c in enumerate(cur_labels)], dtype=torch.bool, device=keep.device
|
||||||
)
|
)
|
||||||
if filtered_small.any().item():
|
if filtered_small.any().item():
|
||||||
cur_scores = cur_scores[~filtered_small]
|
cur_scores = cur_scores[~filtered_small]
|
||||||
cur_classes = cur_classes[~filtered_small]
|
cur_labels = cur_labels[~filtered_small]
|
||||||
cur_masks = cur_masks[~filtered_small]
|
cur_masks = cur_masks[~filtered_small]
|
||||||
area, seg_img = get_ids_area(cur_masks, cur_scores)
|
area, seg_img = get_ids_area(cur_masks, cur_scores)
|
||||||
else:
|
else:
|
||||||
break
|
break
|
||||||
|
|
||||||
else:
|
else:
|
||||||
cur_classes = torch.ones(1, dtype=torch.long, device=cur_classes.device)
|
cur_labels = torch.ones(1, dtype=torch.long, device=cur_labels.device)
|
||||||
|
|
||||||
segments_info = []
|
segments_info = []
|
||||||
for i, a in enumerate(area):
|
for i, a in enumerate(area):
|
||||||
cat = cur_classes[i].item()
|
cat = cur_labels[i].item()
|
||||||
segments_info.append({"id": i, "isthing": is_thing_map[cat], "category_id": cat, "area": a})
|
segments_info.append({"id": i, "isthing": is_thing_map[cat], "category_id": cat, "area": a})
|
||||||
del cur_classes
|
del cur_labels
|
||||||
|
|
||||||
with io.BytesIO() as out:
|
with io.BytesIO() as out:
|
||||||
seg_img.save(out, format="PNG")
|
seg_img.save(out, format="PNG")
|
||||||
|
Loading…
Reference in New Issue
Block a user