mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-24 23:08:57 +06:00
Removes duplicate computations in DETR post processing (#21592)
* Remove redundant computations, comb variable names * Fix scores to cur_scores
This commit is contained in:
parent
d4ba6e1a0e
commit
d3b1adf59f
@ -1311,6 +1311,7 @@ class DetrImageProcessor(BaseImageProcessor):
|
||||
FutureWarning,
|
||||
)
|
||||
out_logits, raw_masks = outputs.logits, outputs.pred_masks
|
||||
empty_label = out_logits.shape[-1] - 1
|
||||
preds = []
|
||||
|
||||
def to_tuple(tup):
|
||||
@ -1320,16 +1321,15 @@ class DetrImageProcessor(BaseImageProcessor):
|
||||
|
||||
for cur_logits, cur_masks, size in zip(out_logits, raw_masks, target_sizes):
|
||||
# we filter empty queries and detection below threshold
|
||||
scores, labels = cur_logits.softmax(-1).max(-1)
|
||||
keep = labels.ne(outputs.logits.shape[-1] - 1) & (scores > threshold)
|
||||
cur_scores, cur_classes = cur_logits.softmax(-1).max(-1)
|
||||
cur_scores, cur_labels = cur_logits.softmax(-1).max(-1)
|
||||
keep = cur_labels.ne(empty_label) & (cur_scores > threshold)
|
||||
cur_scores = cur_scores[keep]
|
||||
cur_classes = cur_classes[keep]
|
||||
cur_labels = cur_labels[keep]
|
||||
cur_masks = cur_masks[keep]
|
||||
cur_masks = nn.functional.interpolate(cur_masks[:, None], to_tuple(size), mode="bilinear").squeeze(1)
|
||||
cur_masks = (cur_masks.sigmoid() > mask_threshold) * 1
|
||||
|
||||
predictions = {"scores": cur_scores, "labels": cur_classes, "masks": cur_masks}
|
||||
predictions = {"scores": cur_scores, "labels": cur_labels, "masks": cur_masks}
|
||||
preds.append(predictions)
|
||||
return preds
|
||||
|
||||
@ -1423,6 +1423,7 @@ class DetrImageProcessor(BaseImageProcessor):
|
||||
raise ValueError(
|
||||
"Make sure that you pass in as many target sizes as the batch dimension of the logits and masks"
|
||||
)
|
||||
empty_label = out_logits.shape[-1] - 1
|
||||
preds = []
|
||||
|
||||
def to_tuple(tup):
|
||||
@ -1434,24 +1435,23 @@ class DetrImageProcessor(BaseImageProcessor):
|
||||
out_logits, raw_masks, raw_boxes, processed_sizes, target_sizes
|
||||
):
|
||||
# we filter empty queries and detection below threshold
|
||||
scores, labels = cur_logits.softmax(-1).max(-1)
|
||||
keep = labels.ne(outputs.logits.shape[-1] - 1) & (scores > threshold)
|
||||
cur_scores, cur_classes = cur_logits.softmax(-1).max(-1)
|
||||
cur_scores, cur_labels = cur_logits.softmax(-1).max(-1)
|
||||
keep = cur_labels.ne(empty_label) & (cur_scores > threshold)
|
||||
cur_scores = cur_scores[keep]
|
||||
cur_classes = cur_classes[keep]
|
||||
cur_labels = cur_labels[keep]
|
||||
cur_masks = cur_masks[keep]
|
||||
cur_masks = nn.functional.interpolate(cur_masks[:, None], to_tuple(size), mode="bilinear").squeeze(1)
|
||||
cur_boxes = center_to_corners_format(cur_boxes[keep])
|
||||
|
||||
h, w = cur_masks.shape[-2:]
|
||||
if len(cur_boxes) != len(cur_classes):
|
||||
if len(cur_boxes) != len(cur_labels):
|
||||
raise ValueError("Not as many boxes as there are classes")
|
||||
|
||||
# It may be that we have several predicted masks for the same stuff class.
|
||||
# In the following, we track the list of masks ids for each stuff class (they are merged later on)
|
||||
cur_masks = cur_masks.flatten(1)
|
||||
stuff_equiv_classes = defaultdict(lambda: [])
|
||||
for k, label in enumerate(cur_classes):
|
||||
for k, label in enumerate(cur_labels):
|
||||
if not is_thing_map[label.item()]:
|
||||
stuff_equiv_classes[label.item()].append(k)
|
||||
|
||||
@ -1491,28 +1491,28 @@ class DetrImageProcessor(BaseImageProcessor):
|
||||
return area, seg_img
|
||||
|
||||
area, seg_img = get_ids_area(cur_masks, cur_scores, dedup=True)
|
||||
if cur_classes.numel() > 0:
|
||||
if cur_labels.numel() > 0:
|
||||
# We know filter empty masks as long as we find some
|
||||
while True:
|
||||
filtered_small = torch.as_tensor(
|
||||
[area[i] <= 4 for i, c in enumerate(cur_classes)], dtype=torch.bool, device=keep.device
|
||||
[area[i] <= 4 for i, c in enumerate(cur_labels)], dtype=torch.bool, device=keep.device
|
||||
)
|
||||
if filtered_small.any().item():
|
||||
cur_scores = cur_scores[~filtered_small]
|
||||
cur_classes = cur_classes[~filtered_small]
|
||||
cur_labels = cur_labels[~filtered_small]
|
||||
cur_masks = cur_masks[~filtered_small]
|
||||
area, seg_img = get_ids_area(cur_masks, cur_scores)
|
||||
else:
|
||||
break
|
||||
|
||||
else:
|
||||
cur_classes = torch.ones(1, dtype=torch.long, device=cur_classes.device)
|
||||
cur_labels = torch.ones(1, dtype=torch.long, device=cur_labels.device)
|
||||
|
||||
segments_info = []
|
||||
for i, a in enumerate(area):
|
||||
cat = cur_classes[i].item()
|
||||
cat = cur_labels[i].item()
|
||||
segments_info.append({"id": i, "isthing": is_thing_map[cat], "category_id": cat, "area": a})
|
||||
del cur_classes
|
||||
del cur_labels
|
||||
|
||||
with io.BytesIO() as out:
|
||||
seg_img.save(out, format="PNG")
|
||||
|
Loading…
Reference in New Issue
Block a user