mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Add EfficientNet Image PreProcessor (#37055)
* added efficientnet image preprocessor but tests fail * ruff checks pass * ruff formatted * properly pass rescale_offset through the functions * - corrected indentation, ordering of methods - reshape test passes when casted to float64 - equivalence test doesn't pass * all tests now pass - changes order of rescale, normalize acc to slow - rescale_offset defaults to False acc to slow - resample was causing difference in fast and slow. Changing test to bilinear resolves this difference * ruff reformat * F.InterpolationMode.NEAREST_EXACT gives TypeError: Object of type InterpolationMode is not JSON serializable * fixes offset not being applied when do_rescale and do_normalization are both true * - using nearest_exact sampling - added tests for rescale + normalize * resolving reviews --------- Co-authored-by: Yoni Gozlan <74535834+yonigozlan@users.noreply.github.com>
This commit is contained in:
parent
32eca7197a
commit
a7d2bbaaa8
@ -43,6 +43,11 @@ The original code can be found [here](https://github.com/tensorflow/tpu/tree/mas
|
||||
[[autodoc]] EfficientNetImageProcessor
|
||||
- preprocess
|
||||
|
||||
## EfficientNetImageProcessorFast
|
||||
|
||||
[[autodoc]] EfficientNetImageProcessorFast
|
||||
- preprocess
|
||||
|
||||
## EfficientNetModel
|
||||
|
||||
[[autodoc]] EfficientNetModel
|
||||
|
@ -65,7 +65,7 @@ if is_vision_available():
|
||||
from torchvision.transforms import InterpolationMode
|
||||
|
||||
pil_torch_interpolation_mapping = {
|
||||
PILImageResampling.NEAREST: InterpolationMode.NEAREST,
|
||||
PILImageResampling.NEAREST: InterpolationMode.NEAREST_EXACT,
|
||||
PILImageResampling.BOX: InterpolationMode.BOX,
|
||||
PILImageResampling.BILINEAR: InterpolationMode.BILINEAR,
|
||||
PILImageResampling.HAMMING: InterpolationMode.HAMMING,
|
||||
|
@ -56,7 +56,7 @@ if TYPE_CHECKING:
|
||||
else:
|
||||
IMAGE_PROCESSOR_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
("align", ("EfficientNetImageProcessor",)),
|
||||
("align", ("EfficientNetImageProcessor", "EfficientNetImageProcessorFast")),
|
||||
("aria", ("AriaImageProcessor",)),
|
||||
("beit", ("BeitImageProcessor",)),
|
||||
("bit", ("BitImageProcessor", "BitImageProcessorFast")),
|
||||
@ -83,7 +83,7 @@ else:
|
||||
("donut-swin", ("DonutImageProcessor", "DonutImageProcessorFast")),
|
||||
("dpt", ("DPTImageProcessor",)),
|
||||
("efficientformer", ("EfficientFormerImageProcessor",)),
|
||||
("efficientnet", ("EfficientNetImageProcessor",)),
|
||||
("efficientnet", ("EfficientNetImageProcessor", "EfficientNetImageProcessorFast")),
|
||||
("flava", ("FlavaImageProcessor", "FlavaImageProcessorFast")),
|
||||
("focalnet", ("BitImageProcessor", "BitImageProcessorFast")),
|
||||
("fuyu", ("FuyuImageProcessor",)),
|
||||
|
@ -20,6 +20,7 @@ from ...utils.import_utils import define_import_structure
|
||||
if TYPE_CHECKING:
|
||||
from .configuration_efficientnet import *
|
||||
from .image_processing_efficientnet import *
|
||||
from .image_processing_efficientnet_fast import *
|
||||
from .modeling_efficientnet import *
|
||||
else:
|
||||
import sys
|
||||
|
@ -0,0 +1,226 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Fast Image processor class for EfficientNet."""
|
||||
|
||||
from functools import lru_cache
|
||||
from typing import Optional, Union
|
||||
|
||||
from ...image_processing_utils_fast import (
|
||||
BASE_IMAGE_PROCESSOR_FAST_DOCSTRING,
|
||||
BASE_IMAGE_PROCESSOR_FAST_DOCSTRING_PREPROCESS,
|
||||
BaseImageProcessorFast,
|
||||
BatchFeature,
|
||||
DefaultFastImageProcessorKwargs,
|
||||
)
|
||||
from ...image_transforms import group_images_by_shape, reorder_images
|
||||
from ...image_utils import (
|
||||
IMAGENET_STANDARD_MEAN,
|
||||
IMAGENET_STANDARD_STD,
|
||||
ImageInput,
|
||||
PILImageResampling,
|
||||
SizeDict,
|
||||
)
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import (
|
||||
TensorType,
|
||||
add_start_docstrings,
|
||||
is_torch_available,
|
||||
is_torchvision_available,
|
||||
is_torchvision_v2_available,
|
||||
)
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
if is_torchvision_available():
|
||||
if is_torchvision_v2_available():
|
||||
from torchvision.transforms.v2 import functional as F
|
||||
else:
|
||||
from torchvision.transforms import functional as F
|
||||
|
||||
|
||||
class EfficientNetFastImageProcessorKwargs(DefaultFastImageProcessorKwargs):
|
||||
rescale_offset: bool
|
||||
include_top: bool
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"Constructs a fast EfficientNet image processor.",
|
||||
BASE_IMAGE_PROCESSOR_FAST_DOCSTRING,
|
||||
)
|
||||
class EfficientNetImageProcessorFast(BaseImageProcessorFast):
|
||||
resample = PILImageResampling.NEAREST
|
||||
image_mean = IMAGENET_STANDARD_MEAN
|
||||
image_std = IMAGENET_STANDARD_STD
|
||||
size = {"height": 346, "width": 346}
|
||||
crop_size = {"height": 289, "width": 289}
|
||||
do_resize = True
|
||||
do_center_crop = False
|
||||
do_rescale = True
|
||||
rescale_factor = 1 / 255
|
||||
rescale_offset = False
|
||||
do_normalize = True
|
||||
include_top = True
|
||||
valid_kwargs = EfficientNetFastImageProcessorKwargs
|
||||
|
||||
def __init__(self, **kwargs: Unpack[EfficientNetFastImageProcessorKwargs]):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def rescale(
|
||||
self,
|
||||
image: "torch.Tensor",
|
||||
scale: float,
|
||||
offset: Optional[bool] = True,
|
||||
**kwargs,
|
||||
) -> "torch.Tensor":
|
||||
"""
|
||||
Rescale an image by a scale factor.
|
||||
|
||||
If `offset` is `True`, the image has its values rescaled by `scale` and then offset by 1. If `scale` is
|
||||
1/127.5, the image is rescaled between [-1, 1].
|
||||
image = image * scale - 1
|
||||
|
||||
If `offset` is `False`, and `scale` is 1/255, the image is rescaled between [0, 1].
|
||||
image = image * scale
|
||||
|
||||
Args:
|
||||
image (`torch.Tensor`):
|
||||
Image to rescale.
|
||||
scale (`float`):
|
||||
The scaling factor to rescale pixel values by.
|
||||
offset (`bool`, *optional*):
|
||||
Whether to scale the image in both negative and positive directions.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`: The rescaled image.
|
||||
"""
|
||||
|
||||
rescaled_image = image * scale
|
||||
|
||||
if offset:
|
||||
rescaled_image -= 1
|
||||
|
||||
return rescaled_image
|
||||
|
||||
@lru_cache(maxsize=10)
|
||||
def _fuse_mean_std_and_rescale_factor(
|
||||
self,
|
||||
do_normalize: Optional[bool] = None,
|
||||
image_mean: Optional[Union[float, list[float]]] = None,
|
||||
image_std: Optional[Union[float, list[float]]] = None,
|
||||
do_rescale: Optional[bool] = None,
|
||||
rescale_factor: Optional[float] = None,
|
||||
device: Optional["torch.device"] = None,
|
||||
rescale_offset: Optional[bool] = False,
|
||||
) -> tuple:
|
||||
if do_rescale and do_normalize and not rescale_offset:
|
||||
# Fused rescale and normalize
|
||||
image_mean = torch.tensor(image_mean, device=device) * (1.0 / rescale_factor)
|
||||
image_std = torch.tensor(image_std, device=device) * (1.0 / rescale_factor)
|
||||
do_rescale = False
|
||||
return image_mean, image_std, do_rescale
|
||||
|
||||
def rescale_and_normalize(
|
||||
self,
|
||||
images: "torch.Tensor",
|
||||
do_rescale: bool,
|
||||
rescale_factor: float,
|
||||
do_normalize: bool,
|
||||
image_mean: Union[float, list[float]],
|
||||
image_std: Union[float, list[float]],
|
||||
rescale_offset: bool = False,
|
||||
) -> "torch.Tensor":
|
||||
"""
|
||||
Rescale and normalize images.
|
||||
"""
|
||||
image_mean, image_std, do_rescale = self._fuse_mean_std_and_rescale_factor(
|
||||
do_normalize=do_normalize,
|
||||
image_mean=image_mean,
|
||||
image_std=image_std,
|
||||
do_rescale=do_rescale,
|
||||
rescale_factor=rescale_factor,
|
||||
device=images.device,
|
||||
rescale_offset=rescale_offset,
|
||||
)
|
||||
# if/elif as we use fused rescale and normalize if both are set to True
|
||||
if do_rescale:
|
||||
images = self.rescale(images, rescale_factor, rescale_offset)
|
||||
if do_normalize:
|
||||
images = self.normalize(images.to(dtype=torch.float32), image_mean, image_std)
|
||||
|
||||
return images
|
||||
|
||||
def _preprocess(
|
||||
self,
|
||||
images: list["torch.Tensor"],
|
||||
do_resize: bool,
|
||||
size: SizeDict,
|
||||
interpolation: Optional["F.InterpolationMode"],
|
||||
do_center_crop: bool,
|
||||
crop_size: SizeDict,
|
||||
do_rescale: bool,
|
||||
rescale_factor: float,
|
||||
rescale_offset: bool,
|
||||
do_normalize: bool,
|
||||
include_top: bool,
|
||||
image_mean: Optional[Union[float, list[float]]],
|
||||
image_std: Optional[Union[float, list[float]]],
|
||||
return_tensors: Optional[Union[str, TensorType]],
|
||||
**kwargs,
|
||||
) -> BatchFeature:
|
||||
# Group images by size for batched resizing
|
||||
grouped_images, grouped_images_index = group_images_by_shape(images)
|
||||
resized_images_grouped = {}
|
||||
for shape, stacked_images in grouped_images.items():
|
||||
if do_resize:
|
||||
stacked_images = self.resize(image=stacked_images, size=size, interpolation=interpolation)
|
||||
resized_images_grouped[shape] = stacked_images
|
||||
resized_images = reorder_images(resized_images_grouped, grouped_images_index)
|
||||
|
||||
# Group images by size for further processing
|
||||
# Needed in case do_resize is False, or resize returns images with different sizes
|
||||
grouped_images, grouped_images_index = group_images_by_shape(resized_images)
|
||||
processed_images_grouped = {}
|
||||
for shape, stacked_images in grouped_images.items():
|
||||
if do_center_crop:
|
||||
stacked_images = self.center_crop(stacked_images, crop_size)
|
||||
# Fused rescale and normalize
|
||||
stacked_images = self.rescale_and_normalize(
|
||||
stacked_images, do_rescale, rescale_factor, do_normalize, image_mean, image_std, rescale_offset
|
||||
)
|
||||
if include_top:
|
||||
stacked_images = self.normalize(stacked_images, 0, image_std)
|
||||
processed_images_grouped[shape] = stacked_images
|
||||
|
||||
processed_images = reorder_images(processed_images_grouped, grouped_images_index)
|
||||
processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
|
||||
|
||||
return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors)
|
||||
|
||||
@add_start_docstrings(
|
||||
BASE_IMAGE_PROCESSOR_FAST_DOCSTRING_PREPROCESS,
|
||||
"""
|
||||
rescale_offset (`bool`, *optional*, defaults to `self.rescale_offset`):
|
||||
Whether to rescale the image between [-max_range/2, scale_range/2] instead of [0, scale_range].
|
||||
include_top (`bool`, *optional*, defaults to `self.include_top`):
|
||||
Normalize the image again with the standard deviation only for image classification if set to True.
|
||||
""",
|
||||
)
|
||||
def preprocess(self, images: ImageInput, **kwargs: Unpack[EfficientNetFastImageProcessorKwargs]) -> BatchFeature:
|
||||
return super().preprocess(images, **kwargs)
|
||||
|
||||
|
||||
__all__ = ["EfficientNetImageProcessorFast"]
|
@ -17,15 +17,26 @@ import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from transformers.image_utils import PILImageResampling
|
||||
from transformers.testing_utils import require_torch, require_vision
|
||||
from transformers.utils import is_vision_available
|
||||
from transformers.utils import (
|
||||
is_torch_available,
|
||||
is_torchvision_available,
|
||||
is_vision_available,
|
||||
)
|
||||
|
||||
from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
if is_vision_available():
|
||||
from transformers import EfficientNetImageProcessor
|
||||
|
||||
if is_torchvision_available():
|
||||
from transformers import EfficientNetImageProcessorFast
|
||||
|
||||
|
||||
class EfficientNetImageProcessorTester:
|
||||
def __init__(
|
||||
@ -41,6 +52,10 @@ class EfficientNetImageProcessorTester:
|
||||
do_normalize=True,
|
||||
image_mean=[0.5, 0.5, 0.5],
|
||||
image_std=[0.5, 0.5, 0.5],
|
||||
do_rescale=True,
|
||||
rescale_offset=True,
|
||||
rescale_factor=1 / 127.5,
|
||||
resample=PILImageResampling.BILINEAR, # NEAREST is too different between PIL and torchvision
|
||||
):
|
||||
size = size if size is not None else {"height": 18, "width": 18}
|
||||
self.parent = parent
|
||||
@ -54,6 +69,7 @@ class EfficientNetImageProcessorTester:
|
||||
self.do_normalize = do_normalize
|
||||
self.image_mean = image_mean
|
||||
self.image_std = image_std
|
||||
self.resample = resample
|
||||
|
||||
def prepare_image_processor_dict(self):
|
||||
return {
|
||||
@ -62,6 +78,7 @@ class EfficientNetImageProcessorTester:
|
||||
"do_normalize": self.do_normalize,
|
||||
"do_resize": self.do_resize,
|
||||
"size": self.size,
|
||||
"resample": self.resample,
|
||||
}
|
||||
|
||||
def expected_output_image_shape(self, images):
|
||||
@ -83,6 +100,7 @@ class EfficientNetImageProcessorTester:
|
||||
@require_vision
|
||||
class EfficientNetImageProcessorTest(ImageProcessingTestMixin, unittest.TestCase):
|
||||
image_processing_class = EfficientNetImageProcessor if is_vision_available() else None
|
||||
fast_image_processing_class = EfficientNetImageProcessorFast if is_torchvision_available() else None
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
@ -93,30 +111,80 @@ class EfficientNetImageProcessorTest(ImageProcessingTestMixin, unittest.TestCase
|
||||
return self.image_processor_tester.prepare_image_processor_dict()
|
||||
|
||||
def test_image_processor_properties(self):
|
||||
image_processing = self.image_processing_class(**self.image_processor_dict)
|
||||
self.assertTrue(hasattr(image_processing, "image_mean"))
|
||||
self.assertTrue(hasattr(image_processing, "image_std"))
|
||||
self.assertTrue(hasattr(image_processing, "do_normalize"))
|
||||
self.assertTrue(hasattr(image_processing, "do_resize"))
|
||||
self.assertTrue(hasattr(image_processing, "size"))
|
||||
for image_processing_class in self.image_processor_list:
|
||||
image_processing = image_processing_class(**self.image_processor_dict)
|
||||
self.assertTrue(hasattr(image_processing, "image_mean"))
|
||||
self.assertTrue(hasattr(image_processing, "image_std"))
|
||||
self.assertTrue(hasattr(image_processing, "do_normalize"))
|
||||
self.assertTrue(hasattr(image_processing, "do_resize"))
|
||||
self.assertTrue(hasattr(image_processing, "size"))
|
||||
|
||||
def test_image_processor_from_dict_with_kwargs(self):
|
||||
image_processor = self.image_processing_class.from_dict(self.image_processor_dict)
|
||||
self.assertEqual(image_processor.size, {"height": 18, "width": 18})
|
||||
for image_processing_class in self.image_processor_list:
|
||||
image_processor = image_processing_class.from_dict(self.image_processor_dict)
|
||||
self.assertEqual(image_processor.size, {"height": 18, "width": 18})
|
||||
|
||||
image_processor = self.image_processing_class.from_dict(self.image_processor_dict, size=42)
|
||||
self.assertEqual(image_processor.size, {"height": 42, "width": 42})
|
||||
image_processor = image_processing_class.from_dict(self.image_processor_dict, size=42)
|
||||
self.assertEqual(image_processor.size, {"height": 42, "width": 42})
|
||||
|
||||
def test_rescale(self):
|
||||
# EfficientNet optionally rescales between -1 and 1 instead of the usual 0 and 1
|
||||
image = np.arange(0, 256, 1, dtype=np.uint8).reshape(1, 8, 32)
|
||||
|
||||
image_processor = self.image_processing_class(**self.image_processor_dict)
|
||||
for image_processing_class in self.image_processor_list:
|
||||
image_processor = image_processing_class(**self.image_processor_dict)
|
||||
if image_processing_class == EfficientNetImageProcessorFast:
|
||||
image = torch.from_numpy(image)
|
||||
|
||||
rescaled_image = image_processor.rescale(image, scale=1 / 127.5)
|
||||
expected_image = (image * (1 / 127.5)).astype(np.float32) - 1
|
||||
self.assertTrue(np.allclose(rescaled_image, expected_image))
|
||||
# Scale between [-1, 1] with rescale_factor 1/127.5 and rescale_offset=True
|
||||
rescaled_image = image_processor.rescale(image, scale=1 / 127.5, offset=True)
|
||||
expected_image = (image * (1 / 127.5)) - 1
|
||||
self.assertTrue(torch.allclose(rescaled_image, expected_image))
|
||||
|
||||
rescaled_image = image_processor.rescale(image, scale=1 / 255, offset=False)
|
||||
expected_image = (image / 255.0).astype(np.float32)
|
||||
self.assertTrue(np.allclose(rescaled_image, expected_image))
|
||||
# Scale between [0, 1] with rescale_factor 1/255 and rescale_offset=True
|
||||
rescaled_image = image_processor.rescale(image, scale=1 / 255, offset=False)
|
||||
expected_image = image / 255.0
|
||||
self.assertTrue(torch.allclose(rescaled_image, expected_image))
|
||||
|
||||
else:
|
||||
rescaled_image = image_processor.rescale(image, scale=1 / 127.5, dtype=np.float64)
|
||||
expected_image = (image * (1 / 127.5)).astype(np.float64) - 1
|
||||
self.assertTrue(np.allclose(rescaled_image, expected_image))
|
||||
|
||||
rescaled_image = image_processor.rescale(image, scale=1 / 255, offset=False, dtype=np.float64)
|
||||
expected_image = (image / 255.0).astype(np.float64)
|
||||
self.assertTrue(np.allclose(rescaled_image, expected_image))
|
||||
|
||||
@require_vision
|
||||
@require_torch
|
||||
def test_rescale_normalize(self):
|
||||
if self.image_processing_class is None or self.fast_image_processing_class is None:
|
||||
self.skipTest(reason="Skipping slow/fast equivalence test as one of the image processors is not defined")
|
||||
|
||||
image = torch.arange(0, 256, 1, dtype=torch.uint8).reshape(1, 8, 32).repeat(3, 1, 1)
|
||||
image_mean_0 = (0.0, 0.0, 0.0)
|
||||
image_std_0 = (1.0, 1.0, 1.0)
|
||||
image_mean_1 = (0.5, 0.5, 0.5)
|
||||
image_std_1 = (0.5, 0.5, 0.5)
|
||||
|
||||
image_processor_fast = self.fast_image_processing_class(**self.image_processor_dict)
|
||||
|
||||
# Rescale between [-1, 1] with rescale_factor=1/127.5 and rescale_offset=True. Then normalize
|
||||
rescaled_normalized = image_processor_fast.rescale_and_normalize(
|
||||
image, True, 1 / 127.5, True, image_mean_0, image_std_0, True
|
||||
)
|
||||
expected_image = (image * (1 / 127.5)) - 1
|
||||
expected_image = (expected_image - torch.tensor(image_mean_0).view(3, 1, 1)) / torch.tensor(image_std_0).view(
|
||||
3, 1, 1
|
||||
)
|
||||
self.assertTrue(torch.allclose(rescaled_normalized, expected_image, rtol=1e-3))
|
||||
|
||||
# Rescale between [0, 1] with rescale_factor=1/255 and rescale_offset=False. Then normalize
|
||||
rescaled_normalized = image_processor_fast.rescale_and_normalize(
|
||||
image, True, 1 / 255, True, image_mean_1, image_std_1, False
|
||||
)
|
||||
expected_image = image * (1 / 255.0)
|
||||
expected_image = (expected_image - torch.tensor(image_mean_1).view(3, 1, 1)) / torch.tensor(image_std_1).view(
|
||||
3, 1, 1
|
||||
)
|
||||
self.assertTrue(torch.allclose(rescaled_normalized, expected_image, rtol=1e-3))
|
||||
|
Loading…
Reference in New Issue
Block a user