Add automatic-mask-generation pipeline for Segment Anything Model (SAM) (#22840)

* cleanup

* updates

* more refactoring

* make style

* update inits

* support other inputs in base

* update based on review

Co-authored-by: Nicolas Patry <patry.nicolas@gmail.com>

* Update tests/pipelines/test_pipelines_automatic_mask_generation.py

Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>

* update

* fixup

* TODO x and y to refactor, _h _w refactored here

* update docstring

* more nits

* style on these

* more doc fix

* rename variables

* update

* updates

* style

* update

* fix `_mask_to_rle_pytorch`

* styling

* fix ask to rle, wrong outputs

* add device arg

* update

* more updates, fix tets

* udpate

* update docstrings

* styling

* fixup

* add notebook on the docs

* update orginal sizes

* fix docstring

* updat condition on point_per-batch

* updates tests

* fix CI  test

* extend is required, append does not work!

* fixup

* fix CI tests

* whit pixels left

* address doc comments

* fix doc

* slow pipeline tests

* update auto init

* add revision

* make fixup

* update p!ipoeline tag when calling tests

* alphabeitcal order in inits

* fix copies

* last style nits

* Apply suggestions from code review

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* reformat docstring

* more reformat

* address most of the comments

* Update src/transformers/pipelines/mask_generation.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* final refactor

* Update src/transformers/models/sam/image_processing_sam.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* fixup and fix slow tests

* revert

---------

Co-authored-by: Nicolas Patry <patry.nicolas@gmail.com>
Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
Co-authored-by: younesbelkada <younesbelkada@gmail.com>
Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
Arthur 2023-04-20 19:27:24 +02:00 committed by GitHub
parent e5f3487190
commit f143037789
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 928 additions and 18 deletions

View File

@ -64,6 +64,7 @@ scores = outputs.iou_scores
Resources:
- [Demo notebook](https://github.com/huggingface/notebooks/blob/main/examples/segment_anything.ipynb) for using the model
- [Demo notebook](https://github.com/huggingface/notebooks/blob/main/examples/automatic_mask_generation.ipynb) for using automatic mask generation pipeline.
## SamConfig

View File

@ -1012,6 +1012,7 @@ else:
"MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING",
"MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING",
"MODEL_FOR_MASKED_LM_MAPPING",
"MODEL_FOR_MASK_GENERATION_MAPPING",
"MODEL_FOR_MULTIPLE_CHOICE_MAPPING",
"MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING",
"MODEL_FOR_OBJECT_DETECTION_MAPPING",
@ -4650,6 +4651,7 @@ if TYPE_CHECKING:
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
MODEL_FOR_IMAGE_SEGMENTATION_MAPPING,
MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING,
MODEL_FOR_MASK_GENERATION_MAPPING,
MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING,
MODEL_FOR_MASKED_LM_MAPPING,
MODEL_FOR_MULTIPLE_CHOICE_MAPPING,

View File

@ -52,6 +52,7 @@ else:
"MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING",
"MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING",
"MODEL_FOR_MASKED_LM_MAPPING",
"MODEL_FOR_MASK_GENERATION_MAPPING",
"MODEL_FOR_MULTIPLE_CHOICE_MAPPING",
"MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING",
"MODEL_FOR_OBJECT_DETECTION_MAPPING",
@ -213,6 +214,7 @@ if TYPE_CHECKING:
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
MODEL_FOR_IMAGE_SEGMENTATION_MAPPING,
MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING,
MODEL_FOR_MASK_GENERATION_MAPPING,
MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING,
MODEL_FOR_MASKED_LM_MAPPING,
MODEL_FOR_MULTIPLE_CHOICE_MAPPING,

View File

@ -977,7 +977,7 @@ MODEL_FOR_BACKBONE_MAPPING_NAMES = OrderedDict(
]
)
MODEL_FOR_AUTOMATIC_MASK_GENERATION_MAPPING_NAMES = OrderedDict(
MODEL_FOR_MASK_GENERATION_MAPPING_NAMES = OrderedDict(
[
("sam", "SamModel"),
]
@ -1058,9 +1058,11 @@ MODEL_FOR_AUDIO_XVECTOR_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_F
MODEL_FOR_BACKBONE_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_BACKBONE_MAPPING_NAMES)
MODEL_FOR_AUTOMATIC_MASK_GENERATION_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, MODEL_FOR_AUTOMATIC_MASK_GENERATION_MAPPING_NAMES
)
MODEL_FOR_MASK_GENERATION_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_MASK_GENERATION_MAPPING_NAMES)
class AutoModelForMaskGeneration(_BaseAutoModelClass):
_model_mapping = MODEL_FOR_MASK_GENERATION_MAPPING
class AutoModel(_BaseAutoModelClass):

View File

@ -13,7 +13,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Image processor class for SAM."""
from typing import Dict, List, Optional, Tuple, Union
import math
from copy import deepcopy
from itertools import product
from typing import Any, Dict, List, Optional, Tuple, Union
import numpy as np
@ -26,16 +29,20 @@ from ...image_utils import (
ImageInput,
PILImageResampling,
get_image_size,
infer_channel_dimension_format,
make_list_of_images,
to_numpy_array,
valid_images,
)
from ...utils import TensorType, is_torch_available, logging, requires_backends
from ...utils import TensorType, is_torch_available, is_torchvision_available, logging, requires_backends
if is_torch_available():
import torch
import torch.nn.functional as F
if is_torchvision_available():
from torchvision.ops.boxes import batched_nms
logger = logging.get_logger(__name__)
@ -354,10 +361,14 @@ class SamImageProcessor(BaseImageProcessor):
images = [self.pad_image(image=image, pad_size=pad_size) for image in images]
images = [to_channel_dimension_format(image, data_format) for image in images]
data = {"pixel_values": images, "original_sizes": original_sizes, "reshaped_input_sizes": reshaped_input_sizes}
encoded_outputs = BatchFeature(data=data, tensor_type=return_tensors)
encoded_outputs = BatchFeature(
data={
"pixel_values": images,
"original_sizes": original_sizes,
"reshaped_input_sizes": reshaped_input_sizes,
},
tensor_type=return_tensors,
)
return encoded_outputs
def post_process_masks(
@ -392,11 +403,453 @@ class SamImageProcessor(BaseImageProcessor):
for i, original_size in enumerate(original_sizes):
interpolated_mask = F.interpolate(masks[i], target_image_size, mode="bilinear", align_corners=False)
interpolated_mask = interpolated_mask[..., : reshaped_input_sizes[i][0], : reshaped_input_sizes[i][1]]
interpolated_mask = F.interpolate(
interpolated_mask, [*original_size.numpy()], mode="bilinear", align_corners=False
)
interpolated_mask = F.interpolate(interpolated_mask, original_size, mode="bilinear", align_corners=False)
if binarize:
interpolated_mask = interpolated_mask > mask_threshold
output_masks.append(interpolated_mask)
return output_masks
def post_process_for_mask_generation(self, all_masks, all_scores, all_boxes, crops_nms_thresh):
"""
Post processes mask that are generated by calling the Non Maximum Suppression algorithm on the predicted masks.
Args:
all_masks (`List[torch.Tensor]`):
List of all predicted segmentation masks
all_scores (`List[torch.Tensor]`):
List of all predicted iou scores
all_boxes (`List[torch.Tensor]`):
List of all bounding boxes of the predicted masks
crops_nms_thresh (`float`):
Threshold for NMS (Non Maximum Suppression) algorithm.
"""
return _postprocess_for_mg(all_masks, all_scores, all_boxes, crops_nms_thresh)
def generate_crop_boxes(
self,
image,
target_size,
crop_n_layers: int = 0,
overlap_ratio: float = 512 / 1500,
points_per_crop: Optional[int] = 32,
crop_n_points_downscale_factor: Optional[List[int]] = 1,
device: Optional["torch.device"] = None,
):
"""
Generates a list of crop boxes of different sizes. Each layer has (2**i)**2 boxes for the ith layer.
Args:
image (`np.array`):
Input original image
target_size (`int`):
Target size of the resized image
crop_n_layers (`int`, *optional*, defaults to 0):
If >0, mask prediction will be run again on crops of the image. Sets the number of layers to run, where
each layer has 2**i_layer number of image crops.
overlap_ratio (`float`, *optional*, defaults to 512/1500):
Sets the degree to which crops overlap. In the first crop layer, crops will overlap by this fraction of
the image length. Later layers with more crops scale down this overlap.
points_per_crop (`int`, *optional*, defaults to 32):
Number of points to sample from each crop.
crop_n_points_downscale_factor (`List[int]`, *optional*, defaults to 1):
The number of points-per-side sampled in layer n is scaled down by crop_n_points_downscale_factor**n.
device (`torch.device`, *optional*, defaults to None):
Device to use for the computation. If None, cpu will be used.
"""
return _generate_crop_boxes(
image, target_size, crop_n_layers, overlap_ratio, points_per_crop, crop_n_points_downscale_factor, device
)
def filter_masks(
self,
masks,
iou_scores,
original_size,
cropped_box_image,
pred_iou_thresh=0.88,
stability_score_thresh=0.95,
mask_threshold=0,
stability_score_offset=1,
):
"""
Filters the predicted masks by selecting only the ones that meets several criteria. The first criterion being
that the iou scores needs to be greater than `pred_iou_thresh`. The second criterion is that the stability
score needs to be greater than `stability_score_thresh`. The method also converts the predicted masks to
bounding boxes and pad the predicted masks if necessary.
Args:
masks (`torch.Tensor`):
Input masks.
iou_scores (`torch.Tensor`):
List of IoU scores.
original_size (`Tuple[int,int]`):
Size of the orginal image.
cropped_box_image (`np.array`):
The cropped image.
pred_iou_thresh (`float`, *optional*, defaults to 0.88):
The threshold for the iou scores.
stability_score_thresh (`float`, *optional*, defaults to 0.95):
The threshold for the stability score.
mask_threshold (`float`, *optional*, defaults to 0):
The threshold for the predicted masks.
stability_score_offset (`float`, *optional*, defaults to 1):
The offset for the stability score used in the `_compute_stability_score` method.
"""
requires_backends(self, ["torch"])
original_height, original_width = original_size
iou_scores = iou_scores.flatten(0, 1)
masks = masks.flatten(0, 1)
if masks.shape[0] != iou_scores.shape[0]:
raise ValueError("masks and iou_scores must have the same batch size.")
if masks.device != iou_scores.device:
iou_scores = iou_scores.to(masks.device)
batch_size = masks.shape[0]
keep_mask = torch.ones(batch_size, dtype=torch.bool, device=masks.device)
if pred_iou_thresh > 0.0:
keep_mask = keep_mask & (iou_scores > pred_iou_thresh)
# compute stability score
if stability_score_thresh > 0.0:
stability_scores = _compute_stability_score(masks, mask_threshold, stability_score_offset)
keep_mask = keep_mask & (stability_scores > stability_score_thresh)
scores = iou_scores[keep_mask]
masks = masks[keep_mask]
# binarize masks
masks = masks > mask_threshold
converted_boxes = _batched_mask_to_box(masks)
keep_mask = ~_is_box_near_crop_edge(
converted_boxes, cropped_box_image, [0, 0, original_width, original_height]
)
scores = scores[keep_mask]
masks = masks[keep_mask]
converted_boxes = converted_boxes[keep_mask]
masks = _pad_masks(masks, cropped_box_image, original_height, original_width)
# conversion to rle is necessary to run non-maximum suppresion
masks = _mask_to_rle_pytorch(masks)
return masks, scores, converted_boxes
def _compute_stability_score(masks: "torch.Tensor", mask_threshold: float, stability_score_offset: int):
# One mask is always contained inside the other.
# Save memory by preventing unnecesary cast to torch.int64
intersections = (
(masks > (mask_threshold + stability_score_offset)).sum(-1, dtype=torch.int16).sum(-1, dtype=torch.int32)
)
unions = (masks > (mask_threshold - stability_score_offset)).sum(-1, dtype=torch.int16).sum(-1, dtype=torch.int32)
stability_scores = intersections / unions
return stability_scores
def _build_point_grid(n_per_side: int) -> np.ndarray:
"""Generates a 2D grid of points evenly spaced in [0,1]x[0,1]."""
offset = 1 / (2 * n_per_side)
points_one_side = np.linspace(offset, 1 - offset, n_per_side)
points_x = np.tile(points_one_side[None, :], (n_per_side, 1))
points_y = np.tile(points_one_side[:, None], (1, n_per_side))
points = np.stack([points_x, points_y], axis=-1).reshape(-1, 2)
return points
def _normalize_coordinates(
target_size: int, coords: np.ndarray, original_size: Tuple[int, int], is_bounding_box=False
) -> np.ndarray:
"""
Expects a numpy array of length 2 in the final dimension. Requires the original image size in (height, width)
format.
"""
old_height, old_width = original_size
scale = target_size * 1.0 / max(old_height, old_width)
new_height, new_width = old_height * scale, old_width * scale
new_width = int(new_width + 0.5)
new_height = int(new_height + 0.5)
coords = deepcopy(coords).astype(float)
if is_bounding_box:
coords = coords.reshape(-1, 2, 2)
coords[..., 0] = coords[..., 0] * (new_width / old_width)
coords[..., 1] = coords[..., 1] * (new_height / old_height)
if is_bounding_box:
coords = coords.reshape(-1, 4)
return coords
def _generate_crop_boxes(
image,
target_size: int, # Is it tuple here?
crop_n_layers: int = 0,
overlap_ratio: float = 512 / 1500,
points_per_crop: Optional[int] = 32,
crop_n_points_downscale_factor: Optional[List[int]] = 1,
device: Optional["torch.device"] = None,
) -> Tuple[List[List[int]], List[int]]:
"""
Generates a list of crop boxes of different sizes. Each layer has (2**i)**2 boxes for the ith layer.
Args:
image (Union[`numpy.ndarray`, `PIL.Image`, `torch.Tensor`]):
Image to generate crops for.
target_size (`int`):
Size of the smallest crop.
crop_n_layers (`int`, *optional*):
If `crops_n_layers>0`, mask prediction will be run again on crops of the image. Sets the number of layers
to run, where each layer has 2**i_layer number of image crops.
overlap_ratio (`int`, *optional*):
Sets the degree to which crops overlap. In the first crop layer, crops will overlap by this fraction of the
image length. Later layers with more crops scale down this overlap.
points_per_crop (`int`, *optional*):
Number of points to sample per crop.
crop_n_points_downscale_factor (`int`, *optional*):
The number of points-per-side sampled in layer n is scaled down by crop_n_points_downscale_factor**n.
device (`torch.device`, *optional*):
Device to run the crop generation on. Defaults to CPU.
"""
if device is None:
device = torch.device("cpu")
if isinstance(image, list):
raise ValueError("Only one image is allowed for crop generation.")
image = to_numpy_array(image)
original_size = get_image_size(image)
points_grid = []
for i in range(crop_n_layers + 1):
n_points = int(points_per_crop / (crop_n_points_downscale_factor**i))
points_grid.append(_build_point_grid(n_points))
crop_boxes, layer_idxs = _generate_per_layer_crops(crop_n_layers, overlap_ratio, original_size)
cropped_images, point_grid_per_crop = _generate_crop_images(
crop_boxes, image, points_grid, layer_idxs, target_size, original_size
)
crop_boxes = torch.tensor(crop_boxes, dtype=torch.float32, device=device)
point_grid_per_crop = np.array([point_grid_per_crop])
points_per_crop = torch.tensor(point_grid_per_crop, device=device)
points_per_crop = points_per_crop.permute(0, 2, 1, 3)
input_labels = torch.ones_like(points_per_crop[:, :, :, 0], dtype=torch.long, device=device)
return crop_boxes, points_per_crop, cropped_images, input_labels
def _generate_per_layer_crops(crop_n_layers, overlap_ratio, original_size):
"""
Generates 2 ** (layers idx + 1) crops for each crop_n_layers. Crops are in the XYWH format : The XYWH format
consists of the following required indices:
- X: X coordinate of the top left of the bounding box
- Y: Y coordinate of the top left of the bounding box
- W: width of the bounding box
- H: height of the bounding box
"""
crop_boxes, layer_idxs = [], []
im_height, im_width = original_size
short_side = min(im_height, im_width)
# Original image
crop_boxes.append([0, 0, im_width, im_height])
layer_idxs.append(0)
for i_layer in range(crop_n_layers):
n_crops_per_side = 2 ** (i_layer + 1)
overlap = int(overlap_ratio * short_side * (2 / n_crops_per_side))
crop_width = int(math.ceil((overlap * (n_crops_per_side - 1) + im_width) / n_crops_per_side))
crop_height = int(math.ceil((overlap * (n_crops_per_side - 1) + im_height) / n_crops_per_side))
crop_box_x0 = [int((crop_width - overlap) * i) for i in range(n_crops_per_side)]
crop_box_y0 = [int((crop_height - overlap) * i) for i in range(n_crops_per_side)]
for left, top in product(crop_box_x0, crop_box_y0):
box = [left, top, min(left + crop_width, im_width), min(top + crop_height, im_height)]
crop_boxes.append(box)
layer_idxs.append(i_layer + 1)
return crop_boxes, layer_idxs
def _generate_crop_images(crop_boxes, image, points_grid, layer_idxs, target_size, original_size):
"""
Takes as an input bounding boxes that are used to crop the image. Based in the crops, the corresponding points are
also passed.
"""
cropped_images = []
total_points_per_crop = []
for i, crop_box in enumerate(crop_boxes):
left, top, right, bottom = crop_box
channel_dim = infer_channel_dimension_format(image)
if channel_dim == ChannelDimension.LAST:
cropped_im = image[top:bottom, left:right, :]
else:
cropped_im = image[:, top:bottom, left:right]
cropped_images.append(cropped_im)
cropped_im_size = get_image_size(cropped_im)
points_scale = np.array(cropped_im_size)[None, ::-1]
points = points_grid[layer_idxs[i]] * points_scale
normalized_points = _normalize_coordinates(target_size, points, original_size)
total_points_per_crop.append(normalized_points)
return cropped_images, total_points_per_crop
def _pad_masks(masks, crop_box: List[int], orig_height: int, orig_width: int):
left, top, right, bottom = crop_box
if left == 0 and top == 0 and right == orig_width and bottom == orig_height:
return masks
# Coordinate transform masks
pad_x, pad_y = orig_width - (right - left), orig_height - (bottom - top)
pad = (left, pad_x - left, top, pad_y - top)
return torch.nn.functional.pad(masks, pad, value=0)
def _is_box_near_crop_edge(boxes, crop_box, orig_box, atol=20.0):
"""Filter masks at the edge of a crop, but not at the edge of the original image."""
crop_box_torch = torch.as_tensor(crop_box, dtype=torch.float, device=boxes.device)
orig_box_torch = torch.as_tensor(orig_box, dtype=torch.float, device=boxes.device)
left, top, _, _ = crop_box
offset = torch.tensor([[left, top, left, top]], device=boxes.device)
# Check if boxes has a channel dimension
if len(boxes.shape) == 3:
offset = offset.unsqueeze(1)
boxes = (boxes + offset).float()
near_crop_edge = torch.isclose(boxes, crop_box_torch[None, :], atol=atol, rtol=0)
near_image_edge = torch.isclose(boxes, orig_box_torch[None, :], atol=atol, rtol=0)
near_crop_edge = torch.logical_and(near_crop_edge, ~near_image_edge)
return torch.any(near_crop_edge, dim=1)
def _batched_mask_to_box(masks: "torch.Tensor"):
"""
Computes the bounding boxes around the given input masks. The bounding boxes are in the XYXY format which
corresponds the following required indices:
- LEFT: left hand side of the bounding box
- TOP: top of the bounding box
- RIGHT: right of the bounding box
- BOTTOM: bottom of the bounding box
Return [0,0,0,0] for an empty mask. For input shape channel_1 x channel_2 x ... x height x width, the output shape
is channel_1 x channel_2 x ... x 4.
Args:
- masks (`torch.Tensor` of shape `(batch, nb_mask, height, width)`)
"""
# torch.max below raises an error on empty inputs, just skip in this case
if torch.numel(masks) == 0:
return torch.zeros(*masks.shape[:-2], 4, device=masks.device)
# Normalize shape to Cxheightxwidth
shape = masks.shape
height, width = shape[-2:]
# Get top and bottom edges
in_height, _ = torch.max(masks, dim=-1)
in_height_coords = in_height * torch.arange(height, device=in_height.device)[None, :]
bottom_edges, _ = torch.max(in_height_coords, dim=-1)
in_height_coords = in_height_coords + height * (~in_height)
top_edges, _ = torch.min(in_height_coords, dim=-1)
# Get left and right edges
in_width, _ = torch.max(masks, dim=-2)
in_width_coords = in_width * torch.arange(width, device=in_width.device)[None, :]
right_edges, _ = torch.max(in_width_coords, dim=-1)
in_width_coords = in_width_coords + width * (~in_width)
left_edges, _ = torch.min(in_width_coords, dim=-1)
# If the mask is empty the right edge will be to the left of the left edge.
# Replace these boxes with [0, 0, 0, 0]
empty_filter = (right_edges < left_edges) | (bottom_edges < top_edges)
out = torch.stack([left_edges, top_edges, right_edges, bottom_edges], dim=-1)
out = out * (~empty_filter).unsqueeze(-1)
# Return to original shape
out = out.reshape(*shape[:-2], 4)
return out
def _mask_to_rle_pytorch(input_mask: "torch.Tensor"):
"""
Encodes masks the run-length encoding (RLE), in the format expected by pycoco tools.
"""
# Put in fortran order and flatten height and width
batch_size, height, width = input_mask.shape
input_mask = input_mask.permute(0, 2, 1).flatten(1)
# Compute change indices
diff = input_mask[:, 1:] ^ input_mask[:, :-1]
change_indices = diff.nonzero()
# Encode run length
out = []
for i in range(batch_size):
cur_idxs = change_indices[change_indices[:, 0] == i, 1] + 1
btw_idxs = cur_idxs[1:] - cur_idxs[:-1]
counts = [] if input_mask[i, 0] == 0 else [0]
counts += [cur_idxs[0].item()] + btw_idxs.tolist() + [height * width - cur_idxs[-1]]
out.append({"size": [height, width], "counts": counts})
return out
def _rle_to_mask(rle: Dict[str, Any]) -> np.ndarray:
"""Compute a binary mask from an uncompressed RLE."""
height, width = rle["size"]
mask = np.empty(height * width, dtype=bool)
idx = 0
parity = False
for count in rle["counts"]:
mask[idx : idx + count] = parity
idx += count
parity = not parity
mask = mask.reshape(width, height)
return mask.transpose() # Reshape to original shape
def _postprocess_for_mg(rle_masks, iou_scores, mask_boxes, amg_crops_nms_thresh=0.7):
"""
Perform NMS (Non Maxium Suppression) on the outputs.
Args:
rle_masks (`torch.Tensor`):
binary masks in the RLE format
iou_scores (`torch.Tensor` of shape (nb_masks, 1)):
iou_scores predicted by the model
mask_boxes (`torch.Tensor`):
The bounding boxes corresponding to segmentation masks
amg_crops_nms_thresh (`float`, *optional*, defaults to 0.7):
NMS threshold.
"""
keep_by_nms = batched_nms(
boxes=mask_boxes.float(),
scores=iou_scores,
idxs=torch.zeros(mask_boxes.shape[0]),
iou_threshold=amg_crops_nms_thresh,
)
iou_scores = iou_scores[keep_by_nms]
rle_masks = [rle_masks[i] for i in keep_by_nms]
mask_boxes = mask_boxes[keep_by_nms]
masks = [_rle_to_mask(rle) for rle in rle_masks]
return masks, iou_scores, rle_masks, mask_boxes

View File

@ -63,6 +63,7 @@ from .fill_mask import FillMaskPipeline
from .image_classification import ImageClassificationPipeline
from .image_segmentation import ImageSegmentationPipeline
from .image_to_text import ImageToTextPipeline
from .mask_generation import MaskGenerationPipeline
from .object_detection import ObjectDetectionPipeline
from .question_answering import QuestionAnsweringArgumentHandler, QuestionAnsweringPipeline
from .table_question_answering import TableQuestionAnsweringArgumentHandler, TableQuestionAnsweringPipeline
@ -124,6 +125,7 @@ if is_torch_available():
AutoModelForImageClassification,
AutoModelForImageSegmentation,
AutoModelForMaskedLM,
AutoModelForMaskGeneration,
AutoModelForObjectDetection,
AutoModelForQuestionAnswering,
AutoModelForSemanticSegmentation,
@ -384,6 +386,13 @@ SUPPORTED_TASKS = {
"default": {"model": {"pt": ("MCG-NJU/videomae-base-finetuned-kinetics", "4800870")}},
"type": "video",
},
"mask-generation": {
"impl": MaskGenerationPipeline,
"tf": (),
"pt": (AutoModelForMaskGeneration,) if is_torch_available() else (),
"default": {"model": {"pt": ("facebook/sam-vit-huge", "997b15")}},
"type": "multimodal",
},
}
NO_FEATURE_EXTRACTOR_TASKS = set()
@ -536,6 +545,7 @@ def pipeline(
- `"image-classification"`: will return a [`ImageClassificationPipeline`].
- `"image-segmentation"`: will return a [`ImageSegmentationPipeline`].
- `"image-to-text"`: will return a [`ImageToTextPipeline`].
- `"mask-generation"`: will return a [`MaskGenerationPipeline`].
- `"object-detection"`: will return a [`ObjectDetectionPipeline`].
- `"question-answering"`: will return a [`QuestionAnsweringPipeline`].
- `"summarization"`: will return a [`SummarizationPipeline`].

View File

@ -97,6 +97,8 @@ def _pad(items, key, padding_value, padding_side):
tensor = torch.zeros((batch_size, max_length), dtype=dtype) + padding_value
elif dim == 3:
tensor = torch.zeros((batch_size, max_length, shape[-1]), dtype=dtype) + padding_value
elif dim == 4:
tensor = torch.zeros((batch_size, max_length, shape[-2], shape[-1]), dtype=dtype) + padding_value
for i, item in enumerate(items):
if dim == 2:
@ -109,6 +111,12 @@ def _pad(items, key, padding_value, padding_side):
tensor[i, -len(item[key][0]) :, :] = item[key][0].clone()
else:
tensor[i, : len(item[key][0]), :] = item[key][0].clone()
elif dim == 4:
if padding_side == "left":
tensor[i, -len(item[key][0]) :, :, :] = item[key][0].clone()
else:
tensor[i, : len(item[key][0]), :, :] = item[key][0].clone()
return tensor
else:
return [item[key] for item in items]

View File

@ -81,11 +81,11 @@ class ImageSegmentationPipeline(Pipeline):
)
def _sanitize_parameters(self, **kwargs):
preprocessor_kwargs = {}
preprocess_kwargs = {}
postprocess_kwargs = {}
if "subtask" in kwargs:
postprocess_kwargs["subtask"] = kwargs["subtask"]
preprocessor_kwargs["subtask"] = kwargs["subtask"]
preprocess_kwargs["subtask"] = kwargs["subtask"]
if "threshold" in kwargs:
postprocess_kwargs["threshold"] = kwargs["threshold"]
if "mask_threshold" in kwargs:
@ -93,7 +93,7 @@ class ImageSegmentationPipeline(Pipeline):
if "overlap_mask_area_threshold" in kwargs:
postprocess_kwargs["overlap_mask_area_threshold"] = kwargs["overlap_mask_area_threshold"]
return preprocessor_kwargs, {}, postprocess_kwargs
return preprocess_kwargs, {}, postprocess_kwargs
def __call__(self, images, **kwargs) -> Union[Predictions, List[Prediction]]:
"""

View File

@ -0,0 +1,286 @@
from collections import defaultdict
from typing import Optional
from ..image_utils import load_image
from ..utils import (
add_end_docstrings,
is_torch_available,
logging,
requires_backends,
)
from .base import PIPELINE_INIT_ARGS, ChunkPipeline
if is_torch_available():
import torch
from ..models.auto.modeling_auto import MODEL_FOR_MASK_GENERATION_MAPPING
logger = logging.get_logger(__name__)
@add_end_docstrings(PIPELINE_INIT_ARGS)
class MaskGenerationPipeline(ChunkPipeline):
"""
Automatic mask generation for images using `SamForMaskGeneration`. This pipeline predicts binary masks for an
image, given an image. It is a `ChunkPipeline` because you can seperate the points in a mini-batch in order to
avoid OOM issues. Use the `points_per_batch` argument to control the number of points that will be processed at the
same time. Default is `64`.
The pipeline works in 3 steps:
1. `preprocess`: A grid of 1024 points evenly separated is generated along with bounding boxes and point
labels.
For more details on how the points and bounding boxes are created, check the `_generate_crop_boxes`
function. The image is also preprocessed using the `image_processor`. This function `yields` a minibatch of
`points_per_batch`.
2. `forward`: feeds the outputs of `preprocess` to the model. The image embedding is computed only once.
Calls both `self.model.get_image_embeddings` and makes sure that the gradients are not computed, and the
tensors and models are on the same device.
3. `postprocess`: The most important part of the automatic mask generation happens here. Three steps
are induced:
- image_processor.postprocess_masks (run on each minibatch loop): takes in the raw output masks,
resizes them according
to the image size, and transforms there to binary masks.
- image_processor.filter_masks (on each minibatch loop): uses both `pred_iou_thresh` and
`stability_scores`. Also
applies a variety of filters based on non maximum suppression to remove bad masks.
- image_processor.postprocess_masks_for_amg applies the NSM on the mask to only keep relevant ones.
Arguments:
model ([`PreTrainedModel`] or [`TFPreTrainedModel`]):
The model that will be used by the pipeline to make predictions. This needs to be a model inheriting from
[`PreTrainedModel`] for PyTorch and [`TFPreTrainedModel`] for TensorFlow.
tokenizer ([`PreTrainedTokenizer`]):
The tokenizer that will be used by the pipeline to encode data for the model. This object inherits from
[`PreTrainedTokenizer`].
feature_extractor ([`SequenceFeatureExtractor`]):
The feature extractor that will be used by the pipeline to encode the input.
points_per_batch (*optional*, int, default to 64):
Sets the number of points run simultaneously by the model. Higher numbers may be faster but use more GPU
memory.
output_bboxes_mask (`bool`, *optional*, default to `False`):
Whether or not to output the bounding box predictions.
output_rle_masks (`bool`, *optional*, default to `False`):
Whether or not to output the masks in `RLE` format
Example:
```python
>>> from transformers import pipeline
>>> generator = pipeline(model="facebook/sam-vit-h", task="mask-generation")
>>> outputs = generator(
... "http://images.cocodataset.org/val2017/000000039769.jpg",
... )
>>> outputs = generator(
... "https://huggingface.co/datasets/Narsil/image_dummy/raw/main/parrots.png", points_per_batch=128
... )
```
Learn more about the basics of using a pipeline in the [pipeline tutorial](../pipeline_tutorial)
This segmentation pipeline can currently be loaded from [`pipeline`] using the following task identifier:
`"mask-generation"`.
See the list of available models on [huggingface.co/models](https://huggingface.co/models?filter=mask-generation).
"""
def __init__(self, **kwargs):
super().__init__(**kwargs)
requires_backends(self, "vision")
requires_backends(self, "torch")
if self.framework != "pt":
raise ValueError(f"The {self.__class__} is only available in PyTorch.")
self.check_model_type(MODEL_FOR_MASK_GENERATION_MAPPING)
def _sanitize_parameters(self, **kwargs):
preprocess_kwargs = {}
postprocess_kwargs = {}
forward_params = {}
# preprocess args
if "points_per_batch" in kwargs:
preprocess_kwargs["points_per_batch"] = kwargs["points_per_batch"]
if "points_per_crop" in kwargs:
preprocess_kwargs["points_per_crop"] = kwargs["points_per_crop"]
if "crops_n_layers" in kwargs:
preprocess_kwargs["crops_n_layers"] = kwargs["crops_n_layers"]
if "crop_overlap_ratio" in kwargs:
preprocess_kwargs["crop_overlap_ratio"] = kwargs["crop_overlap_ratio"]
if "crop_n_points_downscale_factor" in kwargs:
preprocess_kwargs["crop_n_points_downscale_factor"] = kwargs["crop_n_points_downscale_factor"]
# postprocess args
if "pred_iou_thresh" in kwargs:
forward_params["pred_iou_thresh"] = kwargs["pred_iou_thresh"]
if "stability_score_offset" in kwargs:
forward_params["stability_score_offset"] = kwargs["stability_score_offset"]
if "mask_threshold" in kwargs:
forward_params["mask_threshold"] = kwargs["mask_threshold"]
if "stability_score_thresh" in kwargs:
forward_params["stability_score_thresh"] = kwargs["stability_score_thresh"]
if "crops_nms_thresh" in kwargs:
postprocess_kwargs["crops_nms_thresh"] = kwargs["crops_nms_thresh"]
if "output_rle_mask" in kwargs:
postprocess_kwargs["output_rle_mask"] = kwargs["output_rle_mask"]
if "output_bboxes_mask" in kwargs:
postprocess_kwargs["output_bboxes_mask"] = kwargs["output_bboxes_mask"]
return preprocess_kwargs, forward_params, postprocess_kwargs
def __call__(self, image, *args, num_workers=None, batch_size=None, **kwargs):
"""
Generates binary segmentation masks
Args:
inputs (`np.ndarray` or `bytes` or `str` or `dict`):
Image or list of images.
mask_threshold (`float`, *optional*, defaults to 0.0):
Threshold to use when turning the predicted masks into binary values.
pred_iou_thresh (`float`, *optional*, defaults to 0.88):
A filtering threshold in `[0,1]` applied on the model's predicted mask quality.
stability_score_thresh (`float`, *optional*, defaults to 0.95):
A filtering threshold in `[0,1]`, using the stability of the mask under changes to the cutoff used to
binarize the model's mask predictions.
stability_score_offset (`int`, *optional*, defaults to 1):
The amount to shift the cutoff when calculated the stability score.
crops_nms_thresh (`float`, *optional*, defaults to 0.7):
The box IoU cutoff used by non-maximal suppression to filter duplicate masks.
crops_n_layers (`int`, *optional*, defaults to 0):
If `crops_n_layers>0`, mask prediction will be run again on crops of the image. Sets the number of
layers to run, where each layer has 2**i_layer number of image crops.
crop_overlap_ratio (`float`, *optional*, defaults to `512 / 1500`):
Sets the degree to which crops overlap. In the first crop layer, crops will overlap by this fraction of
the image length. Later layers with more crops scale down this overlap.
crop_n_points_downscale_factor (`int`, *optional*, defaults to `1`):
The number of points-per-side sampled in layer n is scaled down by crop_n_points_downscale_factor**n.
Return:
`Dict`: A dictionary with the following keys:
- **mask** (`PIL.Image`) -- A binary mask of the detected object as a PIL Image of shape `(width,
height)` of the original image. Returns a mask filled with zeros if no object is found.
- **score** (*optional* `float`) -- Optionally, when the model is capable of estimating a confidence of
the "object" described by the label and the mask.
"""
return super().__call__(image, *args, num_workers=num_workers, batch_size=batch_size, **kwargs)
def preprocess(
self,
image,
points_per_batch=64,
crops_n_layers: int = 0,
crop_overlap_ratio: float = 512 / 1500,
points_per_crop: Optional[int] = 32,
crop_n_points_downscale_factor: Optional[int] = 1,
):
image = load_image(image)
target_size = self.image_processor.size["longest_edge"]
crop_boxes, grid_points, cropped_images, input_labels = self.image_processor.generate_crop_boxes(
image, target_size, crops_n_layers, crop_overlap_ratio, points_per_crop, crop_n_points_downscale_factor
)
model_inputs = self.image_processor(images=cropped_images, return_tensors="pt")
with self.device_placement():
if self.framework == "pt":
inference_context = self.get_inference_context()
with inference_context():
model_inputs = self._ensure_tensor_on_device(model_inputs, device=self.device)
image_embeddings = self.model.get_image_embeddings(model_inputs.pop("pixel_values"))
model_inputs["image_embeddings"] = image_embeddings
n_points = grid_points.shape[1]
points_per_batch = points_per_batch if points_per_batch is not None else n_points
if points_per_batch <= 0:
raise ValueError(
"Cannot have points_per_batch<=0. Must be >=1 to returned batched outputs. "
"To return all points at once, set points_per_batch to None"
)
for i in range(0, n_points, points_per_batch):
batched_points = grid_points[:, i : i + points_per_batch, :, :]
labels = input_labels[:, i : i + points_per_batch]
is_last = i == n_points - points_per_batch
yield {
"input_points": batched_points,
"input_labels": labels,
"input_boxes": crop_boxes,
"is_last": is_last,
**model_inputs,
}
def _forward(
self,
model_inputs,
pred_iou_thresh=0.88,
stability_score_thresh=0.95,
mask_threshold=0,
stability_score_offset=1,
):
input_boxes = model_inputs.pop("input_boxes")
is_last = model_inputs.pop("is_last")
original_sizes = model_inputs.pop("original_sizes").tolist()
reshaped_input_sizes = model_inputs.pop("reshaped_input_sizes").tolist()
model_outputs = self.model(**model_inputs)
# post processing happens here in order to avoid CPU GPU copies of ALL the masks
low_resolution_masks = model_outputs["pred_masks"]
masks = self.image_processor.post_process_masks(
low_resolution_masks, original_sizes, reshaped_input_sizes, mask_threshold, binarize=False
)
iou_scores = model_outputs["iou_scores"]
masks, iou_scores, boxes = self.image_processor.filter_masks(
masks[0],
iou_scores[0],
original_sizes[0],
input_boxes[0],
pred_iou_thresh,
stability_score_thresh,
mask_threshold,
stability_score_offset,
)
return {
"masks": masks,
"is_last": is_last,
"boxes": boxes,
"iou_scores": iou_scores,
}
def postprocess(
self,
model_outputs,
output_rle_mask=False,
output_bboxes_mask=False,
crops_nms_thresh=0.7,
):
all_scores = []
all_masks = []
all_boxes = []
for model_output in model_outputs:
all_scores.append(model_output.pop("iou_scores"))
all_masks.extend(model_output.pop("masks"))
all_boxes.append(model_output.pop("boxes"))
all_scores = torch.cat(all_scores)
all_boxes = torch.cat(all_boxes)
output_masks, iou_scores, rle_mask, bounding_boxes = self.image_processor.post_process_for_mask_generation(
all_masks, all_scores, all_boxes, crops_nms_thresh
)
extra = defaultdict(list)
for output in model_outputs:
for k, v in output.items():
extra[k].append(v)
optional = {}
if output_rle_mask:
optional["rle_mask"] = rle_mask
if output_bboxes_mask:
optional["bounding_boxes"] = bounding_boxes
return {"masks": output_masks, "scores": iou_scores, **optional, **extra}

View File

@ -475,6 +475,9 @@ MODEL_FOR_IMAGE_SEGMENTATION_MAPPING = None
MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING = None
MODEL_FOR_MASK_GENERATION_MAPPING = None
MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING = None

View File

@ -464,8 +464,8 @@ class SamModelIntegrationTest(unittest.TestCase):
self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.5798), atol=1e-4))
def test_inference_mask_generation_one_point_one_bb(self):
model = SamModel.from_pretrained("facebook/sam-vit-h")
processor = SamProcessor.from_pretrained("facebook/sam-vit-h")
model = SamModel.from_pretrained("facebook/sam-vit-huge")
processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
model.to(torch_device)
model.eval()

