mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-30 17:52:35 +06:00
Fix OneFormer post_process_instance_segmentation
for panoptic tasks (#29304)
* 🐛 Fix oneformer instance post processing when using panoptic task type * ✅ Add unit test for oneformer instance post processing panoptic bug --------- Co-authored-by: Nick DeGroot <1966472+nickthegroot@users.noreply.github.com>
This commit is contained in:
parent
81220cba61
commit
8ef9862864
@ -1244,8 +1244,8 @@ class OneFormerImageProcessor(BaseImageProcessor):
|
||||
# if this is panoptic segmentation, we only keep the "thing" classes
|
||||
if task_type == "panoptic":
|
||||
keep = torch.zeros_like(scores_per_image).bool()
|
||||
for i, lab in enumerate(labels_per_image):
|
||||
keep[i] = lab in self.metadata["thing_ids"]
|
||||
for j, lab in enumerate(labels_per_image):
|
||||
keep[j] = lab in self.metadata["thing_ids"]
|
||||
|
||||
scores_per_image = scores_per_image[keep]
|
||||
labels_per_image = labels_per_image[keep]
|
||||
@ -1258,8 +1258,8 @@ class OneFormerImageProcessor(BaseImageProcessor):
|
||||
continue
|
||||
|
||||
if "ade20k" in self.class_info_file and not is_demo and "instance" in task_type:
|
||||
for i in range(labels_per_image.shape[0]):
|
||||
labels_per_image[i] = self.metadata["thing_ids"].index(labels_per_image[i].item())
|
||||
for j in range(labels_per_image.shape[0]):
|
||||
labels_per_image[j] = self.metadata["thing_ids"].index(labels_per_image[j].item())
|
||||
|
||||
# Get segmentation map and segment information of batch item
|
||||
target_size = target_sizes[i] if target_sizes is not None else None
|
||||
|
@ -295,6 +295,19 @@ class OneFormerImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
|
||||
el["segmentation"].shape, (self.image_processor_tester.height, self.image_processor_tester.width)
|
||||
)
|
||||
|
||||
segmentation_with_opts = image_processor.post_process_instance_segmentation(
|
||||
outputs,
|
||||
threshold=0,
|
||||
target_sizes=[(1, 4) for _ in range(self.image_processor_tester.batch_size)],
|
||||
task_type="panoptic",
|
||||
)
|
||||
self.assertTrue(len(segmentation_with_opts) == self.image_processor_tester.batch_size)
|
||||
for el in segmentation_with_opts:
|
||||
self.assertTrue("segmentation" in el)
|
||||
self.assertTrue("segments_info" in el)
|
||||
self.assertEqual(type(el["segments_info"]), list)
|
||||
self.assertEqual(el["segmentation"].shape, (1, 4))
|
||||
|
||||
def test_post_process_panoptic_segmentation(self):
|
||||
image_processor = self.image_processing_class(
|
||||
num_labels=self.image_processor_tester.num_classes,
|
||||
|
Loading…
Reference in New Issue
Block a user