Fix MaskFormer failing postprocess tests (#19354)

Ensures post_process_instance_segmentation and post_process_panoptic_segmentation methods return a tensor of shape (target_height, target_width) filled with -1 values if no segment with score > threshold is found.
This commit is contained in:
Alara Dirik 2022-10-05 23:25:58 +03:00 committed by GitHub
parent ad98642a82
commit 7598791c09
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 10 additions and 7 deletions

View File

@ -772,8 +772,9 @@ class MaskFormerFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionM
# No mask found
if mask_probs_item.shape[0] <= 0:
segmentation = None
segments: List[Dict] = []
height, width = target_sizes[i] if target_sizes is not None else mask_probs_item.shape[1:]
segmentation = torch.zeros((height, width)) - 1
results.append({"segmentation": segmentation, "segments_info": []})
continue
# Get segmentation map and segment information of batch item
@ -860,8 +861,9 @@ class MaskFormerFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionM
# No mask found
if mask_probs_item.shape[0] <= 0:
segmentation = None
segments: List[Dict] = []
height, width = target_sizes[i] if target_sizes is not None else mask_probs_item.shape[1:]
segmentation = torch.zeros((height, width)) - 1
results.append({"segmentation": segmentation, "segments_info": []})
continue
# Get segmentation map and segment information of batch item

View File

@ -401,10 +401,11 @@ class MaskFormerFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest
@unittest.skip("Fix me Alara!")
def test_post_process_panoptic_segmentation(self):
fature_extractor = self.feature_extraction_class(num_labels=self.feature_extract_tester.num_classes)
feature_extractor = self.feature_extraction_class(num_labels=self.feature_extract_tester.num_classes)
outputs = self.feature_extract_tester.get_fake_maskformer_outputs()
segmentation = fature_extractor.post_process_panoptic_segmentation(outputs, threshold=0)
segmentation = feature_extractor.post_process_panoptic_segmentation(outputs, threshold=0)
print(len(segmentation))
print(self.feature_extract_tester.batch_size)
self.assertTrue(len(segmentation) == self.feature_extract_tester.batch_size)
for el in segmentation:
self.assertTrue("segmentation" in el)