[SegGPT] Fix seggpt image processor (#29550)

* Fixed SegGptImageProcessor to handle 2D and 3D prompt mask inputs

* Added new test to check prompt mask equivalence

* New proposal

* Better proposal

* Removed unnecessary method

* Updated seggpt docs

* Introduced do_convert_rgb

* nits
This commit is contained in:
Eduardo Pacheco 2024-04-26 20:40:12 +02:00 committed by GitHub
parent c793b26f2e
commit 6d4cabda26
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 148 additions and 65 deletions

View File

@ -26,7 +26,8 @@ The abstract from the paper is the following:
Tips:
- One can use [`SegGptImageProcessor`] to prepare image input, prompt and mask to the model.
- It's highly advisable to pass `num_labels` (not considering background) during preprocessing and postprocessing with [`SegGptImageProcessor`] for your use case.
- One can either use segmentation maps or RGB images as prompt masks. If using the latter make sure to set `do_convert_rgb=False` in the `preprocess` method.
- It's highly advisable to pass `num_labels` when using `segmetantion_maps` (not considering background) during preprocessing and postprocessing with [`SegGptImageProcessor`] for your use case.
- When doing inference with [`SegGptForImageSegmentation`] if your `batch_size` is greater than 1 you can use feature ensemble across your images by passing `feature_ensemble=True` in the forward method.
Here's how to use the model for one-shot semantic segmentation:
@ -53,7 +54,7 @@ mask_prompt = ds[29]["label"]
inputs = image_processor(
images=image_input,
prompt_images=image_prompt,
prompt_masks=mask_prompt,
segmentation_maps=mask_prompt,
num_labels=num_labels,
return_tensors="pt"
)

View File

@ -26,19 +26,21 @@ from ...image_utils import (
ChannelDimension,
ImageInput,
PILImageResampling,
get_channel_dimension_axis,
infer_channel_dimension_format,
is_scaled_image,
make_list_of_images,
to_numpy_array,
valid_images,
)
from ...utils import TensorType, is_torch_available, logging, requires_backends
from ...utils import TensorType, is_torch_available, is_vision_available, logging, requires_backends
if is_torch_available():
import torch
if is_vision_available():
pass
logger = logging.get_logger(__name__)
@ -65,29 +67,10 @@ def build_palette(num_labels: int) -> List[Tuple[int, int]]:
return color_list
def get_num_channels(image: np.ndarray, input_data_format: ChannelDimension) -> int:
if image.ndim == 2:
return 0
channel_idx = get_channel_dimension_axis(image, input_data_format)
return image.shape[channel_idx]
def mask_to_rgb(
mask: np.ndarray,
palette: Optional[List[Tuple[int, int]]] = None,
input_data_format: Optional[ChannelDimension] = None,
data_format: Optional[ChannelDimension] = None,
mask: np.ndarray, palette: Optional[List[Tuple[int, int]]] = None, data_format: Optional[ChannelDimension] = None
) -> np.ndarray:
if input_data_format is None and mask.ndim > 2:
input_data_format = infer_channel_dimension_format(mask)
data_format = data_format if data_format is not None else input_data_format
num_channels = get_num_channels(mask, input_data_format)
if num_channels == 3:
return to_channel_dimension_format(mask, data_format, input_data_format) if data_format is not None else mask
data_format = data_format if data_format is not None else ChannelDimension.FIRST
if palette is not None:
height, width = mask.shape
@ -109,9 +92,7 @@ def mask_to_rgb(
else:
rgb_mask = np.repeat(mask[None, ...], 3, axis=0)
return (
to_channel_dimension_format(rgb_mask, data_format, input_data_format) if data_format is not None else rgb_mask
)
return to_channel_dimension_format(rgb_mask, data_format)
class SegGptImageProcessor(BaseImageProcessor):
@ -143,6 +124,9 @@ class SegGptImageProcessor(BaseImageProcessor):
image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_DEFAULT_STD`):
Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
do_convert_rgb (`bool`, *optional*, defaults to `True`):
Whether to convert the prompt mask to RGB format. Can be overridden by the `do_convert_rgb` parameter in the
`preprocess` method.
"""
model_input_names = ["pixel_values"]
@ -157,6 +141,7 @@ class SegGptImageProcessor(BaseImageProcessor):
do_normalize: bool = True,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
do_convert_rgb: bool = True,
**kwargs,
) -> None:
super().__init__(**kwargs)
@ -170,6 +155,7 @@ class SegGptImageProcessor(BaseImageProcessor):
self.rescale_factor = rescale_factor
self.image_mean = image_mean if image_mean is not None else IMAGENET_DEFAULT_MEAN
self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD
self.do_convert_rgb = do_convert_rgb
def get_palette(self, num_labels: int) -> List[Tuple[int, int]]:
"""Build a palette to map the prompt mask from a single channel to a 3 channel RGB.
@ -188,13 +174,12 @@ class SegGptImageProcessor(BaseImageProcessor):
image: np.ndarray,
palette: Optional[List[Tuple[int, int]]] = None,
data_format: Optional[Union[str, ChannelDimension]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
) -> np.ndarray:
"""Convert a mask to RGB format.
"""Converts a segmentation map to RGB format.
Args:
image (`np.ndarray`):
Mask to convert to RGB format. If the mask is already in RGB format, it will be passed through.
Segmentation map with dimensions (height, width) where pixel values represent the class index.
palette (`List[Tuple[int, int]]`, *optional*, defaults to `None`):
Palette to use to convert the mask to RGB format. If unset, the mask is duplicated across the channel
dimension.
@ -203,21 +188,11 @@ class SegGptImageProcessor(BaseImageProcessor):
image is used. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format for the input image. If unset, the channel dimension format is inferred
from the input image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
Returns:
`np.ndarray`: The mask in RGB format.
"""
return mask_to_rgb(
image,
palette=palette,
data_format=data_format,
input_data_format=input_data_format,
)
return mask_to_rgb(image, palette=palette, data_format=data_format)
# Copied from transformers.models.vit.image_processing_vit.ViTImageProcessor.resize with PILImageResampling.BILINEAR->PILImageResampling.BICUBIC
def resize(
@ -271,7 +246,6 @@ class SegGptImageProcessor(BaseImageProcessor):
def _preprocess_step(
self,
images: ImageInput,
is_mask: bool = False,
do_resize: Optional[bool] = None,
size: Dict[str, int] = None,
resample: PILImageResampling = None,
@ -282,6 +256,7 @@ class SegGptImageProcessor(BaseImageProcessor):
image_std: Optional[Union[float, List[float]]] = None,
data_format: Union[str, ChannelDimension] = ChannelDimension.FIRST,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
do_convert_rgb: Optional[bool] = None,
num_labels: Optional[int] = None,
**kwargs,
):
@ -292,9 +267,6 @@ class SegGptImageProcessor(BaseImageProcessor):
images (`ImageInput`):
Image to _preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
passing in images with pixel values between 0 and 1, set `do_rescale=False`.
is_mask (`bool`, *optional*, defaults to `False`):
Whether the image is a mask. If True, the image is converted to RGB using the palette if
`self.num_labels` is specified otherwise RGB is achieved by duplicating the channel.
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
Whether to resize the image.
size (`Dict[str, int]`, *optional*, defaults to `self.size`):
@ -331,6 +303,10 @@ class SegGptImageProcessor(BaseImageProcessor):
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
Whether to convert the prompt mask to RGB format. If `num_labels` is specified, a palette will be built
to map the prompt mask from a single channel to a 3 channel RGB. If unset, the prompt mask is duplicated
across the channel dimension. Must be set to `False` if the prompt mask is already in RGB format.
num_labels: (`int`, *optional*):
Number of classes in the segmentation task (excluding the background). If specified, a palette will be
built, assuming that class_idx 0 is the background, to map the prompt mask from a single class_idx
@ -340,6 +316,7 @@ class SegGptImageProcessor(BaseImageProcessor):
do_resize = do_resize if do_resize is not None else self.do_resize
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
resample = resample if resample is not None else self.resample
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
image_mean = image_mean if image_mean is not None else self.image_mean
@ -348,7 +325,8 @@ class SegGptImageProcessor(BaseImageProcessor):
size = size if size is not None else self.size
size_dict = get_size_dict(size)
images = make_list_of_images(images)
# If segmentation map is passed we expect 2D images
images = make_list_of_images(images, expected_ndims=2 if do_convert_rgb else 3)
if not valid_images(images):
raise ValueError(
@ -374,11 +352,11 @@ class SegGptImageProcessor(BaseImageProcessor):
" images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
)
if input_data_format is None and not is_mask:
if input_data_format is None and not do_convert_rgb:
# We assume that all images have the same channel dimension format.
input_data_format = infer_channel_dimension_format(images[0])
if is_mask:
if do_convert_rgb:
palette = self.get_palette(num_labels) if num_labels is not None else None
# Since this is the input for the next transformations its format should be the same as the input_data_format
images = [
@ -423,6 +401,7 @@ class SegGptImageProcessor(BaseImageProcessor):
do_normalize: Optional[bool] = None,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
do_convert_rgb: Optional[bool] = None,
num_labels: Optional[int] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
data_format: Union[str, ChannelDimension] = ChannelDimension.FIRST,
@ -440,9 +419,12 @@ class SegGptImageProcessor(BaseImageProcessor):
Prompt image to _preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
passing in images with pixel values between 0 and 1, set `do_rescale=False`.
prompt_masks (`ImageInput`):
Prompt mask from prompt image to _preprocess. Expects a single or batch of masks. If the mask masks are
a single channel then it will be converted to RGB using the palette if `self.num_labels` is specified
or by just repeating the channel if not. If the mask is already in RGB format, it will be passed through.
Prompt mask from prompt image to _preprocess that specify prompt_masks value in the preprocessed output.
Can either be in the format of segmentation maps (no channels) or RGB images. If in the format of
RGB images, `do_convert_rgb` should be set to `False`. If in the format of segmentation maps, `num_labels`
specifying `num_labels` is recommended to build a palette to map the prompt mask from a single channel to
a 3 channel RGB. If `num_labels` is not specified, the prompt mask will be duplicated across the channel
dimension.
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
Whether to resize the image.
size (`Dict[str, int]`, *optional*, defaults to `self.size`):
@ -461,6 +443,16 @@ class SegGptImageProcessor(BaseImageProcessor):
Image mean to use if `do_normalize` is set to `True`.
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
Image standard deviation to use if `do_normalize` is set to `True`.
do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
Whether to convert the prompt mask to RGB format. If `num_labels` is specified, a palette will be built
to map the prompt mask from a single channel to a 3 channel RGB. If unset, the prompt mask is duplicated
across the channel dimension. Must be set to `False` if the prompt mask is already in RGB format.
num_labels: (`int`, *optional*):
Number of classes in the segmentation task (excluding the background). If specified, a palette will be
built, assuming that class_idx 0 is the background, to map the prompt mask from a plain segmentation map
with no channels to a 3 channel RGB. Not specifying this will result in the prompt mask either being passed
through as is if it is already in RGB format (if `do_convert_rgb` is false) or being duplicated
across the channel dimension.
return_tensors (`str` or `TensorType`, *optional*):
The type of tensors to return. Can be one of:
- Unset: Return a list of `np.ndarray`.
@ -479,11 +471,6 @@ class SegGptImageProcessor(BaseImageProcessor):
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
num_labels: (`int`, *optional*):
Number of classes in the segmentation task (excluding the background). If specified, a palette will be
built, assuming that class_idx 0 is the background, to map the prompt mask from a single class_idx
channel to a 3 channel RGB. Not specifying this will result in the prompt mask either being passed
through as is if it is already in RGB format or being duplicated across the channel dimension.
"""
if all(v is None for v in [images, prompt_images, prompt_masks]):
raise ValueError("At least one of images, prompt_images, prompt_masks must be specified.")
@ -502,6 +489,7 @@ class SegGptImageProcessor(BaseImageProcessor):
do_normalize=do_normalize,
image_mean=image_mean,
image_std=image_std,
do_convert_rgb=False,
data_format=data_format,
input_data_format=input_data_format,
**kwargs,
@ -521,6 +509,7 @@ class SegGptImageProcessor(BaseImageProcessor):
do_normalize=do_normalize,
image_mean=image_mean,
image_std=image_std,
do_convert_rgb=False,
data_format=data_format,
input_data_format=input_data_format,
**kwargs,
@ -531,7 +520,6 @@ class SegGptImageProcessor(BaseImageProcessor):
if prompt_masks is not None:
prompt_masks = self._preprocess_step(
prompt_masks,
is_mask=True,
do_resize=do_resize,
size=size,
resample=PILImageResampling.NEAREST,
@ -540,9 +528,10 @@ class SegGptImageProcessor(BaseImageProcessor):
do_normalize=do_normalize,
image_mean=image_mean,
image_std=image_std,
do_convert_rgb=do_convert_rgb,
num_labels=num_labels,
data_format=data_format,
input_data_format=input_data_format,
num_labels=num_labels,
**kwargs,
)

View File

@ -30,6 +30,8 @@ if is_torch_available():
from transformers.models.seggpt.modeling_seggpt import SegGptImageSegmentationOutput
if is_vision_available():
from PIL import Image
from transformers import SegGptImageProcessor
@ -147,7 +149,7 @@ class SegGptImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
mask_rgb = mask_binary.convert("RGB")
inputs_binary = image_processor(images=None, prompt_masks=mask_binary, return_tensors="pt")
inputs_rgb = image_processor(images=None, prompt_masks=mask_rgb, return_tensors="pt")
inputs_rgb = image_processor(images=None, prompt_masks=mask_rgb, return_tensors="pt", do_convert_rgb=False)
self.assertTrue((inputs_binary["prompt_masks"] == inputs_rgb["prompt_masks"]).all().item())
@ -196,7 +198,11 @@ class SegGptImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
image_processor = SegGptImageProcessor.from_pretrained("BAAI/seggpt-vit-large")
inputs = image_processor(
images=input_image, prompt_images=prompt_image, prompt_masks=prompt_mask, return_tensors="pt"
images=input_image,
prompt_images=prompt_image,
prompt_masks=prompt_mask,
return_tensors="pt",
do_convert_rgb=False,
)
# Verify pixel values
@ -229,3 +235,76 @@ class SegGptImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
torch.allclose(inputs.prompt_pixel_values[0, :, :3, :3], expected_prompt_pixel_values, atol=1e-4)
)
self.assertTrue(torch.allclose(inputs.prompt_masks[0, :, :3, :3], expected_prompt_masks, atol=1e-4))
def test_prompt_mask_equivalence(self):
image_processor = self.image_processing_class(**self.image_processor_dict)
image_size = self.image_processor_tester.image_size
# Single Mask Examples
expected_single_shape = [1, 3, image_size, image_size]
# Single Semantic Map (2D)
image_np_2d = np.ones((image_size, image_size))
image_pt_2d = torch.ones((image_size, image_size))
image_pil_2d = Image.fromarray(image_np_2d)
inputs_np_2d = image_processor(images=None, prompt_masks=image_np_2d, return_tensors="pt")
inputs_pt_2d = image_processor(images=None, prompt_masks=image_pt_2d, return_tensors="pt")
inputs_pil_2d = image_processor(images=None, prompt_masks=image_pil_2d, return_tensors="pt")
self.assertTrue((inputs_np_2d["prompt_masks"] == inputs_pt_2d["prompt_masks"]).all().item())
self.assertTrue((inputs_np_2d["prompt_masks"] == inputs_pil_2d["prompt_masks"]).all().item())
self.assertEqual(list(inputs_np_2d["prompt_masks"].shape), expected_single_shape)
# Single RGB Images (3D)
image_np_3d = np.ones((3, image_size, image_size))
image_pt_3d = torch.ones((3, image_size, image_size))
image_pil_3d = Image.fromarray(image_np_3d.transpose(1, 2, 0).astype(np.uint8))
inputs_np_3d = image_processor(
images=None, prompt_masks=image_np_3d, return_tensors="pt", do_convert_rgb=False
)
inputs_pt_3d = image_processor(
images=None, prompt_masks=image_pt_3d, return_tensors="pt", do_convert_rgb=False
)
inputs_pil_3d = image_processor(
images=None, prompt_masks=image_pil_3d, return_tensors="pt", do_convert_rgb=False
)
self.assertTrue((inputs_np_3d["prompt_masks"] == inputs_pt_3d["prompt_masks"]).all().item())
self.assertTrue((inputs_np_3d["prompt_masks"] == inputs_pil_3d["prompt_masks"]).all().item())
self.assertEqual(list(inputs_np_3d["prompt_masks"].shape), expected_single_shape)
# Batched Examples
expected_batched_shape = [2, 3, image_size, image_size]
# Batched Semantic Maps (3D)
image_np_2d_batched = np.ones((2, image_size, image_size))
image_pt_2d_batched = torch.ones((2, image_size, image_size))
inputs_np_2d_batched = image_processor(images=None, prompt_masks=image_np_2d_batched, return_tensors="pt")
inputs_pt_2d_batched = image_processor(images=None, prompt_masks=image_pt_2d_batched, return_tensors="pt")
self.assertTrue((inputs_np_2d_batched["prompt_masks"] == inputs_pt_2d_batched["prompt_masks"]).all().item())
self.assertEqual(list(inputs_np_2d_batched["prompt_masks"].shape), expected_batched_shape)
# Batched RGB images
image_np_4d = np.ones((2, 3, image_size, image_size))
image_pt_4d = torch.ones((2, 3, image_size, image_size))
inputs_np_4d = image_processor(
images=None, prompt_masks=image_np_4d, return_tensors="pt", do_convert_rgb=False
)
inputs_pt_4d = image_processor(
images=None, prompt_masks=image_pt_4d, return_tensors="pt", do_convert_rgb=False
)
self.assertTrue((inputs_np_4d["prompt_masks"] == inputs_pt_4d["prompt_masks"]).all().item())
self.assertEqual(list(inputs_np_4d["prompt_masks"].shape), expected_batched_shape)
# Comparing Single and Batched Examples
self.assertTrue((inputs_np_2d["prompt_masks"][0] == inputs_np_3d["prompt_masks"][0]).all().item())
self.assertTrue((inputs_np_2d_batched["prompt_masks"][0] == inputs_np_2d["prompt_masks"][0]).all().item())
self.assertTrue((inputs_np_2d_batched["prompt_masks"][0] == inputs_np_3d["prompt_masks"][0]).all().item())
self.assertTrue((inputs_np_2d_batched["prompt_masks"][0] == inputs_np_4d["prompt_masks"][0]).all().item())
self.assertTrue((inputs_np_2d_batched["prompt_masks"][0] == inputs_np_3d["prompt_masks"][0]).all().item())

View File

@ -363,7 +363,11 @@ class SegGptModelIntegrationTest(unittest.TestCase):
prompt_mask = masks[0]
inputs = image_processor(
images=input_image, prompt_images=prompt_image, prompt_masks=prompt_mask, return_tensors="pt"
images=input_image,
prompt_images=prompt_image,
prompt_masks=prompt_mask,
return_tensors="pt",
do_convert_rgb=False,
)
inputs = inputs.to(torch_device)
@ -404,7 +408,11 @@ class SegGptModelIntegrationTest(unittest.TestCase):
prompt_masks = [masks[0], masks[2]]
inputs = image_processor(
images=input_images, prompt_images=prompt_images, prompt_masks=prompt_masks, return_tensors="pt"
images=input_images,
prompt_images=prompt_images,
prompt_masks=prompt_masks,
return_tensors="pt",
do_convert_rgb=False,
)
inputs = {k: v.to(torch_device) for k, v in inputs.items()}
@ -437,10 +445,16 @@ class SegGptModelIntegrationTest(unittest.TestCase):
prompt_mask = masks[0]
inputs = image_processor(
images=input_image, prompt_masks=prompt_mask, prompt_images=prompt_image, return_tensors="pt"
images=input_image,
prompt_masks=prompt_mask,
prompt_images=prompt_image,
return_tensors="pt",
do_convert_rgb=False,
).to(torch_device)
labels = image_processor(images=None, prompt_masks=label, return_tensors="pt")["prompt_masks"].to(torch_device)
labels = image_processor(images=None, prompt_masks=label, return_tensors="pt", do_convert_rgb=False)[
"prompt_masks"
].to(torch_device)
bool_masked_pos = prepare_bool_masked_pos(model.config).to(torch_device)