mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
[SegGPT] Fix seggpt image processor (#29550)
* Fixed SegGptImageProcessor to handle 2D and 3D prompt mask inputs * Added new test to check prompt mask equivalence * New proposal * Better proposal * Removed unnecessary method * Updated seggpt docs * Introduced do_convert_rgb * nits
This commit is contained in:
parent
c793b26f2e
commit
6d4cabda26
@ -26,7 +26,8 @@ The abstract from the paper is the following:
|
||||
|
||||
Tips:
|
||||
- One can use [`SegGptImageProcessor`] to prepare image input, prompt and mask to the model.
|
||||
- It's highly advisable to pass `num_labels` (not considering background) during preprocessing and postprocessing with [`SegGptImageProcessor`] for your use case.
|
||||
- One can either use segmentation maps or RGB images as prompt masks. If using the latter make sure to set `do_convert_rgb=False` in the `preprocess` method.
|
||||
- It's highly advisable to pass `num_labels` when using `segmetantion_maps` (not considering background) during preprocessing and postprocessing with [`SegGptImageProcessor`] for your use case.
|
||||
- When doing inference with [`SegGptForImageSegmentation`] if your `batch_size` is greater than 1 you can use feature ensemble across your images by passing `feature_ensemble=True` in the forward method.
|
||||
|
||||
Here's how to use the model for one-shot semantic segmentation:
|
||||
@ -53,7 +54,7 @@ mask_prompt = ds[29]["label"]
|
||||
inputs = image_processor(
|
||||
images=image_input,
|
||||
prompt_images=image_prompt,
|
||||
prompt_masks=mask_prompt,
|
||||
segmentation_maps=mask_prompt,
|
||||
num_labels=num_labels,
|
||||
return_tensors="pt"
|
||||
)
|
||||
|
@ -26,19 +26,21 @@ from ...image_utils import (
|
||||
ChannelDimension,
|
||||
ImageInput,
|
||||
PILImageResampling,
|
||||
get_channel_dimension_axis,
|
||||
infer_channel_dimension_format,
|
||||
is_scaled_image,
|
||||
make_list_of_images,
|
||||
to_numpy_array,
|
||||
valid_images,
|
||||
)
|
||||
from ...utils import TensorType, is_torch_available, logging, requires_backends
|
||||
from ...utils import TensorType, is_torch_available, is_vision_available, logging, requires_backends
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
if is_vision_available():
|
||||
pass
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
@ -65,29 +67,10 @@ def build_palette(num_labels: int) -> List[Tuple[int, int]]:
|
||||
return color_list
|
||||
|
||||
|
||||
def get_num_channels(image: np.ndarray, input_data_format: ChannelDimension) -> int:
|
||||
if image.ndim == 2:
|
||||
return 0
|
||||
|
||||
channel_idx = get_channel_dimension_axis(image, input_data_format)
|
||||
return image.shape[channel_idx]
|
||||
|
||||
|
||||
def mask_to_rgb(
|
||||
mask: np.ndarray,
|
||||
palette: Optional[List[Tuple[int, int]]] = None,
|
||||
input_data_format: Optional[ChannelDimension] = None,
|
||||
data_format: Optional[ChannelDimension] = None,
|
||||
mask: np.ndarray, palette: Optional[List[Tuple[int, int]]] = None, data_format: Optional[ChannelDimension] = None
|
||||
) -> np.ndarray:
|
||||
if input_data_format is None and mask.ndim > 2:
|
||||
input_data_format = infer_channel_dimension_format(mask)
|
||||
|
||||
data_format = data_format if data_format is not None else input_data_format
|
||||
|
||||
num_channels = get_num_channels(mask, input_data_format)
|
||||
|
||||
if num_channels == 3:
|
||||
return to_channel_dimension_format(mask, data_format, input_data_format) if data_format is not None else mask
|
||||
data_format = data_format if data_format is not None else ChannelDimension.FIRST
|
||||
|
||||
if palette is not None:
|
||||
height, width = mask.shape
|
||||
@ -109,9 +92,7 @@ def mask_to_rgb(
|
||||
else:
|
||||
rgb_mask = np.repeat(mask[None, ...], 3, axis=0)
|
||||
|
||||
return (
|
||||
to_channel_dimension_format(rgb_mask, data_format, input_data_format) if data_format is not None else rgb_mask
|
||||
)
|
||||
return to_channel_dimension_format(rgb_mask, data_format)
|
||||
|
||||
|
||||
class SegGptImageProcessor(BaseImageProcessor):
|
||||
@ -143,6 +124,9 @@ class SegGptImageProcessor(BaseImageProcessor):
|
||||
image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_DEFAULT_STD`):
|
||||
Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
|
||||
number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
|
||||
do_convert_rgb (`bool`, *optional*, defaults to `True`):
|
||||
Whether to convert the prompt mask to RGB format. Can be overridden by the `do_convert_rgb` parameter in the
|
||||
`preprocess` method.
|
||||
"""
|
||||
|
||||
model_input_names = ["pixel_values"]
|
||||
@ -157,6 +141,7 @@ class SegGptImageProcessor(BaseImageProcessor):
|
||||
do_normalize: bool = True,
|
||||
image_mean: Optional[Union[float, List[float]]] = None,
|
||||
image_std: Optional[Union[float, List[float]]] = None,
|
||||
do_convert_rgb: bool = True,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__(**kwargs)
|
||||
@ -170,6 +155,7 @@ class SegGptImageProcessor(BaseImageProcessor):
|
||||
self.rescale_factor = rescale_factor
|
||||
self.image_mean = image_mean if image_mean is not None else IMAGENET_DEFAULT_MEAN
|
||||
self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD
|
||||
self.do_convert_rgb = do_convert_rgb
|
||||
|
||||
def get_palette(self, num_labels: int) -> List[Tuple[int, int]]:
|
||||
"""Build a palette to map the prompt mask from a single channel to a 3 channel RGB.
|
||||
@ -188,13 +174,12 @@ class SegGptImageProcessor(BaseImageProcessor):
|
||||
image: np.ndarray,
|
||||
palette: Optional[List[Tuple[int, int]]] = None,
|
||||
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
) -> np.ndarray:
|
||||
"""Convert a mask to RGB format.
|
||||
"""Converts a segmentation map to RGB format.
|
||||
|
||||
Args:
|
||||
image (`np.ndarray`):
|
||||
Mask to convert to RGB format. If the mask is already in RGB format, it will be passed through.
|
||||
Segmentation map with dimensions (height, width) where pixel values represent the class index.
|
||||
palette (`List[Tuple[int, int]]`, *optional*, defaults to `None`):
|
||||
Palette to use to convert the mask to RGB format. If unset, the mask is duplicated across the channel
|
||||
dimension.
|
||||
@ -203,21 +188,11 @@ class SegGptImageProcessor(BaseImageProcessor):
|
||||
image is used. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
input_data_format (`ChannelDimension` or `str`, *optional*):
|
||||
The channel dimension format for the input image. If unset, the channel dimension format is inferred
|
||||
from the input image. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
|
||||
Returns:
|
||||
`np.ndarray`: The mask in RGB format.
|
||||
"""
|
||||
return mask_to_rgb(
|
||||
image,
|
||||
palette=palette,
|
||||
data_format=data_format,
|
||||
input_data_format=input_data_format,
|
||||
)
|
||||
return mask_to_rgb(image, palette=palette, data_format=data_format)
|
||||
|
||||
# Copied from transformers.models.vit.image_processing_vit.ViTImageProcessor.resize with PILImageResampling.BILINEAR->PILImageResampling.BICUBIC
|
||||
def resize(
|
||||
@ -271,7 +246,6 @@ class SegGptImageProcessor(BaseImageProcessor):
|
||||
def _preprocess_step(
|
||||
self,
|
||||
images: ImageInput,
|
||||
is_mask: bool = False,
|
||||
do_resize: Optional[bool] = None,
|
||||
size: Dict[str, int] = None,
|
||||
resample: PILImageResampling = None,
|
||||
@ -282,6 +256,7 @@ class SegGptImageProcessor(BaseImageProcessor):
|
||||
image_std: Optional[Union[float, List[float]]] = None,
|
||||
data_format: Union[str, ChannelDimension] = ChannelDimension.FIRST,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
do_convert_rgb: Optional[bool] = None,
|
||||
num_labels: Optional[int] = None,
|
||||
**kwargs,
|
||||
):
|
||||
@ -292,9 +267,6 @@ class SegGptImageProcessor(BaseImageProcessor):
|
||||
images (`ImageInput`):
|
||||
Image to _preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
|
||||
passing in images with pixel values between 0 and 1, set `do_rescale=False`.
|
||||
is_mask (`bool`, *optional*, defaults to `False`):
|
||||
Whether the image is a mask. If True, the image is converted to RGB using the palette if
|
||||
`self.num_labels` is specified otherwise RGB is achieved by duplicating the channel.
|
||||
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
|
||||
Whether to resize the image.
|
||||
size (`Dict[str, int]`, *optional*, defaults to `self.size`):
|
||||
@ -331,6 +303,10 @@ class SegGptImageProcessor(BaseImageProcessor):
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
||||
do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
|
||||
Whether to convert the prompt mask to RGB format. If `num_labels` is specified, a palette will be built
|
||||
to map the prompt mask from a single channel to a 3 channel RGB. If unset, the prompt mask is duplicated
|
||||
across the channel dimension. Must be set to `False` if the prompt mask is already in RGB format.
|
||||
num_labels: (`int`, *optional*):
|
||||
Number of classes in the segmentation task (excluding the background). If specified, a palette will be
|
||||
built, assuming that class_idx 0 is the background, to map the prompt mask from a single class_idx
|
||||
@ -340,6 +316,7 @@ class SegGptImageProcessor(BaseImageProcessor):
|
||||
do_resize = do_resize if do_resize is not None else self.do_resize
|
||||
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
|
||||
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
|
||||
do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
|
||||
resample = resample if resample is not None else self.resample
|
||||
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
|
||||
image_mean = image_mean if image_mean is not None else self.image_mean
|
||||
@ -348,7 +325,8 @@ class SegGptImageProcessor(BaseImageProcessor):
|
||||
size = size if size is not None else self.size
|
||||
size_dict = get_size_dict(size)
|
||||
|
||||
images = make_list_of_images(images)
|
||||
# If segmentation map is passed we expect 2D images
|
||||
images = make_list_of_images(images, expected_ndims=2 if do_convert_rgb else 3)
|
||||
|
||||
if not valid_images(images):
|
||||
raise ValueError(
|
||||
@ -374,11 +352,11 @@ class SegGptImageProcessor(BaseImageProcessor):
|
||||
" images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
|
||||
)
|
||||
|
||||
if input_data_format is None and not is_mask:
|
||||
if input_data_format is None and not do_convert_rgb:
|
||||
# We assume that all images have the same channel dimension format.
|
||||
input_data_format = infer_channel_dimension_format(images[0])
|
||||
|
||||
if is_mask:
|
||||
if do_convert_rgb:
|
||||
palette = self.get_palette(num_labels) if num_labels is not None else None
|
||||
# Since this is the input for the next transformations its format should be the same as the input_data_format
|
||||
images = [
|
||||
@ -423,6 +401,7 @@ class SegGptImageProcessor(BaseImageProcessor):
|
||||
do_normalize: Optional[bool] = None,
|
||||
image_mean: Optional[Union[float, List[float]]] = None,
|
||||
image_std: Optional[Union[float, List[float]]] = None,
|
||||
do_convert_rgb: Optional[bool] = None,
|
||||
num_labels: Optional[int] = None,
|
||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||
data_format: Union[str, ChannelDimension] = ChannelDimension.FIRST,
|
||||
@ -440,9 +419,12 @@ class SegGptImageProcessor(BaseImageProcessor):
|
||||
Prompt image to _preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
|
||||
passing in images with pixel values between 0 and 1, set `do_rescale=False`.
|
||||
prompt_masks (`ImageInput`):
|
||||
Prompt mask from prompt image to _preprocess. Expects a single or batch of masks. If the mask masks are
|
||||
a single channel then it will be converted to RGB using the palette if `self.num_labels` is specified
|
||||
or by just repeating the channel if not. If the mask is already in RGB format, it will be passed through.
|
||||
Prompt mask from prompt image to _preprocess that specify prompt_masks value in the preprocessed output.
|
||||
Can either be in the format of segmentation maps (no channels) or RGB images. If in the format of
|
||||
RGB images, `do_convert_rgb` should be set to `False`. If in the format of segmentation maps, `num_labels`
|
||||
specifying `num_labels` is recommended to build a palette to map the prompt mask from a single channel to
|
||||
a 3 channel RGB. If `num_labels` is not specified, the prompt mask will be duplicated across the channel
|
||||
dimension.
|
||||
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
|
||||
Whether to resize the image.
|
||||
size (`Dict[str, int]`, *optional*, defaults to `self.size`):
|
||||
@ -461,6 +443,16 @@ class SegGptImageProcessor(BaseImageProcessor):
|
||||
Image mean to use if `do_normalize` is set to `True`.
|
||||
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
|
||||
Image standard deviation to use if `do_normalize` is set to `True`.
|
||||
do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
|
||||
Whether to convert the prompt mask to RGB format. If `num_labels` is specified, a palette will be built
|
||||
to map the prompt mask from a single channel to a 3 channel RGB. If unset, the prompt mask is duplicated
|
||||
across the channel dimension. Must be set to `False` if the prompt mask is already in RGB format.
|
||||
num_labels: (`int`, *optional*):
|
||||
Number of classes in the segmentation task (excluding the background). If specified, a palette will be
|
||||
built, assuming that class_idx 0 is the background, to map the prompt mask from a plain segmentation map
|
||||
with no channels to a 3 channel RGB. Not specifying this will result in the prompt mask either being passed
|
||||
through as is if it is already in RGB format (if `do_convert_rgb` is false) or being duplicated
|
||||
across the channel dimension.
|
||||
return_tensors (`str` or `TensorType`, *optional*):
|
||||
The type of tensors to return. Can be one of:
|
||||
- Unset: Return a list of `np.ndarray`.
|
||||
@ -479,11 +471,6 @@ class SegGptImageProcessor(BaseImageProcessor):
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
||||
num_labels: (`int`, *optional*):
|
||||
Number of classes in the segmentation task (excluding the background). If specified, a palette will be
|
||||
built, assuming that class_idx 0 is the background, to map the prompt mask from a single class_idx
|
||||
channel to a 3 channel RGB. Not specifying this will result in the prompt mask either being passed
|
||||
through as is if it is already in RGB format or being duplicated across the channel dimension.
|
||||
"""
|
||||
if all(v is None for v in [images, prompt_images, prompt_masks]):
|
||||
raise ValueError("At least one of images, prompt_images, prompt_masks must be specified.")
|
||||
@ -502,6 +489,7 @@ class SegGptImageProcessor(BaseImageProcessor):
|
||||
do_normalize=do_normalize,
|
||||
image_mean=image_mean,
|
||||
image_std=image_std,
|
||||
do_convert_rgb=False,
|
||||
data_format=data_format,
|
||||
input_data_format=input_data_format,
|
||||
**kwargs,
|
||||
@ -521,6 +509,7 @@ class SegGptImageProcessor(BaseImageProcessor):
|
||||
do_normalize=do_normalize,
|
||||
image_mean=image_mean,
|
||||
image_std=image_std,
|
||||
do_convert_rgb=False,
|
||||
data_format=data_format,
|
||||
input_data_format=input_data_format,
|
||||
**kwargs,
|
||||
@ -531,7 +520,6 @@ class SegGptImageProcessor(BaseImageProcessor):
|
||||
if prompt_masks is not None:
|
||||
prompt_masks = self._preprocess_step(
|
||||
prompt_masks,
|
||||
is_mask=True,
|
||||
do_resize=do_resize,
|
||||
size=size,
|
||||
resample=PILImageResampling.NEAREST,
|
||||
@ -540,9 +528,10 @@ class SegGptImageProcessor(BaseImageProcessor):
|
||||
do_normalize=do_normalize,
|
||||
image_mean=image_mean,
|
||||
image_std=image_std,
|
||||
do_convert_rgb=do_convert_rgb,
|
||||
num_labels=num_labels,
|
||||
data_format=data_format,
|
||||
input_data_format=input_data_format,
|
||||
num_labels=num_labels,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
@ -30,6 +30,8 @@ if is_torch_available():
|
||||
from transformers.models.seggpt.modeling_seggpt import SegGptImageSegmentationOutput
|
||||
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
|
||||
from transformers import SegGptImageProcessor
|
||||
|
||||
|
||||
@ -147,7 +149,7 @@ class SegGptImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
|
||||
mask_rgb = mask_binary.convert("RGB")
|
||||
|
||||
inputs_binary = image_processor(images=None, prompt_masks=mask_binary, return_tensors="pt")
|
||||
inputs_rgb = image_processor(images=None, prompt_masks=mask_rgb, return_tensors="pt")
|
||||
inputs_rgb = image_processor(images=None, prompt_masks=mask_rgb, return_tensors="pt", do_convert_rgb=False)
|
||||
|
||||
self.assertTrue((inputs_binary["prompt_masks"] == inputs_rgb["prompt_masks"]).all().item())
|
||||
|
||||
@ -196,7 +198,11 @@ class SegGptImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
|
||||
image_processor = SegGptImageProcessor.from_pretrained("BAAI/seggpt-vit-large")
|
||||
|
||||
inputs = image_processor(
|
||||
images=input_image, prompt_images=prompt_image, prompt_masks=prompt_mask, return_tensors="pt"
|
||||
images=input_image,
|
||||
prompt_images=prompt_image,
|
||||
prompt_masks=prompt_mask,
|
||||
return_tensors="pt",
|
||||
do_convert_rgb=False,
|
||||
)
|
||||
|
||||
# Verify pixel values
|
||||
@ -229,3 +235,76 @@ class SegGptImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
|
||||
torch.allclose(inputs.prompt_pixel_values[0, :, :3, :3], expected_prompt_pixel_values, atol=1e-4)
|
||||
)
|
||||
self.assertTrue(torch.allclose(inputs.prompt_masks[0, :, :3, :3], expected_prompt_masks, atol=1e-4))
|
||||
|
||||
def test_prompt_mask_equivalence(self):
|
||||
image_processor = self.image_processing_class(**self.image_processor_dict)
|
||||
image_size = self.image_processor_tester.image_size
|
||||
|
||||
# Single Mask Examples
|
||||
expected_single_shape = [1, 3, image_size, image_size]
|
||||
|
||||
# Single Semantic Map (2D)
|
||||
image_np_2d = np.ones((image_size, image_size))
|
||||
image_pt_2d = torch.ones((image_size, image_size))
|
||||
image_pil_2d = Image.fromarray(image_np_2d)
|
||||
|
||||
inputs_np_2d = image_processor(images=None, prompt_masks=image_np_2d, return_tensors="pt")
|
||||
inputs_pt_2d = image_processor(images=None, prompt_masks=image_pt_2d, return_tensors="pt")
|
||||
inputs_pil_2d = image_processor(images=None, prompt_masks=image_pil_2d, return_tensors="pt")
|
||||
|
||||
self.assertTrue((inputs_np_2d["prompt_masks"] == inputs_pt_2d["prompt_masks"]).all().item())
|
||||
self.assertTrue((inputs_np_2d["prompt_masks"] == inputs_pil_2d["prompt_masks"]).all().item())
|
||||
self.assertEqual(list(inputs_np_2d["prompt_masks"].shape), expected_single_shape)
|
||||
|
||||
# Single RGB Images (3D)
|
||||
image_np_3d = np.ones((3, image_size, image_size))
|
||||
image_pt_3d = torch.ones((3, image_size, image_size))
|
||||
image_pil_3d = Image.fromarray(image_np_3d.transpose(1, 2, 0).astype(np.uint8))
|
||||
|
||||
inputs_np_3d = image_processor(
|
||||
images=None, prompt_masks=image_np_3d, return_tensors="pt", do_convert_rgb=False
|
||||
)
|
||||
inputs_pt_3d = image_processor(
|
||||
images=None, prompt_masks=image_pt_3d, return_tensors="pt", do_convert_rgb=False
|
||||
)
|
||||
inputs_pil_3d = image_processor(
|
||||
images=None, prompt_masks=image_pil_3d, return_tensors="pt", do_convert_rgb=False
|
||||
)
|
||||
|
||||
self.assertTrue((inputs_np_3d["prompt_masks"] == inputs_pt_3d["prompt_masks"]).all().item())
|
||||
self.assertTrue((inputs_np_3d["prompt_masks"] == inputs_pil_3d["prompt_masks"]).all().item())
|
||||
self.assertEqual(list(inputs_np_3d["prompt_masks"].shape), expected_single_shape)
|
||||
|
||||
# Batched Examples
|
||||
expected_batched_shape = [2, 3, image_size, image_size]
|
||||
|
||||
# Batched Semantic Maps (3D)
|
||||
image_np_2d_batched = np.ones((2, image_size, image_size))
|
||||
image_pt_2d_batched = torch.ones((2, image_size, image_size))
|
||||
|
||||
inputs_np_2d_batched = image_processor(images=None, prompt_masks=image_np_2d_batched, return_tensors="pt")
|
||||
inputs_pt_2d_batched = image_processor(images=None, prompt_masks=image_pt_2d_batched, return_tensors="pt")
|
||||
|
||||
self.assertTrue((inputs_np_2d_batched["prompt_masks"] == inputs_pt_2d_batched["prompt_masks"]).all().item())
|
||||
self.assertEqual(list(inputs_np_2d_batched["prompt_masks"].shape), expected_batched_shape)
|
||||
|
||||
# Batched RGB images
|
||||
image_np_4d = np.ones((2, 3, image_size, image_size))
|
||||
image_pt_4d = torch.ones((2, 3, image_size, image_size))
|
||||
|
||||
inputs_np_4d = image_processor(
|
||||
images=None, prompt_masks=image_np_4d, return_tensors="pt", do_convert_rgb=False
|
||||
)
|
||||
inputs_pt_4d = image_processor(
|
||||
images=None, prompt_masks=image_pt_4d, return_tensors="pt", do_convert_rgb=False
|
||||
)
|
||||
|
||||
self.assertTrue((inputs_np_4d["prompt_masks"] == inputs_pt_4d["prompt_masks"]).all().item())
|
||||
self.assertEqual(list(inputs_np_4d["prompt_masks"].shape), expected_batched_shape)
|
||||
|
||||
# Comparing Single and Batched Examples
|
||||
self.assertTrue((inputs_np_2d["prompt_masks"][0] == inputs_np_3d["prompt_masks"][0]).all().item())
|
||||
self.assertTrue((inputs_np_2d_batched["prompt_masks"][0] == inputs_np_2d["prompt_masks"][0]).all().item())
|
||||
self.assertTrue((inputs_np_2d_batched["prompt_masks"][0] == inputs_np_3d["prompt_masks"][0]).all().item())
|
||||
self.assertTrue((inputs_np_2d_batched["prompt_masks"][0] == inputs_np_4d["prompt_masks"][0]).all().item())
|
||||
self.assertTrue((inputs_np_2d_batched["prompt_masks"][0] == inputs_np_3d["prompt_masks"][0]).all().item())
|
||||
|
@ -363,7 +363,11 @@ class SegGptModelIntegrationTest(unittest.TestCase):
|
||||
prompt_mask = masks[0]
|
||||
|
||||
inputs = image_processor(
|
||||
images=input_image, prompt_images=prompt_image, prompt_masks=prompt_mask, return_tensors="pt"
|
||||
images=input_image,
|
||||
prompt_images=prompt_image,
|
||||
prompt_masks=prompt_mask,
|
||||
return_tensors="pt",
|
||||
do_convert_rgb=False,
|
||||
)
|
||||
|
||||
inputs = inputs.to(torch_device)
|
||||
@ -404,7 +408,11 @@ class SegGptModelIntegrationTest(unittest.TestCase):
|
||||
prompt_masks = [masks[0], masks[2]]
|
||||
|
||||
inputs = image_processor(
|
||||
images=input_images, prompt_images=prompt_images, prompt_masks=prompt_masks, return_tensors="pt"
|
||||
images=input_images,
|
||||
prompt_images=prompt_images,
|
||||
prompt_masks=prompt_masks,
|
||||
return_tensors="pt",
|
||||
do_convert_rgb=False,
|
||||
)
|
||||
|
||||
inputs = {k: v.to(torch_device) for k, v in inputs.items()}
|
||||
@ -437,10 +445,16 @@ class SegGptModelIntegrationTest(unittest.TestCase):
|
||||
prompt_mask = masks[0]
|
||||
|
||||
inputs = image_processor(
|
||||
images=input_image, prompt_masks=prompt_mask, prompt_images=prompt_image, return_tensors="pt"
|
||||
images=input_image,
|
||||
prompt_masks=prompt_mask,
|
||||
prompt_images=prompt_image,
|
||||
return_tensors="pt",
|
||||
do_convert_rgb=False,
|
||||
).to(torch_device)
|
||||
|
||||
labels = image_processor(images=None, prompt_masks=label, return_tensors="pt")["prompt_masks"].to(torch_device)
|
||||
labels = image_processor(images=None, prompt_masks=label, return_tensors="pt", do_convert_rgb=False)[
|
||||
"prompt_masks"
|
||||
].to(torch_device)
|
||||
|
||||
bool_masked_pos = prepare_bool_masked_pos(model.config).to(torch_device)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user