mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Add top_k
argument to post-process of conditional/deformable-DETR (#22787)
* update min k_value of conditional detr post-processing * feat: add top_k arg to post processing of deformable and conditional detr * refactor: revert changes to deprecated methods * refactor: move prob reshape to improve code clarity and reduce repetition
This commit is contained in:
parent
f82ee109e6
commit
b92abfa6e0
@ -1328,7 +1328,7 @@ class ConditionalDetrImageProcessor(BaseImageProcessor):
|
||||
|
||||
# Copied from transformers.models.deformable_detr.image_processing_deformable_detr.DeformableDetrImageProcessor.post_process_object_detection with DeformableDetr->ConditionalDetr
|
||||
def post_process_object_detection(
|
||||
self, outputs, threshold: float = 0.5, target_sizes: Union[TensorType, List[Tuple]] = None
|
||||
self, outputs, threshold: float = 0.5, target_sizes: Union[TensorType, List[Tuple]] = None, top_k: int = 100
|
||||
):
|
||||
"""
|
||||
Converts the raw output of [`ConditionalDetrForObjectDetection`] into final bounding boxes in (top_left_x,
|
||||
@ -1342,6 +1342,8 @@ class ConditionalDetrImageProcessor(BaseImageProcessor):
|
||||
target_sizes (`torch.Tensor` or `List[Tuple[int, int]]`, *optional*):
|
||||
Tensor of shape `(batch_size, 2)` or list of tuples (`Tuple[int, int]`) containing the target size
|
||||
(height, width) of each image in the batch. If left to None, predictions will not be resized.
|
||||
top_k (`int`, *optional*, defaults to 100):
|
||||
Keep only top k bounding boxes before filtering by thresholding.
|
||||
|
||||
Returns:
|
||||
`List[Dict]`: A list of dictionaries, each dictionary containing the scores, labels and boxes for an image
|
||||
@ -1356,7 +1358,9 @@ class ConditionalDetrImageProcessor(BaseImageProcessor):
|
||||
)
|
||||
|
||||
prob = out_logits.sigmoid()
|
||||
topk_values, topk_indexes = torch.topk(prob.view(out_logits.shape[0], -1), 100, dim=1)
|
||||
prob = prob.view(out_logits.shape[0], -1)
|
||||
k_value = min(top_k, prob.size(1))
|
||||
topk_values, topk_indexes = torch.topk(prob, k_value, dim=1)
|
||||
scores = topk_values
|
||||
topk_boxes = torch.div(topk_indexes, out_logits.shape[2], rounding_mode="floor")
|
||||
labels = topk_indexes % out_logits.shape[2]
|
||||
|
@ -1325,7 +1325,7 @@ class DeformableDetrImageProcessor(BaseImageProcessor):
|
||||
return results
|
||||
|
||||
def post_process_object_detection(
|
||||
self, outputs, threshold: float = 0.5, target_sizes: Union[TensorType, List[Tuple]] = None
|
||||
self, outputs, threshold: float = 0.5, target_sizes: Union[TensorType, List[Tuple]] = None, top_k: int = 100
|
||||
):
|
||||
"""
|
||||
Converts the raw output of [`DeformableDetrForObjectDetection`] into final bounding boxes in (top_left_x,
|
||||
@ -1339,6 +1339,8 @@ class DeformableDetrImageProcessor(BaseImageProcessor):
|
||||
target_sizes (`torch.Tensor` or `List[Tuple[int, int]]`, *optional*):
|
||||
Tensor of shape `(batch_size, 2)` or list of tuples (`Tuple[int, int]`) containing the target size
|
||||
(height, width) of each image in the batch. If left to None, predictions will not be resized.
|
||||
top_k (`int`, *optional*, defaults to 100):
|
||||
Keep only top k bounding boxes before filtering by thresholding.
|
||||
|
||||
Returns:
|
||||
`List[Dict]`: A list of dictionaries, each dictionary containing the scores, labels and boxes for an image
|
||||
@ -1353,7 +1355,9 @@ class DeformableDetrImageProcessor(BaseImageProcessor):
|
||||
)
|
||||
|
||||
prob = out_logits.sigmoid()
|
||||
topk_values, topk_indexes = torch.topk(prob.view(out_logits.shape[0], -1), 100, dim=1)
|
||||
prob = prob.view(out_logits.shape[0], -1)
|
||||
k_value = min(top_k, prob.size(1))
|
||||
topk_values, topk_indexes = torch.topk(prob, k_value, dim=1)
|
||||
scores = topk_values
|
||||
topk_boxes = torch.div(topk_indexes, out_logits.shape[2], rounding_mode="floor")
|
||||
labels = topk_indexes % out_logits.shape[2]
|
||||
|
Loading…
Reference in New Issue
Block a user