Add check for target_sizes is None in post_process_image_guided_detection for owlv2 (#31934)

* Add check for target_sizes is None in post_process_image_guided_detection

* Make sure Owlvit and Owlv2 in sync

* Fix incorrect indentation; add check for correct size of target_sizes
This commit is contained in:
Connor Anderson 2024-07-26 05:05:46 -04:00 committed by GitHub
parent f9756d9edb
commit 5f841c74b6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 20 additions and 10 deletions

View File

@ -565,9 +565,9 @@ class Owlv2ImageProcessor(BaseImageProcessor):
"""
logits, target_boxes = outputs.logits, outputs.target_pred_boxes
if len(logits) != len(target_sizes):
if target_sizes is not None and len(logits) != len(target_sizes):
raise ValueError("Make sure that you pass in as many target sizes as the batch dimension of the logits")
if target_sizes.shape[1] != 2:
if target_sizes is not None and target_sizes.shape[1] != 2:
raise ValueError("Each element of target_sizes must contain the size (h, w) of each image of the batch")
probs = torch.max(logits, dim=-1)
@ -588,9 +588,14 @@ class Owlv2ImageProcessor(BaseImageProcessor):
scores[idx][ious > nms_threshold] = 0.0
# Convert from relative [0, 1] to absolute [0, height] coordinates
img_h, img_w = target_sizes.unbind(1)
scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1).to(target_boxes.device)
target_boxes = target_boxes * scale_fct[:, None, :]
if target_sizes is not None:
if isinstance(target_sizes, List):
img_h = torch.tensor([i[0] for i in target_sizes])
img_w = torch.tensor([i[1] for i in target_sizes])
else:
img_h, img_w = target_sizes.unbind(1)
scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1).to(target_boxes.device)
target_boxes = target_boxes * scale_fct[:, None, :]
# Compute box display alphas based on prediction scores
results = []

View File

@ -556,9 +556,9 @@ class OwlViTImageProcessor(BaseImageProcessor):
"""
logits, target_boxes = outputs.logits, outputs.target_pred_boxes
if len(logits) != len(target_sizes):
if target_sizes is not None and len(logits) != len(target_sizes):
raise ValueError("Make sure that you pass in as many target sizes as the batch dimension of the logits")
if target_sizes.shape[1] != 2:
if target_sizes is not None and target_sizes.shape[1] != 2:
raise ValueError("Each element of target_sizes must contain the size (h, w) of each image of the batch")
probs = torch.max(logits, dim=-1)
@ -579,9 +579,14 @@ class OwlViTImageProcessor(BaseImageProcessor):
scores[idx][ious > nms_threshold] = 0.0
# Convert from relative [0, 1] to absolute [0, height] coordinates
img_h, img_w = target_sizes.unbind(1)
scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1).to(target_boxes.device)
target_boxes = target_boxes * scale_fct[:, None, :]
if target_sizes is not None:
if isinstance(target_sizes, List):
img_h = torch.tensor([i[0] for i in target_sizes])
img_w = torch.tensor([i[1] for i in target_sizes])
else:
img_h, img_w = target_sizes.unbind(1)
scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1).to(target_boxes.device)
target_boxes = target_boxes * scale_fct[:, None, :]
# Compute box display alphas based on prediction scores
results = []