diff --git a/src/transformers/image_processing_utils_fast.py b/src/transformers/image_processing_utils_fast.py index feb254f66a3..b1e26141220 100644 --- a/src/transformers/image_processing_utils_fast.py +++ b/src/transformers/image_processing_utils_fast.py @@ -18,11 +18,7 @@ from typing import Any, Optional, TypedDict, Union import numpy as np -from .image_processing_utils import ( - BaseImageProcessor, - BatchFeature, - get_size_dict, -) +from .image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict from .image_transforms import ( convert_to_rgb, get_resize_output_image_size, @@ -233,6 +229,9 @@ class BaseImageProcessorFast(BaseImageProcessor): else: setattr(self, key, getattr(self, key, None)) + # get valid kwargs names + self._valid_kwargs_names = list(self.valid_kwargs.__annotations__.keys()) + def resize( self, image: "torch.Tensor", @@ -566,12 +565,16 @@ class BaseImageProcessorFast(BaseImageProcessor): data_format=data_format, ) + def __call__(self, images: ImageInput, *args, **kwargs: Unpack[DefaultFastImageProcessorKwargs]) -> BatchFeature: + return self.preprocess(images, *args, **kwargs) + @auto_docstring - def preprocess(self, images: ImageInput, **kwargs: Unpack[DefaultFastImageProcessorKwargs]) -> BatchFeature: - validate_kwargs(captured_kwargs=kwargs.keys(), valid_processor_keys=self.valid_kwargs.__annotations__.keys()) + def preprocess(self, images: ImageInput, *args, **kwargs: Unpack[DefaultFastImageProcessorKwargs]) -> BatchFeature: + # args are not validated, but their order in the `preprocess` and `_preprocess` signatures must be the same + validate_kwargs(captured_kwargs=kwargs.keys(), valid_processor_keys=self._valid_kwargs_names) # Set default kwargs from self. This ensures that if a kwarg is not provided # by the user, it gets its default value from the instance, or is set to None. - for kwarg_name in self.valid_kwargs.__annotations__: + for kwarg_name in self._valid_kwargs_names: kwargs.setdefault(kwarg_name, getattr(self, kwarg_name, None)) # Extract parameters that are only used for preparing the input images @@ -603,7 +606,7 @@ class BaseImageProcessorFast(BaseImageProcessor): kwargs.pop("default_to_square") kwargs.pop("data_format") - return self._preprocess(images=images, **kwargs) + return self._preprocess(images, *args, **kwargs) def _preprocess( self, @@ -651,6 +654,7 @@ class BaseImageProcessorFast(BaseImageProcessor): def to_dict(self): encoder_dict = super().to_dict() encoder_dict.pop("_valid_processor_keys", None) + encoder_dict.pop("_valid_kwargs_names", None) return encoder_dict diff --git a/src/transformers/models/conditional_detr/image_processing_conditional_detr_fast.py b/src/transformers/models/conditional_detr/image_processing_conditional_detr_fast.py index c1448fe6128..c51356ddf56 100644 --- a/src/transformers/models/conditional_detr/image_processing_conditional_detr_fast.py +++ b/src/transformers/models/conditional_detr/image_processing_conditional_detr_fast.py @@ -587,8 +587,6 @@ class ConditionalDetrImageProcessorFast(BaseImageProcessorFast): - "segments_info" (`List[Dict]`): List of segments for an image. Each segment should be a dictionary. An image can have no segments, in which case the list should be empty. - "file_name" (`str`): The file name of the image. - format (`str`, *optional*, defaults to `AnnotationFormat.COCO_DETECTION`): - Data format of the annotations. One of "coco_detection" or "coco_panoptic". masks_path (`str` or `pathlib.Path`, *optional*): Path to the directory containing the segmentation masks. """ @@ -606,19 +604,17 @@ class ConditionalDetrImageProcessorFast(BaseImageProcessorFast): ) kwargs["size"] = kwargs.pop("max_size") - return super().preprocess(images, annotations=annotations, masks_path=masks_path, **kwargs) + return super().preprocess(images, annotations, masks_path, **kwargs) def _preprocess( self, images: List["torch.Tensor"], annotations: Optional[Union[AnnotationType, List[AnnotationType]]], - return_segmentation_masks: bool, masks_path: Optional[Union[str, pathlib.Path]], + return_segmentation_masks: bool, do_resize: bool, size: SizeDict, interpolation: Optional["F.InterpolationMode"], - do_center_crop: bool, - crop_size: SizeDict, do_rescale: bool, rescale_factor: float, do_normalize: bool, @@ -629,6 +625,7 @@ class ConditionalDetrImageProcessorFast(BaseImageProcessorFast): pad_size: Optional[Dict[str, int]], format: Optional[Union[str, AnnotationFormat]], return_tensors: Optional[Union[str, TensorType]], + **kwargs, ) -> BatchFeature: """ Preprocess an image or a batch of images so that it can be used by the model. diff --git a/src/transformers/models/deformable_detr/image_processing_deformable_detr_fast.py b/src/transformers/models/deformable_detr/image_processing_deformable_detr_fast.py index cd25fb7a9d7..5bdf903f211 100644 --- a/src/transformers/models/deformable_detr/image_processing_deformable_detr_fast.py +++ b/src/transformers/models/deformable_detr/image_processing_deformable_detr_fast.py @@ -578,8 +578,6 @@ class DeformableDetrImageProcessorFast(BaseImageProcessorFast): - "segments_info" (`List[Dict]`): List of segments for an image. Each segment should be a dictionary. An image can have no segments, in which case the list should be empty. - "file_name" (`str`): The file name of the image. - format (`str`, *optional*, defaults to `AnnotationFormat.COCO_DETECTION`): - Data format of the annotations. One of "coco_detection" or "coco_panoptic". masks_path (`str` or `pathlib.Path`, *optional*): Path to the directory containing the segmentation masks. """ @@ -597,19 +595,17 @@ class DeformableDetrImageProcessorFast(BaseImageProcessorFast): ) kwargs["size"] = kwargs.pop("max_size") - return super().preprocess(images, annotations=annotations, masks_path=masks_path, **kwargs) + return super().preprocess(images, annotations, masks_path, **kwargs) def _preprocess( self, images: List["torch.Tensor"], annotations: Optional[Union[AnnotationType, List[AnnotationType]]], - return_segmentation_masks: bool, masks_path: Optional[Union[str, pathlib.Path]], + return_segmentation_masks: bool, do_resize: bool, size: SizeDict, interpolation: Optional["F.InterpolationMode"], - do_center_crop: bool, - crop_size: SizeDict, do_rescale: bool, rescale_factor: float, do_normalize: bool, @@ -620,6 +616,7 @@ class DeformableDetrImageProcessorFast(BaseImageProcessorFast): pad_size: Optional[Dict[str, int]], format: Optional[Union[str, AnnotationFormat]], return_tensors: Optional[Union[str, TensorType]], + **kwargs, ) -> BatchFeature: """ Preprocess an image or a batch of images so that it can be used by the model. diff --git a/src/transformers/models/detr/image_processing_detr_fast.py b/src/transformers/models/detr/image_processing_detr_fast.py index cf9d4e52d8c..6840d367c63 100644 --- a/src/transformers/models/detr/image_processing_detr_fast.py +++ b/src/transformers/models/detr/image_processing_detr_fast.py @@ -28,11 +28,7 @@ from ...image_processing_utils_fast import ( get_max_height_width, safe_squeeze, ) -from ...image_transforms import ( - center_to_corners_format, - corners_to_center_format, - id_to_rgb, -) +from ...image_transforms import center_to_corners_format, corners_to_center_format, id_to_rgb from ...image_utils import ( IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, @@ -603,8 +599,6 @@ class DetrImageProcessorFast(BaseImageProcessorFast): - "segments_info" (`List[Dict]`): List of segments for an image. Each segment should be a dictionary. An image can have no segments, in which case the list should be empty. - "file_name" (`str`): The file name of the image. - format (`str`, *optional*, defaults to `AnnotationFormat.COCO_DETECTION`): - Data format of the annotations. One of "coco_detection" or "coco_panoptic". masks_path (`str` or `pathlib.Path`, *optional*): Path to the directory containing the segmentation masks. """ @@ -622,19 +616,17 @@ class DetrImageProcessorFast(BaseImageProcessorFast): ) kwargs["size"] = kwargs.pop("max_size") - return super().preprocess(images, annotations=annotations, masks_path=masks_path, **kwargs) + return super().preprocess(images, annotations, masks_path, **kwargs) def _preprocess( self, images: List["torch.Tensor"], annotations: Optional[Union[AnnotationType, List[AnnotationType]]], - return_segmentation_masks: bool, masks_path: Optional[Union[str, pathlib.Path]], + return_segmentation_masks: bool, do_resize: bool, size: SizeDict, interpolation: Optional["F.InterpolationMode"], - do_center_crop: bool, - crop_size: SizeDict, do_rescale: bool, rescale_factor: float, do_normalize: bool, @@ -645,6 +637,7 @@ class DetrImageProcessorFast(BaseImageProcessorFast): pad_size: Optional[Dict[str, int]], format: Optional[Union[str, AnnotationFormat]], return_tensors: Optional[Union[str, TensorType]], + **kwargs, ) -> BatchFeature: """ Preprocess an image or a batch of images so that it can be used by the model. diff --git a/src/transformers/models/grounding_dino/image_processing_grounding_dino_fast.py b/src/transformers/models/grounding_dino/image_processing_grounding_dino_fast.py index 69166a16070..775f648def1 100644 --- a/src/transformers/models/grounding_dino/image_processing_grounding_dino_fast.py +++ b/src/transformers/models/grounding_dino/image_processing_grounding_dino_fast.py @@ -609,8 +609,6 @@ class GroundingDinoImageProcessorFast(BaseImageProcessorFast): - "segments_info" (`List[Dict]`): List of segments for an image. Each segment should be a dictionary. An image can have no segments, in which case the list should be empty. - "file_name" (`str`): The file name of the image. - format (`str`, *optional*, defaults to `AnnotationFormat.COCO_DETECTION`): - Data format of the annotations. One of "coco_detection" or "coco_panoptic". masks_path (`str` or `pathlib.Path`, *optional*): Path to the directory containing the segmentation masks. """ @@ -628,19 +626,17 @@ class GroundingDinoImageProcessorFast(BaseImageProcessorFast): ) kwargs["size"] = kwargs.pop("max_size") - return super().preprocess(images, annotations=annotations, masks_path=masks_path, **kwargs) + return super().preprocess(images, annotations, masks_path, **kwargs) def _preprocess( self, images: List["torch.Tensor"], annotations: Optional[Union[AnnotationType, List[AnnotationType]]], - return_segmentation_masks: bool, masks_path: Optional[Union[str, pathlib.Path]], + return_segmentation_masks: bool, do_resize: bool, size: SizeDict, interpolation: Optional["F.InterpolationMode"], - do_center_crop: bool, - crop_size: SizeDict, do_rescale: bool, rescale_factor: float, do_normalize: bool, @@ -651,6 +647,7 @@ class GroundingDinoImageProcessorFast(BaseImageProcessorFast): pad_size: Optional[Dict[str, int]], format: Optional[Union[str, AnnotationFormat]], return_tensors: Optional[Union[str, TensorType]], + **kwargs, ) -> BatchFeature: """ Preprocess an image or a batch of images so that it can be used by the model. diff --git a/src/transformers/models/llama4/image_processing_llama4_fast.py b/src/transformers/models/llama4/image_processing_llama4_fast.py index 63aa53bc60d..6bc57efb0a5 100644 --- a/src/transformers/models/llama4/image_processing_llama4_fast.py +++ b/src/transformers/models/llama4/image_processing_llama4_fast.py @@ -26,11 +26,7 @@ from ...image_processing_utils_fast import ( group_images_by_shape, reorder_images, ) -from ...image_utils import ( - ImageInput, - PILImageResampling, - SizeDict, -) +from ...image_utils import ImageInput, PILImageResampling, SizeDict from ...processing_utils import Unpack from ...utils import ( TensorType, @@ -320,7 +316,7 @@ def get_best_fit( else: optimal_canvas = chosen_canvas[0] - return tuple(optimal_canvas.tolist()) + return optimal_canvas class Llama4ImageProcessorKwargs(DefaultFastImageProcessorKwargs): @@ -356,6 +352,8 @@ class Llama4ImageProcessorFast(BaseImageProcessorFast): def __init__(self, **kwargs: Unpack[Llama4ImageProcessorKwargs]): super().__init__(**kwargs) + # Disable compilation here as conversion to bfloat16 causes differences in the output of the compiled and non-compiled versions + @torch.compiler.disable def rescale_and_normalize( self, images: "torch.Tensor", @@ -399,7 +397,7 @@ class Llama4ImageProcessorFast(BaseImageProcessorFast): **kwargs, ) -> BatchFeature: possible_resolutions = find_supported_resolutions(max_num_chunks=max_patches, patch_size=size) - possible_resolutions = torch.tensor(possible_resolutions) + possible_resolutions = torch.tensor(possible_resolutions, device=images[0].device) # process images by batch, grouped by shape grouped_images, grouped_images_index = group_images_by_shape(images) grouped_processed_images = {} @@ -438,7 +436,9 @@ class Llama4ImageProcessorFast(BaseImageProcessorFast): # split into tiles processed_images = split_to_tiles(processed_images, ratio_h, ratio_w) grouped_processed_images[shape] = processed_images - grouped_aspect_ratios[shape] = torch.tensor([[ratio_h, ratio_w]] * stacked_images.shape[0]) + grouped_aspect_ratios[shape] = torch.tensor( + [[ratio_h, ratio_w]] * stacked_images.shape[0], device=images[0].device + ) # add a global tile to the processed tile if there are more than one tile if ratio_h * ratio_w > 1: diff --git a/src/transformers/models/rt_detr/image_processing_rt_detr_fast.py b/src/transformers/models/rt_detr/image_processing_rt_detr_fast.py index 75403e133d9..2742e228bb1 100644 --- a/src/transformers/models/rt_detr/image_processing_rt_detr_fast.py +++ b/src/transformers/models/rt_detr/image_processing_rt_detr_fast.py @@ -397,24 +397,20 @@ class RTDetrImageProcessorFast(BaseImageProcessorFast): - "segments_info" (`List[Dict]`): List of segments for an image. Each segment should be a dictionary. An image can have no segments, in which case the list should be empty. - "file_name" (`str`): The file name of the image. - format (`str`, *optional*, defaults to `AnnotationFormat.COCO_DETECTION`): - Data format of the annotations. One of "coco_detection" or "coco_panoptic". masks_path (`str` or `pathlib.Path`, *optional*): Path to the directory containing the segmentation masks. """ - return super().preprocess(images, annotations=annotations, masks_path=masks_path, **kwargs) + return super().preprocess(images, annotations, masks_path, **kwargs) def _preprocess( self, images: List["torch.Tensor"], annotations: Optional[Union[AnnotationType, List[AnnotationType]]], - return_segmentation_masks: bool, masks_path: Optional[Union[str, pathlib.Path]], + return_segmentation_masks: bool, do_resize: bool, size: SizeDict, interpolation: Optional["F.InterpolationMode"], - do_center_crop: bool, - crop_size: SizeDict, do_rescale: bool, rescale_factor: float, do_normalize: bool, @@ -425,6 +421,7 @@ class RTDetrImageProcessorFast(BaseImageProcessorFast): pad_size: Optional[Dict[str, int]], format: Optional[Union[str, AnnotationFormat]], return_tensors: Optional[Union[str, TensorType]], + **kwargs, ) -> BatchFeature: """ Preprocess an image or a batch of images so that it can be used by the model. diff --git a/src/transformers/models/rt_detr/modular_rt_detr.py b/src/transformers/models/rt_detr/modular_rt_detr.py index 15408b5edc4..b987603192c 100644 --- a/src/transformers/models/rt_detr/modular_rt_detr.py +++ b/src/transformers/models/rt_detr/modular_rt_detr.py @@ -137,7 +137,7 @@ class RTDetrImageProcessorFast(DetrImageProcessorFast, BaseImageProcessorFast): masks_path: Optional[Union[str, pathlib.Path]] = None, **kwargs: Unpack[RTDetrFastImageProcessorKwargs], ) -> BatchFeature: - return BaseImageProcessorFast().preprocess(images, annotations=annotations, masks_path=masks_path, **kwargs) + return BaseImageProcessorFast().preprocess(images, annotations, masks_path, **kwargs) def prepare_annotation( self, @@ -163,13 +163,11 @@ class RTDetrImageProcessorFast(DetrImageProcessorFast, BaseImageProcessorFast): self, images: List["torch.Tensor"], annotations: Optional[Union[AnnotationType, List[AnnotationType]]], - return_segmentation_masks: bool, masks_path: Optional[Union[str, pathlib.Path]], + return_segmentation_masks: bool, do_resize: bool, size: SizeDict, interpolation: Optional["F.InterpolationMode"], - do_center_crop: bool, - crop_size: SizeDict, do_rescale: bool, rescale_factor: float, do_normalize: bool, @@ -180,6 +178,7 @@ class RTDetrImageProcessorFast(DetrImageProcessorFast, BaseImageProcessorFast): pad_size: Optional[Dict[str, int]], format: Optional[Union[str, AnnotationFormat]], return_tensors: Optional[Union[str, TensorType]], + **kwargs, ) -> BatchFeature: """ Preprocess an image or a batch of images so that it can be used by the model. diff --git a/src/transformers/models/vitmatte/image_processing_vitmatte_fast.py b/src/transformers/models/vitmatte/image_processing_vitmatte_fast.py index 59f8a383356..174bebb45fc 100644 --- a/src/transformers/models/vitmatte/image_processing_vitmatte_fast.py +++ b/src/transformers/models/vitmatte/image_processing_vitmatte_fast.py @@ -131,7 +131,7 @@ class VitMatteImageProcessorFast(BaseImageProcessorFast): kwargs.pop("size") kwargs.pop("crop_size") - return self._preprocess(images=images, trimaps=trimaps, **kwargs) + return self._preprocess(images, trimaps, **kwargs) def _prepare_input_trimaps( self, trimaps: ImageInput, device: Optional["torch.device"] = None diff --git a/src/transformers/models/yolos/image_processing_yolos_fast.py b/src/transformers/models/yolos/image_processing_yolos_fast.py index cd29e647a01..120722f05a4 100644 --- a/src/transformers/models/yolos/image_processing_yolos_fast.py +++ b/src/transformers/models/yolos/image_processing_yolos_fast.py @@ -626,8 +626,6 @@ class YolosImageProcessorFast(BaseImageProcessorFast): - "segments_info" (`List[Dict]`): List of segments for an image. Each segment should be a dictionary. An image can have no segments, in which case the list should be empty. - "file_name" (`str`): The file name of the image. - format (`str`, *optional*, defaults to `AnnotationFormat.COCO_DETECTION`): - Data format of the annotations. One of "coco_detection" or "coco_panoptic". masks_path (`str` or `pathlib.Path`, *optional*): Path to the directory containing the segmentation masks. """ @@ -645,19 +643,17 @@ class YolosImageProcessorFast(BaseImageProcessorFast): ) kwargs["size"] = kwargs.pop("max_size") - return super().preprocess(images, annotations=annotations, masks_path=masks_path, **kwargs) + return super().preprocess(images, annotations, masks_path, **kwargs) def _preprocess( self, images: List["torch.Tensor"], annotations: Optional[Union[AnnotationType, List[AnnotationType]]], - return_segmentation_masks: bool, masks_path: Optional[Union[str, pathlib.Path]], + return_segmentation_masks: bool, do_resize: bool, size: SizeDict, interpolation: Optional["F.InterpolationMode"], - do_center_crop: bool, - crop_size: SizeDict, do_rescale: bool, rescale_factor: float, do_normalize: bool, @@ -668,6 +664,7 @@ class YolosImageProcessorFast(BaseImageProcessorFast): pad_size: Optional[Dict[str, int]], format: Optional[Union[str, AnnotationFormat]], return_tensors: Optional[Union[str, TensorType]], + **kwargs, ) -> BatchFeature: """ Preprocess an image or a batch of images so that it can be used by the model. diff --git a/tests/models/pixtral/test_image_processing_pixtral.py b/tests/models/pixtral/test_image_processing_pixtral.py index caaeb78a2e9..afe0c4674d3 100644 --- a/tests/models/pixtral/test_image_processing_pixtral.py +++ b/tests/models/pixtral/test_image_processing_pixtral.py @@ -19,13 +19,7 @@ import numpy as np import requests from packaging import version -from transformers.testing_utils import ( - require_torch, - require_torch_gpu, - require_vision, - slow, - torch_device, -) +from transformers.testing_utils import require_torch, require_torch_gpu, require_vision, slow, torch_device from transformers.utils import is_torch_available, is_torchvision_available, is_vision_available from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs diff --git a/tests/models/vitmatte/test_image_processing_vitmatte.py b/tests/models/vitmatte/test_image_processing_vitmatte.py index f3a8a1507b3..49b8ff281a5 100644 --- a/tests/models/vitmatte/test_image_processing_vitmatte.py +++ b/tests/models/vitmatte/test_image_processing_vitmatte.py @@ -19,8 +19,9 @@ import warnings import numpy as np import requests +from packaging import version -from transformers.testing_utils import is_flaky, require_torch, require_vision +from transformers.testing_utils import is_flaky, require_torch, require_torch_gpu, require_vision, slow, torch_device from transformers.utils import is_torch_available, is_torchvision_available, is_vision_available from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs @@ -334,3 +335,24 @@ class VitMatteImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase): self.assertLessEqual( torch.mean(torch.abs(encoding_slow.pixel_values - encoding_fast.pixel_values)).item(), 1e-3 ) + + @slow + @require_torch_gpu + @require_vision + def test_can_compile_fast_image_processor(self): + # override as trimaps are needed for the image processor + if self.fast_image_processing_class is None: + self.skipTest("Skipping compilation test as fast image processor is not defined") + if version.parse(torch.__version__) < version.parse("2.3"): + self.skipTest(reason="This test requires torch >= 2.3 to run.") + + torch.compiler.reset() + input_image = torch.randint(0, 255, (3, 224, 224), dtype=torch.uint8) + dummy_trimap = np.random.randint(0, 3, size=input_image.shape[1:]) + image_processor = self.fast_image_processing_class(**self.image_processor_dict) + output_eager = image_processor(input_image, dummy_trimap, device=torch_device, return_tensors="pt") + + image_processor = torch.compile(image_processor, mode="reduce-overhead") + output_compiled = image_processor(input_image, dummy_trimap, device=torch_device, return_tensors="pt") + + torch.testing.assert_close(output_eager.pixel_values, output_compiled.pixel_values, rtol=1e-4, atol=1e-4)