mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 02:31:11 +06:00
Add post_process_semantic_segmentation method to DPTFeatureExtractor (#19107)
* add post-processing method for semantic segmentation * add test for post-processing
This commit is contained in:
parent
da6a1b6ca1
commit
e7fdfc720a
@ -37,6 +37,7 @@ This model was contributed by [nielsr](https://huggingface.co/nielsr). The origi
|
||||
|
||||
[[autodoc]] DPTFeatureExtractor
|
||||
- __call__
|
||||
- post_process_semantic_segmentation
|
||||
|
||||
|
||||
## DPTModel
|
||||
|
@ -14,13 +14,12 @@
|
||||
# limitations under the License.
|
||||
"""Feature extractor class for DPT."""
|
||||
|
||||
from typing import Optional, Union
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from ...feature_extraction_utils import BatchFeature, FeatureExtractionMixin
|
||||
from ...file_utils import TensorType
|
||||
from ...image_utils import (
|
||||
IMAGENET_STANDARD_MEAN,
|
||||
IMAGENET_STANDARD_STD,
|
||||
@ -28,9 +27,12 @@ from ...image_utils import (
|
||||
ImageInput,
|
||||
is_torch_tensor,
|
||||
)
|
||||
from ...utils import logging
|
||||
from ...utils import TensorType, is_torch_available, logging
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@ -200,3 +202,44 @@ class DPTFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
|
||||
encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors)
|
||||
|
||||
return encoded_inputs
|
||||
|
||||
def post_process_semantic_segmentation(self, outputs, target_sizes: List[Tuple] = None):
|
||||
"""
|
||||
Converts the output of [`DPTForSemanticSegmentation`] into semantic segmentation maps. Only supports PyTorch.
|
||||
|
||||
Args:
|
||||
outputs ([`DPTForSemanticSegmentation`]):
|
||||
Raw outputs of the model.
|
||||
target_sizes (`List[Tuple]` of length `batch_size`, *optional*):
|
||||
List of tuples corresponding to the requested final size (height, width) of each prediction. If left to
|
||||
None, predictions will not be resized.
|
||||
Returns:
|
||||
semantic_segmentation: `List[torch.Tensor]` of length `batch_size`, where each item is a semantic
|
||||
segmentation map of shape (height, width) corresponding to the target_sizes entry (if `target_sizes` is
|
||||
specified). Each entry of each `torch.Tensor` correspond to a semantic class id.
|
||||
"""
|
||||
logits = outputs.logits
|
||||
|
||||
# Resize logits and compute semantic segmentation maps
|
||||
if target_sizes is not None:
|
||||
if len(logits) != len(target_sizes):
|
||||
raise ValueError(
|
||||
"Make sure that you pass in as many target sizes as the batch dimension of the logits"
|
||||
)
|
||||
|
||||
if is_torch_tensor(target_sizes):
|
||||
target_sizes = target_sizes.numpy()
|
||||
|
||||
semantic_segmentation = []
|
||||
|
||||
for idx in range(len(logits)):
|
||||
resized_logits = torch.nn.functional.interpolate(
|
||||
logits[idx].unsqueeze(dim=0), size=target_sizes[idx], mode="bilinear", align_corners=False
|
||||
)
|
||||
semantic_map = resized_logits[0].argmax(dim=0)
|
||||
semantic_segmentation.append(semantic_map)
|
||||
else:
|
||||
semantic_segmentation = logits.argmax(dim=1)
|
||||
semantic_segmentation = [semantic_segmentation[i] for i in range(semantic_segmentation.shape[0])]
|
||||
|
||||
return semantic_segmentation
|
||||
|
@ -298,3 +298,24 @@ class DPTModelIntegrationTest(unittest.TestCase):
|
||||
).to(torch_device)
|
||||
|
||||
self.assertTrue(torch.allclose(outputs.logits[0, 0, :3, :3], expected_slice, atol=1e-4))
|
||||
|
||||
def test_post_processing_semantic_segmentation(self):
|
||||
feature_extractor = DPTFeatureExtractor.from_pretrained("Intel/dpt-large-ade")
|
||||
model = DPTForSemanticSegmentation.from_pretrained("Intel/dpt-large-ade").to(torch_device)
|
||||
|
||||
image = prepare_img()
|
||||
inputs = feature_extractor(images=image, return_tensors="pt").to(torch_device)
|
||||
|
||||
# forward pass
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs)
|
||||
|
||||
outputs.logits = outputs.logits.detach().cpu()
|
||||
|
||||
segmentation = feature_extractor.post_process_semantic_segmentation(outputs=outputs, target_sizes=[(500, 300)])
|
||||
expected_shape = torch.Size((500, 300))
|
||||
self.assertEqual(segmentation[0].shape, expected_shape)
|
||||
|
||||
segmentation = feature_extractor.post_process_semantic_segmentation(outputs=outputs)
|
||||
expected_shape = torch.Size((480, 480))
|
||||
self.assertEqual(segmentation[0].shape, expected_shape)
|
||||
|
Loading…
Reference in New Issue
Block a user