mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 21:00:08 +06:00
Add reduce_labels
to Mobilevit fast processor
This commit is contained in:
parent
d7ac282524
commit
23231d2e35
@ -103,31 +103,38 @@ class MobileNetV2ImageProcessorFast(BaseImageProcessorFast):
|
|||||||
do_rescale: bool,
|
do_rescale: bool,
|
||||||
do_center_crop: bool,
|
do_center_crop: bool,
|
||||||
do_normalize: bool,
|
do_normalize: bool,
|
||||||
size: SizeDict,
|
size: Optional[SizeDict],
|
||||||
interpolation: Optional["F.InterpolationMode"],
|
interpolation: Optional["F.InterpolationMode"],
|
||||||
rescale_factor: float,
|
rescale_factor: Optional[float],
|
||||||
crop_size: SizeDict,
|
crop_size: Optional[SizeDict],
|
||||||
image_mean: Optional[Union[float, list[float]]],
|
image_mean: Optional[Union[float, list[float]]],
|
||||||
image_std: Optional[Union[float, list[float]]],
|
image_std: Optional[Union[float, list[float]]],
|
||||||
|
disable_grouping: bool,
|
||||||
return_tensors: Optional[Union[str, TensorType]],
|
return_tensors: Optional[Union[str, TensorType]],
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> BatchFeature:
|
) -> BatchFeature:
|
||||||
|
processed_images = []
|
||||||
|
|
||||||
if do_reduce_labels:
|
if do_reduce_labels:
|
||||||
images = self.reduce_label(images)
|
images = self.reduce_label(images)
|
||||||
|
|
||||||
# Group images by size for batched resizing
|
# Group images by shape for more efficient batch processing
|
||||||
grouped_images, grouped_images_index = group_images_by_shape(images)
|
grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping)
|
||||||
resized_images_grouped = {}
|
resized_images_grouped = {}
|
||||||
|
|
||||||
|
# Process each group of images with the same shape
|
||||||
for shape, stacked_images in grouped_images.items():
|
for shape, stacked_images in grouped_images.items():
|
||||||
if do_resize:
|
if do_resize:
|
||||||
stacked_images = self.resize(image=stacked_images, size=size, interpolation=interpolation)
|
stacked_images = self.resize(image=stacked_images, size=size, interpolation=interpolation)
|
||||||
resized_images_grouped[shape] = stacked_images
|
resized_images_grouped[shape] = stacked_images
|
||||||
|
|
||||||
|
# Reorder images to original sequence
|
||||||
resized_images = reorder_images(resized_images_grouped, grouped_images_index)
|
resized_images = reorder_images(resized_images_grouped, grouped_images_index)
|
||||||
|
|
||||||
# Group images by size for further processing
|
# Group again after resizing (in case resize produced different sizes)
|
||||||
# Needed in case do_resize is False, or resize returns images with different sizes
|
grouped_images, grouped_images_index = group_images_by_shape(resized_images, disable_grouping=disable_grouping)
|
||||||
grouped_images, grouped_images_index = group_images_by_shape(resized_images)
|
|
||||||
processed_images_grouped = {}
|
processed_images_grouped = {}
|
||||||
|
|
||||||
for shape, stacked_images in grouped_images.items():
|
for shape, stacked_images in grouped_images.items():
|
||||||
if do_center_crop:
|
if do_center_crop:
|
||||||
stacked_images = self.center_crop(stacked_images, crop_size)
|
stacked_images = self.center_crop(stacked_images, crop_size)
|
||||||
@ -138,7 +145,10 @@ class MobileNetV2ImageProcessorFast(BaseImageProcessorFast):
|
|||||||
processed_images_grouped[shape] = stacked_images
|
processed_images_grouped[shape] = stacked_images
|
||||||
|
|
||||||
processed_images = reorder_images(processed_images_grouped, grouped_images_index)
|
processed_images = reorder_images(processed_images_grouped, grouped_images_index)
|
||||||
|
|
||||||
|
# Stack all processed images if return_tensors is specified
|
||||||
processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
|
processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
|
||||||
|
|
||||||
return processed_images
|
return processed_images
|
||||||
|
|
||||||
def _preprocess_images(
|
def _preprocess_images(
|
||||||
@ -170,12 +180,7 @@ class MobileNetV2ImageProcessorFast(BaseImageProcessorFast):
|
|||||||
|
|
||||||
kwargs["do_normalize"] = False
|
kwargs["do_normalize"] = False
|
||||||
kwargs["do_rescale"] = False
|
kwargs["do_rescale"] = False
|
||||||
kwargs["interpolation"] = (
|
kwargs["interpolation"] = pil_torch_interpolation_mapping[PILImageResampling.NEAREST]
|
||||||
pil_torch_interpolation_mapping[PILImageResampling.NEAREST]
|
|
||||||
if PILImageResampling.NEAREST in pil_torch_interpolation_mapping
|
|
||||||
else kwargs.get("interpolation")
|
|
||||||
)
|
|
||||||
kwargs["input_data_format"] = ChannelDimension.FIRST
|
|
||||||
processed_segmentation_maps = self._preprocess(images=processed_segmentation_maps, **kwargs)
|
processed_segmentation_maps = self._preprocess(images=processed_segmentation_maps, **kwargs)
|
||||||
|
|
||||||
processed_segmentation_maps = processed_segmentation_maps.squeeze(1)
|
processed_segmentation_maps = processed_segmentation_maps.squeeze(1)
|
||||||
@ -233,15 +238,15 @@ class MobileNetV2ImageProcessorFast(BaseImageProcessorFast):
|
|||||||
images=images,
|
images=images,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
data = {"pixel_values": images}
|
|
||||||
|
|
||||||
if segmentation_maps is not None:
|
if segmentation_maps is not None:
|
||||||
segmentation_maps = self._preprocess_segmentation_maps(
|
segmentation_maps = self._preprocess_segmentation_maps(
|
||||||
segmentation_maps=segmentation_maps,
|
segmentation_maps=segmentation_maps,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
data["labels"] = segmentation_maps
|
return BatchFeature(data={"pixel_values": images, "labels": segmentation_maps})
|
||||||
return BatchFeature(data=data)
|
|
||||||
|
return BatchFeature(data={"pixel_values": images})
|
||||||
|
|
||||||
# Copied from transformers.models.beit.image_processing_beit_fast.BeitImageProcessorFast.post_process_semantic_segmentation with Beit->MobileNetV2
|
# Copied from transformers.models.beit.image_processing_beit_fast.BeitImageProcessorFast.post_process_semantic_segmentation with Beit->MobileNetV2
|
||||||
def post_process_semantic_segmentation(self, outputs, target_sizes: Optional[list[tuple]] = None):
|
def post_process_semantic_segmentation(self, outputs, target_sizes: Optional[list[tuple]] = None):
|
||||||
|
@ -14,9 +14,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Fast Image processor class for MobileViT."""
|
"""Fast Image processor class for MobileViT."""
|
||||||
|
|
||||||
from typing import Optional
|
from typing import Optional, Union
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from ...image_processing_utils import BatchFeature
|
from ...image_processing_utils import BatchFeature
|
||||||
from ...image_processing_utils_fast import (
|
from ...image_processing_utils_fast import (
|
||||||
@ -27,23 +25,46 @@ from ...image_processing_utils_fast import (
|
|||||||
)
|
)
|
||||||
from ...image_utils import (
|
from ...image_utils import (
|
||||||
ChannelDimension,
|
ChannelDimension,
|
||||||
|
ImageInput,
|
||||||
PILImageResampling,
|
PILImageResampling,
|
||||||
|
SizeDict,
|
||||||
is_torch_tensor,
|
is_torch_tensor,
|
||||||
make_list_of_images,
|
make_list_of_images,
|
||||||
pil_torch_interpolation_mapping,
|
pil_torch_interpolation_mapping,
|
||||||
validate_kwargs,
|
validate_kwargs,
|
||||||
)
|
)
|
||||||
from ...processing_utils import Unpack
|
from ...processing_utils import Unpack
|
||||||
from ...utils import auto_docstring
|
from ...utils import (
|
||||||
|
TensorType,
|
||||||
|
auto_docstring,
|
||||||
|
is_torch_available,
|
||||||
|
is_torchvision_available,
|
||||||
|
is_torchvision_v2_available,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if is_torch_available():
|
||||||
|
import torch
|
||||||
|
|
||||||
|
if is_torchvision_available():
|
||||||
|
if is_torchvision_v2_available():
|
||||||
|
from torchvision.transforms.v2 import functional as F
|
||||||
|
else:
|
||||||
|
from torchvision.transforms import functional as F
|
||||||
|
|
||||||
|
|
||||||
class MobileVitFastImageProcessorKwargs(DefaultFastImageProcessorKwargs):
|
class MobileVitFastImageProcessorKwargs(DefaultFastImageProcessorKwargs):
|
||||||
"""
|
"""
|
||||||
do_flip_channel_order (`bool`, *optional*, defaults to `self.do_flip_channel_order`):
|
do_flip_channel_order (`bool`, *optional*, defaults to `self.do_flip_channel_order`):
|
||||||
Whether to flip the color channels from RGB to BGR or vice versa.
|
Whether to flip the color channels from RGB to BGR or vice versa.
|
||||||
|
do_reduce_labels (`bool`, *optional*, defaults to `self.do_reduce_labels`):
|
||||||
|
Whether or not to reduce all label values of segmentation maps by 1. Usually used for datasets where 0
|
||||||
|
is used for background, and background itself is not included in all classes of a dataset (e.g.
|
||||||
|
ADE20k). The background label will be replaced by 255.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
do_flip_channel_order: Optional[bool]
|
do_flip_channel_order: Optional[bool]
|
||||||
|
do_reduce_labels: Optional[bool]
|
||||||
|
|
||||||
|
|
||||||
@auto_docstring
|
@auto_docstring
|
||||||
@ -58,28 +79,44 @@ class MobileViTImageProcessorFast(BaseImageProcessorFast):
|
|||||||
do_normalize = None
|
do_normalize = None
|
||||||
do_convert_rgb = None
|
do_convert_rgb = None
|
||||||
do_flip_channel_order = True
|
do_flip_channel_order = True
|
||||||
|
do_reduce_labels = False
|
||||||
valid_kwargs = MobileVitFastImageProcessorKwargs
|
valid_kwargs = MobileVitFastImageProcessorKwargs
|
||||||
|
|
||||||
def __init__(self, **kwargs: Unpack[MobileVitFastImageProcessorKwargs]):
|
def __init__(self, **kwargs: Unpack[MobileVitFastImageProcessorKwargs]):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
# Copied from transformers.models.beit.image_processing_beit_fast.BeitImageProcessorFast.reduce_label
|
||||||
|
def reduce_label(self, labels: list["torch.Tensor"]):
|
||||||
|
for idx in range(len(labels)):
|
||||||
|
label = labels[idx]
|
||||||
|
label = torch.where(label == 0, torch.tensor(255, dtype=label.dtype), label)
|
||||||
|
label = label - 1
|
||||||
|
label = torch.where(label == 254, torch.tensor(255, dtype=label.dtype), label)
|
||||||
|
labels[idx] = label
|
||||||
|
|
||||||
|
return label
|
||||||
|
|
||||||
def _preprocess(
|
def _preprocess(
|
||||||
self,
|
self,
|
||||||
images,
|
images: list["torch.Tensor"],
|
||||||
|
do_reduce_labels: bool,
|
||||||
do_resize: bool,
|
do_resize: bool,
|
||||||
size: Optional[dict],
|
size: Optional[SizeDict],
|
||||||
interpolation: Optional[str],
|
interpolation: Optional["F.InterpolationMode"],
|
||||||
do_rescale: bool,
|
do_rescale: bool,
|
||||||
rescale_factor: Optional[float],
|
rescale_factor: Optional[float],
|
||||||
do_center_crop: bool,
|
do_center_crop: bool,
|
||||||
crop_size: Optional[dict],
|
crop_size: Optional[SizeDict],
|
||||||
do_flip_channel_order: bool,
|
do_flip_channel_order: bool,
|
||||||
disable_grouping: bool,
|
disable_grouping: bool,
|
||||||
return_tensors: Optional[str],
|
return_tensors: Optional[Union[str, TensorType]],
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
) -> BatchFeature:
|
||||||
processed_images = []
|
processed_images = []
|
||||||
|
|
||||||
|
if do_reduce_labels:
|
||||||
|
images = self.reduce_label(images)
|
||||||
|
|
||||||
# Group images by shape for more efficient batch processing
|
# Group images by shape for more efficient batch processing
|
||||||
grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping)
|
grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping)
|
||||||
resized_images_grouped = {}
|
resized_images_grouped = {}
|
||||||
@ -119,6 +156,16 @@ class MobileViTImageProcessorFast(BaseImageProcessorFast):
|
|||||||
|
|
||||||
return processed_images
|
return processed_images
|
||||||
|
|
||||||
|
def _preprocess_images(
|
||||||
|
self,
|
||||||
|
images,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""Preprocesses images."""
|
||||||
|
kwargs["do_reduce_labels"] = False
|
||||||
|
processed_images = self._preprocess(images=images, **kwargs)
|
||||||
|
return processed_images
|
||||||
|
|
||||||
def _preprocess_segmentation_maps(
|
def _preprocess_segmentation_maps(
|
||||||
self,
|
self,
|
||||||
segmentation_maps,
|
segmentation_maps,
|
||||||
@ -149,8 +196,8 @@ class MobileViTImageProcessorFast(BaseImageProcessorFast):
|
|||||||
@auto_docstring
|
@auto_docstring
|
||||||
def preprocess(
|
def preprocess(
|
||||||
self,
|
self,
|
||||||
images,
|
images: ImageInput,
|
||||||
segmentation_maps=None,
|
segmentation_maps: Optional[ImageInput] = None,
|
||||||
**kwargs: Unpack[MobileVitFastImageProcessorKwargs],
|
**kwargs: Unpack[MobileVitFastImageProcessorKwargs],
|
||||||
) -> BatchFeature:
|
) -> BatchFeature:
|
||||||
r"""
|
r"""
|
||||||
@ -192,7 +239,7 @@ class MobileViTImageProcessorFast(BaseImageProcessorFast):
|
|||||||
kwargs.pop("default_to_square")
|
kwargs.pop("default_to_square")
|
||||||
kwargs.pop("data_format")
|
kwargs.pop("data_format")
|
||||||
|
|
||||||
images = self._preprocess(
|
images = self._preprocess_images(
|
||||||
images=images,
|
images=images,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
@ -207,6 +254,21 @@ class MobileViTImageProcessorFast(BaseImageProcessorFast):
|
|||||||
return BatchFeature(data={"pixel_values": images})
|
return BatchFeature(data={"pixel_values": images})
|
||||||
|
|
||||||
def post_process_semantic_segmentation(self, outputs, target_sizes: Optional[list[tuple]] = None):
|
def post_process_semantic_segmentation(self, outputs, target_sizes: Optional[list[tuple]] = None):
|
||||||
|
"""
|
||||||
|
Converts the output of [`MobileNetV2ForSemanticSegmentation`] into semantic segmentation maps. Only supports PyTorch.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
outputs ([`MobileNetV2ForSemanticSegmentation`]):
|
||||||
|
Raw outputs of the model.
|
||||||
|
target_sizes (`list[Tuple]` of length `batch_size`, *optional*):
|
||||||
|
List of tuples corresponding to the requested final size (height, width) of each prediction. If unset,
|
||||||
|
predictions will not be resized.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
semantic_segmentation: `list[torch.Tensor]` of length `batch_size`, where each item is a semantic
|
||||||
|
segmentation map of shape (height, width) corresponding to the target_sizes entry (if `target_sizes` is
|
||||||
|
specified). Each entry of each `torch.Tensor` correspond to a semantic class id.
|
||||||
|
"""
|
||||||
logits = outputs.logits
|
logits = outputs.logits
|
||||||
|
|
||||||
# Resize logits and compute semantic segmentation maps
|
# Resize logits and compute semantic segmentation maps
|
||||||
|
@ -15,6 +15,7 @@
|
|||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
|
import requests
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
|
|
||||||
from transformers.testing_utils import require_torch, require_vision
|
from transformers.testing_utils import require_torch, require_vision
|
||||||
@ -89,23 +90,14 @@ class MobileNetV2ImageProcessingTester:
|
|||||||
|
|
||||||
|
|
||||||
def prepare_semantic_single_inputs():
|
def prepare_semantic_single_inputs():
|
||||||
dataset = load_dataset("hf-internal-testing/fixtures_ade20k", split="test", trust_remote_code=True)
|
ds = load_dataset("hf-internal-testing/fixtures_ade20k", split="test")
|
||||||
|
example = ds[0]
|
||||||
image = Image.open(dataset[0]["file"])
|
return example["image"], example["map"]
|
||||||
map = Image.open(dataset[1]["file"])
|
|
||||||
|
|
||||||
return image, map
|
|
||||||
|
|
||||||
|
|
||||||
def prepare_semantic_batch_inputs():
|
def prepare_semantic_batch_inputs():
|
||||||
dataset = load_dataset("hf-internal-testing/fixtures_ade20k", split="test", trust_remote_code=True)
|
ds = load_dataset("hf-internal-testing/fixtures_ade20k", split="test")
|
||||||
|
return list(ds["image"][:2]), list(ds["map"][:2])
|
||||||
image1 = Image.open(dataset[0]["file"])
|
|
||||||
map1 = Image.open(dataset[1]["file"])
|
|
||||||
image2 = Image.open(dataset[2]["file"])
|
|
||||||
map2 = Image.open(dataset[3]["file"])
|
|
||||||
|
|
||||||
return [image1, image2], [map1, map2]
|
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
@ -275,41 +267,21 @@ class MobileNetV2ImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase
|
|||||||
if self.image_processing_class is None or self.fast_image_processing_class is None:
|
if self.image_processing_class is None or self.fast_image_processing_class is None:
|
||||||
self.skipTest(reason="Skipping slow/fast equivalence test as one of the image processors is not defined")
|
self.skipTest(reason="Skipping slow/fast equivalence test as one of the image processors is not defined")
|
||||||
|
|
||||||
dummy_image, dummy_map = prepare_semantic_single_inputs()
|
# Test with single image
|
||||||
|
dummy_image = Image.open(
|
||||||
|
requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw
|
||||||
|
)
|
||||||
image_processor_slow = self.image_processing_class(**self.image_processor_dict)
|
image_processor_slow = self.image_processing_class(**self.image_processor_dict)
|
||||||
image_processor_fast = self.fast_image_processing_class(**self.image_processor_dict)
|
image_processor_fast = self.fast_image_processing_class(**self.image_processor_dict)
|
||||||
|
|
||||||
image_encoding_slow = image_processor_slow(dummy_image, segmentation_maps=dummy_map, return_tensors="pt")
|
encoding_slow = image_processor_slow(dummy_image, return_tensors="pt")
|
||||||
image_encoding_fast = image_processor_fast(dummy_image, segmentation_maps=dummy_map, return_tensors="pt")
|
encoding_fast = image_processor_fast(dummy_image, return_tensors="pt")
|
||||||
|
self._assert_slow_fast_tensors_equivalence(encoding_slow.pixel_values, encoding_fast.pixel_values)
|
||||||
|
|
||||||
self.assertTrue(torch.allclose(image_encoding_slow.pixel_values, image_encoding_fast.pixel_values, atol=1e-1))
|
# Test with single image and segmentation map
|
||||||
self.assertLessEqual(
|
image, segmentation_map = prepare_semantic_single_inputs()
|
||||||
torch.mean(torch.abs(image_encoding_slow.pixel_values - image_encoding_fast.pixel_values)).item(), 1e-3
|
|
||||||
)
|
|
||||||
self.assertTrue(torch.allclose(image_encoding_slow.labels, image_encoding_fast.labels, atol=1e-1))
|
|
||||||
|
|
||||||
def test_slow_fast_equivalence_batched(self):
|
encoding_slow = image_processor_slow(image, segmentation_map, return_tensors="pt")
|
||||||
if not self.test_slow_image_processor or not self.test_fast_image_processor:
|
encoding_fast = image_processor_fast(image, segmentation_map, return_tensors="pt")
|
||||||
self.skipTest(reason="Skipping slow/fast equivalence test")
|
self._assert_slow_fast_tensors_equivalence(encoding_slow.pixel_values, encoding_fast.pixel_values)
|
||||||
|
torch.testing.assert_close(encoding_slow.labels, encoding_fast.labels, atol=1e-1, rtol=1e-3)
|
||||||
if self.image_processing_class is None or self.fast_image_processing_class is None:
|
|
||||||
self.skipTest(reason="Skipping slow/fast equivalence test as one of the image processors is not defined")
|
|
||||||
|
|
||||||
if hasattr(self.image_processor_tester, "do_center_crop") and self.image_processor_tester.do_center_crop:
|
|
||||||
self.skipTest(
|
|
||||||
reason="Skipping as do_center_crop is True and center_crop functions are not equivalent for fast and slow processors"
|
|
||||||
)
|
|
||||||
|
|
||||||
dummy_images, dummy_maps = prepare_semantic_batch_inputs()
|
|
||||||
|
|
||||||
image_processor_slow = self.image_processing_class(**self.image_processor_dict)
|
|
||||||
image_processor_fast = self.fast_image_processing_class(**self.image_processor_dict)
|
|
||||||
|
|
||||||
encoding_slow = image_processor_slow(dummy_images, segmentation_maps=dummy_maps, return_tensors="pt")
|
|
||||||
encoding_fast = image_processor_fast(dummy_images, segmentation_maps=dummy_maps, return_tensors="pt")
|
|
||||||
|
|
||||||
self.assertTrue(torch.allclose(encoding_slow.pixel_values, encoding_fast.pixel_values, atol=1e-1))
|
|
||||||
self.assertLessEqual(
|
|
||||||
torch.mean(torch.abs(encoding_slow.pixel_values - encoding_fast.pixel_values)).item(), 1e-3
|
|
||||||
)
|
|
||||||
|
@ -248,6 +248,22 @@ class MobileViTImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
|
|||||||
self.assertTrue(encoding["labels"].min().item() >= 0)
|
self.assertTrue(encoding["labels"].min().item() >= 0)
|
||||||
self.assertTrue(encoding["labels"].max().item() <= 255)
|
self.assertTrue(encoding["labels"].max().item() <= 255)
|
||||||
|
|
||||||
|
def test_reduce_labels(self):
|
||||||
|
for image_processing_class in self.image_processor_list:
|
||||||
|
# Initialize image_processing
|
||||||
|
image_processing = self.image_processing_class(**self.image_processor_dict)
|
||||||
|
|
||||||
|
# ADE20k has 150 classes, and the background is included, so labels should be between 0 and 150
|
||||||
|
image, map = prepare_semantic_single_inputs()
|
||||||
|
encoding = image_processing(image, map, return_tensors="pt")
|
||||||
|
self.assertTrue(encoding["labels"].min().item() >= 0)
|
||||||
|
self.assertTrue(encoding["labels"].max().item() <= 150)
|
||||||
|
|
||||||
|
image_processing.do_reduce_labels = True
|
||||||
|
encoding = image_processing(image, map, return_tensors="pt")
|
||||||
|
self.assertTrue(encoding["labels"].min().item() >= 0)
|
||||||
|
self.assertTrue(encoding["labels"].max().item() <= 255)
|
||||||
|
|
||||||
@require_vision
|
@require_vision
|
||||||
@require_torch
|
@require_torch
|
||||||
def test_slow_fast_equivalence(self):
|
def test_slow_fast_equivalence(self):
|
||||||
@ -275,19 +291,3 @@ class MobileViTImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
|
|||||||
encoding_fast = image_processor_fast(image, segmentation_map, return_tensors="pt")
|
encoding_fast = image_processor_fast(image, segmentation_map, return_tensors="pt")
|
||||||
self._assert_slow_fast_tensors_equivalence(encoding_slow.pixel_values, encoding_fast.pixel_values)
|
self._assert_slow_fast_tensors_equivalence(encoding_slow.pixel_values, encoding_fast.pixel_values)
|
||||||
torch.testing.assert_close(encoding_slow.labels, encoding_fast.labels, atol=1e-1, rtol=1e-3)
|
torch.testing.assert_close(encoding_slow.labels, encoding_fast.labels, atol=1e-1, rtol=1e-3)
|
||||||
|
|
||||||
def test_reduce_labels(self):
|
|
||||||
for image_processing_class in self.image_processor_list:
|
|
||||||
# Initialize image_processing
|
|
||||||
image_processing = self.image_processing_class(**self.image_processor_dict)
|
|
||||||
|
|
||||||
# ADE20k has 150 classes, and the background is included, so labels should be between 0 and 150
|
|
||||||
image, map = prepare_semantic_single_inputs()
|
|
||||||
encoding = image_processing(image, map, return_tensors="pt")
|
|
||||||
self.assertTrue(encoding["labels"].min().item() >= 0)
|
|
||||||
self.assertTrue(encoding["labels"].max().item() <= 150)
|
|
||||||
|
|
||||||
image_processing.do_reduce_labels = True
|
|
||||||
encoding = image_processing(image, map, return_tensors="pt")
|
|
||||||
self.assertTrue(encoding["labels"].min().item() >= 0)
|
|
||||||
self.assertTrue(encoding["labels"].max().item() <= 255)
|
|
||||||
|
Loading…
Reference in New Issue
Block a user