diff --git a/src/transformers/models/mask2former/image_processing_mask2former.py b/src/transformers/models/mask2former/image_processing_mask2former.py index 695ae654ccb..28ad6002958 100644 --- a/src/transformers/models/mask2former/image_processing_mask2former.py +++ b/src/transformers/models/mask2former/image_processing_mask2former.py @@ -935,7 +935,7 @@ class Mask2FormerImageProcessor(BaseImageProcessor): if segmentation_maps is not None: mask_labels = [] class_labels = [] - pad_size = get_max_height_width(pixel_values_list) + pad_size = get_max_height_width(pixel_values_list, input_data_format=input_data_format) # Convert to list of binary masks and labels for idx, segmentation_map in enumerate(segmentation_maps): segmentation_map = to_numpy_array(segmentation_map) diff --git a/tests/models/mask2former/test_image_processing_mask2former.py b/tests/models/mask2former/test_image_processing_mask2former.py index 98ffd906e5b..7468c3fd476 100644 --- a/tests/models/mask2former/test_image_processing_mask2former.py +++ b/tests/models/mask2former/test_image_processing_mask2former.py @@ -20,6 +20,7 @@ import numpy as np from datasets import load_dataset from huggingface_hub import hf_hub_download +from transformers.image_utils import ChannelDimension from transformers.testing_utils import require_torch, require_vision from transformers.utils import is_torch_available, is_vision_available @@ -180,31 +181,44 @@ class Mask2FormerImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase self.assertEqual(image_processor.size_divisor, 8) def comm_get_image_processing_inputs( - self, with_segmentation_maps=False, is_instance_map=False, segmentation_type="np" + self, + image_processor_tester, + with_segmentation_maps=False, + is_instance_map=False, + segmentation_type="np", + numpify=False, + input_data_format=None, ): - image_processing = self.image_processing_class(**self.image_processor_dict) + image_processing = self.image_processing_class(**image_processor_tester.prepare_image_processor_dict()) # prepare image and target - num_labels = self.image_processor_tester.num_labels + num_labels = image_processor_tester.num_labels annotations = None instance_id_to_semantic_id = None - image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False) + image_inputs = image_processor_tester.prepare_image_inputs(equal_resolution=False, numpify=numpify) if with_segmentation_maps: high = num_labels if is_instance_map: labels_expanded = list(range(num_labels)) * 2 instance_id_to_semantic_id = dict(enumerate(labels_expanded)) annotations = [ - np.random.randint(0, high * 2, (img.size[1], img.size[0])).astype(np.uint8) for img in image_inputs + np.random.randint(0, high * 2, img.shape[:2] if numpify else (img.size[1], img.size[0])).astype( + np.uint8 + ) + for img in image_inputs ] if segmentation_type == "pil": annotations = [Image.fromarray(annotation) for annotation in annotations] + if input_data_format is ChannelDimension.FIRST and numpify: + image_inputs = [np.moveaxis(img, -1, 0) for img in image_inputs] + inputs = image_processing( image_inputs, annotations, return_tensors="pt", instance_id_to_semantic_id=instance_id_to_semantic_id, pad_and_return_pixel_mask=True, + input_data_format=input_data_format, ) return inputs @@ -223,9 +237,29 @@ class Mask2FormerImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase self.assertTrue((pixel_values.shape[-2] % size_divisor) == 0) def test_call_with_segmentation_maps(self): - def common(is_instance_map=False, segmentation_type=None): + def common( + is_instance_map=False, + segmentation_type=None, + numpify=False, + num_channels=3, + input_data_format=None, + do_resize=True, + ): + image_processor_tester = Mask2FormerImageProcessingTester( + self, + num_channels=num_channels, + do_resize=do_resize, + image_mean=[0.5] * num_channels, + image_std=[0.5] * num_channels, + ) + inputs = self.comm_get_image_processing_inputs( - with_segmentation_maps=True, is_instance_map=is_instance_map, segmentation_type=segmentation_type + image_processor_tester=image_processor_tester, + with_segmentation_maps=True, + is_instance_map=is_instance_map, + segmentation_type=segmentation_type, + numpify=numpify, + input_data_format=input_data_format, ) mask_labels = inputs["mask_labels"] @@ -243,6 +277,18 @@ class Mask2FormerImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase common(is_instance_map=False, segmentation_type="pil") common(is_instance_map=True, segmentation_type="pil") + common(num_channels=1, numpify=True) + common(num_channels=1, numpify=True, input_data_format=ChannelDimension.FIRST) + common(num_channels=2, numpify=True, input_data_format=ChannelDimension.LAST) + common(num_channels=5, numpify=True, input_data_format=ChannelDimension.LAST, do_resize=False) + common(num_channels=5, numpify=True, input_data_format=ChannelDimension.FIRST, do_resize=False) + + with self.assertRaisesRegex(ValueError, expected_regex="Unable to infer channel dimension format"): + common(num_channels=5, numpify=True, do_resize=False) + + with self.assertRaisesRegex(TypeError, expected_regex=r"Cannot handle this data type: .*"): + common(num_channels=5, numpify=True, input_data_format=ChannelDimension.LAST) + def test_integration_instance_segmentation(self): # load 2 images and corresponding annotations from the hub repo_id = "nielsr/image-segmentation-toy-data"