mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
add fast image processor nougat (#37661)
* add fast image processor nougat * test fixes * docstring white space * last fixes * docstring_type * tolerance unit test * fix tolerance * fix rtol * remove traling white space * remove white space * note for tolerance unit test * fix tests * remove print --------- Co-authored-by: yonigozlan <yoni.gozlan@huggingface.co> Co-authored-by: Yoni Gozlan <74535834+yonigozlan@users.noreply.github.com>
This commit is contained in:
parent
0c35280e58
commit
4336ecd1ea
@ -107,6 +107,11 @@ The model is identical to [Donut](donut) in terms of architecture.
|
||||
[[autodoc]] NougatImageProcessor
|
||||
- preprocess
|
||||
|
||||
## NougatImageProcessorFast
|
||||
|
||||
[[autodoc]] NougatImageProcessorFast
|
||||
- preprocess
|
||||
|
||||
## NougatTokenizerFast
|
||||
|
||||
[[autodoc]] NougatTokenizerFast
|
||||
|
@ -126,7 +126,7 @@ else:
|
||||
("mobilevit", ("MobileViTImageProcessor",)),
|
||||
("mobilevitv2", ("MobileViTImageProcessor",)),
|
||||
("nat", ("ViTImageProcessor", "ViTImageProcessorFast")),
|
||||
("nougat", ("NougatImageProcessor",)),
|
||||
("nougat", ("NougatImageProcessor", "NougatImageProcessorFast")),
|
||||
("oneformer", ("OneFormerImageProcessor",)),
|
||||
("owlv2", ("Owlv2ImageProcessor",)),
|
||||
("owlvit", ("OwlViTImageProcessor", "OwlViTImageProcessorFast")),
|
||||
|
@ -19,6 +19,7 @@ from ...utils.import_utils import define_import_structure
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .image_processing_nougat import *
|
||||
from .image_processing_nougat_fast import *
|
||||
from .processing_nougat import *
|
||||
from .tokenization_nougat_fast import *
|
||||
else:
|
||||
|
@ -169,6 +169,7 @@ class NougatImageProcessor(BaseImageProcessor):
|
||||
min_val = data.min()
|
||||
if max_val == min_val:
|
||||
image = np.array(image)
|
||||
image = to_channel_dimension_format(image, input_data_format, ChannelDimension.LAST)
|
||||
image = (
|
||||
to_channel_dimension_format(image, data_format, input_data_format)
|
||||
if data_format is not None
|
||||
|
327
src/transformers/models/nougat/image_processing_nougat_fast.py
Normal file
327
src/transformers/models/nougat/image_processing_nougat_fast.py
Normal file
@ -0,0 +1,327 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Fast Image processor class for Nougat."""
|
||||
|
||||
from typing import Optional, Union
|
||||
|
||||
from ...image_processing_utils import BatchFeature
|
||||
from ...image_processing_utils_fast import (
|
||||
BaseImageProcessorFast,
|
||||
DefaultFastImageProcessorKwargs,
|
||||
group_images_by_shape,
|
||||
reorder_images,
|
||||
)
|
||||
from ...image_transforms import (
|
||||
get_resize_output_image_size,
|
||||
)
|
||||
from ...image_utils import (
|
||||
IMAGENET_DEFAULT_MEAN,
|
||||
IMAGENET_DEFAULT_STD,
|
||||
ChannelDimension,
|
||||
ImageInput,
|
||||
PILImageResampling,
|
||||
SizeDict,
|
||||
)
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import (
|
||||
TensorType,
|
||||
auto_docstring,
|
||||
is_torch_available,
|
||||
is_torchvision_available,
|
||||
is_torchvision_v2_available,
|
||||
)
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
if is_torchvision_available():
|
||||
if is_torchvision_v2_available():
|
||||
from torchvision.transforms.v2 import functional as F
|
||||
else:
|
||||
from torchvision.transforms import functional as F
|
||||
|
||||
|
||||
class NougatFastImageProcessorKwargs(DefaultFastImageProcessorKwargs):
|
||||
"""
|
||||
Args:
|
||||
do_crop_margin (`bool`, *optional*, defaults to `True`):
|
||||
Whether to crop the image margins.
|
||||
do_thumbnail (`bool`, *optional*, defaults to `True`):
|
||||
Whether to resize the image using thumbnail method.
|
||||
do_align_long_axis (`bool`, *optional*, defaults to `False`):
|
||||
Whether to align the long axis of the image with the long axis of `size` by rotating by 90 degrees.
|
||||
do_pad (`bool`, *optional*, defaults to `True`):
|
||||
Whether to pad the images to the largest image size in the batch.
|
||||
"""
|
||||
|
||||
do_crop_margin: Optional[bool]
|
||||
do_thumbnail: Optional[bool]
|
||||
do_align_long_axis: Optional[bool]
|
||||
do_pad: Optional[bool]
|
||||
|
||||
|
||||
@auto_docstring
|
||||
class NougatImageProcessorFast(BaseImageProcessorFast):
|
||||
resample = PILImageResampling.BILINEAR
|
||||
image_mean = IMAGENET_DEFAULT_MEAN
|
||||
image_std = IMAGENET_DEFAULT_STD
|
||||
size = {"height": 896, "width": 672}
|
||||
do_resize: bool = (True,)
|
||||
do_normalize: bool = True
|
||||
do_thumbnail: bool = True
|
||||
do_align_long_axis: bool = False
|
||||
do_pad: bool = True
|
||||
do_rescale = True
|
||||
do_crop_margin: bool = True
|
||||
valid_kwargs = NougatFastImageProcessorKwargs
|
||||
|
||||
def __init__(self, **kwargs: Unpack[NougatFastImageProcessorKwargs]):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@auto_docstring
|
||||
def preprocess(self, images: ImageInput, **kwargs: Unpack[NougatFastImageProcessorKwargs]) -> BatchFeature:
|
||||
return super().preprocess(images, **kwargs)
|
||||
|
||||
def python_find_non_zero(
|
||||
self,
|
||||
image: "torch.Tensor",
|
||||
):
|
||||
"""This is a reimplementation of a findNonZero function equivalent to cv2."""
|
||||
|
||||
non_zero_indices = torch.nonzero(image, as_tuple=False)
|
||||
idxvec = non_zero_indices[:, [2, 1]]
|
||||
idxvec = idxvec.reshape(-1, 1, 2)
|
||||
return idxvec
|
||||
|
||||
def python_bounding_rect(self, coordinates):
|
||||
"""This is a reimplementation of a BoundingRect function equivalent to cv2."""
|
||||
|
||||
min_values = torch.amin(coordinates, axis=(0, 1)).to(torch.int)
|
||||
max_values = torch.amax(coordinates, axis=(0, 1)).to(torch.int)
|
||||
|
||||
x_min, y_min = min_values[0], min_values[1]
|
||||
width = max_values[0] - x_min + 1
|
||||
height = max_values[1] - y_min + 1
|
||||
return x_min, y_min, width, height
|
||||
|
||||
def crop_margin(
|
||||
self,
|
||||
image: "torch.Tensor",
|
||||
gray_threshold: int = 200,
|
||||
) -> "torch.Tensor":
|
||||
"""
|
||||
Crops the margin of the image. Gray pixels are considered margin (i.e., pixels with a value below the
|
||||
threshold).
|
||||
|
||||
Args:
|
||||
image (`torch.Tensor`):
|
||||
The image to be cropped.
|
||||
gray_threshold (`int`, *optional*, defaults to `200`)
|
||||
Value below which pixels are considered to be gray.
|
||||
"""
|
||||
data = F.rgb_to_grayscale(image, num_output_channels=1)
|
||||
|
||||
max_val = torch.max(data)
|
||||
min_val = torch.min(data)
|
||||
|
||||
if max_val == min_val:
|
||||
return image
|
||||
data = (data - min_val) / (max_val - min_val) * 255
|
||||
gray = data < gray_threshold
|
||||
coords = self.python_find_non_zero(gray)
|
||||
x_min, y_min, width, height = self.python_bounding_rect(coords)
|
||||
image = image[:, y_min : y_min + height, x_min : x_min + width]
|
||||
|
||||
return image
|
||||
|
||||
def align_long_axis(
|
||||
self,
|
||||
image: "torch.Tensor",
|
||||
size: SizeDict,
|
||||
) -> "torch.Tensor":
|
||||
"""
|
||||
Align the long axis of the image to the longest axis of the specified size.
|
||||
|
||||
Args:
|
||||
image (`torch.Tensor`):
|
||||
The image to be aligned.
|
||||
size (`Dict[str, int]`):
|
||||
The size `{"height": h, "width": w}` to align the long axis to.
|
||||
Returns:
|
||||
`torch.Tensor`: The aligned image.
|
||||
"""
|
||||
input_height, input_width = image.shape[-2:]
|
||||
output_height, output_width = size.height, size.width
|
||||
|
||||
if (output_width < output_height and input_width > input_height) or (
|
||||
output_width > output_height and input_width < input_height
|
||||
):
|
||||
image = torch.rot90(image, 3, dims=[1, 2])
|
||||
|
||||
return image
|
||||
|
||||
def thumbnail(
|
||||
self,
|
||||
image: "torch.Tensor",
|
||||
size: SizeDict,
|
||||
) -> "torch.Tensor":
|
||||
"""
|
||||
Resize the image to make a thumbnail. The image is resized so that no dimension is larger than any
|
||||
corresponding dimension of the specified size.
|
||||
|
||||
Args:
|
||||
image (`torch.tensor`):
|
||||
The image to be resized.
|
||||
size (`Dict[str, int]`):
|
||||
The size `{"height": h, "width": w}` to resize the image to.
|
||||
"""
|
||||
|
||||
input_height, input_width = image.shape[-2:]
|
||||
output_height, output_width = size.height, size.width
|
||||
|
||||
# We always resize to the smallest of either the input or output size.
|
||||
height = min(input_height, output_height)
|
||||
width = min(input_width, output_width)
|
||||
|
||||
if height == input_height and width == input_width:
|
||||
return image
|
||||
|
||||
if input_height > input_width:
|
||||
width = int(input_width * height / input_height)
|
||||
elif input_width > input_height:
|
||||
height = int(input_height * width / input_width)
|
||||
|
||||
new_size = (height, width)
|
||||
|
||||
return F.resize(image, new_size, interpolation=F.InterpolationMode.BICUBIC)
|
||||
|
||||
def pad_images(
|
||||
self,
|
||||
image: "torch.Tensor",
|
||||
size: SizeDict,
|
||||
) -> "torch.Tensor":
|
||||
"""
|
||||
Pads a batch of images to the specified size at the top, bottom, left and right.
|
||||
|
||||
Args:
|
||||
image (`torch.tensor`):
|
||||
The image to be padded.
|
||||
size (`Dict[str, int]`):
|
||||
The size `{"height": h, "width": w}` to pad the image to.
|
||||
"""
|
||||
input_height, input_width = image.shape[-2:]
|
||||
output_height, output_width = size.height, size.width
|
||||
|
||||
delta_width = output_width - input_width
|
||||
delta_height = output_height - input_height
|
||||
|
||||
pad_top = delta_height // 2
|
||||
pad_left = delta_width // 2
|
||||
|
||||
pad_bottom = delta_height - pad_top
|
||||
pad_right = delta_width - pad_left
|
||||
|
||||
padding = (pad_left, pad_top, pad_right, pad_bottom)
|
||||
return F.pad(image, padding)
|
||||
|
||||
def resize(
|
||||
self,
|
||||
image: "torch.Tensor",
|
||||
size: SizeDict,
|
||||
interpolation: "F.InterpolationMode" = None,
|
||||
antialias: bool = True,
|
||||
**kwargs,
|
||||
) -> "torch.Tensor":
|
||||
"""
|
||||
Resize an image to `(size["height"], size["width"])`.
|
||||
|
||||
Args:
|
||||
image (`torch.Tensor`):
|
||||
Image to resize.
|
||||
size (`SizeDict`):
|
||||
Dictionary in the format `{"height": int, "width": int}` specifying the size of the output image.
|
||||
interpolation (`InterpolationMode`, *optional*, defaults to `InterpolationMode.BICUBIC`):
|
||||
`InterpolationMode` filter to use when resizing the image e.g. `InterpolationMode.BICUBIC`.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`: The resized image.
|
||||
"""
|
||||
interpolation = interpolation if interpolation is not None else F.InterpolationMode.BICUBIC
|
||||
|
||||
shortest_edge = min(size["height"], size["width"])
|
||||
|
||||
new_size = get_resize_output_image_size(
|
||||
image, size=shortest_edge, default_to_square=False, input_data_format=ChannelDimension.FIRST
|
||||
)
|
||||
return F.resize(image, new_size, interpolation=interpolation, antialias=antialias)
|
||||
|
||||
def _preprocess(
|
||||
self,
|
||||
images: list["torch.Tensor"],
|
||||
do_resize: bool,
|
||||
size: SizeDict,
|
||||
do_align_long_axis: bool,
|
||||
do_thumbnail: bool,
|
||||
do_pad: bool,
|
||||
interpolation: Optional["F.InterpolationMode"],
|
||||
do_center_crop: bool,
|
||||
crop_size: SizeDict,
|
||||
do_rescale: bool,
|
||||
rescale_factor: float,
|
||||
do_normalize: bool,
|
||||
do_crop_margin: bool,
|
||||
image_mean: Optional[Union[float, list[float]]],
|
||||
image_std: Optional[Union[float, list[float]]],
|
||||
disable_grouping: bool,
|
||||
return_tensors: Optional[Union[str, TensorType]],
|
||||
**kwargs,
|
||||
) -> BatchFeature:
|
||||
# Crop images
|
||||
images = [self.crop_margin(image) for image in images]
|
||||
|
||||
# Group images by size for batched resizing
|
||||
grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping)
|
||||
resized_images_grouped = {}
|
||||
for shape, stacked_images in grouped_images.items():
|
||||
if do_align_long_axis:
|
||||
stacked_images = self.align_long_axis(image=stacked_images, size=size)
|
||||
if do_resize:
|
||||
stacked_images = self.resize(image=stacked_images, size=size)
|
||||
if do_thumbnail:
|
||||
stacked_images = self.thumbnail(image=stacked_images, size=size)
|
||||
if do_pad:
|
||||
stacked_images = self.pad_images(image=stacked_images, size=size)
|
||||
resized_images_grouped[shape] = stacked_images
|
||||
resized_images = reorder_images(resized_images_grouped, grouped_images_index)
|
||||
|
||||
# Group images by size for further processing
|
||||
# Needed in case do_resize is False, or resize returns images with different sizes
|
||||
grouped_images, grouped_images_index = group_images_by_shape(resized_images, disable_grouping=disable_grouping)
|
||||
processed_images_grouped = {}
|
||||
for shape, stacked_images in grouped_images.items():
|
||||
# Fused rescale and normalize
|
||||
stacked_images = self.rescale_and_normalize(
|
||||
stacked_images, do_rescale, rescale_factor, do_normalize, image_mean, image_std
|
||||
)
|
||||
processed_images_grouped[shape] = stacked_images
|
||||
|
||||
processed_images = reorder_images(processed_images_grouped, grouped_images_index)
|
||||
processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
|
||||
|
||||
return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors)
|
||||
|
||||
|
||||
__all__ = ["NougatImageProcessorFast"]
|
@ -16,10 +16,12 @@
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
from transformers.image_utils import SizeDict
|
||||
from transformers.testing_utils import require_torch, require_vision
|
||||
from transformers.utils import cached_property, is_torch_available, is_vision_available
|
||||
from transformers.utils import cached_property, is_torch_available, is_torchvision_available, is_vision_available
|
||||
|
||||
from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs
|
||||
|
||||
@ -32,6 +34,9 @@ if is_vision_available():
|
||||
|
||||
from transformers import NougatImageProcessor
|
||||
|
||||
if is_torchvision_available():
|
||||
from transformers import NougatImageProcessorFast
|
||||
|
||||
|
||||
class NougatImageProcessingTester:
|
||||
def __init__(
|
||||
@ -68,6 +73,8 @@ class NougatImageProcessingTester:
|
||||
self.do_normalize = do_normalize
|
||||
self.image_mean = image_mean
|
||||
self.image_std = image_std
|
||||
self.data_format = "channels_first"
|
||||
self.input_data_format = "channels_first"
|
||||
|
||||
def prepare_image_processor_dict(self):
|
||||
return {
|
||||
@ -112,6 +119,7 @@ class NougatImageProcessingTester:
|
||||
@require_vision
|
||||
class NougatImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
|
||||
image_processing_class = NougatImageProcessor if is_vision_available() else None
|
||||
fast_image_processing_class = NougatImageProcessorFast if is_torchvision_available() else None
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
@ -126,61 +134,106 @@ class NougatImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
|
||||
return self.image_processing_class(**self.image_processor_dict)
|
||||
|
||||
def test_image_processor_properties(self):
|
||||
image_processing = self.image_processing_class(**self.image_processor_dict)
|
||||
self.assertTrue(hasattr(image_processing, "do_resize"))
|
||||
self.assertTrue(hasattr(image_processing, "size"))
|
||||
self.assertTrue(hasattr(image_processing, "do_normalize"))
|
||||
self.assertTrue(hasattr(image_processing, "image_mean"))
|
||||
self.assertTrue(hasattr(image_processing, "image_std"))
|
||||
for image_processing_class in self.image_processor_list:
|
||||
image_processing = image_processing_class(**self.image_processor_dict)
|
||||
self.assertTrue(hasattr(image_processing, "do_resize"))
|
||||
self.assertTrue(hasattr(image_processing, "size"))
|
||||
self.assertTrue(hasattr(image_processing, "do_normalize"))
|
||||
self.assertTrue(hasattr(image_processing, "image_mean"))
|
||||
self.assertTrue(hasattr(image_processing, "image_std"))
|
||||
|
||||
def test_image_processor_from_dict_with_kwargs(self):
|
||||
image_processor = self.image_processing_class.from_dict(self.image_processor_dict)
|
||||
self.assertEqual(image_processor.size, {"height": 20, "width": 20})
|
||||
for image_processing_class in self.image_processor_list:
|
||||
image_processor = image_processing_class(**self.image_processor_dict)
|
||||
self.assertEqual(image_processor.size, {"height": 20, "width": 20})
|
||||
|
||||
image_processor = self.image_processing_class.from_dict(self.image_processor_dict, size=42)
|
||||
self.assertEqual(image_processor.size, {"height": 42, "width": 42})
|
||||
kwargs = dict(self.image_processor_dict)
|
||||
kwargs.pop("size", None)
|
||||
image_processor = self.image_processing_class(**kwargs, size=42)
|
||||
self.assertEqual(image_processor.size, {"height": 42, "width": 42})
|
||||
|
||||
def test_expected_output(self):
|
||||
dummy_image = self.image_processor_tester.prepare_dummy_image()
|
||||
image_processor = self.image_processor
|
||||
inputs = image_processor(dummy_image, return_tensors="pt")
|
||||
torch.testing.assert_close(inputs["pixel_values"].mean(), torch.tensor(0.4906), rtol=1e-3, atol=1e-3)
|
||||
for image_processing_class in self.image_processor_list:
|
||||
image_processor = image_processing_class(**self.image_processor_dict)
|
||||
inputs = image_processor(dummy_image, return_tensors="pt")
|
||||
torch.testing.assert_close(inputs["pixel_values"].mean(), torch.tensor(0.4906), rtol=1e-3, atol=1e-3)
|
||||
|
||||
def test_crop_margin_all_white(self):
|
||||
image = np.uint8(np.ones((100, 100, 3)) * 255)
|
||||
image_processor = self.image_processor
|
||||
cropped_image = image_processor.crop_margin(image)
|
||||
self.assertTrue(np.array_equal(image, cropped_image))
|
||||
image = np.uint8(np.ones((3, 100, 100)) * 255)
|
||||
for image_processing_class in self.image_processor_list:
|
||||
if image_processing_class == NougatImageProcessorFast:
|
||||
image = torch.from_numpy(image)
|
||||
image_processor = image_processing_class(**self.image_processor_dict)
|
||||
cropped_image = image_processor.crop_margin(image)
|
||||
self.assertTrue(torch.equal(image, cropped_image))
|
||||
else:
|
||||
image_processor = image_processing_class(**self.image_processor_dict)
|
||||
cropped_image = image_processor.crop_margin(image)
|
||||
self.assertTrue(np.array_equal(image, cropped_image))
|
||||
|
||||
def test_crop_margin_centered_black_square(self):
|
||||
image = np.ones((100, 100, 3), dtype=np.uint8) * 255
|
||||
image[45:55, 45:55, :] = 0
|
||||
image_processor = self.image_processor
|
||||
cropped_image = image_processor.crop_margin(image)
|
||||
expected_cropped = image[45:55, 45:55, :]
|
||||
self.assertTrue(np.array_equal(expected_cropped, cropped_image))
|
||||
image = np.ones((3, 100, 100), dtype=np.uint8) * 255
|
||||
image[:, 45:55, 45:55] = 0
|
||||
expected_cropped = image[:, 45:55, 45:55]
|
||||
for image_processing_class in self.image_processor_list:
|
||||
if image_processing_class == NougatImageProcessorFast:
|
||||
image = torch.from_numpy(image)
|
||||
expected_cropped = torch.from_numpy(expected_cropped)
|
||||
image_processor = image_processing_class(**self.image_processor_dict)
|
||||
cropped_image = image_processor.crop_margin(image)
|
||||
self.assertTrue(torch.equal(expected_cropped, cropped_image))
|
||||
else:
|
||||
image_processor = image_processing_class(**self.image_processor_dict)
|
||||
cropped_image = image_processor.crop_margin(image)
|
||||
self.assertTrue(np.array_equal(expected_cropped, cropped_image))
|
||||
|
||||
def test_align_long_axis_no_rotation(self):
|
||||
image = np.uint8(np.ones((100, 200, 3)) * 255)
|
||||
image_processor = self.image_processor
|
||||
size = {"height": 200, "width": 300}
|
||||
aligned_image = image_processor.align_long_axis(image, size)
|
||||
self.assertEqual(image.shape, aligned_image.shape)
|
||||
image = np.uint8(np.ones((3, 100, 200)) * 255)
|
||||
for image_processing_class in self.image_processor_list:
|
||||
if image_processing_class == NougatImageProcessorFast:
|
||||
image = torch.from_numpy(image)
|
||||
size = SizeDict(height=200, width=300)
|
||||
image_processor = image_processing_class(**self.image_processor_dict)
|
||||
aligned_image = image_processor.align_long_axis(image, size)
|
||||
self.assertEqual(image.shape, aligned_image.shape)
|
||||
else:
|
||||
size = {"height": 200, "width": 300}
|
||||
image_processor = image_processing_class(**self.image_processor_dict)
|
||||
aligned_image = image_processor.align_long_axis(image, size)
|
||||
self.assertEqual(image.shape, aligned_image.shape)
|
||||
|
||||
def test_align_long_axis_with_rotation(self):
|
||||
image = np.uint8(np.ones((200, 100, 3)) * 255)
|
||||
image_processor = self.image_processor
|
||||
size = {"height": 300, "width": 200}
|
||||
aligned_image = image_processor.align_long_axis(image, size)
|
||||
self.assertEqual((200, 100, 3), aligned_image.shape)
|
||||
image = np.uint8(np.ones((3, 200, 100)) * 255)
|
||||
for image_processing_class in self.image_processor_list:
|
||||
image_processor = image_processing_class(**self.image_processor_dict)
|
||||
if image_processing_class == NougatImageProcessorFast:
|
||||
image = torch.from_numpy(image)
|
||||
size = SizeDict(height=300, width=200)
|
||||
image_processor = image_processing_class(**self.image_processor_dict)
|
||||
aligned_image = image_processor.align_long_axis(image, size)
|
||||
self.assertEqual(torch.Size([3, 200, 100]), aligned_image.shape)
|
||||
else:
|
||||
size = {"height": 300, "width": 200}
|
||||
image_processor = image_processing_class(**self.image_processor_dict)
|
||||
aligned_image = image_processor.align_long_axis(image, size)
|
||||
self.assertEqual((3, 200, 100), aligned_image.shape)
|
||||
|
||||
def test_align_long_axis_data_format(self):
|
||||
image = np.uint8(np.ones((100, 200, 3)) * 255)
|
||||
data_format = "channels_first"
|
||||
size = {"height": 200, "width": 300}
|
||||
image_processor = self.image_processor
|
||||
aligned_image = image_processor.align_long_axis(image, size, data_format=data_format)
|
||||
self.assertEqual((3, 100, 200), aligned_image.shape)
|
||||
image = np.uint8(np.ones((3, 100, 200)) * 255)
|
||||
for image_processing_class in self.image_processor_list:
|
||||
if image_processing_class == NougatImageProcessorFast:
|
||||
image = torch.from_numpy(image)
|
||||
image_processor = image_processing_class(**self.image_processor_dict)
|
||||
size = SizeDict(height=200, width=300)
|
||||
aligned_image = image_processor.align_long_axis(image, size)
|
||||
self.assertEqual(torch.Size([3, 100, 200]), aligned_image.shape)
|
||||
else:
|
||||
size = {"height": 200, "width": 300}
|
||||
data_format = "channels_first"
|
||||
image_processor = image_processing_class(**self.image_processor_dict)
|
||||
aligned_image = image_processor.align_long_axis(image, size, data_format)
|
||||
self.assertEqual((3, 100, 200), aligned_image.shape)
|
||||
|
||||
def prepare_dummy_np_image(self):
|
||||
revision = "ec57bf8c8b1653a209c13f6e9ee66b12df0fc2db"
|
||||
@ -191,12 +244,77 @@ class NougatImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
|
||||
revision=revision,
|
||||
)
|
||||
image = Image.open(filepath).convert("RGB")
|
||||
return np.array(image)
|
||||
return np.array(image).transpose(2, 0, 1)
|
||||
|
||||
def test_crop_margin_equality_cv2_python(self):
|
||||
image = self.prepare_dummy_np_image()
|
||||
image_processor = self.image_processor
|
||||
image_cropped_python = image_processor.crop_margin(image)
|
||||
for image_processing_class in self.image_processor_list:
|
||||
if image_processing_class == NougatImageProcessorFast:
|
||||
image = torch.from_numpy(image)
|
||||
image_processor = image_processing_class(**self.image_processor_dict)
|
||||
image_cropped_python = image_processor.crop_margin(image)
|
||||
self.assertEqual(image_cropped_python.shape, torch.Size([3, 850, 685]))
|
||||
self.assertAlmostEqual(image_cropped_python.float().mean().item(), 237.43881150708458, delta=0.001)
|
||||
else:
|
||||
image_processor = image_processing_class(**self.image_processor_dict)
|
||||
image_cropped_python = image_processor.crop_margin(image)
|
||||
self.assertEqual(image_cropped_python.shape, (3, 850, 685))
|
||||
self.assertAlmostEqual(image_cropped_python.mean(), 237.43881150708458, delta=0.001)
|
||||
|
||||
self.assertEqual(image_cropped_python.shape, (850, 685, 3))
|
||||
self.assertEqual(image_cropped_python.mean(), 237.43881150708458)
|
||||
def test_call_numpy_4_channels(self):
|
||||
for image_processing_class in self.image_processor_list:
|
||||
if image_processing_class == NougatImageProcessor:
|
||||
# Test that can process images which have an arbitrary number of channels
|
||||
# Initialize image_processing
|
||||
image_processor = image_processing_class(**self.image_processor_dict)
|
||||
|
||||
# create random numpy tensors
|
||||
self.image_processor_tester.num_channels = 4
|
||||
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, numpify=True)
|
||||
|
||||
# Test not batched input
|
||||
encoded_images = image_processor(
|
||||
image_inputs[0],
|
||||
return_tensors="pt",
|
||||
input_data_format="channels_last",
|
||||
image_mean=0,
|
||||
image_std=1,
|
||||
).pixel_values
|
||||
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(
|
||||
[image_inputs[0]]
|
||||
)
|
||||
self.assertEqual(tuple(encoded_images.shape), (1, *expected_output_image_shape))
|
||||
|
||||
# Test batched
|
||||
encoded_images = image_processor(
|
||||
image_inputs,
|
||||
return_tensors="pt",
|
||||
input_data_format="channels_last",
|
||||
image_mean=0,
|
||||
image_std=1,
|
||||
).pixel_values
|
||||
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_inputs)
|
||||
self.assertEqual(
|
||||
tuple(encoded_images.shape), (self.image_processor_tester.batch_size, *expected_output_image_shape)
|
||||
)
|
||||
|
||||
def test_slow_fast_equivalence(self):
|
||||
if not self.test_slow_image_processor or not self.test_fast_image_processor:
|
||||
self.skipTest(reason="Skipping slow/fast equivalence test")
|
||||
|
||||
if self.image_processing_class is None or self.fast_image_processing_class is None:
|
||||
self.skipTest(reason="Skipping slow/fast equivalence test as one of the image processors is not defined")
|
||||
|
||||
dummy_image = Image.open(
|
||||
requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw
|
||||
)
|
||||
image_processor_slow = self.image_processing_class(**self.image_processor_dict)
|
||||
image_processor_fast = self.fast_image_processing_class(**self.image_processor_dict)
|
||||
|
||||
encoding_slow = image_processor_slow(dummy_image, return_tensors="pt")
|
||||
encoding_fast = image_processor_fast(dummy_image, return_tensors="pt")
|
||||
# Adding a larget than usual tolerance because the slow processor uses reducing_gap=2.0 during resizing.
|
||||
torch.testing.assert_close(encoding_slow.pixel_values, encoding_fast.pixel_values, atol=2e-1, rtol=0)
|
||||
self.assertLessEqual(
|
||||
torch.mean(torch.abs(encoding_slow.pixel_values - encoding_fast.pixel_values)).item(), 2e-2
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user