mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Add args support for fast image processors (#37018)
* add args support to fast image processors * add comment for clarity * fix-copies * Handle child class args passed as both args or kwargs in call and preprocess functions * revert support args passed as kwargs in overwritten preprocess * fix image processor errors
This commit is contained in:
parent
d69945e5fc
commit
0ba95564b7
@ -18,11 +18,7 @@ from typing import Any, Optional, TypedDict, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .image_processing_utils import (
|
||||
BaseImageProcessor,
|
||||
BatchFeature,
|
||||
get_size_dict,
|
||||
)
|
||||
from .image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
|
||||
from .image_transforms import (
|
||||
convert_to_rgb,
|
||||
get_resize_output_image_size,
|
||||
@ -233,6 +229,9 @@ class BaseImageProcessorFast(BaseImageProcessor):
|
||||
else:
|
||||
setattr(self, key, getattr(self, key, None))
|
||||
|
||||
# get valid kwargs names
|
||||
self._valid_kwargs_names = list(self.valid_kwargs.__annotations__.keys())
|
||||
|
||||
def resize(
|
||||
self,
|
||||
image: "torch.Tensor",
|
||||
@ -566,12 +565,16 @@ class BaseImageProcessorFast(BaseImageProcessor):
|
||||
data_format=data_format,
|
||||
)
|
||||
|
||||
def __call__(self, images: ImageInput, *args, **kwargs: Unpack[DefaultFastImageProcessorKwargs]) -> BatchFeature:
|
||||
return self.preprocess(images, *args, **kwargs)
|
||||
|
||||
@auto_docstring
|
||||
def preprocess(self, images: ImageInput, **kwargs: Unpack[DefaultFastImageProcessorKwargs]) -> BatchFeature:
|
||||
validate_kwargs(captured_kwargs=kwargs.keys(), valid_processor_keys=self.valid_kwargs.__annotations__.keys())
|
||||
def preprocess(self, images: ImageInput, *args, **kwargs: Unpack[DefaultFastImageProcessorKwargs]) -> BatchFeature:
|
||||
# args are not validated, but their order in the `preprocess` and `_preprocess` signatures must be the same
|
||||
validate_kwargs(captured_kwargs=kwargs.keys(), valid_processor_keys=self._valid_kwargs_names)
|
||||
# Set default kwargs from self. This ensures that if a kwarg is not provided
|
||||
# by the user, it gets its default value from the instance, or is set to None.
|
||||
for kwarg_name in self.valid_kwargs.__annotations__:
|
||||
for kwarg_name in self._valid_kwargs_names:
|
||||
kwargs.setdefault(kwarg_name, getattr(self, kwarg_name, None))
|
||||
|
||||
# Extract parameters that are only used for preparing the input images
|
||||
@ -603,7 +606,7 @@ class BaseImageProcessorFast(BaseImageProcessor):
|
||||
kwargs.pop("default_to_square")
|
||||
kwargs.pop("data_format")
|
||||
|
||||
return self._preprocess(images=images, **kwargs)
|
||||
return self._preprocess(images, *args, **kwargs)
|
||||
|
||||
def _preprocess(
|
||||
self,
|
||||
@ -651,6 +654,7 @@ class BaseImageProcessorFast(BaseImageProcessor):
|
||||
def to_dict(self):
|
||||
encoder_dict = super().to_dict()
|
||||
encoder_dict.pop("_valid_processor_keys", None)
|
||||
encoder_dict.pop("_valid_kwargs_names", None)
|
||||
return encoder_dict
|
||||
|
||||
|
||||
|
@ -587,8 +587,6 @@ class ConditionalDetrImageProcessorFast(BaseImageProcessorFast):
|
||||
- "segments_info" (`List[Dict]`): List of segments for an image. Each segment should be a dictionary.
|
||||
An image can have no segments, in which case the list should be empty.
|
||||
- "file_name" (`str`): The file name of the image.
|
||||
format (`str`, *optional*, defaults to `AnnotationFormat.COCO_DETECTION`):
|
||||
Data format of the annotations. One of "coco_detection" or "coco_panoptic".
|
||||
masks_path (`str` or `pathlib.Path`, *optional*):
|
||||
Path to the directory containing the segmentation masks.
|
||||
"""
|
||||
@ -606,19 +604,17 @@ class ConditionalDetrImageProcessorFast(BaseImageProcessorFast):
|
||||
)
|
||||
kwargs["size"] = kwargs.pop("max_size")
|
||||
|
||||
return super().preprocess(images, annotations=annotations, masks_path=masks_path, **kwargs)
|
||||
return super().preprocess(images, annotations, masks_path, **kwargs)
|
||||
|
||||
def _preprocess(
|
||||
self,
|
||||
images: List["torch.Tensor"],
|
||||
annotations: Optional[Union[AnnotationType, List[AnnotationType]]],
|
||||
return_segmentation_masks: bool,
|
||||
masks_path: Optional[Union[str, pathlib.Path]],
|
||||
return_segmentation_masks: bool,
|
||||
do_resize: bool,
|
||||
size: SizeDict,
|
||||
interpolation: Optional["F.InterpolationMode"],
|
||||
do_center_crop: bool,
|
||||
crop_size: SizeDict,
|
||||
do_rescale: bool,
|
||||
rescale_factor: float,
|
||||
do_normalize: bool,
|
||||
@ -629,6 +625,7 @@ class ConditionalDetrImageProcessorFast(BaseImageProcessorFast):
|
||||
pad_size: Optional[Dict[str, int]],
|
||||
format: Optional[Union[str, AnnotationFormat]],
|
||||
return_tensors: Optional[Union[str, TensorType]],
|
||||
**kwargs,
|
||||
) -> BatchFeature:
|
||||
"""
|
||||
Preprocess an image or a batch of images so that it can be used by the model.
|
||||
|
@ -578,8 +578,6 @@ class DeformableDetrImageProcessorFast(BaseImageProcessorFast):
|
||||
- "segments_info" (`List[Dict]`): List of segments for an image. Each segment should be a dictionary.
|
||||
An image can have no segments, in which case the list should be empty.
|
||||
- "file_name" (`str`): The file name of the image.
|
||||
format (`str`, *optional*, defaults to `AnnotationFormat.COCO_DETECTION`):
|
||||
Data format of the annotations. One of "coco_detection" or "coco_panoptic".
|
||||
masks_path (`str` or `pathlib.Path`, *optional*):
|
||||
Path to the directory containing the segmentation masks.
|
||||
"""
|
||||
@ -597,19 +595,17 @@ class DeformableDetrImageProcessorFast(BaseImageProcessorFast):
|
||||
)
|
||||
kwargs["size"] = kwargs.pop("max_size")
|
||||
|
||||
return super().preprocess(images, annotations=annotations, masks_path=masks_path, **kwargs)
|
||||
return super().preprocess(images, annotations, masks_path, **kwargs)
|
||||
|
||||
def _preprocess(
|
||||
self,
|
||||
images: List["torch.Tensor"],
|
||||
annotations: Optional[Union[AnnotationType, List[AnnotationType]]],
|
||||
return_segmentation_masks: bool,
|
||||
masks_path: Optional[Union[str, pathlib.Path]],
|
||||
return_segmentation_masks: bool,
|
||||
do_resize: bool,
|
||||
size: SizeDict,
|
||||
interpolation: Optional["F.InterpolationMode"],
|
||||
do_center_crop: bool,
|
||||
crop_size: SizeDict,
|
||||
do_rescale: bool,
|
||||
rescale_factor: float,
|
||||
do_normalize: bool,
|
||||
@ -620,6 +616,7 @@ class DeformableDetrImageProcessorFast(BaseImageProcessorFast):
|
||||
pad_size: Optional[Dict[str, int]],
|
||||
format: Optional[Union[str, AnnotationFormat]],
|
||||
return_tensors: Optional[Union[str, TensorType]],
|
||||
**kwargs,
|
||||
) -> BatchFeature:
|
||||
"""
|
||||
Preprocess an image or a batch of images so that it can be used by the model.
|
||||
|
@ -28,11 +28,7 @@ from ...image_processing_utils_fast import (
|
||||
get_max_height_width,
|
||||
safe_squeeze,
|
||||
)
|
||||
from ...image_transforms import (
|
||||
center_to_corners_format,
|
||||
corners_to_center_format,
|
||||
id_to_rgb,
|
||||
)
|
||||
from ...image_transforms import center_to_corners_format, corners_to_center_format, id_to_rgb
|
||||
from ...image_utils import (
|
||||
IMAGENET_DEFAULT_MEAN,
|
||||
IMAGENET_DEFAULT_STD,
|
||||
@ -603,8 +599,6 @@ class DetrImageProcessorFast(BaseImageProcessorFast):
|
||||
- "segments_info" (`List[Dict]`): List of segments for an image. Each segment should be a dictionary.
|
||||
An image can have no segments, in which case the list should be empty.
|
||||
- "file_name" (`str`): The file name of the image.
|
||||
format (`str`, *optional*, defaults to `AnnotationFormat.COCO_DETECTION`):
|
||||
Data format of the annotations. One of "coco_detection" or "coco_panoptic".
|
||||
masks_path (`str` or `pathlib.Path`, *optional*):
|
||||
Path to the directory containing the segmentation masks.
|
||||
"""
|
||||
@ -622,19 +616,17 @@ class DetrImageProcessorFast(BaseImageProcessorFast):
|
||||
)
|
||||
kwargs["size"] = kwargs.pop("max_size")
|
||||
|
||||
return super().preprocess(images, annotations=annotations, masks_path=masks_path, **kwargs)
|
||||
return super().preprocess(images, annotations, masks_path, **kwargs)
|
||||
|
||||
def _preprocess(
|
||||
self,
|
||||
images: List["torch.Tensor"],
|
||||
annotations: Optional[Union[AnnotationType, List[AnnotationType]]],
|
||||
return_segmentation_masks: bool,
|
||||
masks_path: Optional[Union[str, pathlib.Path]],
|
||||
return_segmentation_masks: bool,
|
||||
do_resize: bool,
|
||||
size: SizeDict,
|
||||
interpolation: Optional["F.InterpolationMode"],
|
||||
do_center_crop: bool,
|
||||
crop_size: SizeDict,
|
||||
do_rescale: bool,
|
||||
rescale_factor: float,
|
||||
do_normalize: bool,
|
||||
@ -645,6 +637,7 @@ class DetrImageProcessorFast(BaseImageProcessorFast):
|
||||
pad_size: Optional[Dict[str, int]],
|
||||
format: Optional[Union[str, AnnotationFormat]],
|
||||
return_tensors: Optional[Union[str, TensorType]],
|
||||
**kwargs,
|
||||
) -> BatchFeature:
|
||||
"""
|
||||
Preprocess an image or a batch of images so that it can be used by the model.
|
||||
|
@ -609,8 +609,6 @@ class GroundingDinoImageProcessorFast(BaseImageProcessorFast):
|
||||
- "segments_info" (`List[Dict]`): List of segments for an image. Each segment should be a dictionary.
|
||||
An image can have no segments, in which case the list should be empty.
|
||||
- "file_name" (`str`): The file name of the image.
|
||||
format (`str`, *optional*, defaults to `AnnotationFormat.COCO_DETECTION`):
|
||||
Data format of the annotations. One of "coco_detection" or "coco_panoptic".
|
||||
masks_path (`str` or `pathlib.Path`, *optional*):
|
||||
Path to the directory containing the segmentation masks.
|
||||
"""
|
||||
@ -628,19 +626,17 @@ class GroundingDinoImageProcessorFast(BaseImageProcessorFast):
|
||||
)
|
||||
kwargs["size"] = kwargs.pop("max_size")
|
||||
|
||||
return super().preprocess(images, annotations=annotations, masks_path=masks_path, **kwargs)
|
||||
return super().preprocess(images, annotations, masks_path, **kwargs)
|
||||
|
||||
def _preprocess(
|
||||
self,
|
||||
images: List["torch.Tensor"],
|
||||
annotations: Optional[Union[AnnotationType, List[AnnotationType]]],
|
||||
return_segmentation_masks: bool,
|
||||
masks_path: Optional[Union[str, pathlib.Path]],
|
||||
return_segmentation_masks: bool,
|
||||
do_resize: bool,
|
||||
size: SizeDict,
|
||||
interpolation: Optional["F.InterpolationMode"],
|
||||
do_center_crop: bool,
|
||||
crop_size: SizeDict,
|
||||
do_rescale: bool,
|
||||
rescale_factor: float,
|
||||
do_normalize: bool,
|
||||
@ -651,6 +647,7 @@ class GroundingDinoImageProcessorFast(BaseImageProcessorFast):
|
||||
pad_size: Optional[Dict[str, int]],
|
||||
format: Optional[Union[str, AnnotationFormat]],
|
||||
return_tensors: Optional[Union[str, TensorType]],
|
||||
**kwargs,
|
||||
) -> BatchFeature:
|
||||
"""
|
||||
Preprocess an image or a batch of images so that it can be used by the model.
|
||||
|
@ -26,11 +26,7 @@ from ...image_processing_utils_fast import (
|
||||
group_images_by_shape,
|
||||
reorder_images,
|
||||
)
|
||||
from ...image_utils import (
|
||||
ImageInput,
|
||||
PILImageResampling,
|
||||
SizeDict,
|
||||
)
|
||||
from ...image_utils import ImageInput, PILImageResampling, SizeDict
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import (
|
||||
TensorType,
|
||||
@ -320,7 +316,7 @@ def get_best_fit(
|
||||
else:
|
||||
optimal_canvas = chosen_canvas[0]
|
||||
|
||||
return tuple(optimal_canvas.tolist())
|
||||
return optimal_canvas
|
||||
|
||||
|
||||
class Llama4ImageProcessorKwargs(DefaultFastImageProcessorKwargs):
|
||||
@ -356,6 +352,8 @@ class Llama4ImageProcessorFast(BaseImageProcessorFast):
|
||||
def __init__(self, **kwargs: Unpack[Llama4ImageProcessorKwargs]):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
# Disable compilation here as conversion to bfloat16 causes differences in the output of the compiled and non-compiled versions
|
||||
@torch.compiler.disable
|
||||
def rescale_and_normalize(
|
||||
self,
|
||||
images: "torch.Tensor",
|
||||
@ -399,7 +397,7 @@ class Llama4ImageProcessorFast(BaseImageProcessorFast):
|
||||
**kwargs,
|
||||
) -> BatchFeature:
|
||||
possible_resolutions = find_supported_resolutions(max_num_chunks=max_patches, patch_size=size)
|
||||
possible_resolutions = torch.tensor(possible_resolutions)
|
||||
possible_resolutions = torch.tensor(possible_resolutions, device=images[0].device)
|
||||
# process images by batch, grouped by shape
|
||||
grouped_images, grouped_images_index = group_images_by_shape(images)
|
||||
grouped_processed_images = {}
|
||||
@ -438,7 +436,9 @@ class Llama4ImageProcessorFast(BaseImageProcessorFast):
|
||||
# split into tiles
|
||||
processed_images = split_to_tiles(processed_images, ratio_h, ratio_w)
|
||||
grouped_processed_images[shape] = processed_images
|
||||
grouped_aspect_ratios[shape] = torch.tensor([[ratio_h, ratio_w]] * stacked_images.shape[0])
|
||||
grouped_aspect_ratios[shape] = torch.tensor(
|
||||
[[ratio_h, ratio_w]] * stacked_images.shape[0], device=images[0].device
|
||||
)
|
||||
|
||||
# add a global tile to the processed tile if there are more than one tile
|
||||
if ratio_h * ratio_w > 1:
|
||||
|
@ -397,24 +397,20 @@ class RTDetrImageProcessorFast(BaseImageProcessorFast):
|
||||
- "segments_info" (`List[Dict]`): List of segments for an image. Each segment should be a dictionary.
|
||||
An image can have no segments, in which case the list should be empty.
|
||||
- "file_name" (`str`): The file name of the image.
|
||||
format (`str`, *optional*, defaults to `AnnotationFormat.COCO_DETECTION`):
|
||||
Data format of the annotations. One of "coco_detection" or "coco_panoptic".
|
||||
masks_path (`str` or `pathlib.Path`, *optional*):
|
||||
Path to the directory containing the segmentation masks.
|
||||
"""
|
||||
return super().preprocess(images, annotations=annotations, masks_path=masks_path, **kwargs)
|
||||
return super().preprocess(images, annotations, masks_path, **kwargs)
|
||||
|
||||
def _preprocess(
|
||||
self,
|
||||
images: List["torch.Tensor"],
|
||||
annotations: Optional[Union[AnnotationType, List[AnnotationType]]],
|
||||
return_segmentation_masks: bool,
|
||||
masks_path: Optional[Union[str, pathlib.Path]],
|
||||
return_segmentation_masks: bool,
|
||||
do_resize: bool,
|
||||
size: SizeDict,
|
||||
interpolation: Optional["F.InterpolationMode"],
|
||||
do_center_crop: bool,
|
||||
crop_size: SizeDict,
|
||||
do_rescale: bool,
|
||||
rescale_factor: float,
|
||||
do_normalize: bool,
|
||||
@ -425,6 +421,7 @@ class RTDetrImageProcessorFast(BaseImageProcessorFast):
|
||||
pad_size: Optional[Dict[str, int]],
|
||||
format: Optional[Union[str, AnnotationFormat]],
|
||||
return_tensors: Optional[Union[str, TensorType]],
|
||||
**kwargs,
|
||||
) -> BatchFeature:
|
||||
"""
|
||||
Preprocess an image or a batch of images so that it can be used by the model.
|
||||
|
@ -137,7 +137,7 @@ class RTDetrImageProcessorFast(DetrImageProcessorFast, BaseImageProcessorFast):
|
||||
masks_path: Optional[Union[str, pathlib.Path]] = None,
|
||||
**kwargs: Unpack[RTDetrFastImageProcessorKwargs],
|
||||
) -> BatchFeature:
|
||||
return BaseImageProcessorFast().preprocess(images, annotations=annotations, masks_path=masks_path, **kwargs)
|
||||
return BaseImageProcessorFast().preprocess(images, annotations, masks_path, **kwargs)
|
||||
|
||||
def prepare_annotation(
|
||||
self,
|
||||
@ -163,13 +163,11 @@ class RTDetrImageProcessorFast(DetrImageProcessorFast, BaseImageProcessorFast):
|
||||
self,
|
||||
images: List["torch.Tensor"],
|
||||
annotations: Optional[Union[AnnotationType, List[AnnotationType]]],
|
||||
return_segmentation_masks: bool,
|
||||
masks_path: Optional[Union[str, pathlib.Path]],
|
||||
return_segmentation_masks: bool,
|
||||
do_resize: bool,
|
||||
size: SizeDict,
|
||||
interpolation: Optional["F.InterpolationMode"],
|
||||
do_center_crop: bool,
|
||||
crop_size: SizeDict,
|
||||
do_rescale: bool,
|
||||
rescale_factor: float,
|
||||
do_normalize: bool,
|
||||
@ -180,6 +178,7 @@ class RTDetrImageProcessorFast(DetrImageProcessorFast, BaseImageProcessorFast):
|
||||
pad_size: Optional[Dict[str, int]],
|
||||
format: Optional[Union[str, AnnotationFormat]],
|
||||
return_tensors: Optional[Union[str, TensorType]],
|
||||
**kwargs,
|
||||
) -> BatchFeature:
|
||||
"""
|
||||
Preprocess an image or a batch of images so that it can be used by the model.
|
||||
|
@ -131,7 +131,7 @@ class VitMatteImageProcessorFast(BaseImageProcessorFast):
|
||||
kwargs.pop("size")
|
||||
kwargs.pop("crop_size")
|
||||
|
||||
return self._preprocess(images=images, trimaps=trimaps, **kwargs)
|
||||
return self._preprocess(images, trimaps, **kwargs)
|
||||
|
||||
def _prepare_input_trimaps(
|
||||
self, trimaps: ImageInput, device: Optional["torch.device"] = None
|
||||
|
@ -626,8 +626,6 @@ class YolosImageProcessorFast(BaseImageProcessorFast):
|
||||
- "segments_info" (`List[Dict]`): List of segments for an image. Each segment should be a dictionary.
|
||||
An image can have no segments, in which case the list should be empty.
|
||||
- "file_name" (`str`): The file name of the image.
|
||||
format (`str`, *optional*, defaults to `AnnotationFormat.COCO_DETECTION`):
|
||||
Data format of the annotations. One of "coco_detection" or "coco_panoptic".
|
||||
masks_path (`str` or `pathlib.Path`, *optional*):
|
||||
Path to the directory containing the segmentation masks.
|
||||
"""
|
||||
@ -645,19 +643,17 @@ class YolosImageProcessorFast(BaseImageProcessorFast):
|
||||
)
|
||||
kwargs["size"] = kwargs.pop("max_size")
|
||||
|
||||
return super().preprocess(images, annotations=annotations, masks_path=masks_path, **kwargs)
|
||||
return super().preprocess(images, annotations, masks_path, **kwargs)
|
||||
|
||||
def _preprocess(
|
||||
self,
|
||||
images: List["torch.Tensor"],
|
||||
annotations: Optional[Union[AnnotationType, List[AnnotationType]]],
|
||||
return_segmentation_masks: bool,
|
||||
masks_path: Optional[Union[str, pathlib.Path]],
|
||||
return_segmentation_masks: bool,
|
||||
do_resize: bool,
|
||||
size: SizeDict,
|
||||
interpolation: Optional["F.InterpolationMode"],
|
||||
do_center_crop: bool,
|
||||
crop_size: SizeDict,
|
||||
do_rescale: bool,
|
||||
rescale_factor: float,
|
||||
do_normalize: bool,
|
||||
@ -668,6 +664,7 @@ class YolosImageProcessorFast(BaseImageProcessorFast):
|
||||
pad_size: Optional[Dict[str, int]],
|
||||
format: Optional[Union[str, AnnotationFormat]],
|
||||
return_tensors: Optional[Union[str, TensorType]],
|
||||
**kwargs,
|
||||
) -> BatchFeature:
|
||||
"""
|
||||
Preprocess an image or a batch of images so that it can be used by the model.
|
||||
|
@ -19,13 +19,7 @@ import numpy as np
|
||||
import requests
|
||||
from packaging import version
|
||||
|
||||
from transformers.testing_utils import (
|
||||
require_torch,
|
||||
require_torch_gpu,
|
||||
require_vision,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
from transformers.testing_utils import require_torch, require_torch_gpu, require_vision, slow, torch_device
|
||||
from transformers.utils import is_torch_available, is_torchvision_available, is_vision_available
|
||||
|
||||
from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs
|
||||
|
@ -19,8 +19,9 @@ import warnings
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
from packaging import version
|
||||
|
||||
from transformers.testing_utils import is_flaky, require_torch, require_vision
|
||||
from transformers.testing_utils import is_flaky, require_torch, require_torch_gpu, require_vision, slow, torch_device
|
||||
from transformers.utils import is_torch_available, is_torchvision_available, is_vision_available
|
||||
|
||||
from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs
|
||||
@ -334,3 +335,24 @@ class VitMatteImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
|
||||
self.assertLessEqual(
|
||||
torch.mean(torch.abs(encoding_slow.pixel_values - encoding_fast.pixel_values)).item(), 1e-3
|
||||
)
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
@require_vision
|
||||
def test_can_compile_fast_image_processor(self):
|
||||
# override as trimaps are needed for the image processor
|
||||
if self.fast_image_processing_class is None:
|
||||
self.skipTest("Skipping compilation test as fast image processor is not defined")
|
||||
if version.parse(torch.__version__) < version.parse("2.3"):
|
||||
self.skipTest(reason="This test requires torch >= 2.3 to run.")
|
||||
|
||||
torch.compiler.reset()
|
||||
input_image = torch.randint(0, 255, (3, 224, 224), dtype=torch.uint8)
|
||||
dummy_trimap = np.random.randint(0, 3, size=input_image.shape[1:])
|
||||
image_processor = self.fast_image_processing_class(**self.image_processor_dict)
|
||||
output_eager = image_processor(input_image, dummy_trimap, device=torch_device, return_tensors="pt")
|
||||
|
||||
image_processor = torch.compile(image_processor, mode="reduce-overhead")
|
||||
output_compiled = image_processor(input_image, dummy_trimap, device=torch_device, return_tensors="pt")
|
||||
|
||||
torch.testing.assert_close(output_eager.pixel_values, output_compiled.pixel_values, rtol=1e-4, atol=1e-4)
|
||||
|
Loading…
Reference in New Issue
Block a user