mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 12:50:06 +06:00
Add automatic-mask-generation
pipeline for Segment Anything Model (SAM) (#22840)
* cleanup * updates * more refactoring * make style * update inits * support other inputs in base * update based on review Co-authored-by: Nicolas Patry <patry.nicolas@gmail.com> * Update tests/pipelines/test_pipelines_automatic_mask_generation.py Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com> * update * fixup * TODO x and y to refactor, _h _w refactored here * update docstring * more nits * style on these * more doc fix * rename variables * update * updates * style * update * fix `_mask_to_rle_pytorch` * styling * fix ask to rle, wrong outputs * add device arg * update * more updates, fix tets * udpate * update docstrings * styling * fixup * add notebook on the docs * update orginal sizes * fix docstring * updat condition on point_per-batch * updates tests * fix CI test * extend is required, append does not work! * fixup * fix CI tests * whit pixels left * address doc comments * fix doc * slow pipeline tests * update auto init * add revision * make fixup * update p!ipoeline tag when calling tests * alphabeitcal order in inits * fix copies * last style nits * Apply suggestions from code review Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * reformat docstring * more reformat * address most of the comments * Update src/transformers/pipelines/mask_generation.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * final refactor * Update src/transformers/models/sam/image_processing_sam.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * fixup and fix slow tests * revert --------- Co-authored-by: Nicolas Patry <patry.nicolas@gmail.com> Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com> Co-authored-by: younesbelkada <younesbelkada@gmail.com> Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
parent
e5f3487190
commit
f143037789
@ -64,6 +64,7 @@ scores = outputs.iou_scores
|
||||
Resources:
|
||||
|
||||
- [Demo notebook](https://github.com/huggingface/notebooks/blob/main/examples/segment_anything.ipynb) for using the model
|
||||
- [Demo notebook](https://github.com/huggingface/notebooks/blob/main/examples/automatic_mask_generation.ipynb) for using automatic mask generation pipeline.
|
||||
|
||||
## SamConfig
|
||||
|
||||
|
@ -1012,6 +1012,7 @@ else:
|
||||
"MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING",
|
||||
"MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING",
|
||||
"MODEL_FOR_MASKED_LM_MAPPING",
|
||||
"MODEL_FOR_MASK_GENERATION_MAPPING",
|
||||
"MODEL_FOR_MULTIPLE_CHOICE_MAPPING",
|
||||
"MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING",
|
||||
"MODEL_FOR_OBJECT_DETECTION_MAPPING",
|
||||
@ -4650,6 +4651,7 @@ if TYPE_CHECKING:
|
||||
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
|
||||
MODEL_FOR_IMAGE_SEGMENTATION_MAPPING,
|
||||
MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING,
|
||||
MODEL_FOR_MASK_GENERATION_MAPPING,
|
||||
MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING,
|
||||
MODEL_FOR_MASKED_LM_MAPPING,
|
||||
MODEL_FOR_MULTIPLE_CHOICE_MAPPING,
|
||||
|
@ -52,6 +52,7 @@ else:
|
||||
"MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING",
|
||||
"MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING",
|
||||
"MODEL_FOR_MASKED_LM_MAPPING",
|
||||
"MODEL_FOR_MASK_GENERATION_MAPPING",
|
||||
"MODEL_FOR_MULTIPLE_CHOICE_MAPPING",
|
||||
"MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING",
|
||||
"MODEL_FOR_OBJECT_DETECTION_MAPPING",
|
||||
@ -213,6 +214,7 @@ if TYPE_CHECKING:
|
||||
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
|
||||
MODEL_FOR_IMAGE_SEGMENTATION_MAPPING,
|
||||
MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING,
|
||||
MODEL_FOR_MASK_GENERATION_MAPPING,
|
||||
MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING,
|
||||
MODEL_FOR_MASKED_LM_MAPPING,
|
||||
MODEL_FOR_MULTIPLE_CHOICE_MAPPING,
|
||||
|
@ -977,7 +977,7 @@ MODEL_FOR_BACKBONE_MAPPING_NAMES = OrderedDict(
|
||||
]
|
||||
)
|
||||
|
||||
MODEL_FOR_AUTOMATIC_MASK_GENERATION_MAPPING_NAMES = OrderedDict(
|
||||
MODEL_FOR_MASK_GENERATION_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
("sam", "SamModel"),
|
||||
]
|
||||
@ -1058,9 +1058,11 @@ MODEL_FOR_AUDIO_XVECTOR_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_F
|
||||
|
||||
MODEL_FOR_BACKBONE_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_BACKBONE_MAPPING_NAMES)
|
||||
|
||||
MODEL_FOR_AUTOMATIC_MASK_GENERATION_MAPPING = _LazyAutoMapping(
|
||||
CONFIG_MAPPING_NAMES, MODEL_FOR_AUTOMATIC_MASK_GENERATION_MAPPING_NAMES
|
||||
)
|
||||
MODEL_FOR_MASK_GENERATION_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_MASK_GENERATION_MAPPING_NAMES)
|
||||
|
||||
|
||||
class AutoModelForMaskGeneration(_BaseAutoModelClass):
|
||||
_model_mapping = MODEL_FOR_MASK_GENERATION_MAPPING
|
||||
|
||||
|
||||
class AutoModel(_BaseAutoModelClass):
|
||||
|
@ -13,7 +13,10 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Image processor class for SAM."""
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
import math
|
||||
from copy import deepcopy
|
||||
from itertools import product
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -26,16 +29,20 @@ from ...image_utils import (
|
||||
ImageInput,
|
||||
PILImageResampling,
|
||||
get_image_size,
|
||||
infer_channel_dimension_format,
|
||||
make_list_of_images,
|
||||
to_numpy_array,
|
||||
valid_images,
|
||||
)
|
||||
from ...utils import TensorType, is_torch_available, logging, requires_backends
|
||||
from ...utils import TensorType, is_torch_available, is_torchvision_available, logging, requires_backends
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
if is_torchvision_available():
|
||||
from torchvision.ops.boxes import batched_nms
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
@ -354,10 +361,14 @@ class SamImageProcessor(BaseImageProcessor):
|
||||
images = [self.pad_image(image=image, pad_size=pad_size) for image in images]
|
||||
|
||||
images = [to_channel_dimension_format(image, data_format) for image in images]
|
||||
|
||||
data = {"pixel_values": images, "original_sizes": original_sizes, "reshaped_input_sizes": reshaped_input_sizes}
|
||||
encoded_outputs = BatchFeature(data=data, tensor_type=return_tensors)
|
||||
|
||||
encoded_outputs = BatchFeature(
|
||||
data={
|
||||
"pixel_values": images,
|
||||
"original_sizes": original_sizes,
|
||||
"reshaped_input_sizes": reshaped_input_sizes,
|
||||
},
|
||||
tensor_type=return_tensors,
|
||||
)
|
||||
return encoded_outputs
|
||||
|
||||
def post_process_masks(
|
||||
@ -392,11 +403,453 @@ class SamImageProcessor(BaseImageProcessor):
|
||||
for i, original_size in enumerate(original_sizes):
|
||||
interpolated_mask = F.interpolate(masks[i], target_image_size, mode="bilinear", align_corners=False)
|
||||
interpolated_mask = interpolated_mask[..., : reshaped_input_sizes[i][0], : reshaped_input_sizes[i][1]]
|
||||
interpolated_mask = F.interpolate(
|
||||
interpolated_mask, [*original_size.numpy()], mode="bilinear", align_corners=False
|
||||
)
|
||||
interpolated_mask = F.interpolate(interpolated_mask, original_size, mode="bilinear", align_corners=False)
|
||||
if binarize:
|
||||
interpolated_mask = interpolated_mask > mask_threshold
|
||||
output_masks.append(interpolated_mask)
|
||||
|
||||
return output_masks
|
||||
|
||||
def post_process_for_mask_generation(self, all_masks, all_scores, all_boxes, crops_nms_thresh):
|
||||
"""
|
||||
Post processes mask that are generated by calling the Non Maximum Suppression algorithm on the predicted masks.
|
||||
|
||||
Args:
|
||||
all_masks (`List[torch.Tensor]`):
|
||||
List of all predicted segmentation masks
|
||||
all_scores (`List[torch.Tensor]`):
|
||||
List of all predicted iou scores
|
||||
all_boxes (`List[torch.Tensor]`):
|
||||
List of all bounding boxes of the predicted masks
|
||||
crops_nms_thresh (`float`):
|
||||
Threshold for NMS (Non Maximum Suppression) algorithm.
|
||||
"""
|
||||
return _postprocess_for_mg(all_masks, all_scores, all_boxes, crops_nms_thresh)
|
||||
|
||||
def generate_crop_boxes(
|
||||
self,
|
||||
image,
|
||||
target_size,
|
||||
crop_n_layers: int = 0,
|
||||
overlap_ratio: float = 512 / 1500,
|
||||
points_per_crop: Optional[int] = 32,
|
||||
crop_n_points_downscale_factor: Optional[List[int]] = 1,
|
||||
device: Optional["torch.device"] = None,
|
||||
):
|
||||
"""
|
||||
Generates a list of crop boxes of different sizes. Each layer has (2**i)**2 boxes for the ith layer.
|
||||
|
||||
Args:
|
||||
image (`np.array`):
|
||||
Input original image
|
||||
target_size (`int`):
|
||||
Target size of the resized image
|
||||
crop_n_layers (`int`, *optional*, defaults to 0):
|
||||
If >0, mask prediction will be run again on crops of the image. Sets the number of layers to run, where
|
||||
each layer has 2**i_layer number of image crops.
|
||||
overlap_ratio (`float`, *optional*, defaults to 512/1500):
|
||||
Sets the degree to which crops overlap. In the first crop layer, crops will overlap by this fraction of
|
||||
the image length. Later layers with more crops scale down this overlap.
|
||||
points_per_crop (`int`, *optional*, defaults to 32):
|
||||
Number of points to sample from each crop.
|
||||
crop_n_points_downscale_factor (`List[int]`, *optional*, defaults to 1):
|
||||
The number of points-per-side sampled in layer n is scaled down by crop_n_points_downscale_factor**n.
|
||||
device (`torch.device`, *optional*, defaults to None):
|
||||
Device to use for the computation. If None, cpu will be used.
|
||||
"""
|
||||
return _generate_crop_boxes(
|
||||
image, target_size, crop_n_layers, overlap_ratio, points_per_crop, crop_n_points_downscale_factor, device
|
||||
)
|
||||
|
||||
def filter_masks(
|
||||
self,
|
||||
masks,
|
||||
iou_scores,
|
||||
original_size,
|
||||
cropped_box_image,
|
||||
pred_iou_thresh=0.88,
|
||||
stability_score_thresh=0.95,
|
||||
mask_threshold=0,
|
||||
stability_score_offset=1,
|
||||
):
|
||||
"""
|
||||
Filters the predicted masks by selecting only the ones that meets several criteria. The first criterion being
|
||||
that the iou scores needs to be greater than `pred_iou_thresh`. The second criterion is that the stability
|
||||
score needs to be greater than `stability_score_thresh`. The method also converts the predicted masks to
|
||||
bounding boxes and pad the predicted masks if necessary.
|
||||
|
||||
Args:
|
||||
masks (`torch.Tensor`):
|
||||
Input masks.
|
||||
iou_scores (`torch.Tensor`):
|
||||
List of IoU scores.
|
||||
original_size (`Tuple[int,int]`):
|
||||
Size of the orginal image.
|
||||
cropped_box_image (`np.array`):
|
||||
The cropped image.
|
||||
pred_iou_thresh (`float`, *optional*, defaults to 0.88):
|
||||
The threshold for the iou scores.
|
||||
stability_score_thresh (`float`, *optional*, defaults to 0.95):
|
||||
The threshold for the stability score.
|
||||
mask_threshold (`float`, *optional*, defaults to 0):
|
||||
The threshold for the predicted masks.
|
||||
stability_score_offset (`float`, *optional*, defaults to 1):
|
||||
The offset for the stability score used in the `_compute_stability_score` method.
|
||||
|
||||
"""
|
||||
requires_backends(self, ["torch"])
|
||||
original_height, original_width = original_size
|
||||
iou_scores = iou_scores.flatten(0, 1)
|
||||
masks = masks.flatten(0, 1)
|
||||
|
||||
if masks.shape[0] != iou_scores.shape[0]:
|
||||
raise ValueError("masks and iou_scores must have the same batch size.")
|
||||
|
||||
if masks.device != iou_scores.device:
|
||||
iou_scores = iou_scores.to(masks.device)
|
||||
|
||||
batch_size = masks.shape[0]
|
||||
|
||||
keep_mask = torch.ones(batch_size, dtype=torch.bool, device=masks.device)
|
||||
|
||||
if pred_iou_thresh > 0.0:
|
||||
keep_mask = keep_mask & (iou_scores > pred_iou_thresh)
|
||||
|
||||
# compute stability score
|
||||
if stability_score_thresh > 0.0:
|
||||
stability_scores = _compute_stability_score(masks, mask_threshold, stability_score_offset)
|
||||
keep_mask = keep_mask & (stability_scores > stability_score_thresh)
|
||||
|
||||
scores = iou_scores[keep_mask]
|
||||
masks = masks[keep_mask]
|
||||
|
||||
# binarize masks
|
||||
masks = masks > mask_threshold
|
||||
converted_boxes = _batched_mask_to_box(masks)
|
||||
|
||||
keep_mask = ~_is_box_near_crop_edge(
|
||||
converted_boxes, cropped_box_image, [0, 0, original_width, original_height]
|
||||
)
|
||||
|
||||
scores = scores[keep_mask]
|
||||
masks = masks[keep_mask]
|
||||
converted_boxes = converted_boxes[keep_mask]
|
||||
|
||||
masks = _pad_masks(masks, cropped_box_image, original_height, original_width)
|
||||
# conversion to rle is necessary to run non-maximum suppresion
|
||||
masks = _mask_to_rle_pytorch(masks)
|
||||
|
||||
return masks, scores, converted_boxes
|
||||
|
||||
|
||||
def _compute_stability_score(masks: "torch.Tensor", mask_threshold: float, stability_score_offset: int):
|
||||
# One mask is always contained inside the other.
|
||||
# Save memory by preventing unnecesary cast to torch.int64
|
||||
intersections = (
|
||||
(masks > (mask_threshold + stability_score_offset)).sum(-1, dtype=torch.int16).sum(-1, dtype=torch.int32)
|
||||
)
|
||||
unions = (masks > (mask_threshold - stability_score_offset)).sum(-1, dtype=torch.int16).sum(-1, dtype=torch.int32)
|
||||
stability_scores = intersections / unions
|
||||
return stability_scores
|
||||
|
||||
|
||||
def _build_point_grid(n_per_side: int) -> np.ndarray:
|
||||
"""Generates a 2D grid of points evenly spaced in [0,1]x[0,1]."""
|
||||
offset = 1 / (2 * n_per_side)
|
||||
points_one_side = np.linspace(offset, 1 - offset, n_per_side)
|
||||
points_x = np.tile(points_one_side[None, :], (n_per_side, 1))
|
||||
points_y = np.tile(points_one_side[:, None], (1, n_per_side))
|
||||
points = np.stack([points_x, points_y], axis=-1).reshape(-1, 2)
|
||||
return points
|
||||
|
||||
|
||||
def _normalize_coordinates(
|
||||
target_size: int, coords: np.ndarray, original_size: Tuple[int, int], is_bounding_box=False
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Expects a numpy array of length 2 in the final dimension. Requires the original image size in (height, width)
|
||||
format.
|
||||
"""
|
||||
old_height, old_width = original_size
|
||||
|
||||
scale = target_size * 1.0 / max(old_height, old_width)
|
||||
new_height, new_width = old_height * scale, old_width * scale
|
||||
new_width = int(new_width + 0.5)
|
||||
new_height = int(new_height + 0.5)
|
||||
|
||||
coords = deepcopy(coords).astype(float)
|
||||
|
||||
if is_bounding_box:
|
||||
coords = coords.reshape(-1, 2, 2)
|
||||
|
||||
coords[..., 0] = coords[..., 0] * (new_width / old_width)
|
||||
coords[..., 1] = coords[..., 1] * (new_height / old_height)
|
||||
|
||||
if is_bounding_box:
|
||||
coords = coords.reshape(-1, 4)
|
||||
|
||||
return coords
|
||||
|
||||
|
||||
def _generate_crop_boxes(
|
||||
image,
|
||||
target_size: int, # Is it tuple here?
|
||||
crop_n_layers: int = 0,
|
||||
overlap_ratio: float = 512 / 1500,
|
||||
points_per_crop: Optional[int] = 32,
|
||||
crop_n_points_downscale_factor: Optional[List[int]] = 1,
|
||||
device: Optional["torch.device"] = None,
|
||||
) -> Tuple[List[List[int]], List[int]]:
|
||||
"""
|
||||
Generates a list of crop boxes of different sizes. Each layer has (2**i)**2 boxes for the ith layer.
|
||||
|
||||
Args:
|
||||
image (Union[`numpy.ndarray`, `PIL.Image`, `torch.Tensor`]):
|
||||
Image to generate crops for.
|
||||
target_size (`int`):
|
||||
Size of the smallest crop.
|
||||
crop_n_layers (`int`, *optional*):
|
||||
If `crops_n_layers>0`, mask prediction will be run again on crops of the image. Sets the number of layers
|
||||
to run, where each layer has 2**i_layer number of image crops.
|
||||
overlap_ratio (`int`, *optional*):
|
||||
Sets the degree to which crops overlap. In the first crop layer, crops will overlap by this fraction of the
|
||||
image length. Later layers with more crops scale down this overlap.
|
||||
points_per_crop (`int`, *optional*):
|
||||
Number of points to sample per crop.
|
||||
crop_n_points_downscale_factor (`int`, *optional*):
|
||||
The number of points-per-side sampled in layer n is scaled down by crop_n_points_downscale_factor**n.
|
||||
device (`torch.device`, *optional*):
|
||||
Device to run the crop generation on. Defaults to CPU.
|
||||
"""
|
||||
if device is None:
|
||||
device = torch.device("cpu")
|
||||
|
||||
if isinstance(image, list):
|
||||
raise ValueError("Only one image is allowed for crop generation.")
|
||||
image = to_numpy_array(image)
|
||||
original_size = get_image_size(image)
|
||||
|
||||
points_grid = []
|
||||
for i in range(crop_n_layers + 1):
|
||||
n_points = int(points_per_crop / (crop_n_points_downscale_factor**i))
|
||||
points_grid.append(_build_point_grid(n_points))
|
||||
|
||||
crop_boxes, layer_idxs = _generate_per_layer_crops(crop_n_layers, overlap_ratio, original_size)
|
||||
|
||||
cropped_images, point_grid_per_crop = _generate_crop_images(
|
||||
crop_boxes, image, points_grid, layer_idxs, target_size, original_size
|
||||
)
|
||||
|
||||
crop_boxes = torch.tensor(crop_boxes, dtype=torch.float32, device=device)
|
||||
point_grid_per_crop = np.array([point_grid_per_crop])
|
||||
points_per_crop = torch.tensor(point_grid_per_crop, device=device)
|
||||
points_per_crop = points_per_crop.permute(0, 2, 1, 3)
|
||||
|
||||
input_labels = torch.ones_like(points_per_crop[:, :, :, 0], dtype=torch.long, device=device)
|
||||
|
||||
return crop_boxes, points_per_crop, cropped_images, input_labels
|
||||
|
||||
|
||||
def _generate_per_layer_crops(crop_n_layers, overlap_ratio, original_size):
|
||||
"""
|
||||
Generates 2 ** (layers idx + 1) crops for each crop_n_layers. Crops are in the XYWH format : The XYWH format
|
||||
consists of the following required indices:
|
||||
- X: X coordinate of the top left of the bounding box
|
||||
- Y: Y coordinate of the top left of the bounding box
|
||||
- W: width of the bounding box
|
||||
- H: height of the bounding box
|
||||
"""
|
||||
crop_boxes, layer_idxs = [], []
|
||||
im_height, im_width = original_size
|
||||
short_side = min(im_height, im_width)
|
||||
|
||||
# Original image
|
||||
crop_boxes.append([0, 0, im_width, im_height])
|
||||
layer_idxs.append(0)
|
||||
for i_layer in range(crop_n_layers):
|
||||
n_crops_per_side = 2 ** (i_layer + 1)
|
||||
overlap = int(overlap_ratio * short_side * (2 / n_crops_per_side))
|
||||
|
||||
crop_width = int(math.ceil((overlap * (n_crops_per_side - 1) + im_width) / n_crops_per_side))
|
||||
crop_height = int(math.ceil((overlap * (n_crops_per_side - 1) + im_height) / n_crops_per_side))
|
||||
|
||||
crop_box_x0 = [int((crop_width - overlap) * i) for i in range(n_crops_per_side)]
|
||||
crop_box_y0 = [int((crop_height - overlap) * i) for i in range(n_crops_per_side)]
|
||||
|
||||
for left, top in product(crop_box_x0, crop_box_y0):
|
||||
box = [left, top, min(left + crop_width, im_width), min(top + crop_height, im_height)]
|
||||
crop_boxes.append(box)
|
||||
layer_idxs.append(i_layer + 1)
|
||||
|
||||
return crop_boxes, layer_idxs
|
||||
|
||||
|
||||
def _generate_crop_images(crop_boxes, image, points_grid, layer_idxs, target_size, original_size):
|
||||
"""
|
||||
Takes as an input bounding boxes that are used to crop the image. Based in the crops, the corresponding points are
|
||||
also passed.
|
||||
"""
|
||||
cropped_images = []
|
||||
total_points_per_crop = []
|
||||
for i, crop_box in enumerate(crop_boxes):
|
||||
left, top, right, bottom = crop_box
|
||||
|
||||
channel_dim = infer_channel_dimension_format(image)
|
||||
if channel_dim == ChannelDimension.LAST:
|
||||
cropped_im = image[top:bottom, left:right, :]
|
||||
else:
|
||||
cropped_im = image[:, top:bottom, left:right]
|
||||
|
||||
cropped_images.append(cropped_im)
|
||||
|
||||
cropped_im_size = get_image_size(cropped_im)
|
||||
points_scale = np.array(cropped_im_size)[None, ::-1]
|
||||
|
||||
points = points_grid[layer_idxs[i]] * points_scale
|
||||
normalized_points = _normalize_coordinates(target_size, points, original_size)
|
||||
total_points_per_crop.append(normalized_points)
|
||||
|
||||
return cropped_images, total_points_per_crop
|
||||
|
||||
|
||||
def _pad_masks(masks, crop_box: List[int], orig_height: int, orig_width: int):
|
||||
left, top, right, bottom = crop_box
|
||||
if left == 0 and top == 0 and right == orig_width and bottom == orig_height:
|
||||
return masks
|
||||
# Coordinate transform masks
|
||||
pad_x, pad_y = orig_width - (right - left), orig_height - (bottom - top)
|
||||
pad = (left, pad_x - left, top, pad_y - top)
|
||||
return torch.nn.functional.pad(masks, pad, value=0)
|
||||
|
||||
|
||||
def _is_box_near_crop_edge(boxes, crop_box, orig_box, atol=20.0):
|
||||
"""Filter masks at the edge of a crop, but not at the edge of the original image."""
|
||||
crop_box_torch = torch.as_tensor(crop_box, dtype=torch.float, device=boxes.device)
|
||||
orig_box_torch = torch.as_tensor(orig_box, dtype=torch.float, device=boxes.device)
|
||||
|
||||
left, top, _, _ = crop_box
|
||||
offset = torch.tensor([[left, top, left, top]], device=boxes.device)
|
||||
# Check if boxes has a channel dimension
|
||||
if len(boxes.shape) == 3:
|
||||
offset = offset.unsqueeze(1)
|
||||
boxes = (boxes + offset).float()
|
||||
|
||||
near_crop_edge = torch.isclose(boxes, crop_box_torch[None, :], atol=atol, rtol=0)
|
||||
near_image_edge = torch.isclose(boxes, orig_box_torch[None, :], atol=atol, rtol=0)
|
||||
near_crop_edge = torch.logical_and(near_crop_edge, ~near_image_edge)
|
||||
return torch.any(near_crop_edge, dim=1)
|
||||
|
||||
|
||||
def _batched_mask_to_box(masks: "torch.Tensor"):
|
||||
"""
|
||||
Computes the bounding boxes around the given input masks. The bounding boxes are in the XYXY format which
|
||||
corresponds the following required indices:
|
||||
- LEFT: left hand side of the bounding box
|
||||
- TOP: top of the bounding box
|
||||
- RIGHT: right of the bounding box
|
||||
- BOTTOM: bottom of the bounding box
|
||||
|
||||
Return [0,0,0,0] for an empty mask. For input shape channel_1 x channel_2 x ... x height x width, the output shape
|
||||
is channel_1 x channel_2 x ... x 4.
|
||||
|
||||
Args:
|
||||
- masks (`torch.Tensor` of shape `(batch, nb_mask, height, width)`)
|
||||
"""
|
||||
# torch.max below raises an error on empty inputs, just skip in this case
|
||||
|
||||
if torch.numel(masks) == 0:
|
||||
return torch.zeros(*masks.shape[:-2], 4, device=masks.device)
|
||||
|
||||
# Normalize shape to Cxheightxwidth
|
||||
shape = masks.shape
|
||||
height, width = shape[-2:]
|
||||
|
||||
# Get top and bottom edges
|
||||
in_height, _ = torch.max(masks, dim=-1)
|
||||
in_height_coords = in_height * torch.arange(height, device=in_height.device)[None, :]
|
||||
bottom_edges, _ = torch.max(in_height_coords, dim=-1)
|
||||
in_height_coords = in_height_coords + height * (~in_height)
|
||||
top_edges, _ = torch.min(in_height_coords, dim=-1)
|
||||
|
||||
# Get left and right edges
|
||||
in_width, _ = torch.max(masks, dim=-2)
|
||||
in_width_coords = in_width * torch.arange(width, device=in_width.device)[None, :]
|
||||
right_edges, _ = torch.max(in_width_coords, dim=-1)
|
||||
in_width_coords = in_width_coords + width * (~in_width)
|
||||
left_edges, _ = torch.min(in_width_coords, dim=-1)
|
||||
|
||||
# If the mask is empty the right edge will be to the left of the left edge.
|
||||
# Replace these boxes with [0, 0, 0, 0]
|
||||
empty_filter = (right_edges < left_edges) | (bottom_edges < top_edges)
|
||||
out = torch.stack([left_edges, top_edges, right_edges, bottom_edges], dim=-1)
|
||||
out = out * (~empty_filter).unsqueeze(-1)
|
||||
|
||||
# Return to original shape
|
||||
out = out.reshape(*shape[:-2], 4)
|
||||
return out
|
||||
|
||||
|
||||
def _mask_to_rle_pytorch(input_mask: "torch.Tensor"):
|
||||
"""
|
||||
Encodes masks the run-length encoding (RLE), in the format expected by pycoco tools.
|
||||
"""
|
||||
# Put in fortran order and flatten height and width
|
||||
batch_size, height, width = input_mask.shape
|
||||
input_mask = input_mask.permute(0, 2, 1).flatten(1)
|
||||
|
||||
# Compute change indices
|
||||
diff = input_mask[:, 1:] ^ input_mask[:, :-1]
|
||||
change_indices = diff.nonzero()
|
||||
|
||||
# Encode run length
|
||||
out = []
|
||||
for i in range(batch_size):
|
||||
cur_idxs = change_indices[change_indices[:, 0] == i, 1] + 1
|
||||
btw_idxs = cur_idxs[1:] - cur_idxs[:-1]
|
||||
counts = [] if input_mask[i, 0] == 0 else [0]
|
||||
counts += [cur_idxs[0].item()] + btw_idxs.tolist() + [height * width - cur_idxs[-1]]
|
||||
out.append({"size": [height, width], "counts": counts})
|
||||
return out
|
||||
|
||||
|
||||
def _rle_to_mask(rle: Dict[str, Any]) -> np.ndarray:
|
||||
"""Compute a binary mask from an uncompressed RLE."""
|
||||
height, width = rle["size"]
|
||||
mask = np.empty(height * width, dtype=bool)
|
||||
idx = 0
|
||||
parity = False
|
||||
for count in rle["counts"]:
|
||||
mask[idx : idx + count] = parity
|
||||
idx += count
|
||||
parity = not parity
|
||||
mask = mask.reshape(width, height)
|
||||
return mask.transpose() # Reshape to original shape
|
||||
|
||||
|
||||
def _postprocess_for_mg(rle_masks, iou_scores, mask_boxes, amg_crops_nms_thresh=0.7):
|
||||
"""
|
||||
Perform NMS (Non Maxium Suppression) on the outputs.
|
||||
|
||||
Args:
|
||||
rle_masks (`torch.Tensor`):
|
||||
binary masks in the RLE format
|
||||
iou_scores (`torch.Tensor` of shape (nb_masks, 1)):
|
||||
iou_scores predicted by the model
|
||||
mask_boxes (`torch.Tensor`):
|
||||
The bounding boxes corresponding to segmentation masks
|
||||
amg_crops_nms_thresh (`float`, *optional*, defaults to 0.7):
|
||||
NMS threshold.
|
||||
"""
|
||||
keep_by_nms = batched_nms(
|
||||
boxes=mask_boxes.float(),
|
||||
scores=iou_scores,
|
||||
idxs=torch.zeros(mask_boxes.shape[0]),
|
||||
iou_threshold=amg_crops_nms_thresh,
|
||||
)
|
||||
|
||||
iou_scores = iou_scores[keep_by_nms]
|
||||
rle_masks = [rle_masks[i] for i in keep_by_nms]
|
||||
mask_boxes = mask_boxes[keep_by_nms]
|
||||
masks = [_rle_to_mask(rle) for rle in rle_masks]
|
||||
|
||||
return masks, iou_scores, rle_masks, mask_boxes
|
||||
|
@ -63,6 +63,7 @@ from .fill_mask import FillMaskPipeline
|
||||
from .image_classification import ImageClassificationPipeline
|
||||
from .image_segmentation import ImageSegmentationPipeline
|
||||
from .image_to_text import ImageToTextPipeline
|
||||
from .mask_generation import MaskGenerationPipeline
|
||||
from .object_detection import ObjectDetectionPipeline
|
||||
from .question_answering import QuestionAnsweringArgumentHandler, QuestionAnsweringPipeline
|
||||
from .table_question_answering import TableQuestionAnsweringArgumentHandler, TableQuestionAnsweringPipeline
|
||||
@ -124,6 +125,7 @@ if is_torch_available():
|
||||
AutoModelForImageClassification,
|
||||
AutoModelForImageSegmentation,
|
||||
AutoModelForMaskedLM,
|
||||
AutoModelForMaskGeneration,
|
||||
AutoModelForObjectDetection,
|
||||
AutoModelForQuestionAnswering,
|
||||
AutoModelForSemanticSegmentation,
|
||||
@ -384,6 +386,13 @@ SUPPORTED_TASKS = {
|
||||
"default": {"model": {"pt": ("MCG-NJU/videomae-base-finetuned-kinetics", "4800870")}},
|
||||
"type": "video",
|
||||
},
|
||||
"mask-generation": {
|
||||
"impl": MaskGenerationPipeline,
|
||||
"tf": (),
|
||||
"pt": (AutoModelForMaskGeneration,) if is_torch_available() else (),
|
||||
"default": {"model": {"pt": ("facebook/sam-vit-huge", "997b15")}},
|
||||
"type": "multimodal",
|
||||
},
|
||||
}
|
||||
|
||||
NO_FEATURE_EXTRACTOR_TASKS = set()
|
||||
@ -536,6 +545,7 @@ def pipeline(
|
||||
- `"image-classification"`: will return a [`ImageClassificationPipeline`].
|
||||
- `"image-segmentation"`: will return a [`ImageSegmentationPipeline`].
|
||||
- `"image-to-text"`: will return a [`ImageToTextPipeline`].
|
||||
- `"mask-generation"`: will return a [`MaskGenerationPipeline`].
|
||||
- `"object-detection"`: will return a [`ObjectDetectionPipeline`].
|
||||
- `"question-answering"`: will return a [`QuestionAnsweringPipeline`].
|
||||
- `"summarization"`: will return a [`SummarizationPipeline`].
|
||||
|
@ -97,6 +97,8 @@ def _pad(items, key, padding_value, padding_side):
|
||||
tensor = torch.zeros((batch_size, max_length), dtype=dtype) + padding_value
|
||||
elif dim == 3:
|
||||
tensor = torch.zeros((batch_size, max_length, shape[-1]), dtype=dtype) + padding_value
|
||||
elif dim == 4:
|
||||
tensor = torch.zeros((batch_size, max_length, shape[-2], shape[-1]), dtype=dtype) + padding_value
|
||||
|
||||
for i, item in enumerate(items):
|
||||
if dim == 2:
|
||||
@ -109,6 +111,12 @@ def _pad(items, key, padding_value, padding_side):
|
||||
tensor[i, -len(item[key][0]) :, :] = item[key][0].clone()
|
||||
else:
|
||||
tensor[i, : len(item[key][0]), :] = item[key][0].clone()
|
||||
elif dim == 4:
|
||||
if padding_side == "left":
|
||||
tensor[i, -len(item[key][0]) :, :, :] = item[key][0].clone()
|
||||
else:
|
||||
tensor[i, : len(item[key][0]), :, :] = item[key][0].clone()
|
||||
|
||||
return tensor
|
||||
else:
|
||||
return [item[key] for item in items]
|
||||
|
@ -81,11 +81,11 @@ class ImageSegmentationPipeline(Pipeline):
|
||||
)
|
||||
|
||||
def _sanitize_parameters(self, **kwargs):
|
||||
preprocessor_kwargs = {}
|
||||
preprocess_kwargs = {}
|
||||
postprocess_kwargs = {}
|
||||
if "subtask" in kwargs:
|
||||
postprocess_kwargs["subtask"] = kwargs["subtask"]
|
||||
preprocessor_kwargs["subtask"] = kwargs["subtask"]
|
||||
preprocess_kwargs["subtask"] = kwargs["subtask"]
|
||||
if "threshold" in kwargs:
|
||||
postprocess_kwargs["threshold"] = kwargs["threshold"]
|
||||
if "mask_threshold" in kwargs:
|
||||
@ -93,7 +93,7 @@ class ImageSegmentationPipeline(Pipeline):
|
||||
if "overlap_mask_area_threshold" in kwargs:
|
||||
postprocess_kwargs["overlap_mask_area_threshold"] = kwargs["overlap_mask_area_threshold"]
|
||||
|
||||
return preprocessor_kwargs, {}, postprocess_kwargs
|
||||
return preprocess_kwargs, {}, postprocess_kwargs
|
||||
|
||||
def __call__(self, images, **kwargs) -> Union[Predictions, List[Prediction]]:
|
||||
"""
|
||||
|
286
src/transformers/pipelines/mask_generation.py
Normal file
286
src/transformers/pipelines/mask_generation.py
Normal file
@ -0,0 +1,286 @@
|
||||
from collections import defaultdict
|
||||
from typing import Optional
|
||||
|
||||
from ..image_utils import load_image
|
||||
from ..utils import (
|
||||
add_end_docstrings,
|
||||
is_torch_available,
|
||||
logging,
|
||||
requires_backends,
|
||||
)
|
||||
from .base import PIPELINE_INIT_ARGS, ChunkPipeline
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from ..models.auto.modeling_auto import MODEL_FOR_MASK_GENERATION_MAPPING
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@add_end_docstrings(PIPELINE_INIT_ARGS)
|
||||
class MaskGenerationPipeline(ChunkPipeline):
|
||||
"""
|
||||
Automatic mask generation for images using `SamForMaskGeneration`. This pipeline predicts binary masks for an
|
||||
image, given an image. It is a `ChunkPipeline` because you can seperate the points in a mini-batch in order to
|
||||
avoid OOM issues. Use the `points_per_batch` argument to control the number of points that will be processed at the
|
||||
same time. Default is `64`.
|
||||
|
||||
The pipeline works in 3 steps:
|
||||
1. `preprocess`: A grid of 1024 points evenly separated is generated along with bounding boxes and point
|
||||
labels.
|
||||
For more details on how the points and bounding boxes are created, check the `_generate_crop_boxes`
|
||||
function. The image is also preprocessed using the `image_processor`. This function `yields` a minibatch of
|
||||
`points_per_batch`.
|
||||
|
||||
2. `forward`: feeds the outputs of `preprocess` to the model. The image embedding is computed only once.
|
||||
Calls both `self.model.get_image_embeddings` and makes sure that the gradients are not computed, and the
|
||||
tensors and models are on the same device.
|
||||
|
||||
3. `postprocess`: The most important part of the automatic mask generation happens here. Three steps
|
||||
are induced:
|
||||
- image_processor.postprocess_masks (run on each minibatch loop): takes in the raw output masks,
|
||||
resizes them according
|
||||
to the image size, and transforms there to binary masks.
|
||||
- image_processor.filter_masks (on each minibatch loop): uses both `pred_iou_thresh` and
|
||||
`stability_scores`. Also
|
||||
applies a variety of filters based on non maximum suppression to remove bad masks.
|
||||
- image_processor.postprocess_masks_for_amg applies the NSM on the mask to only keep relevant ones.
|
||||
|
||||
Arguments:
|
||||
model ([`PreTrainedModel`] or [`TFPreTrainedModel`]):
|
||||
The model that will be used by the pipeline to make predictions. This needs to be a model inheriting from
|
||||
[`PreTrainedModel`] for PyTorch and [`TFPreTrainedModel`] for TensorFlow.
|
||||
tokenizer ([`PreTrainedTokenizer`]):
|
||||
The tokenizer that will be used by the pipeline to encode data for the model. This object inherits from
|
||||
[`PreTrainedTokenizer`].
|
||||
feature_extractor ([`SequenceFeatureExtractor`]):
|
||||
The feature extractor that will be used by the pipeline to encode the input.
|
||||
points_per_batch (*optional*, int, default to 64):
|
||||
Sets the number of points run simultaneously by the model. Higher numbers may be faster but use more GPU
|
||||
memory.
|
||||
output_bboxes_mask (`bool`, *optional*, default to `False`):
|
||||
Whether or not to output the bounding box predictions.
|
||||
output_rle_masks (`bool`, *optional*, default to `False`):
|
||||
Whether or not to output the masks in `RLE` format
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import pipeline
|
||||
|
||||
>>> generator = pipeline(model="facebook/sam-vit-h", task="mask-generation")
|
||||
>>> outputs = generator(
|
||||
... "http://images.cocodataset.org/val2017/000000039769.jpg",
|
||||
... )
|
||||
|
||||
>>> outputs = generator(
|
||||
... "https://huggingface.co/datasets/Narsil/image_dummy/raw/main/parrots.png", points_per_batch=128
|
||||
... )
|
||||
```
|
||||
|
||||
Learn more about the basics of using a pipeline in the [pipeline tutorial](../pipeline_tutorial)
|
||||
|
||||
This segmentation pipeline can currently be loaded from [`pipeline`] using the following task identifier:
|
||||
`"mask-generation"`.
|
||||
|
||||
See the list of available models on [huggingface.co/models](https://huggingface.co/models?filter=mask-generation).
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
requires_backends(self, "vision")
|
||||
requires_backends(self, "torch")
|
||||
|
||||
if self.framework != "pt":
|
||||
raise ValueError(f"The {self.__class__} is only available in PyTorch.")
|
||||
|
||||
self.check_model_type(MODEL_FOR_MASK_GENERATION_MAPPING)
|
||||
|
||||
def _sanitize_parameters(self, **kwargs):
|
||||
preprocess_kwargs = {}
|
||||
postprocess_kwargs = {}
|
||||
forward_params = {}
|
||||
# preprocess args
|
||||
if "points_per_batch" in kwargs:
|
||||
preprocess_kwargs["points_per_batch"] = kwargs["points_per_batch"]
|
||||
if "points_per_crop" in kwargs:
|
||||
preprocess_kwargs["points_per_crop"] = kwargs["points_per_crop"]
|
||||
if "crops_n_layers" in kwargs:
|
||||
preprocess_kwargs["crops_n_layers"] = kwargs["crops_n_layers"]
|
||||
if "crop_overlap_ratio" in kwargs:
|
||||
preprocess_kwargs["crop_overlap_ratio"] = kwargs["crop_overlap_ratio"]
|
||||
if "crop_n_points_downscale_factor" in kwargs:
|
||||
preprocess_kwargs["crop_n_points_downscale_factor"] = kwargs["crop_n_points_downscale_factor"]
|
||||
# postprocess args
|
||||
if "pred_iou_thresh" in kwargs:
|
||||
forward_params["pred_iou_thresh"] = kwargs["pred_iou_thresh"]
|
||||
if "stability_score_offset" in kwargs:
|
||||
forward_params["stability_score_offset"] = kwargs["stability_score_offset"]
|
||||
if "mask_threshold" in kwargs:
|
||||
forward_params["mask_threshold"] = kwargs["mask_threshold"]
|
||||
if "stability_score_thresh" in kwargs:
|
||||
forward_params["stability_score_thresh"] = kwargs["stability_score_thresh"]
|
||||
if "crops_nms_thresh" in kwargs:
|
||||
postprocess_kwargs["crops_nms_thresh"] = kwargs["crops_nms_thresh"]
|
||||
if "output_rle_mask" in kwargs:
|
||||
postprocess_kwargs["output_rle_mask"] = kwargs["output_rle_mask"]
|
||||
if "output_bboxes_mask" in kwargs:
|
||||
postprocess_kwargs["output_bboxes_mask"] = kwargs["output_bboxes_mask"]
|
||||
return preprocess_kwargs, forward_params, postprocess_kwargs
|
||||
|
||||
def __call__(self, image, *args, num_workers=None, batch_size=None, **kwargs):
|
||||
"""
|
||||
Generates binary segmentation masks
|
||||
|
||||
Args:
|
||||
inputs (`np.ndarray` or `bytes` or `str` or `dict`):
|
||||
Image or list of images.
|
||||
mask_threshold (`float`, *optional*, defaults to 0.0):
|
||||
Threshold to use when turning the predicted masks into binary values.
|
||||
pred_iou_thresh (`float`, *optional*, defaults to 0.88):
|
||||
A filtering threshold in `[0,1]` applied on the model's predicted mask quality.
|
||||
stability_score_thresh (`float`, *optional*, defaults to 0.95):
|
||||
A filtering threshold in `[0,1]`, using the stability of the mask under changes to the cutoff used to
|
||||
binarize the model's mask predictions.
|
||||
stability_score_offset (`int`, *optional*, defaults to 1):
|
||||
The amount to shift the cutoff when calculated the stability score.
|
||||
crops_nms_thresh (`float`, *optional*, defaults to 0.7):
|
||||
The box IoU cutoff used by non-maximal suppression to filter duplicate masks.
|
||||
crops_n_layers (`int`, *optional*, defaults to 0):
|
||||
If `crops_n_layers>0`, mask prediction will be run again on crops of the image. Sets the number of
|
||||
layers to run, where each layer has 2**i_layer number of image crops.
|
||||
crop_overlap_ratio (`float`, *optional*, defaults to `512 / 1500`):
|
||||
Sets the degree to which crops overlap. In the first crop layer, crops will overlap by this fraction of
|
||||
the image length. Later layers with more crops scale down this overlap.
|
||||
crop_n_points_downscale_factor (`int`, *optional*, defaults to `1`):
|
||||
The number of points-per-side sampled in layer n is scaled down by crop_n_points_downscale_factor**n.
|
||||
|
||||
Return:
|
||||
`Dict`: A dictionary with the following keys:
|
||||
- **mask** (`PIL.Image`) -- A binary mask of the detected object as a PIL Image of shape `(width,
|
||||
height)` of the original image. Returns a mask filled with zeros if no object is found.
|
||||
- **score** (*optional* `float`) -- Optionally, when the model is capable of estimating a confidence of
|
||||
the "object" described by the label and the mask.
|
||||
|
||||
"""
|
||||
return super().__call__(image, *args, num_workers=num_workers, batch_size=batch_size, **kwargs)
|
||||
|
||||
def preprocess(
|
||||
self,
|
||||
image,
|
||||
points_per_batch=64,
|
||||
crops_n_layers: int = 0,
|
||||
crop_overlap_ratio: float = 512 / 1500,
|
||||
points_per_crop: Optional[int] = 32,
|
||||
crop_n_points_downscale_factor: Optional[int] = 1,
|
||||
):
|
||||
image = load_image(image)
|
||||
target_size = self.image_processor.size["longest_edge"]
|
||||
crop_boxes, grid_points, cropped_images, input_labels = self.image_processor.generate_crop_boxes(
|
||||
image, target_size, crops_n_layers, crop_overlap_ratio, points_per_crop, crop_n_points_downscale_factor
|
||||
)
|
||||
model_inputs = self.image_processor(images=cropped_images, return_tensors="pt")
|
||||
|
||||
with self.device_placement():
|
||||
if self.framework == "pt":
|
||||
inference_context = self.get_inference_context()
|
||||
with inference_context():
|
||||
model_inputs = self._ensure_tensor_on_device(model_inputs, device=self.device)
|
||||
image_embeddings = self.model.get_image_embeddings(model_inputs.pop("pixel_values"))
|
||||
model_inputs["image_embeddings"] = image_embeddings
|
||||
|
||||
n_points = grid_points.shape[1]
|
||||
points_per_batch = points_per_batch if points_per_batch is not None else n_points
|
||||
|
||||
if points_per_batch <= 0:
|
||||
raise ValueError(
|
||||
"Cannot have points_per_batch<=0. Must be >=1 to returned batched outputs. "
|
||||
"To return all points at once, set points_per_batch to None"
|
||||
)
|
||||
|
||||
for i in range(0, n_points, points_per_batch):
|
||||
batched_points = grid_points[:, i : i + points_per_batch, :, :]
|
||||
labels = input_labels[:, i : i + points_per_batch]
|
||||
is_last = i == n_points - points_per_batch
|
||||
yield {
|
||||
"input_points": batched_points,
|
||||
"input_labels": labels,
|
||||
"input_boxes": crop_boxes,
|
||||
"is_last": is_last,
|
||||
**model_inputs,
|
||||
}
|
||||
|
||||
def _forward(
|
||||
self,
|
||||
model_inputs,
|
||||
pred_iou_thresh=0.88,
|
||||
stability_score_thresh=0.95,
|
||||
mask_threshold=0,
|
||||
stability_score_offset=1,
|
||||
):
|
||||
input_boxes = model_inputs.pop("input_boxes")
|
||||
is_last = model_inputs.pop("is_last")
|
||||
original_sizes = model_inputs.pop("original_sizes").tolist()
|
||||
reshaped_input_sizes = model_inputs.pop("reshaped_input_sizes").tolist()
|
||||
|
||||
model_outputs = self.model(**model_inputs)
|
||||
|
||||
# post processing happens here in order to avoid CPU GPU copies of ALL the masks
|
||||
low_resolution_masks = model_outputs["pred_masks"]
|
||||
masks = self.image_processor.post_process_masks(
|
||||
low_resolution_masks, original_sizes, reshaped_input_sizes, mask_threshold, binarize=False
|
||||
)
|
||||
iou_scores = model_outputs["iou_scores"]
|
||||
masks, iou_scores, boxes = self.image_processor.filter_masks(
|
||||
masks[0],
|
||||
iou_scores[0],
|
||||
original_sizes[0],
|
||||
input_boxes[0],
|
||||
pred_iou_thresh,
|
||||
stability_score_thresh,
|
||||
mask_threshold,
|
||||
stability_score_offset,
|
||||
)
|
||||
return {
|
||||
"masks": masks,
|
||||
"is_last": is_last,
|
||||
"boxes": boxes,
|
||||
"iou_scores": iou_scores,
|
||||
}
|
||||
|
||||
def postprocess(
|
||||
self,
|
||||
model_outputs,
|
||||
output_rle_mask=False,
|
||||
output_bboxes_mask=False,
|
||||
crops_nms_thresh=0.7,
|
||||
):
|
||||
all_scores = []
|
||||
all_masks = []
|
||||
all_boxes = []
|
||||
for model_output in model_outputs:
|
||||
all_scores.append(model_output.pop("iou_scores"))
|
||||
all_masks.extend(model_output.pop("masks"))
|
||||
all_boxes.append(model_output.pop("boxes"))
|
||||
|
||||
all_scores = torch.cat(all_scores)
|
||||
all_boxes = torch.cat(all_boxes)
|
||||
output_masks, iou_scores, rle_mask, bounding_boxes = self.image_processor.post_process_for_mask_generation(
|
||||
all_masks, all_scores, all_boxes, crops_nms_thresh
|
||||
)
|
||||
|
||||
extra = defaultdict(list)
|
||||
for output in model_outputs:
|
||||
for k, v in output.items():
|
||||
extra[k].append(v)
|
||||
|
||||
optional = {}
|
||||
if output_rle_mask:
|
||||
optional["rle_mask"] = rle_mask
|
||||
|
||||
if output_bboxes_mask:
|
||||
optional["bounding_boxes"] = bounding_boxes
|
||||
|
||||
return {"masks": output_masks, "scores": iou_scores, **optional, **extra}
|
@ -475,6 +475,9 @@ MODEL_FOR_IMAGE_SEGMENTATION_MAPPING = None
|
||||
MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING = None
|
||||
|
||||
|
||||
MODEL_FOR_MASK_GENERATION_MAPPING = None
|
||||
|
||||
|
||||
MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING = None
|
||||
|
||||
|
||||
|
@ -464,8 +464,8 @@ class SamModelIntegrationTest(unittest.TestCase):
|
||||
self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.5798), atol=1e-4))
|
||||
|
||||
def test_inference_mask_generation_one_point_one_bb(self):
|
||||
model = SamModel.from_pretrained("facebook/sam-vit-h")
|
||||
processor = SamProcessor.from_pretrained("facebook/sam-vit-h")
|
||||
model = SamModel.from_pretrained("facebook/sam-vit-huge")
|
||||
processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
|
||||
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
142
tests/pipelines/test_pipelines_mask_generation.py
Normal file
142
tests/pipelines/test_pipelines_mask_generation.py
Normal file
@ -0,0 +1,142 @@
|
||||
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import hashlib
|
||||
import unittest
|
||||
from typing import Dict
|
||||
|
||||
import numpy as np
|
||||
|
||||
from transformers import MODEL_FOR_MASK_GENERATION_MAPPING, is_vision_available, pipeline
|
||||
from transformers.pipelines import MaskGenerationPipeline
|
||||
from transformers.testing_utils import (
|
||||
is_pipeline_test,
|
||||
nested_simplify,
|
||||
require_tf,
|
||||
require_torch,
|
||||
require_vision,
|
||||
slow,
|
||||
)
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
|
||||
|
||||
def hashimage(image: Image) -> str:
|
||||
m = hashlib.md5(image.tobytes())
|
||||
return m.hexdigest()[:10]
|
||||
|
||||
|
||||
def mask_to_test_readable(mask: Image) -> Dict:
|
||||
npimg = np.array(mask)
|
||||
shape = npimg.shape
|
||||
return {"hash": hashimage(mask), "shape": shape}
|
||||
|
||||
|
||||
@is_pipeline_test
|
||||
@require_vision
|
||||
@require_torch
|
||||
class MaskGenerationPipelineTests(unittest.TestCase):
|
||||
model_mapping = dict(
|
||||
(list(MODEL_FOR_MASK_GENERATION_MAPPING.items()) if MODEL_FOR_MASK_GENERATION_MAPPING else [])
|
||||
)
|
||||
|
||||
def get_test_pipeline(self, model, tokenizer, processor):
|
||||
image_segmenter = MaskGenerationPipeline(model=model, image_processor=processor)
|
||||
return image_segmenter, [
|
||||
"./tests/fixtures/tests_samples/COCO/000000039769.png",
|
||||
"./tests/fixtures/tests_samples/COCO/000000039769.png",
|
||||
]
|
||||
|
||||
@require_tf
|
||||
@unittest.skip("Image segmentation not implemented in TF")
|
||||
def test_small_model_tf(self):
|
||||
pass
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
def test_small_model_pt(self):
|
||||
image_segmenter = pipeline("mask-generation", model="facebook/sam-vit-huge")
|
||||
|
||||
outputs = image_segmenter("http://images.cocodataset.org/val2017/000000039769.jpg", points_per_batch=256)
|
||||
|
||||
# Shortening by hashing
|
||||
new_outupt = []
|
||||
for i, o in enumerate(outputs["masks"]):
|
||||
new_outupt += [{"mask": mask_to_test_readable(o), "scores": outputs["scores"][i]}]
|
||||
|
||||
# fmt: off
|
||||
self.assertEqual(
|
||||
nested_simplify(new_outupt, decimals=4),
|
||||
[
|
||||
{'mask': {'hash': '115ad19f5f', 'shape': (480, 640)}, 'scores': 1.0444},
|
||||
{'mask': {'hash': '6affa964c6', 'shape': (480, 640)}, 'scores': 1.021},
|
||||
{'mask': {'hash': 'dfe28a0388', 'shape': (480, 640)}, 'scores': 1.0167},
|
||||
{'mask': {'hash': 'c0a5f4a318', 'shape': (480, 640)}, 'scores': 1.0132},
|
||||
{'mask': {'hash': 'fe8065c197', 'shape': (480, 640)}, 'scores': 1.0053},
|
||||
{'mask': {'hash': 'e2d0b7a0b7', 'shape': (480, 640)}, 'scores': 0.9967},
|
||||
{'mask': {'hash': '453c7844bd', 'shape': (480, 640)}, 'scores': 0.993},
|
||||
{'mask': {'hash': '3d44f2926d', 'shape': (480, 640)}, 'scores': 0.9909},
|
||||
{'mask': {'hash': '64033ddc3f', 'shape': (480, 640)}, 'scores': 0.9879},
|
||||
{'mask': {'hash': '801064ff79', 'shape': (480, 640)}, 'scores': 0.9834},
|
||||
{'mask': {'hash': '6172f276ef', 'shape': (480, 640)}, 'scores': 0.9716},
|
||||
{'mask': {'hash': 'b49e60e084', 'shape': (480, 640)}, 'scores': 0.9612},
|
||||
{'mask': {'hash': 'a811e775fd', 'shape': (480, 640)}, 'scores': 0.9599},
|
||||
{'mask': {'hash': 'a6a8ebcf4b', 'shape': (480, 640)}, 'scores': 0.9552},
|
||||
{'mask': {'hash': '9d8257e080', 'shape': (480, 640)}, 'scores': 0.9532},
|
||||
{'mask': {'hash': '32de6454a8', 'shape': (480, 640)}, 'scores': 0.9516},
|
||||
{'mask': {'hash': 'af3d4af2c8', 'shape': (480, 640)}, 'scores': 0.9499},
|
||||
{'mask': {'hash': '3c6db475fb', 'shape': (480, 640)}, 'scores': 0.9483},
|
||||
{'mask': {'hash': 'c290813fb9', 'shape': (480, 640)}, 'scores': 0.9464},
|
||||
{'mask': {'hash': 'b6f0b8f606', 'shape': (480, 640)}, 'scores': 0.943},
|
||||
{'mask': {'hash': '92ce16bfdf', 'shape': (480, 640)}, 'scores': 0.943},
|
||||
{'mask': {'hash': 'c749b25868', 'shape': (480, 640)}, 'scores': 0.9408},
|
||||
{'mask': {'hash': 'efb6cab859', 'shape': (480, 640)}, 'scores': 0.9335},
|
||||
{'mask': {'hash': '1ff2eafb30', 'shape': (480, 640)}, 'scores': 0.9326},
|
||||
{'mask': {'hash': '788b798e24', 'shape': (480, 640)}, 'scores': 0.9262},
|
||||
{'mask': {'hash': 'abea804f0e', 'shape': (480, 640)}, 'scores': 0.8999},
|
||||
{'mask': {'hash': '7b9e8ddb73', 'shape': (480, 640)}, 'scores': 0.8986},
|
||||
{'mask': {'hash': 'cd24047c8a', 'shape': (480, 640)}, 'scores': 0.8984},
|
||||
{'mask': {'hash': '6943e6bcbd', 'shape': (480, 640)}, 'scores': 0.8873},
|
||||
{'mask': {'hash': 'b5f47c9191', 'shape': (480, 640)}, 'scores': 0.8871}
|
||||
],
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
@require_torch
|
||||
@slow
|
||||
def test_threshold(self):
|
||||
model_id = "facebook/sam-vit-huge"
|
||||
image_segmenter = pipeline("mask-generation", model=model_id)
|
||||
|
||||
outputs = image_segmenter(
|
||||
"http://images.cocodataset.org/val2017/000000039769.jpg", pred_iou_thresh=1, points_per_batch=256
|
||||
)
|
||||
|
||||
# Shortening by hashing
|
||||
new_outupt = []
|
||||
for i, o in enumerate(outputs["masks"]):
|
||||
new_outupt += [{"mask": mask_to_test_readable(o), "scores": outputs["scores"][i]}]
|
||||
|
||||
self.assertEqual(
|
||||
nested_simplify(new_outupt, decimals=4),
|
||||
[
|
||||
{"mask": {"hash": "115ad19f5f", "shape": (480, 640)}, "scores": 1.0444},
|
||||
{"mask": {"hash": "6affa964c6", "shape": (480, 640)}, "scores": 1.0210},
|
||||
{"mask": {"hash": "dfe28a0388", "shape": (480, 640)}, "scores": 1.0167},
|
||||
{"mask": {"hash": "c0a5f4a318", "shape": (480, 640)}, "scores": 1.0132},
|
||||
{"mask": {"hash": "fe8065c197", "shape": (480, 640)}, "scores": 1.0053},
|
||||
],
|
||||
)
|
@ -98,6 +98,7 @@ PIPELINE_TAGS_AND_AUTO_MODELS = [
|
||||
),
|
||||
("depth-estimation", "MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES", "AutoModelForDepthEstimation"),
|
||||
("video-classification", "MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING_NAMES", "AutoModelForVideoClassification"),
|
||||
("mask-generation", "MODEL_FOR_MASK_GENERATION_MAPPING_NAMES", "AutoModelForMaskGeneration"),
|
||||
]
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user