mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 12:50:06 +06:00
[Fast Processor] BEiT (#37005)
* adding fast processor for beit * adding resample * address review issues and add segmentation maps logic * style * chore: adding tests * reduce label test * adding batched tests * Update src/transformers/models/beit/image_processing_beit_fast.py Co-authored-by: Yoni Gozlan <74535834+yonigozlan@users.noreply.github.com> * fix imports and make segmentation masks * fix tests * build segmentation maps * all tests pass * style * style fix * style * chore: delete demo.py file * review suggestions * Update docs/source/en/model_doc/beit.md Co-authored-by: Yoni Gozlan <74535834+yonigozlan@users.noreply.github.com> --------- Co-authored-by: Yoni Gozlan <74535834+yonigozlan@users.noreply.github.com>
This commit is contained in:
parent
ebbe9b12dd
commit
3c0796aaea
@ -150,6 +150,11 @@ If you're interested in submitting a resource to be included here, please feel f
|
||||
[[autodoc]] BeitImageProcessor
|
||||
- preprocess
|
||||
- post_process_semantic_segmentation
|
||||
## BeitImageProcessorFast
|
||||
|
||||
[[autodoc]] BeitImageProcessorFast
|
||||
- preprocess
|
||||
- post_process_semantic_segmentation
|
||||
|
||||
<frameworkcontent>
|
||||
<pt>
|
||||
|
@ -105,6 +105,11 @@ BEiT の使用を開始するのに役立つ公式 Hugging Face およびコミ
|
||||
|
||||
[[autodoc]] BeitImageProcessor
|
||||
- preprocess
|
||||
|
||||
## BeitImageProcessorFast
|
||||
|
||||
[[autodoc]] BeitImageProcessorFast
|
||||
- preprocess
|
||||
- post_process_semantic_segmentation
|
||||
|
||||
## BeitModel
|
||||
|
@ -57,8 +57,8 @@ else:
|
||||
IMAGE_PROCESSOR_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
("align", ("EfficientNetImageProcessor", "EfficientNetImageProcessorFast")),
|
||||
("aria", ("AriaImageProcessor",)),
|
||||
("beit", ("BeitImageProcessor",)),
|
||||
("aria", ("AriaImageProcessor")),
|
||||
("beit", ("BeitImageProcessor", "BeitImageProcessorFast")),
|
||||
("bit", ("BitImageProcessor", "BitImageProcessorFast")),
|
||||
("blip", ("BlipImageProcessor", "BlipImageProcessorFast")),
|
||||
("blip-2", ("BlipImageProcessor", "BlipImageProcessorFast")),
|
||||
@ -71,7 +71,7 @@ else:
|
||||
("convnext", ("ConvNextImageProcessor", "ConvNextImageProcessorFast")),
|
||||
("convnextv2", ("ConvNextImageProcessor", "ConvNextImageProcessorFast")),
|
||||
("cvt", ("ConvNextImageProcessor", "ConvNextImageProcessorFast")),
|
||||
("data2vec-vision", ("BeitImageProcessor",)),
|
||||
("data2vec-vision", ("BeitImageProcessor", "BeitImageProcessorFast")),
|
||||
("deformable_detr", ("DeformableDetrImageProcessor", "DeformableDetrImageProcessorFast")),
|
||||
("deit", ("DeiTImageProcessor", "DeiTImageProcessorFast")),
|
||||
("depth_anything", ("DPTImageProcessor",)),
|
||||
|
@ -21,6 +21,7 @@ if TYPE_CHECKING:
|
||||
from .configuration_beit import *
|
||||
from .feature_extraction_beit import *
|
||||
from .image_processing_beit import *
|
||||
from .image_processing_beit_fast import *
|
||||
from .modeling_beit import *
|
||||
from .modeling_flax_beit import *
|
||||
else:
|
||||
|
284
src/transformers/models/beit/image_processing_beit_fast.py
Normal file
284
src/transformers/models/beit/image_processing_beit_fast.py
Normal file
@ -0,0 +1,284 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Fast Image processor class for Beit."""
|
||||
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torchvision.transforms import functional as F
|
||||
|
||||
from ...image_processing_utils import BatchFeature
|
||||
from ...image_processing_utils_fast import (
|
||||
BASE_IMAGE_PROCESSOR_FAST_DOCSTRING,
|
||||
BaseImageProcessorFast,
|
||||
DefaultFastImageProcessorKwargs,
|
||||
group_images_by_shape,
|
||||
reorder_images,
|
||||
)
|
||||
from ...image_utils import (
|
||||
IMAGENET_STANDARD_MEAN,
|
||||
IMAGENET_STANDARD_STD,
|
||||
ChannelDimension,
|
||||
ImageInput,
|
||||
PILImageResampling,
|
||||
SizeDict,
|
||||
is_torch_tensor,
|
||||
make_list_of_images,
|
||||
pil_torch_interpolation_mapping,
|
||||
validate_kwargs,
|
||||
)
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import TensorType, add_start_docstrings
|
||||
from ...utils.deprecation import deprecate_kwarg
|
||||
|
||||
|
||||
class BeitFastImageProcessorKwargs(DefaultFastImageProcessorKwargs):
|
||||
do_reduce_labels: Optional[bool]
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"Constructs a fast Beit image processor.",
|
||||
BASE_IMAGE_PROCESSOR_FAST_DOCSTRING,
|
||||
"""
|
||||
do_reduce_labels (`bool`, *optional*, defaults to `self.do_reduce_labels`):
|
||||
Whether or not to reduce all label values of segmentation maps by 1. Usually used for datasets where 0
|
||||
is used for background, and background itself is not included in all classes of a dataset (e.g.
|
||||
ADE20k). The background label will be replaced by 255.
|
||||
""",
|
||||
)
|
||||
class BeitImageProcessorFast(BaseImageProcessorFast):
|
||||
resample = PILImageResampling.BICUBIC
|
||||
image_mean = IMAGENET_STANDARD_MEAN
|
||||
image_std = IMAGENET_STANDARD_STD
|
||||
size = {"height": 224, "width": 224}
|
||||
default_to_square = True
|
||||
crop_size = {"height": 224, "width": 224}
|
||||
do_resize = True
|
||||
do_center_crop = False
|
||||
do_rescale = True
|
||||
do_normalize = True
|
||||
do_reduce_labels = False
|
||||
valid_kwargs = BeitFastImageProcessorKwargs
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, image_processor_dict: Dict[str, Any], **kwargs):
|
||||
"""
|
||||
Overrides the `from_dict` method from the base class to save support of deprecated `reduce_labels` in old configs
|
||||
"""
|
||||
image_processor_dict = image_processor_dict.copy()
|
||||
if "reduce_labels" in image_processor_dict:
|
||||
image_processor_dict["do_reduce_labels"] = image_processor_dict.pop("reduce_labels")
|
||||
return super().from_dict(image_processor_dict, **kwargs)
|
||||
|
||||
def reduce_label(self, labels: list["torch.Tensor"]):
|
||||
for idx in range(len(labels)):
|
||||
label = labels[idx]
|
||||
label = torch.where(label == 0, torch.tensor(255, dtype=label.dtype), label)
|
||||
label = label - 1
|
||||
label = torch.where(label == 254, torch.tensor(255, dtype=label.dtype), label)
|
||||
labels[idx] = label
|
||||
|
||||
return label
|
||||
|
||||
def _preprocess(
|
||||
self,
|
||||
images: list["torch.Tensor"],
|
||||
do_reduce_labels: bool,
|
||||
do_resize: bool,
|
||||
size: SizeDict,
|
||||
interpolation: Optional["F.InterpolationMode"],
|
||||
do_center_crop: bool,
|
||||
crop_size: SizeDict,
|
||||
do_rescale: bool,
|
||||
rescale_factor: float,
|
||||
do_normalize: bool,
|
||||
image_mean: Optional[Union[float, list[float]]],
|
||||
image_std: Optional[Union[float, list[float]]],
|
||||
return_tensors: Optional[Union[str, TensorType]],
|
||||
**kwargs,
|
||||
) -> BatchFeature:
|
||||
if do_reduce_labels:
|
||||
images = self.reduce_label(images)
|
||||
|
||||
# Group images by size for batched resizing
|
||||
grouped_images, grouped_images_index = group_images_by_shape(images)
|
||||
resized_images_grouped = {}
|
||||
for shape, stacked_images in grouped_images.items():
|
||||
if do_resize:
|
||||
stacked_images = self.resize(image=stacked_images, size=size, interpolation=interpolation)
|
||||
resized_images_grouped[shape] = stacked_images
|
||||
resized_images = reorder_images(resized_images_grouped, grouped_images_index)
|
||||
|
||||
# Group images by size for further processing
|
||||
# Needed in case do_resize is False, or resize returns images with different sizes
|
||||
grouped_images, grouped_images_index = group_images_by_shape(resized_images)
|
||||
processed_images_grouped = {}
|
||||
for shape, stacked_images in grouped_images.items():
|
||||
if do_center_crop:
|
||||
stacked_images = self.center_crop(stacked_images, crop_size)
|
||||
# Fused rescale and normalize
|
||||
stacked_images = self.rescale_and_normalize(
|
||||
stacked_images, do_rescale, rescale_factor, do_normalize, image_mean, image_std
|
||||
)
|
||||
processed_images_grouped[shape] = stacked_images
|
||||
|
||||
processed_images = reorder_images(processed_images_grouped, grouped_images_index)
|
||||
processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
|
||||
return processed_images
|
||||
|
||||
def _preprocess_segmentation_maps(
|
||||
self,
|
||||
segmentation_maps,
|
||||
**kwargs,
|
||||
):
|
||||
"""Preprocesses a single segmentation map."""
|
||||
processed_segmentation_maps = []
|
||||
for segmentation_map in segmentation_maps:
|
||||
segmentation_map = self._process_image(
|
||||
segmentation_map, do_convert_rgb=False, input_data_format=ChannelDimension.FIRST
|
||||
)
|
||||
|
||||
if segmentation_map.ndim == 2:
|
||||
segmentation_map = segmentation_map[None, ...]
|
||||
|
||||
processed_segmentation_maps.append(segmentation_map)
|
||||
|
||||
kwargs["do_normalize"] = False
|
||||
kwargs["do_rescale"] = False
|
||||
kwargs["input_data_format"] = ChannelDimension.FIRST
|
||||
processed_segmentation_maps = self._preprocess(images=processed_segmentation_maps, **kwargs)
|
||||
|
||||
processed_segmentation_maps = processed_segmentation_maps.squeeze(1)
|
||||
|
||||
processed_segmentation_maps = processed_segmentation_maps.to(torch.int64)
|
||||
return processed_segmentation_maps
|
||||
|
||||
def __call__(self, images, segmentation_maps=None, **kwargs):
|
||||
# Overrides the `__call__` method of the `Preprocessor` class such that the images and segmentation maps can both
|
||||
# be passed in as positional arguments.
|
||||
return super().__call__(images, segmentation_maps=segmentation_maps, **kwargs)
|
||||
|
||||
@deprecate_kwarg("reduce_labels", new_name="do_reduce_labels", version="4.41.0")
|
||||
@add_start_docstrings(
|
||||
"Constructs a fast Beit image processor.",
|
||||
BASE_IMAGE_PROCESSOR_FAST_DOCSTRING,
|
||||
"""
|
||||
do_reduce_labels (`bool`, *optional*, defaults to `self.do_reduce_labels`):
|
||||
Whether or not to reduce all label values of segmentation maps by 1. Usually used for datasets where 0
|
||||
is used for background, and background itself is not included in all classes of a dataset (e.g.
|
||||
ADE20k). The background label will be replaced by 255.
|
||||
""",
|
||||
)
|
||||
def preprocess(
|
||||
self,
|
||||
images: ImageInput,
|
||||
segmentation_maps: Optional[ImageInput] = None,
|
||||
**kwargs: Unpack[DefaultFastImageProcessorKwargs],
|
||||
) -> BatchFeature:
|
||||
validate_kwargs(captured_kwargs=kwargs.keys(), valid_processor_keys=self.valid_kwargs.__annotations__.keys())
|
||||
# Set default kwargs from self. This ensures that if a kwarg is not provided
|
||||
# by the user, it gets its default value from the instance, or is set to None.
|
||||
for kwarg_name in self.valid_kwargs.__annotations__:
|
||||
kwargs.setdefault(kwarg_name, getattr(self, kwarg_name, None))
|
||||
|
||||
# Extract parameters that are only used for preparing the input images
|
||||
do_convert_rgb = kwargs.pop("do_convert_rgb")
|
||||
input_data_format = kwargs.pop("input_data_format")
|
||||
device = kwargs.pop("device")
|
||||
# Prepare input images
|
||||
images = self._prepare_input_images(
|
||||
images=images, do_convert_rgb=do_convert_rgb, input_data_format=input_data_format, device=device
|
||||
)
|
||||
|
||||
# Prepare segmentation maps
|
||||
if segmentation_maps is not None:
|
||||
segmentation_maps = make_list_of_images(images=segmentation_maps, expected_ndims=2)
|
||||
|
||||
# Update kwargs that need further processing before being validated
|
||||
kwargs = self._further_process_kwargs(**kwargs)
|
||||
|
||||
# Validate kwargs
|
||||
self._validate_preprocess_kwargs(**kwargs)
|
||||
|
||||
# torch resize uses interpolation instead of resample
|
||||
resample = kwargs.pop("resample")
|
||||
kwargs["interpolation"] = (
|
||||
pil_torch_interpolation_mapping[resample] if isinstance(resample, (PILImageResampling, int)) else resample
|
||||
)
|
||||
|
||||
# Pop kwargs that are not needed in _preprocess
|
||||
kwargs.pop("default_to_square")
|
||||
kwargs.pop("data_format")
|
||||
|
||||
images = self._preprocess(
|
||||
images=images,
|
||||
**kwargs,
|
||||
)
|
||||
data = {"pixel_values": images}
|
||||
|
||||
if segmentation_maps is not None:
|
||||
segmentation_maps = self._preprocess_segmentation_maps(
|
||||
segmentation_maps=segmentation_maps,
|
||||
**kwargs,
|
||||
)
|
||||
data["labels"] = segmentation_maps
|
||||
|
||||
return BatchFeature(data=data)
|
||||
|
||||
def post_process_semantic_segmentation(self, outputs, target_sizes: List[Tuple] = None):
|
||||
"""
|
||||
Converts the output of [`BeitForSemanticSegmentation`] into semantic segmentation maps. Only supports PyTorch.
|
||||
|
||||
Args:
|
||||
outputs ([`BeitForSemanticSegmentation`]):
|
||||
Raw outputs of the model.
|
||||
target_sizes (`List[Tuple]` of length `batch_size`, *optional*):
|
||||
List of tuples corresponding to the requested final size (height, width) of each prediction. If unset,
|
||||
predictions will not be resized.
|
||||
|
||||
Returns:
|
||||
semantic_segmentation: `List[torch.Tensor]` of length `batch_size`, where each item is a semantic
|
||||
segmentation map of shape (height, width) corresponding to the target_sizes entry (if `target_sizes` is
|
||||
specified). Each entry of each `torch.Tensor` correspond to a semantic class id.
|
||||
"""
|
||||
# TODO: add support for other frameworks
|
||||
logits = outputs.logits
|
||||
|
||||
# Resize logits and compute semantic segmentation maps
|
||||
if target_sizes is not None:
|
||||
if len(logits) != len(target_sizes):
|
||||
raise ValueError(
|
||||
"Make sure that you pass in as many target sizes as the batch dimension of the logits"
|
||||
)
|
||||
|
||||
if is_torch_tensor(target_sizes):
|
||||
target_sizes = target_sizes.numpy()
|
||||
|
||||
semantic_segmentation = []
|
||||
|
||||
for idx in range(len(logits)):
|
||||
resized_logits = torch.nn.functional.interpolate(
|
||||
logits[idx].unsqueeze(dim=0), size=target_sizes[idx], mode="bilinear", align_corners=False
|
||||
)
|
||||
semantic_map = resized_logits[0].argmax(dim=0)
|
||||
semantic_segmentation.append(semantic_map)
|
||||
else:
|
||||
semantic_segmentation = logits.argmax(dim=1)
|
||||
semantic_segmentation = [semantic_segmentation[i] for i in range(semantic_segmentation.shape[0])]
|
||||
|
||||
return semantic_segmentation
|
||||
|
||||
|
||||
__all__ = ["BeitImageProcessorFast"]
|
@ -18,7 +18,7 @@ import unittest
|
||||
from datasets import load_dataset
|
||||
|
||||
from transformers.testing_utils import require_torch, require_vision
|
||||
from transformers.utils import is_torch_available, is_vision_available
|
||||
from transformers.utils import is_torch_available, is_torchvision_available, is_vision_available
|
||||
|
||||
from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs
|
||||
|
||||
@ -31,6 +31,9 @@ if is_vision_available():
|
||||
|
||||
from transformers import BeitImageProcessor
|
||||
|
||||
if is_torchvision_available():
|
||||
from transformers import BeitImageProcessorFast
|
||||
|
||||
|
||||
class BeitImageProcessingTester:
|
||||
def __init__(
|
||||
@ -118,6 +121,7 @@ def prepare_semantic_batch_inputs():
|
||||
@require_vision
|
||||
class BeitImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
|
||||
image_processing_class = BeitImageProcessor if is_vision_available() else None
|
||||
fast_image_processing_class = BeitImageProcessorFast if is_torchvision_available() else None
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
@ -128,159 +132,196 @@ class BeitImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
|
||||
return self.image_processor_tester.prepare_image_processor_dict()
|
||||
|
||||
def test_image_processor_properties(self):
|
||||
image_processing = self.image_processing_class(**self.image_processor_dict)
|
||||
self.assertTrue(hasattr(image_processing, "do_resize"))
|
||||
self.assertTrue(hasattr(image_processing, "size"))
|
||||
self.assertTrue(hasattr(image_processing, "do_center_crop"))
|
||||
self.assertTrue(hasattr(image_processing, "center_crop"))
|
||||
self.assertTrue(hasattr(image_processing, "do_normalize"))
|
||||
self.assertTrue(hasattr(image_processing, "image_mean"))
|
||||
self.assertTrue(hasattr(image_processing, "image_std"))
|
||||
self.assertTrue(hasattr(image_processing, "do_reduce_labels"))
|
||||
for image_processing_class in self.image_processor_list:
|
||||
image_processing = image_processing_class(**self.image_processor_dict)
|
||||
self.assertTrue(hasattr(image_processing, "do_resize"))
|
||||
self.assertTrue(hasattr(image_processing, "size"))
|
||||
self.assertTrue(hasattr(image_processing, "do_center_crop"))
|
||||
self.assertTrue(hasattr(image_processing, "center_crop"))
|
||||
self.assertTrue(hasattr(image_processing, "do_normalize"))
|
||||
self.assertTrue(hasattr(image_processing, "image_mean"))
|
||||
self.assertTrue(hasattr(image_processing, "image_std"))
|
||||
self.assertTrue(hasattr(image_processing, "do_reduce_labels"))
|
||||
|
||||
def test_image_processor_from_dict_with_kwargs(self):
|
||||
image_processor = self.image_processing_class.from_dict(self.image_processor_dict)
|
||||
self.assertEqual(image_processor.size, {"height": 20, "width": 20})
|
||||
self.assertEqual(image_processor.crop_size, {"height": 18, "width": 18})
|
||||
self.assertEqual(image_processor.do_reduce_labels, False)
|
||||
for image_processing_class in self.image_processor_list:
|
||||
image_processor = image_processing_class.from_dict(self.image_processor_dict)
|
||||
self.assertEqual(image_processor.size, {"height": 20, "width": 20})
|
||||
self.assertEqual(image_processor.crop_size, {"height": 18, "width": 18})
|
||||
self.assertEqual(image_processor.do_reduce_labels, False)
|
||||
|
||||
image_processor = self.image_processing_class.from_dict(
|
||||
self.image_processor_dict, size=42, crop_size=84, do_reduce_labels=True
|
||||
)
|
||||
self.assertEqual(image_processor.size, {"height": 42, "width": 42})
|
||||
self.assertEqual(image_processor.crop_size, {"height": 84, "width": 84})
|
||||
self.assertEqual(image_processor.do_reduce_labels, True)
|
||||
image_processor = image_processing_class.from_dict(
|
||||
self.image_processor_dict, size=42, crop_size=84, do_reduce_labels=True
|
||||
)
|
||||
self.assertEqual(image_processor.size, {"height": 42, "width": 42})
|
||||
self.assertEqual(image_processor.crop_size, {"height": 84, "width": 84})
|
||||
self.assertEqual(image_processor.do_reduce_labels, True)
|
||||
|
||||
def test_call_segmentation_maps(self):
|
||||
# Initialize image_processing
|
||||
image_processing = self.image_processing_class(**self.image_processor_dict)
|
||||
# create random PyTorch tensors
|
||||
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, torchify=True)
|
||||
maps = []
|
||||
for image in image_inputs:
|
||||
self.assertIsInstance(image, torch.Tensor)
|
||||
maps.append(torch.zeros(image.shape[-2:]).long())
|
||||
for image_processing_class in self.image_processor_list:
|
||||
# Initialize image_processing
|
||||
image_processing = image_processing_class(**self.image_processor_dict)
|
||||
# create random PyTorch tensors
|
||||
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, torchify=True)
|
||||
maps = []
|
||||
for image in image_inputs:
|
||||
self.assertIsInstance(image, torch.Tensor)
|
||||
maps.append(torch.zeros(image.shape[-2:]).long())
|
||||
|
||||
# Test not batched input
|
||||
encoding = image_processing(image_inputs[0], maps[0], return_tensors="pt")
|
||||
self.assertEqual(
|
||||
encoding["pixel_values"].shape,
|
||||
(
|
||||
1,
|
||||
self.image_processor_tester.num_channels,
|
||||
self.image_processor_tester.crop_size["height"],
|
||||
self.image_processor_tester.crop_size["width"],
|
||||
),
|
||||
)
|
||||
self.assertEqual(
|
||||
encoding["labels"].shape,
|
||||
(
|
||||
1,
|
||||
self.image_processor_tester.crop_size["height"],
|
||||
self.image_processor_tester.crop_size["width"],
|
||||
),
|
||||
)
|
||||
self.assertEqual(encoding["labels"].dtype, torch.long)
|
||||
self.assertTrue(encoding["labels"].min().item() >= 0)
|
||||
self.assertTrue(encoding["labels"].max().item() <= 255)
|
||||
# Test not batched input
|
||||
encoding = image_processing(image_inputs[0], maps[0], return_tensors="pt")
|
||||
self.assertEqual(
|
||||
encoding["pixel_values"].shape,
|
||||
(
|
||||
1,
|
||||
self.image_processor_tester.num_channels,
|
||||
self.image_processor_tester.crop_size["height"],
|
||||
self.image_processor_tester.crop_size["width"],
|
||||
),
|
||||
)
|
||||
self.assertEqual(
|
||||
encoding["labels"].shape,
|
||||
(
|
||||
1,
|
||||
self.image_processor_tester.crop_size["height"],
|
||||
self.image_processor_tester.crop_size["width"],
|
||||
),
|
||||
)
|
||||
self.assertEqual(encoding["labels"].dtype, torch.long)
|
||||
self.assertTrue(encoding["labels"].min().item() >= 0)
|
||||
self.assertTrue(encoding["labels"].max().item() <= 255)
|
||||
|
||||
# Test batched
|
||||
encoding = image_processing(image_inputs, maps, return_tensors="pt")
|
||||
self.assertEqual(
|
||||
encoding["pixel_values"].shape,
|
||||
(
|
||||
self.image_processor_tester.batch_size,
|
||||
self.image_processor_tester.num_channels,
|
||||
self.image_processor_tester.crop_size["height"],
|
||||
self.image_processor_tester.crop_size["width"],
|
||||
),
|
||||
)
|
||||
self.assertEqual(
|
||||
encoding["labels"].shape,
|
||||
(
|
||||
self.image_processor_tester.batch_size,
|
||||
self.image_processor_tester.crop_size["height"],
|
||||
self.image_processor_tester.crop_size["width"],
|
||||
),
|
||||
)
|
||||
self.assertEqual(encoding["labels"].dtype, torch.long)
|
||||
self.assertTrue(encoding["labels"].min().item() >= 0)
|
||||
self.assertTrue(encoding["labels"].max().item() <= 255)
|
||||
# Test batched
|
||||
encoding = image_processing(image_inputs, maps, return_tensors="pt")
|
||||
self.assertEqual(
|
||||
encoding["pixel_values"].shape,
|
||||
(
|
||||
self.image_processor_tester.batch_size,
|
||||
self.image_processor_tester.num_channels,
|
||||
self.image_processor_tester.crop_size["height"],
|
||||
self.image_processor_tester.crop_size["width"],
|
||||
),
|
||||
)
|
||||
self.assertEqual(
|
||||
encoding["labels"].shape,
|
||||
(
|
||||
self.image_processor_tester.batch_size,
|
||||
self.image_processor_tester.crop_size["height"],
|
||||
self.image_processor_tester.crop_size["width"],
|
||||
),
|
||||
)
|
||||
self.assertEqual(encoding["labels"].dtype, torch.long)
|
||||
self.assertTrue(encoding["labels"].min().item() >= 0)
|
||||
self.assertTrue(encoding["labels"].max().item() <= 255)
|
||||
|
||||
# Test not batched input (PIL images)
|
||||
image, segmentation_map = prepare_semantic_single_inputs()
|
||||
# Test not batched input (PIL images)
|
||||
image, segmentation_map = prepare_semantic_single_inputs()
|
||||
|
||||
encoding = image_processing(image, segmentation_map, return_tensors="pt")
|
||||
self.assertEqual(
|
||||
encoding["pixel_values"].shape,
|
||||
(
|
||||
1,
|
||||
self.image_processor_tester.num_channels,
|
||||
self.image_processor_tester.crop_size["height"],
|
||||
self.image_processor_tester.crop_size["width"],
|
||||
),
|
||||
)
|
||||
self.assertEqual(
|
||||
encoding["labels"].shape,
|
||||
(
|
||||
1,
|
||||
self.image_processor_tester.crop_size["height"],
|
||||
self.image_processor_tester.crop_size["width"],
|
||||
),
|
||||
)
|
||||
self.assertEqual(encoding["labels"].dtype, torch.long)
|
||||
self.assertTrue(encoding["labels"].min().item() >= 0)
|
||||
self.assertTrue(encoding["labels"].max().item() <= 255)
|
||||
encoding = image_processing(image, segmentation_map, return_tensors="pt")
|
||||
self.assertEqual(
|
||||
encoding["pixel_values"].shape,
|
||||
(
|
||||
1,
|
||||
self.image_processor_tester.num_channels,
|
||||
self.image_processor_tester.crop_size["height"],
|
||||
self.image_processor_tester.crop_size["width"],
|
||||
),
|
||||
)
|
||||
self.assertEqual(
|
||||
encoding["labels"].shape,
|
||||
(
|
||||
1,
|
||||
self.image_processor_tester.crop_size["height"],
|
||||
self.image_processor_tester.crop_size["width"],
|
||||
),
|
||||
)
|
||||
self.assertEqual(encoding["labels"].dtype, torch.long)
|
||||
self.assertTrue(encoding["labels"].min().item() >= 0)
|
||||
self.assertTrue(encoding["labels"].max().item() <= 255)
|
||||
|
||||
# Test batched input (PIL images)
|
||||
images, segmentation_maps = prepare_semantic_batch_inputs()
|
||||
# Test batched input (PIL images)
|
||||
images, segmentation_maps = prepare_semantic_batch_inputs()
|
||||
|
||||
encoding = image_processing(images, segmentation_maps, return_tensors="pt")
|
||||
self.assertEqual(
|
||||
encoding["pixel_values"].shape,
|
||||
(
|
||||
2,
|
||||
self.image_processor_tester.num_channels,
|
||||
self.image_processor_tester.crop_size["height"],
|
||||
self.image_processor_tester.crop_size["width"],
|
||||
),
|
||||
)
|
||||
self.assertEqual(
|
||||
encoding["labels"].shape,
|
||||
(
|
||||
2,
|
||||
self.image_processor_tester.crop_size["height"],
|
||||
self.image_processor_tester.crop_size["width"],
|
||||
),
|
||||
)
|
||||
self.assertEqual(encoding["labels"].dtype, torch.long)
|
||||
self.assertTrue(encoding["labels"].min().item() >= 0)
|
||||
self.assertTrue(encoding["labels"].max().item() <= 255)
|
||||
encoding = image_processing(images, segmentation_maps, return_tensors="pt")
|
||||
self.assertEqual(
|
||||
encoding["pixel_values"].shape,
|
||||
(
|
||||
2,
|
||||
self.image_processor_tester.num_channels,
|
||||
self.image_processor_tester.crop_size["height"],
|
||||
self.image_processor_tester.crop_size["width"],
|
||||
),
|
||||
)
|
||||
self.assertEqual(
|
||||
encoding["labels"].shape,
|
||||
(
|
||||
2,
|
||||
self.image_processor_tester.crop_size["height"],
|
||||
self.image_processor_tester.crop_size["width"],
|
||||
),
|
||||
)
|
||||
self.assertEqual(encoding["labels"].dtype, torch.long)
|
||||
self.assertTrue(encoding["labels"].min().item() >= 0)
|
||||
self.assertTrue(encoding["labels"].max().item() <= 255)
|
||||
|
||||
def test_reduce_labels(self):
|
||||
# Initialize image_processing
|
||||
image_processing = self.image_processing_class(**self.image_processor_dict)
|
||||
for image_processing_class in self.image_processor_list:
|
||||
# Initialize image_processing
|
||||
image_processing = image_processing_class(**self.image_processor_dict)
|
||||
|
||||
# ADE20k has 150 classes, and the background is included, so labels should be between 0 and 150
|
||||
image, map = prepare_semantic_single_inputs()
|
||||
encoding = image_processing(image, map, return_tensors="pt")
|
||||
self.assertTrue(encoding["labels"].min().item() >= 0)
|
||||
self.assertTrue(encoding["labels"].max().item() <= 150)
|
||||
# ADE20k has 150 classes, and the background is included, so labels should be between 0 and 150
|
||||
image, map = prepare_semantic_single_inputs()
|
||||
encoding = image_processing(image, map, return_tensors="pt")
|
||||
self.assertTrue(encoding["labels"].min().item() >= 0)
|
||||
self.assertTrue(encoding["labels"].max().item() <= 150)
|
||||
|
||||
image_processing.do_reduce_labels = True
|
||||
encoding = image_processing(image, map, return_tensors="pt")
|
||||
self.assertTrue(encoding["labels"].min().item() >= 0)
|
||||
self.assertTrue(encoding["labels"].max().item() <= 255)
|
||||
image_processing.do_reduce_labels = True
|
||||
encoding = image_processing(image, map, return_tensors="pt")
|
||||
self.assertTrue(encoding["labels"].min().item() >= 0)
|
||||
self.assertTrue(encoding["labels"].max().item() <= 255)
|
||||
|
||||
def test_removed_deprecated_kwargs(self):
|
||||
image_processor_dict = dict(self.image_processor_dict)
|
||||
image_processor_dict.pop("do_reduce_labels", None)
|
||||
image_processor_dict["reduce_labels"] = True
|
||||
def test_slow_fast_equivalence(self):
|
||||
if not self.test_slow_image_processor or not self.test_fast_image_processor:
|
||||
self.skipTest(reason="Skipping slow/fast equivalence test")
|
||||
|
||||
# test we are able to create the image processor with the deprecated kwargs
|
||||
image_processor = self.image_processing_class(**image_processor_dict)
|
||||
self.assertEqual(image_processor.do_reduce_labels, True)
|
||||
if self.image_processing_class is None or self.fast_image_processing_class is None:
|
||||
self.skipTest(reason="Skipping slow/fast equivalence test as one of the image processors is not defined")
|
||||
|
||||
# test we still support reduce_labels with config
|
||||
image_processor = self.image_processing_class.from_dict(image_processor_dict)
|
||||
self.assertEqual(image_processor.do_reduce_labels, True)
|
||||
dummy_image, dummy_map = prepare_semantic_single_inputs()
|
||||
|
||||
image_processor_slow = self.image_processing_class(**self.image_processor_dict)
|
||||
image_processor_fast = self.fast_image_processing_class(**self.image_processor_dict)
|
||||
|
||||
image_encoding_slow = image_processor_slow(dummy_image, segmentation_maps=dummy_map, return_tensors="pt")
|
||||
image_encoding_fast = image_processor_fast(dummy_image, segmentation_maps=dummy_map, return_tensors="pt")
|
||||
|
||||
self.assertTrue(torch.allclose(image_encoding_slow.pixel_values, image_encoding_fast.pixel_values, atol=1e-1))
|
||||
self.assertLessEqual(
|
||||
torch.mean(torch.abs(image_encoding_slow.pixel_values - image_encoding_fast.pixel_values)).item(), 1e-3
|
||||
)
|
||||
self.assertTrue(torch.allclose(image_encoding_slow.labels, image_encoding_fast.labels, atol=1e-1))
|
||||
|
||||
def test_slow_fast_equivalence_batched(self):
|
||||
if not self.test_slow_image_processor or not self.test_fast_image_processor:
|
||||
self.skipTest(reason="Skipping slow/fast equivalence test")
|
||||
|
||||
if self.image_processing_class is None or self.fast_image_processing_class is None:
|
||||
self.skipTest(reason="Skipping slow/fast equivalence test as one of the image processors is not defined")
|
||||
|
||||
if hasattr(self.image_processor_tester, "do_center_crop") and self.image_processor_tester.do_center_crop:
|
||||
self.skipTest(
|
||||
reason="Skipping as do_center_crop is True and center_crop functions are not equivalent for fast and slow processors"
|
||||
)
|
||||
|
||||
dummy_images, dummy_maps = prepare_semantic_batch_inputs()
|
||||
|
||||
image_processor_slow = self.image_processing_class(**self.image_processor_dict)
|
||||
image_processor_fast = self.fast_image_processing_class(**self.image_processor_dict)
|
||||
|
||||
encoding_slow = image_processor_slow(dummy_images, segmentation_maps=dummy_maps, return_tensors="pt")
|
||||
encoding_fast = image_processor_fast(dummy_images, segmentation_maps=dummy_maps, return_tensors="pt")
|
||||
|
||||
self.assertTrue(torch.allclose(encoding_slow.pixel_values, encoding_fast.pixel_values, atol=1e-1))
|
||||
self.assertLessEqual(
|
||||
torch.mean(torch.abs(encoding_slow.pixel_values - encoding_fast.pixel_values)).item(), 1e-3
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user