diff --git a/docs/source/en/model_doc/flava.md b/docs/source/en/model_doc/flava.md index b32f93fc8bc..c809be73589 100644 --- a/docs/source/en/model_doc/flava.md +++ b/docs/source/en/model_doc/flava.md @@ -72,6 +72,11 @@ This model was contributed by [aps](https://huggingface.co/aps). The original co [[autodoc]] FlavaImageProcessor - preprocess +## FlavaImageProcessorFast + +[[autodoc]] FlavaImageProcessorFast + - preprocess + ## FlavaForPreTraining [[autodoc]] FlavaForPreTraining diff --git a/src/transformers/models/auto/image_processing_auto.py b/src/transformers/models/auto/image_processing_auto.py index 0eccd9ae123..5e8232d9411 100644 --- a/src/transformers/models/auto/image_processing_auto.py +++ b/src/transformers/models/auto/image_processing_auto.py @@ -84,7 +84,7 @@ else: ("dpt", ("DPTImageProcessor",)), ("efficientformer", ("EfficientFormerImageProcessor",)), ("efficientnet", ("EfficientNetImageProcessor",)), - ("flava", ("FlavaImageProcessor",)), + ("flava", ("FlavaImageProcessor", "FlavaImageProcessorFast")), ("focalnet", ("BitImageProcessor",)), ("fuyu", ("FuyuImageProcessor",)), ("gemma3", ("Gemma3ImageProcessor", "Gemma3ImageProcessorFast")), diff --git a/src/transformers/models/flava/__init__.py b/src/transformers/models/flava/__init__.py index c258a8afc8e..292593cb4a2 100644 --- a/src/transformers/models/flava/__init__.py +++ b/src/transformers/models/flava/__init__.py @@ -21,6 +21,7 @@ if TYPE_CHECKING: from .configuration_flava import * from .feature_extraction_flava import * from .image_processing_flava import * + from .image_processing_flava_fast import * from .modeling_flava import * from .processing_flava import * else: diff --git a/src/transformers/models/flava/image_processing_flava_fast.py b/src/transformers/models/flava/image_processing_flava_fast.py new file mode 100644 index 00000000000..89beb9ab5f5 --- /dev/null +++ b/src/transformers/models/flava/image_processing_flava_fast.py @@ -0,0 +1,549 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Fast Image processor class for Flava.""" + +import math +import random +from functools import lru_cache +from typing import Any, Dict, Iterable, Optional, Tuple, Union + +from ...image_processing_utils_fast import ( + BASE_IMAGE_PROCESSOR_FAST_DOCSTRING, + BASE_IMAGE_PROCESSOR_FAST_DOCSTRING_PREPROCESS, + BaseImageProcessorFast, + BatchFeature, + DefaultFastImageProcessorKwargs, + get_size_dict, +) +from ...image_transforms import ChannelDimension, group_images_by_shape, reorder_images +from ...image_utils import ImageInput, PILImageResampling, SizeDict +from ...processing_utils import Unpack +from ...utils import ( + TensorType, + add_start_docstrings, + is_torch_available, + is_torchvision_available, + is_torchvision_v2_available, +) +from .image_processing_flava import ( + FLAVA_CODEBOOK_MEAN, + FLAVA_CODEBOOK_STD, + FLAVA_IMAGE_MEAN, + FLAVA_IMAGE_STD, + LOGIT_LAPLACE_EPS, +) + + +if is_torch_available(): + import torch + +if is_torchvision_available(): + from ...image_utils import pil_torch_interpolation_mapping + + if is_torchvision_v2_available(): + from torchvision.transforms.v2 import functional as F + else: + from torchvision.transforms import functional as F + + +class FlavaMaskingGenerator: + def __init__( + self, + input_size: Union[int, Tuple[int, int]] = 14, + total_mask_patches: int = 75, + mask_group_max_patches: Optional[int] = None, + mask_group_min_patches: int = 16, + mask_group_min_aspect_ratio: Optional[float] = 0.3, + mask_group_max_aspect_ratio: float = None, + ): + if not isinstance(input_size, tuple): + input_size = (input_size,) * 2 + self.height, self.width = input_size + + self.num_patches = self.height * self.width + self.total_mask_patches = total_mask_patches + + self.mask_group_min_patches = mask_group_min_patches + self.mask_group_max_patches = total_mask_patches if mask_group_max_patches is None else mask_group_max_patches + + mask_group_max_aspect_ratio = mask_group_max_aspect_ratio or 1 / mask_group_min_aspect_ratio + self.log_aspect_ratio = (math.log(mask_group_min_aspect_ratio), math.log(mask_group_max_aspect_ratio)) + + def __repr__(self): + repr_str = "MaskingGenerator(%d, %d -> [%d ~ %d], max = %d, %.3f ~ %.3f)" % ( + self.height, + self.width, + self.mask_group_min_patches, + self.mask_group_max_patches, + self.total_mask_patches, + self.log_aspect_ratio[0], + self.log_aspect_ratio[1], + ) + return repr_str + + def get_shape(self): + return self.height, self.width + + def _mask(self, mask, max_mask_patches): + delta = 0 + for _attempt in range(10): + target_area = random.uniform(self.mask_group_min_patches, max_mask_patches) + aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio)) + height = int(round(math.sqrt(target_area * aspect_ratio))) + width = int(round(math.sqrt(target_area / aspect_ratio))) + if width < self.width and height < self.height: + top = random.randint(0, self.height - height) + left = random.randint(0, self.width - width) + + num_masked = mask[top : top + height, left : left + width].sum() + # Overlap + if 0 < height * width - num_masked <= max_mask_patches: + zeros_pos = mask[top : top + height, left : left + width] == 0 + mask[top : top + height, left : left + width][zeros_pos] = 1 + delta += zeros_pos.sum() + + if delta > 0: + break + return delta + + def __call__(self): + mask = torch.zeros(self.get_shape(), dtype=torch.int) + mask_count = 0 + while mask_count < self.total_mask_patches: + max_mask_patches = self.total_mask_patches - mask_count + max_mask_patches = min(max_mask_patches, self.mask_group_max_patches) + + delta = self._mask(mask, max_mask_patches) + if delta == 0: + break + else: + mask_count += delta + + return mask + + +class FlavaFastImageProcessorKwargs(DefaultFastImageProcessorKwargs): + # Mask related params + return_image_mask: Optional[bool] + input_size_patches: Optional[int] + total_mask_patches: Optional[int] + mask_group_min_patches: Optional[int] + mask_group_max_patches: Optional[int] + mask_group_min_aspect_ratio: Optional[float] + mask_group_max_aspect_ratio: Optional[float] + # Codebook related params + return_codebook_pixels: Optional[bool] + codebook_do_resize: Optional[bool] + codebook_size: Optional[bool] + codebook_resample: Optional[int] + codebook_do_center_crop: Optional[bool] + codebook_crop_size: Optional[int] + codebook_do_rescale: Optional[bool] + codebook_rescale_factor: Optional[Union[int, float]] + codebook_do_map_pixels: Optional[bool] + codebook_do_normalize: Optional[bool] + codebook_image_mean: Optional[Union[float, Iterable[float]]] + codebook_image_std: Optional[Union[float, Iterable[float]]] + + +@add_start_docstrings( + "Constructs a fast Flava image processor.", + BASE_IMAGE_PROCESSOR_FAST_DOCSTRING, + """ + return_image_mask (`bool`, *optional*, defaults to `False`): + Whether to return the image mask. Can be overridden by the `return_image_mask` parameter in `preprocess`. + input_size_patches (`int`, *optional*, defaults to 14): + Number of patches in the image in height and width direction. 14x14 = 196 total patches. Can be overridden + by the `input_size_patches` parameter in `preprocess`. + total_mask_patches (`int`, *optional*, defaults to 75): + Total number of patches that should be masked. Can be overridden by the `total_mask_patches` parameter in + `preprocess`. + mask_group_min_patches (`int`, *optional*, defaults to 16): + Minimum number of patches that should be masked. Can be overridden by the `mask_group_min_patches` + parameter in `preprocess`. + mask_group_max_patches (`int`, *optional*): + Maximum number of patches that should be masked. Can be overridden by the `mask_group_max_patches` + parameter in `preprocess`. + mask_group_min_aspect_ratio (`float`, *optional*, defaults to 0.3): + Minimum aspect ratio of the mask window. Can be overridden by the `mask_group_min_aspect_ratio` parameter + in `preprocess`. + mask_group_max_aspect_ratio (`float`, *optional*): + Maximum aspect ratio of the mask window. Can be overridden by the `mask_group_max_aspect_ratio` parameter + in `preprocess`. + codebook_do_resize (`bool`, *optional*, defaults to `True`): + Whether to resize the input for codebook to a certain. Can be overridden by the `codebook_do_resize` + parameter in `preprocess`. `codebook_size`. + codebook_size (`Dict[str, int]`, *optional*, defaults to `{"height": 224, "width": 224}`): + Resize the input for codebook to the given size. Can be overridden by the `codebook_size` parameter in + `preprocess`. + codebook_resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.LANCZOS`): + Resampling filter to use if resizing the codebook image. Can be overridden by the `codebook_resample` + parameter in `preprocess`. + codebook_do_center_crop (`bool`, *optional*, defaults to `True`): + Whether to crop the input for codebook at the center. If the input size is smaller than + `codebook_crop_size` along any edge, the image is padded with 0's and then center cropped. Can be + overridden by the `codebook_do_center_crop` parameter in `preprocess`. + codebook_crop_size (`Dict[str, int]`, *optional*, defaults to `{"height": 224, "width": 224}`): + Desired output size for codebook input when applying center-cropping. Can be overridden by the + `codebook_crop_size` parameter in `preprocess`. + codebook_do_rescale (`bool`, *optional*, defaults to `True`): + Whether to rescale the input for codebook by the specified scale `codebook_rescale_factor`. Can be + overridden by the `codebook_do_rescale` parameter in `preprocess`. + codebook_rescale_factor (`int` or `float`, *optional*, defaults to `1/255`): + Defines the scale factor to use if rescaling the codebook image. Can be overridden by the + `codebook_rescale_factor` parameter in `preprocess`. + codebook_do_map_pixels (`bool`, *optional*, defaults to `True`): + Whether to map the pixel values of the codebook input to (1 - 2e)x + e. Can be overridden by the + `codebook_do_map_pixels` parameter in `preprocess`. + codebook_do_normalize (`bool`, *optional*, defaults to `True`): + Whether or not to normalize the input for codebook with `codebook_image_mean` and `codebook_image_std`. Can + be overridden by the `codebook_do_normalize` parameter in `preprocess`. + codebook_image_mean (`Optional[Union[float, Iterable[float]]]`, *optional*, defaults to `[0, 0, 0]`): + The sequence of means for each channel, to be used when normalizing images for codebook. Can be overridden + by the `codebook_image_mean` parameter in `preprocess`. + codebook_image_std (`Optional[Union[float, Iterable[float]]]`, *optional*, defaults to `[0.5, 0.5, 0.5]`): + The sequence of standard deviations for each channel, to be used when normalizing images for codebook. Can + be overridden by the `codebook_image_std` parameter in `preprocess`. + """, +) +class FlavaImageProcessorFast(BaseImageProcessorFast): + resample = PILImageResampling.BICUBIC + image_mean = FLAVA_IMAGE_MEAN + image_std = FLAVA_IMAGE_STD + size = {"height": 224, "width": 224} + crop_size = {"height": 224, "width": 224} + do_resize = True + do_center_crop = True + do_rescale = True + do_normalize = True + + # Mask related params + return_image_mask = False + input_size_patches = 14 + total_mask_patches = 75 + mask_group_min_patches = 16 + mask_group_max_patches = None + mask_group_min_aspect_ratio = 0.3 + mask_group_max_aspect_ratio = None + # Codebook related params + return_codebook_pixels = False + codebook_do_resize = True + codebook_size = {"height": 112, "width": 112} + # LANCZOS resample does not support torch Tensor. Use BICUBIC as closest alternative + codebook_resample = PILImageResampling.BICUBIC + codebook_do_center_crop = True + codebook_crop_size = {"height": 112, "width": 112} + codebook_do_rescale = True + codebook_rescale_factor = 1 / 255 + codebook_do_map_pixels = True + codebook_do_normalize = True + codebook_image_mean = FLAVA_CODEBOOK_MEAN + codebook_image_std = FLAVA_CODEBOOK_STD + valid_kwargs = FlavaFastImageProcessorKwargs + + def __init__(self, **kwargs: Unpack[FlavaFastImageProcessorKwargs]): + super().__init__(**kwargs) + + @add_start_docstrings( + BASE_IMAGE_PROCESSOR_FAST_DOCSTRING_PREPROCESS, + """ + return_image_mask (`bool`, *optional*, defaults to `False`): + Whether to return the image mask. Can be overridden by the `return_image_mask` parameter in `preprocess`. + input_size_patches (`int`, *optional*, defaults to 14): + Number of patches in the image in height and width direction. 14x14 = 196 total patches. Can be overridden + by the `input_size_patches` parameter in `preprocess`. + total_mask_patches (`int`, *optional*, defaults to 75): + Total number of patches that should be masked. Can be overridden by the `total_mask_patches` parameter in + `preprocess`. + mask_group_min_patches (`int`, *optional*, defaults to 16): + Minimum number of patches that should be masked. Can be overridden by the `mask_group_min_patches` + parameter in `preprocess`. + mask_group_max_patches (`int`, *optional*): + Maximum number of patches that should be masked. Can be overridden by the `mask_group_max_patches` + parameter in `preprocess`. + mask_group_min_aspect_ratio (`float`, *optional*, defaults to 0.3): + Minimum aspect ratio of the mask window. Can be overridden by the `mask_group_min_aspect_ratio` parameter + in `preprocess`. + mask_group_max_aspect_ratio (`float`, *optional*): + Maximum aspect ratio of the mask window. Can be overridden by the `mask_group_max_aspect_ratio` parameter + in `preprocess`. + codebook_do_resize (`bool`, *optional*, defaults to `True`): + Whether to resize the input for codebook to a certain. Can be overridden by the `codebook_do_resize` + parameter in `preprocess`. `codebook_size`. + codebook_size (`Dict[str, int]`, *optional*, defaults to `{"height": 224, "width": 224}`): + Resize the input for codebook to the given size. Can be overridden by the `codebook_size` parameter in + `preprocess`. + codebook_resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.LANCZOS`): + Resampling filter to use if resizing the codebook image. Can be overridden by the `codebook_resample` + parameter in `preprocess`. + codebook_do_center_crop (`bool`, *optional*, defaults to `True`): + Whether to crop the input for codebook at the center. If the input size is smaller than + `codebook_crop_size` along any edge, the image is padded with 0's and then center cropped. Can be + overridden by the `codebook_do_center_crop` parameter in `preprocess`. + codebook_crop_size (`Dict[str, int]`, *optional*, defaults to `{"height": 224, "width": 224}`): + Desired output size for codebook input when applying center-cropping. Can be overridden by the + `codebook_crop_size` parameter in `preprocess`. + codebook_do_rescale (`bool`, *optional*, defaults to `True`): + Whether to rescale the input for codebook by the specified scale `codebook_rescale_factor`. Can be + overridden by the `codebook_do_rescale` parameter in `preprocess`. + codebook_rescale_factor (`int` or `float`, *optional*, defaults to `1/255`): + Defines the scale factor to use if rescaling the codebook image. Can be overridden by the + `codebook_rescale_factor` parameter in `preprocess`. + codebook_do_map_pixels (`bool`, *optional*, defaults to `True`): + Whether to map the pixel values of the codebook input to (1 - 2e)x + e. Can be overridden by the + `codebook_do_map_pixels` parameter in `preprocess`. + codebook_do_normalize (`bool`, *optional*, defaults to `True`): + Whether or not to normalize the input for codebook with `codebook_image_mean` and `codebook_image_std`. Can + be overridden by the `codebook_do_normalize` parameter in `preprocess`. + codebook_image_mean (`Optional[Union[float, Iterable[float]]]`, *optional*, defaults to `[0, 0, 0]`): + The sequence of means for each channel, to be used when normalizing images for codebook. Can be overridden + by the `codebook_image_mean` parameter in `preprocess`. + codebook_image_std (`Optional[Union[float, Iterable[float]]]`, *optional*, defaults to `[0.5, 0.5, 0.5]`): + The sequence of standard deviations for each channel, to be used when normalizing images for codebook. Can + be overridden by the `codebook_image_std` parameter in `preprocess`. + """, + ) + def preprocess(self, images: ImageInput, **kwargs: Unpack[DefaultFastImageProcessorKwargs]) -> BatchFeature: + return super().preprocess(images, **kwargs) + + @classmethod + def from_dict(cls, image_processor_dict: Dict[str, Any], **kwargs): + """ + Overrides the `from_dict` method from the base class to make sure parameters are updated if image processor is + created using from_dict and kwargs e.g. `FlavaImageProcessor.from_pretrained(checkpoint, codebook_size=600)` + """ + image_processor_dict = image_processor_dict.copy() + if "codebook_size" in kwargs: + image_processor_dict["codebook_size"] = kwargs.pop("codebook_size") + if "codebook_crop_size" in kwargs: + image_processor_dict["codebook_crop_size"] = kwargs.pop("codebook_crop_size") + return super().from_dict(image_processor_dict, **kwargs) + + @lru_cache() + def masking_generator( + self, + input_size_patches, + total_mask_patches, + mask_group_min_patches, + mask_group_max_patches, + mask_group_min_aspect_ratio, + mask_group_max_aspect_ratio, + ) -> FlavaMaskingGenerator: + return FlavaMaskingGenerator( + input_size=input_size_patches, + total_mask_patches=total_mask_patches, + mask_group_min_patches=mask_group_min_patches, + mask_group_max_patches=mask_group_max_patches, + mask_group_min_aspect_ratio=mask_group_min_aspect_ratio, + mask_group_max_aspect_ratio=mask_group_max_aspect_ratio, + ) + + def map_pixels(self, image: "torch.Tensor") -> "torch.Tensor": + return (1 - 2 * LOGIT_LAPLACE_EPS) * image + LOGIT_LAPLACE_EPS + + def _further_process_kwargs( + self, + size: Optional[SizeDict] = None, + crop_size: Optional[SizeDict] = None, + default_to_square: Optional[bool] = None, + image_mean: Optional[Union[float, list[float]]] = None, + image_std: Optional[Union[float, list[float]]] = None, + codebook_size: Optional[SizeDict] = None, + codebook_crop_size: Optional[SizeDict] = None, + codebook_image_mean: Optional[Union[float, list[float]]] = None, + codebook_image_std: Optional[Union[float, list[float]]] = None, + codebook_resample: Optional[PILImageResampling] = None, + data_format: Optional[ChannelDimension] = None, + **kwargs, + ) -> dict: + """ + Update kwargs that need further processing before being validated + Can be overridden by subclasses to customize the processing of kwargs. + """ + if kwargs is None: + kwargs = {} + if size is not None: + size = SizeDict(**get_size_dict(size=size, default_to_square=default_to_square)) + if crop_size is not None: + crop_size = SizeDict(**get_size_dict(crop_size, param_name="crop_size")) + if isinstance(image_mean, list): + image_mean = tuple(image_mean) + if isinstance(image_std, list): + image_std = tuple(image_std) + if data_format is None: + data_format = ChannelDimension.FIRST + if codebook_size is not None: + codebook_size = SizeDict(**get_size_dict(size=codebook_size, default_to_square=default_to_square)) + if codebook_crop_size is not None: + codebook_crop_size = SizeDict(**get_size_dict(codebook_crop_size, param_name="codebook_crop_size")) + if isinstance(codebook_image_mean, list): + codebook_image_mean = tuple(codebook_image_mean) + if isinstance(codebook_image_std, list): + codebook_image_std = tuple(codebook_image_std) + + kwargs["size"] = size + kwargs["crop_size"] = crop_size + kwargs["default_to_square"] = default_to_square + kwargs["image_mean"] = image_mean + kwargs["image_std"] = image_std + kwargs["codebook_size"] = codebook_size + kwargs["codebook_crop_size"] = codebook_crop_size + kwargs["codebook_image_mean"] = codebook_image_mean + kwargs["codebook_image_std"] = codebook_image_std + kwargs["data_format"] = data_format + kwargs["codebook_interpolation"] = ( + pil_torch_interpolation_mapping[codebook_resample] + if isinstance(codebook_resample, (PILImageResampling, int)) + else codebook_resample + ) + + return kwargs + + def _preprocess_image( + self, + images: list["torch.Tensor"], + do_resize: bool, + size: SizeDict, + interpolation: Optional["F.InterpolationMode"], + do_center_crop: bool, + crop_size: SizeDict, + do_rescale: bool, + rescale_factor: float, + do_normalize: bool, + do_map_pixels: bool, + image_mean: Optional[Union[float, list[float]]], + image_std: Optional[Union[float, list[float]]], + return_tensors: Optional[Union[str, TensorType]], + ) -> "torch.Tensor": + # Group images by size for batched resizing + grouped_images, grouped_images_index = group_images_by_shape(images) + resized_images_grouped = {} + for shape, stacked_images in grouped_images.items(): + if do_resize: + stacked_images = self.resize(image=stacked_images, size=size, interpolation=interpolation) + resized_images_grouped[shape] = stacked_images + resized_images = reorder_images(resized_images_grouped, grouped_images_index) + + # Group images by size for further processing + # Needed in case do_resize is False, or resize returns images with different sizes + grouped_images, grouped_images_index = group_images_by_shape(resized_images) + processed_images_grouped = {} + for shape, stacked_images in grouped_images.items(): + if do_center_crop: + stacked_images = self.center_crop(stacked_images, crop_size) + # Fused rescale and normalize + stacked_images = self.rescale_and_normalize( + stacked_images, do_rescale, rescale_factor, do_normalize, image_mean, image_std + ) + if do_map_pixels: + stacked_images = self.map_pixels(image=stacked_images) + processed_images_grouped[shape] = stacked_images + + processed_images = reorder_images(processed_images_grouped, grouped_images_index) + processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images + + return processed_images + + def _preprocess( + self, + images: list["torch.Tensor"], + do_resize: bool, + size: SizeDict, + interpolation: Optional["F.InterpolationMode"], + do_center_crop: bool, + crop_size: SizeDict, + do_rescale: bool, + rescale_factor: float, + do_normalize: bool, + image_mean: Optional[Union[float, list[float]]], + image_std: Optional[Union[float, list[float]]], + # Mask related params + return_image_mask: Optional[bool], + input_size_patches: Optional[int], + total_mask_patches: Optional[int], + mask_group_min_patches: Optional[int], + mask_group_max_patches: Optional[int], + mask_group_min_aspect_ratio: Optional[float], + mask_group_max_aspect_ratio: Optional[float], + # Codebook related params + return_codebook_pixels: Optional[bool], + codebook_do_resize: Optional[bool], + codebook_size: Optional[SizeDict], + codebook_interpolation: Optional["F.InterpolationMode"], + codebook_do_center_crop: Optional[bool], + codebook_crop_size: Optional[SizeDict], + codebook_do_rescale: Optional[bool], + codebook_rescale_factor: Optional[float], + codebook_do_map_pixels: Optional[bool], + codebook_do_normalize: Optional[bool], + codebook_image_mean: Optional[Union[float, list[float]]], + codebook_image_std: Optional[Union[float, list[float]]], + return_tensors: Optional[Union[str, TensorType]], + **kwargs, + ) -> BatchFeature: + processed_images = self._preprocess_image( + images=images, + do_resize=do_resize, + size=size, + interpolation=interpolation, + do_center_crop=do_center_crop, + crop_size=crop_size, + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + do_map_pixels=False, + image_mean=image_mean, + image_std=image_std, + return_tensors=return_tensors, + ) + data = { + "pixel_values": processed_images, + } + + if return_codebook_pixels: + codebook_processed_images = self._preprocess_image( + images=images, + do_resize=codebook_do_resize, + size=codebook_size, + interpolation=codebook_interpolation, + do_center_crop=codebook_do_center_crop, + crop_size=codebook_crop_size, + do_rescale=codebook_do_rescale, + rescale_factor=codebook_rescale_factor, + do_normalize=codebook_do_normalize, + do_map_pixels=codebook_do_map_pixels, + image_mean=codebook_image_mean, + image_std=codebook_image_std, + return_tensors=return_tensors, + ) + data["codebook_pixel_values"] = codebook_processed_images + + if return_image_mask: + mask_generator = self.masking_generator( + input_size_patches=input_size_patches, + total_mask_patches=total_mask_patches, + mask_group_min_patches=mask_group_min_patches, + mask_group_max_patches=mask_group_max_patches, + mask_group_min_aspect_ratio=mask_group_min_aspect_ratio, + mask_group_max_aspect_ratio=mask_group_max_aspect_ratio, + ) + masks = [mask_generator() for _ in range(len(images))] + masks = torch.stack(masks, dim=0) if return_tensors else masks + data["bool_masked_pos"] = masks + + return BatchFeature(data=data, tensor_type=return_tensors) + + +__all__ = ["FlavaImageProcessorFast"] diff --git a/tests/models/flava/test_image_processing_flava.py b/tests/models/flava/test_image_processing_flava.py index b7371c8f9b9..5edb1997abb 100644 --- a/tests/models/flava/test_image_processing_flava.py +++ b/tests/models/flava/test_image_processing_flava.py @@ -16,9 +16,11 @@ import random import unittest import numpy as np +import requests +from PIL import Image from transformers.testing_utils import require_torch, require_vision -from transformers.utils import is_torch_available, is_vision_available +from transformers.utils import is_torch_available, is_torchvision_available, is_vision_available from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs @@ -30,6 +32,9 @@ if is_vision_available(): import PIL from transformers import FlavaImageProcessor + + if is_torchvision_available(): + from transformers import FlavaImageProcessorFast from transformers.image_utils import PILImageResampling from transformers.models.flava.image_processing_flava import ( FLAVA_CODEBOOK_MEAN, @@ -105,7 +110,8 @@ class FlavaImageProcessingTester: self.codebook_do_resize = codebook_do_resize self.codebook_size = codebook_size - self.codebook_resample = codebook_resample if codebook_resample is not None else PILImageResampling.LANCZOS + # LANCZOS resample does not support torch Tensor. Use BICUBIC as closest alternative + self.codebook_resample = codebook_resample if codebook_resample is not None else PILImageResampling.BICUBIC self.codebook_do_center_crop = codebook_do_center_crop self.codebook_crop_size = codebook_crop_size self.codebook_do_map_pixels = codebook_do_map_pixels @@ -171,6 +177,7 @@ class FlavaImageProcessingTester: @require_vision class FlavaImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase): image_processing_class = FlavaImageProcessor if is_vision_available() else None + fast_image_processing_class = FlavaImageProcessorFast if is_torchvision_available() else None maxDiff = None def setUp(self): @@ -182,157 +189,161 @@ class FlavaImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase): return self.image_processor_tester.prepare_image_processor_dict() def test_image_processor_properties(self): - image_processing = self.image_processing_class(**self.image_processor_dict) - self.assertTrue(hasattr(image_processing, "image_mean")) - self.assertTrue(hasattr(image_processing, "image_std")) - self.assertTrue(hasattr(image_processing, "do_normalize")) - self.assertTrue(hasattr(image_processing, "do_resize")) - self.assertTrue(hasattr(image_processing, "resample")) - self.assertTrue(hasattr(image_processing, "crop_size")) - self.assertTrue(hasattr(image_processing, "do_center_crop")) - self.assertTrue(hasattr(image_processing, "do_rescale")) - self.assertTrue(hasattr(image_processing, "rescale_factor")) - self.assertTrue(hasattr(image_processing, "masking_generator")) - self.assertTrue(hasattr(image_processing, "codebook_do_resize")) - self.assertTrue(hasattr(image_processing, "codebook_size")) - self.assertTrue(hasattr(image_processing, "codebook_resample")) - self.assertTrue(hasattr(image_processing, "codebook_do_center_crop")) - self.assertTrue(hasattr(image_processing, "codebook_crop_size")) - self.assertTrue(hasattr(image_processing, "codebook_do_map_pixels")) - self.assertTrue(hasattr(image_processing, "codebook_do_normalize")) - self.assertTrue(hasattr(image_processing, "codebook_image_mean")) - self.assertTrue(hasattr(image_processing, "codebook_image_std")) + for image_processing_class in self.image_processor_list: + image_processing = image_processing_class(**self.image_processor_dict) + self.assertTrue(hasattr(image_processing, "image_mean")) + self.assertTrue(hasattr(image_processing, "image_std")) + self.assertTrue(hasattr(image_processing, "do_normalize")) + self.assertTrue(hasattr(image_processing, "do_resize")) + self.assertTrue(hasattr(image_processing, "resample")) + self.assertTrue(hasattr(image_processing, "crop_size")) + self.assertTrue(hasattr(image_processing, "do_center_crop")) + self.assertTrue(hasattr(image_processing, "do_rescale")) + self.assertTrue(hasattr(image_processing, "rescale_factor")) + self.assertTrue(hasattr(image_processing, "masking_generator")) + self.assertTrue(hasattr(image_processing, "codebook_do_resize")) + self.assertTrue(hasattr(image_processing, "codebook_size")) + self.assertTrue(hasattr(image_processing, "codebook_resample")) + self.assertTrue(hasattr(image_processing, "codebook_do_center_crop")) + self.assertTrue(hasattr(image_processing, "codebook_crop_size")) + self.assertTrue(hasattr(image_processing, "codebook_do_map_pixels")) + self.assertTrue(hasattr(image_processing, "codebook_do_normalize")) + self.assertTrue(hasattr(image_processing, "codebook_image_mean")) + self.assertTrue(hasattr(image_processing, "codebook_image_std")) def test_image_processor_from_dict_with_kwargs(self): - image_processor = self.image_processing_class.from_dict(self.image_processor_dict) - self.assertEqual(image_processor.size, {"height": 224, "width": 224}) - self.assertEqual(image_processor.crop_size, {"height": 224, "width": 224}) - self.assertEqual(image_processor.codebook_size, {"height": 112, "width": 112}) - self.assertEqual(image_processor.codebook_crop_size, {"height": 112, "width": 112}) + for image_processing_class in self.image_processor_list: + image_processor = image_processing_class.from_dict(self.image_processor_dict) + self.assertEqual(image_processor.size, {"height": 224, "width": 224}) + self.assertEqual(image_processor.crop_size, {"height": 224, "width": 224}) + self.assertEqual(image_processor.codebook_size, {"height": 112, "width": 112}) + self.assertEqual(image_processor.codebook_crop_size, {"height": 112, "width": 112}) - image_processor = self.image_processing_class.from_dict( - self.image_processor_dict, size=42, crop_size=84, codebook_size=33, codebook_crop_size=66 - ) - self.assertEqual(image_processor.size, {"height": 42, "width": 42}) - self.assertEqual(image_processor.crop_size, {"height": 84, "width": 84}) - self.assertEqual(image_processor.codebook_size, {"height": 33, "width": 33}) - self.assertEqual(image_processor.codebook_crop_size, {"height": 66, "width": 66}) + image_processor = self.image_processing_class.from_dict( + self.image_processor_dict, size=42, crop_size=84, codebook_size=33, codebook_crop_size=66 + ) + self.assertEqual(image_processor.size, {"height": 42, "width": 42}) + self.assertEqual(image_processor.crop_size, {"height": 84, "width": 84}) + self.assertEqual(image_processor.codebook_size, {"height": 33, "width": 33}) + self.assertEqual(image_processor.codebook_crop_size, {"height": 66, "width": 66}) def test_call_pil(self): - # Initialize image_processing - image_processing = self.image_processing_class(**self.image_processor_dict) - # create random PIL images - image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False) - for image in image_inputs: - self.assertIsInstance(image, PIL.Image.Image) + for image_processing_class in self.image_processor_list: + # Initialize image_processing + image_processing = image_processing_class(**self.image_processor_dict) + # create random PIL images + image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False) + for image in image_inputs: + self.assertIsInstance(image, PIL.Image.Image) - # Test not batched input - encoded_images = image_processing(image_inputs[0], return_tensors="pt") + # Test not batched input + encoded_images = image_processing(image_inputs[0], return_tensors="pt") - # Test no bool masked pos - self.assertFalse("bool_masked_pos" in encoded_images) + # Test no bool masked pos + self.assertFalse("bool_masked_pos" in encoded_images) - expected_height, expected_width = self.image_processor_tester.get_expected_image_size() + expected_height, expected_width = self.image_processor_tester.get_expected_image_size() - self.assertEqual( - encoded_images.pixel_values.shape, - (1, self.image_processor_tester.num_channels, expected_height, expected_width), - ) + self.assertEqual( + encoded_images.pixel_values.shape, + (1, self.image_processor_tester.num_channels, expected_height, expected_width), + ) - # Test batched - encoded_images = image_processing(image_inputs, return_tensors="pt") - expected_height, expected_width = self.image_processor_tester.get_expected_image_size() + # Test batched + encoded_images = image_processing(image_inputs, return_tensors="pt") + expected_height, expected_width = self.image_processor_tester.get_expected_image_size() - # Test no bool masked pos - self.assertFalse("bool_masked_pos" in encoded_images) + # Test no bool masked pos + self.assertFalse("bool_masked_pos" in encoded_images) - self.assertEqual( - encoded_images.pixel_values.shape, - ( - self.image_processor_tester.batch_size, - self.image_processor_tester.num_channels, - expected_height, - expected_width, - ), - ) + self.assertEqual( + encoded_images.pixel_values.shape, + ( + self.image_processor_tester.batch_size, + self.image_processor_tester.num_channels, + expected_height, + expected_width, + ), + ) def _test_call_framework(self, instance_class, prepare_kwargs): - # Initialize image_processing - image_processing = self.image_processing_class(**self.image_processor_dict) - # create random tensors - image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, **prepare_kwargs) - for image in image_inputs: - self.assertIsInstance(image, instance_class) + for image_processing_class in self.image_processor_list: + # Initialize image_processing + image_processing = image_processing_class(**self.image_processor_dict) + # create random tensors + image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, **prepare_kwargs) + for image in image_inputs: + self.assertIsInstance(image, instance_class) - # Test not batched input - encoded_images = image_processing(image_inputs[0], return_tensors="pt") + # Test not batched input + encoded_images = image_processing(image_inputs[0], return_tensors="pt") - expected_height, expected_width = self.image_processor_tester.get_expected_image_size() - self.assertEqual( - encoded_images.pixel_values.shape, - (1, self.image_processor_tester.num_channels, expected_height, expected_width), - ) + expected_height, expected_width = self.image_processor_tester.get_expected_image_size() + self.assertEqual( + encoded_images.pixel_values.shape, + (1, self.image_processor_tester.num_channels, expected_height, expected_width), + ) - encoded_images = image_processing(image_inputs, return_image_mask=True, return_tensors="pt") + encoded_images = image_processing(image_inputs, return_image_mask=True, return_tensors="pt") - expected_height, expected_width = self.image_processor_tester.get_expected_image_size() - self.assertEqual( - encoded_images.pixel_values.shape, - ( - self.image_processor_tester.batch_size, - self.image_processor_tester.num_channels, - expected_height, - expected_width, - ), - ) + expected_height, expected_width = self.image_processor_tester.get_expected_image_size() + self.assertEqual( + encoded_images.pixel_values.shape, + ( + self.image_processor_tester.batch_size, + self.image_processor_tester.num_channels, + expected_height, + expected_width, + ), + ) - expected_height, expected_width = self.image_processor_tester.get_expected_mask_size() - self.assertEqual( - encoded_images.bool_masked_pos.shape, - ( - self.image_processor_tester.batch_size, - expected_height, - expected_width, - ), - ) + expected_height, expected_width = self.image_processor_tester.get_expected_mask_size() + self.assertEqual( + encoded_images.bool_masked_pos.shape, + ( + self.image_processor_tester.batch_size, + expected_height, + expected_width, + ), + ) - # Test batched - encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values + # Test batched + encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values - expected_height, expected_width = self.image_processor_tester.get_expected_image_size() - self.assertEqual( - encoded_images.shape, - ( - self.image_processor_tester.batch_size, - self.image_processor_tester.num_channels, - expected_height, - expected_width, - ), - ) + expected_height, expected_width = self.image_processor_tester.get_expected_image_size() + self.assertEqual( + encoded_images.shape, + ( + self.image_processor_tester.batch_size, + self.image_processor_tester.num_channels, + expected_height, + expected_width, + ), + ) - # Test masking - encoded_images = image_processing(image_inputs, return_image_mask=True, return_tensors="pt") + # Test masking + encoded_images = image_processing(image_inputs, return_image_mask=True, return_tensors="pt") - expected_height, expected_width = self.image_processor_tester.get_expected_image_size() - self.assertEqual( - encoded_images.pixel_values.shape, - ( - self.image_processor_tester.batch_size, - self.image_processor_tester.num_channels, - expected_height, - expected_width, - ), - ) + expected_height, expected_width = self.image_processor_tester.get_expected_image_size() + self.assertEqual( + encoded_images.pixel_values.shape, + ( + self.image_processor_tester.batch_size, + self.image_processor_tester.num_channels, + expected_height, + expected_width, + ), + ) - expected_height, expected_width = self.image_processor_tester.get_expected_mask_size() - self.assertEqual( - encoded_images.bool_masked_pos.shape, - ( - self.image_processor_tester.batch_size, - expected_height, - expected_width, - ), - ) + expected_height, expected_width = self.image_processor_tester.get_expected_mask_size() + self.assertEqual( + encoded_images.bool_masked_pos.shape, + ( + self.image_processor_tester.batch_size, + expected_height, + expected_width, + ), + ) def test_call_numpy(self): self._test_call_framework(np.ndarray, prepare_kwargs={"numpify": True}) @@ -346,40 +357,76 @@ class FlavaImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase): self._test_call_framework(torch.Tensor, prepare_kwargs={"torchify": True}) def test_masking(self): - # Initialize image_processing - random.seed(1234) - image_processing = self.image_processing_class(**self.image_processor_dict) - image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, torchify=True) + for image_processing_class in self.image_processor_list: + # Initialize image_processing + random.seed(1234) + image_processing = image_processing_class(**self.image_processor_dict) + image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, torchify=True) - # Test not batched input - encoded_images = image_processing(image_inputs[0], return_image_mask=True, return_tensors="pt") - self.assertEqual(encoded_images.bool_masked_pos.sum().item(), 75) + # Test not batched input + encoded_images = image_processing(image_inputs[0], return_image_mask=True, return_tensors="pt") + self.assertEqual(encoded_images.bool_masked_pos.sum().item(), 75) def test_codebook_pixels(self): - # Initialize image_processing - image_processing = self.image_processing_class(**self.image_processor_dict) - # create random PIL images - image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False) - for image in image_inputs: - self.assertIsInstance(image, PIL.Image.Image) + for image_processing_class in self.image_processor_list: + # Initialize image_processing + image_processing = image_processing_class(**self.image_processor_dict) + # create random PIL images + image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False) + for image in image_inputs: + self.assertIsInstance(image, PIL.Image.Image) - # Test not batched input - encoded_images = image_processing(image_inputs[0], return_codebook_pixels=True, return_tensors="pt") - expected_height, expected_width = self.image_processor_tester.get_expected_codebook_image_size() - self.assertEqual( - encoded_images.codebook_pixel_values.shape, - (1, self.image_processor_tester.num_channels, expected_height, expected_width), + # Test not batched input + encoded_images = image_processing(image_inputs[0], return_codebook_pixels=True, return_tensors="pt") + expected_height, expected_width = self.image_processor_tester.get_expected_codebook_image_size() + self.assertEqual( + encoded_images.codebook_pixel_values.shape, + (1, self.image_processor_tester.num_channels, expected_height, expected_width), + ) + + # Test batched + encoded_images = image_processing(image_inputs, return_codebook_pixels=True, return_tensors="pt") + expected_height, expected_width = self.image_processor_tester.get_expected_codebook_image_size() + self.assertEqual( + encoded_images.codebook_pixel_values.shape, + ( + self.image_processor_tester.batch_size, + self.image_processor_tester.num_channels, + expected_height, + expected_width, + ), + ) + + @require_vision + @require_torch + def test_slow_fast_equivalence(self): + if not self.test_slow_image_processor or not self.test_fast_image_processor: + self.skipTest(reason="Skipping slow/fast equivalence test") + + if self.image_processing_class is None or self.fast_image_processing_class is None: + self.skipTest(reason="Skipping slow/fast equivalence test as one of the image processors is not defined") + + dummy_image = Image.open( + requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw + ) + image_processor_slow = self.image_processing_class(**self.image_processor_dict) + image_processor_fast = self.fast_image_processing_class(**self.image_processor_dict) + + encoding_slow = image_processor_slow( + dummy_image, return_tensors="pt", return_codebook_pixels=True, return_image_mask=True + ) + encoding_fast = image_processor_fast( + dummy_image, return_tensors="pt", return_codebook_pixels=True, return_image_mask=True + ) + self.assertTrue(torch.allclose(encoding_slow.pixel_values, encoding_fast.pixel_values, atol=1e-1)) + self.assertLessEqual( + torch.mean(torch.abs(encoding_slow.pixel_values - encoding_fast.pixel_values)).item(), 1e-3 ) - # Test batched - encoded_images = image_processing(image_inputs, return_codebook_pixels=True, return_tensors="pt") - expected_height, expected_width = self.image_processor_tester.get_expected_codebook_image_size() - self.assertEqual( - encoded_images.codebook_pixel_values.shape, - ( - self.image_processor_tester.batch_size, - self.image_processor_tester.num_channels, - expected_height, - expected_width, - ), + self.assertTrue( + torch.allclose(encoding_slow.codebook_pixel_values, encoding_fast.codebook_pixel_values, atol=1e-1) + ) + self.assertLessEqual( + torch.mean(torch.abs(encoding_slow.codebook_pixel_values - encoding_fast.codebook_pixel_values)).item(), + 1e-3, )