Add post_process_semantic_segmentation method to SegFormer (#19072)

* add post_process_semantic_segmentation method to SegformerFeatureExtractor
* add test for semantic segmentation post-processing
This commit is contained in:
Alara Dirik 2022-09-21 11:40:35 +03:00 committed by GitHub
parent ef6741fe65
commit 9e95706648
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 75 additions and 2 deletions

View File

@ -93,6 +93,7 @@ SegFormer's results on the segmentation datasets like ADE20k, refer to the [pape
[[autodoc]] SegformerFeatureExtractor
- __call__
- post_process_semantic_segmentation
## SegformerModel

View File

@ -14,7 +14,7 @@
# limitations under the License.
"""Feature extractor class for SegFormer."""
from typing import Optional, Union
from typing import List, Optional, Tuple, Union
import numpy as np
from PIL import Image
@ -27,9 +27,12 @@ from ...image_utils import (
ImageInput,
is_torch_tensor,
)
from ...utils import TensorType, logging
from ...utils import TensorType, is_torch_available, logging
if is_torch_available():
import torch
logger = logging.get_logger(__name__)
@ -211,3 +214,45 @@ class SegformerFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMi
encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors)
return encoded_inputs
def post_process_semantic_segmentation(self, outputs, target_sizes: List[Tuple] = None):
"""
Converts the output of [`SegformerForSemanticSegmentation`] into semantic segmentation maps. Only supports
PyTorch.
Args:
outputs ([`SegformerForSemanticSegmentation`]):
Raw outputs of the model.
target_sizes (`List[Tuple]` of length `batch_size`, *optional*):
List of tuples corresponding to the requested final size (height, width) of each prediction. If left to
None, predictions will not be resized.
Returns:
semantic_segmentation: `List[torch.Tensor]` of length `batch_size`, where each item is a semantic
segmentation map of shape (height, width) corresponding to the target_sizes entry (if `target_sizes` is
specified). Each entry of each `torch.Tensor` correspond to a semantic class id.
"""
logits = outputs.logits
# Resize logits and compute semantic segmentation maps
if target_sizes is not None:
if len(logits) != len(target_sizes):
raise ValueError(
"Make sure that you pass in as many target sizes as the batch dimension of the logits"
)
if is_torch_tensor(target_sizes):
target_sizes = target_sizes.numpy()
semantic_segmentation = []
for idx in range(len(logits)):
resized_logits = torch.nn.functional.interpolate(
logits[idx].unsqueeze(dim=0), size=target_sizes[idx], mode="bilinear", align_corners=False
)
semantic_map = resized_logits[0].argmax(dim=0)
semantic_segmentation.append(semantic_map)
else:
semantic_segmentation = logits.argmax(dim=1)
semantic_segmentation = [semantic_segmentation[i] for i in range(semantic_segmentation.shape[0])]
return semantic_segmentation

View File

@ -395,3 +395,30 @@ class SegformerModelIntegrationTest(unittest.TestCase):
]
).to(torch_device)
self.assertTrue(torch.allclose(outputs.logits[0, :3, :3, :3], expected_slice, atol=1e-1))
@slow
def test_post_processing_semantic_segmentation(self):
# only resize + normalize
feature_extractor = SegformerFeatureExtractor(
image_scale=(512, 512), keep_ratio=False, align=False, do_random_crop=False
)
model = SegformerForSemanticSegmentation.from_pretrained("nvidia/segformer-b0-finetuned-ade-512-512").to(
torch_device
)
image = prepare_img()
encoded_inputs = feature_extractor(images=image, return_tensors="pt")
pixel_values = encoded_inputs.pixel_values.to(torch_device)
with torch.no_grad():
outputs = model(pixel_values)
outputs.logits = outputs.logits.detach().cpu()
segmentation = feature_extractor.post_process_semantic_segmentation(outputs=outputs, target_sizes=[(500, 300)])
expected_shape = torch.Size((500, 300))
self.assertEqual(segmentation[0].shape, expected_shape)
segmentation = feature_extractor.post_process_semantic_segmentation(outputs=outputs)
expected_shape = torch.Size((128, 128))
self.assertEqual(segmentation[0].shape, expected_shape)