mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Fix MaskFormer failing postprocess tests (#19354)
Ensures post_process_instance_segmentation and post_process_panoptic_segmentation methods return a tensor of shape (target_height, target_width) filled with -1 values if no segment with score > threshold is found.
This commit is contained in:
parent
ad98642a82
commit
7598791c09
@ -772,8 +772,9 @@ class MaskFormerFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionM
|
||||
|
||||
# No mask found
|
||||
if mask_probs_item.shape[0] <= 0:
|
||||
segmentation = None
|
||||
segments: List[Dict] = []
|
||||
height, width = target_sizes[i] if target_sizes is not None else mask_probs_item.shape[1:]
|
||||
segmentation = torch.zeros((height, width)) - 1
|
||||
results.append({"segmentation": segmentation, "segments_info": []})
|
||||
continue
|
||||
|
||||
# Get segmentation map and segment information of batch item
|
||||
@ -860,8 +861,9 @@ class MaskFormerFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionM
|
||||
|
||||
# No mask found
|
||||
if mask_probs_item.shape[0] <= 0:
|
||||
segmentation = None
|
||||
segments: List[Dict] = []
|
||||
height, width = target_sizes[i] if target_sizes is not None else mask_probs_item.shape[1:]
|
||||
segmentation = torch.zeros((height, width)) - 1
|
||||
results.append({"segmentation": segmentation, "segments_info": []})
|
||||
continue
|
||||
|
||||
# Get segmentation map and segment information of batch item
|
||||
|
@ -401,10 +401,11 @@ class MaskFormerFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest
|
||||
|
||||
@unittest.skip("Fix me Alara!")
|
||||
def test_post_process_panoptic_segmentation(self):
|
||||
fature_extractor = self.feature_extraction_class(num_labels=self.feature_extract_tester.num_classes)
|
||||
feature_extractor = self.feature_extraction_class(num_labels=self.feature_extract_tester.num_classes)
|
||||
outputs = self.feature_extract_tester.get_fake_maskformer_outputs()
|
||||
segmentation = fature_extractor.post_process_panoptic_segmentation(outputs, threshold=0)
|
||||
|
||||
segmentation = feature_extractor.post_process_panoptic_segmentation(outputs, threshold=0)
|
||||
print(len(segmentation))
|
||||
print(self.feature_extract_tester.batch_size)
|
||||
self.assertTrue(len(segmentation) == self.feature_extract_tester.batch_size)
|
||||
for el in segmentation:
|
||||
self.assertTrue("segmentation" in el)
|
||||
|
Loading…
Reference in New Issue
Block a user