mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 12:50:06 +06:00
Merge 23231d2e35
into 37a239ca50
This commit is contained in:
commit
a9c5d3962e
@ -114,6 +114,7 @@ print(f"The predicted class label is: {predicted_class_label}")
|
||||
|
||||
[[autodoc]] MobileNetV2ImageProcessor
|
||||
- preprocess
|
||||
- post_process_semantic_segmentation
|
||||
|
||||
## MobileNetV2ImageProcessorFast
|
||||
|
||||
|
@ -88,6 +88,11 @@ class MobileNetV2ImageProcessor(BaseImageProcessor):
|
||||
image_std (`float` or `list[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`):
|
||||
Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
|
||||
number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
|
||||
do_reduce_labels (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to reduce all label values of segmentation maps by 1. Usually used for datasets where 0 is
|
||||
used for background, and background itself is not included in all classes of a dataset (e.g. ADE20k). The
|
||||
background label will be replaced by 255. Can be overridden by the `do_reduce_labels` parameter in the
|
||||
`preprocess` method.
|
||||
"""
|
||||
|
||||
model_input_names = ["pixel_values"]
|
||||
@ -104,6 +109,7 @@ class MobileNetV2ImageProcessor(BaseImageProcessor):
|
||||
do_normalize: bool = True,
|
||||
image_mean: Optional[Union[float, list[float]]] = None,
|
||||
image_std: Optional[Union[float, list[float]]] = None,
|
||||
do_reduce_labels: bool = False,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__(**kwargs)
|
||||
@ -121,6 +127,7 @@ class MobileNetV2ImageProcessor(BaseImageProcessor):
|
||||
self.do_normalize = do_normalize
|
||||
self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
|
||||
self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
|
||||
self.do_reduce_labels = do_reduce_labels
|
||||
|
||||
# Copied from transformers.models.mobilenet_v1.image_processing_mobilenet_v1.MobileNetV1ImageProcessor.resize
|
||||
def resize(
|
||||
@ -172,10 +179,151 @@ class MobileNetV2ImageProcessor(BaseImageProcessor):
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Copied from transformers.models.beit.image_processing_beit.BeitImageProcessor.reduce_label
|
||||
def reduce_label(self, label: ImageInput) -> np.ndarray:
|
||||
label = to_numpy_array(label)
|
||||
# Avoid using underflow conversion
|
||||
label[label == 0] = 255
|
||||
label = label - 1
|
||||
label[label == 254] = 255
|
||||
return label
|
||||
|
||||
def __call__(self, images, segmentation_maps=None, **kwargs):
|
||||
"""
|
||||
Preprocesses a batch of images and optionally segmentation maps.
|
||||
|
||||
Overrides the `__call__` method of the `Preprocessor` class so that both images and segmentation maps can be
|
||||
passed in as positional arguments.
|
||||
"""
|
||||
return super().__call__(images, segmentation_maps=segmentation_maps, **kwargs)
|
||||
|
||||
def _preprocess(
|
||||
self,
|
||||
image: ImageInput,
|
||||
do_reduce_labels: bool,
|
||||
do_resize: bool,
|
||||
do_rescale: bool,
|
||||
do_center_crop: bool,
|
||||
do_normalize: bool,
|
||||
size: Optional[dict[str, int]] = None,
|
||||
resample: PILImageResampling = None,
|
||||
rescale_factor: Optional[float] = None,
|
||||
crop_size: Optional[dict[str, int]] = None,
|
||||
image_mean: Optional[Union[float, list[float]]] = None,
|
||||
image_std: Optional[Union[float, list[float]]] = None,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
):
|
||||
if do_reduce_labels:
|
||||
image = self.reduce_label(image)
|
||||
|
||||
if do_resize:
|
||||
image = self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format)
|
||||
|
||||
if do_center_crop:
|
||||
image = self.center_crop(image=image, size=crop_size, input_data_format=input_data_format)
|
||||
|
||||
if do_rescale:
|
||||
image = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
|
||||
|
||||
if do_normalize:
|
||||
image = self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)
|
||||
|
||||
return image
|
||||
|
||||
def _preprocess_image(
|
||||
self,
|
||||
image: ImageInput,
|
||||
do_resize: Optional[bool] = None,
|
||||
size: Optional[dict[str, int]] = None,
|
||||
resample: PILImageResampling = None,
|
||||
do_rescale: Optional[bool] = None,
|
||||
rescale_factor: Optional[float] = None,
|
||||
do_center_crop: Optional[bool] = None,
|
||||
crop_size: Optional[dict[str, int]] = None,
|
||||
do_normalize: Optional[bool] = None,
|
||||
image_mean: Optional[Union[float, list[float]]] = None,
|
||||
image_std: Optional[Union[float, list[float]]] = None,
|
||||
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
) -> np.ndarray:
|
||||
"""Preprocesses a single image."""
|
||||
# All transformations expect numpy arrays.
|
||||
image = to_numpy_array(image)
|
||||
if do_rescale and is_scaled_image(image):
|
||||
logger.warning_once(
|
||||
"It looks like you are trying to rescale already rescaled images. If the input"
|
||||
" images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
|
||||
)
|
||||
if input_data_format is None:
|
||||
input_data_format = infer_channel_dimension_format(image)
|
||||
|
||||
image = self._preprocess(
|
||||
image=image,
|
||||
do_reduce_labels=False,
|
||||
do_resize=do_resize,
|
||||
size=size,
|
||||
resample=resample,
|
||||
do_rescale=do_rescale,
|
||||
rescale_factor=rescale_factor,
|
||||
do_center_crop=do_center_crop,
|
||||
crop_size=crop_size,
|
||||
do_normalize=do_normalize,
|
||||
image_mean=image_mean,
|
||||
image_std=image_std,
|
||||
input_data_format=input_data_format,
|
||||
)
|
||||
|
||||
image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
|
||||
|
||||
return image
|
||||
|
||||
def _preprocess_mask(
|
||||
self,
|
||||
segmentation_map: ImageInput,
|
||||
do_reduce_labels: Optional[bool] = None,
|
||||
do_resize: Optional[bool] = None,
|
||||
size: Optional[dict[str, int]] = None,
|
||||
do_center_crop: Optional[bool] = None,
|
||||
crop_size: Optional[dict[str, int]] = None,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
) -> np.ndarray:
|
||||
"""Preprocesses a single mask."""
|
||||
segmentation_map = to_numpy_array(segmentation_map)
|
||||
# Add channel dimension if missing - needed for certain transformations
|
||||
if segmentation_map.ndim == 2:
|
||||
added_channel_dim = True
|
||||
segmentation_map = segmentation_map[None, ...]
|
||||
input_data_format = ChannelDimension.FIRST
|
||||
else:
|
||||
added_channel_dim = False
|
||||
if input_data_format is None:
|
||||
input_data_format = infer_channel_dimension_format(segmentation_map, num_channels=1)
|
||||
|
||||
segmentation_map = self._preprocess(
|
||||
image=segmentation_map,
|
||||
do_reduce_labels=do_reduce_labels,
|
||||
do_resize=do_resize,
|
||||
size=size,
|
||||
resample=PILImageResampling.NEAREST,
|
||||
do_rescale=False,
|
||||
do_center_crop=do_center_crop,
|
||||
crop_size=crop_size,
|
||||
do_normalize=False,
|
||||
image_mean=None,
|
||||
image_std=None,
|
||||
input_data_format=input_data_format,
|
||||
)
|
||||
# Remove extra channel dimension if added for processing
|
||||
if added_channel_dim:
|
||||
segmentation_map = segmentation_map.squeeze(0)
|
||||
segmentation_map = segmentation_map.astype(np.int64)
|
||||
return segmentation_map
|
||||
|
||||
@filter_out_non_signature_kwargs()
|
||||
def preprocess(
|
||||
self,
|
||||
images: ImageInput,
|
||||
segmentation_maps: Optional[ImageInput] = None,
|
||||
do_resize: Optional[bool] = None,
|
||||
size: Optional[dict[str, int]] = None,
|
||||
resample: PILImageResampling = None,
|
||||
@ -186,6 +334,7 @@ class MobileNetV2ImageProcessor(BaseImageProcessor):
|
||||
do_normalize: Optional[bool] = None,
|
||||
image_mean: Optional[Union[float, list[float]]] = None,
|
||||
image_std: Optional[Union[float, list[float]]] = None,
|
||||
do_reduce_labels: Optional[bool] = None,
|
||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||
data_format: Union[str, ChannelDimension] = ChannelDimension.FIRST,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
@ -197,6 +346,8 @@ class MobileNetV2ImageProcessor(BaseImageProcessor):
|
||||
images (`ImageInput`):
|
||||
Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
|
||||
passing in images with pixel values between 0 and 1, set `do_rescale=False`.
|
||||
segmentation_maps (`ImageInput`, *optional*):
|
||||
Segmentation map to preprocess.
|
||||
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
|
||||
Whether to resize the image.
|
||||
size (`dict[str, int]`, *optional*, defaults to `self.size`):
|
||||
@ -219,6 +370,10 @@ class MobileNetV2ImageProcessor(BaseImageProcessor):
|
||||
Image mean to use if `do_normalize` is set to `True`.
|
||||
image_std (`float` or `list[float]`, *optional*, defaults to `self.image_std`):
|
||||
Image standard deviation to use if `do_normalize` is set to `True`.
|
||||
do_reduce_labels (`bool`, *optional*, defaults to `self.do_reduce_labels`):
|
||||
Whether or not to reduce all label values of segmentation maps by 1. Usually used for datasets where 0
|
||||
is used for background, and background itself is not included in all classes of a dataset (e.g.
|
||||
ADE20k). The background label will be replaced by 255.
|
||||
return_tensors (`str` or `TensorType`, *optional*):
|
||||
The type of tensors to return. Can be one of:
|
||||
- Unset: Return a list of `np.ndarray`.
|
||||
@ -241,6 +396,7 @@ class MobileNetV2ImageProcessor(BaseImageProcessor):
|
||||
do_resize = do_resize if do_resize is not None else self.do_resize
|
||||
size = size if size is not None else self.size
|
||||
size = get_size_dict(size, default_to_square=False)
|
||||
do_reduce_labels = do_reduce_labels if do_reduce_labels is not None else self.do_reduce_labels
|
||||
resample = resample if resample is not None else self.resample
|
||||
do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop
|
||||
crop_size = crop_size if crop_size is not None else self.crop_size
|
||||
@ -253,11 +409,21 @@ class MobileNetV2ImageProcessor(BaseImageProcessor):
|
||||
|
||||
images = make_list_of_images(images)
|
||||
|
||||
if segmentation_maps is not None:
|
||||
segmentation_maps = make_list_of_images(segmentation_maps, expected_ndims=2)
|
||||
|
||||
if not valid_images(images):
|
||||
raise ValueError(
|
||||
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
|
||||
"torch.Tensor, tf.Tensor or jax.ndarray."
|
||||
)
|
||||
|
||||
if segmentation_maps is not None and not valid_images(segmentation_maps):
|
||||
raise ValueError(
|
||||
"Invalid segmentation map type. Must be of type PIL.Image.Image, numpy.ndarray, "
|
||||
"torch.Tensor, tf.Tensor or jax.ndarray."
|
||||
)
|
||||
|
||||
validate_preprocess_arguments(
|
||||
do_rescale=do_rescale,
|
||||
rescale_factor=rescale_factor,
|
||||
@ -270,42 +436,43 @@ class MobileNetV2ImageProcessor(BaseImageProcessor):
|
||||
size=size,
|
||||
resample=resample,
|
||||
)
|
||||
# All transformations expect numpy arrays.
|
||||
images = [to_numpy_array(image) for image in images]
|
||||
|
||||
if do_rescale and is_scaled_image(images[0]):
|
||||
logger.warning_once(
|
||||
"It looks like you are trying to rescale already rescaled images. If the input"
|
||||
" images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
|
||||
)
|
||||
|
||||
if input_data_format is None:
|
||||
# We assume that all images have the same channel dimension format.
|
||||
input_data_format = infer_channel_dimension_format(images[0])
|
||||
|
||||
all_images = []
|
||||
for image in images:
|
||||
if do_resize:
|
||||
image = self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format)
|
||||
|
||||
if do_center_crop:
|
||||
image = self.center_crop(image=image, size=crop_size, input_data_format=input_data_format)
|
||||
|
||||
if do_rescale:
|
||||
image = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
|
||||
|
||||
if do_normalize:
|
||||
image = self.normalize(
|
||||
image=image, mean=image_mean, std=image_std, input_data_format=input_data_format
|
||||
)
|
||||
|
||||
all_images.append(image)
|
||||
images = [
|
||||
to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
|
||||
for image in all_images
|
||||
self._preprocess_image(
|
||||
image=img,
|
||||
do_resize=do_resize,
|
||||
size=size,
|
||||
resample=resample,
|
||||
do_rescale=do_rescale,
|
||||
rescale_factor=rescale_factor,
|
||||
do_center_crop=do_center_crop,
|
||||
crop_size=crop_size,
|
||||
do_normalize=do_normalize,
|
||||
image_mean=image_mean,
|
||||
image_std=image_std,
|
||||
data_format=data_format,
|
||||
input_data_format=input_data_format,
|
||||
)
|
||||
for img in images
|
||||
]
|
||||
|
||||
data = {"pixel_values": images}
|
||||
|
||||
if segmentation_maps is not None:
|
||||
segmentation_maps = [
|
||||
self._preprocess_mask(
|
||||
segmentation_map=segmentation_map,
|
||||
do_reduce_labels=do_reduce_labels,
|
||||
do_resize=do_resize,
|
||||
size=size,
|
||||
do_center_crop=do_center_crop,
|
||||
crop_size=crop_size,
|
||||
input_data_format=input_data_format,
|
||||
)
|
||||
for segmentation_map in segmentation_maps
|
||||
]
|
||||
data["labels"] = segmentation_maps
|
||||
|
||||
return BatchFeature(data=data, tensor_type=return_tensors)
|
||||
|
||||
# Copied from transformers.models.beit.image_processing_beit.BeitImageProcessor.post_process_semantic_segmentation with Beit->MobileNetV2
|
||||
|
@ -14,16 +14,57 @@
|
||||
# limitations under the License.
|
||||
"""Fast Image processor class for MobileNetV2."""
|
||||
|
||||
from typing import Optional
|
||||
from typing import Optional, Union
|
||||
|
||||
from ...image_processing_utils_fast import BaseImageProcessorFast
|
||||
from ...image_utils import IMAGENET_STANDARD_MEAN, IMAGENET_STANDARD_STD, PILImageResampling
|
||||
from ...utils import auto_docstring, is_torch_available, is_torch_tensor
|
||||
from ...image_processing_utils import BatchFeature
|
||||
from ...image_processing_utils_fast import (
|
||||
BaseImageProcessorFast,
|
||||
DefaultFastImageProcessorKwargs,
|
||||
group_images_by_shape,
|
||||
reorder_images,
|
||||
)
|
||||
from ...image_utils import (
|
||||
IMAGENET_STANDARD_MEAN,
|
||||
IMAGENET_STANDARD_STD,
|
||||
ChannelDimension,
|
||||
ImageInput,
|
||||
PILImageResampling,
|
||||
SizeDict,
|
||||
is_torch_tensor,
|
||||
make_list_of_images,
|
||||
pil_torch_interpolation_mapping,
|
||||
validate_kwargs,
|
||||
)
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import (
|
||||
TensorType,
|
||||
auto_docstring,
|
||||
is_torch_available,
|
||||
is_torchvision_available,
|
||||
is_torchvision_v2_available,
|
||||
)
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
if is_torchvision_available():
|
||||
if is_torchvision_v2_available():
|
||||
from torchvision.transforms.v2 import functional as F
|
||||
else:
|
||||
from torchvision.transforms import functional as F
|
||||
|
||||
|
||||
class MobileNetV2FastImageProcessorKwargs(DefaultFastImageProcessorKwargs):
|
||||
"""
|
||||
do_reduce_labels (`bool`, *optional*, defaults to `self.do_reduce_labels`):
|
||||
Whether or not to reduce all label values of segmentation maps by 1. Usually used for datasets where 0
|
||||
is used for background, and background itself is not included in all classes of a dataset (e.g.
|
||||
ADE20k). The background label will be replaced by 255.
|
||||
"""
|
||||
|
||||
do_reduce_labels: Optional[bool]
|
||||
|
||||
|
||||
@auto_docstring
|
||||
class MobileNetV2ImageProcessorFast(BaseImageProcessorFast):
|
||||
@ -37,8 +78,177 @@ class MobileNetV2ImageProcessorFast(BaseImageProcessorFast):
|
||||
do_center_crop = True
|
||||
do_rescale = True
|
||||
do_normalize = True
|
||||
do_convert_rgb = None
|
||||
do_reduce_labels = False
|
||||
valid_kwargs = MobileNetV2FastImageProcessorKwargs
|
||||
|
||||
def __init__(self, **kwargs: Unpack[MobileNetV2FastImageProcessorKwargs]):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
# Copied from transformers.models.beit.image_processing_beit_fast.BeitImageProcessorFast.reduce_label
|
||||
def reduce_label(self, labels: list["torch.Tensor"]):
|
||||
for idx in range(len(labels)):
|
||||
label = labels[idx]
|
||||
label = torch.where(label == 0, torch.tensor(255, dtype=label.dtype), label)
|
||||
label = label - 1
|
||||
label = torch.where(label == 254, torch.tensor(255, dtype=label.dtype), label)
|
||||
labels[idx] = label
|
||||
|
||||
return label
|
||||
|
||||
def _preprocess(
|
||||
self,
|
||||
images: list["torch.Tensor"],
|
||||
do_reduce_labels: bool,
|
||||
do_resize: bool,
|
||||
do_rescale: bool,
|
||||
do_center_crop: bool,
|
||||
do_normalize: bool,
|
||||
size: Optional[SizeDict],
|
||||
interpolation: Optional["F.InterpolationMode"],
|
||||
rescale_factor: Optional[float],
|
||||
crop_size: Optional[SizeDict],
|
||||
image_mean: Optional[Union[float, list[float]]],
|
||||
image_std: Optional[Union[float, list[float]]],
|
||||
disable_grouping: bool,
|
||||
return_tensors: Optional[Union[str, TensorType]],
|
||||
**kwargs,
|
||||
) -> BatchFeature:
|
||||
processed_images = []
|
||||
|
||||
if do_reduce_labels:
|
||||
images = self.reduce_label(images)
|
||||
|
||||
# Group images by shape for more efficient batch processing
|
||||
grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping)
|
||||
resized_images_grouped = {}
|
||||
|
||||
# Process each group of images with the same shape
|
||||
for shape, stacked_images in grouped_images.items():
|
||||
if do_resize:
|
||||
stacked_images = self.resize(image=stacked_images, size=size, interpolation=interpolation)
|
||||
resized_images_grouped[shape] = stacked_images
|
||||
|
||||
# Reorder images to original sequence
|
||||
resized_images = reorder_images(resized_images_grouped, grouped_images_index)
|
||||
|
||||
# Group again after resizing (in case resize produced different sizes)
|
||||
grouped_images, grouped_images_index = group_images_by_shape(resized_images, disable_grouping=disable_grouping)
|
||||
processed_images_grouped = {}
|
||||
|
||||
for shape, stacked_images in grouped_images.items():
|
||||
if do_center_crop:
|
||||
stacked_images = self.center_crop(stacked_images, crop_size)
|
||||
# Fused rescale and normalize
|
||||
stacked_images = self.rescale_and_normalize(
|
||||
stacked_images, do_rescale, rescale_factor, do_normalize, image_mean, image_std
|
||||
)
|
||||
processed_images_grouped[shape] = stacked_images
|
||||
|
||||
processed_images = reorder_images(processed_images_grouped, grouped_images_index)
|
||||
|
||||
# Stack all processed images if return_tensors is specified
|
||||
processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
|
||||
|
||||
return processed_images
|
||||
|
||||
def _preprocess_images(
|
||||
self,
|
||||
images,
|
||||
**kwargs,
|
||||
):
|
||||
"""Preprocesses images."""
|
||||
kwargs["do_reduce_labels"] = False
|
||||
processed_images = self._preprocess(images=images, **kwargs)
|
||||
return processed_images
|
||||
|
||||
def _preprocess_segmentation_maps(
|
||||
self,
|
||||
segmentation_maps,
|
||||
**kwargs,
|
||||
):
|
||||
"""Preprocesses segmentation maps."""
|
||||
processed_segmentation_maps = []
|
||||
for segmentation_map in segmentation_maps:
|
||||
segmentation_map = self._process_image(
|
||||
segmentation_map, do_convert_rgb=False, input_data_format=ChannelDimension.FIRST
|
||||
)
|
||||
|
||||
if segmentation_map.ndim == 2:
|
||||
segmentation_map = segmentation_map[None, ...]
|
||||
|
||||
processed_segmentation_maps.append(segmentation_map)
|
||||
|
||||
kwargs["do_normalize"] = False
|
||||
kwargs["do_rescale"] = False
|
||||
kwargs["interpolation"] = pil_torch_interpolation_mapping[PILImageResampling.NEAREST]
|
||||
processed_segmentation_maps = self._preprocess(images=processed_segmentation_maps, **kwargs)
|
||||
|
||||
processed_segmentation_maps = processed_segmentation_maps.squeeze(1)
|
||||
|
||||
processed_segmentation_maps = processed_segmentation_maps.to(torch.int64)
|
||||
return processed_segmentation_maps
|
||||
|
||||
@auto_docstring
|
||||
def preprocess(
|
||||
self,
|
||||
images: ImageInput,
|
||||
segmentation_maps: Optional[ImageInput] = None,
|
||||
**kwargs: Unpack[MobileNetV2FastImageProcessorKwargs],
|
||||
) -> BatchFeature:
|
||||
r"""
|
||||
segmentation_maps (`ImageInput`, *optional*):
|
||||
The segmentation maps to preprocess.
|
||||
"""
|
||||
validate_kwargs(captured_kwargs=kwargs.keys(), valid_processor_keys=self.valid_kwargs.__annotations__.keys())
|
||||
# Set default kwargs from self. This ensures that if a kwarg is not provided
|
||||
# by the user, it gets its default value from the instance, or is set to None.
|
||||
for kwarg_name in self.valid_kwargs.__annotations__:
|
||||
kwargs.setdefault(kwarg_name, getattr(self, kwarg_name, None))
|
||||
|
||||
# Extract parameters that are only used for preparing the input images
|
||||
do_convert_rgb = kwargs.pop("do_convert_rgb")
|
||||
input_data_format = kwargs.pop("input_data_format")
|
||||
device = kwargs.pop("device")
|
||||
# Prepare input images
|
||||
images = self._prepare_input_images(
|
||||
images=images, do_convert_rgb=do_convert_rgb, input_data_format=input_data_format, device=device
|
||||
)
|
||||
|
||||
# Prepare segmentation maps
|
||||
if segmentation_maps is not None:
|
||||
segmentation_maps = make_list_of_images(images=segmentation_maps, expected_ndims=2)
|
||||
|
||||
# Update kwargs that need further processing before being validated
|
||||
kwargs = self._further_process_kwargs(**kwargs)
|
||||
|
||||
# Validate kwargs
|
||||
self._validate_preprocess_kwargs(**kwargs)
|
||||
|
||||
# torch resize uses interpolation instead of resample
|
||||
resample = kwargs.pop("resample")
|
||||
kwargs["interpolation"] = (
|
||||
pil_torch_interpolation_mapping[resample] if isinstance(resample, (PILImageResampling, int)) else resample
|
||||
)
|
||||
|
||||
# Pop kwargs that are not needed in _preprocess
|
||||
kwargs.pop("default_to_square")
|
||||
kwargs.pop("data_format")
|
||||
|
||||
images = self._preprocess_images(
|
||||
images=images,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if segmentation_maps is not None:
|
||||
segmentation_maps = self._preprocess_segmentation_maps(
|
||||
segmentation_maps=segmentation_maps,
|
||||
**kwargs,
|
||||
)
|
||||
return BatchFeature(data={"pixel_values": images, "labels": segmentation_maps})
|
||||
|
||||
return BatchFeature(data={"pixel_values": images})
|
||||
|
||||
# Copied from transformers.models.beit.image_processing_beit_fast.BeitImageProcessorFast.post_process_semantic_segmentation with Beit->MobileNetV2
|
||||
def post_process_semantic_segmentation(self, outputs, target_sizes: Optional[list[tuple]] = None):
|
||||
"""
|
||||
Converts the output of [`MobileNetV2ForSemanticSegmentation`] into semantic segmentation maps. Only supports PyTorch.
|
||||
|
@ -83,6 +83,11 @@ class MobileViTImageProcessor(BaseImageProcessor):
|
||||
do_flip_channel_order (`bool`, *optional*, defaults to `True`):
|
||||
Whether to flip the color channels from RGB to BGR. Can be overridden by the `do_flip_channel_order`
|
||||
parameter in the `preprocess` method.
|
||||
do_reduce_labels (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to reduce all label values of segmentation maps by 1. Usually used for datasets where 0 is
|
||||
used for background, and background itself is not included in all classes of a dataset (e.g. ADE20k). The
|
||||
background label will be replaced by 255. Can be overridden by the `do_reduce_labels` parameter in the
|
||||
`preprocess` method.
|
||||
"""
|
||||
|
||||
model_input_names = ["pixel_values"]
|
||||
@ -97,6 +102,7 @@ class MobileViTImageProcessor(BaseImageProcessor):
|
||||
do_center_crop: bool = True,
|
||||
crop_size: Optional[dict[str, int]] = None,
|
||||
do_flip_channel_order: bool = True,
|
||||
do_reduce_labels: bool = False,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__(**kwargs)
|
||||
@ -113,6 +119,7 @@ class MobileViTImageProcessor(BaseImageProcessor):
|
||||
self.do_center_crop = do_center_crop
|
||||
self.crop_size = crop_size
|
||||
self.do_flip_channel_order = do_flip_channel_order
|
||||
self.do_reduce_labels = do_reduce_labels
|
||||
|
||||
# Copied from transformers.models.mobilenet_v1.image_processing_mobilenet_v1.MobileNetV1ImageProcessor.resize with PILImageResampling.BICUBIC->PILImageResampling.BILINEAR
|
||||
def resize(
|
||||
@ -183,6 +190,15 @@ class MobileViTImageProcessor(BaseImageProcessor):
|
||||
"""
|
||||
return flip_channel_order(image, data_format=data_format, input_data_format=input_data_format)
|
||||
|
||||
# Copied from transformers.models.beit.image_processing_beit.BeitImageProcessor.reduce_label
|
||||
def reduce_label(self, label: ImageInput) -> np.ndarray:
|
||||
label = to_numpy_array(label)
|
||||
# Avoid using underflow conversion
|
||||
label[label == 0] = 255
|
||||
label = label - 1
|
||||
label[label == 254] = 255
|
||||
return label
|
||||
|
||||
def __call__(self, images, segmentation_maps=None, **kwargs):
|
||||
"""
|
||||
Preprocesses a batch of images and optionally segmentation maps.
|
||||
@ -195,6 +211,7 @@ class MobileViTImageProcessor(BaseImageProcessor):
|
||||
def _preprocess(
|
||||
self,
|
||||
image: ImageInput,
|
||||
do_reduce_labels: bool,
|
||||
do_resize: bool,
|
||||
do_rescale: bool,
|
||||
do_center_crop: bool,
|
||||
@ -205,6 +222,9 @@ class MobileViTImageProcessor(BaseImageProcessor):
|
||||
crop_size: Optional[dict[str, int]] = None,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
):
|
||||
if do_reduce_labels:
|
||||
image = self.reduce_label(image)
|
||||
|
||||
if do_resize:
|
||||
image = self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format)
|
||||
|
||||
@ -246,6 +266,7 @@ class MobileViTImageProcessor(BaseImageProcessor):
|
||||
|
||||
image = self._preprocess(
|
||||
image=image,
|
||||
do_reduce_labels=False,
|
||||
do_resize=do_resize,
|
||||
size=size,
|
||||
resample=resample,
|
||||
@ -264,6 +285,7 @@ class MobileViTImageProcessor(BaseImageProcessor):
|
||||
def _preprocess_mask(
|
||||
self,
|
||||
segmentation_map: ImageInput,
|
||||
do_reduce_labels: Optional[bool] = None,
|
||||
do_resize: Optional[bool] = None,
|
||||
size: Optional[dict[str, int]] = None,
|
||||
do_center_crop: Optional[bool] = None,
|
||||
@ -284,6 +306,7 @@ class MobileViTImageProcessor(BaseImageProcessor):
|
||||
|
||||
segmentation_map = self._preprocess(
|
||||
image=segmentation_map,
|
||||
do_reduce_labels=do_reduce_labels,
|
||||
do_resize=do_resize,
|
||||
size=size,
|
||||
resample=PILImageResampling.NEAREST,
|
||||
@ -312,6 +335,7 @@ class MobileViTImageProcessor(BaseImageProcessor):
|
||||
do_center_crop: Optional[bool] = None,
|
||||
crop_size: Optional[dict[str, int]] = None,
|
||||
do_flip_channel_order: Optional[bool] = None,
|
||||
do_reduce_labels: Optional[bool] = None,
|
||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||
data_format: ChannelDimension = ChannelDimension.FIRST,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
@ -342,6 +366,10 @@ class MobileViTImageProcessor(BaseImageProcessor):
|
||||
Size of the center crop if `do_center_crop` is set to `True`.
|
||||
do_flip_channel_order (`bool`, *optional*, defaults to `self.do_flip_channel_order`):
|
||||
Whether to flip the channel order of the image.
|
||||
do_reduce_labels (`bool`, *optional*, defaults to `self.do_reduce_labels`):
|
||||
Whether or not to reduce all label values of segmentation maps by 1. Usually used for datasets where 0
|
||||
is used for background, and background itself is not included in all classes of a dataset (e.g.
|
||||
ADE20k). The background label will be replaced by 255.
|
||||
return_tensors (`str` or `TensorType`, *optional*):
|
||||
The type of tensors to return. Can be one of:
|
||||
- Unset: Return a list of `np.ndarray`.
|
||||
@ -374,6 +402,8 @@ class MobileViTImageProcessor(BaseImageProcessor):
|
||||
crop_size = crop_size if crop_size is not None else self.crop_size
|
||||
crop_size = get_size_dict(crop_size, param_name="crop_size")
|
||||
|
||||
do_reduce_labels = do_reduce_labels if do_reduce_labels is not None else self.do_reduce_labels
|
||||
|
||||
images = make_list_of_images(images)
|
||||
|
||||
if segmentation_maps is not None:
|
||||
@ -426,6 +456,7 @@ class MobileViTImageProcessor(BaseImageProcessor):
|
||||
segmentation_maps = [
|
||||
self._preprocess_mask(
|
||||
segmentation_map=segmentation_map,
|
||||
do_reduce_labels=do_reduce_labels,
|
||||
do_resize=do_resize,
|
||||
size=size,
|
||||
do_center_crop=do_center_crop,
|
||||
|
@ -14,9 +14,7 @@
|
||||
# limitations under the License.
|
||||
"""Fast Image processor class for MobileViT."""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from typing import Optional, Union
|
||||
|
||||
from ...image_processing_utils import BatchFeature
|
||||
from ...image_processing_utils_fast import (
|
||||
@ -27,23 +25,46 @@ from ...image_processing_utils_fast import (
|
||||
)
|
||||
from ...image_utils import (
|
||||
ChannelDimension,
|
||||
ImageInput,
|
||||
PILImageResampling,
|
||||
SizeDict,
|
||||
is_torch_tensor,
|
||||
make_list_of_images,
|
||||
pil_torch_interpolation_mapping,
|
||||
validate_kwargs,
|
||||
)
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import auto_docstring
|
||||
from ...utils import (
|
||||
TensorType,
|
||||
auto_docstring,
|
||||
is_torch_available,
|
||||
is_torchvision_available,
|
||||
is_torchvision_v2_available,
|
||||
)
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
if is_torchvision_available():
|
||||
if is_torchvision_v2_available():
|
||||
from torchvision.transforms.v2 import functional as F
|
||||
else:
|
||||
from torchvision.transforms import functional as F
|
||||
|
||||
|
||||
class MobileVitFastImageProcessorKwargs(DefaultFastImageProcessorKwargs):
|
||||
"""
|
||||
do_flip_channel_order (`bool`, *optional*, defaults to `self.do_flip_channel_order`):
|
||||
Whether to flip the color channels from RGB to BGR or vice versa.
|
||||
do_reduce_labels (`bool`, *optional*, defaults to `self.do_reduce_labels`):
|
||||
Whether or not to reduce all label values of segmentation maps by 1. Usually used for datasets where 0
|
||||
is used for background, and background itself is not included in all classes of a dataset (e.g.
|
||||
ADE20k). The background label will be replaced by 255.
|
||||
"""
|
||||
|
||||
do_flip_channel_order: Optional[bool]
|
||||
do_reduce_labels: Optional[bool]
|
||||
|
||||
|
||||
@auto_docstring
|
||||
@ -58,28 +79,44 @@ class MobileViTImageProcessorFast(BaseImageProcessorFast):
|
||||
do_normalize = None
|
||||
do_convert_rgb = None
|
||||
do_flip_channel_order = True
|
||||
do_reduce_labels = False
|
||||
valid_kwargs = MobileVitFastImageProcessorKwargs
|
||||
|
||||
def __init__(self, **kwargs: Unpack[MobileVitFastImageProcessorKwargs]):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
# Copied from transformers.models.beit.image_processing_beit_fast.BeitImageProcessorFast.reduce_label
|
||||
def reduce_label(self, labels: list["torch.Tensor"]):
|
||||
for idx in range(len(labels)):
|
||||
label = labels[idx]
|
||||
label = torch.where(label == 0, torch.tensor(255, dtype=label.dtype), label)
|
||||
label = label - 1
|
||||
label = torch.where(label == 254, torch.tensor(255, dtype=label.dtype), label)
|
||||
labels[idx] = label
|
||||
|
||||
return label
|
||||
|
||||
def _preprocess(
|
||||
self,
|
||||
images,
|
||||
images: list["torch.Tensor"],
|
||||
do_reduce_labels: bool,
|
||||
do_resize: bool,
|
||||
size: Optional[dict],
|
||||
interpolation: Optional[str],
|
||||
size: Optional[SizeDict],
|
||||
interpolation: Optional["F.InterpolationMode"],
|
||||
do_rescale: bool,
|
||||
rescale_factor: Optional[float],
|
||||
do_center_crop: bool,
|
||||
crop_size: Optional[dict],
|
||||
crop_size: Optional[SizeDict],
|
||||
do_flip_channel_order: bool,
|
||||
disable_grouping: bool,
|
||||
return_tensors: Optional[str],
|
||||
return_tensors: Optional[Union[str, TensorType]],
|
||||
**kwargs,
|
||||
):
|
||||
) -> BatchFeature:
|
||||
processed_images = []
|
||||
|
||||
if do_reduce_labels:
|
||||
images = self.reduce_label(images)
|
||||
|
||||
# Group images by shape for more efficient batch processing
|
||||
grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping)
|
||||
resized_images_grouped = {}
|
||||
@ -119,6 +156,16 @@ class MobileViTImageProcessorFast(BaseImageProcessorFast):
|
||||
|
||||
return processed_images
|
||||
|
||||
def _preprocess_images(
|
||||
self,
|
||||
images,
|
||||
**kwargs,
|
||||
):
|
||||
"""Preprocesses images."""
|
||||
kwargs["do_reduce_labels"] = False
|
||||
processed_images = self._preprocess(images=images, **kwargs)
|
||||
return processed_images
|
||||
|
||||
def _preprocess_segmentation_maps(
|
||||
self,
|
||||
segmentation_maps,
|
||||
@ -149,8 +196,8 @@ class MobileViTImageProcessorFast(BaseImageProcessorFast):
|
||||
@auto_docstring
|
||||
def preprocess(
|
||||
self,
|
||||
images,
|
||||
segmentation_maps=None,
|
||||
images: ImageInput,
|
||||
segmentation_maps: Optional[ImageInput] = None,
|
||||
**kwargs: Unpack[MobileVitFastImageProcessorKwargs],
|
||||
) -> BatchFeature:
|
||||
r"""
|
||||
@ -192,7 +239,7 @@ class MobileViTImageProcessorFast(BaseImageProcessorFast):
|
||||
kwargs.pop("default_to_square")
|
||||
kwargs.pop("data_format")
|
||||
|
||||
images = self._preprocess(
|
||||
images = self._preprocess_images(
|
||||
images=images,
|
||||
**kwargs,
|
||||
)
|
||||
@ -207,6 +254,21 @@ class MobileViTImageProcessorFast(BaseImageProcessorFast):
|
||||
return BatchFeature(data={"pixel_values": images})
|
||||
|
||||
def post_process_semantic_segmentation(self, outputs, target_sizes: Optional[list[tuple]] = None):
|
||||
"""
|
||||
Converts the output of [`MobileNetV2ForSemanticSegmentation`] into semantic segmentation maps. Only supports PyTorch.
|
||||
|
||||
Args:
|
||||
outputs ([`MobileNetV2ForSemanticSegmentation`]):
|
||||
Raw outputs of the model.
|
||||
target_sizes (`list[Tuple]` of length `batch_size`, *optional*):
|
||||
List of tuples corresponding to the requested final size (height, width) of each prediction. If unset,
|
||||
predictions will not be resized.
|
||||
|
||||
Returns:
|
||||
semantic_segmentation: `list[torch.Tensor]` of length `batch_size`, where each item is a semantic
|
||||
segmentation map of shape (height, width) corresponding to the target_sizes entry (if `target_sizes` is
|
||||
specified). Each entry of each `torch.Tensor` correspond to a semantic class id.
|
||||
"""
|
||||
logits = outputs.logits
|
||||
|
||||
# Resize logits and compute semantic segmentation maps
|
||||
|
@ -15,13 +15,21 @@
|
||||
|
||||
import unittest
|
||||
|
||||
import requests
|
||||
from datasets import load_dataset
|
||||
|
||||
from transformers.testing_utils import require_torch, require_vision
|
||||
from transformers.utils import is_torchvision_available, is_vision_available
|
||||
from transformers.utils import is_torch_available, is_torchvision_available, is_vision_available
|
||||
|
||||
from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
|
||||
from transformers import MobileNetV2ImageProcessor
|
||||
|
||||
if is_torchvision_available():
|
||||
@ -41,6 +49,7 @@ class MobileNetV2ImageProcessingTester:
|
||||
size=None,
|
||||
do_center_crop=True,
|
||||
crop_size=None,
|
||||
do_reduce_labels=False,
|
||||
):
|
||||
size = size if size is not None else {"shortest_edge": 20}
|
||||
crop_size = crop_size if crop_size is not None else {"height": 18, "width": 18}
|
||||
@ -54,6 +63,7 @@ class MobileNetV2ImageProcessingTester:
|
||||
self.size = size
|
||||
self.do_center_crop = do_center_crop
|
||||
self.crop_size = crop_size
|
||||
self.do_reduce_labels = do_reduce_labels
|
||||
|
||||
def prepare_image_processor_dict(self):
|
||||
return {
|
||||
@ -61,6 +71,7 @@ class MobileNetV2ImageProcessingTester:
|
||||
"size": self.size,
|
||||
"do_center_crop": self.do_center_crop,
|
||||
"crop_size": self.crop_size,
|
||||
"do_reduce_labels": self.do_reduce_labels,
|
||||
}
|
||||
|
||||
def expected_output_image_shape(self, images):
|
||||
@ -78,6 +89,17 @@ class MobileNetV2ImageProcessingTester:
|
||||
)
|
||||
|
||||
|
||||
def prepare_semantic_single_inputs():
|
||||
ds = load_dataset("hf-internal-testing/fixtures_ade20k", split="test")
|
||||
example = ds[0]
|
||||
return example["image"], example["map"]
|
||||
|
||||
|
||||
def prepare_semantic_batch_inputs():
|
||||
ds = load_dataset("hf-internal-testing/fixtures_ade20k", split="test")
|
||||
return list(ds["image"][:2]), list(ds["map"][:2])
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_vision
|
||||
class MobileNetV2ImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
|
||||
@ -99,13 +121,167 @@ class MobileNetV2ImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase
|
||||
self.assertTrue(hasattr(image_processor, "size"))
|
||||
self.assertTrue(hasattr(image_processor, "do_center_crop"))
|
||||
self.assertTrue(hasattr(image_processor, "crop_size"))
|
||||
self.assertTrue(hasattr(image_processor, "do_reduce_labels"))
|
||||
|
||||
def test_image_processor_from_dict_with_kwargs(self):
|
||||
for image_processing_class in self.image_processor_list:
|
||||
image_processor = image_processing_class.from_dict(self.image_processor_dict)
|
||||
self.assertEqual(image_processor.size, {"shortest_edge": 20})
|
||||
self.assertEqual(image_processor.crop_size, {"height": 18, "width": 18})
|
||||
self.assertEqual(image_processor.do_reduce_labels, False)
|
||||
|
||||
image_processor = image_processing_class.from_dict(self.image_processor_dict, size=42, crop_size=84)
|
||||
image_processor = image_processing_class.from_dict(
|
||||
self.image_processor_dict, size=42, crop_size=84, do_reduce_labels=True
|
||||
)
|
||||
self.assertEqual(image_processor.size, {"shortest_edge": 42})
|
||||
self.assertEqual(image_processor.crop_size, {"height": 84, "width": 84})
|
||||
self.assertEqual(image_processor.do_reduce_labels, True)
|
||||
|
||||
def test_call_segmentation_maps(self):
|
||||
# Initialize image_processing
|
||||
for image_processing_class in self.image_processor_list:
|
||||
image_processing = image_processing_class(**self.image_processor_dict)
|
||||
# create random PyTorch tensors
|
||||
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, torchify=True)
|
||||
maps = []
|
||||
for image in image_inputs:
|
||||
self.assertIsInstance(image, torch.Tensor)
|
||||
maps.append(torch.zeros(image.shape[-2:]).long())
|
||||
|
||||
# Test not batched input
|
||||
encoding = image_processing(image_inputs[0], maps[0], return_tensors="pt")
|
||||
self.assertEqual(
|
||||
encoding["pixel_values"].shape,
|
||||
(
|
||||
1,
|
||||
self.image_processor_tester.num_channels,
|
||||
self.image_processor_tester.crop_size["height"],
|
||||
self.image_processor_tester.crop_size["width"],
|
||||
),
|
||||
)
|
||||
self.assertEqual(
|
||||
encoding["labels"].shape,
|
||||
(
|
||||
1,
|
||||
self.image_processor_tester.crop_size["height"],
|
||||
self.image_processor_tester.crop_size["width"],
|
||||
),
|
||||
)
|
||||
self.assertEqual(encoding["labels"].dtype, torch.long)
|
||||
self.assertTrue(encoding["labels"].min().item() >= 0)
|
||||
self.assertTrue(encoding["labels"].max().item() <= 255)
|
||||
|
||||
# Test batched
|
||||
encoding = image_processing(image_inputs, maps, return_tensors="pt")
|
||||
self.assertEqual(
|
||||
encoding["pixel_values"].shape,
|
||||
(
|
||||
self.image_processor_tester.batch_size,
|
||||
self.image_processor_tester.num_channels,
|
||||
self.image_processor_tester.crop_size["height"],
|
||||
self.image_processor_tester.crop_size["width"],
|
||||
),
|
||||
)
|
||||
self.assertEqual(
|
||||
encoding["labels"].shape,
|
||||
(
|
||||
self.image_processor_tester.batch_size,
|
||||
self.image_processor_tester.crop_size["height"],
|
||||
self.image_processor_tester.crop_size["width"],
|
||||
),
|
||||
)
|
||||
self.assertEqual(encoding["labels"].dtype, torch.long)
|
||||
self.assertTrue(encoding["labels"].min().item() >= 0)
|
||||
self.assertTrue(encoding["labels"].max().item() <= 255)
|
||||
|
||||
# Test not batched input (PIL images)
|
||||
image, segmentation_map = prepare_semantic_single_inputs()
|
||||
|
||||
encoding = image_processing(image, segmentation_map, return_tensors="pt")
|
||||
self.assertEqual(
|
||||
encoding["pixel_values"].shape,
|
||||
(
|
||||
1,
|
||||
self.image_processor_tester.num_channels,
|
||||
self.image_processor_tester.crop_size["height"],
|
||||
self.image_processor_tester.crop_size["width"],
|
||||
),
|
||||
)
|
||||
self.assertEqual(
|
||||
encoding["labels"].shape,
|
||||
(
|
||||
1,
|
||||
self.image_processor_tester.crop_size["height"],
|
||||
self.image_processor_tester.crop_size["width"],
|
||||
),
|
||||
)
|
||||
self.assertEqual(encoding["labels"].dtype, torch.long)
|
||||
self.assertTrue(encoding["labels"].min().item() >= 0)
|
||||
self.assertTrue(encoding["labels"].max().item() <= 255)
|
||||
|
||||
# Test batched input (PIL images)
|
||||
images, segmentation_maps = prepare_semantic_batch_inputs()
|
||||
|
||||
encoding = image_processing(images, segmentation_maps, return_tensors="pt")
|
||||
self.assertEqual(
|
||||
encoding["pixel_values"].shape,
|
||||
(
|
||||
2,
|
||||
self.image_processor_tester.num_channels,
|
||||
self.image_processor_tester.crop_size["height"],
|
||||
self.image_processor_tester.crop_size["width"],
|
||||
),
|
||||
)
|
||||
self.assertEqual(
|
||||
encoding["labels"].shape,
|
||||
(
|
||||
2,
|
||||
self.image_processor_tester.crop_size["height"],
|
||||
self.image_processor_tester.crop_size["width"],
|
||||
),
|
||||
)
|
||||
self.assertEqual(encoding["labels"].dtype, torch.long)
|
||||
self.assertTrue(encoding["labels"].min().item() >= 0)
|
||||
self.assertTrue(encoding["labels"].max().item() <= 255)
|
||||
|
||||
def test_reduce_labels(self):
|
||||
# Initialize image_processing
|
||||
for image_processing_class in self.image_processor_list:
|
||||
image_processing = image_processing_class(**self.image_processor_dict)
|
||||
|
||||
# ADE20k has 150 classes, and the background is included, so labels should be between 0 and 150
|
||||
image, map = prepare_semantic_single_inputs()
|
||||
encoding = image_processing(image, map, return_tensors="pt")
|
||||
self.assertTrue(encoding["labels"].min().item() >= 0)
|
||||
self.assertTrue(encoding["labels"].max().item() <= 150)
|
||||
|
||||
image_processing.do_reduce_labels = True
|
||||
encoding = image_processing(image, map, return_tensors="pt")
|
||||
self.assertTrue(encoding["labels"].min().item() >= 0)
|
||||
self.assertTrue(encoding["labels"].max().item() <= 255)
|
||||
|
||||
def test_slow_fast_equivalence(self):
|
||||
if not self.test_slow_image_processor or not self.test_fast_image_processor:
|
||||
self.skipTest(reason="Skipping slow/fast equivalence test")
|
||||
|
||||
if self.image_processing_class is None or self.fast_image_processing_class is None:
|
||||
self.skipTest(reason="Skipping slow/fast equivalence test as one of the image processors is not defined")
|
||||
|
||||
# Test with single image
|
||||
dummy_image = Image.open(
|
||||
requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw
|
||||
)
|
||||
image_processor_slow = self.image_processing_class(**self.image_processor_dict)
|
||||
image_processor_fast = self.fast_image_processing_class(**self.image_processor_dict)
|
||||
|
||||
encoding_slow = image_processor_slow(dummy_image, return_tensors="pt")
|
||||
encoding_fast = image_processor_fast(dummy_image, return_tensors="pt")
|
||||
self._assert_slow_fast_tensors_equivalence(encoding_slow.pixel_values, encoding_fast.pixel_values)
|
||||
|
||||
# Test with single image and segmentation map
|
||||
image, segmentation_map = prepare_semantic_single_inputs()
|
||||
|
||||
encoding_slow = image_processor_slow(image, segmentation_map, return_tensors="pt")
|
||||
encoding_fast = image_processor_fast(image, segmentation_map, return_tensors="pt")
|
||||
self._assert_slow_fast_tensors_equivalence(encoding_slow.pixel_values, encoding_fast.pixel_values)
|
||||
torch.testing.assert_close(encoding_slow.labels, encoding_fast.labels, atol=1e-1, rtol=1e-3)
|
||||
|
@ -50,6 +50,7 @@ class MobileViTImageProcessingTester:
|
||||
do_center_crop=True,
|
||||
crop_size=None,
|
||||
do_flip_channel_order=True,
|
||||
do_reduce_labels=False,
|
||||
):
|
||||
size = size if size is not None else {"shortest_edge": 20}
|
||||
crop_size = crop_size if crop_size is not None else {"height": 18, "width": 18}
|
||||
@ -64,6 +65,7 @@ class MobileViTImageProcessingTester:
|
||||
self.do_center_crop = do_center_crop
|
||||
self.crop_size = crop_size
|
||||
self.do_flip_channel_order = do_flip_channel_order
|
||||
self.do_reduce_labels = do_reduce_labels
|
||||
|
||||
def prepare_image_processor_dict(self):
|
||||
return {
|
||||
@ -72,6 +74,7 @@ class MobileViTImageProcessingTester:
|
||||
"do_center_crop": self.do_center_crop,
|
||||
"crop_size": self.crop_size,
|
||||
"do_flip_channel_order": self.do_flip_channel_order,
|
||||
"do_reduce_labels": self.do_reduce_labels,
|
||||
}
|
||||
|
||||
def expected_output_image_shape(self, images):
|
||||
@ -122,16 +125,21 @@ class MobileViTImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
|
||||
self.assertTrue(hasattr(image_processing, "do_center_crop"))
|
||||
self.assertTrue(hasattr(image_processing, "center_crop"))
|
||||
self.assertTrue(hasattr(image_processing, "do_flip_channel_order"))
|
||||
self.assertTrue(hasattr(image_processing, "do_reduce_labels"))
|
||||
|
||||
def test_image_processor_from_dict_with_kwargs(self):
|
||||
for image_processing_class in self.image_processor_list:
|
||||
image_processor = self.image_processing_class.from_dict(self.image_processor_dict)
|
||||
self.assertEqual(image_processor.size, {"shortest_edge": 20})
|
||||
self.assertEqual(image_processor.crop_size, {"height": 18, "width": 18})
|
||||
self.assertEqual(image_processor.do_reduce_labels, False)
|
||||
|
||||
image_processor = self.image_processing_class.from_dict(self.image_processor_dict, size=42, crop_size=84)
|
||||
image_processor = self.image_processing_class.from_dict(
|
||||
self.image_processor_dict, size=42, crop_size=84, do_reduce_labels=True
|
||||
)
|
||||
self.assertEqual(image_processor.size, {"shortest_edge": 42})
|
||||
self.assertEqual(image_processor.crop_size, {"height": 84, "width": 84})
|
||||
self.assertEqual(image_processor.do_reduce_labels, True)
|
||||
|
||||
def test_call_segmentation_maps(self):
|
||||
for image_processing_class in self.image_processor_list:
|
||||
@ -240,6 +248,22 @@ class MobileViTImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
|
||||
self.assertTrue(encoding["labels"].min().item() >= 0)
|
||||
self.assertTrue(encoding["labels"].max().item() <= 255)
|
||||
|
||||
def test_reduce_labels(self):
|
||||
for image_processing_class in self.image_processor_list:
|
||||
# Initialize image_processing
|
||||
image_processing = self.image_processing_class(**self.image_processor_dict)
|
||||
|
||||
# ADE20k has 150 classes, and the background is included, so labels should be between 0 and 150
|
||||
image, map = prepare_semantic_single_inputs()
|
||||
encoding = image_processing(image, map, return_tensors="pt")
|
||||
self.assertTrue(encoding["labels"].min().item() >= 0)
|
||||
self.assertTrue(encoding["labels"].max().item() <= 150)
|
||||
|
||||
image_processing.do_reduce_labels = True
|
||||
encoding = image_processing(image, map, return_tensors="pt")
|
||||
self.assertTrue(encoding["labels"].min().item() >= 0)
|
||||
self.assertTrue(encoding["labels"].max().item() <= 255)
|
||||
|
||||
@require_vision
|
||||
@require_torch
|
||||
def test_slow_fast_equivalence(self):
|
||||
|
Loading…
Reference in New Issue
Block a user