mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 12:50:06 +06:00
36978 | Fast image processor for DPT model (#37481)
* chore: ran codegen script * test: test_image_processor_properties * test: test_image_processor_from_dict_with_kwargs * test: wip - test_padding * test: test_padding * test: test_keep_aspect_ratio * wip * test * test: wip * test: wip * test: test_call_segmentation_maps, wip * chore: tidy up * test: test_call_segmentation_maps * fix: test_save_load_fast_slow * test: reduce labels * chore: make fixup * chore: rm comment * chore: tidy * chore remove comment * refactor: no need to infer channel dimesnion * refactor: encapsulate logic for preparing segmentation maps * refactor: improve readability of segmentation_map preparation * improvement: batched version of pad_image * chore: fixup * docs * chore: make quality * chore: remove unecessary comment * fix: add SemanticSegmentationMixin * feat: add post_process_depth_estimation to fast dpt image processor * chore: fix formatting * remove max_height, max_width * fix: better way of processin segmentation maps - copied from Beit Fast processor * chore: formatting + remove TODO * chore: fixup styles * chore: remove unecessary line break * chore: core review suggestion to remove autodocstring * fix: add do_reduce_labels logic + refactor - refactor preprocess logic to make it consistent with other processors - add missing reduce labels logic * refactor: remove deprecated mixin * chore: fixup * use modular for dpt + final nit changes * fix style --------- Co-authored-by: Samuel Rae <samuelrae@Samuels-Air.fritz.box> Co-authored-by: yonigozlan <yoni.gozlan@huggingface.co>
This commit is contained in:
parent
c27f628e98
commit
b922b22ec2
@ -78,7 +78,13 @@ If you're interested in submitting a resource to be included here, please feel f
|
|||||||
|
|
||||||
[[autodoc]] DPTImageProcessor
|
[[autodoc]] DPTImageProcessor
|
||||||
- preprocess
|
- preprocess
|
||||||
|
|
||||||
|
## DPTImageProcessorFast
|
||||||
|
|
||||||
|
[[autodoc]] DPTImageProcessorFast
|
||||||
|
- preprocess
|
||||||
- post_process_semantic_segmentation
|
- post_process_semantic_segmentation
|
||||||
|
- post_process_depth_estimation
|
||||||
|
|
||||||
## DPTModel
|
## DPTModel
|
||||||
|
|
||||||
|
@ -74,14 +74,14 @@ else:
|
|||||||
("data2vec-vision", ("BeitImageProcessor", "BeitImageProcessorFast")),
|
("data2vec-vision", ("BeitImageProcessor", "BeitImageProcessorFast")),
|
||||||
("deformable_detr", ("DeformableDetrImageProcessor", "DeformableDetrImageProcessorFast")),
|
("deformable_detr", ("DeformableDetrImageProcessor", "DeformableDetrImageProcessorFast")),
|
||||||
("deit", ("DeiTImageProcessor", "DeiTImageProcessorFast")),
|
("deit", ("DeiTImageProcessor", "DeiTImageProcessorFast")),
|
||||||
("depth_anything", ("DPTImageProcessor",)),
|
("depth_anything", ("DPTImageProcessor", "DPTImageProcessorFast")),
|
||||||
("depth_pro", ("DepthProImageProcessor", "DepthProImageProcessorFast")),
|
("depth_pro", ("DepthProImageProcessor", "DepthProImageProcessorFast")),
|
||||||
("deta", ("DetaImageProcessor",)),
|
("deta", ("DetaImageProcessor",)),
|
||||||
("detr", ("DetrImageProcessor", "DetrImageProcessorFast")),
|
("detr", ("DetrImageProcessor", "DetrImageProcessorFast")),
|
||||||
("dinat", ("ViTImageProcessor", "ViTImageProcessorFast")),
|
("dinat", ("ViTImageProcessor", "ViTImageProcessorFast")),
|
||||||
("dinov2", ("BitImageProcessor", "BitImageProcessorFast")),
|
("dinov2", ("BitImageProcessor", "BitImageProcessorFast")),
|
||||||
("donut-swin", ("DonutImageProcessor", "DonutImageProcessorFast")),
|
("donut-swin", ("DonutImageProcessor", "DonutImageProcessorFast")),
|
||||||
("dpt", ("DPTImageProcessor",)),
|
("dpt", ("DPTImageProcessor", "DPTImageProcessorFast")),
|
||||||
("efficientformer", ("EfficientFormerImageProcessor",)),
|
("efficientformer", ("EfficientFormerImageProcessor",)),
|
||||||
("efficientnet", ("EfficientNetImageProcessor", "EfficientNetImageProcessorFast")),
|
("efficientnet", ("EfficientNetImageProcessor", "EfficientNetImageProcessorFast")),
|
||||||
("flava", ("FlavaImageProcessor", "FlavaImageProcessorFast")),
|
("flava", ("FlavaImageProcessor", "FlavaImageProcessorFast")),
|
||||||
|
@ -174,11 +174,6 @@ class BeitImageProcessorFast(BaseImageProcessorFast):
|
|||||||
processed_segmentation_maps = processed_segmentation_maps.to(torch.int64)
|
processed_segmentation_maps = processed_segmentation_maps.to(torch.int64)
|
||||||
return processed_segmentation_maps
|
return processed_segmentation_maps
|
||||||
|
|
||||||
def __call__(self, images, segmentation_maps=None, **kwargs):
|
|
||||||
# Overrides the `__call__` method of the `Preprocessor` class such that the images and segmentation maps can both
|
|
||||||
# be passed in as positional arguments.
|
|
||||||
return super().__call__(images, segmentation_maps=segmentation_maps, **kwargs)
|
|
||||||
|
|
||||||
@auto_docstring
|
@auto_docstring
|
||||||
def preprocess(
|
def preprocess(
|
||||||
self,
|
self,
|
||||||
|
@ -21,6 +21,7 @@ if TYPE_CHECKING:
|
|||||||
from .configuration_dpt import *
|
from .configuration_dpt import *
|
||||||
from .feature_extraction_dpt import *
|
from .feature_extraction_dpt import *
|
||||||
from .image_processing_dpt import *
|
from .image_processing_dpt import *
|
||||||
|
from .image_processing_dpt_fast import *
|
||||||
from .modeling_dpt import *
|
from .modeling_dpt import *
|
||||||
else:
|
else:
|
||||||
import sys
|
import sys
|
||||||
|
474
src/transformers/models/dpt/image_processing_dpt_fast.py
Normal file
474
src/transformers/models/dpt/image_processing_dpt_fast.py
Normal file
@ -0,0 +1,474 @@
|
|||||||
|
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||||
|
# This file was automatically generated from src/transformers/models/dpt/modular_dpt.py.
|
||||||
|
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
||||||
|
# the file from the modular. If any change should be done, please apply the change to the
|
||||||
|
# modular_dpt.py file directly. One of our CI enforces this.
|
||||||
|
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2025 HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import math
|
||||||
|
from collections.abc import Iterable
|
||||||
|
from typing import TYPE_CHECKING, Optional, Union
|
||||||
|
|
||||||
|
from transformers.image_processing_base import BatchFeature
|
||||||
|
from transformers.image_transforms import group_images_by_shape, reorder_images
|
||||||
|
|
||||||
|
from ...image_processing_utils_fast import BaseImageProcessorFast, DefaultFastImageProcessorKwargs
|
||||||
|
from ...image_utils import (
|
||||||
|
IMAGENET_STANDARD_MEAN,
|
||||||
|
IMAGENET_STANDARD_STD,
|
||||||
|
ChannelDimension,
|
||||||
|
ImageInput,
|
||||||
|
PILImageResampling,
|
||||||
|
SizeDict,
|
||||||
|
is_torch_tensor,
|
||||||
|
make_list_of_images,
|
||||||
|
pil_torch_interpolation_mapping,
|
||||||
|
validate_kwargs,
|
||||||
|
)
|
||||||
|
from ...processing_utils import Unpack
|
||||||
|
from ...utils import (
|
||||||
|
TensorType,
|
||||||
|
auto_docstring,
|
||||||
|
is_torch_available,
|
||||||
|
is_torchvision_available,
|
||||||
|
is_torchvision_v2_available,
|
||||||
|
requires_backends,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from ...modeling_outputs import DepthEstimatorOutput
|
||||||
|
|
||||||
|
if is_torch_available():
|
||||||
|
import torch
|
||||||
|
|
||||||
|
if is_torchvision_v2_available():
|
||||||
|
from torchvision.transforms.v2 import functional as F
|
||||||
|
elif is_torchvision_available():
|
||||||
|
from torchvision.transforms import functional as F
|
||||||
|
|
||||||
|
|
||||||
|
class DPTFastImageProcessorKwargs(DefaultFastImageProcessorKwargs):
|
||||||
|
"""
|
||||||
|
ensure_multiple_of (`int`, *optional*, defaults to 1):
|
||||||
|
If `do_resize` is `True`, the image is resized to a size that is a multiple of this value. Can be overidden
|
||||||
|
by `ensure_multiple_of` in `preprocess`.
|
||||||
|
do_pad (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether to apply center padding. This was introduced in the DINOv2 paper, which uses the model in
|
||||||
|
combination with DPT.
|
||||||
|
size_divisor (`int`, *optional*):
|
||||||
|
If `do_pad` is `True`, pads the image dimensions to be divisible by this value. This was introduced in the
|
||||||
|
DINOv2 paper, which uses the model in combination with DPT.
|
||||||
|
keep_aspect_ratio (`bool`, *optional*, defaults to `False`):
|
||||||
|
If `True`, the image is resized to the largest possible size such that the aspect ratio is preserved. Can
|
||||||
|
be overidden by `keep_aspect_ratio` in `preprocess`.
|
||||||
|
do_reduce_labels (`bool`, *optional*, defaults to `self.do_reduce_labels`):
|
||||||
|
Whether or not to reduce all label values of segmentation maps by 1. Usually used for datasets where 0
|
||||||
|
is used for background, and background itself is not included in all classes of a dataset (e.g.
|
||||||
|
ADE20k). The background label will be replaced by 255.
|
||||||
|
"""
|
||||||
|
|
||||||
|
ensure_multiple_of: Optional[int]
|
||||||
|
size_divisor: Optional[int]
|
||||||
|
do_pad: Optional[bool]
|
||||||
|
keep_aspect_ratio: Optional[bool]
|
||||||
|
do_reduce_labels: Optional[bool]
|
||||||
|
|
||||||
|
|
||||||
|
def get_resize_output_image_size(
|
||||||
|
input_image: "torch.Tensor",
|
||||||
|
output_size: Union[int, Iterable[int]],
|
||||||
|
keep_aspect_ratio: bool,
|
||||||
|
multiple: int,
|
||||||
|
) -> SizeDict:
|
||||||
|
def constrain_to_multiple_of(val, multiple, min_val=0, max_val=None):
|
||||||
|
x = round(val / multiple) * multiple
|
||||||
|
|
||||||
|
if max_val is not None and x > max_val:
|
||||||
|
x = math.floor(val / multiple) * multiple
|
||||||
|
|
||||||
|
if x < min_val:
|
||||||
|
x = math.ceil(val / multiple) * multiple
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
input_height, input_width = input_image.shape[-2:]
|
||||||
|
output_height, output_width = output_size
|
||||||
|
|
||||||
|
# determine new height and width
|
||||||
|
scale_height = output_height / input_height
|
||||||
|
scale_width = output_width / input_width
|
||||||
|
|
||||||
|
if keep_aspect_ratio:
|
||||||
|
# scale as little as possible
|
||||||
|
if abs(1 - scale_width) < abs(1 - scale_height):
|
||||||
|
# fit width
|
||||||
|
scale_height = scale_width
|
||||||
|
else:
|
||||||
|
# fit height
|
||||||
|
scale_width = scale_height
|
||||||
|
|
||||||
|
new_height = constrain_to_multiple_of(scale_height * input_height, multiple=multiple)
|
||||||
|
new_width = constrain_to_multiple_of(scale_width * input_width, multiple=multiple)
|
||||||
|
|
||||||
|
return SizeDict(height=new_height, width=new_width)
|
||||||
|
|
||||||
|
|
||||||
|
@auto_docstring
|
||||||
|
class DPTImageProcessorFast(BaseImageProcessorFast):
|
||||||
|
resample = PILImageResampling.BICUBIC
|
||||||
|
image_mean = IMAGENET_STANDARD_MEAN
|
||||||
|
image_std = IMAGENET_STANDARD_STD
|
||||||
|
size = {"height": 384, "width": 384}
|
||||||
|
default_to_square = True
|
||||||
|
crop_size = None
|
||||||
|
do_resize = True
|
||||||
|
do_center_crop = None
|
||||||
|
do_rescale = True
|
||||||
|
do_normalize = True
|
||||||
|
do_reduce_labels = None
|
||||||
|
|
||||||
|
valid_kwargs = DPTFastImageProcessorKwargs
|
||||||
|
do_pad = False
|
||||||
|
rescale_factor = 1 / 255
|
||||||
|
ensure_multiple_of = 1
|
||||||
|
keep_aspect_ratio = False
|
||||||
|
|
||||||
|
def __init__(self, **kwargs: Unpack[DPTFastImageProcessorKwargs]):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
def reduce_label(self, labels: list["torch.Tensor"]):
|
||||||
|
for idx in range(len(labels)):
|
||||||
|
label = labels[idx]
|
||||||
|
label = torch.where(label == 0, torch.tensor(255, dtype=label.dtype), label)
|
||||||
|
label = label - 1
|
||||||
|
label = torch.where(label == 254, torch.tensor(255, dtype=label.dtype), label)
|
||||||
|
labels[idx] = label
|
||||||
|
|
||||||
|
return label
|
||||||
|
|
||||||
|
def _preprocess(
|
||||||
|
self,
|
||||||
|
images: list["torch.Tensor"],
|
||||||
|
do_reduce_labels: bool,
|
||||||
|
do_resize: bool,
|
||||||
|
size: SizeDict,
|
||||||
|
interpolation: Optional["F.InterpolationMode"],
|
||||||
|
do_center_crop: bool,
|
||||||
|
crop_size: SizeDict,
|
||||||
|
do_rescale: bool,
|
||||||
|
rescale_factor: float,
|
||||||
|
do_normalize: bool,
|
||||||
|
image_mean: Optional[Union[float, list[float]]],
|
||||||
|
image_std: Optional[Union[float, list[float]]],
|
||||||
|
return_tensors: Optional[Union[str, TensorType]],
|
||||||
|
keep_aspect_ratio: bool,
|
||||||
|
ensure_multiple_of: Optional[int],
|
||||||
|
do_pad: bool,
|
||||||
|
size_divisor: Optional[int],
|
||||||
|
**kwargs,
|
||||||
|
) -> BatchFeature:
|
||||||
|
if do_reduce_labels:
|
||||||
|
images = self.reduce_label(images)
|
||||||
|
|
||||||
|
# Group images by size for batched resizing
|
||||||
|
grouped_images, grouped_images_index = group_images_by_shape(images)
|
||||||
|
resized_images_grouped = {}
|
||||||
|
for shape, stacked_images in grouped_images.items():
|
||||||
|
if do_resize:
|
||||||
|
stacked_images = self.resize(
|
||||||
|
image=stacked_images,
|
||||||
|
size=size,
|
||||||
|
interpolation=interpolation,
|
||||||
|
ensure_multiple_of=ensure_multiple_of,
|
||||||
|
keep_aspect_ratio=keep_aspect_ratio,
|
||||||
|
)
|
||||||
|
resized_images_grouped[shape] = stacked_images
|
||||||
|
resized_images = reorder_images(resized_images_grouped, grouped_images_index)
|
||||||
|
|
||||||
|
# Group images by size for further processing
|
||||||
|
# Needed in case do_resize is False, or resize returns images with different sizes
|
||||||
|
grouped_images, grouped_images_index = group_images_by_shape(resized_images)
|
||||||
|
processed_images_grouped = {}
|
||||||
|
for shape, stacked_images in grouped_images.items():
|
||||||
|
if do_center_crop:
|
||||||
|
stacked_images = self.center_crop(stacked_images, crop_size)
|
||||||
|
if do_pad:
|
||||||
|
stacked_images = self.pad_image(stacked_images, size_divisor)
|
||||||
|
# Fused rescale and normalize
|
||||||
|
stacked_images = self.rescale_and_normalize(
|
||||||
|
stacked_images, do_rescale, rescale_factor, do_normalize, image_mean, image_std
|
||||||
|
)
|
||||||
|
processed_images_grouped[shape] = stacked_images
|
||||||
|
|
||||||
|
processed_images = reorder_images(processed_images_grouped, grouped_images_index)
|
||||||
|
processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
|
||||||
|
return processed_images
|
||||||
|
|
||||||
|
def _preprocess_images(
|
||||||
|
self,
|
||||||
|
images,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""Preprocesses images."""
|
||||||
|
kwargs["do_reduce_labels"] = False
|
||||||
|
processed_images = self._preprocess(images=images, **kwargs)
|
||||||
|
return processed_images
|
||||||
|
|
||||||
|
def _preprocess_segmentation_maps(
|
||||||
|
self,
|
||||||
|
segmentation_maps,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""Preprocesses segmentation maps."""
|
||||||
|
processed_segmentation_maps = []
|
||||||
|
for segmentation_map in segmentation_maps:
|
||||||
|
segmentation_map = self._process_image(
|
||||||
|
segmentation_map, do_convert_rgb=False, input_data_format=ChannelDimension.FIRST
|
||||||
|
)
|
||||||
|
|
||||||
|
if segmentation_map.ndim == 2:
|
||||||
|
segmentation_map = segmentation_map[None, ...]
|
||||||
|
|
||||||
|
processed_segmentation_maps.append(segmentation_map)
|
||||||
|
|
||||||
|
kwargs["do_normalize"] = False
|
||||||
|
kwargs["do_rescale"] = False
|
||||||
|
kwargs["input_data_format"] = ChannelDimension.FIRST
|
||||||
|
processed_segmentation_maps = self._preprocess(images=processed_segmentation_maps, **kwargs)
|
||||||
|
|
||||||
|
processed_segmentation_maps = processed_segmentation_maps.squeeze(1)
|
||||||
|
|
||||||
|
processed_segmentation_maps = processed_segmentation_maps.to(torch.int64)
|
||||||
|
return processed_segmentation_maps
|
||||||
|
|
||||||
|
@auto_docstring
|
||||||
|
def preprocess(
|
||||||
|
self,
|
||||||
|
images: ImageInput,
|
||||||
|
segmentation_maps: Optional[ImageInput] = None,
|
||||||
|
**kwargs: Unpack[DPTFastImageProcessorKwargs],
|
||||||
|
) -> BatchFeature:
|
||||||
|
r"""
|
||||||
|
segmentation_maps (`ImageInput`, *optional*):
|
||||||
|
The segmentation maps to preprocess.
|
||||||
|
"""
|
||||||
|
validate_kwargs(captured_kwargs=kwargs.keys(), valid_processor_keys=self.valid_kwargs.__annotations__.keys())
|
||||||
|
# Set default kwargs from self. This ensures that if a kwarg is not provided
|
||||||
|
# by the user, it gets its default value from the instance, or is set to None.
|
||||||
|
for kwarg_name in self.valid_kwargs.__annotations__:
|
||||||
|
kwargs.setdefault(kwarg_name, getattr(self, kwarg_name, None))
|
||||||
|
|
||||||
|
# Extract parameters that are only used for preparing the input images
|
||||||
|
do_convert_rgb = kwargs.pop("do_convert_rgb")
|
||||||
|
input_data_format = kwargs.pop("input_data_format")
|
||||||
|
device = kwargs.pop("device")
|
||||||
|
# Prepare input images
|
||||||
|
images = self._prepare_input_images(
|
||||||
|
images=images, do_convert_rgb=do_convert_rgb, input_data_format=input_data_format, device=device
|
||||||
|
)
|
||||||
|
|
||||||
|
# Prepare segmentation maps
|
||||||
|
if segmentation_maps is not None:
|
||||||
|
segmentation_maps = make_list_of_images(images=segmentation_maps, expected_ndims=2)
|
||||||
|
|
||||||
|
# Update kwargs that need further processing before being validated
|
||||||
|
kwargs = self._further_process_kwargs(**kwargs)
|
||||||
|
|
||||||
|
# Validate kwargs
|
||||||
|
self._validate_preprocess_kwargs(**kwargs)
|
||||||
|
|
||||||
|
# torch resize uses interpolation instead of resample
|
||||||
|
resample = kwargs.pop("resample")
|
||||||
|
kwargs["interpolation"] = (
|
||||||
|
pil_torch_interpolation_mapping[resample] if isinstance(resample, (PILImageResampling, int)) else resample
|
||||||
|
)
|
||||||
|
|
||||||
|
# Pop kwargs that are not needed in _preprocess
|
||||||
|
kwargs.pop("default_to_square")
|
||||||
|
kwargs.pop("data_format")
|
||||||
|
|
||||||
|
images = self._preprocess_images(
|
||||||
|
images=images,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
data = {"pixel_values": images}
|
||||||
|
|
||||||
|
if segmentation_maps is not None:
|
||||||
|
segmentation_maps = self._preprocess_segmentation_maps(
|
||||||
|
segmentation_maps=segmentation_maps,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
data["labels"] = segmentation_maps
|
||||||
|
|
||||||
|
return BatchFeature(data=data)
|
||||||
|
|
||||||
|
def post_process_semantic_segmentation(self, outputs, target_sizes: Optional[list[tuple]] = None):
|
||||||
|
"""
|
||||||
|
Converts the output of [`DPTForSemanticSegmentation`] into semantic segmentation maps. Only supports PyTorch.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
outputs ([`DPTForSemanticSegmentation`]):
|
||||||
|
Raw outputs of the model.
|
||||||
|
target_sizes (`list[Tuple]` of length `batch_size`, *optional*):
|
||||||
|
List of tuples corresponding to the requested final size (height, width) of each prediction. If unset,
|
||||||
|
predictions will not be resized.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
semantic_segmentation: `list[torch.Tensor]` of length `batch_size`, where each item is a semantic
|
||||||
|
segmentation map of shape (height, width) corresponding to the target_sizes entry (if `target_sizes` is
|
||||||
|
specified). Each entry of each `torch.Tensor` correspond to a semantic class id.
|
||||||
|
"""
|
||||||
|
# TODO: add support for other frameworks
|
||||||
|
logits = outputs.logits
|
||||||
|
|
||||||
|
# Resize logits and compute semantic segmentation maps
|
||||||
|
if target_sizes is not None:
|
||||||
|
if len(logits) != len(target_sizes):
|
||||||
|
raise ValueError(
|
||||||
|
"Make sure that you pass in as many target sizes as the batch dimension of the logits"
|
||||||
|
)
|
||||||
|
|
||||||
|
if is_torch_tensor(target_sizes):
|
||||||
|
target_sizes = target_sizes.numpy()
|
||||||
|
|
||||||
|
semantic_segmentation = []
|
||||||
|
|
||||||
|
for idx in range(len(logits)):
|
||||||
|
resized_logits = torch.nn.functional.interpolate(
|
||||||
|
logits[idx].unsqueeze(dim=0), size=target_sizes[idx], mode="bilinear", align_corners=False
|
||||||
|
)
|
||||||
|
semantic_map = resized_logits[0].argmax(dim=0)
|
||||||
|
semantic_segmentation.append(semantic_map)
|
||||||
|
else:
|
||||||
|
semantic_segmentation = logits.argmax(dim=1)
|
||||||
|
semantic_segmentation = [semantic_segmentation[i] for i in range(semantic_segmentation.shape[0])]
|
||||||
|
|
||||||
|
return semantic_segmentation
|
||||||
|
|
||||||
|
def resize(
|
||||||
|
self,
|
||||||
|
image: "torch.Tensor",
|
||||||
|
size: SizeDict,
|
||||||
|
interpolation: "F.InterpolationMode" = None,
|
||||||
|
antialias: bool = True,
|
||||||
|
ensure_multiple_of: Optional[int] = 1,
|
||||||
|
keep_aspect_ratio: bool = False,
|
||||||
|
) -> "torch.Tensor":
|
||||||
|
"""
|
||||||
|
Resize an image to `(size["height"], size["width"])`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image (`torch.Tensor`):
|
||||||
|
Image to resize.
|
||||||
|
size (`SizeDict`):
|
||||||
|
Dictionary in the format `{"height": int, "width": int}` specifying the size of the output image.
|
||||||
|
interpolation (`InterpolationMode`, *optional*, defaults to `InterpolationMode.BILINEAR`):
|
||||||
|
`InterpolationMode` filter to use when resizing the image e.g. `InterpolationMode.BICUBIC`.
|
||||||
|
antialias (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether to use antialiasing when resizing the image
|
||||||
|
ensure_multiple_of (`int`, *optional*):
|
||||||
|
If `do_resize` is `True`, the image is resized to a size that is a multiple of this value
|
||||||
|
keep_aspect_ratio (`bool`, *optional*, defaults to `False`):
|
||||||
|
If `True`, and `do_resize` is `True`, the image is resized to the largest possible size such that the aspect ratio is preserved.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`torch.Tensor`: The resized image.
|
||||||
|
"""
|
||||||
|
if not size.height or not size.width:
|
||||||
|
raise ValueError(f"The size dictionary must contain the keys 'height' and 'width'. Got {size.keys()}")
|
||||||
|
|
||||||
|
output_size = get_resize_output_image_size(
|
||||||
|
image,
|
||||||
|
output_size=(size.height, size.width),
|
||||||
|
keep_aspect_ratio=keep_aspect_ratio,
|
||||||
|
multiple=ensure_multiple_of,
|
||||||
|
)
|
||||||
|
return super().resize(image, output_size, interpolation=interpolation, antialias=antialias)
|
||||||
|
|
||||||
|
def pad_image(
|
||||||
|
self,
|
||||||
|
image: "torch.Tensor",
|
||||||
|
size_divisor: int = 1,
|
||||||
|
) -> "torch.Tensor":
|
||||||
|
r"""
|
||||||
|
Center pad a batch of images to be a multiple of `size_divisor`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image (`torch.Tensor`):
|
||||||
|
Image to pad. Can be a batch of images of dimensions (N, C, H, W) or a single image of dimensions (C, H, W).
|
||||||
|
size_divisor (`int`):
|
||||||
|
The width and height of the image will be padded to a multiple of this number.
|
||||||
|
"""
|
||||||
|
height, width = image.shape[-2:]
|
||||||
|
|
||||||
|
def _get_pad(size, size_divisor):
|
||||||
|
new_size = math.ceil(size / size_divisor) * size_divisor
|
||||||
|
pad_size = new_size - size
|
||||||
|
pad_size_left = pad_size // 2
|
||||||
|
pad_size_right = pad_size - pad_size_left
|
||||||
|
return pad_size_left, pad_size_right
|
||||||
|
|
||||||
|
pad_top, pad_bottom = _get_pad(height, size_divisor)
|
||||||
|
pad_left, pad_right = _get_pad(width, size_divisor)
|
||||||
|
padding = (pad_left, pad_top, pad_right, pad_bottom)
|
||||||
|
return F.pad(image, padding)
|
||||||
|
|
||||||
|
def post_process_depth_estimation(
|
||||||
|
self,
|
||||||
|
outputs: "DepthEstimatorOutput",
|
||||||
|
target_sizes: Optional[Union[TensorType, list[tuple[int, int]], None]] = None,
|
||||||
|
) -> list[dict[str, TensorType]]:
|
||||||
|
"""
|
||||||
|
Converts the raw output of [`DepthEstimatorOutput`] into final depth predictions and depth PIL images.
|
||||||
|
Only supports PyTorch.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
outputs ([`DepthEstimatorOutput`]):
|
||||||
|
Raw outputs of the model.
|
||||||
|
target_sizes (`TensorType` or `List[Tuple[int, int]]`, *optional*):
|
||||||
|
Tensor of shape `(batch_size, 2)` or list of tuples (`Tuple[int, int]`) containing the target size
|
||||||
|
(height, width) of each image in the batch. If left to None, predictions will not be resized.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`List[Dict[str, TensorType]]`: A list of dictionaries of tensors representing the processed depth
|
||||||
|
predictions.
|
||||||
|
"""
|
||||||
|
requires_backends(self, "torch")
|
||||||
|
|
||||||
|
predicted_depth = outputs.predicted_depth
|
||||||
|
|
||||||
|
if (target_sizes is not None) and (len(predicted_depth) != len(target_sizes)):
|
||||||
|
raise ValueError(
|
||||||
|
"Make sure that you pass in as many target sizes as the batch dimension of the predicted depth"
|
||||||
|
)
|
||||||
|
|
||||||
|
results = []
|
||||||
|
target_sizes = [None] * len(predicted_depth) if target_sizes is None else target_sizes
|
||||||
|
for depth, target_size in zip(predicted_depth, target_sizes):
|
||||||
|
if target_size is not None:
|
||||||
|
depth = torch.nn.functional.interpolate(
|
||||||
|
depth.unsqueeze(0).unsqueeze(1), size=target_size, mode="bicubic", align_corners=False
|
||||||
|
).squeeze()
|
||||||
|
|
||||||
|
results.append({"predicted_depth": depth})
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["DPTImageProcessorFast"]
|
313
src/transformers/models/dpt/modular_dpt.py
Normal file
313
src/transformers/models/dpt/modular_dpt.py
Normal file
@ -0,0 +1,313 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2025 HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import math
|
||||||
|
from collections.abc import Iterable
|
||||||
|
from typing import TYPE_CHECKING, Optional, Union
|
||||||
|
|
||||||
|
from transformers.image_processing_base import BatchFeature
|
||||||
|
from transformers.image_transforms import group_images_by_shape, reorder_images
|
||||||
|
from transformers.models.beit.image_processing_beit_fast import BeitImageProcessorFast
|
||||||
|
|
||||||
|
from ...image_processing_utils_fast import (
|
||||||
|
DefaultFastImageProcessorKwargs,
|
||||||
|
)
|
||||||
|
from ...image_utils import (
|
||||||
|
IMAGENET_STANDARD_MEAN,
|
||||||
|
IMAGENET_STANDARD_STD,
|
||||||
|
PILImageResampling,
|
||||||
|
SizeDict,
|
||||||
|
)
|
||||||
|
from ...utils import (
|
||||||
|
TensorType,
|
||||||
|
auto_docstring,
|
||||||
|
is_torch_available,
|
||||||
|
is_torchvision_available,
|
||||||
|
is_torchvision_v2_available,
|
||||||
|
requires_backends,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from ...modeling_outputs import DepthEstimatorOutput
|
||||||
|
|
||||||
|
if is_torch_available():
|
||||||
|
import torch
|
||||||
|
|
||||||
|
if is_torchvision_v2_available():
|
||||||
|
from torchvision.transforms.v2 import functional as F
|
||||||
|
elif is_torchvision_available():
|
||||||
|
from torchvision.transforms import functional as F
|
||||||
|
|
||||||
|
|
||||||
|
def get_resize_output_image_size(
|
||||||
|
input_image: "torch.Tensor",
|
||||||
|
output_size: Union[int, Iterable[int]],
|
||||||
|
keep_aspect_ratio: bool,
|
||||||
|
multiple: int,
|
||||||
|
) -> SizeDict:
|
||||||
|
def constrain_to_multiple_of(val, multiple, min_val=0, max_val=None):
|
||||||
|
x = round(val / multiple) * multiple
|
||||||
|
|
||||||
|
if max_val is not None and x > max_val:
|
||||||
|
x = math.floor(val / multiple) * multiple
|
||||||
|
|
||||||
|
if x < min_val:
|
||||||
|
x = math.ceil(val / multiple) * multiple
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
input_height, input_width = input_image.shape[-2:]
|
||||||
|
output_height, output_width = output_size
|
||||||
|
|
||||||
|
# determine new height and width
|
||||||
|
scale_height = output_height / input_height
|
||||||
|
scale_width = output_width / input_width
|
||||||
|
|
||||||
|
if keep_aspect_ratio:
|
||||||
|
# scale as little as possible
|
||||||
|
if abs(1 - scale_width) < abs(1 - scale_height):
|
||||||
|
# fit width
|
||||||
|
scale_height = scale_width
|
||||||
|
else:
|
||||||
|
# fit height
|
||||||
|
scale_width = scale_height
|
||||||
|
|
||||||
|
new_height = constrain_to_multiple_of(scale_height * input_height, multiple=multiple)
|
||||||
|
new_width = constrain_to_multiple_of(scale_width * input_width, multiple=multiple)
|
||||||
|
|
||||||
|
return SizeDict(height=new_height, width=new_width)
|
||||||
|
|
||||||
|
|
||||||
|
class DPTFastImageProcessorKwargs(DefaultFastImageProcessorKwargs):
|
||||||
|
"""
|
||||||
|
ensure_multiple_of (`int`, *optional*, defaults to 1):
|
||||||
|
If `do_resize` is `True`, the image is resized to a size that is a multiple of this value. Can be overidden
|
||||||
|
by `ensure_multiple_of` in `preprocess`.
|
||||||
|
do_pad (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether to apply center padding. This was introduced in the DINOv2 paper, which uses the model in
|
||||||
|
combination with DPT.
|
||||||
|
size_divisor (`int`, *optional*):
|
||||||
|
If `do_pad` is `True`, pads the image dimensions to be divisible by this value. This was introduced in the
|
||||||
|
DINOv2 paper, which uses the model in combination with DPT.
|
||||||
|
keep_aspect_ratio (`bool`, *optional*, defaults to `False`):
|
||||||
|
If `True`, the image is resized to the largest possible size such that the aspect ratio is preserved. Can
|
||||||
|
be overidden by `keep_aspect_ratio` in `preprocess`.
|
||||||
|
do_reduce_labels (`bool`, *optional*, defaults to `self.do_reduce_labels`):
|
||||||
|
Whether or not to reduce all label values of segmentation maps by 1. Usually used for datasets where 0
|
||||||
|
is used for background, and background itself is not included in all classes of a dataset (e.g.
|
||||||
|
ADE20k). The background label will be replaced by 255.
|
||||||
|
"""
|
||||||
|
|
||||||
|
ensure_multiple_of: Optional[int]
|
||||||
|
size_divisor: Optional[int]
|
||||||
|
do_pad: Optional[bool]
|
||||||
|
keep_aspect_ratio: Optional[bool]
|
||||||
|
do_reduce_labels: Optional[bool]
|
||||||
|
|
||||||
|
|
||||||
|
@auto_docstring
|
||||||
|
class DPTImageProcessorFast(BeitImageProcessorFast):
|
||||||
|
resample = PILImageResampling.BICUBIC
|
||||||
|
image_mean = IMAGENET_STANDARD_MEAN
|
||||||
|
image_std = IMAGENET_STANDARD_STD
|
||||||
|
size = {"height": 384, "width": 384}
|
||||||
|
do_resize = True
|
||||||
|
do_rescale = True
|
||||||
|
do_normalize = True
|
||||||
|
do_pad = False
|
||||||
|
rescale_factor = 1 / 255
|
||||||
|
ensure_multiple_of = 1
|
||||||
|
keep_aspect_ratio = False
|
||||||
|
do_reduce_labels = False
|
||||||
|
crop_size = None
|
||||||
|
do_center_crop = None
|
||||||
|
do_reduce_labels = None
|
||||||
|
|
||||||
|
valid_kwargs = DPTFastImageProcessorKwargs
|
||||||
|
|
||||||
|
def from_dict():
|
||||||
|
raise NotImplementedError("No need to override this method")
|
||||||
|
|
||||||
|
def resize(
|
||||||
|
self,
|
||||||
|
image: "torch.Tensor",
|
||||||
|
size: SizeDict,
|
||||||
|
interpolation: "F.InterpolationMode" = None,
|
||||||
|
antialias: bool = True,
|
||||||
|
ensure_multiple_of: Optional[int] = 1,
|
||||||
|
keep_aspect_ratio: bool = False,
|
||||||
|
) -> "torch.Tensor":
|
||||||
|
"""
|
||||||
|
Resize an image to `(size["height"], size["width"])`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image (`torch.Tensor`):
|
||||||
|
Image to resize.
|
||||||
|
size (`SizeDict`):
|
||||||
|
Dictionary in the format `{"height": int, "width": int}` specifying the size of the output image.
|
||||||
|
interpolation (`InterpolationMode`, *optional*, defaults to `InterpolationMode.BILINEAR`):
|
||||||
|
`InterpolationMode` filter to use when resizing the image e.g. `InterpolationMode.BICUBIC`.
|
||||||
|
antialias (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether to use antialiasing when resizing the image
|
||||||
|
ensure_multiple_of (`int`, *optional*):
|
||||||
|
If `do_resize` is `True`, the image is resized to a size that is a multiple of this value
|
||||||
|
keep_aspect_ratio (`bool`, *optional*, defaults to `False`):
|
||||||
|
If `True`, and `do_resize` is `True`, the image is resized to the largest possible size such that the aspect ratio is preserved.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`torch.Tensor`: The resized image.
|
||||||
|
"""
|
||||||
|
if not size.height or not size.width:
|
||||||
|
raise ValueError(f"The size dictionary must contain the keys 'height' and 'width'. Got {size.keys()}")
|
||||||
|
|
||||||
|
output_size = get_resize_output_image_size(
|
||||||
|
image,
|
||||||
|
output_size=(size.height, size.width),
|
||||||
|
keep_aspect_ratio=keep_aspect_ratio,
|
||||||
|
multiple=ensure_multiple_of,
|
||||||
|
)
|
||||||
|
return BeitImageProcessorFast().resize(image, output_size, interpolation=interpolation, antialias=antialias)
|
||||||
|
|
||||||
|
def pad_image(
|
||||||
|
self,
|
||||||
|
image: "torch.Tensor",
|
||||||
|
size_divisor: int = 1,
|
||||||
|
) -> "torch.Tensor":
|
||||||
|
r"""
|
||||||
|
Center pad a batch of images to be a multiple of `size_divisor`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image (`torch.Tensor`):
|
||||||
|
Image to pad. Can be a batch of images of dimensions (N, C, H, W) or a single image of dimensions (C, H, W).
|
||||||
|
size_divisor (`int`):
|
||||||
|
The width and height of the image will be padded to a multiple of this number.
|
||||||
|
"""
|
||||||
|
height, width = image.shape[-2:]
|
||||||
|
|
||||||
|
def _get_pad(size, size_divisor):
|
||||||
|
new_size = math.ceil(size / size_divisor) * size_divisor
|
||||||
|
pad_size = new_size - size
|
||||||
|
pad_size_left = pad_size // 2
|
||||||
|
pad_size_right = pad_size - pad_size_left
|
||||||
|
return pad_size_left, pad_size_right
|
||||||
|
|
||||||
|
pad_top, pad_bottom = _get_pad(height, size_divisor)
|
||||||
|
pad_left, pad_right = _get_pad(width, size_divisor)
|
||||||
|
padding = (pad_left, pad_top, pad_right, pad_bottom)
|
||||||
|
return F.pad(image, padding)
|
||||||
|
|
||||||
|
def _preprocess(
|
||||||
|
self,
|
||||||
|
images: list["torch.Tensor"],
|
||||||
|
do_reduce_labels: bool,
|
||||||
|
do_resize: bool,
|
||||||
|
size: SizeDict,
|
||||||
|
interpolation: Optional["F.InterpolationMode"],
|
||||||
|
do_center_crop: bool,
|
||||||
|
crop_size: SizeDict,
|
||||||
|
do_rescale: bool,
|
||||||
|
rescale_factor: float,
|
||||||
|
do_normalize: bool,
|
||||||
|
image_mean: Optional[Union[float, list[float]]],
|
||||||
|
image_std: Optional[Union[float, list[float]]],
|
||||||
|
return_tensors: Optional[Union[str, TensorType]],
|
||||||
|
keep_aspect_ratio: bool,
|
||||||
|
ensure_multiple_of: Optional[int],
|
||||||
|
do_pad: bool,
|
||||||
|
size_divisor: Optional[int],
|
||||||
|
**kwargs,
|
||||||
|
) -> BatchFeature:
|
||||||
|
if do_reduce_labels:
|
||||||
|
images = self.reduce_label(images)
|
||||||
|
|
||||||
|
# Group images by size for batched resizing
|
||||||
|
grouped_images, grouped_images_index = group_images_by_shape(images)
|
||||||
|
resized_images_grouped = {}
|
||||||
|
for shape, stacked_images in grouped_images.items():
|
||||||
|
if do_resize:
|
||||||
|
stacked_images = self.resize(
|
||||||
|
image=stacked_images,
|
||||||
|
size=size,
|
||||||
|
interpolation=interpolation,
|
||||||
|
ensure_multiple_of=ensure_multiple_of,
|
||||||
|
keep_aspect_ratio=keep_aspect_ratio,
|
||||||
|
)
|
||||||
|
resized_images_grouped[shape] = stacked_images
|
||||||
|
resized_images = reorder_images(resized_images_grouped, grouped_images_index)
|
||||||
|
|
||||||
|
# Group images by size for further processing
|
||||||
|
# Needed in case do_resize is False, or resize returns images with different sizes
|
||||||
|
grouped_images, grouped_images_index = group_images_by_shape(resized_images)
|
||||||
|
processed_images_grouped = {}
|
||||||
|
for shape, stacked_images in grouped_images.items():
|
||||||
|
if do_center_crop:
|
||||||
|
stacked_images = self.center_crop(stacked_images, crop_size)
|
||||||
|
if do_pad:
|
||||||
|
stacked_images = self.pad_image(stacked_images, size_divisor)
|
||||||
|
# Fused rescale and normalize
|
||||||
|
stacked_images = self.rescale_and_normalize(
|
||||||
|
stacked_images, do_rescale, rescale_factor, do_normalize, image_mean, image_std
|
||||||
|
)
|
||||||
|
processed_images_grouped[shape] = stacked_images
|
||||||
|
|
||||||
|
processed_images = reorder_images(processed_images_grouped, grouped_images_index)
|
||||||
|
processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
|
||||||
|
return processed_images
|
||||||
|
|
||||||
|
def post_process_depth_estimation(
|
||||||
|
self,
|
||||||
|
outputs: "DepthEstimatorOutput",
|
||||||
|
target_sizes: Optional[Union[TensorType, list[tuple[int, int]], None]] = None,
|
||||||
|
) -> list[dict[str, TensorType]]:
|
||||||
|
"""
|
||||||
|
Converts the raw output of [`DepthEstimatorOutput`] into final depth predictions and depth PIL images.
|
||||||
|
Only supports PyTorch.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
outputs ([`DepthEstimatorOutput`]):
|
||||||
|
Raw outputs of the model.
|
||||||
|
target_sizes (`TensorType` or `List[Tuple[int, int]]`, *optional*):
|
||||||
|
Tensor of shape `(batch_size, 2)` or list of tuples (`Tuple[int, int]`) containing the target size
|
||||||
|
(height, width) of each image in the batch. If left to None, predictions will not be resized.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`List[Dict[str, TensorType]]`: A list of dictionaries of tensors representing the processed depth
|
||||||
|
predictions.
|
||||||
|
"""
|
||||||
|
requires_backends(self, "torch")
|
||||||
|
|
||||||
|
predicted_depth = outputs.predicted_depth
|
||||||
|
|
||||||
|
if (target_sizes is not None) and (len(predicted_depth) != len(target_sizes)):
|
||||||
|
raise ValueError(
|
||||||
|
"Make sure that you pass in as many target sizes as the batch dimension of the predicted depth"
|
||||||
|
)
|
||||||
|
|
||||||
|
results = []
|
||||||
|
target_sizes = [None] * len(predicted_depth) if target_sizes is None else target_sizes
|
||||||
|
for depth, target_size in zip(predicted_depth, target_sizes):
|
||||||
|
if target_size is not None:
|
||||||
|
depth = torch.nn.functional.interpolate(
|
||||||
|
depth.unsqueeze(0).unsqueeze(1), size=target_size, mode="bicubic", align_corners=False
|
||||||
|
).squeeze()
|
||||||
|
|
||||||
|
results.append({"predicted_depth": depth})
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["DPTImageProcessorFast"]
|
@ -20,6 +20,7 @@ from datasets import load_dataset
|
|||||||
|
|
||||||
from transformers.file_utils import is_torch_available, is_vision_available
|
from transformers.file_utils import is_torch_available, is_vision_available
|
||||||
from transformers.testing_utils import require_torch, require_vision
|
from transformers.testing_utils import require_torch, require_vision
|
||||||
|
from transformers.utils import is_torchvision_available
|
||||||
|
|
||||||
from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs
|
from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs
|
||||||
|
|
||||||
@ -32,6 +33,9 @@ if is_vision_available():
|
|||||||
|
|
||||||
from transformers import DPTImageProcessor
|
from transformers import DPTImageProcessor
|
||||||
|
|
||||||
|
if is_torchvision_available():
|
||||||
|
from transformers import DPTImageProcessorFast
|
||||||
|
|
||||||
|
|
||||||
class DPTImageProcessingTester:
|
class DPTImageProcessingTester:
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -114,6 +118,7 @@ def prepare_semantic_batch_inputs():
|
|||||||
@require_vision
|
@require_vision
|
||||||
class DPTImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
|
class DPTImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
|
||||||
image_processing_class = DPTImageProcessor if is_vision_available() else None
|
image_processing_class = DPTImageProcessor if is_vision_available() else None
|
||||||
|
fast_image_processing_class = DPTImageProcessorFast if is_torchvision_available() else None
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
super().setUp()
|
super().setUp()
|
||||||
@ -124,170 +129,236 @@ class DPTImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
|
|||||||
return self.image_processor_tester.prepare_image_processor_dict()
|
return self.image_processor_tester.prepare_image_processor_dict()
|
||||||
|
|
||||||
def test_image_processor_properties(self):
|
def test_image_processor_properties(self):
|
||||||
image_processing = self.image_processing_class(**self.image_processor_dict)
|
for image_processing_class in self.image_processor_list:
|
||||||
self.assertTrue(hasattr(image_processing, "image_mean"))
|
image_processing = image_processing_class(**self.image_processor_dict)
|
||||||
self.assertTrue(hasattr(image_processing, "image_std"))
|
self.assertTrue(hasattr(image_processing, "image_mean"))
|
||||||
self.assertTrue(hasattr(image_processing, "do_normalize"))
|
self.assertTrue(hasattr(image_processing, "image_std"))
|
||||||
self.assertTrue(hasattr(image_processing, "do_resize"))
|
self.assertTrue(hasattr(image_processing, "do_normalize"))
|
||||||
self.assertTrue(hasattr(image_processing, "size"))
|
self.assertTrue(hasattr(image_processing, "do_resize"))
|
||||||
self.assertTrue(hasattr(image_processing, "do_rescale"))
|
self.assertTrue(hasattr(image_processing, "size"))
|
||||||
self.assertTrue(hasattr(image_processing, "rescale_factor"))
|
self.assertTrue(hasattr(image_processing, "do_rescale"))
|
||||||
self.assertTrue(hasattr(image_processing, "do_pad"))
|
self.assertTrue(hasattr(image_processing, "rescale_factor"))
|
||||||
self.assertTrue(hasattr(image_processing, "size_divisor"))
|
self.assertTrue(hasattr(image_processing, "do_pad"))
|
||||||
self.assertTrue(hasattr(image_processing, "do_reduce_labels"))
|
self.assertTrue(hasattr(image_processing, "size_divisor"))
|
||||||
|
self.assertTrue(hasattr(image_processing, "do_reduce_labels"))
|
||||||
|
|
||||||
def test_image_processor_from_dict_with_kwargs(self):
|
def test_image_processor_from_dict_with_kwargs(self):
|
||||||
image_processor = self.image_processing_class.from_dict(self.image_processor_dict)
|
for image_processing_class in self.image_processor_list:
|
||||||
self.assertEqual(image_processor.size, {"height": 18, "width": 18})
|
image_processing_class = image_processing_class(**self.image_processor_dict)
|
||||||
|
image_processor = image_processing_class.from_dict(self.image_processor_dict)
|
||||||
|
self.assertEqual(image_processor.size, {"height": 18, "width": 18})
|
||||||
|
|
||||||
image_processor = self.image_processing_class.from_dict(self.image_processor_dict, size=42)
|
image_processor = image_processing_class.from_dict(self.image_processor_dict, size=42)
|
||||||
self.assertEqual(image_processor.size, {"height": 42, "width": 42})
|
self.assertEqual(image_processor.size, {"height": 42, "width": 42})
|
||||||
|
|
||||||
def test_padding(self):
|
def test_padding(self):
|
||||||
image_processing = self.image_processing_class(**self.image_processor_dict)
|
for image_processing_class in self.image_processor_list:
|
||||||
image = np.random.randn(3, 249, 491)
|
if image_processing_class == DPTImageProcessorFast:
|
||||||
|
image = torch.arange(0, 366777, 1, dtype=torch.uint8).reshape(3, 249, 491)
|
||||||
# test individual method
|
image_processor = image_processing_class(**self.image_processor_dict)
|
||||||
image = image_processing.pad_image(image, size_divisor=4)
|
padded_image = image_processor.pad_image(image, size_divisor=4)
|
||||||
self.assertTrue(image.shape[1] % 4 == 0)
|
self.assertTrue(padded_image.shape[1] % 4 == 0)
|
||||||
self.assertTrue(image.shape[2] % 4 == 0)
|
self.assertTrue(padded_image.shape[2] % 4 == 0)
|
||||||
|
pixel_values = image_processor.preprocess(
|
||||||
# test by calling
|
image, do_rescale=False, do_resize=False, do_pad=True, size_divisor=4, return_tensors="pt"
|
||||||
pixel_values = image_processing.preprocess(
|
).pixel_values
|
||||||
image, do_rescale=False, do_resize=False, do_pad=True, size_divisor=4, return_tensors="pt"
|
self.assertTrue(pixel_values.shape[2] % 4 == 0)
|
||||||
).pixel_values
|
self.assertTrue(pixel_values.shape[3] % 4 == 0)
|
||||||
self.assertTrue(pixel_values.shape[2] % 4 == 0)
|
else:
|
||||||
self.assertTrue(pixel_values.shape[3] % 4 == 0)
|
image_processor = image_processing_class(**self.image_processor_dict)
|
||||||
|
image = np.random.randn(3, 249, 491)
|
||||||
|
image = image_processor.pad_image(image, size_divisor=4)
|
||||||
|
self.assertTrue(image.shape[1] % 4 == 0)
|
||||||
|
self.assertTrue(image.shape[2] % 4 == 0)
|
||||||
|
pixel_values = image_processor.preprocess(
|
||||||
|
image, do_rescale=False, do_resize=False, do_pad=True, size_divisor=4, return_tensors="pt"
|
||||||
|
).pixel_values
|
||||||
|
self.assertTrue(pixel_values.shape[2] % 4 == 0)
|
||||||
|
self.assertTrue(pixel_values.shape[3] % 4 == 0)
|
||||||
|
|
||||||
def test_keep_aspect_ratio(self):
|
def test_keep_aspect_ratio(self):
|
||||||
size = {"height": 512, "width": 512}
|
size = {"height": 512, "width": 512}
|
||||||
image_processor = DPTImageProcessor(size=size, keep_aspect_ratio=True, ensure_multiple_of=32)
|
for image_processing_class in self.image_processor_list:
|
||||||
|
image_processor = image_processing_class(size=size, keep_aspect_ratio=True, ensure_multiple_of=32)
|
||||||
|
|
||||||
image = np.zeros((489, 640, 3))
|
image = np.zeros((489, 640, 3))
|
||||||
|
|
||||||
pixel_values = image_processor(image, return_tensors="pt").pixel_values
|
pixel_values = image_processor(image, return_tensors="pt").pixel_values
|
||||||
|
|
||||||
self.assertEqual(list(pixel_values.shape), [1, 3, 512, 672])
|
self.assertEqual(list(pixel_values.shape), [1, 3, 512, 672])
|
||||||
|
|
||||||
# Copied from transformers.tests.models.beit.test_image_processing_beit.BeitImageProcessingTest.test_call_segmentation_maps
|
# Copied from transformers.tests.models.beit.test_image_processing_beit.BeitImageProcessingTest.test_call_segmentation_maps
|
||||||
def test_call_segmentation_maps(self):
|
def test_call_segmentation_maps(self):
|
||||||
# Initialize image_processor
|
for image_processing_class in self.image_processor_list:
|
||||||
image_processor = self.image_processing_class(**self.image_processor_dict)
|
# Initialize image_processor
|
||||||
# create random PyTorch tensors
|
image_processor = image_processing_class(**self.image_processor_dict)
|
||||||
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, torchify=True)
|
# create random PyTorch tensors
|
||||||
maps = []
|
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, torchify=True)
|
||||||
for image in image_inputs:
|
maps = []
|
||||||
self.assertIsInstance(image, torch.Tensor)
|
for image in image_inputs:
|
||||||
maps.append(torch.zeros(image.shape[-2:]).long())
|
self.assertIsInstance(image, torch.Tensor)
|
||||||
|
maps.append(torch.zeros(image.shape[-2:]).long())
|
||||||
|
|
||||||
# Test not batched input
|
# Test not batched input
|
||||||
encoding = image_processor(image_inputs[0], maps[0], return_tensors="pt")
|
encoding = image_processor(image_inputs[0], maps[0], return_tensors="pt")
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
encoding["pixel_values"].shape,
|
encoding["pixel_values"].shape,
|
||||||
(
|
(
|
||||||
1,
|
1,
|
||||||
self.image_processor_tester.num_channels,
|
self.image_processor_tester.num_channels,
|
||||||
self.image_processor_tester.size["height"],
|
self.image_processor_tester.size["height"],
|
||||||
self.image_processor_tester.size["width"],
|
self.image_processor_tester.size["width"],
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
encoding["labels"].shape,
|
encoding["labels"].shape,
|
||||||
(
|
(
|
||||||
1,
|
1,
|
||||||
self.image_processor_tester.size["height"],
|
self.image_processor_tester.size["height"],
|
||||||
self.image_processor_tester.size["width"],
|
self.image_processor_tester.size["width"],
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
self.assertEqual(encoding["labels"].dtype, torch.long)
|
self.assertEqual(encoding["labels"].dtype, torch.long)
|
||||||
self.assertTrue(encoding["labels"].min().item() >= 0)
|
self.assertTrue(encoding["labels"].min().item() >= 0)
|
||||||
self.assertTrue(encoding["labels"].max().item() <= 255)
|
self.assertTrue(encoding["labels"].max().item() <= 255)
|
||||||
|
|
||||||
# Test batched
|
# Test batched
|
||||||
encoding = image_processor(image_inputs, maps, return_tensors="pt")
|
encoding = image_processor(image_inputs, maps, return_tensors="pt")
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
encoding["pixel_values"].shape,
|
encoding["pixel_values"].shape,
|
||||||
(
|
(
|
||||||
self.image_processor_tester.batch_size,
|
self.image_processor_tester.batch_size,
|
||||||
self.image_processor_tester.num_channels,
|
self.image_processor_tester.num_channels,
|
||||||
self.image_processor_tester.size["height"],
|
self.image_processor_tester.size["height"],
|
||||||
self.image_processor_tester.size["width"],
|
self.image_processor_tester.size["width"],
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
encoding["labels"].shape,
|
encoding["labels"].shape,
|
||||||
(
|
(
|
||||||
self.image_processor_tester.batch_size,
|
self.image_processor_tester.batch_size,
|
||||||
self.image_processor_tester.size["height"],
|
self.image_processor_tester.size["height"],
|
||||||
self.image_processor_tester.size["width"],
|
self.image_processor_tester.size["width"],
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
self.assertEqual(encoding["labels"].dtype, torch.long)
|
self.assertEqual(encoding["labels"].dtype, torch.long)
|
||||||
self.assertTrue(encoding["labels"].min().item() >= 0)
|
self.assertTrue(encoding["labels"].min().item() >= 0)
|
||||||
self.assertTrue(encoding["labels"].max().item() <= 255)
|
self.assertTrue(encoding["labels"].max().item() <= 255)
|
||||||
|
|
||||||
# Test not batched input (PIL images)
|
# Test not batched input (PIL images)
|
||||||
image, segmentation_map = prepare_semantic_single_inputs()
|
image, segmentation_map = prepare_semantic_single_inputs()
|
||||||
|
|
||||||
encoding = image_processor(image, segmentation_map, return_tensors="pt")
|
encoding = image_processor(image, segmentation_map, return_tensors="pt")
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
encoding["pixel_values"].shape,
|
encoding["pixel_values"].shape,
|
||||||
(
|
(
|
||||||
1,
|
1,
|
||||||
self.image_processor_tester.num_channels,
|
self.image_processor_tester.num_channels,
|
||||||
self.image_processor_tester.size["height"],
|
self.image_processor_tester.size["height"],
|
||||||
self.image_processor_tester.size["width"],
|
self.image_processor_tester.size["width"],
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
encoding["labels"].shape,
|
encoding["labels"].shape,
|
||||||
(
|
(
|
||||||
1,
|
1,
|
||||||
self.image_processor_tester.size["height"],
|
self.image_processor_tester.size["height"],
|
||||||
self.image_processor_tester.size["width"],
|
self.image_processor_tester.size["width"],
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
self.assertEqual(encoding["labels"].dtype, torch.long)
|
self.assertEqual(encoding["labels"].dtype, torch.long)
|
||||||
self.assertTrue(encoding["labels"].min().item() >= 0)
|
self.assertTrue(encoding["labels"].min().item() >= 0)
|
||||||
self.assertTrue(encoding["labels"].max().item() <= 255)
|
self.assertTrue(encoding["labels"].max().item() <= 255)
|
||||||
|
|
||||||
# Test batched input (PIL images)
|
# Test batched input (PIL images)
|
||||||
images, segmentation_maps = prepare_semantic_batch_inputs()
|
images, segmentation_maps = prepare_semantic_batch_inputs()
|
||||||
|
|
||||||
encoding = image_processor(images, segmentation_maps, return_tensors="pt")
|
encoding = image_processor(images, segmentation_maps, return_tensors="pt")
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
encoding["pixel_values"].shape,
|
encoding["pixel_values"].shape,
|
||||||
(
|
(
|
||||||
2,
|
2,
|
||||||
self.image_processor_tester.num_channels,
|
self.image_processor_tester.num_channels,
|
||||||
self.image_processor_tester.size["height"],
|
self.image_processor_tester.size["height"],
|
||||||
self.image_processor_tester.size["width"],
|
self.image_processor_tester.size["width"],
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
encoding["labels"].shape,
|
encoding["labels"].shape,
|
||||||
(
|
(
|
||||||
2,
|
2,
|
||||||
self.image_processor_tester.size["height"],
|
self.image_processor_tester.size["height"],
|
||||||
self.image_processor_tester.size["width"],
|
self.image_processor_tester.size["width"],
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
self.assertEqual(encoding["labels"].dtype, torch.long)
|
self.assertEqual(encoding["labels"].dtype, torch.long)
|
||||||
self.assertTrue(encoding["labels"].min().item() >= 0)
|
self.assertTrue(encoding["labels"].min().item() >= 0)
|
||||||
self.assertTrue(encoding["labels"].max().item() <= 255)
|
self.assertTrue(encoding["labels"].max().item() <= 255)
|
||||||
|
|
||||||
# Copied from transformers.tests.models.beit.test_image_processing_beit.BeitImageProcessingTest.test_reduce_labels
|
|
||||||
def test_reduce_labels(self):
|
def test_reduce_labels(self):
|
||||||
# Initialize image_processor
|
for image_processing_class in self.image_processor_list:
|
||||||
image_processor = self.image_processing_class(**self.image_processor_dict)
|
image_processor = image_processing_class(**self.image_processor_dict)
|
||||||
|
|
||||||
# ADE20k has 150 classes, and the background is included, so labels should be between 0 and 150
|
# ADE20k has 150 classes, and the background is included, so labels should be between 0 and 150
|
||||||
image, map = prepare_semantic_single_inputs()
|
image, map = prepare_semantic_single_inputs()
|
||||||
encoding = image_processor(image, map, return_tensors="pt")
|
encoding = image_processor(image, map, return_tensors="pt")
|
||||||
self.assertTrue(encoding["labels"].min().item() >= 0)
|
labels_no_reduce = encoding["labels"].clone()
|
||||||
self.assertTrue(encoding["labels"].max().item() <= 150)
|
self.assertTrue(labels_no_reduce.min().item() >= 0)
|
||||||
|
self.assertTrue(labels_no_reduce.max().item() <= 150)
|
||||||
|
# Get the first non-zero label coords and value, for comparison when do_reduce_labels is True
|
||||||
|
non_zero_positions = (labels_no_reduce > 0).nonzero()
|
||||||
|
first_non_zero_coords = tuple(non_zero_positions[0].tolist())
|
||||||
|
first_non_zero_value = labels_no_reduce[first_non_zero_coords].item()
|
||||||
|
|
||||||
image_processor.do_reduce_labels = True
|
image_processor.do_reduce_labels = True
|
||||||
encoding = image_processor(image, map, return_tensors="pt")
|
encoding = image_processor(image, map, return_tensors="pt")
|
||||||
self.assertTrue(encoding["labels"].min().item() >= 0)
|
self.assertTrue(encoding["labels"].min().item() >= 0)
|
||||||
self.assertTrue(encoding["labels"].max().item() <= 255)
|
self.assertTrue(encoding["labels"].max().item() <= 255)
|
||||||
|
# Compare with non-reduced label to see if it's reduced by 1
|
||||||
|
self.assertEqual(encoding["labels"][first_non_zero_coords].item(), first_non_zero_value - 1)
|
||||||
|
|
||||||
|
def test_slow_fast_equivalence(self):
|
||||||
|
if not self.test_slow_image_processor or not self.test_fast_image_processor:
|
||||||
|
self.skipTest(reason="Skipping slow/fast equivalence test")
|
||||||
|
|
||||||
|
if self.image_processing_class is None or self.fast_image_processing_class is None:
|
||||||
|
self.skipTest(reason="Skipping slow/fast equivalence test as one of the image processors is not defined")
|
||||||
|
|
||||||
|
dummy_image, dummy_map = prepare_semantic_single_inputs()
|
||||||
|
|
||||||
|
image_processor_slow = self.image_processing_class(**self.image_processor_dict)
|
||||||
|
image_processor_fast = self.fast_image_processing_class(**self.image_processor_dict)
|
||||||
|
|
||||||
|
image_encoding_slow = image_processor_slow(dummy_image, segmentation_maps=dummy_map, return_tensors="pt")
|
||||||
|
image_encoding_fast = image_processor_fast(dummy_image, segmentation_maps=dummy_map, return_tensors="pt")
|
||||||
|
|
||||||
|
self.assertTrue(torch.allclose(image_encoding_slow.pixel_values, image_encoding_fast.pixel_values, atol=1e-1))
|
||||||
|
self.assertLessEqual(
|
||||||
|
torch.mean(torch.abs(image_encoding_slow.pixel_values - image_encoding_fast.pixel_values)).item(), 1e-3
|
||||||
|
)
|
||||||
|
self.assertTrue(torch.allclose(image_encoding_slow.labels, image_encoding_fast.labels, atol=1e-1))
|
||||||
|
|
||||||
|
def test_slow_fast_equivalence_batched(self):
|
||||||
|
if not self.test_slow_image_processor or not self.test_fast_image_processor:
|
||||||
|
self.skipTest(reason="Skipping slow/fast equivalence test")
|
||||||
|
|
||||||
|
if self.image_processing_class is None or self.fast_image_processing_class is None:
|
||||||
|
self.skipTest(reason="Skipping slow/fast equivalence test as one of the image processors is not defined")
|
||||||
|
|
||||||
|
if hasattr(self.image_processor_tester, "do_center_crop") and self.image_processor_tester.do_center_crop:
|
||||||
|
self.skipTest(
|
||||||
|
reason="Skipping as do_center_crop is True and center_crop functions are not equivalent for fast and slow processors"
|
||||||
|
)
|
||||||
|
|
||||||
|
dummy_images, dummy_maps = prepare_semantic_batch_inputs()
|
||||||
|
|
||||||
|
image_processor_slow = self.image_processing_class(**self.image_processor_dict)
|
||||||
|
image_processor_fast = self.fast_image_processing_class(**self.image_processor_dict)
|
||||||
|
|
||||||
|
encoding_slow = image_processor_slow(dummy_images, segmentation_maps=dummy_maps, return_tensors="pt")
|
||||||
|
encoding_fast = image_processor_fast(dummy_images, segmentation_maps=dummy_maps, return_tensors="pt")
|
||||||
|
|
||||||
|
self.assertTrue(torch.allclose(encoding_slow.pixel_values, encoding_fast.pixel_values, atol=1e-1))
|
||||||
|
self.assertLessEqual(
|
||||||
|
torch.mean(torch.abs(encoding_slow.pixel_values - encoding_fast.pixel_values)).item(), 1e-3
|
||||||
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user