mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 10:12:23 +06:00
Fix MaskFormerImageProcessor.post_process_instance_segmentation (#21256)
* fix instance segmentation post processing * add Mask2FormerImageProcessor
This commit is contained in:
parent
767939af52
commit
f424b09410
@ -22,8 +22,8 @@ The abstract from the paper is the following:
|
||||
of semantics defines a task. While only the semantics of each task differ, current research focuses on designing specialized architectures for each task. We present Masked-attention Mask Transformer (Mask2Former), a new architecture capable of addressing any image segmentation task (panoptic, instance or semantic). Its key components include masked attention, which extracts localized features by constraining cross-attention within predicted mask regions. In addition to reducing the research effort by at least three times, it outperforms the best specialized architectures by a significant margin on four popular datasets. Most notably, Mask2Former sets a new state-of-the-art for panoptic segmentation (57.8 PQ on COCO), instance segmentation (50.1 AP on COCO) and semantic segmentation (57.7 mIoU on ADE20K).*
|
||||
|
||||
Tips:
|
||||
- Mask2Former uses the same preprocessing and postprocessing steps as [MaskFormer](maskformer). Use [`MaskFormerImageProcessor`] or [`AutoImageProcessor`] to prepare images and optional targets for the model.
|
||||
- To get the final segmentation, depending on the task, you can call [`~MaskFormerImageProcessor.post_process_semantic_segmentation`] or [`~MaskFormerImageProcessor.post_process_instance_segmentation`] or [`~MaskFormerImageProcessor.post_process_panoptic_segmentation`]. All three tasks can be solved using [`Mask2FormerForUniversalSegmentation`] output, panoptic segmentation accepts an optional `label_ids_to_fuse` argument to fuse instances of the target object/s (e.g. sky) together.
|
||||
- Mask2Former uses the same preprocessing and postprocessing steps as [MaskFormer](maskformer). Use [`Mask2FormerImageProcessor`] or [`AutoImageProcessor`] to prepare images and optional targets for the model.
|
||||
- To get the final segmentation, depending on the task, you can call [`~Mask2FormerImageProcessor.post_process_semantic_segmentation`] or [`~Mask2FormerImageProcessor.post_process_instance_segmentation`] or [`~Mask2FormerImageProcessor.post_process_panoptic_segmentation`]. All three tasks can be solved using [`Mask2FormerForUniversalSegmentation`] output, panoptic segmentation accepts an optional `label_ids_to_fuse` argument to fuse instances of the target object/s (e.g. sky) together.
|
||||
|
||||
This model was contributed by [Shivalika Singh](https://huggingface.co/shivi) and [Alara Dirik](https://huggingface.co/adirik). The original code can be found [here](https://github.com/facebookresearch/Mask2Former).
|
||||
|
||||
@ -55,3 +55,12 @@ The resource should ideally demonstrate something new instead of duplicating an
|
||||
|
||||
[[autodoc]] Mask2FormerForUniversalSegmentation
|
||||
- forward
|
||||
|
||||
## Mask2FormerImageProcessor
|
||||
|
||||
[[autodoc]] Mask2FormerImageProcessor
|
||||
- preprocess
|
||||
- encode_inputs
|
||||
- post_process_semantic_segmentation
|
||||
- post_process_instance_segmentation
|
||||
- post_process_panoptic_segmentation
|
@ -799,6 +799,7 @@ else:
|
||||
_import_structure["models.layoutlmv2"].extend(["LayoutLMv2FeatureExtractor", "LayoutLMv2ImageProcessor"])
|
||||
_import_structure["models.layoutlmv3"].extend(["LayoutLMv3FeatureExtractor", "LayoutLMv3ImageProcessor"])
|
||||
_import_structure["models.levit"].extend(["LevitFeatureExtractor", "LevitImageProcessor"])
|
||||
_import_structure["models.mask2former"].append("Mask2FormerImageProcessor")
|
||||
_import_structure["models.maskformer"].extend(["MaskFormerFeatureExtractor", "MaskFormerImageProcessor"])
|
||||
_import_structure["models.mobilenet_v1"].extend(["MobileNetV1FeatureExtractor", "MobileNetV1ImageProcessor"])
|
||||
_import_structure["models.mobilenet_v2"].extend(["MobileNetV2FeatureExtractor", "MobileNetV2ImageProcessor"])
|
||||
@ -4152,6 +4153,7 @@ if TYPE_CHECKING:
|
||||
from .models.layoutlmv2 import LayoutLMv2FeatureExtractor, LayoutLMv2ImageProcessor
|
||||
from .models.layoutlmv3 import LayoutLMv3FeatureExtractor, LayoutLMv3ImageProcessor
|
||||
from .models.levit import LevitFeatureExtractor, LevitImageProcessor
|
||||
from .models.mask2former import Mask2FormerImageProcessor
|
||||
from .models.maskformer import MaskFormerFeatureExtractor, MaskFormerImageProcessor
|
||||
from .models.mobilenet_v1 import MobileNetV1FeatureExtractor, MobileNetV1ImageProcessor
|
||||
from .models.mobilenet_v2 import MobileNetV2FeatureExtractor, MobileNetV2ImageProcessor
|
||||
|
@ -62,7 +62,7 @@ IMAGE_PROCESSOR_MAPPING_NAMES = OrderedDict(
|
||||
("layoutlmv2", "LayoutLMv2ImageProcessor"),
|
||||
("layoutlmv3", "LayoutLMv3ImageProcessor"),
|
||||
("levit", "LevitImageProcessor"),
|
||||
("mask2former", "MaskFormerImageProcessor"),
|
||||
("mask2former", "Mask2FormerImageProcessor"),
|
||||
("maskformer", "MaskFormerImageProcessor"),
|
||||
("mobilenet_v1", "MobileNetV1ImageProcessor"),
|
||||
("mobilenet_v2", "MobileNetV2ImageProcessor"),
|
||||
|
@ -27,6 +27,13 @@ _import_structure = {
|
||||
],
|
||||
}
|
||||
|
||||
try:
|
||||
if not is_vision_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
_import_structure["image_processing_mask2former"] = ["Mask2FormerImageProcessor"]
|
||||
|
||||
try:
|
||||
if not is_torch_available():
|
||||
@ -44,6 +51,14 @@ else:
|
||||
if TYPE_CHECKING:
|
||||
from .configuration_mask2former import MASK2FORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, Mask2FormerConfig
|
||||
|
||||
try:
|
||||
if not is_vision_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
from .image_processing_mask2former import Mask2FormerImageProcessor
|
||||
|
||||
try:
|
||||
if not is_torch_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
|
@ -33,8 +33,8 @@ from huggingface_hub import hf_hub_download
|
||||
from transformers import (
|
||||
Mask2FormerConfig,
|
||||
Mask2FormerForUniversalSegmentation,
|
||||
Mask2FormerImageProcessor,
|
||||
Mask2FormerModel,
|
||||
MaskFormerImageProcessor,
|
||||
SwinConfig,
|
||||
)
|
||||
from transformers.models.mask2former.modeling_mask2former import (
|
||||
@ -193,11 +193,11 @@ class OriginalMask2FormerConfigToOursConverter:
|
||||
|
||||
|
||||
class OriginalMask2FormerConfigToFeatureExtractorConverter:
|
||||
def __call__(self, original_config: object) -> MaskFormerImageProcessor:
|
||||
def __call__(self, original_config: object) -> Mask2FormerImageProcessor:
|
||||
model = original_config.MODEL
|
||||
model_input = original_config.INPUT
|
||||
|
||||
return MaskFormerImageProcessor(
|
||||
return Mask2FormerImageProcessor(
|
||||
image_mean=(torch.tensor(model.PIXEL_MEAN) / 255).tolist(),
|
||||
image_std=(torch.tensor(model.PIXEL_STD) / 255).tolist(),
|
||||
size=model_input.MIN_SIZE_TEST,
|
||||
@ -847,7 +847,7 @@ class OriginalMask2FormerCheckpointToOursConverter:
|
||||
def test(
|
||||
original_model,
|
||||
our_model: Mask2FormerForUniversalSegmentation,
|
||||
feature_extractor: MaskFormerImageProcessor,
|
||||
feature_extractor: Mask2FormerImageProcessor,
|
||||
tolerance: float,
|
||||
):
|
||||
with torch.no_grad():
|
||||
|
1149
src/transformers/models/mask2former/image_processing_mask2former.py
Normal file
1149
src/transformers/models/mask2former/image_processing_mask2former.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -49,6 +49,7 @@ logger = logging.get_logger(__name__)
|
||||
|
||||
_CONFIG_FOR_DOC = "Mask2FormerConfig"
|
||||
_CHECKPOINT_FOR_DOC = "facebook/mask2former-swin-small-coco-instance"
|
||||
_IMAGE_PROCESSOR_FOR_DOC = "Mask2FormerImageProcessor"
|
||||
|
||||
MASK2FORMER_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||
"facebook/mask2former-swin-small-coco-instance",
|
||||
@ -194,10 +195,10 @@ class Mask2FormerForUniversalSegmentationOutput(ModelOutput):
|
||||
"""
|
||||
Class for outputs of [`Mask2FormerForUniversalSegmentationOutput`].
|
||||
|
||||
This output can be directly passed to [`~MaskFormerImageProcessor.post_process_semantic_segmentation`] or
|
||||
[`~MaskFormerImageProcessor.post_process_instance_segmentation`] or
|
||||
[`~MaskFormerImageProcessor.post_process_panoptic_segmentation`] to compute final segmentation maps. Please, see
|
||||
[`~MaskFormerImageProcessor] for details regarding usage.
|
||||
This output can be directly passed to [`~Mask2FormerImageProcessor.post_process_semantic_segmentation`] or
|
||||
[`~Mask2FormerImageProcessor.post_process_instance_segmentation`] or
|
||||
[`~Mask2FormerImageProcessor.post_process_panoptic_segmentation`] to compute final segmentation maps. Please, see
|
||||
[`~Mask2FormerImageProcessor] for details regarding usage.
|
||||
|
||||
Args:
|
||||
loss (`torch.Tensor`, *optional*):
|
||||
|
@ -1016,6 +1016,7 @@ class MaskFormerImageProcessor(BaseImageProcessor):
|
||||
overlap_mask_area_threshold: float = 0.8,
|
||||
target_sizes: Optional[List[Tuple[int, int]]] = None,
|
||||
return_coco_annotation: Optional[bool] = False,
|
||||
return_binary_maps: Optional[bool] = False,
|
||||
) -> List[Dict]:
|
||||
"""
|
||||
Converts the output of [`MaskFormerForInstanceSegmentationOutput`] into instance segmentation predictions. Only
|
||||
@ -1034,9 +1035,11 @@ class MaskFormerImageProcessor(BaseImageProcessor):
|
||||
target_sizes (`List[Tuple]`, *optional*):
|
||||
List of length (batch_size), where each list item (`Tuple[int, int]]`) corresponds to the requested
|
||||
final size (height, width) of each prediction. If left to None, predictions will not be resized.
|
||||
return_coco_annotation (`bool`, *optional*):
|
||||
Defaults to `False`. If set to `True`, segmentation maps are returned in COCO run-length encoding (RLE)
|
||||
format.
|
||||
return_coco_annotation (`bool`, *optional*, defaults to `False`):
|
||||
If set to `True`, segmentation maps are returned in COCO run-length encoding (RLE) format.
|
||||
return_binary_maps (`bool`, *optional*, defaults to `False`):
|
||||
If set to `True`, segmentation maps are returned as a concatenated tensor of binary segmentation maps
|
||||
(one per detected instance).
|
||||
Returns:
|
||||
`List[Dict]`: A list of dictionaries, one per image, each dictionary containing two keys:
|
||||
- **segmentation** -- A tensor of shape `(height, width)` where each pixel represents a `segment_id` or
|
||||
@ -1047,47 +1050,73 @@ class MaskFormerImageProcessor(BaseImageProcessor):
|
||||
- **label_id** -- An integer representing the label / semantic class id corresponding to `segment_id`.
|
||||
- **score** -- Prediction score of segment with `segment_id`.
|
||||
"""
|
||||
class_queries_logits = outputs.class_queries_logits # [batch_size, num_queries, num_classes+1]
|
||||
masks_queries_logits = outputs.masks_queries_logits # [batch_size, num_queries, height, width]
|
||||
if return_coco_annotation and return_binary_maps:
|
||||
raise ValueError("return_coco_annotation and return_binary_maps can not be both set to True.")
|
||||
|
||||
batch_size = class_queries_logits.shape[0]
|
||||
num_labels = class_queries_logits.shape[-1] - 1
|
||||
# [batch_size, num_queries, num_classes+1]
|
||||
class_queries_logits = outputs.class_queries_logits
|
||||
# [batch_size, num_queries, height, width]
|
||||
masks_queries_logits = outputs.masks_queries_logits
|
||||
|
||||
mask_probs = masks_queries_logits.sigmoid() # [batch_size, num_queries, height, width]
|
||||
|
||||
# Predicted label and score of each query (batch_size, num_queries)
|
||||
pred_scores, pred_labels = nn.functional.softmax(class_queries_logits, dim=-1).max(-1)
|
||||
device = masks_queries_logits.device
|
||||
num_classes = class_queries_logits.shape[-1] - 1
|
||||
num_queries = class_queries_logits.shape[-2]
|
||||
|
||||
# Loop over items in batch size
|
||||
results: List[Dict[str, TensorType]] = []
|
||||
|
||||
for i in range(batch_size):
|
||||
mask_probs_item, pred_scores_item, pred_labels_item = remove_low_and_no_objects(
|
||||
mask_probs[i], pred_scores[i], pred_labels[i], threshold, num_labels
|
||||
for i in range(class_queries_logits.shape[0]):
|
||||
mask_pred = masks_queries_logits[i]
|
||||
mask_cls = class_queries_logits[i]
|
||||
|
||||
scores = torch.nn.functional.softmax(mask_cls, dim=-1)[:, :-1]
|
||||
labels = torch.arange(num_classes, device=device).unsqueeze(0).repeat(num_queries, 1).flatten(0, 1)
|
||||
|
||||
scores_per_image, topk_indices = scores.flatten(0, 1).topk(num_queries, sorted=False)
|
||||
labels_per_image = labels[topk_indices]
|
||||
|
||||
topk_indices = topk_indices // num_classes
|
||||
mask_pred = mask_pred[topk_indices]
|
||||
pred_masks = (mask_pred > 0).float()
|
||||
|
||||
# Calculate average mask prob
|
||||
mask_scores_per_image = (mask_pred.sigmoid().flatten(1) * pred_masks.flatten(1)).sum(1) / (
|
||||
pred_masks.flatten(1).sum(1) + 1e-6
|
||||
)
|
||||
pred_scores = scores_per_image * mask_scores_per_image
|
||||
pred_classes = labels_per_image
|
||||
|
||||
# No mask found
|
||||
if mask_probs_item.shape[0] <= 0:
|
||||
height, width = target_sizes[i] if target_sizes is not None else mask_probs_item.shape[1:]
|
||||
segmentation = torch.zeros((height, width)) - 1
|
||||
results.append({"segmentation": segmentation, "segments_info": []})
|
||||
continue
|
||||
segmentation = torch.zeros(masks_queries_logits.shape[2:]) - 1
|
||||
if target_sizes is not None:
|
||||
segmentation = torch.zeros(target_sizes[i]) - 1
|
||||
pred_masks = torch.nn.functional.interpolate(
|
||||
pred_masks.unsqueeze(0), size=target_sizes[i], mode="nearest"
|
||||
)[0]
|
||||
|
||||
# Get segmentation map and segment information of batch item
|
||||
target_size = target_sizes[i] if target_sizes is not None else None
|
||||
segmentation, segments = compute_segments(
|
||||
mask_probs=mask_probs_item,
|
||||
pred_scores=pred_scores_item,
|
||||
pred_labels=pred_labels_item,
|
||||
mask_threshold=mask_threshold,
|
||||
overlap_mask_area_threshold=overlap_mask_area_threshold,
|
||||
label_ids_to_fuse=[],
|
||||
target_size=target_size,
|
||||
)
|
||||
instance_maps, segments = [], []
|
||||
current_segment_id = 0
|
||||
for j in range(num_queries):
|
||||
score = pred_scores[j].item()
|
||||
|
||||
# Return segmentation map in run-length encoding (RLE) format
|
||||
if return_coco_annotation:
|
||||
segmentation = convert_segmentation_to_rle(segmentation)
|
||||
if not torch.all(pred_masks[j] == 0) and score >= threshold:
|
||||
segmentation[pred_masks[j] == 1] = current_segment_id
|
||||
segments.append(
|
||||
{
|
||||
"id": current_segment_id,
|
||||
"label_id": pred_classes[j].item(),
|
||||
"was_fused": False,
|
||||
"score": round(score, 6),
|
||||
}
|
||||
)
|
||||
current_segment_id += 1
|
||||
instance_maps.append(pred_masks[j])
|
||||
# Return segmentation map in run-length encoding (RLE) format
|
||||
if return_coco_annotation:
|
||||
segmentation = convert_segmentation_to_rle(segmentation)
|
||||
|
||||
# Return a concatenated tensor of binary instance maps
|
||||
if return_binary_maps and len(instance_maps) != 0:
|
||||
segmentation = torch.stack(instance_maps, dim=0)
|
||||
|
||||
results.append({"segmentation": segmentation, "segments_info": segments})
|
||||
return results
|
||||
|
@ -269,6 +269,13 @@ class LevitImageProcessor(metaclass=DummyObject):
|
||||
requires_backends(self, ["vision"])
|
||||
|
||||
|
||||
class Mask2FormerImageProcessor(metaclass=DummyObject):
|
||||
_backends = ["vision"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["vision"])
|
||||
|
||||
|
||||
class MaskFormerFeatureExtractor(metaclass=DummyObject):
|
||||
_backends = ["vision"]
|
||||
|
||||
|
611
tests/models/mask2former/test_image_processing_mask2former.py
Normal file
611
tests/models/mask2former/test_image_processing_mask2former.py
Normal file
@ -0,0 +1,611 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
from datasets import load_dataset
|
||||
|
||||
from huggingface_hub import hf_hub_download
|
||||
from transformers.testing_utils import require_torch, require_vision
|
||||
from transformers.utils import is_torch_available, is_vision_available
|
||||
|
||||
from ...test_image_processing_common import ImageProcessingSavingTestMixin, prepare_image_inputs
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
if is_vision_available():
|
||||
from transformers import Mask2FormerImageProcessor
|
||||
from transformers.models.mask2former.image_processing_mask2former import binary_mask_to_rle
|
||||
from transformers.models.mask2former.modeling_mask2former import Mask2FormerForUniversalSegmentationOutput
|
||||
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
|
||||
|
||||
class Mask2FormerImageProcessingTester(unittest.TestCase):
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
batch_size=7,
|
||||
num_channels=3,
|
||||
min_resolution=30,
|
||||
max_resolution=400,
|
||||
size=None,
|
||||
do_resize=True,
|
||||
do_normalize=True,
|
||||
image_mean=[0.5, 0.5, 0.5],
|
||||
image_std=[0.5, 0.5, 0.5],
|
||||
num_labels=10,
|
||||
do_reduce_labels=True,
|
||||
ignore_index=255,
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
self.num_channels = num_channels
|
||||
self.min_resolution = min_resolution
|
||||
self.max_resolution = max_resolution
|
||||
self.do_resize = do_resize
|
||||
self.size = {"shortest_edge": 32, "longest_edge": 1333} if size is None else size
|
||||
self.do_normalize = do_normalize
|
||||
self.image_mean = image_mean
|
||||
self.image_std = image_std
|
||||
self.size_divisor = 0
|
||||
# for the post_process_functions
|
||||
self.batch_size = 2
|
||||
self.num_queries = 3
|
||||
self.num_classes = 2
|
||||
self.height = 3
|
||||
self.width = 4
|
||||
self.num_labels = num_labels
|
||||
self.do_reduce_labels = do_reduce_labels
|
||||
self.ignore_index = ignore_index
|
||||
|
||||
def prepare_image_processor_dict(self):
|
||||
return {
|
||||
"do_resize": self.do_resize,
|
||||
"size": self.size,
|
||||
"do_normalize": self.do_normalize,
|
||||
"image_mean": self.image_mean,
|
||||
"image_std": self.image_std,
|
||||
"size_divisor": self.size_divisor,
|
||||
"num_labels": self.num_labels,
|
||||
"do_reduce_labels": self.do_reduce_labels,
|
||||
"ignore_index": self.ignore_index,
|
||||
}
|
||||
|
||||
def get_expected_values(self, image_inputs, batched=False):
|
||||
"""
|
||||
This function computes the expected height and width when providing images to Mask2FormerImageProcessor,
|
||||
assuming do_resize is set to True with a scalar size.
|
||||
"""
|
||||
if not batched:
|
||||
image = image_inputs[0]
|
||||
if isinstance(image, Image.Image):
|
||||
w, h = image.size
|
||||
else:
|
||||
h, w = image.shape[1], image.shape[2]
|
||||
if w < h:
|
||||
expected_height = int(self.size["shortest_edge"] * h / w)
|
||||
expected_width = self.size["shortest_edge"]
|
||||
elif w > h:
|
||||
expected_height = self.size["shortest_edge"]
|
||||
expected_width = int(self.size["shortest_edge"] * w / h)
|
||||
else:
|
||||
expected_height = self.size["shortest_edge"]
|
||||
expected_width = self.size["shortest_edge"]
|
||||
|
||||
else:
|
||||
expected_values = []
|
||||
for image in image_inputs:
|
||||
expected_height, expected_width = self.get_expected_values([image])
|
||||
expected_values.append((expected_height, expected_width))
|
||||
expected_height = max(expected_values, key=lambda item: item[0])[0]
|
||||
expected_width = max(expected_values, key=lambda item: item[1])[1]
|
||||
|
||||
return expected_height, expected_width
|
||||
|
||||
def get_fake_mask2former_outputs(self):
|
||||
return Mask2FormerForUniversalSegmentationOutput(
|
||||
# +1 for null class
|
||||
class_queries_logits=torch.randn((self.batch_size, self.num_queries, self.num_classes + 1)),
|
||||
masks_queries_logits=torch.randn((self.batch_size, self.num_queries, self.height, self.width)),
|
||||
)
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_vision
|
||||
class Mask2FormerImageProcessingTest(ImageProcessingSavingTestMixin, unittest.TestCase):
|
||||
|
||||
image_processing_class = Mask2FormerImageProcessor if (is_vision_available() and is_torch_available()) else None
|
||||
|
||||
def setUp(self):
|
||||
self.image_processor_tester = Mask2FormerImageProcessingTester(self)
|
||||
|
||||
@property
|
||||
def image_processor_dict(self):
|
||||
return self.image_processor_tester.prepare_image_processor_dict()
|
||||
|
||||
def test_image_processor_properties(self):
|
||||
image_processing = self.image_processing_class(**self.image_processor_dict)
|
||||
self.assertTrue(hasattr(image_processing, "image_mean"))
|
||||
self.assertTrue(hasattr(image_processing, "image_std"))
|
||||
self.assertTrue(hasattr(image_processing, "do_normalize"))
|
||||
self.assertTrue(hasattr(image_processing, "do_resize"))
|
||||
self.assertTrue(hasattr(image_processing, "size"))
|
||||
self.assertTrue(hasattr(image_processing, "max_size"))
|
||||
self.assertTrue(hasattr(image_processing, "ignore_index"))
|
||||
self.assertTrue(hasattr(image_processing, "num_labels"))
|
||||
|
||||
def test_image_processor_from_dict_with_kwargs(self):
|
||||
image_processor = self.image_processing_class.from_dict(self.image_processor_dict)
|
||||
self.assertEqual(image_processor.size, {"shortest_edge": 32, "longest_edge": 1333})
|
||||
self.assertEqual(image_processor.size_divisor, 0)
|
||||
|
||||
image_processor = self.image_processing_class.from_dict(
|
||||
self.image_processor_dict, size=42, max_size=84, size_divisibility=8
|
||||
)
|
||||
self.assertEqual(image_processor.size, {"shortest_edge": 42, "longest_edge": 84})
|
||||
self.assertEqual(image_processor.size_divisor, 8)
|
||||
|
||||
def test_batch_feature(self):
|
||||
pass
|
||||
|
||||
def test_call_pil(self):
|
||||
# Initialize image_processing
|
||||
image_processing = self.image_processing_class(**self.image_processor_dict)
|
||||
# create random PIL images
|
||||
image_inputs = prepare_image_inputs(self.image_processor_tester, equal_resolution=False)
|
||||
for image in image_inputs:
|
||||
self.assertIsInstance(image, Image.Image)
|
||||
|
||||
# Test not batched input
|
||||
encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values
|
||||
|
||||
expected_height, expected_width = self.image_processor_tester.get_expected_values(image_inputs)
|
||||
|
||||
self.assertEqual(
|
||||
encoded_images.shape,
|
||||
(1, self.image_processor_tester.num_channels, expected_height, expected_width),
|
||||
)
|
||||
|
||||
# Test batched
|
||||
expected_height, expected_width = self.image_processor_tester.get_expected_values(image_inputs, batched=True)
|
||||
|
||||
encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values
|
||||
self.assertEqual(
|
||||
encoded_images.shape,
|
||||
(
|
||||
self.image_processor_tester.batch_size,
|
||||
self.image_processor_tester.num_channels,
|
||||
expected_height,
|
||||
expected_width,
|
||||
),
|
||||
)
|
||||
|
||||
def test_call_numpy(self):
|
||||
# Initialize image_processing
|
||||
image_processing = self.image_processing_class(**self.image_processor_dict)
|
||||
# create random numpy tensors
|
||||
image_inputs = prepare_image_inputs(self.image_processor_tester, equal_resolution=False, numpify=True)
|
||||
for image in image_inputs:
|
||||
self.assertIsInstance(image, np.ndarray)
|
||||
|
||||
# Test not batched input
|
||||
encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values
|
||||
|
||||
expected_height, expected_width = self.image_processor_tester.get_expected_values(image_inputs)
|
||||
|
||||
self.assertEqual(
|
||||
encoded_images.shape,
|
||||
(1, self.image_processor_tester.num_channels, expected_height, expected_width),
|
||||
)
|
||||
|
||||
# Test batched
|
||||
encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values
|
||||
|
||||
expected_height, expected_width = self.image_processor_tester.get_expected_values(image_inputs, batched=True)
|
||||
|
||||
self.assertEqual(
|
||||
encoded_images.shape,
|
||||
(
|
||||
self.image_processor_tester.batch_size,
|
||||
self.image_processor_tester.num_channels,
|
||||
expected_height,
|
||||
expected_width,
|
||||
),
|
||||
)
|
||||
|
||||
def test_call_pytorch(self):
|
||||
# Initialize image_processing
|
||||
image_processing = self.image_processing_class(**self.image_processor_dict)
|
||||
# create random PyTorch tensors
|
||||
image_inputs = prepare_image_inputs(self.image_processor_tester, equal_resolution=False, torchify=True)
|
||||
for image in image_inputs:
|
||||
self.assertIsInstance(image, torch.Tensor)
|
||||
|
||||
# Test not batched input
|
||||
encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values
|
||||
|
||||
expected_height, expected_width = self.image_processor_tester.get_expected_values(image_inputs)
|
||||
|
||||
self.assertEqual(
|
||||
encoded_images.shape,
|
||||
(1, self.image_processor_tester.num_channels, expected_height, expected_width),
|
||||
)
|
||||
|
||||
# Test batched
|
||||
encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values
|
||||
|
||||
expected_height, expected_width = self.image_processor_tester.get_expected_values(image_inputs, batched=True)
|
||||
|
||||
self.assertEqual(
|
||||
encoded_images.shape,
|
||||
(
|
||||
self.image_processor_tester.batch_size,
|
||||
self.image_processor_tester.num_channels,
|
||||
expected_height,
|
||||
expected_width,
|
||||
),
|
||||
)
|
||||
|
||||
def test_equivalence_pad_and_create_pixel_mask(self):
|
||||
# Initialize image_processings
|
||||
image_processing_1 = self.image_processing_class(**self.image_processor_dict)
|
||||
image_processing_2 = self.image_processing_class(
|
||||
do_resize=False, do_normalize=False, do_rescale=False, num_labels=self.image_processor_tester.num_classes
|
||||
)
|
||||
# create random PyTorch tensors
|
||||
image_inputs = prepare_image_inputs(self.image_processor_tester, equal_resolution=False, torchify=True)
|
||||
for image in image_inputs:
|
||||
self.assertIsInstance(image, torch.Tensor)
|
||||
|
||||
# Test whether the method "pad_and_return_pixel_mask" and calling the image processor return the same tensors
|
||||
encoded_images_with_method = image_processing_1.encode_inputs(image_inputs, return_tensors="pt")
|
||||
encoded_images = image_processing_2(image_inputs, return_tensors="pt")
|
||||
|
||||
self.assertTrue(
|
||||
torch.allclose(encoded_images_with_method["pixel_values"], encoded_images["pixel_values"], atol=1e-4)
|
||||
)
|
||||
self.assertTrue(
|
||||
torch.allclose(encoded_images_with_method["pixel_mask"], encoded_images["pixel_mask"], atol=1e-4)
|
||||
)
|
||||
|
||||
def comm_get_image_processing_inputs(
|
||||
self, with_segmentation_maps=False, is_instance_map=False, segmentation_type="np"
|
||||
):
|
||||
image_processing = self.image_processing_class(**self.image_processor_dict)
|
||||
# prepare image and target
|
||||
num_labels = self.image_processor_tester.num_labels
|
||||
annotations = None
|
||||
instance_id_to_semantic_id = None
|
||||
image_inputs = prepare_image_inputs(self.image_processor_tester, equal_resolution=False)
|
||||
if with_segmentation_maps:
|
||||
high = num_labels
|
||||
if is_instance_map:
|
||||
labels_expanded = list(range(num_labels)) * 2
|
||||
instance_id_to_semantic_id = {
|
||||
instance_id: label_id for instance_id, label_id in enumerate(labels_expanded)
|
||||
}
|
||||
annotations = [
|
||||
np.random.randint(0, high * 2, (img.size[1], img.size[0])).astype(np.uint8) for img in image_inputs
|
||||
]
|
||||
if segmentation_type == "pil":
|
||||
annotations = [Image.fromarray(annotation) for annotation in annotations]
|
||||
|
||||
inputs = image_processing(
|
||||
image_inputs,
|
||||
annotations,
|
||||
return_tensors="pt",
|
||||
instance_id_to_semantic_id=instance_id_to_semantic_id,
|
||||
pad_and_return_pixel_mask=True,
|
||||
)
|
||||
|
||||
return inputs
|
||||
|
||||
def test_init_without_params(self):
|
||||
pass
|
||||
|
||||
def test_with_size_divisor(self):
|
||||
size_divisors = [8, 16, 32]
|
||||
weird_input_sizes = [(407, 802), (582, 1094)]
|
||||
for size_divisor in size_divisors:
|
||||
image_processor_dict = {**self.image_processor_dict, **{"size_divisor": size_divisor}}
|
||||
image_processing = self.image_processing_class(**image_processor_dict)
|
||||
for weird_input_size in weird_input_sizes:
|
||||
inputs = image_processing([np.ones((3, *weird_input_size))], return_tensors="pt")
|
||||
pixel_values = inputs["pixel_values"]
|
||||
# check if divisible
|
||||
self.assertTrue((pixel_values.shape[-1] % size_divisor) == 0)
|
||||
self.assertTrue((pixel_values.shape[-2] % size_divisor) == 0)
|
||||
|
||||
def test_call_with_segmentation_maps(self):
|
||||
def common(is_instance_map=False, segmentation_type=None):
|
||||
inputs = self.comm_get_image_processing_inputs(
|
||||
with_segmentation_maps=True, is_instance_map=is_instance_map, segmentation_type=segmentation_type
|
||||
)
|
||||
|
||||
mask_labels = inputs["mask_labels"]
|
||||
class_labels = inputs["class_labels"]
|
||||
pixel_values = inputs["pixel_values"]
|
||||
|
||||
# check the batch_size
|
||||
for mask_label, class_label in zip(mask_labels, class_labels):
|
||||
self.assertEqual(mask_label.shape[0], class_label.shape[0])
|
||||
# this ensure padding has happened
|
||||
self.assertEqual(mask_label.shape[1:], pixel_values.shape[2:])
|
||||
|
||||
common()
|
||||
common(is_instance_map=True)
|
||||
common(is_instance_map=False, segmentation_type="pil")
|
||||
common(is_instance_map=True, segmentation_type="pil")
|
||||
|
||||
def test_integration_instance_segmentation(self):
|
||||
# load 2 images and corresponding annotations from the hub
|
||||
repo_id = "nielsr/image-segmentation-toy-data"
|
||||
image1 = Image.open(
|
||||
hf_hub_download(repo_id=repo_id, filename="instance_segmentation_image_1.png", repo_type="dataset")
|
||||
)
|
||||
image2 = Image.open(
|
||||
hf_hub_download(repo_id=repo_id, filename="instance_segmentation_image_2.png", repo_type="dataset")
|
||||
)
|
||||
annotation1 = Image.open(
|
||||
hf_hub_download(repo_id=repo_id, filename="instance_segmentation_annotation_1.png", repo_type="dataset")
|
||||
)
|
||||
annotation2 = Image.open(
|
||||
hf_hub_download(repo_id=repo_id, filename="instance_segmentation_annotation_2.png", repo_type="dataset")
|
||||
)
|
||||
|
||||
# get instance segmentations and instance-to-segmentation mappings
|
||||
def get_instance_segmentation_and_mapping(annotation):
|
||||
instance_seg = np.array(annotation)[:, :, 1]
|
||||
class_id_map = np.array(annotation)[:, :, 0]
|
||||
class_labels = np.unique(class_id_map)
|
||||
|
||||
# create mapping between instance IDs and semantic category IDs
|
||||
inst2class = {}
|
||||
for label in class_labels:
|
||||
instance_ids = np.unique(instance_seg[class_id_map == label])
|
||||
inst2class.update({i: label for i in instance_ids})
|
||||
|
||||
return instance_seg, inst2class
|
||||
|
||||
instance_seg1, inst2class1 = get_instance_segmentation_and_mapping(annotation1)
|
||||
instance_seg2, inst2class2 = get_instance_segmentation_and_mapping(annotation2)
|
||||
|
||||
# create a image processor
|
||||
image_processing = Mask2FormerImageProcessor(reduce_labels=True, ignore_index=255, size=(512, 512))
|
||||
|
||||
# prepare the images and annotations
|
||||
inputs = image_processing(
|
||||
[image1, image2],
|
||||
[instance_seg1, instance_seg2],
|
||||
instance_id_to_semantic_id=[inst2class1, inst2class2],
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
# verify the pixel values and pixel mask
|
||||
self.assertEqual(inputs["pixel_values"].shape, (2, 3, 512, 512))
|
||||
self.assertEqual(inputs["pixel_mask"].shape, (2, 512, 512))
|
||||
|
||||
# verify the class labels
|
||||
self.assertEqual(len(inputs["class_labels"]), 2)
|
||||
self.assertTrue(torch.allclose(inputs["class_labels"][0], torch.tensor([30, 55])))
|
||||
self.assertTrue(torch.allclose(inputs["class_labels"][1], torch.tensor([4, 4, 23, 55])))
|
||||
|
||||
# verify the mask labels
|
||||
self.assertEqual(len(inputs["mask_labels"]), 2)
|
||||
self.assertEqual(inputs["mask_labels"][0].shape, (2, 512, 512))
|
||||
self.assertEqual(inputs["mask_labels"][1].shape, (4, 512, 512))
|
||||
self.assertEquals(inputs["mask_labels"][0].sum().item(), 41527.0)
|
||||
self.assertEquals(inputs["mask_labels"][1].sum().item(), 26259.0)
|
||||
|
||||
def test_integration_semantic_segmentation(self):
|
||||
# load 2 images and corresponding semantic annotations from the hub
|
||||
repo_id = "nielsr/image-segmentation-toy-data"
|
||||
image1 = Image.open(
|
||||
hf_hub_download(repo_id=repo_id, filename="semantic_segmentation_image_1.png", repo_type="dataset")
|
||||
)
|
||||
image2 = Image.open(
|
||||
hf_hub_download(repo_id=repo_id, filename="semantic_segmentation_image_2.png", repo_type="dataset")
|
||||
)
|
||||
annotation1 = Image.open(
|
||||
hf_hub_download(repo_id=repo_id, filename="semantic_segmentation_annotation_1.png", repo_type="dataset")
|
||||
)
|
||||
annotation2 = Image.open(
|
||||
hf_hub_download(repo_id=repo_id, filename="semantic_segmentation_annotation_2.png", repo_type="dataset")
|
||||
)
|
||||
|
||||
# create a image processor
|
||||
image_processing = Mask2FormerImageProcessor(reduce_labels=True, ignore_index=255, size=(512, 512))
|
||||
|
||||
# prepare the images and annotations
|
||||
inputs = image_processing(
|
||||
[image1, image2],
|
||||
[annotation1, annotation2],
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
# verify the pixel values and pixel mask
|
||||
self.assertEqual(inputs["pixel_values"].shape, (2, 3, 512, 512))
|
||||
self.assertEqual(inputs["pixel_mask"].shape, (2, 512, 512))
|
||||
|
||||
# verify the class labels
|
||||
self.assertEqual(len(inputs["class_labels"]), 2)
|
||||
self.assertTrue(torch.allclose(inputs["class_labels"][0], torch.tensor([2, 4, 60])))
|
||||
self.assertTrue(torch.allclose(inputs["class_labels"][1], torch.tensor([0, 3, 7, 8, 15, 28, 30, 143])))
|
||||
|
||||
# verify the mask labels
|
||||
self.assertEqual(len(inputs["mask_labels"]), 2)
|
||||
self.assertEqual(inputs["mask_labels"][0].shape, (3, 512, 512))
|
||||
self.assertEqual(inputs["mask_labels"][1].shape, (8, 512, 512))
|
||||
self.assertEquals(inputs["mask_labels"][0].sum().item(), 170200.0)
|
||||
self.assertEquals(inputs["mask_labels"][1].sum().item(), 257036.0)
|
||||
|
||||
def test_integration_panoptic_segmentation(self):
|
||||
# load 2 images and corresponding panoptic annotations from the hub
|
||||
dataset = load_dataset("nielsr/ade20k-panoptic-demo")
|
||||
image1 = dataset["train"][0]["image"]
|
||||
image2 = dataset["train"][1]["image"]
|
||||
segments_info1 = dataset["train"][0]["segments_info"]
|
||||
segments_info2 = dataset["train"][1]["segments_info"]
|
||||
annotation1 = dataset["train"][0]["label"]
|
||||
annotation2 = dataset["train"][1]["label"]
|
||||
|
||||
def rgb_to_id(color):
|
||||
if isinstance(color, np.ndarray) and len(color.shape) == 3:
|
||||
if color.dtype == np.uint8:
|
||||
color = color.astype(np.int32)
|
||||
return color[:, :, 0] + 256 * color[:, :, 1] + 256 * 256 * color[:, :, 2]
|
||||
return int(color[0] + 256 * color[1] + 256 * 256 * color[2])
|
||||
|
||||
def create_panoptic_map(annotation, segments_info):
|
||||
annotation = np.array(annotation)
|
||||
# convert RGB to segment IDs per pixel
|
||||
# 0 is the "ignore" label, for which we don't need to make binary masks
|
||||
panoptic_map = rgb_to_id(annotation)
|
||||
|
||||
# create mapping between segment IDs and semantic classes
|
||||
inst2class = {segment["id"]: segment["category_id"] for segment in segments_info}
|
||||
|
||||
return panoptic_map, inst2class
|
||||
|
||||
panoptic_map1, inst2class1 = create_panoptic_map(annotation1, segments_info1)
|
||||
panoptic_map2, inst2class2 = create_panoptic_map(annotation2, segments_info2)
|
||||
|
||||
# create a image processor
|
||||
image_processing = Mask2FormerImageProcessor(ignore_index=0, do_resize=False)
|
||||
|
||||
# prepare the images and annotations
|
||||
pixel_values_list = [np.moveaxis(np.array(image1), -1, 0), np.moveaxis(np.array(image2), -1, 0)]
|
||||
inputs = image_processing.encode_inputs(
|
||||
pixel_values_list,
|
||||
[panoptic_map1, panoptic_map2],
|
||||
instance_id_to_semantic_id=[inst2class1, inst2class2],
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
# verify the pixel values and pixel mask
|
||||
self.assertEqual(inputs["pixel_values"].shape, (2, 3, 512, 711))
|
||||
self.assertEqual(inputs["pixel_mask"].shape, (2, 512, 711))
|
||||
|
||||
# verify the class labels
|
||||
self.assertEqual(len(inputs["class_labels"]), 2)
|
||||
# fmt: off
|
||||
expected_class_labels = torch.tensor([4, 17, 32, 42, 42, 42, 42, 42, 42, 42, 32, 12, 12, 12, 12, 12, 42, 42, 12, 12, 12, 42, 12, 12, 12, 12, 12, 3, 12, 12, 12, 12, 42, 42, 42, 12, 42, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 5, 12, 12, 12, 12, 12, 12, 12, 0, 43, 43, 43, 96, 43, 104, 43, 31, 125, 31, 125, 138, 87, 125, 149, 138, 125, 87, 87]) # noqa: E231
|
||||
# fmt: on
|
||||
self.assertTrue(torch.allclose(inputs["class_labels"][0], torch.tensor(expected_class_labels)))
|
||||
# fmt: off
|
||||
expected_class_labels = torch.tensor([19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 67, 82, 19, 19, 17, 19, 19, 19, 19, 19, 19, 19, 19, 19, 12, 12, 42, 12, 12, 12, 12, 3, 14, 12, 12, 12, 12, 12, 12, 12, 12, 14, 5, 12, 12, 0, 115, 43, 43, 115, 43, 43, 43, 8, 8, 8, 138, 138, 125, 143]) # noqa: E231
|
||||
# fmt: on
|
||||
self.assertTrue(torch.allclose(inputs["class_labels"][1], expected_class_labels))
|
||||
|
||||
# verify the mask labels
|
||||
self.assertEqual(len(inputs["mask_labels"]), 2)
|
||||
self.assertEqual(inputs["mask_labels"][0].shape, (79, 512, 711))
|
||||
self.assertEqual(inputs["mask_labels"][1].shape, (61, 512, 711))
|
||||
self.assertEquals(inputs["mask_labels"][0].sum().item(), 315193.0)
|
||||
self.assertEquals(inputs["mask_labels"][1].sum().item(), 350747.0)
|
||||
|
||||
def test_binary_mask_to_rle(self):
|
||||
fake_binary_mask = np.zeros((20, 50))
|
||||
fake_binary_mask[0, 20:] = 1
|
||||
fake_binary_mask[1, :15] = 1
|
||||
fake_binary_mask[5, :10] = 1
|
||||
|
||||
rle = binary_mask_to_rle(fake_binary_mask)
|
||||
self.assertEqual(len(rle), 4)
|
||||
self.assertEqual(rle[0], 21)
|
||||
self.assertEqual(rle[1], 45)
|
||||
|
||||
def test_post_process_semantic_segmentation(self):
|
||||
fature_extractor = self.image_processing_class(num_labels=self.image_processor_tester.num_classes)
|
||||
outputs = self.image_processor_tester.get_fake_mask2former_outputs()
|
||||
|
||||
segmentation = fature_extractor.post_process_semantic_segmentation(outputs)
|
||||
|
||||
self.assertEqual(len(segmentation), self.image_processor_tester.batch_size)
|
||||
self.assertEqual(segmentation[0].shape, (384, 384))
|
||||
|
||||
target_sizes = [(1, 4) for i in range(self.image_processor_tester.batch_size)]
|
||||
segmentation = fature_extractor.post_process_semantic_segmentation(outputs, target_sizes=target_sizes)
|
||||
|
||||
self.assertEqual(segmentation[0].shape, target_sizes[0])
|
||||
|
||||
def test_post_process_instance_segmentation(self):
|
||||
feature_extractor = self.image_processing_class(num_labels=self.image_processor_tester.num_classes)
|
||||
outputs = self.image_processor_tester.get_fake_mask2former_outputs()
|
||||
segmentation = feature_extractor.post_process_instance_segmentation(outputs, threshold=0)
|
||||
|
||||
self.assertTrue(len(segmentation) == self.image_processor_tester.batch_size)
|
||||
for el in segmentation:
|
||||
self.assertTrue("segmentation" in el)
|
||||
self.assertTrue("segments_info" in el)
|
||||
self.assertEqual(type(el["segments_info"]), list)
|
||||
self.assertEqual(el["segmentation"].shape, (384, 384))
|
||||
|
||||
segmentation = feature_extractor.post_process_instance_segmentation(
|
||||
outputs, threshold=0, return_binary_maps=True
|
||||
)
|
||||
|
||||
self.assertTrue(len(segmentation) == self.image_processor_tester.batch_size)
|
||||
for el in segmentation:
|
||||
self.assertTrue("segmentation" in el)
|
||||
self.assertTrue("segments_info" in el)
|
||||
self.assertEqual(type(el["segments_info"]), list)
|
||||
self.assertEqual(len(el["segmentation"].shape), 3)
|
||||
self.assertEqual(el["segmentation"].shape[1:], (384, 384))
|
||||
|
||||
def test_post_process_panoptic_segmentation(self):
|
||||
image_processing = self.image_processing_class(num_labels=self.image_processor_tester.num_classes)
|
||||
outputs = self.image_processor_tester.get_fake_mask2former_outputs()
|
||||
segmentation = image_processing.post_process_panoptic_segmentation(outputs, threshold=0)
|
||||
|
||||
self.assertTrue(len(segmentation) == self.image_processor_tester.batch_size)
|
||||
for el in segmentation:
|
||||
self.assertTrue("segmentation" in el)
|
||||
self.assertTrue("segments_info" in el)
|
||||
self.assertEqual(type(el["segments_info"]), list)
|
||||
self.assertEqual(el["segmentation"].shape, (384, 384))
|
||||
|
||||
def test_post_process_label_fusing(self):
|
||||
image_processor = self.image_processing_class(num_labels=self.image_processor_tester.num_classes)
|
||||
outputs = self.image_processor_tester.get_fake_mask2former_outputs()
|
||||
|
||||
segmentation = image_processor.post_process_panoptic_segmentation(
|
||||
outputs, threshold=0, mask_threshold=0, overlap_mask_area_threshold=0
|
||||
)
|
||||
unfused_segments = [el["segments_info"] for el in segmentation]
|
||||
|
||||
fused_segmentation = image_processor.post_process_panoptic_segmentation(
|
||||
outputs, threshold=0, mask_threshold=0, overlap_mask_area_threshold=0, label_ids_to_fuse={1}
|
||||
)
|
||||
fused_segments = [el["segments_info"] for el in fused_segmentation]
|
||||
|
||||
for el_unfused, el_fused in zip(unfused_segments, fused_segments):
|
||||
if len(el_unfused) == 0:
|
||||
self.assertEqual(len(el_unfused), len(el_fused))
|
||||
continue
|
||||
|
||||
# Get number of segments to be fused
|
||||
fuse_targets = [1 for el in el_unfused if el["label_id"] in {1}]
|
||||
num_to_fuse = 0 if len(fuse_targets) == 0 else sum(fuse_targets) - 1
|
||||
# Expected number of segments after fusing
|
||||
expected_num_segments = max([el["id"] for el in el_unfused]) - num_to_fuse
|
||||
num_segments_fused = max([el["id"] for el in el_fused])
|
||||
self.assertEqual(num_segments_fused, expected_num_segments)
|
@ -34,7 +34,7 @@ if is_torch_available():
|
||||
from transformers import Mask2FormerForUniversalSegmentation, Mask2FormerModel
|
||||
|
||||
if is_vision_available():
|
||||
from transformers import MaskFormerImageProcessor
|
||||
from transformers import Mask2FormerImageProcessor
|
||||
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
@ -325,7 +325,7 @@ class Mask2FormerModelIntegrationTest(unittest.TestCase):
|
||||
|
||||
@cached_property
|
||||
def default_feature_extractor(self):
|
||||
return MaskFormerImageProcessor.from_pretrained(self.model_checkpoints) if is_vision_available() else None
|
||||
return Mask2FormerImageProcessor.from_pretrained(self.model_checkpoints) if is_vision_available() else None
|
||||
|
||||
def test_inference_no_head(self):
|
||||
model = Mask2FormerModel.from_pretrained(self.model_checkpoints).to(torch_device)
|
||||
|
@ -576,6 +576,34 @@ class MaskFormerImageProcessingTest(ImageProcessingSavingTestMixin, unittest.Tes
|
||||
|
||||
self.assertEqual(segmentation[0].shape, target_sizes[0])
|
||||
|
||||
def test_post_process_instance_segmentation(self):
|
||||
feature_extractor = self.image_processing_class(num_labels=self.image_processor_tester.num_classes)
|
||||
outputs = self.image_processor_tester.get_fake_maskformer_outputs()
|
||||
segmentation = feature_extractor.post_process_instance_segmentation(outputs, threshold=0)
|
||||
|
||||
self.assertTrue(len(segmentation) == self.image_processor_tester.batch_size)
|
||||
for el in segmentation:
|
||||
self.assertTrue("segmentation" in el)
|
||||
self.assertTrue("segments_info" in el)
|
||||
self.assertEqual(type(el["segments_info"]), list)
|
||||
self.assertEqual(
|
||||
el["segmentation"].shape, (self.image_processor_tester.height, self.image_processor_tester.width)
|
||||
)
|
||||
|
||||
segmentation = feature_extractor.post_process_instance_segmentation(
|
||||
outputs, threshold=0, return_binary_maps=True
|
||||
)
|
||||
|
||||
self.assertTrue(len(segmentation) == self.image_processor_tester.batch_size)
|
||||
for el in segmentation:
|
||||
self.assertTrue("segmentation" in el)
|
||||
self.assertTrue("segments_info" in el)
|
||||
self.assertEqual(type(el["segments_info"]), list)
|
||||
self.assertEqual(len(el["segmentation"].shape), 3)
|
||||
self.assertEqual(
|
||||
el["segmentation"].shape[1:], (self.image_processor_tester.height, self.image_processor_tester.width)
|
||||
)
|
||||
|
||||
def test_post_process_panoptic_segmentation(self):
|
||||
image_processing = self.image_processing_class(num_labels=self.image_processor_tester.num_classes)
|
||||
outputs = self.image_processor_tester.get_fake_maskformer_outputs()
|
||||
|
Loading…
Reference in New Issue
Block a user