mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
fixed Mask2Former image processor segmentation maps handling (#33364)
* fixed mask2former image processor segmentation maps handling * introduced review suggestions * introduced review suggestions
This commit is contained in:
parent
7d2d6ce9cb
commit
8e8e7d8558
@ -935,7 +935,7 @@ class Mask2FormerImageProcessor(BaseImageProcessor):
|
||||
if segmentation_maps is not None:
|
||||
mask_labels = []
|
||||
class_labels = []
|
||||
pad_size = get_max_height_width(pixel_values_list)
|
||||
pad_size = get_max_height_width(pixel_values_list, input_data_format=input_data_format)
|
||||
# Convert to list of binary masks and labels
|
||||
for idx, segmentation_map in enumerate(segmentation_maps):
|
||||
segmentation_map = to_numpy_array(segmentation_map)
|
||||
|
@ -20,6 +20,7 @@ import numpy as np
|
||||
from datasets import load_dataset
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
from transformers.image_utils import ChannelDimension
|
||||
from transformers.testing_utils import require_torch, require_vision
|
||||
from transformers.utils import is_torch_available, is_vision_available
|
||||
|
||||
@ -180,31 +181,44 @@ class Mask2FormerImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase
|
||||
self.assertEqual(image_processor.size_divisor, 8)
|
||||
|
||||
def comm_get_image_processing_inputs(
|
||||
self, with_segmentation_maps=False, is_instance_map=False, segmentation_type="np"
|
||||
self,
|
||||
image_processor_tester,
|
||||
with_segmentation_maps=False,
|
||||
is_instance_map=False,
|
||||
segmentation_type="np",
|
||||
numpify=False,
|
||||
input_data_format=None,
|
||||
):
|
||||
image_processing = self.image_processing_class(**self.image_processor_dict)
|
||||
image_processing = self.image_processing_class(**image_processor_tester.prepare_image_processor_dict())
|
||||
# prepare image and target
|
||||
num_labels = self.image_processor_tester.num_labels
|
||||
num_labels = image_processor_tester.num_labels
|
||||
annotations = None
|
||||
instance_id_to_semantic_id = None
|
||||
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False)
|
||||
image_inputs = image_processor_tester.prepare_image_inputs(equal_resolution=False, numpify=numpify)
|
||||
if with_segmentation_maps:
|
||||
high = num_labels
|
||||
if is_instance_map:
|
||||
labels_expanded = list(range(num_labels)) * 2
|
||||
instance_id_to_semantic_id = dict(enumerate(labels_expanded))
|
||||
annotations = [
|
||||
np.random.randint(0, high * 2, (img.size[1], img.size[0])).astype(np.uint8) for img in image_inputs
|
||||
np.random.randint(0, high * 2, img.shape[:2] if numpify else (img.size[1], img.size[0])).astype(
|
||||
np.uint8
|
||||
)
|
||||
for img in image_inputs
|
||||
]
|
||||
if segmentation_type == "pil":
|
||||
annotations = [Image.fromarray(annotation) for annotation in annotations]
|
||||
|
||||
if input_data_format is ChannelDimension.FIRST and numpify:
|
||||
image_inputs = [np.moveaxis(img, -1, 0) for img in image_inputs]
|
||||
|
||||
inputs = image_processing(
|
||||
image_inputs,
|
||||
annotations,
|
||||
return_tensors="pt",
|
||||
instance_id_to_semantic_id=instance_id_to_semantic_id,
|
||||
pad_and_return_pixel_mask=True,
|
||||
input_data_format=input_data_format,
|
||||
)
|
||||
|
||||
return inputs
|
||||
@ -223,9 +237,29 @@ class Mask2FormerImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase
|
||||
self.assertTrue((pixel_values.shape[-2] % size_divisor) == 0)
|
||||
|
||||
def test_call_with_segmentation_maps(self):
|
||||
def common(is_instance_map=False, segmentation_type=None):
|
||||
def common(
|
||||
is_instance_map=False,
|
||||
segmentation_type=None,
|
||||
numpify=False,
|
||||
num_channels=3,
|
||||
input_data_format=None,
|
||||
do_resize=True,
|
||||
):
|
||||
image_processor_tester = Mask2FormerImageProcessingTester(
|
||||
self,
|
||||
num_channels=num_channels,
|
||||
do_resize=do_resize,
|
||||
image_mean=[0.5] * num_channels,
|
||||
image_std=[0.5] * num_channels,
|
||||
)
|
||||
|
||||
inputs = self.comm_get_image_processing_inputs(
|
||||
with_segmentation_maps=True, is_instance_map=is_instance_map, segmentation_type=segmentation_type
|
||||
image_processor_tester=image_processor_tester,
|
||||
with_segmentation_maps=True,
|
||||
is_instance_map=is_instance_map,
|
||||
segmentation_type=segmentation_type,
|
||||
numpify=numpify,
|
||||
input_data_format=input_data_format,
|
||||
)
|
||||
|
||||
mask_labels = inputs["mask_labels"]
|
||||
@ -243,6 +277,18 @@ class Mask2FormerImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase
|
||||
common(is_instance_map=False, segmentation_type="pil")
|
||||
common(is_instance_map=True, segmentation_type="pil")
|
||||
|
||||
common(num_channels=1, numpify=True)
|
||||
common(num_channels=1, numpify=True, input_data_format=ChannelDimension.FIRST)
|
||||
common(num_channels=2, numpify=True, input_data_format=ChannelDimension.LAST)
|
||||
common(num_channels=5, numpify=True, input_data_format=ChannelDimension.LAST, do_resize=False)
|
||||
common(num_channels=5, numpify=True, input_data_format=ChannelDimension.FIRST, do_resize=False)
|
||||
|
||||
with self.assertRaisesRegex(ValueError, expected_regex="Unable to infer channel dimension format"):
|
||||
common(num_channels=5, numpify=True, do_resize=False)
|
||||
|
||||
with self.assertRaisesRegex(TypeError, expected_regex=r"Cannot handle this data type: .*"):
|
||||
common(num_channels=5, numpify=True, input_data_format=ChannelDimension.LAST)
|
||||
|
||||
def test_integration_instance_segmentation(self):
|
||||
# load 2 images and corresponding annotations from the hub
|
||||
repo_id = "nielsr/image-segmentation-toy-data"
|
||||
|
Loading…
Reference in New Issue
Block a user