Add Fast Image Processor for Flava (#37135)

* support flava fast image processor

* run style and quality

* update test

* update according to reviews

* make style

* update comment on BICUBIC

* make style

---------

Co-authored-by: Yoni Gozlan <74535834+yonigozlan@users.noreply.github.com>
This commit is contained in:
Vinh H. Pham 2025-04-14 20:05:31 +07:00 committed by GitHub
parent a5079a2c84
commit 49b9a69a36
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 762 additions and 160 deletions

View File

@ -72,6 +72,11 @@ This model was contributed by [aps](https://huggingface.co/aps). The original co
[[autodoc]] FlavaImageProcessor
- preprocess
## FlavaImageProcessorFast
[[autodoc]] FlavaImageProcessorFast
- preprocess
## FlavaForPreTraining
[[autodoc]] FlavaForPreTraining

View File

@ -84,7 +84,7 @@ else:
("dpt", ("DPTImageProcessor",)),
("efficientformer", ("EfficientFormerImageProcessor",)),
("efficientnet", ("EfficientNetImageProcessor",)),
("flava", ("FlavaImageProcessor",)),
("flava", ("FlavaImageProcessor", "FlavaImageProcessorFast")),
("focalnet", ("BitImageProcessor",)),
("fuyu", ("FuyuImageProcessor",)),
("gemma3", ("Gemma3ImageProcessor", "Gemma3ImageProcessorFast")),

View File

@ -21,6 +21,7 @@ if TYPE_CHECKING:
from .configuration_flava import *
from .feature_extraction_flava import *
from .image_processing_flava import *
from .image_processing_flava_fast import *
from .modeling_flava import *
from .processing_flava import *
else:

View File

@ -0,0 +1,549 @@
# coding=utf-8
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Fast Image processor class for Flava."""
import math
import random
from functools import lru_cache
from typing import Any, Dict, Iterable, Optional, Tuple, Union
from ...image_processing_utils_fast import (
BASE_IMAGE_PROCESSOR_FAST_DOCSTRING,
BASE_IMAGE_PROCESSOR_FAST_DOCSTRING_PREPROCESS,
BaseImageProcessorFast,
BatchFeature,
DefaultFastImageProcessorKwargs,
get_size_dict,
)
from ...image_transforms import ChannelDimension, group_images_by_shape, reorder_images
from ...image_utils import ImageInput, PILImageResampling, SizeDict
from ...processing_utils import Unpack
from ...utils import (
TensorType,
add_start_docstrings,
is_torch_available,
is_torchvision_available,
is_torchvision_v2_available,
)
from .image_processing_flava import (
FLAVA_CODEBOOK_MEAN,
FLAVA_CODEBOOK_STD,
FLAVA_IMAGE_MEAN,
FLAVA_IMAGE_STD,
LOGIT_LAPLACE_EPS,
)
if is_torch_available():
import torch
if is_torchvision_available():
from ...image_utils import pil_torch_interpolation_mapping
if is_torchvision_v2_available():
from torchvision.transforms.v2 import functional as F
else:
from torchvision.transforms import functional as F
class FlavaMaskingGenerator:
def __init__(
self,
input_size: Union[int, Tuple[int, int]] = 14,
total_mask_patches: int = 75,
mask_group_max_patches: Optional[int] = None,
mask_group_min_patches: int = 16,
mask_group_min_aspect_ratio: Optional[float] = 0.3,
mask_group_max_aspect_ratio: float = None,
):
if not isinstance(input_size, tuple):
input_size = (input_size,) * 2
self.height, self.width = input_size
self.num_patches = self.height * self.width
self.total_mask_patches = total_mask_patches
self.mask_group_min_patches = mask_group_min_patches
self.mask_group_max_patches = total_mask_patches if mask_group_max_patches is None else mask_group_max_patches
mask_group_max_aspect_ratio = mask_group_max_aspect_ratio or 1 / mask_group_min_aspect_ratio
self.log_aspect_ratio = (math.log(mask_group_min_aspect_ratio), math.log(mask_group_max_aspect_ratio))
def __repr__(self):
repr_str = "MaskingGenerator(%d, %d -> [%d ~ %d], max = %d, %.3f ~ %.3f)" % (
self.height,
self.width,
self.mask_group_min_patches,
self.mask_group_max_patches,
self.total_mask_patches,
self.log_aspect_ratio[0],
self.log_aspect_ratio[1],
)
return repr_str
def get_shape(self):
return self.height, self.width
def _mask(self, mask, max_mask_patches):
delta = 0
for _attempt in range(10):
target_area = random.uniform(self.mask_group_min_patches, max_mask_patches)
aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio))
height = int(round(math.sqrt(target_area * aspect_ratio)))
width = int(round(math.sqrt(target_area / aspect_ratio)))
if width < self.width and height < self.height:
top = random.randint(0, self.height - height)
left = random.randint(0, self.width - width)
num_masked = mask[top : top + height, left : left + width].sum()
# Overlap
if 0 < height * width - num_masked <= max_mask_patches:
zeros_pos = mask[top : top + height, left : left + width] == 0
mask[top : top + height, left : left + width][zeros_pos] = 1
delta += zeros_pos.sum()
if delta > 0:
break
return delta
def __call__(self):
mask = torch.zeros(self.get_shape(), dtype=torch.int)
mask_count = 0
while mask_count < self.total_mask_patches:
max_mask_patches = self.total_mask_patches - mask_count
max_mask_patches = min(max_mask_patches, self.mask_group_max_patches)
delta = self._mask(mask, max_mask_patches)
if delta == 0:
break
else:
mask_count += delta
return mask
class FlavaFastImageProcessorKwargs(DefaultFastImageProcessorKwargs):
# Mask related params
return_image_mask: Optional[bool]
input_size_patches: Optional[int]
total_mask_patches: Optional[int]
mask_group_min_patches: Optional[int]
mask_group_max_patches: Optional[int]
mask_group_min_aspect_ratio: Optional[float]
mask_group_max_aspect_ratio: Optional[float]
# Codebook related params
return_codebook_pixels: Optional[bool]
codebook_do_resize: Optional[bool]
codebook_size: Optional[bool]
codebook_resample: Optional[int]
codebook_do_center_crop: Optional[bool]
codebook_crop_size: Optional[int]
codebook_do_rescale: Optional[bool]
codebook_rescale_factor: Optional[Union[int, float]]
codebook_do_map_pixels: Optional[bool]
codebook_do_normalize: Optional[bool]
codebook_image_mean: Optional[Union[float, Iterable[float]]]
codebook_image_std: Optional[Union[float, Iterable[float]]]
@add_start_docstrings(
"Constructs a fast Flava image processor.",
BASE_IMAGE_PROCESSOR_FAST_DOCSTRING,
"""
return_image_mask (`bool`, *optional*, defaults to `False`):
Whether to return the image mask. Can be overridden by the `return_image_mask` parameter in `preprocess`.
input_size_patches (`int`, *optional*, defaults to 14):
Number of patches in the image in height and width direction. 14x14 = 196 total patches. Can be overridden
by the `input_size_patches` parameter in `preprocess`.
total_mask_patches (`int`, *optional*, defaults to 75):
Total number of patches that should be masked. Can be overridden by the `total_mask_patches` parameter in
`preprocess`.
mask_group_min_patches (`int`, *optional*, defaults to 16):
Minimum number of patches that should be masked. Can be overridden by the `mask_group_min_patches`
parameter in `preprocess`.
mask_group_max_patches (`int`, *optional*):
Maximum number of patches that should be masked. Can be overridden by the `mask_group_max_patches`
parameter in `preprocess`.
mask_group_min_aspect_ratio (`float`, *optional*, defaults to 0.3):
Minimum aspect ratio of the mask window. Can be overridden by the `mask_group_min_aspect_ratio` parameter
in `preprocess`.
mask_group_max_aspect_ratio (`float`, *optional*):
Maximum aspect ratio of the mask window. Can be overridden by the `mask_group_max_aspect_ratio` parameter
in `preprocess`.
codebook_do_resize (`bool`, *optional*, defaults to `True`):
Whether to resize the input for codebook to a certain. Can be overridden by the `codebook_do_resize`
parameter in `preprocess`. `codebook_size`.
codebook_size (`Dict[str, int]`, *optional*, defaults to `{"height": 224, "width": 224}`):
Resize the input for codebook to the given size. Can be overridden by the `codebook_size` parameter in
`preprocess`.
codebook_resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.LANCZOS`):
Resampling filter to use if resizing the codebook image. Can be overridden by the `codebook_resample`
parameter in `preprocess`.
codebook_do_center_crop (`bool`, *optional*, defaults to `True`):
Whether to crop the input for codebook at the center. If the input size is smaller than
`codebook_crop_size` along any edge, the image is padded with 0's and then center cropped. Can be
overridden by the `codebook_do_center_crop` parameter in `preprocess`.
codebook_crop_size (`Dict[str, int]`, *optional*, defaults to `{"height": 224, "width": 224}`):
Desired output size for codebook input when applying center-cropping. Can be overridden by the
`codebook_crop_size` parameter in `preprocess`.
codebook_do_rescale (`bool`, *optional*, defaults to `True`):
Whether to rescale the input for codebook by the specified scale `codebook_rescale_factor`. Can be
overridden by the `codebook_do_rescale` parameter in `preprocess`.
codebook_rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
Defines the scale factor to use if rescaling the codebook image. Can be overridden by the
`codebook_rescale_factor` parameter in `preprocess`.
codebook_do_map_pixels (`bool`, *optional*, defaults to `True`):
Whether to map the pixel values of the codebook input to (1 - 2e)x + e. Can be overridden by the
`codebook_do_map_pixels` parameter in `preprocess`.
codebook_do_normalize (`bool`, *optional*, defaults to `True`):
Whether or not to normalize the input for codebook with `codebook_image_mean` and `codebook_image_std`. Can
be overridden by the `codebook_do_normalize` parameter in `preprocess`.
codebook_image_mean (`Optional[Union[float, Iterable[float]]]`, *optional*, defaults to `[0, 0, 0]`):
The sequence of means for each channel, to be used when normalizing images for codebook. Can be overridden
by the `codebook_image_mean` parameter in `preprocess`.
codebook_image_std (`Optional[Union[float, Iterable[float]]]`, *optional*, defaults to `[0.5, 0.5, 0.5]`):
The sequence of standard deviations for each channel, to be used when normalizing images for codebook. Can
be overridden by the `codebook_image_std` parameter in `preprocess`.
""",
)
class FlavaImageProcessorFast(BaseImageProcessorFast):
resample = PILImageResampling.BICUBIC
image_mean = FLAVA_IMAGE_MEAN
image_std = FLAVA_IMAGE_STD
size = {"height": 224, "width": 224}
crop_size = {"height": 224, "width": 224}
do_resize = True
do_center_crop = True
do_rescale = True
do_normalize = True
# Mask related params
return_image_mask = False
input_size_patches = 14
total_mask_patches = 75
mask_group_min_patches = 16
mask_group_max_patches = None
mask_group_min_aspect_ratio = 0.3
mask_group_max_aspect_ratio = None
# Codebook related params
return_codebook_pixels = False
codebook_do_resize = True
codebook_size = {"height": 112, "width": 112}
# LANCZOS resample does not support torch Tensor. Use BICUBIC as closest alternative
codebook_resample = PILImageResampling.BICUBIC
codebook_do_center_crop = True
codebook_crop_size = {"height": 112, "width": 112}
codebook_do_rescale = True
codebook_rescale_factor = 1 / 255
codebook_do_map_pixels = True
codebook_do_normalize = True
codebook_image_mean = FLAVA_CODEBOOK_MEAN
codebook_image_std = FLAVA_CODEBOOK_STD
valid_kwargs = FlavaFastImageProcessorKwargs
def __init__(self, **kwargs: Unpack[FlavaFastImageProcessorKwargs]):
super().__init__(**kwargs)
@add_start_docstrings(
BASE_IMAGE_PROCESSOR_FAST_DOCSTRING_PREPROCESS,
"""
return_image_mask (`bool`, *optional*, defaults to `False`):
Whether to return the image mask. Can be overridden by the `return_image_mask` parameter in `preprocess`.
input_size_patches (`int`, *optional*, defaults to 14):
Number of patches in the image in height and width direction. 14x14 = 196 total patches. Can be overridden
by the `input_size_patches` parameter in `preprocess`.
total_mask_patches (`int`, *optional*, defaults to 75):
Total number of patches that should be masked. Can be overridden by the `total_mask_patches` parameter in
`preprocess`.
mask_group_min_patches (`int`, *optional*, defaults to 16):
Minimum number of patches that should be masked. Can be overridden by the `mask_group_min_patches`
parameter in `preprocess`.
mask_group_max_patches (`int`, *optional*):
Maximum number of patches that should be masked. Can be overridden by the `mask_group_max_patches`
parameter in `preprocess`.
mask_group_min_aspect_ratio (`float`, *optional*, defaults to 0.3):
Minimum aspect ratio of the mask window. Can be overridden by the `mask_group_min_aspect_ratio` parameter
in `preprocess`.
mask_group_max_aspect_ratio (`float`, *optional*):
Maximum aspect ratio of the mask window. Can be overridden by the `mask_group_max_aspect_ratio` parameter
in `preprocess`.
codebook_do_resize (`bool`, *optional*, defaults to `True`):
Whether to resize the input for codebook to a certain. Can be overridden by the `codebook_do_resize`
parameter in `preprocess`. `codebook_size`.
codebook_size (`Dict[str, int]`, *optional*, defaults to `{"height": 224, "width": 224}`):
Resize the input for codebook to the given size. Can be overridden by the `codebook_size` parameter in
`preprocess`.
codebook_resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.LANCZOS`):
Resampling filter to use if resizing the codebook image. Can be overridden by the `codebook_resample`
parameter in `preprocess`.
codebook_do_center_crop (`bool`, *optional*, defaults to `True`):
Whether to crop the input for codebook at the center. If the input size is smaller than
`codebook_crop_size` along any edge, the image is padded with 0's and then center cropped. Can be
overridden by the `codebook_do_center_crop` parameter in `preprocess`.
codebook_crop_size (`Dict[str, int]`, *optional*, defaults to `{"height": 224, "width": 224}`):
Desired output size for codebook input when applying center-cropping. Can be overridden by the
`codebook_crop_size` parameter in `preprocess`.
codebook_do_rescale (`bool`, *optional*, defaults to `True`):
Whether to rescale the input for codebook by the specified scale `codebook_rescale_factor`. Can be
overridden by the `codebook_do_rescale` parameter in `preprocess`.
codebook_rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
Defines the scale factor to use if rescaling the codebook image. Can be overridden by the
`codebook_rescale_factor` parameter in `preprocess`.
codebook_do_map_pixels (`bool`, *optional*, defaults to `True`):
Whether to map the pixel values of the codebook input to (1 - 2e)x + e. Can be overridden by the
`codebook_do_map_pixels` parameter in `preprocess`.
codebook_do_normalize (`bool`, *optional*, defaults to `True`):
Whether or not to normalize the input for codebook with `codebook_image_mean` and `codebook_image_std`. Can
be overridden by the `codebook_do_normalize` parameter in `preprocess`.
codebook_image_mean (`Optional[Union[float, Iterable[float]]]`, *optional*, defaults to `[0, 0, 0]`):
The sequence of means for each channel, to be used when normalizing images for codebook. Can be overridden
by the `codebook_image_mean` parameter in `preprocess`.
codebook_image_std (`Optional[Union[float, Iterable[float]]]`, *optional*, defaults to `[0.5, 0.5, 0.5]`):
The sequence of standard deviations for each channel, to be used when normalizing images for codebook. Can
be overridden by the `codebook_image_std` parameter in `preprocess`.
""",
)
def preprocess(self, images: ImageInput, **kwargs: Unpack[DefaultFastImageProcessorKwargs]) -> BatchFeature:
return super().preprocess(images, **kwargs)
@classmethod
def from_dict(cls, image_processor_dict: Dict[str, Any], **kwargs):
"""
Overrides the `from_dict` method from the base class to make sure parameters are updated if image processor is
created using from_dict and kwargs e.g. `FlavaImageProcessor.from_pretrained(checkpoint, codebook_size=600)`
"""
image_processor_dict = image_processor_dict.copy()
if "codebook_size" in kwargs:
image_processor_dict["codebook_size"] = kwargs.pop("codebook_size")
if "codebook_crop_size" in kwargs:
image_processor_dict["codebook_crop_size"] = kwargs.pop("codebook_crop_size")
return super().from_dict(image_processor_dict, **kwargs)
@lru_cache()
def masking_generator(
self,
input_size_patches,
total_mask_patches,
mask_group_min_patches,
mask_group_max_patches,
mask_group_min_aspect_ratio,
mask_group_max_aspect_ratio,
) -> FlavaMaskingGenerator:
return FlavaMaskingGenerator(
input_size=input_size_patches,
total_mask_patches=total_mask_patches,
mask_group_min_patches=mask_group_min_patches,
mask_group_max_patches=mask_group_max_patches,
mask_group_min_aspect_ratio=mask_group_min_aspect_ratio,
mask_group_max_aspect_ratio=mask_group_max_aspect_ratio,
)
def map_pixels(self, image: "torch.Tensor") -> "torch.Tensor":
return (1 - 2 * LOGIT_LAPLACE_EPS) * image + LOGIT_LAPLACE_EPS
def _further_process_kwargs(
self,
size: Optional[SizeDict] = None,
crop_size: Optional[SizeDict] = None,
default_to_square: Optional[bool] = None,
image_mean: Optional[Union[float, list[float]]] = None,
image_std: Optional[Union[float, list[float]]] = None,
codebook_size: Optional[SizeDict] = None,
codebook_crop_size: Optional[SizeDict] = None,
codebook_image_mean: Optional[Union[float, list[float]]] = None,
codebook_image_std: Optional[Union[float, list[float]]] = None,
codebook_resample: Optional[PILImageResampling] = None,
data_format: Optional[ChannelDimension] = None,
**kwargs,
) -> dict:
"""
Update kwargs that need further processing before being validated
Can be overridden by subclasses to customize the processing of kwargs.
"""
if kwargs is None:
kwargs = {}
if size is not None:
size = SizeDict(**get_size_dict(size=size, default_to_square=default_to_square))
if crop_size is not None:
crop_size = SizeDict(**get_size_dict(crop_size, param_name="crop_size"))
if isinstance(image_mean, list):
image_mean = tuple(image_mean)
if isinstance(image_std, list):
image_std = tuple(image_std)
if data_format is None:
data_format = ChannelDimension.FIRST
if codebook_size is not None:
codebook_size = SizeDict(**get_size_dict(size=codebook_size, default_to_square=default_to_square))
if codebook_crop_size is not None:
codebook_crop_size = SizeDict(**get_size_dict(codebook_crop_size, param_name="codebook_crop_size"))
if isinstance(codebook_image_mean, list):
codebook_image_mean = tuple(codebook_image_mean)
if isinstance(codebook_image_std, list):
codebook_image_std = tuple(codebook_image_std)
kwargs["size"] = size
kwargs["crop_size"] = crop_size
kwargs["default_to_square"] = default_to_square
kwargs["image_mean"] = image_mean
kwargs["image_std"] = image_std
kwargs["codebook_size"] = codebook_size
kwargs["codebook_crop_size"] = codebook_crop_size
kwargs["codebook_image_mean"] = codebook_image_mean
kwargs["codebook_image_std"] = codebook_image_std
kwargs["data_format"] = data_format
kwargs["codebook_interpolation"] = (
pil_torch_interpolation_mapping[codebook_resample]
if isinstance(codebook_resample, (PILImageResampling, int))
else codebook_resample
)
return kwargs
def _preprocess_image(
self,
images: list["torch.Tensor"],
do_resize: bool,
size: SizeDict,
interpolation: Optional["F.InterpolationMode"],
do_center_crop: bool,
crop_size: SizeDict,
do_rescale: bool,
rescale_factor: float,
do_normalize: bool,
do_map_pixels: bool,
image_mean: Optional[Union[float, list[float]]],
image_std: Optional[Union[float, list[float]]],
return_tensors: Optional[Union[str, TensorType]],
) -> "torch.Tensor":
# Group images by size for batched resizing
grouped_images, grouped_images_index = group_images_by_shape(images)
resized_images_grouped = {}
for shape, stacked_images in grouped_images.items():
if do_resize:
stacked_images = self.resize(image=stacked_images, size=size, interpolation=interpolation)
resized_images_grouped[shape] = stacked_images
resized_images = reorder_images(resized_images_grouped, grouped_images_index)
# Group images by size for further processing
# Needed in case do_resize is False, or resize returns images with different sizes
grouped_images, grouped_images_index = group_images_by_shape(resized_images)
processed_images_grouped = {}
for shape, stacked_images in grouped_images.items():
if do_center_crop:
stacked_images = self.center_crop(stacked_images, crop_size)
# Fused rescale and normalize
stacked_images = self.rescale_and_normalize(
stacked_images, do_rescale, rescale_factor, do_normalize, image_mean, image_std
)
if do_map_pixels:
stacked_images = self.map_pixels(image=stacked_images)
processed_images_grouped[shape] = stacked_images
processed_images = reorder_images(processed_images_grouped, grouped_images_index)
processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
return processed_images
def _preprocess(
self,
images: list["torch.Tensor"],
do_resize: bool,
size: SizeDict,
interpolation: Optional["F.InterpolationMode"],
do_center_crop: bool,
crop_size: SizeDict,
do_rescale: bool,
rescale_factor: float,
do_normalize: bool,
image_mean: Optional[Union[float, list[float]]],
image_std: Optional[Union[float, list[float]]],
# Mask related params
return_image_mask: Optional[bool],
input_size_patches: Optional[int],
total_mask_patches: Optional[int],
mask_group_min_patches: Optional[int],
mask_group_max_patches: Optional[int],
mask_group_min_aspect_ratio: Optional[float],
mask_group_max_aspect_ratio: Optional[float],
# Codebook related params
return_codebook_pixels: Optional[bool],
codebook_do_resize: Optional[bool],
codebook_size: Optional[SizeDict],
codebook_interpolation: Optional["F.InterpolationMode"],
codebook_do_center_crop: Optional[bool],
codebook_crop_size: Optional[SizeDict],
codebook_do_rescale: Optional[bool],
codebook_rescale_factor: Optional[float],
codebook_do_map_pixels: Optional[bool],
codebook_do_normalize: Optional[bool],
codebook_image_mean: Optional[Union[float, list[float]]],
codebook_image_std: Optional[Union[float, list[float]]],
return_tensors: Optional[Union[str, TensorType]],
**kwargs,
) -> BatchFeature:
processed_images = self._preprocess_image(
images=images,
do_resize=do_resize,
size=size,
interpolation=interpolation,
do_center_crop=do_center_crop,
crop_size=crop_size,
do_rescale=do_rescale,
rescale_factor=rescale_factor,
do_normalize=do_normalize,
do_map_pixels=False,
image_mean=image_mean,
image_std=image_std,
return_tensors=return_tensors,
)
data = {
"pixel_values": processed_images,
}
if return_codebook_pixels:
codebook_processed_images = self._preprocess_image(
images=images,
do_resize=codebook_do_resize,
size=codebook_size,
interpolation=codebook_interpolation,
do_center_crop=codebook_do_center_crop,
crop_size=codebook_crop_size,
do_rescale=codebook_do_rescale,
rescale_factor=codebook_rescale_factor,
do_normalize=codebook_do_normalize,
do_map_pixels=codebook_do_map_pixels,
image_mean=codebook_image_mean,
image_std=codebook_image_std,
return_tensors=return_tensors,
)
data["codebook_pixel_values"] = codebook_processed_images
if return_image_mask:
mask_generator = self.masking_generator(
input_size_patches=input_size_patches,
total_mask_patches=total_mask_patches,
mask_group_min_patches=mask_group_min_patches,
mask_group_max_patches=mask_group_max_patches,
mask_group_min_aspect_ratio=mask_group_min_aspect_ratio,
mask_group_max_aspect_ratio=mask_group_max_aspect_ratio,
)
masks = [mask_generator() for _ in range(len(images))]
masks = torch.stack(masks, dim=0) if return_tensors else masks
data["bool_masked_pos"] = masks
return BatchFeature(data=data, tensor_type=return_tensors)
__all__ = ["FlavaImageProcessorFast"]

View File

@ -16,9 +16,11 @@ import random
import unittest
import numpy as np
import requests
from PIL import Image
from transformers.testing_utils import require_torch, require_vision
from transformers.utils import is_torch_available, is_vision_available
from transformers.utils import is_torch_available, is_torchvision_available, is_vision_available
from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs
@ -30,6 +32,9 @@ if is_vision_available():
import PIL
from transformers import FlavaImageProcessor
if is_torchvision_available():
from transformers import FlavaImageProcessorFast
from transformers.image_utils import PILImageResampling
from transformers.models.flava.image_processing_flava import (
FLAVA_CODEBOOK_MEAN,
@ -105,7 +110,8 @@ class FlavaImageProcessingTester:
self.codebook_do_resize = codebook_do_resize
self.codebook_size = codebook_size
self.codebook_resample = codebook_resample if codebook_resample is not None else PILImageResampling.LANCZOS
# LANCZOS resample does not support torch Tensor. Use BICUBIC as closest alternative
self.codebook_resample = codebook_resample if codebook_resample is not None else PILImageResampling.BICUBIC
self.codebook_do_center_crop = codebook_do_center_crop
self.codebook_crop_size = codebook_crop_size
self.codebook_do_map_pixels = codebook_do_map_pixels
@ -171,6 +177,7 @@ class FlavaImageProcessingTester:
@require_vision
class FlavaImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
image_processing_class = FlavaImageProcessor if is_vision_available() else None
fast_image_processing_class = FlavaImageProcessorFast if is_torchvision_available() else None
maxDiff = None
def setUp(self):
@ -182,157 +189,161 @@ class FlavaImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
return self.image_processor_tester.prepare_image_processor_dict()
def test_image_processor_properties(self):
image_processing = self.image_processing_class(**self.image_processor_dict)
self.assertTrue(hasattr(image_processing, "image_mean"))
self.assertTrue(hasattr(image_processing, "image_std"))
self.assertTrue(hasattr(image_processing, "do_normalize"))
self.assertTrue(hasattr(image_processing, "do_resize"))
self.assertTrue(hasattr(image_processing, "resample"))
self.assertTrue(hasattr(image_processing, "crop_size"))
self.assertTrue(hasattr(image_processing, "do_center_crop"))
self.assertTrue(hasattr(image_processing, "do_rescale"))
self.assertTrue(hasattr(image_processing, "rescale_factor"))
self.assertTrue(hasattr(image_processing, "masking_generator"))
self.assertTrue(hasattr(image_processing, "codebook_do_resize"))
self.assertTrue(hasattr(image_processing, "codebook_size"))
self.assertTrue(hasattr(image_processing, "codebook_resample"))
self.assertTrue(hasattr(image_processing, "codebook_do_center_crop"))
self.assertTrue(hasattr(image_processing, "codebook_crop_size"))
self.assertTrue(hasattr(image_processing, "codebook_do_map_pixels"))
self.assertTrue(hasattr(image_processing, "codebook_do_normalize"))
self.assertTrue(hasattr(image_processing, "codebook_image_mean"))
self.assertTrue(hasattr(image_processing, "codebook_image_std"))
for image_processing_class in self.image_processor_list:
image_processing = image_processing_class(**self.image_processor_dict)
self.assertTrue(hasattr(image_processing, "image_mean"))
self.assertTrue(hasattr(image_processing, "image_std"))
self.assertTrue(hasattr(image_processing, "do_normalize"))
self.assertTrue(hasattr(image_processing, "do_resize"))
self.assertTrue(hasattr(image_processing, "resample"))
self.assertTrue(hasattr(image_processing, "crop_size"))
self.assertTrue(hasattr(image_processing, "do_center_crop"))
self.assertTrue(hasattr(image_processing, "do_rescale"))
self.assertTrue(hasattr(image_processing, "rescale_factor"))
self.assertTrue(hasattr(image_processing, "masking_generator"))
self.assertTrue(hasattr(image_processing, "codebook_do_resize"))
self.assertTrue(hasattr(image_processing, "codebook_size"))
self.assertTrue(hasattr(image_processing, "codebook_resample"))
self.assertTrue(hasattr(image_processing, "codebook_do_center_crop"))
self.assertTrue(hasattr(image_processing, "codebook_crop_size"))
self.assertTrue(hasattr(image_processing, "codebook_do_map_pixels"))
self.assertTrue(hasattr(image_processing, "codebook_do_normalize"))
self.assertTrue(hasattr(image_processing, "codebook_image_mean"))
self.assertTrue(hasattr(image_processing, "codebook_image_std"))
def test_image_processor_from_dict_with_kwargs(self):
image_processor = self.image_processing_class.from_dict(self.image_processor_dict)
self.assertEqual(image_processor.size, {"height": 224, "width": 224})
self.assertEqual(image_processor.crop_size, {"height": 224, "width": 224})
self.assertEqual(image_processor.codebook_size, {"height": 112, "width": 112})
self.assertEqual(image_processor.codebook_crop_size, {"height": 112, "width": 112})
for image_processing_class in self.image_processor_list:
image_processor = image_processing_class.from_dict(self.image_processor_dict)
self.assertEqual(image_processor.size, {"height": 224, "width": 224})
self.assertEqual(image_processor.crop_size, {"height": 224, "width": 224})
self.assertEqual(image_processor.codebook_size, {"height": 112, "width": 112})
self.assertEqual(image_processor.codebook_crop_size, {"height": 112, "width": 112})
image_processor = self.image_processing_class.from_dict(
self.image_processor_dict, size=42, crop_size=84, codebook_size=33, codebook_crop_size=66
)
self.assertEqual(image_processor.size, {"height": 42, "width": 42})
self.assertEqual(image_processor.crop_size, {"height": 84, "width": 84})
self.assertEqual(image_processor.codebook_size, {"height": 33, "width": 33})
self.assertEqual(image_processor.codebook_crop_size, {"height": 66, "width": 66})
image_processor = self.image_processing_class.from_dict(
self.image_processor_dict, size=42, crop_size=84, codebook_size=33, codebook_crop_size=66
)
self.assertEqual(image_processor.size, {"height": 42, "width": 42})
self.assertEqual(image_processor.crop_size, {"height": 84, "width": 84})
self.assertEqual(image_processor.codebook_size, {"height": 33, "width": 33})
self.assertEqual(image_processor.codebook_crop_size, {"height": 66, "width": 66})
def test_call_pil(self):
# Initialize image_processing
image_processing = self.image_processing_class(**self.image_processor_dict)
# create random PIL images
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False)
for image in image_inputs:
self.assertIsInstance(image, PIL.Image.Image)
for image_processing_class in self.image_processor_list:
# Initialize image_processing
image_processing = image_processing_class(**self.image_processor_dict)
# create random PIL images
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False)
for image in image_inputs:
self.assertIsInstance(image, PIL.Image.Image)
# Test not batched input
encoded_images = image_processing(image_inputs[0], return_tensors="pt")
# Test not batched input
encoded_images = image_processing(image_inputs[0], return_tensors="pt")
# Test no bool masked pos
self.assertFalse("bool_masked_pos" in encoded_images)
# Test no bool masked pos
self.assertFalse("bool_masked_pos" in encoded_images)
expected_height, expected_width = self.image_processor_tester.get_expected_image_size()
expected_height, expected_width = self.image_processor_tester.get_expected_image_size()
self.assertEqual(
encoded_images.pixel_values.shape,
(1, self.image_processor_tester.num_channels, expected_height, expected_width),
)
self.assertEqual(
encoded_images.pixel_values.shape,
(1, self.image_processor_tester.num_channels, expected_height, expected_width),
)
# Test batched
encoded_images = image_processing(image_inputs, return_tensors="pt")
expected_height, expected_width = self.image_processor_tester.get_expected_image_size()
# Test batched
encoded_images = image_processing(image_inputs, return_tensors="pt")
expected_height, expected_width = self.image_processor_tester.get_expected_image_size()
# Test no bool masked pos
self.assertFalse("bool_masked_pos" in encoded_images)
# Test no bool masked pos
self.assertFalse("bool_masked_pos" in encoded_images)
self.assertEqual(
encoded_images.pixel_values.shape,
(
self.image_processor_tester.batch_size,
self.image_processor_tester.num_channels,
expected_height,
expected_width,
),
)
self.assertEqual(
encoded_images.pixel_values.shape,
(
self.image_processor_tester.batch_size,
self.image_processor_tester.num_channels,
expected_height,
expected_width,
),
)
def _test_call_framework(self, instance_class, prepare_kwargs):
# Initialize image_processing
image_processing = self.image_processing_class(**self.image_processor_dict)
# create random tensors
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, **prepare_kwargs)
for image in image_inputs:
self.assertIsInstance(image, instance_class)
for image_processing_class in self.image_processor_list:
# Initialize image_processing
image_processing = image_processing_class(**self.image_processor_dict)
# create random tensors
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, **prepare_kwargs)
for image in image_inputs:
self.assertIsInstance(image, instance_class)
# Test not batched input
encoded_images = image_processing(image_inputs[0], return_tensors="pt")
# Test not batched input
encoded_images = image_processing(image_inputs[0], return_tensors="pt")
expected_height, expected_width = self.image_processor_tester.get_expected_image_size()
self.assertEqual(
encoded_images.pixel_values.shape,
(1, self.image_processor_tester.num_channels, expected_height, expected_width),
)
expected_height, expected_width = self.image_processor_tester.get_expected_image_size()
self.assertEqual(
encoded_images.pixel_values.shape,
(1, self.image_processor_tester.num_channels, expected_height, expected_width),
)
encoded_images = image_processing(image_inputs, return_image_mask=True, return_tensors="pt")
encoded_images = image_processing(image_inputs, return_image_mask=True, return_tensors="pt")
expected_height, expected_width = self.image_processor_tester.get_expected_image_size()
self.assertEqual(
encoded_images.pixel_values.shape,
(
self.image_processor_tester.batch_size,
self.image_processor_tester.num_channels,
expected_height,
expected_width,
),
)
expected_height, expected_width = self.image_processor_tester.get_expected_image_size()
self.assertEqual(
encoded_images.pixel_values.shape,
(
self.image_processor_tester.batch_size,
self.image_processor_tester.num_channels,
expected_height,
expected_width,
),
)
expected_height, expected_width = self.image_processor_tester.get_expected_mask_size()
self.assertEqual(
encoded_images.bool_masked_pos.shape,
(
self.image_processor_tester.batch_size,
expected_height,
expected_width,
),
)
expected_height, expected_width = self.image_processor_tester.get_expected_mask_size()
self.assertEqual(
encoded_images.bool_masked_pos.shape,
(
self.image_processor_tester.batch_size,
expected_height,
expected_width,
),
)
# Test batched
encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values
# Test batched
encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values
expected_height, expected_width = self.image_processor_tester.get_expected_image_size()
self.assertEqual(
encoded_images.shape,
(
self.image_processor_tester.batch_size,
self.image_processor_tester.num_channels,
expected_height,
expected_width,
),
)
expected_height, expected_width = self.image_processor_tester.get_expected_image_size()
self.assertEqual(
encoded_images.shape,
(
self.image_processor_tester.batch_size,
self.image_processor_tester.num_channels,
expected_height,
expected_width,
),
)
# Test masking
encoded_images = image_processing(image_inputs, return_image_mask=True, return_tensors="pt")
# Test masking
encoded_images = image_processing(image_inputs, return_image_mask=True, return_tensors="pt")
expected_height, expected_width = self.image_processor_tester.get_expected_image_size()
self.assertEqual(
encoded_images.pixel_values.shape,
(
self.image_processor_tester.batch_size,
self.image_processor_tester.num_channels,
expected_height,
expected_width,
),
)
expected_height, expected_width = self.image_processor_tester.get_expected_image_size()
self.assertEqual(
encoded_images.pixel_values.shape,
(
self.image_processor_tester.batch_size,
self.image_processor_tester.num_channels,
expected_height,
expected_width,
),
)
expected_height, expected_width = self.image_processor_tester.get_expected_mask_size()
self.assertEqual(
encoded_images.bool_masked_pos.shape,
(
self.image_processor_tester.batch_size,
expected_height,
expected_width,
),
)
expected_height, expected_width = self.image_processor_tester.get_expected_mask_size()
self.assertEqual(
encoded_images.bool_masked_pos.shape,
(
self.image_processor_tester.batch_size,
expected_height,
expected_width,
),
)
def test_call_numpy(self):
self._test_call_framework(np.ndarray, prepare_kwargs={"numpify": True})
@ -346,40 +357,76 @@ class FlavaImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
self._test_call_framework(torch.Tensor, prepare_kwargs={"torchify": True})
def test_masking(self):
# Initialize image_processing
random.seed(1234)
image_processing = self.image_processing_class(**self.image_processor_dict)
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, torchify=True)
for image_processing_class in self.image_processor_list:
# Initialize image_processing
random.seed(1234)
image_processing = image_processing_class(**self.image_processor_dict)
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, torchify=True)
# Test not batched input
encoded_images = image_processing(image_inputs[0], return_image_mask=True, return_tensors="pt")
self.assertEqual(encoded_images.bool_masked_pos.sum().item(), 75)
# Test not batched input
encoded_images = image_processing(image_inputs[0], return_image_mask=True, return_tensors="pt")
self.assertEqual(encoded_images.bool_masked_pos.sum().item(), 75)
def test_codebook_pixels(self):
# Initialize image_processing
image_processing = self.image_processing_class(**self.image_processor_dict)
# create random PIL images
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False)
for image in image_inputs:
self.assertIsInstance(image, PIL.Image.Image)
for image_processing_class in self.image_processor_list:
# Initialize image_processing
image_processing = image_processing_class(**self.image_processor_dict)
# create random PIL images
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False)
for image in image_inputs:
self.assertIsInstance(image, PIL.Image.Image)
# Test not batched input
encoded_images = image_processing(image_inputs[0], return_codebook_pixels=True, return_tensors="pt")
expected_height, expected_width = self.image_processor_tester.get_expected_codebook_image_size()
self.assertEqual(
encoded_images.codebook_pixel_values.shape,
(1, self.image_processor_tester.num_channels, expected_height, expected_width),
# Test not batched input
encoded_images = image_processing(image_inputs[0], return_codebook_pixels=True, return_tensors="pt")
expected_height, expected_width = self.image_processor_tester.get_expected_codebook_image_size()
self.assertEqual(
encoded_images.codebook_pixel_values.shape,
(1, self.image_processor_tester.num_channels, expected_height, expected_width),
)
# Test batched
encoded_images = image_processing(image_inputs, return_codebook_pixels=True, return_tensors="pt")
expected_height, expected_width = self.image_processor_tester.get_expected_codebook_image_size()
self.assertEqual(
encoded_images.codebook_pixel_values.shape,
(
self.image_processor_tester.batch_size,
self.image_processor_tester.num_channels,
expected_height,
expected_width,
),
)
@require_vision
@require_torch
def test_slow_fast_equivalence(self):
if not self.test_slow_image_processor or not self.test_fast_image_processor:
self.skipTest(reason="Skipping slow/fast equivalence test")
if self.image_processing_class is None or self.fast_image_processing_class is None:
self.skipTest(reason="Skipping slow/fast equivalence test as one of the image processors is not defined")
dummy_image = Image.open(
requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw
)
image_processor_slow = self.image_processing_class(**self.image_processor_dict)
image_processor_fast = self.fast_image_processing_class(**self.image_processor_dict)
encoding_slow = image_processor_slow(
dummy_image, return_tensors="pt", return_codebook_pixels=True, return_image_mask=True
)
encoding_fast = image_processor_fast(
dummy_image, return_tensors="pt", return_codebook_pixels=True, return_image_mask=True
)
self.assertTrue(torch.allclose(encoding_slow.pixel_values, encoding_fast.pixel_values, atol=1e-1))
self.assertLessEqual(
torch.mean(torch.abs(encoding_slow.pixel_values - encoding_fast.pixel_values)).item(), 1e-3
)
# Test batched
encoded_images = image_processing(image_inputs, return_codebook_pixels=True, return_tensors="pt")
expected_height, expected_width = self.image_processor_tester.get_expected_codebook_image_size()
self.assertEqual(
encoded_images.codebook_pixel_values.shape,
(
self.image_processor_tester.batch_size,
self.image_processor_tester.num_channels,
expected_height,
expected_width,
),
self.assertTrue(
torch.allclose(encoding_slow.codebook_pixel_values, encoding_fast.codebook_pixel_values, atol=1e-1)
)
self.assertLessEqual(
torch.mean(torch.abs(encoding_slow.codebook_pixel_values - encoding_fast.codebook_pixel_values)).item(),
1e-3,
)