Add args support for fast image processors (#37018)

* add args support to fast image processors

* add comment for clarity

* fix-copies

* Handle child class args passed as both args or kwargs in call and preprocess functions

* revert support args passed as kwargs in overwritten preprocess

* fix image processor errors
This commit is contained in:
Yoni Gozlan 2025-05-16 12:01:46 -04:00 committed by GitHub
parent d69945e5fc
commit 0ba95564b7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 68 additions and 71 deletions

View File

@ -18,11 +18,7 @@ from typing import Any, Optional, TypedDict, Union
import numpy as np
from .image_processing_utils import (
BaseImageProcessor,
BatchFeature,
get_size_dict,
)
from .image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
from .image_transforms import (
convert_to_rgb,
get_resize_output_image_size,
@ -233,6 +229,9 @@ class BaseImageProcessorFast(BaseImageProcessor):
else:
setattr(self, key, getattr(self, key, None))
# get valid kwargs names
self._valid_kwargs_names = list(self.valid_kwargs.__annotations__.keys())
def resize(
self,
image: "torch.Tensor",
@ -566,12 +565,16 @@ class BaseImageProcessorFast(BaseImageProcessor):
data_format=data_format,
)
def __call__(self, images: ImageInput, *args, **kwargs: Unpack[DefaultFastImageProcessorKwargs]) -> BatchFeature:
return self.preprocess(images, *args, **kwargs)
@auto_docstring
def preprocess(self, images: ImageInput, **kwargs: Unpack[DefaultFastImageProcessorKwargs]) -> BatchFeature:
validate_kwargs(captured_kwargs=kwargs.keys(), valid_processor_keys=self.valid_kwargs.__annotations__.keys())
def preprocess(self, images: ImageInput, *args, **kwargs: Unpack[DefaultFastImageProcessorKwargs]) -> BatchFeature:
# args are not validated, but their order in the `preprocess` and `_preprocess` signatures must be the same
validate_kwargs(captured_kwargs=kwargs.keys(), valid_processor_keys=self._valid_kwargs_names)
# Set default kwargs from self. This ensures that if a kwarg is not provided
# by the user, it gets its default value from the instance, or is set to None.
for kwarg_name in self.valid_kwargs.__annotations__:
for kwarg_name in self._valid_kwargs_names:
kwargs.setdefault(kwarg_name, getattr(self, kwarg_name, None))
# Extract parameters that are only used for preparing the input images
@ -603,7 +606,7 @@ class BaseImageProcessorFast(BaseImageProcessor):
kwargs.pop("default_to_square")
kwargs.pop("data_format")
return self._preprocess(images=images, **kwargs)
return self._preprocess(images, *args, **kwargs)
def _preprocess(
self,
@ -651,6 +654,7 @@ class BaseImageProcessorFast(BaseImageProcessor):
def to_dict(self):
encoder_dict = super().to_dict()
encoder_dict.pop("_valid_processor_keys", None)
encoder_dict.pop("_valid_kwargs_names", None)
return encoder_dict

View File

@ -587,8 +587,6 @@ class ConditionalDetrImageProcessorFast(BaseImageProcessorFast):
- "segments_info" (`List[Dict]`): List of segments for an image. Each segment should be a dictionary.
An image can have no segments, in which case the list should be empty.
- "file_name" (`str`): The file name of the image.
format (`str`, *optional*, defaults to `AnnotationFormat.COCO_DETECTION`):
Data format of the annotations. One of "coco_detection" or "coco_panoptic".
masks_path (`str` or `pathlib.Path`, *optional*):
Path to the directory containing the segmentation masks.
"""
@ -606,19 +604,17 @@ class ConditionalDetrImageProcessorFast(BaseImageProcessorFast):
)
kwargs["size"] = kwargs.pop("max_size")
return super().preprocess(images, annotations=annotations, masks_path=masks_path, **kwargs)
return super().preprocess(images, annotations, masks_path, **kwargs)
def _preprocess(
self,
images: List["torch.Tensor"],
annotations: Optional[Union[AnnotationType, List[AnnotationType]]],
return_segmentation_masks: bool,
masks_path: Optional[Union[str, pathlib.Path]],
return_segmentation_masks: bool,
do_resize: bool,
size: SizeDict,
interpolation: Optional["F.InterpolationMode"],
do_center_crop: bool,
crop_size: SizeDict,
do_rescale: bool,
rescale_factor: float,
do_normalize: bool,
@ -629,6 +625,7 @@ class ConditionalDetrImageProcessorFast(BaseImageProcessorFast):
pad_size: Optional[Dict[str, int]],
format: Optional[Union[str, AnnotationFormat]],
return_tensors: Optional[Union[str, TensorType]],
**kwargs,
) -> BatchFeature:
"""
Preprocess an image or a batch of images so that it can be used by the model.

View File

@ -578,8 +578,6 @@ class DeformableDetrImageProcessorFast(BaseImageProcessorFast):
- "segments_info" (`List[Dict]`): List of segments for an image. Each segment should be a dictionary.
An image can have no segments, in which case the list should be empty.
- "file_name" (`str`): The file name of the image.
format (`str`, *optional*, defaults to `AnnotationFormat.COCO_DETECTION`):
Data format of the annotations. One of "coco_detection" or "coco_panoptic".
masks_path (`str` or `pathlib.Path`, *optional*):
Path to the directory containing the segmentation masks.
"""
@ -597,19 +595,17 @@ class DeformableDetrImageProcessorFast(BaseImageProcessorFast):
)
kwargs["size"] = kwargs.pop("max_size")
return super().preprocess(images, annotations=annotations, masks_path=masks_path, **kwargs)
return super().preprocess(images, annotations, masks_path, **kwargs)
def _preprocess(
self,
images: List["torch.Tensor"],
annotations: Optional[Union[AnnotationType, List[AnnotationType]]],
return_segmentation_masks: bool,
masks_path: Optional[Union[str, pathlib.Path]],
return_segmentation_masks: bool,
do_resize: bool,
size: SizeDict,
interpolation: Optional["F.InterpolationMode"],
do_center_crop: bool,
crop_size: SizeDict,
do_rescale: bool,
rescale_factor: float,
do_normalize: bool,
@ -620,6 +616,7 @@ class DeformableDetrImageProcessorFast(BaseImageProcessorFast):
pad_size: Optional[Dict[str, int]],
format: Optional[Union[str, AnnotationFormat]],
return_tensors: Optional[Union[str, TensorType]],
**kwargs,
) -> BatchFeature:
"""
Preprocess an image or a batch of images so that it can be used by the model.

View File

@ -28,11 +28,7 @@ from ...image_processing_utils_fast import (
get_max_height_width,
safe_squeeze,
)
from ...image_transforms import (
center_to_corners_format,
corners_to_center_format,
id_to_rgb,
)
from ...image_transforms import center_to_corners_format, corners_to_center_format, id_to_rgb
from ...image_utils import (
IMAGENET_DEFAULT_MEAN,
IMAGENET_DEFAULT_STD,
@ -603,8 +599,6 @@ class DetrImageProcessorFast(BaseImageProcessorFast):
- "segments_info" (`List[Dict]`): List of segments for an image. Each segment should be a dictionary.
An image can have no segments, in which case the list should be empty.
- "file_name" (`str`): The file name of the image.
format (`str`, *optional*, defaults to `AnnotationFormat.COCO_DETECTION`):
Data format of the annotations. One of "coco_detection" or "coco_panoptic".
masks_path (`str` or `pathlib.Path`, *optional*):
Path to the directory containing the segmentation masks.
"""
@ -622,19 +616,17 @@ class DetrImageProcessorFast(BaseImageProcessorFast):
)
kwargs["size"] = kwargs.pop("max_size")
return super().preprocess(images, annotations=annotations, masks_path=masks_path, **kwargs)
return super().preprocess(images, annotations, masks_path, **kwargs)
def _preprocess(
self,
images: List["torch.Tensor"],
annotations: Optional[Union[AnnotationType, List[AnnotationType]]],
return_segmentation_masks: bool,
masks_path: Optional[Union[str, pathlib.Path]],
return_segmentation_masks: bool,
do_resize: bool,
size: SizeDict,
interpolation: Optional["F.InterpolationMode"],
do_center_crop: bool,
crop_size: SizeDict,
do_rescale: bool,
rescale_factor: float,
do_normalize: bool,
@ -645,6 +637,7 @@ class DetrImageProcessorFast(BaseImageProcessorFast):
pad_size: Optional[Dict[str, int]],
format: Optional[Union[str, AnnotationFormat]],
return_tensors: Optional[Union[str, TensorType]],
**kwargs,
) -> BatchFeature:
"""
Preprocess an image or a batch of images so that it can be used by the model.

View File

@ -609,8 +609,6 @@ class GroundingDinoImageProcessorFast(BaseImageProcessorFast):
- "segments_info" (`List[Dict]`): List of segments for an image. Each segment should be a dictionary.
An image can have no segments, in which case the list should be empty.
- "file_name" (`str`): The file name of the image.
format (`str`, *optional*, defaults to `AnnotationFormat.COCO_DETECTION`):
Data format of the annotations. One of "coco_detection" or "coco_panoptic".
masks_path (`str` or `pathlib.Path`, *optional*):
Path to the directory containing the segmentation masks.
"""
@ -628,19 +626,17 @@ class GroundingDinoImageProcessorFast(BaseImageProcessorFast):
)
kwargs["size"] = kwargs.pop("max_size")
return super().preprocess(images, annotations=annotations, masks_path=masks_path, **kwargs)
return super().preprocess(images, annotations, masks_path, **kwargs)
def _preprocess(
self,
images: List["torch.Tensor"],
annotations: Optional[Union[AnnotationType, List[AnnotationType]]],
return_segmentation_masks: bool,
masks_path: Optional[Union[str, pathlib.Path]],
return_segmentation_masks: bool,
do_resize: bool,
size: SizeDict,
interpolation: Optional["F.InterpolationMode"],
do_center_crop: bool,
crop_size: SizeDict,
do_rescale: bool,
rescale_factor: float,
do_normalize: bool,
@ -651,6 +647,7 @@ class GroundingDinoImageProcessorFast(BaseImageProcessorFast):
pad_size: Optional[Dict[str, int]],
format: Optional[Union[str, AnnotationFormat]],
return_tensors: Optional[Union[str, TensorType]],
**kwargs,
) -> BatchFeature:
"""
Preprocess an image or a batch of images so that it can be used by the model.

View File

@ -26,11 +26,7 @@ from ...image_processing_utils_fast import (
group_images_by_shape,
reorder_images,
)
from ...image_utils import (
ImageInput,
PILImageResampling,
SizeDict,
)
from ...image_utils import ImageInput, PILImageResampling, SizeDict
from ...processing_utils import Unpack
from ...utils import (
TensorType,
@ -320,7 +316,7 @@ def get_best_fit(
else:
optimal_canvas = chosen_canvas[0]
return tuple(optimal_canvas.tolist())
return optimal_canvas
class Llama4ImageProcessorKwargs(DefaultFastImageProcessorKwargs):
@ -356,6 +352,8 @@ class Llama4ImageProcessorFast(BaseImageProcessorFast):
def __init__(self, **kwargs: Unpack[Llama4ImageProcessorKwargs]):
super().__init__(**kwargs)
# Disable compilation here as conversion to bfloat16 causes differences in the output of the compiled and non-compiled versions
@torch.compiler.disable
def rescale_and_normalize(
self,
images: "torch.Tensor",
@ -399,7 +397,7 @@ class Llama4ImageProcessorFast(BaseImageProcessorFast):
**kwargs,
) -> BatchFeature:
possible_resolutions = find_supported_resolutions(max_num_chunks=max_patches, patch_size=size)
possible_resolutions = torch.tensor(possible_resolutions)
possible_resolutions = torch.tensor(possible_resolutions, device=images[0].device)
# process images by batch, grouped by shape
grouped_images, grouped_images_index = group_images_by_shape(images)
grouped_processed_images = {}
@ -438,7 +436,9 @@ class Llama4ImageProcessorFast(BaseImageProcessorFast):
# split into tiles
processed_images = split_to_tiles(processed_images, ratio_h, ratio_w)
grouped_processed_images[shape] = processed_images
grouped_aspect_ratios[shape] = torch.tensor([[ratio_h, ratio_w]] * stacked_images.shape[0])
grouped_aspect_ratios[shape] = torch.tensor(
[[ratio_h, ratio_w]] * stacked_images.shape[0], device=images[0].device
)
# add a global tile to the processed tile if there are more than one tile
if ratio_h * ratio_w > 1:

View File

@ -397,24 +397,20 @@ class RTDetrImageProcessorFast(BaseImageProcessorFast):
- "segments_info" (`List[Dict]`): List of segments for an image. Each segment should be a dictionary.
An image can have no segments, in which case the list should be empty.
- "file_name" (`str`): The file name of the image.
format (`str`, *optional*, defaults to `AnnotationFormat.COCO_DETECTION`):
Data format of the annotations. One of "coco_detection" or "coco_panoptic".
masks_path (`str` or `pathlib.Path`, *optional*):
Path to the directory containing the segmentation masks.
"""
return super().preprocess(images, annotations=annotations, masks_path=masks_path, **kwargs)
return super().preprocess(images, annotations, masks_path, **kwargs)
def _preprocess(
self,
images: List["torch.Tensor"],
annotations: Optional[Union[AnnotationType, List[AnnotationType]]],
return_segmentation_masks: bool,
masks_path: Optional[Union[str, pathlib.Path]],
return_segmentation_masks: bool,
do_resize: bool,
size: SizeDict,
interpolation: Optional["F.InterpolationMode"],
do_center_crop: bool,
crop_size: SizeDict,
do_rescale: bool,
rescale_factor: float,
do_normalize: bool,
@ -425,6 +421,7 @@ class RTDetrImageProcessorFast(BaseImageProcessorFast):
pad_size: Optional[Dict[str, int]],
format: Optional[Union[str, AnnotationFormat]],
return_tensors: Optional[Union[str, TensorType]],
**kwargs,
) -> BatchFeature:
"""
Preprocess an image or a batch of images so that it can be used by the model.

View File

@ -137,7 +137,7 @@ class RTDetrImageProcessorFast(DetrImageProcessorFast, BaseImageProcessorFast):
masks_path: Optional[Union[str, pathlib.Path]] = None,
**kwargs: Unpack[RTDetrFastImageProcessorKwargs],
) -> BatchFeature:
return BaseImageProcessorFast().preprocess(images, annotations=annotations, masks_path=masks_path, **kwargs)
return BaseImageProcessorFast().preprocess(images, annotations, masks_path, **kwargs)
def prepare_annotation(
self,
@ -163,13 +163,11 @@ class RTDetrImageProcessorFast(DetrImageProcessorFast, BaseImageProcessorFast):
self,
images: List["torch.Tensor"],
annotations: Optional[Union[AnnotationType, List[AnnotationType]]],
return_segmentation_masks: bool,
masks_path: Optional[Union[str, pathlib.Path]],
return_segmentation_masks: bool,
do_resize: bool,
size: SizeDict,
interpolation: Optional["F.InterpolationMode"],
do_center_crop: bool,
crop_size: SizeDict,
do_rescale: bool,
rescale_factor: float,
do_normalize: bool,
@ -180,6 +178,7 @@ class RTDetrImageProcessorFast(DetrImageProcessorFast, BaseImageProcessorFast):
pad_size: Optional[Dict[str, int]],
format: Optional[Union[str, AnnotationFormat]],
return_tensors: Optional[Union[str, TensorType]],
**kwargs,
) -> BatchFeature:
"""
Preprocess an image or a batch of images so that it can be used by the model.

View File

@ -131,7 +131,7 @@ class VitMatteImageProcessorFast(BaseImageProcessorFast):
kwargs.pop("size")
kwargs.pop("crop_size")
return self._preprocess(images=images, trimaps=trimaps, **kwargs)
return self._preprocess(images, trimaps, **kwargs)
def _prepare_input_trimaps(
self, trimaps: ImageInput, device: Optional["torch.device"] = None

View File

@ -626,8 +626,6 @@ class YolosImageProcessorFast(BaseImageProcessorFast):
- "segments_info" (`List[Dict]`): List of segments for an image. Each segment should be a dictionary.
An image can have no segments, in which case the list should be empty.
- "file_name" (`str`): The file name of the image.
format (`str`, *optional*, defaults to `AnnotationFormat.COCO_DETECTION`):
Data format of the annotations. One of "coco_detection" or "coco_panoptic".
masks_path (`str` or `pathlib.Path`, *optional*):
Path to the directory containing the segmentation masks.
"""
@ -645,19 +643,17 @@ class YolosImageProcessorFast(BaseImageProcessorFast):
)
kwargs["size"] = kwargs.pop("max_size")
return super().preprocess(images, annotations=annotations, masks_path=masks_path, **kwargs)
return super().preprocess(images, annotations, masks_path, **kwargs)
def _preprocess(
self,
images: List["torch.Tensor"],
annotations: Optional[Union[AnnotationType, List[AnnotationType]]],
return_segmentation_masks: bool,
masks_path: Optional[Union[str, pathlib.Path]],
return_segmentation_masks: bool,
do_resize: bool,
size: SizeDict,
interpolation: Optional["F.InterpolationMode"],
do_center_crop: bool,
crop_size: SizeDict,
do_rescale: bool,
rescale_factor: float,
do_normalize: bool,
@ -668,6 +664,7 @@ class YolosImageProcessorFast(BaseImageProcessorFast):
pad_size: Optional[Dict[str, int]],
format: Optional[Union[str, AnnotationFormat]],
return_tensors: Optional[Union[str, TensorType]],
**kwargs,
) -> BatchFeature:
"""
Preprocess an image or a batch of images so that it can be used by the model.

View File

@ -19,13 +19,7 @@ import numpy as np
import requests
from packaging import version
from transformers.testing_utils import (
require_torch,
require_torch_gpu,
require_vision,
slow,
torch_device,
)
from transformers.testing_utils import require_torch, require_torch_gpu, require_vision, slow, torch_device
from transformers.utils import is_torch_available, is_torchvision_available, is_vision_available
from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs

View File

@ -19,8 +19,9 @@ import warnings
import numpy as np
import requests
from packaging import version
from transformers.testing_utils import is_flaky, require_torch, require_vision
from transformers.testing_utils import is_flaky, require_torch, require_torch_gpu, require_vision, slow, torch_device
from transformers.utils import is_torch_available, is_torchvision_available, is_vision_available
from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs
@ -334,3 +335,24 @@ class VitMatteImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
self.assertLessEqual(
torch.mean(torch.abs(encoding_slow.pixel_values - encoding_fast.pixel_values)).item(), 1e-3
)
@slow
@require_torch_gpu
@require_vision
def test_can_compile_fast_image_processor(self):
# override as trimaps are needed for the image processor
if self.fast_image_processing_class is None:
self.skipTest("Skipping compilation test as fast image processor is not defined")
if version.parse(torch.__version__) < version.parse("2.3"):
self.skipTest(reason="This test requires torch >= 2.3 to run.")
torch.compiler.reset()
input_image = torch.randint(0, 255, (3, 224, 224), dtype=torch.uint8)
dummy_trimap = np.random.randint(0, 3, size=input_image.shape[1:])
image_processor = self.fast_image_processing_class(**self.image_processor_dict)
output_eager = image_processor(input_image, dummy_trimap, device=torch_device, return_tensors="pt")
image_processor = torch.compile(image_processor, mode="reduce-overhead")
output_compiled = image_processor(input_image, dummy_trimap, device=torch_device, return_tensors="pt")
torch.testing.assert_close(output_eager.pixel_values, output_compiled.pixel_values, rtol=1e-4, atol=1e-4)