mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Maskformer post-processing fixes and improvements (#19172)
- Improves MaskFormer docs, corrects minor typos - Restructures MaskFormerFeatureExtractor.post_process_panoptic_segmentation for better readability, adds target_sizes argument for optional resizing - Adds post_process_semantic_segmentation and post_process_instance_segmentation methods. - Adds a deprecation warning to post_process_segmentation method in favour of post_process_instance_segmentation
This commit is contained in:
parent
6268694e27
commit
07e94bf159
@ -58,6 +58,7 @@ This model was contributed by [francesco](https://huggingface.co/francesco). The
|
||||
- encode_inputs
|
||||
- post_process_segmentation
|
||||
- post_process_semantic_segmentation
|
||||
- post_process_instance_segmentation
|
||||
- post_process_panoptic_segmentation
|
||||
|
||||
## MaskFormerModel
|
||||
|
@ -35,6 +35,153 @@ if is_torch_available():
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def binary_mask_to_rle(mask):
|
||||
"""
|
||||
Converts given binary mask of shape (height, width) to the run-length encoding (RLE) format.
|
||||
|
||||
Args:
|
||||
mask (`torch.Tensor` or `numpy.array`):
|
||||
A binary mask tensor of shape `(height, width)` where 0 denotes background and 1 denotes the target
|
||||
segment_id or class_id.
|
||||
|
||||
Returns:
|
||||
`List`: Run-length encoded list of the binary mask. Refer to COCO API for more information about the RLE
|
||||
format.
|
||||
"""
|
||||
if is_torch_tensor(mask):
|
||||
mask = mask.numpy()
|
||||
|
||||
pixels = mask.flatten()
|
||||
pixels = np.concatenate([[0], pixels, [0]])
|
||||
runs = np.where(pixels[1:] != pixels[:-1])[0] + 1
|
||||
runs[1::2] -= runs[::2]
|
||||
return [x for x in runs]
|
||||
|
||||
|
||||
def convert_segmentation_to_rle(segmentation):
|
||||
"""
|
||||
Converts given segmentation map of shape (height, width) to the run-length encoding (RLE) format.
|
||||
|
||||
Args:
|
||||
segmentation (`torch.Tensor` or `numpy.array`):
|
||||
A segmentation map of shape `(height, width)` where each value denotes a segment or class id.
|
||||
|
||||
Returns:
|
||||
`List[List]`: A list of lists, where each list is the run-length encoding of a segment / class id.
|
||||
"""
|
||||
segment_ids = torch.unique(segmentation)
|
||||
|
||||
run_length_encodings = []
|
||||
for idx in segment_ids:
|
||||
mask = torch.where(segmentation == idx, 1, 0)
|
||||
rle = binary_mask_to_rle(mask)
|
||||
run_length_encodings.append(rle)
|
||||
|
||||
return run_length_encodings
|
||||
|
||||
|
||||
def remove_low_and_no_objects(masks, scores, labels, object_mask_threshold, num_labels):
|
||||
"""
|
||||
Binarize the given masks using `object_mask_threshold`, it returns the associated values of `masks`, `scores` and
|
||||
`labels`.
|
||||
|
||||
Args:
|
||||
masks (`torch.Tensor`):
|
||||
A tensor of shape `(num_queries, height, width)`.
|
||||
scores (`torch.Tensor`):
|
||||
A tensor of shape `(num_queries)`.
|
||||
labels (`torch.Tensor`):
|
||||
A tensor of shape `(num_queries)`.
|
||||
object_mask_threshold (`float`):
|
||||
A number between 0 and 1 used to binarize the masks.
|
||||
|
||||
Raises:
|
||||
`ValueError`: Raised when the first dimension doesn't match in all input tensors.
|
||||
|
||||
Returns:
|
||||
`Tuple[`torch.Tensor`, `torch.Tensor`, `torch.Tensor`]`: The `masks`, `scores` and `labels` without the region
|
||||
< `object_mask_threshold`.
|
||||
"""
|
||||
if not (masks.shape[0] == scores.shape[0] == labels.shape[0]):
|
||||
raise ValueError("mask, scores and labels must have the same shape!")
|
||||
|
||||
to_keep = labels.ne(num_labels) & (scores > object_mask_threshold)
|
||||
return masks[to_keep], scores[to_keep], labels[to_keep]
|
||||
|
||||
|
||||
def check_segment_validity(mask_labels, mask_probs, k, overlap_mask_area_threshold=0.8):
|
||||
# Get the mask associated with the k class
|
||||
mask_k = mask_labels == k
|
||||
mask_k_area = mask_k.sum()
|
||||
|
||||
# Compute the area of all the stuff in query k
|
||||
original_area = (mask_probs[k] >= 0.5).sum()
|
||||
mask_exists = mask_k_area > 0 and original_area > 0
|
||||
|
||||
# Eliminate disconnected tiny segments
|
||||
if mask_exists:
|
||||
area_ratio = mask_k_area / original_area
|
||||
if not area_ratio.item() > overlap_mask_area_threshold:
|
||||
mask_exists = False
|
||||
|
||||
return mask_exists, mask_k
|
||||
|
||||
|
||||
def compute_segments(
|
||||
mask_probs,
|
||||
pred_scores,
|
||||
pred_labels,
|
||||
overlap_mask_area_threshold: float = 0.8,
|
||||
label_ids_to_fuse: Optional[Set[int]] = None,
|
||||
target_size: Tuple[int, int] = None,
|
||||
):
|
||||
height = mask_probs.shape[1] if target_size is None else target_size[0]
|
||||
width = mask_probs.shape[2] if target_size is None else target_size[1]
|
||||
|
||||
segmentation = torch.zeros((height, width), dtype=torch.int32, device=mask_probs.device)
|
||||
segments: List[Dict] = []
|
||||
|
||||
if target_size is not None:
|
||||
mask_probs = interpolate(mask_probs.unsqueeze(0), size=target_size, mode="bilinear", align_corners=False)[0]
|
||||
|
||||
current_segment_id = 0
|
||||
|
||||
# Weigh each mask by its prediction score
|
||||
mask_probs *= pred_scores.view(-1, 1, 1)
|
||||
mask_labels = mask_probs.argmax(0) # [height, width]
|
||||
|
||||
# Keep track of instances of each class
|
||||
stuff_memory_list: Dict[str, int] = {}
|
||||
for k in range(pred_labels.shape[0]):
|
||||
pred_class = pred_labels[k].item()
|
||||
should_fuse = pred_class in label_ids_to_fuse
|
||||
|
||||
# Check if mask exists and large enough to be a segment
|
||||
mask_exists, mask_k = check_segment_validity(mask_labels, mask_probs, k, overlap_mask_area_threshold)
|
||||
|
||||
if mask_exists:
|
||||
if pred_class in stuff_memory_list:
|
||||
current_segment_id = stuff_memory_list[pred_class]
|
||||
else:
|
||||
current_segment_id += 1
|
||||
|
||||
# Add current object segment to final segmentation map
|
||||
segmentation[mask_k] = current_segment_id
|
||||
segment_score = round(pred_scores[k].item(), 6)
|
||||
segments.append(
|
||||
{
|
||||
"id": current_segment_id,
|
||||
"label_id": pred_class,
|
||||
"was_fused": should_fuse,
|
||||
"score": segment_score,
|
||||
}
|
||||
)
|
||||
if should_fuse:
|
||||
stuff_memory_list[pred_class] = current_segment_id
|
||||
|
||||
return segmentation, segments
|
||||
|
||||
|
||||
class MaskFormerFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
|
||||
r"""
|
||||
Constructs a MaskFormer feature extractor. The feature extractor can be used to prepare image(s) and optional
|
||||
@ -488,6 +635,12 @@ class MaskFormerFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionM
|
||||
`torch.Tensor`:
|
||||
A tensor of shape (`batch_size, num_class_labels, height, width`).
|
||||
"""
|
||||
logger.warning(
|
||||
"`post_process_segmentation` is deprecated and will be removed in v5 of Transformers, please use"
|
||||
" `post_process_instance_segmentation`",
|
||||
FutureWarning,
|
||||
)
|
||||
|
||||
# class_queries_logits has shape [BATCH, QUERIES, CLASSES + 1]
|
||||
class_queries_logits = outputs.class_queries_logits
|
||||
# masks_queries_logits has shape [BATCH, QUERIES, HEIGHT, WIDTH]
|
||||
@ -512,59 +665,141 @@ class MaskFormerFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionM
|
||||
|
||||
return segmentation
|
||||
|
||||
def remove_low_and_no_objects(self, masks, scores, labels, object_mask_threshold, num_labels):
|
||||
"""
|
||||
Binarize the given masks using `object_mask_threshold`, it returns the associated values of `masks`, `scores`
|
||||
and `labels`.
|
||||
|
||||
Args:
|
||||
masks (`torch.Tensor`):
|
||||
A tensor of shape `(num_queries, height, width)`.
|
||||
scores (`torch.Tensor`):
|
||||
A tensor of shape `(num_queries)`.
|
||||
labels (`torch.Tensor`):
|
||||
A tensor of shape `(num_queries)`.
|
||||
object_mask_threshold (`float`):
|
||||
A number between 0 and 1 used to binarize the masks.
|
||||
|
||||
Raises:
|
||||
`ValueError`: Raised when the first dimension doesn't match in all input tensors.
|
||||
|
||||
Returns:
|
||||
`Tuple[`torch.Tensor`, `torch.Tensor`, `torch.Tensor`]`: The `masks`, `scores` and `labels` without the
|
||||
region < `object_mask_threshold`.
|
||||
"""
|
||||
if not (masks.shape[0] == scores.shape[0] == labels.shape[0]):
|
||||
raise ValueError("mask, scores and labels must have the same shape!")
|
||||
|
||||
to_keep = labels.ne(num_labels) & (scores > object_mask_threshold)
|
||||
|
||||
return masks[to_keep], scores[to_keep], labels[to_keep]
|
||||
|
||||
def post_process_semantic_segmentation(
|
||||
self, outputs: "MaskFormerForInstanceSegmentationOutput", target_size: Tuple[int, int] = None
|
||||
self, outputs, target_sizes: Optional[List[Tuple[int, int]]] = None
|
||||
) -> "torch.Tensor":
|
||||
"""
|
||||
Converts the output of [`MaskFormerForInstanceSegmentationOutput`] into semantic segmentation predictions. Only
|
||||
Converts the output of [`MaskFormerForInstanceSegmentation`] into semantic segmentation maps. Only supports
|
||||
PyTorch.
|
||||
|
||||
Args:
|
||||
outputs ([`MaskFormerForInstanceSegmentation`]):
|
||||
Raw outputs of the model.
|
||||
target_sizes (`List[Tuple[int, int]]`, *optional*, defaults to `None`):
|
||||
List of length (batch_size), where each list item (`Tuple[int, int]]`) corresponds to the requested
|
||||
final size (height, width) of each prediction. If left to None, predictions will not be resized.
|
||||
Returns:
|
||||
`List[torch.Tensor]`:
|
||||
A list of length `batch_size`, where each item is a semantic segmentation map of shape (height, width)
|
||||
corresponding to the target_sizes entry (if `target_sizes` is specified). Each entry of each
|
||||
`torch.Tensor` correspond to a semantic class id.
|
||||
"""
|
||||
class_queries_logits = outputs.class_queries_logits # [batch_size, num_queries, num_classes+1]
|
||||
masks_queries_logits = outputs.masks_queries_logits # [batch_size, num_queries, height, width]
|
||||
|
||||
# Remove the null class `[..., :-1]`
|
||||
masks_classes = class_queries_logits.softmax(dim=-1)[..., :-1]
|
||||
masks_probs = masks_queries_logits.sigmoid() # [batch_size, num_queries, height, width]
|
||||
|
||||
# Semantic segmentation logits of shape (batch_size, num_classes, height, width)
|
||||
segmentation = torch.einsum("bqc, bqhw -> bchw", masks_classes, masks_probs)
|
||||
batch_size = class_queries_logits.shape[0]
|
||||
|
||||
# Resize logits and compute semantic segmentation maps
|
||||
if target_sizes is not None:
|
||||
if batch_size != len(target_sizes):
|
||||
raise ValueError(
|
||||
"Make sure that you pass in as many target sizes as the batch dimension of the logits"
|
||||
)
|
||||
|
||||
semantic_segmentation = []
|
||||
for idx in range(batch_size):
|
||||
resized_logits = torch.nn.functional.interpolate(
|
||||
segmentation[idx].unsqueeze(dim=0), size=target_sizes[idx], mode="bilinear", align_corners=False
|
||||
)
|
||||
semantic_map = resized_logits[0].argmax(dim=0)
|
||||
semantic_segmentation.append(semantic_map)
|
||||
else:
|
||||
semantic_segmentation = segmentation.argmax(dim=1)
|
||||
semantic_segmentation = [semantic_segmentation[i] for i in range(semantic_segmentation.shape[0])]
|
||||
|
||||
return semantic_segmentation
|
||||
|
||||
def post_process_instance_segmentation(
|
||||
self,
|
||||
outputs,
|
||||
threshold: float = 0.5,
|
||||
overlap_mask_area_threshold: float = 0.8,
|
||||
target_sizes: Optional[List[Tuple[int, int]]] = None,
|
||||
return_coco_annotation: Optional[bool] = False,
|
||||
) -> List[Dict]:
|
||||
"""
|
||||
Converts the output of [`MaskFormerForInstanceSegmentationOutput`] into instance segmentation predictions. Only
|
||||
supports PyTorch.
|
||||
|
||||
Args:
|
||||
outputs ([`MaskFormerForInstanceSegmentationOutput`]):
|
||||
The outputs from [`MaskFormerForInstanceSegmentation`].
|
||||
|
||||
outputs ([`MaskFormerForInstanceSegmentation`]):
|
||||
Raw outputs of the model.
|
||||
threshold (`float`, *optional*, defaults to 0.5):
|
||||
The probability score threshold to keep predicted instance masks.
|
||||
overlap_mask_area_threshold (`float`, *optional*, defaults to 0.8):
|
||||
The overlap mask area threshold to merge or discard small disconnected parts within each binary
|
||||
instance mask.
|
||||
target_sizes (`List[Tuple]`, *optional*):
|
||||
List of length (batch_size), where each list item (`Tuple[int, int]]`) corresponds to the requested
|
||||
final size (height, width) of each prediction. If left to None, predictions will not be resized.
|
||||
return_coco_annotation (`bool`, *optional*):
|
||||
Defaults to `False`. If set to `True`, segmentation maps are returned in COCO run-length encoding (RLE)
|
||||
format.
|
||||
Returns:
|
||||
`torch.Tensor`: A tensor of shape `batch_size, height, width`.
|
||||
`List[Dict]`: A list of dictionaries, one per image, each dictionary containing two keys:
|
||||
- **segmentation** -- A tensor of shape `(height, width)` where each pixel represents a `segment_id` or
|
||||
`List[List]` run-length encoding (RLE) of the segmentation map if return_coco_annotation is set to
|
||||
`True`. Set to `None` if no mask if found above `threshold`.
|
||||
- **segments_info** -- A dictionary that contains additional information on each segment.
|
||||
- **id** -- An integer representing the `segment_id`.
|
||||
- **label_id** -- An integer representing the label / semantic class id corresponding to `segment_id`.
|
||||
- **score** -- Prediction score of segment with `segment_id`.
|
||||
"""
|
||||
segmentation = self.post_process_segmentation(outputs, target_size)
|
||||
semantic_segmentation = segmentation.argmax(dim=1)
|
||||
return semantic_segmentation
|
||||
class_queries_logits = outputs.class_queries_logits # [batch_size, num_queries, num_classes+1]
|
||||
masks_queries_logits = outputs.masks_queries_logits # [batch_size, num_queries, height, width]
|
||||
|
||||
batch_size = class_queries_logits.shape[0]
|
||||
num_labels = class_queries_logits.shape[-1] - 1
|
||||
|
||||
mask_probs = masks_queries_logits.sigmoid() # [batch_size, num_queries, height, width]
|
||||
|
||||
# Predicted label and score of each query (batch_size, num_queries)
|
||||
pred_scores, pred_labels = nn.functional.softmax(class_queries_logits, dim=-1).max(-1)
|
||||
|
||||
# Loop over items in batch size
|
||||
results: List[Dict[str, Tensor]] = []
|
||||
|
||||
for i in range(batch_size):
|
||||
mask_probs_item, pred_scores_item, pred_labels_item = remove_low_and_no_objects(
|
||||
mask_probs[i], pred_scores[i], pred_labels[i], threshold, num_labels
|
||||
)
|
||||
|
||||
# No mask found
|
||||
if mask_probs_item.shape[0] <= 0:
|
||||
segmentation = None
|
||||
segments: List[Dict] = []
|
||||
continue
|
||||
|
||||
# Get segmentation map and segment information of batch item
|
||||
target_size = target_sizes[i] if target_sizes is not None else None
|
||||
segmentation, segments = compute_segments(
|
||||
mask_probs_item,
|
||||
pred_scores_item,
|
||||
pred_labels_item,
|
||||
overlap_mask_area_threshold,
|
||||
target_size,
|
||||
)
|
||||
|
||||
# Return segmentation map in run-length encoding (RLE) format
|
||||
if return_coco_annotation:
|
||||
segmentation = convert_segmentation_to_rle(segmentation)
|
||||
|
||||
results.append({"segmentation": segmentation, "segments_info": segments})
|
||||
return results
|
||||
|
||||
def post_process_panoptic_segmentation(
|
||||
self,
|
||||
outputs: "MaskFormerForInstanceSegmentationOutput",
|
||||
object_mask_threshold: float = 0.8,
|
||||
outputs,
|
||||
threshold: float = 0.5,
|
||||
overlap_mask_area_threshold: float = 0.8,
|
||||
label_ids_to_fuse: Optional[Set[int]] = None,
|
||||
target_sizes: Optional[List[Tuple[int, int]]] = None,
|
||||
) -> List[Dict]:
|
||||
"""
|
||||
Converts the output of [`MaskFormerForInstanceSegmentationOutput`] into image panoptic segmentation
|
||||
@ -573,94 +808,72 @@ class MaskFormerFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionM
|
||||
Args:
|
||||
outputs ([`MaskFormerForInstanceSegmentationOutput`]):
|
||||
The outputs from [`MaskFormerForInstanceSegmentation`].
|
||||
object_mask_threshold (`float`, *optional*, defaults to 0.8):
|
||||
The object mask threshold.
|
||||
threshold (`float`, *optional*, defaults to 0.5):
|
||||
The probability score threshold to keep predicted instance masks.
|
||||
overlap_mask_area_threshold (`float`, *optional*, defaults to 0.8):
|
||||
The overlap mask area threshold to use.
|
||||
The overlap mask area threshold to merge or discard small disconnected parts within each binary
|
||||
instance mask.
|
||||
label_ids_to_fuse (`Set[int]`, *optional*):
|
||||
The labels in this state will have all their instances be fused together. For instance we could say
|
||||
there can only be one sky in an image, but several persons, so the label ID for sky would be in that
|
||||
set, but not the one for person.
|
||||
target_sizes (`List[Tuple]`, *optional*):
|
||||
List of length (batch_size), where each list item (`Tuple[int, int]]`) corresponds to the requested
|
||||
final size (height, width) of each prediction in batch. If left to None, predictions will not be
|
||||
resized.
|
||||
|
||||
Returns:
|
||||
`List[Dict]`: A list of dictionaries, one per image, each dictionary containing two keys:
|
||||
- **segmentation** -- a tensor of shape `(height, width)` where each pixel represents a `segment_id`.
|
||||
- **segments** -- a dictionary with the following keys
|
||||
- **segmentation** -- a tensor of shape `(height, width)` where each pixel represents a `segment_id`, set
|
||||
to `None` if no mask if found above `threshold`. If `target_sizes` is specified, segmentation is resized
|
||||
to the corresponding `target_sizes` entry.
|
||||
- **segments_info** -- A dictionary that contains additional information on each segment.
|
||||
- **id** -- an integer representing the `segment_id`.
|
||||
- **label_id** -- an integer representing the segment's label.
|
||||
- **label_id** -- An integer representing the label / semantic class id corresponding to `segment_id`.
|
||||
- **was_fused** -- a boolean, `True` if `label_id` was in `label_ids_to_fuse`, `False` otherwise.
|
||||
Multiple instances of the same class / label were fused and assigned a single `segment_id`.
|
||||
- **score** -- Prediction score of segment with `segment_id`.
|
||||
"""
|
||||
|
||||
if label_ids_to_fuse is None:
|
||||
logger.warning("`label_ids_to_fuse` unset. No instance will be fused.")
|
||||
label_ids_to_fuse = set()
|
||||
# class_queries_logits has shape [BATCH, QUERIES, CLASSES + 1]
|
||||
class_queries_logits = outputs.class_queries_logits
|
||||
# keep track of the number of labels, subtract -1 for null class
|
||||
|
||||
class_queries_logits = outputs.class_queries_logits # [batch_size, num_queries, num_classes+1]
|
||||
masks_queries_logits = outputs.masks_queries_logits # [batch_size, num_queries, height, width]
|
||||
|
||||
batch_size = class_queries_logits.shape[0]
|
||||
num_labels = class_queries_logits.shape[-1] - 1
|
||||
# masks_queries_logits has shape [BATCH, QUERIES, HEIGHT, WIDTH]
|
||||
masks_queries_logits = outputs.masks_queries_logits
|
||||
# since all images are padded, they all have the same spatial dimensions
|
||||
_, _, height, width = masks_queries_logits.shape
|
||||
# for each query, the best scores and their indeces
|
||||
|
||||
mask_probs = masks_queries_logits.sigmoid() # [batch_size, num_queries, height, width]
|
||||
|
||||
# Predicted label and score of each query (batch_size, num_queries)
|
||||
pred_scores, pred_labels = nn.functional.softmax(class_queries_logits, dim=-1).max(-1)
|
||||
# pred_scores and pred_labels shape = [BATH,NUM_QUERIES]
|
||||
mask_probs = masks_queries_logits.sigmoid()
|
||||
# mask probs has shape [BATCH, QUERIES, HEIGHT, WIDTH]
|
||||
# now, we need to iterate over the batch size to correctly process the segmentation we got from the queries using our thresholds. Even if the original predicted masks have the same shape across the batch, they won't after thresholding so batch-wise operations are impossible
|
||||
|
||||
# Loop over items in batch size
|
||||
results: List[Dict[str, Tensor]] = []
|
||||
for mask_probs, pred_scores, pred_labels in zip(mask_probs, pred_scores, pred_labels):
|
||||
mask_probs, pred_scores, pred_labels = self.remove_low_and_no_objects(
|
||||
mask_probs, pred_scores, pred_labels, object_mask_threshold, num_labels
|
||||
|
||||
for i in range(batch_size):
|
||||
mask_probs_item, pred_scores_item, pred_labels_item = remove_low_and_no_objects(
|
||||
mask_probs[i], pred_scores[i], pred_labels[i], threshold, num_labels
|
||||
)
|
||||
we_detect_something = mask_probs.shape[0] > 0
|
||||
|
||||
segmentation = torch.zeros((height, width), dtype=torch.int32, device=mask_probs.device)
|
||||
segments: List[Dict] = []
|
||||
# No mask found
|
||||
if mask_probs_item.shape[0] <= 0:
|
||||
segmentation = None
|
||||
segments: List[Dict] = []
|
||||
continue
|
||||
|
||||
if we_detect_something:
|
||||
current_segment_id = 0
|
||||
# weight each mask by its score
|
||||
mask_probs *= pred_scores.view(-1, 1, 1)
|
||||
# find out for each pixel what is the most likely class to be there
|
||||
mask_labels = mask_probs.argmax(0)
|
||||
# mask_labels shape = [H,W] where each pixel has a class label
|
||||
stuff_memory_list: Dict[str, int] = {}
|
||||
# this is a map between stuff and segments id, the used it to keep track of the instances of one class
|
||||
for k in range(pred_labels.shape[0]):
|
||||
pred_class = pred_labels[k].item()
|
||||
# check if pred_class should be fused. For example, class "sky" cannot have more then one instance
|
||||
should_fuse = pred_class in label_ids_to_fuse
|
||||
# get the mask associated with the k class
|
||||
mask_k = mask_labels == k
|
||||
# create the area, since bool we just need to sum :)
|
||||
mask_k_area = mask_k.sum()
|
||||
# this is the area of all the stuff in query k
|
||||
original_area = (mask_probs[k] >= 0.5).sum()
|
||||
# Get segmentation map and segment information of batch item
|
||||
target_size = target_sizes[i] if target_sizes is not None else None
|
||||
segmentation, segments = compute_segments(
|
||||
mask_probs_item,
|
||||
pred_scores_item,
|
||||
pred_labels_item,
|
||||
overlap_mask_area_threshold,
|
||||
label_ids_to_fuse,
|
||||
target_size,
|
||||
)
|
||||
|
||||
mask_exists = mask_k_area > 0 and original_area > 0
|
||||
|
||||
if mask_exists:
|
||||
# find out how much of the all area mask_k is using
|
||||
area_ratio = mask_k_area / original_area
|
||||
mask_k_is_overlapping_enough = area_ratio.item() > overlap_mask_area_threshold
|
||||
|
||||
if mask_k_is_overlapping_enough:
|
||||
# merge stuff regions
|
||||
if pred_class in stuff_memory_list:
|
||||
current_segment_id = stuff_memory_list[pred_class]
|
||||
else:
|
||||
current_segment_id += 1
|
||||
# then we update out mask with the current segment
|
||||
segmentation[mask_k] = current_segment_id
|
||||
segments.append(
|
||||
{
|
||||
"id": current_segment_id,
|
||||
"label_id": pred_class,
|
||||
"was_fused": should_fuse,
|
||||
}
|
||||
)
|
||||
if should_fuse:
|
||||
stuff_memory_list[pred_class] = current_segment_id
|
||||
results.append({"segmentation": segmentation, "segments": segments})
|
||||
results.append({"segmentation": segmentation, "segments_info": segments})
|
||||
return results
|
||||
|
@ -259,7 +259,8 @@ class MaskFormerForInstanceSegmentationOutput(ModelOutput):
|
||||
"""
|
||||
Class for outputs of [`MaskFormerForInstanceSegmentation`].
|
||||
|
||||
This output can be directly passed to [`~MaskFormerFeatureExtractor.post_process_segmentation`] or
|
||||
This output can be directly passed to [`~MaskFormerFeatureExtractor.post_process_semantic_segmentation`] or or
|
||||
[`~MaskFormerFeatureExtractor.post_process_instance_segmentation`] or
|
||||
[`~MaskFormerFeatureExtractor.post_process_panoptic_segmentation`] depending on the task. Please, see
|
||||
[`~MaskFormerFeatureExtractor] for details regarding usage.
|
||||
|
||||
@ -267,11 +268,11 @@ class MaskFormerForInstanceSegmentationOutput(ModelOutput):
|
||||
loss (`torch.Tensor`, *optional*):
|
||||
The computed loss, returned when labels are present.
|
||||
class_queries_logits (`torch.FloatTensor`):
|
||||
A tensor of shape `(batch_size, num_queries, height, width)` representing the proposed masks for each
|
||||
query.
|
||||
masks_queries_logits (`torch.FloatTensor`):
|
||||
A tensor of shape `(batch_size, num_queries, num_labels + 1)` representing the proposed classes for each
|
||||
query. Note the `+ 1` is needed because we incorporate the null class.
|
||||
masks_queries_logits (`torch.FloatTensor`):
|
||||
A tensor of shape `(batch_size, num_queries, height, width)` representing the proposed masks for each
|
||||
query.
|
||||
encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
||||
Last hidden states (final feature map) of the last stage of the encoder model (backbone).
|
||||
pixel_decoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
||||
@ -2547,8 +2548,8 @@ class MaskFormerForInstanceSegmentation(MaskFormerPreTrainedModel):
|
||||
>>> masks_queries_logits = outputs.masks_queries_logits
|
||||
|
||||
>>> # you can pass them to feature_extractor for postprocessing
|
||||
>>> output = feature_extractor.post_process_segmentation(outputs)
|
||||
>>> output = feature_extractor.post_process_semantic_segmentation(outputs)
|
||||
>>> output = feature_extractor.post_process_instance_segmentation(outputs)
|
||||
>>> output = feature_extractor.post_process_panoptic_segmentation(outputs)
|
||||
```
|
||||
"""
|
||||
|
@ -29,6 +29,7 @@ if is_torch_available():
|
||||
|
||||
if is_vision_available():
|
||||
from transformers import MaskFormerFeatureExtractor
|
||||
from transformers.models.maskformer.feature_extraction_maskformer import binary_mask_to_rle
|
||||
from transformers.models.maskformer.modeling_maskformer import MaskFormerForInstanceSegmentationOutput
|
||||
|
||||
if is_vision_available():
|
||||
@ -344,6 +345,17 @@ class MaskFormerFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest
|
||||
common(is_instance_map=False, segmentation_type="pil")
|
||||
common(is_instance_map=True, segmentation_type="pil")
|
||||
|
||||
def test_binary_mask_to_rle(self):
|
||||
fake_binary_mask = np.zeros((20, 50))
|
||||
fake_binary_mask[0, 20:] = 1
|
||||
fake_binary_mask[1, :15] = 1
|
||||
fake_binary_mask[5, :10] = 1
|
||||
|
||||
rle = binary_mask_to_rle(fake_binary_mask)
|
||||
self.assertEqual(len(rle), 4)
|
||||
self.assertEqual(rle[0], 21)
|
||||
self.assertEqual(rle[1], 45)
|
||||
|
||||
def test_post_process_segmentation(self):
|
||||
fature_extractor = self.feature_extraction_class(num_labels=self.feature_extract_tester.num_classes)
|
||||
outputs = self.feature_extract_tester.get_fake_maskformer_outputs()
|
||||
@ -373,31 +385,30 @@ class MaskFormerFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest
|
||||
|
||||
segmentation = fature_extractor.post_process_semantic_segmentation(outputs)
|
||||
|
||||
self.assertEqual(len(segmentation), self.feature_extract_tester.batch_size)
|
||||
self.assertEqual(
|
||||
segmentation.shape,
|
||||
segmentation[0].shape,
|
||||
(
|
||||
self.feature_extract_tester.batch_size,
|
||||
self.feature_extract_tester.height,
|
||||
self.feature_extract_tester.width,
|
||||
),
|
||||
)
|
||||
|
||||
target_size = (1, 4)
|
||||
target_sizes = [(1, 4) for i in range(self.feature_extract_tester.batch_size)]
|
||||
segmentation = fature_extractor.post_process_semantic_segmentation(outputs, target_sizes=target_sizes)
|
||||
|
||||
segmentation = fature_extractor.post_process_semantic_segmentation(outputs, target_size=target_size)
|
||||
|
||||
self.assertEqual(segmentation.shape, (self.feature_extract_tester.batch_size, *target_size))
|
||||
self.assertEqual(segmentation[0].shape, target_sizes[0])
|
||||
|
||||
def test_post_process_panoptic_segmentation(self):
|
||||
fature_extractor = self.feature_extraction_class(num_labels=self.feature_extract_tester.num_classes)
|
||||
outputs = self.feature_extract_tester.get_fake_maskformer_outputs()
|
||||
segmentation = fature_extractor.post_process_panoptic_segmentation(outputs, object_mask_threshold=0)
|
||||
segmentation = fature_extractor.post_process_panoptic_segmentation(outputs, threshold=0)
|
||||
|
||||
self.assertTrue(len(segmentation) == self.feature_extract_tester.batch_size)
|
||||
for el in segmentation:
|
||||
self.assertTrue("segmentation" in el)
|
||||
self.assertTrue("segments" in el)
|
||||
self.assertEqual(type(el["segments"]), list)
|
||||
self.assertTrue("segments_info" in el)
|
||||
self.assertEqual(type(el["segments_info"]), list)
|
||||
self.assertEqual(
|
||||
el["segmentation"].shape, (self.feature_extract_tester.height, self.feature_extract_tester.width)
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user