Beit postprocessing (#19099)

* add post_process_semantic_segmentation method to BeiTFeatureExtractor
This commit is contained in:
Alara Dirik 2022-09-20 10:41:56 +03:00 committed by GitHub
parent 261301d388
commit c81ebd1c39
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 47 additions and 2 deletions

View File

@ -82,6 +82,7 @@ contributed by [kamalkraj](https://huggingface.co/kamalkraj). The original code
[[autodoc]] BeitFeatureExtractor
- __call__
- post_process_semantic_segmentation
## BeitModel

View File

@ -14,7 +14,7 @@
# limitations under the License.
"""Feature extractor class for BEiT."""
from typing import Optional, Union
from typing import List, Optional, Tuple, Union
import numpy as np
from PIL import Image
@ -27,9 +27,12 @@ from ...image_utils import (
ImageInput,
is_torch_tensor,
)
from ...utils import TensorType, logging
from ...utils import TensorType, is_torch_available, logging
if is_torch_available():
import torch
logger = logging.get_logger(__name__)
@ -222,3 +225,44 @@ class BeitFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors)
return encoded_inputs
def post_process_semantic_segmentation(self, outputs, target_sizes: Union[TensorType, List[Tuple]] = None):
"""
Converts the output of [`BeitForSemanticSegmentation`] into semantic segmentation maps. Only supports PyTorch.
Args:
outputs ([`BeitForSemanticSegmentation`]):
Raw outputs of the model.
target_sizes (`torch.Tensor` of shape `(batch_size, 2)` or `List[Tuple]` of length `batch_size`, *optional*):
Torch Tensor (or list) corresponding to the requested final size (h, w) of each prediction. If left to
None, predictions will not be resized.
Returns:
semantic_segmentation: `torch.Tensor` of shape `(batch_size, 2)` or `List[torch.Tensor]` of length
`batch_size`, where each item is a semantic segmentation map of of the corresponding target_sizes entry (if
`target_sizes` is specified). Each entry of each `torch.Tensor` correspond to a semantic class id.
"""
logits = outputs.logits
if len(logits) != len(target_sizes):
raise ValueError("Make sure that you pass in as many target sizes as the batch dimension of the logits")
if target_sizes is not None and target_sizes.shape[1] != 2:
raise ValueError("Each element of target_sizes must contain the size (h, w) of each image of the batch")
semantic_segmentation = logits.argmax(dim=1)
# Resize semantic segmentation maps
if target_sizes is not None:
if is_torch_tensor(target_sizes):
target_sizes = target_sizes.numpy()
resized_maps = []
semantic_segmentation = semantic_segmentation.numpy()
for idx in range(len(semantic_segmentation)):
resized = self.resize(image=semantic_segmentation[idx], size=target_sizes[idx])
resized_maps.append(resized)
semantic_segmentation = [torch.Tensor(np.array(image)) for image in resized_maps]
return semantic_segmentation