diff --git a/docs/source/en/model_doc/nougat.md b/docs/source/en/model_doc/nougat.md index c3d6ef54f47..accde09ffdd 100644 --- a/docs/source/en/model_doc/nougat.md +++ b/docs/source/en/model_doc/nougat.md @@ -107,6 +107,11 @@ The model is identical to [Donut](donut) in terms of architecture. [[autodoc]] NougatImageProcessor - preprocess +## NougatImageProcessorFast + +[[autodoc]] NougatImageProcessorFast + - preprocess + ## NougatTokenizerFast [[autodoc]] NougatTokenizerFast diff --git a/src/transformers/models/auto/image_processing_auto.py b/src/transformers/models/auto/image_processing_auto.py index 4ad74482ebc..4586627b919 100644 --- a/src/transformers/models/auto/image_processing_auto.py +++ b/src/transformers/models/auto/image_processing_auto.py @@ -126,7 +126,7 @@ else: ("mobilevit", ("MobileViTImageProcessor",)), ("mobilevitv2", ("MobileViTImageProcessor",)), ("nat", ("ViTImageProcessor", "ViTImageProcessorFast")), - ("nougat", ("NougatImageProcessor",)), + ("nougat", ("NougatImageProcessor", "NougatImageProcessorFast")), ("oneformer", ("OneFormerImageProcessor",)), ("owlv2", ("Owlv2ImageProcessor",)), ("owlvit", ("OwlViTImageProcessor", "OwlViTImageProcessorFast")), diff --git a/src/transformers/models/nougat/__init__.py b/src/transformers/models/nougat/__init__.py index 4c87d75e58e..6cd3208bfa2 100644 --- a/src/transformers/models/nougat/__init__.py +++ b/src/transformers/models/nougat/__init__.py @@ -19,6 +19,7 @@ from ...utils.import_utils import define_import_structure if TYPE_CHECKING: from .image_processing_nougat import * + from .image_processing_nougat_fast import * from .processing_nougat import * from .tokenization_nougat_fast import * else: diff --git a/src/transformers/models/nougat/image_processing_nougat.py b/src/transformers/models/nougat/image_processing_nougat.py index 3447c0ab151..827686a6066 100644 --- a/src/transformers/models/nougat/image_processing_nougat.py +++ b/src/transformers/models/nougat/image_processing_nougat.py @@ -169,6 +169,7 @@ class NougatImageProcessor(BaseImageProcessor): min_val = data.min() if max_val == min_val: image = np.array(image) + image = to_channel_dimension_format(image, input_data_format, ChannelDimension.LAST) image = ( to_channel_dimension_format(image, data_format, input_data_format) if data_format is not None diff --git a/src/transformers/models/nougat/image_processing_nougat_fast.py b/src/transformers/models/nougat/image_processing_nougat_fast.py new file mode 100644 index 00000000000..29e1d6e2175 --- /dev/null +++ b/src/transformers/models/nougat/image_processing_nougat_fast.py @@ -0,0 +1,327 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Fast Image processor class for Nougat.""" + +from typing import Optional, Union + +from ...image_processing_utils import BatchFeature +from ...image_processing_utils_fast import ( + BaseImageProcessorFast, + DefaultFastImageProcessorKwargs, + group_images_by_shape, + reorder_images, +) +from ...image_transforms import ( + get_resize_output_image_size, +) +from ...image_utils import ( + IMAGENET_DEFAULT_MEAN, + IMAGENET_DEFAULT_STD, + ChannelDimension, + ImageInput, + PILImageResampling, + SizeDict, +) +from ...processing_utils import Unpack +from ...utils import ( + TensorType, + auto_docstring, + is_torch_available, + is_torchvision_available, + is_torchvision_v2_available, +) + + +if is_torch_available(): + import torch + +if is_torchvision_available(): + if is_torchvision_v2_available(): + from torchvision.transforms.v2 import functional as F + else: + from torchvision.transforms import functional as F + + +class NougatFastImageProcessorKwargs(DefaultFastImageProcessorKwargs): + """ + Args: + do_crop_margin (`bool`, *optional*, defaults to `True`): + Whether to crop the image margins. + do_thumbnail (`bool`, *optional*, defaults to `True`): + Whether to resize the image using thumbnail method. + do_align_long_axis (`bool`, *optional*, defaults to `False`): + Whether to align the long axis of the image with the long axis of `size` by rotating by 90 degrees. + do_pad (`bool`, *optional*, defaults to `True`): + Whether to pad the images to the largest image size in the batch. + """ + + do_crop_margin: Optional[bool] + do_thumbnail: Optional[bool] + do_align_long_axis: Optional[bool] + do_pad: Optional[bool] + + +@auto_docstring +class NougatImageProcessorFast(BaseImageProcessorFast): + resample = PILImageResampling.BILINEAR + image_mean = IMAGENET_DEFAULT_MEAN + image_std = IMAGENET_DEFAULT_STD + size = {"height": 896, "width": 672} + do_resize: bool = (True,) + do_normalize: bool = True + do_thumbnail: bool = True + do_align_long_axis: bool = False + do_pad: bool = True + do_rescale = True + do_crop_margin: bool = True + valid_kwargs = NougatFastImageProcessorKwargs + + def __init__(self, **kwargs: Unpack[NougatFastImageProcessorKwargs]): + super().__init__(**kwargs) + + @auto_docstring + def preprocess(self, images: ImageInput, **kwargs: Unpack[NougatFastImageProcessorKwargs]) -> BatchFeature: + return super().preprocess(images, **kwargs) + + def python_find_non_zero( + self, + image: "torch.Tensor", + ): + """This is a reimplementation of a findNonZero function equivalent to cv2.""" + + non_zero_indices = torch.nonzero(image, as_tuple=False) + idxvec = non_zero_indices[:, [2, 1]] + idxvec = idxvec.reshape(-1, 1, 2) + return idxvec + + def python_bounding_rect(self, coordinates): + """This is a reimplementation of a BoundingRect function equivalent to cv2.""" + + min_values = torch.amin(coordinates, axis=(0, 1)).to(torch.int) + max_values = torch.amax(coordinates, axis=(0, 1)).to(torch.int) + + x_min, y_min = min_values[0], min_values[1] + width = max_values[0] - x_min + 1 + height = max_values[1] - y_min + 1 + return x_min, y_min, width, height + + def crop_margin( + self, + image: "torch.Tensor", + gray_threshold: int = 200, + ) -> "torch.Tensor": + """ + Crops the margin of the image. Gray pixels are considered margin (i.e., pixels with a value below the + threshold). + + Args: + image (`torch.Tensor`): + The image to be cropped. + gray_threshold (`int`, *optional*, defaults to `200`) + Value below which pixels are considered to be gray. + """ + data = F.rgb_to_grayscale(image, num_output_channels=1) + + max_val = torch.max(data) + min_val = torch.min(data) + + if max_val == min_val: + return image + data = (data - min_val) / (max_val - min_val) * 255 + gray = data < gray_threshold + coords = self.python_find_non_zero(gray) + x_min, y_min, width, height = self.python_bounding_rect(coords) + image = image[:, y_min : y_min + height, x_min : x_min + width] + + return image + + def align_long_axis( + self, + image: "torch.Tensor", + size: SizeDict, + ) -> "torch.Tensor": + """ + Align the long axis of the image to the longest axis of the specified size. + + Args: + image (`torch.Tensor`): + The image to be aligned. + size (`Dict[str, int]`): + The size `{"height": h, "width": w}` to align the long axis to. + Returns: + `torch.Tensor`: The aligned image. + """ + input_height, input_width = image.shape[-2:] + output_height, output_width = size.height, size.width + + if (output_width < output_height and input_width > input_height) or ( + output_width > output_height and input_width < input_height + ): + image = torch.rot90(image, 3, dims=[1, 2]) + + return image + + def thumbnail( + self, + image: "torch.Tensor", + size: SizeDict, + ) -> "torch.Tensor": + """ + Resize the image to make a thumbnail. The image is resized so that no dimension is larger than any + corresponding dimension of the specified size. + + Args: + image (`torch.tensor`): + The image to be resized. + size (`Dict[str, int]`): + The size `{"height": h, "width": w}` to resize the image to. + """ + + input_height, input_width = image.shape[-2:] + output_height, output_width = size.height, size.width + + # We always resize to the smallest of either the input or output size. + height = min(input_height, output_height) + width = min(input_width, output_width) + + if height == input_height and width == input_width: + return image + + if input_height > input_width: + width = int(input_width * height / input_height) + elif input_width > input_height: + height = int(input_height * width / input_width) + + new_size = (height, width) + + return F.resize(image, new_size, interpolation=F.InterpolationMode.BICUBIC) + + def pad_images( + self, + image: "torch.Tensor", + size: SizeDict, + ) -> "torch.Tensor": + """ + Pads a batch of images to the specified size at the top, bottom, left and right. + + Args: + image (`torch.tensor`): + The image to be padded. + size (`Dict[str, int]`): + The size `{"height": h, "width": w}` to pad the image to. + """ + input_height, input_width = image.shape[-2:] + output_height, output_width = size.height, size.width + + delta_width = output_width - input_width + delta_height = output_height - input_height + + pad_top = delta_height // 2 + pad_left = delta_width // 2 + + pad_bottom = delta_height - pad_top + pad_right = delta_width - pad_left + + padding = (pad_left, pad_top, pad_right, pad_bottom) + return F.pad(image, padding) + + def resize( + self, + image: "torch.Tensor", + size: SizeDict, + interpolation: "F.InterpolationMode" = None, + antialias: bool = True, + **kwargs, + ) -> "torch.Tensor": + """ + Resize an image to `(size["height"], size["width"])`. + + Args: + image (`torch.Tensor`): + Image to resize. + size (`SizeDict`): + Dictionary in the format `{"height": int, "width": int}` specifying the size of the output image. + interpolation (`InterpolationMode`, *optional*, defaults to `InterpolationMode.BICUBIC`): + `InterpolationMode` filter to use when resizing the image e.g. `InterpolationMode.BICUBIC`. + + Returns: + `torch.Tensor`: The resized image. + """ + interpolation = interpolation if interpolation is not None else F.InterpolationMode.BICUBIC + + shortest_edge = min(size["height"], size["width"]) + + new_size = get_resize_output_image_size( + image, size=shortest_edge, default_to_square=False, input_data_format=ChannelDimension.FIRST + ) + return F.resize(image, new_size, interpolation=interpolation, antialias=antialias) + + def _preprocess( + self, + images: list["torch.Tensor"], + do_resize: bool, + size: SizeDict, + do_align_long_axis: bool, + do_thumbnail: bool, + do_pad: bool, + interpolation: Optional["F.InterpolationMode"], + do_center_crop: bool, + crop_size: SizeDict, + do_rescale: bool, + rescale_factor: float, + do_normalize: bool, + do_crop_margin: bool, + image_mean: Optional[Union[float, list[float]]], + image_std: Optional[Union[float, list[float]]], + disable_grouping: bool, + return_tensors: Optional[Union[str, TensorType]], + **kwargs, + ) -> BatchFeature: + # Crop images + images = [self.crop_margin(image) for image in images] + + # Group images by size for batched resizing + grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping) + resized_images_grouped = {} + for shape, stacked_images in grouped_images.items(): + if do_align_long_axis: + stacked_images = self.align_long_axis(image=stacked_images, size=size) + if do_resize: + stacked_images = self.resize(image=stacked_images, size=size) + if do_thumbnail: + stacked_images = self.thumbnail(image=stacked_images, size=size) + if do_pad: + stacked_images = self.pad_images(image=stacked_images, size=size) + resized_images_grouped[shape] = stacked_images + resized_images = reorder_images(resized_images_grouped, grouped_images_index) + + # Group images by size for further processing + # Needed in case do_resize is False, or resize returns images with different sizes + grouped_images, grouped_images_index = group_images_by_shape(resized_images, disable_grouping=disable_grouping) + processed_images_grouped = {} + for shape, stacked_images in grouped_images.items(): + # Fused rescale and normalize + stacked_images = self.rescale_and_normalize( + stacked_images, do_rescale, rescale_factor, do_normalize, image_mean, image_std + ) + processed_images_grouped[shape] = stacked_images + + processed_images = reorder_images(processed_images_grouped, grouped_images_index) + processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images + + return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors) + + +__all__ = ["NougatImageProcessorFast"] diff --git a/tests/models/nougat/test_image_processing_nougat.py b/tests/models/nougat/test_image_processing_nougat.py index 996860da6ed..6be868e39e9 100644 --- a/tests/models/nougat/test_image_processing_nougat.py +++ b/tests/models/nougat/test_image_processing_nougat.py @@ -16,10 +16,12 @@ import unittest import numpy as np +import requests from huggingface_hub import hf_hub_download +from transformers.image_utils import SizeDict from transformers.testing_utils import require_torch, require_vision -from transformers.utils import cached_property, is_torch_available, is_vision_available +from transformers.utils import cached_property, is_torch_available, is_torchvision_available, is_vision_available from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs @@ -32,6 +34,9 @@ if is_vision_available(): from transformers import NougatImageProcessor + if is_torchvision_available(): + from transformers import NougatImageProcessorFast + class NougatImageProcessingTester: def __init__( @@ -68,6 +73,8 @@ class NougatImageProcessingTester: self.do_normalize = do_normalize self.image_mean = image_mean self.image_std = image_std + self.data_format = "channels_first" + self.input_data_format = "channels_first" def prepare_image_processor_dict(self): return { @@ -112,6 +119,7 @@ class NougatImageProcessingTester: @require_vision class NougatImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase): image_processing_class = NougatImageProcessor if is_vision_available() else None + fast_image_processing_class = NougatImageProcessorFast if is_torchvision_available() else None def setUp(self): super().setUp() @@ -126,61 +134,106 @@ class NougatImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase): return self.image_processing_class(**self.image_processor_dict) def test_image_processor_properties(self): - image_processing = self.image_processing_class(**self.image_processor_dict) - self.assertTrue(hasattr(image_processing, "do_resize")) - self.assertTrue(hasattr(image_processing, "size")) - self.assertTrue(hasattr(image_processing, "do_normalize")) - self.assertTrue(hasattr(image_processing, "image_mean")) - self.assertTrue(hasattr(image_processing, "image_std")) + for image_processing_class in self.image_processor_list: + image_processing = image_processing_class(**self.image_processor_dict) + self.assertTrue(hasattr(image_processing, "do_resize")) + self.assertTrue(hasattr(image_processing, "size")) + self.assertTrue(hasattr(image_processing, "do_normalize")) + self.assertTrue(hasattr(image_processing, "image_mean")) + self.assertTrue(hasattr(image_processing, "image_std")) def test_image_processor_from_dict_with_kwargs(self): - image_processor = self.image_processing_class.from_dict(self.image_processor_dict) - self.assertEqual(image_processor.size, {"height": 20, "width": 20}) + for image_processing_class in self.image_processor_list: + image_processor = image_processing_class(**self.image_processor_dict) + self.assertEqual(image_processor.size, {"height": 20, "width": 20}) - image_processor = self.image_processing_class.from_dict(self.image_processor_dict, size=42) - self.assertEqual(image_processor.size, {"height": 42, "width": 42}) + kwargs = dict(self.image_processor_dict) + kwargs.pop("size", None) + image_processor = self.image_processing_class(**kwargs, size=42) + self.assertEqual(image_processor.size, {"height": 42, "width": 42}) def test_expected_output(self): dummy_image = self.image_processor_tester.prepare_dummy_image() - image_processor = self.image_processor - inputs = image_processor(dummy_image, return_tensors="pt") - torch.testing.assert_close(inputs["pixel_values"].mean(), torch.tensor(0.4906), rtol=1e-3, atol=1e-3) + for image_processing_class in self.image_processor_list: + image_processor = image_processing_class(**self.image_processor_dict) + inputs = image_processor(dummy_image, return_tensors="pt") + torch.testing.assert_close(inputs["pixel_values"].mean(), torch.tensor(0.4906), rtol=1e-3, atol=1e-3) def test_crop_margin_all_white(self): - image = np.uint8(np.ones((100, 100, 3)) * 255) - image_processor = self.image_processor - cropped_image = image_processor.crop_margin(image) - self.assertTrue(np.array_equal(image, cropped_image)) + image = np.uint8(np.ones((3, 100, 100)) * 255) + for image_processing_class in self.image_processor_list: + if image_processing_class == NougatImageProcessorFast: + image = torch.from_numpy(image) + image_processor = image_processing_class(**self.image_processor_dict) + cropped_image = image_processor.crop_margin(image) + self.assertTrue(torch.equal(image, cropped_image)) + else: + image_processor = image_processing_class(**self.image_processor_dict) + cropped_image = image_processor.crop_margin(image) + self.assertTrue(np.array_equal(image, cropped_image)) def test_crop_margin_centered_black_square(self): - image = np.ones((100, 100, 3), dtype=np.uint8) * 255 - image[45:55, 45:55, :] = 0 - image_processor = self.image_processor - cropped_image = image_processor.crop_margin(image) - expected_cropped = image[45:55, 45:55, :] - self.assertTrue(np.array_equal(expected_cropped, cropped_image)) + image = np.ones((3, 100, 100), dtype=np.uint8) * 255 + image[:, 45:55, 45:55] = 0 + expected_cropped = image[:, 45:55, 45:55] + for image_processing_class in self.image_processor_list: + if image_processing_class == NougatImageProcessorFast: + image = torch.from_numpy(image) + expected_cropped = torch.from_numpy(expected_cropped) + image_processor = image_processing_class(**self.image_processor_dict) + cropped_image = image_processor.crop_margin(image) + self.assertTrue(torch.equal(expected_cropped, cropped_image)) + else: + image_processor = image_processing_class(**self.image_processor_dict) + cropped_image = image_processor.crop_margin(image) + self.assertTrue(np.array_equal(expected_cropped, cropped_image)) def test_align_long_axis_no_rotation(self): - image = np.uint8(np.ones((100, 200, 3)) * 255) - image_processor = self.image_processor - size = {"height": 200, "width": 300} - aligned_image = image_processor.align_long_axis(image, size) - self.assertEqual(image.shape, aligned_image.shape) + image = np.uint8(np.ones((3, 100, 200)) * 255) + for image_processing_class in self.image_processor_list: + if image_processing_class == NougatImageProcessorFast: + image = torch.from_numpy(image) + size = SizeDict(height=200, width=300) + image_processor = image_processing_class(**self.image_processor_dict) + aligned_image = image_processor.align_long_axis(image, size) + self.assertEqual(image.shape, aligned_image.shape) + else: + size = {"height": 200, "width": 300} + image_processor = image_processing_class(**self.image_processor_dict) + aligned_image = image_processor.align_long_axis(image, size) + self.assertEqual(image.shape, aligned_image.shape) def test_align_long_axis_with_rotation(self): - image = np.uint8(np.ones((200, 100, 3)) * 255) - image_processor = self.image_processor - size = {"height": 300, "width": 200} - aligned_image = image_processor.align_long_axis(image, size) - self.assertEqual((200, 100, 3), aligned_image.shape) + image = np.uint8(np.ones((3, 200, 100)) * 255) + for image_processing_class in self.image_processor_list: + image_processor = image_processing_class(**self.image_processor_dict) + if image_processing_class == NougatImageProcessorFast: + image = torch.from_numpy(image) + size = SizeDict(height=300, width=200) + image_processor = image_processing_class(**self.image_processor_dict) + aligned_image = image_processor.align_long_axis(image, size) + self.assertEqual(torch.Size([3, 200, 100]), aligned_image.shape) + else: + size = {"height": 300, "width": 200} + image_processor = image_processing_class(**self.image_processor_dict) + aligned_image = image_processor.align_long_axis(image, size) + self.assertEqual((3, 200, 100), aligned_image.shape) def test_align_long_axis_data_format(self): - image = np.uint8(np.ones((100, 200, 3)) * 255) - data_format = "channels_first" - size = {"height": 200, "width": 300} - image_processor = self.image_processor - aligned_image = image_processor.align_long_axis(image, size, data_format=data_format) - self.assertEqual((3, 100, 200), aligned_image.shape) + image = np.uint8(np.ones((3, 100, 200)) * 255) + for image_processing_class in self.image_processor_list: + if image_processing_class == NougatImageProcessorFast: + image = torch.from_numpy(image) + image_processor = image_processing_class(**self.image_processor_dict) + size = SizeDict(height=200, width=300) + aligned_image = image_processor.align_long_axis(image, size) + self.assertEqual(torch.Size([3, 100, 200]), aligned_image.shape) + else: + size = {"height": 200, "width": 300} + data_format = "channels_first" + image_processor = image_processing_class(**self.image_processor_dict) + aligned_image = image_processor.align_long_axis(image, size, data_format) + self.assertEqual((3, 100, 200), aligned_image.shape) def prepare_dummy_np_image(self): revision = "ec57bf8c8b1653a209c13f6e9ee66b12df0fc2db" @@ -191,12 +244,77 @@ class NougatImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase): revision=revision, ) image = Image.open(filepath).convert("RGB") - return np.array(image) + return np.array(image).transpose(2, 0, 1) def test_crop_margin_equality_cv2_python(self): image = self.prepare_dummy_np_image() - image_processor = self.image_processor - image_cropped_python = image_processor.crop_margin(image) + for image_processing_class in self.image_processor_list: + if image_processing_class == NougatImageProcessorFast: + image = torch.from_numpy(image) + image_processor = image_processing_class(**self.image_processor_dict) + image_cropped_python = image_processor.crop_margin(image) + self.assertEqual(image_cropped_python.shape, torch.Size([3, 850, 685])) + self.assertAlmostEqual(image_cropped_python.float().mean().item(), 237.43881150708458, delta=0.001) + else: + image_processor = image_processing_class(**self.image_processor_dict) + image_cropped_python = image_processor.crop_margin(image) + self.assertEqual(image_cropped_python.shape, (3, 850, 685)) + self.assertAlmostEqual(image_cropped_python.mean(), 237.43881150708458, delta=0.001) - self.assertEqual(image_cropped_python.shape, (850, 685, 3)) - self.assertEqual(image_cropped_python.mean(), 237.43881150708458) + def test_call_numpy_4_channels(self): + for image_processing_class in self.image_processor_list: + if image_processing_class == NougatImageProcessor: + # Test that can process images which have an arbitrary number of channels + # Initialize image_processing + image_processor = image_processing_class(**self.image_processor_dict) + + # create random numpy tensors + self.image_processor_tester.num_channels = 4 + image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, numpify=True) + + # Test not batched input + encoded_images = image_processor( + image_inputs[0], + return_tensors="pt", + input_data_format="channels_last", + image_mean=0, + image_std=1, + ).pixel_values + expected_output_image_shape = self.image_processor_tester.expected_output_image_shape( + [image_inputs[0]] + ) + self.assertEqual(tuple(encoded_images.shape), (1, *expected_output_image_shape)) + + # Test batched + encoded_images = image_processor( + image_inputs, + return_tensors="pt", + input_data_format="channels_last", + image_mean=0, + image_std=1, + ).pixel_values + expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_inputs) + self.assertEqual( + tuple(encoded_images.shape), (self.image_processor_tester.batch_size, *expected_output_image_shape) + ) + + def test_slow_fast_equivalence(self): + if not self.test_slow_image_processor or not self.test_fast_image_processor: + self.skipTest(reason="Skipping slow/fast equivalence test") + + if self.image_processing_class is None or self.fast_image_processing_class is None: + self.skipTest(reason="Skipping slow/fast equivalence test as one of the image processors is not defined") + + dummy_image = Image.open( + requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw + ) + image_processor_slow = self.image_processing_class(**self.image_processor_dict) + image_processor_fast = self.fast_image_processing_class(**self.image_processor_dict) + + encoding_slow = image_processor_slow(dummy_image, return_tensors="pt") + encoding_fast = image_processor_fast(dummy_image, return_tensors="pt") + # Adding a larget than usual tolerance because the slow processor uses reducing_gap=2.0 during resizing. + torch.testing.assert_close(encoding_slow.pixel_values, encoding_fast.pixel_values, atol=2e-1, rtol=0) + self.assertLessEqual( + torch.mean(torch.abs(encoding_slow.pixel_values - encoding_fast.pixel_values)).item(), 2e-2 + )