View File

@ -0,0 +1,142 @@
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import hashlib
import unittest
from typing import Dict
import numpy as np
from transformers import MODEL_FOR_MASK_GENERATION_MAPPING, is_vision_available, pipeline
from transformers.pipelines import MaskGenerationPipeline
from transformers.testing_utils import (
is_pipeline_test,
nested_simplify,
require_tf,
require_torch,
require_vision,
slow,
)
if is_vision_available():
from PIL import Image
def hashimage(image: Image) -> str:
m = hashlib.md5(image.tobytes())
return m.hexdigest()[:10]
def mask_to_test_readable(mask: Image) -> Dict:
npimg = np.array(mask)
shape = npimg.shape
return {"hash": hashimage(mask), "shape": shape}
@is_pipeline_test
@require_vision
@require_torch
class MaskGenerationPipelineTests(unittest.TestCase):
model_mapping = dict(
(list(MODEL_FOR_MASK_GENERATION_MAPPING.items()) if MODEL_FOR_MASK_GENERATION_MAPPING else [])
)
def get_test_pipeline(self, model, tokenizer, processor):
image_segmenter = MaskGenerationPipeline(model=model, image_processor=processor)
return image_segmenter, [
"./tests/fixtures/tests_samples/COCO/000000039769.png",
"./tests/fixtures/tests_samples/COCO/000000039769.png",
]
@require_tf
@unittest.skip("Image segmentation not implemented in TF")
def test_small_model_tf(self):
pass
@slow
@require_torch
def test_small_model_pt(self):
image_segmenter = pipeline("mask-generation", model="facebook/sam-vit-huge")
outputs = image_segmenter("http://images.cocodataset.org/val2017/000000039769.jpg", points_per_batch=256)
# Shortening by hashing
new_outupt = []
for i, o in enumerate(outputs["masks"]):
new_outupt += [{"mask": mask_to_test_readable(o), "scores": outputs["scores"][i]}]
# fmt: off
self.assertEqual(
nested_simplify(new_outupt, decimals=4),
[
{'mask': {'hash': '115ad19f5f', 'shape': (480, 640)}, 'scores': 1.0444},
{'mask': {'hash': '6affa964c6', 'shape': (480, 640)}, 'scores': 1.021},
{'mask': {'hash': 'dfe28a0388', 'shape': (480, 640)}, 'scores': 1.0167},
{'mask': {'hash': 'c0a5f4a318', 'shape': (480, 640)}, 'scores': 1.0132},
{'mask': {'hash': 'fe8065c197', 'shape': (480, 640)}, 'scores': 1.0053},
{'mask': {'hash': 'e2d0b7a0b7', 'shape': (480, 640)}, 'scores': 0.9967},
{'mask': {'hash': '453c7844bd', 'shape': (480, 640)}, 'scores': 0.993},
{'mask': {'hash': '3d44f2926d', 'shape': (480, 640)}, 'scores': 0.9909},
{'mask': {'hash': '64033ddc3f', 'shape': (480, 640)}, 'scores': 0.9879},
{'mask': {'hash': '801064ff79', 'shape': (480, 640)}, 'scores': 0.9834},
{'mask': {'hash': '6172f276ef', 'shape': (480, 640)}, 'scores': 0.9716},
{'mask': {'hash': 'b49e60e084', 'shape': (480, 640)}, 'scores': 0.9612},
{'mask': {'hash': 'a811e775fd', 'shape': (480, 640)}, 'scores': 0.9599},
{'mask': {'hash': 'a6a8ebcf4b', 'shape': (480, 640)}, 'scores': 0.9552},
{'mask': {'hash': '9d8257e080', 'shape': (480, 640)}, 'scores': 0.9532},
{'mask': {'hash': '32de6454a8', 'shape': (480, 640)}, 'scores': 0.9516},
{'mask': {'hash': 'af3d4af2c8', 'shape': (480, 640)}, 'scores': 0.9499},
{'mask': {'hash': '3c6db475fb', 'shape': (480, 640)}, 'scores': 0.9483},
{'mask': {'hash': 'c290813fb9', 'shape': (480, 640)}, 'scores': 0.9464},
{'mask': {'hash': 'b6f0b8f606', 'shape': (480, 640)}, 'scores': 0.943},
{'mask': {'hash': '92ce16bfdf', 'shape': (480, 640)}, 'scores': 0.943},
{'mask': {'hash': 'c749b25868', 'shape': (480, 640)}, 'scores': 0.9408},
{'mask': {'hash': 'efb6cab859', 'shape': (480, 640)}, 'scores': 0.9335},
{'mask': {'hash': '1ff2eafb30', 'shape': (480, 640)}, 'scores': 0.9326},
{'mask': {'hash': '788b798e24', 'shape': (480, 640)}, 'scores': 0.9262},
{'mask': {'hash': 'abea804f0e', 'shape': (480, 640)}, 'scores': 0.8999},
{'mask': {'hash': '7b9e8ddb73', 'shape': (480, 640)}, 'scores': 0.8986},
{'mask': {'hash': 'cd24047c8a', 'shape': (480, 640)}, 'scores': 0.8984},
{'mask': {'hash': '6943e6bcbd', 'shape': (480, 640)}, 'scores': 0.8873},
{'mask': {'hash': 'b5f47c9191', 'shape': (480, 640)}, 'scores': 0.8871}
],
)
# fmt: on
@require_torch
@slow
def test_threshold(self):
model_id = "facebook/sam-vit-huge"
image_segmenter = pipeline("mask-generation", model=model_id)
outputs = image_segmenter(
"http://images.cocodataset.org/val2017/000000039769.jpg", pred_iou_thresh=1, points_per_batch=256
)
# Shortening by hashing
new_outupt = []
for i, o in enumerate(outputs["masks"]):
new_outupt += [{"mask": mask_to_test_readable(o), "scores": outputs["scores"][i]}]
self.assertEqual(
nested_simplify(new_outupt, decimals=4),
[
{"mask": {"hash": "115ad19f5f", "shape": (480, 640)}, "scores": 1.0444},
{"mask": {"hash": "6affa964c6", "shape": (480, 640)}, "scores": 1.0210},
{"mask": {"hash": "dfe28a0388", "shape": (480, 640)}, "scores": 1.0167},
{"mask": {"hash": "c0a5f4a318", "shape": (480, 640)}, "scores": 1.0132},
{"mask": {"hash": "fe8065c197", "shape": (480, 640)}, "scores": 1.0053},
],
)

View File

@ -98,6 +98,7 @@ PIPELINE_TAGS_AND_AUTO_MODELS = [
),
("depth-estimation", "MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES", "AutoModelForDepthEstimation"),
("video-classification", "MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING_NAMES", "AutoModelForVideoClassification"),
("mask-generation", "MODEL_FOR_MASK_GENERATION_MAPPING_NAMES", "AutoModelForMaskGeneration"),
]