fixed Mask2Former image processor segmentation maps handling (#33364)

* fixed mask2former image processor segmentation maps handling

* introduced review suggestions

* introduced review suggestions
This commit is contained in:
Maciej Adamiak 2024-09-10 12:19:56 +02:00 committed by GitHub
parent 7d2d6ce9cb
commit 8e8e7d8558
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 54 additions and 8 deletions

View File

@ -935,7 +935,7 @@ class Mask2FormerImageProcessor(BaseImageProcessor):
if segmentation_maps is not None:
mask_labels = []
class_labels = []
pad_size = get_max_height_width(pixel_values_list)
pad_size = get_max_height_width(pixel_values_list, input_data_format=input_data_format)
# Convert to list of binary masks and labels
for idx, segmentation_map in enumerate(segmentation_maps):
segmentation_map = to_numpy_array(segmentation_map)

View File

@ -20,6 +20,7 @@ import numpy as np
from datasets import load_dataset
from huggingface_hub import hf_hub_download
from transformers.image_utils import ChannelDimension
from transformers.testing_utils import require_torch, require_vision
from transformers.utils import is_torch_available, is_vision_available
@ -180,31 +181,44 @@ class Mask2FormerImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase
self.assertEqual(image_processor.size_divisor, 8)
def comm_get_image_processing_inputs(
self, with_segmentation_maps=False, is_instance_map=False, segmentation_type="np"
self,
image_processor_tester,
with_segmentation_maps=False,
is_instance_map=False,
segmentation_type="np",
numpify=False,
input_data_format=None,
):
image_processing = self.image_processing_class(**self.image_processor_dict)
image_processing = self.image_processing_class(**image_processor_tester.prepare_image_processor_dict())
# prepare image and target
num_labels = self.image_processor_tester.num_labels
num_labels = image_processor_tester.num_labels
annotations = None
instance_id_to_semantic_id = None
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False)
image_inputs = image_processor_tester.prepare_image_inputs(equal_resolution=False, numpify=numpify)
if with_segmentation_maps:
high = num_labels
if is_instance_map:
labels_expanded = list(range(num_labels)) * 2
instance_id_to_semantic_id = dict(enumerate(labels_expanded))
annotations = [
np.random.randint(0, high * 2, (img.size[1], img.size[0])).astype(np.uint8) for img in image_inputs
np.random.randint(0, high * 2, img.shape[:2] if numpify else (img.size[1], img.size[0])).astype(
np.uint8
)
for img in image_inputs
]
if segmentation_type == "pil":
annotations = [Image.fromarray(annotation) for annotation in annotations]
if input_data_format is ChannelDimension.FIRST and numpify:
image_inputs = [np.moveaxis(img, -1, 0) for img in image_inputs]
inputs = image_processing(
image_inputs,
annotations,
return_tensors="pt",
instance_id_to_semantic_id=instance_id_to_semantic_id,
pad_and_return_pixel_mask=True,
input_data_format=input_data_format,
)
return inputs
@ -223,9 +237,29 @@ class Mask2FormerImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase
self.assertTrue((pixel_values.shape[-2] % size_divisor) == 0)
def test_call_with_segmentation_maps(self):
def common(is_instance_map=False, segmentation_type=None):
def common(
is_instance_map=False,
segmentation_type=None,
numpify=False,
num_channels=3,
input_data_format=None,
do_resize=True,
):
image_processor_tester = Mask2FormerImageProcessingTester(
self,
num_channels=num_channels,
do_resize=do_resize,
image_mean=[0.5] * num_channels,
image_std=[0.5] * num_channels,
)
inputs = self.comm_get_image_processing_inputs(
with_segmentation_maps=True, is_instance_map=is_instance_map, segmentation_type=segmentation_type
image_processor_tester=image_processor_tester,
with_segmentation_maps=True,
is_instance_map=is_instance_map,
segmentation_type=segmentation_type,
numpify=numpify,
input_data_format=input_data_format,
)
mask_labels = inputs["mask_labels"]
@ -243,6 +277,18 @@ class Mask2FormerImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase
common(is_instance_map=False, segmentation_type="pil")
common(is_instance_map=True, segmentation_type="pil")
common(num_channels=1, numpify=True)
common(num_channels=1, numpify=True, input_data_format=ChannelDimension.FIRST)
common(num_channels=2, numpify=True, input_data_format=ChannelDimension.LAST)
common(num_channels=5, numpify=True, input_data_format=ChannelDimension.LAST, do_resize=False)
common(num_channels=5, numpify=True, input_data_format=ChannelDimension.FIRST, do_resize=False)
with self.assertRaisesRegex(ValueError, expected_regex="Unable to infer channel dimension format"):
common(num_channels=5, numpify=True, do_resize=False)
with self.assertRaisesRegex(TypeError, expected_regex=r"Cannot handle this data type: .*"):
common(num_channels=5, numpify=True, input_data_format=ChannelDimension.LAST)
def test_integration_instance_segmentation(self):
# load 2 images and corresponding annotations from the hub
repo_id = "nielsr/image-segmentation-toy-data"