mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Beit postprocessing (#19099)
* add post_process_semantic_segmentation method to BeiTFeatureExtractor
This commit is contained in:
parent
261301d388
commit
c81ebd1c39
@ -82,6 +82,7 @@ contributed by [kamalkraj](https://huggingface.co/kamalkraj). The original code
|
||||
|
||||
[[autodoc]] BeitFeatureExtractor
|
||||
- __call__
|
||||
- post_process_semantic_segmentation
|
||||
|
||||
## BeitModel
|
||||
|
||||
|
@ -14,7 +14,7 @@
|
||||
# limitations under the License.
|
||||
"""Feature extractor class for BEiT."""
|
||||
|
||||
from typing import Optional, Union
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
@ -27,9 +27,12 @@ from ...image_utils import (
|
||||
ImageInput,
|
||||
is_torch_tensor,
|
||||
)
|
||||
from ...utils import TensorType, logging
|
||||
from ...utils import TensorType, is_torch_available, logging
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@ -222,3 +225,44 @@ class BeitFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
|
||||
encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors)
|
||||
|
||||
return encoded_inputs
|
||||
|
||||
def post_process_semantic_segmentation(self, outputs, target_sizes: Union[TensorType, List[Tuple]] = None):
|
||||
"""
|
||||
Converts the output of [`BeitForSemanticSegmentation`] into semantic segmentation maps. Only supports PyTorch.
|
||||
|
||||
Args:
|
||||
outputs ([`BeitForSemanticSegmentation`]):
|
||||
Raw outputs of the model.
|
||||
target_sizes (`torch.Tensor` of shape `(batch_size, 2)` or `List[Tuple]` of length `batch_size`, *optional*):
|
||||
Torch Tensor (or list) corresponding to the requested final size (h, w) of each prediction. If left to
|
||||
None, predictions will not be resized.
|
||||
Returns:
|
||||
semantic_segmentation: `torch.Tensor` of shape `(batch_size, 2)` or `List[torch.Tensor]` of length
|
||||
`batch_size`, where each item is a semantic segmentation map of of the corresponding target_sizes entry (if
|
||||
`target_sizes` is specified). Each entry of each `torch.Tensor` correspond to a semantic class id.
|
||||
"""
|
||||
logits = outputs.logits
|
||||
|
||||
if len(logits) != len(target_sizes):
|
||||
raise ValueError("Make sure that you pass in as many target sizes as the batch dimension of the logits")
|
||||
|
||||
if target_sizes is not None and target_sizes.shape[1] != 2:
|
||||
raise ValueError("Each element of target_sizes must contain the size (h, w) of each image of the batch")
|
||||
|
||||
semantic_segmentation = logits.argmax(dim=1)
|
||||
|
||||
# Resize semantic segmentation maps
|
||||
if target_sizes is not None:
|
||||
if is_torch_tensor(target_sizes):
|
||||
target_sizes = target_sizes.numpy()
|
||||
|
||||
resized_maps = []
|
||||
semantic_segmentation = semantic_segmentation.numpy()
|
||||
|
||||
for idx in range(len(semantic_segmentation)):
|
||||
resized = self.resize(image=semantic_segmentation[idx], size=target_sizes[idx])
|
||||
resized_maps.append(resized)
|
||||
|
||||
semantic_segmentation = [torch.Tensor(np.array(image)) for image in resized_maps]
|
||||
|
||||
return semantic_segmentation
|
||||
|
Loading…
Reference in New Issue
Block a user