mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-04 13:20:12 +06:00
Add Fast Image Processor for Donut (#37081)
* add donut fast image processor support * run make style * Update src/transformers/models/donut/image_processing_donut_fast.py Co-authored-by: Parteek <parteekkamboj112@gmail.com> * update test, remove none default values * add do_align_axis = True test, fix bug in slow image processor * run make style * remove np usage * make style * Apply suggestions from code review * Update src/transformers/models/donut/image_processing_donut_fast.py Co-authored-by: Yoni Gozlan <74535834+yonigozlan@users.noreply.github.com> * add size revert in preprocess * make style * fix copies * add test for preprocess with kwargs * make style * handle None input_data_format in align_long_axis --------- Co-authored-by: Parteek <parteekkamboj112@gmail.com> Co-authored-by: Yoni Gozlan <74535834+yonigozlan@users.noreply.github.com>
This commit is contained in:
parent
4e53840920
commit
7cc9e61a3a
@ -208,6 +208,11 @@ print(answer)
|
|||||||
[[autodoc]] DonutImageProcessor
|
[[autodoc]] DonutImageProcessor
|
||||||
- preprocess
|
- preprocess
|
||||||
|
|
||||||
|
## DonutImageProcessorFast
|
||||||
|
|
||||||
|
[[autodoc]] DonutImageProcessorFast
|
||||||
|
- preprocess
|
||||||
|
|
||||||
## DonutFeatureExtractor
|
## DonutFeatureExtractor
|
||||||
|
|
||||||
[[autodoc]] DonutFeatureExtractor
|
[[autodoc]] DonutFeatureExtractor
|
||||||
|
@ -80,7 +80,7 @@ else:
|
|||||||
("detr", ("DetrImageProcessor", "DetrImageProcessorFast")),
|
("detr", ("DetrImageProcessor", "DetrImageProcessorFast")),
|
||||||
("dinat", ("ViTImageProcessor", "ViTImageProcessorFast")),
|
("dinat", ("ViTImageProcessor", "ViTImageProcessorFast")),
|
||||||
("dinov2", ("BitImageProcessor",)),
|
("dinov2", ("BitImageProcessor",)),
|
||||||
("donut-swin", ("DonutImageProcessor",)),
|
("donut-swin", ("DonutImageProcessor", "DonutImageProcessorFast")),
|
||||||
("dpt", ("DPTImageProcessor",)),
|
("dpt", ("DPTImageProcessor",)),
|
||||||
("efficientformer", ("EfficientFormerImageProcessor",)),
|
("efficientformer", ("EfficientFormerImageProcessor",)),
|
||||||
("efficientnet", ("EfficientNetImageProcessor",)),
|
("efficientnet", ("EfficientNetImageProcessor",)),
|
||||||
|
@ -21,6 +21,7 @@ if TYPE_CHECKING:
|
|||||||
from .configuration_donut_swin import *
|
from .configuration_donut_swin import *
|
||||||
from .feature_extraction_donut import *
|
from .feature_extraction_donut import *
|
||||||
from .image_processing_donut import *
|
from .image_processing_donut import *
|
||||||
|
from .image_processing_donut_fast import *
|
||||||
from .modeling_donut_swin import *
|
from .modeling_donut_swin import *
|
||||||
from .processing_donut import *
|
from .processing_donut import *
|
||||||
else:
|
else:
|
||||||
|
@ -20,6 +20,7 @@ import numpy as np
|
|||||||
|
|
||||||
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
|
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
|
||||||
from ...image_transforms import (
|
from ...image_transforms import (
|
||||||
|
convert_to_rgb,
|
||||||
get_resize_output_image_size,
|
get_resize_output_image_size,
|
||||||
pad,
|
pad,
|
||||||
resize,
|
resize,
|
||||||
@ -151,10 +152,21 @@ class DonutImageProcessor(BaseImageProcessor):
|
|||||||
input_height, input_width = get_image_size(image, channel_dim=input_data_format)
|
input_height, input_width = get_image_size(image, channel_dim=input_data_format)
|
||||||
output_height, output_width = size["height"], size["width"]
|
output_height, output_width = size["height"], size["width"]
|
||||||
|
|
||||||
|
if input_data_format is None:
|
||||||
|
# We assume that all images have the same channel dimension format.
|
||||||
|
input_data_format = infer_channel_dimension_format(image)
|
||||||
|
|
||||||
|
if input_data_format == ChannelDimension.LAST:
|
||||||
|
rot_axes = (0, 1)
|
||||||
|
elif input_data_format == ChannelDimension.FIRST:
|
||||||
|
rot_axes = (1, 2)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported data format: {input_data_format}")
|
||||||
|
|
||||||
if (output_width < output_height and input_width > input_height) or (
|
if (output_width < output_height and input_width > input_height) or (
|
||||||
output_width > output_height and input_width < input_height
|
output_width > output_height and input_width < input_height
|
||||||
):
|
):
|
||||||
image = np.rot90(image, 3)
|
image = np.rot90(image, 3, axes=rot_axes)
|
||||||
|
|
||||||
if data_format is not None:
|
if data_format is not None:
|
||||||
image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
|
image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
|
||||||
@ -407,6 +419,8 @@ class DonutImageProcessor(BaseImageProcessor):
|
|||||||
resample=resample,
|
resample=resample,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
images = [convert_to_rgb(image) for image in images]
|
||||||
|
|
||||||
# All transformations expect numpy arrays.
|
# All transformations expect numpy arrays.
|
||||||
images = [to_numpy_array(image) for image in images]
|
images = [to_numpy_array(image) for image in images]
|
||||||
|
|
||||||
|
289
src/transformers/models/donut/image_processing_donut_fast.py
Normal file
289
src/transformers/models/donut/image_processing_donut_fast.py
Normal file
@ -0,0 +1,289 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Fast Image processor class for Donut."""
|
||||||
|
|
||||||
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
from ...image_processing_utils_fast import (
|
||||||
|
BASE_IMAGE_PROCESSOR_FAST_DOCSTRING,
|
||||||
|
BASE_IMAGE_PROCESSOR_FAST_DOCSTRING_PREPROCESS,
|
||||||
|
BaseImageProcessorFast,
|
||||||
|
BatchFeature,
|
||||||
|
DefaultFastImageProcessorKwargs,
|
||||||
|
)
|
||||||
|
from ...image_transforms import group_images_by_shape, reorder_images
|
||||||
|
from ...image_utils import (
|
||||||
|
IMAGENET_STANDARD_MEAN,
|
||||||
|
IMAGENET_STANDARD_STD,
|
||||||
|
ImageInput,
|
||||||
|
PILImageResampling,
|
||||||
|
SizeDict,
|
||||||
|
)
|
||||||
|
from ...processing_utils import Unpack
|
||||||
|
from ...utils import (
|
||||||
|
TensorType,
|
||||||
|
add_start_docstrings,
|
||||||
|
is_torch_available,
|
||||||
|
is_torchvision_available,
|
||||||
|
is_torchvision_v2_available,
|
||||||
|
logging,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
if is_torch_available():
|
||||||
|
import torch
|
||||||
|
|
||||||
|
if is_torchvision_available():
|
||||||
|
if is_torchvision_v2_available():
|
||||||
|
from torchvision.transforms.v2 import functional as F
|
||||||
|
else:
|
||||||
|
from torchvision.transforms import functional as F
|
||||||
|
|
||||||
|
|
||||||
|
class DonutFastImageProcessorKwargs(DefaultFastImageProcessorKwargs):
|
||||||
|
do_thumbnail: Optional[bool]
|
||||||
|
do_align_long_axis: Optional[bool]
|
||||||
|
do_pad: Optional[bool]
|
||||||
|
|
||||||
|
|
||||||
|
@add_start_docstrings(
|
||||||
|
"Constructs a fast Donut image processor.",
|
||||||
|
BASE_IMAGE_PROCESSOR_FAST_DOCSTRING,
|
||||||
|
"""
|
||||||
|
do_thumbnail (`bool`, *optional*, defaults to `self.do_thumbnail`):
|
||||||
|
Whether to resize the image using thumbnail method.
|
||||||
|
do_align_long_axis (`bool`, *optional*, defaults to `self.do_align_long_axis`):
|
||||||
|
Whether to align the long axis of the image with the long axis of `size` by rotating by 90 degrees.
|
||||||
|
do_pad (`bool`, *optional*, defaults to `self.do_pad`):
|
||||||
|
Whether to pad the image. If `random_padding` is set to `True`, each image is padded with a random
|
||||||
|
amount of padding on each size, up to the largest image size in the batch. Otherwise, all images are
|
||||||
|
padded to the largest image size in the batch.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
class DonutImageProcessorFast(BaseImageProcessorFast):
|
||||||
|
resample = PILImageResampling.BILINEAR
|
||||||
|
image_mean = IMAGENET_STANDARD_MEAN
|
||||||
|
image_std = IMAGENET_STANDARD_STD
|
||||||
|
size = {"height": 2560, "width": 1920}
|
||||||
|
do_resize = True
|
||||||
|
do_rescale = True
|
||||||
|
do_normalize = True
|
||||||
|
do_thumbnail = True
|
||||||
|
do_align_long_axis = False
|
||||||
|
do_pad = True
|
||||||
|
valid_kwargs = DonutFastImageProcessorKwargs
|
||||||
|
|
||||||
|
def __init__(self, **kwargs: Unpack[DonutFastImageProcessorKwargs]):
|
||||||
|
size = kwargs.pop("size", None)
|
||||||
|
if isinstance(size, (tuple, list)):
|
||||||
|
size = size[::-1]
|
||||||
|
kwargs["size"] = size
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
@add_start_docstrings(
|
||||||
|
BASE_IMAGE_PROCESSOR_FAST_DOCSTRING_PREPROCESS,
|
||||||
|
"""
|
||||||
|
do_thumbnail (`bool`, *optional*, defaults to `self.do_thumbnail`):
|
||||||
|
Whether to resize the image using thumbnail method.
|
||||||
|
do_align_long_axis (`bool`, *optional*, defaults to `self.do_align_long_axis`):
|
||||||
|
Whether to align the long axis of the image with the long axis of `size` by rotating by 90 degrees.
|
||||||
|
do_pad (`bool`, *optional*, defaults to `self.do_pad`):
|
||||||
|
Whether to pad the image. If `random_padding` is set to `True`, each image is padded with a random
|
||||||
|
amount of padding on each size, up to the largest image size in the batch. Otherwise, all images are
|
||||||
|
padded to the largest image size in the batch.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
def preprocess(self, images: ImageInput, **kwargs: Unpack[DonutFastImageProcessorKwargs]) -> BatchFeature:
|
||||||
|
if "size" in kwargs:
|
||||||
|
size = kwargs.pop("size")
|
||||||
|
if isinstance(size, (tuple, list)):
|
||||||
|
size = size[::-1]
|
||||||
|
kwargs["size"] = size
|
||||||
|
return super().preprocess(images, **kwargs)
|
||||||
|
|
||||||
|
def align_long_axis(
|
||||||
|
self,
|
||||||
|
image: "torch.Tensor",
|
||||||
|
size: SizeDict,
|
||||||
|
) -> "torch.Tensor":
|
||||||
|
"""
|
||||||
|
Align the long axis of the image to the longest axis of the specified size.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image (`torch.Tensor`):
|
||||||
|
The image to be aligned.
|
||||||
|
size (`Dict[str, int]`):
|
||||||
|
The size `{"height": h, "width": w}` to align the long axis to.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`torch.Tensor`: The aligned image.
|
||||||
|
"""
|
||||||
|
input_height, input_width = image.shape[-2:]
|
||||||
|
output_height, output_width = size.height, size.width
|
||||||
|
|
||||||
|
if (output_width < output_height and input_width > input_height) or (
|
||||||
|
output_width > output_height and input_width < input_height
|
||||||
|
):
|
||||||
|
height_dim, width_dim = image.dim() - 2, image.dim() - 1
|
||||||
|
image = torch.rot90(image, 3, dims=[height_dim, width_dim])
|
||||||
|
|
||||||
|
return image
|
||||||
|
|
||||||
|
def pad_image(
|
||||||
|
self,
|
||||||
|
image: "torch.Tensor",
|
||||||
|
size: SizeDict,
|
||||||
|
random_padding: bool = False,
|
||||||
|
) -> "torch.Tensor":
|
||||||
|
"""
|
||||||
|
Pad the image to the specified size.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image (`torch.Tensor`):
|
||||||
|
The image to be padded.
|
||||||
|
size (`Dict[str, int]`):
|
||||||
|
The size `{"height": h, "width": w}` to pad the image to.
|
||||||
|
random_padding (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether to use random padding or not.
|
||||||
|
data_format (`str` or `ChannelDimension`, *optional*):
|
||||||
|
The data format of the output image. If unset, the same format as the input image is used.
|
||||||
|
input_data_format (`ChannelDimension` or `str`, *optional*):
|
||||||
|
The channel dimension format of the input image. If not provided, it will be inferred.
|
||||||
|
"""
|
||||||
|
output_height, output_width = size.height, size.width
|
||||||
|
input_height, input_width = image.shape[-2:]
|
||||||
|
|
||||||
|
delta_width = output_width - input_width
|
||||||
|
delta_height = output_height - input_height
|
||||||
|
|
||||||
|
if random_padding:
|
||||||
|
pad_top = torch.random.randint(low=0, high=delta_height + 1)
|
||||||
|
pad_left = torch.random.randint(low=0, high=delta_width + 1)
|
||||||
|
else:
|
||||||
|
pad_top = delta_height // 2
|
||||||
|
pad_left = delta_width // 2
|
||||||
|
|
||||||
|
pad_bottom = delta_height - pad_top
|
||||||
|
pad_right = delta_width - pad_left
|
||||||
|
|
||||||
|
padding = (pad_left, pad_top, pad_right, pad_bottom)
|
||||||
|
return F.pad(image, padding)
|
||||||
|
|
||||||
|
def pad(self, *args, **kwargs):
|
||||||
|
logger.info("pad is deprecated and will be removed in version 4.27. Please use pad_image instead.")
|
||||||
|
return self.pad_image(*args, **kwargs)
|
||||||
|
|
||||||
|
def thumbnail(
|
||||||
|
self,
|
||||||
|
image: "torch.Tensor",
|
||||||
|
size: SizeDict,
|
||||||
|
) -> "torch.Tensor":
|
||||||
|
"""
|
||||||
|
Resize the image to make a thumbnail. The image is resized so that no dimension is larger than any
|
||||||
|
corresponding dimension of the specified size.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image (`torch.Tensor`):
|
||||||
|
The image to be resized.
|
||||||
|
size (`Dict[str, int]`):
|
||||||
|
The size `{"height": h, "width": w}` to resize the image to.
|
||||||
|
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
|
||||||
|
The resampling filter to use.
|
||||||
|
data_format (`Optional[Union[str, ChannelDimension]]`, *optional*):
|
||||||
|
The data format of the output image. If unset, the same format as the input image is used.
|
||||||
|
input_data_format (`ChannelDimension` or `str`, *optional*):
|
||||||
|
The channel dimension format of the input image. If not provided, it will be inferred.
|
||||||
|
"""
|
||||||
|
input_height, input_width = image.shape[-2:]
|
||||||
|
output_height, output_width = size.height, size.width
|
||||||
|
|
||||||
|
# We always resize to the smallest of either the input or output size.
|
||||||
|
height = min(input_height, output_height)
|
||||||
|
width = min(input_width, output_width)
|
||||||
|
|
||||||
|
if height == input_height and width == input_width:
|
||||||
|
return image
|
||||||
|
|
||||||
|
if input_height > input_width:
|
||||||
|
width = int(input_width * height / input_height)
|
||||||
|
elif input_width > input_height:
|
||||||
|
height = int(input_height * width / input_width)
|
||||||
|
|
||||||
|
return self.resize(
|
||||||
|
image,
|
||||||
|
size=SizeDict(width=width, height=height),
|
||||||
|
interpolation=F.InterpolationMode.BICUBIC,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _preprocess(
|
||||||
|
self,
|
||||||
|
images: list["torch.Tensor"],
|
||||||
|
do_resize: bool,
|
||||||
|
do_thumbnail: bool,
|
||||||
|
do_align_long_axis: bool,
|
||||||
|
do_pad: bool,
|
||||||
|
size: SizeDict,
|
||||||
|
interpolation: Optional["F.InterpolationMode"],
|
||||||
|
do_center_crop: bool,
|
||||||
|
crop_size: SizeDict,
|
||||||
|
do_rescale: bool,
|
||||||
|
rescale_factor: float,
|
||||||
|
do_normalize: bool,
|
||||||
|
image_mean: Optional[Union[float, list[float]]],
|
||||||
|
image_std: Optional[Union[float, list[float]]],
|
||||||
|
return_tensors: Optional[Union[str, TensorType]],
|
||||||
|
**kwargs,
|
||||||
|
) -> BatchFeature:
|
||||||
|
# Group images by size for batched resizing
|
||||||
|
grouped_images, grouped_images_index = group_images_by_shape(images)
|
||||||
|
resized_images_grouped = {}
|
||||||
|
for shape, stacked_images in grouped_images.items():
|
||||||
|
if do_align_long_axis:
|
||||||
|
stacked_images = self.align_long_axis(image=stacked_images, size=size)
|
||||||
|
if do_resize:
|
||||||
|
shortest_edge = min(size.height, size.width)
|
||||||
|
stacked_images = self.resize(
|
||||||
|
image=stacked_images, size=SizeDict(shortest_edge=shortest_edge), interpolation=interpolation
|
||||||
|
)
|
||||||
|
if do_thumbnail:
|
||||||
|
stacked_images = self.thumbnail(image=stacked_images, size=size)
|
||||||
|
if do_pad:
|
||||||
|
stacked_images = self.pad_image(image=stacked_images, size=size, random_padding=False)
|
||||||
|
|
||||||
|
resized_images_grouped[shape] = stacked_images
|
||||||
|
resized_images = reorder_images(resized_images_grouped, grouped_images_index)
|
||||||
|
|
||||||
|
# Group images by size for further processing
|
||||||
|
# Needed in case do_resize is False, or resize returns images with different sizes
|
||||||
|
grouped_images, grouped_images_index = group_images_by_shape(resized_images)
|
||||||
|
processed_images_grouped = {}
|
||||||
|
for shape, stacked_images in grouped_images.items():
|
||||||
|
if do_center_crop:
|
||||||
|
stacked_images = self.center_crop(stacked_images, crop_size)
|
||||||
|
# Fused rescale and normalize
|
||||||
|
stacked_images = self.rescale_and_normalize(
|
||||||
|
stacked_images, do_rescale, rescale_factor, do_normalize, image_mean, image_std
|
||||||
|
)
|
||||||
|
processed_images_grouped[shape] = stacked_images
|
||||||
|
|
||||||
|
processed_images = reorder_images(processed_images_grouped, grouped_images_index)
|
||||||
|
processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
|
||||||
|
|
||||||
|
return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors)
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["DonutImageProcessorFast"]
|
@ -220,10 +220,21 @@ class NougatImageProcessor(BaseImageProcessor):
|
|||||||
input_height, input_width = get_image_size(image, channel_dim=input_data_format)
|
input_height, input_width = get_image_size(image, channel_dim=input_data_format)
|
||||||
output_height, output_width = size["height"], size["width"]
|
output_height, output_width = size["height"], size["width"]
|
||||||
|
|
||||||
|
if input_data_format is None:
|
||||||
|
# We assume that all images have the same channel dimension format.
|
||||||
|
input_data_format = infer_channel_dimension_format(image)
|
||||||
|
|
||||||
|
if input_data_format == ChannelDimension.LAST:
|
||||||
|
rot_axes = (0, 1)
|
||||||
|
elif input_data_format == ChannelDimension.FIRST:
|
||||||
|
rot_axes = (1, 2)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported data format: {input_data_format}")
|
||||||
|
|
||||||
if (output_width < output_height and input_width > input_height) or (
|
if (output_width < output_height and input_width > input_height) or (
|
||||||
output_width > output_height and input_width < input_height
|
output_width > output_height and input_width < input_height
|
||||||
):
|
):
|
||||||
image = np.rot90(image, 3)
|
image = np.rot90(image, 3, axes=rot_axes)
|
||||||
|
|
||||||
if data_format is not None:
|
if data_format is not None:
|
||||||
image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
|
image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
|
||||||
|
@ -18,7 +18,7 @@ import unittest
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from transformers.testing_utils import is_flaky, require_torch, require_vision
|
from transformers.testing_utils import is_flaky, require_torch, require_vision
|
||||||
from transformers.utils import is_torch_available, is_vision_available
|
from transformers.utils import is_torch_available, is_torchvision_available, is_vision_available
|
||||||
|
|
||||||
from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs
|
from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs
|
||||||
|
|
||||||
@ -31,6 +31,9 @@ if is_vision_available():
|
|||||||
|
|
||||||
from transformers import DonutImageProcessor
|
from transformers import DonutImageProcessor
|
||||||
|
|
||||||
|
if is_torchvision_available():
|
||||||
|
from transformers import DonutImageProcessorFast
|
||||||
|
|
||||||
|
|
||||||
class DonutImageProcessingTester:
|
class DonutImageProcessingTester:
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -96,6 +99,7 @@ class DonutImageProcessingTester:
|
|||||||
@require_vision
|
@require_vision
|
||||||
class DonutImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
|
class DonutImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
|
||||||
image_processing_class = DonutImageProcessor if is_vision_available() else None
|
image_processing_class = DonutImageProcessor if is_vision_available() else None
|
||||||
|
fast_image_processing_class = DonutImageProcessorFast if is_torchvision_available() else None
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
super().setUp()
|
super().setUp()
|
||||||
@ -106,7 +110,8 @@ class DonutImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
|
|||||||
return self.image_processor_tester.prepare_image_processor_dict()
|
return self.image_processor_tester.prepare_image_processor_dict()
|
||||||
|
|
||||||
def test_image_processor_properties(self):
|
def test_image_processor_properties(self):
|
||||||
image_processing = self.image_processing_class(**self.image_processor_dict)
|
for image_processing_class in self.image_processor_list:
|
||||||
|
image_processing = image_processing_class(**self.image_processor_dict)
|
||||||
self.assertTrue(hasattr(image_processing, "do_resize"))
|
self.assertTrue(hasattr(image_processing, "do_resize"))
|
||||||
self.assertTrue(hasattr(image_processing, "size"))
|
self.assertTrue(hasattr(image_processing, "size"))
|
||||||
self.assertTrue(hasattr(image_processing, "do_thumbnail"))
|
self.assertTrue(hasattr(image_processing, "do_thumbnail"))
|
||||||
@ -117,20 +122,43 @@ class DonutImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
|
|||||||
self.assertTrue(hasattr(image_processing, "image_std"))
|
self.assertTrue(hasattr(image_processing, "image_std"))
|
||||||
|
|
||||||
def test_image_processor_from_dict_with_kwargs(self):
|
def test_image_processor_from_dict_with_kwargs(self):
|
||||||
image_processor = self.image_processing_class.from_dict(self.image_processor_dict)
|
for image_processing_class in self.image_processor_list:
|
||||||
|
image_processor = image_processing_class.from_dict(self.image_processor_dict)
|
||||||
self.assertEqual(image_processor.size, {"height": 18, "width": 20})
|
self.assertEqual(image_processor.size, {"height": 18, "width": 20})
|
||||||
|
|
||||||
image_processor = self.image_processing_class.from_dict(self.image_processor_dict, size=42)
|
image_processor = image_processing_class.from_dict(self.image_processor_dict, size=42)
|
||||||
self.assertEqual(image_processor.size, {"height": 42, "width": 42})
|
self.assertEqual(image_processor.size, {"height": 42, "width": 42})
|
||||||
|
|
||||||
# Previous config had dimensions in (width, height) order
|
# Previous config had dimensions in (width, height) order
|
||||||
image_processor = self.image_processing_class.from_dict(self.image_processor_dict, size=(42, 84))
|
image_processor = image_processing_class.from_dict(self.image_processor_dict, size=(42, 84))
|
||||||
self.assertEqual(image_processor.size, {"height": 84, "width": 42})
|
self.assertEqual(image_processor.size, {"height": 84, "width": 42})
|
||||||
|
|
||||||
|
def test_image_processor_preprocess_with_kwargs(self):
|
||||||
|
for image_processing_class in self.image_processor_list:
|
||||||
|
# Initialize image_processing
|
||||||
|
image_processing = image_processing_class(**self.image_processor_dict)
|
||||||
|
# create random PyTorch tensors
|
||||||
|
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, torchify=True)
|
||||||
|
|
||||||
|
height = 84
|
||||||
|
width = 42
|
||||||
|
# Previous config had dimensions in (width, height) order
|
||||||
|
encoded_images = image_processing(image_inputs[0], size=(width, height), return_tensors="pt").pixel_values
|
||||||
|
self.assertEqual(
|
||||||
|
encoded_images.shape,
|
||||||
|
(
|
||||||
|
1,
|
||||||
|
self.image_processor_tester.num_channels,
|
||||||
|
height,
|
||||||
|
width,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
@is_flaky()
|
@is_flaky()
|
||||||
def test_call_pil(self):
|
def test_call_pil(self):
|
||||||
|
for image_processing_class in self.image_processor_list:
|
||||||
# Initialize image_processing
|
# Initialize image_processing
|
||||||
image_processing = self.image_processing_class(**self.image_processor_dict)
|
image_processing = image_processing_class(**self.image_processor_dict)
|
||||||
# create random PIL images
|
# create random PIL images
|
||||||
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False)
|
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False)
|
||||||
for image in image_inputs:
|
for image in image_inputs:
|
||||||
@ -162,8 +190,9 @@ class DonutImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
|
|||||||
|
|
||||||
@is_flaky()
|
@is_flaky()
|
||||||
def test_call_numpy(self):
|
def test_call_numpy(self):
|
||||||
|
for image_processing_class in self.image_processor_list:
|
||||||
# Initialize image_processing
|
# Initialize image_processing
|
||||||
image_processing = self.image_processing_class(**self.image_processor_dict)
|
image_processing = image_processing_class(**self.image_processor_dict)
|
||||||
# create random numpy tensors
|
# create random numpy tensors
|
||||||
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, numpify=True)
|
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, numpify=True)
|
||||||
for image in image_inputs:
|
for image in image_inputs:
|
||||||
@ -195,8 +224,9 @@ class DonutImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
|
|||||||
|
|
||||||
@is_flaky()
|
@is_flaky()
|
||||||
def test_call_pytorch(self):
|
def test_call_pytorch(self):
|
||||||
|
for image_processing_class in self.image_processor_list:
|
||||||
# Initialize image_processing
|
# Initialize image_processing
|
||||||
image_processing = self.image_processing_class(**self.image_processor_dict)
|
image_processing = image_processing_class(**self.image_processor_dict)
|
||||||
# create random PyTorch tensors
|
# create random PyTorch tensors
|
||||||
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, torchify=True)
|
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, torchify=True)
|
||||||
for image in image_inputs:
|
for image in image_inputs:
|
||||||
@ -225,3 +255,11 @@ class DonutImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
|
|||||||
self.image_processor_tester.size["width"],
|
self.image_processor_tester.size["width"],
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
@require_vision
|
||||||
|
class DonutImageProcessingAlignAxisTest(DonutImageProcessingTest):
|
||||||
|
def setUp(self):
|
||||||
|
super().setUp()
|
||||||
|
self.image_processor_tester = DonutImageProcessingTester(self, do_align_axis=True)
|
||||||
|
Loading…
Reference in New Issue
Block a user